def decode(token, secret=None, privkey=None, pubkey=None, algorithms=['RS256','HS256']): decoded = None try: js = json.loads(token) except ValueError as err: options = { 'verify_signature': False, #We don't need to verify as its symmetric 'verify_exp': True, 'verify_nbf': False, 'verify_iat': True, 'verify_aud': False } #Version 1 return jwt.decode(token, secret, algorithms=algorithms, options=options) else: # Version 2 copy = json.loads(token) del copy['sigs'] if hm(json.dumps(copy,separators=(',', ':')),pubkey) != js['sigs']['_']['hash']: raise Exception('hash mismatch') private_key = jwk.construct(privkey, "RS256").to_dict() public_key = jwk.construct(pubkey, "RS256").to_dict() sig = jws.verify(js['sigs']['_']['sig'], public_key, algorithms, verify=True) if js['priv'] != None: js['priv'] = json.loads(jws.verify(js['priv'], public_key, algorithms, verify=False)) if str(sig,'utf-8').replace('"',"") != js['sigs']['_']['hash']: raise Exception('bad sig') return js
def get_PEM_JWK(private_key): """ This class gets the PEM and JWK representation of the private key from the Okta Client configuration. Args: private_key (str or dict): Either a dict representing a JWK, str representing the JWK, filepath to PEM, or PEM string Returns: RSAKey, JWK: Tuple of the generated PEM, and JWK """ # Start off with both as None my_jwk = None my_pem = None # check if JWK # String representation of dictionary or dict if ((type(private_key) == str and private_key.startswith("{")) or type(private_key) == dict): # if string repr, convert to dict object if (type(private_key) == str): private_key = literal_eval(private_key) # Create JWK using dict obj my_jwk = jwk.construct(private_key, JWT.HASH_ALGORITHM) else: # it's a PEM # check for filepath or explicit private key if os.path.exists(private_key): # open file if exists and import key pem_file = open(private_key, 'r') my_pem = RSA.import_key(pem_file.read()) pem_file.close() else: # convert given string to bytes and import key private_key_bytes = bytes(private_key, 'ascii') my_pem = RSA.import_key(private_key_bytes) if not my_pem: # return error if import failed return ( None, ValueError("RSA Private Key given is of the wrong type")) if my_jwk: # was JWK provided # get PEM using JWK pem_bytes = my_jwk.to_pem(JWT.PEM_FORMAT) my_pem = RSA.import_key(pem_bytes) else: # was pem provided # get JWK using PEM my_jwk = jwk.construct(my_pem.export_key(), JWT.HASH_ALGORITHM) return (my_pem, my_jwk)
def test_construct_from_jwk_missing_alg(self): hmac_key = { "kty": "oct", "kid": "018c0ae5-4d9b-471b-bfd6-eef314bc7037", "use": "sig", "k": "hJtXIZ2uSN5kbQfbtTNWbpdmhkV8FJG-Onbc6mxCcYg" } with pytest.raises(JWKError): key = jwk.construct(hmac_key) with pytest.raises(JWKError): key = jwk.construct("key", algorithm="NONEXISTENT") # noqa: F841
def key(self, key_name): ssm_namespace = self.config("secret_manager_ssm_path", namespace="cis", default="/iam") ssm_response = self.ssm_client.get_parameter(Name="{}/{}".format( ssm_namespace, key_name), WithDecryption=True) result = ssm_response.get("Parameter") try: key_dict = json.loads(result.get("Value")) key_construct = jwk.construct(key_dict, "RS256") except json.decoder.JSONDecodeError: key_construct = jwk.construct(result.get("Value"), "RS256") return key_construct
def test_response_values(app, client): """ Do some more thorough checking on the response obtained from the JWKS endpoint. Because fence only uses the RSA algorithm for signing and validating JWTs, the ``alg``, ``kty``, ``use``, and ``key_ops`` fields are hard-coded for this. Furthermore, every JWK in the response should have values for the RSA public modulus ``n`` and exponent ``e`` which may be used to reconstruct the public key. """ keys = client.get("/.well-known/jwks").json["keys"] app_kids = [keypair.kid for keypair in app.keypairs] app_public_keys = [keypair.public_key for keypair in app.keypairs] for key in keys: assert key["alg"] == "RS256" assert key["kty"] == "RSA" assert key["use"] == "sig" assert key["key_ops"] == "verify" assert key["kid"] in app_kids # Attempt to reproduce the public key from the JWK response. key_pem = jwk.construct(key).to_pem() assert key_pem in app_public_keys
def validate_and_return_id_token(self, id_token, access_token): """ Validates the id_token according to the steps at http://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation. """ client_id, client_secret = self.get_key_and_secret() key = self.find_valid_key(id_token) if not key: raise AuthTokenError(self, 'Signature verification failed') rsakey = jwk.construct(key, algorithm=ALGORITHMS.RS256) try: claims = jwt.decode( id_token, rsakey.to_pem().decode('utf-8'), algorithms=[ALGORITHMS.HS256, ALGORITHMS.RS256, ALGORITHMS.ES256], audience=client_id, issuer=self.id_token_issuer(), access_token=access_token, options=self.JWT_DECODE_OPTIONS, ) except ExpiredSignatureError: raise AuthTokenError(self, 'Signature has expired') except JWTClaimsError as error: raise AuthTokenError(self, str(error)) except JWTError: raise AuthTokenError(self, 'Invalid signature') self.validate_claims(claims)
def __init__(self, wallet, **kwargs): self.jwk_data = wallet.jwk_data self.jwk = jwk.construct(self.jwk_data, algorithm="RS256") self.wallet = wallet self.id = kwargs.get('id', '') self.last_tx = wallet.get_last_transaction_id() self.owner = self.jwk_data.get('n') self.tags = [] self.format = kwargs.get('format', 2) self.api_url = API_URL self.chunks = None data = kwargs.get('data', '') self.data_size = len(data) if type(data) is bytes: self.data = base64url_encode(data) else: self.data = base64url_encode(data.encode('utf-8')) if self.data is None: self.data = '' self.file_handler = kwargs.get('file_handler', None) if self.file_handler: self.uses_uploader = True self.data_size = os.stat(kwargs['file_path']).st_size else: self.uses_uploader = False if kwargs.get('transaction'): self.from_serialized_transaction(kwargs.get('transaction')) else: self.data_root = "" self.data_tree = [] self.target = kwargs.get('target', '') self.to = kwargs.get('to', '') if self.target == '' and self.to != '': self.target = self.to self.quantity = kwargs.get('quantity', '0') if float(self.quantity) > 0: if self.target == '': raise ArweaveTransactionException( "Unable to send {} AR without specifying a target address".format(self.quantity)) # convert to winston self.quantity = ar_to_winston(float(self.quantity)) reward = kwargs.get('reward', None) if reward is not None: self.reward = reward self.signature = '' self.status = None
def lambda_handler(event, context): token = event['token'] # get the kid from the headers prior to verification headers = jwt.get_unverified_headers(token) kid = headers['kid'] # search for the kid in the downloaded public keys key_index = -1 for i in range(len(keys)): if kid == keys[i]['kid']: key_index = i break if key_index == -1: print('Public key not found in jwks.json') return False # construct the public key public_key = jwk.construct(keys[key_index]) # get the last two sections of the token, # message and signature (encoded in base64) message, encoded_signature = str(token).rsplit('.', 1) # decode the signature decoded_signature = base64url_decode(encoded_signature) # verify the signature if public_key.verify(message, decoded_signature): print('Signature successfully verified') # since we passed the verification, now we can safely # get the unverified claims claims = jwt.get_unverified_claims(token) # do some stuff with our claims print(claims) return claims else: print('Signature verification failed') return False
def lambda_handler(event, context): print(event) token = event['request']['headers']['v-cognito-user-jwt'] headers = jwt.get_unverified_headers(token) kid = headers['kid'] key_index = -1 for i in range(len(KEYS)): if kid == KEYS[i]['kid']: key_index = i break if key_index == -1: print('Public key not found in jwks.json') return 'jwks failed validation - return Public key' public_key = jwk.construct(KEYS[key_index]) message, encoded_signature = str(token).rsplit('.', 1) decoded_signature = base64url_decode(encoded_signature.encode('utf-8')) if not public_key.verify(message.encode('utf8'), decoded_signature): print('Signature verification failed') return 'jwks failed validation - return Signature' claims = jwt.get_unverified_claims(token) print(claims) print(claims['email']) claims['cognito_username'] = claims['cognito:username'] claims['cognito_groups'] = claims['cognito:groups'] return claims
def get_username_from_token(self, token): if self.keys is None: (self.keys, self.keys_iss) = self.__get_userpool_keys() try: headers = jwt.get_unverified_header(token) except: return None if not headers.get('kid'): return None kid = headers['kid'] key_index = -1 for i in range(len(self.keys)): if kid == self.keys[i]['kid']: key_index = i break if key_index == -1: return None public_key = jwk.construct(self.keys[key_index]) message, encoded_signature = str(token).rsplit('.', 1) decoded_signature = base64url_decode(encoded_signature.encode('utf-8')) if not public_key.verify(message.encode("utf8"), decoded_signature): return None claims = jwt.get_unverified_claims(token) return claims["username"]
def _make_event( authorization="jwt", groups=[], is_expired=None, is_kid_match=True, is_signature_valid=True, user_id=str(uuid4()), ): event = {"headers": {"Authorization": None}} if authorization is None: return event payload = { "sub": user_id, "exp": time.time() + 3600 * (-1 if is_expired else 1), } if len(groups): payload["cognito:groups"] = groups token = jwt.encode( payload, jwk.construct(TEST_KEY).to_pem(), algorithm=TEST_KEY["alg"], headers={"kid": TEST_KEY["kid"] if is_kid_match else "nope"}, ) if not is_signature_valid: token += "nope" event["headers"]["Authorization"] = token return event
def verify_signature(self, token, pkey_data): """ #Verifica si no a modificado Payload del token #Mediante la signature Args: token (string): jwt pkey_data ([type]): [description] Raises: VerifyTokenException: [description] VerifyTokenException: [description] """ try: # Key data public_key = jwk.construct(pkey_data) except JOSEError: raise CustomException(status_code=401, type='signature', detail='Token inválido') #Obtiene Signature message, encoded_signature = str(token).rsplit(".", 1) # Decodifica Signature decoded_signature = base64url_decode(encoded_signature.encode("utf-8")) # Verifica Signature if not public_key.verify(message.encode("utf8"), decoded_signature): raise CustomException( status_code=401, type='signature', detail='Token a sido Modificado, Signature inválido!')
def validate_and_return_id_token(self, id_token, access_token): """ Validates the id_token according to the steps at http://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation. """ client_id, client_secret = self.get_key_and_secret() key = self.find_valid_key(id_token) alg = key['alg'] rsakey = jwk.construct(key) try: claims = jwt.decode( id_token, rsakey.to_pem().decode('utf-8'), algorithms=[alg], audience=client_id, issuer=self.id_token_issuer(), access_token=access_token ) except ExpiredSignatureError: raise AuthTokenError(self, 'Signature has expired') except JWTClaimsError: raise AuthTokenError(self, 'Invalid claims') except JWTError: raise AuthTokenError(self, 'Invalid signature') self.validate_claims(claims) return claims
def validate_token(access_token, issuer, audience, client_ids): # Client ID's list cid_list = [] if not isinstance(client_ids, list): cid_list = client_ids.split(',') else: cid_list = client_ids check_presence_of(access_token, issuer, audience, cid_list) # Decoding Header & Payload from token header = jwt.get_unverified_header(access_token) payload = jwt.get_unverified_claims(access_token) # Verifying Claims verify_claims(payload, issuer, audience, cid_list) # Verifying Signature jwks_key = fetch_jwk_for(header, payload) key = jwk.construct(jwks_key) message, encoded_sig = access_token.rsplit('.', 1) decoded_sig = base64url_decode(encoded_sig.encode('utf-8')) valid = key.verify(message.encode(), decoded_sig) # If the token is valid, it returns the payload if valid == True: return payload else: raise Exception('Invalid Token')
def parse_incomming_data(self, data): key = self.find_valid_key(data) rsakey = jwk.construct(key) return jwt.decode(data, rsakey.to_pem().decode('utf-8'), algorithms=key['alg'], options={"verify": False})
def token_decoder(token): set_cognito_data_global() headers = jwt.get_unverified_headers(token) kid = headers['kid'] result = {} for item in POOL_KEYS: result[item['kid']] = item public_key = jwk.construct(result.get(kid)) message, encoded_signature = str(token).rsplit('.', 1) decoded_signature = base64url_decode(encoded_signature.encode('utf-8')) if not public_key.verify(message.encode("utf8"), decoded_signature): app.logger.error('Signature verification failed') raise Exception app.logger.debug('Signature successfully verified') claims = jwt.get_unverified_claims(token) if time.time() > claims['exp']: app.logger.error('Token is expired') raise Exception app.logger.debug(claims) return claims
def make_jwk(key, kid): jwks = jwk.construct(key, "RS256").to_dict() jwks["e"] = jwks["e"].decode("utf-8") jwks["n"] = jwks["n"].decode("utf-8") jwks["use"] = "sig" jwks["kid"] = kid return {"keys": [jwks]}
def _get_token(self): # Fetch key from server http_protocol = "http" if DMC_INSECURE else "https" self._logger.info( "attempting to authenticate on %s://%s" % (http_protocol, DMC_URI) ) key = requests.get('{}://{}/drone/keys.json'.format(http_protocol, DMC_URI)) if key.status_code >= 299: self._logger.error("could not fetch keys, server error") raise AuthException('could not access keys: {}'.format(key.status_code)) key = key.json() # Authenticate anip_uri = DMC_ANIP_URI or "{}/drone/auth".format(DMC_URI) self._logger.info("fetching token from %s" % anip_uri) ret = requests.post("{}://{}".format(http_protocol, anip_uri), json={ 'id': self.drone_id, 'password': self.password }) if ret.status_code >= 500: raise AuthException('could not fetch token, server error') elif ret.status_code > 299: raise AuthException('unauthorized: {}'.format(ret)) # Verify recieved token with key from server token = ret.json()['token'] rsa_key = jwk.construct(key) message, encoded_sig = token.rsplit('.', 1) decoded_sig = base64url_decode(str(encoded_sig)) if not rsa_key.verify(message, decoded_sig): raise AuthException('invalid token') payload = message.split('.')[1] auth = json.loads(base64url_decode(str(payload))) return token, auth
def validateJWTToken(token): current_time = (datetime.utcnow() - datetime(1970, 1, 1)).total_seconds() token_parts = token.split('.') idTokenHeader = json.loads( base64.b64decode(token_parts[0]).decode('ascii')) idTokenPayload = json.loads( base64.b64decode(incorrect_padding(token_parts[1])).decode('ascii')) if idTokenPayload['iss'] != settings.ID_TOKEN_ISSUER: return False elif idTokenPayload['aud'][0] != settings.CLIENT_ID: return False elif idTokenPayload['exp'] < current_time: return False token = token.encode() token_to_verify = token.decode("ascii").split('.') message = token_to_verify[0] + '.' + token_to_verify[1] idTokenSignature = base64.urlsafe_b64decode( incorrect_padding(token_to_verify[2])) keys = getKeyFromJWKUrl(idTokenHeader['kid']) publicKey = jwk.construct(keys) return publicKey.verify(message.encode('utf-8'), idTokenSignature)
def verified_claims(token: str, app_client_id: str = os.environ.get('COGNITO_CLIENT_ID'), public_keys: dict = None) -> bool: """ verified token signature, expiry and app_id then returns claims """ headers = jwt.get_unverified_headers(token) kid = headers['kid'] public_keys = public_keys or get_cognito_public_keys() public_key = jwk.construct(public_keys[kid]) message, encoded_signature = str(token).rsplit('.', 1) decoded_signature = base64url_decode(encoded_signature.encode('utf-8')) if not public_key.verify(message.encode("utf8"), decoded_signature): raise RuntimeError(f'Signature verification failed') claims = jwt.get_unverified_claims(token) if time.time() > claims['exp']: expiry = datetime.fromtimestamp(claims['exp']) raise RuntimeError( f"Token expired at {expiry}, it is now {datetime.now()}") if app_client_id and claims['client_id'] != app_client_id: raise RuntimeError( f"Token audience {claims['aud']} was not issued for this audience {app_client_id}" ) return claims
def validate_and_return_id_token(self, id_token, access_token): """ Validates the id_token according to the steps at http://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation. """ client_id, client_secret = self.get_key_and_secret() key = self.find_valid_key(id_token) alg = key['alg'] rsakey = jwk.construct(key) try: claims = jwt.decode( id_token, rsakey.to_pem().decode('utf-8'), algorithms=[alg], audience=client_id, issuer=self.id_token_issuer(), access_token=access_token, options=self.JWT_DECODE_OPTIONS, ) except ExpiredSignatureError: raise AuthTokenError(self, 'Signature has expired') except JWTClaimsError: raise AuthTokenError(self, 'Invalid claims') except JWTError: raise AuthTokenError(self, 'Invalid signature') self.validate_claims(claims) return claims
def validate_id_token(id_token, client_id, intuit_issuer, jwk_uri): """Validates ID Token returned by Intuit :param id_token: ID Token :param client_id: Client ID :param intuit_issuer: Intuit Issuer :param jwk_uri: JWK URI :return: True/False """ id_token_parts = id_token.split('.') if len(id_token_parts) < 3: return False id_token_header = json.loads(b64decode(_correct_padding(id_token_parts[0])).decode('ascii')) id_token_payload = json.loads(b64decode(_correct_padding(id_token_parts[1])).decode('ascii')) id_token_signature = urlsafe_b64decode(((_correct_padding(id_token_parts[2])).encode('ascii'))) if id_token_payload['iss'] != intuit_issuer: return False elif id_token_payload['aud'][0] != client_id: return False current_time = (datetime.utcnow() - datetime(1970, 1, 1)).total_seconds() if id_token_payload['exp'] < current_time: return False message = id_token_parts[0] + '.' + id_token_parts[1] keys_dict = get_jwk(id_token_header['kid'], jwk_uri) public_key = jwk.construct(keys_dict) is_signature_valid = public_key.verify(message.encode('utf-8'), id_token_signature) return is_signature_valid
def token_decoder(token): """ The ID token expires one hour after the user authenticates. You should not process the ID token in your client or web API after it has expired. https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-with-identity-providers.html :param token: :return: """ set_cognito_data_global() token = remove_barer(token) headers = jwt.get_unverified_headers(token) kid = headers['kid'] result = {} for item in POOL_KEYS: result[item['kid']] = item public_key = jwk.construct(result.get(kid)) message, encoded_signature = str(token).rsplit('.', 1) decoded_signature = base64url_decode(encoded_signature.encode('utf-8')) if not public_key.verify(message.encode("utf8"), decoded_signature): raise UnauthorizedError('Invalid Token') claims = jwt.get_unverified_claims(token) if time.time() > claims['exp']: raise UnauthorizedError('Token in expired!') return claims
def __init__(self, wallet, **kwargs): self.jwk_data = wallet.jwk_data self.jwk = jwk.construct(self.jwk_data, algorithm="RS256") self.id = kwargs.get('id', '') self.last_tx = wallet.get_last_transaction_id() self.owner = self.jwk_data.get('n') self.tags = [] self.quantity = kwargs.get('quantity', '0') data = kwargs.get('data', '') if type(data) is bytes: self.data = base64url_encode(data) else: self.data = base64url_encode(data.encode('ascii')) self.target = kwargs.get('target', '') self.to = kwargs.get('to', '') self.api_url = "https://arweave.net" reward = kwargs.get('reward', None) if reward != None: self.reward = reward else: self.reward = self.get_reward(self.data) self.signature = '' self.status = None
def __validateSignatureJWS(self, token): userPoolId = settings.AWS_CONFIG['USER_POOL_ID'] region = settings.AWS_CONFIG['REGION'] keysUrl = 'https://cognito-idp.{}.amazonaws.com/{}/.well-known/jwks.json'.format( region, userPoolId) with urllib.request.urlopen(keysUrl) as f: response = f.read() keys = json.loads(response.decode('utf-8'))['keys'] headers = jwt.get_unverified_headers(token) kid = headers['kid'] keyIndex = -1 for i in range(len(keys)): if kid == keys[i]['kid']: keyIndex = i break if keyIndex == -1: return False, 'Public key not found in jwks.json' publicKey = jwk.construct(keys[keyIndex]) payload, signature = str(token).rsplit('.', 1) decodedSignature = base64url_decode(signature.encode('utf-8')) if not publicKey.verify(payload.encode("utf8"), decodedSignature): return False, 'Signature verification failed' return True, None
def get_and_verify_claims(token): # get the kid from the headers prior to verification headers = jwt.get_unverified_headers(token) kid = headers['kid'] # search for the kid in the downloaded public keys key_index = -1 for i in range(len(keys)): if kid == keys[i]['kid']: key_index = i break if key_index == -1: raise Exception("Public key not found in jwks.json") # construct the public key public_key = jwk.construct(keys[key_index]) # get the last two sections of the token, # message and signature (encoded in base64) message, encoded_signature = str(token).rsplit('.', 1) # decode the signature decoded_signature = base64url_decode(encoded_signature.encode('utf-8')) # verify the signature if not public_key.verify(message.encode("utf8"), decoded_signature): raise Exception('Signature verification failed') # since we passed the verification, we can now safely # use the unverified claims claims = jwt.get_unverified_claims(token) # additionally we can verify the token expiration if time.time() > claims['exp']: raise Exception('Token is expired') # and the Audience (use claims['client_id'] if verifying an access token) if claims['aud'] != app_client_id: raise Exception('Token was not issued for this audience') # now we can use the claims return claims
def get_claims(logger: logging.Logger, token: str) -> Dict[str, Union[str, int]]: # client_id = os.environ["COGNITO_APP_CLIENT_ID"] # get the kid from the headers prior to verification headers: Dict[str, str] = cast(Dict[str, str], jwt.get_unverified_headers(token=token)) kid: str = headers["kid"] # search for the kid in the downloaded public keys for key in _get_keys(logger): if kid == key["kid"]: # construct the public key public_key = jwk.construct(key_data=key) break else: raise ValueError("Public key not found in JWK.") # get the last two sections of the token, # message and signature (encoded in base64) message, encoded_signature = token.rsplit(".", 1) # decode the signature decoded_signature: bytes = base64url_decode(encoded_signature.encode("utf-8")) # verify the signature if public_key.verify(msg=message.encode("utf8"), sig=decoded_signature) is False: raise RuntimeError("Signature verification failed.") logger.debug("Signature validaded.") # since we passed the verification, we can now safely use the unverified claims claims = jwt.get_unverified_claims(token) # additionally we can verify the token expiration if time.time() > int(claims["exp"]): raise ValueError("Token expired.") logger.debug("Token not expired.") logger.debug("claims: %s", claims) return cast(Dict[str, Union[str, int]], claims)
def test_ssm_provider(self): from cis_crypto import secret os.environ["CIS_SECRET_MANAGER_SSM_PATH"] = "/baz" key_dir = "tests/fixture/" key_name = "fake-access-file-key" file_name = "{}.priv.pem".format(key_name) fh = open((os.path.join(key_dir, file_name)), "rb") key_content = fh.read() key_construct = jwk.construct(key_content, "RS256") key_dict = key_construct.to_dict() for k, v in key_dict.items(): if isinstance(v, bytes): key_dict[k] = v.decode() deserialized_key_dict = json.dumps(key_dict) client = boto3.client("ssm", region_name="us-west-2") client.put_parameter( Name="/baz/{}".format(key_name), Description="A secure test parameter", Value=deserialized_key_dict, Type="SecureString", KeyId="alias/aws/ssm", ) manager = secret.Manager(provider_type="aws-ssm") key_material = manager.get_key("fake-access-file-key") assert key_material is not None
def find_key(token, keys): """ Find a key from the configured keys based on the kid claim of the token Parameters ---------- token : token to search for the kid from keys: list of keys Raises ------ KeyNotFoundError: returned if the token does not have a kid claim Returns ------ key: found key object """ unverified = jwt.get_unverified_header(token) kid = unverified.get("kid") if not kid: raise KeyNotFoundError("No 'kid' in token") for key in keys: if key["kid"] == kid: return jwk.construct(key) return KeyNotFoundError( f"Token specifies {kid} but we have {[k['kid'] for k in keys]}")
def find_valid_key(self, id_token): for key in self.get_jwks_keys(): rsakey = jwk.construct(key) message, encoded_sig = id_token.rsplit('.', 1) decoded_sig = base64url_decode(encoded_sig.encode('utf-8')) if rsakey.verify(message.encode('utf-8'), decoded_sig): return key
def get_claims(token: str) -> Dict[str, Union[str, int]]: # get the kid from the headers prior to verification headers: Dict[str, str] = cast(Dict[str, str], jwt.get_unverified_headers(token=token)) kid: str = headers["kid"] # search for the kid in the downloaded public keys for key in _get_keys(): if kid == key["kid"]: # construct the public key public_key = jwk.construct(key_data=key) break else: raise ValueError("Public key not found in JWK.") # get the last two sections of the token, # message and signature (encoded in base64) message, encoded_signature = token.rsplit(".", 1) # decode the signature decoded_signature: bytes = base64url_decode(encoded_signature.encode("utf-8")) # verify the signature if public_key.verify(msg=message.encode("utf8"), sig=decoded_signature) is False: raise RuntimeError("Signature verification failed.") print("Signature validaded.") # since we passed the verification, we can now safely use the unverified claims claims: Dict[str, Union[str, int]] = cast(Dict[str, Union[str, int]], jwt.get_unverified_claims(token)) # additionally we can verify the token expiration if time.time() > int(claims["exp"]): raise ValueError("Token expired.") print("Token not expired.") # and the Audience (use claims['client_id'] if verifying an access token) if claims["aud"] != COGNITO_USER_POOL_CLIENT_ID: raise ValueError("Token was not issued for this audience.") # now we can use the claims return claims
def find_valid_key(self, id_token): for key in self.get_jwks_keys(): rsakey = jwk.construct(key) message, encoded_sig = id_token.rsplit('.', 1) decoded_sig = base64url_decode(encoded_sig.encode('utf-8')) if rsakey.verify(message.encode('utf-8'), decoded_sig): return key return False
def _sig_matches_keys(keys, signing_input, signature, alg): for key in keys: key = jwk.construct(key, alg) try: if key.verify(signing_input, signature): return True except Exception: pass return False
def test_construct_from_jwk(self): hmac_key = { "kty": "oct", "kid": "018c0ae5-4d9b-471b-bfd6-eef314bc7037", "use": "sig", "alg": "HS256", "k": "hJtXIZ2uSN5kbQfbtTNWbpdmhkV8FJG-Onbc6mxCcYg" } key = jwk.construct(hmac_key) assert isinstance(key, jwk.Key)
def _sign_header_and_claims(encoded_header, encoded_claims, algorithm, key_data): signing_input = b'.'.join([encoded_header, encoded_claims]) try: key = jwk.construct(key_data, algorithm) signature = key.sign(signing_input) except Exception as e: raise JWSError(e) encoded_signature = base64url_encode(signature) encoded_string = b'.'.join([encoded_header, encoded_claims, encoded_signature]) return encoded_string.decode('utf-8')
def _sig_matches_keys(keys, signing_input, signature, alg): for key in keys: key = jwk.construct(key, alg) if key.verify(signing_input, signature): return True return False