Exemplo n.º 1
0
    def test_decode_with_invalid_sig(self):
        self.payload["exp"] = aware_utcnow() - timedelta(seconds=1)
        for backend in self.backends:
            with self.subTest(
                    f"Test decode with invalid sig for {backend.algorithm}"):
                payload = self.payload.copy()
                payload["exp"] = aware_utcnow() + timedelta(days=1)
                token_1 = jwt.encode(payload,
                                     backend.signing_key,
                                     algorithm=backend.algorithm)
                payload["foo"] = "baz"
                token_2 = jwt.encode(payload,
                                     backend.signing_key,
                                     algorithm=backend.algorithm)

                if IS_OLD_JWT:
                    token_1 = token_1.decode("utf-8")
                    token_2 = token_2.decode("utf-8")

                token_2_payload = token_2.rsplit(".", 1)[0]
                token_1_sig = token_1.rsplit(".", 1)[-1]
                invalid_token = token_2_payload + "." + token_1_sig

                with self.assertRaises(TokenBackendError):
                    backend.decode(invalid_token)
    def test_decode(self):
        # No expiry tokens cause no exception
        payload = {'foo': 'bar'}
        no_exp_token = jwt.encode(payload, self.secret, algorithm='HS256')
        self.token_backend.decode(no_exp_token)

        # Expired tokens should cause exception
        payload['exp'] = aware_utcnow() - timedelta(seconds=1)
        expired_token = jwt.encode(payload, self.secret, algorithm='HS256')
        with self.assertRaises(TokenBackendError):
            self.token_backend.decode(expired_token)

        # Token with invalid signature should cause exception
        payload['exp'] = aware_utcnow() + timedelta(days=1)
        token = jwt.encode(payload, self.secret, algorithm='HS256')
        payload['foo'] = 'baz'
        other_token = jwt.encode(payload, self.secret, algorithm='HS256')

        incorrect_payload = other_token.rsplit('.', 1)[0]
        correct_sig = token.rsplit('.', 1)[-1]
        invalid_token = incorrect_payload + '.' + correct_sig

        with self.assertRaises(TokenBackendError):
            self.token_backend.decode(invalid_token)

        # Otherwise, should return data payload for token
        self.assertEqual(self.token_backend.decode(other_token), payload)
Exemplo n.º 3
0
    def test_init_token_given(self):
        # Test successful instantiation
        original_now = aware_utcnow()

        with patch('rest_framework_simplejwt.tokens.aware_utcnow'
                   ) as fake_aware_utcnow:
            fake_aware_utcnow.return_value = original_now
            good_token = MyToken()

        good_token['some_value'] = 'arst'
        encoded_good_token = str(good_token)

        now = aware_utcnow()

        # Create new token from encoded token
        with patch('rest_framework_simplejwt.tokens.aware_utcnow'
                   ) as fake_aware_utcnow:
            fake_aware_utcnow.return_value = now
            # Should raise no exception
            t = MyToken(encoded_good_token)

        # Should have expected properties
        self.assertEqual(t.current_time, now)
        self.assertEqual(t.token, encoded_good_token)

        self.assertEqual(len(t.payload), 4)
        self.assertEqual(t['some_value'], 'arst')
        self.assertEqual(t['exp'],
                         datetime_to_epoch(original_now + MyToken.lifetime))
        self.assertEqual(t[api_settings.TOKEN_TYPE_CLAIM], MyToken.token_type)
        self.assertIn('jti', t.payload)
    def test_it_should_return_the_correct_value(self):
        now = datetime.utcnow()

        with patch('rest_framework_simplejwt.utils.datetime') as fake_datetime:
            fake_datetime.utcnow.return_value = now

            # Should return aware utcnow if USE_TZ == True
            with self.settings(USE_TZ=True):
                self.assertEqual(
                    timezone.make_aware(now, timezone=timezone.utc),
                    aware_utcnow())

            # Should return naive utcnow if USE_TZ == False
            with self.settings(USE_TZ=False):
                self.assertEqual(now, aware_utcnow())
    def test_it_should_blacklist_refresh_token_if_everything_ok(self):
        self.assertEqual(OutstandingToken.objects.count(), 0)
        self.assertEqual(BlacklistedToken.objects.count(), 0)

        refresh = RefreshToken()

        refresh['test_claim'] = 'arst'

        old_jti = refresh['jti']

        # Serializer validates
        ser = TokenBlacklistSerializer(data={'refresh': str(refresh)})

        now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

        with patch('rest_framework_simplejwt.tokens.aware_utcnow'
                   ) as fake_aware_utcnow:
            fake_aware_utcnow.return_value = now
            self.assertTrue(ser.is_valid())

        self.assertEqual(OutstandingToken.objects.count(), 1)
        self.assertEqual(BlacklistedToken.objects.count(), 1)

        # Assert old refresh token is blacklisted
        self.assertEqual(BlacklistedToken.objects.first().token.jti, old_jti)
    def test_decode_rsa_aud_iss_jwk_success(self):
        self.payload["exp"] = aware_utcnow() + timedelta(days=1)
        self.payload["foo"] = "baz"
        self.payload["aud"] = AUDIENCE
        self.payload["iss"] = ISSUER

        token = jwt.encode(
            self.payload,
            PRIVATE_KEY_2,
            algorithm="RS256",
            headers={"kid": "230498151c214b788dd97f22b85410a5"},
        )
        # Payload copied
        self.payload["exp"] = datetime_to_epoch(self.payload["exp"])

        mock_jwk_module = mock.MagicMock()
        with patch("rest_framework_simplejwt.backends.PyJWKClient") as mock_jwk_module:
            mock_jwk_client = mock.MagicMock()
            mock_signing_key = mock.MagicMock()

            mock_jwk_module.return_value = mock_jwk_client
            mock_jwk_client.get_signing_key_from_jwt.return_value = mock_signing_key
            type(mock_signing_key).key = mock.PropertyMock(return_value=PUBLIC_KEY_2)

            # Note the PRIV,PUB care is intentially the original pairing
            jwk_token_backend = TokenBackend(
                "RS256", PRIVATE_KEY, PUBLIC_KEY, AUDIENCE, ISSUER, JWK_URL
            )

            self.assertEqual(jwk_token_backend.decode(token), self.payload)
Exemplo n.º 7
0
    def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self):
        refresh = RefreshToken()

        refresh['test_claim'] = 'arst'

        old_jti = refresh['jti']
        old_exp = refresh['exp']

        # Serializer validates
        ser = TokenRefreshSerializer(data={'refresh': str(refresh)})

        now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

        with override_api_settings(ROTATE_REFRESH_TOKENS=True, BLACKLIST_AFTER_ROTATION=False):
            with patch('rest_framework_simplejwt.tokens.aware_utcnow') as fake_aware_utcnow:
                fake_aware_utcnow.return_value = now
                self.assertTrue(ser.is_valid())

        access = AccessToken(ser.validated_data['access'])
        new_refresh = RefreshToken(ser.validated_data['refresh'])

        self.assertEqual(refresh['test_claim'], access['test_claim'])
        self.assertEqual(refresh['test_claim'], new_refresh['test_claim'])

        self.assertNotEqual(old_jti, new_refresh['jti'])
        self.assertNotEqual(old_exp, new_refresh['exp'])

        self.assertEqual(access['exp'], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME))
        self.assertEqual(new_refresh['exp'], datetime_to_epoch(now + api_settings.REFRESH_TOKEN_LIFETIME))
    def test_decode_hmac_with_expiry(self):
        self.payload['exp'] = aware_utcnow() - timedelta(seconds=1)

        expired_token = jwt.encode(self.payload, SECRET, algorithm='HS256')

        with self.assertRaises(TokenBackendError):
            self.hmac_token_backend.decode(expired_token)
    def test_decode_hmac_success(self):
        self.payload['exp'] = aware_utcnow() + timedelta(days=1)
        self.payload['foo'] = 'baz'

        token = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8')

        self.assertEqual(self.hmac_token_backend.decode(token), self.payload)
    def test_decode_rsa_with_expiry(self):
        self.payload['exp'] = aware_utcnow() - timedelta(seconds=1)

        expired_token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')

        with self.assertRaises(TokenBackendError):
            self.rsa_token_backend.decode(expired_token)
Exemplo n.º 11
0
    def __init__(self, token=None, verify=True):
        self.token = token
        self.current_time = aware_utcnow()
        if token is not None:
            # An encoded token was provided
            from rest_framework_simplejwt.state import token_backend

            # Decode token
            try:
                self.payload = token_backend.decode(token, verify=verify)
                if not get_or_none(User, pk=self.payload['user_id']):
                    raise TokenError(_('Token is invalid or expired'))
            except TokenBackendError:
                raise TokenError(_('Token is invalid or expired'))
            if tokens.BlacklistedToken.objects.filter(
                    token__jti=self.payload['jti']).exists():
                raise TokenError(_('Token is blacklisted'))
            if verify:
                self.verify()
                new_token, created = \
                    tokens.OutstandingToken.objects.get_or_create(
                        jti=self.payload['jti'],
                        user_id=self.payload['user_id'],
                        defaults={
                            'token': str(self.token),
                            'expires_at': tokens.datetime_from_epoch(
                                self.payload['exp']
                            ),
                        },
                    )
        super().__init__(token, verify)
    def test_decode_leeway_hmac_fail(self):
        self.payload["exp"] = datetime_to_epoch(aware_utcnow() - timedelta(seconds=LEEWAY * 2))

        expired_token = jwt.encode(self.payload, SECRET, algorithm='HS256')

        with self.assertRaises(TokenBackendError):
            self.hmac_leeway_token_backend.decode(expired_token)
    def test_decode_rsa_success(self):
        self.payload['exp'] = aware_utcnow() + timedelta(days=1)
        self.payload['foo'] = 'baz'

        token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')

        self.assertEqual(self.rsa_token_backend.decode(token), self.payload)
Exemplo n.º 14
0
    def __init__(self, token=None, verify=True):
        """
        !!!! IMPORTANT !!!! MUST raise a TokenError with a user-facing error
        message if the given token is invalid, expired, or otherwise not safe
        to use.
        """
        if self.token_type is None or self.lifetime is None:
            raise TokenError(_('Cannot create token with no type or lifetime'))

        self.token = token
        self.current_time = aware_utcnow()

        # Set up token
        if token is not None:
            # An encoded token was provided
            from .state import token_backend

            # Decode token
            try:
                self.payload = token_backend.decode(token, verify=verify)
            except TokenBackendError:
                raise TokenError(_('Token is invalid or expired'))

            if verify:
                self.verify()
        else:
            # New token.  Skip all the verification steps.
            self.payload = {api_settings.TOKEN_TYPE_CLAIM: self.token_type}

            # Set "exp" claim with default value
            self.set_exp(from_time=self.current_time, lifetime=self.lifetime)

            # Set "jti" claim
            self.set_jti()
Exemplo n.º 15
0
def test_token_user_cache_fallback_life():
    iat = datetime_to_epoch(aware_utcnow())
    jwt = get_jwt(exp=iat+15, iat=iat)
    token = UntypedToken(jwt)
    token.payload['iat'] = None
    token_user = PermissionedTokenUser(token)
    assert token_user._get_permission_cache_life() == 300
Exemplo n.º 16
0
    def test_decode_with_invalid_sig_no_verify(self):
        self.payload["exp"] = aware_utcnow() + timedelta(days=1)
        for backend in self.backends:
            with self.subTest(
                    "Test decode with invalid sig for f{backend.algorithm}"):
                payload = self.payload.copy()
                token_1 = jwt.encode(payload,
                                     backend.signing_key,
                                     algorithm=backend.algorithm)
                payload["foo"] = "baz"
                token_2 = jwt.encode(payload,
                                     backend.signing_key,
                                     algorithm=backend.algorithm)
                if IS_OLD_JWT:
                    token_1 = token_1.decode("utf-8")
                    token_2 = token_2.decode("utf-8")
                else:
                    # Payload copied
                    payload["exp"] = datetime_to_epoch(payload["exp"])

                token_2_payload = token_2.rsplit(".", 1)[0]
                token_1_sig = token_1.rsplit(".", 1)[-1]
                invalid_token = token_2_payload + "." + token_1_sig

                self.assertEqual(
                    backend.decode(invalid_token, verify=False),
                    payload,
                )
    def test_decode_rsa_success(self):
        self.payload['exp'] = aware_utcnow() + timedelta(days=1)
        self.payload['foo'] = 'baz'

        token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
        # Payload copied
        self.payload["exp"] = datetime_to_epoch(self.payload["exp"])

        self.assertEqual(self.rsa_token_backend.decode(token), self.payload)
    def test_decode_aud_iss_success(self):
        self.payload['exp'] = aware_utcnow() + timedelta(days=1)
        self.payload['foo'] = 'baz'
        self.payload['aud'] = AUDIENCE
        self.payload['iss'] = ISSUER

        token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')

        self.assertEqual(self.aud_iss_token_backend.decode(token), self.payload)
    def test_decode_leeway_hmac_success(self):
        self.payload["exp"] = datetime_to_epoch(aware_utcnow() - timedelta(seconds=LEEWAY / 2))

        expired_token = jwt.encode(self.payload, SECRET, algorithm='HS256')

        self.assertEqual(
            self.hmac_leeway_token_backend.decode(expired_token),
            self.payload,
        )
Exemplo n.º 20
0
def flush_expired_tokens():
    """Flush expired tokens

    Adapted from DRF simplejwt flushexpiredtokens command
    """
    from rest_framework_simplejwt.token_blacklist.models import OutstandingToken
    tokens = OutstandingToken.objects.filter(expires_at__lte=aware_utcnow())
    logger.info(f'Flushing {len(tokens)} expired tokens')
    tokens.delete()
Exemplo n.º 21
0
    def test_decode_with_expiry(self):
        self.payload["exp"] = aware_utcnow() - timedelta(seconds=1)
        for backend in self.backends:
            with self.subTest("Test decode with expiry for f{backend.algorithm}"):

                expired_token = jwt.encode(
                    self.payload, backend.signing_key, algorithm=backend.algorithm
                )

                with self.assertRaises(TokenBackendError):
                    backend.decode(expired_token)
    def test_decode_rsa_with_invalid_sig(self):
        self.payload['exp'] = aware_utcnow() + timedelta(days=1)
        token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
        self.payload['foo'] = 'baz'
        token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')

        token_2_payload = token_2.rsplit('.', 1)[0]
        token_1_sig = token_1.rsplit('.', 1)[-1]
        invalid_token = token_2_payload + '.' + token_1_sig

        with self.assertRaises(TokenBackendError):
            self.rsa_token_backend.decode(invalid_token)
Exemplo n.º 23
0
 def post(self, request):
     try:
         refresh_token = request.data['refresh_token']
         token = RefreshToken(refresh_token)
         token.blacklist()
         OutstandingToken.objects.filter(
             expires_at__lte=aware_utcnow()).delete()
         # BlacklistedToken.objects.all().delete()
         return Response(status=status.HTTP_205_RESET_CONTENT)
     except Exception as e:
         return Response({"error": "Logged out"},
                         status=status.HTTP_400_BAD_REQUEST)
Exemplo n.º 24
0
    def post(self, request):
        if 'refresh' not in request.data:
            return Response(data={"refresh": ["This field is required."]},
                            status=status.HTTP_400_BAD_REQUEST)

        login_user = self.authenticate(request)[0]
        OutstandingToken.objects.filter(expires_at__lte=aware_utcnow(),
                                        user=login_user).delete()

        token = RefreshToken(request.data['refresh'])
        token.blacklist()
        return Response(data={}, status=status.HTTP_200_OK)
Exemplo n.º 25
0
    def test_decode_aud_iss_success(self):
        self.payload["exp"] = aware_utcnow() + timedelta(days=1)
        self.payload["foo"] = "baz"
        self.payload["aud"] = AUDIENCE
        self.payload["iss"] = ISSUER

        token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")
        if IS_OLD_JWT:
            token = token.decode("utf-8")
        else:
            # Payload copied
            self.payload["exp"] = datetime_to_epoch(self.payload["exp"])

        self.assertEqual(self.aud_iss_token_backend.decode(token), self.payload)
    def test_init_bad_sig_token_given(self):
        # Test backend rejects encoded token (expired or bad signature)
        payload = {'foo': 'bar'}
        payload['exp'] = aware_utcnow() + timedelta(days=1)
        token_1 = jwt.encode(payload, api_settings.SIGNING_KEY, algorithm='HS256')
        payload['foo'] = 'baz'
        token_2 = jwt.encode(payload, api_settings.SIGNING_KEY, algorithm='HS256')

        token_2_payload = token_2.rsplit('.', 1)[0]
        token_1_sig = token_1.rsplit('.', 1)[-1]
        invalid_token = token_2_payload + '.' + token_1_sig

        with self.assertRaises(TokenError):
            MyToken(invalid_token)
Exemplo n.º 27
0
    def test_it_should_return_given_token_if_everything_ok(self):
        refresh = RefreshToken()
        refresh['test_claim'] = 'arst'

        # Serializer validates
        s = TokenVerifySerializer(data={'token': str(refresh)})

        now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

        with patch('rest_framework_simplejwt.tokens.aware_utcnow') as fake_aware_utcnow:
            fake_aware_utcnow.return_value = now
            self.assertTrue(s.is_valid())

        self.assertEqual(len(s.validated_data), 0)
    def test_decode_rsa_with_invalid_sig_no_verify(self):
        self.payload['exp'] = aware_utcnow() + timedelta(days=1)
        token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
        self.payload['foo'] = 'baz'
        token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')

        token_2_payload = token_2.rsplit('.', 1)[0]
        token_1_sig = token_1.rsplit('.', 1)[-1]
        invalid_token = token_2_payload + '.' + token_1_sig

        self.assertEqual(
            self.hmac_token_backend.decode(invalid_token, verify=False),
            self.payload,
        )
Exemplo n.º 29
0
    def test_it_should_delete_any_expired_tokens(self):
        # Make some tokens that won't expire soon
        not_expired_1 = RefreshToken.for_user(self.user)
        not_expired_2 = RefreshToken.for_user(self.user)
        not_expired_3 = RefreshToken()

        # Blacklist fresh tokens
        not_expired_2.blacklist()
        not_expired_3.blacklist()

        # Make tokens with fake exp time that will expire soon
        fake_now = aware_utcnow() - api_settings.REFRESH_TOKEN_LIFETIME

        with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
            fake_aware_utcnow.return_value = fake_now
            expired_1 = RefreshToken.for_user(self.user)
            expired_2 = RefreshToken()

        # Blacklist expired tokens
        expired_1.blacklist()
        expired_2.blacklist()

        # Make another token that won't expire soon
        not_expired_4 = RefreshToken.for_user(self.user)

        # Should be certain number of outstanding tokens and blacklisted
        # tokens
        self.assertEqual(OutstandingToken.objects.count(), 6)
        self.assertEqual(BlacklistedToken.objects.count(), 4)

        call_command("flushexpiredtokens")

        # Expired outstanding *and* blacklisted tokens should be gone
        self.assertEqual(OutstandingToken.objects.count(), 4)
        self.assertEqual(BlacklistedToken.objects.count(), 2)

        self.assertEqual(
            [i.jti for i in OutstandingToken.objects.order_by("id")],
            [
                not_expired_1["jti"],
                not_expired_2["jti"],
                not_expired_3["jti"],
                not_expired_4["jti"],
            ],
        )
        self.assertEqual(
            [i.token.jti for i in BlacklistedToken.objects.order_by("id")],
            [not_expired_2["jti"], not_expired_3["jti"]],
        )
Exemplo n.º 30
0
    def test_it_should_return_nothing_if_everything_ok(self):
        refresh = RefreshToken()
        refresh["test_claim"] = "arst"

        # Serializer validates
        s = TokenBlacklistSerializer(data={"refresh": str(refresh)})

        now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

        with patch("rest_framework_simplejwt.tokens.aware_utcnow"
                   ) as fake_aware_utcnow:
            fake_aware_utcnow.return_value = now
            self.assertTrue(s.is_valid())

        self.assertDictEqual(s.validated_data, {})