def test_function_key(self): protected = {'alg': 'HS256'} header = [ { 'protected': protected, 'header': { 'kid': 'a' } }, { 'protected': protected, 'header': { 'kid': 'b' } }, ] def load_key(header, payload): self.assertEqual(payload, b'hello') kid = header.get('kid') if kid == 'a': return 'secret-a' return 'secret-b' jws = JWS(algorithms=JWS_ALGORITHMS) s = jws.serialize(header, b'hello', load_key) self.assertIsInstance(s, dict) self.assertIn('signatures', s) data = jws.deserialize(json.dumps(s), load_key) header, payload = data['header'], data['payload'] self.assertEqual(payload, b'hello') self.assertEqual(header[0]['alg'], 'HS256') self.assertNotIn('signature', data)
def test_rsa_encode_decode(self): jws = JWS(algorithms=JWS_ALGORITHMS) s = jws.encode({'alg': 'RS256'}, 'hello', read_file_path('rsa_private.pem')) header, payload = jws.decode(s, read_file_path('rsa_public.pem')) self.assertEqual(payload, b'hello') self.assertEqual(header['alg'], 'RS256')
def test_compact_jws(self): jws = JWS(algorithms=JWS_ALGORITHMS) s = jws.serialize({'alg': 'HS256'}, 'hello', 'secret') data = jws.deserialize(s, 'secret') header, payload = data['header'], data['payload'] self.assertEqual(payload, b'hello') self.assertEqual(header['alg'], 'HS256') self.assertNotIn('signature', data)
def test_compact_rsa(self): jws = JWS(algorithms=JWS_ALGORITHMS) s = jws.serialize({'alg': 'RS256'}, 'hello', read_file_path('rsa_private.pem')) data = jws.deserialize(s, read_file_path('rsa_public.pem')) header, payload = data['header'], data['payload'] self.assertEqual(payload, b'hello') self.assertEqual(header['alg'], 'RS256')
def test_validate_header(self): jws = JWS(algorithms=JWS_ALGORITHMS) protected = {'alg': 'HS256', 'invalid': 'k'} header = {'protected': protected, 'header': {'kid': 'a'}} self.assertRaises(errors.InvalidHeaderParameterName, jws.serialize, header, b'hello', 'secret') jws = JWS(algorithms=JWS_ALGORITHMS, private_headers=['invalid']) s = jws.serialize(header, b'hello', 'secret') self.assertIsInstance(s, dict)
def test_flattened_json_jws(self): jws = JWS(algorithms=JWS_ALGORITHMS) protected = {'alg': 'HS256'} header = {'protected': protected, 'header': {'kid': 'a'}} s = jws.serialize(header, 'hello', 'secret') self.assertIsInstance(s, dict) data = jws.deserialize(s, 'secret') header, payload = data['header'], data['payload'] self.assertEqual(payload, b'hello') self.assertEqual(header['alg'], 'HS256') self.assertNotIn('protected', data)
def __init__(self, algorithms=None, private_headers=None): if algorithms is None: self._jws = JWS(JWS_ALGORITHMS, private_headers) self._jwe = JWE(JWE_ALGORITHMS, private_headers) else: self._jws = JWS(None, private_headers) self._jwe = JWE(None, private_headers) if isinstance(algorithms, (tuple, list)): for algorithm in algorithms: self.register_algorithm(algorithm) elif isinstance(algorithms, text_types): self.register_algorithm(algorithms)
def test_register_invalid_algorithms(self): jws = JWS(algorithms=[]) self.assertRaises( ValueError, jws.register_algorithm, JWE_ALGORITHMS[0] )
def test_nested_json_jws(self): jws = JWS(algorithms=JWS_ALGORITHMS) protected = {'alg': 'HS256'} header = {'protected': protected, 'header': {'kid': 'a'}} s = jws.serialize([header], 'hello', 'secret') self.assertIsInstance(s, dict) self.assertIn('signatures', s) data = jws.deserialize(s, 'secret') header, payload = data['header'], data['payload'] self.assertEqual(payload, b'hello') self.assertEqual(header[0]['alg'], 'HS256') self.assertNotIn('signatures', data) # test bad signature self.assertRaises(errors.BadSignatureError, jws.deserialize, s, 'f')
def test_invalid_alg(self): jws = JWS(algorithms=JWS_ALGORITHMS) self.assertRaises(errors.UnsupportedAlgorithmError, jws.decode, 'eyJhbGciOiJzIn0.YQ.YQ', 'k') self.assertRaises(errors.MissingAlgorithmError, jws.encode, {}, '', 'k') self.assertRaises(errors.UnsupportedAlgorithmError, jws.encode, {'alg': 's'}, '', 'k')
def test_fail_deserialize_json(self): jws = JWS(algorithms=JWS_ALGORITHMS) self.assertRaises(errors.DecodeError, jws.deserialize_json, None, '') self.assertRaises(errors.DecodeError, jws.deserialize_json, '[]', '') self.assertRaises(errors.DecodeError, jws.deserialize_json, '{}', '') # missing protected s = json.dumps({'payload': 'YQ'}) self.assertRaises(errors.DecodeError, jws.deserialize_json, s, '') # missing signature s = json.dumps({'payload': 'YQ', 'protected': 'YQ'}) self.assertRaises(errors.DecodeError, jws.deserialize_json, s, '')
def test_invalid_input(self): jws = JWS(algorithms=JWS_ALGORITHMS) self.assertRaises(errors.DecodeError, jws.decode, 'a', 'k') self.assertRaises(errors.DecodeError, jws.decode, 'a.b.c', 'k') self.assertRaises(errors.DecodeError, jws.decode, 'YQ.YQ.YQ', 'k') # a self.assertRaises(errors.DecodeError, jws.decode, 'W10.a.YQ', 'k') # [] self.assertRaises(errors.DecodeError, jws.decode, 'e30.a.YQ', 'k') # {} self.assertRaises(errors.DecodeError, jws.decode, 'eyJhbGciOiJzIn0.a.YQ', 'k') self.assertRaises(errors.DecodeError, jws.decode, 'eyJhbGciOiJzIn0.YQ.a', 'k')
def test_bad_signature(self): jws = JWS(algorithms=JWS_ALGORITHMS) s = 'eyJhbGciOiJIUzI1NiJ9.YQ.YQ' self.assertRaises(errors.BadSignatureError, jws.deserialize, s, 'k')
def test_success_encode_decode(self): jws = JWS(algorithms=JWS_ALGORITHMS) s = jws.encode({'alg': 'HS256'}, 'hello', 'secret') header, payload = jws.decode(s, 'secret') self.assertEqual(payload, b'hello') self.assertEqual(header['alg'], 'HS256')
def test_rsa_encode_decode(self): jws = JWS(algorithms=JWS_ALGORITHMS) s = jws.encode({'alg': 'RS256'}, 'hello', get_rsa_private_key()) header, payload = jws.decode(s, get_rsa_public_key()) self.assertEqual(payload, b'hello') self.assertEqual(header['alg'], 'RS256')
class JWT(object): SENSITIVE_NAMES = ('password', 'token', 'secret', 'secret_key') # Thanks to sentry SensitiveDataFilter SENSITIVE_VALUES = re.compile(r'|'.join([ # http://www.richardsramblings.com/regex/credit-card-numbers/ r'\b(?:3[47]\d|(?:4\d|5[1-5]|65)\d{2}|6011)\d{12}\b', # various private keys r'-----BEGIN[A-Z ]+PRIVATE KEY-----.+-----END[A-Z ]+PRIVATE KEY-----', # social security numbers (US) r'^\b(?!(000|666|9))\d{3}-(?!00)\d{2}-(?!0000)\d{4}\b', ]), re.DOTALL) def __init__(self, algorithms=None, private_headers=None): if algorithms is None: self._jws = JWS(JWS_ALGORITHMS, private_headers) self._jwe = JWE(JWE_ALGORITHMS, private_headers) else: self._jws = JWS(None, private_headers) self._jwe = JWE(None, private_headers) if isinstance(algorithms, (tuple, list)): for algorithm in algorithms: self.register_algorithm(algorithm) elif isinstance(algorithms, text_types): self.register_algorithm(algorithms) def register_algorithm(self, algorithm): if isinstance(algorithm, text_types): algorithm = _AVAILABLE_ALGORITHMS.get(algorithm) if algorithm.TYPE == 'JWS': self._jws.register_algorithm(algorithm) elif algorithm.TYPE == 'JWE': self._jwe.register_algorithm(algorithm) def check_sensitive_data(self, payload): """Check if payload contains sensitive information.""" for k in payload: # check claims key name if k in self.SENSITIVE_NAMES: raise InsecureClaimError(k) # check claims values v = payload[k] if isinstance(v, text_types) and self.SENSITIVE_VALUES.search(v): raise InsecureClaimError(k) def encode(self, header, payload, key, check=True): """Encode a JWT with the given header, payload and key. :param header: A dict of JWS header :param payload: A dict to be encoded :param key: key used to sign the signature :param check: check if sensitive data in payload :return: JWT """ header['typ'] = 'JWT' for k in ['exp', 'iat', 'nbf']: # convert datetime into timestamp claim = payload.get(k) if isinstance(claim, datetime.datetime): payload[k] = calendar.timegm(claim.utctimetuple()) if check: self.check_sensitive_data(payload) key = load_key(key, header, payload) text = to_bytes(json.dumps(payload, separators=(',', ':'))) if 'enc' in header: return self._jwe.serialize_compact(header, text, key) else: return self._jws.serialize_compact(header, text, key) def decode(self, s, key, claims_cls=None, claims_options=None, claims_params=None): """Decode the JWS with the given key. This is similar with :meth:`verify`, except that it will raise BadSignatureError when signature doesn't match. :param s: text of JWT :param key: key used to verify the signature :param claims_cls: class to be used for JWT claims :param claims_options: `options` parameters for claims_cls :param claims_params: `params` parameters for claims_cls :return: claims_cls instance :raise: BadSignatureError """ if claims_cls is None: claims_cls = JWTClaims key_func = create_key_func(key) s = to_bytes(s) dot_count = s.count(b'.') if dot_count == 2: data = self._jws.deserialize_compact(s, key_func, decode_payload) elif dot_count == 4: data = self._jwe.deserialize_compact(s, key_func, decode_payload) else: raise DecodeError('Invalid input segments length') return claims_cls( data['payload'], data['header'], options=claims_options, params=claims_params, )