Example #1
0
 def get_page_size(self, request):
     if (get_rest_framework_features()['max_paginate_by'] and
         self.page_size_query_param and
         self.max_page_size and
         request.query_params.get(self.page_size_query_param) == 'max'):
         return self.max_page_size
     return super(PaginateByMaxMixin, self).get_page_size(request)
Example #2
0
def get_lookup_allowed_symbols(kwarg_name='pk', force_dot=False):
    # todo: test me

    if get_rest_framework_features()['use_dot_in_lookup_regex_by_default'] or force_dot:
        return '(?P<{0}>[^/.]+)'.format(kwarg_name)
    else:
        return '(?P<{0}>[^/]+)'.format(kwarg_name)
Example #3
0
    def test_router_trailing_slash(self):
        experiments = [
            {
                'version': (2, 3),
                'expected': False
            },
            {
                'version': (2, 3, 5),
                'expected': False
            },
            {
                'version': (2, 3, 6),
                'expected': True
            },
            {
                'version': (2, 3, 7),
                'expected': True
            },
            {
                'version': (2, 4),
                'expected': True
            },
        ]

        for exp in experiments:
            with patch('rest_framework_extensions.utils.get_rest_framework_version', Mock(return_value=exp['version'])):
                self.assertEqual(get_rest_framework_features()['router_trailing_slash'], exp['expected'])
Example #4
0
    def test_should_save_only_fields_from_data_for_partial_update(self):
        # it's important to use different instances for Comment,
        # because serializer's save method affects instance from arguments
        serializer_one = CommentSerializer(
            instance=self.get_comment(),
            data={'title': 'goodbye'}, partial=True)
        serializer_two = CommentSerializer(
            instance=self.get_comment(), data={'text': 'moon'}, partial=True)
        serializer_three_kwargs = {
            'instance': self.get_comment(),
            'partial': True
        }
        if get_rest_framework_features()['uses_single_request_data_in_serializers']:
            serializer_three_kwargs['data'] = {'attachment': self.files[1]}
        else:
            serializer_three_kwargs.update({
                'data': {},
                'files': {'attachment': self.files[1]}
            })
        serializer_three = CommentSerializer(**serializer_three_kwargs)
        self.assertTrue(serializer_one.is_valid())
        self.assertTrue(serializer_two.is_valid())
        self.assertTrue(serializer_three.is_valid())

        # saving three serializers expecting they don't affect each other's saving
        serializer_one.save()
        serializer_two.save()
        serializer_three.save()

        fresh_instance = self.get_comment()
        self.assertEqual(
            fresh_instance.attachment.read(), u'file two'.encode('utf-8'))
        self.assertEqual(fresh_instance.text, 'moon')
        self.assertEqual(fresh_instance.title, 'goodbye')
Example #5
0
 def get_paginate_by(self, *args, **kwargs):
     if (get_rest_framework_features()['max_paginate_by'] and
         self.paginate_by_param and
         self.max_paginate_by and
         self.request.QUERY_PARAMS.get(self.paginate_by_param) == 'max'):
         return self.max_paginate_by
     else:
         return super(PaginateByMaxMixin, self).get_paginate_by(*args, **kwargs)
Example #6
0
 def get_paginate_by(self, *args, **kwargs):
     if (get_rest_framework_features()['max_paginate_by'] and
         self.paginate_by_param and
         self.max_paginate_by and
         self.request.QUERY_PARAMS.get(self.paginate_by_param) == 'max'):
         return self.max_paginate_by
     else:
         return super(PaginateByMaxMixin, self).get_paginate_by(*args, **kwargs)
Example #7
0
def get_lookup_allowed_symbols(kwarg_name='pk', force_dot=False):
    # todo: test me

    if get_rest_framework_features(
    )['use_dot_in_lookup_regex_by_default'] or force_dot:
        return '(?P<{0}>[^/.]+)'.format(kwarg_name)
    else:
        return '(?P<{0}>[^/]+)'.format(kwarg_name)
Example #8
0
 def test_simple_response(self):
     resp = self.client.get('/users/')
     expected = [{
         'id': 1,
         'age': 24,
         'name': 'Gennady',
         'surname': 'Chibisov'
     }]
     if not get_rest_framework_features()['write_only_fields']:
         expected[0]['password'] = self.user.password
     self.assertEqual(resp.data, expected)
Example #9
0
 def test_simple_response(self):
     resp = self.client.get('/users/')
     expected = [
         {
             'id': 1,
             'age': 24,
             'name': 'Gennady',
             'surname': 'Chibisov'
         }
     ]
     if not get_rest_framework_features()['write_only_fields']:
         expected[0]['password'] = self.user.password
     self.assertEqual(resp.data, expected)
Example #10
0
    def test_urls_can_have_trailing_slash_removed(self):
        router = ExtendedSimpleRouter(trailing_slash=False)
        router.register(r'router-viewset', RouterViewSet)
        urls = router.urls

        lookup_allowed_symbols = get_lookup_allowed_symbols(
            force_dot=get_rest_framework_features()['allow_dot_in_lookup_regex_without_trailing_slash']
        )

        for exp in ['^router-viewset$',
                    '^router-viewset/{0}$'.format(lookup_allowed_symbols),
                    '^router-viewset/list_controller$',
                    '^router-viewset/{0}/detail_controller$'.format(lookup_allowed_symbols)]:
            msg = 'Should find url pattern with regexp %s' % exp
            self.assertIsNotNone(get_url_pattern_by_regex_pattern(urls, exp), msg=msg)
Example #11
0
    def test_urls_can_have_trailing_slash_removed(self):
        router = ExtendedSimpleRouter(trailing_slash=False)
        router.register(r'router-viewset', RouterViewSet)
        urls = router.urls

        if get_rest_framework_features()['allow_dot_in_lookup_regex_without_trailing_slash']:
            lookup_allowed_symbols = '(?P<pk>[^/.]+)'
        else:
            lookup_allowed_symbols = '(?P<pk>[^/]+)'

        for exp in ['^router-viewset$',
                    '^router-viewset/' + lookup_allowed_symbols + r'$',
                    '^router-viewset/list_controller$',
                    '^router-viewset/' + lookup_allowed_symbols + '/detail_controller$']:
            msg = 'Should find url pattern with regexp %s' % exp
            self.assertIsNotNone(get_url_pattern_by_regex_pattern(urls, exp), msg=msg)
Example #12
0
    def test_max_paginate_by(self):
        experiments = [
            {
                'version': (2, 3),
                'expected': False
            },
            {
                'version': (2, 3, 5),
                'expected': False
            },
            {
                'version': (2, 3, 6),
                'expected': False
            },
            {
                'version': (2, 3, 7),
                'expected': False
            },
            {
                'version': (2, 3, 8),
                'expected': True
            },
            {
                'version': (2, 3, 9),
                'expected': True
            },
            {
                'version': (2, 3, 10),
                'expected': True
            },
            {
                'version': (2, 4),
                'expected': True
            },
        ]

        for exp in experiments:
            with patch('rest_framework_extensions.utils.get_rest_framework_version', Mock(return_value=exp['version'])):
                self.assertEqual(
                    get_rest_framework_features()['max_paginate_by'],
                    exp['expected']
                )
Example #13
0
    def test_django_object_permissions_class(self):
        experiments = [
            {
                'version': (2, 3),
                'expected': False
            },
            {
                'version': (2, 3, 5),
                'expected': False
            },
            {
                'version': (2, 3, 6),
                'expected': False
            },
            {
                'version': (2, 3, 7),
                'expected': False
            },
            {
                'version': (2, 3, 8),
                'expected': True
            },
            {
                'version': (2, 3, 9),
                'expected': True
            },
            {
                'version': (2, 3, 10),
                'expected': True
            },
            {
                'version': (2, 4),
                'expected': True
            },
        ]

        for exp in experiments:
            with patch('rest_framework_extensions.utils.get_rest_framework_version', Mock(return_value=exp['version'])):
                self.assertEqual(
                    get_rest_framework_features()['django_object_permissions_class'],
                    exp['expected']
                )
Example #14
0
    def test_should_not_use_update_fields_when_related_objects_are_saving(self):
        data = {
            'title': 'goodbye',
            'user': {
                'id': self.user.pk,
                'name': 'oleg'
            }

        }
        serializer = CommentSerializerWithExpandedUsersLiked(instance=self.get_comment(), data=data, partial=True)
        self.assertTrue(serializer.is_valid())
        try:
            serializer.save()
        except ValueError:
            self.fail('If serializer has expanded related serializer, then it should not use update_fields while '
                      'saving related object')
        fresh_instance = self.get_comment()
        self.assertEqual(fresh_instance.title, 'goodbye')

        if get_rest_framework_features()['save_related_serializers']:
            self.assertEqual(fresh_instance.user.name, 'oleg')
Example #15
0
    def test_should_not_use_update_fields_when_related_objects_are_saving(
            self):
        data = {
            'title': 'goodbye',
            'user': {
                'id': self.user.pk,
                'name': 'oleg'
            }
        }
        serializer = CommentSerializerWithExpandedUsersLiked(
            instance=self.get_comment(), data=data, partial=True)
        self.assertTrue(serializer.is_valid())
        try:
            serializer.save()
        except ValueError:
            self.fail(
                'If serializer has expanded related serializer, then it should not use update_fields while '
                'saving related object')
        fresh_instance = self.get_comment()
        self.assertEqual(fresh_instance.title, 'goodbye')

        if get_rest_framework_features()['save_related_serializers']:
            self.assertEqual(fresh_instance.user.name, 'oleg')
Example #16
0
def add_trailing_slash_if_needed(regexp_string):
    # todo: test me
    if get_rest_framework_features()['router_trailing_slash']:
        return regexp_string[:-2] + '{trailing_slash}$'
    else:
        return regexp_string
Example #17
0
class ListUpdateModelMixinTestBehaviour__serializer_fields(APITestCase):
    urls = urlpatterns

    def setUp(self):
        self.user = User.objects.create(
            id=1,
            name='Gennady',
            age=24,
            last_name='Chibisov',
            email='*****@*****.**',
            password='******'
        )
        self.headers = {
            utils.prepare_header_name(extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME): 'true'
        }

    def get_fresh_user(self):
        return User.objects.get(pk=self.user.pk)

    def test_simple_response(self):
        resp = self.client.get('/users/')
        expected = [
            {
                'id': 1,
                'age': 24,
                'name': 'Gennady',
                'surname': 'Chibisov'
            }
        ]
        if not get_rest_framework_features()['write_only_fields']:
            expected[0]['password'] = self.user.password
        self.assertEqual(resp.data, expected)

    def test_invalid_for_db_data(self):
        data = {
            'age': 'Not integer value'
        }
        try:
            resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers)
        except ValueError:
            self.fail('Errors with invalid for DB data should be caught')
        else:
            self.assertEqual(resp.status_code, 400)
            expected_message = {
                'detail': "invalid literal for int() with base 10: 'Not integer value'"
            }
            self.assertEqual(resp.data, expected_message)

    def test_should_use_source_if_it_set_in_serializer(self):
        data = {
            'surname': 'Ivanov'
        }
        resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers)
        self.assertEqual(resp.status_code, 204)
        self.assertEqual(self.get_fresh_user().last_name, data['surname'])

    @unittest.skipIf(
        not get_rest_framework_features()['write_only_fields'],
        "Current DRF version doesn't support write_only_fields"
    )
    def test_should_update_write_only_fields(self):
        data = {
            'password': '******'
        }
        resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers)
        self.assertEqual(resp.status_code, 204)
        self.assertEqual(self.get_fresh_user().password, data['password'])

    def test_should_not_update_read_only_fields(self):
        data = {
            'name': 'Ivan'
        }
        resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers)
        self.assertEqual(resp.status_code, 204)
        self.assertEqual(self.get_fresh_user().name, self.user.name)

    def test_should_not_update_hidden_fields(self):
        data = {
            'email': '*****@*****.**'
        }
        resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers)
        self.assertEqual(resp.status_code, 204)
        self.assertEqual(self.get_fresh_user().email, self.user.email)
Example #18
0
# -*- coding: utf-8 -*-
import datetime
import unittest

from django.test import TestCase

from rest_framework_extensions.utils import get_rest_framework_features

from .urls import urlpatterns
from .models import CommentForPaginateByMaxMixin


@unittest.skipIf(
    not get_rest_framework_features()['max_paginate_by'],
    "Current DRF version doesn't support max_paginate_by parameter"
)
class PaginateByMaxMixinTest(TestCase):
    urls = urlpatterns

    def setUp(self):
        for i in range(30):
            CommentForPaginateByMaxMixin.objects.create(
                email='*****@*****.**',
                content='Hello world',
                created=datetime.datetime.now()
            )

    def test_default_page_size(self):
        resp = self.client.get('/comments/')
        self.assertEqual(len(resp.data['results']), 10)
Example #19
0
# -*- coding: utf-8 -*-
import datetime
import unittest

from django.test import TestCase

from rest_framework_extensions.utils import get_rest_framework_features

from .urls import urlpatterns
from .models import CommentForPaginateByMaxMixin


@unittest.skipIf(
    not get_rest_framework_features()['max_paginate_by'],
    "Current DRF version doesn't support max_paginate_by parameter")
class PaginateByMaxMixinTest(TestCase):
    urls = urlpatterns

    def setUp(self):
        for i in range(30):
            CommentForPaginateByMaxMixin.objects.create(
                email='*****@*****.**',
                content='Hello world',
                created=datetime.datetime.now())

    def test_default_page_size(self):
        resp = self.client.get('/comments/')
        self.assertEqual(len(resp.data['results']), 10)

    def test_custom_page_size__less_then_maximum(self):
        resp = self.client.get('/comments/?limit=15')
Example #20
0
        assign_perm(perms['view'], readers, model)
        assign_perm(perms['change'], writers, model)
        assign_perm(perms['delete'], deleters, model)

        readers.user_set.add(users['fullaccess'], users['readonly'])
        writers.user_set.add(users['fullaccess'], users['writeonly'])
        deleters.user_set.add(users['fullaccess'], users['deleteonly'])

        self.credentials = {}
        for user in users.values():
            self.credentials[user.username] = basic_auth_header(user.username, 'password')


@unittest.skipIf(
    not get_rest_framework_features()['django_object_permissions_class'],
    "Current DRF version doesn't support DjangoObjectPermissions"
)
class ExtendedDjangoObjectPermissionsTest_should_inherit_standard(ExtendedDjangoObjectPermissionTestMixin,
                                                                  APITestCase):
    urls = urlpatterns

    # Delete
    def test_can_delete_permissions(self):
        response = self.client.delete('/comments/1/', **{'HTTP_AUTHORIZATION': self.credentials['deleteonly']})
        self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)

    def test_cannot_delete_permissions(self):
        response = self.client.delete('/comments/1/', **{'HTTP_AUTHORIZATION': self.credentials['readonly']})
        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
Example #21
0
    concrete_field_names = []
    for field in get_model_opts_concrete_fields(opts):
        if not field.primary_key:
            concrete_field_names.append(field.name)
            if field.name != field.attname:
                concrete_field_names.append(field.attname)
    update_fields = []
    for field_name in partial_fields:
        if field_name in fields:
            model_field_name = getattr(fields[field_name], 'source') or field_name
            if model_field_name in concrete_field_names:
                update_fields.append(model_field_name)
    return update_fields


if get_rest_framework_features()['single_step_object_creation_in_serializers']:
    class PartialUpdateSerializerMixin(object):
        def save(self, **kwargs):
            self._update_fields = kwargs.get('update_fields', None)
            return super(PartialUpdateSerializerMixin, self).save(**kwargs)

        def update(self, instance, validated_attrs):
            for attr, value in validated_attrs.items():
                setattr(instance, attr, value)
            if self.partial and isinstance(instance, self.Meta.model):
                instance.save(
                    update_fields=getattr(self, '_update_fields') or get_fields_for_partial_update(
                        opts=self.Meta,
                        init_data=self.get_initial(),
                        fields=self.fields.fields
                    )
Example #22
0
class ExtendedDefaultRouterTest(TestCase):
    def setUp(self):
        self.router = ExtendedDefaultRouter()

    def get_routes_names(self, routes):
        return [i.name for i in routes]

    def get_dynamic_route_by_def_name(self, def_name, routes):
        try:
            return [i for i in routes if def_name in i.mapping.values()][0]
        except IndexError:
            return None

    def test_dynamic_routes_should_be_first_in_order(self):
        class BasicViewSet(viewsets.ViewSet):
            def list(self, request, *args, **kwargs):
                return Response({'method': 'list'})

            @action()
            def action1(self, request, *args, **kwargs):
                return Response({'method': 'action1'})

            @link()
            def link1(self, request, *args, **kwargs):
                return Response({'method': 'link1'})

        routes = self.router.get_routes(BasicViewSet)
        expected = [
            '{basename}-action1', '{basename}-link1', '{basename}-list',
            '{basename}-detail'
        ]
        msg = '@action and @link methods should come first in routes order'
        self.assertEqual(self.get_routes_names(routes), expected, msg)

    def test_action_endpoint(self):
        class BasicViewSet(viewsets.ViewSet):
            @action(endpoint='action-one')
            def action1(self, request, *args, **kwargs):
                pass

        routes = self.router.get_routes(BasicViewSet)
        action1_route = self.get_dynamic_route_by_def_name('action1', routes)

        msg = '@action with endpoint route should map methods to endpoint if it is specified'
        self.assertEqual(action1_route.mapping, {'post': 'action1'}, msg)

        msg = '@action with endpoint route should use url with detail lookup'
        self.assertEqual(
            action1_route.url,
            add_trailing_slash_if_needed(u'^{prefix}/{lookup}/action-one/$'),
            msg)

    def test_link_endpoint(self):
        class BasicViewSet(viewsets.ViewSet):
            @link(endpoint='link-one')
            def link1(self, request, *args, **kwargs):
                pass

        routes = self.router.get_routes(BasicViewSet)
        link1_route = self.get_dynamic_route_by_def_name('link1', routes)

        msg = '@link with endpoint route should map methods to endpoint if it is specified'
        self.assertEqual(link1_route.mapping, {'get': 'link1'}, msg)

        msg = '@link with endpoint route should use url with detail lookup'
        self.assertEqual(
            link1_route.url,
            add_trailing_slash_if_needed(u'^{prefix}/{lookup}/link-one/$'),
            msg)

    def test_action__for_list(self):
        class BasicViewSet(viewsets.ViewSet):
            @action(is_for_list=True)
            def action1(self, request, *args, **kwargs):
                pass

        routes = self.router.get_routes(BasicViewSet)
        action1_route = self.get_dynamic_route_by_def_name('action1', routes)

        msg = '@action with is_for_list=True route should map methods to def name'
        self.assertEqual(action1_route.mapping, {'post': 'action1'}, msg)

        msg = '@action with is_for_list=True route should use url in list scope'
        self.assertEqual(action1_route.url,
                         add_trailing_slash_if_needed(u'^{prefix}/action1/$'),
                         msg)

    def test_action__for_list__and__with_endpoint(self):
        class BasicViewSet(viewsets.ViewSet):
            @action(is_for_list=True, endpoint='action-one')
            def action1(self, request, *args, **kwargs):
                pass

        routes = self.router.get_routes(BasicViewSet)
        action1_route = self.get_dynamic_route_by_def_name('action1', routes)

        msg = '@action with is_for_list=True and endpoint route should map methods to "endpoint"'
        self.assertEqual(action1_route.mapping, {'post': 'action1'}, msg)

        msg = '@action with is_for_list=True and endpoint route should use url in list scope with "endpoint" value'
        self.assertEqual(
            action1_route.url,
            add_trailing_slash_if_needed(u'^{prefix}/action-one/$'), msg)

    def test_actions__for_list_and_detail_with_exact_names(self):
        class BasicViewSet(viewsets.ViewSet):
            @action(is_for_list=True, endpoint='action-one')
            def action1(self, request, *args, **kwargs):
                pass

            @action(endpoint='action-one')
            def action1_detail(self, request, *args, **kwargs):
                pass

        routes = self.router.get_routes(BasicViewSet)
        action1_list_route = self.get_dynamic_route_by_def_name(
            'action1', routes)
        action1_detail_route = self.get_dynamic_route_by_def_name(
            'action1_detail', routes)

        self.assertEqual(action1_list_route.mapping, {'post': 'action1'})
        self.assertEqual(
            action1_list_route.url,
            add_trailing_slash_if_needed(u'^{prefix}/action-one/$'))

        self.assertEqual(action1_detail_route.mapping,
                         {'post': 'action1_detail'})
        self.assertEqual(
            action1_detail_route.url,
            add_trailing_slash_if_needed(u'^{prefix}/{lookup}/action-one/$'))

    def test_action_names(self):
        class BasicViewSet(viewsets.ViewSet):
            @action(is_for_list=True)
            def action1(self, request, *args, **kwargs):
                pass

            @action()
            def action2(self, request, *args, **kwargs):
                pass

        routes = self.router.get_routes(BasicViewSet)
        action1_list_route = self.get_dynamic_route_by_def_name(
            'action1', routes)
        action2_detail_route = self.get_dynamic_route_by_def_name(
            'action2', routes)

        self.assertEqual(action1_list_route.name, u'{basename}-action1-list')
        self.assertEqual(action2_detail_route.name, u'{basename}-action2')

    def test_action_names__with_endpoints(self):
        class BasicViewSet(viewsets.ViewSet):
            @action(is_for_list=True, endpoint='action_one')
            def action1(self, request, *args, **kwargs):
                pass

            @action(endpoint='action-two')
            def action2(self, request, *args, **kwargs):
                pass

        routes = self.router.get_routes(BasicViewSet)
        action1_list_route = self.get_dynamic_route_by_def_name(
            'action1', routes)
        action2_detail_route = self.get_dynamic_route_by_def_name(
            'action2', routes)

        self.assertEqual(action1_list_route.name,
                         u'{basename}-action-one-list')
        self.assertEqual(action2_detail_route.name, u'{basename}-action-two')

    @unittest.skipIf(
        not get_rest_framework_features()['has_action_and_link_decorators'],
        "Current DRF version has removed 'action' and 'link' decorators")
    def test_with_default_controllers(self):
        class BasicViewSet(viewsets.ViewSet):
            @link()
            def link(self, request, *args, **kwargs):
                pass

            @decorators.link()
            def link_default(self, request, *args, **kwargs):
                pass

            @action()
            def action(self, request, *args, **kwargs):
                pass

            @decorators.action()
            def action_default(self, request, *args, **kwargs):
                pass

        routes = self.router.get_routes(BasicViewSet)
        link_route = self.get_dynamic_route_by_def_name('link', routes)
        link_default_route = self.get_dynamic_route_by_def_name(
            'link_default', routes)
        action_route = self.get_dynamic_route_by_def_name('action', routes)
        action_default_route = self.get_dynamic_route_by_def_name(
            'action_default', routes)

        self.assertEqual(link_route.name, u'{basename}-link')
        self.assertEqual(link_default_route.name, u'{basename}-link-default')
        self.assertEqual(action_route.name, u'{basename}-action')
        self.assertEqual(action_default_route.name,
                         u'{basename}-action-default')
Example #23
0
 def get_page_size(self, request):
     if (get_rest_framework_features()['max_paginate_by']
             and self.page_size_query_param and self.max_page_size and
             request.query_params.get(self.page_size_query_param) == 'max'):
         return self.max_page_size
     return super(PaginateByMaxMixin, self).get_page_size(request)
Example #24
0
    for field in get_model_opts_concrete_fields(opts):
        if not field.primary_key:
            concrete_field_names.append(field.name)
            if field.name != field.attname:
                concrete_field_names.append(field.attname)
    update_fields = []
    for field_name in partial_fields:
        if field_name in fields:
            model_field_name = getattr(fields[field_name],
                                       'source') or field_name
            if model_field_name in concrete_field_names:
                update_fields.append(model_field_name)
    return update_fields


if get_rest_framework_features()['single_step_object_creation_in_serializers']:

    class PartialUpdateSerializerMixin(object):
        def save(self, **kwargs):
            self._update_fields = kwargs.get('update_fields', None)
            return super(PartialUpdateSerializerMixin, self).save(**kwargs)

        def update(self, instance, validated_attrs):
            for attr, value in validated_attrs.items():
                setattr(instance, attr, value)
            if self.partial and isinstance(instance, self.Meta.model):
                instance.save(
                    update_fields=getattr(self, '_update_fields') or
                    get_fields_for_partial_update(opts=self.Meta,
                                                  init_data=self.get_initial(),
                                                  fields=self.fields.fields))
Example #25
0
        router = ExtendedSimpleRouter()
        router.register(r'router-viewset', RouterViewSet)
        urls = router.urls

        lookup_allowed_symbols = get_lookup_allowed_symbols()

        for exp in ['^router-viewset/$',
                    '^router-viewset/{0}/$'.format(lookup_allowed_symbols),
                    '^router-viewset/list_controller/$',
                    '^router-viewset/{0}/detail_controller/$'.format(lookup_allowed_symbols)]:
            msg = 'Should find url pattern with regexp %s' % exp
            self.assertIsNotNone(get_url_pattern_by_regex_pattern(urls, exp), msg=msg)


@unittest.skipUnless(
    get_rest_framework_features()['router_trailing_slash'],
    "Current DRF version doesn't support Router trailing_slash"
)
class TestTrailingSlashRemoved(TestCase):
    def test_urls_can_have_trailing_slash_removed(self):
        router = ExtendedSimpleRouter(trailing_slash=False)
        router.register(r'router-viewset', RouterViewSet)
        urls = router.urls

        lookup_allowed_symbols = get_lookup_allowed_symbols(
            force_dot=get_rest_framework_features()['allow_dot_in_lookup_regex_without_trailing_slash']
        )

        for exp in ['^router-viewset$',
                    '^router-viewset/{0}$'.format(lookup_allowed_symbols),
                    '^router-viewset/list_controller$',
Example #26
0
def add_trailing_slash_if_needed(regexp_string):
    # todo: test me
    if get_rest_framework_features()['router_trailing_slash']:
        return regexp_string[:-2] + '{trailing_slash}$'
    else:
        return regexp_string
Example #27
0
# -*- coding: utf-8 -*-
from rest_framework_extensions.utils import get_rest_framework_features


if get_rest_framework_features()['django_object_permissions_class']:
    from .extended_django_object_permissions import ExtendedDjangoObjectPermissions
Example #28
0
class PartialUpdateSerializerMixinTest(TestCase):
    def setUp(self):
        self.files = [
            File(BytesIO(u'file one'.encode('utf-8')), name='file1.txt'),
            File(BytesIO(u'file two'.encode('utf-8')), name='file2.txt'),
        ]
        self.files[0]._set_size(8)
        self.files[1]._set_size(8)
        self.user = UserModel.objects.create(name='gena')
        self.comment = CommentModel.objects.create(user=self.user,
                                                   title='hello',
                                                   text='world',
                                                   attachment=self.files[0])

    def get_comment(self):
        return CommentModel.objects.get(pk=self.comment.pk)

    @unittest.skipIf(
        sys.version_info[0] == 3,
        "Skipped for python3 because of https://github.com/tomchristie/django-rest-framework/issues/1642"
    )
    def test_should_use_default_saving_without_partial(self):
        serializer = CommentSerializer(data={
            'user': self.user.id,
            'title': 'hola',
            'text': 'amigos',
        })

        self.assertTrue(
            serializer.is_valid())  # bug for python3 comes from here

        saved_object = serializer.save()
        self.assertEqual(saved_object.user, self.user)
        self.assertEqual(saved_object.title, 'hola')
        self.assertEqual(saved_object.text, 'amigos')

    def test_should_save_partial(self):
        serializer = CommentSerializer(instance=self.comment,
                                       data={'title': 'hola'},
                                       partial=True)
        self.assertTrue(serializer.is_valid())
        saved_object = serializer.save()
        self.assertEqual(saved_object.user, self.user)
        self.assertEqual(saved_object.title, 'hola')
        self.assertEqual(saved_object.text, 'world')

    def test_should_save_only_fields_from_data_for_partial_update(self):
        # it's important to use different instances for Comment, because serializer's save method affects
        # instance from arguments
        serializer_one = CommentSerializer(instance=self.get_comment(),
                                           data={'title': 'goodbye'},
                                           partial=True)
        serializer_two = CommentSerializer(instance=self.get_comment(),
                                           data={'text': 'moon'},
                                           partial=True)
        serializer_three_kwargs = {
            'instance': self.get_comment(),
            'partial': True
        }
        if get_rest_framework_features(
        )['uses_single_request_data_in_serializers']:
            serializer_three_kwargs['data'] = {'attachment': self.files[1]}
        else:
            serializer_three_kwargs.update({
                'data': {},
                'files': {
                    'attachment': self.files[1]
                }
            })
        serializer_three = CommentSerializer(**serializer_three_kwargs)
        self.assertTrue(serializer_one.is_valid())
        self.assertTrue(serializer_two.is_valid())
        self.assertTrue(serializer_three.is_valid())

        # saving three serializers expecting they don't affect each other's saving
        serializer_one.save()
        serializer_two.save()
        serializer_three.save()

        fresh_instance = self.get_comment()
        self.assertEqual(fresh_instance.attachment.read(),
                         u'file two'.encode('utf-8'))
        self.assertEqual(fresh_instance.text, 'moon')
        self.assertEqual(fresh_instance.title, 'goodbye')

    def test_should_use_related_field_name_for_update_field_list(self):
        another_user = UserModel.objects.create(name='vova')
        data = {'title': 'goodbye', 'user': another_user.pk}
        serializer = CommentSerializer(instance=self.get_comment(),
                                       data=data,
                                       partial=True)
        self.assertTrue(serializer.is_valid())
        serializer.save()
        fresh_instance = self.get_comment()
        self.assertEqual(fresh_instance.title, 'goodbye')
        self.assertEqual(fresh_instance.user, another_user)

    def test_should_use_field_source_value_for_searching_model_concrete_fields(
            self):
        data = {'title_from_source': 'goodbye'}
        serializer = CommentSerializer(instance=self.get_comment(),
                                       data=data,
                                       partial=True)
        self.assertTrue(serializer.is_valid())
        serializer.save()
        fresh_instance = self.get_comment()
        self.assertEqual(fresh_instance.title, 'goodbye')

    def test_should_not_use_m2m_field_name_for_update_field_list(self):
        another_user = UserModel.objects.create(name='vova')
        data = {
            'title': 'goodbye',
            'users_liked': [self.user.pk, another_user.pk]
        }
        serializer = CommentSerializer(instance=self.get_comment(),
                                       data=data,
                                       partial=True)
        self.assertTrue(serializer.is_valid())
        try:
            serializer.save()
        except ValueError:
            self.fail(
                'If m2m field used in partial update then it should not be used in update_fields list'
            )
        fresh_instance = self.get_comment()
        self.assertEqual(fresh_instance.title, 'goodbye')
        users_liked = set(fresh_instance.users_liked.all().values_list(
            'pk', flat=True))
        self.assertEqual(users_liked, set([self.user.pk, another_user.pk]))

    def test_should_not_use_related_set_field_name_for_update_field_list(self):
        another_user = UserModel.objects.create(name='vova')
        another_comment = CommentModel.objects.create(
            user=another_user,
            title='goodbye',
            text='moon',
        )
        data = {'name': 'vova', 'comments': [another_comment.pk]}
        serializer = UserSerializer(instance=another_user,
                                    data=data,
                                    partial=True)
        self.assertTrue(serializer.is_valid())
        serializer.save()
        try:
            serializer.save()
        except ValueError:
            self.fail(
                'If related set field used in partial update then it should not be used in update_fields list'
            )
        fresh_comment = CommentModel.objects.get(pk=another_comment.pk)
        fresh_user = UserModel.objects.get(pk=another_user.pk)
        self.assertEqual(fresh_comment.user, another_user)
        self.assertEqual(fresh_user.name, 'vova')

    def test_should_not_try_to_update_fields_that_are_not_in_model(self):
        data = {'title': 'goodbye', 'not_existing_field': 'moon'}
        serializer = CommentSerializer(instance=self.get_comment(),
                                       data=data,
                                       partial=True)
        self.assertTrue(serializer.is_valid())
        try:
            serializer.save()
        except ValueError:
            msg = 'Should not pass values to update_fields from data, if they are not in model'
            self.fail(msg)
        fresh_instance = self.get_comment()
        self.assertEqual(fresh_instance.title, 'goodbye')
        self.assertEqual(fresh_instance.text, 'world')

    def test_should_not_try_to_update_fields_that_are_not_allowed_from_serializer(
            self):
        data = {'title': 'goodbye', 'hidden_text': 'do not change me'}
        serializer = CommentSerializer(instance=self.get_comment(),
                                       data=data,
                                       partial=True)
        self.assertTrue(serializer.is_valid())
        serializer.save()
        fresh_instance = self.get_comment()
        self.assertEqual(fresh_instance.title, 'goodbye')
        self.assertEqual(fresh_instance.text, 'world')
        self.assertEqual(fresh_instance.hidden_text, None)

    def test_should_use_list_of_fields_to_update_from_arguments_if_it_passed(
            self):
        data = {'title': 'goodbye', 'text': 'moon'}
        serializer = CommentSerializer(instance=self.get_comment(),
                                       data=data,
                                       partial=True)
        self.assertTrue(serializer.is_valid())
        serializer.save(**{'update_fields': ['title']})
        fresh_instance = self.get_comment()
        self.assertEqual(fresh_instance.title, 'goodbye')
        self.assertEqual(fresh_instance.text, 'world')

    @unittest.skipUnless(
        get_rest_framework_features()
        ['has_auto_writable_nested_serialization'],
        "This version of DRF doesn't have automatic writable nested serialization"
    )
    def test_should_not_use_update_fields_when_related_objects_are_saving(
            self):
        data = {
            'title': 'goodbye',
            'user': {
                'id': self.user.pk,
                'name': 'oleg'
            }
        }
        serializer = CommentSerializerWithExpandedUsersLiked(
            instance=self.get_comment(), data=data, partial=True)
        self.assertTrue(serializer.is_valid())
        try:
            serializer.save()
        except ValueError as exc:
            self.fail(
                'If serializer has expanded related serializer, then it should not use update_fields while '
                'saving related object')
        fresh_instance = self.get_comment()
        self.assertEqual(fresh_instance.title, 'goodbye')

        if get_rest_framework_features()['save_related_serializers']:
            self.assertEqual(fresh_instance.user.name, 'oleg')

    def test_should_not_use_field_attname_for_update_fields__if_attname_not_allowed_in_serializer_fields(
            self):
        another_user = UserModel.objects.create(name='vova')
        data = {'title': 'goodbye', 'user_id': another_user.id}
        serializer = CommentSerializer(instance=self.get_comment(),
                                       data=data,
                                       partial=True)
        self.assertTrue(serializer.is_valid())
        serializer.save()
        fresh_instance = self.get_comment()
        self.assertEqual(fresh_instance.user_id, self.user.id)

    def test_should_use_field_attname_for_update_fields__if_attname_allowed_in_serializer_fields(
            self):
        another_user = UserModel.objects.create(name='vova')
        data = {'title': 'goodbye', 'user_id': another_user.id}
        serializer = CommentSerializerWithAllowedUserId(
            instance=self.get_comment(), data=data, partial=True)
        self.assertTrue(serializer.is_valid())
        serializer.save()
        fresh_instance = self.get_comment()
        self.assertEqual(fresh_instance.user_id, another_user.id)

    def test_should_not_use_pk_field_for_update_fields(self):
        old_pk = self.get_comment().pk
        data = {'id': old_pk + 1, 'title': 'goodbye'}
        serializer = CommentSerializer(instance=self.get_comment(),
                                       data=data,
                                       partial=True)
        self.assertTrue(serializer.is_valid())
        try:
            serializer.save()
        except ValueError:
            self.fail(
                'Primary key field should be excluded from update_fields list')
        fresh_instance = self.get_comment()
        self.assertEqual(fresh_instance.pk, old_pk)
        self.assertEqual(fresh_instance.title, u'goodbye')