예제 #1
0
 def test_custom_scope(self):
     _auth_req = AUTH_REQ.copy()
     _auth_req["scope"] = ["openid", "research_and_scholarship"]
     session_id = setup_session(
         self.endpoint.endpoint_context,
         _auth_req,
         uid="userID",
         authn_event={
             "authn_info": "loa1",
             "uid": "diana",
             "authn_time": utc_time_sans_frac(),
             "valid_until": utc_time_sans_frac() + 3600,
         },
     )
     _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(
         key=session_id)
     _req = self.endpoint.parse_request({},
                                        auth="Bearer {}".format(
                                            _dic["access_token"]))
     args = self.endpoint.process_request(_req)
     assert set(args["response_args"].keys()) == {
         "sub",
         "name",
         "given_name",
         "family_name",
         "email",
         "email_verified",
         "eduperson_scoped_affiliation",
     }
예제 #2
0
    def test_setup_auth_session_revoked(self):
        request = AuthorizationRequest(
            client_id="client_id",
            redirect_uri="https://rp.example.com/cb",
            response_type=["id_token"],
            state="state",
            nonce="nonce",
            scope="openid",
        )
        redirect_uri = request["redirect_uri"]
        cinfo = {
            "client_id": "client_id",
            "redirect_uris": [("https://rp.example.com/cb", {})],
            "id_token_signed_response_alg": "RS256",
        }
        _ec = self.endpoint.endpoint_context
        _ec.sdb["session_id"] = SessionInfo(
            authn_req=request,
            uid="diana",
            sub="abcdefghijkl",
            authn_event={
                "authn_info": "loa1",
                "uid": "diana",
                "authn_time": utc_time_sans_frac(),
            },
            revoked=True,
        )

        item = _ec.authn_broker.db["anon"]
        item["method"].user = b64e(
            as_bytes(json.dumps({"uid": "krall", "sid": "session_id"}))
        )

        res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None)
        assert set(res.keys()) == {"args", "function"}
예제 #3
0
 def is_expired(self):
     now = utc_time_sans_frac()
     if self.exp < now:
         logger.debug('is_expired: {} < {}'.format(self.exp, now))
         return True
     else:
         return False
예제 #4
0
    def test_create_authn_response(self):
        request = AuthorizationRequest(
            client_id="client_id",
            redirect_uri="https://rp.example.com/cb",
            response_type=["id_token"],
            state="state",
            nonce="nonce",
            scope="openid",
        )

        _ec = self.endpoint.endpoint_context
        _ec.sdb["session_id"] = SessionInfo(
            authn_req=request,
            uid="diana",
            sub="abcdefghijkl",
            authn_event={
                "authn_info": "loa1",
                "uid": "diana",
                "authn_time": utc_time_sans_frac(),
            },
        )
        _ec.cdb["client_id"] = {
            "client_id": "client_id",
            "redirect_uris": [("https://rp.example.com/cb", {})],
            "id_token_signed_response_alg": "ES256",
        }

        resp = create_authn_response(self.endpoint, request, "session_id")
        assert isinstance(resp["response_args"], AuthorizationErrorResponse)
예제 #5
0
    def process_request(self, req: Union[Message, dict], **kwargs):
        _context = self.endpoint.server_get("endpoint_context")
        _mngr = _context.session_manager

        if req["grant_type"] != "refresh_token":
            return self.error_cls(error="invalid_request", error_description="Wrong grant_type")

        token_value = req["refresh_token"]
        _session_info = _mngr.get_session_info_by_token(token_value, grant=True)
        _grant = _session_info["grant"]

        token_type = "Bearer"

        # Is DPOP supported
        if "dpop_signing_alg_values_supported" in _context.provider_info:
            _dpop_jkt = req.get("dpop_jkt")
            if _dpop_jkt:
                _grant.extra["dpop_jkt"] = _dpop_jkt
                token_type = "DPoP"

        token = _grant.get_token(token_value)
        scope = _grant.find_scope(token.based_on)
        if "scope" in req:
            scope = req["scope"]
        access_token = self._mint_token(
            token_class="access_token",
            grant=_grant,
            session_id=_session_info["session_id"],
            client_id=_session_info["client_id"],
            based_on=token,
            scope=scope,
            token_type=token_type,
        )

        _resp = {
            "access_token": access_token.value,
            "token_type": access_token.token_type,
            "scope": _grant.scope,
        }

        if access_token.expires_at:
            _resp["expires_in"] = access_token.expires_at - utc_time_sans_frac()

        _mints = token.usage_rules.get("supports_minting")
        issue_refresh = kwargs.get("issue_refresh", False)
        if "refresh_token" in _mints and issue_refresh:
            refresh_token = self._mint_token(
                token_class="refresh_token",
                grant=_grant,
                session_id=_session_info["session_id"],
                client_id=_session_info["client_id"],
                based_on=token,
                scope=scope,
            )
            refresh_token.usage_rules = token.usage_rules.copy()
            _resp["refresh_token"] = refresh_token.value

        token.register_usage()

        return _resp
예제 #6
0
def test_policy_language_crit_not_supported():
    _now = utc_time_sans_frac()
    _statement = EntityStatement(iat=_now, exp=_now + 3600, **MSG)

    _statement.verify(known_policy_extensions=["regexp"])

    with pytest.raises(UnknownCriticalExtension):
        _statement.verify()
예제 #7
0
    def process_request(self, request=None, **kwargs):
        _mngr = self.server_get("endpoint_context").session_manager
        _session_info = _mngr.get_session_info_by_token(
            request["access_token"], grant=True)
        _grant = _session_info["grant"]
        token = _grant.get_token(request["access_token"])
        # should be an access token
        if token.token_class != "access_token":
            return self.error_cls(error="invalid_token",
                                  error_description="Wrong type of token")

        # And it should be valid
        if token.is_active() is False:
            return self.error_cls(error="invalid_token",
                                  error_description="Invalid Token")

        allowed = True
        _auth_event = _grant.authentication_event
        # if the authenticate is still active or offline_access is granted.
        if _auth_event["valid_until"] > utc_time_sans_frac():
            pass
        else:
            logger.debug("authentication not valid: {} > {}".format(
                _auth_event["valid_until"], utc_time_sans_frac()))
            allowed = False

            # This has to be made more fine grained.
            # if "offline_access" in session["authn_req"]["scope"]:
            #     pass

        if allowed:
            _claims = _grant.claims.get("userinfo")
            info = self.server_get(
                "endpoint_context").claims_interface.get_user_claims(
                    user_id=_session_info["user_id"],
                    claims_restriction=_claims)
            info["sub"] = _grant.sub
        else:
            info = {
                "error": "invalid_request",
                "error_description": "Access not granted",
            }

        return {"response_args": info, "client_id": _session_info["client_id"]}
예제 #8
0
def assertion_jwt(cli, keys, audience, algorithm, lifetime=600):
    _now = utc_time_sans_frac()

    at = AuthnToken(iss=cli.client_id,
                    sub=cli.client_id,
                    aud=audience,
                    jti=rndstr(32),
                    exp=_now + lifetime,
                    iat=_now)
    return at.to_jwt(key=keys, algorithm=algorithm)
예제 #9
0
 def test_process_request(self):
     session_id = setup_session(
         self.endpoint.endpoint_context,
         AUTH_REQ,
         uid="userID",
         authn_event={
             "authn_info": "loa1",
             "uid": "diana",
             "authn_time": utc_time_sans_frac(),
             "valid_until": utc_time_sans_frac() + 3600,
         },
     )
     _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(
         key=session_id)
     _req = self.endpoint.parse_request({},
                                        auth="Bearer {}".format(
                                            _dic["access_token"]))
     args = self.endpoint.process_request(_req)
     assert args
예제 #10
0
 def decrypt(self, token: str) -> dict:
     info = self._jwt.unpack(token)
     if self._jwt.lifetime:
         now = utc_time_sans_frac()
         if 'iat' not in info or 'exp' not in info:
             raise MissingExpirationHeader()
         if now < info['iat']:
             raise NotYet()
         if now > info['exp']:
             raise Expired()
     return info.get('data')
예제 #11
0
    def create_cookie(self, value, typ, **kwargs):
        cookie = SimpleCookie()
        timestamp = str(utc_time_sans_frac())

        _payload = "::".join([value, timestamp, typ])
        bytes_load = _payload.encode("utf-8")
        bytes_timestamp = timestamp.encode("utf-8")

        cookie_payload = [bytes_load, bytes_timestamp]
        cookie[self.name] = (b"|".join(cookie_payload)).decode('utf-8')
        return cookie
예제 #12
0
 def test_process_request_not_allowed(self):
     session_id = setup_session(
         self.endpoint.endpoint_context,
         AUTH_REQ,
         uid="userID",
         authn_event={
             "authn_info": "loa1",
             "uid": "diana",
             "authn_time": utc_time_sans_frac() - 7200,
             "valid_until": utc_time_sans_frac() - 3600,
         },
     )
     _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(
         key=session_id)
     _req = self.endpoint.parse_request({},
                                        auth="Bearer {}".format(
                                            _dic["access_token"]))
     args = self.endpoint.process_request(_req)
     assert set(
         args["response_args"].keys()) == {"error", "error_description"}
예제 #13
0
def test_crit_known_unknown_not_critical():
    entity_id = "https://ent.example.org"
    _now = utc_time_sans_frac()
    _statement = EntityStatement(sub=entity_id,
                                 iss=entity_id,
                                 iat=_now,
                                 exp=_now + 3600,
                                 foo="bar")

    _statement.verify(known_extensions=["foo"])
    _statement.verify(known_extensions=["foo", "xyz"])
    _statement.verify()
예제 #14
0
 def test_process_request_offline_access(self):
     auth_req = AUTH_REQ.copy()
     auth_req["scope"] = ["openid", "offline_access"]
     session_id = setup_session(
         self.endpoint.endpoint_context,
         auth_req,
         uid="userID",
         authn_event={
             "authn_info": "loa1",
             "uid": "diana",
             "authn_time": utc_time_sans_frac() - 7200,
             "valid_until": utc_time_sans_frac() - 3600,
         },
     )
     _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(
         key=session_id)
     _req = self.endpoint.parse_request({},
                                        auth="Bearer {}".format(
                                            _dic["access_token"]))
     args = self.endpoint.process_request(_req)
     assert set(args["response_args"].keys()) == {"sub"}
예제 #15
0
    def verify(self, request, key_type, **kwargs):
        _context = self.server_get("endpoint_context")
        _jwt = JWT(_context.keyjar, msg_cls=JsonWebToken)
        try:
            ca_jwt = _jwt.unpack(request["client_assertion"])
        except (Invalid, MissingKey, BadSignature) as err:
            logger.info("%s" % sanitize(err))
            raise AuthnFailure("Could not verify client_assertion.")

        _sign_alg = ca_jwt.jws_header.get("alg")
        if _sign_alg and _sign_alg.startswith("HS"):
            if key_type == "private_key":
                raise AttributeError("Wrong key type")
            keys = _context.keyjar.get("sig", "oct", ca_jwt["iss"],
                                       ca_jwt.jws_header.get("kid"))
            _secret = _context.cdb[ca_jwt["iss"]].get("client_secret")
            if _secret and keys[0].key != as_bytes(_secret):
                raise AttributeError(
                    "Oct key used for signing not client_secret")
        else:
            if key_type == "client_secret":
                raise AttributeError("Wrong key type")

        authtoken = sanitize(ca_jwt.to_dict())
        logger.debug("authntoken: {}".format(authtoken))

        _endpoint = kwargs.get("endpoint")
        if _endpoint is None or not _endpoint:
            if _context.issuer in ca_jwt["aud"]:
                pass
            else:
                raise NotForMe("Not for me!")
        else:
            if set(ca_jwt["aud"]).intersection(
                    _endpoint.allowed_target_uris()):
                pass
            else:
                raise NotForMe("Not for me!")

        # If there is a jti use it to make sure one-time usage is true
        _jti = ca_jwt.get("jti")
        if _jti:
            _key = "{}:{}".format(ca_jwt["iss"], _jti)
            if _key in _context.jti_db:
                raise MultipleUsage("Have seen this token once before")
            else:
                _context.jti_db[_key] = utc_time_sans_frac()

        request[verified_claim_name("client_assertion")] = ca_jwt
        client_id = kwargs.get("client_id") or ca_jwt["iss"]

        return {"client_id": client_id, "jwt": ca_jwt}
예제 #16
0
    def test_is_expired(self):
        session_id = setup_session(
            self.endpoint.endpoint_context, AUTH_REQ, uid="diana"
        )
        _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id)

        handler = self.endpoint.endpoint_context.sdb.handler.handler["access_token"]
        assert handler.is_expired(_dic["access_token"]) is False

        assert (
            handler.is_expired(_dic["access_token"], utc_time_sans_frac() + 4000)
            is True
        )
예제 #17
0
파일: __init__.py 프로젝트: rohe/fedservice
    def create_trust_mark(self, id, sub):
        _now = utc_time_sans_frac()
        _add = {'iat': _now, 'id': id, 'sub': sub}
        lifetime = self.tm_lifetime.get(id)
        if lifetime:
            _add['exp'] = _now + lifetime

        content = self.trust_marks[id].copy()
        content.update(_add)
        self.issued.add(id, content)
        _ctx = self.server_get("context")
        packer = JWT(key_jar=_ctx.keyjar, iss=_ctx.entity_id)
        return packer.pack(payload=content)
예제 #18
0
def test_crit_critical_not_supported():
    entity_id = "https://ent.example.org"
    _now = utc_time_sans_frac()
    _statement = EntityStatement(sub=entity_id,
                                 iss=entity_id,
                                 iat=_now,
                                 exp=_now + 3600,
                                 foo="bar",
                                 crit=["foo"])

    with pytest.raises(UnknownCriticalExtension):
        _statement.verify(known_extensions=["xyz"])
    with pytest.raises(UnknownCriticalExtension):
        _statement.verify()
예제 #19
0
    def test_do_signed_response(self):
        self.endpoint.endpoint_context.cdb["client_1"][
            "userinfo_signed_response_alg"] = "ES256"

        session_id = setup_session(
            self.endpoint.endpoint_context,
            AUTH_REQ,
            uid="userID",
            authn_event={
                "authn_info": "loa1",
                "uid": "diana",
                "authn_time": utc_time_sans_frac(),
                "valid_until": utc_time_sans_frac() + 3600,
            },
        )
        _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(
            key=session_id)
        _req = self.endpoint.parse_request({},
                                           auth="Bearer {}".format(
                                               _dic["access_token"]))
        args = self.endpoint.process_request(_req)
        assert args
        res = self.endpoint.do_response(request=_req, **args)
        assert res
예제 #20
0
 def __getitem__(self, item):
     try:
         statement = self._db[item]
     except KeyError:
         return None
     else:
         if isinstance(statement, dict):
             # verify that the statement is recent enough
             _now = utc_time_sans_frac()
             if _now < (statement["exp"] - self.allowed_delta):
                 return statement
             else:
                 del self._db[item]
                 return None
         else:
             return statement
예제 #21
0
    def verify(self, **kwargs):
        super(TrustMark, self).verify(**kwargs)

        entity_id = kwargs.get("entity_id")

        if entity_id is not None and entity_id != self["sub"]:
            raise WrongSubject(
                "Mismatch between subject in trust mark and entity_id of entity"
            )

        exp = kwargs.get("exp", 0)
        if exp:
            _now = utc_time_sans_frac()
            if _now > exp:  # have passed the time of expiration
                raise Expired()

        return True
예제 #22
0
    def create_cookie(self, value, typ, **kwargs):
        cookie = SimpleCookie()
        timestamp = str(utc_time_sans_frac())

        _payload = "::".join([value, timestamp, typ])

        bytes_load = _payload.encode("utf-8")
        bytes_timestamp = timestamp.encode("utf-8")

        cookie_payload = [bytes_load, bytes_timestamp]
        cookie[self.name] = (b"|".join(cookie_payload)).decode('utf-8')
        try:
            ttl = kwargs['ttl']
        except KeyError:
            pass
        else:
            cookie[self.name]["expires"] = in_a_while(seconds=ttl)
        return cookie
예제 #23
0
    def verify(self, request=None, **kwargs):
        _context = self.server_get("endpoint_context")
        _jwt = JWT(_context.keyjar, msg_cls=JsonWebToken)
        try:
            _jwt = _jwt.unpack(request["request"])
        except (Invalid, MissingKey, BadSignature) as err:
            logger.info("%s" % sanitize(err))
            raise AuthnFailure("Could not verify client_assertion.")

        # If there is a jti use it to make sure one-time usage is true
        _jti = _jwt.get("jti")
        if _jti:
            _key = "{}:{}".format(_jwt["iss"], _jti)
            if _key in _context.jti_db:
                raise MultipleUsage("Have seen this token once before")
            else:
                _context.jti_db[_key] = utc_time_sans_frac()

        request[verified_claim_name("client_assertion")] = _jwt
        client_id = kwargs.get("client_id") or _jwt["iss"]

        return {"client_id": client_id, "jwt": _jwt}
예제 #24
0
    def verify(self, request, **kwargs):
        _jwt = JWT(self.endpoint_context.keyjar)
        try:
            ca_jwt = _jwt.unpack(request["client_assertion"])
        except (Invalid, MissingKey, BadSignature) as err:
            logger.info("%s" % sanitize(err))
            raise AuthnFailure("Could not verify client_assertion.")

        authtoken = sanitize(ca_jwt)
        if hasattr(ca_jwt, "to_dict") and callable(ca_jwt, "to_dict"):
            authtoken = sanitize(ca_jwt.to_dict())
        logger.debug("authntoken: {}".format(authtoken))

        _endpoint = kwargs.get("endpoint")
        if _endpoint is None or not _endpoint:
            if self.endpoint_context.issuer in ca_jwt["aud"]:
                pass
            else:
                raise NotForMe("Not for me!")
        else:
            if set(ca_jwt["aud"]).intersection(
                    self.endpoint_context.endpoint[_endpoint].allowed_target_uris()):
                pass
            else:
                raise NotForMe("Not for me!")

        # If there is a jti use it to make sure one-time usage is true
        _jti = ca_jwt.get('jti')
        if _jti:
            _key = "{}:{}".format(ca_jwt['iss'], _jti)
            if _key in self.endpoint_context.jti_db:
                raise MultipleUsage("Have seen this token once before")
            else:
                self.endpoint_context.jti_db.set(_key, utc_time_sans_frac())

        request[verified_claim_name("client_assertion")] = ca_jwt
        client_id = kwargs.get("client_id") or ca_jwt["iss"]

        return {"client_id": client_id, "jwt": ca_jwt}
예제 #25
0
    def test_process_request_not_allowed(self):
        session_id = self._create_session(AUTH_REQ, index=2)
        grant = self.endpoint[2].server_get("endpoint_context").authz(
            session_id, AUTH_REQ)
        code = self._mint_code(grant, session_id, index=2)
        access_token = self._mint_access_token(grant, session_id, code, 2)

        access_token.expires_at = utc_time_sans_frac() - 60
        self.session_manager[2].set(
            [self.user_id, AUTH_REQ["client_id"], grant.id], grant)

        self._dump_restore(2, 1)

        http_info = {
            "headers": {
                "authorization": "Bearer {}".format(access_token.value)
            }
        }

        _req = self.endpoint[1].parse_request({}, http_info=http_info)

        args = self.endpoint[1].process_request(_req)
        assert set(args.keys()) == {"error", "error_description"}
        assert args["error"] == "invalid_token"
예제 #26
0
from cryptojwt.jwt import utc_time_sans_frac
from fedservice import apply_policy
from fedservice import combine_policy
from fedservice.message import EntityStatement
from fedservice.message import FederationEntity
from fedservice.message import MetadataPolicy
from fedservice.message import TrustMark
from oidcmsg.oidc import ProviderConfigurationResponse
from oidcmsg.oidc import RegistrationRequest
from oidcmsg.oidc import RegistrationResponse

txt = open("2.1.json").read()

es = EntityStatement().from_json(txt)

now = utc_time_sans_frac()
es['iat'] = now
es['exp'] = now + 3600

print("2.1", es.verify())

# 3.6

txt = open("3.6.json").read()

fe = FederationEntity().from_json(txt)

print("3.6", fe.verify())

# 4.1.3.1
예제 #27
0
    def collect_intermediate(self,
                             entity_id,
                             intermediate,
                             seen=None,
                             max_superiors=10):
        """
        Collect information about an entity by another entity, the intermediate.
        This consist of first find the fed_api_endpoint URL for the intermediate and then
        asking the intermediate for its view of the entity.

        :param entity_id: The ID of the entity
        :param intermediate: The immediate superior
        :param seen: A list of intermediates that this process has seen. This to capture
            loops. Also used to control the allowed depth.
        :param max_superiors: The maximum number of superiors.
        :return:
        """
        logger.debug('Collect intermediate "%s"', intermediate)
        # Should I stop when I reach the first trust anchor ?
        if entity_id == intermediate and entity_id in self.trusted_anchors:
            return None

        if seen is None:
            _seen = []
        else:
            _seen = seen[:]

        _seen.append(intermediate)
        # if len(_seen) > max_superiors:
        #     logger.warning("Reached max superiors. The path here was {}".format(_seen))
        #     return None

        # Try to get the entity statement from the cache
        cache_key = "{}!!{}".format(intermediate, entity_id)
        entity_statement = self.entity_statement_cache[cache_key]

        if entity_statement is not None:
            _now = utc_time_sans_frac()
            time_key = "{}!exp!{}".format(intermediate, entity_id)
            _exp = self.entity_statement_cache[time_key]
            if _now > (_exp - self.allowed_delta):
                logger.debug("Cached entity statement timed out")
                del self.entity_statement_cache[cache_key]
                del self.entity_statement_cache[time_key]
                entity_statement = None

        if entity_statement is None:
            fed_api_endpoint = self.get_federation_api_endpoint(intermediate)
            if fed_api_endpoint is None:
                raise SystemError('Could not find federation_api endpoint')
            logger.debug("Federation API endpoint: '{}' for '{}'".format(
                fed_api_endpoint, intermediate))
            entity_statement = self.get_entity_statement(
                fed_api_endpoint, intermediate, entity_id)
            # entity_statement is a signed JWT
            statement = unverified_entity_statement(entity_statement)
            logger.debug(
                "Unverified entity statement from {} about {}: {}".format(
                    fed_api_endpoint, entity_id, statement))
            self.entity_statement_cache[cache_key] = entity_statement
            time_key = "{}!exp!{}".format(intermediate, entity_id)
            self.entity_statement_cache[time_key] = statement["exp"]

        if entity_statement:
            intermediate_statement = self.config_cache[intermediate]
            return entity_statement, self.collect_superiors(
                intermediate,
                intermediate_statement,
                seen=_seen,
                max_superiors=max_superiors)
        else:
            return None
예제 #28
0
def valid_client_info(cinfo):
    eta = cinfo.get("client_secret_expires_at", 0)
    if eta != 0 and eta < utc_time_sans_frac():
        return False
    return True
예제 #29
0
    def collect_intermediate(self,
                             entity_id,
                             authority,
                             seen=None,
                             max_superiors=10):
        """
        Collect information about an entity by another entity, the authority.
        This consist of first find the fed_fetch_endpoint URL for the authority and then
        asking the authority for its view of the entity.

        :param entity_id: The ID of the entity
        :param authority: The immediate superior
        :param seen: A list of authorities that this process has seen. This to capture
            loops. Also used to control the allowed depth.
        :param max_superiors: The maximum number of superiors.
        :return:
        """
        logger.debug(f'Get authority "{authority}" for "{entity_id}"')
        # Should I stop when I reach the first trust anchor ?
        if entity_id == authority and entity_id in self.trusted_anchors:
            return None

        if seen is None:
            _seen = []
        else:
            _seen = seen[:]

        _seen.append(authority)
        # if len(_seen) > max_superiors:
        #     logger.warning("Reached max superiors. The path here was {}".format(_seen))
        #     return None

        # Try to get the entity statement from the cache
        cache_key = "{}!!{}".format(authority, entity_id)
        entity_statement = self.entity_statement_cache[cache_key]

        if entity_statement is not None:
            logger.debug("Have cached statement")
            _now = utc_time_sans_frac()
            time_key = "{}!exp!{}".format(authority, entity_id)
            _exp = self.entity_statement_cache[time_key]
            if _now > (_exp - self.allowed_delta):
                logger.debug("Cached entity statement timed out")
                del self.entity_statement_cache[cache_key]
                del self.entity_statement_cache[time_key]
                entity_statement = None

        if entity_statement is None:
            logger.debug(f"Have not seen '{authority}' before")
            fed_fetch_endpoint = self.get_federation_fetch_endpoint(authority)
            if fed_fetch_endpoint is None:
                return None
            logger.debug(
                f"Federation fetch endpoint: '{fed_fetch_endpoint}' for '{authority}'"
            )
            entity_statement = self.get_entity_statement(
                fed_fetch_endpoint, authority, entity_id)
            # entity_statement is a signed JWT
            statement = unverified_entity_statement(entity_statement)
            logger.debug(
                f"Unverified entity statement from {fed_fetch_endpoint} about {entity_id}: {statement}"
            )
            self.entity_statement_cache[cache_key] = entity_statement
            time_key = "{}!exp!{}".format(authority, entity_id)
            self.entity_statement_cache[time_key] = statement["exp"]

        if entity_statement:
            authority_statement = self.config_cache[authority]
            return entity_statement, self.collect_superiors(
                authority,
                authority_statement,
                seen=_seen,
                max_superiors=max_superiors)
        else:
            return None
예제 #30
0
    def process_request(self, req: Union[Message, dict], **kwargs):
        """

        :param req:
        :param kwargs:
        :return:
        """
        _context = self.endpoint.server_get("endpoint_context")

        _mngr = _context.session_manager
        _log_debug = logger.debug

        if req["grant_type"] != "authorization_code":
            return self.error_cls(error="invalid_request", error_description="Unknown grant_type")

        try:
            _access_code = req["code"].replace(" ", "+")
        except KeyError:  # Missing code parameter - absolutely fatal
            return self.error_cls(error="invalid_request", error_description="Missing code")

        _session_info = _mngr.get_session_info_by_token(_access_code, grant=True)
        grant = _session_info["grant"]

        token_type = "Bearer"

        # Is DPOP supported
        try:
            _dpop_enabled = _context.dpop_enabled
        except AttributeError:
            _dpop_enabled = False

        if _dpop_enabled:
            _dpop_jkt = req.get("dpop_jkt")
            if _dpop_jkt:
                grant.extra["dpop_jkt"] = _dpop_jkt
                token_type = "DPoP"

        _based_on = grant.get_token(_access_code)
        _supports_minting = _based_on.usage_rules.get("supports_minting", [])

        _authn_req = grant.authorization_request

        # If redirect_uri was in the initial authorization request
        # verify that the one given here is the correct one.
        if "redirect_uri" in _authn_req:
            if req["redirect_uri"] != _authn_req["redirect_uri"]:
                return self.error_cls(
                    error="invalid_request", error_description="redirect_uri mismatch"
                )

        _log_debug("All checks OK")

        issue_refresh = False
        if "issue_refresh" in kwargs:
            issue_refresh = kwargs["issue_refresh"]
        else:
            if "offline_access" in grant.scope:
                issue_refresh = True

        _response = {
            "token_type": token_type,
            "scope": grant.scope,
        }

        if "access_token" in _supports_minting:
            try:
                token = self._mint_token(
                    token_class="access_token",
                    grant=grant,
                    session_id=_session_info["session_id"],
                    client_id=_session_info["client_id"],
                    based_on=_based_on,
                    token_type=token_type,
                )
            except MintingNotAllowed as err:
                logger.warning(err)
            else:
                _response["access_token"] = token.value
                if token.expires_at:
                    _response["expires_in"] = token.expires_at - utc_time_sans_frac()

        if issue_refresh and "refresh_token" in _supports_minting:
            try:
                refresh_token = self._mint_token(
                    token_class="refresh_token",
                    grant=grant,
                    session_id=_session_info["session_id"],
                    client_id=_session_info["client_id"],
                    based_on=_based_on,
                )
            except MintingNotAllowed as err:
                logger.warning(err)
            else:
                _response["refresh_token"] = refresh_token.value

        # since the grant content has changed. Make sure it's stored
        _mngr[_session_info["session_id"]] = grant

        if "openid" in _authn_req["scope"] and "id_token" in _supports_minting:
            if "id_token" in _based_on.usage_rules.get("supports_minting"):
                try:
                    _idtoken = self._mint_token(
                        token_class="id_token",
                        grant=grant,
                        session_id=_session_info["session_id"],
                        client_id=_session_info["client_id"],
                        based_on=_based_on,
                    )
                except (JWEException, NoSuitableSigningKeys) as err:
                    logger.warning(str(err))
                    resp = self.error_cls(
                        error="invalid_request",
                        error_description="Could not sign/encrypt id_token",
                    )
                    return resp

                _response["id_token"] = _idtoken.value

        _based_on.register_usage()

        return _response