Ejemplo n.º 1
0
 def test_request_factory_url_arguments(self):
     """
     This is a non regression test against #1461
     """
     factory = APIRequestFactory()
     request = factory.get('/view/?demo=test')
     self.assertEqual(dict(request.GET), {'demo': ['test']})
     request = factory.get('/view/', {'demo': 'test'})
     self.assertEqual(dict(request.GET), {'demo': ['test']})
Ejemplo n.º 2
0
 def test_force_authenticate(self):
     """
     Setting `force_authenticate()` forcibly authenticates the request.
     """
     user = User.objects.create_user('example', '*****@*****.**')
     factory = APIRequestFactory()
     request = factory.get('/view')
     force_authenticate(request, user=user)
     response = view(request)
     self.assertEqual(response.data['user'], 'example')
 def _resolve_urlpatterns(self, urlpatterns, test_paths):
     factory = APIRequestFactory()
     try:
         urlpatterns = format_suffix_patterns(urlpatterns)
     except Exception:
         self.fail("Failed to apply `format_suffix_patterns` on  the supplied urlpatterns")
     resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
     for test_path in test_paths:
         request = factory.get(test_path.path)
         try:
             callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
         except Exception:
             self.fail("Failed to resolve URL: %s" % request.path_info)
         self.assertEqual(callback_args, test_path.args)
         self.assertEqual(callback_kwargs, test_path.kwargs)
class ThrottlingTests(TestCase):
    def setUp(self):
        """
        Reset the cache so that no throttles will be active
        """
        cache.clear()
        self.factory = APIRequestFactory()

    def test_requests_are_throttled(self):
        """
        Ensure request rate is limited
        """
        request = self.factory.get('/')
        for dummy in range(4):
            response = MockView.as_view()(request)
        self.assertEqual(429, response.status_code)

    def set_throttle_timer(self, view, value):
        """
        Explicitly set the timer, overriding time.time()
        """
        view.throttle_classes[0].timer = lambda self: value

    def test_request_throttling_expires(self):
        """
        Ensure request rate is limited for a limited duration only
        """
        self.set_throttle_timer(MockView, 0)

        request = self.factory.get('/')
        for dummy in range(4):
            response = MockView.as_view()(request)
        self.assertEqual(429, response.status_code)

        # Advance the timer by one second
        self.set_throttle_timer(MockView, 1)

        response = MockView.as_view()(request)
        self.assertEqual(200, response.status_code)

    def ensure_is_throttled(self, view, expect):
        request = self.factory.get('/')
        request.user = User.objects.create(username='******')
        for dummy in range(3):
            view.as_view()(request)
        request.user = User.objects.create(username='******')
        response = view.as_view()(request)
        self.assertEqual(expect, response.status_code)

    def test_request_throttling_is_per_user(self):
        """
        Ensure request rate is only limited per user, not globally for
        PerUserThrottles
        """
        self.ensure_is_throttled(MockView, 200)

    def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
        """
        Ensure the response returns an Retry-After field with status and next attributes
        set properly.
        """
        request = self.factory.get('/')
        for timer, expect in expected_headers:
            self.set_throttle_timer(view, timer)
            response = view.as_view()(request)
            if expect is not None:
                self.assertEqual(response['Retry-After'], expect)
            else:
                self.assertFalse('Retry-After' in response)

    def test_seconds_fields(self):
        """
        Ensure for second based throttles.
        """
        self.ensure_response_header_contains_proper_throttle_field(
            MockView, (
                (0, None),
                (0, None),
                (0, None),
                (0, '1')
            )
        )

    def test_minutes_fields(self):
        """
        Ensure for minute based throttles.
        """
        self.ensure_response_header_contains_proper_throttle_field(
            MockView_MinuteThrottling, (
                (0, None),
                (0, None),
                (0, None),
                (0, '60')
            )
        )

    def test_next_rate_remains_constant_if_followed(self):
        """
        If a client follows the recommended next request rate,
        the throttling rate should stay constant.
        """
        self.ensure_response_header_contains_proper_throttle_field(
            MockView_MinuteThrottling, (
                (0, None),
                (20, None),
                (40, None),
                (60, None),
                (80, None)
            )
        )

    def test_non_time_throttle(self):
        """
        Ensure for second based throttles.
        """
        request = self.factory.get('/')

        self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called'))

        response = MockView_NonTimeThrottling.as_view()(request)
        self.assertFalse('Retry-After' in response)

        self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)

        response = MockView_NonTimeThrottling.as_view()(request)
        self.assertFalse('Retry-After' in response)
class ScopedRateThrottleTests(TestCase):
    """
    Tests for ScopedRateThrottle.
    """

    def setUp(self):
        class XYScopedRateThrottle(ScopedRateThrottle):
            TIMER_SECONDS = 0
            THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}

            def timer(self):
                return self.TIMER_SECONDS

        class XView(APIView):
            throttle_classes = (XYScopedRateThrottle,)
            throttle_scope = 'x'

            def get(self, request):
                return Response('x')

        class YView(APIView):
            throttle_classes = (XYScopedRateThrottle,)
            throttle_scope = 'y'

            def get(self, request):
                return Response('y')

        class UnscopedView(APIView):
            throttle_classes = (XYScopedRateThrottle,)

            def get(self, request):
                return Response('y')

        self.throttle_class = XYScopedRateThrottle
        self.factory = APIRequestFactory()
        self.x_view = XView.as_view()
        self.y_view = YView.as_view()
        self.unscoped_view = UnscopedView.as_view()

    def increment_timer(self, seconds=1):
        self.throttle_class.TIMER_SECONDS += seconds

    def test_scoped_rate_throttle(self):
        request = self.factory.get('/')

        # Should be able to hit x view 3 times per minute.
        response = self.x_view(request)
        self.assertEqual(200, response.status_code)

        self.increment_timer()
        response = self.x_view(request)
        self.assertEqual(200, response.status_code)

        self.increment_timer()
        response = self.x_view(request)
        self.assertEqual(200, response.status_code)

        self.increment_timer()
        response = self.x_view(request)
        self.assertEqual(429, response.status_code)

        # Should be able to hit y view 1 time per minute.
        self.increment_timer()
        response = self.y_view(request)
        self.assertEqual(200, response.status_code)

        self.increment_timer()
        response = self.y_view(request)
        self.assertEqual(429, response.status_code)

        # Ensure throttles properly reset by advancing the rest of the minute
        self.increment_timer(55)

        # Should still be able to hit x view 3 times per minute.
        response = self.x_view(request)
        self.assertEqual(200, response.status_code)

        self.increment_timer()
        response = self.x_view(request)
        self.assertEqual(200, response.status_code)

        self.increment_timer()
        response = self.x_view(request)
        self.assertEqual(200, response.status_code)

        self.increment_timer()
        response = self.x_view(request)
        self.assertEqual(429, response.status_code)

        # Should still be able to hit y view 1 time per minute.
        self.increment_timer()
        response = self.y_view(request)
        self.assertEqual(200, response.status_code)

        self.increment_timer()
        response = self.y_view(request)
        self.assertEqual(429, response.status_code)

    def test_unscoped_view_not_throttled(self):
        request = self.factory.get('/')

        for idx in range(10):
            self.increment_timer()
            response = self.unscoped_view(request)
            self.assertEqual(200, response.status_code)
Ejemplo n.º 6
0
class DecoratorTestCase(TestCase):

    def setUp(self):
        self.factory = APIRequestFactory()

    def _finalize_response(self, request, response, *args, **kwargs):
        response.request = request
        return APIView.finalize_response(self, request, response, *args, **kwargs)

    def test_api_view_incorrect(self):
        """
        If @api_view is not applied correct, we should raise an assertion.
        """

        @api_view
        def view(request):
            return Response()

        request = self.factory.get('/')
        self.assertRaises(AssertionError, view, request)

    def test_api_view_incorrect_arguments(self):
        """
        If @api_view is missing arguments, we should raise an assertion.
        """

        with self.assertRaises(AssertionError):
            @api_view('GET')
            def view(request):
                return Response()

    def test_calling_method(self):

        @api_view(['GET'])
        def view(request):
            return Response({})

        request = self.factory.get('/')
        response = view(request)
        self.assertEqual(response.status_code, status.HTTP_200_OK)

        request = self.factory.post('/')
        response = view(request)
        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)

    def test_calling_put_method(self):

        @api_view(['GET', 'PUT'])
        def view(request):
            return Response({})

        request = self.factory.put('/')
        response = view(request)
        self.assertEqual(response.status_code, status.HTTP_200_OK)

        request = self.factory.post('/')
        response = view(request)
        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)

    def test_calling_patch_method(self):

        @api_view(['GET', 'PATCH'])
        def view(request):
            return Response({})

        request = self.factory.patch('/')
        response = view(request)
        self.assertEqual(response.status_code, status.HTTP_200_OK)

        request = self.factory.post('/')
        response = view(request)
        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)

    def test_renderer_classes(self):

        @api_view(['GET'])
        @renderer_classes([JSONRenderer])
        def view(request):
            return Response({})

        request = self.factory.get('/')
        response = view(request)
        self.assertTrue(isinstance(response.accepted_renderer, JSONRenderer))

    def test_parser_classes(self):

        @api_view(['GET'])
        @parser_classes([JSONParser])
        def view(request):
            self.assertEqual(len(request.parsers), 1)
            self.assertTrue(isinstance(request.parsers[0],
                                       JSONParser))
            return Response({})

        request = self.factory.get('/')
        view(request)

    def test_authentication_classes(self):

        @api_view(['GET'])
        @authentication_classes([BasicAuthentication])
        def view(request):
            self.assertEqual(len(request.authenticators), 1)
            self.assertTrue(isinstance(request.authenticators[0],
                                       BasicAuthentication))
            return Response({})

        request = self.factory.get('/')
        view(request)

    def test_permission_classes(self):

        @api_view(['GET'])
        @permission_classes([IsAuthenticated])
        def view(request):
            return Response({})

        request = self.factory.get('/')
        response = view(request)
        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)

    def test_throttle_classes(self):
        class OncePerDayUserThrottle(UserRateThrottle):
            rate = '1/day'

        @api_view(['GET'])
        @throttle_classes([OncePerDayUserThrottle])
        def view(request):
            return Response({})

        request = self.factory.get('/')
        response = view(request)
        self.assertEqual(response.status_code, status.HTTP_200_OK)

        response = view(request)
        self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
class ThrottlingTests(TestCase):
    def setUp(self):
        """
        Reset the cache so that no throttles will be active
        """
        cache.clear()
        self.factory = APIRequestFactory()

    def test_requests_are_throttled(self):
        """
        Ensure request rate is limited
        """
        request = self.factory.get('/')
        for dummy in range(4):
            response = MockView.as_view()(request)
        self.assertEqual(429, response.status_code)

    def set_throttle_timer(self, view, value):
        """
        Explicitly set the timer, overriding time.time()
        """
        view.throttle_classes[0].timer = lambda self: value

    def test_request_throttling_expires(self):
        """
        Ensure request rate is limited for a limited duration only
        """
        self.set_throttle_timer(MockView, 0)

        request = self.factory.get('/')
        for dummy in range(4):
            response = MockView.as_view()(request)
        self.assertEqual(429, response.status_code)

        # Advance the timer by one second
        self.set_throttle_timer(MockView, 1)

        response = MockView.as_view()(request)
        self.assertEqual(200, response.status_code)

    def ensure_is_throttled(self, view, expect):
        request = self.factory.get('/')
        request.user = User.objects.create(username='******')
        for dummy in range(3):
            view.as_view()(request)
        request.user = User.objects.create(username='******')
        response = view.as_view()(request)
        self.assertEqual(expect, response.status_code)

    def test_request_throttling_is_per_user(self):
        """
        Ensure request rate is only limited per user, not globally for
        PerUserThrottles
        """
        self.ensure_is_throttled(MockView, 200)

    def ensure_response_header_contains_proper_throttle_field(
            self, view, expected_headers):
        """
        Ensure the response returns an Retry-After field with status and next attributes
        set properly.
        """
        request = self.factory.get('/')
        for timer, expect in expected_headers:
            self.set_throttle_timer(view, timer)
            response = view.as_view()(request)
            if expect is not None:
                self.assertEqual(response['Retry-After'], expect)
            else:
                self.assertFalse('Retry-After' in response)

    def test_seconds_fields(self):
        """
        Ensure for second based throttles.
        """
        self.ensure_response_header_contains_proper_throttle_field(
            MockView, ((0, None), (0, None), (0, None), (0, '1')))

    def test_minutes_fields(self):
        """
        Ensure for minute based throttles.
        """
        self.ensure_response_header_contains_proper_throttle_field(
            MockView_MinuteThrottling,
            ((0, None), (0, None), (0, None), (0, '60')))

    def test_next_rate_remains_constant_if_followed(self):
        """
        If a client follows the recommended next request rate,
        the throttling rate should stay constant.
        """
        self.ensure_response_header_contains_proper_throttle_field(
            MockView_MinuteThrottling,
            ((0, None), (20, None), (40, None), (60, None), (80, None)))

    def test_non_time_throttle(self):
        """
        Ensure for second based throttles.
        """
        request = self.factory.get('/')

        self.assertFalse(
            hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called'))

        response = MockView_NonTimeThrottling.as_view()(request)
        self.assertFalse('Retry-After' in response)

        self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)

        response = MockView_NonTimeThrottling.as_view()(request)
        self.assertFalse('Retry-After' in response)
class ScopedRateThrottleTests(TestCase):
    """
    Tests for ScopedRateThrottle.
    """
    def setUp(self):
        class XYScopedRateThrottle(ScopedRateThrottle):
            TIMER_SECONDS = 0
            THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}

            def timer(self):
                return self.TIMER_SECONDS

        class XView(APIView):
            throttle_classes = (XYScopedRateThrottle, )
            throttle_scope = 'x'

            def get(self, request):
                return Response('x')

        class YView(APIView):
            throttle_classes = (XYScopedRateThrottle, )
            throttle_scope = 'y'

            def get(self, request):
                return Response('y')

        class UnscopedView(APIView):
            throttle_classes = (XYScopedRateThrottle, )

            def get(self, request):
                return Response('y')

        self.throttle_class = XYScopedRateThrottle
        self.factory = APIRequestFactory()
        self.x_view = XView.as_view()
        self.y_view = YView.as_view()
        self.unscoped_view = UnscopedView.as_view()

    def increment_timer(self, seconds=1):
        self.throttle_class.TIMER_SECONDS += seconds

    def test_scoped_rate_throttle(self):
        request = self.factory.get('/')

        # Should be able to hit x view 3 times per minute.
        response = self.x_view(request)
        self.assertEqual(200, response.status_code)

        self.increment_timer()
        response = self.x_view(request)
        self.assertEqual(200, response.status_code)

        self.increment_timer()
        response = self.x_view(request)
        self.assertEqual(200, response.status_code)

        self.increment_timer()
        response = self.x_view(request)
        self.assertEqual(429, response.status_code)

        # Should be able to hit y view 1 time per minute.
        self.increment_timer()
        response = self.y_view(request)
        self.assertEqual(200, response.status_code)

        self.increment_timer()
        response = self.y_view(request)
        self.assertEqual(429, response.status_code)

        # Ensure throttles properly reset by advancing the rest of the minute
        self.increment_timer(55)

        # Should still be able to hit x view 3 times per minute.
        response = self.x_view(request)
        self.assertEqual(200, response.status_code)

        self.increment_timer()
        response = self.x_view(request)
        self.assertEqual(200, response.status_code)

        self.increment_timer()
        response = self.x_view(request)
        self.assertEqual(200, response.status_code)

        self.increment_timer()
        response = self.x_view(request)
        self.assertEqual(429, response.status_code)

        # Should still be able to hit y view 1 time per minute.
        self.increment_timer()
        response = self.y_view(request)
        self.assertEqual(200, response.status_code)

        self.increment_timer()
        response = self.y_view(request)
        self.assertEqual(429, response.status_code)

    def test_unscoped_view_not_throttled(self):
        request = self.factory.get('/')

        for idx in range(10):
            self.increment_timer()
            response = self.unscoped_view(request)
            self.assertEqual(200, response.status_code)
from __future__ import unicode_literals
from django.conf.urls import url
from django.test import TestCase
from rest_framework_3 import serializers
from rest_framework_3.test import APIRequestFactory
from tests.models import (
    ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
    NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
)

factory = APIRequestFactory()
request = factory.get('/')  # Just to ensure we have a request in the serializer context


def dummy_view(request, pk):
    pass


urlpatterns = [
    url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),
    url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
    url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
    url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),
    url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'),
    url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),
    url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'),
    url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),
]


# ManyToMany