class MCData: """モノトニックカウンタサーバに保存されるデータ""" def __init__(self): self.ctr = dict() self.known_keys = JWKSet() def add_new(self, v, pubkey): k = random.randrange(2**32) while k in self.ctr: k = random.randramge(2**32) self.ctr[k] = [v, pubkey] self.known_keys.add(pubkey) return k def increment(self, key, inc=1): if inc < 0: raise ValueError('tried to subtract') if key in self.ctr: self.ctr[key][0] = self.ctr[key][0] + inc return self.ctr[key][0] else: raise KeyError('no key') def pubkey(self, key): return self.ctr[key][1]
def gen_keys(key_size): try: from jwcrypto.jwk import JWK, JWKSet except ImportError as e: msg = "You have to install jwcrypto to use this function" print(msg) raise ImportError(msg) from e jwk = JWK() jwk.generate_key(generate="RSA", size=key_size) contents = jwk.export_to_pem(private_key=True, password=None) with open("private.pem", "w") as priv_pem_file: priv_pem_file.write(contents.decode("utf8")) contents = jwk.export_to_pem(private_key=False, password=None) with open("public.pem", "w") as priv_pem_file: priv_pem_file.write(contents.decode("utf8")) jwks = JWKSet() jwks.add(jwk) raw = jwks.export(private_keys=True) formatted = json.dumps(json.loads(raw), indent=2) with open("private.json", "w") as priv_jwks_file: priv_jwks_file.write(formatted) raw = jwks.export(private_keys=False) formatted = json.dumps(json.loads(raw), indent=2) with open("public.json", "w") as public_jwks_file: public_jwks_file.write(formatted)
def dump_pem_to_jwks(in_private): try: from jwcrypto.jwk import JWK, JWKSet except ImportError as e: msg = "You have to install jwcrypto to use this function" print(msg) raise ImportError(msg) from e with open(in_private, "rb") as privfile: data = privfile.read() jwk = JWK() jwk.import_from_pem(data) jwks = JWKSet() jwks.add(jwk) raw = jwks.export(private_keys=True) formatted = json.dumps(json.loads(raw), indent=2) with open("private.json", "w") as priv_jwks_file: priv_jwks_file.write(formatted) raw = jwks.export(private_keys=False) formatted = json.dumps(json.loads(raw), indent=2) with open("public.json", "w") as public_jwks_file: public_jwks_file.write(formatted)
def listAll(): repo = KeyRepository(getDb()) docList = repo.fetchAll() jwkset = JWKSet() for keyDoc in docList: jwk = JWK.from_pem(base64.b64decode(keyDoc.publicKey)) # wrong way, no idea how to make it proper :| jwk._params['kid'] = str(keyDoc.id) jwk._params['alg'] = keyDoc.algorithm jwkset.add(jwk) return jwkset.export(private_keys=False, as_dict=True)
def test_validate_token(): payload = {'typ': 'Bearer', 'foo': 'bar', 'baz': 42} other_headers = {'kid': key.key_id} token = generate_jwt(payload, key, 'RS256', datetime.timedelta(minutes=5), other_headers=other_headers) header, claims = verify_jwt(token, key, ['RS256']) assert header is not None assert claims is not None keyset = JWKSet() keyset.add(key) assert validate_token(keyset, token, clock_skew_seconds=60)
def configure(self, opts, changes): if opts["openidc"] != "yes": return path = os.path.join(opts["data_dir"], "openidc") if not os.path.exists(path): os.makedirs(path, 0700) keyfile = os.path.join(path, "openidc.key") keyid = int(time.time()) keyset = JWKSet() # We generate one RSA2048 signing key rsasig = JWK(generate="RSA", size=2048, use="sig", kid="%s-sig" % keyid) keyset.add(rsasig) # We generate one RSA2048 encryption key rsasig = JWK(generate="RSA", size=2048, use="enc", kid="%s-enc" % keyid) keyset.add(rsasig) with open(keyfile, "w") as m: m.write(keyset.export()) proto = "https" url = "%s://%s/%s/openidc/" % (proto, opts["hostname"], opts["instance"]) subject_salt = uuid.uuid4().hex if opts["openidc_subject_salt"]: subject_salt = opts["openidc_subject_salt"] # Add configuration data to database po = PluginObject(*self.pargs) po.name = "openidc" po.wipe_data() po.wipe_config_values() config = { "endpoint url": url, "database url": opts["openidc_dburi"] or opts["database_url"] % {"datadir": opts["data_dir"], "dbname": "openidc"}, "enabled extensions": opts["openidc_extensions"], "idp key file": keyfile, "idp sig key id": "%s-sig" % keyid, "idp subject salt": subject_salt, } po.save_plugin_config(config) # Update global config to add login plugin po.is_enabled = True po.save_enabled_state()
def init(workdir): # Initialize SAML2, since this is quite tricky to get right cert = Certificate(os.path.join(workdir, 'saml2')) cert.generate('certificate', 'ipsilon-quickrun') url = 'http://localhost:8080/' validity = 365 * 5 meta = IdpMetadataGenerator(url, cert, timedelta(validity)) meta.output(os.path.join(workdir, 'saml2', 'metadata.xml')) # Also initalize OpenID Connect keyfile = os.path.join(workdir, 'openidc.key') keyset = JWKSet() # We generate one RSA2048 signing key rsasig = JWK(generate='RSA', size=2048, use='sig', kid='quickstart') keyset.add(rsasig) with open(keyfile, 'w') as m: m.write(keyset.export())
def init(workdir): # Initialize SAML2, since this is quite tricky to get right cert = Certificate(os.path.join(workdir, 'saml2')) cert.generate('certificate', 'ipsilon-quickrun') url = 'http://localhost:8080/idp' validity = 365 * 5 meta = IdpMetadataGenerator(url, cert, timedelta(validity)) meta.output(os.path.join(workdir, 'saml2', 'metadata.xml')) # Also initalize OpenID Connect keyfile = os.path.join(workdir, 'openidc.key') keyset = JWKSet() # We generate one RSA2048 signing key rsasig = JWK(generate='RSA', size=2048, use='sig', kid='quickstart') keyset.add(rsasig) with open(keyfile, 'w') as m: m.write(keyset.export())
#!/usr/bin/python import time import os.path from jwcrypto.jwk import JWK, JWKSet keyid = int(time.time()) keyset = JWKSet() rsasig = JWK(generate='RSA', size=2048, use='sig', kid='%s-sig' % keyid) keyset.add(rsasig) rsasig = JWK(generate='RSA', size=2048, use='enc', kid='%s-enc' % keyid) keyset.add(rsasig) if not os.path.exists('/var/lib/ipsilon/idp/openidc'): os.makedirs('/var/lib/ipsilon/idp/openidc') with open('/var/lib/ipsilon/idp/openidc/openidc.key', 'w') as m: m.write(keyset.export())
def jwks(self) -> str: jwks = JWKSet() jwks.add(self._jwk) return jwks.export(private_keys=False)
def configure(self, opts, changes): if opts['openidc'] != 'yes': return path = os.path.join(opts['data_dir'], 'openidc') if not os.path.exists(path): os.makedirs(path, 0o700) keyfile = os.path.join(path, 'openidc.key') keyid = int(time.time()) keyset = JWKSet() # We generate one RSA2048 signing key rsasig = JWK(generate='RSA', size=2048, use='sig', kid='%s-sig' % keyid) keyset.add(rsasig) # We generate one RSA2048 encryption key rsasig = JWK(generate='RSA', size=2048, use='enc', kid='%s-enc' % keyid) keyset.add(rsasig) with open(keyfile, 'w') as m: m.write(keyset.export()) proto = 'https' url = '%s://%s%s/openidc/' % (proto, opts['hostname'], opts['instanceurl']) subject_salt = uuid.uuid4().hex if opts['openidc_subject_salt']: subject_salt = opts['openidc_subject_salt'] # Add configuration data to database po = PluginObject(*self.pargs) po.name = 'openidc' po.wipe_data() po.wipe_config_values() config = { 'endpoint url': url, 'database url': opts['openidc_dburi'] or opts['database_url'] % { 'datadir': opts['data_dir'], 'dbname': 'openidc' }, 'static database url': opts['openidc_static_dburi'] or opts['database_url'] % { 'datadir': opts['data_dir'], 'dbname': 'openidc.static' }, 'enabled extensions': opts['openidc_extensions'], 'idp key file': keyfile, 'idp sig key id': '%s-sig' % keyid, 'idp subject salt': subject_salt } po.save_plugin_config(config) # Update global config to add login plugin po.is_enabled = True po.save_enabled_state()
class IdpProvider(ProviderBase): def __init__(self, *pargs): super(IdpProvider, self).__init__('openidc', 'OpenID Connect', 'openidc', *pargs) self.mapping = InfoMapping() self.keyset = None self.admin = None self.page = None self.datastore = None self.server = None self.basepath = None self.extensions = LoadExtensions() self.description = """ Provides OpenID Connect authentication infrastructure. """ self.new_config( self.name, pconfig.String('database url', 'Database URL for OpenID Connect storage', 'openidc.sqlite'), pconfig.String( 'static database url', 'Database URL for OpenID Connect static client configuration', 'openidc.static.sqlite'), pconfig.Choice('enabled extensions', 'Choose the extensions to enable', self.extensions.available().keys()), pconfig.String('endpoint url', 'The Absolute URL of the OpenID Connect provider', 'http://localhost:8080/openidc/'), pconfig.String( 'documentation url', 'The Absolute URL of the OpenID Connect documentation', 'https://ipsilonproject.org/doc/openidc/'), pconfig.String('policy url', 'The Absolute URL of the OpenID Connect policy', 'http://www.example.com/'), pconfig.String( 'tos url', 'The Absolute URL of the OpenID Connect terms of service', 'http://www.example.com/'), pconfig.String('idp key file', 'The file where the OpenIDC keyset is stored.', 'openidc.key'), pconfig.String('idp sig key id', 'The key to use for signing.', ''), pconfig.String('idp subject salt', 'The salt used for pairwise subjects.', None), pconfig.Condition( 'allow dynamic client registration', 'Allow Dynamic Client registrations for Relying Parties', True), pconfig.MappingList('default attribute mapping', 'Defines how to map attributes', [['*', '*']]), pconfig.ComplexList( 'default allowed attributes', 'Defines a list of allowed attributes, applied after mapping', ['*']), ) @property def endpoint_url(self): url = self.get_config_value('endpoint url') if url.endswith('/'): return url else: return url + '/' @property def documentation_url(self): url = self.get_config_value('documentation url') if url.endswith('/'): return url else: return url + '/' @property def policy_url(self): url = self.get_config_value('policy url') if url.endswith('/'): return url else: return url + '/' @property def tos_url(self): url = self.get_config_value('tos url') if url.endswith('/'): return url else: return url + '/' @property def enabled_extensions(self): return self.get_config_value('enabled extensions') @property def idp_key_file(self): return self.get_config_value('idp key file') @property def idp_sig_key_id(self): return self.get_config_value('idp sig key id') @property def idp_subject_salt(self): return self.get_config_value('idp subject salt') @property def allow_dynamic_client_registration(self): return self.get_config_value('allow dynamic client registration') @property def default_attribute_mapping(self): return self.get_config_value('default attribute mapping') @property def default_allowed_attributes(self): return self.get_config_value('default allowed attributes') @property def supported_scopes(self): supported = ['openid'] # Default scopes used in OpenID Connect claims supported.extend(['profile', 'email', 'address', 'phone']) for _, ext in self.extensions.available().items(): supported.extend(ext.get_scopes()) return supported def get_tree(self, site): self.page = OpenIDC(site, self) self.admin = OpenIDCAdminPage(site, self) return self.page def used_datastores(self): return [self.datastore, self.datastore.static_store] def init_idp(self): self.keyset = JWKSet() with open(self.idp_key_file, 'r') as keyfile: loaded_keys = json.loads(keyfile.read()) for key in loaded_keys['keys']: self.keyset.add(JWK(**key)) static_store = OpenIDCStaticStore( self.get_config_value('static database url')) self.datastore = OpenIDCStore(self.get_config_value('database url'), static_store) def openid_connect_issuer_wf_rel(self, resource): link = { 'rel': 'http://openid.net/specs/connect/1.0/issuer', 'href': self.endpoint_url } return {'links': [link]} def on_enable(self): super(IdpProvider, self).on_enable() self.init_idp() self.extensions.enable(self._config['enabled extensions'].get_value(), self) self._root.webfinger.register_rel( 'http://openid.net/specs/connect/1.0/issuer', self.openid_connect_issuer_wf_rel) def on_disable(self): super(IdpProvider, self).on_enable() self._root.webfinger.unregister_rel( 'http://openid.net/specs/connect/1.0/issuer') def get_client_display_name(self, clientid): return self.datastore.getClient(clientid)['client_name'] def consent_to_display(self, consentdata): d = [] if len(consentdata['scopes']) > 0: scopes = [] for dummy_n, e in self.extensions.available().items(): data = e.get_display_data(consentdata['scopes']) if len(data) > 0: scopes.append(e.get_display_name()) d.append('Scopes: %s' % ', '.join(sorted(scopes))) if len(consentdata['claims']) > 0: d.append('Claims: %s' % ', '.join( [self.mapping.display_name(x) for x in consentdata['claims']])) return d def revoke_consent(self, user, clientid): return self.datastore.revokeConsent(user, clientid) def on_reconfigure(self): super(IdpProvider, self).on_reconfigure() self.init_idp() self.extensions.enable(self._config['enabled extensions'].get_value(), self)
def start_authz(self, arguments): request_data = { 'scope': [], 'response_type': [], 'client_id': None, 'redirect_uri': None, 'state': None, 'response_mode': None, 'nonce': None, 'display': None, 'prompt': [], 'max_age': None, 'ui_locales': None, 'id_token_hint': None, 'login_hint': None, 'acr_values': None, 'claims': '{}' } # Get the request # Step 1: get the get query arguments for data in request_data.keys(): if arguments.get(data, None): request_data[data] = arguments[data] # This is a workaround for python not understanding the splits we # do later if request_data['prompt'] == []: request_data['prompt'] = None for required_arg in ['scope', 'response_type', 'client_id']: if request_data[required_arg] is None or \ len(request_data[required_arg]) == 0: return self._respond_error( request_data, 'invalid_request', 'missing required argument %s' % required_arg) client = self.cfg.datastore.getClient(request_data['client_id']) if not client: return self._respond_error(request_data, 'unauthorized_client', 'Unknown client ID') request_data['response_type'] = request_data.get('response_type', '').split(' ') for rtype in request_data['response_type']: if rtype not in ['id_token', 'token', 'code']: return self._respond_error( request_data, 'unsupported_response_type', 'response type %s is not supported' % rtype) if request_data['response_type'] != ['code'] and \ not request_data['nonce']: return self._respond_error(request_data, 'invalid_request', 'nonce missing in non-code flow') # Step 2: get any provided request or request_uri if 'request' in arguments or 'request_uri' in arguments: # This is a JWT-encoded request if 'request' in arguments and 'request_uri' in arguments: return self._respond_error( request_data, 'invalid_request', 'both request and request_uri ' + 'provided') if 'request' in arguments: jwt_object = arguments['request'] else: try: # FIXME: MAY cache this at client registration time and # cache permanently until client registration is changed. jwt_object = requests.get(arguments['request_uri']).text except Exception as ex: # pylint: disable=broad-except self.debug('Unable to get request: %s' % ex) return self._respond_error(request_data, 'invalid_request', 'unable to parse request_uri') jwt_request = None try: # FIXME: Implement decryption decoded = JWT(jwt=jwt_object) if client['request_object_signing_alg'] != 'none': # Client told us we need to check signature if decoded.token.jose_header['alg'] != \ client['request_object_signing_alg']: raise Exception('Invalid algorithm used: %s' % decoded.token.jose_header['alg']) if client['request_object_signing_alg'] == 'none': jwt_request = json.loads(decoded.token.objects['payload']) else: keyset = None if client['jwks']: keys = json.loads(client['jkws']) else: keys = requests.get(client['jwks_uri']).json() keyset = JWKSet() for key in keys['keys']: keyset.add(JWK(**key)) key = keyset.get_key(decoded.token.jose_header['kid']) decoded = JWT(jwt=jwt_object, key=key) jwt_request = json.loads(decoded.claims) except Exception as ex: # pylint: disable=broad-except self.debug('Unable to parse request: %s' % ex) return self._respond_error(request_data, 'invalid_request', 'unable to parse request') if 'response_type' in jwt_request: jwt_request['response_type'] = \ jwt_request['response_type'].split(' ') if jwt_request['response_type'] != \ request_data['response_type']: return self._respond_error(request_data, 'invalid_request', 'response_type does not match') if 'client_id' in jwt_request: if jwt_request['client_id'] != request_data['client_id']: return self._respond_error(request_data, 'invalid_request', 'client_id does not match') for data in request_data.keys(): if data in jwt_request: request_data[data] = jwt_request[data] # Split these options since they are space-separated lists for to_split in ['prompt', 'ui_locales', 'acr_values', 'scope']: if request_data[to_split] is not None: # We know better than pylint in this regard # pylint: disable=no-member request_data[to_split] = request_data[to_split].split(' ') else: request_data[to_split] = [] # Start checking the request if request_data['redirect_uri'] is None: if len(client['redirect_uris']) != 1: return self._respond_error(request_data, 'invalid_request', 'missing redirect_uri') else: request_data['redirect_uri'] = client['redirect_uris'][0] for scope in request_data['scope']: if scope not in self.cfg.supported_scopes: return self._respond_error( request_data, 'invalid_scope', 'unknown scope %s requested' % scope) for response_type in request_data['response_type']: if response_type not in ['code', 'id_token', 'token']: return self._respond_error( request_data, 'unsupported_response_type', 'response_type %s is unknown' % response_type) if request_data['redirect_uri'] not in client['redirect_uris']: raise InvalidRequest('Invalid redirect_uri') # Build the "claims" values from scopes try: request_data['claims'] = json.loads(request_data['claims']) except Exception as ex: # pylint: disable=broad-except return self._respond_error(request_data, 'invalid_request', 'claims malformed: %s' % ex) if 'userinfo' not in request_data['claims']: request_data['claims']['userinfo'] = {} if 'id_token' not in request_data['claims']: request_data['claims']['id_token'] = {} scopes_to_claim = { 'profile': [ 'name', 'family_name', 'given_name', 'middle_name', 'nickname', 'preferred_username', 'profile', 'picture', 'website', 'gender', 'birthdate', 'zoneinfo', 'locale', 'updated_at' ], 'email': ['email', 'email_verified'], 'address': ['address'], 'phone': ['phone_number', 'phone_number_verified'] } for scope in scopes_to_claim: if scope in request_data['scope']: for claim in scopes_to_claim[scope]: if claim not in request_data['claims']: # pylint: disable=invalid-sequence-index request_data['claims']['userinfo'][claim] = None # Add claims from extensions for n, e in self.cfg.extensions.available().items(): data = e.get_claims(request_data['scope']) self.debug('%s returned %s' % (n, repr(data))) for claim in data: # pylint: disable=invalid-sequence-index request_data['claims']['userinfo'][claim] = None # Store data so we can continue with the request us = UserSession() user = us.get_user() returl = '%s/%s/Continue?%s' % (self.basepath, URLROOT, self.trans.get_GET_arg()) data = { 'login_target': client.get('client_name', None), 'login_return': returl, 'openidc_stage': 'continue', 'openidc_request': json.dumps(request_data) } if request_data['login_hint']: data['login_username'] = request_data['login_hint'] if not data['login_target']: data['login_target'] = get_url_hostpart( request_data['redirect_uri']) # Decide what to do with the request if request_data['max_age'] is None: request_data['max_age'] = client.get('default_max_age', None) needs_auth = True if not user.is_anonymous: if request_data['max_age'] in [None, 0]: needs_auth = False else: auth_time = us.get_user_attrs()['_auth_time'] needs_auth = ((int(auth_time) + int(request_data['max_age'])) <= int(time.time())) if needs_auth or 'login' in request_data['prompt']: if 'none' in request_data['prompt']: # We were asked not to provide a UI. Answer with false. return self._respond_error(request_data, 'login_required', 'user interface required') # Either the user wasn't logged in, or we were explicitly # asked to re-auth them. Let's do so! us.logout(user) # Let the user go to auth self.trans.store(data) redirect = '%s/login?%s' % (self.basepath, self.trans.get_GET_arg()) self.debug('Redirecting: %s' % redirect) raise cherrypy.HTTPRedirect(redirect) # Return error if authz check fails authz_check_res = self._authz_stack_check(request_data, client, user.name, us.get_user_attrs()) if authz_check_res: return authz_check_res self.trans.store(data) # The user was already signed on, and no request to re-assert its # identity. Let's forward directly to /Continue/ self.debug('Redirecting: %s' % returl) raise cherrypy.HTTPRedirect(returl)
class IdpProvider(ProviderBase): def __init__(self, *pargs): super(IdpProvider, self).__init__("openidc", "openidc", *pargs) self.mapping = InfoMapping() self.keyset = None self.page = None self.datastore = None self.server = None self.basepath = None self.extensions = LoadExtensions() self.description = """ Provides OpenID Connect authentication infrastructure. """ self.new_config( self.name, pconfig.String("database url", "Database URL for OpenID Connect storage", "openidc.sqlite"), pconfig.Choice("enabled extensions", "Choose the extensions to enable", self.extensions.available().keys()), pconfig.String( "endpoint url", "The Absolute URL of the OpenID Connect provider", "http://localhost:8080/idp/openidc/" ), pconfig.String( "documentation url", "The Absolute URL of the OpenID Connect documentation", "https://ipsilonproject.org/doc/openidc/", ), pconfig.String("policy url", "The Absolute URL of the OpenID Connect policy", "http://www.example.com/"), pconfig.String( "tos url", "The Absolute URL of the OpenID Connect terms of service", "http://www.example.com/" ), pconfig.String("idp key file", "The file where the OpenIDC keyset is stored.", "openidc.key"), pconfig.String("idp sig key id", "The key to use for signing.", ""), pconfig.String("idp subject salt", "The salt used for pairwise subjects.", None), pconfig.MappingList("default attribute mapping", "Defines how to map attributes", [["*", "*"]]), pconfig.ComplexList( "default allowed attributes", "Defines a list of allowed attributes, applied after mapping", ["*"] ), ) @property def endpoint_url(self): url = self.get_config_value("endpoint url") if url.endswith("/"): return url else: return url + "/" @property def documentation_url(self): url = self.get_config_value("documentation url") if url.endswith("/"): return url else: return url + "/" @property def policy_url(self): url = self.get_config_value("policy url") if url.endswith("/"): return url else: return url + "/" @property def tos_url(self): url = self.get_config_value("tos url") if url.endswith("/"): return url else: return url + "/" @property def enabled_extensions(self): return self.get_config_value("enabled extensions") @property def idp_key_file(self): return self.get_config_value("idp key file") @property def idp_sig_key_id(self): return self.get_config_value("idp sig key id") @property def idp_subject_salt(self): return self.get_config_value("idp subject salt") @property def default_attribute_mapping(self): return self.get_config_value("default attribute mapping") @property def default_allowed_attributes(self): return self.get_config_value("default allowed attributes") @property def supported_scopes(self): supported = ["openid"] # Default scopes used in OpenID Connect claims supported.extend(["profile", "email", "address", "phone"]) for _, ext in self.extensions.available().items(): supported.extend(ext.get_scopes()) return supported def get_tree(self, site): self.page = OpenIDC(site, self) # self.admin = AdminPage(site, self) return self.page def used_datastores(self): return [self.datastore] def init_idp(self): self.keyset = JWKSet() with open(self.idp_key_file, "r") as keyfile: loaded_keys = json.loads(keyfile.read()) for key in loaded_keys["keys"]: self.keyset.add(JWK(**key)) self.datastore = OpenIDCStore(self.get_config_value("database url")) def openid_connect_issuer_wf_rel(self, resource): link = {"rel": "http://openid.net/specs/connect/1.0/issuer", "href": self.endpoint_url} return {"links": [link]} def on_enable(self): super(IdpProvider, self).on_enable() self.init_idp() self.extensions.enable(self._config["enabled extensions"].get_value(), self) self._root.webfinger.register_rel( "http://openid.net/specs/connect/1.0/issuer", self.openid_connect_issuer_wf_rel ) def on_disable(self): super(IdpProvider, self).on_enable() self._root.webfinger.unregister_rel("http://openid.net/specs/connect/1.0/issuer")
def oidc_provider_jwkset(): key = JWK.generate(kty='RSA', size=512) jwkset = JWKSet() jwkset.add(key) return jwkset
class JWTValidator: def __init__(self, jwks_urls: "Optional[Union[str, Collection[str]]]" = None): from jwcrypto.jwk import JWKSet self.jwks_urls = jwks_urls self.keys = JWKSet() self.session = None # type: Optional[ClientSession] async def poll(self, poll_interval: timedelta = DEFAULT_POLL_INTERVAL): """ Periodically check for new keys. This coroutine will NEVER terminate naturally, so it should not be awaited. """ while True: await sleep(poll_interval.total_seconds()) LOG.debug("JWKS poller polling for new keys...") await shield(self._load_new_keys()) LOG.debug("...JWKS poll complete.") async def decode_claims(self, token: str) -> "JWTClaims": from jwcrypto.jwt import JWT LOG.debug("Verifying token: %r", token) jwt = JWT(jwt=token) key = jwt.token.jose_header["kid"] await self.get_key(key) jwt = JWT(jwt=token, key=self.keys) return json.loads(jwt.claims) async def get_key(self, kid: str) -> "Optional[JWK]": """ Retrieve a key for a given ``kid``. If the key could not be found, the keystore is refreshed, and if a key is discovered through that process, it is returned. May return ``None`` if even after refreshing keys from the remote JWKS source, a key for the given ``kid`` could not be found. """ key = self.keys.get_key(kid) if key is None: await shield(self._load_new_keys()) key = self.keys.get_key(kid) return key def export_all_keys(self) -> "str": """ Return a JSON string that formats all public keys currently in the store as JWKS. """ return self.keys.export(private_keys=False) async def _load_new_keys(self): from jwcrypto.jwk import JWK for url in self.jwks_urls: try: if self.session is None: self.session = ClientSession() async with self.session.get(url, allow_redirects=False) as response: jwks_json = await response.json() # ``JWKSet.import_keyset`` suffers from a few critical flaws that make it # unusable for us: # 1) ``import_keyset`` internally adds keys to a set, which is semantically # correct. However, because JWK has no __eq__ or __hash__ implementation, # EVERY key is repeatedly appended to the set rather than duplicates # getting filtered out. # 2) The previous point necessitates that we pre-process the data to filter out # keys that we do not wish to add, thereby requiring us to parse the JSON # and read the payload. ``import_keyset`` expects its argument as a serialized # JSON string, which it promptly parses back into a data structure. # # We process JWKS endpoints ourselves and selectively add keys directly to the # implementation. We are NOT reimplementing ``import_keyset``'s functionality of # carrying additional non-``keys`` fields into the ``JWKSet`` object. We are # generating JWKS data ourselves, and always generate data that contains only the # single top-level property of ``keys`` so this has no impact on us. jwks_keys = jwks_json.get("keys") if jwks_keys is None: LOG.warning( 'The JWKS endpoint did not return a "keys" property, so no new ' "keys were added. This will be retried") continue existing_kids = {k.key_id for k in self.keys} for jwk_dict in jwks_keys: kid = jwk_dict.get("kid") if kid is None: LOG.warning( 'The JWKS endpoint contained a key without a "kid" field. ' "It will be dropped.") elif kid in existing_kids: LOG.debug( "We already know about kid %s, so the new value will be " "ignored.", kid, ) else: jwk = None try: jwk = JWK(**jwk_dict) except Exception: # noqa LOG.exception( f"The JWK identified by {kid} could not be parsed." ) if jwk is not None: try: self.keys.add(jwk) except Exception: # noqa LOG.exception( f"The JWK identified by {kid} could not be added." ) except Exception as ex: # noqa # Do NOT log these with full stack traces because they're actually fairly common, # particularly at startup when user-service has yet to start. Merely logging the # text of the exception without a scary stack trace is sufficient. LOG.warning("Error when checking url %r for new keys: %s", url, ex) async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): if self.session is not None: await self.session.close()