452 lines
17 KiB
Python
452 lines
17 KiB
Python
import csv
|
|
import datetime
|
|
import json
|
|
import os
|
|
import re
|
|
from distutils import dir_util
|
|
from importlib import import_module
|
|
from json import JSONEncoder
|
|
from uuid import uuid4
|
|
|
|
from flask import request
|
|
from flask_migrate import Migrate
|
|
from flask_migrate import init as migrate_init
|
|
from flask_migrate import upgrade as migrate_upgrade
|
|
from flask_security import RoleMixin, UserMixin
|
|
from flask_security import Security, SQLAlchemyUserDatastore
|
|
from flask_security import current_user
|
|
from flask_sqlalchemy import SQLAlchemy
|
|
from flask_wtf import CSRFProtect
|
|
from markupsafe import escape, Markup
|
|
from sqlalchemy import Boolean
|
|
from sqlalchemy import TypeDecorator
|
|
from sqlalchemy.ext.declarative import declared_attr, DeclarativeMeta
|
|
from sqlalchemy.orm.collections import InstrumentedList
|
|
from sqlalchemy_utils import Choice
|
|
from tww.lib import solve_query, resolve_timezone, dt_tz_translation, time_ago
|
|
from whooshalchemy import IndexService
|
|
|
|
from config import SQLALCHEMY_DATABASE_URI, MAKEDIRS, DATABASE_FILE, SEARCH_INDEX_PATH, STATIC_DATA_DIR, MEDIA_DIR, \
|
|
basepath, SECURITY_ENABLED
|
|
from oshipka.util.strings import camel_case_to_snake_case
|
|
from vm_gen.vm_gen import order_from_process_order
|
|
|
|
db = SQLAlchemy()
|
|
migrate = Migrate()
|
|
csrf = CSRFProtect()
|
|
|
|
SHARING_TYPE_TYPES_TYPE_PUBLIC = "PUBLIC"
|
|
SHARING_TYPE_TYPES_TYPE_AUTHZ = "AUTHZ"
|
|
SHARING_TYPE_TYPES_TYPE_AUTHN = "AUTHN"
|
|
|
|
SHARING_TYPE_TYPES = [
|
|
(SHARING_TYPE_TYPES_TYPE_PUBLIC, u'public'),
|
|
(SHARING_TYPE_TYPES_TYPE_AUTHZ, u'all logged in'),
|
|
(SHARING_TYPE_TYPES_TYPE_AUTHN, u'some authenticated users'),
|
|
]
|
|
|
|
|
|
class Ownable(object):
|
|
@declared_attr
|
|
def user_id(self):
|
|
return db.Column(db.Integer, db.ForeignKey('user.id'))
|
|
|
|
@declared_attr
|
|
def user(self):
|
|
return db.relationship("User")
|
|
|
|
|
|
class Datable(object):
|
|
created_dt = db.Column(db.UnicodeText())
|
|
updated_dt = db.Column(db.UnicodeText())
|
|
|
|
|
|
class ModelJsonEncoder(JSONEncoder):
|
|
def default(self, o):
|
|
if isinstance(o, datetime.datetime):
|
|
return str(o)
|
|
return o.id
|
|
|
|
|
|
class LiberalBoolean(TypeDecorator):
|
|
impl = Boolean
|
|
|
|
def process_bind_param(self, value, dialect):
|
|
if value is not None:
|
|
if hasattr(value, 'isdigit') and value.isdigit():
|
|
value = int(value)
|
|
value = bool(value)
|
|
return value
|
|
|
|
|
|
class ModelController(ModelJsonEncoder):
|
|
"""
|
|
This interface is the parent of all models in our database.
|
|
"""
|
|
|
|
@declared_attr
|
|
def __tablename__(self):
|
|
return camel_case_to_snake_case(self.__name__) # pylint: disable=E1101
|
|
|
|
_sa_declared_attr_reg = {'__tablename__': True}
|
|
__mapper_args__ = {'always_refresh': True}
|
|
|
|
id = db.Column(db.Integer, primary_key=True)
|
|
uuid = db.Column(db.Unicode, index=True, default=lambda: str(uuid4()))
|
|
|
|
_excluded_serialization = []
|
|
|
|
def serialize(self, with_=None, depth=0, withs_used=None):
|
|
"""
|
|
Serializes object to dict
|
|
It will ignore fields that are not encodable (set them to 'None').
|
|
|
|
It expands relations mentioned in with_ recursively up to MAX_DEPTH and tries to smartly ignore
|
|
recursions by mentioning which with elements have already been used in previous depths
|
|
:return:
|
|
"""
|
|
|
|
MAX_DEPTH = 3
|
|
|
|
def with_used_in_prev_depth(field, previous_depths):
|
|
for previous_depth in previous_depths:
|
|
if field in previous_depth:
|
|
return True
|
|
return False
|
|
|
|
def handle_withs(data, with_, depth, withs_used):
|
|
if isinstance(data, InstrumentedList):
|
|
if depth >= MAX_DEPTH:
|
|
return [e.serialize() for e in data]
|
|
else:
|
|
return [e.serialize(with_=with_, depth=depth + 1, withs_used=withs_used) for e in data]
|
|
else:
|
|
if depth >= MAX_DEPTH:
|
|
return data.serialize()
|
|
else:
|
|
return data.serialize(with_=with_, depth=depth + 1, withs_used=withs_used)
|
|
|
|
if not with_:
|
|
with_ = []
|
|
if not withs_used:
|
|
withs_used = []
|
|
if isinstance(self.__class__, DeclarativeMeta):
|
|
# an SQLAlchemy class
|
|
fields = {}
|
|
iterable_fields = [x for x in dir(self) if not x.startswith('_') and x not in ['metadata',
|
|
'item_separator',
|
|
'key_separator'] and x.islower()
|
|
and x not in self._excluded_serialization]
|
|
for field in iterable_fields:
|
|
data = self.__getattribute__(field)
|
|
try:
|
|
if field in with_:
|
|
# this hanldes withs nested inside other models
|
|
if len(withs_used) < depth + 1:
|
|
withs_used.append([])
|
|
previous_depths = withs_used[:depth]
|
|
if with_used_in_prev_depth(field, previous_depths):
|
|
continue
|
|
withs_used[depth].append(field)
|
|
data = handle_withs(data, with_, depth, withs_used)
|
|
if isinstance(data, datetime.datetime):
|
|
data = str(data)
|
|
if isinstance(data, Choice):
|
|
data = data.code
|
|
json.dumps(data) # this will fail on non-encodable values, like other classes
|
|
if isinstance(data, InstrumentedList):
|
|
continue # pragma: no cover
|
|
fields[field] = data
|
|
except TypeError as e:
|
|
pass # Don't assign anything
|
|
# a json-encodable dict
|
|
return fields
|
|
|
|
|
|
if SECURITY_ENABLED:
|
|
roles_users = db.Table('roles_users',
|
|
db.Column('user_id', db.Integer(), db.ForeignKey('user.id')),
|
|
db.Column('role_id', db.Integer(), db.ForeignKey('role.id')))
|
|
|
|
|
|
class Role(db.Model, ModelController, RoleMixin):
|
|
name = db.Column(db.Unicode, unique=True)
|
|
description = db.Column(db.Unicode)
|
|
|
|
|
|
class User(db.Model, ModelController, UserMixin):
|
|
username = db.Column(db.Unicode, unique=True)
|
|
token = db.Column(db.Unicode)
|
|
|
|
active = db.Column(db.Boolean(), default=True)
|
|
|
|
timezone = db.Column(db.String, default='UTC')
|
|
tz_offset_seconds = db.Column(db.Integer, default=0)
|
|
locale = db.Column(db.String(4), default='en')
|
|
|
|
name = db.Column(db.Unicode)
|
|
profile_image_url = db.Column(db.String)
|
|
|
|
_m_n_table_roles = 'Role'
|
|
|
|
roles = db.relationship('Role', secondary=roles_users,
|
|
backref=db.backref('users', lazy='dynamic'))
|
|
|
|
|
|
security = Security()
|
|
user_datastore = SQLAlchemyUserDatastore(db, User, Role)
|
|
|
|
|
|
def register_filters(app):
|
|
# register jinja filters
|
|
_paragraph_re = re.compile(r'(?:\r\n|\r|\n){2,}')
|
|
from oshipka.webapp.views import MODEL_VIEWS
|
|
|
|
@app.template_filter('nl2br')
|
|
def nl2br(text):
|
|
text = escape(text)
|
|
result = u'<p>'.join(u'%s' % p.replace('\n', '<br>\n') for p in _paragraph_re.split(text))
|
|
return Markup(result)
|
|
|
|
@app.template_filter('sp2nbsp')
|
|
def sp2nbsp(text):
|
|
text = escape(text)
|
|
result = u'<p>'.join(u'%s' % p.replace(' ', ' ') for p in _paragraph_re.split(text))
|
|
return Markup(result)
|
|
|
|
@app.template_filter('format_dt')
|
|
def format_datetime(dt, formatting="%a, %d %b %Y"):
|
|
"""
|
|
Formats the datetime string provided in value into whatever format you want that is supported by python strftime
|
|
http://strftime.org/
|
|
:param formatting The specific format of the datetime
|
|
:param dt a datetime object
|
|
:return:
|
|
"""
|
|
if not dt:
|
|
return dt
|
|
if type(dt) is str:
|
|
dt = solve_query(dt)
|
|
return dt.strftime(formatting)
|
|
|
|
@app.template_filter('to_dt')
|
|
def to_dt(dt_str, formatting="%Y%m%d"):
|
|
return datetime.datetime.strptime(dt_str, formatting)
|
|
|
|
@app.template_filter('to_tz')
|
|
def to_tz(dt, human_tz="utc", formatting='%H:%M', include_offset=True):
|
|
if type(dt) is str:
|
|
dt = solve_query(dt)
|
|
tz = resolve_timezone(human_tz)
|
|
dt = dt_tz_translation(dt, to_tz_offset=tz.get('tz_offset'))
|
|
base = dt.strftime(formatting)
|
|
if not include_offset:
|
|
return base
|
|
return "{} ({})".format(base, tz.get('tz_offset'))
|
|
|
|
@app.template_filter('timeago')
|
|
def timeago(dt):
|
|
return time_ago(dt)
|
|
|
|
@app.template_filter('timediff')
|
|
def timediff(diff):
|
|
return time_ago(None, diff)
|
|
|
|
@app.template_filter('bool')
|
|
def bool_filter(v):
|
|
return bool(v)
|
|
|
|
from oshipka.webapp.views import has_permission
|
|
app.jinja_env.globals.update(has_permission=has_permission)
|
|
|
|
|
|
class Proxy(object):
|
|
def __init__(self, proxied):
|
|
self.proxied = proxied
|
|
self.searchables = []
|
|
|
|
|
|
index_service = Proxy(None)
|
|
|
|
|
|
def register_index_svc():
|
|
for searchable in index_service.searchables:
|
|
index_service.proxied.register_class(searchable)
|
|
|
|
|
|
def _init_translations(app):
|
|
from flask_babelex import Babel, gettext, Domain
|
|
babel = Babel(app, default_domain=Domain(dirname='translations'))
|
|
|
|
@babel.localeselector
|
|
def get_locale():
|
|
# if a user is logged in, use the locale from the user settings
|
|
if current_user.is_authenticated:
|
|
return current_user.locale
|
|
# otherwise try to guess the language from the user accept
|
|
# header the browser transmits
|
|
return request.accept_languages.best_match(app.config.get('TRANSLATION_LANGUAGES', ['en']))
|
|
|
|
@babel.timezoneselector
|
|
def get_timezone():
|
|
if current_user is not None:
|
|
return current_user.timezone
|
|
|
|
|
|
def init_db(app):
|
|
rv = False
|
|
app.config["SQLALCHEMY_DATABASE_URI"] = SQLALCHEMY_DATABASE_URI
|
|
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
|
|
app.config["WHOOSH_BASE"] = SEARCH_INDEX_PATH
|
|
app.config["BABEL_TRANSLATION_DIRECTORIES"] = '/translations'
|
|
try:
|
|
from config import TRANSLATION_LANGUAGES
|
|
app.config["TRANSLATION_LANGUAGES"] = TRANSLATION_LANGUAGES
|
|
except:
|
|
app.config["TRANSLATION_LANGUAGES"] = ['en']
|
|
|
|
from oshipka.webapp import test_bp, oshipka_bp
|
|
app.register_blueprint(test_bp)
|
|
app.register_blueprint(oshipka_bp)
|
|
|
|
db.init_app(app)
|
|
csrf.init_app(app)
|
|
migrate.init_app(app, db)
|
|
if SECURITY_ENABLED:
|
|
security.init_app(app, user_datastore)
|
|
_init_translations(app)
|
|
|
|
register_filters(app)
|
|
|
|
for dir in MAKEDIRS:
|
|
os.makedirs(dir, exist_ok=True)
|
|
if not os.path.exists(os.path.join(basepath, 'migrations')):
|
|
with app.app_context():
|
|
migrate_init()
|
|
if not os.path.exists(DATABASE_FILE):
|
|
with app.app_context():
|
|
migrate_upgrade()
|
|
rv = True
|
|
global index_service
|
|
index_service.proxied = IndexService(config=app.config, session=db.session)
|
|
register_index_svc()
|
|
return rv
|
|
|
|
|
|
SENSITIVE_PREFIX = "__SENSITIVE__."
|
|
|
|
|
|
DEFAULT_PERMISSION_PERMISSIONS = ['get', 'add_user', 'add_role', 'remove_user', 'remove_role']
|
|
DEFAULT_MODEL_PERMISSIONS = ['get', 'list', 'table', 'search', 'create', 'update', 'delete']
|
|
DEFAULT_COLUMN_PERMISSIONS = ['read', 'write']
|
|
DEFAULT_SUBJECTS = [('0', 'public'), ('1', 'logged')]
|
|
|
|
|
|
def generate_permissions():
|
|
from oshipka.webapp.views import MODEL_VIEWS
|
|
with open(os.path.join(STATIC_DATA_DIR, "Permission.csv"), 'w') as f:
|
|
f.write("subject,subject_id,action,object,object_id,is_allowed\n")
|
|
for permission in DEFAULT_PERMISSION_PERMISSIONS:
|
|
f.write("role,1,permission.{},admin.permissions,,1\n".format(permission))
|
|
for model, model_view in MODEL_VIEWS.items():
|
|
if model in ['permission']:
|
|
continue
|
|
is_ownable = 'Ownable' in model_view.definitions.get('inherits', [])
|
|
subjects = DEFAULT_SUBJECTS + [('1', 'owner')] if is_ownable else DEFAULT_SUBJECTS
|
|
f.write("role,1,permission.update,models.{},,1\n".format(model))
|
|
f.write("role,1,permission.remove_user_self,models.{},,1\n".format(model))
|
|
model_acls = model_view.definitions['acls']
|
|
for perm, subject in subjects:
|
|
for permission in DEFAULT_PERMISSION_PERMISSIONS:
|
|
f.write("{},,permission.{},models.{},,0\n".format(subject, permission, model))
|
|
f.write("role,1,permission.{},models.{},,1\n".format(permission, model))
|
|
f.write("{},,permission.update,models.{},,0\n".format(subject, model))
|
|
f.write("{},,permission.remove_user_self,models.{},,0\n".format(subject, model))
|
|
if is_ownable:
|
|
if subject in ['owner']:
|
|
f.write("{},,permission.change_owner,models.{},,1\n".format(subject, model))
|
|
else:
|
|
f.write("{},,permission.change_owner,models.{},,0\n".format(subject, model))
|
|
for permission in DEFAULT_MODEL_PERMISSIONS:
|
|
# TODO: TEST AND FIX THIS - VERY NAIVE RIGHT NOW!!!
|
|
this_perm = int(model_acls.get(permission)['authn'])
|
|
f.write("{},,model.{},models.{},,{}\n".format(subject, permission, model, this_perm))
|
|
for column in model_view.definitions.get('columns'):
|
|
column_name = column.get('name')
|
|
for permission in DEFAULT_COLUMN_PERMISSIONS:
|
|
f.write("{},,column.{}.{},columns.{},,{}\n".format(subject, column_name, permission, model, perm))
|
|
f.write("role,1,column.{}.{},columns.{},,{}\n".format(subject, column_name, permission, model, perm))
|
|
|
|
|
|
def populate_static(app):
|
|
print("populating...")
|
|
static_media_dir = os.path.join(STATIC_DATA_DIR, "media")
|
|
if os.path.exists(static_media_dir):
|
|
dir_util.copy_tree(static_media_dir, MEDIA_DIR)
|
|
with app.app_context():
|
|
models = import_module("webapp.models")
|
|
sensitive = import_module("sensitive")
|
|
ordered_model_names = order_from_process_order('csv', STATIC_DATA_DIR)
|
|
for model_name in ordered_model_names:
|
|
if SECURITY_ENABLED and model_name in ['User', 'Role']:
|
|
model = eval(model_name)
|
|
else:
|
|
if SECURITY_ENABLED and model_name in ['Permission']:
|
|
generate_permissions()
|
|
model = getattr(models, model_name)
|
|
with open(os.path.join(STATIC_DATA_DIR, "{}.csv".format(model_name))) as f:
|
|
if issubclass(model, Ownable):
|
|
user = User.query.first()
|
|
else:
|
|
user = None
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
row_updates = dict()
|
|
for key, value in row.items():
|
|
if value and value.startswith(SENSITIVE_PREFIX):
|
|
sensitive_key = SENSITIVE_PREFIX.join(value.split(SENSITIVE_PREFIX)[1:])
|
|
sensitive_value = getattr(sensitive, sensitive_key)
|
|
row_updates[key] = sensitive_value
|
|
if row_updates:
|
|
row.update(row_updates)
|
|
instance = create_model(model, user, row)
|
|
db.session.add(instance)
|
|
db.session.commit()
|
|
print("Finished populating")
|
|
|
|
|
|
def filter_m_n(serialized_args):
|
|
m_ns, to_delete = {}, []
|
|
for k in serialized_args:
|
|
if k.startswith('_m_n_'):
|
|
m_n_name = k.split('_m_n_')[1]
|
|
m_ns[m_n_name] = serialized_args[k]
|
|
to_delete.append(k)
|
|
return m_ns, to_delete
|
|
|
|
|
|
def update_m_ns(instance, m_ns):
|
|
from oshipka.webapp.views import webapp_models
|
|
instance = instance
|
|
for key, ids in m_ns.items():
|
|
child_rel = getattr(instance, "_m_n_table_{}".format(key))
|
|
if key not in ['roles']:
|
|
child_table = getattr(webapp_models, child_rel)
|
|
else:
|
|
child_table = Role
|
|
children = db.session.query(child_table).filter(child_table.id.in_(ids)).all()
|
|
setattr(instance, key, children)
|
|
|
|
|
|
def create_model(model, user, serialized_args):
|
|
m_ns, to_delete = filter_m_n(serialized_args)
|
|
for key in to_delete:
|
|
del serialized_args[key]
|
|
instance = model()
|
|
for k, v in serialized_args.items():
|
|
setattr(instance, k, v)
|
|
for key, ids in m_ns.items():
|
|
m_ns[key] = ids.split(',')
|
|
update_m_ns(instance, m_ns)
|
|
return instance
|