From 3a539d698dcb36979e16357862cb5172682c049d Mon Sep 17 00:00:00 2001 From: Daniel Tsvetkov Date: Sun, 9 May 2021 17:27:22 +0200 Subject: [PATCH] test permissions after sso --- oshipka/persistance/__init__.py | 12 +++++++++--- oshipka/webapp/views.py | 9 ++++++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/oshipka/persistance/__init__.py b/oshipka/persistance/__init__.py index fd9ee9b..334e6ff 100644 --- a/oshipka/persistance/__init__.py +++ b/oshipka/persistance/__init__.py @@ -252,7 +252,7 @@ def register_filters(app): model_acl = model_view.model_acl # Anonymous user -> check public ACL if current_user.is_anonymous: - instance_acl = model_acl.query.filter_by(user=current_user, instance=instance, + instance_acl = model_acl.query.filter_by(instance=instance, acl_type=SHARING_TYPE_TYPES_TYPE_PUBLIC).first() else: # Logged in user -> find (user, instance) pair @@ -374,9 +374,12 @@ def populate_static(app): for model_name in ordered_model_names: if SECURITY_ENABLED and model_name in ['User', 'Role']: model = eval(model_name) + model_acl = None else: model = getattr(models, model_name) + model_acl = getattr(models, model_name + 'Acl') with open(os.path.join(STATIC_DATA_DIR, "{}.csv".format(model_name))) as f: + user = User.query.first() reader = csv.DictReader(f) for row in reader: row_updates = dict() @@ -387,7 +390,7 @@ def populate_static(app): row_updates[key] = sensitive_value if row_updates: row.update(row_updates) - instance = create_model(model, row) + instance = create_model(model, model_acl, user, row) db.session.add(instance) db.session.commit() print("Finished populating") @@ -413,7 +416,8 @@ def update_m_ns(instance, m_ns): setattr(instance, key, children) -def create_model(model, serialized_args): +def create_model(model, model_acl, user, serialized_args): + from oshipka.webapp.views import create_acls m_ns, to_delete = filter_m_n(serialized_args) for key in to_delete: del serialized_args[key] @@ -423,4 +427,6 @@ def create_model(model, serialized_args): for key, ids in m_ns.items(): m_ns[key] = ids.split(',') update_m_ns(instance, m_ns) + if model_acl and user: + create_acls(model_acl, instance, user) return instance diff --git a/oshipka/webapp/views.py b/oshipka/webapp/views.py index 610a145..097753e 100644 --- a/oshipka/webapp/views.py +++ b/oshipka/webapp/views.py @@ -129,10 +129,13 @@ def default_create_func(vc): instance = vc.instances or vc.model_view.model() vc.instances = [instance] default_update_func(vc) + create_acls(vc.model_view.model_acl, instance, current_user) - instance_public_acl = vc.model_view.model_acl(user=current_user, instance=instance, acl_type=SHARING_TYPE_TYPES_TYPE_PUBLIC) - instance_authn_acl = vc.model_view.model_acl(user=current_user, instance=instance, acl_type=SHARING_TYPE_TYPES_TYPE_AUTHN) - instance_authz_acl = vc.model_view.model_acl(user=current_user, instance=instance, acl_type=SHARING_TYPE_TYPES_TYPE_AUTHZ) + +def create_acls(model_acl, instance, user): + instance_public_acl = model_acl(user=user, instance=instance, acl_type=SHARING_TYPE_TYPES_TYPE_PUBLIC) + instance_authn_acl = model_acl(user=user, instance=instance, acl_type=SHARING_TYPE_TYPES_TYPE_AUTHN) + instance_authz_acl = model_acl(user=user, instance=instance, acl_type=SHARING_TYPE_TYPES_TYPE_AUTHZ) db.session.add(instance_public_acl) db.session.add(instance_authn_acl) db.session.add(instance_authz_acl)