def test_init_token_given(self):
        # Test successful instantiation
        original_now = aware_utcnow()

        with patch("rest_framework_simplejwt.tokens.aware_utcnow"
                   ) as fake_aware_utcnow:
            fake_aware_utcnow.return_value = original_now
            good_token = MyToken()

        good_token["some_value"] = "arst"
        encoded_good_token = str(good_token)

        now = aware_utcnow()

        # Create new token from encoded token
        with patch("rest_framework_simplejwt.tokens.aware_utcnow"
                   ) as fake_aware_utcnow:
            fake_aware_utcnow.return_value = now
            # Should raise no exception
            t = MyToken(encoded_good_token)

        # Should have expected properties
        self.assertEqual(t.current_time, now)
        self.assertEqual(t.token, encoded_good_token)

        self.assertEqual(len(t.payload), 5)
        self.assertEqual(t["some_value"], "arst")
        self.assertEqual(t["exp"],
                         datetime_to_epoch(original_now + MyToken.lifetime))
        self.assertEqual(t["iat"], datetime_to_epoch(original_now))
        self.assertEqual(t[api_settings.TOKEN_TYPE_CLAIM], MyToken.token_type)
        self.assertIn("jti", t.payload)
Example #2
0
    def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self):
        refresh = RefreshToken()

        refresh['test_claim'] = 'arst'

        old_jti = refresh['jti']
        old_exp = refresh['exp']

        # Serializer validates
        ser = TokenRefreshSerializer(data={'refresh': str(refresh)})

        now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

        with override_api_settings(ROTATE_REFRESH_TOKENS=True, BLACKLIST_AFTER_ROTATION=False):
            with patch('rest_framework_simplejwt.tokens.aware_utcnow') as fake_aware_utcnow:
                fake_aware_utcnow.return_value = now
                self.assertTrue(ser.is_valid())

        access = AccessToken(ser.validated_data['access'])
        new_refresh = RefreshToken(ser.validated_data['refresh'])

        self.assertEqual(refresh['test_claim'], access['test_claim'])
        self.assertEqual(refresh['test_claim'], new_refresh['test_claim'])

        self.assertNotEqual(old_jti, new_refresh['jti'])
        self.assertNotEqual(old_exp, new_refresh['exp'])

        self.assertEqual(access['exp'], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME))
        self.assertEqual(new_refresh['exp'], datetime_to_epoch(now + api_settings.REFRESH_TOKEN_LIFETIME))
 def test_it_should_return_the_correct_values(self):
     self.assertEqual(
         datetime_to_epoch(datetime(year=1970, month=1, day=1)), 0)
     self.assertEqual(
         datetime_to_epoch(datetime(year=1970, month=1, day=1, second=1)),
         1)
     self.assertEqual(
         datetime_to_epoch(datetime(year=2000, month=1, day=1)), 946684800)
Example #4
0
    def test_set_exp(self):
        now = make_utc(datetime(year=2000, month=1, day=1))

        token = MyToken()
        token.current_time = now

        # By default, should add 'exp' claim to token using `self.current_time`
        # and the TOKEN_LIFETIME setting
        token.set_exp()
        self.assertEqual(token['exp'], datetime_to_epoch(now + MyToken.lifetime))

        # Should allow overriding of beginning time, lifetime, and claim name
        token.set_exp(claim='refresh_exp', from_time=now, lifetime=timedelta(days=1))
        self.assertIn('refresh_exp', token)
        self.assertEqual(token['refresh_exp'], datetime_to_epoch(now + timedelta(days=1)))
    def test_set_iat(self):
        now = make_utc(datetime(year=2000, month=1, day=1))

        token = MyToken()
        token.current_time = now

        # By default, should add 'iat' claim to token using `self.current_time`
        token.set_iat()
        self.assertEqual(token["iat"], datetime_to_epoch(now))

        # Should allow overriding of time and claim name
        token.set_iat(claim="refresh_iat", at_time=now + timedelta(days=1))
        self.assertIn("refresh_iat", token)
        self.assertEqual(token["refresh_iat"],
                         datetime_to_epoch(now + timedelta(days=1)))
Example #6
0
    def validate(self, attrs):
        data = super().validate(attrs)

        # Generate the new access token
        refresh_token = RefreshToken(attrs['refresh'])
        access_token = refresh_token.access_token

        user_id = access_token[api_settings.USER_ID_CLAIM]
        user = get_user_model().objects.filter(pk=user_id).first()

        if not user:
            raise ValidationError({'user': '******'})

        # Give the expiry time by base-role
        token_lifetime = self.get_lifetime(user)

        access_token.set_exp(from_time=refresh_token.current_time,
                             lifetime=token_lifetime)
        access_token.payload['exp'] = datetime_to_epoch(
            access_token.current_time + token_lifetime)

        return {
            **data, 'access': str(access_token),
            'access_token_expiry': int(token_lifetime.total_seconds())
        }
Example #7
0
    def validate(self, attrs):
        try:
            if self.user:
                if api_settings.UPDATE_LAST_LOGIN:
                    update_last_login(None, self.user)
        except AttributeError:
            super().validate(attrs)

        refresh_token = self.get_token(self.user)
        access_token = refresh_token.access_token

        # Give the expiry time by base-role
        token_lifetime = self.get_lifetime(self.user)

        access_token.set_exp(from_time=refresh_token.current_time,
                             lifetime=token_lifetime)
        access_token.payload['exp'] = datetime_to_epoch(
            access_token.current_time + token_lifetime)

        return {
            'refresh': str(refresh_token),
            'access': str(access_token),
            'access_token_expiry': int(token_lifetime.total_seconds()),
            'user_data': UserSerializer(self.user).data
        }
Example #8
0
    def test_decode_with_invalid_sig_no_verify(self):
        self.payload["exp"] = aware_utcnow() + timedelta(days=1)
        for backend in self.backends:
            with self.subTest(
                    "Test decode with invalid sig for f{backend.algorithm}"):
                payload = self.payload.copy()
                token_1 = jwt.encode(payload,
                                     backend.signing_key,
                                     algorithm=backend.algorithm)
                payload["foo"] = "baz"
                token_2 = jwt.encode(payload,
                                     backend.signing_key,
                                     algorithm=backend.algorithm)
                if IS_OLD_JWT:
                    token_1 = token_1.decode("utf-8")
                    token_2 = token_2.decode("utf-8")
                else:
                    # Payload copied
                    payload["exp"] = datetime_to_epoch(payload["exp"])

                token_2_payload = token_2.rsplit(".", 1)[0]
                token_1_sig = token_1.rsplit(".", 1)[-1]
                invalid_token = token_2_payload + "." + token_1_sig

                self.assertEqual(
                    backend.decode(invalid_token, verify=False),
                    payload,
                )
    def test_decode_rsa_aud_iss_jwk_success(self):
        self.payload["exp"] = aware_utcnow() + timedelta(days=1)
        self.payload["foo"] = "baz"
        self.payload["aud"] = AUDIENCE
        self.payload["iss"] = ISSUER

        token = jwt.encode(
            self.payload,
            PRIVATE_KEY_2,
            algorithm="RS256",
            headers={"kid": "230498151c214b788dd97f22b85410a5"},
        )
        # Payload copied
        self.payload["exp"] = datetime_to_epoch(self.payload["exp"])

        mock_jwk_module = mock.MagicMock()
        with patch("rest_framework_simplejwt.backends.PyJWKClient") as mock_jwk_module:
            mock_jwk_client = mock.MagicMock()
            mock_signing_key = mock.MagicMock()

            mock_jwk_module.return_value = mock_jwk_client
            mock_jwk_client.get_signing_key_from_jwt.return_value = mock_signing_key
            type(mock_signing_key).key = mock.PropertyMock(return_value=PUBLIC_KEY_2)

            # Note the PRIV,PUB care is intentially the original pairing
            jwk_token_backend = TokenBackend(
                "RS256", PRIVATE_KEY, PUBLIC_KEY, AUDIENCE, ISSUER, JWK_URL
            )

            self.assertEqual(jwk_token_backend.decode(token), self.payload)
Example #10
0
def test_token_user_cache_fallback_life():
    iat = datetime_to_epoch(aware_utcnow())
    jwt = get_jwt(exp=iat+15, iat=iat)
    token = UntypedToken(jwt)
    token.payload['iat'] = None
    token_user = PermissionedTokenUser(token)
    assert token_user._get_permission_cache_life() == 300
    def test_decode_leeway_hmac_fail(self):
        self.payload["exp"] = datetime_to_epoch(aware_utcnow() - timedelta(seconds=LEEWAY * 2))

        expired_token = jwt.encode(self.payload, SECRET, algorithm='HS256')

        with self.assertRaises(TokenBackendError):
            self.hmac_leeway_token_backend.decode(expired_token)
Example #12
0
    def test_init_with_timedelta(self):
        now = make_utc(datetime(year=2000, month=1, day=1))
        token = AccessToken(lifetime=timedelta(minutes=10))
        token.current_time = now
        token.set_exp()

        self.assertEqual(token['exp'],
                         datetime_to_epoch(now + timedelta(minutes=10)))
Example #13
0
    def test_it_should_blacklist_refresh_token_if_tokens_should_be_rotated_and_blacklisted(
        self, ):
        self.assertEqual(OutstandingToken.objects.count(), 0)
        self.assertEqual(BlacklistedToken.objects.count(), 0)

        refresh = RefreshToken()

        refresh["test_claim"] = "arst"

        old_jti = refresh["jti"]
        old_exp = refresh["exp"]

        # Serializer validates
        ser = TokenRefreshSerializer(data={"refresh": str(refresh)})

        now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

        with override_api_settings(ROTATE_REFRESH_TOKENS=True,
                                   BLACKLIST_AFTER_ROTATION=True):
            with patch("rest_framework_simplejwt.tokens.aware_utcnow"
                       ) as fake_aware_utcnow:
                fake_aware_utcnow.return_value = now
                self.assertTrue(ser.is_valid())

        access = AccessToken(ser.validated_data["access"])
        new_refresh = RefreshToken(ser.validated_data["refresh"])

        self.assertEqual(refresh["test_claim"], access["test_claim"])
        self.assertEqual(refresh["test_claim"], new_refresh["test_claim"])

        self.assertNotEqual(old_jti, new_refresh["jti"])
        self.assertNotEqual(old_exp, new_refresh["exp"])

        self.assertEqual(
            access["exp"],
            datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME))
        self.assertEqual(
            new_refresh["exp"],
            datetime_to_epoch(now + api_settings.REFRESH_TOKEN_LIFETIME),
        )

        self.assertEqual(OutstandingToken.objects.count(), 1)
        self.assertEqual(BlacklistedToken.objects.count(), 1)

        # Assert old refresh token is blacklisted
        self.assertEqual(BlacklistedToken.objects.first().token.jti, old_jti)
Example #14
0
    def get_token(cls, user):
        token = super(TokenObtainTokenSerializer, cls).get_token(user)
        token['username'] = user.username
        token['email'] = user.email

        if user:
            token.payload['exp'] = datetime_to_epoch(token.current_time +
                                                     SUPERUSER_LIFETIME)
            return token
    def test_decode_leeway_hmac_success(self):
        self.payload["exp"] = datetime_to_epoch(aware_utcnow() - timedelta(seconds=LEEWAY / 2))

        expired_token = jwt.encode(self.payload, SECRET, algorithm='HS256')

        self.assertEqual(
            self.hmac_leeway_token_backend.decode(expired_token),
            self.payload,
        )
    def test_decode_rsa_success(self):
        self.payload['exp'] = aware_utcnow() + timedelta(days=1)
        self.payload['foo'] = 'baz'

        token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
        # Payload copied
        self.payload["exp"] = datetime_to_epoch(self.payload["exp"])

        self.assertEqual(self.rsa_token_backend.decode(token), self.payload)
    def test_init_no_token_given(self):
        now = make_utc(datetime(year=2000, month=1, day=1))

        with patch("rest_framework_simplejwt.tokens.aware_utcnow"
                   ) as fake_aware_utcnow:
            fake_aware_utcnow.return_value = now
            t = MyToken()

        self.assertEqual(t.current_time, now)
        self.assertIsNone(t.token)

        self.assertEqual(len(t.payload), 4)
        self.assertEqual(t.payload["exp"],
                         datetime_to_epoch(now + MyToken.lifetime))
        self.assertEqual(t.payload["iat"], datetime_to_epoch(now))
        self.assertIn("jti", t.payload)
        self.assertEqual(t.payload[api_settings.TOKEN_TYPE_CLAIM],
                         MyToken.token_type)
Example #18
0
    def test_init(self):
        # Should set sliding refresh claim and token type claim
        token = SlidingToken()

        self.assertEqual(
            token[api_settings.SLIDING_TOKEN_REFRESH_EXP_CLAIM],
            datetime_to_epoch(token.current_time + api_settings.SLIDING_TOKEN_REFRESH_LIFETIME),
        )
        self.assertEqual(token[api_settings.TOKEN_TYPE_CLAIM], 'sliding')
Example #19
0
 def get_token(cls, user):
     token = super(CustomTokenObtainPairSerializer, cls).get_token(user)
     token['user_id'] = user.id
     token['name'] = user.username
     token['email'] = user.email
     token['first_name'] = user.first_name
     token['last_name'] = user.last_name
     token['is_active'] = user.is_active
     if user:
         token.payload['exp'] = datetime_to_epoch(token.current_time + SUPERUSER_LIFETIME)
         return token
Example #20
0
    def get_token(cls, user):
        token = super(MyTokenObtainPairSerializer, cls).get_token(user)
        token['name'] = user.email
        token['user_id'] = user.id

        if user.is_superuser:
            # token.set_exp(from_time=starttime,lifetime=SUPERUSER_LIFETIME)
            token.payload['exp'] = datetime_to_epoch(token.current_time +
                                                     SUPERUSER_LIFETIME)

        return token
Example #21
0
    def set_exp(self, claim='exp', from_time=None, lifetime=None):
        """
        Updates the expiration time of a token.
        """
        if from_time is None:
            from_time = self.current_time

        if lifetime is None:
            lifetime = self.lifetime

        self.payload[claim] = datetime_to_epoch(from_time + lifetime)
Example #22
0
 def long_lived_access_token(self):
     """
     After a short_lived_access_token is redeemed, we set this long_lived_access_token in an HTTP-only cookie,
     completing the contract expected by the front-end and established in the muggle authentication pattern used
     in this app.
     """
     access = self.access_token
     access["exp"] = datetime_to_epoch(
         self.current_time + settings.CONTRIBUTOR_LONG_TOKEN_LIFETIME)
     access["ctx"] = LONG_TOKEN
     return access
Example #23
0
 def short_lived_access_token(self):
     """
     Returns a short-lived access token from a refresh token that will be provided as a
     paramter in a magic link. This token has a short TTL-- it is exchanged with a more secure, slightly
     longer-lived token when it is redeemed.
     """
     access = self.access_token
     access["exp"] = datetime_to_epoch(
         self.current_time + settings.CONTRIBUTOR_SHORT_TOKEN_LIFETIME)
     access["ctx"] = SHORT_TOKEN
     return access
Example #24
0
    def test_decode_aud_iss_success(self):
        self.payload["exp"] = aware_utcnow() + timedelta(days=1)
        self.payload["foo"] = "baz"
        self.payload["aud"] = AUDIENCE
        self.payload["iss"] = ISSUER

        token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")
        if IS_OLD_JWT:
            token = token.decode("utf-8")
        else:
            # Payload copied
            self.payload["exp"] = datetime_to_epoch(self.payload["exp"])

        self.assertEqual(self.aud_iss_token_backend.decode(token), self.payload)
Example #25
0
def get_jwt(exp=None,
            iat=None,
            sub='00000000-0000-0000-0000-000000000000',
            username='******',
            **kwargs):
    payload = {
        'email': username,
        'username': username,
        'sub': sub,
        'aud': '892633',
        'jti': uuid4().hex
    }

    for prop_name, prop_value in kwargs.items():
        if prop_value is not None:
            # we test 'prop_value' against None to avoid being caught
            # with a boolean 'prop_value = False' which will be false
            # but needs to be set
            payload[prop_name] = prop_value

    now = aware_utcnow()
    if exp:
        payload['exp'] = exp
    else:
        payload['exp'] = datetime_to_epoch(now + timedelta(minutes=5))

    if iat:
        payload['iat'] = iat
    else:
        payload['iat'] = datetime_to_epoch(now)

    return jwt.encode(payload=payload,
                      key=settings.SIMPLE_JWT['PRIVATE_KEY'],
                      algorithm='RS256',
                      headers={
                          'kid': '230498151c214b788dd97f22b85410a5'
                      }).decode('utf-8')
Example #26
0
    def validate(self, attrs):
        data = super(TokenForUserSerializer, self).validate(attrs)

        refresh = self.get_token(self.user)
        data['refresh'] = str(refresh)

        access = refresh.access_token
        # 生存期間を変更する処理を実施する
        access.payload['exp'] = datetime_to_epoch(access.current_time +
                                                  timedelta(minutes=15))
        # 現在時刻を設定 // 検証用項目
        access['start_datetime'] = timegm(access.current_time.utctimetuple())
        data['access'] = str(access)

        return data
    def test_decode_rsa_with_invalid_sig_no_verify(self):
        self.payload['exp'] = aware_utcnow() + timedelta(days=1)
        token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
        self.payload['foo'] = 'baz'
        token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')

        token_2_payload = token_2.rsplit('.', 1)[0]
        token_1_sig = token_1.rsplit('.', 1)[-1]
        invalid_token = token_2_payload + '.' + token_1_sig
        # Payload copied
        self.payload["exp"] = datetime_to_epoch(self.payload["exp"])

        self.assertEqual(
            self.hmac_token_backend.decode(invalid_token, verify=False),
            self.payload,
        )
Example #28
0
    def test_it_should_return_access_token_if_everything_ok(self):
        refresh = RefreshToken()
        refresh['test_claim'] = 'arst'

        # Serializer validates
        s = TokenRefreshSerializer(data={'refresh': str(refresh)})

        now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

        with patch('rest_framework_simplejwt.tokens.aware_utcnow') as fake_aware_utcnow:
            fake_aware_utcnow.return_value = now
            self.assertTrue(s.is_valid())

        access = AccessToken(s.validated_data['access'])

        self.assertEqual(refresh['test_claim'], access['test_claim'])
        self.assertEqual(access['exp'], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME))
Example #29
0
    def test_decode_success(self):
        self.payload["exp"] = aware_utcnow() + timedelta(days=1)
        self.payload["foo"] = "baz"
        for backend in self.backends:
            with self.subTest("Test decode success for f{backend.algorithm}"):

                token = jwt.encode(self.payload,
                                   backend.signing_key,
                                   algorithm=backend.algorithm)
                if IS_OLD_JWT:
                    token = token.decode("utf-8")
                    payload = self.payload
                else:
                    # Payload copied
                    payload = self.payload.copy()
                    payload["exp"] = datetime_to_epoch(self.payload["exp"])

                self.assertEqual(backend.decode(token), payload)
Example #30
0
    def test_it_should_return_access_token_if_everything_ok(self):
        refresh = RefreshToken()
        refresh['test_claim'] = 'arst'

        # View returns 200
        now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

        with patch('rest_framework_simplejwt.tokens.aware_utcnow') as fake_aware_utcnow:
            fake_aware_utcnow.return_value = now

            res = self.view_post(data={'refresh': str(refresh)})

        self.assertEqual(res.status_code, 200)

        access = AccessToken(res.data['access'])

        self.assertEqual(refresh['test_claim'], access['test_claim'])
        self.assertEqual(access['exp'], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME))