oshipka/oshipka/persistance/__init__.py

277 lines
9.3 KiB
Python

import csv
import datetime
import json
import os
import re
from importlib import import_module
from json import JSONEncoder
from uuid import uuid4
from config import SQLALCHEMY_DATABASE_URI, MAKEDIRS, DATABASE_FILE, SEARCH_INDEX_PATH, STATIC_DATA_DIR, basepath
from flask_migrate import Migrate
from flask_migrate import upgrade as migrate_upgrade
from flask_migrate import init as migrate_init
from flask_security import RoleMixin, UserMixin
from flask_security import Security, SQLAlchemyUserDatastore
from flask_sqlalchemy import SQLAlchemy
from markupsafe import escape, Markup
from sqlalchemy import Boolean
from sqlalchemy import TypeDecorator
from sqlalchemy.ext.declarative import declared_attr, DeclarativeMeta
from sqlalchemy.orm.collections import InstrumentedList
from sqlalchemy_utils import Choice
from tww.lib import solve_query, resolve_timezone, dt_tz_translation, time_ago
from whooshalchemy import IndexService
from oshipka.util.strings import camel_case_to_snake_case
from vm_gen.vm_gen import order_from_process_order
db = SQLAlchemy()
migrate = Migrate()
class Ownable(object):
@declared_attr
def user_id(self):
return db.Column(db.Integer, db.ForeignKey('user.id'))
@declared_attr
def user(self):
return db.relationship("User")
roles_users = db.Table('roles_users',
db.Column('user_id', db.Integer(), db.ForeignKey('user.id')),
db.Column('role_id', db.Integer(), db.ForeignKey('role.id')))
class ModelJsonEncoder(JSONEncoder):
def default(self, o):
if isinstance(o, datetime.datetime):
return str(o)
return o.id
class LiberalBoolean(TypeDecorator):
impl = Boolean
def process_bind_param(self, value, dialect):
if value is not None:
if hasattr(value, 'isdigit') and value.isdigit():
value = int(value)
value = bool(value)
return value
class ModelController(ModelJsonEncoder):
"""
This interface is the parent of all models in our database.
"""
@declared_attr
def __tablename__(self):
return camel_case_to_snake_case(self.__name__) # pylint: disable=E1101
_sa_declared_attr_reg = {'__tablename__': True}
__mapper_args__ = {'always_refresh': True}
id = db.Column(db.Integer, primary_key=True)
uuid = db.Column(db.Unicode, index=True, default=lambda: str(uuid4()))
_excluded_serialization = []
def serialize(self, with_=None, depth=0, withs_used=None):
"""
Serializes object to dict
It will ignore fields that are not encodable (set them to 'None').
It expands relations mentioned in with_ recursively up to MAX_DEPTH and tries to smartly ignore
recursions by mentioning which with elements have already been used in previous depths
:return:
"""
MAX_DEPTH = 3
def with_used_in_prev_depth(field, previous_depths):
for previous_depth in previous_depths:
if field in previous_depth:
return True
return False
def handle_withs(data, with_, depth, withs_used):
if isinstance(data, InstrumentedList):
if depth >= MAX_DEPTH:
return [e.serialize() for e in data]
else:
return [e.serialize(with_=with_, depth=depth + 1, withs_used=withs_used) for e in data]
else:
if depth >= MAX_DEPTH:
return data.serialize()
else:
return data.serialize(with_=with_, depth=depth + 1, withs_used=withs_used)
if not with_:
with_ = []
if not withs_used:
withs_used = []
if isinstance(self.__class__, DeclarativeMeta):
# an SQLAlchemy class
fields = {}
iterable_fields = [x for x in dir(self) if not x.startswith('_') and x not in ['metadata',
'item_separator',
'key_separator'] and x.islower()
and x not in self._excluded_serialization]
for field in iterable_fields:
data = self.__getattribute__(field)
try:
if field in with_:
# this hanldes withs nested inside other models
if len(withs_used) < depth + 1:
withs_used.append([])
previous_depths = withs_used[:depth]
if with_used_in_prev_depth(field, previous_depths):
continue
withs_used[depth].append(field)
data = handle_withs(data, with_, depth, withs_used)
if isinstance(data, datetime.datetime):
data = str(data)
if isinstance(data, Choice):
data = data.code
json.dumps(data) # this will fail on non-encodable values, like other classes
if isinstance(data, InstrumentedList):
continue # pragma: no cover
fields[field] = data
except TypeError as e:
pass # Don't assign anything
# a json-encodable dict
return fields
class Role(db.Model, ModelController, RoleMixin):
name = db.Column(db.Unicode, unique=True)
description = db.Column(db.Unicode)
class User(db.Model, ModelController, UserMixin):
email = db.Column(db.Unicode, unique=True)
password = db.Column(db.Unicode)
active = db.Column(db.Boolean(), default=True)
confirmed_at = db.Column(db.DateTime())
timezone = db.Column(db.String, default='UTC')
tz_offset_seconds = db.Column(db.Integer, default=0)
locale = db.Column(db.String(4), default='en')
name = db.Column(db.Unicode)
profile_image_url = db.Column(db.String)
roles = db.relationship('Role', secondary=roles_users,
backref=db.backref('users', lazy='dynamic'))
security = Security()
user_datastore = SQLAlchemyUserDatastore(db, User, Role)
def register_filters(app):
_paragraph_re = re.compile(r'(?:\r\n|\r|\n){2,}')
@app.template_filter('nl2br')
def nl2br(text):
text = escape(text)
result = u'<p>'.join(u'%s' % p.replace('\n', '<br>\n') for p in _paragraph_re.split(text))
return Markup(result)
@app.template_filter('format_dt')
def format_datetime(dt, formatting="%a, %d %b %Y"):
"""
Formats the datetime string provided in value into whatever format you want that is supported by python strftime
http://strftime.org/
:param formatting The specific format of the datetime
:param dt a datetime object
:return:
"""
if not dt:
return dt
if type(dt) is str:
dt = solve_query(dt)
return dt.strftime(formatting)
@app.template_filter('to_tz')
def to_tz(dt, human_tz="utc", formatting='%H:%M', include_offset=True):
if type(dt) is str:
dt = solve_query(dt)
tz = resolve_timezone(human_tz)
dt = dt_tz_translation(dt, to_tz_offset=tz.get('tz_offset'))
base = dt.strftime(formatting)
if not include_offset:
return base
return "{} ({})".format(base, tz.get('tz_offset'))
@app.template_filter('timeago')
def timeago(dt):
return time_ago(dt)
@app.template_filter('timediff')
def timediff(diff):
return time_ago(None, diff)
class Proxy(object):
def __init__(self, proxied):
self.proxied = proxied
self.searchables = []
index_service = Proxy(None)
def register_index_svc():
for searchable in index_service.searchables:
index_service.proxied.register_class(searchable)
def init_db(app):
rv = False
app.config["SQLALCHEMY_DATABASE_URI"] = SQLALCHEMY_DATABASE_URI
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
app.config["WHOOSH_BASE"] = SEARCH_INDEX_PATH
from oshipka.webapp import test_bp, oshipka_bp
app.register_blueprint(test_bp)
app.register_blueprint(oshipka_bp)
db.init_app(app)
migrate.init_app(app, db)
security.init_app(app, user_datastore)
register_filters(app)
for dir in MAKEDIRS:
os.makedirs(dir, exist_ok=True)
if not os.path.exists(os.path.join(basepath, 'migrations')):
with app.app_context():
migrate_init()
if not os.path.exists(DATABASE_FILE):
with app.app_context():
migrate_upgrade()
rv = True
global index_service
index_service.proxied = IndexService(config=app.config, session=db.session)
register_index_svc()
return rv
def populate_static(app):
with app.app_context():
models = import_module("webapp.models")
ordered_model_names = order_from_process_order('csv', STATIC_DATA_DIR)
for model_name in ordered_model_names:
model = getattr(models, model_name)
with open(os.path.join(STATIC_DATA_DIR, "{}.csv".format(model_name))) as f:
reader = csv.DictReader(f)
for row in reader:
instance = model(**row)
db.session.add(instance)
db.session.commit()