def test_decode_with_invalid_sig(self): self.payload["exp"] = aware_utcnow() - timedelta(seconds=1) for backend in self.backends: with self.subTest( f"Test decode with invalid sig for {backend.algorithm}"): payload = self.payload.copy() payload["exp"] = aware_utcnow() + timedelta(days=1) token_1 = jwt.encode(payload, backend.signing_key, algorithm=backend.algorithm) payload["foo"] = "baz" token_2 = jwt.encode(payload, backend.signing_key, algorithm=backend.algorithm) if IS_OLD_JWT: token_1 = token_1.decode("utf-8") token_2 = token_2.decode("utf-8") token_2_payload = token_2.rsplit(".", 1)[0] token_1_sig = token_1.rsplit(".", 1)[-1] invalid_token = token_2_payload + "." + token_1_sig with self.assertRaises(TokenBackendError): backend.decode(invalid_token)
def test_decode(self): # No expiry tokens cause no exception payload = {'foo': 'bar'} no_exp_token = jwt.encode(payload, self.secret, algorithm='HS256') self.token_backend.decode(no_exp_token) # Expired tokens should cause exception payload['exp'] = aware_utcnow() - timedelta(seconds=1) expired_token = jwt.encode(payload, self.secret, algorithm='HS256') with self.assertRaises(TokenBackendError): self.token_backend.decode(expired_token) # Token with invalid signature should cause exception payload['exp'] = aware_utcnow() + timedelta(days=1) token = jwt.encode(payload, self.secret, algorithm='HS256') payload['foo'] = 'baz' other_token = jwt.encode(payload, self.secret, algorithm='HS256') incorrect_payload = other_token.rsplit('.', 1)[0] correct_sig = token.rsplit('.', 1)[-1] invalid_token = incorrect_payload + '.' + correct_sig with self.assertRaises(TokenBackendError): self.token_backend.decode(invalid_token) # Otherwise, should return data payload for token self.assertEqual(self.token_backend.decode(other_token), payload)
def test_init_token_given(self): # Test successful instantiation original_now = aware_utcnow() with patch('rest_framework_simplejwt.tokens.aware_utcnow' ) as fake_aware_utcnow: fake_aware_utcnow.return_value = original_now good_token = MyToken() good_token['some_value'] = 'arst' encoded_good_token = str(good_token) now = aware_utcnow() # Create new token from encoded token with patch('rest_framework_simplejwt.tokens.aware_utcnow' ) as fake_aware_utcnow: fake_aware_utcnow.return_value = now # Should raise no exception t = MyToken(encoded_good_token) # Should have expected properties self.assertEqual(t.current_time, now) self.assertEqual(t.token, encoded_good_token) self.assertEqual(len(t.payload), 4) self.assertEqual(t['some_value'], 'arst') self.assertEqual(t['exp'], datetime_to_epoch(original_now + MyToken.lifetime)) self.assertEqual(t[api_settings.TOKEN_TYPE_CLAIM], MyToken.token_type) self.assertIn('jti', t.payload)
def test_it_should_return_the_correct_value(self): now = datetime.utcnow() with patch('rest_framework_simplejwt.utils.datetime') as fake_datetime: fake_datetime.utcnow.return_value = now # Should return aware utcnow if USE_TZ == True with self.settings(USE_TZ=True): self.assertEqual( timezone.make_aware(now, timezone=timezone.utc), aware_utcnow()) # Should return naive utcnow if USE_TZ == False with self.settings(USE_TZ=False): self.assertEqual(now, aware_utcnow())
def test_it_should_blacklist_refresh_token_if_everything_ok(self): self.assertEqual(OutstandingToken.objects.count(), 0) self.assertEqual(BlacklistedToken.objects.count(), 0) refresh = RefreshToken() refresh['test_claim'] = 'arst' old_jti = refresh['jti'] # Serializer validates ser = TokenBlacklistSerializer(data={'refresh': str(refresh)}) now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2 with patch('rest_framework_simplejwt.tokens.aware_utcnow' ) as fake_aware_utcnow: fake_aware_utcnow.return_value = now self.assertTrue(ser.is_valid()) self.assertEqual(OutstandingToken.objects.count(), 1) self.assertEqual(BlacklistedToken.objects.count(), 1) # Assert old refresh token is blacklisted self.assertEqual(BlacklistedToken.objects.first().token.jti, old_jti)
def test_decode_rsa_aud_iss_jwk_success(self): self.payload["exp"] = aware_utcnow() + timedelta(days=1) self.payload["foo"] = "baz" self.payload["aud"] = AUDIENCE self.payload["iss"] = ISSUER token = jwt.encode( self.payload, PRIVATE_KEY_2, algorithm="RS256", headers={"kid": "230498151c214b788dd97f22b85410a5"}, ) # Payload copied self.payload["exp"] = datetime_to_epoch(self.payload["exp"]) mock_jwk_module = mock.MagicMock() with patch("rest_framework_simplejwt.backends.PyJWKClient") as mock_jwk_module: mock_jwk_client = mock.MagicMock() mock_signing_key = mock.MagicMock() mock_jwk_module.return_value = mock_jwk_client mock_jwk_client.get_signing_key_from_jwt.return_value = mock_signing_key type(mock_signing_key).key = mock.PropertyMock(return_value=PUBLIC_KEY_2) # Note the PRIV,PUB care is intentially the original pairing jwk_token_backend = TokenBackend( "RS256", PRIVATE_KEY, PUBLIC_KEY, AUDIENCE, ISSUER, JWK_URL ) self.assertEqual(jwk_token_backend.decode(token), self.payload)
def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self): refresh = RefreshToken() refresh['test_claim'] = 'arst' old_jti = refresh['jti'] old_exp = refresh['exp'] # Serializer validates ser = TokenRefreshSerializer(data={'refresh': str(refresh)}) now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2 with override_api_settings(ROTATE_REFRESH_TOKENS=True, BLACKLIST_AFTER_ROTATION=False): with patch('rest_framework_simplejwt.tokens.aware_utcnow') as fake_aware_utcnow: fake_aware_utcnow.return_value = now self.assertTrue(ser.is_valid()) access = AccessToken(ser.validated_data['access']) new_refresh = RefreshToken(ser.validated_data['refresh']) self.assertEqual(refresh['test_claim'], access['test_claim']) self.assertEqual(refresh['test_claim'], new_refresh['test_claim']) self.assertNotEqual(old_jti, new_refresh['jti']) self.assertNotEqual(old_exp, new_refresh['exp']) self.assertEqual(access['exp'], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME)) self.assertEqual(new_refresh['exp'], datetime_to_epoch(now + api_settings.REFRESH_TOKEN_LIFETIME))
def test_decode_hmac_with_expiry(self): self.payload['exp'] = aware_utcnow() - timedelta(seconds=1) expired_token = jwt.encode(self.payload, SECRET, algorithm='HS256') with self.assertRaises(TokenBackendError): self.hmac_token_backend.decode(expired_token)
def test_decode_hmac_success(self): self.payload['exp'] = aware_utcnow() + timedelta(days=1) self.payload['foo'] = 'baz' token = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8') self.assertEqual(self.hmac_token_backend.decode(token), self.payload)
def test_decode_rsa_with_expiry(self): self.payload['exp'] = aware_utcnow() - timedelta(seconds=1) expired_token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256') with self.assertRaises(TokenBackendError): self.rsa_token_backend.decode(expired_token)
def __init__(self, token=None, verify=True): self.token = token self.current_time = aware_utcnow() if token is not None: # An encoded token was provided from rest_framework_simplejwt.state import token_backend # Decode token try: self.payload = token_backend.decode(token, verify=verify) if not get_or_none(User, pk=self.payload['user_id']): raise TokenError(_('Token is invalid or expired')) except TokenBackendError: raise TokenError(_('Token is invalid or expired')) if tokens.BlacklistedToken.objects.filter( token__jti=self.payload['jti']).exists(): raise TokenError(_('Token is blacklisted')) if verify: self.verify() new_token, created = \ tokens.OutstandingToken.objects.get_or_create( jti=self.payload['jti'], user_id=self.payload['user_id'], defaults={ 'token': str(self.token), 'expires_at': tokens.datetime_from_epoch( self.payload['exp'] ), }, ) super().__init__(token, verify)
def test_decode_leeway_hmac_fail(self): self.payload["exp"] = datetime_to_epoch(aware_utcnow() - timedelta(seconds=LEEWAY * 2)) expired_token = jwt.encode(self.payload, SECRET, algorithm='HS256') with self.assertRaises(TokenBackendError): self.hmac_leeway_token_backend.decode(expired_token)
def test_decode_rsa_success(self): self.payload['exp'] = aware_utcnow() + timedelta(days=1) self.payload['foo'] = 'baz' token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8') self.assertEqual(self.rsa_token_backend.decode(token), self.payload)
def __init__(self, token=None, verify=True): """ !!!! IMPORTANT !!!! MUST raise a TokenError with a user-facing error message if the given token is invalid, expired, or otherwise not safe to use. """ if self.token_type is None or self.lifetime is None: raise TokenError(_('Cannot create token with no type or lifetime')) self.token = token self.current_time = aware_utcnow() # Set up token if token is not None: # An encoded token was provided from .state import token_backend # Decode token try: self.payload = token_backend.decode(token, verify=verify) except TokenBackendError: raise TokenError(_('Token is invalid or expired')) if verify: self.verify() else: # New token. Skip all the verification steps. self.payload = {api_settings.TOKEN_TYPE_CLAIM: self.token_type} # Set "exp" claim with default value self.set_exp(from_time=self.current_time, lifetime=self.lifetime) # Set "jti" claim self.set_jti()
def test_token_user_cache_fallback_life(): iat = datetime_to_epoch(aware_utcnow()) jwt = get_jwt(exp=iat+15, iat=iat) token = UntypedToken(jwt) token.payload['iat'] = None token_user = PermissionedTokenUser(token) assert token_user._get_permission_cache_life() == 300
def test_decode_with_invalid_sig_no_verify(self): self.payload["exp"] = aware_utcnow() + timedelta(days=1) for backend in self.backends: with self.subTest( "Test decode with invalid sig for f{backend.algorithm}"): payload = self.payload.copy() token_1 = jwt.encode(payload, backend.signing_key, algorithm=backend.algorithm) payload["foo"] = "baz" token_2 = jwt.encode(payload, backend.signing_key, algorithm=backend.algorithm) if IS_OLD_JWT: token_1 = token_1.decode("utf-8") token_2 = token_2.decode("utf-8") else: # Payload copied payload["exp"] = datetime_to_epoch(payload["exp"]) token_2_payload = token_2.rsplit(".", 1)[0] token_1_sig = token_1.rsplit(".", 1)[-1] invalid_token = token_2_payload + "." + token_1_sig self.assertEqual( backend.decode(invalid_token, verify=False), payload, )
def test_decode_rsa_success(self): self.payload['exp'] = aware_utcnow() + timedelta(days=1) self.payload['foo'] = 'baz' token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256') # Payload copied self.payload["exp"] = datetime_to_epoch(self.payload["exp"]) self.assertEqual(self.rsa_token_backend.decode(token), self.payload)
def test_decode_aud_iss_success(self): self.payload['exp'] = aware_utcnow() + timedelta(days=1) self.payload['foo'] = 'baz' self.payload['aud'] = AUDIENCE self.payload['iss'] = ISSUER token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8') self.assertEqual(self.aud_iss_token_backend.decode(token), self.payload)
def test_decode_leeway_hmac_success(self): self.payload["exp"] = datetime_to_epoch(aware_utcnow() - timedelta(seconds=LEEWAY / 2)) expired_token = jwt.encode(self.payload, SECRET, algorithm='HS256') self.assertEqual( self.hmac_leeway_token_backend.decode(expired_token), self.payload, )
def flush_expired_tokens(): """Flush expired tokens Adapted from DRF simplejwt flushexpiredtokens command """ from rest_framework_simplejwt.token_blacklist.models import OutstandingToken tokens = OutstandingToken.objects.filter(expires_at__lte=aware_utcnow()) logger.info(f'Flushing {len(tokens)} expired tokens') tokens.delete()
def test_decode_with_expiry(self): self.payload["exp"] = aware_utcnow() - timedelta(seconds=1) for backend in self.backends: with self.subTest("Test decode with expiry for f{backend.algorithm}"): expired_token = jwt.encode( self.payload, backend.signing_key, algorithm=backend.algorithm ) with self.assertRaises(TokenBackendError): backend.decode(expired_token)
def test_decode_rsa_with_invalid_sig(self): self.payload['exp'] = aware_utcnow() + timedelta(days=1) token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8') self.payload['foo'] = 'baz' token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8') token_2_payload = token_2.rsplit('.', 1)[0] token_1_sig = token_1.rsplit('.', 1)[-1] invalid_token = token_2_payload + '.' + token_1_sig with self.assertRaises(TokenBackendError): self.rsa_token_backend.decode(invalid_token)
def post(self, request): try: refresh_token = request.data['refresh_token'] token = RefreshToken(refresh_token) token.blacklist() OutstandingToken.objects.filter( expires_at__lte=aware_utcnow()).delete() # BlacklistedToken.objects.all().delete() return Response(status=status.HTTP_205_RESET_CONTENT) except Exception as e: return Response({"error": "Logged out"}, status=status.HTTP_400_BAD_REQUEST)
def post(self, request): if 'refresh' not in request.data: return Response(data={"refresh": ["This field is required."]}, status=status.HTTP_400_BAD_REQUEST) login_user = self.authenticate(request)[0] OutstandingToken.objects.filter(expires_at__lte=aware_utcnow(), user=login_user).delete() token = RefreshToken(request.data['refresh']) token.blacklist() return Response(data={}, status=status.HTTP_200_OK)
def test_decode_aud_iss_success(self): self.payload["exp"] = aware_utcnow() + timedelta(days=1) self.payload["foo"] = "baz" self.payload["aud"] = AUDIENCE self.payload["iss"] = ISSUER token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256") if IS_OLD_JWT: token = token.decode("utf-8") else: # Payload copied self.payload["exp"] = datetime_to_epoch(self.payload["exp"]) self.assertEqual(self.aud_iss_token_backend.decode(token), self.payload)
def test_init_bad_sig_token_given(self): # Test backend rejects encoded token (expired or bad signature) payload = {'foo': 'bar'} payload['exp'] = aware_utcnow() + timedelta(days=1) token_1 = jwt.encode(payload, api_settings.SIGNING_KEY, algorithm='HS256') payload['foo'] = 'baz' token_2 = jwt.encode(payload, api_settings.SIGNING_KEY, algorithm='HS256') token_2_payload = token_2.rsplit('.', 1)[0] token_1_sig = token_1.rsplit('.', 1)[-1] invalid_token = token_2_payload + '.' + token_1_sig with self.assertRaises(TokenError): MyToken(invalid_token)
def test_it_should_return_given_token_if_everything_ok(self): refresh = RefreshToken() refresh['test_claim'] = 'arst' # Serializer validates s = TokenVerifySerializer(data={'token': str(refresh)}) now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2 with patch('rest_framework_simplejwt.tokens.aware_utcnow') as fake_aware_utcnow: fake_aware_utcnow.return_value = now self.assertTrue(s.is_valid()) self.assertEqual(len(s.validated_data), 0)
def test_decode_rsa_with_invalid_sig_no_verify(self): self.payload['exp'] = aware_utcnow() + timedelta(days=1) token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8') self.payload['foo'] = 'baz' token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8') token_2_payload = token_2.rsplit('.', 1)[0] token_1_sig = token_1.rsplit('.', 1)[-1] invalid_token = token_2_payload + '.' + token_1_sig self.assertEqual( self.hmac_token_backend.decode(invalid_token, verify=False), self.payload, )
def test_it_should_delete_any_expired_tokens(self): # Make some tokens that won't expire soon not_expired_1 = RefreshToken.for_user(self.user) not_expired_2 = RefreshToken.for_user(self.user) not_expired_3 = RefreshToken() # Blacklist fresh tokens not_expired_2.blacklist() not_expired_3.blacklist() # Make tokens with fake exp time that will expire soon fake_now = aware_utcnow() - api_settings.REFRESH_TOKEN_LIFETIME with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow: fake_aware_utcnow.return_value = fake_now expired_1 = RefreshToken.for_user(self.user) expired_2 = RefreshToken() # Blacklist expired tokens expired_1.blacklist() expired_2.blacklist() # Make another token that won't expire soon not_expired_4 = RefreshToken.for_user(self.user) # Should be certain number of outstanding tokens and blacklisted # tokens self.assertEqual(OutstandingToken.objects.count(), 6) self.assertEqual(BlacklistedToken.objects.count(), 4) call_command("flushexpiredtokens") # Expired outstanding *and* blacklisted tokens should be gone self.assertEqual(OutstandingToken.objects.count(), 4) self.assertEqual(BlacklistedToken.objects.count(), 2) self.assertEqual( [i.jti for i in OutstandingToken.objects.order_by("id")], [ not_expired_1["jti"], not_expired_2["jti"], not_expired_3["jti"], not_expired_4["jti"], ], ) self.assertEqual( [i.token.jti for i in BlacklistedToken.objects.order_by("id")], [not_expired_2["jti"], not_expired_3["jti"]], )
def test_it_should_return_nothing_if_everything_ok(self): refresh = RefreshToken() refresh["test_claim"] = "arst" # Serializer validates s = TokenBlacklistSerializer(data={"refresh": str(refresh)}) now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2 with patch("rest_framework_simplejwt.tokens.aware_utcnow" ) as fake_aware_utcnow: fake_aware_utcnow.return_value = now self.assertTrue(s.is_valid()) self.assertDictEqual(s.validated_data, {})