if args.rsa_file:
        keys.append(RSAKey(key=import_rsa_key(args.rsa_file), kid=_kid))
    if args.hmac_key:
        keys.append(SYMKey(key=args.hmac_key, kid=_kid))

    if args.jwk:
        _key = key_from_jwk_dict(open(args.jwk).read())
        keys.append(_key)

    if args.jwks:
        _k = KeyJar()
        _k.import_jwks(open(args.jwks).read(), "")
        keys.extend(_k.issuer_keys(""))

    if args.jwks_url:
        _kb = KeyBundle(source=args.jwks_url)
        keys.extend(_kb.get())

    if not args.msg:  # If nothing specified assume stdin
        message = sys.stdin.read()
    elif args.msg == "-":
        message = sys.stdin.read()
    else:
        if os.path.isfile(args.msg):
            message = open(args.msg).read().strip("\n")
        else:
            message = args.msg

    message = message.strip()
    message = message.strip('"')
    main(message, keys, args.quiet)
    def test_get_audience_and_algorithm_default_alg(self, entity):
        _service_context = entity.client_get("service_context")
        _service_context.token_endpoint = "https://example.com/token"

        _service_context.provider_info = {
            'issuer': 'https://example.com/',
            'token_endpoint': "https://example.com/token"
        }

        _service_context.registration_response = {
            'token_endpoint_auth_signing_alg': "HS256"
        }

        csj = ClientSecretJWT()
        request = AccessTokenRequest()

        _service_context.registration_response = {}

        token_service = entity.client_get("service", 'accesstoken')

        # Add a RSA key to be able to handle default
        _kb = KeyBundle()
        _rsa_key = new_rsa_key()
        _kb.append(_rsa_key)
        _service_context.keyjar.add_kb("", _kb)
        # Since I have a RSA key this doesn't fail
        csj.construct(request,
                      service=token_service,
                      authn_endpoint='token_endpoint')

        _jws = factory(request["client_assertion"])
        assert _jws.jwt.headers["alg"] == "RS256"
        assert _jws.jwt.headers["kid"] == _rsa_key.kid

        # By client preferences
        request = AccessTokenRequest()
        _service_context.client_preferences = {
            "token_endpoint_auth_signing_alg": "RS512"
        }
        csj.construct(request,
                      service=token_service,
                      authn_endpoint='token_endpoint')

        _jws = factory(request["client_assertion"])
        assert _jws.jwt.headers["alg"] == "RS512"
        assert _jws.jwt.headers["kid"] == _rsa_key.kid

        # Use provider information is everything else fails
        request = AccessTokenRequest()
        _service_context.client_preferences = {}
        _service_context.provider_info[
            "token_endpoint_auth_signing_alg_values_supported"] = [
                "ES256", "RS256"
            ]
        csj.construct(request,
                      service=token_service,
                      authn_endpoint='token_endpoint')

        _jws = factory(request["client_assertion"])
        # Should be RS256 since I have no key for ES256
        assert _jws.jwt.headers["alg"] == "RS256"
        assert _jws.jwt.headers["kid"] == _rsa_key.kid
def test_local_jwk():
    _path = full_path('jwk_private_key.json')
    kb = KeyBundle(source='file://{}'.format(_path))
    assert kb
Ejemplo n.º 4
0
def test_thumbprint():
    kb = KeyBundle(JWKS_DICT)
    for key in kb:
        txt = key.thumbprint("SHA-256")
        assert txt in EXPECTED
Ejemplo n.º 5
0
@pytest.mark.parametrize("keytype,alg,enc", [
    ('RSA', 'RSA1_5', 'A128CBC-HS256'),
    ('EC', 'ECDH-ES', 'A128GCM'),
])
def test_to_jwe(keytype, alg, enc):
    msg = Message(a='foo', b='bar', c='tjoho')
    _jwe = msg.to_jwe(KEYJAR.get_encrypt_key(keytype, ''), alg=alg, enc=enc)
    with pytest.raises(HeaderError):
        Message().from_jwt(_jwe, KEYJAR, encalg="RSA-OAEP", encenc=enc)
    with pytest.raises(HeaderError):
        Message().from_jwt(_jwe, KEYJAR, encenc="A256CBC-HS512", encalg=alg)


NEW_KEYJAR = KEYJAR.copy()
kb = KeyBundle()
k = new_rsa_key()
NEW_KID = k.kid
kb.append(k)
NEW_KEYJAR.add_kb('', kb)


def test_no_suitable_keys():
    keytype = 'RSA'
    alg = 'RS256'
    msg = Message(a='foo', b='bar', c='tjoho')
    _jwt = msg.to_jwt(NEW_KEYJAR.get_signing_key(keytype, '', kid=NEW_KID),
                      alg)
    with pytest.raises(NoSuitableSigningKeys):
        Message().from_jwt(_jwt, KEYJAR)
Ejemplo n.º 6
0
def test_with_sym_key():
    kc = KeyBundle({"kty": "oct", "key": "highestsupersecret", "use": "sig"})
    assert len(kc.get("oct")) == 1
    assert len(kc.get("rsa")) == 0
    assert kc.remote is False
    assert kc.source is None
Ejemplo n.º 7
0
def test_local_jwk_copy():
    _path = full_path("jwk_private_key.json")
    kb = KeyBundle(source="file://{}".format(_path))
    kb2 = kb.copy()
    assert kb2.source == kb.source
from oidcmsg.oauth2 import RefreshAccessTokenRequest
from oidcmsg.oauth2 import ResponseMessage
from oidcmsg.oidc import IdToken
from oidcmsg.time_util import utc_time_sans_frac
from oidcservice.exception import OidcServiceError
from oidcservice.exception import ParseError

from oidcrp.oauth2 import Client

sys.path.insert(0, '.')

_dirname = os.path.dirname(os.path.abspath(__file__))
BASE_PATH = os.path.join(_dirname, "data", "keys")

_key = import_private_rsa_key_from_file(os.path.join(BASE_PATH, "rsa.key"))
KC_RSA = KeyBundle({"priv_key": _key, "kty": "RSA", "use": "sig"})

CLIENT_ID = "client_1"
IDTOKEN = IdToken(iss="http://oidc.example.org/",
                  sub="sub",
                  aud=CLIENT_ID,
                  exp=utc_time_sans_frac() + 86400,
                  nonce="N0nce",
                  iat=time.time())


class MockResponse():
    def __init__(self, status_code, text, headers=None):
        self.status_code = status_code
        self.text = text
        self.headers = headers or {}
from oidcmsg.oidc.session import EndSessionRequest
from oidcmsg.oidc.session import EndSessionResponse
from oidcmsg.time_util import utc_time_sans_frac

CLIENT_ID = "client_1"
ISS = 'https://example.com'

IDTOKEN = IdToken(iss=ISS,
                  sub="sub",
                  aud=CLIENT_ID,
                  exp=utc_time_sans_frac() + 300,
                  nonce="N0nce",
                  iat=time.time())
KC_SYM_S = KeyBundle({
    "kty": "oct",
    "key": "abcdefghijklmnop".encode("utf-8"),
    "use": "sig",
    "alg": "HS256"
})

NOW = utc_time_sans_frac()

KEYDEF = [{
    "type": "EC",
    "crv": "P-256",
    "use": ["sig"]
}, {
    "type": "EC",
    "crv": "P-256",
    "use": ["enc"]
}]
Ejemplo n.º 10
0
 def test_no_use(self):
     kb = KeyBundle(JWK0["keys"])
     issuer = KeyIssuer()
     issuer.add_kb(kb)
     enc_key = issuer.get("enc", "RSA")
     assert enc_key != []
Ejemplo n.º 11
0
def test_contains():
    issuer = KeyIssuer()
    issuer.add_kb(KeyBundle(JWK1["keys"]))
    for k in issuer.all_keys():
        assert k in issuer
Ejemplo n.º 12
0
__author__ = 'Roland Hedberg'

ALICE = 'https://example.org/alice'
BOB = 'https://example.com/bob'
BASEDIR = os.path.abspath(os.path.dirname(__file__))


def full_path(local_file):
    return os.path.join(BASEDIR, local_file)


# k1 = import_private_rsa_key_from_file(full_path('rsa.key'))
# k2 = import_private_rsa_key_from_file(full_path('size2048.key'))

kb1 = KeyBundle(source='file://{}'.format(full_path('rsa.key')),
                fileformat='der',
                keyusage='sig',
                kid='1')
kb2 = KeyBundle(source='file://{}'.format(full_path('size2048.key')),
                fileformat='der',
                keyusage='enc',
                kid='2')

ALICE_KEY_JAR = KeyJar()
ALICE_KEY_JAR.add_kb(ALICE, kb1)
ALICE_KEY_JAR.add_kb(ALICE, kb2)

kb3 = KeyBundle(source='file://{}'.format(full_path('server.key')),
                fileformat='der',
                keyusage='enc',
                kid='3')
def test_remote_not_modified():
    source = "https://example.com/keys.json"
    headers = {
        "Date": "Fri, 15 Mar 2019 10:14:25 GMT",
        "Last-Modified": "Fri, 1 Jan 1970 00:00:00 GMT",
    }
    headers = {}

    # Mock response
    httpc_params = {"timeout": (2, 2)}  # connect, read timeouts in seconds
    kb = KeyBundle(source=source,
                   httpc=requests.request,
                   httpc_params=httpc_params)

    with responses.RequestsMock() as rsps:
        rsps.add(method="GET",
                 url=source,
                 json=JWKS_DICT,
                 status=200,
                 headers=headers)
        assert kb.do_remote()
        assert kb.last_remote == headers.get("Last-Modified")
        timeout1 = kb.time_out

    with responses.RequestsMock() as rsps:
        rsps.add(method="GET", url=source, status=304, headers=headers)
        assert not kb.do_remote()
        assert kb.last_remote == headers.get("Last-Modified")
        timeout2 = kb.time_out

    assert timeout1 != timeout2

    exp = kb.dump()
    kb2 = KeyBundle().load(exp)
    assert kb2.source == source
    assert len(kb2.keys()) == 3
    assert len(kb2.active_keys()) == 3
    assert len(kb2.get("rsa")) == 1
    assert len(kb2.get("oct")) == 1
    assert len(kb2.get("ec")) == 1
    assert kb2.httpc_params == {"timeout": (2, 2)}
    assert kb2.imp_jwks
    assert kb2.last_updated
Ejemplo n.º 14
0
def test_export_inactive():
    desc = {"kty": "oct", "key": "highestsupersecret", "use": "sig"}
    kb = KeyBundle([desc])
    assert len(kb.keys()) == 1
    for k in kb.keys():
        kb.mark_as_inactive(k.kid)
    desc = {"kty": "oct", "key": "highestsupersecret", "use": "enc"}
    kb.add_jwk_dicts([desc])
    res = kb.dump()
    assert set(res.keys()) == {
        "cache_time",
        "etag",
        "fileformat",
        "httpc_params",
        "ignore_errors_until",
        "ignore_errors_period",
        "ignore_invalid_keys",
        "imp_jwks",
        "keys",
        "keytype",
        "keyusage",
        "last_updated",
        "last_remote",
        "last_local",
        "remote",
        "local",
        "source",
        "time_out",
    }

    kb2 = KeyBundle().load(res)
    assert len(kb2.keys()) == 2
    assert len(kb2.active_keys()) == 1
Ejemplo n.º 15
0
def test_key_bundle_difference_none():
    _kb0 = build_key_bundle(key_conf=KEYSPEC_6)
    _kb1 = KeyBundle()
    _kb1.extend(_kb0.keys())

    assert _kb0.difference(_kb1) == []
Ejemplo n.º 16
0
        "dEtpjbEvbhfgwUI-bdK5xAU_9UQ",
        "kty":
        "RSA",
        "n":
        "x7HNcD9ZxTFRaAgZ7-gdYLkgQua3zvQseqBJIt8Uq3MimInMZoE9QGQeSML7qZPlowb5BUakdLI70ayM4vN36--0ht8-oCHhl8YjGFQkU-Iv2yahWHEP-1EK6eOEYu6INQP9Lk0HMk3QViLwshwb-KXVD02jdmX2HNdYJdPyc0c",
        "use":
        "sig",
        "x5c": [
            "MIICWzCCAcSgAwIBAgIJAL3MzqqEFMYjMA0GCSqGSIb3DQEBBQUAMCkxJzAlBgNVBAMTHkxpdmUgSUQgU1RTIFNpZ25pbmcgUHVibGljIEtleTAeFw0xMzExMTExOTA1MDJaFw0xOTExMTAxOTA1MDJaMCkxJzAlBgNVBAMTHkxpdmUgSUQgU1RTIFNpZ25pbmcgUHVibGljIEtleTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAx7HNcD9ZxTFRaAgZ7+gdYLkgQua3zvQseqBJIt8Uq3MimInMZoE9QGQeSML7qZPlowb5BUakdLI70ayM4vN36++0ht8+oCHhl8YjGFQkU+Iv2yahWHEP+1EK6eOEYu6INQP9Lk0HMk3QViLwshwb+KXVD02jdmX2HNdYJdPyc0cCAwEAAaOBijCBhzAdBgNVHQ4EFgQULR0aj9AtiNMgqIY8ZyXZGsHcJ5gwWQYDVR0jBFIwUIAULR0aj9AtiNMgqIY8ZyXZGsHcJ5ihLaQrMCkxJzAlBgNVBAMTHkxpdmUgSUQgU1RTIFNpZ25pbmcgUHVibGljIEtleYIJAL3MzqqEFMYjMAsGA1UdDwQEAwIBxjANBgkqhkiG9w0BAQUFAAOBgQBshrsF9yls4ArxOKqXdQPDgHrbynZL8m1iinLI4TeSfmTCDevXVBJrQ6SgDkihl3aCj74IEte2MWN78sHvLLTWTAkiQSlGf1Zb0durw+OvlunQ2AKbK79Qv0Q+wwGuK+oymWc3GSdP1wZqk9dhrQxb3FtdU2tMke01QTut6wr7ig=="
        ],
        "x5t":
        "dEtpjbEvbhfgwUI-bdK5xAU_9UQ"
    }]
}

SIGJWKS = KeyBundle(JWKS_b)


def P256():
    return ec.generate_private_key(ec.SECP256R1(), default_backend())


def test_1():
    claimset = {
        "iss": "joe",
        "exp": 1300819380,
        "http://example.com/is_root": True
    }

    _jws = JWS(claimset, cty="JWT", alg='none')
    _jwt = _jws.sign_compact()
Ejemplo n.º 17
0
def test_export_inactive():
    desc = {"kty": "oct", "key": "highestsupersecret", "use": "sig"}
    kb = KeyBundle([desc])
    assert len(kb.keys()) == 1
    for k in kb.keys():
        kb.mark_as_inactive(k.kid)
    desc = {"kty": "oct", "key": "highestsupersecret", "use": "enc"}
    kb.do_keys([desc])
    res = kb.dump()
    assert set(res.keys()) == {
        "cache_time",
        "fileformat",
        "httpc_params",
        "imp_jwks",
        "keys",
        "last_updated",
        "last_remote",
        "last_local",
        "remote",
        "local",
        "time_out",
    }

    kb2 = KeyBundle().load(res)
    assert len(kb2.keys()) == 2
    assert len(kb2.active_keys()) == 1
Ejemplo n.º 18
0
def test_pick_use():
    keys = KeyBundle(JWK_b)
    _jws = JWS("foobar", alg="RS256", kid="MnC_VZcATfM5pOYiJHMba9goEKY")
    _keys = _jws.pick_keys(keys, use="sig")
    assert len(_keys) == 1
Ejemplo n.º 19
0
def test_unknown_source():
    with pytest.raises(ImportError):
        KeyBundle(source="foobar")
Ejemplo n.º 20
0
def test_pick_wrong_alg():
    keys = KeyBundle(JWKS_b)
    _jws = JWS("foobar", alg="EC256", kid="rsa1")
    with pytest.raises(ValueError):
        _keys = _jws.pick_keys(keys, use="sig")
Ejemplo n.º 21
0
def test_loads_0():
    kb = KeyBundle(JWK0)
    assert len(kb) == 1
    key = kb.get("rsa")[0]
    assert key.kid == "abc"
    assert key.kty == "RSA"
Ejemplo n.º 22
0
class JWx:
    """A basic class with the commonalities between the JWS and JWE classes.

    :param alg: The signing algorithm
    :param jku: a URI that refers to a resource for a set of JSON-encoded
        public keys, one of which corresponds to the key used to digitally
        sign the JWS
    :param jwk: A JSON Web Key that corresponds to the key used to
        digitally sign the JWS
    :param x5u: a URI that refers to a resource for the X.509 public key
        certificate or certificate chain [RFC5280] corresponding to the key
        used to digitally sign the JWS.
    :param x5t: a base64url encoded SHA-1 thumbprint (a.k.a. digest) of the
        DER encoding of the X.509 certificate [RFC5280] corresponding to
        the key used to digitally sign the JWS.
    :param x5c: the X.509 public key certificate or certificate chain
        corresponding to the key used to digitally sign the JWS.
    :param kid: a hint indicating which key was used to secure the JWS.
    :param typ: the type of this object. 'JWS' == JWS Compact Serialization
        'JWS+JSON' == JWS JSON Serialization
    :param cty: the type of the secured content
    :param crit: indicates which extensions that are being used and MUST
        be understood and processed.
    :param kwargs: Extra header parameters
    :return: A class instance
    """

    args = [
        "alg", "jku", "jwk", "x5u", "x5t", "x5c", "kid", "typ", "cty", "crit"
    ]

    def __init__(self, msg=None, with_digest=False, httpc=None, **kwargs):
        self.msg = msg

        self._dict = {}
        self.with_digest = with_digest
        if httpc:
            self.httpc = httpc
        else:
            self.httpc = requests.request

        self.jwt = None
        self._jwk = None
        self._jwks = None
        self._header = {}

        if kwargs:
            for key in self.args:
                try:
                    _val = kwargs[key]
                except KeyError:
                    continue

                if key == "jwk":
                    self._set_jwk(_val)
                    self._jwk = self._dict["jwk"]
                elif key == "x5c":
                    self._dict["x5c"] = _val
                    _pub_key = import_rsa_key(_val)
                    self._jwk = RSAKey(pub_key=_pub_key).to_dict()
                elif key == "jku":
                    self._jwks = KeyBundle(source=_val, httpc=self.httpc)
                    self._dict["jku"] = _val
                elif "x5u" in self:
                    try:
                        _spec = load_x509_cert(self["x5u"], self.httpc, {})
                        self._jwk = RSAKey(pub_key=_spec["rsa"]).to_dict()
                    except Exception:
                        # ca_chain = load_x509_cert_chain(self["x5u"])
                        raise ValueError("x5u")
                else:
                    self._dict[key] = _val

    def _set_jwk(self, val):
        if isinstance(val, dict):
            _k = key_from_jwk_dict(val)
            self._dict["jwk"] = val
        elif isinstance(val, str):
            # verify that it's a real JWK
            _val = json.loads(val)
            _j = key_from_jwk_dict(_val)
            self._dict["jwk"] = _val
        elif isinstance(val, JWK):
            self._dict["jwk"] = val.to_dict()
        else:
            raise ValueError(
                "JWK must be a string a JSON object or a JWK instance")

    def __contains__(self, item):
        return item in self._dict

    def __getitem__(self, item):
        return self._dict[item]

    def __setitem__(self, key, value):
        self._dict[key] = value

    def __getattr__(self, item):
        try:
            return self._dict[item]
        except KeyError:
            raise AttributeError(item)

    def keys(self):
        """Return all keys."""
        return list(self._dict.keys())

    def _set_header_jwk(self, header, **kwargs):
        if "jwk" in self:
            header["jwk"] = self["jwk"]
        else:
            try:
                _jwk = kwargs["jwk"]
            except KeyError:
                pass
            else:
                try:
                    header["jwk"] = _jwk.serialize()  # JWK instance
                except AttributeError:
                    if isinstance(_jwk, dict):
                        header["jwk"] = _jwk  # dictionary
                    else:
                        _d = json.loads(_jwk)  # JSON
                        # Verify that it's a valid JWK
                        _k = key_from_jwk_dict(_d)
                        header["jwk"] = _d

    def headers(self, **kwargs):
        """Return the JWE/JWS header."""
        _header = self._header.copy()
        for param in self.args:
            try:
                _header[param] = kwargs[param]
            except KeyError:
                try:
                    if self._dict[param]:
                        _header[param] = self._dict[param]
                except KeyError:
                    pass

        self._set_header_jwk(_header, **kwargs)

        if "kid" in self:
            if not isinstance(self["kid"], str):
                raise HeaderError("kid of wrong value type")

        return _header

    def _get_keys(self):
        _keys = []
        if self._jwk:
            _keys.append(key_from_jwk_dict(self._jwk))
        if self._jwks is not None:
            _keys.extend(self._jwks.keys())
        return _keys

    def alg2keytype(self, alg):
        """Convert an algorithm identifier to a key type identifier."""
        raise NotImplementedError()

    def pick_keys(self, keys, use="", alg=""):
        """
        The assumption is that upper layer has made certain you only get
        keys you can use.

        :param alg: The crypto algorithm
        :param use: What the key should be used for
        :param keys: A list of JWK instances
        :return: A list of JWK instances that fulfill the requirements
        """
        if not alg:
            alg = self["alg"]

        if alg == "none":
            return []

        _k = self.alg2keytype(alg)
        if _k is None:
            LOGGER.error("Unknown algorithm '%s'", alg)
            raise ValueError("Unknown cryptography algorithm")

        LOGGER.debug("Picking key by key type=%s", _k)
        _kty = [
            _k.lower(),
            _k.upper(),
            _k.lower().encode("utf-8"),
            _k.upper().encode("utf-8"),
        ]
        _keys = [k for k in keys if k.kty in _kty]
        try:
            _kid = self["kid"]
        except KeyError:
            try:
                _kid = self.jwt.headers["kid"]
            except (AttributeError, KeyError):
                _kid = None

        LOGGER.debug("Picking key based on alg=%s, kid=%s and use=%s", alg,
                     _kid, use)

        pkey = []
        for _key in _keys:
            LOGGER.debug("Picked: kid:%s, use:%s, kty:%s", _key.kid, _key.use,
                         _key.kty)
            if _kid:
                if _kid != _key.kid:
                    continue

            if use and _key.use and _key.use != use:
                continue

            if alg and _key.alg and _key.alg != alg:
                continue

            pkey.append(_key)

        return pkey

    def _pick_alg(self, keys):
        alg = None
        try:
            alg = self["alg"]
        except KeyError:
            # try to get alg from key if there is only one
            if keys is not None and len(keys) == 1:
                key = next(
                    iter(keys))  # first element from either list or dict
                if key.alg:
                    self["alg"] = alg = key.alg

        if not alg:
            self["alg"] = alg = "none"

        return alg

    def _decode(self, payload):
        _msg = b64d(as_bytes(payload))
        if "cty" in self:
            if self["cty"] == "JWT":
                _msg = json.loads(as_unicode(_msg))
        return _msg

    def dump_header(self):
        """Return all attributes with values."""
        return {x: self._dict[x] for x in self.args if x in self._dict}
Ejemplo n.º 23
0
def test_jwks_url():
    keys = KeyBundle(source="https://login.salesforce.com/id/keys")
    # Forces read from the network
    keys.update()
    assert len(keys)
Ejemplo n.º 24
0
 def test_no_use(self):
     kb = KeyBundle(JWK0["keys"])
     kj = KeyJar()
     kj.add_kb("abcdefgh", kb)
     enc_key = kj.get_encrypt_key("RSA", "abcdefgh")
     assert enc_key != []