예제 #1
0
def create_session_db(ec,
                      token_handler_args,
                      db=None,
                      sso_db=None,
                      sub_func=None):
    _token_handler = token_handler.factory(ec, **token_handler_args)
    db = db or InMemoryDataBase()
    sso_db = sso_db or SSODb()
    return SessionDB(db, _token_handler, sso_db, sub_func=sub_func)
예제 #2
0
    def set_session_db(self, sso_db=None, db=None):
        if sso_db is None and self.conf.get("sso_db"):
            _spec = self.conf.get("sso_db")
            _kwargs = _spec.get("kwargs", {})
            _db = importer(_spec["class"])(**_kwargs)
            sso_db = SSODb(_db)
        else:
            sso_db = sso_db or SSODb()

        if db is None and self.conf.get("session_db"):
            _spec = self.conf.get("session_db")
            _kwargs = _spec.get("kwargs", {})
            db = importer(_spec["class"])(**_kwargs)

        self.do_session_db(sso_db, db)
        # append userinfo db to the session db
        self.do_userinfo()
        logger.debug("Session DB: {}".format(self.sdb.__dict__))
예제 #3
0
def create_session_db(password, token_expires_in=3600,
                      grant_expires_in=600, refresh_token_expires_in=86400,
                      db=None, sso_db=SSODb()):
    _token_handler = token_handler.factory(
        password, token_expires_in, grant_expires_in, refresh_token_expires_in)

    if not db:
        db = InMemoryDataBase()

    return SessionDB(db, _token_handler, sso_db)
예제 #4
0
    def __init__(self, db, handler, sso_db=SSODb(), userinfo=None, sub_func=None):
        # db must implement the InMemoryDataBase interface
        self._db = db
        self.handler = handler
        self.sso_db = sso_db
        self.userinfo = userinfo

        # this allows the subject identifier minters to be defined by someone
        # else then me.
        if sub_func is None:
            self.sub_func = {"public": public_id, "pairwise": pairwise_id}
        else:
            self.sub_func = sub_func
            if "public" not in sub_func:
                self.sub_func["public"] = public_id
            if "pairwise" not in sub_func:
                self.sub_func["pairwise"] = pairwise_id
예제 #5
0
    def create_sdb(self):
        _sso_db = SSODb()
        passwd = rndstr(24)
        _th_args = {
            "code": {
                "lifetime": 600,
                "password": passwd
            },
            "token": {
                "lifetime": 3600,
                "password": passwd
            },
            "refresh": {
                "lifetime": 86400,
                "password": passwd
            },
        }

        _token_handler = token_handler.factory(None, **_th_args)
        userinfo = UserInfo(db_file=full_path("users.json"))
        self.sdb = SessionDB(InMemoryDataBase(), _token_handler, _sso_db,
                             userinfo)
예제 #6
0
 def create_sdb(self):
     _sso_db = SSODb()
     _token_handler = token_handler.factory('losenord')
     self.sdb = SessionDB(InMemoryDataBase(), _token_handler, _sso_db)
예제 #7
0
    def __init__(
        self,
        conf,
        keyjar=None,
        client_db=None,
        session_db=None,
        cwd="",
        cookie_dealer=None,
        httpc=None,
        cookie_name=None,
        jwks_uri_path=None,
    ):
        self.conf = conf
        self.keyjar = keyjar or KeyJar()
        self.cwd = cwd

        # client database
        self.cdb = client_db or {}

        try:
            self.seed = bytes(conf["seed"], "utf-8")
        except KeyError:
            self.seed = bytes(rndstr(16), "utf-8")

        # Default values, to be changed below depending on configuration
        self.endpoint = {}
        self.issuer = ""
        self.httpc = httpc or requests
        self.verify_ssl = True
        self.jwks_uri = None
        self.sso_ttl = 14400  # 4h
        self.symkey = rndstr(24)
        self.id_token_schema = IdToken
        self.endpoint_to_authn_method = {}
        self.cookie_dealer = cookie_dealer
        self.login_hint_lookup = None

        if cookie_name:
            self.cookie_name = cookie_name
        elif "cookie_name" in conf:
            self.cookie_name = conf["cookie_name"]
        else:
            self.cookie_name = {
                "session": "oidcop",
                "register": "oidc_op_rp",
                "session_management": "sman",
            }

        for param in [
                "verify_ssl",
                "issuer",
                "sso_ttl",
                "symkey",
                "client_authn",
                "id_token_schema",
        ]:
            try:
                setattr(self, param, conf[param])
            except KeyError:
                pass

        try:
            self.template_handler = conf["template_handler"]
        except KeyError:
            try:
                loader = conf["template_loader"]
            except KeyError:
                template_dir = conf["template_dir"]
                loader = Environment(loader=FileSystemLoader(template_dir),
                                     autoescape=True)
            self.template_handler = Jinja2TemplateHandler(loader)

        self.setup = {}
        if not jwks_uri_path:
            try:
                jwks_uri_path = conf["jwks"]["uri_path"]
            except KeyError:
                pass

        try:
            if self.issuer.endswith("/"):
                self.jwks_uri = "{}{}".format(self.issuer, jwks_uri_path)
            else:
                self.jwks_uri = "{}/{}".format(self.issuer, jwks_uri_path)
        except KeyError:
            self.jwks_uri = ""

        if self.keyjar is None or self.keyjar.owners() == []:
            args = {k: v for k, v in conf["jwks"].items() if k != "uri_path"}
            self.keyjar = init_key_jar(**args)

        try:
            _conf = conf["cookie_dealer"]
        except KeyError:
            pass
        else:
            if self.cookie_dealer:  # already defined
                raise ValueError("Cookie Dealer already defined")
            self.cookie_dealer = init_service(_conf)

        try:
            _conf = conf["sub_func"]
        except KeyError:
            sub_func = None
        else:
            sub_func = {}
            for key, args in _conf.items():
                if "class" in args:
                    sub_func[key] = init_service(args)
                elif "function" in args:
                    if isinstance(args["function"], str):
                        sub_func[key] = util.importer(args["function"])
                    else:
                        sub_func[key] = args["function"]

        if session_db:
            self.sdb = session_db
        else:
            try:
                _th_args = conf["token_handler_args"]
            except KeyError:
                # create 3 keys
                keydef = [
                    {
                        "type": "oct",
                        "bytes": "24",
                        "use": ["enc"],
                        "kid": "code"
                    },
                    {
                        "type": "oct",
                        "bytes": "24",
                        "use": ["enc"],
                        "kid": "token"
                    },
                    {
                        "type": "oct",
                        "bytes": "24",
                        "use": ["enc"],
                        "kid": "refresh"
                    },
                ]

                jwks_def = {
                    "private_path": "private/token_jwks.json",
                    "key_defs": keydef,
                    "read_only": False,
                }

                _th_args = {"jwks_def": jwks_def}
                for typ, tid in [("code", 600), ("token", 3600),
                                 ("refresh", 86400)]:
                    _th_args[typ] = {"lifetime": tid}

            self.sdb = create_session_db(self,
                                         _th_args,
                                         db=None,
                                         sso_db=SSODb(),
                                         sub_func=sub_func)

        self.endpoint = build_endpoints(
            conf["endpoint"],
            endpoint_context=self,
            client_authn_method=CLIENT_AUTHN_METHOD,
            issuer=conf["issuer"],
        )
        try:
            _cap = conf["capabilities"]
        except KeyError:
            _cap = {}

        for endpoint, endpoint_instance in self.endpoint.items():
            if endpoint_instance.provider_info:
                _cap.update(endpoint_instance.provider_info)

            if endpoint in ["webfinger", "provider_info"]:
                continue

            _cap[endpoint_instance.endpoint_name] = "{}".format(
                endpoint_instance.endpoint_path)

        try:
            authz_spec = conf["authz"]
        except KeyError:
            self.authz = authz.Implicit(self)
        else:
            self.authz = init_service(authz_spec, self)

        try:
            _authn = conf["authentication"]
        except KeyError:
            self.authn_broker = None
        else:
            self.authn_broker = populate_authn_broker(_authn, self,
                                                      self.template_handler)

        try:
            _conf = conf["id_token"]
        except KeyError:
            self.idtoken = IDToken(self)
        else:
            self.idtoken = init_service(_conf, self)

        try:
            _conf = conf["userinfo"]
        except KeyError:
            pass
        else:
            self.userinfo = init_user_info(_conf, self.cwd)
            self.sdb.userinfo = self.userinfo

        try:
            _conf = conf["login_hint_lookup"]
        except KeyError:
            pass
        else:
            self.login_hint_lookup = init_service(_conf)
            if self.userinfo:
                self.login_hint_lookup.user_info = self.userinfo

        try:
            _conf = conf["login_hint2acrs"]
        except KeyError:
            self.login_hint2acrs = None
        else:
            self.login_hint2acrs = init_service(_conf)

        self.provider_info = self.create_providerinfo(_cap)

        # which signing/encryption algorithms to use in what context
        self.jwx_def = {}

        # special type of logging
        self.events = None

        # client registration access tokens
        self.registration_access_token = {}
예제 #8
0
 def create_sdb(self):
     self.sso_db = SSODb()
예제 #9
0
class TestSessionDB(object):
    @pytest.fixture(autouse=True)
    def create_sdb(self):
        self.sso_db = SSODb()

    def test_map_sid2uid(self):
        self.sso_db.map_sid2uid("session id 1", "Lizz")
        assert self.sso_db.get_sids_by_uid("Lizz") == ["session id 1"]

    def test_missing_map(self):
        assert self.sso_db.get_sids_by_uid("Lizz") is None

    def test_multiple_map_sid2uid(self):
        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Lizz")
        assert set(self.sso_db.get_sids_by_uid("Lizz")) == {
            "session id 1",
            "session id 2",
        }

    def test_map_unmap_sid2uid(self):
        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Lizz")
        assert set(self.sso_db.get_sids_by_uid("Lizz")) == {
            "session id 1",
            "session id 2",
        }

        self.sso_db.remove_sid2uid("session id 1", "Lizz")
        assert self.sso_db.get_sids_by_uid("Lizz") == ["session id 2"]

    def test_get_uid_by_sid(self):
        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Lizz")

        assert self.sso_db.get_uid_by_sid("session id 1") == "Lizz"
        assert self.sso_db.get_uid_by_sid("session id 2") == "Lizz"

    def test_remove_uid(self):
        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Diana")

        self.sso_db.remove_uid("Lizz")
        assert self.sso_db.get_uid_by_sid("session id 1") is None
        assert self.sso_db.get_sids_by_uid("Lizz") is None

    def test_map_sid2sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        assert self.sso_db.get_sids_by_sub("abcdefgh") == ["session id 1"]

    def test_missing_sid2sub_map(self):
        assert self.sso_db.get_sids_by_sub("abcdefgh") is None

    def test_multiple_map_sid2sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "abcdefgh")
        assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == {
            "session id 1",
            "session id 2",
        }

    def test_map_unmap_sid2sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "abcdefgh")
        assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == {
            "session id 1",
            "session id 2",
        }

        self.sso_db.remove_sid2sub("session id 1", "abcdefgh")
        assert self.sso_db.get_sids_by_sub("abcdefgh") == ["session id 2"]

    def test_get_sub_by_sid(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "abcdefgh")

        assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == {
            "session id 1",
            "session id 2",
        }

    def test_remove_sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "012346789")

        self.sso_db.remove_sub("abcdefgh")
        assert self.sso_db.get_sub_by_sid("session id 1") is None
        assert self.sso_db.get_sids_by_sub("abcdefgh") is None
        # have not touched the others
        assert self.sso_db.get_sub_by_sid("session id 2") == "012346789"
        assert self.sso_db.get_sids_by_sub("012346789") == ["session id 2"]

    def test_get_sub_by_uid_same_sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "abcdefgh")

        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Lizz")

        res = self.sso_db.get_subs_by_uid("Lizz")

        assert set(res) == {"abcdefgh"}

    def test_get_sub_by_uid_different_sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "012346789")

        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Lizz")

        res = self.sso_db.get_subs_by_uid("Lizz")

        assert set(res) == {"abcdefgh", "012346789"}
예제 #10
0
 def create_sdb(self):
     # Create fresh database each time
     _db = ShelveDataBase(filename='shelf', flag='n', writeback=True)
     self.sso_db = SSODb(_db)
예제 #11
0
class TestSessionShelveDB(object):
    @pytest.fixture(autouse=True)
    def create_sdb(self):
        # Create fresh database each time
        _db = ShelveDataBase(filename='shelf', flag='n', writeback=True)
        self.sso_db = SSODb(_db)

    def _reset(self):
        self.sso_db.clear()
        self.sso_db.close()

    def test_map_sid2uid(self):
        self.sso_db.map_sid2uid("session id 1", "Lizz")
        assert self.sso_db.get_sids_by_uid("Lizz") == ["session id 1"]
        self._reset()

    def test_missing_map(self):
        assert self.sso_db.get_sids_by_uid("Lizz") is None
        self._reset()

    def test_multiple_map_sid2uid(self):
        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Lizz")
        assert set(self.sso_db.get_sids_by_uid("Lizz")) == {
            "session id 1",
            "session id 2",
        }
        self._reset()

    def test_map_unmap_sid2uid(self):
        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Lizz")
        assert set(self.sso_db.get_sids_by_uid("Lizz")) == {
            "session id 1",
            "session id 2",
        }

        self.sso_db.remove_sid2uid("session id 1", "Lizz")
        assert self.sso_db.get_sids_by_uid("Lizz") == ["session id 2"]
        self._reset()

    def test_get_uid_by_sid(self):
        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Lizz")

        assert self.sso_db.get_uid_by_sid("session id 1") == "Lizz"
        assert self.sso_db.get_uid_by_sid("session id 2") == "Lizz"
        self._reset()

    def test_remove_uid(self):
        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Diana")

        self.sso_db.remove_uid("Lizz")
        assert self.sso_db.get_uid_by_sid("session id 1") is None
        assert self.sso_db.get_sids_by_uid("Lizz") is None
        self._reset()

    def test_map_sid2sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        assert self.sso_db.get_sids_by_sub("abcdefgh") == ["session id 1"]
        self._reset()

    def test_missing_sid2sub_map(self):
        assert self.sso_db.get_sids_by_sub("abcdefgh") is None
        self._reset()

    def test_multiple_map_sid2sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "abcdefgh")
        assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == {
            "session id 1",
            "session id 2",
        }
        self._reset()

    def test_map_unmap_sid2sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "abcdefgh")
        assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == {
            "session id 1",
            "session id 2",
        }

        self.sso_db.remove_sid2sub("session id 1", "abcdefgh")
        assert self.sso_db.get_sids_by_sub("abcdefgh") == ["session id 2"]
        self._reset()

    def test_get_sub_by_sid(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "abcdefgh")

        assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == {
            "session id 1",
            "session id 2",
        }
        self._reset()

    def test_remove_sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "012346789")

        self.sso_db.remove_sub("abcdefgh")
        assert self.sso_db.get_sub_by_sid("session id 1") is None
        assert self.sso_db.get_sids_by_sub("abcdefgh") is None
        # have not touched the others
        assert self.sso_db.get_sub_by_sid("session id 2") == "012346789"
        assert self.sso_db.get_sids_by_sub("012346789") == ["session id 2"]
        self._reset()

    def test_get_sub_by_uid_same_sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "abcdefgh")

        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Lizz")

        res = self.sso_db.get_subs_by_uid("Lizz")

        assert set(res) == {"abcdefgh"}
        self._reset()

    def test_get_sub_by_uid_different_sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "012346789")

        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Lizz")

        res = self.sso_db.get_subs_by_uid("Lizz")

        assert set(res) == {"abcdefgh", "012346789"}
        self._reset()
예제 #12
0
class TestSessionDB(object):
    @pytest.fixture(autouse=True)
    def create_sdb(self):
        self.sso_db = SSODb()

    def test_map_sid2uid(self):
        self.sso_db.map_sid2uid('session id 1', 'Lizz')
        assert self.sso_db.get_sids_by_uid('Lizz') == ['session id 1']

    def test_missing_map(self):
        with pytest.raises(KeyError):
            self.sso_db.get_sids_by_uid('Lizz')

    def test_multiple_map_sid2uid(self):
        self.sso_db.map_sid2uid('session id 1', 'Lizz')
        self.sso_db.map_sid2uid('session id 2', 'Lizz')
        assert set(self.sso_db.get_sids_by_uid('Lizz')) == {
            'session id 1', 'session id 2'
        }

    def test_map_unmap_sid2uid(self):
        self.sso_db.map_sid2uid('session id 1', 'Lizz')
        self.sso_db.map_sid2uid('session id 2', 'Lizz')
        assert set(self.sso_db.get_sids_by_uid('Lizz')) == {
            'session id 1', 'session id 2'
        }

        self.sso_db.unmap_sid2uid('session id 1', 'Lizz')
        assert self.sso_db.get_sids_by_uid('Lizz') == ['session id 2']

    def test_get_uid_by_sid(self):
        self.sso_db.map_sid2uid('session id 1', 'Lizz')
        self.sso_db.map_sid2uid('session id 2', 'Lizz')

        assert self.sso_db.get_uid_by_sid('session id 1') == 'Lizz'
        assert self.sso_db.get_uid_by_sid('session id 2') == 'Lizz'

    def test_remove_uid(self):
        self.sso_db.map_sid2uid('session id 1', 'Lizz')
        self.sso_db.map_sid2uid('session id 2', 'Diana')

        self.sso_db.remove_uid('Lizz')
        assert set(self.sso_db.uid2sid.keys()) == {'Diana'}
        assert set(self.sso_db.uid2sid_rev.keys()) == {'session id 2'}

    def test_map_sid2sub(self):
        self.sso_db.map_sid2sub('session id 1', 'abcdefgh')
        assert self.sso_db.get_sids_by_sub('abcdefgh') == ['session id 1']

    def test_missing_sid2sub_map(self):
        with pytest.raises(KeyError):
            self.sso_db.get_sids_by_sub('abcdefgh')

    def test_multiple_map_sid2sub(self):
        self.sso_db.map_sid2sub('session id 1', 'abcdefgh')
        self.sso_db.map_sid2sub('session id 2', 'abcdefgh')
        assert set(self.sso_db.get_sids_by_sub('abcdefgh')) == {
            'session id 1', 'session id 2'
        }

    def test_map_unmap_sid2sub(self):
        self.sso_db.map_sid2sub('session id 1', 'abcdefgh')
        self.sso_db.map_sid2sub('session id 2', 'abcdefgh')
        assert set(self.sso_db.get_sids_by_sub('abcdefgh')) == {
            'session id 1', 'session id 2'
        }

        self.sso_db.unmap_sid2sub('session id 1', 'abcdefgh')
        assert self.sso_db.get_sids_by_sub('abcdefgh') == ['session id 2']

    def test_get_sub_by_sid(self):
        self.sso_db.map_sid2sub('session id 1', 'abcdefgh')
        self.sso_db.map_sid2sub('session id 2', 'abcdefgh')

        assert set(self.sso_db.get_sids_by_sub('abcdefgh')) == {
            'session id 1', 'session id 2'
        }

    def test_remove_sub(self):
        self.sso_db.map_sid2sub('session id 1', 'abcdefgh')
        self.sso_db.map_sid2sub('session id 2', '012346789')

        self.sso_db.remove_sub('012346789')
        assert set(self.sso_db.sub2sid.keys()) == {'abcdefgh'}
        assert set(self.sso_db.sub2sid_rev.keys()) == {'session id 1'}

    def test_get_sub_by_uid_same_sub(self):
        self.sso_db.map_sid2sub('session id 1', 'abcdefgh')
        self.sso_db.map_sid2sub('session id 2', 'abcdefgh')

        self.sso_db.map_sid2uid('session id 1', 'Lizz')
        self.sso_db.map_sid2uid('session id 2', 'Lizz')

        res = self.sso_db.get_sub_by_uid('Lizz')

        assert res == {'abcdefgh'}

    def test_get_sub_by_uid_different_sub(self):
        self.sso_db.map_sid2sub('session id 1', 'abcdefgh')
        self.sso_db.map_sid2sub('session id 2', '012346789')

        self.sso_db.map_sid2uid('session id 1', 'Lizz')
        self.sso_db.map_sid2uid('session id 2', 'Lizz')

        res = self.sso_db.get_sub_by_uid('Lizz')

        assert res == {'abcdefgh', '012346789'}
예제 #13
0
    def __init__(self,
                 conf,
                 keyjar=None,
                 client_db=None,
                 session_db=None,
                 cwd='',
                 cookie_dealer=None):
        self.conf = conf
        self.keyjar = keyjar or KeyJar()
        self.cwd = cwd

        if session_db:
            self.sdb = session_db
        else:
            self.sdb = create_session_db(
                conf['password'],
                db=None,
                token_expires_in=conf['token_expires_in'],
                grant_expires_in=conf['grant_expires_in'],
                refresh_token_expires_in=conf['refresh_token_expires_in'],
                sso_db=SSODb())

        # client database
        self.cdb = client_db or {}

        try:
            self.seed = bytes(conf['seed'], 'utf-8')
        except KeyError:
            self.seed = bytes(rndstr(16), 'utf-8')

        # Default values, to be changed below depending on configuration
        self.endpoint = {}
        self.issuer = ''
        self.verify_ssl = True
        self.jwks_uri = None
        self.sso_ttl = 14400  # 4h
        self.symkey = rndstr(24)
        self.id_token_schema = IdToken
        self.endpoint_to_authn_method = {}
        self.cookie_dealer = cookie_dealer

        for param in [
                'verify_ssl', 'issuer', 'sso_ttl', 'symkey', 'client_authn',
                'id_token_schema'
        ]:
            try:
                setattr(self, param, conf[param])
            except KeyError:
                pass

        template_dir = conf["template_dir"]
        jinja_env = Environment(loader=FileSystemLoader(template_dir))

        self.setup = {}
        try:
            self.jwks_uri = '{}/{}'.format(self.issuer,
                                           conf['jwks']['public_path'])
        except KeyError:
            self.jwks_uri = ''

        self.endpoint = build_endpoints(
            conf['endpoint'],
            endpoint_context=self,
            client_authn_method=CLIENT_AUTHN_METHOD,
            issuer=conf['issuer'])
        try:
            _cap = conf['capabilities']
        except KeyError:
            _cap = {}

        for endpoint in ['authorization', 'token', 'userinfo', 'registration']:
            try:
                endpoint_spec = self.endpoint[endpoint]
            except KeyError:
                pass
            else:
                _cap[endpoint_spec.endpoint_name] = '{}'.format(
                    self.endpoint[endpoint].endpoint_path)

        try:
            authz_spec = conf['authz']
        except KeyError:
            self.authz = authz.Implicit(self)
        else:
            if 'args' in authz_spec:
                self.authz = authz.factory(authz_spec['name'],
                                           **authz_spec['args'])
            else:
                self.authz = authz.factory(self, authz_spec['name'])

        try:
            _authn = conf['authentication']
        except KeyError:
            self.authn_broker = None
        else:
            self.authn_broker = AuthnBroker()

            for authn_spec in _authn:
                try:
                    _args = authn_spec['kwargs']
                except KeyError:
                    _args = {}

                if 'template' in _args:
                    _args['template_env'] = jinja_env

                _args['endpoint_context'] = self
                authn_method = user.factory(authn_spec['name'], **_args)
                args = {
                    k: authn_spec[k]
                    for k in ['acr', 'level', 'authn_authority']
                    if k in authn_spec
                }

                self.authn_broker.add(method=authn_method, **args)
                self.endpoint_to_authn_method[
                    authn_method.url_endpoint] = authn_method

        try:
            _conf = conf['userinfo']
        except KeyError:
            pass
        else:
            try:
                kwargs = _conf['kwargs']
            except KeyError:
                kwargs = {}

            if 'db_file' in kwargs:
                kwargs['db_file'] = os.path.join(self.cwd, kwargs['db_file'])
            self.userinfo = _conf['class'](**kwargs)

        self.provider_info = self.create_providerinfo(_cap)

        # which signing/encryption algorithms to use in what context
        self.jwx_def = {}

        # special type of logging
        self.events = None
예제 #14
0
 def set_session_db(self, sso_db=None, db=None):
     sso_db = sso_db or SSODb()
     self.do_session_db(sso_db, db)
     # append useinfo db to the session db
     self.do_userinfo()
     logger.debug("Session DB: {}".format(self.sdb.__dict__))
예제 #15
0
class TestSessionDB(object):
    @pytest.fixture(autouse=True)
    def create_sdb(self):
        self.sso_db = SSODb()

    def test_map_sid2uid(self):
        self.sso_db.map_sid2uid('session id 1', 'Lizz')
        assert self.sso_db.get_sids_by_uid('Lizz') == ['session id 1']

    def test_missing_map(self):
        assert self.sso_db.get_sids_by_uid('Lizz') is None

    def test_multiple_map_sid2uid(self):
        self.sso_db.map_sid2uid('session id 1', 'Lizz')
        self.sso_db.map_sid2uid('session id 2', 'Lizz')
        assert set(self.sso_db.get_sids_by_uid('Lizz')) == {
            'session id 1', 'session id 2'
        }

    def test_map_unmap_sid2uid(self):
        self.sso_db.map_sid2uid('session id 1', 'Lizz')
        self.sso_db.map_sid2uid('session id 2', 'Lizz')
        assert set(self.sso_db.get_sids_by_uid('Lizz')) == {
            'session id 1', 'session id 2'
        }

        self.sso_db.remove_sid2uid('session id 1', 'Lizz')
        assert self.sso_db.get_sids_by_uid('Lizz') == ['session id 2']

    def test_get_uid_by_sid(self):
        self.sso_db.map_sid2uid('session id 1', 'Lizz')
        self.sso_db.map_sid2uid('session id 2', 'Lizz')

        assert self.sso_db.get_uid_by_sid('session id 1') == 'Lizz'
        assert self.sso_db.get_uid_by_sid('session id 2') == 'Lizz'

    def test_remove_uid(self):
        self.sso_db.map_sid2uid('session id 1', 'Lizz')
        self.sso_db.map_sid2uid('session id 2', 'Diana')

        self.sso_db.remove_uid('Lizz')
        assert self.sso_db.get_uid_by_sid('session id 1') is None
        assert self.sso_db.get_sids_by_uid('Lizz') is None

    def test_map_sid2sub(self):
        self.sso_db.map_sid2sub('session id 1', 'abcdefgh')
        assert self.sso_db.get_sids_by_sub('abcdefgh') == ['session id 1']

    def test_missing_sid2sub_map(self):
        assert self.sso_db.get_sids_by_sub('abcdefgh') is None

    def test_multiple_map_sid2sub(self):
        self.sso_db.map_sid2sub('session id 1', 'abcdefgh')
        self.sso_db.map_sid2sub('session id 2', 'abcdefgh')
        assert set(self.sso_db.get_sids_by_sub('abcdefgh')) == {
            'session id 1', 'session id 2'
        }

    def test_map_unmap_sid2sub(self):
        self.sso_db.map_sid2sub('session id 1', 'abcdefgh')
        self.sso_db.map_sid2sub('session id 2', 'abcdefgh')
        assert set(self.sso_db.get_sids_by_sub('abcdefgh')) == {
            'session id 1', 'session id 2'
        }

        self.sso_db.remove_sid2sub('session id 1', 'abcdefgh')
        assert self.sso_db.get_sids_by_sub('abcdefgh') == ['session id 2']

    def test_get_sub_by_sid(self):
        self.sso_db.map_sid2sub('session id 1', 'abcdefgh')
        self.sso_db.map_sid2sub('session id 2', 'abcdefgh')

        assert set(self.sso_db.get_sids_by_sub('abcdefgh')) == {
            'session id 1', 'session id 2'
        }

    def test_remove_sub(self):
        self.sso_db.map_sid2sub('session id 1', 'abcdefgh')
        self.sso_db.map_sid2sub('session id 2', '012346789')

        self.sso_db.remove_sub('abcdefgh')
        assert self.sso_db.get_sub_by_sid('session id 1') is None
        assert self.sso_db.get_sids_by_sub('abcdefgh') is None
        # have not touched the others
        assert self.sso_db.get_sub_by_sid('session id 2') == '012346789'
        assert self.sso_db.get_sids_by_sub('012346789') == ['session id 2']

    def test_get_sub_by_uid_same_sub(self):
        self.sso_db.map_sid2sub('session id 1', 'abcdefgh')
        self.sso_db.map_sid2sub('session id 2', 'abcdefgh')

        self.sso_db.map_sid2uid('session id 1', 'Lizz')
        self.sso_db.map_sid2uid('session id 2', 'Lizz')

        res = self.sso_db.get_subs_by_uid('Lizz')

        assert set(res) == {'abcdefgh'}

    def test_get_sub_by_uid_different_sub(self):
        self.sso_db.map_sid2sub('session id 1', 'abcdefgh')
        self.sso_db.map_sid2sub('session id 2', '012346789')

        self.sso_db.map_sid2uid('session id 1', 'Lizz')
        self.sso_db.map_sid2uid('session id 2', 'Lizz')

        res = self.sso_db.get_subs_by_uid('Lizz')

        assert set(res) == {'abcdefgh', '012346789'}
예제 #16
0
class TestSessionDB(object):
    @pytest.fixture(autouse=True)
    def create_sdb(self):
        self.sso_db = SSODb()

    def test_map_sid2uid(self):
        self.sso_db.map_sid2uid("session id 1", "Lizz")
        assert self.sso_db.get_sids_by_uid("Lizz") == ["session id 1"]

    def test_missing_map(self):
        with pytest.raises(KeyError):
            self.sso_db.get_sids_by_uid("Lizz")

    def test_multiple_map_sid2uid(self):
        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Lizz")
        assert set(self.sso_db.get_sids_by_uid("Lizz")) == {
            "session id 1",
            "session id 2",
        }

    def test_map_unmap_sid2uid(self):
        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Lizz")
        assert set(self.sso_db.get_sids_by_uid("Lizz")) == {
            "session id 1",
            "session id 2",
        }

        self.sso_db.unmap_sid2uid("session id 1", "Lizz")
        assert self.sso_db.get_sids_by_uid("Lizz") == ["session id 2"]

    def test_get_uid_by_sid(self):
        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Lizz")

        assert self.sso_db.get_uid_by_sid("session id 1") == "Lizz"
        assert self.sso_db.get_uid_by_sid("session id 2") == "Lizz"

    def test_remove_uid(self):
        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Diana")

        self.sso_db.remove_uid("Lizz")
        assert set(self.sso_db.uid2sid.keys()) == {"Diana"}
        assert set(self.sso_db.uid2sid_rev.keys()) == {"session id 2"}

    def test_map_sid2sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        assert self.sso_db.get_sids_by_sub("abcdefgh") == ["session id 1"]

    def test_missing_sid2sub_map(self):
        with pytest.raises(KeyError):
            self.sso_db.get_sids_by_sub("abcdefgh")

    def test_multiple_map_sid2sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "abcdefgh")
        assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == {
            "session id 1",
            "session id 2",
        }

    def test_map_unmap_sid2sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "abcdefgh")
        assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == {
            "session id 1",
            "session id 2",
        }

        self.sso_db.unmap_sid2sub("session id 1", "abcdefgh")
        assert self.sso_db.get_sids_by_sub("abcdefgh") == ["session id 2"]

    def test_get_sub_by_sid(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "abcdefgh")

        assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == {
            "session id 1",
            "session id 2",
        }

    def test_remove_sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "012346789")

        self.sso_db.remove_sub("012346789")
        assert set(self.sso_db.sub2sid.keys()) == {"abcdefgh"}
        assert set(self.sso_db.sub2sid_rev.keys()) == {"session id 1"}

    def test_get_sub_by_uid_same_sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "abcdefgh")

        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Lizz")

        res = self.sso_db.get_sub_by_uid("Lizz")

        assert res == {"abcdefgh"}

    def test_get_sub_by_uid_different_sub(self):
        self.sso_db.map_sid2sub("session id 1", "abcdefgh")
        self.sso_db.map_sid2sub("session id 2", "012346789")

        self.sso_db.map_sid2uid("session id 1", "Lizz")
        self.sso_db.map_sid2uid("session id 2", "Lizz")

        res = self.sso_db.get_sub_by_uid("Lizz")

        assert res == {"abcdefgh", "012346789"}