示例#1
0
def test_export_inactive():
    desc = {"kty": "oct", "key": "highestsupersecret", "use": "sig"}
    kb = KeyBundle([desc])
    assert len(kb.keys()) == 1
    for k in kb.keys():
        kb.mark_as_inactive(k.kid)
    desc = {"kty": "oct", "key": "highestsupersecret", "use": "enc"}
    kb.do_keys([desc])
    res = kb.dump()
    assert set(res.keys()) == {
        "cache_time",
        "fileformat",
        "httpc_params",
        "imp_jwks",
        "keys",
        "last_updated",
        "last_remote",
        "last_local",
        "remote",
        "local",
        "time_out",
    }

    kb2 = KeyBundle().load(res)
    assert len(kb2.keys()) == 2
    assert len(kb2.active_keys()) == 1
示例#2
0
def test_export_inactive():
    desc = {"kty": "oct", "key": "highestsupersecret", "use": "sig"}
    kb = KeyBundle([desc])
    assert len(kb.keys()) == 1
    for k in kb.keys():
        kb.mark_as_inactive(k.kid)
    desc = {"kty": "oct", "key": "highestsupersecret", "use": "enc"}
    kb.add_jwk_dicts([desc])
    res = kb.dump()
    assert set(res.keys()) == {
        "cache_time",
        "etag",
        "fileformat",
        "httpc_params",
        "ignore_errors_until",
        "ignore_errors_period",
        "ignore_invalid_keys",
        "imp_jwks",
        "keys",
        "keytype",
        "keyusage",
        "last_updated",
        "last_remote",
        "last_local",
        "remote",
        "local",
        "source",
        "time_out",
    }

    kb2 = KeyBundle().load(res)
    assert len(kb2.keys()) == 2
    assert len(kb2.active_keys()) == 1
示例#3
0
def test_mark_as_inactive():
    desc = {"kty": "oct", "key": "highestsupersecret", "use": "sig"}
    kb = KeyBundle([desc])
    assert len(kb.keys()) == 1
    for k in kb.keys():
        kb.mark_as_inactive(k.kid)
    desc = {"kty": "oct", "key": "highestsupersecret", "use": "enc"}
    kb.do_keys([desc])
    assert len(kb.keys()) == 2
    assert len(kb.active_keys()) == 1
示例#4
0
def test_copy():
    desc = {"kty": "oct", "key": "highestsupersecret", "use": "sig"}
    kb = KeyBundle([desc])
    assert len(kb.keys()) == 1
    for k in kb.keys():
        kb.mark_as_inactive(k.kid)
    desc = {"kty": "oct", "key": "highestsupersecret", "use": "enc"}
    kb.add_jwk_dicts([desc])

    kbc = kb.copy()
    assert len(kbc.keys()) == 2
    assert len(kbc.active_keys()) == 1
示例#5
0
def test_local_jwk_update():
    cache_time = 0.1
    _path = full_path("jwk_private_key.json")
    kb = KeyBundle(source="file://{}".format(_path), cache_time=cache_time)
    assert kb
    _ = kb.keys()
    last1 = kb.last_local
    _ = kb.keys()
    last2 = kb.last_local
    assert last1 == last2  # file not changed
    time.sleep(cache_time + 0.1)
    Path(_path).touch()
    _ = kb.keys()
    last3 = kb.last_local
    assert last2 != last3  # file changed
示例#6
0
def test_outdated():
    a = {"kty": "oct", "key": "highestsupersecret", "use": "sig"}
    b = {"kty": "oct", "key": "highestsupersecret", "use": "enc"}
    kb = KeyBundle([a, b])
    keys = kb.keys()
    now = time.time()
    keys[0].inactive_since = now - 60
    kb.remove_outdated(30)
    assert len(kb) == 1
示例#7
0
def test_unique_keys_2():
    _kb0 = build_key_bundle(key_conf=KEYSPEC_6)
    _kb1 = KeyBundle()
    _kb1.extend(_kb0.keys())

    keys = _kb0.keys()
    keys.extend(_kb1.keys())

    # 3 of 6
    assert len(unique_keys(keys)) == 3
示例#8
0
def test_jws_verifier_with_multiple_keys():
    # Set up phase: parse the keys and initialize the verifier.
    jwks = KeyBundle(json.loads(test_vector.json_pub_keys))
    keys = jwks.keys()

    verifier = JWS(alg='RS256')
    assert verifier.verify_compact(test_vector.rsa_token, keys)
    for modified_token in modify_token(
            test_vector.rsa_token,
        ['RS384', 'RS512', 'PS256', 'PS384', 'PS512']):
        with pytest.raises(JWKESTException):
            verifier.verify_compact(modified_token, keys)

    verifier = JWS(alg='ES256')
    assert verifier.verify_compact(test_vector.es256_ecdsa_token, keys)
    for modified_token in modify_token(test_vector.es256_ecdsa_token,
                                       ['ES384', 'ES512']):
        with pytest.raises(JWKESTException):
            verifier.verify_compact(modified_token, keys)
def test_remote_not_modified():
    source = "https://example.com/keys.json"
    headers = {
        "Date": "Fri, 15 Mar 2019 10:14:25 GMT",
        "Last-Modified": "Fri, 1 Jan 1970 00:00:00 GMT",
    }
    headers = {}

    # Mock response
    httpc_params = {"timeout": (2, 2)}  # connect, read timeouts in seconds
    kb = KeyBundle(source=source,
                   httpc=requests.request,
                   httpc_params=httpc_params)

    with responses.RequestsMock() as rsps:
        rsps.add(method="GET",
                 url=source,
                 json=JWKS_DICT,
                 status=200,
                 headers=headers)
        assert kb.do_remote()
        assert kb.last_remote == headers.get("Last-Modified")
        timeout1 = kb.time_out

    with responses.RequestsMock() as rsps:
        rsps.add(method="GET", url=source, status=304, headers=headers)
        assert not kb.do_remote()
        assert kb.last_remote == headers.get("Last-Modified")
        timeout2 = kb.time_out

    assert timeout1 != timeout2

    exp = kb.dump()
    kb2 = KeyBundle().load(exp)
    assert kb2.source == source
    assert len(kb2.keys()) == 3
    assert len(kb2.active_keys()) == 3
    assert len(kb2.get("rsa")) == 1
    assert len(kb2.get("oct")) == 1
    assert len(kb2.get("ec")) == 1
    assert kb2.httpc_params == {"timeout": (2, 2)}
    assert kb2.imp_jwks
    assert kb2.last_updated
示例#10
0
def test_remote():
    source = "https://example.com/keys.json"
    # Mock response
    with responses.RequestsMock() as rsps:
        rsps.add(method="GET", url=source, json=JWKS_DICT, status=200)
        httpc_params = {"timeout": (2, 2)}  # connect, read timeouts in seconds
        kb = KeyBundle(source=source, httpc=requests.request, httpc_params=httpc_params)
        kb._do_remote()

    exp = kb.dump()
    kb2 = KeyBundle().load(exp)
    assert kb2.source == source
    assert len(kb2.keys()) == 3
    assert len(kb2.get("rsa")) == 1
    assert len(kb2.get("oct")) == 1
    assert len(kb2.get("ec")) == 1
    assert kb2.httpc_params == {"timeout": (2, 2)}
    assert kb2.imp_jwks
    assert kb2.last_updated
示例#11
0
class JWx:
    """A basic class with the commonalities between the JWS and JWE classes.

    :param alg: The signing algorithm
    :param jku: a URI that refers to a resource for a set of JSON-encoded
        public keys, one of which corresponds to the key used to digitally
        sign the JWS
    :param jwk: A JSON Web Key that corresponds to the key used to
        digitally sign the JWS
    :param x5u: a URI that refers to a resource for the X.509 public key
        certificate or certificate chain [RFC5280] corresponding to the key
        used to digitally sign the JWS.
    :param x5t: a base64url encoded SHA-1 thumbprint (a.k.a. digest) of the
        DER encoding of the X.509 certificate [RFC5280] corresponding to
        the key used to digitally sign the JWS.
    :param x5c: the X.509 public key certificate or certificate chain
        corresponding to the key used to digitally sign the JWS.
    :param kid: a hint indicating which key was used to secure the JWS.
    :param typ: the type of this object. 'JWS' == JWS Compact Serialization
        'JWS+JSON' == JWS JSON Serialization
    :param cty: the type of the secured content
    :param crit: indicates which extensions that are being used and MUST
        be understood and processed.
    :param kwargs: Extra header parameters
    :return: A class instance
    """
    args = [
        "alg", "jku", "jwk", "x5u", "x5t", "x5c", "kid", "typ", "cty", "crit"
    ]

    def __init__(self, msg=None, with_digest=False, httpc=None, **kwargs):
        self.msg = msg

        self._dict = {}
        self.with_digest = with_digest
        if httpc:
            self.httpc = httpc
        else:
            self.httpc = requests.request

        self.jwt = None
        self._jwk = None
        self._jwks = None
        self._header = {}

        if kwargs:
            for key in self.args:
                try:
                    _val = kwargs[key]
                except KeyError:
                    continue

                if key == "jwk":
                    self._set_jwk(_val)
                    self._jwk = self._dict['jwk']
                elif key == "x5c":
                    self._dict["x5c"] = _val
                    _pub_key = import_rsa_key(_val)
                    self._jwk = RSAKey(pub_key=_pub_key).to_dict()
                elif key == "jku":
                    self._jwks = KeyBundle(source=_val, httpc=self.httpc)
                    self._dict['jku'] = _val
                elif "x5u" in self:
                    try:
                        _spec = load_x509_cert(self["x5u"], self.httpc, {})
                        self._jwk = RSAKey(pub_key=_spec['rsa']).to_dict()
                    except Exception:
                        # ca_chain = load_x509_cert_chain(self["x5u"])
                        raise ValueError('x5u')
                else:
                    self._dict[key] = _val

    def _set_jwk(self, val):
        if isinstance(val, dict):
            _k = key_from_jwk_dict(val)
            self._dict["jwk"] = val
        elif isinstance(val, str):
            # verify that it's a real JWK
            _val = json.loads(val)
            _j = key_from_jwk_dict(_val)
            self._dict["jwk"] = _val
        elif isinstance(val, JWK):
            self._dict['jwk'] = val.to_dict()
        else:
            raise ValueError(
                'JWK must be a string a JSON object or a JWK instance')

    def __contains__(self, item):
        return item in self._dict

    def __getitem__(self, item):
        return self._dict[item]

    def __setitem__(self, key, value):
        self._dict[key] = value

    def __getattr__(self, item):
        try:
            return self._dict[item]
        except KeyError:
            raise AttributeError(item)

    def keys(self):
        """Return all keys."""
        return list(self._dict.keys())

    def _set_header_jwk(self, header, **kwargs):
        if "jwk" in self:
            header["jwk"] = self["jwk"]
        else:
            try:
                _jwk = kwargs['jwk']
            except KeyError:
                pass
            else:
                try:
                    header["jwk"] = _jwk.serialize()  # JWK instance
                except AttributeError:
                    if isinstance(_jwk, dict):
                        header['jwk'] = _jwk  # dictionary
                    else:
                        _d = json.loads(_jwk)  # JSON
                        # Verify that it's a valid JWK
                        _k = key_from_jwk_dict(_d)
                        header['jwk'] = _d

    def headers(self, **kwargs):
        """Return the JWE/JWS header."""
        _header = self._header.copy()
        for param in self.args:
            try:
                _header[param] = kwargs[param]
            except KeyError:
                try:
                    if self._dict[param]:
                        _header[param] = self._dict[param]
                except KeyError:
                    pass

        self._set_header_jwk(_header, **kwargs)

        if "kid" in self:
            if not isinstance(self["kid"], str):
                raise HeaderError("kid of wrong value type")

        return _header

    def _get_keys(self):
        _keys = []
        if self._jwk:
            _keys.append(key_from_jwk_dict(self._jwk))
        if self._jwks is not None:
            _keys.extend(self._jwks.keys())
        return _keys

    def alg2keytype(self, alg):
        """Convert an algorithm identifier to a key type identifier."""
        raise NotImplementedError()

    def pick_keys(self, keys, use="", alg=""):
        """
        The assumption is that upper layer has made certain you only get
        keys you can use.

        :param alg: The crypto algorithm
        :param use: What the key should be used for
        :param keys: A list of JWK instances
        :return: A list of JWK instances that fulfill the requirements
        """
        if not alg:
            alg = self["alg"]

        if alg == "none":
            return []

        _k = self.alg2keytype(alg)
        if _k is None:
            LOGGER.error("Unknown algorithm '%s'", alg)
            raise ValueError('Unknown cryptography algorithm')

        LOGGER.debug("Picking key by key type=%s", _k)
        _kty = [
            _k.lower(),
            _k.upper(),
            _k.lower().encode("utf-8"),
            _k.upper().encode("utf-8")
        ]
        _keys = [k for k in keys if k.kty in _kty]
        try:
            _kid = self["kid"]
        except KeyError:
            try:
                _kid = self.jwt.headers["kid"]
            except (AttributeError, KeyError):
                _kid = None

        LOGGER.debug("Picking key based on alg=%s, kid=%s and use=%s", alg,
                     _kid, use)

        pkey = []
        for _key in _keys:
            LOGGER.debug("Picked: kid:%s, use:%s, kty:%s", _key.kid, _key.use,
                         _key.kty)
            if _kid:
                if _kid != _key.kid:
                    continue

            if use and _key.use and _key.use != use:
                continue

            if alg and _key.alg and _key.alg != alg:
                continue

            pkey.append(_key)

        return pkey

    def _pick_alg(self, keys):
        alg = None
        try:
            alg = self["alg"]
        except KeyError:
            # try to get alg from key if there is only one
            if keys is not None and len(keys) == 1:
                key = next(
                    iter(keys))  # first element from either list or dict
                if key.alg:
                    self["alg"] = alg = key.alg

        if not alg:
            self["alg"] = alg = "none"

        return alg

    def _decode(self, payload):
        _msg = b64d(as_bytes(payload))
        if "cty" in self:
            if self["cty"] == "JWT":
                _msg = json.loads(as_unicode(_msg))
        return _msg

    def dump_header(self):
        """Return all attributes with values."""
        return {x: self._dict[x] for x in self.args if x in self._dict}