예제 #1
0
    def test_provider(self):
        provider_info = {"jwks_uri": "https://connect-op.herokuapp.com/jwks.json"}
        ks = KeyJar()
        ks.load_keys(provider_info, "https://connect-op.heroku.com")

        with responses.RequestsMock() as rsps:
            rsps.add(
                responses.GET,
                "https://connect-op.herokuapp.com/jwks.json",
                json={
                    "keys": [
                        {
                            "kty": "RSA",
                            "e": "AQAB",
                            "n": "pKybs0WaHU_y4cHxWbm8Wzj66HtcyFn7Fh3n-99qTXu5yNa30MRYIYfSDwe9JVc1JUoGw41yq2StdGBJ"
                            "40HxichjE-Yopfu3B58QlgJvToUbWD4gmTDGgMGxQxtv1En2yedaynQ73sDpIK-12JJDY55pvf-PCiSQ9OjxZ"
                            "LiVGKlClDus44_uv2370b9IN2JiEOF-a7JBqaTEYLPpXaoKWDSnJNonr79tL0T7iuJmO1l705oO3Y0TQ-INLY"
                            "6jnKG_RpsvyvGNnwP9pMvcP1phKsWZ10ofuuhJGRp8IxQL9RfzT87OvF0RBSO1U73h09YP-corWDsnKIi6Tbz"
                            "RpN5YDw",
                            "use": "sig",
                            "kid": "default",
                        }
                    ]
                },
            )
            assert ks["https://connect-op.heroku.com"][0].keys()
예제 #2
0
    def test_provider(self):
        provider_info = {
            "jwks_uri": "https://connect-op.herokuapp.com/jwks.json",
        }

        ks = KeyJar()
        ks.load_keys(provider_info, "https://connect-op.heroku.com")

        assert ks["https://connect-op.heroku.com"][0].keys()
예제 #3
0
    def test_provider(self):
        provider_info = {
            "jwks_uri": "https://connect-op.herokuapp.com/jwks.json",
        }

        ks = KeyJar()
        ks.load_keys(provider_info, "https://connect-op.heroku.com")

        assert ks["https://connect-op.heroku.com"][0].keys()
예제 #4
0
class Provider(provider.Provider):
    """
    A OAuth2 RP that knows all the OAuth2 extensions I've implemented
    """

    def __init__(self, name, sdb, cdb, authn_broker, authz, client_authn,
                 symkey="", urlmap=None, iv=0, default_scope="",
                 ca_bundle=None, seed=b"", client_authn_methods=None,
                 authn_at_registration="", client_info_url="",
                 secret_lifetime=86400, jwks_uri='', keyjar=None,
                 capabilities=None, verify_ssl=True, baseurl='', hostname='',
                 config=None, behavior=None):

        if not name.endswith("/"):
            name += "/"
        provider.Provider.__init__(self, name, sdb, cdb, authn_broker, authz,
                                   client_authn, symkey, urlmap, iv,
                                   default_scope, ca_bundle)

        self.endp.extend([RegistrationEndpoint, ClientInfoEndpoint,
                          RevocationEndpoint, IntrospectionEndpoint])

        # dictionary of client authentication methods
        self.client_authn_methods = client_authn_methods
        if authn_at_registration:
            if authn_at_registration not in client_authn_methods:
                raise UnknownAuthnMethod(authn_at_registration)

        self.authn_at_registration = authn_at_registration
        self.seed = seed
        self.client_info_url = client_info_url
        self.secret_lifetime = secret_lifetime
        self.jwks_uri = jwks_uri
        self.verify_ssl = verify_ssl

        self.keyjar = keyjar
        if self.keyjar is None:
            self.keyjar = KeyJar(verify_ssl=self.verify_ssl)

        if capabilities:
            self.capabilities = self.provider_features(
                provider_config=capabilities)
        else:
            self.capabilities = self.provider_features()
        self.baseurl = baseurl
        self.hostname = hostname or socket.gethostname()
        self.kid = {"sig": {}, "enc": {}}
        self.config = config
        self.behavior = behavior or {}

    @staticmethod
    def _uris_to_tuples(uris):
        tup = []
        for uri in uris:
            base, query = splitquery(uri)
            if query:
                tup.append((base, query))
            else:
                tup.append((base, ""))
        return tup

    @staticmethod
    def _tuples_to_uris(items):
        _uri = []
        for url, query in items:
            if query:
                _uri.append("%s?%s" % (url, query))
            else:
                _uri.append(url)
        return _uri

    def load_keys(self, request, client_id, client_secret):
        try:
            self.keyjar.load_keys(request, client_id)
            try:
                logger.debug("keys for %s: [%s]" % (
                    client_id,
                    ",".join(["%s" % x for x in self.keyjar[client_id]])))
            except KeyError:
                pass
        except Exception as err:
            logger.error("Failed to load client keys: %s" % request.to_dict())
            logger.error("%s", err)
            err = ClientRegistrationError(
                error="invalid_configuration_parameter",
                error_description="%s" % err)
            return Response(err.to_json(), content="application/json",
                            status="400 Bad Request")

        # Add the client_secret as a symmetric key to the keyjar
        _kc = KeyBundle([{"kty": "oct", "key": client_secret,
                          "use": "ver"},
                         {"kty": "oct", "key": client_secret,
                          "use": "sig"}])
        try:
            self.keyjar[client_id].append(_kc)
        except KeyError:
            self.keyjar[client_id] = [_kc]

    @staticmethod
    def verify_correct(cinfo, restrictions):
        for fname, arg in restrictions.items():
            func = restrict.factory(fname)
            res = func(arg, cinfo)
            if res:
                raise RestrictionError(res)

    def create_new_client(self, request, restrictions):
        """

        :param request: The Client registration request
        :param restrictions: Restrictions on the client
        :return: The client_id
        """

        _cinfo = request.to_dict()

        self.match_client_request(_cinfo)

        # create new id and secret
        _id = rndstr(12)
        while _id in self.cdb:
            _id = rndstr(12)

        _cinfo["client_id"] = _id
        _cinfo["client_secret"] = secret(self.seed, _id)
        _cinfo["client_id_issued_at"] = utc_time_sans_frac()
        _cinfo["client_secret_expires_at"] = utc_time_sans_frac() + \
                                             self.secret_lifetime

        # If I support client info endpoint
        if ClientInfoEndpoint in self.endp:
            _cinfo["registration_access_token"] = rndstr(32)
            _cinfo["registration_client_uri"] = "%s%s%s?client_id=%s" % (
                self.name, self.client_info_url, ClientInfoEndpoint.etype,
                _id)

        if "redirect_uris" in request:
            _cinfo["redirect_uris"] = self._uris_to_tuples(
                request["redirect_uris"])

        self.load_keys(request, _id, _cinfo["client_secret"])

        try:
            _behav = self.behavior['client_registration']
        except KeyError:
            pass
        else:
            self.verify_correct(_cinfo, _behav)

        self.cdb[_id] = _cinfo

        return _id

    def match_client_request(self, request):
        for _pref, _prov in PREFERENCE2PROVIDER.items():
            if _pref in request:
                if _pref == "response_types":
                    for val in request[_pref]:
                        match = False
                        p = set(val.split(" "))
                        for cv in self.capabilities[_prov]:
                            if p == set(cv.split(' ')):
                                match = True
                                break
                        if not match:
                            raise CapabilitiesMisMatch(
                                'Not allowed {}'.format(_pref))
                else:
                    if isinstance(request[_pref], six.string_types):
                        if request[_pref] not in self.capabilities[_prov]:
                            raise CapabilitiesMisMatch(
                                'Not allowed {}'.format(_pref))
                    else:
                        if not set(request[_pref]).issubset(
                                set(self.capabilities[_prov])):
                            raise CapabilitiesMisMatch(
                                'Not allowed {}'.format(_pref))

    def client_info(self, client_id):
        _cinfo = self.cdb[client_id].copy()
        try:
            _cinfo["redirect_uris"] = self._tuples_to_uris(
                _cinfo["redirect_uris"])
        except KeyError:
            pass

        msg = ClientInfoResponse(**_cinfo)
        return Response(msg.to_json(), content="application/json")

    def client_info_update(self, client_id, request):
        _cinfo = self.cdb[client_id].copy()
        try:
            _cinfo["redirect_uris"] = self._tuples_to_uris(
                _cinfo["redirect_uris"])
        except KeyError:
            pass

        for key, value in request.items():
            if key in ["client_secret", "client_id"]:
                # assure it's the same
                try:
                    assert value == _cinfo[key]
                except AssertionError:
                    raise ModificationForbidden("Not allowed to change")
            else:
                _cinfo[key] = value

        for key in list(_cinfo.keys()):
            if key in ["client_id_issued_at", "client_secret_expires_at",
                       "registration_access_token", "registration_client_uri"]:
                continue
            if key not in request:
                del _cinfo[key]

        if "redirect_uris" in request:
            _cinfo["redirect_uris"] = self._uris_to_tuples(
                request["redirect_uris"])

        self.cdb[client_id] = _cinfo

    def verify_client(self, environ, areq, authn_method, client_id=""):
        """

        :param environ: WSGI environ
        :param areq: The request
        :param authn_method: client authentication method
        :return:
        """

        if not client_id:
            client_id = get_client_id(self.cdb, areq,
                                      environ["HTTP_AUTHORIZATION"])

        try:
            method = self.client_authn_methods[authn_method]
        except KeyError:
            raise UnSupported()
        return method(self).verify(environ, client_id=client_id)

    def consume_software_statement(self, software_statement):
        return {}

    def registration_endpoint(self, **kwargs):
        """

        :param request: The request
        :param authn: Client authentication information
        :param kwargs: extra keyword arguments
        :return: A Response instance
        """

        _request = RegistrationRequest().deserialize(kwargs['request'], "json")
        try:
            _request.verify(keyjar=self.keyjar)
        except InvalidRedirectUri as err:
            msg = ClientRegistrationError(error="invalid_redirect_uri",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")
        except (MissingPage, VerificationError) as err:
            msg = ClientRegistrationError(error="invalid_client_metadata",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")

        # authenticated client
        if self.authn_at_registration:
            try:
                _ = self.verify_client(kwargs['environ'], _request,
                                       self.authn_at_registration)
            except (AuthnFailure, UnknownAssertionType):
                return Unauthorized()

        client_restrictions = {}
        if 'parsed_software_statement' in _request:
            for ss in _request['parsed_software_statement']:
                client_restrictions.update(self.consume_software_statement(ss))
            del _request['software_statement']
            del _request['parsed_software_statement']

        try:
            client_id = self.create_new_client(_request, client_restrictions)
        except CapabilitiesMisMatch as err:
            msg = ClientRegistrationError(error="invalid_client_metadata",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")
        except RestrictionError as err:
            msg = ClientRegistrationError(error="invalid_client_metadata",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")

        return self.client_info(client_id)

    def client_info_endpoint(self, method="GET", **kwargs):
        """
        Operations on this endpoint are switched through the use of different
        HTTP methods

        :param method: HTTP method used for the request
        :param kwargs: keyword arguments
        :return: A Response instance
        """

        _query = parse_qs(kwargs['query'])
        try:
            _id = _query["client_id"][0]
        except KeyError:
            return BadRequest("Missing query component")

        try:
            assert _id in self.cdb
        except AssertionError:
            return Unauthorized()

        # authenticated client
        try:
            _ = self.verify_client(kwargs['environ'], kwargs['request'],
                                   "bearer_header", client_id=_id)
        except (AuthnFailure, UnknownAssertionType):
            return Unauthorized()

        if method == "GET":
            return self.client_info(_id)
        elif method == "PUT":
            try:
                _request = ClientUpdateRequest().from_json(kwargs['request'])
            except ValueError as err:
                return BadRequest(str(err))

            try:
                _request.verify()
            except InvalidRedirectUri as err:
                msg = ClientRegistrationError(error="invalid_redirect_uri",
                                              error_description="%s" % err)
                return BadRequest(msg.to_json(), content="application/json")
            except (MissingPage, VerificationError) as err:
                msg = ClientRegistrationError(error="invalid_client_metadata",
                                              error_description="%s" % err)
                return BadRequest(msg.to_json(), content="application/json")

            try:
                self.client_info_update(_id, _request)
                return self.client_info(_id)
            except ModificationForbidden:
                return Forbidden()
        elif method == "DELETE":
            try:
                del self.cdb[_id]
            except KeyError:
                return Unauthorized()
            else:
                return NoContent()

    def provider_features(self, pcr_class=ASConfigurationResponse,
                          provider_config=None):
        """
        Specifies what the server capabilities are.

        :param pcr_class:
        :return: ProviderConfigurationResponse instance
        """

        _provider_info = pcr_class(**CAPABILITIES)

        _scopes = list(SCOPE2CLAIMS.keys())
        _provider_info["scopes_supported"] = _scopes

        sign_algs = list(jws.SIGNER_ALGS.keys())
        # Remove 'none' for token_endpoint_auth_signing_alg_values_supported
        # since it is not allowed
        sign_algs = sign_algs[:]
        sign_algs.remove('none')
        _provider_info[
            "token_endpoint_auth_signing_alg_values_supported"] = sign_algs

        if provider_config:
            _provider_info.update(provider_config)

        return _provider_info

    def verify_capabilities(self, capabilities):
        """
        Verify that what the admin wants the server to do actually
        can be done by this implementation.

        :param capabilities: The asked for capabilities as a dictionary
        or a ProviderConfigurationResponse instance. The later can be
        treated as a dictionary.
        :return: True or False
        """
        _pinfo = self.provider_features()
        for key, val in capabilities.items():
            if isinstance(val, six.string_types):
                try:
                    if val in _pinfo[key]:
                        continue
                    else:
                        return False
                except KeyError:
                    return False

        return True

    def create_providerinfo(self, pcr_class=ASConfigurationResponse,
                            setup=None):
        """
        Dynamically create the provider info response
        :param pcr_class:
        :param setup:
        :return:
        """

        _provider_info = self.capabilities

        if self.jwks_uri and self.keyjar:
            _provider_info["jwks_uri"] = self.jwks_uri

        for endp in self.endp:
            # _log_info("# %s, %s" % (endp, endp.name))
            _provider_info['{}_endpoint'.format(endp.etype)] = os.path.join(
                self.baseurl, endp.url)

        if setup and isinstance(setup, dict):
            for key in pcr_class.c_param.keys():
                if key in setup:
                    _provider_info[key] = setup[key]

        _provider_info["issuer"] = self.baseurl
        _provider_info["version"] = "3.0"

        return _provider_info

    def providerinfo_endpoint(self, **kwargs):
        _log_info = logger.info

        _log_info("@providerinfo_endpoint")
        try:
            _response = self.create_providerinfo()
            _log_info("provider_info_response: %s" % (_response.to_dict(),))

            headers = [("Cache-Control", "no-store"), ("x-ffo", "bar")]
            if 'handle' in kwargs:
                (key, timestamp) = kwargs['handle']
                if key.startswith(STR) and key.endswith(STR):
                    cookie = self.cookie_func(key, self.cookie_name, "pinfo",
                                              self.sso_ttl)
                    headers.append(cookie)

            resp = Response(_response.to_json(), content="application/json",
                            headers=headers)
        except Exception as err:
            message = traceback.format_exception(*sys.exc_info())
            logger.error(message)
            resp = Response(message, content="html/text")

        return resp

    @staticmethod
    def verify_code_challenge(code_verifier, code_challenge,
                              code_challenge_method='S256'):
        """
        Verify a PKCE (RFC7636) code challenge

        :param code_verifier: The origin
        :param code_challenge: The transformed verifier used as challenge
        :return:
        """
        _h = CC_METHOD[code_challenge_method](
            code_verifier.encode()).hexdigest()
        _cc = b64e(_h.encode())
        if _cc.decode() != code_challenge:
            logger.error('PCKE Code Challenge check failed')
            err = TokenErrorResponse(error="invalid_request",
                                     error_description="PCKE check failed")
            return Response(err.to_json(), content="application/json",
                            status="401 Unauthorized")
        return True

    def token_endpoint(self, authn="", **kwargs):
        """
        This is where clients come to get their access tokens
        """

        _sdb = self.sdb

        logger.debug("- token -")
        body = kwargs["request"]
        logger.debug("body: %s" % body)

        areq = AccessTokenRequest().deserialize(body, "urlencoded")

        try:
            client_id = self.client_authn(self, areq, authn)
        except FailedAuthentication as err:
            err = TokenErrorResponse(error="unauthorized_client",
                                     error_description="%s" % err)
            return Response(err.to_json(), content="application/json",
                            status="401 Unauthorized")

        logger.debug("AccessTokenRequest: %s" % areq)

        # assert that the code is valid
        try:
            _info = _sdb[areq["code"]]
        except KeyError:
            err = TokenErrorResponse(error="invalid_grant",
                                     error_description="Unknown access grant")
            return Response(err.to_json(), content="application/json",
                            status="401 Unauthorized")

        authzreq = json.loads(_info['authzreq'])
        if 'code_verifier' in areq:
            try:
                _method = authzreq['code_challenge_method']
            except KeyError:
                _method = 'S256'

            resp = self.verify_code_challenge(areq['code_verifier'],
                                              authzreq['code_challenge'],
                                              _method)
            if resp:
                return resp

        try:
            assert areq["grant_type"] == "authorization_code"
        except AssertionError:
            err = TokenErrorResponse(error="invalid_request",
                                     error_description="Wrong grant type")
            return Response(err.to_json(), content="application/json",
                            status="401 Unauthorized")

        if 'state_hash' in areq:
            # have to get the token to get at the state
            code = areq['code']
            shash = base64.urlsafe_b64encode(
                hashlib.sha256(
                    self.sdb[code]['state'].encode('utf8')).digest())
            if shash.decode('ascii') != areq['state_hash']:
                err = TokenErrorResponse(error="unauthorized_client")
                return Unauthorized(err.to_json(), content="application/json")

        resp = self.token_scope_check(areq, _info)
        if resp:
            return resp

        # If redirect_uri was in the initial authorization request
        # verify that the one given here is the correct one.
        if "redirect_uri" in _info:
            assert areq["redirect_uri"] == _info["redirect_uri"]

        issue_refresh = False
        if 'scope' in authzreq and 'offline_access' in authzreq['scope']:
            if authzreq['response_type'] == 'code':
                issue_refresh = True

        try:
            _tinfo = _sdb.upgrade_to_token(areq["code"],
                                           issue_refresh=issue_refresh)
        except AccessCodeUsed:
            err = TokenErrorResponse(error="invalid_grant",
                                     error_description="Access grant used")
            return Response(err.to_json(), content="application/json",
                            status="401 Unauthorized")

        logger.debug("_tinfo: %s" % _tinfo)

        atr = AccessTokenResponse(**by_schema(AccessTokenResponse, **_tinfo))

        logger.debug("AccessTokenResponse: %s" % atr)

        return Response(atr.to_json(), content="application/json")

    def key_setup(self, local_path, vault="keys", sig=None, enc=None):
        """
        my keys
        :param local_path: The path to where the JWKs should be stored
        :param vault: Where the private key will be stored
        :param sig: Key for signature
        :param enc: Key for encryption
        :return: A URL the RP can use to download the key.
        """
        self.jwks_uri = key_export(self.baseurl, local_path, vault, self.keyjar,
                                   fqdn=self.hostname, sig=sig, enc=enc)

    @staticmethod
    def token_access(endpoint, client_id, token_info):
        # simple rules: if client_id in azp or aud it's allow to introspect
        # to revoke it has to be in azr
        allow = False
        if endpoint == 'revocation_endpoint':
            if 'azr' in token_info and client_id == token_info['azr']:
                allow = True
        else:  # has to be introspection endpoint
            if 'azr' in token_info and client_id == token_info['azr']:
                allow = True
            elif 'aud' in token_info:
                if client_id in token_info['aud']:
                    allow = True
        return allow

    def get_token_info(self, authn, req, endpoint):
        """

        :param authn:
        :param req:
        :return:
        """
        try:
            client_id = self.client_authn(self, req, authn)
        except FailedAuthentication as err:
            err = TokenErrorResponse(error="unauthorized_client",
                                     error_description="%s" % err)
            return Response(err.to_json(), content="application/json",
                            status="401 Unauthorized")

        logger.debug('{}: {} requesting {}'.format(endpoint, client_id,
                                                   req.to_dict()))

        try:
            token_type = req['token_type_hint']
        except KeyError:
            try:
                _info = self.sdb.token_factory['access_token'].info(
                    req['token'])
            except KeyError:
                try:
                    _info = self.sdb.token_factory['refresh_token'].get_info(
                        req['token'])
                except KeyError:
                    raise
                else:
                    token_type = 'refresh_token'
            else:
                token_type = 'access_token'
        else:
            try:
                _info = self.sdb.token_factory[token_type].get_info(
                    req['token'])
            except KeyError:
                raise

        if not self.token_access(endpoint, client_id, _info):
            return BadRequest()

        return client_id, token_type, _info

    def revocation_endpoint(self, authn='', request=None, **kwargs):
        """
        Implements RFC7009 allows a client to invalidate an access or refresh
        token.

        :param authn: Client Authentication information
        :param request: The revocation request
        :param kwargs:
        :return:
        """

        trr = TokenRevocationRequest().deserialize(request, "urlencoded")

        resp = self.get_token_info(authn, trr, 'revocation_endpoint')

        if isinstance(resp, Response):
            return resp
        else:
            client_id, token_type, _info = resp

        logger.info('{} token revocation: {}'.format(client_id, trr.to_dict()))

        try:
            self.sdb.token_factory[token_type].invalidate(trr['token'])
        except KeyError:
            return BadRequest()
        else:
            return Response('OK')

    def introspection_endpoint(self, authn='', request=None, **kwargs):
        """
        Implements RFC7662

        :param authn: Client Authentication information
        :param request: The introspection request
        :param kwargs:
        :return:
        """

        tir = TokenIntrospectionRequest().deserialize(request, "urlencoded")

        resp = self.get_token_info(authn, tir, 'introspection_endpoint')

        if isinstance(resp, Response):
            return resp
        else:
            client_id, token_type, _info = resp

        logger.info('{} token introspection: {}'.format(client_id,
                                                        tir.to_dict()))

        ir = TokenIntrospectionResponse(
            active=self.sdb.token_factory[token_type].is_valid(_info),
            **_info.to_dict())

        ir.weed()

        return Response(ir.to_json(), content="application/json")
예제 #5
0
class Client(PBase):
    _endpoints = ENDPOINTS

    def __init__(self,
                 client_id=None,
                 ca_certs=None,
                 client_authn_method=None,
                 keyjar=None,
                 verify_ssl=True,
                 config=None,
                 client_cert=None):
        """

        :param client_id: The client identifier
        :param ca_certs: Certificates used to verify HTTPS certificates
        :param client_authn_method: Methods that this client can use to
            authenticate itself. It's a dictionary with method names as
            keys and method classes as values.
        :param verify_ssl: Whether the SSL certificate should be verified.
        :return: Client instance
        """

        PBase.__init__(self,
                       ca_certs,
                       verify_ssl=verify_ssl,
                       client_cert=client_cert,
                       keyjar=keyjar)

        self.client_id = client_id
        self.client_authn_method = client_authn_method
        self.verify_ssl = verify_ssl
        # self.secret_type = "basic "

        # self.state = None
        self.nonce = None

        self.grant = {}
        self.state2nonce = {}
        # own endpoint
        self.redirect_uris = [None]

        # service endpoints
        self.authorization_endpoint = None
        self.token_endpoint = None
        self.token_revocation_endpoint = None

        self.request2endpoint = REQUEST2ENDPOINT
        self.response2error = RESPONSE2ERROR
        self.grant_class = Grant
        self.token_class = Token

        self.provider_info = {}
        self._c_secret = None
        self.kid = {"sig": {}, "enc": {}}
        self.authz_req = None

        # the OAuth issuer is the URL of the authorization server's
        # configuration information location
        self.config = config or {}
        try:
            self.issuer = self.config['issuer']
        except KeyError:
            self.issuer = ''
        self.allow = {}
        self.provider_info = {}

    def store_response(self, clinst, text):
        pass

    def get_client_secret(self):
        return self._c_secret

    def set_client_secret(self, val):
        if not val:
            self._c_secret = ""
        else:
            self._c_secret = val
            # client uses it for signing
            # Server might also use it for signing which means the
            # client uses it for verifying server signatures
            if self.keyjar is None:
                self.keyjar = KeyJar()
            self.keyjar.add_symmetric("", str(val))

    client_secret = property(get_client_secret, set_client_secret)

    def reset(self):
        # self.state = None
        self.nonce = None

        self.grant = {}

        self.authorization_endpoint = None
        self.token_endpoint = None
        self.redirect_uris = None

    def grant_from_state(self, state):
        for key, grant in self.grant.items():
            if key == state:
                return grant

        return None

    def _parse_args(self, request, **kwargs):
        ar_args = kwargs.copy()

        for prop in request.c_param.keys():
            if prop in ar_args:
                continue
            else:
                if prop == "redirect_uri":
                    _val = getattr(self, "redirect_uris", [None])[0]
                    if _val:
                        ar_args[prop] = _val
                else:
                    _val = getattr(self, prop, None)
                    if _val:
                        ar_args[prop] = _val

        return ar_args

    def _endpoint(self, endpoint, **kwargs):
        try:
            uri = kwargs[endpoint]
            if uri:
                del kwargs[endpoint]
        except KeyError:
            uri = ""

        if not uri:
            try:
                uri = getattr(self, endpoint)
            except Exception:
                raise MissingEndpoint("No '%s' specified" % endpoint)

        if not uri:
            raise MissingEndpoint("No '%s' specified" % endpoint)

        return uri

    def get_grant(self, state, **kwargs):
        # try:
        # _state = kwargs["state"]
        # if not _state:
        #         _state = self.state
        # except KeyError:
        #     _state = self.state

        try:
            return self.grant[state]
        except KeyError:
            raise GrantError("No grant found for state:'%s'" % state)

    def get_token(self, also_expired=False, **kwargs):
        try:
            return kwargs["token"]
        except KeyError:
            grant = self.get_grant(**kwargs)

            try:
                token = grant.get_token(kwargs["scope"])
            except KeyError:
                token = grant.get_token("")
                if not token:
                    try:
                        token = self.grant[kwargs["state"]].get_token("")
                    except KeyError:
                        raise TokenError("No token found for scope")

        if token is None:
            raise TokenError("No suitable token found")

        if also_expired:
            return token
        elif token.is_valid():
            return token
        else:
            raise TokenError("Token has expired")

    def construct_request(self, request, request_args=None, extra_args=None):
        if request_args is None:
            request_args = {}

        # logger.debug("request_args: %s" % sanitize(request_args))
        kwargs = self._parse_args(request, **request_args)

        if extra_args:
            kwargs.update(extra_args)
            # logger.debug("kwargs: %s" % sanitize(kwargs))
        # logger.debug("request: %s" % sanitize(request))
        return request(**kwargs)

    def construct_Message(self,
                          request=Message,
                          request_args=None,
                          extra_args=None,
                          **kwargs):

        return self.construct_request(request, request_args, extra_args)

    def construct_AuthorizationRequest(self,
                                       request=AuthorizationRequest,
                                       request_args=None,
                                       extra_args=None,
                                       **kwargs):

        if request_args is not None:
            try:  # change default
                new = request_args["redirect_uri"]
                if new:
                    self.redirect_uris = [new]
            except KeyError:
                pass
        else:
            request_args = {}

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id

        return self.construct_request(request, request_args, extra_args)

    def construct_AccessTokenRequest(self,
                                     request=AccessTokenRequest,
                                     request_args=None,
                                     extra_args=None,
                                     **kwargs):

        grant = self.get_grant(**kwargs)

        if not grant.is_valid():
            raise GrantExpired(
                "Authorization Code to old %s > %s" %
                (utc_time_sans_frac(), grant.grant_expiration_time))

        if request_args is None:
            request_args = {}

        request_args["code"] = grant.code

        try:
            request_args['state'] = kwargs['state']
        except KeyError:
            pass

        if "grant_type" not in request_args:
            request_args["grant_type"] = "authorization_code"

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id
        return self.construct_request(request, request_args, extra_args)

    def construct_RefreshAccessTokenRequest(self,
                                            request=RefreshAccessTokenRequest,
                                            request_args=None,
                                            extra_args=None,
                                            **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(also_expired=True, **kwargs)

        request_args["refresh_token"] = token.refresh_token

        try:
            request_args["scope"] = token.scope
        except AttributeError:
            pass

        return self.construct_request(request, request_args, extra_args)

    # def construct_TokenRevocationRequest(self,
    #                                      request=TokenRevocationRequest,
    #                                      request_args=None, extra_args=None,
    #                                      **kwargs):
    #
    #     if request_args is None:
    #         request_args = {}
    #
    #     token = self.get_token(**kwargs)
    #
    #     request_args["token"] = token.access_token
    #     return self.construct_request(request, request_args, extra_args)

    def construct_ResourceRequest(self,
                                  request=ResourceRequest,
                                  request_args=None,
                                  extra_args=None,
                                  **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(**kwargs)

        request_args["access_token"] = token.access_token
        return self.construct_request(request, request_args, extra_args)

    def uri_and_body(self,
                     reqmsg,
                     cis,
                     method="POST",
                     request_args=None,
                     **kwargs):

        if "endpoint" in kwargs and kwargs["endpoint"]:
            uri = kwargs["endpoint"]
        else:
            uri = self._endpoint(self.request2endpoint[reqmsg.__name__],
                                 **request_args)

        uri, body, kwargs = get_or_post(uri, method, cis, **kwargs)
        try:
            h_args = {"headers": kwargs["headers"]}
        except KeyError:
            h_args = {}

        return uri, body, h_args, cis

    def request_info(self,
                     request,
                     method="POST",
                     request_args=None,
                     extra_args=None,
                     lax=False,
                     **kwargs):

        if request_args is None:
            request_args = {}

        try:
            cls = getattr(self, "construct_%s" % request.__name__)
            cis = cls(request_args=request_args,
                      extra_args=extra_args,
                      **kwargs)
        except AttributeError:
            cis = self.construct_request(request, request_args, extra_args)

        if self.events:
            self.events.store('Protocol request', cis)

        if 'nonce' in cis and 'state' in cis:
            self.state2nonce[cis['state']] = cis['nonce']

        cis.lax = lax

        if "authn_method" in kwargs:
            h_arg = self.init_authentication_method(cis,
                                                    request_args=request_args,
                                                    **kwargs)
        else:
            h_arg = None

        if h_arg:
            if "headers" in kwargs.keys():
                kwargs["headers"].update(h_arg["headers"])
            else:
                kwargs["headers"] = h_arg["headers"]

        return self.uri_and_body(request, cis, method, request_args, **kwargs)

    def authorization_request_info(self,
                                   request_args=None,
                                   extra_args=None,
                                   **kwargs):
        return self.request_info(AuthorizationRequest, "GET", request_args,
                                 extra_args, **kwargs)

    def get_urlinfo(self, info):
        if '?' in info or '#' in info:
            parts = urlparse(info)
            scheme, netloc, path, params, query, fragment = parts[:6]
            # either query of fragment
            if query:
                info = query
            else:
                info = fragment
        return info

    def parse_response(self,
                       response,
                       info="",
                       sformat="json",
                       state="",
                       **kwargs):
        """
        Parse a response

        :param response: Response type
        :param info: The response, can be either in a JSON or an urlencoded
            format
        :param sformat: Which serialization that was used
        :param state: The state
        :param kwargs: Extra key word arguments
        :return: The parsed and to some extend verified response
        """

        _r2e = self.response2error

        if sformat == "urlencoded":
            info = self.get_urlinfo(info)

        # if self.events:
        #    self.events.store('Response', info)
        resp = response().deserialize(info, sformat, **kwargs)
        msg = 'Initial response parsing => "{}"'
        logger.debug(msg.format(sanitize(resp.to_dict())))
        if self.events:
            self.events.store('Response', resp.to_dict())

        if "error" in resp and not isinstance(resp, ErrorResponse):
            resp = None
            try:
                errmsgs = _r2e[response.__name__]
            except KeyError:
                errmsgs = [ErrorResponse]

            try:
                for errmsg in errmsgs:
                    try:
                        resp = errmsg().deserialize(info, sformat)
                        resp.verify()
                        break
                    except Exception:
                        resp = None
            except KeyError:
                pass
        elif resp.only_extras():
            resp = None
        else:
            kwargs["client_id"] = self.client_id
            try:
                kwargs['iss'] = self.provider_info['issuer']
            except (KeyError, AttributeError):
                if self.issuer:
                    kwargs['iss'] = self.issuer

            if "key" not in kwargs and "keyjar" not in kwargs:
                kwargs["keyjar"] = self.keyjar

            logger.debug("Verify response with {}".format(sanitize(kwargs)))
            verf = resp.verify(**kwargs)

            if not verf:
                logger.error('Verification of the response failed')
                raise PyoidcError("Verification of the response failed")
            if resp.type() == "AuthorizationResponse" and "scope" not in resp:
                try:
                    resp["scope"] = kwargs["scope"]
                except KeyError:
                    pass

        if not resp:
            logger.error('Missing or faulty response')
            raise ResponseError("Missing or faulty response")

        self.store_response(resp, info)

        if resp.type() in ["AuthorizationResponse", "AccessTokenResponse"]:
            try:
                _state = resp["state"]
            except (AttributeError, KeyError):
                _state = ""

            if not _state:
                _state = state

            try:
                self.grant[_state].update(resp)
            except KeyError:
                self.grant[_state] = self.grant_class(resp=resp)

        return resp

    def init_authentication_method(self,
                                   cis,
                                   authn_method,
                                   request_args=None,
                                   http_args=None,
                                   **kwargs):

        if http_args is None:
            http_args = {}
        if request_args is None:
            request_args = {}

        if authn_method:
            return self.client_authn_method[authn_method](self).construct(
                cis, request_args, http_args, **kwargs)
        else:
            return http_args

    def parse_request_response(self,
                               reqresp,
                               response,
                               body_type,
                               state="",
                               **kwargs):

        if reqresp.status_code in SUCCESSFUL:
            body_type = verify_header(reqresp, body_type)
        elif reqresp.status_code in [302, 303]:  # redirect
            return reqresp
        elif reqresp.status_code == 500:
            logger.error("(%d) %s" %
                         (reqresp.status_code, sanitize(reqresp.text)))
            raise ParseError("ERROR: Something went wrong: %s" % reqresp.text)
        elif reqresp.status_code in [400, 401]:
            # expecting an error response
            if issubclass(response, ErrorResponse):
                pass
        else:
            logger.error("(%d) %s" %
                         (reqresp.status_code, sanitize(reqresp.text)))
            raise HttpError("HTTP ERROR: %s [%s] on %s" %
                            (reqresp.text, reqresp.status_code, reqresp.url))

        if response:
            if body_type == 'txt':
                # no meaning trying to parse unstructured text
                return reqresp.text
            return self.parse_response(response, reqresp.text, body_type,
                                       state, **kwargs)

        # could be an error response
        if reqresp.status_code in [200, 400, 401]:
            if body_type == 'txt':
                body_type = 'urlencoded'
            try:
                err = ErrorResponse().deserialize(reqresp.message,
                                                  method=body_type)
                try:
                    err.verify()
                except PyoidcError:
                    pass
                else:
                    return err
            except Exception:
                pass

        return reqresp

    def request_and_return(self,
                           url,
                           response=None,
                           method="GET",
                           body=None,
                           body_type="json",
                           state="",
                           http_args=None,
                           **kwargs):
        """
        :param url: The URL to which the request should be sent
        :param response: Response type
        :param method: Which HTTP method to use
        :param body: A message body if any
        :param body_type: The format of the body of the return message
        :param http_args: Arguments for the HTTP client
        :return: A cls or ErrorResponse instance or the HTTP response
            instance if no response body was expected.
        """

        if http_args is None:
            http_args = {}

        try:
            resp = self.http_request(url, method, data=body, **http_args)
        except Exception:
            raise

        if "keyjar" not in kwargs:
            kwargs["keyjar"] = self.keyjar

        return self.parse_request_response(resp, response, body_type, state,
                                           **kwargs)

    def do_authorization_request(self,
                                 request=AuthorizationRequest,
                                 state="",
                                 body_type="",
                                 method="GET",
                                 request_args=None,
                                 extra_args=None,
                                 http_args=None,
                                 response_cls=AuthorizationResponse,
                                 **kwargs):

        if state:
            try:
                request_args["state"] = state
            except TypeError:
                request_args = {"state": state}

        kwargs['authn_endpoint'] = 'authorization'
        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        try:
            self.authz_req[request_args["state"]] = csi
        except TypeError:
            pass

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        try:
            algs = kwargs["algs"]
        except KeyError:
            algs = {}

        resp = self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       state=state,
                                       http_args=http_args,
                                       algs=algs)

        if isinstance(resp, Message):
            if resp.type() in RESPONSE2ERROR["AuthorizationResponse"]:
                resp.state = csi.state

        return resp

    def do_access_token_request(self,
                                request=AccessTokenRequest,
                                scope="",
                                state="",
                                body_type="json",
                                method="POST",
                                request_args=None,
                                extra_args=None,
                                http_args=None,
                                response_cls=AccessTokenResponse,
                                authn_method="",
                                **kwargs):

        kwargs['authn_endpoint'] = 'token'
        # method is default POST
        url, body, ht_args, csi = self.request_info(request,
                                                    method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope,
                                                    state=state,
                                                    authn_method=authn_method,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        if self.events is not None:
            self.events.store('request_url', url)
            self.events.store('request_http_args', http_args)
            self.events.store('Request', body)

        logger.debug("<do_access_token> URL: %s, Body: %s" %
                     (url, sanitize(body)))
        logger.debug("<do_access_token> response_cls: %s" % response_cls)

        return self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       state=state,
                                       http_args=http_args,
                                       **kwargs)

    def do_access_token_refresh(self,
                                request=RefreshAccessTokenRequest,
                                state="",
                                body_type="json",
                                method="POST",
                                request_args=None,
                                extra_args=None,
                                http_args=None,
                                response_cls=AccessTokenResponse,
                                authn_method="",
                                **kwargs):

        token = self.get_token(also_expired=True, state=state, **kwargs)
        kwargs['authn_endpoint'] = 'refresh'
        url, body, ht_args, csi = self.request_info(request,
                                                    method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    token=token,
                                                    authn_method=authn_method)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        return self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       state=state,
                                       http_args=http_args)

    # def do_revocate_token(self, request=TokenRevocationRequest,
    #                       scope="", state="", body_type="json", method="POST",
    #                       request_args=None, extra_args=None, http_args=None,
    #                       response_cls=None, authn_method=""):
    #
    #     url, body, ht_args, csi = self.request_info(request, method=method,
    #                                                 request_args=request_args,
    #                                                 extra_args=extra_args,
    #                                                 scope=scope, state=state,
    #                                                 authn_method=authn_method)
    #
    #     if http_args is None:
    #         http_args = ht_args
    #     else:
    #         http_args.update(ht_args)
    #
    #     return self.request_and_return(url, response_cls, method, body,
    #                                    body_type, state=state,
    #                                    http_args=http_args)

    def do_any(self,
               request,
               endpoint="",
               scope="",
               state="",
               body_type="json",
               method="POST",
               request_args=None,
               extra_args=None,
               http_args=None,
               response=None,
               authn_method=""):

        url, body, ht_args, csi = self.request_info(request,
                                                    method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope,
                                                    state=state,
                                                    authn_method=authn_method,
                                                    endpoint=endpoint)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        return self.request_and_return(url,
                                       response,
                                       method,
                                       body,
                                       body_type,
                                       state=state,
                                       http_args=http_args)

    def fetch_protected_resource(self,
                                 uri,
                                 method="GET",
                                 headers=None,
                                 state="",
                                 **kwargs):

        if "token" in kwargs and kwargs["token"]:
            token = kwargs["token"]
            request_args = {"access_token": token}
        else:
            try:
                token = self.get_token(state=state, **kwargs)
            except ExpiredToken:
                # The token is to old, refresh
                self.do_access_token_refresh()
                token = self.get_token(state=state, **kwargs)
            request_args = {"access_token": token.access_token}

        if headers is None:
            headers = {}

        if "authn_method" in kwargs:
            http_args = self.init_authentication_method(
                request_args=request_args, **kwargs)
        else:
            # If nothing defined this is the default
            http_args = self.client_authn_method["bearer_header"](
                self).construct(request_args=request_args)

        headers.update(http_args["headers"])

        logger.debug("Fetch URI: %s" % uri)
        return self.http_request(uri, method, headers=headers)

    def add_code_challenge(self):
        """
        PKCE RFC 7636 support

        :return:
        """
        try:
            cv_len = self.config['code_challenge']['length']
        except KeyError:
            cv_len = 64  # Use default

        code_verifier = unreserved(cv_len)
        _cv = code_verifier.encode()

        try:
            _method = self.config['code_challenge']['method']
        except KeyError:
            _method = 'S256'

        try:
            _h = CC_METHOD[_method](_cv).hexdigest()
            code_challenge = b64e(_h.encode()).decode()
        except KeyError:
            raise Unsupported('PKCE Transformation method:{}'.format(_method))

        # TODO store code_verifier

        return {
            "code_challenge": code_challenge,
            "code_challenge_method": _method
        }, code_verifier

    def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True):
        """
        Deal with Provider Config Response
        :param pcr: The ProviderConfigResponse instance
        :param issuer: The one I thought should be the issuer of the config
        :param keys: Should I deal with keys
        :param endpoints: Should I deal with endpoints, that is store them
        as attributes in self.
        """

        if "issuer" in pcr:
            _pcr_issuer = pcr["issuer"]
            if pcr["issuer"].endswith("/"):
                if issuer.endswith("/"):
                    _issuer = issuer
                else:
                    _issuer = issuer + "/"
            else:
                if issuer.endswith("/"):
                    _issuer = issuer[:-1]
                else:
                    _issuer = issuer

            try:
                self.allow["issuer_mismatch"]
            except KeyError:
                try:
                    assert _issuer == _pcr_issuer
                except AssertionError:
                    raise PyoidcError(
                        "provider info issuer mismatch '%s' != '%s'" %
                        (_issuer, _pcr_issuer))

            self.provider_info = pcr
        else:
            _pcr_issuer = issuer

        self.issuer = _pcr_issuer

        if endpoints:
            for key, val in pcr.items():
                if key.endswith("_endpoint"):
                    setattr(self, key, val)

        if keys:
            if self.keyjar is None:
                self.keyjar = KeyJar()

            self.keyjar.load_keys(pcr, _pcr_issuer)

    def provider_config(self,
                        issuer,
                        keys=True,
                        endpoints=True,
                        response_cls=ASConfigurationResponse,
                        serv_pattern=OIDCONF_PATTERN):
        if issuer.endswith("/"):
            _issuer = issuer[:-1]
        else:
            _issuer = issuer

        url = serv_pattern % _issuer

        pcr = None
        r = self.http_request(url)
        if r.status_code == 200:
            pcr = response_cls().from_json(r.text)
        elif r.status_code == 302:
            while r.status_code == 302:
                r = self.http_request(r.headers["location"])
                if r.status_code == 200:
                    pcr = response_cls().from_json(r.text)
                    break

        if pcr is None:
            raise PyoidcError("Trying '%s', status %s" % (url, r.status_code))

        self.handle_provider_config(pcr, issuer, keys, endpoints)

        return pcr
예제 #6
0
파일: dynreg.py 프로젝트: zofuthan/pyoidc
class Client(oauth2.Client):
    def __init__(self, client_id=None, ca_certs=None,
                 client_authn_method=None, keyjar=None):
        oauth2.Client.__init__(self, client_id=client_id, ca_certs=ca_certs,
                               client_authn_method=client_authn_method,
                               keyjar=keyjar)
        self.allow = {}
        self.request2endpoint.update({
            "RegistrationRequest": "registration_endpoint",
            "ClientUpdateRequest": "clientinfo_endpoint"
        })
        self.registration_response = None

    def construct_RegistrationRequest(self, request=RegistrationRequest,
                                      request_args=None, extra_args=None,
                                      **kwargs):

        if request_args is None:
            request_args = {}

        return self.construct_request(request, request_args, extra_args)

    def do_client_registration(self, request=RegistrationRequest,
                               body_type="", method="GET",
                               request_args=None, extra_args=None,
                               http_args=None,
                               response_cls=ClientInfoResponse,
                               **kwargs):

        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        resp = self.request_and_return(url, response_cls, method, body,
                                       body_type, http_args=http_args)

        return resp

    def do_client_read_request(self, request=ClientUpdateRequest,
                               body_type="", method="GET",
                               request_args=None, extra_args=None,
                               http_args=None,
                               response_cls=ClientInfoResponse,
                               **kwargs):

        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        resp = self.request_and_return(url, response_cls, method, body,
                                       body_type, http_args=http_args)

        return resp

    def do_client_update_request(self, request=ClientUpdateRequest,
                                 body_type="", method="PUT",
                                 request_args=None, extra_args=None,
                                 http_args=None,
                                 response_cls=ClientInfoResponse,
                                 **kwargs):

        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        resp = self.request_and_return(url, response_cls, method, body,
                                       body_type, http_args=http_args)

        return resp

    def do_client_delete_request(self, request=ClientUpdateRequest,
                                 body_type="", method="DELETE",
                                 request_args=None, extra_args=None,
                                 http_args=None,
                                 response_cls=ClientInfoResponse,
                                 **kwargs):

        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        resp = self.request_and_return(url, response_cls, method, body,
                                       body_type, http_args=http_args)

        return resp

    def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True):
        """
        Deal with Provider Config Response
        :param pcr: The ProviderConfigResponse instance
        :param issuer: The one I thought should be the issuer of the config
        :param keys: Should I deal with keys
        :param endpoints: Should I deal with endpoints, that is store them
            as attributes in self.
        """

        if "issuer" in pcr:
            _pcr_issuer = pcr["issuer"]
            if pcr["issuer"].endswith("/"):
                if issuer.endswith("/"):
                    _issuer = issuer
                else:
                    _issuer = issuer + "/"
            else:
                if issuer.endswith("/"):
                    _issuer = issuer[:-1]
                else:
                    _issuer = issuer

            try:
                _ = self.allow["issuer_mismatch"]
            except KeyError:
                try:
                    assert _issuer == _pcr_issuer
                except AssertionError:
                    raise PyoidcError(
                        "provider info issuer mismatch '%s' != '%s'" % (
                            _issuer, _pcr_issuer))

            self.provider_info[_pcr_issuer] = pcr
        else:
            _pcr_issuer = issuer

        if endpoints:
            for key, val in pcr.items():
                if key.endswith("_endpoint"):
                    setattr(self, key, val)

        if keys:
            if self.keyjar is None:
                self.keyjar = KeyJar()

            self.keyjar.load_keys(pcr, _pcr_issuer)

    def provider_config(self, issuer, keys=True, endpoints=True,
                        response_cls=ProviderConfigurationResponse,
                        serv_pattern=OIDCONF_PATTERN):
        if issuer.endswith("/"):
            _issuer = issuer[:-1]
        else:
            _issuer = issuer

        url = serv_pattern % _issuer

        pcr = None
        r = self.http_request(url)
        if r.status_code == 200:
            pcr = response_cls().from_json(r.text)
        elif r.status_code == 302:
            while r.status_code == 302:
                r = self.http_request(r.headers["location"])
                if r.status_code == 200:
                    pcr = response_cls().from_json(r.text)
                    break

        if pcr is None:
            raise PyoidcError("Trying '%s', status %s" % (url, r.status_code))

        self.handle_provider_config(pcr, issuer, keys, endpoints)

        return pcr

    def store_registration_info(self, reginfo):
        self.registration_response = reginfo
        self.client_secret = reginfo["client_secret"]
        self.client_id = reginfo["client_id"]
        self.redirect_uris = reginfo["redirect_uris"]

    def handle_registration_info(self, response):
        if response.status_code == 200:
            resp = ClientInfoResponse().deserialize(response.text, "json")
            self.store_registration_info(resp)
        else:
            err = ErrorResponse().deserialize(response.text, "json")
            raise PyoidcError("Registration failed: %s" % err.get_json())

        return resp

    def register(self, url, **kwargs):
        """
        Register the client at an OP

        :param url: The OPs registration endpoint
        :param kwargs: parameters to the registration request
        :return:
        """
        req = self.construct_RegistrationRequest(request_args=kwargs)

        headers = {"content-type": "application/json"}

        rsp = self.http_request(url, "POST", data=req.to_json(),
                                headers=headers)

        return self.handle_registration_info(rsp)

    def parse_authz_response(self, query):
        aresp = self.parse_response(AuthorizationResponse,
                                    info=query,
                                    sformat="urlencoded",
                                    keyjar=self.keyjar)
        if aresp.type() == "ErrorResponse":
            logger.info("ErrorResponse: %s" % aresp)
            raise AuthzError(aresp.error)

        logger.info("Aresp: %s" % aresp)

        return aresp
예제 #7
0
파일: __init__.py 프로젝트: lacoski/pyoidc
class Client(PBase):
    _endpoints = ENDPOINTS

    def __init__(
        self,
        client_id=None,
        client_authn_method=None,
        keyjar=None,
        verify_ssl=True,
        config=None,
        client_cert=None,
        timeout=5,
        message_factory: Type[MessageFactory] = OauthMessageFactory,
    ):
        """
        Initialize the instance.

        :param client_id: The client identifier
        :param client_authn_method: Methods that this client can use to
            authenticate itself. It's a dictionary with method names as
            keys and method classes as values.
        :param keyjar: The keyjar for this client.
        :param verify_ssl: Whether the SSL certificate should be verified.
        :param client_cert: A client certificate to use.
        :param timeout: Timeout for requests library. Can be specified either as
            a single integer or as a tuple of integers. For more details, refer to
            ``requests`` documentation.
        :param: message_factory: Factory for message classes, should inherit from OauthMessageFactory
        :return: Client instance
        """
        PBase.__init__(
            self,
            verify_ssl=verify_ssl,
            keyjar=keyjar,
            client_cert=client_cert,
            timeout=timeout,
        )

        self.client_id = client_id
        self.client_authn_method = client_authn_method

        self.nonce = None

        self.message_factory = message_factory
        self.grant = {}  # type: Dict[str, Grant]
        self.state2nonce = {}  # type: Dict[str, str]
        # own endpoint
        self.redirect_uris = []  # type: List[str]

        # service endpoints
        self.authorization_endpoint = None  # type: Optional[str]
        self.token_endpoint = None  # type: Optional[str]
        self.token_revocation_endpoint = None  # type: Optional[str]

        self.request2endpoint = REQUEST2ENDPOINT
        self.response2error = RESPONSE2ERROR  # type: Dict[str, List]
        self.grant_class = Grant
        self.token_class = Token

        self.provider_info = ASConfigurationResponse()  # type: Message
        self._c_secret = ""  # type: str
        self.kid = {"sig": {}, "enc": {}}  # type: Dict[str, Dict]
        self.authz_req = {}  # type: Dict[str, Message]

        # the OAuth issuer is the URL of the authorization server's
        # configuration information location
        self.config = config or {}
        try:
            self.issuer = self.config["issuer"]
        except KeyError:
            self.issuer = ""
        self.allow = {}  # type: Dict[str, Any]

    def store_response(self, clinst, text):
        pass

    def get_client_secret(self) -> str:
        return self._c_secret

    def set_client_secret(self, val: str):
        if not val:
            self._c_secret = ""
        else:
            self._c_secret = val
            # client uses it for signing
            # Server might also use it for signing which means the
            # client uses it for verifying server signatures
            if self.keyjar is None:
                self.keyjar = KeyJar()
            self.keyjar.add_symmetric("", str(val))

    client_secret = property(get_client_secret, set_client_secret)

    def reset(self) -> None:
        self.nonce = None

        self.grant = {}

        self.authorization_endpoint = None
        self.token_endpoint = None
        self.redirect_uris = []

    def grant_from_state(self, state: str) -> Optional[Grant]:
        for key, grant in self.grant.items():
            if key == state:
                return grant

        return None

    def _parse_args(self, request: Type[Message], **kwargs) -> Dict:
        ar_args = kwargs.copy()

        for prop in request.c_param.keys():
            if prop in ar_args:
                continue
            else:
                if prop == "redirect_uri":
                    _val = getattr(self, "redirect_uris", [None])[0]
                    if _val:
                        ar_args[prop] = _val
                else:
                    _val = getattr(self, prop, None)
                    if _val:
                        ar_args[prop] = _val

        return ar_args

    def _endpoint(self, endpoint: str, **kwargs) -> str:
        try:
            uri = kwargs[endpoint]
            if uri:
                del kwargs[endpoint]
        except KeyError:
            uri = ""

        if not uri:
            try:
                uri = getattr(self, endpoint)
            except Exception:
                raise MissingEndpoint("No '%s' specified" % endpoint)

        if not uri:
            raise MissingEndpoint("No '%s' specified" % endpoint)

        return uri

    def get_grant(self, state: str, **kwargs) -> Grant:
        try:
            return self.grant[state]
        except KeyError:
            raise GrantError("No grant found for state:'%s'" % state)

    def get_token(self, also_expired: bool = False, **kwargs) -> Token:
        try:
            return kwargs["token"]
        except KeyError:
            grant = self.get_grant(**kwargs)

            try:
                token = grant.get_token(kwargs["scope"])
            except KeyError:
                token = grant.get_token("")
                if not token:
                    try:
                        token = self.grant[kwargs["state"]].get_token("")
                    except KeyError:
                        raise TokenError("No token found for scope")

        if token is None:
            raise TokenError("No suitable token found")

        if also_expired:
            return token
        elif token.is_valid():
            return token
        else:
            raise TokenError("Token has expired")

    def clean_tokens(self) -> None:
        """Clean replaced and invalid tokens."""
        for state in self.grant:
            grant = self.get_grant(state)
            for token in grant.tokens:
                if token.replaced or not token.is_valid():
                    grant.delete_token(token)

    def construct_request(
        self, request: Type[Message], request_args=None, extra_args=None
    ):
        if request_args is None:
            request_args = {}

        kwargs = self._parse_args(request, **request_args)

        if extra_args:
            kwargs.update(extra_args)
        logger.debug("request: %s" % sanitize(request))
        return request(**kwargs)

    def construct_Message(
        self,
        request: Type[Message] = Message,
        request_args=None,
        extra_args=None,
        **kwargs
    ) -> Message:

        return self.construct_request(request, request_args, extra_args)

    def construct_AuthorizationRequest(
        self,
        request: Type[AuthorizationRequest] = None,
        request_args=None,
        extra_args=None,
        **kwargs
    ) -> AuthorizationRequest:

        if request is None:
            request = self.message_factory.get_request_type("authorization_endpoint")
        if request_args is not None:
            try:  # change default
                new = request_args["redirect_uri"]
                if new:
                    self.redirect_uris = [new]
            except KeyError:
                pass
        else:
            request_args = {}

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id

        return self.construct_request(request, request_args, extra_args)

    def construct_AccessTokenRequest(
        self,
        request: Type[AccessTokenRequest] = None,
        request_args=None,
        extra_args=None,
        **kwargs
    ) -> AccessTokenRequest:

        if request is None:
            request = self.message_factory.get_request_type("token_endpoint")
        if request_args is None:
            request_args = {}
        if request is not ROPCAccessTokenRequest:
            grant = self.get_grant(**kwargs)

            if not grant.is_valid():
                raise GrantExpired(
                    "Authorization Code to old %s > %s"
                    % (utc_time_sans_frac(), grant.grant_expiration_time)
                )

            request_args["code"] = grant.code

        try:
            request_args["state"] = kwargs["state"]
        except KeyError:
            pass

        if "grant_type" not in request_args:
            request_args["grant_type"] = "authorization_code"

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id
        return self.construct_request(request, request_args, extra_args)

    def construct_RefreshAccessTokenRequest(
        self,
        request: Type[RefreshAccessTokenRequest] = None,
        request_args=None,
        extra_args=None,
        **kwargs
    ) -> RefreshAccessTokenRequest:

        if request is None:
            request = self.message_factory.get_request_type("refresh_endpoint")
        if request_args is None:
            request_args = {}

        token = self.get_token(also_expired=True, **kwargs)

        request_args["refresh_token"] = token.refresh_token

        try:
            request_args["scope"] = token.scope
        except AttributeError:
            pass

        return self.construct_request(request, request_args, extra_args)

    def construct_ResourceRequest(
        self,
        request: Type[ResourceRequest] = None,
        request_args=None,
        extra_args=None,
        **kwargs
    ) -> ResourceRequest:

        if request is None:
            request = self.message_factory.get_request_type("resource_endpoint")
        if request_args is None:
            request_args = {}

        token = self.get_token(**kwargs)

        request_args["access_token"] = token.access_token
        return self.construct_request(request, request_args, extra_args)

    def uri_and_body(
        self,
        reqmsg: Type[Message],
        cis: Message,
        method="POST",
        request_args=None,
        **kwargs
    ) -> Tuple[str, str, Dict, Message]:
        if "endpoint" in kwargs and kwargs["endpoint"]:
            uri = kwargs["endpoint"]
        else:
            uri = self._endpoint(self.request2endpoint[reqmsg.__name__], **request_args)

        uri, body, kwargs = get_or_post(uri, method, cis, **kwargs)
        try:
            h_args = {"headers": kwargs["headers"]}
        except KeyError:
            h_args = {}

        return uri, body, h_args, cis

    def request_info(
        self,
        request: Type[Message],
        method="POST",
        request_args=None,
        extra_args=None,
        lax=False,
        **kwargs
    ) -> Tuple[str, str, Dict, Message]:

        if request_args is None:
            request_args = {}

        try:
            cls = getattr(self, "construct_%s" % request.__name__)
            cis = cls(request_args=request_args, extra_args=extra_args, **kwargs)
        except AttributeError:
            cis = self.construct_request(request, request_args, extra_args)

        if self.events:
            self.events.store("Protocol request", cis)

        if "nonce" in cis and "state" in cis:
            self.state2nonce[cis["state"]] = cis["nonce"]

        cis.lax = lax

        if "authn_method" in kwargs:
            h_arg = self.init_authentication_method(
                cis, request_args=request_args, **kwargs
            )
        else:
            h_arg = None

        if h_arg:
            if "headers" in kwargs.keys():
                kwargs["headers"].update(h_arg["headers"])
            else:
                kwargs["headers"] = h_arg["headers"]

        return self.uri_and_body(request, cis, method, request_args, **kwargs)

    def authorization_request_info(self, request_args=None, extra_args=None, **kwargs):
        return self.request_info(
            self.message_factory.get_request_type("authorization_endpoint"),
            "GET",
            request_args,
            extra_args,
            **kwargs
        )

    @staticmethod
    def get_urlinfo(info: str) -> str:
        if "?" in info or "#" in info:
            parts = urlparse(info)
            scheme, netloc, path, params, query, fragment = parts[:6]
            # either query of fragment
            if query:
                info = query
            else:
                info = fragment
        return info

    def parse_response(
        self,
        response: Type[Message],
        info: str = "",
        sformat: ENCODINGS = "json",
        state: str = "",
        **kwargs
    ) -> Message:
        """
        Parse a response.

        :param response: Response type
        :param info: The response, can be either in a JSON or an urlencoded
            format
        :param sformat: Which serialization that was used
        :param state: The state
        :param kwargs: Extra key word arguments
        :return: The parsed and to some extend verified response
        """
        _r2e = self.response2error

        if sformat == "urlencoded":
            info = self.get_urlinfo(info)

        resp = response().deserialize(info, sformat, **kwargs)
        msg = 'Initial response parsing => "{}"'
        logger.debug(msg.format(sanitize(resp.to_dict())))
        if self.events:
            self.events.store("Response", resp.to_dict())

        if "error" in resp and not isinstance(resp, ErrorResponse):
            resp = None
            errmsgs = []  # type: List[Any]
            try:
                errmsgs = _r2e[response.__name__]
            except KeyError:
                errmsgs = [ErrorResponse]

            try:
                for errmsg in errmsgs:
                    try:
                        resp = errmsg().deserialize(info, sformat)
                        resp.verify()
                        break
                    except Exception:
                        resp = None
            except KeyError:
                pass
        elif resp.only_extras():
            resp = None
        else:
            kwargs["client_id"] = self.client_id
            try:
                kwargs["iss"] = self.provider_info["issuer"]
            except (KeyError, AttributeError):
                if self.issuer:
                    kwargs["iss"] = self.issuer

            if "key" not in kwargs and "keyjar" not in kwargs:
                kwargs["keyjar"] = self.keyjar

            logger.debug("Verify response with {}".format(sanitize(kwargs)))
            verf = resp.verify(**kwargs)

            if not verf:
                logger.error("Verification of the response failed")
                raise PyoidcError("Verification of the response failed")
            if resp.type() == "AuthorizationResponse" and "scope" not in resp:
                try:
                    resp["scope"] = kwargs["scope"]
                except KeyError:
                    pass

        if not resp:
            logger.error("Missing or faulty response")
            raise ResponseError("Missing or faulty response")

        self.store_response(resp, info)

        if isinstance(resp, (AuthorizationResponse, AccessTokenResponse)):
            try:
                _state = resp["state"]
            except (AttributeError, KeyError):
                _state = ""

            if not _state:
                _state = state

            try:
                self.grant[_state].update(resp)
            except KeyError:
                self.grant[_state] = self.grant_class(resp=resp)

        return resp

    def init_authentication_method(
        self, cis, authn_method, request_args=None, http_args=None, **kwargs
    ):

        if http_args is None:
            http_args = {}
        if request_args is None:
            request_args = {}

        if authn_method:
            return self.client_authn_method[authn_method](self).construct(
                cis, request_args, http_args, **kwargs
            )
        else:
            return http_args

    def parse_request_response(self, reqresp, response, body_type, state="", **kwargs):

        if reqresp.status_code in SUCCESSFUL:
            body_type = verify_header(reqresp, body_type)
        elif reqresp.status_code in [302, 303]:  # redirect
            return reqresp
        elif reqresp.status_code == 500:
            logger.error("(%d) %s" % (reqresp.status_code, sanitize(reqresp.text)))
            raise ParseError("ERROR: Something went wrong: %s" % reqresp.text)
        elif reqresp.status_code in [400, 401]:
            # expecting an error response
            if issubclass(response, ErrorResponse):
                pass
        else:
            logger.error("(%d) %s" % (reqresp.status_code, sanitize(reqresp.text)))
            raise HttpError(
                "HTTP ERROR: %s [%s] on %s"
                % (reqresp.text, reqresp.status_code, reqresp.url)
            )

        if response:
            if body_type == "txt":
                # no meaning trying to parse unstructured text
                return reqresp.text
            return self.parse_response(
                response, reqresp.text, body_type, state, **kwargs
            )

        # could be an error response
        if reqresp.status_code in [200, 400, 401]:
            if body_type == "txt":
                body_type = "urlencoded"
            try:
                err = ErrorResponse().deserialize(reqresp.message, method=body_type)
                try:
                    err.verify()
                except PyoidcError:
                    pass
                else:
                    return err
            except Exception:
                pass

        return reqresp

    def request_and_return(
        self,
        url: str,
        response: Type[Message] = None,
        method="GET",
        body=None,
        body_type: ENCODINGS = "json",
        state: str = "",
        http_args=None,
        **kwargs
    ):
        """
        Perform a request and return the response.

        :param url: The URL to which the request should be sent
        :param response: Response type
        :param method: Which HTTP method to use
        :param body: A message body if any
        :param body_type: The format of the body of the return message
        :param http_args: Arguments for the HTTP client
        :return: A cls or ErrorResponse instance or the HTTP response instance if no response body was expected.
        """
        # FIXME: Cannot annotate return value as Message since it disrupts all other methods
        if http_args is None:
            http_args = {}

        try:
            resp = self.http_request(url, method, data=body, **http_args)
        except Exception:
            raise

        if "keyjar" not in kwargs:
            kwargs["keyjar"] = self.keyjar

        return self.parse_request_response(resp, response, body_type, state, **kwargs)

    def do_authorization_request(
        self,
        state="",
        body_type="",
        method="GET",
        request_args=None,
        extra_args=None,
        http_args=None,
        **kwargs
    ) -> AuthorizationResponse:

        request = self.message_factory.get_request_type("authorization_endpoint")
        response_cls = self.message_factory.get_response_type("authorization_endpoint")

        if state:
            try:
                request_args["state"] = state
            except TypeError:
                request_args = {"state": state}

        kwargs["authn_endpoint"] = "authorization"
        url, body, ht_args, csi = self.request_info(
            request, method, request_args, extra_args, **kwargs
        )

        try:
            self.authz_req[request_args["state"]] = csi
        except TypeError:
            pass

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        try:
            algs = kwargs["algs"]
        except KeyError:
            algs = {}

        resp = self.request_and_return(
            url,
            response_cls,
            method,
            body,
            body_type,
            state=state,
            http_args=http_args,
            algs=algs,
        )

        if isinstance(resp, Message):
            # FIXME: The Message classes do not have classical attrs
            if resp.type() in RESPONSE2ERROR["AuthorizationResponse"]:  # type: ignore
                resp.state = csi.state  # type: ignore

        return resp

    def do_access_token_request(
        self,
        scope: str = "",
        state: str = "",
        body_type: ENCODINGS = "json",
        method="POST",
        request_args=None,
        extra_args=None,
        http_args=None,
        authn_method="",
        **kwargs
    ) -> AccessTokenResponse:

        request = self.message_factory.get_request_type("token_endpoint")
        response_cls = self.message_factory.get_response_type("token_endpoint")

        kwargs["authn_endpoint"] = "token"
        # method is default POST
        url, body, ht_args, csi = self.request_info(
            request,
            method=method,
            request_args=request_args,
            extra_args=extra_args,
            scope=scope,
            state=state,
            authn_method=authn_method,
            **kwargs
        )

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)
            http_args.pop("password", None)

        if self.events is not None:
            self.events.store("request_url", url)
            self.events.store("request_http_args", http_args)
            self.events.store("Request", body)

        logger.debug("<do_access_token> URL: %s, Body: %s" % (url, sanitize(body)))
        logger.debug("<do_access_token> response_cls: %s" % response_cls)

        return self.request_and_return(
            url,
            response_cls,
            method,
            body,
            body_type,
            state=state,
            http_args=http_args,
            **kwargs
        )

    def do_access_token_refresh(
        self,
        state: str = "",
        body_type: ENCODINGS = "json",
        method="POST",
        request_args=None,
        extra_args=None,
        http_args=None,
        authn_method="",
        **kwargs
    ) -> AccessTokenResponse:

        request = self.message_factory.get_request_type("refresh_endpoint")
        response_cls = self.message_factory.get_response_type("refresh_endpoint")

        token = self.get_token(also_expired=True, state=state, **kwargs)
        kwargs["authn_endpoint"] = "refresh"
        url, body, ht_args, csi = self.request_info(
            request,
            method=method,
            request_args=request_args,
            extra_args=extra_args,
            token=token,
            authn_method=authn_method,
        )

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        response = self.request_and_return(
            url, response_cls, method, body, body_type, state=state, http_args=http_args
        )
        if token.replaced:
            grant = self.get_grant(state)
            grant.delete_token(token)
        return response

    def do_any(
        self,
        request: Type[Message],
        endpoint="",
        scope="",
        state="",
        body_type="json",
        method="POST",
        request_args=None,
        extra_args=None,
        http_args=None,
        response: Type[Message] = None,
        authn_method="",
    ) -> Message:

        url, body, ht_args, _ = self.request_info(
            request,
            method=method,
            request_args=request_args,
            extra_args=extra_args,
            scope=scope,
            state=state,
            authn_method=authn_method,
            endpoint=endpoint,
        )

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        return self.request_and_return(
            url, response, method, body, body_type, state=state, http_args=http_args
        )

    def fetch_protected_resource(
        self, uri, method="GET", headers=None, state="", **kwargs
    ):

        if "token" in kwargs and kwargs["token"]:
            token = kwargs["token"]
            request_args = {"access_token": token}
        else:
            try:
                token = self.get_token(state=state, **kwargs)
            except ExpiredToken:
                # The token is to old, refresh
                self.do_access_token_refresh(state=state)
                token = self.get_token(state=state, **kwargs)
            request_args = {"access_token": token.access_token}

        if headers is None:
            headers = {}

        if "authn_method" in kwargs:
            http_args = self.init_authentication_method(
                request_args=request_args, **kwargs
            )
        else:
            # If nothing defined this is the default
            http_args = self.client_authn_method["bearer_header"](self).construct(
                request_args=request_args
            )

        headers.update(http_args["headers"])

        logger.debug("Fetch URI: %s" % uri)
        return self.http_request(uri, method, headers=headers)

    def add_code_challenge(self):
        """
        PKCE RFC 7636 support.

        :return:
        """
        try:
            cv_len = self.config["code_challenge"]["length"]
        except KeyError:
            cv_len = 64  # Use default

        code_verifier = unreserved(cv_len)
        _cv = code_verifier.encode("ascii")

        try:
            _method = self.config["code_challenge"]["method"]
        except KeyError:
            _method = "S256"

        try:
            _h = CC_METHOD[_method](_cv).digest()
            code_challenge = b64e(_h).decode("ascii")
        except KeyError:
            raise Unsupported("PKCE Transformation method:{}".format(_method))

        # TODO store code_verifier

        return (
            {"code_challenge": code_challenge, "code_challenge_method": _method},
            code_verifier,
        )

    def handle_provider_config(
        self,
        pcr: ASConfigurationResponse,
        issuer: str,
        keys: bool = True,
        endpoints: bool = True,
    ) -> None:
        """
        Deal with Provider Config Response.

        :param pcr: The ProviderConfigResponse instance
        :param issuer: The one I thought should be the issuer of the config
        :param keys: Should I deal with keys
        :param endpoints: Should I deal with endpoints, that is store them as attributes in self.
        """
        if "issuer" in pcr:
            _pcr_issuer = pcr["issuer"]
            if pcr["issuer"].endswith("/"):
                if issuer.endswith("/"):
                    _issuer = issuer
                else:
                    _issuer = issuer + "/"
            else:
                if issuer.endswith("/"):
                    _issuer = issuer[:-1]
                else:
                    _issuer = issuer

            if not self.allow.get("issuer_mismatch", False) and _issuer != _pcr_issuer:
                raise PyoidcError(
                    "provider info issuer mismatch '%s' != '%s'"
                    % (_issuer, _pcr_issuer)
                )

            self.provider_info = pcr
        else:
            _pcr_issuer = issuer

        self.issuer = _pcr_issuer

        if endpoints:
            for key, val in pcr.items():
                if key.endswith("_endpoint"):
                    setattr(self, key, val)

        if keys:
            if self.keyjar is None:
                self.keyjar = KeyJar()

            self.keyjar.load_keys(pcr, _pcr_issuer)

    def provider_config(
        self,
        issuer: str,
        keys: bool = True,
        endpoints: bool = True,
        serv_pattern: str = OIDCONF_PATTERN,
    ) -> ASConfigurationResponse:

        response_cls = self.message_factory.get_response_type("configuration_endpoint")
        if issuer.endswith("/"):
            _issuer = issuer[:-1]
        else:
            _issuer = issuer

        url = serv_pattern % _issuer

        pcr = None
        r = self.http_request(url, allow_redirects=True)
        if r.status_code == 200:
            try:
                pcr = response_cls().from_json(r.text)
            except Exception as e:
                # FIXME: This should catch specific exception from `from_json()`
                _err_txt = "Faulty provider config response: {}".format(e)
                logger.error(sanitize(_err_txt))
                raise ParseError(_err_txt)
        else:
            raise CommunicationError("Trying '%s', status %s" % (url, r.status_code))

        self.store_response(pcr, r.text)
        self.handle_provider_config(pcr, issuer, keys, endpoints)
        return pcr
예제 #8
0
파일: client.py 프로젝트: webleydev/pyoidc
class Client(oauth2.Client):
    def __init__(self,
                 client_id=None,
                 client_authn_method=None,
                 keyjar=None,
                 verify_ssl=True,
                 config=None):
        oauth2.Client.__init__(self,
                               client_id=client_id,
                               client_authn_method=client_authn_method,
                               keyjar=keyjar,
                               verify_ssl=verify_ssl,
                               config=config)
        self.allow = {}
        self.request2endpoint.update({
            "RegistrationRequest":
            "registration_endpoint",
            "ClientUpdateRequest":
            "clientinfo_endpoint",
            'TokenIntrospectionRequest':
            'introspection_endpoint',
            'TokenRevocationRequest':
            'revocation_endpoint'
        })
        self.registration_response = None

    def construct_RegistrationRequest(self,
                                      request=RegistrationRequest,
                                      request_args=None,
                                      extra_args=None,
                                      **kwargs):

        if request_args is None:
            request_args = {}

        return self.construct_request(request, request_args, extra_args)

    def construct_ClientUpdateRequest(self,
                                      request=ClientUpdateRequest,
                                      request_args=None,
                                      extra_args=None,
                                      **kwargs):

        if request_args is None:
            request_args = {}

        return self.construct_request(request, request_args, extra_args)

    def _token_interaction_setup(self, request_args=None, **kwargs):
        if request_args is None or 'token' not in request_args:
            token = self.get_token(**kwargs)
            try:
                _token_type_hint = kwargs['token_type_hint']
            except KeyError:
                _token_type_hint = 'access_token'

            request_args = {
                'token_type_hint': _token_type_hint,
                'token': getattr(token, _token_type_hint)
            }

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id

        return request_args

    def construct_TokenIntrospectionRequest(self,
                                            request=TokenIntrospectionRequest,
                                            request_args=None,
                                            extra_args=None,
                                            **kwargs):
        request_args = self._token_interaction_setup(request_args, **kwargs)
        return self.construct_request(request, request_args, extra_args)

    def construct_TokenRevocationRequest(self,
                                         request=TokenRevocationRequest,
                                         request_args=None,
                                         extra_args=None,
                                         **kwargs):

        request_args = self._token_interaction_setup(request_args, **kwargs)

        return self.construct_request(request, request_args, extra_args)

    def do_op(self,
              request,
              body_type='',
              method='GET',
              request_args=None,
              extra_args=None,
              http_args=None,
              response_cls=None,
              **kwargs):

        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        resp = self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       http_args=http_args)

        return resp

    def do_client_registration(self,
                               request=RegistrationRequest,
                               body_type="",
                               method="GET",
                               request_args=None,
                               extra_args=None,
                               http_args=None,
                               response_cls=ClientInfoResponse,
                               **kwargs):

        return self.do_op(request=request,
                          body_type=body_type,
                          method=method,
                          request_args=request_args,
                          extra_args=extra_args,
                          http_args=http_args,
                          response_cls=response_cls,
                          **kwargs)

    def do_client_read_request(self,
                               request=ClientUpdateRequest,
                               body_type="",
                               method="GET",
                               request_args=None,
                               extra_args=None,
                               http_args=None,
                               response_cls=ClientInfoResponse,
                               **kwargs):

        return self.do_op(request=request,
                          body_type=body_type,
                          method=method,
                          request_args=request_args,
                          extra_args=extra_args,
                          http_args=http_args,
                          response_cls=response_cls,
                          **kwargs)

    def do_client_update_request(self,
                                 request=ClientUpdateRequest,
                                 body_type="",
                                 method="PUT",
                                 request_args=None,
                                 extra_args=None,
                                 http_args=None,
                                 response_cls=ClientInfoResponse,
                                 **kwargs):

        return self.do_op(request=request,
                          body_type=body_type,
                          method=method,
                          request_args=request_args,
                          extra_args=extra_args,
                          http_args=http_args,
                          response_cls=response_cls,
                          **kwargs)

    def do_client_delete_request(self,
                                 request=ClientUpdateRequest,
                                 body_type="",
                                 method="DELETE",
                                 request_args=None,
                                 extra_args=None,
                                 http_args=None,
                                 response_cls=ClientInfoResponse,
                                 **kwargs):

        return self.do_op(request=request,
                          body_type=body_type,
                          method=method,
                          request_args=request_args,
                          extra_args=extra_args,
                          http_args=http_args,
                          response_cls=response_cls,
                          **kwargs)

    def do_token_introspection(self,
                               request=TokenIntrospectionRequest,
                               body_type="json",
                               method="POST",
                               request_args=None,
                               extra_args=None,
                               http_args=None,
                               response_cls=TokenIntrospectionResponse,
                               **kwargs):

        return self.do_op(request=request,
                          body_type=body_type,
                          method=method,
                          request_args=request_args,
                          extra_args=extra_args,
                          http_args=http_args,
                          response_cls=response_cls,
                          **kwargs)

    def do_token_revocation(self,
                            request=TokenRevocationRequest,
                            body_type="",
                            method="POST",
                            request_args=None,
                            extra_args=None,
                            http_args=None,
                            response_cls=None,
                            **kwargs):

        return self.do_op(request=request,
                          body_type=body_type,
                          method=method,
                          request_args=request_args,
                          extra_args=extra_args,
                          http_args=http_args,
                          response_cls=response_cls,
                          **kwargs)

    def add_code_challenge(self):
        try:
            cv_len = self.config['code_challenge']['length']
        except KeyError:
            cv_len = 64  # Use default

        code_verifier = unreserved(cv_len)
        _cv = code_verifier.encode()

        try:
            _method = self.config['code_challenge']['method']
        except KeyError:
            _method = 'S256'

        try:
            _h = CC_METHOD[_method](_cv).hexdigest()
            code_challenge = b64e(_h.encode()).decode()
        except KeyError:
            raise Unsupported('PKCE Transformation method:{}'.format(_method))

        # TODO store code_verifier

        return {
            "code_challenge": code_challenge,
            "code_challenge_method": _method
        }, code_verifier

    def do_authorization_request(self,
                                 request=AuthorizationRequest,
                                 state="",
                                 body_type="",
                                 method="GET",
                                 request_args=None,
                                 extra_args=None,
                                 http_args=None,
                                 response_cls=AuthorizationResponse,
                                 **kwargs):

        if 'code_challenge' in self.config and self.config['code_challenge']:
            _args, code_verifier = self.add_code_challenge()
            request_args.update(_args)

        oauth2.Client.do_authorization_request(self,
                                               request=request,
                                               state=state,
                                               body_type=body_type,
                                               method=method,
                                               request_args=request_args,
                                               extra_args=extra_args,
                                               http_args=http_args,
                                               response_cls=response_cls,
                                               **kwargs)

    def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True):
        """
        Deal with Provider Config Response
        :param pcr: The ProviderConfigResponse instance
        :param issuer: The one I thought should be the issuer of the config
        :param keys: Should I deal with keys
        :param endpoints: Should I deal with endpoints, that is store them
        as attributes in self.
        """

        if "issuer" in pcr:
            _pcr_issuer = pcr["issuer"]
            if pcr["issuer"].endswith("/"):
                if issuer.endswith("/"):
                    _issuer = issuer
                else:
                    _issuer = issuer + "/"
            else:
                if issuer.endswith("/"):
                    _issuer = issuer[:-1]
                else:
                    _issuer = issuer

            try:
                self.allow["issuer_mismatch"]
            except KeyError:
                try:
                    assert _issuer == _pcr_issuer
                except AssertionError:
                    raise PyoidcError(
                        "provider info issuer mismatch '%s' != '%s'" %
                        (_issuer, _pcr_issuer))

            self.provider_info = pcr
        else:
            _pcr_issuer = issuer

        if endpoints:
            for key, val in pcr.items():
                if key.endswith("_endpoint"):
                    setattr(self, key, val)

        if keys:
            if self.keyjar is None:
                self.keyjar = KeyJar()

            self.keyjar.load_keys(pcr, _pcr_issuer)

    def provider_config(self,
                        issuer,
                        keys=True,
                        endpoints=True,
                        response_cls=ASConfigurationResponse,
                        serv_pattern=OIDCONF_PATTERN):
        if issuer.endswith("/"):
            _issuer = issuer[:-1]
        else:
            _issuer = issuer

        url = serv_pattern % _issuer

        pcr = None
        r = self.http_request(url)

        if self.events:
            self.events.store('HTTP response header', r.headers)

        if r.status_code == 200:
            pcr = response_cls().from_json(r.text)
        elif r.status_code == 302:
            while r.status_code == 302:
                r = self.http_request(r.headers["location"])
                if r.status_code == 200:
                    pcr = response_cls().from_json(r.text)
                    break

        if pcr is None:
            raise PyoidcError("Trying '%s', status %s" % (url, r.status_code))

        self.handle_provider_config(pcr, issuer, keys, endpoints)

        return pcr

    def store_registration_info(self, reginfo):
        self.registration_response = reginfo
        self.client_secret = reginfo["client_secret"]
        self.client_id = reginfo["client_id"]
        self.redirect_uris = reginfo["redirect_uris"]

    def handle_registration_info(self, response):
        if response.status_code in SUCCESSFUL:
            resp = ClientInfoResponse().deserialize(response.text, "json")
            self.store_response(resp, response.text)
            self.store_registration_info(resp)
        else:
            resp = ErrorResponse().deserialize(response.text, "json")
            try:
                resp.verify()
                self.store_response(resp, response.text)
            except Exception:
                raise PyoidcError('Registration failed: {}'.format(
                    response.text))

        return resp

    def register(self, url, **kwargs):
        """
        Register the client at an OP

        :param url: The OPs registration endpoint
        :param kwargs: parameters to the registration request
        :return:
        """
        req = self.construct_RegistrationRequest(request_args=kwargs)

        headers = {"content-type": "application/json"}

        rsp = self.http_request(url,
                                "POST",
                                data=req.to_json(),
                                headers=headers)

        return self.handle_registration_info(rsp)

    def parse_authz_response(self, query):
        aresp = self.parse_response(AuthorizationResponse,
                                    info=query,
                                    sformat="urlencoded",
                                    keyjar=self.keyjar)
        if aresp.type() == "ErrorResponse":
            logger.info("ErrorResponse: %s" % sanitize(aresp))
            raise AuthzError(aresp.error)

        logger.info("Aresp: %s" % sanitize(aresp))

        return aresp
예제 #9
0
class Client(oauth2.Client):
    _endpoints = ENDPOINTS

    def __init__(self, client_id=None, ca_certs=None,
                 client_prefs=None, client_authn_method=None, keyjar=None,
                 verify_ssl=True):

        oauth2.Client.__init__(self, client_id, ca_certs,
                               client_authn_method=client_authn_method,
                               keyjar=keyjar, verify_ssl=verify_ssl)

        self.file_store = "./file/"
        self.file_uri = "http://localhost/"

        # OpenID connect specific endpoints
        for endpoint in ENDPOINTS:
            setattr(self, endpoint, "")

        self.id_token = None
        self.log = None

        self.request2endpoint = REQUEST2ENDPOINT
        self.response2error = RESPONSE2ERROR

        self.grant_class = Grant
        self.token_class = Token
        self.provider_info = None
        self.registration_response = None
        self.client_prefs = client_prefs or {}

        self.behaviour = {
            "request_object_signing_alg":
            DEF_SIGN_ALG["openid_request_object"]}

        self.wf = WebFinger(OIC_ISSUER)
        self.wf.httpd = self
        self.allow = {}
        self.post_logout_redirect_uris = []
        self.registration_expires = 0
        self.registration_access_token = None

        # Default key by kid for different key types
        # For instance {"RSA":"abc"}
        self.kid = {"sig": {}, "enc": {}}

    def _get_id_token(self, **kwargs):
        try:
            return kwargs["id_token"]
        except KeyError:
            grant = self.get_grant(**kwargs)

        if grant:
            try:
                _scope = kwargs["scope"]
            except KeyError:
                _scope = None

            for token in grant.tokens:
                if token.scope and _scope:
                    flag = True
                    for item in _scope:
                        try:
                            assert item in token.scope
                        except AssertionError:
                            flag = False
                            break
                    if not flag:
                        break
                if token.id_token:
                    return token.id_token

        return None

    def construct_AuthorizationRequest(self, request=AuthorizationRequest,
                                       request_args=None, extra_args=None,
                                       request_param=None, **kwargs):

        if request_args is not None:
            # if "claims" in request_args:
            #     kwargs["claims"] = request_args["claims"]
            #     del request_args["claims"]
            if "nonce" not in request_args:
                _rt = request_args["response_type"]
                if "token" in _rt or "id_token" in _rt:
                    request_args["nonce"] = rndstr(12)
        elif "response_type" in kwargs:
            if "token" in kwargs["response_type"]:
                request_args = {"nonce": rndstr(12)}
        else:  # Never wrong to specify a nonce
            request_args = {"nonce": rndstr(12)}

        if "request_method" in kwargs:
            if kwargs["request_method"] == "file":
                request_param = "request_uri"
                del kwargs["request_method"]

        areq = oauth2.Client.construct_AuthorizationRequest(self, request,
                                                            request_args,
                                                            extra_args,
                                                            **kwargs)

        if request_param:
            alg = self.behaviour["request_object_signing_alg"]
            if "algorithm" not in kwargs:
                kwargs["algorithm"] = alg

            if "keys" not in kwargs and alg:
                _kty = alg2keytype(alg)
                try:
                    kwargs["keys"] = self.keyjar.get_signing_key(
                        _kty, kid=self.kid["sig"][_kty])
                except KeyError:
                    kwargs["keys"] = self.keyjar.get_signing_key(_kty)

            _req = make_openid_request(areq, **kwargs)

            if request_param == "request":
                areq["request"] = _req
            else:
                _filedir = kwargs["local_dir"]
                _webpath = kwargs["base_path"]
                _name = rndstr(10)
                filename = os.path.join(_filedir, _name)
                while os.path.exists(filename):
                    _name = rndstr(10)
                    filename = os.path.join(_filedir, _name)
                fid = open(filename, mode="w")
                fid.write(_req)
                fid.close()
                _webname = "%s%s" % (_webpath, _name)
                areq["request_uri"] = _webname

        return areq

    #noinspection PyUnusedLocal
    def construct_AccessTokenRequest(self, request=AccessTokenRequest,
                                     request_args=None, extra_args=None,
                                     **kwargs):

        return oauth2.Client.construct_AccessTokenRequest(self, request,
                                                          request_args,
                                                          extra_args, **kwargs)

    def construct_RefreshAccessTokenRequest(self,
                                            request=RefreshAccessTokenRequest,
                                            request_args=None, extra_args=None,
                                            **kwargs):

        return oauth2.Client.construct_RefreshAccessTokenRequest(self, request,
                                                                 request_args,
                                                                 extra_args,
                                                                 **kwargs)

    def construct_UserInfoRequest(self, request=UserInfoRequest,
                                  request_args=None, extra_args=None,
                                  **kwargs):

        if request_args is None:
            request_args = {}

        if "access_token" in request_args:
            pass
        else:
            if "scope" not in kwargs:
                kwargs["scope"] = "openid"
            token = self.get_token(**kwargs)
            if token is None:
                raise PyoidcError("No valid token available")

            request_args["access_token"] = token.access_token

        return self.construct_request(request, request_args, extra_args)

    #noinspection PyUnusedLocal
    def construct_RegistrationRequest(self, request=RegistrationRequest,
                                      request_args=None, extra_args=None,
                                      **kwargs):

        return self.construct_request(request, request_args, extra_args)

    #noinspection PyUnusedLocal
    def construct_RefreshSessionRequest(self,
                                        request=RefreshSessionRequest,
                                        request_args=None, extra_args=None,
                                        **kwargs):

        return self.construct_request(request, request_args, extra_args)

    def _id_token_based(self, request, request_args=None, extra_args=None,
                        **kwargs):

        if request_args is None:
            request_args = {}

        try:
            _prop = kwargs["prop"]
        except KeyError:
            _prop = "id_token"

        if _prop in request_args:
            pass
        else:
            id_token = self._get_id_token(**kwargs)
            if id_token is None:
                raise PyoidcError("No valid id token available")

            request_args[_prop] = id_token

        return self.construct_request(request, request_args, extra_args)

    def construct_CheckSessionRequest(self, request=CheckSessionRequest,
                                      request_args=None, extra_args=None,
                                      **kwargs):

        return self._id_token_based(request, request_args, extra_args, **kwargs)

    def construct_CheckIDRequest(self, request=CheckIDRequest,
                                 request_args=None,
                                 extra_args=None, **kwargs):

        # access_token is where the id_token will be placed
        return self._id_token_based(request, request_args, extra_args,
                                    prop="access_token", **kwargs)

    def construct_EndSessionRequest(self, request=EndSessionRequest,
                                    request_args=None, extra_args=None,
                                    **kwargs):

        if request_args is None:
            request_args = {}

        if "state" in kwargs:
            request_args["state"] = kwargs["state"]
        elif "state" in request_args:
            kwargs["state"] = request_args["state"]

        #        if "redirect_url" not in request_args:
        #            request_args["redirect_url"] = self.redirect_url

        return self._id_token_based(request, request_args, extra_args,
                                    **kwargs)

    # ------------------------------------------------------------------------
    def authorization_request_info(self, request_args=None, extra_args=None,
                                   **kwargs):
        return self.request_info(AuthorizationRequest, "GET",
                                 request_args, extra_args, **kwargs)

    # ------------------------------------------------------------------------
    def do_authorization_request(self, request=AuthorizationRequest,
                                 state="", body_type="", method="GET",
                                 request_args=None, extra_args=None,
                                 http_args=None,
                                 response_cls=AuthorizationResponse):

        return oauth2.Client.do_authorization_request(self, request, state,
                                                      body_type, method,
                                                      request_args,
                                                      extra_args, http_args,
                                                      response_cls)

    def do_access_token_request(self, request=AccessTokenRequest,
                                scope="", state="", body_type="json",
                                method="POST", request_args=None,
                                extra_args=None, http_args=None,
                                response_cls=AccessTokenResponse,
                                authn_method="", **kwargs):

        return oauth2.Client.do_access_token_request(self, request, scope,
                                                     state, body_type, method,
                                                     request_args, extra_args,
                                                     http_args, response_cls,
                                                     authn_method, **kwargs)

    def do_access_token_refresh(self, request=RefreshAccessTokenRequest,
                                state="", body_type="json", method="POST",
                                request_args=None, extra_args=None,
                                http_args=None,
                                response_cls=AccessTokenResponse,
                                **kwargs):

        return oauth2.Client.do_access_token_refresh(self, request, state,
                                                     body_type, method,
                                                     request_args,
                                                     extra_args, http_args,
                                                     response_cls, **kwargs)

    def do_registration_request(self, request=RegistrationRequest,
                                scope="", state="", body_type="json",
                                method="POST", request_args=None,
                                extra_args=None, http_args=None,
                                response_cls=None):

        url, body, ht_args, csi = self.request_info(request, method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope, state=state)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        if response_cls is None:
            response_cls = RegistrationResponse

        response = self.request_and_return(url, response_cls, method, body,
                                           body_type, state=state,
                                           http_args=http_args)

        return response

    def do_check_session_request(self, request=CheckSessionRequest,
                                 scope="",
                                 state="", body_type="json", method="GET",
                                 request_args=None, extra_args=None,
                                 http_args=None,
                                 response_cls=IdToken):

        url, body, ht_args, csi = self.request_info(request, method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope, state=state)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        return self.request_and_return(url, response_cls, method, body,
                                       body_type, state=state,
                                       http_args=http_args)

    def do_check_id_request(self, request=CheckIDRequest, scope="",
                            state="", body_type="json", method="GET",
                            request_args=None, extra_args=None,
                            http_args=None,
                            response_cls=IdToken):

        url, body, ht_args, csi = self.request_info(request, method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope, state=state)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        return self.request_and_return(url, response_cls, method, body,
                                       body_type, state=state,
                                       http_args=http_args)

    def do_end_session_request(self, request=EndSessionRequest, scope="",
                               state="", body_type="", method="GET",
                               request_args=None, extra_args=None,
                               http_args=None, response_cls=None):

        url, body, ht_args, csi = self.request_info(request, method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope, state=state)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        return self.request_and_return(url, response_cls, method, body,
                                       body_type, state=state,
                                       http_args=http_args)

    def user_info_request(self, method="GET", state="", scope="", **kwargs):
        uir = UserInfoRequest()
        logger.debug("[user_info_request]: kwargs:%s" % (kwargs,))
        if "token" in kwargs:
            if kwargs["token"]:
                uir["access_token"] = kwargs["token"]
                token = Token()
                token.token_type = "Bearer"
                token.access_token = kwargs["token"]
                kwargs["behavior"] = "use_authorization_header"
            else:
                # What to do ? Need a callback
                token = None
        elif "access_token" in kwargs and kwargs["access_token"]:
            uir["access_token"] = kwargs["access_token"]
            del kwargs["access_token"]
            token = None
        else:
            token = self.grant[state].get_token(scope)

            if token.is_valid():
                uir["access_token"] = token.access_token
                if token.token_type == "Bearer" and method == "GET":
                    kwargs["behavior"] = "use_authorization_header"
            else:
                # raise oauth2.OldAccessToken
                if self.log:
                    self.log.info("do access token refresh")
                try:
                    self.do_access_token_refresh(token=token)
                    token = self.grant[state].get_token(scope)
                    uir["access_token"] = token.access_token
                except Exception:
                    raise

        uri = self._endpoint("userinfo_endpoint", **kwargs)
        # If access token is a bearer token it might be sent in the
        # authorization header
        # 3-ways of sending the access_token:
        # - POST with token in authorization header
        # - POST with token in message body
        # - GET with token in authorization header
        if "behavior" in kwargs:
            _behav = kwargs["behavior"]
            _token = uir["access_token"]
            try:
                _ttype = kwargs["token_type"]
            except KeyError:
                try:
                    _ttype = token.token_type
                except AttributeError:
                    raise MissingParameter("Unspecified token type")

            # use_authorization_header, token_in_message_body
            if "use_authorization_header" in _behav and _ttype == "Bearer":
                bh = "Bearer %s" % _token
                if "headers" in kwargs:
                    kwargs["headers"].update({"Authorization": bh})
                else:
                    kwargs["headers"] = {"Authorization": bh}

            if not "token_in_message_body" in _behav:
                # remove the token from the request
                del uir["access_token"]

        path, body, kwargs = self.get_or_post(uri, method, uir, **kwargs)

        h_args = dict([(k, v) for k, v in kwargs.items() if k in HTTP_ARGS])

        return path, body, method, h_args

    def do_user_info_request(self, method="POST", state="", scope="openid",
                             request="openid", **kwargs):

        kwargs["request"] = request
        path, body, method, h_args = self.user_info_request(method, state,
                                                            scope, **kwargs)

        logger.debug("[do_user_info_request] PATH:%s BODY:%s H_ARGS: %s" % (
            path, body, h_args))

        try:
            resp = self.http_request(path, method, data=body, **h_args)
        except oauth2.MissingRequiredAttribute:
            raise

        if resp.status_code == 200:
            try:
                assert "application/json" in resp.headers["content-type"]
                sformat = "json"
            except AssertionError:
                assert "application/jwt" in resp.headers["content-type"]
                sformat = "jwt"
        elif resp.status_code == 500:
            raise PyoidcError("ERROR: Something went wrong: %s" % resp.text)
        else:
            raise PyoidcError("ERROR: Something went wrong [%s]: %s" % (
                resp.status_code, resp.text))

        try:
            _schema = kwargs["user_info_schema"]
        except KeyError:
            _schema = OpenIDSchema

        logger.debug("Reponse text: '%s'" % resp.text)

        if sformat == "json":
            return _schema().from_json(txt=resp.text)
        else:
            algo = self.client_prefs["userinfo_signed_response_alg"]
            _kty = alg2keytype(algo)
            # Keys of the OP ?
            try:
                keys = self.keyjar.get_signing_key(_kty, self.kid["sig"][_kty])
            except KeyError:
                keys = self.keyjar.get_signing_key(_kty)

            return _schema().from_jwt(resp.text, keys)

    def get_userinfo_claims(self, access_token, endpoint, method="POST",
                            schema_class=OpenIDSchema, **kwargs):

        uir = UserInfoRequest(access_token=access_token)

        h_args = dict([(k, v) for k, v in kwargs.items() if k in HTTP_ARGS])

        if "authn_method" in kwargs:
            http_args = self.init_authentication_method(**kwargs)
        else:
            # If nothing defined this is the default
            http_args = self.init_authentication_method(uir, "bearer_header",
                                                        **kwargs)

        h_args.update(http_args)
        path, body, kwargs = self.get_or_post(endpoint, method, uir, **kwargs)

        try:
            resp = self.http_request(path, method, data=body, **h_args)
        except oauth2.MissingRequiredAttribute:
            raise

        if resp.status_code == 200:
            assert "application/json" in resp.headers["content-type"]
        elif resp.status_code == 500:
            raise PyoidcError("ERROR: Something went wrong: %s" % resp.text)
        else:
            raise PyoidcError(
                "ERROR: Something went wrong [%s]" % resp.status_code)

        return schema_class().from_json(txt=resp.text)

    def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True):
        """
        Deal with Provider Config Response
        :param pcr: The ProviderConfigResponse instance
        :param issuer: The one I thought should be the issuer of the config
        :param keys: Should I deal with keys
        :param endpoints: Should I deal with endpoints, that is store them
        as attributes in self.
        """

        if "issuer" in pcr:
            _pcr_issuer = pcr["issuer"]
            if pcr["issuer"].endswith("/"):
                if issuer.endswith("/"):
                    _issuer = issuer
                else:
                    _issuer = issuer + "/"
            else:
                if issuer.endswith("/"):
                    _issuer = issuer[:-1]
                else:
                    _issuer = issuer

            try:
                _ = self.allow["issuer_mismatch"]
            except KeyError:
                try:
                    assert _issuer == _pcr_issuer
                except AssertionError:
                    raise IssuerMismatch("'%s' != '%s'" % (_issuer,
                                                           _pcr_issuer),
                                         pcr)

            self.provider_info = pcr
        else:
            _pcr_issuer = issuer

        if endpoints:
            for key, val in pcr.items():
                if key.endswith("_endpoint"):
                    setattr(self, key, val)

        if keys:
            if self.keyjar is None:
                self.keyjar = KeyJar(verify_ssl=self.verify_ssl)

            self.keyjar.load_keys(pcr, _pcr_issuer)

    def provider_config(self, issuer, keys=True, endpoints=True,
                        response_cls=ProviderConfigurationResponse,
                        serv_pattern=OIDCONF_PATTERN):
        if issuer.endswith("/"):
            _issuer = issuer[:-1]
        else:
            _issuer = issuer

        url = serv_pattern % _issuer

        pcr = None
        r = self.http_request(url)
        if r.status_code == 200:
            pcr = response_cls().from_json(r.text)
        elif r.status_code == 302:
            while r.status_code == 302:
                r = self.http_request(r.headers["location"])
                if r.status_code == 200:
                    pcr = response_cls().from_json(r.text)
                    break

        #logger.debug("Provider info: %s" % pcr)
        if pcr is None:
            raise PyoidcError("Trying '%s', status %s" % (url, r.status_code))

        self.handle_provider_config(pcr, issuer, keys, endpoints)

        return pcr

    def unpack_aggregated_claims(self, userinfo):
        if userinfo._claim_sources:
            for csrc, spec in userinfo._claim_sources.items():
                if "JWT" in spec:
                    if not csrc in self.keyjar:
                        self.provider_config(csrc, endpoints=False)

                    keycol = self.keyjar.get_verify_key(owner=csrc)
                    for typ, keyl in self.keyjar.get_verify_key().items():
                        try:
                            keycol[typ].extend(keyl)
                        except KeyError:
                            keycol[typ] = keyl

                    info = json.loads(JWS().verify(str(spec["JWT"]), keycol))
                    attr = [n for n, s in
                            userinfo._claim_names.items() if s == csrc]
                    assert attr == info.keys()

                    for key, vals in info.items():
                        userinfo[key] = vals

        return userinfo

    def fetch_distributed_claims(self, userinfo, callback=None):
        for csrc, spec in userinfo._claim_sources.items():
            if "endpoint" in spec:
                #pcr = self.provider_config(csrc, keys=False, endpoints=False)

                if "access_token" in spec:
                    _uinfo = self.do_user_info_request(
                        token=spec["access_token"],
                        userinfo_endpoint=spec["endpoint"])
                else:
                    _uinfo = self.do_user_info_request(
                        token=callback(csrc),
                        userinfo_endpoint=spec["endpoint"])

                attr = [n for n, s in
                        userinfo._claim_names.items() if s == csrc]
                assert attr == _uinfo.keys()

                for key, vals in _uinfo.items():
                    userinfo[key] = vals

        return userinfo

    def verify_alg_support(self, alg, usage, other):
        """
        Verifies that the algorithm to be used are supported by the other side.

        :param alg: The algorithm specification
        :param usage: In which context the 'alg' will be used.
            The following values are supported:
            - userinfo
            - id_token
            - request_object
            - token_endpoint_auth
        :param other: The identifier for the other side
        :return: True or False
        """

        try:
            _pcr = self.provider_info
            supported = _pcr["%s_algs_supported" % usage]
        except KeyError:
            try:
                supported = getattr(self, "%s_algs_supported" % usage)
            except AttributeError:
                supported = None

        if supported is None:
            return True
        else:
            if alg in supported:
                return True
            else:
                return False

    def match_preferences(self, pcr=None, issuer=None):
        """
        Match the clients preferences against what the provider can do.

        :param pcr: Provider configuration response if available
        :param issuer: The issuer identifier
        """
        if not pcr:
            pcr = self.provider_info

        regreq = RegistrationRequest

        for _pref, _prov in PREFERENCE2PROVIDER.items():
            try:
                vals = self.client_prefs[_pref]
            except KeyError:
                continue

            try:
                _pvals = pcr[_prov]
            except KeyError:
                try:
                    self.behaviour[_pref] = PROVIDER_DEFAULT[_pref]
                except KeyError:
                    #self.behaviour[_pref]= vals[0]
                    if isinstance(pcr.c_param[_prov][0], list):
                        self.behaviour[_pref] = []
                    else:
                        self.behaviour[_pref] = None
                continue

            if isinstance(vals, basestring):
                if vals in _pvals:
                    self.behaviour[_pref] = vals
            else:
                vtyp = regreq.c_param[_pref]
                if isinstance(vtyp[0], list):
                    _list = True
                else:
                    _list = False

                for val in vals:
                    if val in _pvals:
                        if not _list:
                            self.behaviour[_pref] = val
                            break
                        else:
                            try:
                                self.behaviour[_pref].append(val)
                            except KeyError:
                                self.behaviour[_pref] = [val]


            if _pref not in self.behaviour:
                raise ConfigurationError(
                    "OP couldn't match preference:%s" % _pref, pcr)

        for key, val in self.client_prefs.items():
            if key in self.behaviour:
                continue

            try:
                vtyp = regreq.c_param[key]
                if isinstance(vtyp[0], list):
                    pass
                elif isinstance(val, list) and not isinstance(val, basestring):
                    val = val[0]
            except KeyError:
                pass
            if key not in PREFERENCE2PROVIDER:
                self.behaviour[key] = val

    def store_registration_info(self, reginfo):
        self.registration_response = reginfo
        if "token_endpoint_auth_method" not in self.registration_response:
            self.registration_response["token_endpoint_auth_method"] = "client_secret_post"
        self.client_secret = reginfo["client_secret"]
        self.client_id = reginfo["client_id"]
        try:
            self.registration_expires = reginfo["client_secret_expires_at"]
        except KeyError:
            pass
        try:
            self.registration_access_token = reginfo[
                "registration_access_token"]
        except KeyError:
            pass

    def handle_registration_info(self, response):
        if response.status_code == 200:
            resp = RegistrationResponse().deserialize(response.text, "json")
            self.store_registration_info(resp)
        else:
            err = ErrorResponse().deserialize(response.text, "json")
            raise PyoidcError("Registration failed: %s" % err.get_json())

        return resp

    def registration_read(self, url="", registration_access_token=None):
        if not url:
            url = self.registration_response["registration_client_uri"]

        if not registration_access_token:
            registration_access_token = self.registration_access_token

        headers = [("Authorization", "Bearer %s" % registration_access_token)]
        rsp = self.http_request(url, "GET", headers=headers)

        return self.handle_registration_info(rsp)

    def create_registration_request(self, **kwargs):
        """
        Create a registration request

        :param kwargs: parameters to the registration request
        :return:
        """
        req = RegistrationRequest()

        for prop in req.parameters():
            try:
                req[prop] = kwargs[prop]
            except KeyError:
                try:
                    req[prop] = self.behaviour[prop]
                except KeyError:
                    pass

        if "post_logout_redirect_uris" not in req:
            try:
                req[
                    "post_logout_redirect_uris"] = self.post_logout_redirect_uris
            except AttributeError:
                pass

        if "redirect_uris" not in req:
            try:
                req["redirect_uris"] = self.redirect_uris
            except AttributeError:
                raise MissingRequiredAttribute("redirect_uris", req)

        return req

    def register(self, url, **kwargs):
        """
        Register the client at an OP

        :param url: The OPs registration endpoint
        :param kwargs: parameters to the registration request
        :return:
        """
        req = self.create_registration_request(**kwargs)

        headers = {"content-type": "application/json"}

        rsp = self.http_request(url, "POST", data=req.to_json(),
                                headers=headers)

        return self.handle_registration_info(rsp)

    def normalization(self, principal, idtype="mail"):
        if idtype == "mail":
            (local, domain) = principal.split("@")
            subject = "acct:%s" % principal
        elif idtype == "url":
            p = urlparse.urlparse(principal)
            domain = p.netloc
            subject = principal
        else:
            domain = ""
            subject = principal

        return subject, domain

    def discover(self, principal):
        #subject, host = self.normalization(principal)
        return self.wf.discovery_query(principal)
예제 #10
0
파일: __init__.py 프로젝트: ghedin/pyoidc
class Client(oauth2.Client):
    _endpoints = ENDPOINTS

    def __init__(self,
                 client_id=None,
                 ca_certs=None,
                 client_prefs=None,
                 client_authn_method=None,
                 keyjar=None,
                 verify_ssl=True):

        oauth2.Client.__init__(self,
                               client_id,
                               ca_certs,
                               client_authn_method=client_authn_method,
                               keyjar=keyjar,
                               verify_ssl=verify_ssl)

        self.file_store = "./file/"
        self.file_uri = "http://localhost/"

        # OpenID connect specific endpoints
        for endpoint in ENDPOINTS:
            setattr(self, endpoint, "")

        self.id_token = None
        self.log = None

        self.request2endpoint = REQUEST2ENDPOINT
        self.response2error = RESPONSE2ERROR

        self.grant_class = Grant
        self.token_class = Token
        self.provider_info = None
        self.registration_response = None
        self.client_prefs = client_prefs or {}

        self.behaviour = {
            "request_object_signing_alg": DEF_SIGN_ALG["openid_request_object"]
        }

        self.wf = WebFinger(OIC_ISSUER)
        self.wf.httpd = self
        self.allow = {}
        self.post_logout_redirect_uris = []
        self.registration_expires = 0
        self.registration_access_token = None

        # Default key by kid for different key types
        # For instance {"RSA":"abc"}
        self.kid = {"sig": {}, "enc": {}}

    def _get_id_token(self, **kwargs):
        try:
            return kwargs["id_token"]
        except KeyError:
            grant = self.get_grant(**kwargs)

        if grant:
            try:
                _scope = kwargs["scope"]
            except KeyError:
                _scope = None

            for token in grant.tokens:
                if token.scope and _scope:
                    flag = True
                    for item in _scope:
                        try:
                            assert item in token.scope
                        except AssertionError:
                            flag = False
                            break
                    if not flag:
                        break
                if token.id_token:
                    return token.id_token

        return None

    def construct_AuthorizationRequest(self,
                                       request=AuthorizationRequest,
                                       request_args=None,
                                       extra_args=None,
                                       request_param=None,
                                       **kwargs):

        if request_args is not None:
            # if "claims" in request_args:
            #     kwargs["claims"] = request_args["claims"]
            #     del request_args["claims"]
            if "nonce" not in request_args:
                _rt = request_args["response_type"]
                if "token" in _rt or "id_token" in _rt:
                    request_args["nonce"] = rndstr(12)
        elif "response_type" in kwargs:
            if "token" in kwargs["response_type"]:
                request_args = {"nonce": rndstr(12)}
        else:  # Never wrong to specify a nonce
            request_args = {"nonce": rndstr(12)}

        if "request_method" in kwargs:
            if kwargs["request_method"] == "file":
                request_param = "request_uri"
                del kwargs["request_method"]

        areq = oauth2.Client.construct_AuthorizationRequest(
            self, request, request_args, extra_args, **kwargs)

        if request_param:
            alg = self.behaviour["request_object_signing_alg"]
            if "algorithm" not in kwargs:
                kwargs["algorithm"] = alg

            if "keys" not in kwargs and alg:
                _kty = alg2keytype(alg)
                try:
                    kwargs["keys"] = self.keyjar.get_signing_key(
                        _kty, kid=self.kid["sig"][_kty])
                except KeyError:
                    kwargs["keys"] = self.keyjar.get_signing_key(_kty)

            _req = make_openid_request(areq, **kwargs)

            if request_param == "request":
                areq["request"] = _req
            else:
                _filedir = kwargs["local_dir"]
                _webpath = kwargs["base_path"]
                _name = rndstr(10)
                filename = os.path.join(_filedir, _name)
                while os.path.exists(filename):
                    _name = rndstr(10)
                    filename = os.path.join(_filedir, _name)
                fid = open(filename, mode="w")
                fid.write(_req)
                fid.close()
                _webname = "%s%s" % (_webpath, _name)
                areq["request_uri"] = _webname

        return areq

    #noinspection PyUnusedLocal
    def construct_AccessTokenRequest(self,
                                     request=AccessTokenRequest,
                                     request_args=None,
                                     extra_args=None,
                                     **kwargs):

        return oauth2.Client.construct_AccessTokenRequest(
            self, request, request_args, extra_args, **kwargs)

    def construct_RefreshAccessTokenRequest(self,
                                            request=RefreshAccessTokenRequest,
                                            request_args=None,
                                            extra_args=None,
                                            **kwargs):

        return oauth2.Client.construct_RefreshAccessTokenRequest(
            self, request, request_args, extra_args, **kwargs)

    def construct_UserInfoRequest(self,
                                  request=UserInfoRequest,
                                  request_args=None,
                                  extra_args=None,
                                  **kwargs):

        if request_args is None:
            request_args = {}

        if "access_token" in request_args:
            pass
        else:
            if "scope" not in kwargs:
                kwargs["scope"] = "openid"
            token = self.get_token(**kwargs)
            if token is None:
                raise PyoidcError("No valid token available")

            request_args["access_token"] = token.access_token

        return self.construct_request(request, request_args, extra_args)

    #noinspection PyUnusedLocal
    def construct_RegistrationRequest(self,
                                      request=RegistrationRequest,
                                      request_args=None,
                                      extra_args=None,
                                      **kwargs):

        return self.construct_request(request, request_args, extra_args)

    #noinspection PyUnusedLocal
    def construct_RefreshSessionRequest(self,
                                        request=RefreshSessionRequest,
                                        request_args=None,
                                        extra_args=None,
                                        **kwargs):

        return self.construct_request(request, request_args, extra_args)

    def _id_token_based(self,
                        request,
                        request_args=None,
                        extra_args=None,
                        **kwargs):

        if request_args is None:
            request_args = {}

        try:
            _prop = kwargs["prop"]
        except KeyError:
            _prop = "id_token"

        if _prop in request_args:
            pass
        else:
            id_token = self._get_id_token(**kwargs)
            if id_token is None:
                raise PyoidcError("No valid id token available")

            request_args[_prop] = id_token

        return self.construct_request(request, request_args, extra_args)

    def construct_CheckSessionRequest(self,
                                      request=CheckSessionRequest,
                                      request_args=None,
                                      extra_args=None,
                                      **kwargs):

        return self._id_token_based(request, request_args, extra_args,
                                    **kwargs)

    def construct_CheckIDRequest(self,
                                 request=CheckIDRequest,
                                 request_args=None,
                                 extra_args=None,
                                 **kwargs):

        # access_token is where the id_token will be placed
        return self._id_token_based(request,
                                    request_args,
                                    extra_args,
                                    prop="access_token",
                                    **kwargs)

    def construct_EndSessionRequest(self,
                                    request=EndSessionRequest,
                                    request_args=None,
                                    extra_args=None,
                                    **kwargs):

        if request_args is None:
            request_args = {}

        if "state" in kwargs:
            request_args["state"] = kwargs["state"]
        elif "state" in request_args:
            kwargs["state"] = request_args["state"]

        #        if "redirect_url" not in request_args:
        #            request_args["redirect_url"] = self.redirect_url

        return self._id_token_based(request, request_args, extra_args,
                                    **kwargs)

    # ------------------------------------------------------------------------
    def authorization_request_info(self,
                                   request_args=None,
                                   extra_args=None,
                                   **kwargs):
        return self.request_info(AuthorizationRequest, "GET", request_args,
                                 extra_args, **kwargs)

    # ------------------------------------------------------------------------
    def do_authorization_request(self,
                                 request=AuthorizationRequest,
                                 state="",
                                 body_type="",
                                 method="GET",
                                 request_args=None,
                                 extra_args=None,
                                 http_args=None,
                                 response_cls=AuthorizationResponse):

        return oauth2.Client.do_authorization_request(self, request, state,
                                                      body_type, method,
                                                      request_args, extra_args,
                                                      http_args, response_cls)

    def do_access_token_request(self,
                                request=AccessTokenRequest,
                                scope="",
                                state="",
                                body_type="json",
                                method="POST",
                                request_args=None,
                                extra_args=None,
                                http_args=None,
                                response_cls=AccessTokenResponse,
                                authn_method="",
                                **kwargs):

        return oauth2.Client.do_access_token_request(self, request, scope,
                                                     state, body_type, method,
                                                     request_args, extra_args,
                                                     http_args, response_cls,
                                                     authn_method, **kwargs)

    def do_access_token_refresh(self,
                                request=RefreshAccessTokenRequest,
                                state="",
                                body_type="json",
                                method="POST",
                                request_args=None,
                                extra_args=None,
                                http_args=None,
                                response_cls=AccessTokenResponse,
                                **kwargs):

        return oauth2.Client.do_access_token_refresh(self, request, state,
                                                     body_type, method,
                                                     request_args, extra_args,
                                                     http_args, response_cls,
                                                     **kwargs)

    def do_registration_request(self,
                                request=RegistrationRequest,
                                scope="",
                                state="",
                                body_type="json",
                                method="POST",
                                request_args=None,
                                extra_args=None,
                                http_args=None,
                                response_cls=None):

        url, body, ht_args, csi = self.request_info(request,
                                                    method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope,
                                                    state=state)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        if response_cls is None:
            response_cls = RegistrationResponse

        response = self.request_and_return(url,
                                           response_cls,
                                           method,
                                           body,
                                           body_type,
                                           state=state,
                                           http_args=http_args)

        return response

    def do_check_session_request(self,
                                 request=CheckSessionRequest,
                                 scope="",
                                 state="",
                                 body_type="json",
                                 method="GET",
                                 request_args=None,
                                 extra_args=None,
                                 http_args=None,
                                 response_cls=IdToken):

        url, body, ht_args, csi = self.request_info(request,
                                                    method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope,
                                                    state=state)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        return self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       state=state,
                                       http_args=http_args)

    def do_check_id_request(self,
                            request=CheckIDRequest,
                            scope="",
                            state="",
                            body_type="json",
                            method="GET",
                            request_args=None,
                            extra_args=None,
                            http_args=None,
                            response_cls=IdToken):

        url, body, ht_args, csi = self.request_info(request,
                                                    method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope,
                                                    state=state)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        return self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       state=state,
                                       http_args=http_args)

    def do_end_session_request(self,
                               request=EndSessionRequest,
                               scope="",
                               state="",
                               body_type="",
                               method="GET",
                               request_args=None,
                               extra_args=None,
                               http_args=None,
                               response_cls=None):

        url, body, ht_args, csi = self.request_info(request,
                                                    method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope,
                                                    state=state)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        return self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       state=state,
                                       http_args=http_args)

    def user_info_request(self, method="GET", state="", scope="", **kwargs):
        uir = UserInfoRequest()
        logger.debug("[user_info_request]: kwargs:%s" % (kwargs, ))
        if "token" in kwargs:
            if kwargs["token"]:
                uir["access_token"] = kwargs["token"]
                token = Token()
                token.token_type = "Bearer"
                token.access_token = kwargs["token"]
                kwargs["behavior"] = "use_authorization_header"
            else:
                # What to do ? Need a callback
                token = None
        elif "access_token" in kwargs and kwargs["access_token"]:
            uir["access_token"] = kwargs["access_token"]
            del kwargs["access_token"]
            token = None
        else:
            token = self.grant[state].get_token(scope)

            if token.is_valid():
                uir["access_token"] = token.access_token
                if token.token_type == "Bearer" and method == "GET":
                    kwargs["behavior"] = "use_authorization_header"
            else:
                # raise oauth2.OldAccessToken
                if self.log:
                    self.log.info("do access token refresh")
                try:
                    self.do_access_token_refresh(token=token)
                    token = self.grant[state].get_token(scope)
                    uir["access_token"] = token.access_token
                except Exception:
                    raise

        uri = self._endpoint("userinfo_endpoint", **kwargs)
        # If access token is a bearer token it might be sent in the
        # authorization header
        # 3-ways of sending the access_token:
        # - POST with token in authorization header
        # - POST with token in message body
        # - GET with token in authorization header
        if "behavior" in kwargs:
            _behav = kwargs["behavior"]
            _token = uir["access_token"]
            try:
                _ttype = kwargs["token_type"]
            except KeyError:
                try:
                    _ttype = token.token_type
                except AttributeError:
                    raise MissingParameter("Unspecified token type")

            # use_authorization_header, token_in_message_body
            if "use_authorization_header" in _behav and _ttype == "Bearer":
                bh = "Bearer %s" % _token
                if "headers" in kwargs:
                    kwargs["headers"].update({"Authorization": bh})
                else:
                    kwargs["headers"] = {"Authorization": bh}

            if not "token_in_message_body" in _behav:
                # remove the token from the request
                del uir["access_token"]

        path, body, kwargs = self.get_or_post(uri, method, uir, **kwargs)

        h_args = dict([(k, v) for k, v in kwargs.items() if k in HTTP_ARGS])

        return path, body, method, h_args

    def do_user_info_request(self,
                             method="POST",
                             state="",
                             scope="openid",
                             request="openid",
                             **kwargs):

        kwargs["request"] = request
        path, body, method, h_args = self.user_info_request(
            method, state, scope, **kwargs)

        logger.debug("[do_user_info_request] PATH:%s BODY:%s H_ARGS: %s" %
                     (path, body, h_args))

        try:
            resp = self.http_request(path, method, data=body, **h_args)
        except oauth2.MissingRequiredAttribute:
            raise

        if resp.status_code == 200:
            try:
                assert "application/json" in resp.headers["content-type"]
                sformat = "json"
            except AssertionError:
                assert "application/jwt" in resp.headers["content-type"]
                sformat = "jwt"
        elif resp.status_code == 500:
            raise PyoidcError("ERROR: Something went wrong: %s" % resp.text)
        else:
            raise PyoidcError("ERROR: Something went wrong [%s]: %s" %
                              (resp.status_code, resp.text))

        try:
            _schema = kwargs["user_info_schema"]
        except KeyError:
            _schema = OpenIDSchema

        logger.debug("Reponse text: '%s'" % resp.text)

        if sformat == "json":
            return _schema().from_json(txt=resp.text)
        else:
            algo = self.client_prefs["userinfo_signed_response_alg"]
            _kty = alg2keytype(algo)
            # Keys of the OP ?
            try:
                keys = self.keyjar.get_signing_key(_kty, self.kid["sig"][_kty])
            except KeyError:
                keys = self.keyjar.get_signing_key(_kty)

            return _schema().from_jwt(resp.text, keys)

    def get_userinfo_claims(self,
                            access_token,
                            endpoint,
                            method="POST",
                            schema_class=OpenIDSchema,
                            **kwargs):

        uir = UserInfoRequest(access_token=access_token)

        h_args = dict([(k, v) for k, v in kwargs.items() if k in HTTP_ARGS])

        if "authn_method" in kwargs:
            http_args = self.init_authentication_method(**kwargs)
        else:
            # If nothing defined this is the default
            http_args = self.init_authentication_method(
                uir, "bearer_header", **kwargs)

        h_args.update(http_args)
        path, body, kwargs = self.get_or_post(endpoint, method, uir, **kwargs)

        try:
            resp = self.http_request(path, method, data=body, **h_args)
        except oauth2.MissingRequiredAttribute:
            raise

        if resp.status_code == 200:
            assert "application/json" in resp.headers["content-type"]
        elif resp.status_code == 500:
            raise PyoidcError("ERROR: Something went wrong: %s" % resp.text)
        else:
            raise PyoidcError("ERROR: Something went wrong [%s]" %
                              resp.status_code)

        return schema_class().from_json(txt=resp.text)

    def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True):
        """
        Deal with Provider Config Response
        :param pcr: The ProviderConfigResponse instance
        :param issuer: The one I thought should be the issuer of the config
        :param keys: Should I deal with keys
        :param endpoints: Should I deal with endpoints, that is store them
        as attributes in self.
        """

        if "issuer" in pcr:
            _pcr_issuer = pcr["issuer"]
            if pcr["issuer"].endswith("/"):
                if issuer.endswith("/"):
                    _issuer = issuer
                else:
                    _issuer = issuer + "/"
            else:
                if issuer.endswith("/"):
                    _issuer = issuer[:-1]
                else:
                    _issuer = issuer

            try:
                _ = self.allow["issuer_mismatch"]
            except KeyError:
                try:
                    assert _issuer == _pcr_issuer
                except AssertionError:
                    raise IssuerMismatch(
                        "'%s' != '%s'" % (_issuer, _pcr_issuer), pcr)

            self.provider_info = pcr
        else:
            _pcr_issuer = issuer

        if endpoints:
            for key, val in pcr.items():
                if key.endswith("_endpoint"):
                    setattr(self, key, val)

        if keys:
            if self.keyjar is None:
                self.keyjar = KeyJar(verify_ssl=self.verify_ssl)

            self.keyjar.load_keys(pcr, _pcr_issuer)

    def provider_config(self,
                        issuer,
                        keys=True,
                        endpoints=True,
                        response_cls=ProviderConfigurationResponse,
                        serv_pattern=OIDCONF_PATTERN):
        if issuer.endswith("/"):
            _issuer = issuer[:-1]
        else:
            _issuer = issuer

        url = serv_pattern % _issuer

        pcr = None
        r = self.http_request(url)
        if r.status_code == 200:
            pcr = response_cls().from_json(r.text)
        elif r.status_code == 302:
            while r.status_code == 302:
                r = self.http_request(r.headers["location"])
                if r.status_code == 200:
                    pcr = response_cls().from_json(r.text)
                    break

        #logger.debug("Provider info: %s" % pcr)
        if pcr is None:
            raise PyoidcError("Trying '%s', status %s" % (url, r.status_code))

        self.handle_provider_config(pcr, issuer, keys, endpoints)

        return pcr

    def unpack_aggregated_claims(self, userinfo):
        if userinfo._claim_sources:
            for csrc, spec in userinfo._claim_sources.items():
                if "JWT" in spec:
                    if not csrc in self.keyjar:
                        self.provider_config(csrc, endpoints=False)

                    keycol = self.keyjar.get_verify_key(owner=csrc)
                    for typ, keyl in self.keyjar.get_verify_key().items():
                        try:
                            keycol[typ].extend(keyl)
                        except KeyError:
                            keycol[typ] = keyl

                    info = json.loads(JWS().verify(str(spec["JWT"]), keycol))
                    attr = [
                        n for n, s in userinfo._claim_names.items()
                        if s == csrc
                    ]
                    assert attr == info.keys()

                    for key, vals in info.items():
                        userinfo[key] = vals

        return userinfo

    def fetch_distributed_claims(self, userinfo, callback=None):
        for csrc, spec in userinfo._claim_sources.items():
            if "endpoint" in spec:
                #pcr = self.provider_config(csrc, keys=False, endpoints=False)

                if "access_token" in spec:
                    _uinfo = self.do_user_info_request(
                        token=spec["access_token"],
                        userinfo_endpoint=spec["endpoint"])
                else:
                    _uinfo = self.do_user_info_request(
                        token=callback(csrc),
                        userinfo_endpoint=spec["endpoint"])

                attr = [
                    n for n, s in userinfo._claim_names.items() if s == csrc
                ]
                assert attr == _uinfo.keys()

                for key, vals in _uinfo.items():
                    userinfo[key] = vals

        return userinfo

    def verify_alg_support(self, alg, usage, other):
        """
        Verifies that the algorithm to be used are supported by the other side.

        :param alg: The algorithm specification
        :param usage: In which context the 'alg' will be used.
            The following values are supported:
            - userinfo
            - id_token
            - request_object
            - token_endpoint_auth
        :param other: The identifier for the other side
        :return: True or False
        """

        try:
            _pcr = self.provider_info
            supported = _pcr["%s_algs_supported" % usage]
        except KeyError:
            try:
                supported = getattr(self, "%s_algs_supported" % usage)
            except AttributeError:
                supported = None

        if supported is None:
            return True
        else:
            if alg in supported:
                return True
            else:
                return False

    def match_preferences(self, pcr=None, issuer=None):
        """
        Match the clients preferences against what the provider can do.

        :param pcr: Provider configuration response if available
        :param issuer: The issuer identifier
        """
        if not pcr:
            pcr = self.provider_info

        regreq = RegistrationRequest

        for _pref, _prov in PREFERENCE2PROVIDER.items():
            try:
                vals = self.client_prefs[_pref]
            except KeyError:
                continue

            try:
                _pvals = pcr[_prov]
            except KeyError:
                try:
                    self.behaviour[_pref] = PROVIDER_DEFAULT[_pref]
                except KeyError:
                    #self.behaviour[_pref]= vals[0]
                    if isinstance(pcr.c_param[_prov][0], list):
                        self.behaviour[_pref] = []
                    else:
                        self.behaviour[_pref] = None
                continue

            if isinstance(vals, basestring):
                if vals in _pvals:
                    self.behaviour[_pref] = vals
            else:
                vtyp = regreq.c_param[_pref]
                if isinstance(vtyp[0], list):
                    _list = True
                else:
                    _list = False

                for val in vals:
                    if val in _pvals:
                        if not _list:
                            self.behaviour[_pref] = val
                            break
                        else:
                            try:
                                self.behaviour[_pref].append(val)
                            except KeyError:
                                self.behaviour[_pref] = [val]

            if _pref not in self.behaviour:
                raise ConfigurationError(
                    "OP couldn't match preference:%s" % _pref, pcr)

        for key, val in self.client_prefs.items():
            if key in self.behaviour:
                continue

            try:
                vtyp = regreq.c_param[key]
                if isinstance(vtyp[0], list):
                    pass
                elif isinstance(val, list) and not isinstance(val, basestring):
                    val = val[0]
            except KeyError:
                pass
            if key not in PREFERENCE2PROVIDER:
                self.behaviour[key] = val

    def store_registration_info(self, reginfo):
        self.registration_response = reginfo
        if "token_endpoint_auth_method" not in self.registration_response:
            self.registration_response[
                "token_endpoint_auth_method"] = "client_secret_post"
        self.client_secret = reginfo["client_secret"]
        self.client_id = reginfo["client_id"]
        try:
            self.registration_expires = reginfo["client_secret_expires_at"]
        except KeyError:
            pass
        try:
            self.registration_access_token = reginfo[
                "registration_access_token"]
        except KeyError:
            pass

    def handle_registration_info(self, response):
        if response.status_code == 200:
            resp = RegistrationResponse().deserialize(response.text, "json")
            self.store_registration_info(resp)
        else:
            err = ErrorResponse().deserialize(response.text, "json")
            raise PyoidcError("Registration failed: %s" % err.get_json())

        return resp

    def registration_read(self, url="", registration_access_token=None):
        if not url:
            url = self.registration_response["registration_client_uri"]

        if not registration_access_token:
            registration_access_token = self.registration_access_token

        headers = [("Authorization", "Bearer %s" % registration_access_token)]
        rsp = self.http_request(url, "GET", headers=headers)

        return self.handle_registration_info(rsp)

    def create_registration_request(self, **kwargs):
        """
        Create a registration request

        :param kwargs: parameters to the registration request
        :return:
        """
        req = RegistrationRequest()

        for prop in req.parameters():
            try:
                req[prop] = kwargs[prop]
            except KeyError:
                try:
                    req[prop] = self.behaviour[prop]
                except KeyError:
                    pass

        if "post_logout_redirect_uris" not in req:
            try:
                req["post_logout_redirect_uris"] = self.post_logout_redirect_uris
            except AttributeError:
                pass

        if "redirect_uris" not in req:
            try:
                req["redirect_uris"] = self.redirect_uris
            except AttributeError:
                raise MissingRequiredAttribute("redirect_uris", req)

        return req

    def register(self, url, **kwargs):
        """
        Register the client at an OP

        :param url: The OPs registration endpoint
        :param kwargs: parameters to the registration request
        :return:
        """
        req = self.create_registration_request(**kwargs)

        headers = {"content-type": "application/json"}

        rsp = self.http_request(url,
                                "POST",
                                data=req.to_json(),
                                headers=headers)

        return self.handle_registration_info(rsp)

    def normalization(self, principal, idtype="mail"):
        if idtype == "mail":
            (local, domain) = principal.split("@")
            subject = "acct:%s" % principal
        elif idtype == "url":
            p = urlparse.urlparse(principal)
            domain = p.netloc
            subject = principal
        else:
            domain = ""
            subject = principal

        return subject, domain

    def discover(self, principal):
        #subject, host = self.normalization(principal)
        return self.wf.discovery_query(principal)
예제 #11
0
파일: x_provider.py 프로젝트: rohe/pyuma
class Provider(provider.Provider):
    def __init__(self, name, sdb, cdb, authn_broker, authz,
                 client_authn, symkey, urlmap=None, keyjar=None,
                 hostname="", configuration=None):
        provider.Provider.__init__(self, name, sdb, cdb, authn_broker, authz,
                                   client_authn, symkey=symkey, urlmap=urlmap)

        self.baseurl = ""
        if keyjar:
            self.keyjar = keyjar
        else:
            self.keyjar = KeyJar()

        self.hostname = hostname or socket.gethostname
        self.jwks_uri = []
        self.endpoints = [DynamicClientEndpoint, TokenEndpoint,
                          AuthorizationEndpoint, UserEndpoint,
                          ResourceSetRegistrationEndpoint,
                          IntrospectionEndpoint, RPTEndpoint,
                          PermissionRegistrationEndpoint]


    def set_authn_broker(self, authn_broker):
        self.authn_broker = authn_broker
        for meth in self.authn_broker:
            meth.srv = self
        if authn_broker:
            self.cookie_func = authn_broker[0].create_cookie
        else:
            self.cookie_func = None

    @staticmethod
    def _verify_url(url, urlset):
        part = urllib.parse.urlparse(url)

        for reg, qp in urlset:
            _part = urllib.parse.urlparse(reg)
            if part.scheme == _part.scheme and part.netloc == _part.netloc:
                return True

        return False

    def do_client_registration(self, request, client_id, ignore=None):
        if ignore is None:
            ignore = []

        _cinfo = self.cdb[client_id].copy()
        logger.debug("_cinfo: %s" % _cinfo)

        for key, val in list(request.items()):
            if key not in ignore:
                _cinfo[key] = val

        if "redirect_uris" in request:
            ruri = []
            for uri in request["redirect_uris"]:
                if urllib.parse.urlparse(uri).fragment:
                    err = ClientRegistrationErrorResponse(
                        error="invalid_configuration_parameter",
                        error_description="redirect_uri contains fragment")
                    return Response(err.to_json(),
                                    content="application/json",
                                    status="400 Bad Request")
                base, query = urllib.parse.splitquery(uri)
                if query:
                    ruri.append((base, urllib.parse.parse_qs(query)))
                else:
                    ruri.append((base, query))
            _cinfo["redirect_uris"] = ruri

        if "sector_identifier_uri" in request:
            si_url = request["sector_identifier_uri"]
            try:
                res = self.server.http_request(si_url)
            except ConnectionError as err:
                logger.error("%s" % err)
                return self._error_response(
                    "invalid_configuration_parameter",
                    descr="Couldn't open sector_identifier_uri")

            if not res:
                return self._error_response(
                    "invalid_configuration_parameter",
                    descr="Couldn't open sector_identifier_uri")

            logger.debug("sector_identifier_uri => %s" % res.text)

            try:
                si_redirects = json.loads(res.text)
            except ValueError:
                return self._error_response(
                    "invalid_configuration_parameter",
                    descr="Error deserializing sector_identifier_uri content")

            if "redirect_uris" in request:
                logger.debug("redirect_uris: %s" % request["redirect_uris"])
                for uri in request["redirect_uris"]:
                    try:
                        assert uri in si_redirects
                    except AssertionError:
                        return self._error_response(
                            "invalid_configuration_parameter",
                            descr="redirect_uri missing from sector_identifiers"
                        )

            _cinfo["si_redirects"] = si_redirects
            _cinfo["sector_id"] = si_url
        elif "redirect_uris" in request:
            if len(request["redirect_uris"]) > 1:
                # check that the hostnames are the same
                host = ""
                for url in request["redirect_uris"]:
                    part = urllib.parse.urlparse(url)
                    _host = part.netloc.split(":")[0]
                    if not host:
                        host = _host
                    else:
                        try:
                            assert host == _host
                        except AssertionError:
                            return self._error_response(
                                "invalid_configuration_parameter",
                                descr=
                                "'sector_identifier_uri' must be registered")

        for item in ["policy_url", "logo_url"]:
            if item in request:
                if self._verify_url(request[item], _cinfo["redirect_uris"]):
                    _cinfo[item] = request[item]
                else:
                    return self._error_response(
                        "invalid_configuration_parameter",
                        descr="%s pointed to illegal URL" % item)

        try:
            self.keyjar.load_keys(request, client_id)
            try:
                logger.debug("keys for %s: [%s]" % (
                    client_id,
                    ",".join(["%s" % x for x in self.keyjar[client_id]])))
            except KeyError:
                pass
        except Exception as err:
            logger.error("Failed to load client keys: %s" % request.to_dict())
            err = ClientRegistrationErrorResponse(
                error="invalid_configuration_parameter",
                error_description="%s" % err)
            return Response(err.to_json(), content="application/json",
                            status="400 Bad Request")

        return _cinfo

    @staticmethod
    def comb_redirect_uris(args):
        if "redirect_uris" not in args:
            return

        val = []
        for base, query in args["redirect_uris"]:
            if query:
                val.append("%s?%s" % (base, query))
            else:
                val.append(base)

        args["redirect_uris"] = val

    #noinspection PyUnusedLocal
    def l_registration_endpoint(self, request, authn=None, **kwargs):
        _log_debug = logger.debug
        _log_info = logger.info

        _log_debug("@registration_endpoint")

        request = RegistrationRequest().deserialize(request, "json")

        _log_info("registration_request:%s" % request.to_dict())
        resp_keys = list(request.keys())

        try:
            request.verify()
        except MessageException as err:
            if "type" not in request:
                return self._error(error="invalid_type",
                                   descr="%s" % err)
            else:
                return self._error(error="invalid_configuration_parameter",
                                   descr="%s" % err)

        _keyjar = self.server.keyjar

        # create new id och secret
        client_id = rndstr(12)
        while client_id in self.cdb:
            client_id = rndstr(12)

        client_secret = secret(self.seed, client_id)

        _rat = rndstr(32)
        reg_enp = ""
        for endp in self.endpoints:
            if isinstance(endp, DynamicClientEndpoint):
                reg_enp = "%s%s" % (self.baseurl, endp.etype)

        self.cdb[client_id] = {
            "client_id": client_id,
            "client_secret": client_secret,
            "registration_access_token": _rat,
            "registration_client_uri": "%s?client_id=%s" % (reg_enp, client_id),
            "client_secret_expires_at": utc_time_sans_frac() + 86400,
            "client_id_issued_at": utc_time_sans_frac()}

        self.cdb[_rat] = client_id

        _cinfo = self.do_client_registration(request, client_id,
                                             ignore=["redirect_uris",
                                                     "policy_url",
                                                     "logo_url"])
        if isinstance(_cinfo, Response):
            return _cinfo

        args = dict([(k, v) for k, v in list(_cinfo.items())
                     if k in RegistrationResponse.c_param])

        self.comb_redirect_uris(args)
        response = RegistrationResponse(**args)

        self.keyjar.load_keys(request, client_id)

        # Add the key to the keyjar
        if client_secret:
            _kc = KeyBundle([{"kty": "oct", "key": client_secret,
                              "use": "ver"},
                             {"kty": "oct", "key": client_secret,
                              "use": "sig"}])
            try:
                _keyjar[client_id].append(_kc)
            except KeyError:
                _keyjar[client_id] = [_kc]

        self.cdb[client_id] = _cinfo
        _log_info("Client info: %s" % _cinfo)

        logger.debug("registration_response: %s" % response.to_dict())

        return Response(response.to_json(), content="application/json",
                        headers=[("Cache-Control", "no-store")])

    def registration_endpoint(self, request, authn=None, **kwargs):
        return self.l_registration_endpoint(request, authn, **kwargs)

    def read_registration(self, authn, request, **kwargs):
        """
        Read all information this server has on a client.
        Authorization is done by using the access token that was return as
        part of the client registration result.

        :param authn: The Authorization HTTP header
        :param request: The query part of the URL
        :param kwargs: Any other arguments
        :return:
        """

        logger.debug("authn: %s, request: %s" % (authn, request))

        # verify the access token, has to be key into the client information
        # database.
        assert authn.startswith("Bearer ")
        token = authn[len("Bearer "):]

        client_id = self.cdb[token]

        # extra check
        _info = urllib.parse.parse_qs(request)
        assert _info["client_id"][0] == client_id

        logger.debug("Client '%s' reads client info" % client_id)
        args = dict([(k, v) for k, v in list(self.cdb[client_id].items())
                     if k in RegistrationResponse.c_param])

        self.comb_redirect_uris(args)
        response = RegistrationResponse(**args)

        return Response(response.to_json(), content="application/json",
                        headers=[("Cache-Control", "no-store")])

    #noinspection PyUnusedLocal
    def providerinfo_endpoint(self, handle="", **kwargs):
        _log_debug = logger.debug
        _log_info = logger.info

        _log_info("@providerinfo_endpoint")
        try:
            _response = self.conf_info

            #for endp in self.endpoints:
            #    _response[endp(None).name] = "%s%s" % (self.baseurl,
            # endp.etype)

            _log_info("provider_info_response: %s" % (_response.to_dict(),))

            headers = [("Cache-Control", "no-store"), ("x-ffo", "bar")]
            if handle:
                (key, timestamp) = handle
                if key.startswith(STR) and key.endswith(STR):
                    cookie = self.cookie_func(key, self.cookie_name, "pinfo",
                                              self.sso_ttl)
                    headers.append(cookie)

            resp = Response(_response.to_json(), content="application/json",
                            headers=headers)
        except Exception as err:
            message = traceback.format_exception(*sys.exc_info())
            logger.error(message)
            resp = Response(message, content="html/text")

        return resp
예제 #12
0
class Provider(provider.Provider):
    """A OAuth2 RP that knows all the OAuth2 extensions I've implemented."""

    def __init__(
        self,
        name,
        sdb,
        cdb,
        authn_broker,
        authz,
        client_authn,
        symkey=None,
        urlmap=None,
        iv=0,
        default_scope="",
        ca_bundle=None,
        seed=b"",
        client_authn_methods=None,
        authn_at_registration="",
        client_info_url="",
        secret_lifetime=86400,
        jwks_uri="",
        keyjar=None,
        capabilities=None,
        verify_ssl=True,
        baseurl="",
        hostname="",
        config=None,
        behavior=None,
        lifetime_policy=None,
        message_factory=ExtensionMessageFactory,
        **kwargs
    ):

        if not name.endswith("/"):
            name += "/"

        try:
            args = {"server_cls": kwargs["server_cls"]}
        except KeyError:
            args = {}

        super().__init__(
            name,
            sdb,
            cdb,
            authn_broker,
            authz,
            client_authn,
            symkey,
            urlmap,
            iv,
            default_scope,
            ca_bundle,
            message_factory=message_factory,
            **args
        )

        self.endp.extend(
            [
                RegistrationEndpoint,
                ClientInfoEndpoint,
                RevocationEndpoint,
                IntrospectionEndpoint,
            ]
        )

        # dictionary of client authentication methods
        self.client_authn_methods = client_authn_methods
        if authn_at_registration:
            if authn_at_registration not in client_authn_methods:
                raise UnknownAuthnMethod(authn_at_registration)

        self.authn_at_registration = authn_at_registration
        self.seed = seed
        self.client_info_url = client_info_url
        self.secret_lifetime = secret_lifetime
        self.jwks_uri = jwks_uri
        self.verify_ssl = verify_ssl
        self.scopes.extend(kwargs.get("scopes", []))
        self.keyjar = keyjar
        if self.keyjar is None:
            self.keyjar = KeyJar(verify_ssl=self.verify_ssl)

        if capabilities:
            self.capabilities = self.provider_features(provider_config=capabilities)
        else:
            self.capabilities = self.provider_features()
        self.baseurl = baseurl or name
        self.hostname = hostname or socket.gethostname()
        self.kid = {"sig": {}, "enc": {}}  # type: Dict[str, Dict[str, str]]
        self.config = config or {}
        self.behavior = behavior or {}
        self.token_policy = {
            "access_token": {},
            "refresh_token": {},
        }  # type: Dict[str, Dict[str, str]]
        if lifetime_policy is None:
            self.lifetime_policy = {
                "access_token": {
                    "code": 600,
                    "token": 120,
                    "implicit": 120,
                    "authorization_code": 600,
                    "client_credentials": 600,
                    "password": 600,
                },
                "refresh_token": {
                    "code": 3600,
                    "token": 3600,
                    "implicit": 3600,
                    "authorization_code": 3600,
                    "client_credentials": 3600,
                    "password": 3600,
                },
            }
        else:
            self.lifetime_policy = lifetime_policy

        self.token_handler = TokenHandler(
            self.baseurl, self.token_policy, keyjar=self.keyjar
        )

    @staticmethod
    def _uris_to_tuples(uris):
        tup = []
        for uri in uris:
            base, query = splitquery(uri)
            if query:
                tup.append((base, query))
            else:
                tup.append((base, ""))
        return tup

    @staticmethod
    def _tuples_to_uris(items):
        _uri = []
        for url, query in items:
            if query:
                _uri.append("%s?%s" % (url, query))
            else:
                _uri.append(url)
        return _uri

    def load_keys(self, request, client_id, client_secret):
        try:
            self.keyjar.load_keys(request, client_id)
            try:
                n_keys = len(self.keyjar[client_id])
                msg = "Found {} keys for client_id={}"
                logger.debug(msg.format(n_keys, client_id))
            except KeyError:
                pass
        except Exception as err:
            msg = "Failed to load client keys: {}"
            logger.error(msg.format(sanitize(request.to_dict())))
            logger.error("%s", err)
            error = ClientRegistrationError(
                error="invalid_configuration_parameter", error_description="%s" % err
            )
            return Response(
                error.to_json(),
                content="application/json",
                status_code="400 Bad Request",
            )

        # Add the client_secret as a symmetric key to the keyjar
        _kc = KeyBundle(
            [
                {"kty": "oct", "key": client_secret, "use": "ver"},
                {"kty": "oct", "key": client_secret, "use": "sig"},
            ]
        )
        try:
            self.keyjar[client_id].append(_kc)
        except KeyError:
            self.keyjar[client_id] = [_kc]

    @staticmethod
    def verify_correct(cinfo, restrictions):
        for fname, arg in restrictions.items():
            func = restrict.factory(fname)
            res = func(arg, cinfo)
            if res:
                raise RestrictionError(res)

    def set_token_policy(self, cid, cinfo):
        for ttyp in ["access_token", "refresh_token"]:
            pol = {}
            for rgtyp in ["response_type", "grant_type"]:
                try:
                    rtyp = cinfo[rgtyp]
                except KeyError:
                    pass
                else:
                    for typ in rtyp:
                        try:
                            pol[typ] = self.lifetime_policy[ttyp][typ]
                        except KeyError:
                            pass

            self.token_policy[ttyp][cid] = pol

    def create_new_client(self, request, restrictions):
        """
        Create new client based on request and restrictions.

        :param request: The Client registration request
        :param restrictions: Restrictions on the client
        :return: The client_id
        """
        _cinfo = request.to_dict()

        self.match_client_request(_cinfo)

        # create new id and secret
        _id = rndstr(12)
        while _id in self.cdb:
            _id = rndstr(12)

        _cinfo["client_id"] = _id
        _cinfo["client_secret"] = secret(self.seed, _id)
        _cinfo["client_id_issued_at"] = utc_time_sans_frac()
        _cinfo["client_secret_expires_at"] = utc_time_sans_frac() + self.secret_lifetime

        # If I support client info endpoint
        if ClientInfoEndpoint in self.endp:
            _cinfo["registration_access_token"] = rndstr(32)
            _cinfo["registration_client_uri"] = "%s%s%s?client_id=%s" % (
                self.name,
                self.client_info_url,
                ClientInfoEndpoint.etype,
                _id,
            )

        if "redirect_uris" in request:
            _cinfo["redirect_uris"] = self._uris_to_tuples(request["redirect_uris"])

        self.load_keys(request, _id, _cinfo["client_secret"])

        try:
            _behav = self.behavior["client_registration"]
        except KeyError:
            pass
        else:
            self.verify_correct(_cinfo, _behav)

        self.set_token_policy(_id, _cinfo)
        self.cdb[_id] = _cinfo

        return _id

    def match_client_request(self, request):
        for _pref, _prov in PREFERENCE2PROVIDER.items():
            if _pref in request:
                if _pref == "response_types":
                    for val in request[_pref]:
                        match = False
                        p = set(val.split(" "))
                        for cv in self.capabilities[_prov]:
                            if p == set(cv.split(" ")):
                                match = True
                                break
                        if not match:
                            raise CapabilitiesMisMatch("Not allowed {}".format(_pref))
                else:
                    if isinstance(request[_pref], str):
                        if request[_pref] not in self.capabilities[_prov]:
                            raise CapabilitiesMisMatch("Not allowed {}".format(_pref))
                    else:
                        if not set(request[_pref]).issubset(
                            set(self.capabilities[_prov])
                        ):
                            raise CapabilitiesMisMatch("Not allowed {}".format(_pref))

    def client_info(self, client_id):
        _cinfo = self.cdb[client_id].copy()
        if not valid_client_info(_cinfo):
            err = ErrorResponse(
                error="invalid_client", error_description="Invalid client secret"
            )
            return BadRequest(err.to_json(), content="application/json")

        try:
            _cinfo["redirect_uris"] = self._tuples_to_uris(_cinfo["redirect_uris"])
        except KeyError:
            pass

        msg = self.server.message_factory.get_response_type("update_endpoint")(**_cinfo)
        return Response(msg.to_json(), content="application/json")

    def client_info_update(self, client_id, request):
        _cinfo = self.cdb[client_id].copy()
        try:
            _cinfo["redirect_uris"] = self._tuples_to_uris(_cinfo["redirect_uris"])
        except KeyError:
            pass

        for key, value in request.items():
            if key in ["client_secret", "client_id"]:
                # assure it's the same
                if value != _cinfo[key]:
                    raise ModificationForbidden("Not allowed to change")
            else:
                _cinfo[key] = value

        for key in list(_cinfo.keys()):
            if key in [
                "client_id_issued_at",
                "client_secret_expires_at",
                "registration_access_token",
                "registration_client_uri",
            ]:
                continue
            if key not in request:
                del _cinfo[key]

        if "redirect_uris" in request:
            _cinfo["redirect_uris"] = self._uris_to_tuples(request["redirect_uris"])

        self.cdb[client_id] = _cinfo

    def verify_client(self, environ, areq, authn_method, client_id=""):
        """
        Verify the client based on credentials.

        :param environ: WSGI environ
        :param areq: The request
        :param authn_method: client authentication method
        :return:
        """
        if not client_id:
            client_id = get_client_id(self.cdb, areq, environ["HTTP_AUTHORIZATION"])

        try:
            method = self.client_authn_methods[authn_method]
        except KeyError:
            raise UnSupported()
        return method(self).verify(environ, client_id=client_id)

    def consume_software_statement(self, software_statement):
        return {}

    def registration_endpoint(self, **kwargs):
        """
        Perform dynamic client registration.

        :param request: The request
        :param authn: Client authentication information
        :param kwargs: extra keyword arguments
        :return: A Response instance
        """
        _request = self.server.message_factory.get_request_type(
            "registration_endpoint"
        )().deserialize(kwargs["request"], "json")
        try:
            _request.verify(keyjar=self.keyjar)
        except InvalidRedirectUri as err:
            msg = ClientRegistrationError(
                error="invalid_redirect_uri", error_description="%s" % err
            )
            return BadRequest(msg.to_json(), content="application/json")
        except (MissingPage, VerificationError) as err:
            msg = ClientRegistrationError(
                error="invalid_client_metadata", error_description="%s" % err
            )
            return BadRequest(msg.to_json(), content="application/json")

        # If authentication is necessary at registration
        if self.authn_at_registration:
            try:
                self.verify_client(
                    kwargs["environ"], _request, self.authn_at_registration
                )
            except (AuthnFailure, UnknownAssertionType):
                return Unauthorized()

        client_restrictions = {}  # type: ignore
        if "parsed_software_statement" in _request:
            for ss in _request["parsed_software_statement"]:
                client_restrictions.update(self.consume_software_statement(ss))
            del _request["software_statement"]
            del _request["parsed_software_statement"]

        try:
            client_id = self.create_new_client(_request, client_restrictions)
        except CapabilitiesMisMatch as err:
            msg = ClientRegistrationError(
                error="invalid_client_metadata", error_description="%s" % err
            )
            return BadRequest(msg.to_json(), content="application/json")
        except RestrictionError as err:
            msg = ClientRegistrationError(
                error="invalid_client_metadata", error_description="%s" % err
            )
            return BadRequest(msg.to_json(), content="application/json")

        return self.client_info(client_id)

    def client_info_endpoint(self, method="GET", **kwargs):
        """
        Operations on this endpoint are switched through the use of different HTTP methods.

        :param method: HTTP method used for the request
        :param kwargs: keyword arguments
        :return: A Response instance
        """
        _query = compact(parse_qs(kwargs["query"]))
        try:
            _id = _query["client_id"]
        except KeyError:
            return BadRequest("Missing query component")

        if _id not in self.cdb:
            return Unauthorized()

        # authenticated client
        try:
            self.verify_client(
                kwargs["environ"], kwargs["request"], "bearer_header", client_id=_id
            )
        except (AuthnFailure, UnknownAssertionType):
            return Unauthorized()

        if method == "GET":
            return self.client_info(_id)
        elif method == "PUT":
            try:
                _request = self.server.message_factory.get_request_type(
                    "update_endpoint"
                )().from_json(kwargs["request"])
            except ValueError as err:
                return BadRequest(str(err))

            try:
                _request.verify()
            except InvalidRedirectUri as err:
                msg = ClientRegistrationError(
                    error="invalid_redirect_uri", error_description="%s" % err
                )
                return BadRequest(msg.to_json(), content="application/json")
            except (MissingPage, VerificationError) as err:
                msg = ClientRegistrationError(
                    error="invalid_client_metadata", error_description="%s" % err
                )
                return BadRequest(msg.to_json(), content="application/json")

            try:
                self.client_info_update(_id, _request)
                return self.client_info(_id)
            except ModificationForbidden:
                return Forbidden()
        elif method == "DELETE":
            try:
                del self.cdb[_id]
            except KeyError:
                return Unauthorized()
            else:
                return NoContent()

    @staticmethod
    def verify_code_challenge(
        code_verifier, code_challenge, code_challenge_method="S256"
    ):
        """
        Verify a PKCE (RFC7636) code challenge.

        :param code_verifier: The origin
        :param code_challenge: The transformed verifier used as challenge
        :return:
        """
        _h = CC_METHOD[code_challenge_method](code_verifier.encode("ascii")).digest()
        _cc = b64e(_h)
        if _cc.decode("ascii") != code_challenge:
            logger.error("PCKE Code Challenge check failed")
            err = TokenErrorResponse(
                error="invalid_request", error_description="PCKE check failed"
            )
            return Response(err.to_json(), content="application/json", status_code=401)
        return True

    def do_access_token_response(self, access_token, atinfo, state, refresh_token=None):
        _tinfo = {
            "access_token": access_token,
            "expires_in": atinfo["exp"],
            "token_type": "bearer",
            "state": state,
        }
        try:
            _tinfo["scope"] = atinfo["scope"]
        except KeyError:
            pass

        if refresh_token:
            _tinfo["refresh_token"] = refresh_token

        atr_class = self.server.message_factory.get_response_type("token_endpoint")
        return atr_class(**by_schema(atr_class, **_tinfo))

    def code_grant_type(self, areq):
        # assert that the code is valid
        try:
            _info = self.sdb[areq["code"]]
        except KeyError:
            err = TokenErrorResponse(
                error="invalid_grant", error_description="Unknown access grant"
            )
            return Response(
                err.to_json(), content="application/json", status="401 Unauthorized"
            )

        authzreq = json.loads(_info["authzreq"])
        if "code_verifier" in areq:
            try:
                _method = authzreq["code_challenge_method"]
            except KeyError:
                _method = "S256"

            resp = self.verify_code_challenge(
                areq["code_verifier"], authzreq["code_challenge"], _method
            )
            if resp:
                return resp

        if "state" in areq:
            if self.sdb[areq["code"]]["state"] != areq["state"]:
                logger.error("State value mismatch")
                err = TokenErrorResponse(error="unauthorized_client")
                return Unauthorized(err.to_json(), content="application/json")

        resp = self.token_scope_check(areq, _info)
        if resp:
            return resp

        # If redirect_uri was in the initial authorization request
        # verify that the one given here is the correct one.
        if "redirect_uri" in _info and areq["redirect_uri"] != _info["redirect_uri"]:
            logger.error("Redirect_uri mismatch")
            err = TokenErrorResponse(error="unauthorized_client")
            return Unauthorized(err.to_json(), content="application/json")

        issue_refresh = False
        if "scope" in authzreq and "offline_access" in authzreq["scope"]:
            if authzreq["response_type"] == "code":
                issue_refresh = True

        try:
            _tinfo = self.sdb.upgrade_to_token(
                areq["code"], issue_refresh=issue_refresh
            )
        except AccessCodeUsed:
            err = TokenErrorResponse(
                error="invalid_grant", error_description="Access grant used"
            )
            return Response(
                err.to_json(), content="application/json", status="401 Unauthorized"
            )

        logger.debug("_tinfo: %s" % _tinfo)

        atr_class = self.server.message_factory.get_response_type("token_endpoint")
        atr = atr_class(**by_schema(atr_class, **_tinfo))

        logger.debug("AccessTokenResponse: %s" % atr)

        return Response(
            atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS
        )

    def client_credentials_grant_type(self, areq):
        _at = self.token_handler.get_access_token(
            areq["client_id"], scope=areq["scope"], grant_type="client_credentials"
        )
        _info = self.token_handler.token_factory.get_info(_at)
        try:
            _rt = self.token_handler.get_refresh_token(
                self.baseurl, _info["access_token"], "client_credentials"
            )
        except NotAllowed:
            atr = self.do_access_token_response(_at, _info, areq["state"])
        else:
            atr = self.do_access_token_response(_at, _info, areq["state"], _rt)

        return Response(atr.to_json(), content="application/json")

    def password_grant_type(self, areq):
        """
        Token authorization using Resource owner password credentials.

        RFC6749 section 4.3
        """
        # `Any` comparison tries a first broker, so we either hit an IndexError or get a method
        try:
            authn, authn_class_ref = self.pick_auth(areq, "any")
        except IndexError:
            err = TokenErrorResponse(error="invalid_grant")
            return Unauthorized(err.to_json(), content="application/json")
        identity, _ts = authn.authenticated_as(
            username=areq["username"], password=areq["password"]
        )
        if identity is None:
            err = TokenErrorResponse(error="invalid_grant")
            return Unauthorized(err.to_json(), content="application/json")
        # We are returning a token
        areq["response_type"] = ["token"]
        authn_event = AuthnEvent(
            identity["uid"],
            identity.get("salt", ""),
            authn_info=authn_class_ref,
            time_stamp=_ts,
        )
        sid = self.setup_session(areq, authn_event, self.cdb[areq["client_id"]])
        _at = self.sdb.upgrade_to_token(self.sdb[sid]["code"], issue_refresh=True)
        atr_class = self.server.message_factory.get_response_type("token_endpoint")
        atr = atr_class(**by_schema(atr_class, **_at))
        return Response(
            atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS
        )

    def refresh_token_grant_type(self, areq):
        at = self.token_handler.refresh_access_token(
            self.baseurl, areq["access_token"], "refresh_token"
        )

        atr_class = self.server.message_factory.get_response_type("token_endpoint")
        atr = atr_class(**by_schema(atr_class, **at))
        return Response(atr.to_json(), content="application/json")

    @staticmethod
    def token_access(endpoint, client_id, token_info):
        # simple rules: if client_id in azp or aud it's allow to introspect
        # to revoke it has to be in azr
        allow = False
        if endpoint == "revocation_endpoint":
            if "azr" in token_info and client_id == token_info["azr"]:
                allow = True
            elif len(token_info["aud"]) == 1 and token_info["aud"] == [client_id]:
                allow = True
        else:  # has to be introspection endpoint
            if "azr" in token_info and client_id == token_info["azr"]:
                allow = True
            elif "aud" in token_info:
                if client_id in token_info["aud"]:
                    allow = True
        return allow

    def get_token_info(self, authn, req, endpoint):
        """
        Parse token for information.

        :param authn:
        :param req:
        :return:
        """
        try:
            client_id = self.client_authn(self, req, authn)
        except FailedAuthentication as err:
            logger.error(err)
            error = TokenErrorResponse(
                error="unauthorized_client", error_description="%s" % err
            )
            return Response(
                error.to_json(), content="application/json", status="401 Unauthorized"
            )

        logger.debug("{}: {} requesting {}".format(endpoint, client_id, req.to_dict()))

        try:
            token_type = req["token_type_hint"]
        except KeyError:
            try:
                _info = self.sdb.token_factory["access_token"].get_info(req["token"])
            except Exception:
                try:
                    _info = self.sdb.token_factory["refresh_token"].get_info(
                        req["token"]
                    )
                except Exception:
                    return self._return_inactive()
                else:
                    token_type = "refresh_token"
            else:
                token_type = "access_token"
        else:
            try:
                _info = self.sdb.token_factory[token_type].get_info(req["token"])
            except Exception:
                return self._return_inactive()

        if not self.token_access(endpoint, client_id, _info):
            return BadRequest()

        return client_id, token_type, _info

    def _return_inactive(self):
        ir = self.server.message_factory.get_response_type("introspection_endpoint")(
            active=False
        )
        return Response(ir.to_json(), content="application/json")

    def revocation_endpoint(self, authn="", request=None, **kwargs):
        """
        Implement RFC7009 allows a client to invalidate an access or refresh token.

        :param authn: Client Authentication information
        :param request: The revocation request
        :param kwargs:
        :return:
        """
        trr = self.server.message_factory.get_request_type(
            "revocation_endpoint"
        )().deserialize(request, "urlencoded")

        resp = self.get_token_info(authn, trr, "revocation_endpoint")

        if isinstance(resp, Response):
            return resp
        else:
            client_id, token_type, _info = resp

        logger.info("{} token revocation: {}".format(client_id, trr.to_dict()))

        try:
            self.sdb.token_factory[token_type].invalidate(trr["token"])
        except KeyError:
            return BadRequest()
        else:
            return Response("OK")

    def introspection_endpoint(self, authn="", request=None, **kwargs):
        """
        Implement RFC7662.

        :param authn: Client Authentication information
        :param request: The introspection request
        :param kwargs:
        :return:
        """
        tir = self.server.message_factory.get_request_type(
            "introspection_endpoint"
        )().deserialize(request, "urlencoded")

        resp = self.get_token_info(authn, tir, "introspection_endpoint")

        if isinstance(resp, Response):
            return resp
        else:
            client_id, token_type, _info = resp

        logger.info("{} token introspection: {}".format(client_id, tir.to_dict()))

        ir = self.server.message_factory.get_response_type("introspection_endpoint")(
            active=self.sdb.token_factory[token_type].is_valid(_info), **_info.to_dict()
        )

        ir.weed()

        return Response(ir.to_json(), content="application/json")
예제 #13
0
def http_request(url, **kwargs):
    return request("GET", url, **kwargs)

info = {"registration_endpoint": "https://openidconnect.info/connect/register",
 "userinfo_endpoint": "https://openidconnect.info/connect/userinfo",
 "token_endpoint_auth_types_supported": ["client_secret_basic", "client_secret_post", "client_secret_jwt", "private_key_jwt"],
 "jwk_url": "https://openidconnect.info/jwk/jwk.json",
 "userinfo_algs_supported": ["HS256", "HS384", "HS512", "RS256", "RS384", "RS512"],
 "user_id_types_supported": ["pairwise", "public"],
 "scopes_supported": ["openid", "profile", "email", "address", "phone"],
 "token_endpoint": "https://openidconnect.info/connect/token",
 "id_token_algs_supported": ["HS256", "HS384", "HS512", "RS256", "RS384", "RS512"],
 "version": "3.0",
 "token_endpoint_auth_algs_supported": ["RS256", "RS384", "RS512"],
 "request_object_algs_supported": ["HS256", "HS384", "HS512", "RS256", "RS384", "RS512"],
 "response_types_supported": ["code", "token", "id_token", "id_token token", "code token", "code id_token", "code id_token token"],
 "authorization_endpoint": "https://openidconnect.info/connect/authorize",
 "acrs_supported": ["1"],
 "check_id_endpoint": "https://openidconnect.info/connect/check_session",
 "x509_url": "https://openidconnect.info/x509/cert.pem",
 "issuer": "https://openidconnect.info"}

pi = ProviderConfigurationResponse(**info)

kj = KeyJar()

kj.load_keys(pi, pi["issuer"])

keys = kj.get("ver", "rsa", "https://openidconnect.info")

print keys
예제 #14
0
def test_provider():
    ks = KeyJar()
    pcr = ProviderConfigurationResponse().from_dict(PROVIDER_INFO)
    ks.load_keys(pcr, "https://connect-op.heroku.com")

    assert ks["https://connect-op.heroku.com"]
예제 #15
0
파일: x_provider.py 프로젝트: rohe/pyuma
class Provider(provider.Provider):
    def __init__(self,
                 name,
                 sdb,
                 cdb,
                 authn_broker,
                 authz,
                 client_authn,
                 symkey,
                 urlmap=None,
                 keyjar=None,
                 hostname="",
                 configuration=None):
        provider.Provider.__init__(self,
                                   name,
                                   sdb,
                                   cdb,
                                   authn_broker,
                                   authz,
                                   client_authn,
                                   symkey=symkey,
                                   urlmap=urlmap)

        self.baseurl = ""
        if keyjar:
            self.keyjar = keyjar
        else:
            self.keyjar = KeyJar()

        self.hostname = hostname or socket.gethostname
        self.jwks_uri = []
        self.endpoints = [
            DynamicClientEndpoint, TokenEndpoint, AuthorizationEndpoint,
            UserEndpoint, ResourceSetRegistrationEndpoint,
            IntrospectionEndpoint, RPTEndpoint, PermissionRegistrationEndpoint
        ]

    def set_authn_broker(self, authn_broker):
        self.authn_broker = authn_broker
        for meth in self.authn_broker:
            meth.srv = self
        if authn_broker:
            self.cookie_func = authn_broker[0].create_cookie
        else:
            self.cookie_func = None

    @staticmethod
    def _verify_url(url, urlset):
        part = urllib.parse.urlparse(url)

        for reg, qp in urlset:
            _part = urllib.parse.urlparse(reg)
            if part.scheme == _part.scheme and part.netloc == _part.netloc:
                return True

        return False

    def do_client_registration(self, request, client_id, ignore=None):
        if ignore is None:
            ignore = []

        _cinfo = self.cdb[client_id].copy()
        logger.debug("_cinfo: %s" % _cinfo)

        for key, val in list(request.items()):
            if key not in ignore:
                _cinfo[key] = val

        if "redirect_uris" in request:
            ruri = []
            for uri in request["redirect_uris"]:
                if urllib.parse.urlparse(uri).fragment:
                    err = ClientRegistrationErrorResponse(
                        error="invalid_configuration_parameter",
                        error_description="redirect_uri contains fragment")
                    return Response(err.to_json(),
                                    content="application/json",
                                    status="400 Bad Request")
                base, query = urllib.parse.splitquery(uri)
                if query:
                    ruri.append((base, urllib.parse.parse_qs(query)))
                else:
                    ruri.append((base, query))
            _cinfo["redirect_uris"] = ruri

        if "sector_identifier_uri" in request:
            si_url = request["sector_identifier_uri"]
            try:
                res = self.server.http_request(si_url)
            except ConnectionError as err:
                logger.error("%s" % err)
                return self._error_response(
                    "invalid_configuration_parameter",
                    descr="Couldn't open sector_identifier_uri")

            if not res:
                return self._error_response(
                    "invalid_configuration_parameter",
                    descr="Couldn't open sector_identifier_uri")

            logger.debug("sector_identifier_uri => %s" % res.text)

            try:
                si_redirects = json.loads(res.text)
            except ValueError:
                return self._error_response(
                    "invalid_configuration_parameter",
                    descr="Error deserializing sector_identifier_uri content")

            if "redirect_uris" in request:
                logger.debug("redirect_uris: %s" % request["redirect_uris"])
                for uri in request["redirect_uris"]:
                    try:
                        assert uri in si_redirects
                    except AssertionError:
                        return self._error_response(
                            "invalid_configuration_parameter",
                            descr="redirect_uri missing from sector_identifiers"
                        )

            _cinfo["si_redirects"] = si_redirects
            _cinfo["sector_id"] = si_url
        elif "redirect_uris" in request:
            if len(request["redirect_uris"]) > 1:
                # check that the hostnames are the same
                host = ""
                for url in request["redirect_uris"]:
                    part = urllib.parse.urlparse(url)
                    _host = part.netloc.split(":")[0]
                    if not host:
                        host = _host
                    else:
                        try:
                            assert host == _host
                        except AssertionError:
                            return self._error_response(
                                "invalid_configuration_parameter",
                                descr=
                                "'sector_identifier_uri' must be registered")

        for item in ["policy_url", "logo_url"]:
            if item in request:
                if self._verify_url(request[item], _cinfo["redirect_uris"]):
                    _cinfo[item] = request[item]
                else:
                    return self._error_response(
                        "invalid_configuration_parameter",
                        descr="%s pointed to illegal URL" % item)

        try:
            self.keyjar.load_keys(request, client_id)
            try:
                logger.debug("keys for %s: [%s]" % (client_id, ",".join(
                    ["%s" % x for x in self.keyjar[client_id]])))
            except KeyError:
                pass
        except Exception as err:
            logger.error("Failed to load client keys: %s" % request.to_dict())
            err = ClientRegistrationErrorResponse(
                error="invalid_configuration_parameter",
                error_description="%s" % err)
            return Response(err.to_json(),
                            content="application/json",
                            status="400 Bad Request")

        return _cinfo

    @staticmethod
    def comb_redirect_uris(args):
        if "redirect_uris" not in args:
            return

        val = []
        for base, query in args["redirect_uris"]:
            if query:
                val.append("%s?%s" % (base, query))
            else:
                val.append(base)

        args["redirect_uris"] = val

    #noinspection PyUnusedLocal
    def l_registration_endpoint(self, request, authn=None, **kwargs):
        _log_debug = logger.debug
        _log_info = logger.info

        _log_debug("@registration_endpoint")

        request = RegistrationRequest().deserialize(request, "json")

        _log_info("registration_request:%s" % request.to_dict())
        resp_keys = list(request.keys())

        try:
            request.verify()
        except MessageException as err:
            if "type" not in request:
                return self._error(error="invalid_type", descr="%s" % err)
            else:
                return self._error(error="invalid_configuration_parameter",
                                   descr="%s" % err)

        _keyjar = self.server.keyjar

        # create new id och secret
        client_id = rndstr(12)
        while client_id in self.cdb:
            client_id = rndstr(12)

        client_secret = secret(self.seed, client_id)

        _rat = rndstr(32)
        reg_enp = ""
        for endp in self.endpoints:
            if isinstance(endp, DynamicClientEndpoint):
                reg_enp = "%s%s" % (self.baseurl, endp.etype)

        self.cdb[client_id] = {
            "client_id": client_id,
            "client_secret": client_secret,
            "registration_access_token": _rat,
            "registration_client_uri":
            "%s?client_id=%s" % (reg_enp, client_id),
            "client_secret_expires_at": utc_time_sans_frac() + 86400,
            "client_id_issued_at": utc_time_sans_frac()
        }

        self.cdb[_rat] = client_id

        _cinfo = self.do_client_registration(
            request,
            client_id,
            ignore=["redirect_uris", "policy_url", "logo_url"])
        if isinstance(_cinfo, Response):
            return _cinfo

        args = dict([(k, v) for k, v in list(_cinfo.items())
                     if k in RegistrationResponse.c_param])

        self.comb_redirect_uris(args)
        response = RegistrationResponse(**args)

        self.keyjar.load_keys(request, client_id)

        # Add the key to the keyjar
        if client_secret:
            _kc = KeyBundle([{
                "kty": "oct",
                "key": client_secret,
                "use": "ver"
            }, {
                "kty": "oct",
                "key": client_secret,
                "use": "sig"
            }])
            try:
                _keyjar[client_id].append(_kc)
            except KeyError:
                _keyjar[client_id] = [_kc]

        self.cdb[client_id] = _cinfo
        _log_info("Client info: %s" % _cinfo)

        logger.debug("registration_response: %s" % response.to_dict())

        return Response(response.to_json(),
                        content="application/json",
                        headers=[("Cache-Control", "no-store")])

    def registration_endpoint(self, request, authn=None, **kwargs):
        return self.l_registration_endpoint(request, authn, **kwargs)

    def read_registration(self, authn, request, **kwargs):
        """
        Read all information this server has on a client.
        Authorization is done by using the access token that was return as
        part of the client registration result.

        :param authn: The Authorization HTTP header
        :param request: The query part of the URL
        :param kwargs: Any other arguments
        :return:
        """

        logger.debug("authn: %s, request: %s" % (authn, request))

        # verify the access token, has to be key into the client information
        # database.
        assert authn.startswith("Bearer ")
        token = authn[len("Bearer "):]

        client_id = self.cdb[token]

        # extra check
        _info = urllib.parse.parse_qs(request)
        assert _info["client_id"][0] == client_id

        logger.debug("Client '%s' reads client info" % client_id)
        args = dict([(k, v) for k, v in list(self.cdb[client_id].items())
                     if k in RegistrationResponse.c_param])

        self.comb_redirect_uris(args)
        response = RegistrationResponse(**args)

        return Response(response.to_json(),
                        content="application/json",
                        headers=[("Cache-Control", "no-store")])

    #noinspection PyUnusedLocal
    def providerinfo_endpoint(self, handle="", **kwargs):
        _log_debug = logger.debug
        _log_info = logger.info

        _log_info("@providerinfo_endpoint")
        try:
            _response = self.conf_info

            #for endp in self.endpoints:
            #    _response[endp(None).name] = "%s%s" % (self.baseurl,
            # endp.etype)

            _log_info("provider_info_response: %s" % (_response.to_dict(), ))

            headers = [("Cache-Control", "no-store"), ("x-ffo", "bar")]
            if handle:
                (key, timestamp) = handle
                if key.startswith(STR) and key.endswith(STR):
                    cookie = self.cookie_func(key, self.cookie_name, "pinfo",
                                              self.sso_ttl)
                    headers.append(cookie)

            resp = Response(_response.to_json(),
                            content="application/json",
                            headers=headers)
        except Exception as err:
            message = traceback.format_exception(*sys.exc_info())
            logger.error(message)
            resp = Response(message, content="html/text")

        return resp
예제 #16
0
파일: dynreg.py 프로젝트: programDEV/pyoidc
class Client(oauth2.Client):
    def __init__(self,
                 client_id=None,
                 ca_certs=None,
                 client_authn_method=None,
                 keyjar=None,
                 verify_ssl=True):
        oauth2.Client.__init__(self,
                               client_id=client_id,
                               ca_certs=ca_certs,
                               client_authn_method=client_authn_method,
                               keyjar=keyjar,
                               verify_ssl=verify_ssl)
        self.allow = {}
        self.request2endpoint.update({
            "RegistrationRequest":
            "registration_endpoint",
            "ClientUpdateRequest":
            "clientinfo_endpoint"
        })
        self.registration_response = None

    def construct_RegistrationRequest(self,
                                      request=RegistrationRequest,
                                      request_args=None,
                                      extra_args=None,
                                      **kwargs):

        if request_args is None:
            request_args = {}

        return self.construct_request(request, request_args, extra_args)

    def do_client_registration(self,
                               request=RegistrationRequest,
                               body_type="",
                               method="GET",
                               request_args=None,
                               extra_args=None,
                               http_args=None,
                               response_cls=ClientInfoResponse,
                               **kwargs):

        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        resp = self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       http_args=http_args)

        return resp

    def do_client_read_request(self,
                               request=ClientUpdateRequest,
                               body_type="",
                               method="GET",
                               request_args=None,
                               extra_args=None,
                               http_args=None,
                               response_cls=ClientInfoResponse,
                               **kwargs):

        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        resp = self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       http_args=http_args)

        return resp

    def do_client_update_request(self,
                                 request=ClientUpdateRequest,
                                 body_type="",
                                 method="PUT",
                                 request_args=None,
                                 extra_args=None,
                                 http_args=None,
                                 response_cls=ClientInfoResponse,
                                 **kwargs):

        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        resp = self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       http_args=http_args)

        return resp

    def do_client_delete_request(self,
                                 request=ClientUpdateRequest,
                                 body_type="",
                                 method="DELETE",
                                 request_args=None,
                                 extra_args=None,
                                 http_args=None,
                                 response_cls=ClientInfoResponse,
                                 **kwargs):

        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        resp = self.request_and_return(url,
                                       response_cls,
                                       method,
                                       body,
                                       body_type,
                                       http_args=http_args)

        return resp

    def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True):
        """
        Deal with Provider Config Response
        :param pcr: The ProviderConfigResponse instance
        :param issuer: The one I thought should be the issuer of the config
        :param keys: Should I deal with keys
        :param endpoints: Should I deal with endpoints, that is store them
        as attributes in self.
        """

        if "issuer" in pcr:
            _pcr_issuer = pcr["issuer"]
            if pcr["issuer"].endswith("/"):
                if issuer.endswith("/"):
                    _issuer = issuer
                else:
                    _issuer = issuer + "/"
            else:
                if issuer.endswith("/"):
                    _issuer = issuer[:-1]
                else:
                    _issuer = issuer

            try:
                _ = self.allow["issuer_mismatch"]
            except KeyError:
                try:
                    assert _issuer == _pcr_issuer
                except AssertionError:
                    raise PyoidcError(
                        "provider info issuer mismatch '%s' != '%s'" %
                        (_issuer, _pcr_issuer))

            self.provider_info[_pcr_issuer] = pcr
        else:
            _pcr_issuer = issuer

        if endpoints:
            for key, val in pcr.items():
                if key.endswith("_endpoint"):
                    setattr(self, key, val)

        if keys:
            if self.keyjar is None:
                self.keyjar = KeyJar()

            self.keyjar.load_keys(pcr, _pcr_issuer)

    def provider_config(self,
                        issuer,
                        keys=True,
                        endpoints=True,
                        response_cls=ProviderConfigurationResponse,
                        serv_pattern=OIDCONF_PATTERN):
        if issuer.endswith("/"):
            _issuer = issuer[:-1]
        else:
            _issuer = issuer

        url = serv_pattern % _issuer

        pcr = None
        r = self.http_request(url)
        if r.status_code == 200:
            pcr = response_cls().from_json(r.text)
        elif r.status_code == 302:
            while r.status_code == 302:
                r = self.http_request(r.headers["location"])
                if r.status_code == 200:
                    pcr = response_cls().from_json(r.text)
                    break

        if pcr is None:
            raise PyoidcError("Trying '%s', status %s" % (url, r.status_code))

        self.handle_provider_config(pcr, issuer, keys, endpoints)

        return pcr

    def store_registration_info(self, reginfo):
        self.registration_response = reginfo
        self.client_secret = reginfo["client_secret"]
        self.client_id = reginfo["client_id"]
        self.redirect_uris = reginfo["redirect_uris"]

    def handle_registration_info(self, response):
        if response.status_code == 200:
            resp = ClientInfoResponse().deserialize(response.text, "json")
            self.store_registration_info(resp)
        else:
            err = ErrorResponse().deserialize(response.text, "json")
            raise PyoidcError("Registration failed: %s" % err.get_json())

        return resp

    def register(self, url, **kwargs):
        """
        Register the client at an OP

        :param url: The OPs registration endpoint
        :param kwargs: parameters to the registration request
        :return:
        """
        req = self.construct_RegistrationRequest(request_args=kwargs)

        headers = {"content-type": "application/json"}

        rsp = self.http_request(url,
                                "POST",
                                data=req.to_json(),
                                headers=headers)

        return self.handle_registration_info(rsp)

    def parse_authz_response(self, query):
        aresp = self.parse_response(AuthorizationResponse,
                                    info=query,
                                    sformat="urlencoded",
                                    keyjar=self.keyjar)
        if aresp.type() == "ErrorResponse":
            logger.info("ErrorResponse: %s" % aresp)
            raise AuthzError(aresp.error)

        logger.info("Aresp: %s" % aresp)

        return aresp
예제 #17
0
class Client(oauth2.Client):
    def __init__(self, client_id=None, ca_certs=None,
                 client_authn_method=None, keyjar=None, verify_ssl=True,
                 config=None):
        oauth2.Client.__init__(self, client_id=client_id, ca_certs=ca_certs,
                               client_authn_method=client_authn_method,
                               keyjar=keyjar, verify_ssl=verify_ssl,
                               config=config)
        self.allow = {}
        self.request2endpoint.update({
            "RegistrationRequest": "registration_endpoint",
            "ClientUpdateRequest": "clientinfo_endpoint",
            'TokenIntrospectionRequest': 'introspection_endpoint',
            'TokenRevocationRequest': 'revocation_endpoint'
        })
        self.registration_response = None

    def construct_RegistrationRequest(self, request=RegistrationRequest,
                                      request_args=None, extra_args=None,
                                      **kwargs):

        if request_args is None:
            request_args = {}

        return self.construct_request(request, request_args, extra_args)

    def construct_ClientUpdateRequest(self, request=ClientUpdateRequest,
                                      request_args=None, extra_args=None,
                                      **kwargs):

        if request_args is None:
            request_args = {}

        return self.construct_request(request, request_args, extra_args)

    def _token_interaction_setup(self, request_args=None, **kwargs):
        if request_args is None or 'token' not in request_args:
            token = self.get_token(**kwargs)
            try:
                _token_type_hint = kwargs['token_type_hint']
            except KeyError:
                _token_type_hint = 'access_token'

            request_args = {'token_type_hint': _token_type_hint,
                            'token': getattr(token, _token_type_hint)}

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id

        return request_args

    def construct_TokenIntrospectionRequest(self,
                                            request=TokenIntrospectionRequest,
                                            request_args=None, extra_args=None,
                                            **kwargs):
        request_args = self._token_interaction_setup(request_args, **kwargs)
        return self.construct_request(request, request_args, extra_args)

    def construct_TokenRevocationRequest(self,
                                         request=TokenRevocationRequest,
                                         request_args=None, extra_args=None,
                                         **kwargs):

        request_args = self._token_interaction_setup(request_args, **kwargs)

        return self.construct_request(request, request_args, extra_args)

    def do_op(self, request, body_type='', method='GET', request_args=None,
              extra_args=None, http_args=None, response_cls=None, **kwargs):

        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        resp = self.request_and_return(url, response_cls, method, body,
                                       body_type, http_args=http_args)

        return resp

    def do_client_registration(self, request=RegistrationRequest,
                               body_type="", method="GET",
                               request_args=None, extra_args=None,
                               http_args=None,
                               response_cls=ClientInfoResponse,
                               **kwargs):

        return self.do_op(request=request, body_type=body_type, method=method,
                          request_args=request_args, extra_args=extra_args,
                          http_args=http_args, response_cls=response_cls,
                          **kwargs)

    def do_client_read_request(self, request=ClientUpdateRequest,
                               body_type="", method="GET",
                               request_args=None, extra_args=None,
                               http_args=None,
                               response_cls=ClientInfoResponse,
                               **kwargs):

        return self.do_op(request=request, body_type=body_type, method=method,
                          request_args=request_args, extra_args=extra_args,
                          http_args=http_args, response_cls=response_cls,
                          **kwargs)

    def do_client_update_request(self, request=ClientUpdateRequest,
                                 body_type="", method="PUT",
                                 request_args=None, extra_args=None,
                                 http_args=None,
                                 response_cls=ClientInfoResponse,
                                 **kwargs):

        return self.do_op(request=request, body_type=body_type, method=method,
                          request_args=request_args, extra_args=extra_args,
                          http_args=http_args, response_cls=response_cls,
                          **kwargs)

    def do_client_delete_request(self, request=ClientUpdateRequest,
                                 body_type="", method="DELETE",
                                 request_args=None, extra_args=None,
                                 http_args=None,
                                 response_cls=ClientInfoResponse,
                                 **kwargs):

        return self.do_op(request=request, body_type=body_type, method=method,
                          request_args=request_args, extra_args=extra_args,
                          http_args=http_args, response_cls=response_cls,
                          **kwargs)

    def do_token_introspection(
            self, request=TokenIntrospectionRequest, body_type="json",
            method="POST", request_args=None, extra_args=None,
            http_args=None, response_cls=TokenIntrospectionResponse, **kwargs):

        return self.do_op(request=request, body_type=body_type, method=method,
                          request_args=request_args, extra_args=extra_args,
                          http_args=http_args, response_cls=response_cls,
                          **kwargs)

    def do_token_revocation(
            self, request=TokenRevocationRequest, body_type="",
            method="POST", request_args=None, extra_args=None,
            http_args=None, response_cls=None, **kwargs):

        return self.do_op(request=request, body_type=body_type, method=method,
                          request_args=request_args, extra_args=extra_args,
                          http_args=http_args, response_cls=response_cls,
                          **kwargs)

    def add_code_challenge(self):
        try:
            cv_len = self.config['code_challenge']['length']
        except KeyError:
            cv_len = 64  # Use default

        code_verifier = unreserved(cv_len)
        _cv = code_verifier.encode()

        try:
            _method = self.config['code_challenge']['method']
        except KeyError:
            _method = 'S256'

        try:
            _h = CC_METHOD[_method](_cv).hexdigest()
            code_challenge = b64e(_h.encode()).decode()
        except KeyError:
            raise Unsupported(
                'PKCE Transformation method:{}'.format(_method))

        # TODO store code_verifier

        return {"code_challenge": code_challenge,
                "code_challenge_method": _method}, code_verifier

    def do_authorization_request(
            self, request=AuthorizationRequest, state="", body_type="",
            method="GET", request_args=None, extra_args=None, http_args=None,
            response_cls=AuthorizationResponse, **kwargs):

        if 'code_challenge' in self.config and self.config['code_challenge']:
            _args, code_verifier = self.add_code_challenge()
            request_args.update(_args)

        oauth2.Client.do_authorization_request(self,
                                               request=request, state=state,
                                               body_type=body_type,
                                               method=method,
                                               request_args=request_args,
                                               extra_args=extra_args,
                                               http_args=http_args,
                                               response_cls=response_cls,
                                               **kwargs)

    def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True):
        """
        Deal with Provider Config Response
        :param pcr: The ProviderConfigResponse instance
        :param issuer: The one I thought should be the issuer of the config
        :param keys: Should I deal with keys
        :param endpoints: Should I deal with endpoints, that is store them
        as attributes in self.
        """

        if "issuer" in pcr:
            _pcr_issuer = pcr["issuer"]
            if pcr["issuer"].endswith("/"):
                if issuer.endswith("/"):
                    _issuer = issuer
                else:
                    _issuer = issuer + "/"
            else:
                if issuer.endswith("/"):
                    _issuer = issuer[:-1]
                else:
                    _issuer = issuer

            try:
                _ = self.allow["issuer_mismatch"]
            except KeyError:
                try:
                    assert _issuer == _pcr_issuer
                except AssertionError:
                    raise PyoidcError(
                        "provider info issuer mismatch '%s' != '%s'" % (
                            _issuer, _pcr_issuer))

            self.provider_info = pcr
        else:
            _pcr_issuer = issuer

        if endpoints:
            for key, val in pcr.items():
                if key.endswith("_endpoint"):
                    setattr(self, key, val)

        if keys:
            if self.keyjar is None:
                self.keyjar = KeyJar()

            self.keyjar.load_keys(pcr, _pcr_issuer)

    def provider_config(self, issuer, keys=True, endpoints=True,
                        response_cls=ASConfigurationResponse,
                        serv_pattern=OIDCONF_PATTERN):
        if issuer.endswith("/"):
            _issuer = issuer[:-1]
        else:
            _issuer = issuer

        url = serv_pattern % _issuer

        pcr = None
        r = self.http_request(url)
        if r.status_code == 200:
            pcr = response_cls().from_json(r.text)
        elif r.status_code == 302:
            while r.status_code == 302:
                r = self.http_request(r.headers["location"])
                if r.status_code == 200:
                    pcr = response_cls().from_json(r.text)
                    break

        if pcr is None:
            raise PyoidcError("Trying '%s', status %s" % (url, r.status_code))

        self.handle_provider_config(pcr, issuer, keys, endpoints)

        return pcr

    def store_registration_info(self, reginfo):
        self.registration_response = reginfo
        self.client_secret = reginfo["client_secret"]
        self.client_id = reginfo["client_id"]
        self.redirect_uris = reginfo["redirect_uris"]

    def handle_registration_info(self, response):
        if response.status_code in SUCCESSFUL:
            resp = ClientInfoResponse().deserialize(response.text, "json")
            self.store_response(resp, response.text)
            self.store_registration_info(resp)
        else:
            resp = ErrorResponse().deserialize(response.text, "json")
            try:
                resp.verify()
                self.store_response(resp, response.text)
            except Exception as err:
                raise PyoidcError(
                    'Registration failed: {}'.format(response.text))

        return resp

    def register(self, url, **kwargs):
        """
        Register the client at an OP

        :param url: The OPs registration endpoint
        :param kwargs: parameters to the registration request
        :return:
        """
        req = self.construct_RegistrationRequest(request_args=kwargs)

        headers = {"content-type": "application/json"}

        rsp = self.http_request(url, "POST", data=req.to_json(),
                                headers=headers)

        return self.handle_registration_info(rsp)

    def parse_authz_response(self, query):
        aresp = self.parse_response(AuthorizationResponse,
                                    info=query,
                                    sformat="urlencoded",
                                    keyjar=self.keyjar)
        if aresp.type() == "ErrorResponse":
            logger.info("ErrorResponse: %s" % aresp)
            raise AuthzError(aresp.error)

        logger.info("Aresp: %s" % aresp)

        return aresp
예제 #18
0
class Client(PBase):
    _endpoints = ENDPOINTS

    def __init__(self, client_id=None, client_authn_method=None,
                 keyjar=None, verify_ssl=True, config=None, client_cert=None,
                 timeout=5):
        """

        :param client_id: The client identifier
        :param client_authn_method: Methods that this client can use to
            authenticate itself. It's a dictionary with method names as
            keys and method classes as values.
        :param keyjar: The keyjar for this client.
        :param verify_ssl: Whether the SSL certificate should be verified.
        :param client_cert: A client certificate to use.
        :param timeout: Timeout for requests library. Can be specified either as
            a single integer or as a tuple of integers. For more details, refer to
            ``requests`` documentation.
        :return: Client instance
        """

        PBase.__init__(self, verify_ssl=verify_ssl, keyjar=keyjar,
                       client_cert=client_cert, timeout=timeout)

        self.client_id = client_id
        self.client_authn_method = client_authn_method

        self.nonce = None

        self.grant = {}
        self.state2nonce = {}
        # own endpoint
        self.redirect_uris = [None]

        # service endpoints
        self.authorization_endpoint = None
        self.token_endpoint = None
        self.token_revocation_endpoint = None

        self.request2endpoint = REQUEST2ENDPOINT
        self.response2error = RESPONSE2ERROR
        self.grant_class = Grant
        self.token_class = Token

        self.provider_info = {}
        self._c_secret = None
        self.kid = {"sig": {}, "enc": {}}
        self.authz_req = None

        # the OAuth issuer is the URL of the authorization server's
        # configuration information location
        self.config = config or {}
        try:
            self.issuer = self.config['issuer']
        except KeyError:
            self.issuer = ''
        self.allow = {}
        self.provider_info = {}

    def store_response(self, clinst, text):
        pass

    def get_client_secret(self):
        return self._c_secret

    def set_client_secret(self, val):
        if not val:
            self._c_secret = ""
        else:
            self._c_secret = val
            # client uses it for signing
            # Server might also use it for signing which means the
            # client uses it for verifying server signatures
            if self.keyjar is None:
                self.keyjar = KeyJar()
            self.keyjar.add_symmetric("", str(val))

    client_secret = property(get_client_secret, set_client_secret)

    def reset(self):
        self.nonce = None

        self.grant = {}

        self.authorization_endpoint = None
        self.token_endpoint = None
        self.redirect_uris = None

    def grant_from_state(self, state):
        for key, grant in self.grant.items():
            if key == state:
                return grant

        return None

    def _parse_args(self, request, **kwargs):
        ar_args = kwargs.copy()

        for prop in request.c_param.keys():
            if prop in ar_args:
                continue
            else:
                if prop == "redirect_uri":
                    _val = getattr(self, "redirect_uris", [None])[0]
                    if _val:
                        ar_args[prop] = _val
                else:
                    _val = getattr(self, prop, None)
                    if _val:
                        ar_args[prop] = _val

        return ar_args

    def _endpoint(self, endpoint, **kwargs):
        try:
            uri = kwargs[endpoint]
            if uri:
                del kwargs[endpoint]
        except KeyError:
            uri = ""

        if not uri:
            try:
                uri = getattr(self, endpoint)
            except Exception:
                raise MissingEndpoint("No '%s' specified" % endpoint)

        if not uri:
            raise MissingEndpoint("No '%s' specified" % endpoint)

        return uri

    def get_grant(self, state, **kwargs):
        try:
            return self.grant[state]
        except KeyError:
            raise GrantError("No grant found for state:'%s'" % state)

    def get_token(self, also_expired=False, **kwargs):
        try:
            return kwargs["token"]
        except KeyError:
            grant = self.get_grant(**kwargs)

            try:
                token = grant.get_token(kwargs["scope"])
            except KeyError:
                token = grant.get_token("")
                if not token:
                    try:
                        token = self.grant[kwargs["state"]].get_token("")
                    except KeyError:
                        raise TokenError("No token found for scope")

        if token is None:
            raise TokenError("No suitable token found")

        if also_expired:
            return token
        elif token.is_valid():
            return token
        else:
            raise TokenError("Token has expired")

    def clean_tokens(self):
        """Clean replaced and invalid tokens."""
        for state in self.grant:
            grant = self.get_grant(state)
            for token in grant.tokens:
                if token.replaced or not token.is_valid():
                    grant.delete_token(token)

    def construct_request(self, request, request_args=None, extra_args=None):
        if request_args is None:
            request_args = {}

        kwargs = self._parse_args(request, **request_args)

        if extra_args:
            kwargs.update(extra_args)
        logger.debug("request: %s" % sanitize(request))
        return request(**kwargs)

    def construct_Message(self, request=Message, request_args=None,
                          extra_args=None, **kwargs):

        return self.construct_request(request, request_args, extra_args)

    def construct_AuthorizationRequest(self, request=AuthorizationRequest,
                                       request_args=None, extra_args=None,
                                       **kwargs):

        if request_args is not None:
            try:  # change default
                new = request_args["redirect_uri"]
                if new:
                    self.redirect_uris = [new]
            except KeyError:
                pass
        else:
            request_args = {}

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id

        return self.construct_request(request, request_args, extra_args)

    def construct_AccessTokenRequest(self,
                                     request=AccessTokenRequest,
                                     request_args=None, extra_args=None,
                                     **kwargs):

        if request_args is None:
            request_args = {}
        if request is not ROPCAccessTokenRequest:
            grant = self.get_grant(**kwargs)

            if not grant.is_valid():
                raise GrantExpired("Authorization Code to old %s > %s" % (
                    utc_time_sans_frac(),
                    grant.grant_expiration_time))

            request_args["code"] = grant.code

        try:
            request_args['state'] = kwargs['state']
        except KeyError:
            pass

        if "grant_type" not in request_args:
            request_args["grant_type"] = "authorization_code"

        if "client_id" not in request_args:
            request_args["client_id"] = self.client_id
        elif not request_args["client_id"]:
            request_args["client_id"] = self.client_id
        return self.construct_request(request, request_args, extra_args)

    def construct_RefreshAccessTokenRequest(self,
                                            request=RefreshAccessTokenRequest,
                                            request_args=None, extra_args=None,
                                            **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(also_expired=True, **kwargs)

        request_args["refresh_token"] = token.refresh_token

        try:
            request_args["scope"] = token.scope
        except AttributeError:
            pass

        return self.construct_request(request, request_args, extra_args)

    def construct_ResourceRequest(self, request=ResourceRequest,
                                  request_args=None, extra_args=None,
                                  **kwargs):

        if request_args is None:
            request_args = {}

        token = self.get_token(**kwargs)

        request_args["access_token"] = token.access_token
        return self.construct_request(request, request_args, extra_args)

    def uri_and_body(self, reqmsg, cis, method="POST", request_args=None,
                     **kwargs):

        if "endpoint" in kwargs and kwargs["endpoint"]:
            uri = kwargs["endpoint"]
        else:
            uri = self._endpoint(self.request2endpoint[reqmsg.__name__],
                                 **request_args)

        uri, body, kwargs = get_or_post(uri, method, cis, **kwargs)
        try:
            h_args = {"headers": kwargs["headers"]}
        except KeyError:
            h_args = {}

        return uri, body, h_args, cis

    def request_info(self, request, method="POST", request_args=None,
                     extra_args=None, lax=False, **kwargs):

        if request_args is None:
            request_args = {}

        try:
            cls = getattr(self, "construct_%s" % request.__name__)
            cis = cls(request_args=request_args, extra_args=extra_args,
                      **kwargs)
        except AttributeError:
            cis = self.construct_request(request, request_args, extra_args)

        if self.events:
            self.events.store('Protocol request', cis)

        if 'nonce' in cis and 'state' in cis:
            self.state2nonce[cis['state']] = cis['nonce']

        cis.lax = lax

        if "authn_method" in kwargs:
            h_arg = self.init_authentication_method(cis,
                                                    request_args=request_args,
                                                    **kwargs)
        else:
            h_arg = None

        if h_arg:
            if "headers" in kwargs.keys():
                kwargs["headers"].update(h_arg["headers"])
            else:
                kwargs["headers"] = h_arg["headers"]

        return self.uri_and_body(request, cis, method, request_args,
                                 **kwargs)

    def authorization_request_info(self, request_args=None, extra_args=None,
                                   **kwargs):
        return self.request_info(AuthorizationRequest, "GET",
                                 request_args, extra_args, **kwargs)

    def get_urlinfo(self, info):
        if '?' in info or '#' in info:
            parts = urlparse(info)
            scheme, netloc, path, params, query, fragment = parts[:6]
            # either query of fragment
            if query:
                info = query
            else:
                info = fragment
        return info

    def parse_response(self, response, info="", sformat="json", state="",
                       **kwargs):
        """
        Parse a response

        :param response: Response type
        :param info: The response, can be either in a JSON or an urlencoded
            format
        :param sformat: Which serialization that was used
        :param state: The state
        :param kwargs: Extra key word arguments
        :return: The parsed and to some extend verified response
        """

        _r2e = self.response2error

        if sformat == "urlencoded":
            info = self.get_urlinfo(info)

        resp = response().deserialize(info, sformat, **kwargs)
        msg = 'Initial response parsing => "{}"'
        logger.debug(msg.format(sanitize(resp.to_dict())))
        if self.events:
            self.events.store('Response', resp.to_dict())

        if "error" in resp and not isinstance(resp, ErrorResponse):
            resp = None
            try:
                errmsgs = _r2e[response.__name__]
            except KeyError:
                errmsgs = [ErrorResponse]

            try:
                for errmsg in errmsgs:
                    try:
                        resp = errmsg().deserialize(info, sformat)
                        resp.verify()
                        break
                    except Exception:
                        resp = None
            except KeyError:
                pass
        elif resp.only_extras():
            resp = None
        else:
            kwargs["client_id"] = self.client_id
            try:
                kwargs['iss'] = self.provider_info['issuer']
            except (KeyError, AttributeError):
                if self.issuer:
                    kwargs['iss'] = self.issuer

            if "key" not in kwargs and "keyjar" not in kwargs:
                kwargs["keyjar"] = self.keyjar

            logger.debug("Verify response with {}".format(sanitize(kwargs)))
            verf = resp.verify(**kwargs)

            if not verf:
                logger.error('Verification of the response failed')
                raise PyoidcError("Verification of the response failed")
            if resp.type() == "AuthorizationResponse" and "scope" not in resp:
                try:
                    resp["scope"] = kwargs["scope"]
                except KeyError:
                    pass

        if not resp:
            logger.error('Missing or faulty response')
            raise ResponseError("Missing or faulty response")

        self.store_response(resp, info)

        if resp.type() in ["AuthorizationResponse", "AccessTokenResponse"]:
            try:
                _state = resp["state"]
            except (AttributeError, KeyError):
                _state = ""

            if not _state:
                _state = state

            try:
                self.grant[_state].update(resp)
            except KeyError:
                self.grant[_state] = self.grant_class(resp=resp)

        return resp

    def init_authentication_method(self, cis, authn_method, request_args=None,
                                   http_args=None, **kwargs):

        if http_args is None:
            http_args = {}
        if request_args is None:
            request_args = {}

        if authn_method:
            return self.client_authn_method[authn_method](self).construct(
                cis, request_args, http_args, **kwargs)
        else:
            return http_args

    def parse_request_response(self, reqresp, response, body_type, state="",
                               **kwargs):

        if reqresp.status_code in SUCCESSFUL:
            body_type = verify_header(reqresp, body_type)
        elif reqresp.status_code in [302, 303]:  # redirect
            return reqresp
        elif reqresp.status_code == 500:
            logger.error("(%d) %s" % (reqresp.status_code,
                                      sanitize(reqresp.text)))
            raise ParseError("ERROR: Something went wrong: %s" % reqresp.text)
        elif reqresp.status_code in [400, 401]:
            # expecting an error response
            if issubclass(response, ErrorResponse):
                pass
        else:
            logger.error("(%d) %s" % (reqresp.status_code,
                                      sanitize(reqresp.text)))
            raise HttpError("HTTP ERROR: %s [%s] on %s" % (
                reqresp.text, reqresp.status_code, reqresp.url))

        if response:
            if body_type == 'txt':
                # no meaning trying to parse unstructured text
                return reqresp.text
            return self.parse_response(response, reqresp.text, body_type,
                                       state, **kwargs)

        # could be an error response
        if reqresp.status_code in [200, 400, 401]:
            if body_type == 'txt':
                body_type = 'urlencoded'
            try:
                err = ErrorResponse().deserialize(reqresp.message,
                                                  method=body_type)
                try:
                    err.verify()
                except PyoidcError:
                    pass
                else:
                    return err
            except Exception:
                pass

        return reqresp

    def request_and_return(self, url, response=None, method="GET", body=None,
                           body_type="json", state="", http_args=None,
                           **kwargs):
        """
        :param url: The URL to which the request should be sent
        :param response: Response type
        :param method: Which HTTP method to use
        :param body: A message body if any
        :param body_type: The format of the body of the return message
        :param http_args: Arguments for the HTTP client
        :return: A cls or ErrorResponse instance or the HTTP response
            instance if no response body was expected.
        """

        if http_args is None:
            http_args = {}

        try:
            resp = self.http_request(url, method, data=body, **http_args)
        except Exception:
            raise

        if "keyjar" not in kwargs:
            kwargs["keyjar"] = self.keyjar

        return self.parse_request_response(resp, response, body_type, state,
                                           **kwargs)

    def do_authorization_request(self, request=AuthorizationRequest,
                                 state="", body_type="", method="GET",
                                 request_args=None, extra_args=None,
                                 http_args=None,
                                 response_cls=AuthorizationResponse,
                                 **kwargs):

        if state:
            try:
                request_args["state"] = state
            except TypeError:
                request_args = {"state": state}

        kwargs['authn_endpoint'] = 'authorization'
        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        try:
            self.authz_req[request_args["state"]] = csi
        except TypeError:
            pass

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        try:
            algs = kwargs["algs"]
        except KeyError:
            algs = {}

        resp = self.request_and_return(url, response_cls, method, body,
                                       body_type, state=state,
                                       http_args=http_args, algs=algs)

        if isinstance(resp, Message):
            if resp.type() in RESPONSE2ERROR["AuthorizationResponse"]:
                resp.state = csi.state

        return resp

    def do_access_token_request(self, request=AccessTokenRequest,
                                scope="", state="", body_type="json",
                                method="POST", request_args=None,
                                extra_args=None, http_args=None,
                                response_cls=AccessTokenResponse,
                                authn_method="", **kwargs):

        kwargs['authn_endpoint'] = 'token'
        # method is default POST
        url, body, ht_args, csi = self.request_info(request, method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    scope=scope, state=state,
                                                    authn_method=authn_method,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        if self.events is not None:
            self.events.store('request_url', url)
            self.events.store('request_http_args', http_args)
            self.events.store('Request', body)

        logger.debug("<do_access_token> URL: %s, Body: %s" % (url,
                                                              sanitize(body)))
        logger.debug("<do_access_token> response_cls: %s" % response_cls)

        return self.request_and_return(url, response_cls, method, body,
                                       body_type, state=state,
                                       http_args=http_args, **kwargs)

    def do_access_token_refresh(self, request=RefreshAccessTokenRequest,
                                state="", body_type="json", method="POST",
                                request_args=None, extra_args=None,
                                http_args=None,
                                response_cls=AccessTokenResponse,
                                authn_method="", **kwargs):

        token = self.get_token(also_expired=True, state=state, **kwargs)
        kwargs['authn_endpoint'] = 'refresh'
        url, body, ht_args, csi = self.request_info(request, method=method,
                                                    request_args=request_args,
                                                    extra_args=extra_args,
                                                    token=token,
                                                    authn_method=authn_method)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        response = self.request_and_return(url, response_cls, method, body,
                                           body_type, state=state,
                                           http_args=http_args)
        if token.replaced:
            grant = self.get_grant(state)
            grant.delete_token(token)
        return response

    def do_any(self, request, endpoint="", scope="", state="", body_type="json",
               method="POST", request_args=None, extra_args=None,
               http_args=None, response=None, authn_method=""):

        url, body, ht_args, _ = self.request_info(request, method=method,
                                                  request_args=request_args,
                                                  extra_args=extra_args,
                                                  scope=scope, state=state,
                                                  authn_method=authn_method,
                                                  endpoint=endpoint)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(ht_args)

        return self.request_and_return(url, response, method, body, body_type,
                                       state=state, http_args=http_args)

    def fetch_protected_resource(self, uri, method="GET", headers=None,
                                 state="", **kwargs):

        if "token" in kwargs and kwargs["token"]:
            token = kwargs["token"]
            request_args = {"access_token": token}
        else:
            try:
                token = self.get_token(state=state, **kwargs)
            except ExpiredToken:
                # The token is to old, refresh
                self.do_access_token_refresh(state=state)
                token = self.get_token(state=state, **kwargs)
            request_args = {"access_token": token.access_token}

        if headers is None:
            headers = {}

        if "authn_method" in kwargs:
            http_args = self.init_authentication_method(
                request_args=request_args, **kwargs)
        else:
            # If nothing defined this is the default
            http_args = self.client_authn_method[
                "bearer_header"](self).construct(request_args=request_args)

        headers.update(http_args["headers"])

        logger.debug("Fetch URI: %s" % uri)
        return self.http_request(uri, method, headers=headers)

    def add_code_challenge(self):
        """
        PKCE RFC 7636 support

        :return:
        """
        try:
            cv_len = self.config['code_challenge']['length']
        except KeyError:
            cv_len = 64  # Use default

        code_verifier = unreserved(cv_len)
        _cv = code_verifier.encode('ascii')

        try:
            _method = self.config['code_challenge']['method']
        except KeyError:
            _method = 'S256'

        try:
            _h = CC_METHOD[_method](_cv).digest()
            code_challenge = b64e(_h).decode('ascii')
        except KeyError:
            raise Unsupported(
                'PKCE Transformation method:{}'.format(_method))

        # TODO store code_verifier

        return {"code_challenge": code_challenge,
                "code_challenge_method": _method}, code_verifier

    def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True):
        """
        Deal with Provider Config Response
        :param pcr: The ProviderConfigResponse instance
        :param issuer: The one I thought should be the issuer of the config
        :param keys: Should I deal with keys
        :param endpoints: Should I deal with endpoints, that is store them
        as attributes in self.
        """

        if "issuer" in pcr:
            _pcr_issuer = pcr["issuer"]
            if pcr["issuer"].endswith("/"):
                if issuer.endswith("/"):
                    _issuer = issuer
                else:
                    _issuer = issuer + "/"
            else:
                if issuer.endswith("/"):
                    _issuer = issuer[:-1]
                else:
                    _issuer = issuer

            if not self.allow.get("issuer_mismatch", False) and _issuer != _pcr_issuer:
                raise PyoidcError("provider info issuer mismatch '%s' != '%s'" % (_issuer, _pcr_issuer))

            self.provider_info = pcr
        else:
            _pcr_issuer = issuer

        self.issuer = _pcr_issuer

        if endpoints:
            for key, val in pcr.items():
                if key.endswith("_endpoint"):
                    setattr(self, key, val)

        if keys:
            if self.keyjar is None:
                self.keyjar = KeyJar()

            self.keyjar.load_keys(pcr, _pcr_issuer)

    def provider_config(self, issuer, keys=True, endpoints=True,
                        response_cls=ASConfigurationResponse,
                        serv_pattern=OIDCONF_PATTERN):
        if issuer.endswith("/"):
            _issuer = issuer[:-1]
        else:
            _issuer = issuer

        url = serv_pattern % _issuer

        pcr = None
        r = self.http_request(url)
        if r.status_code == 200:
            pcr = response_cls().from_json(r.text)
        elif r.status_code == 302:
            while r.status_code == 302:
                r = self.http_request(r.headers["location"])
                if r.status_code == 200:
                    pcr = response_cls().from_json(r.text)
                    break

        if pcr is None:
            raise PyoidcError("Trying '%s', status %s" % (url, r.status_code))

        self.handle_provider_config(pcr, issuer, keys, endpoints)

        return pcr
예제 #19
0
class Provider(provider.Provider):
    """
    A OAuth2 RP that knows all the OAuth2 extensions I've implemented
    """
    def __init__(self,
                 name,
                 sdb,
                 cdb,
                 authn_broker,
                 authz,
                 client_authn,
                 symkey=None,
                 urlmap=None,
                 iv=0,
                 default_scope="",
                 ca_bundle=None,
                 seed=b"",
                 client_authn_methods=None,
                 authn_at_registration="",
                 client_info_url="",
                 secret_lifetime=86400,
                 jwks_uri='',
                 keyjar=None,
                 capabilities=None,
                 verify_ssl=True,
                 baseurl='',
                 hostname='',
                 config=None,
                 behavior=None,
                 lifetime_policy=None,
                 **kwargs):

        if not name.endswith("/"):
            name += "/"

        try:
            args = {'server_cls': kwargs['server_cls']}
        except KeyError:
            args = {}

        provider.Provider.__init__(self, name, sdb, cdb, authn_broker, authz,
                                   client_authn, symkey, urlmap, iv,
                                   default_scope, ca_bundle, **args)

        self.endp.extend([
            RegistrationEndpoint, ClientInfoEndpoint, RevocationEndpoint,
            IntrospectionEndpoint
        ])

        # dictionary of client authentication methods
        self.client_authn_methods = client_authn_methods
        if authn_at_registration:
            if authn_at_registration not in client_authn_methods:
                raise UnknownAuthnMethod(authn_at_registration)

        self.authn_at_registration = authn_at_registration
        self.seed = seed
        self.client_info_url = client_info_url
        self.secret_lifetime = secret_lifetime
        self.jwks_uri = jwks_uri
        self.verify_ssl = verify_ssl
        try:
            self.scopes = kwargs['scopes']
        except KeyError:
            self.scopes = ['offline_access']
        self.keyjar = keyjar
        if self.keyjar is None:
            self.keyjar = KeyJar(verify_ssl=self.verify_ssl)

        if capabilities:
            self.capabilities = self.provider_features(
                provider_config=capabilities)
        else:
            self.capabilities = self.provider_features()
        self.baseurl = baseurl or name
        self.hostname = hostname or socket.gethostname()
        self.kid = {"sig": {}, "enc": {}}
        self.config = config or {}
        self.behavior = behavior or {}
        self.token_policy = {'access_token': {}, 'refresh_token': {}}
        if lifetime_policy is None:
            self.lifetime_policy = {
                'access_token': {
                    'code': 600,
                    'token': 120,
                    'implicit': 120,
                    'authorization_code': 600,
                    'client_credentials': 600,
                    'password': 600
                },
                'refresh_token': {
                    'code': 3600,
                    'token': 3600,
                    'implicit': 3600,
                    'authorization_code': 3600,
                    'client_credentials': 3600,
                    'password': 3600
                }
            }
        else:
            self.lifetime_policy = lifetime_policy

        self.token_handler = TokenHandler(self.baseurl,
                                          self.token_policy,
                                          keyjar=self.keyjar)

    @staticmethod
    def _uris_to_tuples(uris):
        tup = []
        for uri in uris:
            base, query = splitquery(uri)
            if query:
                tup.append((base, query))
            else:
                tup.append((base, ""))
        return tup

    @staticmethod
    def _tuples_to_uris(items):
        _uri = []
        for url, query in items:
            if query:
                _uri.append("%s?%s" % (url, query))
            else:
                _uri.append(url)
        return _uri

    def load_keys(self, request, client_id, client_secret):
        try:
            self.keyjar.load_keys(request, client_id)
            try:
                n_keys = len(self.keyjar[client_id])
                msg = "Found {} keys for client_id={}"
                logger.debug(msg.format(n_keys, client_id))
            except KeyError:
                pass
        except Exception as err:
            msg = "Failed to load client keys: {}"
            logger.error(msg.format(sanitize(request.to_dict())))
            logger.error("%s", err)
            err = ClientRegistrationError(
                error="invalid_configuration_parameter",
                error_description="%s" % err)
            return Response(err.to_json(),
                            content="application/json",
                            status_code="400 Bad Request")

        # Add the client_secret as a symmetric key to the keyjar
        _kc = KeyBundle([{
            "kty": "oct",
            "key": client_secret,
            "use": "ver"
        }, {
            "kty": "oct",
            "key": client_secret,
            "use": "sig"
        }])
        try:
            self.keyjar[client_id].append(_kc)
        except KeyError:
            self.keyjar[client_id] = [_kc]

    @staticmethod
    def verify_correct(cinfo, restrictions):
        for fname, arg in restrictions.items():
            func = restrict.factory(fname)
            res = func(arg, cinfo)
            if res:
                raise RestrictionError(res)

    def set_token_policy(self, cid, cinfo):
        for ttyp in ['access_token', 'refresh_token']:
            pol = {}
            for rgtyp in ['response_type', 'grant_type']:
                try:
                    rtyp = cinfo[rgtyp]
                except KeyError:
                    pass
                else:
                    for typ in rtyp:
                        try:
                            pol[typ] = self.lifetime_policy[ttyp][typ]
                        except KeyError:
                            pass

            self.token_policy[ttyp][cid] = pol

    def create_new_client(self, request, restrictions):
        """

        :param request: The Client registration request
        :param restrictions: Restrictions on the client
        :return: The client_id
        """

        _cinfo = request.to_dict()

        self.match_client_request(_cinfo)

        # create new id and secret
        _id = rndstr(12)
        while _id in self.cdb:
            _id = rndstr(12)

        _cinfo["client_id"] = _id
        _cinfo["client_secret"] = secret(self.seed, _id)
        _cinfo["client_id_issued_at"] = utc_time_sans_frac()
        _cinfo["client_secret_expires_at"] = utc_time_sans_frac(
        ) + self.secret_lifetime

        # If I support client info endpoint
        if ClientInfoEndpoint in self.endp:
            _cinfo["registration_access_token"] = rndstr(32)
            _cinfo["registration_client_uri"] = "%s%s%s?client_id=%s" % (
                self.name, self.client_info_url, ClientInfoEndpoint.etype, _id)

        if "redirect_uris" in request:
            _cinfo["redirect_uris"] = self._uris_to_tuples(
                request["redirect_uris"])

        self.load_keys(request, _id, _cinfo["client_secret"])

        try:
            _behav = self.behavior['client_registration']
        except KeyError:
            pass
        else:
            self.verify_correct(_cinfo, _behav)

        self.set_token_policy(_id, _cinfo)
        self.cdb[_id] = _cinfo

        return _id

    def match_client_request(self, request):
        for _pref, _prov in PREFERENCE2PROVIDER.items():
            if _pref in request:
                if _pref == "response_types":
                    for val in request[_pref]:
                        match = False
                        p = set(val.split(" "))
                        for cv in self.capabilities[_prov]:
                            if p == set(cv.split(' ')):
                                match = True
                                break
                        if not match:
                            raise CapabilitiesMisMatch(
                                'Not allowed {}'.format(_pref))
                else:
                    if isinstance(request[_pref], str):
                        if request[_pref] not in self.capabilities[_prov]:
                            raise CapabilitiesMisMatch(
                                'Not allowed {}'.format(_pref))
                    else:
                        if not set(request[_pref]).issubset(
                                set(self.capabilities[_prov])):
                            raise CapabilitiesMisMatch(
                                'Not allowed {}'.format(_pref))

    def client_info(self, client_id):
        _cinfo = self.cdb[client_id].copy()
        if not valid_client_info(_cinfo):
            err = ErrorResponse(error="invalid_client",
                                error_description="Invalid client secret")
            return BadRequest(err.to_json(), content="application/json")

        try:
            _cinfo["redirect_uris"] = self._tuples_to_uris(
                _cinfo["redirect_uris"])
        except KeyError:
            pass

        msg = ClientInfoResponse(**_cinfo)
        return Response(msg.to_json(), content="application/json")

    def client_info_update(self, client_id, request):
        _cinfo = self.cdb[client_id].copy()
        try:
            _cinfo["redirect_uris"] = self._tuples_to_uris(
                _cinfo["redirect_uris"])
        except KeyError:
            pass

        for key, value in request.items():
            if key in ["client_secret", "client_id"]:
                # assure it's the same
                if value != _cinfo[key]:
                    raise ModificationForbidden("Not allowed to change")
            else:
                _cinfo[key] = value

        for key in list(_cinfo.keys()):
            if key in [
                    "client_id_issued_at", "client_secret_expires_at",
                    "registration_access_token", "registration_client_uri"
            ]:
                continue
            if key not in request:
                del _cinfo[key]

        if "redirect_uris" in request:
            _cinfo["redirect_uris"] = self._uris_to_tuples(
                request["redirect_uris"])

        self.cdb[client_id] = _cinfo

    def verify_client(self, environ, areq, authn_method, client_id=""):
        """

        :param environ: WSGI environ
        :param areq: The request
        :param authn_method: client authentication method
        :return:
        """

        if not client_id:
            client_id = get_client_id(self.cdb, areq,
                                      environ["HTTP_AUTHORIZATION"])

        try:
            method = self.client_authn_methods[authn_method]
        except KeyError:
            raise UnSupported()
        return method(self).verify(environ, client_id=client_id)

    def consume_software_statement(self, software_statement):
        return {}

    def registration_endpoint(self, **kwargs):
        """

        :param request: The request
        :param authn: Client authentication information
        :param kwargs: extra keyword arguments
        :return: A Response instance
        """

        _request = RegistrationRequest().deserialize(kwargs['request'], "json")
        try:
            _request.verify(keyjar=self.keyjar)
        except InvalidRedirectUri as err:
            msg = ClientRegistrationError(error="invalid_redirect_uri",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")
        except (MissingPage, VerificationError) as err:
            msg = ClientRegistrationError(error="invalid_client_metadata",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")

        # If authentication is necessary at registration
        if self.authn_at_registration:
            try:
                self.verify_client(kwargs['environ'], _request,
                                   self.authn_at_registration)
            except (AuthnFailure, UnknownAssertionType):
                return Unauthorized()

        client_restrictions = {}
        if 'parsed_software_statement' in _request:
            for ss in _request['parsed_software_statement']:
                client_restrictions.update(self.consume_software_statement(ss))
            del _request['software_statement']
            del _request['parsed_software_statement']

        try:
            client_id = self.create_new_client(_request, client_restrictions)
        except CapabilitiesMisMatch as err:
            msg = ClientRegistrationError(error="invalid_client_metadata",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")
        except RestrictionError as err:
            msg = ClientRegistrationError(error="invalid_client_metadata",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")

        return self.client_info(client_id)

    def client_info_endpoint(self, method="GET", **kwargs):
        """
        Operations on this endpoint are switched through the use of different
        HTTP methods

        :param method: HTTP method used for the request
        :param kwargs: keyword arguments
        :return: A Response instance
        """

        _query = compact(parse_qs(kwargs['query']))
        try:
            _id = _query["client_id"]
        except KeyError:
            return BadRequest("Missing query component")

        if _id not in self.cdb:
            return Unauthorized()

        # authenticated client
        try:
            self.verify_client(kwargs['environ'],
                               kwargs['request'],
                               "bearer_header",
                               client_id=_id)
        except (AuthnFailure, UnknownAssertionType):
            return Unauthorized()

        if method == "GET":
            return self.client_info(_id)
        elif method == "PUT":
            try:
                _request = ClientUpdateRequest().from_json(kwargs['request'])
            except ValueError as err:
                return BadRequest(str(err))

            try:
                _request.verify()
            except InvalidRedirectUri as err:
                msg = ClientRegistrationError(error="invalid_redirect_uri",
                                              error_description="%s" % err)
                return BadRequest(msg.to_json(), content="application/json")
            except (MissingPage, VerificationError) as err:
                msg = ClientRegistrationError(error="invalid_client_metadata",
                                              error_description="%s" % err)
                return BadRequest(msg.to_json(), content="application/json")

            try:
                self.client_info_update(_id, _request)
                return self.client_info(_id)
            except ModificationForbidden:
                return Forbidden()
        elif method == "DELETE":
            try:
                del self.cdb[_id]
            except KeyError:
                return Unauthorized()
            else:
                return NoContent()

    def provider_features(self,
                          pcr_class=ServerMetadata,
                          provider_config=None):
        """
        Specifies what the server capabilities are.

        :param pcr_class:
        :return: ProviderConfigurationResponse instance
        """

        _provider_info = pcr_class(**CAPABILITIES)
        _provider_info["scopes_supported"] = self.scopes

        sign_algs = list(jws.SIGNER_ALGS.keys())
        sign_algs.remove('none')
        sign_algs = sorted(sign_algs, key=cmp_to_key(sort_sign_alg))

        _pat1 = "{}_endpoint_auth_signing_alg_values_supported"
        _pat2 = "{}_endpoint_auth_methods_supported"
        for typ in ["token", "revocation", "introspection"]:
            _provider_info[_pat1.format(typ)] = sign_algs
            _provider_info[_pat2.format(typ)] = AUTH_METHODS_SUPPORTED

        if provider_config:
            _provider_info.update(provider_config)

        return _provider_info

    def verify_capabilities(self, capabilities):
        """
        Verify that what the admin wants the server to do actually
        can be done by this implementation.

        :param capabilities: The asked for capabilities as a dictionary
        or a ProviderConfigurationResponse instance. The later can be
        treated as a dictionary.
        :return: True or False
        """
        _pinfo = self.provider_features()
        for key, val in capabilities.items():
            if isinstance(val, str):
                try:
                    if val in _pinfo[key]:
                        continue
                    else:
                        return False
                except KeyError:
                    return False

        return True

    def create_providerinfo(self,
                            pcr_class=ASConfigurationResponse,
                            setup=None):
        """
        Dynamically create the provider info response
        :param pcr_class:
        :param setup:
        :return:
        """

        _provider_info = self.capabilities

        if self.jwks_uri and self.keyjar:
            _provider_info["jwks_uri"] = self.jwks_uri

        for endp in self.endp:
            _provider_info['{}_endpoint'.format(endp.etype)] = os.path.join(
                self.baseurl, endp.url)

        if setup and isinstance(setup, dict):
            for key in pcr_class.c_param.keys():
                if key in setup:
                    _provider_info[key] = setup[key]

        _provider_info["issuer"] = self.baseurl
        _provider_info["version"] = "3.0"

        return _provider_info

    def providerinfo_endpoint(self, **kwargs):
        _log_info = logger.info

        _log_info("@providerinfo_endpoint")
        try:
            _response = self.create_providerinfo()
            _log_info("provider_info_response: %s" % (_response.to_dict(), ))

            headers = [("Cache-Control", "no-store"), ("x-ffo", "bar")]
            if 'handle' in kwargs:
                (key, _) = kwargs['handle']
                if key.startswith(STR) and key.endswith(STR):
                    cookie = self.cookie_func(key, self.cookie_name, "pinfo",
                                              self.sso_ttl)
                    headers.append(cookie)

            resp = Response(_response.to_json(),
                            content="application/json",
                            headers=headers)
        except Exception:
            message = traceback.format_exception(*sys.exc_info())
            logger.error(message)
            resp = Response(message, content="html/text")

        return resp

    @staticmethod
    def verify_code_challenge(code_verifier,
                              code_challenge,
                              code_challenge_method='S256'):
        """
        Verify a PKCE (RFC7636) code challenge

        :param code_verifier: The origin
        :param code_challenge: The transformed verifier used as challenge
        :return:
        """
        _h = CC_METHOD[code_challenge_method](
            code_verifier.encode('ascii')).digest()
        _cc = b64e(_h)
        if _cc.decode('ascii') != code_challenge:
            logger.error('PCKE Code Challenge check failed')
            err = TokenErrorResponse(error="invalid_request",
                                     error_description="PCKE check failed")
            return Response(err.to_json(),
                            content="application/json",
                            status_code=401)
        return True

    def do_access_token_response(self,
                                 access_token,
                                 atinfo,
                                 state,
                                 refresh_token=None):
        _tinfo = {
            'access_token': access_token,
            'expires_in': atinfo['exp'],
            'token_type': 'bearer',
            'state': state
        }
        try:
            _tinfo['scope'] = atinfo['scope']
        except KeyError:
            pass

        if refresh_token:
            _tinfo['refresh_token'] = refresh_token

        return AccessTokenResponse(**by_schema(AccessTokenResponse, **_tinfo))

    def code_grant_type(self, areq):
        # assert that the code is valid
        try:
            _info = self.sdb[areq["code"]]
        except KeyError:
            err = TokenErrorResponse(error="invalid_grant",
                                     error_description="Unknown access grant")
            return Response(err.to_json(),
                            content="application/json",
                            status="401 Unauthorized")

        authzreq = json.loads(_info['authzreq'])
        if 'code_verifier' in areq:
            try:
                _method = authzreq['code_challenge_method']
            except KeyError:
                _method = 'S256'

            resp = self.verify_code_challenge(areq['code_verifier'],
                                              authzreq['code_challenge'],
                                              _method)
            if resp:
                return resp

        if 'state' in areq:
            if self.sdb[areq['code']]['state'] != areq['state']:
                logger.error('State value mismatch')
                err = TokenErrorResponse(error="unauthorized_client")
                return Unauthorized(err.to_json(), content="application/json")

        resp = self.token_scope_check(areq, _info)
        if resp:
            return resp

        # If redirect_uri was in the initial authorization request
        # verify that the one given here is the correct one.
        if "redirect_uri" in _info and areq["redirect_uri"] != _info[
                "redirect_uri"]:
            logger.error('Redirect_uri mismatch')
            err = TokenErrorResponse(error="unauthorized_client")
            return Unauthorized(err.to_json(), content="application/json")

        issue_refresh = False
        if 'scope' in authzreq and 'offline_access' in authzreq['scope']:
            if authzreq['response_type'] == 'code':
                issue_refresh = True

        try:
            _tinfo = self.sdb.upgrade_to_token(areq["code"],
                                               issue_refresh=issue_refresh)
        except AccessCodeUsed:
            err = TokenErrorResponse(error="invalid_grant",
                                     error_description="Access grant used")
            return Response(err.to_json(),
                            content="application/json",
                            status="401 Unauthorized")

        logger.debug("_tinfo: %s" % _tinfo)

        atr = AccessTokenResponse(**by_schema(AccessTokenResponse, **_tinfo))

        logger.debug("AccessTokenResponse: %s" % atr)

        return Response(atr.to_json(),
                        content="application/json",
                        headers=OAUTH2_NOCACHE_HEADERS)

    def client_credentials_grant_type(self, areq):
        _at = self.token_handler.get_access_token(
            areq['client_id'],
            scope=areq['scope'],
            grant_type='client_credentials')
        _info = self.token_handler.token_factory.get_info(_at)
        try:
            _rt = self.token_handler.get_refresh_token(self.baseurl,
                                                       _info['access_token'],
                                                       'client_credentials')
        except NotAllowed:
            atr = self.do_access_token_response(_at, _info, areq['state'])
        else:
            atr = self.do_access_token_response(_at, _info, areq['state'], _rt)

        return Response(atr.to_json(), content="application/json")

    def password_grant_type(self, areq):
        _at = self.token_handler.get_access_token(areq['client_id'],
                                                  scope=areq['scope'],
                                                  grant_type='password')
        _info = self.token_handler.token_factory.get_info(_at)
        try:
            _rt = self.token_handler.get_refresh_token(self.baseurl,
                                                       _info['access_token'],
                                                       'password')
        except NotAllowed:
            atr = self.do_access_token_response(_at, _info, areq['state'])
        else:
            atr = self.do_access_token_response(_at, _info, areq['state'], _rt)

        return Response(atr.to_json(), content="application/json")

    def refresh_token_grant_type(self, areq):
        at = self.token_handler.refresh_access_token(self.baseurl,
                                                     areq['access_token'],
                                                     'refresh_token')

        atr = AccessTokenResponse(**by_schema(AccessTokenResponse, **at))
        return Response(atr.to_json(), content="application/json")

    def token_endpoint(self, authn="", **kwargs):
        """
        This is where clients come to get their access tokens
        """

        logger.debug("- token -")
        body = kwargs["request"]
        logger.debug("body: %s" % body)

        areq = AccessTokenRequest().deserialize(body, "urlencoded")

        try:
            self.client_authn(self, areq, authn)
        except FailedAuthentication as err:
            logger.error(err)
            err = TokenErrorResponse(error="unauthorized_client",
                                     error_description="%s" % err)
            return Response(err.to_json(),
                            content="application/json",
                            status_code=401)

        logger.debug("AccessTokenRequest: %s" % areq)

        _grant_type = areq["grant_type"]
        if _grant_type == "authorization_code":
            return self.code_grant_type(areq)
        elif _grant_type == 'client_credentials':
            return self.client_credentials_grant_type(areq)
        elif _grant_type == 'password':
            return self.password_grant_type(areq)
        elif _grant_type == 'refresh_token':
            return self.refresh_token_grant_type(areq)
        else:
            raise UnSupported('grant_type: {}'.format(_grant_type))

    def key_setup(self, local_path, vault="keys", sig=None, enc=None):
        """
        my keys
        :param local_path: The path to where the JWKs should be stored
        :param vault: Where the private key will be stored
        :param sig: Key for signature
        :param enc: Key for encryption
        :return: A URL the RP can use to download the key.
        """
        self.jwks_uri = key_export(self.baseurl,
                                   local_path,
                                   vault,
                                   self.keyjar,
                                   fqdn=self.hostname,
                                   sig=sig,
                                   enc=enc)

    @staticmethod
    def token_access(endpoint, client_id, token_info):
        # simple rules: if client_id in azp or aud it's allow to introspect
        # to revoke it has to be in azr
        allow = False
        if endpoint == 'revocation_endpoint':
            if 'azr' in token_info and client_id == token_info['azr']:
                allow = True
            elif len(token_info['aud']) == 1 and token_info['aud'] == [
                    client_id
            ]:
                allow = True
        else:  # has to be introspection endpoint
            if 'azr' in token_info and client_id == token_info['azr']:
                allow = True
            elif 'aud' in token_info:
                if client_id in token_info['aud']:
                    allow = True
        return allow

    def get_token_info(self, authn, req, endpoint):
        """

        :param authn:
        :param req:
        :return:
        """
        try:
            client_id = self.client_authn(self, req, authn)
        except FailedAuthentication as err:
            logger.error(err)
            err = TokenErrorResponse(error="unauthorized_client",
                                     error_description="%s" % err)
            return Response(err.to_json(),
                            content="application/json",
                            status="401 Unauthorized")

        logger.debug('{}: {} requesting {}'.format(endpoint, client_id,
                                                   req.to_dict()))

        try:
            token_type = req['token_type_hint']
        except KeyError:
            try:
                _info = self.sdb.token_factory['access_token'].get_info(
                    req['token'])
            except Exception:
                try:
                    _info = self.sdb.token_factory['refresh_token'].get_info(
                        req['token'])
                except Exception:
                    return self._return_inactive()
                else:
                    token_type = 'refresh_token'
            else:
                token_type = 'access_token'
        else:
            try:
                _info = self.sdb.token_factory[token_type].get_info(
                    req['token'])
            except Exception:
                return self._return_inactive()

        if not self.token_access(endpoint, client_id, _info):
            return BadRequest()

        return client_id, token_type, _info

    @staticmethod
    def _return_inactive():
        ir = TokenIntrospectionResponse(active=False)
        return Response(ir.to_json(), content="application/json")

    def revocation_endpoint(self, authn='', request=None, **kwargs):
        """
        Implements RFC7009 allows a client to invalidate an access or refresh
        token.

        :param authn: Client Authentication information
        :param request: The revocation request
        :param kwargs:
        :return:
        """

        trr = TokenRevocationRequest().deserialize(request, "urlencoded")

        resp = self.get_token_info(authn, trr, 'revocation_endpoint')

        if isinstance(resp, Response):
            return resp
        else:
            client_id, token_type, _info = resp

        logger.info('{} token revocation: {}'.format(client_id, trr.to_dict()))

        try:
            self.sdb.token_factory[token_type].invalidate(trr['token'])
        except KeyError:
            return BadRequest()
        else:
            return Response('OK')

    def introspection_endpoint(self, authn='', request=None, **kwargs):
        """
        Implements RFC7662

        :param authn: Client Authentication information
        :param request: The introspection request
        :param kwargs:
        :return:
        """

        tir = TokenIntrospectionRequest().deserialize(request, "urlencoded")

        resp = self.get_token_info(authn, tir, 'introspection_endpoint')

        if isinstance(resp, Response):
            return resp
        else:
            client_id, token_type, _info = resp

        logger.info('{} token introspection: {}'.format(
            client_id, tir.to_dict()))

        ir = TokenIntrospectionResponse(
            active=self.sdb.token_factory[token_type].is_valid(_info),
            **_info.to_dict())

        ir.weed()

        return Response(ir.to_json(), content="application/json")
예제 #20
0
class Provider(provider.Provider):
    """
    A OAuth2 RP that knows all the OAuth2 extensions I've implemented
    """

    def __init__(
        self,
        name,
        sdb,
        cdb,
        authn_broker,
        authz,
        client_authn,
        symkey="",
        urlmap=None,
        iv=0,
        default_scope="",
        ca_bundle=None,
        seed=b"",
        client_authn_methods=None,
        authn_at_registration="",
        client_info_url="",
        secret_lifetime=86400,
        jwks_uri="",
        keyjar=None,
        capabilities=None,
        verify_ssl=True,
        baseurl="",
        hostname="",
    ):

        if not name.endswith("/"):
            name += "/"
        provider.Provider.__init__(
            self, name, sdb, cdb, authn_broker, authz, client_authn, symkey, urlmap, iv, default_scope, ca_bundle
        )

        self.endp.extend([RegistrationEndpoint, ClientInfoEndpoint])

        # dictionary of client authentication methods
        self.client_authn_methods = client_authn_methods
        if authn_at_registration:
            if authn_at_registration not in client_authn_methods:
                raise UnknownAuthnMethod(authn_at_registration)

        self.authn_at_registration = authn_at_registration
        self.seed = seed
        self.client_info_url = client_info_url
        self.secret_lifetime = secret_lifetime
        self.jwks_uri = jwks_uri
        self.verify_ssl = verify_ssl

        self.keyjar = keyjar
        if self.keyjar is None:
            self.keyjar = KeyJar(verify_ssl=self.verify_ssl)

        if capabilities:
            self.verify_capabilities(capabilities)
            self.capabilities = ProviderConfigurationResponse(**capabilities)
        else:
            self.capabilities = self.provider_features()
        self.baseurl = baseurl
        self.hostname = hostname or gethostname()
        self.kid = {"sig": {}, "enc": {}}

    @staticmethod
    def _uris_to_tuples(uris):
        tup = []
        for uri in uris:
            base, query = splitquery(uri)
            if query:
                tup.append((base, query))
            else:
                tup.append((base, ""))
        return tup

    @staticmethod
    def _tuples_to_uris(items):
        _uri = []
        for url, query in items:
            if query:
                _uri.append("%s?%s" % (url, query))
            else:
                _uri.append(url)
        return _uri

    def load_keys(self, request, client_id, client_secret):
        try:
            self.keyjar.load_keys(request, client_id)
            try:
                logger.debug("keys for %s: [%s]" % (client_id, ",".join(["%s" % x for x in self.keyjar[client_id]])))
            except KeyError:
                pass
        except Exception as err:
            logger.error("Failed to load client keys: %s" % request.to_dict())
            logger.error("%s", err)
            err = ClientRegistrationError(error="invalid_configuration_parameter", error_description="%s" % err)
            return Response(err.to_json(), content="application/json", status="400 Bad Request")

        # Add the client_secret as a symmetric key to the keyjar
        _kc = KeyBundle(
            [{"kty": "oct", "key": client_secret, "use": "ver"}, {"kty": "oct", "key": client_secret, "use": "sig"}]
        )
        try:
            self.keyjar[client_id].append(_kc)
        except KeyError:
            self.keyjar[client_id] = [_kc]

    def create_new_client(self, request):
        """

        :param request: The Client registration request
        :return: The client_id
        """

        _cinfo = request.to_dict()

        # create new id and secret
        _id = rndstr(12)
        while _id in self.cdb:
            _id = rndstr(12)

        _cinfo["client_id"] = _id
        _cinfo["client_secret"] = secret(self.seed, _id)
        _cinfo["client_id_issued_at"] = utc_time_sans_frac()
        _cinfo["client_secret_expires_at"] = utc_time_sans_frac() + self.secret_lifetime

        # If I support client info endpoint
        if ClientInfoEndpoint in self.endp:
            _cinfo["registration_access_token"] = rndstr(32)
            _cinfo["registration_client_uri"] = "%s%s%s?client_id=%s" % (
                self.name,
                self.client_info_url,
                ClientInfoEndpoint.etype,
                _id,
            )

        if "redirect_uris" in request:
            _cinfo["redirect_uris"] = self._uris_to_tuples(request["redirect_uris"])

        self.load_keys(request, _id, _cinfo["client_secret"])
        self.cdb[_id] = _cinfo

        return _id

    def client_info(self, client_id):
        _cinfo = self.cdb[client_id].copy()
        try:
            _cinfo["redirect_uris"] = self._tuples_to_uris(_cinfo["redirect_uris"])
        except KeyError:
            pass

        msg = ClientInfoResponse(**_cinfo)
        return Response(msg.to_json(), content="application/json")

    def client_info_update(self, client_id, request):
        _cinfo = self.cdb[client_id].copy()
        try:
            _cinfo["redirect_uris"] = self._tuples_to_uris(_cinfo["redirect_uris"])
        except KeyError:
            pass

        for key, value in request.items():
            if key in ["client_secret", "client_id"]:
                # assure it's the same
                try:
                    assert value == _cinfo[key]
                except AssertionError:
                    raise ModificationForbidden("Not allowed to change")
            else:
                _cinfo[key] = value

        for key in list(_cinfo.keys()):
            if key in [
                "client_id_issued_at",
                "client_secret_expires_at",
                "registration_access_token",
                "registration_client_uri",
            ]:
                continue
            if key not in request:
                del _cinfo[key]

        if "redirect_uris" in request:
            _cinfo["redirect_uris"] = self._uris_to_tuples(request["redirect_uris"])

        self.cdb[client_id] = _cinfo

    def verify_client(self, environ, areq, authn_method, client_id=""):
        """

        :param environ: WSGI environ
        :param areq: The request
        :param authn_method: client authentication method
        :return:
        """

        if not client_id:
            client_id = get_client_id(self.cdb, areq, environ["HTTP_AUTHORIZATION"])

        try:
            method = self.client_authn_methods[authn_method]
        except KeyError:
            raise UnSupported()
        return method(self).verify(environ, client_id=client_id)

    def registration_endpoint(self, **kwargs):
        """

        :param request: The request
        :param authn: Client authentication information
        :param kwargs: extra keyword arguments
        :return: A Response instance
        """

        _request = RegistrationRequest().deserialize(kwargs["request"], "json")
        try:
            _request.verify()
        except InvalidRedirectUri as err:
            msg = ClientRegistrationError(error="invalid_redirect_uri", error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")
        except (MissingPage, VerificationError) as err:
            msg = ClientRegistrationError(error="invalid_client_metadata", error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")

        # authenticated client
        if self.authn_at_registration:
            try:
                _ = self.verify_client(kwargs["environ"], _request, self.authn_at_registration)
            except (AuthnFailure, UnknownAssertionType):
                return Unauthorized()

        client_id = self.create_new_client(_request)

        return self.client_info(client_id)

    def client_info_endpoint(self, method="GET", **kwargs):
        """
        Operations on this endpoint are switched through the use of different
        HTTP methods

        :param method: HTTP method used for the request
        :param kwargs: keyword arguments
        :return: A Response instance
        """

        _query = parse_qs(kwargs["query"])
        try:
            _id = _query["client_id"][0]
        except KeyError:
            return BadRequest("Missing query component")

        try:
            assert _id in self.cdb
        except AssertionError:
            return Unauthorized()

        # authenticated client
        try:
            _ = self.verify_client(kwargs["environ"], kwargs["request"], "bearer_header", client_id=_id)
        except (AuthnFailure, UnknownAssertionType):
            return Unauthorized()

        if method == "GET":
            return self.client_info(_id)
        elif method == "PUT":
            try:
                _request = ClientUpdateRequest().from_json(kwargs["request"])
            except ValueError as err:
                return BadRequest(str(err))

            try:
                _request.verify()
            except InvalidRedirectUri as err:
                msg = ClientRegistrationError(error="invalid_redirect_uri", error_description="%s" % err)
                return BadRequest(msg.to_json(), content="application/json")
            except (MissingPage, VerificationError) as err:
                msg = ClientRegistrationError(error="invalid_client_metadata", error_description="%s" % err)
                return BadRequest(msg.to_json(), content="application/json")

            try:
                self.client_info_update(_id, _request)
                return self.client_info(_id)
            except ModificationForbidden:
                return Forbidden()
        elif method == "DELETE":
            try:
                del self.cdb[_id]
            except KeyError:
                return Unauthorized()
            else:
                return NoContent()

    def provider_features(self, pcr_class=ProviderConfigurationResponse):
        """
        Specifies what the server capabilities are.

        :param pcr_class:
        :return: ProviderConfigurationResponse instance
        """

        _provider_info = pcr_class(**CAPABILITIES)

        _claims = []
        for _cl in SCOPE2CLAIMS.values():
            _claims.extend(_cl)
        _provider_info["claims_supported"] = list(set(_claims))

        _scopes = list(SCOPE2CLAIMS.keys())
        _scopes.append("openid")
        _provider_info["scopes_supported"] = _scopes

        sign_algs = list(jws.SIGNER_ALGS.keys())
        for typ in ["userinfo", "id_token", "request_object"]:
            _provider_info["%s_signing_alg_values_supported" % typ] = sign_algs

        # Remove 'none' for token_endpoint_auth_signing_alg_values_supported
        # since it is not allowed
        sign_algs = sign_algs[:]
        sign_algs.remove("none")
        _provider_info["token_endpoint_auth_signing_alg_values_supported"] = sign_algs

        algs = jwe.SUPPORTED["alg"]
        for typ in ["userinfo", "id_token", "request_object"]:
            _provider_info["%s_encryption_alg_values_supported" % typ] = algs

        encs = jwe.SUPPORTED["enc"]
        for typ in ["userinfo", "id_token", "request_object"]:
            _provider_info["%s_encryption_enc_values_supported" % typ] = encs

        # acr_values
        if self.authn_broker:
            acr_values = self.authn_broker.getAcrValuesString()
            if acr_values is not None:
                _provider_info["acr_values_supported"] = acr_values

        return _provider_info

    def verify_capabilities(self, capabilities):
        """
        Verify that what the admin wants the server to do actually
        can be done by this implementation.

        :param capabilities: The asked for capabilities as a dictionary
        or a ProviderConfigurationResponse instance. The later can be
        treated as a dictionary.
        :return: True or False
        """
        _pinfo = self.provider_features()
        for key, val in capabilities.items():
            if isinstance(val, six.string_types):
                try:
                    if val in _pinfo[key]:
                        continue
                    else:
                        return False
                except KeyError:
                    return False

        return True

    def create_providerinfo(self, pcr_class=ProviderConfigurationResponse, setup=None):
        """
        Dynamically create the provider info response
        :param pcr_class:
        :param setup:
        :return:
        """

        _provider_info = self.capabilities

        if self.jwks_uri and self.keyjar:
            _provider_info["jwks_uri"] = self.jwks_uri

        for endp in self.endp:
            # _log_info("# %s, %s" % (endp, endp.name))
            _provider_info["{}_endpoint".format(endp.etype)] = os.path.join(self.baseurl, endp.url)

        if setup and isinstance(setup, dict):
            for key in pcr_class.c_param.keys():
                if key in setup:
                    _provider_info[key] = setup[key]

        _provider_info["issuer"] = self.baseurl
        _provider_info["version"] = "3.0"

        return _provider_info

    def providerinfo_endpoint(self, **kwargs):
        _log_info = logger.info

        _log_info("@providerinfo_endpoint")
        try:
            _response = self.create_providerinfo()
            _log_info("provider_info_response: %s" % (_response.to_dict(),))

            headers = [("Cache-Control", "no-store"), ("x-ffo", "bar")]
            if "handle" in kwargs:
                (key, timestamp) = kwargs["handle"]
                if key.startswith(STR) and key.endswith(STR):
                    cookie = self.cookie_func(key, self.cookie_name, "pinfo", self.sso_ttl)
                    headers.append(cookie)

            resp = Response(_response.to_json(), content="application/json", headers=headers)
        except Exception as err:
            message = traceback.format_exception(*sys.exc_info())
            logger.error(message)
            resp = Response(message, content="html/text")

        return resp

    def token_endpoint(self, authn="", **kwargs):
        """
        This is where clients come to get their access tokens
        """

        _sdb = self.sdb

        logger.debug("- token -")
        body = kwargs["request"]
        logger.debug("body: %s" % body)

        areq = AccessTokenRequest().deserialize(body, "urlencoded")

        try:
            client = self.client_authn(self, areq, authn)
        except FailedAuthentication as err:
            err = TokenErrorResponse(error="unauthorized_client", error_description="%s" % err)
            return Response(err.to_json(), content="application/json", status="401 Unauthorized")

        logger.debug("AccessTokenRequest: %s" % areq)

        try:
            assert areq["grant_type"] == "authorization_code"
        except AssertionError:
            err = TokenErrorResponse(error="invalid_request", error_description="Wrong grant type")
            return Response(err.to_json(), content="application/json", status="401 Unauthorized")

        # assert that the code is valid
        _info = _sdb[areq["code"]]

        resp = self.token_scope_check(areq, _info)
        if resp:
            return resp

        # If redirect_uri was in the initial authorization request
        # verify that the one given here is the correct one.
        if "redirect_uri" in _info:
            assert areq["redirect_uri"] == _info["redirect_uri"]

        try:
            _tinfo = _sdb.upgrade_to_token(areq["code"], issue_refresh=True)
        except AccessCodeUsed:
            err = TokenErrorResponse(error="invalid_grant", error_description="Access grant used")
            return Response(err.to_json(), content="application/json", status="401 Unauthorized")

        logger.debug("_tinfo: %s" % _tinfo)

        atr = AccessTokenResponse(**by_schema(AccessTokenResponse, **_tinfo))

        logger.debug("AccessTokenResponse: %s" % atr)

        return Response(atr.to_json(), content="application/json")

    def key_setup(self, local_path, vault="keys", sig=None, enc=None):
        """
        my keys
        :param local_path: The path to where the JWKs should be stored
        :param vault: Where the private key will be stored
        :param sig: Key for signature
        :param enc: Key for encryption
        :return: A URL the RP can use to download the key.
        """
        self.jwks_uri = key_export(self.baseurl, local_path, vault, self.keyjar, fqdn=self.hostname, sig=sig, enc=enc)