def setUp(self): ec_jwk = jwk.ECKey(use=u"sig").load_key(ecc.P256) ec_jwk.kid = self._ec_kid rsa_key = jwk.RSAKey(use=u"sig").load_key(PublicKey.RSA.generate(1024)) rsa_key.kid = self._rsa_kid jwks = jwk.KEYS() jwks._keys.append(ec_jwk) jwks._keys.append(rsa_key) self._issuers_to_provider_ids = {} self._jwks_supplier = mock.MagicMock() self._authenticator = tokens.Authenticator(self._issuers_to_provider_ids, self._jwks_supplier) self._jwks = jwks self._jwks_supplier.supply.return_value = self._jwks self._method_info = mock.MagicMock() self._service_name = u"service.name.com" self._jwt_claims = { u"aud": [u"first.com", u"second.com"], u"email": u"*****@*****.**", u"exp": int(time.time()) + 10, u"iss": u"https://issuer.com", u"sub": u"subject-id"}
def _output_public_keys(self, jwk_key, add_previous, strip_prefix): public_keys = jwk.KEYS() if add_previous: self._add_previous_public_keys(public_keys) public_keys.append(jwk_key) serialized_public_keys = public_keys.dump_jwks() prefix = '' if strip_prefix else 'COMMON_' public_signing_key = '{}JWT_PUBLIC_SIGNING_JWK_SET'.format(prefix) log.info('New JWT_PUBLIC_SIGNING_JWK_SET: %s.', serialized_public_keys) print(" ") print(" ") print(" *** YAML to share with ALL IDAs ***") print(" ") print( " # The following is the string representation of a JSON Web Key Set (JWK set)" ) print( " # containing all active public keys for verifying JWT signatures." ) print( " # See https://github.com/edx/edx-platform/blob/master/openedx/core/djangoapps/oauth_dispatch/" "docs/decisions/0008-use-asymmetric-jwts.rst") print(" ") print(" {}: '{}'".format(public_signing_key, serialized_public_keys)) return {public_signing_key: serialized_public_keys}
def _retrieve_jwks(): """Retrieve the JWKS from the given jwks_uri when cache misses.""" jwks_uri = self._key_uri_supplier.supply(issuer) if not jwks_uri: raise UnauthenticatedException(u"Cannot find the `jwks_uri` for issuer " u"%s: either the issuer is unknown or " u"the OpenID discovery failed" % issuer) try: response = requests.get(jwks_uri) json_response = response.json() except Exception as exception: message = u"Cannot retrieve valid verification keys from the `jwks_uri`" raise UnauthenticatedException(message, exception) if u"keys" in json_response: # De-serialize the JSON as a JWKS object. jwks_keys = jwk.KEYS() jwks_keys.load_jwks(response.text) return jwks_keys._keys else: # The JSON is a dictionary mapping from key id to X.509 certificates. # Thus we extract the public key from the X.509 certificates and # construct a JWKS object. return _extract_x509_certificates(json_response)
def test_supply_cached_jwks(self): rsa_key = PublicKey.RSA.generate(2048) jwks = jwk.KEYS() jwks.wrap_add(rsa_key) scheme = "https" issuer = "issuer.com" self._key_uri_supplier.supply.return_value = scheme + "://" + issuer @httmock.urlmatch(scheme=scheme, netloc=issuer) def _mock_response_with_jwks(url, response): # pylint: disable=unused-argument return jwks.dump_jwks() with httmock.HTTMock(_mock_response_with_jwks): JwksSupplierTest._mock_timer.return_value = 10 self.assertEqual(1, len(self._jwks_uri_supplier.supply(issuer))) # Add an additional key to the JWKS to be returned by the HTTP request. jwks.wrap_add(PublicKey.RSA.generate(2048)) # Forward the clock by 1 second. The JWKS should remain cached. JwksSupplierTest._mock_timer.return_value += 1 self._jwks_uri_supplier.supply(issuer) self.assertEqual(1, len(self._jwks_uri_supplier.supply(issuer))) # Forward the clock by 5 minutes. The cache entry should have expired so # the returned JWKS should be the updated one with two keys. JwksSupplierTest._mock_timer.return_value += 5 * 60 self._jwks_uri_supplier.supply(issuer) self.assertEqual(2, len(self._jwks_uri_supplier.supply(issuer)))
def get_public_jwk(self): """ Export Public JWK """ public_keys = jwk.KEYS() # Only append to keyset if a key exists if self.key: public_keys.append(self.key) return json.loads(public_keys.dump_jwks())
def _generate_key_pair(self): """ Generates an asymmetric key pair and returns the JWK of its public keys and keypair. """ rsa_key = RSA.generate(2048) rsa_jwk = jwk.RSAKey(kid="key_id", key=rsa_key) public_keys = jwk.KEYS() public_keys.append(rsa_jwk) serialized_public_keys_json = public_keys.dump_jwks() serialized_keypair = rsa_jwk.serialize(private=True) serialized_keypair_json = json.dumps(serialized_keypair) return serialized_public_keys_json, serialized_keypair_json
def test_verify_fails(self): auth_token = token_utils.generate_auth_token(self._jwt_claims, self._jwks._keys, kid=self._ec_kid) # Let the _jwks_supplier return a different key than the one we use to sign # the JWT. new_jwk = jwk.ECKey(use=u"sig").load_key(ecc.P256) new_jwks = jwk.KEYS() new_jwks._keys.append(new_jwk) self._jwks_supplier.supply.return_value = new_jwks with self.assertRaises(suppliers.UnauthenticatedException): self._authenticator.get_jwt_claims(auth_token)
def _encode_and_sign(payload, use_asymmetric_key, secret): """Encode and sign the provided payload.""" set_custom_metric('jwt_is_asymmetric', use_asymmetric_key) keys = jwk.KEYS() if use_asymmetric_key: serialized_keypair = json.loads(settings.JWT_AUTH['JWT_PRIVATE_SIGNING_JWK']) keys.add(serialized_keypair) algorithm = settings.JWT_AUTH['JWT_SIGNING_ALGORITHM'] else: key = secret if secret else settings.JWT_AUTH['JWT_SECRET_KEY'] keys.add({'key': key, 'kty': 'oct'}) algorithm = settings.JWT_AUTH['JWT_ALGORITHM'] data = json.dumps(payload) jws = JWS(data, alg=algorithm) return jws.sign_compact(keys=keys)
def test_authenticate_auth_token_with_bad_signature(self): new_rsa_key = jwk.RSAKey(use=u"sig").load_key( PublicKey.RSA.generate(2048)) kid = IntegrationTest._rsa_key.kid new_rsa_key.kid = kid new_jwks = jwk.KEYS() new_jwks._keys.append(new_rsa_key) auth_token = token_utils.generate_auth_token( IntegrationTest._JWT_CLAIMS, new_jwks._keys, alg=u"RS256", kid=kid) url = get_url(IntegrationTest._JWKS_PATH) self._provider_ids[self._ISSUER] = self._PROVIDER_ID self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig( False, url) message = u"Signature verification failed" with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): self._authenticator.authenticate(auth_token, self._auth_info, IntegrationTest._SERVICE_NAME)
def encode(self, payload): """Encode the provided payload.""" keys = jwk.KEYS() if self.asymmetric: serialized_keypair = json.loads( self.jwt_auth['JWT_PRIVATE_SIGNING_JWK']) keys.add(serialized_keypair) algorithm = self.jwt_auth['JWT_SIGNING_ALGORITHM'] else: key = self.secret if self.secret else self.jwt_auth[ 'JWT_SECRET_KEY'] keys.add({'key': key, 'kty': 'oct'}) algorithm = self.jwt_auth['JWT_ALGORITHM'] data = json.dumps(payload) jws = JWS(data, alg=algorithm) return jws.sign_compact(keys=keys)
def test_supply_jwks(self): rsa_key = PublicKey.RSA.generate(2048) jwks = jwk.KEYS() jwks.wrap_add(rsa_key) scheme = "https" issuer = "issuer.com" self._key_uri_supplier.supply.return_value = scheme + "://" + issuer @httmock.urlmatch(scheme=scheme, netloc=issuer) def _mock_response_with_jwks(url, response): # pylint: disable=unused-argument return jwks.dump_jwks() with httmock.HTTMock(_mock_response_with_jwks): actual_jwks = self._jwks_uri_supplier.supply(issuer) self.assertEquals(1, len(actual_jwks)) actual_key = actual_jwks[0].key self.assertEquals(rsa_key.n, actual_key.n) self.assertEquals(rsa_key.e, actual_key.e)
def get_jwk_key_pair(self): """ Returns the asymmetric JWT signing keys required """ rsa_jwk = jwk.RSAKey(kid="opencraft", key=self.rsa_key) # Serialize public JWT signing keys public_keys = jwk.KEYS() public_keys.append(rsa_jwk) serialized_public_keys_json = public_keys.dump_jwks() # Serialize private JWT signing keys serialized_keypair = rsa_jwk.serialize(private=True) serialized_keypair_json = json.dumps(serialized_keypair) # Named tuple for storing public and private JWT key pair jwk_key_pair = namedtuple('JWK_KEY_PAIR', ['public', 'private']) jwk_key_pair.public = serialized_public_keys_json jwk_key_pair.private = serialized_keypair_json return jwk_key_pair
def test_auth_token_cache_capacity(self): authenticator = tokens.Authenticator({}, self._jwks_supplier, cache_capacity=2) self._jwt_claims[u"email"] = u"*****@*****.**" auth_token1 = token_utils.generate_auth_token(self._jwt_claims, self._jwks._keys) self._jwt_claims[u"email"] = u"*****@*****.**" auth_token2 = token_utils.generate_auth_token(self._jwt_claims, self._jwks._keys) # Populate the decoded result into cache. authenticator.get_jwt_claims(auth_token1) authenticator.get_jwt_claims(auth_token2) # Reset the returned JWKS so the signature verification will fail next # time. new_ec_jwk = jwk.ECKey(use=u"sig").load_key(ecc.P256) new_ec_jwk.kid = self._ec_kid new_jwks = jwk.KEYS() new_jwks._keys.append(new_ec_jwk) self._jwks_supplier.supply.return_value = new_jwks # Verify the following calls still succeed since the auth tokens are # cached. authenticator.get_jwt_claims(auth_token1) authenticator.get_jwt_claims(auth_token2) # Populate a third auth token into the cache. self._jwt_claims[u"email"] = u"*****@*****.**" auth_token3 = token_utils.generate_auth_token(self._jwt_claims, new_jwks._keys) authenticator.get_jwt_claims(auth_token3) # Make sure the first auth token is evicted from the cache since the cache # is full. with self.assertRaises(suppliers.UnauthenticatedException): authenticator.get_jwt_claims(auth_token1)
def test_get_jwt_claims_via_caching(self): AuthenticatorTest._mock_timer.return_value = 10 auth_token = token_utils.generate_auth_token(self._jwt_claims, self._jwks._keys) # Populate the decoded result into cache. self._authenticator.get_jwt_claims(auth_token) # Reset the returned JWKS so the signature verification will fail next # time. self._jwks_supplier.supply.return_value = jwk.KEYS() # Forword time by 10 seconds. AuthenticatorTest._mock_timer.return_value += 10 # This call should succeed since the auth_token is cached. self._authenticator.get_jwt_claims(auth_token) # Forword time by 5 minutes. AuthenticatorTest._mock_timer.return_value += 5 * 60 # This call should fail since the cache expires and it needs to re-decode # the auth token with a different key set. with self.assertRaises(suppliers.UnauthenticatedException): self._authenticator.get_jwt_claims(auth_token)
# Author: Virgo Darth # Time: Sep 4, 2020 import getpass from Cryptodome.PublicKey import RSA from jwkest import jwk KEY_SIZE = 2048 # recommended key_phase = getpass.getpass(prompt='Enter your key phase: ') print() rsa_alg = RSA.generate(KEY_SIZE) # generate private key rsa_jwk = jwk.RSAKey(kid=key_phase, key=rsa_alg) # generate public key public_keys = jwk.KEYS() public_keys.append(rsa_jwk) # convert to string serialized_private_keys = rsa_jwk.__str__() serialized_public_keys = public_keys.dump_jwks() print("=====Start JWT_PRIVATE_SIGNING_JWK=====\n", serialized_private_keys, "\n=====End JWT_PRIVATE_SIGNING_JWK=====\n") print("=====Start JWT_PUBLIC_SIGNING_JWK_SET=====\n", serialized_public_keys, "\n=====End JWT_PUBLIC_SIGNING_JWK_SET=====")
class IntegrationTest(unittest.TestCase): _CURRENT_TIME = int(time.time()) _PORT = 8080 _ISSUER = u"https://*****:*****@name.com" _X509_PATH = u"x509" _JWT_CLAIMS = { u"aud": [u"https://aud1.local.host", u"https://aud2.local.host"], u"exp": _CURRENT_TIME + 60, u"email": u"*****@*****.**", u"iss": _ISSUER, u"sub": u"subject-id" } _ec_jwk = jwk.ECKey(use=u"sig").load_key(ecc.P256) _rsa_key = jwk.RSAKey(use=u"sig").load_key(PublicKey.RSA.generate(1024)) _ec_jwk.kid = u"ec-key-id" _rsa_key.kid = u"rsa-key-id" _mock_timer = mock.MagicMock() _jwks = jwk.KEYS() _jwks._keys.append(_ec_jwk) _jwks._keys.append(_rsa_key) _AUTH_TOKEN = token_utils.generate_auth_token(_JWT_CLAIMS, _jwks._keys, alg=u"RS256", kid=_rsa_key.kid) @classmethod def setUpClass(cls): dirname = os.path.dirname(os.path.realpath(__file__)) cls._cert_file = os.path.join(dirname, u"ssl.cert") cls._key_file = os.path.join(dirname, u"ssl.key") os.environ[u"REQUESTS_CA_BUNDLE"] = cls._cert_file rest_server = cls._RestServer() rest_server.start() def setUp(self): self._provider_ids = {} self._configs = {} self._authenticator = auth.create_authenticator( self._provider_ids, self._configs) self._auth_info = mock.MagicMock() self._auth_info.is_provider_allowed.return_value = True self._auth_info.get_allowed_audiences.return_value = [ u"https://aud1.local.host" ] def test_verify_auth_token_with_jwks(self): url = get_url(IntegrationTest._JWKS_PATH) self._provider_ids[self._ISSUER] = self._PROVIDER_ID self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig( False, url) user_info = self._authenticator.authenticate( IntegrationTest._AUTH_TOKEN, self._auth_info, IntegrationTest._SERVICE_NAME) self._assert_user_info_equals( tokens.UserInfo(IntegrationTest._JWT_CLAIMS), user_info) def test_authenticate_auth_token_with_bad_signature(self): new_rsa_key = jwk.RSAKey(use=u"sig").load_key( PublicKey.RSA.generate(2048)) kid = IntegrationTest._rsa_key.kid new_rsa_key.kid = kid new_jwks = jwk.KEYS() new_jwks._keys.append(new_rsa_key) auth_token = token_utils.generate_auth_token( IntegrationTest._JWT_CLAIMS, new_jwks._keys, alg=u"RS256", kid=kid) url = get_url(IntegrationTest._JWKS_PATH) self._provider_ids[self._ISSUER] = self._PROVIDER_ID self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig( False, url) message = u"Signature verification failed" with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): self._authenticator.authenticate(auth_token, self._auth_info, IntegrationTest._SERVICE_NAME) def test_verify_auth_token_with_x509(self): url = get_url(IntegrationTest._X509_PATH) self._provider_ids[self._ISSUER] = self._PROVIDER_ID self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig( False, url) user_info = self._authenticator.authenticate( IntegrationTest._AUTH_TOKEN, self._auth_info, IntegrationTest._SERVICE_NAME) self._assert_user_info_equals( tokens.UserInfo(IntegrationTest._JWT_CLAIMS), user_info) def test_verify_auth_token_with_invalid_x509(self): url = get_url(IntegrationTest._INVALID_X509_PATH) self._provider_ids[self._ISSUER] = self._PROVIDER_ID self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig( False, url) message = u"Cannot load X.509 certificate" with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, self._auth_info, IntegrationTest._SERVICE_NAME) def test_openid_discovery(self): self._provider_ids[self._ISSUER] = self._PROVIDER_ID self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig( True, None) user_info = self._authenticator.authenticate( IntegrationTest._AUTH_TOKEN, self._auth_info, IntegrationTest._SERVICE_NAME) self._assert_user_info_equals( tokens.UserInfo(IntegrationTest._JWT_CLAIMS), user_info) def test_openid_discovery_failed(self): self._provider_ids[self._ISSUER] = self._PROVIDER_ID self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig( False, None) message = (u"Cannot find the `jwks_uri` for issuer %s" % IntegrationTest._ISSUER) with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, self._auth_info, IntegrationTest._SERVICE_NAME) def test_authenticate_with_malformed_auth_code(self): with self.assertRaisesRegexp(suppliers.UnauthenticatedException, u"Cannot decode the auth token"): self._authenticator.authenticate(u"invalid-auth-code", self._auth_info, IntegrationTest._SERVICE_NAME) def test_authenticate_with_disallowed_issuer(self): url = get_url(IntegrationTest._JWKS_PATH) self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig( False, url) message = u"Unknown issuer: " + self._ISSUER with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, self._auth_info, IntegrationTest._SERVICE_NAME) def test_authenticate_with_unknown_issuer(self): message = (u"Cannot find the `jwks_uri` for issuer %s: " u"either the issuer is unknown") % IntegrationTest._ISSUER with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, self._auth_info, IntegrationTest._SERVICE_NAME) def test_authenticate_with_invalid_audience(self): url = get_url(IntegrationTest._JWKS_PATH) self._provider_ids[self._ISSUER] = self._PROVIDER_ID self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig( False, url) self._auth_info.get_allowed_audiences.return_value = [] with self.assertRaisesRegexp(suppliers.UnauthenticatedException, u"Audiences not allowed"): self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, self._auth_info, IntegrationTest._SERVICE_NAME) @mock.patch(u"time.time", _mock_timer) def test_authenticate_with_expired_auth_token(self): url = get_url(IntegrationTest._JWKS_PATH) self._provider_ids[self._ISSUER] = self._PROVIDER_ID self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig( False, url) IntegrationTest._mock_timer.return_value = 0 # Create an auth token that expires in 10 seconds. jwt_claims = copy.deepcopy(IntegrationTest._JWT_CLAIMS) jwt_claims[u"exp"] = time.time() + 10 auth_token = token_utils.generate_auth_token( jwt_claims, IntegrationTest._jwks._keys, alg=u"RS256", kid=IntegrationTest._rsa_key.kid) # Verify that the auth token can be authenticated successfully. self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, self._auth_info, IntegrationTest._SERVICE_NAME) # Advance the timer by 20 seconds and make sure the token is expired. IntegrationTest._mock_timer.return_value += 20 message = u"The auth token has already expired" with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): self._authenticator.authenticate(auth_token, self._auth_info, IntegrationTest._SERVICE_NAME) def test_invalid_openid_discovery_url(self): issuer = u"https://invalid.issuer" self._provider_ids[self._ISSUER] = self._PROVIDER_ID self._configs[issuer] = suppliers.IssuerUriConfig(True, None) jwt_claims = copy.deepcopy(IntegrationTest._JWT_CLAIMS) jwt_claims[u"iss"] = issuer auth_token = token_utils.generate_auth_token( jwt_claims, IntegrationTest._jwks._keys, alg=u"RS256", kid=IntegrationTest._rsa_key.kid) message = u"Cannot discover the jwks uri" with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): self._authenticator.authenticate(auth_token, self._auth_info, IntegrationTest._SERVICE_NAME) def test_invalid_jwks_uri(self): url = u"https://invalid.jwks.uri" self._provider_ids[self._ISSUER] = self._PROVIDER_ID self._configs[IntegrationTest._ISSUER] = suppliers.IssuerUriConfig( False, url) message = u"Cannot retrieve valid verification keys from the `jwks_uri`" with self.assertRaisesRegexp(suppliers.UnauthenticatedException, message): self._authenticator.authenticate(IntegrationTest._AUTH_TOKEN, self._auth_info, IntegrationTest._SERVICE_NAME) def _assert_user_info_equals(self, expected, actual): self.assertEqual(expected.audiences, actual.audiences) self.assertEqual(expected.email, actual.email) self.assertEqual(expected.subject_id, actual.subject_id) self.assertEqual(expected.issuer, actual.issuer) class _RestServer(object): def __init__(self): app = flask.Flask(u"integration-test-server") @app.route(u"/" + IntegrationTest._JWKS_PATH) def get_json_web_key_set(): # pylint: disable=unused-variable return IntegrationTest._jwks.dump_jwks() @app.route(u"/" + IntegrationTest._X509_PATH) def get_x509_certificates(): # pylint: disable=unused-variable cert = IntegrationTest._rsa_key.key.publickey().exportKey( u"PEM") return flask.jsonify( {IntegrationTest._rsa_key.kid: cert.decode('ascii')}) @app.route(u"/" + IntegrationTest._INVALID_X509_PATH) def get_invalid_x509_certificates(): # pylint: disable=unused-variable return flask.jsonify( {IntegrationTest._rsa_key.kid: u"invalid cert"}) @app.route(u"/.well-known/openid-configuration") def get_openid_configuration(): # pylint: disable=unused-variable return flask.jsonify( {u"jwks_uri": get_url(IntegrationTest._JWKS_PATH)}) self._application = app def start(self): def run_app(): ssl_context = (IntegrationTest._cert_file, IntegrationTest._key_file) self._application.run(port=IntegrationTest._PORT, ssl_context=ssl_context) thread = threading.Thread(target=run_app, args=()) thread.daemon = True thread.start()