diff --git a/oshipka/persistance/__init__.py b/oshipka/persistance/__init__.py index 226abd4..3ace074 100644 --- a/oshipka/persistance/__init__.py +++ b/oshipka/persistance/__init__.py @@ -194,18 +194,18 @@ class User(db.Model, ModelController, UserMixin): backref=db.backref('users', lazy='dynamic')) -class U2FCredential(db.Model, ModelController): +class Credential(db.Model, ModelController): name = db.Column(db.Unicode) date_added = db.Column(db.DateTime) device = db.Column(db.Unicode) - user_sso_id = db.Column(db.Integer, db.ForeignKey('user_sso.id')) - user_sso = db.relationship('UserSSO', - backref=db.backref("u2f_credentials"), + user_sso_id = db.Column(db.Integer, db.ForeignKey('user_s.id')) + user_sso = db.relationship('UserS', + backref=db.backref("credentials"), ) -class UserSSO(db.Model, ModelController): +class UserS(db.Model, ModelController): username = db.Column(db.Unicode, unique=True) email = db.Column(db.Unicode, unique=True) password_hash = db.Column(db.Unicode) @@ -213,7 +213,7 @@ class UserSSO(db.Model, ModelController): otp_secret = db.Column(db.String(16)) def __init__(self, **kwargs): - super(UserSSO, self).__init__(**kwargs) + super(UserS, self).__init__(**kwargs) if self.otp_secret is None: # generate a random secret self.otp_secret = base64.b32encode(os.urandom(10)).decode('utf-8') diff --git a/oshipka/util/strings.py b/oshipka/util/strings.py index e0a63ad..22cdb5f 100644 --- a/oshipka/util/strings.py +++ b/oshipka/util/strings.py @@ -1,4 +1,6 @@ import re +import random +import string def camel_case_to_snake_case(name): @@ -17,4 +19,11 @@ def snake_case_to_camel_case(name): :param name: the name to be converted :return: """ - return ''.join(x.title() for x in name.split('_')) \ No newline at end of file + return ''.join(x.title() for x in name.split('_')) + + +def random_string_generator(str_size=30, allowed_chars=None): + if not allowed_chars: + allowed_chars = string.ascii_letters + string.digits + return ''.join(random.choice(allowed_chars) for _ in range(str_size)) + diff --git a/oshipka/webapp/default_routes.py b/oshipka/webapp/default_routes.py index fde7544..92a0689 100644 --- a/oshipka/webapp/default_routes.py +++ b/oshipka/webapp/default_routes.py @@ -1,8 +1,9 @@ import urllib import requests -from flask import send_from_directory, redirect, request, url_for +from flask import send_from_directory, redirect, request, url_for, session, jsonify +from oshipka.util.strings import random_string_generator from oshipka.webapp import oshipka_bp from config import MEDIA_DIR, APP_BASE_URL from sensitive import SSO_CLIENT_ID, SSO_CLIENT_SECRET @@ -15,37 +16,56 @@ def get_media(filepath): SSO_BASE_URL = 'http://localhost:5008' +SSO_AUTH_URL = '/oidc/auth' +SSO_TOKEN_URL = '/oidc/token' +SSO_USERINFO_URL = "/endpoints/userinfo" @oshipka_bp.route('/sso') def sso(): - callback_url = APP_BASE_URL + url_for('oshipka_bp.oidc_code') - return redirect(SSO_BASE_URL + '/authenticate?callback={}&client_id={}'.format( - urllib.parse.quote(callback_url), - SSO_CLIENT_ID, - )) + callback_url = APP_BASE_URL + url_for('oshipka_bp.oidc_callback') + state = random_string_generator() + session['oidc_state'] = state + params = urllib.parse.urlencode({ + 'redirect_uri': callback_url, + 'client_id': SSO_CLIENT_ID, + 'state': state, + 'scope': 'openid', + 'response_type': 'code', + 'nonce': random_string_generator(), + }) + return redirect(SSO_BASE_URL + SSO_AUTH_URL + '?' + params) -@oshipka_bp.route('/oidc/code') -def oidc_code(): +@oshipka_bp.route('/oidc/callback') +def oidc_callback(): + error = request.args.get('error') + if error: + return jsonify({"error": "from auth server: {}".format(error)}), 400 + state = request.args.get('state') + session_state = session['oidc_state'] + if state != session_state: + return jsonify({"error": "state is different from session state"}), 400 code = request.args.get('code') - # TODO : client_id and client_secret are passed in Authorization header - # https://connect2id.com/learn/openid-connect - response = requests.get( - SSO_BASE_URL + "/oidc/token", - params={ + response = requests.post( + SSO_BASE_URL + SSO_TOKEN_URL, + data={ 'code': code, 'client_id': SSO_CLIENT_ID, 'client_secret': SSO_CLIENT_SECRET, + 'grant_type': 'authorization_code' }, ) if response.status_code == 200: response_json = response.json() access_token = response_json.get('access_token') response = requests.get( - SSO_BASE_URL + "/endpoints/user", + SSO_BASE_URL + SSO_USERINFO_URL, headers={ 'Authorization': "Bearer {}".format(access_token) }, ) + if response.status_code == 200: + return response.json() + return 'got code for userinfo: {}'.format(response.status_code) return 'got response for token: {}'.format(response.status_code)