From 5e0bd4c1281a0c1cef7c1e8c3319cf03adf722ae Mon Sep 17 00:00:00 2001 From: Daniel Tsvetkov Date: Sat, 25 Apr 2020 12:39:38 +0200 Subject: [PATCH] util changes --- oshipka/persistance/__init__.py | 102 +++++++++++++++++++++++++++++++- oshipka/webapp/views.py | 8 +-- 2 files changed, 104 insertions(+), 6 deletions(-) diff --git a/oshipka/persistance/__init__.py b/oshipka/persistance/__init__.py index 4b9e3e4..0ae08d7 100644 --- a/oshipka/persistance/__init__.py +++ b/oshipka/persistance/__init__.py @@ -1,4 +1,5 @@ import datetime +import json import os import re from json import JSONEncoder @@ -6,10 +7,14 @@ from uuid import uuid4 from flask_sqlalchemy import SQLAlchemy from flask_security import Security, SQLAlchemyUserDatastore -from sqlalchemy.ext.declarative import declared_attr +from markupsafe import escape, Markup +from sqlalchemy.ext.declarative import declared_attr, DeclarativeMeta from flask_security import RoleMixin, UserMixin from config import SQLALCHEMY_DATABASE_URI, MAKEDIRS, DATABASE_FILE +from sqlalchemy.orm.collections import InstrumentedList +from sqlalchemy_utils import Choice +from tww.tww import solve_query db = SQLAlchemy() @@ -61,6 +66,74 @@ class ModelController(ModelJsonEncoder): id = db.Column(db.Integer, primary_key=True) uuid = db.Column(db.Unicode, default=str(uuid4())) + _excluded_serialization = [] + + def serialize(self, with_=None, depth=0, withs_used=None): + """ + Serializes object to dict + It will ignore fields that are not encodable (set them to 'None'). + + It expands relations mentioned in with_ recursively up to MAX_DEPTH and tries to smartly ignore + recursions by mentioning which with elements have already been used in previous depths + :return: + """ + + MAX_DEPTH = 3 + + def with_used_in_prev_depth(field, previous_depths): + for previous_depth in previous_depths: + if field in previous_depth: + return True + return False + + def handle_withs(data, with_, depth, withs_used): + if isinstance(data, InstrumentedList): + if depth >= MAX_DEPTH: + return [e.serialize_flat() for e in data] + else: + return [e.serialize_flat(with_=with_, depth=depth + 1, withs_used=withs_used) for e in data] + else: + if depth >= MAX_DEPTH: + return data.serialize_flat() + else: + return data.serialize_flat(with_=with_, depth=depth + 1, withs_used=withs_used) + + if not with_: + with_ = [] + if not withs_used: + withs_used = [] + if isinstance(self.__class__, DeclarativeMeta): + # an SQLAlchemy class + fields = {} + iterable_fields = [x for x in dir(self) if not x.startswith('_') and x not in ['metadata', + 'item_separator', + 'key_separator'] and x.islower() + and x not in self._excluded_serialization] + for field in iterable_fields: + data = self.__getattribute__(field) + try: + if field in with_: + # this hanldes withs nested inside other models + if len(withs_used) < depth + 1: + withs_used.append([]) + previous_depths = withs_used[:depth] + if with_used_in_prev_depth(field, previous_depths): + continue + withs_used[depth].append(field) + data = handle_withs(data, with_, depth, withs_used) + if isinstance(data, datetime.datetime): + data = str(data) + if isinstance(data, Choice): + data = data.code + json.dumps(data) # this will fail on non-encodable values, like other classes + if isinstance(data, InstrumentedList): + continue # pragma: no cover + fields[field] = data + except TypeError: + pass # Don't assign anything + # a json-encodable dict + return fields + class Role(db.Model, ModelController, RoleMixin): name = db.Column(db.Unicode, unique=True) @@ -89,6 +162,29 @@ security = Security() user_datastore = SQLAlchemyUserDatastore(db, User, Role) +def register_filters(app): + _paragraph_re = re.compile(r'(?:\r\n|\r|\n){2,}') + + @app.template_filter('nl2br') + def nl2br(text): + text = escape(text) + result = u'
'.join(u'%s' % p.replace('\n', '
\n') for p in _paragraph_re.split(text)) + return Markup(result) + + @app.template_filter('format_dt') + def format_datetime(dt, formatting="%A, %d %b %Y"): + """ + Formats the datetime string provided in value into whatever format you want that is supported by python strftime + http://strftime.org/ + :param formatting The specific format of the datetime + :param dt a datetime object + :return: + """ + if type(dt) is str: + dt = solve_query(dt) + return dt.strftime(formatting) + + def init_db(app): app.config["SQLALCHEMY_DATABASE_URI"] = SQLALCHEMY_DATABASE_URI app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False @@ -100,9 +196,11 @@ def init_db(app): db.init_app(app) security.init_app(app, user_datastore) + register_filters(app) + for dir in MAKEDIRS: os.makedirs(dir, exist_ok=True) if not os.path.exists(DATABASE_FILE): with app.app_context(): db.create_all() - return True \ No newline at end of file + return True diff --git a/oshipka/webapp/views.py b/oshipka/webapp/views.py index 4be3449..338b203 100644 --- a/oshipka/webapp/views.py +++ b/oshipka/webapp/views.py @@ -14,14 +14,14 @@ def list_view(model_view, template): return inner -def get_view(model_view, template): +def get_view(model_view, template, template_ctx): def inner(uuid): model = model_view.model instance = model.query.filter_by(uuid=uuid).first() if not instance: flash("No {}:{}".format(model_view.model_name, uuid)) return redirect(request.referrer or url_for('home')) - return render_template(template, instance=instance) + return render_template(template, instance=instance, **template_ctx) return inner @@ -113,11 +113,11 @@ class ModelView(object): 'list_{}'.format(self.model_name), list_view(self, list_template)) - def register_get(self, retrieve_template): + def register_get(self, retrieve_template, template_ctx=None): url = '/{}/'.format(self.model_name_pl) self.app.add_url_rule(url, 'get_{}'.format(self.model_name), - get_view(self, retrieve_template)) + get_view(self, retrieve_template, template_ctx)) def register_update(self, update_template): url = '/{}//edit'.format(self.model_name_pl)