Exemplo n.º 1
0
 def create_handler(self):
     self.th = TokenHandler('https://example.com/as', {
         'access_token': {
             'https://example.org/rp': {
                 'client_credentials': 1200
             }
         },
         'refresh_token': {
             'https://example.org/rp': {
                 'client_credentials': 86400
             }
         }
     },
                            keyjar=KEYJAR)
Exemplo n.º 2
0
 def create_handler(self):
     self.th = TokenHandler(
         "https://example.com/as",
         {
             "access_token": {
                 "https://example.org/rp": {
                     "client_credentials": 1200
                 }
             },
             "refresh_token": {
                 "https://example.org/rp": {
                     "client_credentials": 86400
                 }
             },
         },
         keyjar=KEYJAR,
     )
Exemplo n.º 3
0
 def create_handler(self):
     self.th = TokenHandler(
         'https://example.com/as',
         {
             'access_token': {
                 'https://example.org/rp': {
                     'client_credentials': 1200
                 }
             },
             'refresh_token': {
                 'https://example.org/rp': {
                     'client_credentials': 86400
                 }
             }
         },
         keyjar=KEYJAR
     )
Exemplo n.º 4
0
    def __init__(self,
                 name,
                 sdb,
                 cdb,
                 authn_broker,
                 authz,
                 client_authn,
                 symkey=None,
                 urlmap=None,
                 iv=0,
                 default_scope="",
                 ca_bundle=None,
                 seed=b"",
                 client_authn_methods=None,
                 authn_at_registration="",
                 client_info_url="",
                 secret_lifetime=86400,
                 jwks_uri='',
                 keyjar=None,
                 capabilities=None,
                 verify_ssl=True,
                 baseurl='',
                 hostname='',
                 config=None,
                 behavior=None,
                 lifetime_policy=None,
                 **kwargs):

        if not name.endswith("/"):
            name += "/"

        try:
            args = {'server_cls': kwargs['server_cls']}
        except KeyError:
            args = {}

        provider.Provider.__init__(self, name, sdb, cdb, authn_broker, authz,
                                   client_authn, symkey, urlmap, iv,
                                   default_scope, ca_bundle, **args)

        self.endp.extend([
            RegistrationEndpoint, ClientInfoEndpoint, RevocationEndpoint,
            IntrospectionEndpoint
        ])

        # dictionary of client authentication methods
        self.client_authn_methods = client_authn_methods
        if authn_at_registration:
            if authn_at_registration not in client_authn_methods:
                raise UnknownAuthnMethod(authn_at_registration)

        self.authn_at_registration = authn_at_registration
        self.seed = seed
        self.client_info_url = client_info_url
        self.secret_lifetime = secret_lifetime
        self.jwks_uri = jwks_uri
        self.verify_ssl = verify_ssl
        try:
            self.scopes = kwargs['scopes']
        except KeyError:
            self.scopes = ['offline_access']
        self.keyjar = keyjar
        if self.keyjar is None:
            self.keyjar = KeyJar(verify_ssl=self.verify_ssl)

        if capabilities:
            self.capabilities = self.provider_features(
                provider_config=capabilities)
        else:
            self.capabilities = self.provider_features()
        self.baseurl = baseurl or name
        self.hostname = hostname or socket.gethostname()
        self.kid = {"sig": {}, "enc": {}}
        self.config = config or {}
        self.behavior = behavior or {}
        self.token_policy = {'access_token': {}, 'refresh_token': {}}
        if lifetime_policy is None:
            self.lifetime_policy = {
                'access_token': {
                    'code': 600,
                    'token': 120,
                    'implicit': 120,
                    'authorization_code': 600,
                    'client_credentials': 600,
                    'password': 600
                },
                'refresh_token': {
                    'code': 3600,
                    'token': 3600,
                    'implicit': 3600,
                    'authorization_code': 3600,
                    'client_credentials': 3600,
                    'password': 3600
                }
            }
        else:
            self.lifetime_policy = lifetime_policy

        self.token_handler = TokenHandler(self.baseurl,
                                          self.token_policy,
                                          keyjar=self.keyjar)
Exemplo n.º 5
0
class Provider(provider.Provider):
    """
    A OAuth2 RP that knows all the OAuth2 extensions I've implemented
    """
    def __init__(self,
                 name,
                 sdb,
                 cdb,
                 authn_broker,
                 authz,
                 client_authn,
                 symkey=None,
                 urlmap=None,
                 iv=0,
                 default_scope="",
                 ca_bundle=None,
                 seed=b"",
                 client_authn_methods=None,
                 authn_at_registration="",
                 client_info_url="",
                 secret_lifetime=86400,
                 jwks_uri='',
                 keyjar=None,
                 capabilities=None,
                 verify_ssl=True,
                 baseurl='',
                 hostname='',
                 config=None,
                 behavior=None,
                 lifetime_policy=None,
                 **kwargs):

        if not name.endswith("/"):
            name += "/"

        try:
            args = {'server_cls': kwargs['server_cls']}
        except KeyError:
            args = {}

        provider.Provider.__init__(self, name, sdb, cdb, authn_broker, authz,
                                   client_authn, symkey, urlmap, iv,
                                   default_scope, ca_bundle, **args)

        self.endp.extend([
            RegistrationEndpoint, ClientInfoEndpoint, RevocationEndpoint,
            IntrospectionEndpoint
        ])

        # dictionary of client authentication methods
        self.client_authn_methods = client_authn_methods
        if authn_at_registration:
            if authn_at_registration not in client_authn_methods:
                raise UnknownAuthnMethod(authn_at_registration)

        self.authn_at_registration = authn_at_registration
        self.seed = seed
        self.client_info_url = client_info_url
        self.secret_lifetime = secret_lifetime
        self.jwks_uri = jwks_uri
        self.verify_ssl = verify_ssl
        try:
            self.scopes = kwargs['scopes']
        except KeyError:
            self.scopes = ['offline_access']
        self.keyjar = keyjar
        if self.keyjar is None:
            self.keyjar = KeyJar(verify_ssl=self.verify_ssl)

        if capabilities:
            self.capabilities = self.provider_features(
                provider_config=capabilities)
        else:
            self.capabilities = self.provider_features()
        self.baseurl = baseurl or name
        self.hostname = hostname or socket.gethostname()
        self.kid = {"sig": {}, "enc": {}}
        self.config = config or {}
        self.behavior = behavior or {}
        self.token_policy = {'access_token': {}, 'refresh_token': {}}
        if lifetime_policy is None:
            self.lifetime_policy = {
                'access_token': {
                    'code': 600,
                    'token': 120,
                    'implicit': 120,
                    'authorization_code': 600,
                    'client_credentials': 600,
                    'password': 600
                },
                'refresh_token': {
                    'code': 3600,
                    'token': 3600,
                    'implicit': 3600,
                    'authorization_code': 3600,
                    'client_credentials': 3600,
                    'password': 3600
                }
            }
        else:
            self.lifetime_policy = lifetime_policy

        self.token_handler = TokenHandler(self.baseurl,
                                          self.token_policy,
                                          keyjar=self.keyjar)

    @staticmethod
    def _uris_to_tuples(uris):
        tup = []
        for uri in uris:
            base, query = splitquery(uri)
            if query:
                tup.append((base, query))
            else:
                tup.append((base, ""))
        return tup

    @staticmethod
    def _tuples_to_uris(items):
        _uri = []
        for url, query in items:
            if query:
                _uri.append("%s?%s" % (url, query))
            else:
                _uri.append(url)
        return _uri

    def load_keys(self, request, client_id, client_secret):
        try:
            self.keyjar.load_keys(request, client_id)
            try:
                n_keys = len(self.keyjar[client_id])
                msg = "Found {} keys for client_id={}"
                logger.debug(msg.format(n_keys, client_id))
            except KeyError:
                pass
        except Exception as err:
            msg = "Failed to load client keys: {}"
            logger.error(msg.format(sanitize(request.to_dict())))
            logger.error("%s", err)
            err = ClientRegistrationError(
                error="invalid_configuration_parameter",
                error_description="%s" % err)
            return Response(err.to_json(),
                            content="application/json",
                            status_code="400 Bad Request")

        # Add the client_secret as a symmetric key to the keyjar
        _kc = KeyBundle([{
            "kty": "oct",
            "key": client_secret,
            "use": "ver"
        }, {
            "kty": "oct",
            "key": client_secret,
            "use": "sig"
        }])
        try:
            self.keyjar[client_id].append(_kc)
        except KeyError:
            self.keyjar[client_id] = [_kc]

    @staticmethod
    def verify_correct(cinfo, restrictions):
        for fname, arg in restrictions.items():
            func = restrict.factory(fname)
            res = func(arg, cinfo)
            if res:
                raise RestrictionError(res)

    def set_token_policy(self, cid, cinfo):
        for ttyp in ['access_token', 'refresh_token']:
            pol = {}
            for rgtyp in ['response_type', 'grant_type']:
                try:
                    rtyp = cinfo[rgtyp]
                except KeyError:
                    pass
                else:
                    for typ in rtyp:
                        try:
                            pol[typ] = self.lifetime_policy[ttyp][typ]
                        except KeyError:
                            pass

            self.token_policy[ttyp][cid] = pol

    def create_new_client(self, request, restrictions):
        """

        :param request: The Client registration request
        :param restrictions: Restrictions on the client
        :return: The client_id
        """

        _cinfo = request.to_dict()

        self.match_client_request(_cinfo)

        # create new id and secret
        _id = rndstr(12)
        while _id in self.cdb:
            _id = rndstr(12)

        _cinfo["client_id"] = _id
        _cinfo["client_secret"] = secret(self.seed, _id)
        _cinfo["client_id_issued_at"] = utc_time_sans_frac()
        _cinfo["client_secret_expires_at"] = utc_time_sans_frac(
        ) + self.secret_lifetime

        # If I support client info endpoint
        if ClientInfoEndpoint in self.endp:
            _cinfo["registration_access_token"] = rndstr(32)
            _cinfo["registration_client_uri"] = "%s%s%s?client_id=%s" % (
                self.name, self.client_info_url, ClientInfoEndpoint.etype, _id)

        if "redirect_uris" in request:
            _cinfo["redirect_uris"] = self._uris_to_tuples(
                request["redirect_uris"])

        self.load_keys(request, _id, _cinfo["client_secret"])

        try:
            _behav = self.behavior['client_registration']
        except KeyError:
            pass
        else:
            self.verify_correct(_cinfo, _behav)

        self.set_token_policy(_id, _cinfo)
        self.cdb[_id] = _cinfo

        return _id

    def match_client_request(self, request):
        for _pref, _prov in PREFERENCE2PROVIDER.items():
            if _pref in request:
                if _pref == "response_types":
                    for val in request[_pref]:
                        match = False
                        p = set(val.split(" "))
                        for cv in self.capabilities[_prov]:
                            if p == set(cv.split(' ')):
                                match = True
                                break
                        if not match:
                            raise CapabilitiesMisMatch(
                                'Not allowed {}'.format(_pref))
                else:
                    if isinstance(request[_pref], str):
                        if request[_pref] not in self.capabilities[_prov]:
                            raise CapabilitiesMisMatch(
                                'Not allowed {}'.format(_pref))
                    else:
                        if not set(request[_pref]).issubset(
                                set(self.capabilities[_prov])):
                            raise CapabilitiesMisMatch(
                                'Not allowed {}'.format(_pref))

    def client_info(self, client_id):
        _cinfo = self.cdb[client_id].copy()
        if not valid_client_info(_cinfo):
            err = ErrorResponse(error="invalid_client",
                                error_description="Invalid client secret")
            return BadRequest(err.to_json(), content="application/json")

        try:
            _cinfo["redirect_uris"] = self._tuples_to_uris(
                _cinfo["redirect_uris"])
        except KeyError:
            pass

        msg = ClientInfoResponse(**_cinfo)
        return Response(msg.to_json(), content="application/json")

    def client_info_update(self, client_id, request):
        _cinfo = self.cdb[client_id].copy()
        try:
            _cinfo["redirect_uris"] = self._tuples_to_uris(
                _cinfo["redirect_uris"])
        except KeyError:
            pass

        for key, value in request.items():
            if key in ["client_secret", "client_id"]:
                # assure it's the same
                if value != _cinfo[key]:
                    raise ModificationForbidden("Not allowed to change")
            else:
                _cinfo[key] = value

        for key in list(_cinfo.keys()):
            if key in [
                    "client_id_issued_at", "client_secret_expires_at",
                    "registration_access_token", "registration_client_uri"
            ]:
                continue
            if key not in request:
                del _cinfo[key]

        if "redirect_uris" in request:
            _cinfo["redirect_uris"] = self._uris_to_tuples(
                request["redirect_uris"])

        self.cdb[client_id] = _cinfo

    def verify_client(self, environ, areq, authn_method, client_id=""):
        """

        :param environ: WSGI environ
        :param areq: The request
        :param authn_method: client authentication method
        :return:
        """

        if not client_id:
            client_id = get_client_id(self.cdb, areq,
                                      environ["HTTP_AUTHORIZATION"])

        try:
            method = self.client_authn_methods[authn_method]
        except KeyError:
            raise UnSupported()
        return method(self).verify(environ, client_id=client_id)

    def consume_software_statement(self, software_statement):
        return {}

    def registration_endpoint(self, **kwargs):
        """

        :param request: The request
        :param authn: Client authentication information
        :param kwargs: extra keyword arguments
        :return: A Response instance
        """

        _request = RegistrationRequest().deserialize(kwargs['request'], "json")
        try:
            _request.verify(keyjar=self.keyjar)
        except InvalidRedirectUri as err:
            msg = ClientRegistrationError(error="invalid_redirect_uri",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")
        except (MissingPage, VerificationError) as err:
            msg = ClientRegistrationError(error="invalid_client_metadata",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")

        # If authentication is necessary at registration
        if self.authn_at_registration:
            try:
                self.verify_client(kwargs['environ'], _request,
                                   self.authn_at_registration)
            except (AuthnFailure, UnknownAssertionType):
                return Unauthorized()

        client_restrictions = {}
        if 'parsed_software_statement' in _request:
            for ss in _request['parsed_software_statement']:
                client_restrictions.update(self.consume_software_statement(ss))
            del _request['software_statement']
            del _request['parsed_software_statement']

        try:
            client_id = self.create_new_client(_request, client_restrictions)
        except CapabilitiesMisMatch as err:
            msg = ClientRegistrationError(error="invalid_client_metadata",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")
        except RestrictionError as err:
            msg = ClientRegistrationError(error="invalid_client_metadata",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")

        return self.client_info(client_id)

    def client_info_endpoint(self, method="GET", **kwargs):
        """
        Operations on this endpoint are switched through the use of different
        HTTP methods

        :param method: HTTP method used for the request
        :param kwargs: keyword arguments
        :return: A Response instance
        """

        _query = compact(parse_qs(kwargs['query']))
        try:
            _id = _query["client_id"]
        except KeyError:
            return BadRequest("Missing query component")

        if _id not in self.cdb:
            return Unauthorized()

        # authenticated client
        try:
            self.verify_client(kwargs['environ'],
                               kwargs['request'],
                               "bearer_header",
                               client_id=_id)
        except (AuthnFailure, UnknownAssertionType):
            return Unauthorized()

        if method == "GET":
            return self.client_info(_id)
        elif method == "PUT":
            try:
                _request = ClientUpdateRequest().from_json(kwargs['request'])
            except ValueError as err:
                return BadRequest(str(err))

            try:
                _request.verify()
            except InvalidRedirectUri as err:
                msg = ClientRegistrationError(error="invalid_redirect_uri",
                                              error_description="%s" % err)
                return BadRequest(msg.to_json(), content="application/json")
            except (MissingPage, VerificationError) as err:
                msg = ClientRegistrationError(error="invalid_client_metadata",
                                              error_description="%s" % err)
                return BadRequest(msg.to_json(), content="application/json")

            try:
                self.client_info_update(_id, _request)
                return self.client_info(_id)
            except ModificationForbidden:
                return Forbidden()
        elif method == "DELETE":
            try:
                del self.cdb[_id]
            except KeyError:
                return Unauthorized()
            else:
                return NoContent()

    def provider_features(self,
                          pcr_class=ServerMetadata,
                          provider_config=None):
        """
        Specifies what the server capabilities are.

        :param pcr_class:
        :return: ProviderConfigurationResponse instance
        """

        _provider_info = pcr_class(**CAPABILITIES)
        _provider_info["scopes_supported"] = self.scopes

        sign_algs = list(jws.SIGNER_ALGS.keys())
        sign_algs.remove('none')
        sign_algs = sorted(sign_algs, key=cmp_to_key(sort_sign_alg))

        _pat1 = "{}_endpoint_auth_signing_alg_values_supported"
        _pat2 = "{}_endpoint_auth_methods_supported"
        for typ in ["token", "revocation", "introspection"]:
            _provider_info[_pat1.format(typ)] = sign_algs
            _provider_info[_pat2.format(typ)] = AUTH_METHODS_SUPPORTED

        if provider_config:
            _provider_info.update(provider_config)

        return _provider_info

    def verify_capabilities(self, capabilities):
        """
        Verify that what the admin wants the server to do actually
        can be done by this implementation.

        :param capabilities: The asked for capabilities as a dictionary
        or a ProviderConfigurationResponse instance. The later can be
        treated as a dictionary.
        :return: True or False
        """
        _pinfo = self.provider_features()
        for key, val in capabilities.items():
            if isinstance(val, str):
                try:
                    if val in _pinfo[key]:
                        continue
                    else:
                        return False
                except KeyError:
                    return False

        return True

    def create_providerinfo(self,
                            pcr_class=ASConfigurationResponse,
                            setup=None):
        """
        Dynamically create the provider info response
        :param pcr_class:
        :param setup:
        :return:
        """

        _provider_info = self.capabilities

        if self.jwks_uri and self.keyjar:
            _provider_info["jwks_uri"] = self.jwks_uri

        for endp in self.endp:
            _provider_info['{}_endpoint'.format(endp.etype)] = os.path.join(
                self.baseurl, endp.url)

        if setup and isinstance(setup, dict):
            for key in pcr_class.c_param.keys():
                if key in setup:
                    _provider_info[key] = setup[key]

        _provider_info["issuer"] = self.baseurl
        _provider_info["version"] = "3.0"

        return _provider_info

    def providerinfo_endpoint(self, **kwargs):
        _log_info = logger.info

        _log_info("@providerinfo_endpoint")
        try:
            _response = self.create_providerinfo()
            _log_info("provider_info_response: %s" % (_response.to_dict(), ))

            headers = [("Cache-Control", "no-store"), ("x-ffo", "bar")]
            if 'handle' in kwargs:
                (key, _) = kwargs['handle']
                if key.startswith(STR) and key.endswith(STR):
                    cookie = self.cookie_func(key, self.cookie_name, "pinfo",
                                              self.sso_ttl)
                    headers.append(cookie)

            resp = Response(_response.to_json(),
                            content="application/json",
                            headers=headers)
        except Exception:
            message = traceback.format_exception(*sys.exc_info())
            logger.error(message)
            resp = Response(message, content="html/text")

        return resp

    @staticmethod
    def verify_code_challenge(code_verifier,
                              code_challenge,
                              code_challenge_method='S256'):
        """
        Verify a PKCE (RFC7636) code challenge

        :param code_verifier: The origin
        :param code_challenge: The transformed verifier used as challenge
        :return:
        """
        _h = CC_METHOD[code_challenge_method](
            code_verifier.encode('ascii')).digest()
        _cc = b64e(_h)
        if _cc.decode('ascii') != code_challenge:
            logger.error('PCKE Code Challenge check failed')
            err = TokenErrorResponse(error="invalid_request",
                                     error_description="PCKE check failed")
            return Response(err.to_json(),
                            content="application/json",
                            status_code=401)
        return True

    def do_access_token_response(self,
                                 access_token,
                                 atinfo,
                                 state,
                                 refresh_token=None):
        _tinfo = {
            'access_token': access_token,
            'expires_in': atinfo['exp'],
            'token_type': 'bearer',
            'state': state
        }
        try:
            _tinfo['scope'] = atinfo['scope']
        except KeyError:
            pass

        if refresh_token:
            _tinfo['refresh_token'] = refresh_token

        return AccessTokenResponse(**by_schema(AccessTokenResponse, **_tinfo))

    def code_grant_type(self, areq):
        # assert that the code is valid
        try:
            _info = self.sdb[areq["code"]]
        except KeyError:
            err = TokenErrorResponse(error="invalid_grant",
                                     error_description="Unknown access grant")
            return Response(err.to_json(),
                            content="application/json",
                            status="401 Unauthorized")

        authzreq = json.loads(_info['authzreq'])
        if 'code_verifier' in areq:
            try:
                _method = authzreq['code_challenge_method']
            except KeyError:
                _method = 'S256'

            resp = self.verify_code_challenge(areq['code_verifier'],
                                              authzreq['code_challenge'],
                                              _method)
            if resp:
                return resp

        if 'state' in areq:
            if self.sdb[areq['code']]['state'] != areq['state']:
                logger.error('State value mismatch')
                err = TokenErrorResponse(error="unauthorized_client")
                return Unauthorized(err.to_json(), content="application/json")

        resp = self.token_scope_check(areq, _info)
        if resp:
            return resp

        # If redirect_uri was in the initial authorization request
        # verify that the one given here is the correct one.
        if "redirect_uri" in _info and areq["redirect_uri"] != _info[
                "redirect_uri"]:
            logger.error('Redirect_uri mismatch')
            err = TokenErrorResponse(error="unauthorized_client")
            return Unauthorized(err.to_json(), content="application/json")

        issue_refresh = False
        if 'scope' in authzreq and 'offline_access' in authzreq['scope']:
            if authzreq['response_type'] == 'code':
                issue_refresh = True

        try:
            _tinfo = self.sdb.upgrade_to_token(areq["code"],
                                               issue_refresh=issue_refresh)
        except AccessCodeUsed:
            err = TokenErrorResponse(error="invalid_grant",
                                     error_description="Access grant used")
            return Response(err.to_json(),
                            content="application/json",
                            status="401 Unauthorized")

        logger.debug("_tinfo: %s" % _tinfo)

        atr = AccessTokenResponse(**by_schema(AccessTokenResponse, **_tinfo))

        logger.debug("AccessTokenResponse: %s" % atr)

        return Response(atr.to_json(),
                        content="application/json",
                        headers=OAUTH2_NOCACHE_HEADERS)

    def client_credentials_grant_type(self, areq):
        _at = self.token_handler.get_access_token(
            areq['client_id'],
            scope=areq['scope'],
            grant_type='client_credentials')
        _info = self.token_handler.token_factory.get_info(_at)
        try:
            _rt = self.token_handler.get_refresh_token(self.baseurl,
                                                       _info['access_token'],
                                                       'client_credentials')
        except NotAllowed:
            atr = self.do_access_token_response(_at, _info, areq['state'])
        else:
            atr = self.do_access_token_response(_at, _info, areq['state'], _rt)

        return Response(atr.to_json(), content="application/json")

    def password_grant_type(self, areq):
        _at = self.token_handler.get_access_token(areq['client_id'],
                                                  scope=areq['scope'],
                                                  grant_type='password')
        _info = self.token_handler.token_factory.get_info(_at)
        try:
            _rt = self.token_handler.get_refresh_token(self.baseurl,
                                                       _info['access_token'],
                                                       'password')
        except NotAllowed:
            atr = self.do_access_token_response(_at, _info, areq['state'])
        else:
            atr = self.do_access_token_response(_at, _info, areq['state'], _rt)

        return Response(atr.to_json(), content="application/json")

    def refresh_token_grant_type(self, areq):
        at = self.token_handler.refresh_access_token(self.baseurl,
                                                     areq['access_token'],
                                                     'refresh_token')

        atr = AccessTokenResponse(**by_schema(AccessTokenResponse, **at))
        return Response(atr.to_json(), content="application/json")

    def token_endpoint(self, authn="", **kwargs):
        """
        This is where clients come to get their access tokens
        """

        logger.debug("- token -")
        body = kwargs["request"]
        logger.debug("body: %s" % body)

        areq = AccessTokenRequest().deserialize(body, "urlencoded")

        try:
            self.client_authn(self, areq, authn)
        except FailedAuthentication as err:
            logger.error(err)
            err = TokenErrorResponse(error="unauthorized_client",
                                     error_description="%s" % err)
            return Response(err.to_json(),
                            content="application/json",
                            status_code=401)

        logger.debug("AccessTokenRequest: %s" % areq)

        _grant_type = areq["grant_type"]
        if _grant_type == "authorization_code":
            return self.code_grant_type(areq)
        elif _grant_type == 'client_credentials':
            return self.client_credentials_grant_type(areq)
        elif _grant_type == 'password':
            return self.password_grant_type(areq)
        elif _grant_type == 'refresh_token':
            return self.refresh_token_grant_type(areq)
        else:
            raise UnSupported('grant_type: {}'.format(_grant_type))

    def key_setup(self, local_path, vault="keys", sig=None, enc=None):
        """
        my keys
        :param local_path: The path to where the JWKs should be stored
        :param vault: Where the private key will be stored
        :param sig: Key for signature
        :param enc: Key for encryption
        :return: A URL the RP can use to download the key.
        """
        self.jwks_uri = key_export(self.baseurl,
                                   local_path,
                                   vault,
                                   self.keyjar,
                                   fqdn=self.hostname,
                                   sig=sig,
                                   enc=enc)

    @staticmethod
    def token_access(endpoint, client_id, token_info):
        # simple rules: if client_id in azp or aud it's allow to introspect
        # to revoke it has to be in azr
        allow = False
        if endpoint == 'revocation_endpoint':
            if 'azr' in token_info and client_id == token_info['azr']:
                allow = True
            elif len(token_info['aud']) == 1 and token_info['aud'] == [
                    client_id
            ]:
                allow = True
        else:  # has to be introspection endpoint
            if 'azr' in token_info and client_id == token_info['azr']:
                allow = True
            elif 'aud' in token_info:
                if client_id in token_info['aud']:
                    allow = True
        return allow

    def get_token_info(self, authn, req, endpoint):
        """

        :param authn:
        :param req:
        :return:
        """
        try:
            client_id = self.client_authn(self, req, authn)
        except FailedAuthentication as err:
            logger.error(err)
            err = TokenErrorResponse(error="unauthorized_client",
                                     error_description="%s" % err)
            return Response(err.to_json(),
                            content="application/json",
                            status="401 Unauthorized")

        logger.debug('{}: {} requesting {}'.format(endpoint, client_id,
                                                   req.to_dict()))

        try:
            token_type = req['token_type_hint']
        except KeyError:
            try:
                _info = self.sdb.token_factory['access_token'].get_info(
                    req['token'])
            except Exception:
                try:
                    _info = self.sdb.token_factory['refresh_token'].get_info(
                        req['token'])
                except Exception:
                    return self._return_inactive()
                else:
                    token_type = 'refresh_token'
            else:
                token_type = 'access_token'
        else:
            try:
                _info = self.sdb.token_factory[token_type].get_info(
                    req['token'])
            except Exception:
                return self._return_inactive()

        if not self.token_access(endpoint, client_id, _info):
            return BadRequest()

        return client_id, token_type, _info

    @staticmethod
    def _return_inactive():
        ir = TokenIntrospectionResponse(active=False)
        return Response(ir.to_json(), content="application/json")

    def revocation_endpoint(self, authn='', request=None, **kwargs):
        """
        Implements RFC7009 allows a client to invalidate an access or refresh
        token.

        :param authn: Client Authentication information
        :param request: The revocation request
        :param kwargs:
        :return:
        """

        trr = TokenRevocationRequest().deserialize(request, "urlencoded")

        resp = self.get_token_info(authn, trr, 'revocation_endpoint')

        if isinstance(resp, Response):
            return resp
        else:
            client_id, token_type, _info = resp

        logger.info('{} token revocation: {}'.format(client_id, trr.to_dict()))

        try:
            self.sdb.token_factory[token_type].invalidate(trr['token'])
        except KeyError:
            return BadRequest()
        else:
            return Response('OK')

    def introspection_endpoint(self, authn='', request=None, **kwargs):
        """
        Implements RFC7662

        :param authn: Client Authentication information
        :param request: The introspection request
        :param kwargs:
        :return:
        """

        tir = TokenIntrospectionRequest().deserialize(request, "urlencoded")

        resp = self.get_token_info(authn, tir, 'introspection_endpoint')

        if isinstance(resp, Response):
            return resp
        else:
            client_id, token_type, _info = resp

        logger.info('{} token introspection: {}'.format(
            client_id, tir.to_dict()))

        ir = TokenIntrospectionResponse(
            active=self.sdb.token_factory[token_type].is_valid(_info),
            **_info.to_dict())

        ir.weed()

        return Response(ir.to_json(), content="application/json")
Exemplo n.º 6
0
class TestTokenHandler(object):
    @pytest.fixture(autouse=True)
    def create_handler(self):
        self.th = TokenHandler('https://example.com/as', {
            'access_token': {
                'https://example.org/rp': {
                    'client_credentials': 1200
                }
            },
            'refresh_token': {
                'https://example.org/rp': {
                    'client_credentials': 86400
                }
            }
        },
                               keyjar=KEYJAR)

    def test_construct_access_token(self):
        token = self.th.get_access_token('https://example.org/rp', 'foo bar',
                                         'client_credentials')

        assert token

        info = self.th.token_factory.get_info(token)

        assert _eq(list(info.keys()),
                   ['jti', 'scope', 'exp', 'iss', 'aud', 'iat', 'kid', 'azp'])

    def test_construct_access_token_fail(self):
        # Unknown client
        try:
            self.th.get_access_token('https://example.com/rp', 'foo bar',
                                     'client_credentials')
        except NotAllowed:
            pass
        # wrong grant_type
        try:
            self.th.get_access_token('https://example.org/rp', 'foo bar',
                                     'implicit')
        except NotAllowed:
            pass

    def test_from_access_to_refresh_token(self):

        token = self.th.get_access_token('https://example.org/rp', 'foo bar',
                                         'client_credentials')

        refresh_token = self.th.refresh_access_token('https://example.org/rp',
                                                     token,
                                                     'client_credentials')

        assert refresh_token

    def test_construct_refresh_token(self):
        sid = '1234'
        rtoken = self.th.get_refresh_token('https://example.org/rp',
                                           grant_type='client_credentials',
                                           sid=sid)

        info = self.th.token_factory.get_info(rtoken)

        assert _eq(list(info.keys()),
                   ['jti', 'exp', 'iss', 'aud', 'iat', 'kid', 'azp'])

        assert self.th.refresh_token_factory.db[info['jti']] == sid
Exemplo n.º 7
0
class Provider(provider.Provider):
    """
    A OAuth2 RP that knows all the OAuth2 extensions I've implemented
    """

    def __init__(self, name, sdb, cdb, authn_broker, authz, client_authn,
                 symkey="", urlmap=None, iv=0, default_scope="",
                 ca_bundle=None, seed=b"", client_authn_methods=None,
                 authn_at_registration="", client_info_url="",
                 secret_lifetime=86400, jwks_uri='', keyjar=None,
                 capabilities=None, verify_ssl=True, baseurl='', hostname='',
                 config=None, behavior=None, lifetime_policy=None, **kwargs):

        if not name.endswith("/"):
            name += "/"

        try:
            args = {'server_cls': kwargs['server_cls']}
        except KeyError:
            args = {}

        provider.Provider.__init__(self, name, sdb, cdb, authn_broker, authz,
                                   client_authn, symkey, urlmap, iv,
                                   default_scope, ca_bundle, **args)

        self.endp.extend([RegistrationEndpoint, ClientInfoEndpoint,
                          RevocationEndpoint, IntrospectionEndpoint])

        # dictionary of client authentication methods
        self.client_authn_methods = client_authn_methods
        if authn_at_registration:
            if authn_at_registration not in client_authn_methods:
                raise UnknownAuthnMethod(authn_at_registration)

        self.authn_at_registration = authn_at_registration
        self.seed = seed
        self.client_info_url = client_info_url
        self.secret_lifetime = secret_lifetime
        self.jwks_uri = jwks_uri
        self.verify_ssl = verify_ssl
        try:
            self.scopes = kwargs['scopes']
        except KeyError:
            self.scopes = ['offline_access']
        self.keyjar = keyjar
        if self.keyjar is None:
            self.keyjar = KeyJar(verify_ssl=self.verify_ssl)

        if capabilities:
            self.capabilities = self.provider_features(
                provider_config=capabilities)
        else:
            self.capabilities = self.provider_features()
        self.baseurl = baseurl or name
        self.hostname = hostname or socket.gethostname()
        self.kid = {"sig": {}, "enc": {}}
        self.config = config or {}
        self.behavior = behavior or {}
        self.token_policy = {'access_token': {}, 'refresh_token': {}}
        if lifetime_policy is None:
            self.lifetime_policy = {
                'access_token': {
                    'code': 600,
                    'token': 120,
                    'implicit': 120,
                    'authorization_code': 600,
                    'client_credentials': 600,
                    'password': 600
                },
                'refresh_token': {
                    'code': 3600,
                    'token': 3600,
                    'implicit': 3600,
                    'authorization_code': 3600,
                    'client_credentials': 3600,
                    'password': 3600
                }
            }
        else:
            self.lifetime_policy = lifetime_policy

        self.token_handler = TokenHandler(self.baseurl, self.token_policy,
                                          keyjar=self.keyjar)

    @staticmethod
    def _uris_to_tuples(uris):
        tup = []
        for uri in uris:
            base, query = splitquery(uri)
            if query:
                tup.append((base, query))
            else:
                tup.append((base, ""))
        return tup

    @staticmethod
    def _tuples_to_uris(items):
        _uri = []
        for url, query in items:
            if query:
                _uri.append("%s?%s" % (url, query))
            else:
                _uri.append(url)
        return _uri

    def load_keys(self, request, client_id, client_secret):
        try:
            self.keyjar.load_keys(request, client_id)
            try:
                logger.debug("keys for %s: [%s]" % (
                    client_id,
                    ",".join(["%s" % x for x in self.keyjar[client_id]])))
            except KeyError:
                pass
        except Exception as err:
            logger.error("Failed to load client keys: %s" % request.to_dict())
            logger.error("%s", err)
            err = ClientRegistrationError(
                error="invalid_configuration_parameter",
                error_description="%s" % err)
            return Response(err.to_json(), content="application/json",
                            status="400 Bad Request")

        # Add the client_secret as a symmetric key to the keyjar
        _kc = KeyBundle([{"kty": "oct", "key": client_secret,
                          "use": "ver"},
                         {"kty": "oct", "key": client_secret,
                          "use": "sig"}])
        try:
            self.keyjar[client_id].append(_kc)
        except KeyError:
            self.keyjar[client_id] = [_kc]

    @staticmethod
    def verify_correct(cinfo, restrictions):
        for fname, arg in restrictions.items():
            func = restrict.factory(fname)
            res = func(arg, cinfo)
            if res:
                raise RestrictionError(res)

    def set_token_policy(self, cid, cinfo):
        for ttyp in ['access_token', 'refresh_token']:
            pol = {}
            for rgtyp in ['response_type', 'grant_type']:
                try:
                    rtyp = cinfo[rgtyp]
                except KeyError:
                    pass
                else:
                    for typ in rtyp:
                        try:
                            pol[typ] = self.lifetime_policy[ttyp][typ]
                        except KeyError:
                            pass

            self.token_policy[ttyp][cid] = pol

    def create_new_client(self, request, restrictions):
        """

        :param request: The Client registration request
        :param restrictions: Restrictions on the client
        :return: The client_id
        """

        _cinfo = request.to_dict()

        self.match_client_request(_cinfo)

        # create new id and secret
        _id = rndstr(12)
        while _id in self.cdb:
            _id = rndstr(12)

        _cinfo["client_id"] = _id
        _cinfo["client_secret"] = secret(self.seed, _id)
        _cinfo["client_id_issued_at"] = utc_time_sans_frac()
        _cinfo["client_secret_expires_at"] = utc_time_sans_frac() + \
                                             self.secret_lifetime

        # If I support client info endpoint
        if ClientInfoEndpoint in self.endp:
            _cinfo["registration_access_token"] = rndstr(32)
            _cinfo["registration_client_uri"] = "%s%s%s?client_id=%s" % (
                self.name, self.client_info_url, ClientInfoEndpoint.etype,
                _id)

        if "redirect_uris" in request:
            _cinfo["redirect_uris"] = self._uris_to_tuples(
                request["redirect_uris"])

        self.load_keys(request, _id, _cinfo["client_secret"])

        try:
            _behav = self.behavior['client_registration']
        except KeyError:
            pass
        else:
            self.verify_correct(_cinfo, _behav)

        self.set_token_policy(_id, _cinfo)
        self.cdb[_id] = _cinfo

        return _id

    def match_client_request(self, request):
        for _pref, _prov in PREFERENCE2PROVIDER.items():
            if _pref in request:
                if _pref == "response_types":
                    for val in request[_pref]:
                        match = False
                        p = set(val.split(" "))
                        for cv in self.capabilities[_prov]:
                            if p == set(cv.split(' ')):
                                match = True
                                break
                        if not match:
                            raise CapabilitiesMisMatch(
                                'Not allowed {}'.format(_pref))
                else:
                    if isinstance(request[_pref], six.string_types):
                        if request[_pref] not in self.capabilities[_prov]:
                            raise CapabilitiesMisMatch(
                                'Not allowed {}'.format(_pref))
                    else:
                        if not set(request[_pref]).issubset(
                                set(self.capabilities[_prov])):
                            raise CapabilitiesMisMatch(
                                'Not allowed {}'.format(_pref))

    def client_info(self, client_id):
        _cinfo = self.cdb[client_id].copy()
        if not valid_client_info(_cinfo):
            err = ErrorResponse(
                error="invalid_client",
                error_description="Invalid client secret")
            return BadRequest(err.to_json(), content="application/json")

        try:
            _cinfo["redirect_uris"] = self._tuples_to_uris(
                _cinfo["redirect_uris"])
        except KeyError:
            pass

        msg = ClientInfoResponse(**_cinfo)
        return Response(msg.to_json(), content="application/json")

    def client_info_update(self, client_id, request):
        _cinfo = self.cdb[client_id].copy()
        try:
            _cinfo["redirect_uris"] = self._tuples_to_uris(
                _cinfo["redirect_uris"])
        except KeyError:
            pass

        for key, value in request.items():
            if key in ["client_secret", "client_id"]:
                # assure it's the same
                try:
                    assert value == _cinfo[key]
                except AssertionError:
                    raise ModificationForbidden("Not allowed to change")
            else:
                _cinfo[key] = value

        for key in list(_cinfo.keys()):
            if key in ["client_id_issued_at", "client_secret_expires_at",
                       "registration_access_token", "registration_client_uri"]:
                continue
            if key not in request:
                del _cinfo[key]

        if "redirect_uris" in request:
            _cinfo["redirect_uris"] = self._uris_to_tuples(
                request["redirect_uris"])

        self.cdb[client_id] = _cinfo

    def verify_client(self, environ, areq, authn_method, client_id=""):
        """

        :param environ: WSGI environ
        :param areq: The request
        :param authn_method: client authentication method
        :return:
        """

        if not client_id:
            client_id = get_client_id(self.cdb, areq,
                                      environ["HTTP_AUTHORIZATION"])

        try:
            method = self.client_authn_methods[authn_method]
        except KeyError:
            raise UnSupported()
        return method(self).verify(environ, client_id=client_id)

    def consume_software_statement(self, software_statement):
        return {}

    def registration_endpoint(self, **kwargs):
        """

        :param request: The request
        :param authn: Client authentication information
        :param kwargs: extra keyword arguments
        :return: A Response instance
        """

        _request = RegistrationRequest().deserialize(kwargs['request'], "json")
        try:
            _request.verify(keyjar=self.keyjar)
        except InvalidRedirectUri as err:
            msg = ClientRegistrationError(error="invalid_redirect_uri",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")
        except (MissingPage, VerificationError) as err:
            msg = ClientRegistrationError(error="invalid_client_metadata",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")

        # If authentication is necessary at registration
        if self.authn_at_registration:
            try:
                self.verify_client(kwargs['environ'], _request,
                                   self.authn_at_registration)
            except (AuthnFailure, UnknownAssertionType):
                return Unauthorized()

        client_restrictions = {}
        if 'parsed_software_statement' in _request:
            for ss in _request['parsed_software_statement']:
                client_restrictions.update(self.consume_software_statement(ss))
            del _request['software_statement']
            del _request['parsed_software_statement']

        try:
            client_id = self.create_new_client(_request, client_restrictions)
        except CapabilitiesMisMatch as err:
            msg = ClientRegistrationError(error="invalid_client_metadata",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")
        except RestrictionError as err:
            msg = ClientRegistrationError(error="invalid_client_metadata",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")

        return self.client_info(client_id)

    def client_info_endpoint(self, method="GET", **kwargs):
        """
        Operations on this endpoint are switched through the use of different
        HTTP methods

        :param method: HTTP method used for the request
        :param kwargs: keyword arguments
        :return: A Response instance
        """

        _query = parse_qs(kwargs['query'])
        try:
            _id = _query["client_id"][0]
        except KeyError:
            return BadRequest("Missing query component")

        try:
            assert _id in self.cdb
        except AssertionError:
            return Unauthorized()

        # authenticated client
        try:
            self.verify_client(kwargs['environ'], kwargs['request'],
                               "bearer_header", client_id=_id)
        except (AuthnFailure, UnknownAssertionType):
            return Unauthorized()

        if method == "GET":
            return self.client_info(_id)
        elif method == "PUT":
            try:
                _request = ClientUpdateRequest().from_json(kwargs['request'])
            except ValueError as err:
                return BadRequest(str(err))

            try:
                _request.verify()
            except InvalidRedirectUri as err:
                msg = ClientRegistrationError(error="invalid_redirect_uri",
                                              error_description="%s" % err)
                return BadRequest(msg.to_json(), content="application/json")
            except (MissingPage, VerificationError) as err:
                msg = ClientRegistrationError(error="invalid_client_metadata",
                                              error_description="%s" % err)
                return BadRequest(msg.to_json(), content="application/json")

            try:
                self.client_info_update(_id, _request)
                return self.client_info(_id)
            except ModificationForbidden:
                return Forbidden()
        elif method == "DELETE":
            try:
                del self.cdb[_id]
            except KeyError:
                return Unauthorized()
            else:
                return NoContent()

    def provider_features(self, pcr_class=ServerMetadata, provider_config=None):
        """
        Specifies what the server capabilities are.

        :param pcr_class:
        :return: ProviderConfigurationResponse instance
        """

        _provider_info = pcr_class(**CAPABILITIES)
        _provider_info["scopes_supported"] = self.scopes

        sign_algs = list(jws.SIGNER_ALGS.keys())
        sign_algs.remove('none')
        sign_algs = sorted(sign_algs, key=cmp_to_key(sort_sign_alg))

        _pat1 = "{}_endpoint_auth_signing_alg_values_supported"
        _pat2 = "{}_endpoint_auth_methods_supported"
        for typ in ["token", "revocation", "introspection"]:
            _provider_info[_pat1.format(typ)] = sign_algs
            _provider_info[_pat2.format(typ)] = AUTH_METHODS_SUPPORTED

        if provider_config:
            _provider_info.update(provider_config)

        return _provider_info

    def verify_capabilities(self, capabilities):
        """
        Verify that what the admin wants the server to do actually
        can be done by this implementation.

        :param capabilities: The asked for capabilities as a dictionary
        or a ProviderConfigurationResponse instance. The later can be
        treated as a dictionary.
        :return: True or False
        """
        _pinfo = self.provider_features()
        for key, val in capabilities.items():
            if isinstance(val, six.string_types):
                try:
                    if val in _pinfo[key]:
                        continue
                    else:
                        return False
                except KeyError:
                    return False

        return True

    def create_providerinfo(self, pcr_class=ASConfigurationResponse,
                            setup=None):
        """
        Dynamically create the provider info response
        :param pcr_class:
        :param setup:
        :return:
        """

        _provider_info = self.capabilities

        if self.jwks_uri and self.keyjar:
            _provider_info["jwks_uri"] = self.jwks_uri

        for endp in self.endp:
            # _log_info("# %s, %s" % (endp, endp.name))
            _provider_info['{}_endpoint'.format(endp.etype)] = os.path.join(
                self.baseurl, endp.url)

        if setup and isinstance(setup, dict):
            for key in pcr_class.c_param.keys():
                if key in setup:
                    _provider_info[key] = setup[key]

        _provider_info["issuer"] = self.baseurl
        _provider_info["version"] = "3.0"

        return _provider_info

    def providerinfo_endpoint(self, **kwargs):
        _log_info = logger.info

        _log_info("@providerinfo_endpoint")
        try:
            _response = self.create_providerinfo()
            _log_info("provider_info_response: %s" % (_response.to_dict(),))

            headers = [("Cache-Control", "no-store"), ("x-ffo", "bar")]
            if 'handle' in kwargs:
                (key, timestamp) = kwargs['handle']
                if key.startswith(STR) and key.endswith(STR):
                    cookie = self.cookie_func(key, self.cookie_name, "pinfo",
                                              self.sso_ttl)
                    headers.append(cookie)

            resp = Response(_response.to_json(), content="application/json",
                            headers=headers)
        except Exception as err:
            message = traceback.format_exception(*sys.exc_info())
            logger.error(message)
            resp = Response(message, content="html/text")

        return resp

    @staticmethod
    def verify_code_challenge(code_verifier, code_challenge,
                              code_challenge_method='S256'):
        """
        Verify a PKCE (RFC7636) code challenge

        :param code_verifier: The origin
        :param code_challenge: The transformed verifier used as challenge
        :return:
        """
        _h = CC_METHOD[code_challenge_method](
            code_verifier.encode()).hexdigest()
        _cc = b64e(_h.encode())
        if _cc.decode() != code_challenge:
            logger.error('PCKE Code Challenge check failed')
            err = TokenErrorResponse(error="invalid_request",
                                     error_description="PCKE check failed")
            return Response(err.to_json(), content="application/json",
                            status="401 Unauthorized")
        return True

    def do_access_token_response(self, access_token, atinfo, state,
                                 refresh_token=None):
        _tinfo = {'access_token': access_token, 'expires_in': atinfo['exp'],
                  'token_type': 'bearer', 'state': state}
        try:
            _tinfo['scope'] = atinfo['scope']
        except KeyError:
            pass

        if refresh_token:
            _tinfo['refresh_token'] = refresh_token

        return AccessTokenResponse(**by_schema(AccessTokenResponse, **_tinfo))

    def code_grant_type(self, areq):
        # assert that the code is valid
        try:
            _info = self.sdb[areq["code"]]
        except KeyError:
            err = TokenErrorResponse(error="invalid_grant",
                                     error_description="Unknown access grant")
            return Response(err.to_json(), content="application/json",
                            status="401 Unauthorized")

        authzreq = json.loads(_info['authzreq'])
        if 'code_verifier' in areq:
            try:
                _method = authzreq['code_challenge_method']
            except KeyError:
                _method = 'S256'

            resp = self.verify_code_challenge(areq['code_verifier'],
                                              authzreq['code_challenge'],
                                              _method)
            if resp:
                return resp

        if 'state' in areq:
            if self.sdb[areq['code']]['state'] != areq['state']:
                err = TokenErrorResponse(error="unauthorized_client")
                return Unauthorized(err.to_json(), content="application/json")

        resp = self.token_scope_check(areq, _info)
        if resp:
            return resp

        # If redirect_uri was in the initial authorization request
        # verify that the one given here is the correct one.
        if "redirect_uri" in _info:
            assert areq["redirect_uri"] == _info["redirect_uri"]

        issue_refresh = False
        if 'scope' in authzreq and 'offline_access' in authzreq['scope']:
            if authzreq['response_type'] == 'code':
                issue_refresh = True

        try:
            _tinfo = self.sdb.upgrade_to_token(areq["code"],
                                               issue_refresh=issue_refresh)
        except AccessCodeUsed:
            err = TokenErrorResponse(error="invalid_grant",
                                     error_description="Access grant used")
            return Response(err.to_json(), content="application/json",
                            status="401 Unauthorized")

        logger.debug("_tinfo: %s" % _tinfo)

        atr = AccessTokenResponse(**by_schema(AccessTokenResponse, **_tinfo))

        logger.debug("AccessTokenResponse: %s" % atr)

        return Response(atr.to_json(), content="application/json")

    def client_credentials_grant_type(self, areq):
        _at = self.token_handler.get_access_token(areq['client_id'],
                                                  scope=areq['scope'],
                                                  grant_type='client_credentials')
        _info = self.token_handler.token_factory.get_info(_at)
        try:
            _rt = self.token_handler.get_refresh_token(
                self.baseurl, _info['access_token'], 'client_credentials')
        except NotAllowed:
            atr = self.do_access_token_response(_at, _info, areq['state'])
        else:
            atr = self.do_access_token_response(_at, _info, areq['state'], _rt)

        return Response(atr.to_json(), content="application/json")

    def password_grant_type(self, areq):
        _at = self.token_handler.get_access_token(areq['client_id'],
                                                  scope=areq['scope'],
                                                  grant_type='password')
        _info = self.token_handler.token_factory.get_info(_at)
        try:
            _rt = self.token_handler.get_refresh_token(
                self.baseurl, _info['access_token'], 'password')
        except NotAllowed:
            atr = self.do_access_token_response(_at, _info, areq['state'])
        else:
            atr = self.do_access_token_response(_at, _info, areq['state'], _rt)

        return Response(atr.to_json(), content="application/json")

    def refresh_token_grant_type(self, areq):
        at = self.token_handler.refresh_access_token(
            self.baseurl, areq['access_token'], 'refresh_token')

        atr = AccessTokenResponse(**by_schema(AccessTokenResponse, **at))
        return Response(atr.to_json(), content="application/json")

    def token_endpoint(self, authn="", **kwargs):
        """
        This is where clients come to get their access tokens
        """

        _sdb = self.sdb

        logger.debug("- token -")
        body = kwargs["request"]
        logger.debug("body: %s" % body)

        areq = AccessTokenRequest().deserialize(body, "urlencoded")

        try:
            client_id = self.client_authn(self, areq, authn)
        except FailedAuthentication as err:
            err = TokenErrorResponse(error="unauthorized_client",
                                     error_description="%s" % err)
            return Response(err.to_json(), content="application/json",
                            status="401 Unauthorized")

        logger.debug("AccessTokenRequest: %s" % areq)

        _grant_type = areq["grant_type"]
        if _grant_type == "authorization_code":
            return self.code_grant_type(areq)
        elif _grant_type == 'client_credentials':
            return self.client_credentials_grant_type(areq)
        elif _grant_type == 'password':
            return self.password_grant_type(areq)
        elif _grant_type == 'refresh_token':
            return self.refresh_token_grant_type(areq)
        else:
            raise UnSupported('grant_type: {}'.format(_grant_type))

    def key_setup(self, local_path, vault="keys", sig=None, enc=None):
        """
        my keys
        :param local_path: The path to where the JWKs should be stored
        :param vault: Where the private key will be stored
        :param sig: Key for signature
        :param enc: Key for encryption
        :return: A URL the RP can use to download the key.
        """
        self.jwks_uri = key_export(self.baseurl, local_path, vault, self.keyjar,
                                   fqdn=self.hostname, sig=sig, enc=enc)

    @staticmethod
    def token_access(endpoint, client_id, token_info):
        # simple rules: if client_id in azp or aud it's allow to introspect
        # to revoke it has to be in azr
        allow = False
        if endpoint == 'revocation_endpoint':
            if 'azr' in token_info and client_id == token_info['azr']:
                allow = True
            elif len(token_info['aud']) == 1 and token_info['aud'] == [
                client_id]:
                allow = True
        else:  # has to be introspection endpoint
            if 'azr' in token_info and client_id == token_info['azr']:
                allow = True
            elif 'aud' in token_info:
                if client_id in token_info['aud']:
                    allow = True
        return allow

    def get_token_info(self, authn, req, endpoint):
        """

        :param authn:
        :param req:
        :return:
        """
        try:
            client_id = self.client_authn(self, req, authn)
        except FailedAuthentication as err:
            err = TokenErrorResponse(error="unauthorized_client",
                                     error_description="%s" % err)
            return Response(err.to_json(), content="application/json",
                            status="401 Unauthorized")

        logger.debug('{}: {} requesting {}'.format(endpoint, client_id,
                                                   req.to_dict()))

        try:
            token_type = req['token_type_hint']
        except KeyError:
            try:
                _info = self.sdb.token_factory['access_token'].info(
                    req['token'])
            except KeyError:
                try:
                    _info = self.sdb.token_factory['refresh_token'].get_info(
                        req['token'])
                except KeyError:
                    raise
                else:
                    token_type = 'refresh_token'
            else:
                token_type = 'access_token'
        else:
            try:
                _info = self.sdb.token_factory[token_type].get_info(
                    req['token'])
            except KeyError:
                raise

        if not self.token_access(endpoint, client_id, _info):
            return BadRequest()

        return client_id, token_type, _info

    def revocation_endpoint(self, authn='', request=None, **kwargs):
        """
        Implements RFC7009 allows a client to invalidate an access or refresh
        token.

        :param authn: Client Authentication information
        :param request: The revocation request
        :param kwargs:
        :return:
        """

        trr = TokenRevocationRequest().deserialize(request, "urlencoded")

        resp = self.get_token_info(authn, trr, 'revocation_endpoint')

        if isinstance(resp, Response):
            return resp
        else:
            client_id, token_type, _info = resp

        logger.info('{} token revocation: {}'.format(client_id, trr.to_dict()))

        try:
            self.sdb.token_factory[token_type].invalidate(trr['token'])
        except KeyError:
            return BadRequest()
        else:
            return Response('OK')

    def introspection_endpoint(self, authn='', request=None, **kwargs):
        """
        Implements RFC7662

        :param authn: Client Authentication information
        :param request: The introspection request
        :param kwargs:
        :return:
        """

        tir = TokenIntrospectionRequest().deserialize(request, "urlencoded")

        resp = self.get_token_info(authn, tir, 'introspection_endpoint')

        if isinstance(resp, Response):
            return resp
        else:
            client_id, token_type, _info = resp

        logger.info('{} token introspection: {}'.format(client_id,
                                                        tir.to_dict()))

        ir = TokenIntrospectionResponse(
            active=self.sdb.token_factory[token_type].is_valid(_info),
            **_info.to_dict())

        ir.weed()

        return Response(ir.to_json(), content="application/json")
Exemplo n.º 8
0
    def __init__(self, name, sdb, cdb, authn_broker, authz, client_authn,
                 symkey="", urlmap=None, iv=0, default_scope="",
                 ca_bundle=None, seed=b"", client_authn_methods=None,
                 authn_at_registration="", client_info_url="",
                 secret_lifetime=86400, jwks_uri='', keyjar=None,
                 capabilities=None, verify_ssl=True, baseurl='', hostname='',
                 config=None, behavior=None, lifetime_policy=None, **kwargs):

        if not name.endswith("/"):
            name += "/"

        try:
            args = {'server_cls': kwargs['server_cls']}
        except KeyError:
            args = {}

        provider.Provider.__init__(self, name, sdb, cdb, authn_broker, authz,
                                   client_authn, symkey, urlmap, iv,
                                   default_scope, ca_bundle, **args)

        self.endp.extend([RegistrationEndpoint, ClientInfoEndpoint,
                          RevocationEndpoint, IntrospectionEndpoint])

        # dictionary of client authentication methods
        self.client_authn_methods = client_authn_methods
        if authn_at_registration:
            if authn_at_registration not in client_authn_methods:
                raise UnknownAuthnMethod(authn_at_registration)

        self.authn_at_registration = authn_at_registration
        self.seed = seed
        self.client_info_url = client_info_url
        self.secret_lifetime = secret_lifetime
        self.jwks_uri = jwks_uri
        self.verify_ssl = verify_ssl
        try:
            self.scopes = kwargs['scopes']
        except KeyError:
            self.scopes = ['offline_access']
        self.keyjar = keyjar
        if self.keyjar is None:
            self.keyjar = KeyJar(verify_ssl=self.verify_ssl)

        if capabilities:
            self.capabilities = self.provider_features(
                provider_config=capabilities)
        else:
            self.capabilities = self.provider_features()
        self.baseurl = baseurl or name
        self.hostname = hostname or socket.gethostname()
        self.kid = {"sig": {}, "enc": {}}
        self.config = config or {}
        self.behavior = behavior or {}
        self.token_policy = {'access_token': {}, 'refresh_token': {}}
        if lifetime_policy is None:
            self.lifetime_policy = {
                'access_token': {
                    'code': 600,
                    'token': 120,
                    'implicit': 120,
                    'authorization_code': 600,
                    'client_credentials': 600,
                    'password': 600
                },
                'refresh_token': {
                    'code': 3600,
                    'token': 3600,
                    'implicit': 3600,
                    'authorization_code': 3600,
                    'client_credentials': 3600,
                    'password': 3600
                }
            }
        else:
            self.lifetime_policy = lifetime_policy

        self.token_handler = TokenHandler(self.baseurl, self.token_policy,
                                          keyjar=self.keyjar)
Exemplo n.º 9
0
    def __init__(
        self,
        name,
        sdb,
        cdb,
        authn_broker,
        authz,
        client_authn,
        symkey=None,
        urlmap=None,
        iv=0,
        default_scope="",
        ca_bundle=None,
        seed=b"",
        client_authn_methods=None,
        authn_at_registration="",
        client_info_url="",
        secret_lifetime=86400,
        jwks_uri="",
        keyjar=None,
        capabilities=None,
        verify_ssl=True,
        baseurl="",
        hostname="",
        config=None,
        behavior=None,
        lifetime_policy=None,
        message_factory=ExtensionMessageFactory,
        **kwargs
    ):

        if not name.endswith("/"):
            name += "/"

        try:
            args = {"server_cls": kwargs["server_cls"]}
        except KeyError:
            args = {}

        super().__init__(
            name,
            sdb,
            cdb,
            authn_broker,
            authz,
            client_authn,
            symkey,
            urlmap,
            iv,
            default_scope,
            ca_bundle,
            message_factory=message_factory,
            **args
        )

        self.endp.extend(
            [
                RegistrationEndpoint,
                ClientInfoEndpoint,
                RevocationEndpoint,
                IntrospectionEndpoint,
            ]
        )

        # dictionary of client authentication methods
        self.client_authn_methods = client_authn_methods
        if authn_at_registration:
            if authn_at_registration not in client_authn_methods:
                raise UnknownAuthnMethod(authn_at_registration)

        self.authn_at_registration = authn_at_registration
        self.seed = seed
        self.client_info_url = client_info_url
        self.secret_lifetime = secret_lifetime
        self.jwks_uri = jwks_uri
        self.verify_ssl = verify_ssl
        self.scopes.extend(kwargs.get("scopes", []))
        self.keyjar = keyjar
        if self.keyjar is None:
            self.keyjar = KeyJar(verify_ssl=self.verify_ssl)

        if capabilities:
            self.capabilities = self.provider_features(provider_config=capabilities)
        else:
            self.capabilities = self.provider_features()
        self.baseurl = baseurl or name
        self.hostname = hostname or socket.gethostname()
        self.kid = {"sig": {}, "enc": {}}  # type: Dict[str, Dict[str, str]]
        self.config = config or {}
        self.behavior = behavior or {}
        self.token_policy = {
            "access_token": {},
            "refresh_token": {},
        }  # type: Dict[str, Dict[str, str]]
        if lifetime_policy is None:
            self.lifetime_policy = {
                "access_token": {
                    "code": 600,
                    "token": 120,
                    "implicit": 120,
                    "authorization_code": 600,
                    "client_credentials": 600,
                    "password": 600,
                },
                "refresh_token": {
                    "code": 3600,
                    "token": 3600,
                    "implicit": 3600,
                    "authorization_code": 3600,
                    "client_credentials": 3600,
                    "password": 3600,
                },
            }
        else:
            self.lifetime_policy = lifetime_policy

        self.token_handler = TokenHandler(
            self.baseurl, self.token_policy, keyjar=self.keyjar
        )
Exemplo n.º 10
0
class Provider(provider.Provider):
    """A OAuth2 RP that knows all the OAuth2 extensions I've implemented."""

    def __init__(
        self,
        name,
        sdb,
        cdb,
        authn_broker,
        authz,
        client_authn,
        symkey=None,
        urlmap=None,
        iv=0,
        default_scope="",
        ca_bundle=None,
        seed=b"",
        client_authn_methods=None,
        authn_at_registration="",
        client_info_url="",
        secret_lifetime=86400,
        jwks_uri="",
        keyjar=None,
        capabilities=None,
        verify_ssl=True,
        baseurl="",
        hostname="",
        config=None,
        behavior=None,
        lifetime_policy=None,
        message_factory=ExtensionMessageFactory,
        **kwargs
    ):

        if not name.endswith("/"):
            name += "/"

        try:
            args = {"server_cls": kwargs["server_cls"]}
        except KeyError:
            args = {}

        super().__init__(
            name,
            sdb,
            cdb,
            authn_broker,
            authz,
            client_authn,
            symkey,
            urlmap,
            iv,
            default_scope,
            ca_bundle,
            message_factory=message_factory,
            **args
        )

        self.endp.extend(
            [
                RegistrationEndpoint,
                ClientInfoEndpoint,
                RevocationEndpoint,
                IntrospectionEndpoint,
            ]
        )

        # dictionary of client authentication methods
        self.client_authn_methods = client_authn_methods
        if authn_at_registration:
            if authn_at_registration not in client_authn_methods:
                raise UnknownAuthnMethod(authn_at_registration)

        self.authn_at_registration = authn_at_registration
        self.seed = seed
        self.client_info_url = client_info_url
        self.secret_lifetime = secret_lifetime
        self.jwks_uri = jwks_uri
        self.verify_ssl = verify_ssl
        self.scopes.extend(kwargs.get("scopes", []))
        self.keyjar = keyjar
        if self.keyjar is None:
            self.keyjar = KeyJar(verify_ssl=self.verify_ssl)

        if capabilities:
            self.capabilities = self.provider_features(provider_config=capabilities)
        else:
            self.capabilities = self.provider_features()
        self.baseurl = baseurl or name
        self.hostname = hostname or socket.gethostname()
        self.kid = {"sig": {}, "enc": {}}  # type: Dict[str, Dict[str, str]]
        self.config = config or {}
        self.behavior = behavior or {}
        self.token_policy = {
            "access_token": {},
            "refresh_token": {},
        }  # type: Dict[str, Dict[str, str]]
        if lifetime_policy is None:
            self.lifetime_policy = {
                "access_token": {
                    "code": 600,
                    "token": 120,
                    "implicit": 120,
                    "authorization_code": 600,
                    "client_credentials": 600,
                    "password": 600,
                },
                "refresh_token": {
                    "code": 3600,
                    "token": 3600,
                    "implicit": 3600,
                    "authorization_code": 3600,
                    "client_credentials": 3600,
                    "password": 3600,
                },
            }
        else:
            self.lifetime_policy = lifetime_policy

        self.token_handler = TokenHandler(
            self.baseurl, self.token_policy, keyjar=self.keyjar
        )

    @staticmethod
    def _uris_to_tuples(uris):
        tup = []
        for uri in uris:
            base, query = splitquery(uri)
            if query:
                tup.append((base, query))
            else:
                tup.append((base, ""))
        return tup

    @staticmethod
    def _tuples_to_uris(items):
        _uri = []
        for url, query in items:
            if query:
                _uri.append("%s?%s" % (url, query))
            else:
                _uri.append(url)
        return _uri

    def load_keys(self, request, client_id, client_secret):
        try:
            self.keyjar.load_keys(request, client_id)
            try:
                n_keys = len(self.keyjar[client_id])
                msg = "Found {} keys for client_id={}"
                logger.debug(msg.format(n_keys, client_id))
            except KeyError:
                pass
        except Exception as err:
            msg = "Failed to load client keys: {}"
            logger.error(msg.format(sanitize(request.to_dict())))
            logger.error("%s", err)
            error = ClientRegistrationError(
                error="invalid_configuration_parameter", error_description="%s" % err
            )
            return Response(
                error.to_json(),
                content="application/json",
                status_code="400 Bad Request",
            )

        # Add the client_secret as a symmetric key to the keyjar
        _kc = KeyBundle(
            [
                {"kty": "oct", "key": client_secret, "use": "ver"},
                {"kty": "oct", "key": client_secret, "use": "sig"},
            ]
        )
        try:
            self.keyjar[client_id].append(_kc)
        except KeyError:
            self.keyjar[client_id] = [_kc]

    @staticmethod
    def verify_correct(cinfo, restrictions):
        for fname, arg in restrictions.items():
            func = restrict.factory(fname)
            res = func(arg, cinfo)
            if res:
                raise RestrictionError(res)

    def set_token_policy(self, cid, cinfo):
        for ttyp in ["access_token", "refresh_token"]:
            pol = {}
            for rgtyp in ["response_type", "grant_type"]:
                try:
                    rtyp = cinfo[rgtyp]
                except KeyError:
                    pass
                else:
                    for typ in rtyp:
                        try:
                            pol[typ] = self.lifetime_policy[ttyp][typ]
                        except KeyError:
                            pass

            self.token_policy[ttyp][cid] = pol

    def create_new_client(self, request, restrictions):
        """
        Create new client based on request and restrictions.

        :param request: The Client registration request
        :param restrictions: Restrictions on the client
        :return: The client_id
        """
        _cinfo = request.to_dict()

        self.match_client_request(_cinfo)

        # create new id and secret
        _id = rndstr(12)
        while _id in self.cdb:
            _id = rndstr(12)

        _cinfo["client_id"] = _id
        _cinfo["client_secret"] = secret(self.seed, _id)
        _cinfo["client_id_issued_at"] = utc_time_sans_frac()
        _cinfo["client_secret_expires_at"] = utc_time_sans_frac() + self.secret_lifetime

        # If I support client info endpoint
        if ClientInfoEndpoint in self.endp:
            _cinfo["registration_access_token"] = rndstr(32)
            _cinfo["registration_client_uri"] = "%s%s%s?client_id=%s" % (
                self.name,
                self.client_info_url,
                ClientInfoEndpoint.etype,
                _id,
            )

        if "redirect_uris" in request:
            _cinfo["redirect_uris"] = self._uris_to_tuples(request["redirect_uris"])

        self.load_keys(request, _id, _cinfo["client_secret"])

        try:
            _behav = self.behavior["client_registration"]
        except KeyError:
            pass
        else:
            self.verify_correct(_cinfo, _behav)

        self.set_token_policy(_id, _cinfo)
        self.cdb[_id] = _cinfo

        return _id

    def match_client_request(self, request):
        for _pref, _prov in PREFERENCE2PROVIDER.items():
            if _pref in request:
                if _pref == "response_types":
                    for val in request[_pref]:
                        match = False
                        p = set(val.split(" "))
                        for cv in self.capabilities[_prov]:
                            if p == set(cv.split(" ")):
                                match = True
                                break
                        if not match:
                            raise CapabilitiesMisMatch("Not allowed {}".format(_pref))
                else:
                    if isinstance(request[_pref], str):
                        if request[_pref] not in self.capabilities[_prov]:
                            raise CapabilitiesMisMatch("Not allowed {}".format(_pref))
                    else:
                        if not set(request[_pref]).issubset(
                            set(self.capabilities[_prov])
                        ):
                            raise CapabilitiesMisMatch("Not allowed {}".format(_pref))

    def client_info(self, client_id):
        _cinfo = self.cdb[client_id].copy()
        if not valid_client_info(_cinfo):
            err = ErrorResponse(
                error="invalid_client", error_description="Invalid client secret"
            )
            return BadRequest(err.to_json(), content="application/json")

        try:
            _cinfo["redirect_uris"] = self._tuples_to_uris(_cinfo["redirect_uris"])
        except KeyError:
            pass

        msg = self.server.message_factory.get_response_type("update_endpoint")(**_cinfo)
        return Response(msg.to_json(), content="application/json")

    def client_info_update(self, client_id, request):
        _cinfo = self.cdb[client_id].copy()
        try:
            _cinfo["redirect_uris"] = self._tuples_to_uris(_cinfo["redirect_uris"])
        except KeyError:
            pass

        for key, value in request.items():
            if key in ["client_secret", "client_id"]:
                # assure it's the same
                if value != _cinfo[key]:
                    raise ModificationForbidden("Not allowed to change")
            else:
                _cinfo[key] = value

        for key in list(_cinfo.keys()):
            if key in [
                "client_id_issued_at",
                "client_secret_expires_at",
                "registration_access_token",
                "registration_client_uri",
            ]:
                continue
            if key not in request:
                del _cinfo[key]

        if "redirect_uris" in request:
            _cinfo["redirect_uris"] = self._uris_to_tuples(request["redirect_uris"])

        self.cdb[client_id] = _cinfo

    def verify_client(self, environ, areq, authn_method, client_id=""):
        """
        Verify the client based on credentials.

        :param environ: WSGI environ
        :param areq: The request
        :param authn_method: client authentication method
        :return:
        """
        if not client_id:
            client_id = get_client_id(self.cdb, areq, environ["HTTP_AUTHORIZATION"])

        try:
            method = self.client_authn_methods[authn_method]
        except KeyError:
            raise UnSupported()
        return method(self).verify(environ, client_id=client_id)

    def consume_software_statement(self, software_statement):
        return {}

    def registration_endpoint(self, **kwargs):
        """
        Perform dynamic client registration.

        :param request: The request
        :param authn: Client authentication information
        :param kwargs: extra keyword arguments
        :return: A Response instance
        """
        _request = self.server.message_factory.get_request_type(
            "registration_endpoint"
        )().deserialize(kwargs["request"], "json")
        try:
            _request.verify(keyjar=self.keyjar)
        except InvalidRedirectUri as err:
            msg = ClientRegistrationError(
                error="invalid_redirect_uri", error_description="%s" % err
            )
            return BadRequest(msg.to_json(), content="application/json")
        except (MissingPage, VerificationError) as err:
            msg = ClientRegistrationError(
                error="invalid_client_metadata", error_description="%s" % err
            )
            return BadRequest(msg.to_json(), content="application/json")

        # If authentication is necessary at registration
        if self.authn_at_registration:
            try:
                self.verify_client(
                    kwargs["environ"], _request, self.authn_at_registration
                )
            except (AuthnFailure, UnknownAssertionType):
                return Unauthorized()

        client_restrictions = {}  # type: ignore
        if "parsed_software_statement" in _request:
            for ss in _request["parsed_software_statement"]:
                client_restrictions.update(self.consume_software_statement(ss))
            del _request["software_statement"]
            del _request["parsed_software_statement"]

        try:
            client_id = self.create_new_client(_request, client_restrictions)
        except CapabilitiesMisMatch as err:
            msg = ClientRegistrationError(
                error="invalid_client_metadata", error_description="%s" % err
            )
            return BadRequest(msg.to_json(), content="application/json")
        except RestrictionError as err:
            msg = ClientRegistrationError(
                error="invalid_client_metadata", error_description="%s" % err
            )
            return BadRequest(msg.to_json(), content="application/json")

        return self.client_info(client_id)

    def client_info_endpoint(self, method="GET", **kwargs):
        """
        Operations on this endpoint are switched through the use of different HTTP methods.

        :param method: HTTP method used for the request
        :param kwargs: keyword arguments
        :return: A Response instance
        """
        _query = compact(parse_qs(kwargs["query"]))
        try:
            _id = _query["client_id"]
        except KeyError:
            return BadRequest("Missing query component")

        if _id not in self.cdb:
            return Unauthorized()

        # authenticated client
        try:
            self.verify_client(
                kwargs["environ"], kwargs["request"], "bearer_header", client_id=_id
            )
        except (AuthnFailure, UnknownAssertionType):
            return Unauthorized()

        if method == "GET":
            return self.client_info(_id)
        elif method == "PUT":
            try:
                _request = self.server.message_factory.get_request_type(
                    "update_endpoint"
                )().from_json(kwargs["request"])
            except ValueError as err:
                return BadRequest(str(err))

            try:
                _request.verify()
            except InvalidRedirectUri as err:
                msg = ClientRegistrationError(
                    error="invalid_redirect_uri", error_description="%s" % err
                )
                return BadRequest(msg.to_json(), content="application/json")
            except (MissingPage, VerificationError) as err:
                msg = ClientRegistrationError(
                    error="invalid_client_metadata", error_description="%s" % err
                )
                return BadRequest(msg.to_json(), content="application/json")

            try:
                self.client_info_update(_id, _request)
                return self.client_info(_id)
            except ModificationForbidden:
                return Forbidden()
        elif method == "DELETE":
            try:
                del self.cdb[_id]
            except KeyError:
                return Unauthorized()
            else:
                return NoContent()

    @staticmethod
    def verify_code_challenge(
        code_verifier, code_challenge, code_challenge_method="S256"
    ):
        """
        Verify a PKCE (RFC7636) code challenge.

        :param code_verifier: The origin
        :param code_challenge: The transformed verifier used as challenge
        :return:
        """
        _h = CC_METHOD[code_challenge_method](code_verifier.encode("ascii")).digest()
        _cc = b64e(_h)
        if _cc.decode("ascii") != code_challenge:
            logger.error("PCKE Code Challenge check failed")
            err = TokenErrorResponse(
                error="invalid_request", error_description="PCKE check failed"
            )
            return Response(err.to_json(), content="application/json", status_code=401)
        return True

    def do_access_token_response(self, access_token, atinfo, state, refresh_token=None):
        _tinfo = {
            "access_token": access_token,
            "expires_in": atinfo["exp"],
            "token_type": "bearer",
            "state": state,
        }
        try:
            _tinfo["scope"] = atinfo["scope"]
        except KeyError:
            pass

        if refresh_token:
            _tinfo["refresh_token"] = refresh_token

        atr_class = self.server.message_factory.get_response_type("token_endpoint")
        return atr_class(**by_schema(atr_class, **_tinfo))

    def code_grant_type(self, areq):
        # assert that the code is valid
        try:
            _info = self.sdb[areq["code"]]
        except KeyError:
            err = TokenErrorResponse(
                error="invalid_grant", error_description="Unknown access grant"
            )
            return Response(
                err.to_json(), content="application/json", status="401 Unauthorized"
            )

        authzreq = json.loads(_info["authzreq"])
        if "code_verifier" in areq:
            try:
                _method = authzreq["code_challenge_method"]
            except KeyError:
                _method = "S256"

            resp = self.verify_code_challenge(
                areq["code_verifier"], authzreq["code_challenge"], _method
            )
            if resp:
                return resp

        if "state" in areq:
            if self.sdb[areq["code"]]["state"] != areq["state"]:
                logger.error("State value mismatch")
                err = TokenErrorResponse(error="unauthorized_client")
                return Unauthorized(err.to_json(), content="application/json")

        resp = self.token_scope_check(areq, _info)
        if resp:
            return resp

        # If redirect_uri was in the initial authorization request
        # verify that the one given here is the correct one.
        if "redirect_uri" in _info and areq["redirect_uri"] != _info["redirect_uri"]:
            logger.error("Redirect_uri mismatch")
            err = TokenErrorResponse(error="unauthorized_client")
            return Unauthorized(err.to_json(), content="application/json")

        issue_refresh = False
        if "scope" in authzreq and "offline_access" in authzreq["scope"]:
            if authzreq["response_type"] == "code":
                issue_refresh = True

        try:
            _tinfo = self.sdb.upgrade_to_token(
                areq["code"], issue_refresh=issue_refresh
            )
        except AccessCodeUsed:
            err = TokenErrorResponse(
                error="invalid_grant", error_description="Access grant used"
            )
            return Response(
                err.to_json(), content="application/json", status="401 Unauthorized"
            )

        logger.debug("_tinfo: %s" % _tinfo)

        atr_class = self.server.message_factory.get_response_type("token_endpoint")
        atr = atr_class(**by_schema(atr_class, **_tinfo))

        logger.debug("AccessTokenResponse: %s" % atr)

        return Response(
            atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS
        )

    def client_credentials_grant_type(self, areq):
        _at = self.token_handler.get_access_token(
            areq["client_id"], scope=areq["scope"], grant_type="client_credentials"
        )
        _info = self.token_handler.token_factory.get_info(_at)
        try:
            _rt = self.token_handler.get_refresh_token(
                self.baseurl, _info["access_token"], "client_credentials"
            )
        except NotAllowed:
            atr = self.do_access_token_response(_at, _info, areq["state"])
        else:
            atr = self.do_access_token_response(_at, _info, areq["state"], _rt)

        return Response(atr.to_json(), content="application/json")

    def password_grant_type(self, areq):
        """
        Token authorization using Resource owner password credentials.

        RFC6749 section 4.3
        """
        # `Any` comparison tries a first broker, so we either hit an IndexError or get a method
        try:
            authn, authn_class_ref = self.pick_auth(areq, "any")
        except IndexError:
            err = TokenErrorResponse(error="invalid_grant")
            return Unauthorized(err.to_json(), content="application/json")
        identity, _ts = authn.authenticated_as(
            username=areq["username"], password=areq["password"]
        )
        if identity is None:
            err = TokenErrorResponse(error="invalid_grant")
            return Unauthorized(err.to_json(), content="application/json")
        # We are returning a token
        areq["response_type"] = ["token"]
        authn_event = AuthnEvent(
            identity["uid"],
            identity.get("salt", ""),
            authn_info=authn_class_ref,
            time_stamp=_ts,
        )
        sid = self.setup_session(areq, authn_event, self.cdb[areq["client_id"]])
        _at = self.sdb.upgrade_to_token(self.sdb[sid]["code"], issue_refresh=True)
        atr_class = self.server.message_factory.get_response_type("token_endpoint")
        atr = atr_class(**by_schema(atr_class, **_at))
        return Response(
            atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS
        )

    def refresh_token_grant_type(self, areq):
        at = self.token_handler.refresh_access_token(
            self.baseurl, areq["access_token"], "refresh_token"
        )

        atr_class = self.server.message_factory.get_response_type("token_endpoint")
        atr = atr_class(**by_schema(atr_class, **at))
        return Response(atr.to_json(), content="application/json")

    @staticmethod
    def token_access(endpoint, client_id, token_info):
        # simple rules: if client_id in azp or aud it's allow to introspect
        # to revoke it has to be in azr
        allow = False
        if endpoint == "revocation_endpoint":
            if "azr" in token_info and client_id == token_info["azr"]:
                allow = True
            elif len(token_info["aud"]) == 1 and token_info["aud"] == [client_id]:
                allow = True
        else:  # has to be introspection endpoint
            if "azr" in token_info and client_id == token_info["azr"]:
                allow = True
            elif "aud" in token_info:
                if client_id in token_info["aud"]:
                    allow = True
        return allow

    def get_token_info(self, authn, req, endpoint):
        """
        Parse token for information.

        :param authn:
        :param req:
        :return:
        """
        try:
            client_id = self.client_authn(self, req, authn)
        except FailedAuthentication as err:
            logger.error(err)
            error = TokenErrorResponse(
                error="unauthorized_client", error_description="%s" % err
            )
            return Response(
                error.to_json(), content="application/json", status="401 Unauthorized"
            )

        logger.debug("{}: {} requesting {}".format(endpoint, client_id, req.to_dict()))

        try:
            token_type = req["token_type_hint"]
        except KeyError:
            try:
                _info = self.sdb.token_factory["access_token"].get_info(req["token"])
            except Exception:
                try:
                    _info = self.sdb.token_factory["refresh_token"].get_info(
                        req["token"]
                    )
                except Exception:
                    return self._return_inactive()
                else:
                    token_type = "refresh_token"
            else:
                token_type = "access_token"
        else:
            try:
                _info = self.sdb.token_factory[token_type].get_info(req["token"])
            except Exception:
                return self._return_inactive()

        if not self.token_access(endpoint, client_id, _info):
            return BadRequest()

        return client_id, token_type, _info

    def _return_inactive(self):
        ir = self.server.message_factory.get_response_type("introspection_endpoint")(
            active=False
        )
        return Response(ir.to_json(), content="application/json")

    def revocation_endpoint(self, authn="", request=None, **kwargs):
        """
        Implement RFC7009 allows a client to invalidate an access or refresh token.

        :param authn: Client Authentication information
        :param request: The revocation request
        :param kwargs:
        :return:
        """
        trr = self.server.message_factory.get_request_type(
            "revocation_endpoint"
        )().deserialize(request, "urlencoded")

        resp = self.get_token_info(authn, trr, "revocation_endpoint")

        if isinstance(resp, Response):
            return resp
        else:
            client_id, token_type, _info = resp

        logger.info("{} token revocation: {}".format(client_id, trr.to_dict()))

        try:
            self.sdb.token_factory[token_type].invalidate(trr["token"])
        except KeyError:
            return BadRequest()
        else:
            return Response("OK")

    def introspection_endpoint(self, authn="", request=None, **kwargs):
        """
        Implement RFC7662.

        :param authn: Client Authentication information
        :param request: The introspection request
        :param kwargs:
        :return:
        """
        tir = self.server.message_factory.get_request_type(
            "introspection_endpoint"
        )().deserialize(request, "urlencoded")

        resp = self.get_token_info(authn, tir, "introspection_endpoint")

        if isinstance(resp, Response):
            return resp
        else:
            client_id, token_type, _info = resp

        logger.info("{} token introspection: {}".format(client_id, tir.to_dict()))

        ir = self.server.message_factory.get_response_type("introspection_endpoint")(
            active=self.sdb.token_factory[token_type].is_valid(_info), **_info.to_dict()
        )

        ir.weed()

        return Response(ir.to_json(), content="application/json")
Exemplo n.º 11
0
class TestTokenHandler(object):
    @pytest.fixture(autouse=True)
    def create_handler(self):
        self.th = TokenHandler(
            'https://example.com/as',
            {
                'access_token': {
                    'https://example.org/rp': {
                        'client_credentials': 1200
                    }
                },
                'refresh_token': {
                    'https://example.org/rp': {
                        'client_credentials': 86400
                    }
                }
            },
            keyjar=KEYJAR
        )

    def test_construct_access_token(self):
        token = self.th.get_access_token('https://example.org/rp', 'foo bar',
                                         'client_credentials')

        assert token

        info = self.th.token_factory.get_info(token)

        assert _eq(list(info.keys()),
                   ['jti', 'scope', 'exp', 'iss', 'aud', 'iat', 'kid', 'azp'])

    def test_construct_access_token_fail(self):
        # Unknown client
        try:
            self.th.get_access_token('https://example.com/rp', 'foo bar',
                                     'client_credentials')
        except NotAllowed:
            pass
        # wrong grant_type
        try:
            self.th.get_access_token('https://example.org/rp', 'foo bar',
                                     'implicit')
        except NotAllowed:
            pass

    def test_from_access_to_refresh_token(self):

        token = self.th.get_access_token('https://example.org/rp', 'foo bar',
                                         'client_credentials')

        refresh_token = self.th.refresh_access_token(
            'https://example.org/rp', token, 'client_credentials')

        assert refresh_token

    def test_construct_refresh_token(self):
        sid = '1234'
        rtoken = self.th.get_refresh_token('https://example.org/rp',
                                           grant_type='client_credentials',
                                           sid=sid)

        info = self.th.token_factory.get_info(rtoken)

        assert _eq(list(info.keys()),
                   ['jti', 'exp', 'iss', 'aud', 'iat', 'kid', 'azp'])

        assert self.th.refresh_token_factory.db[info['jti']] == sid
Exemplo n.º 12
0
class TestTokenHandler(object):
    @pytest.fixture(autouse=True)
    def create_handler(self):
        self.th = TokenHandler(
            "https://example.com/as",
            {
                "access_token": {
                    "https://example.org/rp": {
                        "client_credentials": 1200
                    }
                },
                "refresh_token": {
                    "https://example.org/rp": {
                        "client_credentials": 86400
                    }
                },
            },
            keyjar=KEYJAR,
        )

    def test_construct_access_token(self):
        token = self.th.get_access_token("https://example.org/rp", "foo bar",
                                         "client_credentials")

        assert token

        info = self.th.token_factory.get_info(token)

        assert _eq(
            list(info.keys()),
            ["jti", "scope", "exp", "iss", "aud", "iat", "kid", "azp"],
        )

    def test_construct_access_token_fail(self):
        # Unknown client
        try:
            self.th.get_access_token("https://example.com/rp", "foo bar",
                                     "client_credentials")
        except NotAllowed:
            pass
        # wrong grant_type
        try:
            self.th.get_access_token("https://example.org/rp", "foo bar",
                                     "implicit")
        except NotAllowed:
            pass

    def test_from_access_to_refresh_token(self):

        token = self.th.get_access_token("https://example.org/rp", "foo bar",
                                         "client_credentials")

        refresh_token = self.th.refresh_access_token("https://example.org/rp",
                                                     token,
                                                     "client_credentials")

        assert refresh_token

    def test_construct_refresh_token(self):
        sid = "1234"
        rtoken = self.th.get_refresh_token("https://example.org/rp",
                                           grant_type="client_credentials",
                                           sid=sid)

        info = self.th.token_factory.get_info(rtoken)

        assert _eq(list(info.keys()),
                   ["jti", "exp", "iss", "aud", "iat", "kid", "azp"])

        assert self.th.refresh_token_factory.db[info["jti"]] == sid