Exemple #1
0
    def test_function_key(self):
        protected = {'alg': 'HS256'}
        header = [
            {
                'protected': protected,
                'header': {
                    'kid': 'a'
                }
            },
            {
                'protected': protected,
                'header': {
                    'kid': 'b'
                }
            },
        ]

        def load_key(header, payload):
            self.assertEqual(payload, b'hello')
            kid = header.get('kid')
            if kid == 'a':
                return 'secret-a'
            return 'secret-b'

        jws = JWS(algorithms=JWS_ALGORITHMS)
        s = jws.serialize(header, b'hello', load_key)
        self.assertIsInstance(s, dict)
        self.assertIn('signatures', s)

        data = jws.deserialize(json.dumps(s), load_key)
        header, payload = data['header'], data['payload']
        self.assertEqual(payload, b'hello')
        self.assertEqual(header[0]['alg'], 'HS256')
        self.assertNotIn('signature', data)
Exemple #2
0
 def test_rsa_encode_decode(self):
     jws = JWS(algorithms=JWS_ALGORITHMS)
     s = jws.encode({'alg': 'RS256'}, 'hello',
                    read_file_path('rsa_private.pem'))
     header, payload = jws.decode(s, read_file_path('rsa_public.pem'))
     self.assertEqual(payload, b'hello')
     self.assertEqual(header['alg'], 'RS256')
Exemple #3
0
 def test_compact_jws(self):
     jws = JWS(algorithms=JWS_ALGORITHMS)
     s = jws.serialize({'alg': 'HS256'}, 'hello', 'secret')
     data = jws.deserialize(s, 'secret')
     header, payload = data['header'], data['payload']
     self.assertEqual(payload, b'hello')
     self.assertEqual(header['alg'], 'HS256')
     self.assertNotIn('signature', data)
Exemple #4
0
 def test_compact_rsa(self):
     jws = JWS(algorithms=JWS_ALGORITHMS)
     s = jws.serialize({'alg': 'RS256'}, 'hello',
                       read_file_path('rsa_private.pem'))
     data = jws.deserialize(s, read_file_path('rsa_public.pem'))
     header, payload = data['header'], data['payload']
     self.assertEqual(payload, b'hello')
     self.assertEqual(header['alg'], 'RS256')
Exemple #5
0
 def test_validate_header(self):
     jws = JWS(algorithms=JWS_ALGORITHMS)
     protected = {'alg': 'HS256', 'invalid': 'k'}
     header = {'protected': protected, 'header': {'kid': 'a'}}
     self.assertRaises(errors.InvalidHeaderParameterName, jws.serialize,
                       header, b'hello', 'secret')
     jws = JWS(algorithms=JWS_ALGORITHMS, private_headers=['invalid'])
     s = jws.serialize(header, b'hello', 'secret')
     self.assertIsInstance(s, dict)
Exemple #6
0
    def test_flattened_json_jws(self):
        jws = JWS(algorithms=JWS_ALGORITHMS)
        protected = {'alg': 'HS256'}
        header = {'protected': protected, 'header': {'kid': 'a'}}
        s = jws.serialize(header, 'hello', 'secret')
        self.assertIsInstance(s, dict)

        data = jws.deserialize(s, 'secret')
        header, payload = data['header'], data['payload']
        self.assertEqual(payload, b'hello')
        self.assertEqual(header['alg'], 'HS256')
        self.assertNotIn('protected', data)
Exemple #7
0
    def __init__(self, algorithms=None, private_headers=None):
        if algorithms is None:
            self._jws = JWS(JWS_ALGORITHMS, private_headers)
            self._jwe = JWE(JWE_ALGORITHMS, private_headers)
        else:
            self._jws = JWS(None, private_headers)
            self._jwe = JWE(None, private_headers)

            if isinstance(algorithms, (tuple, list)):
                for algorithm in algorithms:
                    self.register_algorithm(algorithm)
            elif isinstance(algorithms, text_types):
                self.register_algorithm(algorithms)
Exemple #8
0
 def test_register_invalid_algorithms(self):
     jws = JWS(algorithms=[])
     self.assertRaises(
         ValueError,
         jws.register_algorithm,
         JWE_ALGORITHMS[0]
     )
Exemple #9
0
    def test_nested_json_jws(self):
        jws = JWS(algorithms=JWS_ALGORITHMS)
        protected = {'alg': 'HS256'}
        header = {'protected': protected, 'header': {'kid': 'a'}}
        s = jws.serialize([header], 'hello', 'secret')
        self.assertIsInstance(s, dict)
        self.assertIn('signatures', s)

        data = jws.deserialize(s, 'secret')
        header, payload = data['header'], data['payload']
        self.assertEqual(payload, b'hello')
        self.assertEqual(header[0]['alg'], 'HS256')
        self.assertNotIn('signatures', data)

        # test bad signature
        self.assertRaises(errors.BadSignatureError, jws.deserialize, s, 'f')
Exemple #10
0
 def test_invalid_alg(self):
     jws = JWS(algorithms=JWS_ALGORITHMS)
     self.assertRaises(errors.UnsupportedAlgorithmError, jws.decode,
                       'eyJhbGciOiJzIn0.YQ.YQ', 'k')
     self.assertRaises(errors.MissingAlgorithmError, jws.encode, {}, '',
                       'k')
     self.assertRaises(errors.UnsupportedAlgorithmError, jws.encode,
                       {'alg': 's'}, '', 'k')
Exemple #11
0
    def test_fail_deserialize_json(self):
        jws = JWS(algorithms=JWS_ALGORITHMS)
        self.assertRaises(errors.DecodeError, jws.deserialize_json, None, '')
        self.assertRaises(errors.DecodeError, jws.deserialize_json, '[]', '')
        self.assertRaises(errors.DecodeError, jws.deserialize_json, '{}', '')

        # missing protected
        s = json.dumps({'payload': 'YQ'})
        self.assertRaises(errors.DecodeError, jws.deserialize_json, s, '')

        # missing signature
        s = json.dumps({'payload': 'YQ', 'protected': 'YQ'})
        self.assertRaises(errors.DecodeError, jws.deserialize_json, s, '')
Exemple #12
0
 def test_invalid_input(self):
     jws = JWS(algorithms=JWS_ALGORITHMS)
     self.assertRaises(errors.DecodeError, jws.decode, 'a', 'k')
     self.assertRaises(errors.DecodeError, jws.decode, 'a.b.c', 'k')
     self.assertRaises(errors.DecodeError, jws.decode, 'YQ.YQ.YQ', 'k')  # a
     self.assertRaises(errors.DecodeError, jws.decode, 'W10.a.YQ',
                       'k')  # []
     self.assertRaises(errors.DecodeError, jws.decode, 'e30.a.YQ',
                       'k')  # {}
     self.assertRaises(errors.DecodeError, jws.decode,
                       'eyJhbGciOiJzIn0.a.YQ', 'k')
     self.assertRaises(errors.DecodeError, jws.decode,
                       'eyJhbGciOiJzIn0.YQ.a', 'k')
Exemple #13
0
 def test_bad_signature(self):
     jws = JWS(algorithms=JWS_ALGORITHMS)
     s = 'eyJhbGciOiJIUzI1NiJ9.YQ.YQ'
     self.assertRaises(errors.BadSignatureError, jws.deserialize, s, 'k')
Exemple #14
0
 def test_success_encode_decode(self):
     jws = JWS(algorithms=JWS_ALGORITHMS)
     s = jws.encode({'alg': 'HS256'}, 'hello', 'secret')
     header, payload = jws.decode(s, 'secret')
     self.assertEqual(payload, b'hello')
     self.assertEqual(header['alg'], 'HS256')
Exemple #15
0
 def test_rsa_encode_decode(self):
     jws = JWS(algorithms=JWS_ALGORITHMS)
     s = jws.encode({'alg': 'RS256'}, 'hello', get_rsa_private_key())
     header, payload = jws.decode(s, get_rsa_public_key())
     self.assertEqual(payload, b'hello')
     self.assertEqual(header['alg'], 'RS256')
Exemple #16
0
class JWT(object):
    SENSITIVE_NAMES = ('password', 'token', 'secret', 'secret_key')
    # Thanks to sentry SensitiveDataFilter
    SENSITIVE_VALUES = re.compile(r'|'.join([
        # http://www.richardsramblings.com/regex/credit-card-numbers/
        r'\b(?:3[47]\d|(?:4\d|5[1-5]|65)\d{2}|6011)\d{12}\b',
        # various private keys
        r'-----BEGIN[A-Z ]+PRIVATE KEY-----.+-----END[A-Z ]+PRIVATE KEY-----',
        # social security numbers (US)
        r'^\b(?!(000|666|9))\d{3}-(?!00)\d{2}-(?!0000)\d{4}\b',
    ]), re.DOTALL)

    def __init__(self, algorithms=None, private_headers=None):
        if algorithms is None:
            self._jws = JWS(JWS_ALGORITHMS, private_headers)
            self._jwe = JWE(JWE_ALGORITHMS, private_headers)
        else:
            self._jws = JWS(None, private_headers)
            self._jwe = JWE(None, private_headers)

            if isinstance(algorithms, (tuple, list)):
                for algorithm in algorithms:
                    self.register_algorithm(algorithm)
            elif isinstance(algorithms, text_types):
                self.register_algorithm(algorithms)

    def register_algorithm(self, algorithm):
        if isinstance(algorithm, text_types):
            algorithm = _AVAILABLE_ALGORITHMS.get(algorithm)

        if algorithm.TYPE == 'JWS':
            self._jws.register_algorithm(algorithm)
        elif algorithm.TYPE == 'JWE':
            self._jwe.register_algorithm(algorithm)

    def check_sensitive_data(self, payload):
        """Check if payload contains sensitive information."""
        for k in payload:
            # check claims key name
            if k in self.SENSITIVE_NAMES:
                raise InsecureClaimError(k)

            # check claims values
            v = payload[k]
            if isinstance(v, text_types) and self.SENSITIVE_VALUES.search(v):
                raise InsecureClaimError(k)

    def encode(self, header, payload, key, check=True):
        """Encode a JWT with the given header, payload and key.

        :param header: A dict of JWS header
        :param payload: A dict to be encoded
        :param key: key used to sign the signature
        :param check: check if sensitive data in payload
        :return: JWT
        """
        header['typ'] = 'JWT'

        for k in ['exp', 'iat', 'nbf']:
            # convert datetime into timestamp
            claim = payload.get(k)
            if isinstance(claim, datetime.datetime):
                payload[k] = calendar.timegm(claim.utctimetuple())

        if check:
            self.check_sensitive_data(payload)

        key = load_key(key, header, payload)
        text = to_bytes(json.dumps(payload, separators=(',', ':')))
        if 'enc' in header:
            return self._jwe.serialize_compact(header, text, key)
        else:
            return self._jws.serialize_compact(header, text, key)

    def decode(self, s, key, claims_cls=None,
               claims_options=None, claims_params=None):
        """Decode the JWS with the given key. This is similar with
        :meth:`verify`, except that it will raise BadSignatureError when
        signature doesn't match.

        :param s: text of JWT
        :param key: key used to verify the signature
        :param claims_cls: class to be used for JWT claims
        :param claims_options: `options` parameters for claims_cls
        :param claims_params: `params` parameters for claims_cls
        :return: claims_cls instance
        :raise: BadSignatureError
        """
        if claims_cls is None:
            claims_cls = JWTClaims

        key_func = create_key_func(key)

        s = to_bytes(s)
        dot_count = s.count(b'.')
        if dot_count == 2:
            data = self._jws.deserialize_compact(s, key_func, decode_payload)
        elif dot_count == 4:
            data = self._jwe.deserialize_compact(s, key_func, decode_payload)
        else:
            raise DecodeError('Invalid input segments length')
        return claims_cls(
            data['payload'], data['header'],
            options=claims_options,
            params=claims_params,
        )