if args.rsa_file: keys.append(RSAKey(key=import_rsa_key(args.rsa_file), kid=_kid)) if args.hmac_key: keys.append(SYMKey(key=args.hmac_key, kid=_kid)) if args.jwk: _key = key_from_jwk_dict(open(args.jwk).read()) keys.append(_key) if args.jwks: _k = KeyJar() _k.import_jwks(open(args.jwks).read(), "") keys.extend(_k.issuer_keys("")) if args.jwks_url: _kb = KeyBundle(source=args.jwks_url) keys.extend(_kb.get()) if not args.msg: # If nothing specified assume stdin message = sys.stdin.read() elif args.msg == "-": message = sys.stdin.read() else: if os.path.isfile(args.msg): message = open(args.msg).read().strip("\n") else: message = args.msg message = message.strip() message = message.strip('"') main(message, keys, args.quiet)
def test_get_audience_and_algorithm_default_alg(self, entity): _service_context = entity.client_get("service_context") _service_context.token_endpoint = "https://example.com/token" _service_context.provider_info = { 'issuer': 'https://example.com/', 'token_endpoint': "https://example.com/token" } _service_context.registration_response = { 'token_endpoint_auth_signing_alg': "HS256" } csj = ClientSecretJWT() request = AccessTokenRequest() _service_context.registration_response = {} token_service = entity.client_get("service", 'accesstoken') # Add a RSA key to be able to handle default _kb = KeyBundle() _rsa_key = new_rsa_key() _kb.append(_rsa_key) _service_context.keyjar.add_kb("", _kb) # Since I have a RSA key this doesn't fail csj.construct(request, service=token_service, authn_endpoint='token_endpoint') _jws = factory(request["client_assertion"]) assert _jws.jwt.headers["alg"] == "RS256" assert _jws.jwt.headers["kid"] == _rsa_key.kid # By client preferences request = AccessTokenRequest() _service_context.client_preferences = { "token_endpoint_auth_signing_alg": "RS512" } csj.construct(request, service=token_service, authn_endpoint='token_endpoint') _jws = factory(request["client_assertion"]) assert _jws.jwt.headers["alg"] == "RS512" assert _jws.jwt.headers["kid"] == _rsa_key.kid # Use provider information is everything else fails request = AccessTokenRequest() _service_context.client_preferences = {} _service_context.provider_info[ "token_endpoint_auth_signing_alg_values_supported"] = [ "ES256", "RS256" ] csj.construct(request, service=token_service, authn_endpoint='token_endpoint') _jws = factory(request["client_assertion"]) # Should be RS256 since I have no key for ES256 assert _jws.jwt.headers["alg"] == "RS256" assert _jws.jwt.headers["kid"] == _rsa_key.kid
def test_local_jwk(): _path = full_path('jwk_private_key.json') kb = KeyBundle(source='file://{}'.format(_path)) assert kb
def test_thumbprint(): kb = KeyBundle(JWKS_DICT) for key in kb: txt = key.thumbprint("SHA-256") assert txt in EXPECTED
@pytest.mark.parametrize("keytype,alg,enc", [ ('RSA', 'RSA1_5', 'A128CBC-HS256'), ('EC', 'ECDH-ES', 'A128GCM'), ]) def test_to_jwe(keytype, alg, enc): msg = Message(a='foo', b='bar', c='tjoho') _jwe = msg.to_jwe(KEYJAR.get_encrypt_key(keytype, ''), alg=alg, enc=enc) with pytest.raises(HeaderError): Message().from_jwt(_jwe, KEYJAR, encalg="RSA-OAEP", encenc=enc) with pytest.raises(HeaderError): Message().from_jwt(_jwe, KEYJAR, encenc="A256CBC-HS512", encalg=alg) NEW_KEYJAR = KEYJAR.copy() kb = KeyBundle() k = new_rsa_key() NEW_KID = k.kid kb.append(k) NEW_KEYJAR.add_kb('', kb) def test_no_suitable_keys(): keytype = 'RSA' alg = 'RS256' msg = Message(a='foo', b='bar', c='tjoho') _jwt = msg.to_jwt(NEW_KEYJAR.get_signing_key(keytype, '', kid=NEW_KID), alg) with pytest.raises(NoSuitableSigningKeys): Message().from_jwt(_jwt, KEYJAR)
def test_with_sym_key(): kc = KeyBundle({"kty": "oct", "key": "highestsupersecret", "use": "sig"}) assert len(kc.get("oct")) == 1 assert len(kc.get("rsa")) == 0 assert kc.remote is False assert kc.source is None
def test_local_jwk_copy(): _path = full_path("jwk_private_key.json") kb = KeyBundle(source="file://{}".format(_path)) kb2 = kb.copy() assert kb2.source == kb.source
from oidcmsg.oauth2 import RefreshAccessTokenRequest from oidcmsg.oauth2 import ResponseMessage from oidcmsg.oidc import IdToken from oidcmsg.time_util import utc_time_sans_frac from oidcservice.exception import OidcServiceError from oidcservice.exception import ParseError from oidcrp.oauth2 import Client sys.path.insert(0, '.') _dirname = os.path.dirname(os.path.abspath(__file__)) BASE_PATH = os.path.join(_dirname, "data", "keys") _key = import_private_rsa_key_from_file(os.path.join(BASE_PATH, "rsa.key")) KC_RSA = KeyBundle({"priv_key": _key, "kty": "RSA", "use": "sig"}) CLIENT_ID = "client_1" IDTOKEN = IdToken(iss="http://oidc.example.org/", sub="sub", aud=CLIENT_ID, exp=utc_time_sans_frac() + 86400, nonce="N0nce", iat=time.time()) class MockResponse(): def __init__(self, status_code, text, headers=None): self.status_code = status_code self.text = text self.headers = headers or {}
from oidcmsg.oidc.session import EndSessionRequest from oidcmsg.oidc.session import EndSessionResponse from oidcmsg.time_util import utc_time_sans_frac CLIENT_ID = "client_1" ISS = 'https://example.com' IDTOKEN = IdToken(iss=ISS, sub="sub", aud=CLIENT_ID, exp=utc_time_sans_frac() + 300, nonce="N0nce", iat=time.time()) KC_SYM_S = KeyBundle({ "kty": "oct", "key": "abcdefghijklmnop".encode("utf-8"), "use": "sig", "alg": "HS256" }) NOW = utc_time_sans_frac() KEYDEF = [{ "type": "EC", "crv": "P-256", "use": ["sig"] }, { "type": "EC", "crv": "P-256", "use": ["enc"] }]
def test_no_use(self): kb = KeyBundle(JWK0["keys"]) issuer = KeyIssuer() issuer.add_kb(kb) enc_key = issuer.get("enc", "RSA") assert enc_key != []
def test_contains(): issuer = KeyIssuer() issuer.add_kb(KeyBundle(JWK1["keys"])) for k in issuer.all_keys(): assert k in issuer
__author__ = 'Roland Hedberg' ALICE = 'https://example.org/alice' BOB = 'https://example.com/bob' BASEDIR = os.path.abspath(os.path.dirname(__file__)) def full_path(local_file): return os.path.join(BASEDIR, local_file) # k1 = import_private_rsa_key_from_file(full_path('rsa.key')) # k2 = import_private_rsa_key_from_file(full_path('size2048.key')) kb1 = KeyBundle(source='file://{}'.format(full_path('rsa.key')), fileformat='der', keyusage='sig', kid='1') kb2 = KeyBundle(source='file://{}'.format(full_path('size2048.key')), fileformat='der', keyusage='enc', kid='2') ALICE_KEY_JAR = KeyJar() ALICE_KEY_JAR.add_kb(ALICE, kb1) ALICE_KEY_JAR.add_kb(ALICE, kb2) kb3 = KeyBundle(source='file://{}'.format(full_path('server.key')), fileformat='der', keyusage='enc', kid='3')
def test_remote_not_modified(): source = "https://example.com/keys.json" headers = { "Date": "Fri, 15 Mar 2019 10:14:25 GMT", "Last-Modified": "Fri, 1 Jan 1970 00:00:00 GMT", } headers = {} # Mock response httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds kb = KeyBundle(source=source, httpc=requests.request, httpc_params=httpc_params) with responses.RequestsMock() as rsps: rsps.add(method="GET", url=source, json=JWKS_DICT, status=200, headers=headers) assert kb.do_remote() assert kb.last_remote == headers.get("Last-Modified") timeout1 = kb.time_out with responses.RequestsMock() as rsps: rsps.add(method="GET", url=source, status=304, headers=headers) assert not kb.do_remote() assert kb.last_remote == headers.get("Last-Modified") timeout2 = kb.time_out assert timeout1 != timeout2 exp = kb.dump() kb2 = KeyBundle().load(exp) assert kb2.source == source assert len(kb2.keys()) == 3 assert len(kb2.active_keys()) == 3 assert len(kb2.get("rsa")) == 1 assert len(kb2.get("oct")) == 1 assert len(kb2.get("ec")) == 1 assert kb2.httpc_params == {"timeout": (2, 2)} assert kb2.imp_jwks assert kb2.last_updated
def test_export_inactive(): desc = {"kty": "oct", "key": "highestsupersecret", "use": "sig"} kb = KeyBundle([desc]) assert len(kb.keys()) == 1 for k in kb.keys(): kb.mark_as_inactive(k.kid) desc = {"kty": "oct", "key": "highestsupersecret", "use": "enc"} kb.add_jwk_dicts([desc]) res = kb.dump() assert set(res.keys()) == { "cache_time", "etag", "fileformat", "httpc_params", "ignore_errors_until", "ignore_errors_period", "ignore_invalid_keys", "imp_jwks", "keys", "keytype", "keyusage", "last_updated", "last_remote", "last_local", "remote", "local", "source", "time_out", } kb2 = KeyBundle().load(res) assert len(kb2.keys()) == 2 assert len(kb2.active_keys()) == 1
def test_key_bundle_difference_none(): _kb0 = build_key_bundle(key_conf=KEYSPEC_6) _kb1 = KeyBundle() _kb1.extend(_kb0.keys()) assert _kb0.difference(_kb1) == []
"dEtpjbEvbhfgwUI-bdK5xAU_9UQ", "kty": "RSA", "n": "x7HNcD9ZxTFRaAgZ7-gdYLkgQua3zvQseqBJIt8Uq3MimInMZoE9QGQeSML7qZPlowb5BUakdLI70ayM4vN36--0ht8-oCHhl8YjGFQkU-Iv2yahWHEP-1EK6eOEYu6INQP9Lk0HMk3QViLwshwb-KXVD02jdmX2HNdYJdPyc0c", "use": "sig", "x5c": [ "MIICWzCCAcSgAwIBAgIJAL3MzqqEFMYjMA0GCSqGSIb3DQEBBQUAMCkxJzAlBgNVBAMTHkxpdmUgSUQgU1RTIFNpZ25pbmcgUHVibGljIEtleTAeFw0xMzExMTExOTA1MDJaFw0xOTExMTAxOTA1MDJaMCkxJzAlBgNVBAMTHkxpdmUgSUQgU1RTIFNpZ25pbmcgUHVibGljIEtleTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAx7HNcD9ZxTFRaAgZ7+gdYLkgQua3zvQseqBJIt8Uq3MimInMZoE9QGQeSML7qZPlowb5BUakdLI70ayM4vN36++0ht8+oCHhl8YjGFQkU+Iv2yahWHEP+1EK6eOEYu6INQP9Lk0HMk3QViLwshwb+KXVD02jdmX2HNdYJdPyc0cCAwEAAaOBijCBhzAdBgNVHQ4EFgQULR0aj9AtiNMgqIY8ZyXZGsHcJ5gwWQYDVR0jBFIwUIAULR0aj9AtiNMgqIY8ZyXZGsHcJ5ihLaQrMCkxJzAlBgNVBAMTHkxpdmUgSUQgU1RTIFNpZ25pbmcgUHVibGljIEtleYIJAL3MzqqEFMYjMAsGA1UdDwQEAwIBxjANBgkqhkiG9w0BAQUFAAOBgQBshrsF9yls4ArxOKqXdQPDgHrbynZL8m1iinLI4TeSfmTCDevXVBJrQ6SgDkihl3aCj74IEte2MWN78sHvLLTWTAkiQSlGf1Zb0durw+OvlunQ2AKbK79Qv0Q+wwGuK+oymWc3GSdP1wZqk9dhrQxb3FtdU2tMke01QTut6wr7ig==" ], "x5t": "dEtpjbEvbhfgwUI-bdK5xAU_9UQ" }] } SIGJWKS = KeyBundle(JWKS_b) def P256(): return ec.generate_private_key(ec.SECP256R1(), default_backend()) def test_1(): claimset = { "iss": "joe", "exp": 1300819380, "http://example.com/is_root": True } _jws = JWS(claimset, cty="JWT", alg='none') _jwt = _jws.sign_compact()
def test_export_inactive(): desc = {"kty": "oct", "key": "highestsupersecret", "use": "sig"} kb = KeyBundle([desc]) assert len(kb.keys()) == 1 for k in kb.keys(): kb.mark_as_inactive(k.kid) desc = {"kty": "oct", "key": "highestsupersecret", "use": "enc"} kb.do_keys([desc]) res = kb.dump() assert set(res.keys()) == { "cache_time", "fileformat", "httpc_params", "imp_jwks", "keys", "last_updated", "last_remote", "last_local", "remote", "local", "time_out", } kb2 = KeyBundle().load(res) assert len(kb2.keys()) == 2 assert len(kb2.active_keys()) == 1
def test_pick_use(): keys = KeyBundle(JWK_b) _jws = JWS("foobar", alg="RS256", kid="MnC_VZcATfM5pOYiJHMba9goEKY") _keys = _jws.pick_keys(keys, use="sig") assert len(_keys) == 1
def test_unknown_source(): with pytest.raises(ImportError): KeyBundle(source="foobar")
def test_pick_wrong_alg(): keys = KeyBundle(JWKS_b) _jws = JWS("foobar", alg="EC256", kid="rsa1") with pytest.raises(ValueError): _keys = _jws.pick_keys(keys, use="sig")
def test_loads_0(): kb = KeyBundle(JWK0) assert len(kb) == 1 key = kb.get("rsa")[0] assert key.kid == "abc" assert key.kty == "RSA"
class JWx: """A basic class with the commonalities between the JWS and JWE classes. :param alg: The signing algorithm :param jku: a URI that refers to a resource for a set of JSON-encoded public keys, one of which corresponds to the key used to digitally sign the JWS :param jwk: A JSON Web Key that corresponds to the key used to digitally sign the JWS :param x5u: a URI that refers to a resource for the X.509 public key certificate or certificate chain [RFC5280] corresponding to the key used to digitally sign the JWS. :param x5t: a base64url encoded SHA-1 thumbprint (a.k.a. digest) of the DER encoding of the X.509 certificate [RFC5280] corresponding to the key used to digitally sign the JWS. :param x5c: the X.509 public key certificate or certificate chain corresponding to the key used to digitally sign the JWS. :param kid: a hint indicating which key was used to secure the JWS. :param typ: the type of this object. 'JWS' == JWS Compact Serialization 'JWS+JSON' == JWS JSON Serialization :param cty: the type of the secured content :param crit: indicates which extensions that are being used and MUST be understood and processed. :param kwargs: Extra header parameters :return: A class instance """ args = [ "alg", "jku", "jwk", "x5u", "x5t", "x5c", "kid", "typ", "cty", "crit" ] def __init__(self, msg=None, with_digest=False, httpc=None, **kwargs): self.msg = msg self._dict = {} self.with_digest = with_digest if httpc: self.httpc = httpc else: self.httpc = requests.request self.jwt = None self._jwk = None self._jwks = None self._header = {} if kwargs: for key in self.args: try: _val = kwargs[key] except KeyError: continue if key == "jwk": self._set_jwk(_val) self._jwk = self._dict["jwk"] elif key == "x5c": self._dict["x5c"] = _val _pub_key = import_rsa_key(_val) self._jwk = RSAKey(pub_key=_pub_key).to_dict() elif key == "jku": self._jwks = KeyBundle(source=_val, httpc=self.httpc) self._dict["jku"] = _val elif "x5u" in self: try: _spec = load_x509_cert(self["x5u"], self.httpc, {}) self._jwk = RSAKey(pub_key=_spec["rsa"]).to_dict() except Exception: # ca_chain = load_x509_cert_chain(self["x5u"]) raise ValueError("x5u") else: self._dict[key] = _val def _set_jwk(self, val): if isinstance(val, dict): _k = key_from_jwk_dict(val) self._dict["jwk"] = val elif isinstance(val, str): # verify that it's a real JWK _val = json.loads(val) _j = key_from_jwk_dict(_val) self._dict["jwk"] = _val elif isinstance(val, JWK): self._dict["jwk"] = val.to_dict() else: raise ValueError( "JWK must be a string a JSON object or a JWK instance") def __contains__(self, item): return item in self._dict def __getitem__(self, item): return self._dict[item] def __setitem__(self, key, value): self._dict[key] = value def __getattr__(self, item): try: return self._dict[item] except KeyError: raise AttributeError(item) def keys(self): """Return all keys.""" return list(self._dict.keys()) def _set_header_jwk(self, header, **kwargs): if "jwk" in self: header["jwk"] = self["jwk"] else: try: _jwk = kwargs["jwk"] except KeyError: pass else: try: header["jwk"] = _jwk.serialize() # JWK instance except AttributeError: if isinstance(_jwk, dict): header["jwk"] = _jwk # dictionary else: _d = json.loads(_jwk) # JSON # Verify that it's a valid JWK _k = key_from_jwk_dict(_d) header["jwk"] = _d def headers(self, **kwargs): """Return the JWE/JWS header.""" _header = self._header.copy() for param in self.args: try: _header[param] = kwargs[param] except KeyError: try: if self._dict[param]: _header[param] = self._dict[param] except KeyError: pass self._set_header_jwk(_header, **kwargs) if "kid" in self: if not isinstance(self["kid"], str): raise HeaderError("kid of wrong value type") return _header def _get_keys(self): _keys = [] if self._jwk: _keys.append(key_from_jwk_dict(self._jwk)) if self._jwks is not None: _keys.extend(self._jwks.keys()) return _keys def alg2keytype(self, alg): """Convert an algorithm identifier to a key type identifier.""" raise NotImplementedError() def pick_keys(self, keys, use="", alg=""): """ The assumption is that upper layer has made certain you only get keys you can use. :param alg: The crypto algorithm :param use: What the key should be used for :param keys: A list of JWK instances :return: A list of JWK instances that fulfill the requirements """ if not alg: alg = self["alg"] if alg == "none": return [] _k = self.alg2keytype(alg) if _k is None: LOGGER.error("Unknown algorithm '%s'", alg) raise ValueError("Unknown cryptography algorithm") LOGGER.debug("Picking key by key type=%s", _k) _kty = [ _k.lower(), _k.upper(), _k.lower().encode("utf-8"), _k.upper().encode("utf-8"), ] _keys = [k for k in keys if k.kty in _kty] try: _kid = self["kid"] except KeyError: try: _kid = self.jwt.headers["kid"] except (AttributeError, KeyError): _kid = None LOGGER.debug("Picking key based on alg=%s, kid=%s and use=%s", alg, _kid, use) pkey = [] for _key in _keys: LOGGER.debug("Picked: kid:%s, use:%s, kty:%s", _key.kid, _key.use, _key.kty) if _kid: if _kid != _key.kid: continue if use and _key.use and _key.use != use: continue if alg and _key.alg and _key.alg != alg: continue pkey.append(_key) return pkey def _pick_alg(self, keys): alg = None try: alg = self["alg"] except KeyError: # try to get alg from key if there is only one if keys is not None and len(keys) == 1: key = next( iter(keys)) # first element from either list or dict if key.alg: self["alg"] = alg = key.alg if not alg: self["alg"] = alg = "none" return alg def _decode(self, payload): _msg = b64d(as_bytes(payload)) if "cty" in self: if self["cty"] == "JWT": _msg = json.loads(as_unicode(_msg)) return _msg def dump_header(self): """Return all attributes with values.""" return {x: self._dict[x] for x in self.args if x in self._dict}
def test_jwks_url(): keys = KeyBundle(source="https://login.salesforce.com/id/keys") # Forces read from the network keys.update() assert len(keys)
def test_no_use(self): kb = KeyBundle(JWK0["keys"]) kj = KeyJar() kj.add_kb("abcdefgh", kb) enc_key = kj.get_encrypt_key("RSA", "abcdefgh") assert enc_key != []