diff --git a/oshipka/persistance/__init__.py b/oshipka/persistance/__init__.py
index 4b9e3e4..0ae08d7 100644
--- a/oshipka/persistance/__init__.py
+++ b/oshipka/persistance/__init__.py
@@ -1,4 +1,5 @@
import datetime
+import json
import os
import re
from json import JSONEncoder
@@ -6,10 +7,14 @@ from uuid import uuid4
from flask_sqlalchemy import SQLAlchemy
from flask_security import Security, SQLAlchemyUserDatastore
-from sqlalchemy.ext.declarative import declared_attr
+from markupsafe import escape, Markup
+from sqlalchemy.ext.declarative import declared_attr, DeclarativeMeta
from flask_security import RoleMixin, UserMixin
from config import SQLALCHEMY_DATABASE_URI, MAKEDIRS, DATABASE_FILE
+from sqlalchemy.orm.collections import InstrumentedList
+from sqlalchemy_utils import Choice
+from tww.tww import solve_query
db = SQLAlchemy()
@@ -61,6 +66,74 @@ class ModelController(ModelJsonEncoder):
id = db.Column(db.Integer, primary_key=True)
uuid = db.Column(db.Unicode, default=str(uuid4()))
+ _excluded_serialization = []
+
+ def serialize(self, with_=None, depth=0, withs_used=None):
+ """
+ Serializes object to dict
+ It will ignore fields that are not encodable (set them to 'None').
+
+ It expands relations mentioned in with_ recursively up to MAX_DEPTH and tries to smartly ignore
+ recursions by mentioning which with elements have already been used in previous depths
+ :return:
+ """
+
+ MAX_DEPTH = 3
+
+ def with_used_in_prev_depth(field, previous_depths):
+ for previous_depth in previous_depths:
+ if field in previous_depth:
+ return True
+ return False
+
+ def handle_withs(data, with_, depth, withs_used):
+ if isinstance(data, InstrumentedList):
+ if depth >= MAX_DEPTH:
+ return [e.serialize_flat() for e in data]
+ else:
+ return [e.serialize_flat(with_=with_, depth=depth + 1, withs_used=withs_used) for e in data]
+ else:
+ if depth >= MAX_DEPTH:
+ return data.serialize_flat()
+ else:
+ return data.serialize_flat(with_=with_, depth=depth + 1, withs_used=withs_used)
+
+ if not with_:
+ with_ = []
+ if not withs_used:
+ withs_used = []
+ if isinstance(self.__class__, DeclarativeMeta):
+ # an SQLAlchemy class
+ fields = {}
+ iterable_fields = [x for x in dir(self) if not x.startswith('_') and x not in ['metadata',
+ 'item_separator',
+ 'key_separator'] and x.islower()
+ and x not in self._excluded_serialization]
+ for field in iterable_fields:
+ data = self.__getattribute__(field)
+ try:
+ if field in with_:
+ # this hanldes withs nested inside other models
+ if len(withs_used) < depth + 1:
+ withs_used.append([])
+ previous_depths = withs_used[:depth]
+ if with_used_in_prev_depth(field, previous_depths):
+ continue
+ withs_used[depth].append(field)
+ data = handle_withs(data, with_, depth, withs_used)
+ if isinstance(data, datetime.datetime):
+ data = str(data)
+ if isinstance(data, Choice):
+ data = data.code
+ json.dumps(data) # this will fail on non-encodable values, like other classes
+ if isinstance(data, InstrumentedList):
+ continue # pragma: no cover
+ fields[field] = data
+ except TypeError:
+ pass # Don't assign anything
+ # a json-encodable dict
+ return fields
+
class Role(db.Model, ModelController, RoleMixin):
name = db.Column(db.Unicode, unique=True)
@@ -89,6 +162,29 @@ security = Security()
user_datastore = SQLAlchemyUserDatastore(db, User, Role)
+def register_filters(app):
+ _paragraph_re = re.compile(r'(?:\r\n|\r|\n){2,}')
+
+ @app.template_filter('nl2br')
+ def nl2br(text):
+ text = escape(text)
+ result = u'
'.join(u'%s' % p.replace('\n', '
\n') for p in _paragraph_re.split(text))
+ return Markup(result)
+
+ @app.template_filter('format_dt')
+ def format_datetime(dt, formatting="%A, %d %b %Y"):
+ """
+ Formats the datetime string provided in value into whatever format you want that is supported by python strftime
+ http://strftime.org/
+ :param formatting The specific format of the datetime
+ :param dt a datetime object
+ :return:
+ """
+ if type(dt) is str:
+ dt = solve_query(dt)
+ return dt.strftime(formatting)
+
+
def init_db(app):
app.config["SQLALCHEMY_DATABASE_URI"] = SQLALCHEMY_DATABASE_URI
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
@@ -100,9 +196,11 @@ def init_db(app):
db.init_app(app)
security.init_app(app, user_datastore)
+ register_filters(app)
+
for dir in MAKEDIRS:
os.makedirs(dir, exist_ok=True)
if not os.path.exists(DATABASE_FILE):
with app.app_context():
db.create_all()
- return True
\ No newline at end of file
+ return True
diff --git a/oshipka/webapp/views.py b/oshipka/webapp/views.py
index 4be3449..338b203 100644
--- a/oshipka/webapp/views.py
+++ b/oshipka/webapp/views.py
@@ -14,14 +14,14 @@ def list_view(model_view, template):
return inner
-def get_view(model_view, template):
+def get_view(model_view, template, template_ctx):
def inner(uuid):
model = model_view.model
instance = model.query.filter_by(uuid=uuid).first()
if not instance:
flash("No {}:{}".format(model_view.model_name, uuid))
return redirect(request.referrer or url_for('home'))
- return render_template(template, instance=instance)
+ return render_template(template, instance=instance, **template_ctx)
return inner
@@ -113,11 +113,11 @@ class ModelView(object):
'list_{}'.format(self.model_name),
list_view(self, list_template))
- def register_get(self, retrieve_template):
+ def register_get(self, retrieve_template, template_ctx=None):
url = '/{}/'.format(self.model_name_pl)
self.app.add_url_rule(url,
'get_{}'.format(self.model_name),
- get_view(self, retrieve_template))
+ get_view(self, retrieve_template, template_ctx))
def register_update(self, update_template):
url = '/{}//edit'.format(self.model_name_pl)