def test_provider(self): provider_info = {"jwks_uri": "https://connect-op.herokuapp.com/jwks.json"} ks = KeyJar() ks.load_keys(provider_info, "https://connect-op.heroku.com") with responses.RequestsMock() as rsps: rsps.add( responses.GET, "https://connect-op.herokuapp.com/jwks.json", json={ "keys": [ { "kty": "RSA", "e": "AQAB", "n": "pKybs0WaHU_y4cHxWbm8Wzj66HtcyFn7Fh3n-99qTXu5yNa30MRYIYfSDwe9JVc1JUoGw41yq2StdGBJ" "40HxichjE-Yopfu3B58QlgJvToUbWD4gmTDGgMGxQxtv1En2yedaynQ73sDpIK-12JJDY55pvf-PCiSQ9OjxZ" "LiVGKlClDus44_uv2370b9IN2JiEOF-a7JBqaTEYLPpXaoKWDSnJNonr79tL0T7iuJmO1l705oO3Y0TQ-INLY" "6jnKG_RpsvyvGNnwP9pMvcP1phKsWZ10ofuuhJGRp8IxQL9RfzT87OvF0RBSO1U73h09YP-corWDsnKIi6Tbz" "RpN5YDw", "use": "sig", "kid": "default", } ] }, ) assert ks["https://connect-op.heroku.com"][0].keys()
def test_provider(self): provider_info = { "jwks_uri": "https://connect-op.herokuapp.com/jwks.json", } ks = KeyJar() ks.load_keys(provider_info, "https://connect-op.heroku.com") assert ks["https://connect-op.heroku.com"][0].keys()
def test_provider(self): provider_info = { "jwks_uri": "https://connect-op.herokuapp.com/jwks.json", } ks = KeyJar() ks.load_keys(provider_info, "https://connect-op.heroku.com") assert ks["https://connect-op.heroku.com"][0].keys()
class Provider(provider.Provider): """ A OAuth2 RP that knows all the OAuth2 extensions I've implemented """ def __init__(self, name, sdb, cdb, authn_broker, authz, client_authn, symkey="", urlmap=None, iv=0, default_scope="", ca_bundle=None, seed=b"", client_authn_methods=None, authn_at_registration="", client_info_url="", secret_lifetime=86400, jwks_uri='', keyjar=None, capabilities=None, verify_ssl=True, baseurl='', hostname='', config=None, behavior=None): if not name.endswith("/"): name += "/" provider.Provider.__init__(self, name, sdb, cdb, authn_broker, authz, client_authn, symkey, urlmap, iv, default_scope, ca_bundle) self.endp.extend([RegistrationEndpoint, ClientInfoEndpoint, RevocationEndpoint, IntrospectionEndpoint]) # dictionary of client authentication methods self.client_authn_methods = client_authn_methods if authn_at_registration: if authn_at_registration not in client_authn_methods: raise UnknownAuthnMethod(authn_at_registration) self.authn_at_registration = authn_at_registration self.seed = seed self.client_info_url = client_info_url self.secret_lifetime = secret_lifetime self.jwks_uri = jwks_uri self.verify_ssl = verify_ssl self.keyjar = keyjar if self.keyjar is None: self.keyjar = KeyJar(verify_ssl=self.verify_ssl) if capabilities: self.capabilities = self.provider_features( provider_config=capabilities) else: self.capabilities = self.provider_features() self.baseurl = baseurl self.hostname = hostname or socket.gethostname() self.kid = {"sig": {}, "enc": {}} self.config = config self.behavior = behavior or {} @staticmethod def _uris_to_tuples(uris): tup = [] for uri in uris: base, query = splitquery(uri) if query: tup.append((base, query)) else: tup.append((base, "")) return tup @staticmethod def _tuples_to_uris(items): _uri = [] for url, query in items: if query: _uri.append("%s?%s" % (url, query)) else: _uri.append(url) return _uri def load_keys(self, request, client_id, client_secret): try: self.keyjar.load_keys(request, client_id) try: logger.debug("keys for %s: [%s]" % ( client_id, ",".join(["%s" % x for x in self.keyjar[client_id]]))) except KeyError: pass except Exception as err: logger.error("Failed to load client keys: %s" % request.to_dict()) logger.error("%s", err) err = ClientRegistrationError( error="invalid_configuration_parameter", error_description="%s" % err) return Response(err.to_json(), content="application/json", status="400 Bad Request") # Add the client_secret as a symmetric key to the keyjar _kc = KeyBundle([{"kty": "oct", "key": client_secret, "use": "ver"}, {"kty": "oct", "key": client_secret, "use": "sig"}]) try: self.keyjar[client_id].append(_kc) except KeyError: self.keyjar[client_id] = [_kc] @staticmethod def verify_correct(cinfo, restrictions): for fname, arg in restrictions.items(): func = restrict.factory(fname) res = func(arg, cinfo) if res: raise RestrictionError(res) def create_new_client(self, request, restrictions): """ :param request: The Client registration request :param restrictions: Restrictions on the client :return: The client_id """ _cinfo = request.to_dict() self.match_client_request(_cinfo) # create new id and secret _id = rndstr(12) while _id in self.cdb: _id = rndstr(12) _cinfo["client_id"] = _id _cinfo["client_secret"] = secret(self.seed, _id) _cinfo["client_id_issued_at"] = utc_time_sans_frac() _cinfo["client_secret_expires_at"] = utc_time_sans_frac() + \ self.secret_lifetime # If I support client info endpoint if ClientInfoEndpoint in self.endp: _cinfo["registration_access_token"] = rndstr(32) _cinfo["registration_client_uri"] = "%s%s%s?client_id=%s" % ( self.name, self.client_info_url, ClientInfoEndpoint.etype, _id) if "redirect_uris" in request: _cinfo["redirect_uris"] = self._uris_to_tuples( request["redirect_uris"]) self.load_keys(request, _id, _cinfo["client_secret"]) try: _behav = self.behavior['client_registration'] except KeyError: pass else: self.verify_correct(_cinfo, _behav) self.cdb[_id] = _cinfo return _id def match_client_request(self, request): for _pref, _prov in PREFERENCE2PROVIDER.items(): if _pref in request: if _pref == "response_types": for val in request[_pref]: match = False p = set(val.split(" ")) for cv in self.capabilities[_prov]: if p == set(cv.split(' ')): match = True break if not match: raise CapabilitiesMisMatch( 'Not allowed {}'.format(_pref)) else: if isinstance(request[_pref], six.string_types): if request[_pref] not in self.capabilities[_prov]: raise CapabilitiesMisMatch( 'Not allowed {}'.format(_pref)) else: if not set(request[_pref]).issubset( set(self.capabilities[_prov])): raise CapabilitiesMisMatch( 'Not allowed {}'.format(_pref)) def client_info(self, client_id): _cinfo = self.cdb[client_id].copy() try: _cinfo["redirect_uris"] = self._tuples_to_uris( _cinfo["redirect_uris"]) except KeyError: pass msg = ClientInfoResponse(**_cinfo) return Response(msg.to_json(), content="application/json") def client_info_update(self, client_id, request): _cinfo = self.cdb[client_id].copy() try: _cinfo["redirect_uris"] = self._tuples_to_uris( _cinfo["redirect_uris"]) except KeyError: pass for key, value in request.items(): if key in ["client_secret", "client_id"]: # assure it's the same try: assert value == _cinfo[key] except AssertionError: raise ModificationForbidden("Not allowed to change") else: _cinfo[key] = value for key in list(_cinfo.keys()): if key in ["client_id_issued_at", "client_secret_expires_at", "registration_access_token", "registration_client_uri"]: continue if key not in request: del _cinfo[key] if "redirect_uris" in request: _cinfo["redirect_uris"] = self._uris_to_tuples( request["redirect_uris"]) self.cdb[client_id] = _cinfo def verify_client(self, environ, areq, authn_method, client_id=""): """ :param environ: WSGI environ :param areq: The request :param authn_method: client authentication method :return: """ if not client_id: client_id = get_client_id(self.cdb, areq, environ["HTTP_AUTHORIZATION"]) try: method = self.client_authn_methods[authn_method] except KeyError: raise UnSupported() return method(self).verify(environ, client_id=client_id) def consume_software_statement(self, software_statement): return {} def registration_endpoint(self, **kwargs): """ :param request: The request :param authn: Client authentication information :param kwargs: extra keyword arguments :return: A Response instance """ _request = RegistrationRequest().deserialize(kwargs['request'], "json") try: _request.verify(keyjar=self.keyjar) except InvalidRedirectUri as err: msg = ClientRegistrationError(error="invalid_redirect_uri", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") except (MissingPage, VerificationError) as err: msg = ClientRegistrationError(error="invalid_client_metadata", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") # authenticated client if self.authn_at_registration: try: _ = self.verify_client(kwargs['environ'], _request, self.authn_at_registration) except (AuthnFailure, UnknownAssertionType): return Unauthorized() client_restrictions = {} if 'parsed_software_statement' in _request: for ss in _request['parsed_software_statement']: client_restrictions.update(self.consume_software_statement(ss)) del _request['software_statement'] del _request['parsed_software_statement'] try: client_id = self.create_new_client(_request, client_restrictions) except CapabilitiesMisMatch as err: msg = ClientRegistrationError(error="invalid_client_metadata", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") except RestrictionError as err: msg = ClientRegistrationError(error="invalid_client_metadata", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") return self.client_info(client_id) def client_info_endpoint(self, method="GET", **kwargs): """ Operations on this endpoint are switched through the use of different HTTP methods :param method: HTTP method used for the request :param kwargs: keyword arguments :return: A Response instance """ _query = parse_qs(kwargs['query']) try: _id = _query["client_id"][0] except KeyError: return BadRequest("Missing query component") try: assert _id in self.cdb except AssertionError: return Unauthorized() # authenticated client try: _ = self.verify_client(kwargs['environ'], kwargs['request'], "bearer_header", client_id=_id) except (AuthnFailure, UnknownAssertionType): return Unauthorized() if method == "GET": return self.client_info(_id) elif method == "PUT": try: _request = ClientUpdateRequest().from_json(kwargs['request']) except ValueError as err: return BadRequest(str(err)) try: _request.verify() except InvalidRedirectUri as err: msg = ClientRegistrationError(error="invalid_redirect_uri", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") except (MissingPage, VerificationError) as err: msg = ClientRegistrationError(error="invalid_client_metadata", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") try: self.client_info_update(_id, _request) return self.client_info(_id) except ModificationForbidden: return Forbidden() elif method == "DELETE": try: del self.cdb[_id] except KeyError: return Unauthorized() else: return NoContent() def provider_features(self, pcr_class=ASConfigurationResponse, provider_config=None): """ Specifies what the server capabilities are. :param pcr_class: :return: ProviderConfigurationResponse instance """ _provider_info = pcr_class(**CAPABILITIES) _scopes = list(SCOPE2CLAIMS.keys()) _provider_info["scopes_supported"] = _scopes sign_algs = list(jws.SIGNER_ALGS.keys()) # Remove 'none' for token_endpoint_auth_signing_alg_values_supported # since it is not allowed sign_algs = sign_algs[:] sign_algs.remove('none') _provider_info[ "token_endpoint_auth_signing_alg_values_supported"] = sign_algs if provider_config: _provider_info.update(provider_config) return _provider_info def verify_capabilities(self, capabilities): """ Verify that what the admin wants the server to do actually can be done by this implementation. :param capabilities: The asked for capabilities as a dictionary or a ProviderConfigurationResponse instance. The later can be treated as a dictionary. :return: True or False """ _pinfo = self.provider_features() for key, val in capabilities.items(): if isinstance(val, six.string_types): try: if val in _pinfo[key]: continue else: return False except KeyError: return False return True def create_providerinfo(self, pcr_class=ASConfigurationResponse, setup=None): """ Dynamically create the provider info response :param pcr_class: :param setup: :return: """ _provider_info = self.capabilities if self.jwks_uri and self.keyjar: _provider_info["jwks_uri"] = self.jwks_uri for endp in self.endp: # _log_info("# %s, %s" % (endp, endp.name)) _provider_info['{}_endpoint'.format(endp.etype)] = os.path.join( self.baseurl, endp.url) if setup and isinstance(setup, dict): for key in pcr_class.c_param.keys(): if key in setup: _provider_info[key] = setup[key] _provider_info["issuer"] = self.baseurl _provider_info["version"] = "3.0" return _provider_info def providerinfo_endpoint(self, **kwargs): _log_info = logger.info _log_info("@providerinfo_endpoint") try: _response = self.create_providerinfo() _log_info("provider_info_response: %s" % (_response.to_dict(),)) headers = [("Cache-Control", "no-store"), ("x-ffo", "bar")] if 'handle' in kwargs: (key, timestamp) = kwargs['handle'] if key.startswith(STR) and key.endswith(STR): cookie = self.cookie_func(key, self.cookie_name, "pinfo", self.sso_ttl) headers.append(cookie) resp = Response(_response.to_json(), content="application/json", headers=headers) except Exception as err: message = traceback.format_exception(*sys.exc_info()) logger.error(message) resp = Response(message, content="html/text") return resp @staticmethod def verify_code_challenge(code_verifier, code_challenge, code_challenge_method='S256'): """ Verify a PKCE (RFC7636) code challenge :param code_verifier: The origin :param code_challenge: The transformed verifier used as challenge :return: """ _h = CC_METHOD[code_challenge_method]( code_verifier.encode()).hexdigest() _cc = b64e(_h.encode()) if _cc.decode() != code_challenge: logger.error('PCKE Code Challenge check failed') err = TokenErrorResponse(error="invalid_request", error_description="PCKE check failed") return Response(err.to_json(), content="application/json", status="401 Unauthorized") return True def token_endpoint(self, authn="", **kwargs): """ This is where clients come to get their access tokens """ _sdb = self.sdb logger.debug("- token -") body = kwargs["request"] logger.debug("body: %s" % body) areq = AccessTokenRequest().deserialize(body, "urlencoded") try: client_id = self.client_authn(self, areq, authn) except FailedAuthentication as err: err = TokenErrorResponse(error="unauthorized_client", error_description="%s" % err) return Response(err.to_json(), content="application/json", status="401 Unauthorized") logger.debug("AccessTokenRequest: %s" % areq) # assert that the code is valid try: _info = _sdb[areq["code"]] except KeyError: err = TokenErrorResponse(error="invalid_grant", error_description="Unknown access grant") return Response(err.to_json(), content="application/json", status="401 Unauthorized") authzreq = json.loads(_info['authzreq']) if 'code_verifier' in areq: try: _method = authzreq['code_challenge_method'] except KeyError: _method = 'S256' resp = self.verify_code_challenge(areq['code_verifier'], authzreq['code_challenge'], _method) if resp: return resp try: assert areq["grant_type"] == "authorization_code" except AssertionError: err = TokenErrorResponse(error="invalid_request", error_description="Wrong grant type") return Response(err.to_json(), content="application/json", status="401 Unauthorized") if 'state_hash' in areq: # have to get the token to get at the state code = areq['code'] shash = base64.urlsafe_b64encode( hashlib.sha256( self.sdb[code]['state'].encode('utf8')).digest()) if shash.decode('ascii') != areq['state_hash']: err = TokenErrorResponse(error="unauthorized_client") return Unauthorized(err.to_json(), content="application/json") resp = self.token_scope_check(areq, _info) if resp: return resp # If redirect_uri was in the initial authorization request # verify that the one given here is the correct one. if "redirect_uri" in _info: assert areq["redirect_uri"] == _info["redirect_uri"] issue_refresh = False if 'scope' in authzreq and 'offline_access' in authzreq['scope']: if authzreq['response_type'] == 'code': issue_refresh = True try: _tinfo = _sdb.upgrade_to_token(areq["code"], issue_refresh=issue_refresh) except AccessCodeUsed: err = TokenErrorResponse(error="invalid_grant", error_description="Access grant used") return Response(err.to_json(), content="application/json", status="401 Unauthorized") logger.debug("_tinfo: %s" % _tinfo) atr = AccessTokenResponse(**by_schema(AccessTokenResponse, **_tinfo)) logger.debug("AccessTokenResponse: %s" % atr) return Response(atr.to_json(), content="application/json") def key_setup(self, local_path, vault="keys", sig=None, enc=None): """ my keys :param local_path: The path to where the JWKs should be stored :param vault: Where the private key will be stored :param sig: Key for signature :param enc: Key for encryption :return: A URL the RP can use to download the key. """ self.jwks_uri = key_export(self.baseurl, local_path, vault, self.keyjar, fqdn=self.hostname, sig=sig, enc=enc) @staticmethod def token_access(endpoint, client_id, token_info): # simple rules: if client_id in azp or aud it's allow to introspect # to revoke it has to be in azr allow = False if endpoint == 'revocation_endpoint': if 'azr' in token_info and client_id == token_info['azr']: allow = True else: # has to be introspection endpoint if 'azr' in token_info and client_id == token_info['azr']: allow = True elif 'aud' in token_info: if client_id in token_info['aud']: allow = True return allow def get_token_info(self, authn, req, endpoint): """ :param authn: :param req: :return: """ try: client_id = self.client_authn(self, req, authn) except FailedAuthentication as err: err = TokenErrorResponse(error="unauthorized_client", error_description="%s" % err) return Response(err.to_json(), content="application/json", status="401 Unauthorized") logger.debug('{}: {} requesting {}'.format(endpoint, client_id, req.to_dict())) try: token_type = req['token_type_hint'] except KeyError: try: _info = self.sdb.token_factory['access_token'].info( req['token']) except KeyError: try: _info = self.sdb.token_factory['refresh_token'].get_info( req['token']) except KeyError: raise else: token_type = 'refresh_token' else: token_type = 'access_token' else: try: _info = self.sdb.token_factory[token_type].get_info( req['token']) except KeyError: raise if not self.token_access(endpoint, client_id, _info): return BadRequest() return client_id, token_type, _info def revocation_endpoint(self, authn='', request=None, **kwargs): """ Implements RFC7009 allows a client to invalidate an access or refresh token. :param authn: Client Authentication information :param request: The revocation request :param kwargs: :return: """ trr = TokenRevocationRequest().deserialize(request, "urlencoded") resp = self.get_token_info(authn, trr, 'revocation_endpoint') if isinstance(resp, Response): return resp else: client_id, token_type, _info = resp logger.info('{} token revocation: {}'.format(client_id, trr.to_dict())) try: self.sdb.token_factory[token_type].invalidate(trr['token']) except KeyError: return BadRequest() else: return Response('OK') def introspection_endpoint(self, authn='', request=None, **kwargs): """ Implements RFC7662 :param authn: Client Authentication information :param request: The introspection request :param kwargs: :return: """ tir = TokenIntrospectionRequest().deserialize(request, "urlencoded") resp = self.get_token_info(authn, tir, 'introspection_endpoint') if isinstance(resp, Response): return resp else: client_id, token_type, _info = resp logger.info('{} token introspection: {}'.format(client_id, tir.to_dict())) ir = TokenIntrospectionResponse( active=self.sdb.token_factory[token_type].is_valid(_info), **_info.to_dict()) ir.weed() return Response(ir.to_json(), content="application/json")
class Client(PBase): _endpoints = ENDPOINTS def __init__(self, client_id=None, ca_certs=None, client_authn_method=None, keyjar=None, verify_ssl=True, config=None, client_cert=None): """ :param client_id: The client identifier :param ca_certs: Certificates used to verify HTTPS certificates :param client_authn_method: Methods that this client can use to authenticate itself. It's a dictionary with method names as keys and method classes as values. :param verify_ssl: Whether the SSL certificate should be verified. :return: Client instance """ PBase.__init__(self, ca_certs, verify_ssl=verify_ssl, client_cert=client_cert, keyjar=keyjar) self.client_id = client_id self.client_authn_method = client_authn_method self.verify_ssl = verify_ssl # self.secret_type = "basic " # self.state = None self.nonce = None self.grant = {} self.state2nonce = {} # own endpoint self.redirect_uris = [None] # service endpoints self.authorization_endpoint = None self.token_endpoint = None self.token_revocation_endpoint = None self.request2endpoint = REQUEST2ENDPOINT self.response2error = RESPONSE2ERROR self.grant_class = Grant self.token_class = Token self.provider_info = {} self._c_secret = None self.kid = {"sig": {}, "enc": {}} self.authz_req = None # the OAuth issuer is the URL of the authorization server's # configuration information location self.config = config or {} try: self.issuer = self.config['issuer'] except KeyError: self.issuer = '' self.allow = {} self.provider_info = {} def store_response(self, clinst, text): pass def get_client_secret(self): return self._c_secret def set_client_secret(self, val): if not val: self._c_secret = "" else: self._c_secret = val # client uses it for signing # Server might also use it for signing which means the # client uses it for verifying server signatures if self.keyjar is None: self.keyjar = KeyJar() self.keyjar.add_symmetric("", str(val)) client_secret = property(get_client_secret, set_client_secret) def reset(self): # self.state = None self.nonce = None self.grant = {} self.authorization_endpoint = None self.token_endpoint = None self.redirect_uris = None def grant_from_state(self, state): for key, grant in self.grant.items(): if key == state: return grant return None def _parse_args(self, request, **kwargs): ar_args = kwargs.copy() for prop in request.c_param.keys(): if prop in ar_args: continue else: if prop == "redirect_uri": _val = getattr(self, "redirect_uris", [None])[0] if _val: ar_args[prop] = _val else: _val = getattr(self, prop, None) if _val: ar_args[prop] = _val return ar_args def _endpoint(self, endpoint, **kwargs): try: uri = kwargs[endpoint] if uri: del kwargs[endpoint] except KeyError: uri = "" if not uri: try: uri = getattr(self, endpoint) except Exception: raise MissingEndpoint("No '%s' specified" % endpoint) if not uri: raise MissingEndpoint("No '%s' specified" % endpoint) return uri def get_grant(self, state, **kwargs): # try: # _state = kwargs["state"] # if not _state: # _state = self.state # except KeyError: # _state = self.state try: return self.grant[state] except KeyError: raise GrantError("No grant found for state:'%s'" % state) def get_token(self, also_expired=False, **kwargs): try: return kwargs["token"] except KeyError: grant = self.get_grant(**kwargs) try: token = grant.get_token(kwargs["scope"]) except KeyError: token = grant.get_token("") if not token: try: token = self.grant[kwargs["state"]].get_token("") except KeyError: raise TokenError("No token found for scope") if token is None: raise TokenError("No suitable token found") if also_expired: return token elif token.is_valid(): return token else: raise TokenError("Token has expired") def construct_request(self, request, request_args=None, extra_args=None): if request_args is None: request_args = {} # logger.debug("request_args: %s" % sanitize(request_args)) kwargs = self._parse_args(request, **request_args) if extra_args: kwargs.update(extra_args) # logger.debug("kwargs: %s" % sanitize(kwargs)) # logger.debug("request: %s" % sanitize(request)) return request(**kwargs) def construct_Message(self, request=Message, request_args=None, extra_args=None, **kwargs): return self.construct_request(request, request_args, extra_args) def construct_AuthorizationRequest(self, request=AuthorizationRequest, request_args=None, extra_args=None, **kwargs): if request_args is not None: try: # change default new = request_args["redirect_uri"] if new: self.redirect_uris = [new] except KeyError: pass else: request_args = {} if "client_id" not in request_args: request_args["client_id"] = self.client_id elif not request_args["client_id"]: request_args["client_id"] = self.client_id return self.construct_request(request, request_args, extra_args) def construct_AccessTokenRequest(self, request=AccessTokenRequest, request_args=None, extra_args=None, **kwargs): grant = self.get_grant(**kwargs) if not grant.is_valid(): raise GrantExpired( "Authorization Code to old %s > %s" % (utc_time_sans_frac(), grant.grant_expiration_time)) if request_args is None: request_args = {} request_args["code"] = grant.code try: request_args['state'] = kwargs['state'] except KeyError: pass if "grant_type" not in request_args: request_args["grant_type"] = "authorization_code" if "client_id" not in request_args: request_args["client_id"] = self.client_id elif not request_args["client_id"]: request_args["client_id"] = self.client_id return self.construct_request(request, request_args, extra_args) def construct_RefreshAccessTokenRequest(self, request=RefreshAccessTokenRequest, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} token = self.get_token(also_expired=True, **kwargs) request_args["refresh_token"] = token.refresh_token try: request_args["scope"] = token.scope except AttributeError: pass return self.construct_request(request, request_args, extra_args) # def construct_TokenRevocationRequest(self, # request=TokenRevocationRequest, # request_args=None, extra_args=None, # **kwargs): # # if request_args is None: # request_args = {} # # token = self.get_token(**kwargs) # # request_args["token"] = token.access_token # return self.construct_request(request, request_args, extra_args) def construct_ResourceRequest(self, request=ResourceRequest, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} token = self.get_token(**kwargs) request_args["access_token"] = token.access_token return self.construct_request(request, request_args, extra_args) def uri_and_body(self, reqmsg, cis, method="POST", request_args=None, **kwargs): if "endpoint" in kwargs and kwargs["endpoint"]: uri = kwargs["endpoint"] else: uri = self._endpoint(self.request2endpoint[reqmsg.__name__], **request_args) uri, body, kwargs = get_or_post(uri, method, cis, **kwargs) try: h_args = {"headers": kwargs["headers"]} except KeyError: h_args = {} return uri, body, h_args, cis def request_info(self, request, method="POST", request_args=None, extra_args=None, lax=False, **kwargs): if request_args is None: request_args = {} try: cls = getattr(self, "construct_%s" % request.__name__) cis = cls(request_args=request_args, extra_args=extra_args, **kwargs) except AttributeError: cis = self.construct_request(request, request_args, extra_args) if self.events: self.events.store('Protocol request', cis) if 'nonce' in cis and 'state' in cis: self.state2nonce[cis['state']] = cis['nonce'] cis.lax = lax if "authn_method" in kwargs: h_arg = self.init_authentication_method(cis, request_args=request_args, **kwargs) else: h_arg = None if h_arg: if "headers" in kwargs.keys(): kwargs["headers"].update(h_arg["headers"]) else: kwargs["headers"] = h_arg["headers"] return self.uri_and_body(request, cis, method, request_args, **kwargs) def authorization_request_info(self, request_args=None, extra_args=None, **kwargs): return self.request_info(AuthorizationRequest, "GET", request_args, extra_args, **kwargs) def get_urlinfo(self, info): if '?' in info or '#' in info: parts = urlparse(info) scheme, netloc, path, params, query, fragment = parts[:6] # either query of fragment if query: info = query else: info = fragment return info def parse_response(self, response, info="", sformat="json", state="", **kwargs): """ Parse a response :param response: Response type :param info: The response, can be either in a JSON or an urlencoded format :param sformat: Which serialization that was used :param state: The state :param kwargs: Extra key word arguments :return: The parsed and to some extend verified response """ _r2e = self.response2error if sformat == "urlencoded": info = self.get_urlinfo(info) # if self.events: # self.events.store('Response', info) resp = response().deserialize(info, sformat, **kwargs) msg = 'Initial response parsing => "{}"' logger.debug(msg.format(sanitize(resp.to_dict()))) if self.events: self.events.store('Response', resp.to_dict()) if "error" in resp and not isinstance(resp, ErrorResponse): resp = None try: errmsgs = _r2e[response.__name__] except KeyError: errmsgs = [ErrorResponse] try: for errmsg in errmsgs: try: resp = errmsg().deserialize(info, sformat) resp.verify() break except Exception: resp = None except KeyError: pass elif resp.only_extras(): resp = None else: kwargs["client_id"] = self.client_id try: kwargs['iss'] = self.provider_info['issuer'] except (KeyError, AttributeError): if self.issuer: kwargs['iss'] = self.issuer if "key" not in kwargs and "keyjar" not in kwargs: kwargs["keyjar"] = self.keyjar logger.debug("Verify response with {}".format(sanitize(kwargs))) verf = resp.verify(**kwargs) if not verf: logger.error('Verification of the response failed') raise PyoidcError("Verification of the response failed") if resp.type() == "AuthorizationResponse" and "scope" not in resp: try: resp["scope"] = kwargs["scope"] except KeyError: pass if not resp: logger.error('Missing or faulty response') raise ResponseError("Missing or faulty response") self.store_response(resp, info) if resp.type() in ["AuthorizationResponse", "AccessTokenResponse"]: try: _state = resp["state"] except (AttributeError, KeyError): _state = "" if not _state: _state = state try: self.grant[_state].update(resp) except KeyError: self.grant[_state] = self.grant_class(resp=resp) return resp def init_authentication_method(self, cis, authn_method, request_args=None, http_args=None, **kwargs): if http_args is None: http_args = {} if request_args is None: request_args = {} if authn_method: return self.client_authn_method[authn_method](self).construct( cis, request_args, http_args, **kwargs) else: return http_args def parse_request_response(self, reqresp, response, body_type, state="", **kwargs): if reqresp.status_code in SUCCESSFUL: body_type = verify_header(reqresp, body_type) elif reqresp.status_code in [302, 303]: # redirect return reqresp elif reqresp.status_code == 500: logger.error("(%d) %s" % (reqresp.status_code, sanitize(reqresp.text))) raise ParseError("ERROR: Something went wrong: %s" % reqresp.text) elif reqresp.status_code in [400, 401]: # expecting an error response if issubclass(response, ErrorResponse): pass else: logger.error("(%d) %s" % (reqresp.status_code, sanitize(reqresp.text))) raise HttpError("HTTP ERROR: %s [%s] on %s" % (reqresp.text, reqresp.status_code, reqresp.url)) if response: if body_type == 'txt': # no meaning trying to parse unstructured text return reqresp.text return self.parse_response(response, reqresp.text, body_type, state, **kwargs) # could be an error response if reqresp.status_code in [200, 400, 401]: if body_type == 'txt': body_type = 'urlencoded' try: err = ErrorResponse().deserialize(reqresp.message, method=body_type) try: err.verify() except PyoidcError: pass else: return err except Exception: pass return reqresp def request_and_return(self, url, response=None, method="GET", body=None, body_type="json", state="", http_args=None, **kwargs): """ :param url: The URL to which the request should be sent :param response: Response type :param method: Which HTTP method to use :param body: A message body if any :param body_type: The format of the body of the return message :param http_args: Arguments for the HTTP client :return: A cls or ErrorResponse instance or the HTTP response instance if no response body was expected. """ if http_args is None: http_args = {} try: resp = self.http_request(url, method, data=body, **http_args) except Exception: raise if "keyjar" not in kwargs: kwargs["keyjar"] = self.keyjar return self.parse_request_response(resp, response, body_type, state, **kwargs) def do_authorization_request(self, request=AuthorizationRequest, state="", body_type="", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=AuthorizationResponse, **kwargs): if state: try: request_args["state"] = state except TypeError: request_args = {"state": state} kwargs['authn_endpoint'] = 'authorization' url, body, ht_args, csi = self.request_info(request, method, request_args, extra_args, **kwargs) try: self.authz_req[request_args["state"]] = csi except TypeError: pass if http_args is None: http_args = ht_args else: http_args.update(ht_args) try: algs = kwargs["algs"] except KeyError: algs = {} resp = self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args, algs=algs) if isinstance(resp, Message): if resp.type() in RESPONSE2ERROR["AuthorizationResponse"]: resp.state = csi.state return resp def do_access_token_request(self, request=AccessTokenRequest, scope="", state="", body_type="json", method="POST", request_args=None, extra_args=None, http_args=None, response_cls=AccessTokenResponse, authn_method="", **kwargs): kwargs['authn_endpoint'] = 'token' # method is default POST url, body, ht_args, csi = self.request_info(request, method=method, request_args=request_args, extra_args=extra_args, scope=scope, state=state, authn_method=authn_method, **kwargs) if http_args is None: http_args = ht_args else: http_args.update(ht_args) if self.events is not None: self.events.store('request_url', url) self.events.store('request_http_args', http_args) self.events.store('Request', body) logger.debug("<do_access_token> URL: %s, Body: %s" % (url, sanitize(body))) logger.debug("<do_access_token> response_cls: %s" % response_cls) return self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args, **kwargs) def do_access_token_refresh(self, request=RefreshAccessTokenRequest, state="", body_type="json", method="POST", request_args=None, extra_args=None, http_args=None, response_cls=AccessTokenResponse, authn_method="", **kwargs): token = self.get_token(also_expired=True, state=state, **kwargs) kwargs['authn_endpoint'] = 'refresh' url, body, ht_args, csi = self.request_info(request, method=method, request_args=request_args, extra_args=extra_args, token=token, authn_method=authn_method) if http_args is None: http_args = ht_args else: http_args.update(ht_args) return self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args) # def do_revocate_token(self, request=TokenRevocationRequest, # scope="", state="", body_type="json", method="POST", # request_args=None, extra_args=None, http_args=None, # response_cls=None, authn_method=""): # # url, body, ht_args, csi = self.request_info(request, method=method, # request_args=request_args, # extra_args=extra_args, # scope=scope, state=state, # authn_method=authn_method) # # if http_args is None: # http_args = ht_args # else: # http_args.update(ht_args) # # return self.request_and_return(url, response_cls, method, body, # body_type, state=state, # http_args=http_args) def do_any(self, request, endpoint="", scope="", state="", body_type="json", method="POST", request_args=None, extra_args=None, http_args=None, response=None, authn_method=""): url, body, ht_args, csi = self.request_info(request, method=method, request_args=request_args, extra_args=extra_args, scope=scope, state=state, authn_method=authn_method, endpoint=endpoint) if http_args is None: http_args = ht_args else: http_args.update(ht_args) return self.request_and_return(url, response, method, body, body_type, state=state, http_args=http_args) def fetch_protected_resource(self, uri, method="GET", headers=None, state="", **kwargs): if "token" in kwargs and kwargs["token"]: token = kwargs["token"] request_args = {"access_token": token} else: try: token = self.get_token(state=state, **kwargs) except ExpiredToken: # The token is to old, refresh self.do_access_token_refresh() token = self.get_token(state=state, **kwargs) request_args = {"access_token": token.access_token} if headers is None: headers = {} if "authn_method" in kwargs: http_args = self.init_authentication_method( request_args=request_args, **kwargs) else: # If nothing defined this is the default http_args = self.client_authn_method["bearer_header"]( self).construct(request_args=request_args) headers.update(http_args["headers"]) logger.debug("Fetch URI: %s" % uri) return self.http_request(uri, method, headers=headers) def add_code_challenge(self): """ PKCE RFC 7636 support :return: """ try: cv_len = self.config['code_challenge']['length'] except KeyError: cv_len = 64 # Use default code_verifier = unreserved(cv_len) _cv = code_verifier.encode() try: _method = self.config['code_challenge']['method'] except KeyError: _method = 'S256' try: _h = CC_METHOD[_method](_cv).hexdigest() code_challenge = b64e(_h.encode()).decode() except KeyError: raise Unsupported('PKCE Transformation method:{}'.format(_method)) # TODO store code_verifier return { "code_challenge": code_challenge, "code_challenge_method": _method }, code_verifier def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True): """ Deal with Provider Config Response :param pcr: The ProviderConfigResponse instance :param issuer: The one I thought should be the issuer of the config :param keys: Should I deal with keys :param endpoints: Should I deal with endpoints, that is store them as attributes in self. """ if "issuer" in pcr: _pcr_issuer = pcr["issuer"] if pcr["issuer"].endswith("/"): if issuer.endswith("/"): _issuer = issuer else: _issuer = issuer + "/" else: if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer try: self.allow["issuer_mismatch"] except KeyError: try: assert _issuer == _pcr_issuer except AssertionError: raise PyoidcError( "provider info issuer mismatch '%s' != '%s'" % (_issuer, _pcr_issuer)) self.provider_info = pcr else: _pcr_issuer = issuer self.issuer = _pcr_issuer if endpoints: for key, val in pcr.items(): if key.endswith("_endpoint"): setattr(self, key, val) if keys: if self.keyjar is None: self.keyjar = KeyJar() self.keyjar.load_keys(pcr, _pcr_issuer) def provider_config(self, issuer, keys=True, endpoints=True, response_cls=ASConfigurationResponse, serv_pattern=OIDCONF_PATTERN): if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer url = serv_pattern % _issuer pcr = None r = self.http_request(url) if r.status_code == 200: pcr = response_cls().from_json(r.text) elif r.status_code == 302: while r.status_code == 302: r = self.http_request(r.headers["location"]) if r.status_code == 200: pcr = response_cls().from_json(r.text) break if pcr is None: raise PyoidcError("Trying '%s', status %s" % (url, r.status_code)) self.handle_provider_config(pcr, issuer, keys, endpoints) return pcr
class Client(oauth2.Client): def __init__(self, client_id=None, ca_certs=None, client_authn_method=None, keyjar=None): oauth2.Client.__init__(self, client_id=client_id, ca_certs=ca_certs, client_authn_method=client_authn_method, keyjar=keyjar) self.allow = {} self.request2endpoint.update({ "RegistrationRequest": "registration_endpoint", "ClientUpdateRequest": "clientinfo_endpoint" }) self.registration_response = None def construct_RegistrationRequest(self, request=RegistrationRequest, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} return self.construct_request(request, request_args, extra_args) def do_client_registration(self, request=RegistrationRequest, body_type="", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=ClientInfoResponse, **kwargs): url, body, ht_args, csi = self.request_info(request, method, request_args, extra_args, **kwargs) if http_args is None: http_args = ht_args else: http_args.update(http_args) resp = self.request_and_return(url, response_cls, method, body, body_type, http_args=http_args) return resp def do_client_read_request(self, request=ClientUpdateRequest, body_type="", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=ClientInfoResponse, **kwargs): url, body, ht_args, csi = self.request_info(request, method, request_args, extra_args, **kwargs) if http_args is None: http_args = ht_args else: http_args.update(http_args) resp = self.request_and_return(url, response_cls, method, body, body_type, http_args=http_args) return resp def do_client_update_request(self, request=ClientUpdateRequest, body_type="", method="PUT", request_args=None, extra_args=None, http_args=None, response_cls=ClientInfoResponse, **kwargs): url, body, ht_args, csi = self.request_info(request, method, request_args, extra_args, **kwargs) if http_args is None: http_args = ht_args else: http_args.update(http_args) resp = self.request_and_return(url, response_cls, method, body, body_type, http_args=http_args) return resp def do_client_delete_request(self, request=ClientUpdateRequest, body_type="", method="DELETE", request_args=None, extra_args=None, http_args=None, response_cls=ClientInfoResponse, **kwargs): url, body, ht_args, csi = self.request_info(request, method, request_args, extra_args, **kwargs) if http_args is None: http_args = ht_args else: http_args.update(http_args) resp = self.request_and_return(url, response_cls, method, body, body_type, http_args=http_args) return resp def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True): """ Deal with Provider Config Response :param pcr: The ProviderConfigResponse instance :param issuer: The one I thought should be the issuer of the config :param keys: Should I deal with keys :param endpoints: Should I deal with endpoints, that is store them as attributes in self. """ if "issuer" in pcr: _pcr_issuer = pcr["issuer"] if pcr["issuer"].endswith("/"): if issuer.endswith("/"): _issuer = issuer else: _issuer = issuer + "/" else: if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer try: _ = self.allow["issuer_mismatch"] except KeyError: try: assert _issuer == _pcr_issuer except AssertionError: raise PyoidcError( "provider info issuer mismatch '%s' != '%s'" % ( _issuer, _pcr_issuer)) self.provider_info[_pcr_issuer] = pcr else: _pcr_issuer = issuer if endpoints: for key, val in pcr.items(): if key.endswith("_endpoint"): setattr(self, key, val) if keys: if self.keyjar is None: self.keyjar = KeyJar() self.keyjar.load_keys(pcr, _pcr_issuer) def provider_config(self, issuer, keys=True, endpoints=True, response_cls=ProviderConfigurationResponse, serv_pattern=OIDCONF_PATTERN): if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer url = serv_pattern % _issuer pcr = None r = self.http_request(url) if r.status_code == 200: pcr = response_cls().from_json(r.text) elif r.status_code == 302: while r.status_code == 302: r = self.http_request(r.headers["location"]) if r.status_code == 200: pcr = response_cls().from_json(r.text) break if pcr is None: raise PyoidcError("Trying '%s', status %s" % (url, r.status_code)) self.handle_provider_config(pcr, issuer, keys, endpoints) return pcr def store_registration_info(self, reginfo): self.registration_response = reginfo self.client_secret = reginfo["client_secret"] self.client_id = reginfo["client_id"] self.redirect_uris = reginfo["redirect_uris"] def handle_registration_info(self, response): if response.status_code == 200: resp = ClientInfoResponse().deserialize(response.text, "json") self.store_registration_info(resp) else: err = ErrorResponse().deserialize(response.text, "json") raise PyoidcError("Registration failed: %s" % err.get_json()) return resp def register(self, url, **kwargs): """ Register the client at an OP :param url: The OPs registration endpoint :param kwargs: parameters to the registration request :return: """ req = self.construct_RegistrationRequest(request_args=kwargs) headers = {"content-type": "application/json"} rsp = self.http_request(url, "POST", data=req.to_json(), headers=headers) return self.handle_registration_info(rsp) def parse_authz_response(self, query): aresp = self.parse_response(AuthorizationResponse, info=query, sformat="urlencoded", keyjar=self.keyjar) if aresp.type() == "ErrorResponse": logger.info("ErrorResponse: %s" % aresp) raise AuthzError(aresp.error) logger.info("Aresp: %s" % aresp) return aresp
class Client(PBase): _endpoints = ENDPOINTS def __init__( self, client_id=None, client_authn_method=None, keyjar=None, verify_ssl=True, config=None, client_cert=None, timeout=5, message_factory: Type[MessageFactory] = OauthMessageFactory, ): """ Initialize the instance. :param client_id: The client identifier :param client_authn_method: Methods that this client can use to authenticate itself. It's a dictionary with method names as keys and method classes as values. :param keyjar: The keyjar for this client. :param verify_ssl: Whether the SSL certificate should be verified. :param client_cert: A client certificate to use. :param timeout: Timeout for requests library. Can be specified either as a single integer or as a tuple of integers. For more details, refer to ``requests`` documentation. :param: message_factory: Factory for message classes, should inherit from OauthMessageFactory :return: Client instance """ PBase.__init__( self, verify_ssl=verify_ssl, keyjar=keyjar, client_cert=client_cert, timeout=timeout, ) self.client_id = client_id self.client_authn_method = client_authn_method self.nonce = None self.message_factory = message_factory self.grant = {} # type: Dict[str, Grant] self.state2nonce = {} # type: Dict[str, str] # own endpoint self.redirect_uris = [] # type: List[str] # service endpoints self.authorization_endpoint = None # type: Optional[str] self.token_endpoint = None # type: Optional[str] self.token_revocation_endpoint = None # type: Optional[str] self.request2endpoint = REQUEST2ENDPOINT self.response2error = RESPONSE2ERROR # type: Dict[str, List] self.grant_class = Grant self.token_class = Token self.provider_info = ASConfigurationResponse() # type: Message self._c_secret = "" # type: str self.kid = {"sig": {}, "enc": {}} # type: Dict[str, Dict] self.authz_req = {} # type: Dict[str, Message] # the OAuth issuer is the URL of the authorization server's # configuration information location self.config = config or {} try: self.issuer = self.config["issuer"] except KeyError: self.issuer = "" self.allow = {} # type: Dict[str, Any] def store_response(self, clinst, text): pass def get_client_secret(self) -> str: return self._c_secret def set_client_secret(self, val: str): if not val: self._c_secret = "" else: self._c_secret = val # client uses it for signing # Server might also use it for signing which means the # client uses it for verifying server signatures if self.keyjar is None: self.keyjar = KeyJar() self.keyjar.add_symmetric("", str(val)) client_secret = property(get_client_secret, set_client_secret) def reset(self) -> None: self.nonce = None self.grant = {} self.authorization_endpoint = None self.token_endpoint = None self.redirect_uris = [] def grant_from_state(self, state: str) -> Optional[Grant]: for key, grant in self.grant.items(): if key == state: return grant return None def _parse_args(self, request: Type[Message], **kwargs) -> Dict: ar_args = kwargs.copy() for prop in request.c_param.keys(): if prop in ar_args: continue else: if prop == "redirect_uri": _val = getattr(self, "redirect_uris", [None])[0] if _val: ar_args[prop] = _val else: _val = getattr(self, prop, None) if _val: ar_args[prop] = _val return ar_args def _endpoint(self, endpoint: str, **kwargs) -> str: try: uri = kwargs[endpoint] if uri: del kwargs[endpoint] except KeyError: uri = "" if not uri: try: uri = getattr(self, endpoint) except Exception: raise MissingEndpoint("No '%s' specified" % endpoint) if not uri: raise MissingEndpoint("No '%s' specified" % endpoint) return uri def get_grant(self, state: str, **kwargs) -> Grant: try: return self.grant[state] except KeyError: raise GrantError("No grant found for state:'%s'" % state) def get_token(self, also_expired: bool = False, **kwargs) -> Token: try: return kwargs["token"] except KeyError: grant = self.get_grant(**kwargs) try: token = grant.get_token(kwargs["scope"]) except KeyError: token = grant.get_token("") if not token: try: token = self.grant[kwargs["state"]].get_token("") except KeyError: raise TokenError("No token found for scope") if token is None: raise TokenError("No suitable token found") if also_expired: return token elif token.is_valid(): return token else: raise TokenError("Token has expired") def clean_tokens(self) -> None: """Clean replaced and invalid tokens.""" for state in self.grant: grant = self.get_grant(state) for token in grant.tokens: if token.replaced or not token.is_valid(): grant.delete_token(token) def construct_request( self, request: Type[Message], request_args=None, extra_args=None ): if request_args is None: request_args = {} kwargs = self._parse_args(request, **request_args) if extra_args: kwargs.update(extra_args) logger.debug("request: %s" % sanitize(request)) return request(**kwargs) def construct_Message( self, request: Type[Message] = Message, request_args=None, extra_args=None, **kwargs ) -> Message: return self.construct_request(request, request_args, extra_args) def construct_AuthorizationRequest( self, request: Type[AuthorizationRequest] = None, request_args=None, extra_args=None, **kwargs ) -> AuthorizationRequest: if request is None: request = self.message_factory.get_request_type("authorization_endpoint") if request_args is not None: try: # change default new = request_args["redirect_uri"] if new: self.redirect_uris = [new] except KeyError: pass else: request_args = {} if "client_id" not in request_args: request_args["client_id"] = self.client_id elif not request_args["client_id"]: request_args["client_id"] = self.client_id return self.construct_request(request, request_args, extra_args) def construct_AccessTokenRequest( self, request: Type[AccessTokenRequest] = None, request_args=None, extra_args=None, **kwargs ) -> AccessTokenRequest: if request is None: request = self.message_factory.get_request_type("token_endpoint") if request_args is None: request_args = {} if request is not ROPCAccessTokenRequest: grant = self.get_grant(**kwargs) if not grant.is_valid(): raise GrantExpired( "Authorization Code to old %s > %s" % (utc_time_sans_frac(), grant.grant_expiration_time) ) request_args["code"] = grant.code try: request_args["state"] = kwargs["state"] except KeyError: pass if "grant_type" not in request_args: request_args["grant_type"] = "authorization_code" if "client_id" not in request_args: request_args["client_id"] = self.client_id elif not request_args["client_id"]: request_args["client_id"] = self.client_id return self.construct_request(request, request_args, extra_args) def construct_RefreshAccessTokenRequest( self, request: Type[RefreshAccessTokenRequest] = None, request_args=None, extra_args=None, **kwargs ) -> RefreshAccessTokenRequest: if request is None: request = self.message_factory.get_request_type("refresh_endpoint") if request_args is None: request_args = {} token = self.get_token(also_expired=True, **kwargs) request_args["refresh_token"] = token.refresh_token try: request_args["scope"] = token.scope except AttributeError: pass return self.construct_request(request, request_args, extra_args) def construct_ResourceRequest( self, request: Type[ResourceRequest] = None, request_args=None, extra_args=None, **kwargs ) -> ResourceRequest: if request is None: request = self.message_factory.get_request_type("resource_endpoint") if request_args is None: request_args = {} token = self.get_token(**kwargs) request_args["access_token"] = token.access_token return self.construct_request(request, request_args, extra_args) def uri_and_body( self, reqmsg: Type[Message], cis: Message, method="POST", request_args=None, **kwargs ) -> Tuple[str, str, Dict, Message]: if "endpoint" in kwargs and kwargs["endpoint"]: uri = kwargs["endpoint"] else: uri = self._endpoint(self.request2endpoint[reqmsg.__name__], **request_args) uri, body, kwargs = get_or_post(uri, method, cis, **kwargs) try: h_args = {"headers": kwargs["headers"]} except KeyError: h_args = {} return uri, body, h_args, cis def request_info( self, request: Type[Message], method="POST", request_args=None, extra_args=None, lax=False, **kwargs ) -> Tuple[str, str, Dict, Message]: if request_args is None: request_args = {} try: cls = getattr(self, "construct_%s" % request.__name__) cis = cls(request_args=request_args, extra_args=extra_args, **kwargs) except AttributeError: cis = self.construct_request(request, request_args, extra_args) if self.events: self.events.store("Protocol request", cis) if "nonce" in cis and "state" in cis: self.state2nonce[cis["state"]] = cis["nonce"] cis.lax = lax if "authn_method" in kwargs: h_arg = self.init_authentication_method( cis, request_args=request_args, **kwargs ) else: h_arg = None if h_arg: if "headers" in kwargs.keys(): kwargs["headers"].update(h_arg["headers"]) else: kwargs["headers"] = h_arg["headers"] return self.uri_and_body(request, cis, method, request_args, **kwargs) def authorization_request_info(self, request_args=None, extra_args=None, **kwargs): return self.request_info( self.message_factory.get_request_type("authorization_endpoint"), "GET", request_args, extra_args, **kwargs ) @staticmethod def get_urlinfo(info: str) -> str: if "?" in info or "#" in info: parts = urlparse(info) scheme, netloc, path, params, query, fragment = parts[:6] # either query of fragment if query: info = query else: info = fragment return info def parse_response( self, response: Type[Message], info: str = "", sformat: ENCODINGS = "json", state: str = "", **kwargs ) -> Message: """ Parse a response. :param response: Response type :param info: The response, can be either in a JSON or an urlencoded format :param sformat: Which serialization that was used :param state: The state :param kwargs: Extra key word arguments :return: The parsed and to some extend verified response """ _r2e = self.response2error if sformat == "urlencoded": info = self.get_urlinfo(info) resp = response().deserialize(info, sformat, **kwargs) msg = 'Initial response parsing => "{}"' logger.debug(msg.format(sanitize(resp.to_dict()))) if self.events: self.events.store("Response", resp.to_dict()) if "error" in resp and not isinstance(resp, ErrorResponse): resp = None errmsgs = [] # type: List[Any] try: errmsgs = _r2e[response.__name__] except KeyError: errmsgs = [ErrorResponse] try: for errmsg in errmsgs: try: resp = errmsg().deserialize(info, sformat) resp.verify() break except Exception: resp = None except KeyError: pass elif resp.only_extras(): resp = None else: kwargs["client_id"] = self.client_id try: kwargs["iss"] = self.provider_info["issuer"] except (KeyError, AttributeError): if self.issuer: kwargs["iss"] = self.issuer if "key" not in kwargs and "keyjar" not in kwargs: kwargs["keyjar"] = self.keyjar logger.debug("Verify response with {}".format(sanitize(kwargs))) verf = resp.verify(**kwargs) if not verf: logger.error("Verification of the response failed") raise PyoidcError("Verification of the response failed") if resp.type() == "AuthorizationResponse" and "scope" not in resp: try: resp["scope"] = kwargs["scope"] except KeyError: pass if not resp: logger.error("Missing or faulty response") raise ResponseError("Missing or faulty response") self.store_response(resp, info) if isinstance(resp, (AuthorizationResponse, AccessTokenResponse)): try: _state = resp["state"] except (AttributeError, KeyError): _state = "" if not _state: _state = state try: self.grant[_state].update(resp) except KeyError: self.grant[_state] = self.grant_class(resp=resp) return resp def init_authentication_method( self, cis, authn_method, request_args=None, http_args=None, **kwargs ): if http_args is None: http_args = {} if request_args is None: request_args = {} if authn_method: return self.client_authn_method[authn_method](self).construct( cis, request_args, http_args, **kwargs ) else: return http_args def parse_request_response(self, reqresp, response, body_type, state="", **kwargs): if reqresp.status_code in SUCCESSFUL: body_type = verify_header(reqresp, body_type) elif reqresp.status_code in [302, 303]: # redirect return reqresp elif reqresp.status_code == 500: logger.error("(%d) %s" % (reqresp.status_code, sanitize(reqresp.text))) raise ParseError("ERROR: Something went wrong: %s" % reqresp.text) elif reqresp.status_code in [400, 401]: # expecting an error response if issubclass(response, ErrorResponse): pass else: logger.error("(%d) %s" % (reqresp.status_code, sanitize(reqresp.text))) raise HttpError( "HTTP ERROR: %s [%s] on %s" % (reqresp.text, reqresp.status_code, reqresp.url) ) if response: if body_type == "txt": # no meaning trying to parse unstructured text return reqresp.text return self.parse_response( response, reqresp.text, body_type, state, **kwargs ) # could be an error response if reqresp.status_code in [200, 400, 401]: if body_type == "txt": body_type = "urlencoded" try: err = ErrorResponse().deserialize(reqresp.message, method=body_type) try: err.verify() except PyoidcError: pass else: return err except Exception: pass return reqresp def request_and_return( self, url: str, response: Type[Message] = None, method="GET", body=None, body_type: ENCODINGS = "json", state: str = "", http_args=None, **kwargs ): """ Perform a request and return the response. :param url: The URL to which the request should be sent :param response: Response type :param method: Which HTTP method to use :param body: A message body if any :param body_type: The format of the body of the return message :param http_args: Arguments for the HTTP client :return: A cls or ErrorResponse instance or the HTTP response instance if no response body was expected. """ # FIXME: Cannot annotate return value as Message since it disrupts all other methods if http_args is None: http_args = {} try: resp = self.http_request(url, method, data=body, **http_args) except Exception: raise if "keyjar" not in kwargs: kwargs["keyjar"] = self.keyjar return self.parse_request_response(resp, response, body_type, state, **kwargs) def do_authorization_request( self, state="", body_type="", method="GET", request_args=None, extra_args=None, http_args=None, **kwargs ) -> AuthorizationResponse: request = self.message_factory.get_request_type("authorization_endpoint") response_cls = self.message_factory.get_response_type("authorization_endpoint") if state: try: request_args["state"] = state except TypeError: request_args = {"state": state} kwargs["authn_endpoint"] = "authorization" url, body, ht_args, csi = self.request_info( request, method, request_args, extra_args, **kwargs ) try: self.authz_req[request_args["state"]] = csi except TypeError: pass if http_args is None: http_args = ht_args else: http_args.update(ht_args) try: algs = kwargs["algs"] except KeyError: algs = {} resp = self.request_and_return( url, response_cls, method, body, body_type, state=state, http_args=http_args, algs=algs, ) if isinstance(resp, Message): # FIXME: The Message classes do not have classical attrs if resp.type() in RESPONSE2ERROR["AuthorizationResponse"]: # type: ignore resp.state = csi.state # type: ignore return resp def do_access_token_request( self, scope: str = "", state: str = "", body_type: ENCODINGS = "json", method="POST", request_args=None, extra_args=None, http_args=None, authn_method="", **kwargs ) -> AccessTokenResponse: request = self.message_factory.get_request_type("token_endpoint") response_cls = self.message_factory.get_response_type("token_endpoint") kwargs["authn_endpoint"] = "token" # method is default POST url, body, ht_args, csi = self.request_info( request, method=method, request_args=request_args, extra_args=extra_args, scope=scope, state=state, authn_method=authn_method, **kwargs ) if http_args is None: http_args = ht_args else: http_args.update(ht_args) http_args.pop("password", None) if self.events is not None: self.events.store("request_url", url) self.events.store("request_http_args", http_args) self.events.store("Request", body) logger.debug("<do_access_token> URL: %s, Body: %s" % (url, sanitize(body))) logger.debug("<do_access_token> response_cls: %s" % response_cls) return self.request_and_return( url, response_cls, method, body, body_type, state=state, http_args=http_args, **kwargs ) def do_access_token_refresh( self, state: str = "", body_type: ENCODINGS = "json", method="POST", request_args=None, extra_args=None, http_args=None, authn_method="", **kwargs ) -> AccessTokenResponse: request = self.message_factory.get_request_type("refresh_endpoint") response_cls = self.message_factory.get_response_type("refresh_endpoint") token = self.get_token(also_expired=True, state=state, **kwargs) kwargs["authn_endpoint"] = "refresh" url, body, ht_args, csi = self.request_info( request, method=method, request_args=request_args, extra_args=extra_args, token=token, authn_method=authn_method, ) if http_args is None: http_args = ht_args else: http_args.update(ht_args) response = self.request_and_return( url, response_cls, method, body, body_type, state=state, http_args=http_args ) if token.replaced: grant = self.get_grant(state) grant.delete_token(token) return response def do_any( self, request: Type[Message], endpoint="", scope="", state="", body_type="json", method="POST", request_args=None, extra_args=None, http_args=None, response: Type[Message] = None, authn_method="", ) -> Message: url, body, ht_args, _ = self.request_info( request, method=method, request_args=request_args, extra_args=extra_args, scope=scope, state=state, authn_method=authn_method, endpoint=endpoint, ) if http_args is None: http_args = ht_args else: http_args.update(ht_args) return self.request_and_return( url, response, method, body, body_type, state=state, http_args=http_args ) def fetch_protected_resource( self, uri, method="GET", headers=None, state="", **kwargs ): if "token" in kwargs and kwargs["token"]: token = kwargs["token"] request_args = {"access_token": token} else: try: token = self.get_token(state=state, **kwargs) except ExpiredToken: # The token is to old, refresh self.do_access_token_refresh(state=state) token = self.get_token(state=state, **kwargs) request_args = {"access_token": token.access_token} if headers is None: headers = {} if "authn_method" in kwargs: http_args = self.init_authentication_method( request_args=request_args, **kwargs ) else: # If nothing defined this is the default http_args = self.client_authn_method["bearer_header"](self).construct( request_args=request_args ) headers.update(http_args["headers"]) logger.debug("Fetch URI: %s" % uri) return self.http_request(uri, method, headers=headers) def add_code_challenge(self): """ PKCE RFC 7636 support. :return: """ try: cv_len = self.config["code_challenge"]["length"] except KeyError: cv_len = 64 # Use default code_verifier = unreserved(cv_len) _cv = code_verifier.encode("ascii") try: _method = self.config["code_challenge"]["method"] except KeyError: _method = "S256" try: _h = CC_METHOD[_method](_cv).digest() code_challenge = b64e(_h).decode("ascii") except KeyError: raise Unsupported("PKCE Transformation method:{}".format(_method)) # TODO store code_verifier return ( {"code_challenge": code_challenge, "code_challenge_method": _method}, code_verifier, ) def handle_provider_config( self, pcr: ASConfigurationResponse, issuer: str, keys: bool = True, endpoints: bool = True, ) -> None: """ Deal with Provider Config Response. :param pcr: The ProviderConfigResponse instance :param issuer: The one I thought should be the issuer of the config :param keys: Should I deal with keys :param endpoints: Should I deal with endpoints, that is store them as attributes in self. """ if "issuer" in pcr: _pcr_issuer = pcr["issuer"] if pcr["issuer"].endswith("/"): if issuer.endswith("/"): _issuer = issuer else: _issuer = issuer + "/" else: if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer if not self.allow.get("issuer_mismatch", False) and _issuer != _pcr_issuer: raise PyoidcError( "provider info issuer mismatch '%s' != '%s'" % (_issuer, _pcr_issuer) ) self.provider_info = pcr else: _pcr_issuer = issuer self.issuer = _pcr_issuer if endpoints: for key, val in pcr.items(): if key.endswith("_endpoint"): setattr(self, key, val) if keys: if self.keyjar is None: self.keyjar = KeyJar() self.keyjar.load_keys(pcr, _pcr_issuer) def provider_config( self, issuer: str, keys: bool = True, endpoints: bool = True, serv_pattern: str = OIDCONF_PATTERN, ) -> ASConfigurationResponse: response_cls = self.message_factory.get_response_type("configuration_endpoint") if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer url = serv_pattern % _issuer pcr = None r = self.http_request(url, allow_redirects=True) if r.status_code == 200: try: pcr = response_cls().from_json(r.text) except Exception as e: # FIXME: This should catch specific exception from `from_json()` _err_txt = "Faulty provider config response: {}".format(e) logger.error(sanitize(_err_txt)) raise ParseError(_err_txt) else: raise CommunicationError("Trying '%s', status %s" % (url, r.status_code)) self.store_response(pcr, r.text) self.handle_provider_config(pcr, issuer, keys, endpoints) return pcr
class Client(oauth2.Client): def __init__(self, client_id=None, client_authn_method=None, keyjar=None, verify_ssl=True, config=None): oauth2.Client.__init__(self, client_id=client_id, client_authn_method=client_authn_method, keyjar=keyjar, verify_ssl=verify_ssl, config=config) self.allow = {} self.request2endpoint.update({ "RegistrationRequest": "registration_endpoint", "ClientUpdateRequest": "clientinfo_endpoint", 'TokenIntrospectionRequest': 'introspection_endpoint', 'TokenRevocationRequest': 'revocation_endpoint' }) self.registration_response = None def construct_RegistrationRequest(self, request=RegistrationRequest, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} return self.construct_request(request, request_args, extra_args) def construct_ClientUpdateRequest(self, request=ClientUpdateRequest, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} return self.construct_request(request, request_args, extra_args) def _token_interaction_setup(self, request_args=None, **kwargs): if request_args is None or 'token' not in request_args: token = self.get_token(**kwargs) try: _token_type_hint = kwargs['token_type_hint'] except KeyError: _token_type_hint = 'access_token' request_args = { 'token_type_hint': _token_type_hint, 'token': getattr(token, _token_type_hint) } if "client_id" not in request_args: request_args["client_id"] = self.client_id elif not request_args["client_id"]: request_args["client_id"] = self.client_id return request_args def construct_TokenIntrospectionRequest(self, request=TokenIntrospectionRequest, request_args=None, extra_args=None, **kwargs): request_args = self._token_interaction_setup(request_args, **kwargs) return self.construct_request(request, request_args, extra_args) def construct_TokenRevocationRequest(self, request=TokenRevocationRequest, request_args=None, extra_args=None, **kwargs): request_args = self._token_interaction_setup(request_args, **kwargs) return self.construct_request(request, request_args, extra_args) def do_op(self, request, body_type='', method='GET', request_args=None, extra_args=None, http_args=None, response_cls=None, **kwargs): url, body, ht_args, csi = self.request_info(request, method, request_args, extra_args, **kwargs) if http_args is None: http_args = ht_args else: http_args.update(http_args) resp = self.request_and_return(url, response_cls, method, body, body_type, http_args=http_args) return resp def do_client_registration(self, request=RegistrationRequest, body_type="", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=ClientInfoResponse, **kwargs): return self.do_op(request=request, body_type=body_type, method=method, request_args=request_args, extra_args=extra_args, http_args=http_args, response_cls=response_cls, **kwargs) def do_client_read_request(self, request=ClientUpdateRequest, body_type="", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=ClientInfoResponse, **kwargs): return self.do_op(request=request, body_type=body_type, method=method, request_args=request_args, extra_args=extra_args, http_args=http_args, response_cls=response_cls, **kwargs) def do_client_update_request(self, request=ClientUpdateRequest, body_type="", method="PUT", request_args=None, extra_args=None, http_args=None, response_cls=ClientInfoResponse, **kwargs): return self.do_op(request=request, body_type=body_type, method=method, request_args=request_args, extra_args=extra_args, http_args=http_args, response_cls=response_cls, **kwargs) def do_client_delete_request(self, request=ClientUpdateRequest, body_type="", method="DELETE", request_args=None, extra_args=None, http_args=None, response_cls=ClientInfoResponse, **kwargs): return self.do_op(request=request, body_type=body_type, method=method, request_args=request_args, extra_args=extra_args, http_args=http_args, response_cls=response_cls, **kwargs) def do_token_introspection(self, request=TokenIntrospectionRequest, body_type="json", method="POST", request_args=None, extra_args=None, http_args=None, response_cls=TokenIntrospectionResponse, **kwargs): return self.do_op(request=request, body_type=body_type, method=method, request_args=request_args, extra_args=extra_args, http_args=http_args, response_cls=response_cls, **kwargs) def do_token_revocation(self, request=TokenRevocationRequest, body_type="", method="POST", request_args=None, extra_args=None, http_args=None, response_cls=None, **kwargs): return self.do_op(request=request, body_type=body_type, method=method, request_args=request_args, extra_args=extra_args, http_args=http_args, response_cls=response_cls, **kwargs) def add_code_challenge(self): try: cv_len = self.config['code_challenge']['length'] except KeyError: cv_len = 64 # Use default code_verifier = unreserved(cv_len) _cv = code_verifier.encode() try: _method = self.config['code_challenge']['method'] except KeyError: _method = 'S256' try: _h = CC_METHOD[_method](_cv).hexdigest() code_challenge = b64e(_h.encode()).decode() except KeyError: raise Unsupported('PKCE Transformation method:{}'.format(_method)) # TODO store code_verifier return { "code_challenge": code_challenge, "code_challenge_method": _method }, code_verifier def do_authorization_request(self, request=AuthorizationRequest, state="", body_type="", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=AuthorizationResponse, **kwargs): if 'code_challenge' in self.config and self.config['code_challenge']: _args, code_verifier = self.add_code_challenge() request_args.update(_args) oauth2.Client.do_authorization_request(self, request=request, state=state, body_type=body_type, method=method, request_args=request_args, extra_args=extra_args, http_args=http_args, response_cls=response_cls, **kwargs) def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True): """ Deal with Provider Config Response :param pcr: The ProviderConfigResponse instance :param issuer: The one I thought should be the issuer of the config :param keys: Should I deal with keys :param endpoints: Should I deal with endpoints, that is store them as attributes in self. """ if "issuer" in pcr: _pcr_issuer = pcr["issuer"] if pcr["issuer"].endswith("/"): if issuer.endswith("/"): _issuer = issuer else: _issuer = issuer + "/" else: if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer try: self.allow["issuer_mismatch"] except KeyError: try: assert _issuer == _pcr_issuer except AssertionError: raise PyoidcError( "provider info issuer mismatch '%s' != '%s'" % (_issuer, _pcr_issuer)) self.provider_info = pcr else: _pcr_issuer = issuer if endpoints: for key, val in pcr.items(): if key.endswith("_endpoint"): setattr(self, key, val) if keys: if self.keyjar is None: self.keyjar = KeyJar() self.keyjar.load_keys(pcr, _pcr_issuer) def provider_config(self, issuer, keys=True, endpoints=True, response_cls=ASConfigurationResponse, serv_pattern=OIDCONF_PATTERN): if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer url = serv_pattern % _issuer pcr = None r = self.http_request(url) if self.events: self.events.store('HTTP response header', r.headers) if r.status_code == 200: pcr = response_cls().from_json(r.text) elif r.status_code == 302: while r.status_code == 302: r = self.http_request(r.headers["location"]) if r.status_code == 200: pcr = response_cls().from_json(r.text) break if pcr is None: raise PyoidcError("Trying '%s', status %s" % (url, r.status_code)) self.handle_provider_config(pcr, issuer, keys, endpoints) return pcr def store_registration_info(self, reginfo): self.registration_response = reginfo self.client_secret = reginfo["client_secret"] self.client_id = reginfo["client_id"] self.redirect_uris = reginfo["redirect_uris"] def handle_registration_info(self, response): if response.status_code in SUCCESSFUL: resp = ClientInfoResponse().deserialize(response.text, "json") self.store_response(resp, response.text) self.store_registration_info(resp) else: resp = ErrorResponse().deserialize(response.text, "json") try: resp.verify() self.store_response(resp, response.text) except Exception: raise PyoidcError('Registration failed: {}'.format( response.text)) return resp def register(self, url, **kwargs): """ Register the client at an OP :param url: The OPs registration endpoint :param kwargs: parameters to the registration request :return: """ req = self.construct_RegistrationRequest(request_args=kwargs) headers = {"content-type": "application/json"} rsp = self.http_request(url, "POST", data=req.to_json(), headers=headers) return self.handle_registration_info(rsp) def parse_authz_response(self, query): aresp = self.parse_response(AuthorizationResponse, info=query, sformat="urlencoded", keyjar=self.keyjar) if aresp.type() == "ErrorResponse": logger.info("ErrorResponse: %s" % sanitize(aresp)) raise AuthzError(aresp.error) logger.info("Aresp: %s" % sanitize(aresp)) return aresp
class Client(oauth2.Client): _endpoints = ENDPOINTS def __init__(self, client_id=None, ca_certs=None, client_prefs=None, client_authn_method=None, keyjar=None, verify_ssl=True): oauth2.Client.__init__(self, client_id, ca_certs, client_authn_method=client_authn_method, keyjar=keyjar, verify_ssl=verify_ssl) self.file_store = "./file/" self.file_uri = "http://localhost/" # OpenID connect specific endpoints for endpoint in ENDPOINTS: setattr(self, endpoint, "") self.id_token = None self.log = None self.request2endpoint = REQUEST2ENDPOINT self.response2error = RESPONSE2ERROR self.grant_class = Grant self.token_class = Token self.provider_info = None self.registration_response = None self.client_prefs = client_prefs or {} self.behaviour = { "request_object_signing_alg": DEF_SIGN_ALG["openid_request_object"]} self.wf = WebFinger(OIC_ISSUER) self.wf.httpd = self self.allow = {} self.post_logout_redirect_uris = [] self.registration_expires = 0 self.registration_access_token = None # Default key by kid for different key types # For instance {"RSA":"abc"} self.kid = {"sig": {}, "enc": {}} def _get_id_token(self, **kwargs): try: return kwargs["id_token"] except KeyError: grant = self.get_grant(**kwargs) if grant: try: _scope = kwargs["scope"] except KeyError: _scope = None for token in grant.tokens: if token.scope and _scope: flag = True for item in _scope: try: assert item in token.scope except AssertionError: flag = False break if not flag: break if token.id_token: return token.id_token return None def construct_AuthorizationRequest(self, request=AuthorizationRequest, request_args=None, extra_args=None, request_param=None, **kwargs): if request_args is not None: # if "claims" in request_args: # kwargs["claims"] = request_args["claims"] # del request_args["claims"] if "nonce" not in request_args: _rt = request_args["response_type"] if "token" in _rt or "id_token" in _rt: request_args["nonce"] = rndstr(12) elif "response_type" in kwargs: if "token" in kwargs["response_type"]: request_args = {"nonce": rndstr(12)} else: # Never wrong to specify a nonce request_args = {"nonce": rndstr(12)} if "request_method" in kwargs: if kwargs["request_method"] == "file": request_param = "request_uri" del kwargs["request_method"] areq = oauth2.Client.construct_AuthorizationRequest(self, request, request_args, extra_args, **kwargs) if request_param: alg = self.behaviour["request_object_signing_alg"] if "algorithm" not in kwargs: kwargs["algorithm"] = alg if "keys" not in kwargs and alg: _kty = alg2keytype(alg) try: kwargs["keys"] = self.keyjar.get_signing_key( _kty, kid=self.kid["sig"][_kty]) except KeyError: kwargs["keys"] = self.keyjar.get_signing_key(_kty) _req = make_openid_request(areq, **kwargs) if request_param == "request": areq["request"] = _req else: _filedir = kwargs["local_dir"] _webpath = kwargs["base_path"] _name = rndstr(10) filename = os.path.join(_filedir, _name) while os.path.exists(filename): _name = rndstr(10) filename = os.path.join(_filedir, _name) fid = open(filename, mode="w") fid.write(_req) fid.close() _webname = "%s%s" % (_webpath, _name) areq["request_uri"] = _webname return areq #noinspection PyUnusedLocal def construct_AccessTokenRequest(self, request=AccessTokenRequest, request_args=None, extra_args=None, **kwargs): return oauth2.Client.construct_AccessTokenRequest(self, request, request_args, extra_args, **kwargs) def construct_RefreshAccessTokenRequest(self, request=RefreshAccessTokenRequest, request_args=None, extra_args=None, **kwargs): return oauth2.Client.construct_RefreshAccessTokenRequest(self, request, request_args, extra_args, **kwargs) def construct_UserInfoRequest(self, request=UserInfoRequest, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} if "access_token" in request_args: pass else: if "scope" not in kwargs: kwargs["scope"] = "openid" token = self.get_token(**kwargs) if token is None: raise PyoidcError("No valid token available") request_args["access_token"] = token.access_token return self.construct_request(request, request_args, extra_args) #noinspection PyUnusedLocal def construct_RegistrationRequest(self, request=RegistrationRequest, request_args=None, extra_args=None, **kwargs): return self.construct_request(request, request_args, extra_args) #noinspection PyUnusedLocal def construct_RefreshSessionRequest(self, request=RefreshSessionRequest, request_args=None, extra_args=None, **kwargs): return self.construct_request(request, request_args, extra_args) def _id_token_based(self, request, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} try: _prop = kwargs["prop"] except KeyError: _prop = "id_token" if _prop in request_args: pass else: id_token = self._get_id_token(**kwargs) if id_token is None: raise PyoidcError("No valid id token available") request_args[_prop] = id_token return self.construct_request(request, request_args, extra_args) def construct_CheckSessionRequest(self, request=CheckSessionRequest, request_args=None, extra_args=None, **kwargs): return self._id_token_based(request, request_args, extra_args, **kwargs) def construct_CheckIDRequest(self, request=CheckIDRequest, request_args=None, extra_args=None, **kwargs): # access_token is where the id_token will be placed return self._id_token_based(request, request_args, extra_args, prop="access_token", **kwargs) def construct_EndSessionRequest(self, request=EndSessionRequest, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} if "state" in kwargs: request_args["state"] = kwargs["state"] elif "state" in request_args: kwargs["state"] = request_args["state"] # if "redirect_url" not in request_args: # request_args["redirect_url"] = self.redirect_url return self._id_token_based(request, request_args, extra_args, **kwargs) # ------------------------------------------------------------------------ def authorization_request_info(self, request_args=None, extra_args=None, **kwargs): return self.request_info(AuthorizationRequest, "GET", request_args, extra_args, **kwargs) # ------------------------------------------------------------------------ def do_authorization_request(self, request=AuthorizationRequest, state="", body_type="", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=AuthorizationResponse): return oauth2.Client.do_authorization_request(self, request, state, body_type, method, request_args, extra_args, http_args, response_cls) def do_access_token_request(self, request=AccessTokenRequest, scope="", state="", body_type="json", method="POST", request_args=None, extra_args=None, http_args=None, response_cls=AccessTokenResponse, authn_method="", **kwargs): return oauth2.Client.do_access_token_request(self, request, scope, state, body_type, method, request_args, extra_args, http_args, response_cls, authn_method, **kwargs) def do_access_token_refresh(self, request=RefreshAccessTokenRequest, state="", body_type="json", method="POST", request_args=None, extra_args=None, http_args=None, response_cls=AccessTokenResponse, **kwargs): return oauth2.Client.do_access_token_refresh(self, request, state, body_type, method, request_args, extra_args, http_args, response_cls, **kwargs) def do_registration_request(self, request=RegistrationRequest, scope="", state="", body_type="json", method="POST", request_args=None, extra_args=None, http_args=None, response_cls=None): url, body, ht_args, csi = self.request_info(request, method=method, request_args=request_args, extra_args=extra_args, scope=scope, state=state) if http_args is None: http_args = ht_args else: http_args.update(http_args) if response_cls is None: response_cls = RegistrationResponse response = self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args) return response def do_check_session_request(self, request=CheckSessionRequest, scope="", state="", body_type="json", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=IdToken): url, body, ht_args, csi = self.request_info(request, method=method, request_args=request_args, extra_args=extra_args, scope=scope, state=state) if http_args is None: http_args = ht_args else: http_args.update(http_args) return self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args) def do_check_id_request(self, request=CheckIDRequest, scope="", state="", body_type="json", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=IdToken): url, body, ht_args, csi = self.request_info(request, method=method, request_args=request_args, extra_args=extra_args, scope=scope, state=state) if http_args is None: http_args = ht_args else: http_args.update(http_args) return self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args) def do_end_session_request(self, request=EndSessionRequest, scope="", state="", body_type="", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=None): url, body, ht_args, csi = self.request_info(request, method=method, request_args=request_args, extra_args=extra_args, scope=scope, state=state) if http_args is None: http_args = ht_args else: http_args.update(http_args) return self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args) def user_info_request(self, method="GET", state="", scope="", **kwargs): uir = UserInfoRequest() logger.debug("[user_info_request]: kwargs:%s" % (kwargs,)) if "token" in kwargs: if kwargs["token"]: uir["access_token"] = kwargs["token"] token = Token() token.token_type = "Bearer" token.access_token = kwargs["token"] kwargs["behavior"] = "use_authorization_header" else: # What to do ? Need a callback token = None elif "access_token" in kwargs and kwargs["access_token"]: uir["access_token"] = kwargs["access_token"] del kwargs["access_token"] token = None else: token = self.grant[state].get_token(scope) if token.is_valid(): uir["access_token"] = token.access_token if token.token_type == "Bearer" and method == "GET": kwargs["behavior"] = "use_authorization_header" else: # raise oauth2.OldAccessToken if self.log: self.log.info("do access token refresh") try: self.do_access_token_refresh(token=token) token = self.grant[state].get_token(scope) uir["access_token"] = token.access_token except Exception: raise uri = self._endpoint("userinfo_endpoint", **kwargs) # If access token is a bearer token it might be sent in the # authorization header # 3-ways of sending the access_token: # - POST with token in authorization header # - POST with token in message body # - GET with token in authorization header if "behavior" in kwargs: _behav = kwargs["behavior"] _token = uir["access_token"] try: _ttype = kwargs["token_type"] except KeyError: try: _ttype = token.token_type except AttributeError: raise MissingParameter("Unspecified token type") # use_authorization_header, token_in_message_body if "use_authorization_header" in _behav and _ttype == "Bearer": bh = "Bearer %s" % _token if "headers" in kwargs: kwargs["headers"].update({"Authorization": bh}) else: kwargs["headers"] = {"Authorization": bh} if not "token_in_message_body" in _behav: # remove the token from the request del uir["access_token"] path, body, kwargs = self.get_or_post(uri, method, uir, **kwargs) h_args = dict([(k, v) for k, v in kwargs.items() if k in HTTP_ARGS]) return path, body, method, h_args def do_user_info_request(self, method="POST", state="", scope="openid", request="openid", **kwargs): kwargs["request"] = request path, body, method, h_args = self.user_info_request(method, state, scope, **kwargs) logger.debug("[do_user_info_request] PATH:%s BODY:%s H_ARGS: %s" % ( path, body, h_args)) try: resp = self.http_request(path, method, data=body, **h_args) except oauth2.MissingRequiredAttribute: raise if resp.status_code == 200: try: assert "application/json" in resp.headers["content-type"] sformat = "json" except AssertionError: assert "application/jwt" in resp.headers["content-type"] sformat = "jwt" elif resp.status_code == 500: raise PyoidcError("ERROR: Something went wrong: %s" % resp.text) else: raise PyoidcError("ERROR: Something went wrong [%s]: %s" % ( resp.status_code, resp.text)) try: _schema = kwargs["user_info_schema"] except KeyError: _schema = OpenIDSchema logger.debug("Reponse text: '%s'" % resp.text) if sformat == "json": return _schema().from_json(txt=resp.text) else: algo = self.client_prefs["userinfo_signed_response_alg"] _kty = alg2keytype(algo) # Keys of the OP ? try: keys = self.keyjar.get_signing_key(_kty, self.kid["sig"][_kty]) except KeyError: keys = self.keyjar.get_signing_key(_kty) return _schema().from_jwt(resp.text, keys) def get_userinfo_claims(self, access_token, endpoint, method="POST", schema_class=OpenIDSchema, **kwargs): uir = UserInfoRequest(access_token=access_token) h_args = dict([(k, v) for k, v in kwargs.items() if k in HTTP_ARGS]) if "authn_method" in kwargs: http_args = self.init_authentication_method(**kwargs) else: # If nothing defined this is the default http_args = self.init_authentication_method(uir, "bearer_header", **kwargs) h_args.update(http_args) path, body, kwargs = self.get_or_post(endpoint, method, uir, **kwargs) try: resp = self.http_request(path, method, data=body, **h_args) except oauth2.MissingRequiredAttribute: raise if resp.status_code == 200: assert "application/json" in resp.headers["content-type"] elif resp.status_code == 500: raise PyoidcError("ERROR: Something went wrong: %s" % resp.text) else: raise PyoidcError( "ERROR: Something went wrong [%s]" % resp.status_code) return schema_class().from_json(txt=resp.text) def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True): """ Deal with Provider Config Response :param pcr: The ProviderConfigResponse instance :param issuer: The one I thought should be the issuer of the config :param keys: Should I deal with keys :param endpoints: Should I deal with endpoints, that is store them as attributes in self. """ if "issuer" in pcr: _pcr_issuer = pcr["issuer"] if pcr["issuer"].endswith("/"): if issuer.endswith("/"): _issuer = issuer else: _issuer = issuer + "/" else: if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer try: _ = self.allow["issuer_mismatch"] except KeyError: try: assert _issuer == _pcr_issuer except AssertionError: raise IssuerMismatch("'%s' != '%s'" % (_issuer, _pcr_issuer), pcr) self.provider_info = pcr else: _pcr_issuer = issuer if endpoints: for key, val in pcr.items(): if key.endswith("_endpoint"): setattr(self, key, val) if keys: if self.keyjar is None: self.keyjar = KeyJar(verify_ssl=self.verify_ssl) self.keyjar.load_keys(pcr, _pcr_issuer) def provider_config(self, issuer, keys=True, endpoints=True, response_cls=ProviderConfigurationResponse, serv_pattern=OIDCONF_PATTERN): if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer url = serv_pattern % _issuer pcr = None r = self.http_request(url) if r.status_code == 200: pcr = response_cls().from_json(r.text) elif r.status_code == 302: while r.status_code == 302: r = self.http_request(r.headers["location"]) if r.status_code == 200: pcr = response_cls().from_json(r.text) break #logger.debug("Provider info: %s" % pcr) if pcr is None: raise PyoidcError("Trying '%s', status %s" % (url, r.status_code)) self.handle_provider_config(pcr, issuer, keys, endpoints) return pcr def unpack_aggregated_claims(self, userinfo): if userinfo._claim_sources: for csrc, spec in userinfo._claim_sources.items(): if "JWT" in spec: if not csrc in self.keyjar: self.provider_config(csrc, endpoints=False) keycol = self.keyjar.get_verify_key(owner=csrc) for typ, keyl in self.keyjar.get_verify_key().items(): try: keycol[typ].extend(keyl) except KeyError: keycol[typ] = keyl info = json.loads(JWS().verify(str(spec["JWT"]), keycol)) attr = [n for n, s in userinfo._claim_names.items() if s == csrc] assert attr == info.keys() for key, vals in info.items(): userinfo[key] = vals return userinfo def fetch_distributed_claims(self, userinfo, callback=None): for csrc, spec in userinfo._claim_sources.items(): if "endpoint" in spec: #pcr = self.provider_config(csrc, keys=False, endpoints=False) if "access_token" in spec: _uinfo = self.do_user_info_request( token=spec["access_token"], userinfo_endpoint=spec["endpoint"]) else: _uinfo = self.do_user_info_request( token=callback(csrc), userinfo_endpoint=spec["endpoint"]) attr = [n for n, s in userinfo._claim_names.items() if s == csrc] assert attr == _uinfo.keys() for key, vals in _uinfo.items(): userinfo[key] = vals return userinfo def verify_alg_support(self, alg, usage, other): """ Verifies that the algorithm to be used are supported by the other side. :param alg: The algorithm specification :param usage: In which context the 'alg' will be used. The following values are supported: - userinfo - id_token - request_object - token_endpoint_auth :param other: The identifier for the other side :return: True or False """ try: _pcr = self.provider_info supported = _pcr["%s_algs_supported" % usage] except KeyError: try: supported = getattr(self, "%s_algs_supported" % usage) except AttributeError: supported = None if supported is None: return True else: if alg in supported: return True else: return False def match_preferences(self, pcr=None, issuer=None): """ Match the clients preferences against what the provider can do. :param pcr: Provider configuration response if available :param issuer: The issuer identifier """ if not pcr: pcr = self.provider_info regreq = RegistrationRequest for _pref, _prov in PREFERENCE2PROVIDER.items(): try: vals = self.client_prefs[_pref] except KeyError: continue try: _pvals = pcr[_prov] except KeyError: try: self.behaviour[_pref] = PROVIDER_DEFAULT[_pref] except KeyError: #self.behaviour[_pref]= vals[0] if isinstance(pcr.c_param[_prov][0], list): self.behaviour[_pref] = [] else: self.behaviour[_pref] = None continue if isinstance(vals, basestring): if vals in _pvals: self.behaviour[_pref] = vals else: vtyp = regreq.c_param[_pref] if isinstance(vtyp[0], list): _list = True else: _list = False for val in vals: if val in _pvals: if not _list: self.behaviour[_pref] = val break else: try: self.behaviour[_pref].append(val) except KeyError: self.behaviour[_pref] = [val] if _pref not in self.behaviour: raise ConfigurationError( "OP couldn't match preference:%s" % _pref, pcr) for key, val in self.client_prefs.items(): if key in self.behaviour: continue try: vtyp = regreq.c_param[key] if isinstance(vtyp[0], list): pass elif isinstance(val, list) and not isinstance(val, basestring): val = val[0] except KeyError: pass if key not in PREFERENCE2PROVIDER: self.behaviour[key] = val def store_registration_info(self, reginfo): self.registration_response = reginfo if "token_endpoint_auth_method" not in self.registration_response: self.registration_response["token_endpoint_auth_method"] = "client_secret_post" self.client_secret = reginfo["client_secret"] self.client_id = reginfo["client_id"] try: self.registration_expires = reginfo["client_secret_expires_at"] except KeyError: pass try: self.registration_access_token = reginfo[ "registration_access_token"] except KeyError: pass def handle_registration_info(self, response): if response.status_code == 200: resp = RegistrationResponse().deserialize(response.text, "json") self.store_registration_info(resp) else: err = ErrorResponse().deserialize(response.text, "json") raise PyoidcError("Registration failed: %s" % err.get_json()) return resp def registration_read(self, url="", registration_access_token=None): if not url: url = self.registration_response["registration_client_uri"] if not registration_access_token: registration_access_token = self.registration_access_token headers = [("Authorization", "Bearer %s" % registration_access_token)] rsp = self.http_request(url, "GET", headers=headers) return self.handle_registration_info(rsp) def create_registration_request(self, **kwargs): """ Create a registration request :param kwargs: parameters to the registration request :return: """ req = RegistrationRequest() for prop in req.parameters(): try: req[prop] = kwargs[prop] except KeyError: try: req[prop] = self.behaviour[prop] except KeyError: pass if "post_logout_redirect_uris" not in req: try: req[ "post_logout_redirect_uris"] = self.post_logout_redirect_uris except AttributeError: pass if "redirect_uris" not in req: try: req["redirect_uris"] = self.redirect_uris except AttributeError: raise MissingRequiredAttribute("redirect_uris", req) return req def register(self, url, **kwargs): """ Register the client at an OP :param url: The OPs registration endpoint :param kwargs: parameters to the registration request :return: """ req = self.create_registration_request(**kwargs) headers = {"content-type": "application/json"} rsp = self.http_request(url, "POST", data=req.to_json(), headers=headers) return self.handle_registration_info(rsp) def normalization(self, principal, idtype="mail"): if idtype == "mail": (local, domain) = principal.split("@") subject = "acct:%s" % principal elif idtype == "url": p = urlparse.urlparse(principal) domain = p.netloc subject = principal else: domain = "" subject = principal return subject, domain def discover(self, principal): #subject, host = self.normalization(principal) return self.wf.discovery_query(principal)
class Client(oauth2.Client): _endpoints = ENDPOINTS def __init__(self, client_id=None, ca_certs=None, client_prefs=None, client_authn_method=None, keyjar=None, verify_ssl=True): oauth2.Client.__init__(self, client_id, ca_certs, client_authn_method=client_authn_method, keyjar=keyjar, verify_ssl=verify_ssl) self.file_store = "./file/" self.file_uri = "http://localhost/" # OpenID connect specific endpoints for endpoint in ENDPOINTS: setattr(self, endpoint, "") self.id_token = None self.log = None self.request2endpoint = REQUEST2ENDPOINT self.response2error = RESPONSE2ERROR self.grant_class = Grant self.token_class = Token self.provider_info = None self.registration_response = None self.client_prefs = client_prefs or {} self.behaviour = { "request_object_signing_alg": DEF_SIGN_ALG["openid_request_object"] } self.wf = WebFinger(OIC_ISSUER) self.wf.httpd = self self.allow = {} self.post_logout_redirect_uris = [] self.registration_expires = 0 self.registration_access_token = None # Default key by kid for different key types # For instance {"RSA":"abc"} self.kid = {"sig": {}, "enc": {}} def _get_id_token(self, **kwargs): try: return kwargs["id_token"] except KeyError: grant = self.get_grant(**kwargs) if grant: try: _scope = kwargs["scope"] except KeyError: _scope = None for token in grant.tokens: if token.scope and _scope: flag = True for item in _scope: try: assert item in token.scope except AssertionError: flag = False break if not flag: break if token.id_token: return token.id_token return None def construct_AuthorizationRequest(self, request=AuthorizationRequest, request_args=None, extra_args=None, request_param=None, **kwargs): if request_args is not None: # if "claims" in request_args: # kwargs["claims"] = request_args["claims"] # del request_args["claims"] if "nonce" not in request_args: _rt = request_args["response_type"] if "token" in _rt or "id_token" in _rt: request_args["nonce"] = rndstr(12) elif "response_type" in kwargs: if "token" in kwargs["response_type"]: request_args = {"nonce": rndstr(12)} else: # Never wrong to specify a nonce request_args = {"nonce": rndstr(12)} if "request_method" in kwargs: if kwargs["request_method"] == "file": request_param = "request_uri" del kwargs["request_method"] areq = oauth2.Client.construct_AuthorizationRequest( self, request, request_args, extra_args, **kwargs) if request_param: alg = self.behaviour["request_object_signing_alg"] if "algorithm" not in kwargs: kwargs["algorithm"] = alg if "keys" not in kwargs and alg: _kty = alg2keytype(alg) try: kwargs["keys"] = self.keyjar.get_signing_key( _kty, kid=self.kid["sig"][_kty]) except KeyError: kwargs["keys"] = self.keyjar.get_signing_key(_kty) _req = make_openid_request(areq, **kwargs) if request_param == "request": areq["request"] = _req else: _filedir = kwargs["local_dir"] _webpath = kwargs["base_path"] _name = rndstr(10) filename = os.path.join(_filedir, _name) while os.path.exists(filename): _name = rndstr(10) filename = os.path.join(_filedir, _name) fid = open(filename, mode="w") fid.write(_req) fid.close() _webname = "%s%s" % (_webpath, _name) areq["request_uri"] = _webname return areq #noinspection PyUnusedLocal def construct_AccessTokenRequest(self, request=AccessTokenRequest, request_args=None, extra_args=None, **kwargs): return oauth2.Client.construct_AccessTokenRequest( self, request, request_args, extra_args, **kwargs) def construct_RefreshAccessTokenRequest(self, request=RefreshAccessTokenRequest, request_args=None, extra_args=None, **kwargs): return oauth2.Client.construct_RefreshAccessTokenRequest( self, request, request_args, extra_args, **kwargs) def construct_UserInfoRequest(self, request=UserInfoRequest, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} if "access_token" in request_args: pass else: if "scope" not in kwargs: kwargs["scope"] = "openid" token = self.get_token(**kwargs) if token is None: raise PyoidcError("No valid token available") request_args["access_token"] = token.access_token return self.construct_request(request, request_args, extra_args) #noinspection PyUnusedLocal def construct_RegistrationRequest(self, request=RegistrationRequest, request_args=None, extra_args=None, **kwargs): return self.construct_request(request, request_args, extra_args) #noinspection PyUnusedLocal def construct_RefreshSessionRequest(self, request=RefreshSessionRequest, request_args=None, extra_args=None, **kwargs): return self.construct_request(request, request_args, extra_args) def _id_token_based(self, request, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} try: _prop = kwargs["prop"] except KeyError: _prop = "id_token" if _prop in request_args: pass else: id_token = self._get_id_token(**kwargs) if id_token is None: raise PyoidcError("No valid id token available") request_args[_prop] = id_token return self.construct_request(request, request_args, extra_args) def construct_CheckSessionRequest(self, request=CheckSessionRequest, request_args=None, extra_args=None, **kwargs): return self._id_token_based(request, request_args, extra_args, **kwargs) def construct_CheckIDRequest(self, request=CheckIDRequest, request_args=None, extra_args=None, **kwargs): # access_token is where the id_token will be placed return self._id_token_based(request, request_args, extra_args, prop="access_token", **kwargs) def construct_EndSessionRequest(self, request=EndSessionRequest, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} if "state" in kwargs: request_args["state"] = kwargs["state"] elif "state" in request_args: kwargs["state"] = request_args["state"] # if "redirect_url" not in request_args: # request_args["redirect_url"] = self.redirect_url return self._id_token_based(request, request_args, extra_args, **kwargs) # ------------------------------------------------------------------------ def authorization_request_info(self, request_args=None, extra_args=None, **kwargs): return self.request_info(AuthorizationRequest, "GET", request_args, extra_args, **kwargs) # ------------------------------------------------------------------------ def do_authorization_request(self, request=AuthorizationRequest, state="", body_type="", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=AuthorizationResponse): return oauth2.Client.do_authorization_request(self, request, state, body_type, method, request_args, extra_args, http_args, response_cls) def do_access_token_request(self, request=AccessTokenRequest, scope="", state="", body_type="json", method="POST", request_args=None, extra_args=None, http_args=None, response_cls=AccessTokenResponse, authn_method="", **kwargs): return oauth2.Client.do_access_token_request(self, request, scope, state, body_type, method, request_args, extra_args, http_args, response_cls, authn_method, **kwargs) def do_access_token_refresh(self, request=RefreshAccessTokenRequest, state="", body_type="json", method="POST", request_args=None, extra_args=None, http_args=None, response_cls=AccessTokenResponse, **kwargs): return oauth2.Client.do_access_token_refresh(self, request, state, body_type, method, request_args, extra_args, http_args, response_cls, **kwargs) def do_registration_request(self, request=RegistrationRequest, scope="", state="", body_type="json", method="POST", request_args=None, extra_args=None, http_args=None, response_cls=None): url, body, ht_args, csi = self.request_info(request, method=method, request_args=request_args, extra_args=extra_args, scope=scope, state=state) if http_args is None: http_args = ht_args else: http_args.update(http_args) if response_cls is None: response_cls = RegistrationResponse response = self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args) return response def do_check_session_request(self, request=CheckSessionRequest, scope="", state="", body_type="json", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=IdToken): url, body, ht_args, csi = self.request_info(request, method=method, request_args=request_args, extra_args=extra_args, scope=scope, state=state) if http_args is None: http_args = ht_args else: http_args.update(http_args) return self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args) def do_check_id_request(self, request=CheckIDRequest, scope="", state="", body_type="json", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=IdToken): url, body, ht_args, csi = self.request_info(request, method=method, request_args=request_args, extra_args=extra_args, scope=scope, state=state) if http_args is None: http_args = ht_args else: http_args.update(http_args) return self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args) def do_end_session_request(self, request=EndSessionRequest, scope="", state="", body_type="", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=None): url, body, ht_args, csi = self.request_info(request, method=method, request_args=request_args, extra_args=extra_args, scope=scope, state=state) if http_args is None: http_args = ht_args else: http_args.update(http_args) return self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args) def user_info_request(self, method="GET", state="", scope="", **kwargs): uir = UserInfoRequest() logger.debug("[user_info_request]: kwargs:%s" % (kwargs, )) if "token" in kwargs: if kwargs["token"]: uir["access_token"] = kwargs["token"] token = Token() token.token_type = "Bearer" token.access_token = kwargs["token"] kwargs["behavior"] = "use_authorization_header" else: # What to do ? Need a callback token = None elif "access_token" in kwargs and kwargs["access_token"]: uir["access_token"] = kwargs["access_token"] del kwargs["access_token"] token = None else: token = self.grant[state].get_token(scope) if token.is_valid(): uir["access_token"] = token.access_token if token.token_type == "Bearer" and method == "GET": kwargs["behavior"] = "use_authorization_header" else: # raise oauth2.OldAccessToken if self.log: self.log.info("do access token refresh") try: self.do_access_token_refresh(token=token) token = self.grant[state].get_token(scope) uir["access_token"] = token.access_token except Exception: raise uri = self._endpoint("userinfo_endpoint", **kwargs) # If access token is a bearer token it might be sent in the # authorization header # 3-ways of sending the access_token: # - POST with token in authorization header # - POST with token in message body # - GET with token in authorization header if "behavior" in kwargs: _behav = kwargs["behavior"] _token = uir["access_token"] try: _ttype = kwargs["token_type"] except KeyError: try: _ttype = token.token_type except AttributeError: raise MissingParameter("Unspecified token type") # use_authorization_header, token_in_message_body if "use_authorization_header" in _behav and _ttype == "Bearer": bh = "Bearer %s" % _token if "headers" in kwargs: kwargs["headers"].update({"Authorization": bh}) else: kwargs["headers"] = {"Authorization": bh} if not "token_in_message_body" in _behav: # remove the token from the request del uir["access_token"] path, body, kwargs = self.get_or_post(uri, method, uir, **kwargs) h_args = dict([(k, v) for k, v in kwargs.items() if k in HTTP_ARGS]) return path, body, method, h_args def do_user_info_request(self, method="POST", state="", scope="openid", request="openid", **kwargs): kwargs["request"] = request path, body, method, h_args = self.user_info_request( method, state, scope, **kwargs) logger.debug("[do_user_info_request] PATH:%s BODY:%s H_ARGS: %s" % (path, body, h_args)) try: resp = self.http_request(path, method, data=body, **h_args) except oauth2.MissingRequiredAttribute: raise if resp.status_code == 200: try: assert "application/json" in resp.headers["content-type"] sformat = "json" except AssertionError: assert "application/jwt" in resp.headers["content-type"] sformat = "jwt" elif resp.status_code == 500: raise PyoidcError("ERROR: Something went wrong: %s" % resp.text) else: raise PyoidcError("ERROR: Something went wrong [%s]: %s" % (resp.status_code, resp.text)) try: _schema = kwargs["user_info_schema"] except KeyError: _schema = OpenIDSchema logger.debug("Reponse text: '%s'" % resp.text) if sformat == "json": return _schema().from_json(txt=resp.text) else: algo = self.client_prefs["userinfo_signed_response_alg"] _kty = alg2keytype(algo) # Keys of the OP ? try: keys = self.keyjar.get_signing_key(_kty, self.kid["sig"][_kty]) except KeyError: keys = self.keyjar.get_signing_key(_kty) return _schema().from_jwt(resp.text, keys) def get_userinfo_claims(self, access_token, endpoint, method="POST", schema_class=OpenIDSchema, **kwargs): uir = UserInfoRequest(access_token=access_token) h_args = dict([(k, v) for k, v in kwargs.items() if k in HTTP_ARGS]) if "authn_method" in kwargs: http_args = self.init_authentication_method(**kwargs) else: # If nothing defined this is the default http_args = self.init_authentication_method( uir, "bearer_header", **kwargs) h_args.update(http_args) path, body, kwargs = self.get_or_post(endpoint, method, uir, **kwargs) try: resp = self.http_request(path, method, data=body, **h_args) except oauth2.MissingRequiredAttribute: raise if resp.status_code == 200: assert "application/json" in resp.headers["content-type"] elif resp.status_code == 500: raise PyoidcError("ERROR: Something went wrong: %s" % resp.text) else: raise PyoidcError("ERROR: Something went wrong [%s]" % resp.status_code) return schema_class().from_json(txt=resp.text) def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True): """ Deal with Provider Config Response :param pcr: The ProviderConfigResponse instance :param issuer: The one I thought should be the issuer of the config :param keys: Should I deal with keys :param endpoints: Should I deal with endpoints, that is store them as attributes in self. """ if "issuer" in pcr: _pcr_issuer = pcr["issuer"] if pcr["issuer"].endswith("/"): if issuer.endswith("/"): _issuer = issuer else: _issuer = issuer + "/" else: if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer try: _ = self.allow["issuer_mismatch"] except KeyError: try: assert _issuer == _pcr_issuer except AssertionError: raise IssuerMismatch( "'%s' != '%s'" % (_issuer, _pcr_issuer), pcr) self.provider_info = pcr else: _pcr_issuer = issuer if endpoints: for key, val in pcr.items(): if key.endswith("_endpoint"): setattr(self, key, val) if keys: if self.keyjar is None: self.keyjar = KeyJar(verify_ssl=self.verify_ssl) self.keyjar.load_keys(pcr, _pcr_issuer) def provider_config(self, issuer, keys=True, endpoints=True, response_cls=ProviderConfigurationResponse, serv_pattern=OIDCONF_PATTERN): if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer url = serv_pattern % _issuer pcr = None r = self.http_request(url) if r.status_code == 200: pcr = response_cls().from_json(r.text) elif r.status_code == 302: while r.status_code == 302: r = self.http_request(r.headers["location"]) if r.status_code == 200: pcr = response_cls().from_json(r.text) break #logger.debug("Provider info: %s" % pcr) if pcr is None: raise PyoidcError("Trying '%s', status %s" % (url, r.status_code)) self.handle_provider_config(pcr, issuer, keys, endpoints) return pcr def unpack_aggregated_claims(self, userinfo): if userinfo._claim_sources: for csrc, spec in userinfo._claim_sources.items(): if "JWT" in spec: if not csrc in self.keyjar: self.provider_config(csrc, endpoints=False) keycol = self.keyjar.get_verify_key(owner=csrc) for typ, keyl in self.keyjar.get_verify_key().items(): try: keycol[typ].extend(keyl) except KeyError: keycol[typ] = keyl info = json.loads(JWS().verify(str(spec["JWT"]), keycol)) attr = [ n for n, s in userinfo._claim_names.items() if s == csrc ] assert attr == info.keys() for key, vals in info.items(): userinfo[key] = vals return userinfo def fetch_distributed_claims(self, userinfo, callback=None): for csrc, spec in userinfo._claim_sources.items(): if "endpoint" in spec: #pcr = self.provider_config(csrc, keys=False, endpoints=False) if "access_token" in spec: _uinfo = self.do_user_info_request( token=spec["access_token"], userinfo_endpoint=spec["endpoint"]) else: _uinfo = self.do_user_info_request( token=callback(csrc), userinfo_endpoint=spec["endpoint"]) attr = [ n for n, s in userinfo._claim_names.items() if s == csrc ] assert attr == _uinfo.keys() for key, vals in _uinfo.items(): userinfo[key] = vals return userinfo def verify_alg_support(self, alg, usage, other): """ Verifies that the algorithm to be used are supported by the other side. :param alg: The algorithm specification :param usage: In which context the 'alg' will be used. The following values are supported: - userinfo - id_token - request_object - token_endpoint_auth :param other: The identifier for the other side :return: True or False """ try: _pcr = self.provider_info supported = _pcr["%s_algs_supported" % usage] except KeyError: try: supported = getattr(self, "%s_algs_supported" % usage) except AttributeError: supported = None if supported is None: return True else: if alg in supported: return True else: return False def match_preferences(self, pcr=None, issuer=None): """ Match the clients preferences against what the provider can do. :param pcr: Provider configuration response if available :param issuer: The issuer identifier """ if not pcr: pcr = self.provider_info regreq = RegistrationRequest for _pref, _prov in PREFERENCE2PROVIDER.items(): try: vals = self.client_prefs[_pref] except KeyError: continue try: _pvals = pcr[_prov] except KeyError: try: self.behaviour[_pref] = PROVIDER_DEFAULT[_pref] except KeyError: #self.behaviour[_pref]= vals[0] if isinstance(pcr.c_param[_prov][0], list): self.behaviour[_pref] = [] else: self.behaviour[_pref] = None continue if isinstance(vals, basestring): if vals in _pvals: self.behaviour[_pref] = vals else: vtyp = regreq.c_param[_pref] if isinstance(vtyp[0], list): _list = True else: _list = False for val in vals: if val in _pvals: if not _list: self.behaviour[_pref] = val break else: try: self.behaviour[_pref].append(val) except KeyError: self.behaviour[_pref] = [val] if _pref not in self.behaviour: raise ConfigurationError( "OP couldn't match preference:%s" % _pref, pcr) for key, val in self.client_prefs.items(): if key in self.behaviour: continue try: vtyp = regreq.c_param[key] if isinstance(vtyp[0], list): pass elif isinstance(val, list) and not isinstance(val, basestring): val = val[0] except KeyError: pass if key not in PREFERENCE2PROVIDER: self.behaviour[key] = val def store_registration_info(self, reginfo): self.registration_response = reginfo if "token_endpoint_auth_method" not in self.registration_response: self.registration_response[ "token_endpoint_auth_method"] = "client_secret_post" self.client_secret = reginfo["client_secret"] self.client_id = reginfo["client_id"] try: self.registration_expires = reginfo["client_secret_expires_at"] except KeyError: pass try: self.registration_access_token = reginfo[ "registration_access_token"] except KeyError: pass def handle_registration_info(self, response): if response.status_code == 200: resp = RegistrationResponse().deserialize(response.text, "json") self.store_registration_info(resp) else: err = ErrorResponse().deserialize(response.text, "json") raise PyoidcError("Registration failed: %s" % err.get_json()) return resp def registration_read(self, url="", registration_access_token=None): if not url: url = self.registration_response["registration_client_uri"] if not registration_access_token: registration_access_token = self.registration_access_token headers = [("Authorization", "Bearer %s" % registration_access_token)] rsp = self.http_request(url, "GET", headers=headers) return self.handle_registration_info(rsp) def create_registration_request(self, **kwargs): """ Create a registration request :param kwargs: parameters to the registration request :return: """ req = RegistrationRequest() for prop in req.parameters(): try: req[prop] = kwargs[prop] except KeyError: try: req[prop] = self.behaviour[prop] except KeyError: pass if "post_logout_redirect_uris" not in req: try: req["post_logout_redirect_uris"] = self.post_logout_redirect_uris except AttributeError: pass if "redirect_uris" not in req: try: req["redirect_uris"] = self.redirect_uris except AttributeError: raise MissingRequiredAttribute("redirect_uris", req) return req def register(self, url, **kwargs): """ Register the client at an OP :param url: The OPs registration endpoint :param kwargs: parameters to the registration request :return: """ req = self.create_registration_request(**kwargs) headers = {"content-type": "application/json"} rsp = self.http_request(url, "POST", data=req.to_json(), headers=headers) return self.handle_registration_info(rsp) def normalization(self, principal, idtype="mail"): if idtype == "mail": (local, domain) = principal.split("@") subject = "acct:%s" % principal elif idtype == "url": p = urlparse.urlparse(principal) domain = p.netloc subject = principal else: domain = "" subject = principal return subject, domain def discover(self, principal): #subject, host = self.normalization(principal) return self.wf.discovery_query(principal)
class Provider(provider.Provider): def __init__(self, name, sdb, cdb, authn_broker, authz, client_authn, symkey, urlmap=None, keyjar=None, hostname="", configuration=None): provider.Provider.__init__(self, name, sdb, cdb, authn_broker, authz, client_authn, symkey=symkey, urlmap=urlmap) self.baseurl = "" if keyjar: self.keyjar = keyjar else: self.keyjar = KeyJar() self.hostname = hostname or socket.gethostname self.jwks_uri = [] self.endpoints = [DynamicClientEndpoint, TokenEndpoint, AuthorizationEndpoint, UserEndpoint, ResourceSetRegistrationEndpoint, IntrospectionEndpoint, RPTEndpoint, PermissionRegistrationEndpoint] def set_authn_broker(self, authn_broker): self.authn_broker = authn_broker for meth in self.authn_broker: meth.srv = self if authn_broker: self.cookie_func = authn_broker[0].create_cookie else: self.cookie_func = None @staticmethod def _verify_url(url, urlset): part = urllib.parse.urlparse(url) for reg, qp in urlset: _part = urllib.parse.urlparse(reg) if part.scheme == _part.scheme and part.netloc == _part.netloc: return True return False def do_client_registration(self, request, client_id, ignore=None): if ignore is None: ignore = [] _cinfo = self.cdb[client_id].copy() logger.debug("_cinfo: %s" % _cinfo) for key, val in list(request.items()): if key not in ignore: _cinfo[key] = val if "redirect_uris" in request: ruri = [] for uri in request["redirect_uris"]: if urllib.parse.urlparse(uri).fragment: err = ClientRegistrationErrorResponse( error="invalid_configuration_parameter", error_description="redirect_uri contains fragment") return Response(err.to_json(), content="application/json", status="400 Bad Request") base, query = urllib.parse.splitquery(uri) if query: ruri.append((base, urllib.parse.parse_qs(query))) else: ruri.append((base, query)) _cinfo["redirect_uris"] = ruri if "sector_identifier_uri" in request: si_url = request["sector_identifier_uri"] try: res = self.server.http_request(si_url) except ConnectionError as err: logger.error("%s" % err) return self._error_response( "invalid_configuration_parameter", descr="Couldn't open sector_identifier_uri") if not res: return self._error_response( "invalid_configuration_parameter", descr="Couldn't open sector_identifier_uri") logger.debug("sector_identifier_uri => %s" % res.text) try: si_redirects = json.loads(res.text) except ValueError: return self._error_response( "invalid_configuration_parameter", descr="Error deserializing sector_identifier_uri content") if "redirect_uris" in request: logger.debug("redirect_uris: %s" % request["redirect_uris"]) for uri in request["redirect_uris"]: try: assert uri in si_redirects except AssertionError: return self._error_response( "invalid_configuration_parameter", descr="redirect_uri missing from sector_identifiers" ) _cinfo["si_redirects"] = si_redirects _cinfo["sector_id"] = si_url elif "redirect_uris" in request: if len(request["redirect_uris"]) > 1: # check that the hostnames are the same host = "" for url in request["redirect_uris"]: part = urllib.parse.urlparse(url) _host = part.netloc.split(":")[0] if not host: host = _host else: try: assert host == _host except AssertionError: return self._error_response( "invalid_configuration_parameter", descr= "'sector_identifier_uri' must be registered") for item in ["policy_url", "logo_url"]: if item in request: if self._verify_url(request[item], _cinfo["redirect_uris"]): _cinfo[item] = request[item] else: return self._error_response( "invalid_configuration_parameter", descr="%s pointed to illegal URL" % item) try: self.keyjar.load_keys(request, client_id) try: logger.debug("keys for %s: [%s]" % ( client_id, ",".join(["%s" % x for x in self.keyjar[client_id]]))) except KeyError: pass except Exception as err: logger.error("Failed to load client keys: %s" % request.to_dict()) err = ClientRegistrationErrorResponse( error="invalid_configuration_parameter", error_description="%s" % err) return Response(err.to_json(), content="application/json", status="400 Bad Request") return _cinfo @staticmethod def comb_redirect_uris(args): if "redirect_uris" not in args: return val = [] for base, query in args["redirect_uris"]: if query: val.append("%s?%s" % (base, query)) else: val.append(base) args["redirect_uris"] = val #noinspection PyUnusedLocal def l_registration_endpoint(self, request, authn=None, **kwargs): _log_debug = logger.debug _log_info = logger.info _log_debug("@registration_endpoint") request = RegistrationRequest().deserialize(request, "json") _log_info("registration_request:%s" % request.to_dict()) resp_keys = list(request.keys()) try: request.verify() except MessageException as err: if "type" not in request: return self._error(error="invalid_type", descr="%s" % err) else: return self._error(error="invalid_configuration_parameter", descr="%s" % err) _keyjar = self.server.keyjar # create new id och secret client_id = rndstr(12) while client_id in self.cdb: client_id = rndstr(12) client_secret = secret(self.seed, client_id) _rat = rndstr(32) reg_enp = "" for endp in self.endpoints: if isinstance(endp, DynamicClientEndpoint): reg_enp = "%s%s" % (self.baseurl, endp.etype) self.cdb[client_id] = { "client_id": client_id, "client_secret": client_secret, "registration_access_token": _rat, "registration_client_uri": "%s?client_id=%s" % (reg_enp, client_id), "client_secret_expires_at": utc_time_sans_frac() + 86400, "client_id_issued_at": utc_time_sans_frac()} self.cdb[_rat] = client_id _cinfo = self.do_client_registration(request, client_id, ignore=["redirect_uris", "policy_url", "logo_url"]) if isinstance(_cinfo, Response): return _cinfo args = dict([(k, v) for k, v in list(_cinfo.items()) if k in RegistrationResponse.c_param]) self.comb_redirect_uris(args) response = RegistrationResponse(**args) self.keyjar.load_keys(request, client_id) # Add the key to the keyjar if client_secret: _kc = KeyBundle([{"kty": "oct", "key": client_secret, "use": "ver"}, {"kty": "oct", "key": client_secret, "use": "sig"}]) try: _keyjar[client_id].append(_kc) except KeyError: _keyjar[client_id] = [_kc] self.cdb[client_id] = _cinfo _log_info("Client info: %s" % _cinfo) logger.debug("registration_response: %s" % response.to_dict()) return Response(response.to_json(), content="application/json", headers=[("Cache-Control", "no-store")]) def registration_endpoint(self, request, authn=None, **kwargs): return self.l_registration_endpoint(request, authn, **kwargs) def read_registration(self, authn, request, **kwargs): """ Read all information this server has on a client. Authorization is done by using the access token that was return as part of the client registration result. :param authn: The Authorization HTTP header :param request: The query part of the URL :param kwargs: Any other arguments :return: """ logger.debug("authn: %s, request: %s" % (authn, request)) # verify the access token, has to be key into the client information # database. assert authn.startswith("Bearer ") token = authn[len("Bearer "):] client_id = self.cdb[token] # extra check _info = urllib.parse.parse_qs(request) assert _info["client_id"][0] == client_id logger.debug("Client '%s' reads client info" % client_id) args = dict([(k, v) for k, v in list(self.cdb[client_id].items()) if k in RegistrationResponse.c_param]) self.comb_redirect_uris(args) response = RegistrationResponse(**args) return Response(response.to_json(), content="application/json", headers=[("Cache-Control", "no-store")]) #noinspection PyUnusedLocal def providerinfo_endpoint(self, handle="", **kwargs): _log_debug = logger.debug _log_info = logger.info _log_info("@providerinfo_endpoint") try: _response = self.conf_info #for endp in self.endpoints: # _response[endp(None).name] = "%s%s" % (self.baseurl, # endp.etype) _log_info("provider_info_response: %s" % (_response.to_dict(),)) headers = [("Cache-Control", "no-store"), ("x-ffo", "bar")] if handle: (key, timestamp) = handle if key.startswith(STR) and key.endswith(STR): cookie = self.cookie_func(key, self.cookie_name, "pinfo", self.sso_ttl) headers.append(cookie) resp = Response(_response.to_json(), content="application/json", headers=headers) except Exception as err: message = traceback.format_exception(*sys.exc_info()) logger.error(message) resp = Response(message, content="html/text") return resp
class Provider(provider.Provider): """A OAuth2 RP that knows all the OAuth2 extensions I've implemented.""" def __init__( self, name, sdb, cdb, authn_broker, authz, client_authn, symkey=None, urlmap=None, iv=0, default_scope="", ca_bundle=None, seed=b"", client_authn_methods=None, authn_at_registration="", client_info_url="", secret_lifetime=86400, jwks_uri="", keyjar=None, capabilities=None, verify_ssl=True, baseurl="", hostname="", config=None, behavior=None, lifetime_policy=None, message_factory=ExtensionMessageFactory, **kwargs ): if not name.endswith("/"): name += "/" try: args = {"server_cls": kwargs["server_cls"]} except KeyError: args = {} super().__init__( name, sdb, cdb, authn_broker, authz, client_authn, symkey, urlmap, iv, default_scope, ca_bundle, message_factory=message_factory, **args ) self.endp.extend( [ RegistrationEndpoint, ClientInfoEndpoint, RevocationEndpoint, IntrospectionEndpoint, ] ) # dictionary of client authentication methods self.client_authn_methods = client_authn_methods if authn_at_registration: if authn_at_registration not in client_authn_methods: raise UnknownAuthnMethod(authn_at_registration) self.authn_at_registration = authn_at_registration self.seed = seed self.client_info_url = client_info_url self.secret_lifetime = secret_lifetime self.jwks_uri = jwks_uri self.verify_ssl = verify_ssl self.scopes.extend(kwargs.get("scopes", [])) self.keyjar = keyjar if self.keyjar is None: self.keyjar = KeyJar(verify_ssl=self.verify_ssl) if capabilities: self.capabilities = self.provider_features(provider_config=capabilities) else: self.capabilities = self.provider_features() self.baseurl = baseurl or name self.hostname = hostname or socket.gethostname() self.kid = {"sig": {}, "enc": {}} # type: Dict[str, Dict[str, str]] self.config = config or {} self.behavior = behavior or {} self.token_policy = { "access_token": {}, "refresh_token": {}, } # type: Dict[str, Dict[str, str]] if lifetime_policy is None: self.lifetime_policy = { "access_token": { "code": 600, "token": 120, "implicit": 120, "authorization_code": 600, "client_credentials": 600, "password": 600, }, "refresh_token": { "code": 3600, "token": 3600, "implicit": 3600, "authorization_code": 3600, "client_credentials": 3600, "password": 3600, }, } else: self.lifetime_policy = lifetime_policy self.token_handler = TokenHandler( self.baseurl, self.token_policy, keyjar=self.keyjar ) @staticmethod def _uris_to_tuples(uris): tup = [] for uri in uris: base, query = splitquery(uri) if query: tup.append((base, query)) else: tup.append((base, "")) return tup @staticmethod def _tuples_to_uris(items): _uri = [] for url, query in items: if query: _uri.append("%s?%s" % (url, query)) else: _uri.append(url) return _uri def load_keys(self, request, client_id, client_secret): try: self.keyjar.load_keys(request, client_id) try: n_keys = len(self.keyjar[client_id]) msg = "Found {} keys for client_id={}" logger.debug(msg.format(n_keys, client_id)) except KeyError: pass except Exception as err: msg = "Failed to load client keys: {}" logger.error(msg.format(sanitize(request.to_dict()))) logger.error("%s", err) error = ClientRegistrationError( error="invalid_configuration_parameter", error_description="%s" % err ) return Response( error.to_json(), content="application/json", status_code="400 Bad Request", ) # Add the client_secret as a symmetric key to the keyjar _kc = KeyBundle( [ {"kty": "oct", "key": client_secret, "use": "ver"}, {"kty": "oct", "key": client_secret, "use": "sig"}, ] ) try: self.keyjar[client_id].append(_kc) except KeyError: self.keyjar[client_id] = [_kc] @staticmethod def verify_correct(cinfo, restrictions): for fname, arg in restrictions.items(): func = restrict.factory(fname) res = func(arg, cinfo) if res: raise RestrictionError(res) def set_token_policy(self, cid, cinfo): for ttyp in ["access_token", "refresh_token"]: pol = {} for rgtyp in ["response_type", "grant_type"]: try: rtyp = cinfo[rgtyp] except KeyError: pass else: for typ in rtyp: try: pol[typ] = self.lifetime_policy[ttyp][typ] except KeyError: pass self.token_policy[ttyp][cid] = pol def create_new_client(self, request, restrictions): """ Create new client based on request and restrictions. :param request: The Client registration request :param restrictions: Restrictions on the client :return: The client_id """ _cinfo = request.to_dict() self.match_client_request(_cinfo) # create new id and secret _id = rndstr(12) while _id in self.cdb: _id = rndstr(12) _cinfo["client_id"] = _id _cinfo["client_secret"] = secret(self.seed, _id) _cinfo["client_id_issued_at"] = utc_time_sans_frac() _cinfo["client_secret_expires_at"] = utc_time_sans_frac() + self.secret_lifetime # If I support client info endpoint if ClientInfoEndpoint in self.endp: _cinfo["registration_access_token"] = rndstr(32) _cinfo["registration_client_uri"] = "%s%s%s?client_id=%s" % ( self.name, self.client_info_url, ClientInfoEndpoint.etype, _id, ) if "redirect_uris" in request: _cinfo["redirect_uris"] = self._uris_to_tuples(request["redirect_uris"]) self.load_keys(request, _id, _cinfo["client_secret"]) try: _behav = self.behavior["client_registration"] except KeyError: pass else: self.verify_correct(_cinfo, _behav) self.set_token_policy(_id, _cinfo) self.cdb[_id] = _cinfo return _id def match_client_request(self, request): for _pref, _prov in PREFERENCE2PROVIDER.items(): if _pref in request: if _pref == "response_types": for val in request[_pref]: match = False p = set(val.split(" ")) for cv in self.capabilities[_prov]: if p == set(cv.split(" ")): match = True break if not match: raise CapabilitiesMisMatch("Not allowed {}".format(_pref)) else: if isinstance(request[_pref], str): if request[_pref] not in self.capabilities[_prov]: raise CapabilitiesMisMatch("Not allowed {}".format(_pref)) else: if not set(request[_pref]).issubset( set(self.capabilities[_prov]) ): raise CapabilitiesMisMatch("Not allowed {}".format(_pref)) def client_info(self, client_id): _cinfo = self.cdb[client_id].copy() if not valid_client_info(_cinfo): err = ErrorResponse( error="invalid_client", error_description="Invalid client secret" ) return BadRequest(err.to_json(), content="application/json") try: _cinfo["redirect_uris"] = self._tuples_to_uris(_cinfo["redirect_uris"]) except KeyError: pass msg = self.server.message_factory.get_response_type("update_endpoint")(**_cinfo) return Response(msg.to_json(), content="application/json") def client_info_update(self, client_id, request): _cinfo = self.cdb[client_id].copy() try: _cinfo["redirect_uris"] = self._tuples_to_uris(_cinfo["redirect_uris"]) except KeyError: pass for key, value in request.items(): if key in ["client_secret", "client_id"]: # assure it's the same if value != _cinfo[key]: raise ModificationForbidden("Not allowed to change") else: _cinfo[key] = value for key in list(_cinfo.keys()): if key in [ "client_id_issued_at", "client_secret_expires_at", "registration_access_token", "registration_client_uri", ]: continue if key not in request: del _cinfo[key] if "redirect_uris" in request: _cinfo["redirect_uris"] = self._uris_to_tuples(request["redirect_uris"]) self.cdb[client_id] = _cinfo def verify_client(self, environ, areq, authn_method, client_id=""): """ Verify the client based on credentials. :param environ: WSGI environ :param areq: The request :param authn_method: client authentication method :return: """ if not client_id: client_id = get_client_id(self.cdb, areq, environ["HTTP_AUTHORIZATION"]) try: method = self.client_authn_methods[authn_method] except KeyError: raise UnSupported() return method(self).verify(environ, client_id=client_id) def consume_software_statement(self, software_statement): return {} def registration_endpoint(self, **kwargs): """ Perform dynamic client registration. :param request: The request :param authn: Client authentication information :param kwargs: extra keyword arguments :return: A Response instance """ _request = self.server.message_factory.get_request_type( "registration_endpoint" )().deserialize(kwargs["request"], "json") try: _request.verify(keyjar=self.keyjar) except InvalidRedirectUri as err: msg = ClientRegistrationError( error="invalid_redirect_uri", error_description="%s" % err ) return BadRequest(msg.to_json(), content="application/json") except (MissingPage, VerificationError) as err: msg = ClientRegistrationError( error="invalid_client_metadata", error_description="%s" % err ) return BadRequest(msg.to_json(), content="application/json") # If authentication is necessary at registration if self.authn_at_registration: try: self.verify_client( kwargs["environ"], _request, self.authn_at_registration ) except (AuthnFailure, UnknownAssertionType): return Unauthorized() client_restrictions = {} # type: ignore if "parsed_software_statement" in _request: for ss in _request["parsed_software_statement"]: client_restrictions.update(self.consume_software_statement(ss)) del _request["software_statement"] del _request["parsed_software_statement"] try: client_id = self.create_new_client(_request, client_restrictions) except CapabilitiesMisMatch as err: msg = ClientRegistrationError( error="invalid_client_metadata", error_description="%s" % err ) return BadRequest(msg.to_json(), content="application/json") except RestrictionError as err: msg = ClientRegistrationError( error="invalid_client_metadata", error_description="%s" % err ) return BadRequest(msg.to_json(), content="application/json") return self.client_info(client_id) def client_info_endpoint(self, method="GET", **kwargs): """ Operations on this endpoint are switched through the use of different HTTP methods. :param method: HTTP method used for the request :param kwargs: keyword arguments :return: A Response instance """ _query = compact(parse_qs(kwargs["query"])) try: _id = _query["client_id"] except KeyError: return BadRequest("Missing query component") if _id not in self.cdb: return Unauthorized() # authenticated client try: self.verify_client( kwargs["environ"], kwargs["request"], "bearer_header", client_id=_id ) except (AuthnFailure, UnknownAssertionType): return Unauthorized() if method == "GET": return self.client_info(_id) elif method == "PUT": try: _request = self.server.message_factory.get_request_type( "update_endpoint" )().from_json(kwargs["request"]) except ValueError as err: return BadRequest(str(err)) try: _request.verify() except InvalidRedirectUri as err: msg = ClientRegistrationError( error="invalid_redirect_uri", error_description="%s" % err ) return BadRequest(msg.to_json(), content="application/json") except (MissingPage, VerificationError) as err: msg = ClientRegistrationError( error="invalid_client_metadata", error_description="%s" % err ) return BadRequest(msg.to_json(), content="application/json") try: self.client_info_update(_id, _request) return self.client_info(_id) except ModificationForbidden: return Forbidden() elif method == "DELETE": try: del self.cdb[_id] except KeyError: return Unauthorized() else: return NoContent() @staticmethod def verify_code_challenge( code_verifier, code_challenge, code_challenge_method="S256" ): """ Verify a PKCE (RFC7636) code challenge. :param code_verifier: The origin :param code_challenge: The transformed verifier used as challenge :return: """ _h = CC_METHOD[code_challenge_method](code_verifier.encode("ascii")).digest() _cc = b64e(_h) if _cc.decode("ascii") != code_challenge: logger.error("PCKE Code Challenge check failed") err = TokenErrorResponse( error="invalid_request", error_description="PCKE check failed" ) return Response(err.to_json(), content="application/json", status_code=401) return True def do_access_token_response(self, access_token, atinfo, state, refresh_token=None): _tinfo = { "access_token": access_token, "expires_in": atinfo["exp"], "token_type": "bearer", "state": state, } try: _tinfo["scope"] = atinfo["scope"] except KeyError: pass if refresh_token: _tinfo["refresh_token"] = refresh_token atr_class = self.server.message_factory.get_response_type("token_endpoint") return atr_class(**by_schema(atr_class, **_tinfo)) def code_grant_type(self, areq): # assert that the code is valid try: _info = self.sdb[areq["code"]] except KeyError: err = TokenErrorResponse( error="invalid_grant", error_description="Unknown access grant" ) return Response( err.to_json(), content="application/json", status="401 Unauthorized" ) authzreq = json.loads(_info["authzreq"]) if "code_verifier" in areq: try: _method = authzreq["code_challenge_method"] except KeyError: _method = "S256" resp = self.verify_code_challenge( areq["code_verifier"], authzreq["code_challenge"], _method ) if resp: return resp if "state" in areq: if self.sdb[areq["code"]]["state"] != areq["state"]: logger.error("State value mismatch") err = TokenErrorResponse(error="unauthorized_client") return Unauthorized(err.to_json(), content="application/json") resp = self.token_scope_check(areq, _info) if resp: return resp # If redirect_uri was in the initial authorization request # verify that the one given here is the correct one. if "redirect_uri" in _info and areq["redirect_uri"] != _info["redirect_uri"]: logger.error("Redirect_uri mismatch") err = TokenErrorResponse(error="unauthorized_client") return Unauthorized(err.to_json(), content="application/json") issue_refresh = False if "scope" in authzreq and "offline_access" in authzreq["scope"]: if authzreq["response_type"] == "code": issue_refresh = True try: _tinfo = self.sdb.upgrade_to_token( areq["code"], issue_refresh=issue_refresh ) except AccessCodeUsed: err = TokenErrorResponse( error="invalid_grant", error_description="Access grant used" ) return Response( err.to_json(), content="application/json", status="401 Unauthorized" ) logger.debug("_tinfo: %s" % _tinfo) atr_class = self.server.message_factory.get_response_type("token_endpoint") atr = atr_class(**by_schema(atr_class, **_tinfo)) logger.debug("AccessTokenResponse: %s" % atr) return Response( atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS ) def client_credentials_grant_type(self, areq): _at = self.token_handler.get_access_token( areq["client_id"], scope=areq["scope"], grant_type="client_credentials" ) _info = self.token_handler.token_factory.get_info(_at) try: _rt = self.token_handler.get_refresh_token( self.baseurl, _info["access_token"], "client_credentials" ) except NotAllowed: atr = self.do_access_token_response(_at, _info, areq["state"]) else: atr = self.do_access_token_response(_at, _info, areq["state"], _rt) return Response(atr.to_json(), content="application/json") def password_grant_type(self, areq): """ Token authorization using Resource owner password credentials. RFC6749 section 4.3 """ # `Any` comparison tries a first broker, so we either hit an IndexError or get a method try: authn, authn_class_ref = self.pick_auth(areq, "any") except IndexError: err = TokenErrorResponse(error="invalid_grant") return Unauthorized(err.to_json(), content="application/json") identity, _ts = authn.authenticated_as( username=areq["username"], password=areq["password"] ) if identity is None: err = TokenErrorResponse(error="invalid_grant") return Unauthorized(err.to_json(), content="application/json") # We are returning a token areq["response_type"] = ["token"] authn_event = AuthnEvent( identity["uid"], identity.get("salt", ""), authn_info=authn_class_ref, time_stamp=_ts, ) sid = self.setup_session(areq, authn_event, self.cdb[areq["client_id"]]) _at = self.sdb.upgrade_to_token(self.sdb[sid]["code"], issue_refresh=True) atr_class = self.server.message_factory.get_response_type("token_endpoint") atr = atr_class(**by_schema(atr_class, **_at)) return Response( atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS ) def refresh_token_grant_type(self, areq): at = self.token_handler.refresh_access_token( self.baseurl, areq["access_token"], "refresh_token" ) atr_class = self.server.message_factory.get_response_type("token_endpoint") atr = atr_class(**by_schema(atr_class, **at)) return Response(atr.to_json(), content="application/json") @staticmethod def token_access(endpoint, client_id, token_info): # simple rules: if client_id in azp or aud it's allow to introspect # to revoke it has to be in azr allow = False if endpoint == "revocation_endpoint": if "azr" in token_info and client_id == token_info["azr"]: allow = True elif len(token_info["aud"]) == 1 and token_info["aud"] == [client_id]: allow = True else: # has to be introspection endpoint if "azr" in token_info and client_id == token_info["azr"]: allow = True elif "aud" in token_info: if client_id in token_info["aud"]: allow = True return allow def get_token_info(self, authn, req, endpoint): """ Parse token for information. :param authn: :param req: :return: """ try: client_id = self.client_authn(self, req, authn) except FailedAuthentication as err: logger.error(err) error = TokenErrorResponse( error="unauthorized_client", error_description="%s" % err ) return Response( error.to_json(), content="application/json", status="401 Unauthorized" ) logger.debug("{}: {} requesting {}".format(endpoint, client_id, req.to_dict())) try: token_type = req["token_type_hint"] except KeyError: try: _info = self.sdb.token_factory["access_token"].get_info(req["token"]) except Exception: try: _info = self.sdb.token_factory["refresh_token"].get_info( req["token"] ) except Exception: return self._return_inactive() else: token_type = "refresh_token" else: token_type = "access_token" else: try: _info = self.sdb.token_factory[token_type].get_info(req["token"]) except Exception: return self._return_inactive() if not self.token_access(endpoint, client_id, _info): return BadRequest() return client_id, token_type, _info def _return_inactive(self): ir = self.server.message_factory.get_response_type("introspection_endpoint")( active=False ) return Response(ir.to_json(), content="application/json") def revocation_endpoint(self, authn="", request=None, **kwargs): """ Implement RFC7009 allows a client to invalidate an access or refresh token. :param authn: Client Authentication information :param request: The revocation request :param kwargs: :return: """ trr = self.server.message_factory.get_request_type( "revocation_endpoint" )().deserialize(request, "urlencoded") resp = self.get_token_info(authn, trr, "revocation_endpoint") if isinstance(resp, Response): return resp else: client_id, token_type, _info = resp logger.info("{} token revocation: {}".format(client_id, trr.to_dict())) try: self.sdb.token_factory[token_type].invalidate(trr["token"]) except KeyError: return BadRequest() else: return Response("OK") def introspection_endpoint(self, authn="", request=None, **kwargs): """ Implement RFC7662. :param authn: Client Authentication information :param request: The introspection request :param kwargs: :return: """ tir = self.server.message_factory.get_request_type( "introspection_endpoint" )().deserialize(request, "urlencoded") resp = self.get_token_info(authn, tir, "introspection_endpoint") if isinstance(resp, Response): return resp else: client_id, token_type, _info = resp logger.info("{} token introspection: {}".format(client_id, tir.to_dict())) ir = self.server.message_factory.get_response_type("introspection_endpoint")( active=self.sdb.token_factory[token_type].is_valid(_info), **_info.to_dict() ) ir.weed() return Response(ir.to_json(), content="application/json")
def http_request(url, **kwargs): return request("GET", url, **kwargs) info = {"registration_endpoint": "https://openidconnect.info/connect/register", "userinfo_endpoint": "https://openidconnect.info/connect/userinfo", "token_endpoint_auth_types_supported": ["client_secret_basic", "client_secret_post", "client_secret_jwt", "private_key_jwt"], "jwk_url": "https://openidconnect.info/jwk/jwk.json", "userinfo_algs_supported": ["HS256", "HS384", "HS512", "RS256", "RS384", "RS512"], "user_id_types_supported": ["pairwise", "public"], "scopes_supported": ["openid", "profile", "email", "address", "phone"], "token_endpoint": "https://openidconnect.info/connect/token", "id_token_algs_supported": ["HS256", "HS384", "HS512", "RS256", "RS384", "RS512"], "version": "3.0", "token_endpoint_auth_algs_supported": ["RS256", "RS384", "RS512"], "request_object_algs_supported": ["HS256", "HS384", "HS512", "RS256", "RS384", "RS512"], "response_types_supported": ["code", "token", "id_token", "id_token token", "code token", "code id_token", "code id_token token"], "authorization_endpoint": "https://openidconnect.info/connect/authorize", "acrs_supported": ["1"], "check_id_endpoint": "https://openidconnect.info/connect/check_session", "x509_url": "https://openidconnect.info/x509/cert.pem", "issuer": "https://openidconnect.info"} pi = ProviderConfigurationResponse(**info) kj = KeyJar() kj.load_keys(pi, pi["issuer"]) keys = kj.get("ver", "rsa", "https://openidconnect.info") print keys
def test_provider(): ks = KeyJar() pcr = ProviderConfigurationResponse().from_dict(PROVIDER_INFO) ks.load_keys(pcr, "https://connect-op.heroku.com") assert ks["https://connect-op.heroku.com"]
class Provider(provider.Provider): def __init__(self, name, sdb, cdb, authn_broker, authz, client_authn, symkey, urlmap=None, keyjar=None, hostname="", configuration=None): provider.Provider.__init__(self, name, sdb, cdb, authn_broker, authz, client_authn, symkey=symkey, urlmap=urlmap) self.baseurl = "" if keyjar: self.keyjar = keyjar else: self.keyjar = KeyJar() self.hostname = hostname or socket.gethostname self.jwks_uri = [] self.endpoints = [ DynamicClientEndpoint, TokenEndpoint, AuthorizationEndpoint, UserEndpoint, ResourceSetRegistrationEndpoint, IntrospectionEndpoint, RPTEndpoint, PermissionRegistrationEndpoint ] def set_authn_broker(self, authn_broker): self.authn_broker = authn_broker for meth in self.authn_broker: meth.srv = self if authn_broker: self.cookie_func = authn_broker[0].create_cookie else: self.cookie_func = None @staticmethod def _verify_url(url, urlset): part = urllib.parse.urlparse(url) for reg, qp in urlset: _part = urllib.parse.urlparse(reg) if part.scheme == _part.scheme and part.netloc == _part.netloc: return True return False def do_client_registration(self, request, client_id, ignore=None): if ignore is None: ignore = [] _cinfo = self.cdb[client_id].copy() logger.debug("_cinfo: %s" % _cinfo) for key, val in list(request.items()): if key not in ignore: _cinfo[key] = val if "redirect_uris" in request: ruri = [] for uri in request["redirect_uris"]: if urllib.parse.urlparse(uri).fragment: err = ClientRegistrationErrorResponse( error="invalid_configuration_parameter", error_description="redirect_uri contains fragment") return Response(err.to_json(), content="application/json", status="400 Bad Request") base, query = urllib.parse.splitquery(uri) if query: ruri.append((base, urllib.parse.parse_qs(query))) else: ruri.append((base, query)) _cinfo["redirect_uris"] = ruri if "sector_identifier_uri" in request: si_url = request["sector_identifier_uri"] try: res = self.server.http_request(si_url) except ConnectionError as err: logger.error("%s" % err) return self._error_response( "invalid_configuration_parameter", descr="Couldn't open sector_identifier_uri") if not res: return self._error_response( "invalid_configuration_parameter", descr="Couldn't open sector_identifier_uri") logger.debug("sector_identifier_uri => %s" % res.text) try: si_redirects = json.loads(res.text) except ValueError: return self._error_response( "invalid_configuration_parameter", descr="Error deserializing sector_identifier_uri content") if "redirect_uris" in request: logger.debug("redirect_uris: %s" % request["redirect_uris"]) for uri in request["redirect_uris"]: try: assert uri in si_redirects except AssertionError: return self._error_response( "invalid_configuration_parameter", descr="redirect_uri missing from sector_identifiers" ) _cinfo["si_redirects"] = si_redirects _cinfo["sector_id"] = si_url elif "redirect_uris" in request: if len(request["redirect_uris"]) > 1: # check that the hostnames are the same host = "" for url in request["redirect_uris"]: part = urllib.parse.urlparse(url) _host = part.netloc.split(":")[0] if not host: host = _host else: try: assert host == _host except AssertionError: return self._error_response( "invalid_configuration_parameter", descr= "'sector_identifier_uri' must be registered") for item in ["policy_url", "logo_url"]: if item in request: if self._verify_url(request[item], _cinfo["redirect_uris"]): _cinfo[item] = request[item] else: return self._error_response( "invalid_configuration_parameter", descr="%s pointed to illegal URL" % item) try: self.keyjar.load_keys(request, client_id) try: logger.debug("keys for %s: [%s]" % (client_id, ",".join( ["%s" % x for x in self.keyjar[client_id]]))) except KeyError: pass except Exception as err: logger.error("Failed to load client keys: %s" % request.to_dict()) err = ClientRegistrationErrorResponse( error="invalid_configuration_parameter", error_description="%s" % err) return Response(err.to_json(), content="application/json", status="400 Bad Request") return _cinfo @staticmethod def comb_redirect_uris(args): if "redirect_uris" not in args: return val = [] for base, query in args["redirect_uris"]: if query: val.append("%s?%s" % (base, query)) else: val.append(base) args["redirect_uris"] = val #noinspection PyUnusedLocal def l_registration_endpoint(self, request, authn=None, **kwargs): _log_debug = logger.debug _log_info = logger.info _log_debug("@registration_endpoint") request = RegistrationRequest().deserialize(request, "json") _log_info("registration_request:%s" % request.to_dict()) resp_keys = list(request.keys()) try: request.verify() except MessageException as err: if "type" not in request: return self._error(error="invalid_type", descr="%s" % err) else: return self._error(error="invalid_configuration_parameter", descr="%s" % err) _keyjar = self.server.keyjar # create new id och secret client_id = rndstr(12) while client_id in self.cdb: client_id = rndstr(12) client_secret = secret(self.seed, client_id) _rat = rndstr(32) reg_enp = "" for endp in self.endpoints: if isinstance(endp, DynamicClientEndpoint): reg_enp = "%s%s" % (self.baseurl, endp.etype) self.cdb[client_id] = { "client_id": client_id, "client_secret": client_secret, "registration_access_token": _rat, "registration_client_uri": "%s?client_id=%s" % (reg_enp, client_id), "client_secret_expires_at": utc_time_sans_frac() + 86400, "client_id_issued_at": utc_time_sans_frac() } self.cdb[_rat] = client_id _cinfo = self.do_client_registration( request, client_id, ignore=["redirect_uris", "policy_url", "logo_url"]) if isinstance(_cinfo, Response): return _cinfo args = dict([(k, v) for k, v in list(_cinfo.items()) if k in RegistrationResponse.c_param]) self.comb_redirect_uris(args) response = RegistrationResponse(**args) self.keyjar.load_keys(request, client_id) # Add the key to the keyjar if client_secret: _kc = KeyBundle([{ "kty": "oct", "key": client_secret, "use": "ver" }, { "kty": "oct", "key": client_secret, "use": "sig" }]) try: _keyjar[client_id].append(_kc) except KeyError: _keyjar[client_id] = [_kc] self.cdb[client_id] = _cinfo _log_info("Client info: %s" % _cinfo) logger.debug("registration_response: %s" % response.to_dict()) return Response(response.to_json(), content="application/json", headers=[("Cache-Control", "no-store")]) def registration_endpoint(self, request, authn=None, **kwargs): return self.l_registration_endpoint(request, authn, **kwargs) def read_registration(self, authn, request, **kwargs): """ Read all information this server has on a client. Authorization is done by using the access token that was return as part of the client registration result. :param authn: The Authorization HTTP header :param request: The query part of the URL :param kwargs: Any other arguments :return: """ logger.debug("authn: %s, request: %s" % (authn, request)) # verify the access token, has to be key into the client information # database. assert authn.startswith("Bearer ") token = authn[len("Bearer "):] client_id = self.cdb[token] # extra check _info = urllib.parse.parse_qs(request) assert _info["client_id"][0] == client_id logger.debug("Client '%s' reads client info" % client_id) args = dict([(k, v) for k, v in list(self.cdb[client_id].items()) if k in RegistrationResponse.c_param]) self.comb_redirect_uris(args) response = RegistrationResponse(**args) return Response(response.to_json(), content="application/json", headers=[("Cache-Control", "no-store")]) #noinspection PyUnusedLocal def providerinfo_endpoint(self, handle="", **kwargs): _log_debug = logger.debug _log_info = logger.info _log_info("@providerinfo_endpoint") try: _response = self.conf_info #for endp in self.endpoints: # _response[endp(None).name] = "%s%s" % (self.baseurl, # endp.etype) _log_info("provider_info_response: %s" % (_response.to_dict(), )) headers = [("Cache-Control", "no-store"), ("x-ffo", "bar")] if handle: (key, timestamp) = handle if key.startswith(STR) and key.endswith(STR): cookie = self.cookie_func(key, self.cookie_name, "pinfo", self.sso_ttl) headers.append(cookie) resp = Response(_response.to_json(), content="application/json", headers=headers) except Exception as err: message = traceback.format_exception(*sys.exc_info()) logger.error(message) resp = Response(message, content="html/text") return resp
class Client(oauth2.Client): def __init__(self, client_id=None, ca_certs=None, client_authn_method=None, keyjar=None, verify_ssl=True): oauth2.Client.__init__(self, client_id=client_id, ca_certs=ca_certs, client_authn_method=client_authn_method, keyjar=keyjar, verify_ssl=verify_ssl) self.allow = {} self.request2endpoint.update({ "RegistrationRequest": "registration_endpoint", "ClientUpdateRequest": "clientinfo_endpoint" }) self.registration_response = None def construct_RegistrationRequest(self, request=RegistrationRequest, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} return self.construct_request(request, request_args, extra_args) def do_client_registration(self, request=RegistrationRequest, body_type="", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=ClientInfoResponse, **kwargs): url, body, ht_args, csi = self.request_info(request, method, request_args, extra_args, **kwargs) if http_args is None: http_args = ht_args else: http_args.update(http_args) resp = self.request_and_return(url, response_cls, method, body, body_type, http_args=http_args) return resp def do_client_read_request(self, request=ClientUpdateRequest, body_type="", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=ClientInfoResponse, **kwargs): url, body, ht_args, csi = self.request_info(request, method, request_args, extra_args, **kwargs) if http_args is None: http_args = ht_args else: http_args.update(http_args) resp = self.request_and_return(url, response_cls, method, body, body_type, http_args=http_args) return resp def do_client_update_request(self, request=ClientUpdateRequest, body_type="", method="PUT", request_args=None, extra_args=None, http_args=None, response_cls=ClientInfoResponse, **kwargs): url, body, ht_args, csi = self.request_info(request, method, request_args, extra_args, **kwargs) if http_args is None: http_args = ht_args else: http_args.update(http_args) resp = self.request_and_return(url, response_cls, method, body, body_type, http_args=http_args) return resp def do_client_delete_request(self, request=ClientUpdateRequest, body_type="", method="DELETE", request_args=None, extra_args=None, http_args=None, response_cls=ClientInfoResponse, **kwargs): url, body, ht_args, csi = self.request_info(request, method, request_args, extra_args, **kwargs) if http_args is None: http_args = ht_args else: http_args.update(http_args) resp = self.request_and_return(url, response_cls, method, body, body_type, http_args=http_args) return resp def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True): """ Deal with Provider Config Response :param pcr: The ProviderConfigResponse instance :param issuer: The one I thought should be the issuer of the config :param keys: Should I deal with keys :param endpoints: Should I deal with endpoints, that is store them as attributes in self. """ if "issuer" in pcr: _pcr_issuer = pcr["issuer"] if pcr["issuer"].endswith("/"): if issuer.endswith("/"): _issuer = issuer else: _issuer = issuer + "/" else: if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer try: _ = self.allow["issuer_mismatch"] except KeyError: try: assert _issuer == _pcr_issuer except AssertionError: raise PyoidcError( "provider info issuer mismatch '%s' != '%s'" % (_issuer, _pcr_issuer)) self.provider_info[_pcr_issuer] = pcr else: _pcr_issuer = issuer if endpoints: for key, val in pcr.items(): if key.endswith("_endpoint"): setattr(self, key, val) if keys: if self.keyjar is None: self.keyjar = KeyJar() self.keyjar.load_keys(pcr, _pcr_issuer) def provider_config(self, issuer, keys=True, endpoints=True, response_cls=ProviderConfigurationResponse, serv_pattern=OIDCONF_PATTERN): if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer url = serv_pattern % _issuer pcr = None r = self.http_request(url) if r.status_code == 200: pcr = response_cls().from_json(r.text) elif r.status_code == 302: while r.status_code == 302: r = self.http_request(r.headers["location"]) if r.status_code == 200: pcr = response_cls().from_json(r.text) break if pcr is None: raise PyoidcError("Trying '%s', status %s" % (url, r.status_code)) self.handle_provider_config(pcr, issuer, keys, endpoints) return pcr def store_registration_info(self, reginfo): self.registration_response = reginfo self.client_secret = reginfo["client_secret"] self.client_id = reginfo["client_id"] self.redirect_uris = reginfo["redirect_uris"] def handle_registration_info(self, response): if response.status_code == 200: resp = ClientInfoResponse().deserialize(response.text, "json") self.store_registration_info(resp) else: err = ErrorResponse().deserialize(response.text, "json") raise PyoidcError("Registration failed: %s" % err.get_json()) return resp def register(self, url, **kwargs): """ Register the client at an OP :param url: The OPs registration endpoint :param kwargs: parameters to the registration request :return: """ req = self.construct_RegistrationRequest(request_args=kwargs) headers = {"content-type": "application/json"} rsp = self.http_request(url, "POST", data=req.to_json(), headers=headers) return self.handle_registration_info(rsp) def parse_authz_response(self, query): aresp = self.parse_response(AuthorizationResponse, info=query, sformat="urlencoded", keyjar=self.keyjar) if aresp.type() == "ErrorResponse": logger.info("ErrorResponse: %s" % aresp) raise AuthzError(aresp.error) logger.info("Aresp: %s" % aresp) return aresp
class Client(oauth2.Client): def __init__(self, client_id=None, ca_certs=None, client_authn_method=None, keyjar=None, verify_ssl=True, config=None): oauth2.Client.__init__(self, client_id=client_id, ca_certs=ca_certs, client_authn_method=client_authn_method, keyjar=keyjar, verify_ssl=verify_ssl, config=config) self.allow = {} self.request2endpoint.update({ "RegistrationRequest": "registration_endpoint", "ClientUpdateRequest": "clientinfo_endpoint", 'TokenIntrospectionRequest': 'introspection_endpoint', 'TokenRevocationRequest': 'revocation_endpoint' }) self.registration_response = None def construct_RegistrationRequest(self, request=RegistrationRequest, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} return self.construct_request(request, request_args, extra_args) def construct_ClientUpdateRequest(self, request=ClientUpdateRequest, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} return self.construct_request(request, request_args, extra_args) def _token_interaction_setup(self, request_args=None, **kwargs): if request_args is None or 'token' not in request_args: token = self.get_token(**kwargs) try: _token_type_hint = kwargs['token_type_hint'] except KeyError: _token_type_hint = 'access_token' request_args = {'token_type_hint': _token_type_hint, 'token': getattr(token, _token_type_hint)} if "client_id" not in request_args: request_args["client_id"] = self.client_id elif not request_args["client_id"]: request_args["client_id"] = self.client_id return request_args def construct_TokenIntrospectionRequest(self, request=TokenIntrospectionRequest, request_args=None, extra_args=None, **kwargs): request_args = self._token_interaction_setup(request_args, **kwargs) return self.construct_request(request, request_args, extra_args) def construct_TokenRevocationRequest(self, request=TokenRevocationRequest, request_args=None, extra_args=None, **kwargs): request_args = self._token_interaction_setup(request_args, **kwargs) return self.construct_request(request, request_args, extra_args) def do_op(self, request, body_type='', method='GET', request_args=None, extra_args=None, http_args=None, response_cls=None, **kwargs): url, body, ht_args, csi = self.request_info(request, method, request_args, extra_args, **kwargs) if http_args is None: http_args = ht_args else: http_args.update(http_args) resp = self.request_and_return(url, response_cls, method, body, body_type, http_args=http_args) return resp def do_client_registration(self, request=RegistrationRequest, body_type="", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=ClientInfoResponse, **kwargs): return self.do_op(request=request, body_type=body_type, method=method, request_args=request_args, extra_args=extra_args, http_args=http_args, response_cls=response_cls, **kwargs) def do_client_read_request(self, request=ClientUpdateRequest, body_type="", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=ClientInfoResponse, **kwargs): return self.do_op(request=request, body_type=body_type, method=method, request_args=request_args, extra_args=extra_args, http_args=http_args, response_cls=response_cls, **kwargs) def do_client_update_request(self, request=ClientUpdateRequest, body_type="", method="PUT", request_args=None, extra_args=None, http_args=None, response_cls=ClientInfoResponse, **kwargs): return self.do_op(request=request, body_type=body_type, method=method, request_args=request_args, extra_args=extra_args, http_args=http_args, response_cls=response_cls, **kwargs) def do_client_delete_request(self, request=ClientUpdateRequest, body_type="", method="DELETE", request_args=None, extra_args=None, http_args=None, response_cls=ClientInfoResponse, **kwargs): return self.do_op(request=request, body_type=body_type, method=method, request_args=request_args, extra_args=extra_args, http_args=http_args, response_cls=response_cls, **kwargs) def do_token_introspection( self, request=TokenIntrospectionRequest, body_type="json", method="POST", request_args=None, extra_args=None, http_args=None, response_cls=TokenIntrospectionResponse, **kwargs): return self.do_op(request=request, body_type=body_type, method=method, request_args=request_args, extra_args=extra_args, http_args=http_args, response_cls=response_cls, **kwargs) def do_token_revocation( self, request=TokenRevocationRequest, body_type="", method="POST", request_args=None, extra_args=None, http_args=None, response_cls=None, **kwargs): return self.do_op(request=request, body_type=body_type, method=method, request_args=request_args, extra_args=extra_args, http_args=http_args, response_cls=response_cls, **kwargs) def add_code_challenge(self): try: cv_len = self.config['code_challenge']['length'] except KeyError: cv_len = 64 # Use default code_verifier = unreserved(cv_len) _cv = code_verifier.encode() try: _method = self.config['code_challenge']['method'] except KeyError: _method = 'S256' try: _h = CC_METHOD[_method](_cv).hexdigest() code_challenge = b64e(_h.encode()).decode() except KeyError: raise Unsupported( 'PKCE Transformation method:{}'.format(_method)) # TODO store code_verifier return {"code_challenge": code_challenge, "code_challenge_method": _method}, code_verifier def do_authorization_request( self, request=AuthorizationRequest, state="", body_type="", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=AuthorizationResponse, **kwargs): if 'code_challenge' in self.config and self.config['code_challenge']: _args, code_verifier = self.add_code_challenge() request_args.update(_args) oauth2.Client.do_authorization_request(self, request=request, state=state, body_type=body_type, method=method, request_args=request_args, extra_args=extra_args, http_args=http_args, response_cls=response_cls, **kwargs) def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True): """ Deal with Provider Config Response :param pcr: The ProviderConfigResponse instance :param issuer: The one I thought should be the issuer of the config :param keys: Should I deal with keys :param endpoints: Should I deal with endpoints, that is store them as attributes in self. """ if "issuer" in pcr: _pcr_issuer = pcr["issuer"] if pcr["issuer"].endswith("/"): if issuer.endswith("/"): _issuer = issuer else: _issuer = issuer + "/" else: if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer try: _ = self.allow["issuer_mismatch"] except KeyError: try: assert _issuer == _pcr_issuer except AssertionError: raise PyoidcError( "provider info issuer mismatch '%s' != '%s'" % ( _issuer, _pcr_issuer)) self.provider_info = pcr else: _pcr_issuer = issuer if endpoints: for key, val in pcr.items(): if key.endswith("_endpoint"): setattr(self, key, val) if keys: if self.keyjar is None: self.keyjar = KeyJar() self.keyjar.load_keys(pcr, _pcr_issuer) def provider_config(self, issuer, keys=True, endpoints=True, response_cls=ASConfigurationResponse, serv_pattern=OIDCONF_PATTERN): if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer url = serv_pattern % _issuer pcr = None r = self.http_request(url) if r.status_code == 200: pcr = response_cls().from_json(r.text) elif r.status_code == 302: while r.status_code == 302: r = self.http_request(r.headers["location"]) if r.status_code == 200: pcr = response_cls().from_json(r.text) break if pcr is None: raise PyoidcError("Trying '%s', status %s" % (url, r.status_code)) self.handle_provider_config(pcr, issuer, keys, endpoints) return pcr def store_registration_info(self, reginfo): self.registration_response = reginfo self.client_secret = reginfo["client_secret"] self.client_id = reginfo["client_id"] self.redirect_uris = reginfo["redirect_uris"] def handle_registration_info(self, response): if response.status_code in SUCCESSFUL: resp = ClientInfoResponse().deserialize(response.text, "json") self.store_response(resp, response.text) self.store_registration_info(resp) else: resp = ErrorResponse().deserialize(response.text, "json") try: resp.verify() self.store_response(resp, response.text) except Exception as err: raise PyoidcError( 'Registration failed: {}'.format(response.text)) return resp def register(self, url, **kwargs): """ Register the client at an OP :param url: The OPs registration endpoint :param kwargs: parameters to the registration request :return: """ req = self.construct_RegistrationRequest(request_args=kwargs) headers = {"content-type": "application/json"} rsp = self.http_request(url, "POST", data=req.to_json(), headers=headers) return self.handle_registration_info(rsp) def parse_authz_response(self, query): aresp = self.parse_response(AuthorizationResponse, info=query, sformat="urlencoded", keyjar=self.keyjar) if aresp.type() == "ErrorResponse": logger.info("ErrorResponse: %s" % aresp) raise AuthzError(aresp.error) logger.info("Aresp: %s" % aresp) return aresp
class Client(PBase): _endpoints = ENDPOINTS def __init__(self, client_id=None, client_authn_method=None, keyjar=None, verify_ssl=True, config=None, client_cert=None, timeout=5): """ :param client_id: The client identifier :param client_authn_method: Methods that this client can use to authenticate itself. It's a dictionary with method names as keys and method classes as values. :param keyjar: The keyjar for this client. :param verify_ssl: Whether the SSL certificate should be verified. :param client_cert: A client certificate to use. :param timeout: Timeout for requests library. Can be specified either as a single integer or as a tuple of integers. For more details, refer to ``requests`` documentation. :return: Client instance """ PBase.__init__(self, verify_ssl=verify_ssl, keyjar=keyjar, client_cert=client_cert, timeout=timeout) self.client_id = client_id self.client_authn_method = client_authn_method self.nonce = None self.grant = {} self.state2nonce = {} # own endpoint self.redirect_uris = [None] # service endpoints self.authorization_endpoint = None self.token_endpoint = None self.token_revocation_endpoint = None self.request2endpoint = REQUEST2ENDPOINT self.response2error = RESPONSE2ERROR self.grant_class = Grant self.token_class = Token self.provider_info = {} self._c_secret = None self.kid = {"sig": {}, "enc": {}} self.authz_req = None # the OAuth issuer is the URL of the authorization server's # configuration information location self.config = config or {} try: self.issuer = self.config['issuer'] except KeyError: self.issuer = '' self.allow = {} self.provider_info = {} def store_response(self, clinst, text): pass def get_client_secret(self): return self._c_secret def set_client_secret(self, val): if not val: self._c_secret = "" else: self._c_secret = val # client uses it for signing # Server might also use it for signing which means the # client uses it for verifying server signatures if self.keyjar is None: self.keyjar = KeyJar() self.keyjar.add_symmetric("", str(val)) client_secret = property(get_client_secret, set_client_secret) def reset(self): self.nonce = None self.grant = {} self.authorization_endpoint = None self.token_endpoint = None self.redirect_uris = None def grant_from_state(self, state): for key, grant in self.grant.items(): if key == state: return grant return None def _parse_args(self, request, **kwargs): ar_args = kwargs.copy() for prop in request.c_param.keys(): if prop in ar_args: continue else: if prop == "redirect_uri": _val = getattr(self, "redirect_uris", [None])[0] if _val: ar_args[prop] = _val else: _val = getattr(self, prop, None) if _val: ar_args[prop] = _val return ar_args def _endpoint(self, endpoint, **kwargs): try: uri = kwargs[endpoint] if uri: del kwargs[endpoint] except KeyError: uri = "" if not uri: try: uri = getattr(self, endpoint) except Exception: raise MissingEndpoint("No '%s' specified" % endpoint) if not uri: raise MissingEndpoint("No '%s' specified" % endpoint) return uri def get_grant(self, state, **kwargs): try: return self.grant[state] except KeyError: raise GrantError("No grant found for state:'%s'" % state) def get_token(self, also_expired=False, **kwargs): try: return kwargs["token"] except KeyError: grant = self.get_grant(**kwargs) try: token = grant.get_token(kwargs["scope"]) except KeyError: token = grant.get_token("") if not token: try: token = self.grant[kwargs["state"]].get_token("") except KeyError: raise TokenError("No token found for scope") if token is None: raise TokenError("No suitable token found") if also_expired: return token elif token.is_valid(): return token else: raise TokenError("Token has expired") def clean_tokens(self): """Clean replaced and invalid tokens.""" for state in self.grant: grant = self.get_grant(state) for token in grant.tokens: if token.replaced or not token.is_valid(): grant.delete_token(token) def construct_request(self, request, request_args=None, extra_args=None): if request_args is None: request_args = {} kwargs = self._parse_args(request, **request_args) if extra_args: kwargs.update(extra_args) logger.debug("request: %s" % sanitize(request)) return request(**kwargs) def construct_Message(self, request=Message, request_args=None, extra_args=None, **kwargs): return self.construct_request(request, request_args, extra_args) def construct_AuthorizationRequest(self, request=AuthorizationRequest, request_args=None, extra_args=None, **kwargs): if request_args is not None: try: # change default new = request_args["redirect_uri"] if new: self.redirect_uris = [new] except KeyError: pass else: request_args = {} if "client_id" not in request_args: request_args["client_id"] = self.client_id elif not request_args["client_id"]: request_args["client_id"] = self.client_id return self.construct_request(request, request_args, extra_args) def construct_AccessTokenRequest(self, request=AccessTokenRequest, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} if request is not ROPCAccessTokenRequest: grant = self.get_grant(**kwargs) if not grant.is_valid(): raise GrantExpired("Authorization Code to old %s > %s" % ( utc_time_sans_frac(), grant.grant_expiration_time)) request_args["code"] = grant.code try: request_args['state'] = kwargs['state'] except KeyError: pass if "grant_type" not in request_args: request_args["grant_type"] = "authorization_code" if "client_id" not in request_args: request_args["client_id"] = self.client_id elif not request_args["client_id"]: request_args["client_id"] = self.client_id return self.construct_request(request, request_args, extra_args) def construct_RefreshAccessTokenRequest(self, request=RefreshAccessTokenRequest, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} token = self.get_token(also_expired=True, **kwargs) request_args["refresh_token"] = token.refresh_token try: request_args["scope"] = token.scope except AttributeError: pass return self.construct_request(request, request_args, extra_args) def construct_ResourceRequest(self, request=ResourceRequest, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} token = self.get_token(**kwargs) request_args["access_token"] = token.access_token return self.construct_request(request, request_args, extra_args) def uri_and_body(self, reqmsg, cis, method="POST", request_args=None, **kwargs): if "endpoint" in kwargs and kwargs["endpoint"]: uri = kwargs["endpoint"] else: uri = self._endpoint(self.request2endpoint[reqmsg.__name__], **request_args) uri, body, kwargs = get_or_post(uri, method, cis, **kwargs) try: h_args = {"headers": kwargs["headers"]} except KeyError: h_args = {} return uri, body, h_args, cis def request_info(self, request, method="POST", request_args=None, extra_args=None, lax=False, **kwargs): if request_args is None: request_args = {} try: cls = getattr(self, "construct_%s" % request.__name__) cis = cls(request_args=request_args, extra_args=extra_args, **kwargs) except AttributeError: cis = self.construct_request(request, request_args, extra_args) if self.events: self.events.store('Protocol request', cis) if 'nonce' in cis and 'state' in cis: self.state2nonce[cis['state']] = cis['nonce'] cis.lax = lax if "authn_method" in kwargs: h_arg = self.init_authentication_method(cis, request_args=request_args, **kwargs) else: h_arg = None if h_arg: if "headers" in kwargs.keys(): kwargs["headers"].update(h_arg["headers"]) else: kwargs["headers"] = h_arg["headers"] return self.uri_and_body(request, cis, method, request_args, **kwargs) def authorization_request_info(self, request_args=None, extra_args=None, **kwargs): return self.request_info(AuthorizationRequest, "GET", request_args, extra_args, **kwargs) def get_urlinfo(self, info): if '?' in info or '#' in info: parts = urlparse(info) scheme, netloc, path, params, query, fragment = parts[:6] # either query of fragment if query: info = query else: info = fragment return info def parse_response(self, response, info="", sformat="json", state="", **kwargs): """ Parse a response :param response: Response type :param info: The response, can be either in a JSON or an urlencoded format :param sformat: Which serialization that was used :param state: The state :param kwargs: Extra key word arguments :return: The parsed and to some extend verified response """ _r2e = self.response2error if sformat == "urlencoded": info = self.get_urlinfo(info) resp = response().deserialize(info, sformat, **kwargs) msg = 'Initial response parsing => "{}"' logger.debug(msg.format(sanitize(resp.to_dict()))) if self.events: self.events.store('Response', resp.to_dict()) if "error" in resp and not isinstance(resp, ErrorResponse): resp = None try: errmsgs = _r2e[response.__name__] except KeyError: errmsgs = [ErrorResponse] try: for errmsg in errmsgs: try: resp = errmsg().deserialize(info, sformat) resp.verify() break except Exception: resp = None except KeyError: pass elif resp.only_extras(): resp = None else: kwargs["client_id"] = self.client_id try: kwargs['iss'] = self.provider_info['issuer'] except (KeyError, AttributeError): if self.issuer: kwargs['iss'] = self.issuer if "key" not in kwargs and "keyjar" not in kwargs: kwargs["keyjar"] = self.keyjar logger.debug("Verify response with {}".format(sanitize(kwargs))) verf = resp.verify(**kwargs) if not verf: logger.error('Verification of the response failed') raise PyoidcError("Verification of the response failed") if resp.type() == "AuthorizationResponse" and "scope" not in resp: try: resp["scope"] = kwargs["scope"] except KeyError: pass if not resp: logger.error('Missing or faulty response') raise ResponseError("Missing or faulty response") self.store_response(resp, info) if resp.type() in ["AuthorizationResponse", "AccessTokenResponse"]: try: _state = resp["state"] except (AttributeError, KeyError): _state = "" if not _state: _state = state try: self.grant[_state].update(resp) except KeyError: self.grant[_state] = self.grant_class(resp=resp) return resp def init_authentication_method(self, cis, authn_method, request_args=None, http_args=None, **kwargs): if http_args is None: http_args = {} if request_args is None: request_args = {} if authn_method: return self.client_authn_method[authn_method](self).construct( cis, request_args, http_args, **kwargs) else: return http_args def parse_request_response(self, reqresp, response, body_type, state="", **kwargs): if reqresp.status_code in SUCCESSFUL: body_type = verify_header(reqresp, body_type) elif reqresp.status_code in [302, 303]: # redirect return reqresp elif reqresp.status_code == 500: logger.error("(%d) %s" % (reqresp.status_code, sanitize(reqresp.text))) raise ParseError("ERROR: Something went wrong: %s" % reqresp.text) elif reqresp.status_code in [400, 401]: # expecting an error response if issubclass(response, ErrorResponse): pass else: logger.error("(%d) %s" % (reqresp.status_code, sanitize(reqresp.text))) raise HttpError("HTTP ERROR: %s [%s] on %s" % ( reqresp.text, reqresp.status_code, reqresp.url)) if response: if body_type == 'txt': # no meaning trying to parse unstructured text return reqresp.text return self.parse_response(response, reqresp.text, body_type, state, **kwargs) # could be an error response if reqresp.status_code in [200, 400, 401]: if body_type == 'txt': body_type = 'urlencoded' try: err = ErrorResponse().deserialize(reqresp.message, method=body_type) try: err.verify() except PyoidcError: pass else: return err except Exception: pass return reqresp def request_and_return(self, url, response=None, method="GET", body=None, body_type="json", state="", http_args=None, **kwargs): """ :param url: The URL to which the request should be sent :param response: Response type :param method: Which HTTP method to use :param body: A message body if any :param body_type: The format of the body of the return message :param http_args: Arguments for the HTTP client :return: A cls or ErrorResponse instance or the HTTP response instance if no response body was expected. """ if http_args is None: http_args = {} try: resp = self.http_request(url, method, data=body, **http_args) except Exception: raise if "keyjar" not in kwargs: kwargs["keyjar"] = self.keyjar return self.parse_request_response(resp, response, body_type, state, **kwargs) def do_authorization_request(self, request=AuthorizationRequest, state="", body_type="", method="GET", request_args=None, extra_args=None, http_args=None, response_cls=AuthorizationResponse, **kwargs): if state: try: request_args["state"] = state except TypeError: request_args = {"state": state} kwargs['authn_endpoint'] = 'authorization' url, body, ht_args, csi = self.request_info(request, method, request_args, extra_args, **kwargs) try: self.authz_req[request_args["state"]] = csi except TypeError: pass if http_args is None: http_args = ht_args else: http_args.update(ht_args) try: algs = kwargs["algs"] except KeyError: algs = {} resp = self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args, algs=algs) if isinstance(resp, Message): if resp.type() in RESPONSE2ERROR["AuthorizationResponse"]: resp.state = csi.state return resp def do_access_token_request(self, request=AccessTokenRequest, scope="", state="", body_type="json", method="POST", request_args=None, extra_args=None, http_args=None, response_cls=AccessTokenResponse, authn_method="", **kwargs): kwargs['authn_endpoint'] = 'token' # method is default POST url, body, ht_args, csi = self.request_info(request, method=method, request_args=request_args, extra_args=extra_args, scope=scope, state=state, authn_method=authn_method, **kwargs) if http_args is None: http_args = ht_args else: http_args.update(ht_args) if self.events is not None: self.events.store('request_url', url) self.events.store('request_http_args', http_args) self.events.store('Request', body) logger.debug("<do_access_token> URL: %s, Body: %s" % (url, sanitize(body))) logger.debug("<do_access_token> response_cls: %s" % response_cls) return self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args, **kwargs) def do_access_token_refresh(self, request=RefreshAccessTokenRequest, state="", body_type="json", method="POST", request_args=None, extra_args=None, http_args=None, response_cls=AccessTokenResponse, authn_method="", **kwargs): token = self.get_token(also_expired=True, state=state, **kwargs) kwargs['authn_endpoint'] = 'refresh' url, body, ht_args, csi = self.request_info(request, method=method, request_args=request_args, extra_args=extra_args, token=token, authn_method=authn_method) if http_args is None: http_args = ht_args else: http_args.update(ht_args) response = self.request_and_return(url, response_cls, method, body, body_type, state=state, http_args=http_args) if token.replaced: grant = self.get_grant(state) grant.delete_token(token) return response def do_any(self, request, endpoint="", scope="", state="", body_type="json", method="POST", request_args=None, extra_args=None, http_args=None, response=None, authn_method=""): url, body, ht_args, _ = self.request_info(request, method=method, request_args=request_args, extra_args=extra_args, scope=scope, state=state, authn_method=authn_method, endpoint=endpoint) if http_args is None: http_args = ht_args else: http_args.update(ht_args) return self.request_and_return(url, response, method, body, body_type, state=state, http_args=http_args) def fetch_protected_resource(self, uri, method="GET", headers=None, state="", **kwargs): if "token" in kwargs and kwargs["token"]: token = kwargs["token"] request_args = {"access_token": token} else: try: token = self.get_token(state=state, **kwargs) except ExpiredToken: # The token is to old, refresh self.do_access_token_refresh(state=state) token = self.get_token(state=state, **kwargs) request_args = {"access_token": token.access_token} if headers is None: headers = {} if "authn_method" in kwargs: http_args = self.init_authentication_method( request_args=request_args, **kwargs) else: # If nothing defined this is the default http_args = self.client_authn_method[ "bearer_header"](self).construct(request_args=request_args) headers.update(http_args["headers"]) logger.debug("Fetch URI: %s" % uri) return self.http_request(uri, method, headers=headers) def add_code_challenge(self): """ PKCE RFC 7636 support :return: """ try: cv_len = self.config['code_challenge']['length'] except KeyError: cv_len = 64 # Use default code_verifier = unreserved(cv_len) _cv = code_verifier.encode('ascii') try: _method = self.config['code_challenge']['method'] except KeyError: _method = 'S256' try: _h = CC_METHOD[_method](_cv).digest() code_challenge = b64e(_h).decode('ascii') except KeyError: raise Unsupported( 'PKCE Transformation method:{}'.format(_method)) # TODO store code_verifier return {"code_challenge": code_challenge, "code_challenge_method": _method}, code_verifier def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True): """ Deal with Provider Config Response :param pcr: The ProviderConfigResponse instance :param issuer: The one I thought should be the issuer of the config :param keys: Should I deal with keys :param endpoints: Should I deal with endpoints, that is store them as attributes in self. """ if "issuer" in pcr: _pcr_issuer = pcr["issuer"] if pcr["issuer"].endswith("/"): if issuer.endswith("/"): _issuer = issuer else: _issuer = issuer + "/" else: if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer if not self.allow.get("issuer_mismatch", False) and _issuer != _pcr_issuer: raise PyoidcError("provider info issuer mismatch '%s' != '%s'" % (_issuer, _pcr_issuer)) self.provider_info = pcr else: _pcr_issuer = issuer self.issuer = _pcr_issuer if endpoints: for key, val in pcr.items(): if key.endswith("_endpoint"): setattr(self, key, val) if keys: if self.keyjar is None: self.keyjar = KeyJar() self.keyjar.load_keys(pcr, _pcr_issuer) def provider_config(self, issuer, keys=True, endpoints=True, response_cls=ASConfigurationResponse, serv_pattern=OIDCONF_PATTERN): if issuer.endswith("/"): _issuer = issuer[:-1] else: _issuer = issuer url = serv_pattern % _issuer pcr = None r = self.http_request(url) if r.status_code == 200: pcr = response_cls().from_json(r.text) elif r.status_code == 302: while r.status_code == 302: r = self.http_request(r.headers["location"]) if r.status_code == 200: pcr = response_cls().from_json(r.text) break if pcr is None: raise PyoidcError("Trying '%s', status %s" % (url, r.status_code)) self.handle_provider_config(pcr, issuer, keys, endpoints) return pcr
class Provider(provider.Provider): """ A OAuth2 RP that knows all the OAuth2 extensions I've implemented """ def __init__(self, name, sdb, cdb, authn_broker, authz, client_authn, symkey=None, urlmap=None, iv=0, default_scope="", ca_bundle=None, seed=b"", client_authn_methods=None, authn_at_registration="", client_info_url="", secret_lifetime=86400, jwks_uri='', keyjar=None, capabilities=None, verify_ssl=True, baseurl='', hostname='', config=None, behavior=None, lifetime_policy=None, **kwargs): if not name.endswith("/"): name += "/" try: args = {'server_cls': kwargs['server_cls']} except KeyError: args = {} provider.Provider.__init__(self, name, sdb, cdb, authn_broker, authz, client_authn, symkey, urlmap, iv, default_scope, ca_bundle, **args) self.endp.extend([ RegistrationEndpoint, ClientInfoEndpoint, RevocationEndpoint, IntrospectionEndpoint ]) # dictionary of client authentication methods self.client_authn_methods = client_authn_methods if authn_at_registration: if authn_at_registration not in client_authn_methods: raise UnknownAuthnMethod(authn_at_registration) self.authn_at_registration = authn_at_registration self.seed = seed self.client_info_url = client_info_url self.secret_lifetime = secret_lifetime self.jwks_uri = jwks_uri self.verify_ssl = verify_ssl try: self.scopes = kwargs['scopes'] except KeyError: self.scopes = ['offline_access'] self.keyjar = keyjar if self.keyjar is None: self.keyjar = KeyJar(verify_ssl=self.verify_ssl) if capabilities: self.capabilities = self.provider_features( provider_config=capabilities) else: self.capabilities = self.provider_features() self.baseurl = baseurl or name self.hostname = hostname or socket.gethostname() self.kid = {"sig": {}, "enc": {}} self.config = config or {} self.behavior = behavior or {} self.token_policy = {'access_token': {}, 'refresh_token': {}} if lifetime_policy is None: self.lifetime_policy = { 'access_token': { 'code': 600, 'token': 120, 'implicit': 120, 'authorization_code': 600, 'client_credentials': 600, 'password': 600 }, 'refresh_token': { 'code': 3600, 'token': 3600, 'implicit': 3600, 'authorization_code': 3600, 'client_credentials': 3600, 'password': 3600 } } else: self.lifetime_policy = lifetime_policy self.token_handler = TokenHandler(self.baseurl, self.token_policy, keyjar=self.keyjar) @staticmethod def _uris_to_tuples(uris): tup = [] for uri in uris: base, query = splitquery(uri) if query: tup.append((base, query)) else: tup.append((base, "")) return tup @staticmethod def _tuples_to_uris(items): _uri = [] for url, query in items: if query: _uri.append("%s?%s" % (url, query)) else: _uri.append(url) return _uri def load_keys(self, request, client_id, client_secret): try: self.keyjar.load_keys(request, client_id) try: n_keys = len(self.keyjar[client_id]) msg = "Found {} keys for client_id={}" logger.debug(msg.format(n_keys, client_id)) except KeyError: pass except Exception as err: msg = "Failed to load client keys: {}" logger.error(msg.format(sanitize(request.to_dict()))) logger.error("%s", err) err = ClientRegistrationError( error="invalid_configuration_parameter", error_description="%s" % err) return Response(err.to_json(), content="application/json", status_code="400 Bad Request") # Add the client_secret as a symmetric key to the keyjar _kc = KeyBundle([{ "kty": "oct", "key": client_secret, "use": "ver" }, { "kty": "oct", "key": client_secret, "use": "sig" }]) try: self.keyjar[client_id].append(_kc) except KeyError: self.keyjar[client_id] = [_kc] @staticmethod def verify_correct(cinfo, restrictions): for fname, arg in restrictions.items(): func = restrict.factory(fname) res = func(arg, cinfo) if res: raise RestrictionError(res) def set_token_policy(self, cid, cinfo): for ttyp in ['access_token', 'refresh_token']: pol = {} for rgtyp in ['response_type', 'grant_type']: try: rtyp = cinfo[rgtyp] except KeyError: pass else: for typ in rtyp: try: pol[typ] = self.lifetime_policy[ttyp][typ] except KeyError: pass self.token_policy[ttyp][cid] = pol def create_new_client(self, request, restrictions): """ :param request: The Client registration request :param restrictions: Restrictions on the client :return: The client_id """ _cinfo = request.to_dict() self.match_client_request(_cinfo) # create new id and secret _id = rndstr(12) while _id in self.cdb: _id = rndstr(12) _cinfo["client_id"] = _id _cinfo["client_secret"] = secret(self.seed, _id) _cinfo["client_id_issued_at"] = utc_time_sans_frac() _cinfo["client_secret_expires_at"] = utc_time_sans_frac( ) + self.secret_lifetime # If I support client info endpoint if ClientInfoEndpoint in self.endp: _cinfo["registration_access_token"] = rndstr(32) _cinfo["registration_client_uri"] = "%s%s%s?client_id=%s" % ( self.name, self.client_info_url, ClientInfoEndpoint.etype, _id) if "redirect_uris" in request: _cinfo["redirect_uris"] = self._uris_to_tuples( request["redirect_uris"]) self.load_keys(request, _id, _cinfo["client_secret"]) try: _behav = self.behavior['client_registration'] except KeyError: pass else: self.verify_correct(_cinfo, _behav) self.set_token_policy(_id, _cinfo) self.cdb[_id] = _cinfo return _id def match_client_request(self, request): for _pref, _prov in PREFERENCE2PROVIDER.items(): if _pref in request: if _pref == "response_types": for val in request[_pref]: match = False p = set(val.split(" ")) for cv in self.capabilities[_prov]: if p == set(cv.split(' ')): match = True break if not match: raise CapabilitiesMisMatch( 'Not allowed {}'.format(_pref)) else: if isinstance(request[_pref], str): if request[_pref] not in self.capabilities[_prov]: raise CapabilitiesMisMatch( 'Not allowed {}'.format(_pref)) else: if not set(request[_pref]).issubset( set(self.capabilities[_prov])): raise CapabilitiesMisMatch( 'Not allowed {}'.format(_pref)) def client_info(self, client_id): _cinfo = self.cdb[client_id].copy() if not valid_client_info(_cinfo): err = ErrorResponse(error="invalid_client", error_description="Invalid client secret") return BadRequest(err.to_json(), content="application/json") try: _cinfo["redirect_uris"] = self._tuples_to_uris( _cinfo["redirect_uris"]) except KeyError: pass msg = ClientInfoResponse(**_cinfo) return Response(msg.to_json(), content="application/json") def client_info_update(self, client_id, request): _cinfo = self.cdb[client_id].copy() try: _cinfo["redirect_uris"] = self._tuples_to_uris( _cinfo["redirect_uris"]) except KeyError: pass for key, value in request.items(): if key in ["client_secret", "client_id"]: # assure it's the same if value != _cinfo[key]: raise ModificationForbidden("Not allowed to change") else: _cinfo[key] = value for key in list(_cinfo.keys()): if key in [ "client_id_issued_at", "client_secret_expires_at", "registration_access_token", "registration_client_uri" ]: continue if key not in request: del _cinfo[key] if "redirect_uris" in request: _cinfo["redirect_uris"] = self._uris_to_tuples( request["redirect_uris"]) self.cdb[client_id] = _cinfo def verify_client(self, environ, areq, authn_method, client_id=""): """ :param environ: WSGI environ :param areq: The request :param authn_method: client authentication method :return: """ if not client_id: client_id = get_client_id(self.cdb, areq, environ["HTTP_AUTHORIZATION"]) try: method = self.client_authn_methods[authn_method] except KeyError: raise UnSupported() return method(self).verify(environ, client_id=client_id) def consume_software_statement(self, software_statement): return {} def registration_endpoint(self, **kwargs): """ :param request: The request :param authn: Client authentication information :param kwargs: extra keyword arguments :return: A Response instance """ _request = RegistrationRequest().deserialize(kwargs['request'], "json") try: _request.verify(keyjar=self.keyjar) except InvalidRedirectUri as err: msg = ClientRegistrationError(error="invalid_redirect_uri", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") except (MissingPage, VerificationError) as err: msg = ClientRegistrationError(error="invalid_client_metadata", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") # If authentication is necessary at registration if self.authn_at_registration: try: self.verify_client(kwargs['environ'], _request, self.authn_at_registration) except (AuthnFailure, UnknownAssertionType): return Unauthorized() client_restrictions = {} if 'parsed_software_statement' in _request: for ss in _request['parsed_software_statement']: client_restrictions.update(self.consume_software_statement(ss)) del _request['software_statement'] del _request['parsed_software_statement'] try: client_id = self.create_new_client(_request, client_restrictions) except CapabilitiesMisMatch as err: msg = ClientRegistrationError(error="invalid_client_metadata", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") except RestrictionError as err: msg = ClientRegistrationError(error="invalid_client_metadata", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") return self.client_info(client_id) def client_info_endpoint(self, method="GET", **kwargs): """ Operations on this endpoint are switched through the use of different HTTP methods :param method: HTTP method used for the request :param kwargs: keyword arguments :return: A Response instance """ _query = compact(parse_qs(kwargs['query'])) try: _id = _query["client_id"] except KeyError: return BadRequest("Missing query component") if _id not in self.cdb: return Unauthorized() # authenticated client try: self.verify_client(kwargs['environ'], kwargs['request'], "bearer_header", client_id=_id) except (AuthnFailure, UnknownAssertionType): return Unauthorized() if method == "GET": return self.client_info(_id) elif method == "PUT": try: _request = ClientUpdateRequest().from_json(kwargs['request']) except ValueError as err: return BadRequest(str(err)) try: _request.verify() except InvalidRedirectUri as err: msg = ClientRegistrationError(error="invalid_redirect_uri", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") except (MissingPage, VerificationError) as err: msg = ClientRegistrationError(error="invalid_client_metadata", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") try: self.client_info_update(_id, _request) return self.client_info(_id) except ModificationForbidden: return Forbidden() elif method == "DELETE": try: del self.cdb[_id] except KeyError: return Unauthorized() else: return NoContent() def provider_features(self, pcr_class=ServerMetadata, provider_config=None): """ Specifies what the server capabilities are. :param pcr_class: :return: ProviderConfigurationResponse instance """ _provider_info = pcr_class(**CAPABILITIES) _provider_info["scopes_supported"] = self.scopes sign_algs = list(jws.SIGNER_ALGS.keys()) sign_algs.remove('none') sign_algs = sorted(sign_algs, key=cmp_to_key(sort_sign_alg)) _pat1 = "{}_endpoint_auth_signing_alg_values_supported" _pat2 = "{}_endpoint_auth_methods_supported" for typ in ["token", "revocation", "introspection"]: _provider_info[_pat1.format(typ)] = sign_algs _provider_info[_pat2.format(typ)] = AUTH_METHODS_SUPPORTED if provider_config: _provider_info.update(provider_config) return _provider_info def verify_capabilities(self, capabilities): """ Verify that what the admin wants the server to do actually can be done by this implementation. :param capabilities: The asked for capabilities as a dictionary or a ProviderConfigurationResponse instance. The later can be treated as a dictionary. :return: True or False """ _pinfo = self.provider_features() for key, val in capabilities.items(): if isinstance(val, str): try: if val in _pinfo[key]: continue else: return False except KeyError: return False return True def create_providerinfo(self, pcr_class=ASConfigurationResponse, setup=None): """ Dynamically create the provider info response :param pcr_class: :param setup: :return: """ _provider_info = self.capabilities if self.jwks_uri and self.keyjar: _provider_info["jwks_uri"] = self.jwks_uri for endp in self.endp: _provider_info['{}_endpoint'.format(endp.etype)] = os.path.join( self.baseurl, endp.url) if setup and isinstance(setup, dict): for key in pcr_class.c_param.keys(): if key in setup: _provider_info[key] = setup[key] _provider_info["issuer"] = self.baseurl _provider_info["version"] = "3.0" return _provider_info def providerinfo_endpoint(self, **kwargs): _log_info = logger.info _log_info("@providerinfo_endpoint") try: _response = self.create_providerinfo() _log_info("provider_info_response: %s" % (_response.to_dict(), )) headers = [("Cache-Control", "no-store"), ("x-ffo", "bar")] if 'handle' in kwargs: (key, _) = kwargs['handle'] if key.startswith(STR) and key.endswith(STR): cookie = self.cookie_func(key, self.cookie_name, "pinfo", self.sso_ttl) headers.append(cookie) resp = Response(_response.to_json(), content="application/json", headers=headers) except Exception: message = traceback.format_exception(*sys.exc_info()) logger.error(message) resp = Response(message, content="html/text") return resp @staticmethod def verify_code_challenge(code_verifier, code_challenge, code_challenge_method='S256'): """ Verify a PKCE (RFC7636) code challenge :param code_verifier: The origin :param code_challenge: The transformed verifier used as challenge :return: """ _h = CC_METHOD[code_challenge_method]( code_verifier.encode('ascii')).digest() _cc = b64e(_h) if _cc.decode('ascii') != code_challenge: logger.error('PCKE Code Challenge check failed') err = TokenErrorResponse(error="invalid_request", error_description="PCKE check failed") return Response(err.to_json(), content="application/json", status_code=401) return True def do_access_token_response(self, access_token, atinfo, state, refresh_token=None): _tinfo = { 'access_token': access_token, 'expires_in': atinfo['exp'], 'token_type': 'bearer', 'state': state } try: _tinfo['scope'] = atinfo['scope'] except KeyError: pass if refresh_token: _tinfo['refresh_token'] = refresh_token return AccessTokenResponse(**by_schema(AccessTokenResponse, **_tinfo)) def code_grant_type(self, areq): # assert that the code is valid try: _info = self.sdb[areq["code"]] except KeyError: err = TokenErrorResponse(error="invalid_grant", error_description="Unknown access grant") return Response(err.to_json(), content="application/json", status="401 Unauthorized") authzreq = json.loads(_info['authzreq']) if 'code_verifier' in areq: try: _method = authzreq['code_challenge_method'] except KeyError: _method = 'S256' resp = self.verify_code_challenge(areq['code_verifier'], authzreq['code_challenge'], _method) if resp: return resp if 'state' in areq: if self.sdb[areq['code']]['state'] != areq['state']: logger.error('State value mismatch') err = TokenErrorResponse(error="unauthorized_client") return Unauthorized(err.to_json(), content="application/json") resp = self.token_scope_check(areq, _info) if resp: return resp # If redirect_uri was in the initial authorization request # verify that the one given here is the correct one. if "redirect_uri" in _info and areq["redirect_uri"] != _info[ "redirect_uri"]: logger.error('Redirect_uri mismatch') err = TokenErrorResponse(error="unauthorized_client") return Unauthorized(err.to_json(), content="application/json") issue_refresh = False if 'scope' in authzreq and 'offline_access' in authzreq['scope']: if authzreq['response_type'] == 'code': issue_refresh = True try: _tinfo = self.sdb.upgrade_to_token(areq["code"], issue_refresh=issue_refresh) except AccessCodeUsed: err = TokenErrorResponse(error="invalid_grant", error_description="Access grant used") return Response(err.to_json(), content="application/json", status="401 Unauthorized") logger.debug("_tinfo: %s" % _tinfo) atr = AccessTokenResponse(**by_schema(AccessTokenResponse, **_tinfo)) logger.debug("AccessTokenResponse: %s" % atr) return Response(atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS) def client_credentials_grant_type(self, areq): _at = self.token_handler.get_access_token( areq['client_id'], scope=areq['scope'], grant_type='client_credentials') _info = self.token_handler.token_factory.get_info(_at) try: _rt = self.token_handler.get_refresh_token(self.baseurl, _info['access_token'], 'client_credentials') except NotAllowed: atr = self.do_access_token_response(_at, _info, areq['state']) else: atr = self.do_access_token_response(_at, _info, areq['state'], _rt) return Response(atr.to_json(), content="application/json") def password_grant_type(self, areq): _at = self.token_handler.get_access_token(areq['client_id'], scope=areq['scope'], grant_type='password') _info = self.token_handler.token_factory.get_info(_at) try: _rt = self.token_handler.get_refresh_token(self.baseurl, _info['access_token'], 'password') except NotAllowed: atr = self.do_access_token_response(_at, _info, areq['state']) else: atr = self.do_access_token_response(_at, _info, areq['state'], _rt) return Response(atr.to_json(), content="application/json") def refresh_token_grant_type(self, areq): at = self.token_handler.refresh_access_token(self.baseurl, areq['access_token'], 'refresh_token') atr = AccessTokenResponse(**by_schema(AccessTokenResponse, **at)) return Response(atr.to_json(), content="application/json") def token_endpoint(self, authn="", **kwargs): """ This is where clients come to get their access tokens """ logger.debug("- token -") body = kwargs["request"] logger.debug("body: %s" % body) areq = AccessTokenRequest().deserialize(body, "urlencoded") try: self.client_authn(self, areq, authn) except FailedAuthentication as err: logger.error(err) err = TokenErrorResponse(error="unauthorized_client", error_description="%s" % err) return Response(err.to_json(), content="application/json", status_code=401) logger.debug("AccessTokenRequest: %s" % areq) _grant_type = areq["grant_type"] if _grant_type == "authorization_code": return self.code_grant_type(areq) elif _grant_type == 'client_credentials': return self.client_credentials_grant_type(areq) elif _grant_type == 'password': return self.password_grant_type(areq) elif _grant_type == 'refresh_token': return self.refresh_token_grant_type(areq) else: raise UnSupported('grant_type: {}'.format(_grant_type)) def key_setup(self, local_path, vault="keys", sig=None, enc=None): """ my keys :param local_path: The path to where the JWKs should be stored :param vault: Where the private key will be stored :param sig: Key for signature :param enc: Key for encryption :return: A URL the RP can use to download the key. """ self.jwks_uri = key_export(self.baseurl, local_path, vault, self.keyjar, fqdn=self.hostname, sig=sig, enc=enc) @staticmethod def token_access(endpoint, client_id, token_info): # simple rules: if client_id in azp or aud it's allow to introspect # to revoke it has to be in azr allow = False if endpoint == 'revocation_endpoint': if 'azr' in token_info and client_id == token_info['azr']: allow = True elif len(token_info['aud']) == 1 and token_info['aud'] == [ client_id ]: allow = True else: # has to be introspection endpoint if 'azr' in token_info and client_id == token_info['azr']: allow = True elif 'aud' in token_info: if client_id in token_info['aud']: allow = True return allow def get_token_info(self, authn, req, endpoint): """ :param authn: :param req: :return: """ try: client_id = self.client_authn(self, req, authn) except FailedAuthentication as err: logger.error(err) err = TokenErrorResponse(error="unauthorized_client", error_description="%s" % err) return Response(err.to_json(), content="application/json", status="401 Unauthorized") logger.debug('{}: {} requesting {}'.format(endpoint, client_id, req.to_dict())) try: token_type = req['token_type_hint'] except KeyError: try: _info = self.sdb.token_factory['access_token'].get_info( req['token']) except Exception: try: _info = self.sdb.token_factory['refresh_token'].get_info( req['token']) except Exception: return self._return_inactive() else: token_type = 'refresh_token' else: token_type = 'access_token' else: try: _info = self.sdb.token_factory[token_type].get_info( req['token']) except Exception: return self._return_inactive() if not self.token_access(endpoint, client_id, _info): return BadRequest() return client_id, token_type, _info @staticmethod def _return_inactive(): ir = TokenIntrospectionResponse(active=False) return Response(ir.to_json(), content="application/json") def revocation_endpoint(self, authn='', request=None, **kwargs): """ Implements RFC7009 allows a client to invalidate an access or refresh token. :param authn: Client Authentication information :param request: The revocation request :param kwargs: :return: """ trr = TokenRevocationRequest().deserialize(request, "urlencoded") resp = self.get_token_info(authn, trr, 'revocation_endpoint') if isinstance(resp, Response): return resp else: client_id, token_type, _info = resp logger.info('{} token revocation: {}'.format(client_id, trr.to_dict())) try: self.sdb.token_factory[token_type].invalidate(trr['token']) except KeyError: return BadRequest() else: return Response('OK') def introspection_endpoint(self, authn='', request=None, **kwargs): """ Implements RFC7662 :param authn: Client Authentication information :param request: The introspection request :param kwargs: :return: """ tir = TokenIntrospectionRequest().deserialize(request, "urlencoded") resp = self.get_token_info(authn, tir, 'introspection_endpoint') if isinstance(resp, Response): return resp else: client_id, token_type, _info = resp logger.info('{} token introspection: {}'.format( client_id, tir.to_dict())) ir = TokenIntrospectionResponse( active=self.sdb.token_factory[token_type].is_valid(_info), **_info.to_dict()) ir.weed() return Response(ir.to_json(), content="application/json")
class Provider(provider.Provider): """ A OAuth2 RP that knows all the OAuth2 extensions I've implemented """ def __init__( self, name, sdb, cdb, authn_broker, authz, client_authn, symkey="", urlmap=None, iv=0, default_scope="", ca_bundle=None, seed=b"", client_authn_methods=None, authn_at_registration="", client_info_url="", secret_lifetime=86400, jwks_uri="", keyjar=None, capabilities=None, verify_ssl=True, baseurl="", hostname="", ): if not name.endswith("/"): name += "/" provider.Provider.__init__( self, name, sdb, cdb, authn_broker, authz, client_authn, symkey, urlmap, iv, default_scope, ca_bundle ) self.endp.extend([RegistrationEndpoint, ClientInfoEndpoint]) # dictionary of client authentication methods self.client_authn_methods = client_authn_methods if authn_at_registration: if authn_at_registration not in client_authn_methods: raise UnknownAuthnMethod(authn_at_registration) self.authn_at_registration = authn_at_registration self.seed = seed self.client_info_url = client_info_url self.secret_lifetime = secret_lifetime self.jwks_uri = jwks_uri self.verify_ssl = verify_ssl self.keyjar = keyjar if self.keyjar is None: self.keyjar = KeyJar(verify_ssl=self.verify_ssl) if capabilities: self.verify_capabilities(capabilities) self.capabilities = ProviderConfigurationResponse(**capabilities) else: self.capabilities = self.provider_features() self.baseurl = baseurl self.hostname = hostname or gethostname() self.kid = {"sig": {}, "enc": {}} @staticmethod def _uris_to_tuples(uris): tup = [] for uri in uris: base, query = splitquery(uri) if query: tup.append((base, query)) else: tup.append((base, "")) return tup @staticmethod def _tuples_to_uris(items): _uri = [] for url, query in items: if query: _uri.append("%s?%s" % (url, query)) else: _uri.append(url) return _uri def load_keys(self, request, client_id, client_secret): try: self.keyjar.load_keys(request, client_id) try: logger.debug("keys for %s: [%s]" % (client_id, ",".join(["%s" % x for x in self.keyjar[client_id]]))) except KeyError: pass except Exception as err: logger.error("Failed to load client keys: %s" % request.to_dict()) logger.error("%s", err) err = ClientRegistrationError(error="invalid_configuration_parameter", error_description="%s" % err) return Response(err.to_json(), content="application/json", status="400 Bad Request") # Add the client_secret as a symmetric key to the keyjar _kc = KeyBundle( [{"kty": "oct", "key": client_secret, "use": "ver"}, {"kty": "oct", "key": client_secret, "use": "sig"}] ) try: self.keyjar[client_id].append(_kc) except KeyError: self.keyjar[client_id] = [_kc] def create_new_client(self, request): """ :param request: The Client registration request :return: The client_id """ _cinfo = request.to_dict() # create new id and secret _id = rndstr(12) while _id in self.cdb: _id = rndstr(12) _cinfo["client_id"] = _id _cinfo["client_secret"] = secret(self.seed, _id) _cinfo["client_id_issued_at"] = utc_time_sans_frac() _cinfo["client_secret_expires_at"] = utc_time_sans_frac() + self.secret_lifetime # If I support client info endpoint if ClientInfoEndpoint in self.endp: _cinfo["registration_access_token"] = rndstr(32) _cinfo["registration_client_uri"] = "%s%s%s?client_id=%s" % ( self.name, self.client_info_url, ClientInfoEndpoint.etype, _id, ) if "redirect_uris" in request: _cinfo["redirect_uris"] = self._uris_to_tuples(request["redirect_uris"]) self.load_keys(request, _id, _cinfo["client_secret"]) self.cdb[_id] = _cinfo return _id def client_info(self, client_id): _cinfo = self.cdb[client_id].copy() try: _cinfo["redirect_uris"] = self._tuples_to_uris(_cinfo["redirect_uris"]) except KeyError: pass msg = ClientInfoResponse(**_cinfo) return Response(msg.to_json(), content="application/json") def client_info_update(self, client_id, request): _cinfo = self.cdb[client_id].copy() try: _cinfo["redirect_uris"] = self._tuples_to_uris(_cinfo["redirect_uris"]) except KeyError: pass for key, value in request.items(): if key in ["client_secret", "client_id"]: # assure it's the same try: assert value == _cinfo[key] except AssertionError: raise ModificationForbidden("Not allowed to change") else: _cinfo[key] = value for key in list(_cinfo.keys()): if key in [ "client_id_issued_at", "client_secret_expires_at", "registration_access_token", "registration_client_uri", ]: continue if key not in request: del _cinfo[key] if "redirect_uris" in request: _cinfo["redirect_uris"] = self._uris_to_tuples(request["redirect_uris"]) self.cdb[client_id] = _cinfo def verify_client(self, environ, areq, authn_method, client_id=""): """ :param environ: WSGI environ :param areq: The request :param authn_method: client authentication method :return: """ if not client_id: client_id = get_client_id(self.cdb, areq, environ["HTTP_AUTHORIZATION"]) try: method = self.client_authn_methods[authn_method] except KeyError: raise UnSupported() return method(self).verify(environ, client_id=client_id) def registration_endpoint(self, **kwargs): """ :param request: The request :param authn: Client authentication information :param kwargs: extra keyword arguments :return: A Response instance """ _request = RegistrationRequest().deserialize(kwargs["request"], "json") try: _request.verify() except InvalidRedirectUri as err: msg = ClientRegistrationError(error="invalid_redirect_uri", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") except (MissingPage, VerificationError) as err: msg = ClientRegistrationError(error="invalid_client_metadata", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") # authenticated client if self.authn_at_registration: try: _ = self.verify_client(kwargs["environ"], _request, self.authn_at_registration) except (AuthnFailure, UnknownAssertionType): return Unauthorized() client_id = self.create_new_client(_request) return self.client_info(client_id) def client_info_endpoint(self, method="GET", **kwargs): """ Operations on this endpoint are switched through the use of different HTTP methods :param method: HTTP method used for the request :param kwargs: keyword arguments :return: A Response instance """ _query = parse_qs(kwargs["query"]) try: _id = _query["client_id"][0] except KeyError: return BadRequest("Missing query component") try: assert _id in self.cdb except AssertionError: return Unauthorized() # authenticated client try: _ = self.verify_client(kwargs["environ"], kwargs["request"], "bearer_header", client_id=_id) except (AuthnFailure, UnknownAssertionType): return Unauthorized() if method == "GET": return self.client_info(_id) elif method == "PUT": try: _request = ClientUpdateRequest().from_json(kwargs["request"]) except ValueError as err: return BadRequest(str(err)) try: _request.verify() except InvalidRedirectUri as err: msg = ClientRegistrationError(error="invalid_redirect_uri", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") except (MissingPage, VerificationError) as err: msg = ClientRegistrationError(error="invalid_client_metadata", error_description="%s" % err) return BadRequest(msg.to_json(), content="application/json") try: self.client_info_update(_id, _request) return self.client_info(_id) except ModificationForbidden: return Forbidden() elif method == "DELETE": try: del self.cdb[_id] except KeyError: return Unauthorized() else: return NoContent() def provider_features(self, pcr_class=ProviderConfigurationResponse): """ Specifies what the server capabilities are. :param pcr_class: :return: ProviderConfigurationResponse instance """ _provider_info = pcr_class(**CAPABILITIES) _claims = [] for _cl in SCOPE2CLAIMS.values(): _claims.extend(_cl) _provider_info["claims_supported"] = list(set(_claims)) _scopes = list(SCOPE2CLAIMS.keys()) _scopes.append("openid") _provider_info["scopes_supported"] = _scopes sign_algs = list(jws.SIGNER_ALGS.keys()) for typ in ["userinfo", "id_token", "request_object"]: _provider_info["%s_signing_alg_values_supported" % typ] = sign_algs # Remove 'none' for token_endpoint_auth_signing_alg_values_supported # since it is not allowed sign_algs = sign_algs[:] sign_algs.remove("none") _provider_info["token_endpoint_auth_signing_alg_values_supported"] = sign_algs algs = jwe.SUPPORTED["alg"] for typ in ["userinfo", "id_token", "request_object"]: _provider_info["%s_encryption_alg_values_supported" % typ] = algs encs = jwe.SUPPORTED["enc"] for typ in ["userinfo", "id_token", "request_object"]: _provider_info["%s_encryption_enc_values_supported" % typ] = encs # acr_values if self.authn_broker: acr_values = self.authn_broker.getAcrValuesString() if acr_values is not None: _provider_info["acr_values_supported"] = acr_values return _provider_info def verify_capabilities(self, capabilities): """ Verify that what the admin wants the server to do actually can be done by this implementation. :param capabilities: The asked for capabilities as a dictionary or a ProviderConfigurationResponse instance. The later can be treated as a dictionary. :return: True or False """ _pinfo = self.provider_features() for key, val in capabilities.items(): if isinstance(val, six.string_types): try: if val in _pinfo[key]: continue else: return False except KeyError: return False return True def create_providerinfo(self, pcr_class=ProviderConfigurationResponse, setup=None): """ Dynamically create the provider info response :param pcr_class: :param setup: :return: """ _provider_info = self.capabilities if self.jwks_uri and self.keyjar: _provider_info["jwks_uri"] = self.jwks_uri for endp in self.endp: # _log_info("# %s, %s" % (endp, endp.name)) _provider_info["{}_endpoint".format(endp.etype)] = os.path.join(self.baseurl, endp.url) if setup and isinstance(setup, dict): for key in pcr_class.c_param.keys(): if key in setup: _provider_info[key] = setup[key] _provider_info["issuer"] = self.baseurl _provider_info["version"] = "3.0" return _provider_info def providerinfo_endpoint(self, **kwargs): _log_info = logger.info _log_info("@providerinfo_endpoint") try: _response = self.create_providerinfo() _log_info("provider_info_response: %s" % (_response.to_dict(),)) headers = [("Cache-Control", "no-store"), ("x-ffo", "bar")] if "handle" in kwargs: (key, timestamp) = kwargs["handle"] if key.startswith(STR) and key.endswith(STR): cookie = self.cookie_func(key, self.cookie_name, "pinfo", self.sso_ttl) headers.append(cookie) resp = Response(_response.to_json(), content="application/json", headers=headers) except Exception as err: message = traceback.format_exception(*sys.exc_info()) logger.error(message) resp = Response(message, content="html/text") return resp def token_endpoint(self, authn="", **kwargs): """ This is where clients come to get their access tokens """ _sdb = self.sdb logger.debug("- token -") body = kwargs["request"] logger.debug("body: %s" % body) areq = AccessTokenRequest().deserialize(body, "urlencoded") try: client = self.client_authn(self, areq, authn) except FailedAuthentication as err: err = TokenErrorResponse(error="unauthorized_client", error_description="%s" % err) return Response(err.to_json(), content="application/json", status="401 Unauthorized") logger.debug("AccessTokenRequest: %s" % areq) try: assert areq["grant_type"] == "authorization_code" except AssertionError: err = TokenErrorResponse(error="invalid_request", error_description="Wrong grant type") return Response(err.to_json(), content="application/json", status="401 Unauthorized") # assert that the code is valid _info = _sdb[areq["code"]] resp = self.token_scope_check(areq, _info) if resp: return resp # If redirect_uri was in the initial authorization request # verify that the one given here is the correct one. if "redirect_uri" in _info: assert areq["redirect_uri"] == _info["redirect_uri"] try: _tinfo = _sdb.upgrade_to_token(areq["code"], issue_refresh=True) except AccessCodeUsed: err = TokenErrorResponse(error="invalid_grant", error_description="Access grant used") return Response(err.to_json(), content="application/json", status="401 Unauthorized") logger.debug("_tinfo: %s" % _tinfo) atr = AccessTokenResponse(**by_schema(AccessTokenResponse, **_tinfo)) logger.debug("AccessTokenResponse: %s" % atr) return Response(atr.to_json(), content="application/json") def key_setup(self, local_path, vault="keys", sig=None, enc=None): """ my keys :param local_path: The path to where the JWKs should be stored :param vault: Where the private key will be stored :param sig: Key for signature :param enc: Key for encryption :return: A URL the RP can use to download the key. """ self.jwks_uri = key_export(self.baseurl, local_path, vault, self.keyjar, fqdn=self.hostname, sig=sig, enc=enc)