277 lines
9.3 KiB
Python
277 lines
9.3 KiB
Python
import csv
|
|
import datetime
|
|
import json
|
|
import os
|
|
import re
|
|
from importlib import import_module
|
|
from json import JSONEncoder
|
|
from uuid import uuid4
|
|
|
|
from config import SQLALCHEMY_DATABASE_URI, MAKEDIRS, DATABASE_FILE, SEARCH_INDEX_PATH, STATIC_DATA_DIR, basepath
|
|
from flask_migrate import Migrate
|
|
from flask_migrate import upgrade as migrate_upgrade
|
|
from flask_migrate import init as migrate_init
|
|
from flask_security import RoleMixin, UserMixin
|
|
from flask_security import Security, SQLAlchemyUserDatastore
|
|
from flask_sqlalchemy import SQLAlchemy
|
|
from markupsafe import escape, Markup
|
|
from sqlalchemy import Boolean
|
|
from sqlalchemy import TypeDecorator
|
|
from sqlalchemy.ext.declarative import declared_attr, DeclarativeMeta
|
|
from sqlalchemy.orm.collections import InstrumentedList
|
|
from sqlalchemy_utils import Choice
|
|
from tww.lib import solve_query, resolve_timezone, dt_tz_translation, time_ago
|
|
from whooshalchemy import IndexService
|
|
from oshipka.util.strings import camel_case_to_snake_case
|
|
from vm_gen.vm_gen import order_from_process_order
|
|
|
|
db = SQLAlchemy()
|
|
migrate = Migrate()
|
|
|
|
|
|
class Ownable(object):
|
|
@declared_attr
|
|
def user_id(self):
|
|
return db.Column(db.Integer, db.ForeignKey('user.id'))
|
|
|
|
@declared_attr
|
|
def user(self):
|
|
return db.relationship("User")
|
|
|
|
|
|
roles_users = db.Table('roles_users',
|
|
db.Column('user_id', db.Integer(), db.ForeignKey('user.id')),
|
|
db.Column('role_id', db.Integer(), db.ForeignKey('role.id')))
|
|
|
|
|
|
class ModelJsonEncoder(JSONEncoder):
|
|
def default(self, o):
|
|
if isinstance(o, datetime.datetime):
|
|
return str(o)
|
|
return o.id
|
|
|
|
|
|
class LiberalBoolean(TypeDecorator):
|
|
impl = Boolean
|
|
|
|
def process_bind_param(self, value, dialect):
|
|
if value is not None:
|
|
if hasattr(value, 'isdigit') and value.isdigit():
|
|
value = int(value)
|
|
value = bool(value)
|
|
return value
|
|
|
|
|
|
class ModelController(ModelJsonEncoder):
|
|
"""
|
|
This interface is the parent of all models in our database.
|
|
"""
|
|
|
|
@declared_attr
|
|
def __tablename__(self):
|
|
return camel_case_to_snake_case(self.__name__) # pylint: disable=E1101
|
|
|
|
_sa_declared_attr_reg = {'__tablename__': True}
|
|
__mapper_args__ = {'always_refresh': True}
|
|
|
|
id = db.Column(db.Integer, primary_key=True)
|
|
uuid = db.Column(db.Unicode, index=True, default=lambda: str(uuid4()))
|
|
|
|
_excluded_serialization = []
|
|
|
|
def serialize(self, with_=None, depth=0, withs_used=None):
|
|
"""
|
|
Serializes object to dict
|
|
It will ignore fields that are not encodable (set them to 'None').
|
|
|
|
It expands relations mentioned in with_ recursively up to MAX_DEPTH and tries to smartly ignore
|
|
recursions by mentioning which with elements have already been used in previous depths
|
|
:return:
|
|
"""
|
|
|
|
MAX_DEPTH = 3
|
|
|
|
def with_used_in_prev_depth(field, previous_depths):
|
|
for previous_depth in previous_depths:
|
|
if field in previous_depth:
|
|
return True
|
|
return False
|
|
|
|
def handle_withs(data, with_, depth, withs_used):
|
|
if isinstance(data, InstrumentedList):
|
|
if depth >= MAX_DEPTH:
|
|
return [e.serialize() for e in data]
|
|
else:
|
|
return [e.serialize(with_=with_, depth=depth + 1, withs_used=withs_used) for e in data]
|
|
else:
|
|
if depth >= MAX_DEPTH:
|
|
return data.serialize()
|
|
else:
|
|
return data.serialize(with_=with_, depth=depth + 1, withs_used=withs_used)
|
|
|
|
if not with_:
|
|
with_ = []
|
|
if not withs_used:
|
|
withs_used = []
|
|
if isinstance(self.__class__, DeclarativeMeta):
|
|
# an SQLAlchemy class
|
|
fields = {}
|
|
iterable_fields = [x for x in dir(self) if not x.startswith('_') and x not in ['metadata',
|
|
'item_separator',
|
|
'key_separator'] and x.islower()
|
|
and x not in self._excluded_serialization]
|
|
for field in iterable_fields:
|
|
data = self.__getattribute__(field)
|
|
try:
|
|
if field in with_:
|
|
# this hanldes withs nested inside other models
|
|
if len(withs_used) < depth + 1:
|
|
withs_used.append([])
|
|
previous_depths = withs_used[:depth]
|
|
if with_used_in_prev_depth(field, previous_depths):
|
|
continue
|
|
withs_used[depth].append(field)
|
|
data = handle_withs(data, with_, depth, withs_used)
|
|
if isinstance(data, datetime.datetime):
|
|
data = str(data)
|
|
if isinstance(data, Choice):
|
|
data = data.code
|
|
json.dumps(data) # this will fail on non-encodable values, like other classes
|
|
if isinstance(data, InstrumentedList):
|
|
continue # pragma: no cover
|
|
fields[field] = data
|
|
except TypeError as e:
|
|
pass # Don't assign anything
|
|
# a json-encodable dict
|
|
return fields
|
|
|
|
|
|
class Role(db.Model, ModelController, RoleMixin):
|
|
name = db.Column(db.Unicode, unique=True)
|
|
description = db.Column(db.Unicode)
|
|
|
|
|
|
class User(db.Model, ModelController, UserMixin):
|
|
email = db.Column(db.Unicode, unique=True)
|
|
password = db.Column(db.Unicode)
|
|
|
|
active = db.Column(db.Boolean(), default=True)
|
|
confirmed_at = db.Column(db.DateTime())
|
|
|
|
timezone = db.Column(db.String, default='UTC')
|
|
tz_offset_seconds = db.Column(db.Integer, default=0)
|
|
locale = db.Column(db.String(4), default='en')
|
|
|
|
name = db.Column(db.Unicode)
|
|
profile_image_url = db.Column(db.String)
|
|
|
|
roles = db.relationship('Role', secondary=roles_users,
|
|
backref=db.backref('users', lazy='dynamic'))
|
|
|
|
|
|
security = Security()
|
|
user_datastore = SQLAlchemyUserDatastore(db, User, Role)
|
|
|
|
|
|
def register_filters(app):
|
|
_paragraph_re = re.compile(r'(?:\r\n|\r|\n){2,}')
|
|
|
|
@app.template_filter('nl2br')
|
|
def nl2br(text):
|
|
text = escape(text)
|
|
result = u'<p>'.join(u'%s' % p.replace('\n', '<br>\n') for p in _paragraph_re.split(text))
|
|
return Markup(result)
|
|
|
|
@app.template_filter('format_dt')
|
|
def format_datetime(dt, formatting="%a, %d %b %Y"):
|
|
"""
|
|
Formats the datetime string provided in value into whatever format you want that is supported by python strftime
|
|
http://strftime.org/
|
|
:param formatting The specific format of the datetime
|
|
:param dt a datetime object
|
|
:return:
|
|
"""
|
|
if not dt:
|
|
return dt
|
|
if type(dt) is str:
|
|
dt = solve_query(dt)
|
|
return dt.strftime(formatting)
|
|
|
|
@app.template_filter('to_tz')
|
|
def to_tz(dt, human_tz="utc", formatting='%H:%M', include_offset=True):
|
|
if type(dt) is str:
|
|
dt = solve_query(dt)
|
|
tz = resolve_timezone(human_tz)
|
|
dt = dt_tz_translation(dt, to_tz_offset=tz.get('tz_offset'))
|
|
base = dt.strftime(formatting)
|
|
if not include_offset:
|
|
return base
|
|
return "{} ({})".format(base, tz.get('tz_offset'))
|
|
|
|
@app.template_filter('timeago')
|
|
def timeago(dt):
|
|
return time_ago(dt)
|
|
|
|
@app.template_filter('timediff')
|
|
def timediff(diff):
|
|
return time_ago(None, diff)
|
|
|
|
|
|
class Proxy(object):
|
|
def __init__(self, proxied):
|
|
self.proxied = proxied
|
|
self.searchables = []
|
|
|
|
|
|
index_service = Proxy(None)
|
|
|
|
|
|
def register_index_svc():
|
|
for searchable in index_service.searchables:
|
|
index_service.proxied.register_class(searchable)
|
|
|
|
|
|
def init_db(app):
|
|
rv = False
|
|
app.config["SQLALCHEMY_DATABASE_URI"] = SQLALCHEMY_DATABASE_URI
|
|
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
|
|
app.config["WHOOSH_BASE"] = SEARCH_INDEX_PATH
|
|
|
|
from oshipka.webapp import test_bp, oshipka_bp
|
|
app.register_blueprint(test_bp)
|
|
app.register_blueprint(oshipka_bp)
|
|
|
|
db.init_app(app)
|
|
migrate.init_app(app, db)
|
|
security.init_app(app, user_datastore)
|
|
|
|
register_filters(app)
|
|
|
|
for dir in MAKEDIRS:
|
|
os.makedirs(dir, exist_ok=True)
|
|
if not os.path.exists(os.path.join(basepath, 'migrations')):
|
|
with app.app_context():
|
|
migrate_init()
|
|
if not os.path.exists(DATABASE_FILE):
|
|
with app.app_context():
|
|
migrate_upgrade()
|
|
rv = True
|
|
global index_service
|
|
index_service.proxied = IndexService(config=app.config, session=db.session)
|
|
register_index_svc()
|
|
return rv
|
|
|
|
|
|
def populate_static(app):
|
|
with app.app_context():
|
|
models = import_module("webapp.models")
|
|
ordered_model_names = order_from_process_order('csv', STATIC_DATA_DIR)
|
|
for model_name in ordered_model_names:
|
|
model = getattr(models, model_name)
|
|
with open(os.path.join(STATIC_DATA_DIR, "{}.csv".format(model_name))) as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
instance = model(**row)
|
|
db.session.add(instance)
|
|
db.session.commit()
|