def get_page_size(self, request): if (get_rest_framework_features()['max_paginate_by'] and self.page_size_query_param and self.max_page_size and request.query_params.get(self.page_size_query_param) == 'max'): return self.max_page_size return super(PaginateByMaxMixin, self).get_page_size(request)
def get_lookup_allowed_symbols(kwarg_name='pk', force_dot=False): # todo: test me if get_rest_framework_features()['use_dot_in_lookup_regex_by_default'] or force_dot: return '(?P<{0}>[^/.]+)'.format(kwarg_name) else: return '(?P<{0}>[^/]+)'.format(kwarg_name)
def test_router_trailing_slash(self): experiments = [ { 'version': (2, 3), 'expected': False }, { 'version': (2, 3, 5), 'expected': False }, { 'version': (2, 3, 6), 'expected': True }, { 'version': (2, 3, 7), 'expected': True }, { 'version': (2, 4), 'expected': True }, ] for exp in experiments: with patch('rest_framework_extensions.utils.get_rest_framework_version', Mock(return_value=exp['version'])): self.assertEqual(get_rest_framework_features()['router_trailing_slash'], exp['expected'])
def test_should_save_only_fields_from_data_for_partial_update(self): # it's important to use different instances for Comment, # because serializer's save method affects instance from arguments serializer_one = CommentSerializer( instance=self.get_comment(), data={'title': 'goodbye'}, partial=True) serializer_two = CommentSerializer( instance=self.get_comment(), data={'text': 'moon'}, partial=True) serializer_three_kwargs = { 'instance': self.get_comment(), 'partial': True } if get_rest_framework_features()['uses_single_request_data_in_serializers']: serializer_three_kwargs['data'] = {'attachment': self.files[1]} else: serializer_three_kwargs.update({ 'data': {}, 'files': {'attachment': self.files[1]} }) serializer_three = CommentSerializer(**serializer_three_kwargs) self.assertTrue(serializer_one.is_valid()) self.assertTrue(serializer_two.is_valid()) self.assertTrue(serializer_three.is_valid()) # saving three serializers expecting they don't affect each other's saving serializer_one.save() serializer_two.save() serializer_three.save() fresh_instance = self.get_comment() self.assertEqual( fresh_instance.attachment.read(), u'file two'.encode('utf-8')) self.assertEqual(fresh_instance.text, 'moon') self.assertEqual(fresh_instance.title, 'goodbye')
def get_paginate_by(self, *args, **kwargs): if (get_rest_framework_features()['max_paginate_by'] and self.paginate_by_param and self.max_paginate_by and self.request.QUERY_PARAMS.get(self.paginate_by_param) == 'max'): return self.max_paginate_by else: return super(PaginateByMaxMixin, self).get_paginate_by(*args, **kwargs)
def get_paginate_by(self, *args, **kwargs): if (get_rest_framework_features()['max_paginate_by'] and self.paginate_by_param and self.max_paginate_by and self.request.QUERY_PARAMS.get(self.paginate_by_param) == 'max'): return self.max_paginate_by else: return super(PaginateByMaxMixin, self).get_paginate_by(*args, **kwargs)
def get_lookup_allowed_symbols(kwarg_name='pk', force_dot=False): # todo: test me if get_rest_framework_features( )['use_dot_in_lookup_regex_by_default'] or force_dot: return '(?P<{0}>[^/.]+)'.format(kwarg_name) else: return '(?P<{0}>[^/]+)'.format(kwarg_name)
def test_simple_response(self): resp = self.client.get('/users/') expected = [{ 'id': 1, 'age': 24, 'name': 'Gennady', 'surname': 'Chibisov' }] if not get_rest_framework_features()['write_only_fields']: expected[0]['password'] = self.user.password self.assertEqual(resp.data, expected)
def test_simple_response(self): resp = self.client.get('/users/') expected = [ { 'id': 1, 'age': 24, 'name': 'Gennady', 'surname': 'Chibisov' } ] if not get_rest_framework_features()['write_only_fields']: expected[0]['password'] = self.user.password self.assertEqual(resp.data, expected)
def test_urls_can_have_trailing_slash_removed(self): router = ExtendedSimpleRouter(trailing_slash=False) router.register(r'router-viewset', RouterViewSet) urls = router.urls lookup_allowed_symbols = get_lookup_allowed_symbols( force_dot=get_rest_framework_features()['allow_dot_in_lookup_regex_without_trailing_slash'] ) for exp in ['^router-viewset$', '^router-viewset/{0}$'.format(lookup_allowed_symbols), '^router-viewset/list_controller$', '^router-viewset/{0}/detail_controller$'.format(lookup_allowed_symbols)]: msg = 'Should find url pattern with regexp %s' % exp self.assertIsNotNone(get_url_pattern_by_regex_pattern(urls, exp), msg=msg)
def test_urls_can_have_trailing_slash_removed(self): router = ExtendedSimpleRouter(trailing_slash=False) router.register(r'router-viewset', RouterViewSet) urls = router.urls if get_rest_framework_features()['allow_dot_in_lookup_regex_without_trailing_slash']: lookup_allowed_symbols = '(?P<pk>[^/.]+)' else: lookup_allowed_symbols = '(?P<pk>[^/]+)' for exp in ['^router-viewset$', '^router-viewset/' + lookup_allowed_symbols + r'$', '^router-viewset/list_controller$', '^router-viewset/' + lookup_allowed_symbols + '/detail_controller$']: msg = 'Should find url pattern with regexp %s' % exp self.assertIsNotNone(get_url_pattern_by_regex_pattern(urls, exp), msg=msg)
def test_max_paginate_by(self): experiments = [ { 'version': (2, 3), 'expected': False }, { 'version': (2, 3, 5), 'expected': False }, { 'version': (2, 3, 6), 'expected': False }, { 'version': (2, 3, 7), 'expected': False }, { 'version': (2, 3, 8), 'expected': True }, { 'version': (2, 3, 9), 'expected': True }, { 'version': (2, 3, 10), 'expected': True }, { 'version': (2, 4), 'expected': True }, ] for exp in experiments: with patch('rest_framework_extensions.utils.get_rest_framework_version', Mock(return_value=exp['version'])): self.assertEqual( get_rest_framework_features()['max_paginate_by'], exp['expected'] )
def test_django_object_permissions_class(self): experiments = [ { 'version': (2, 3), 'expected': False }, { 'version': (2, 3, 5), 'expected': False }, { 'version': (2, 3, 6), 'expected': False }, { 'version': (2, 3, 7), 'expected': False }, { 'version': (2, 3, 8), 'expected': True }, { 'version': (2, 3, 9), 'expected': True }, { 'version': (2, 3, 10), 'expected': True }, { 'version': (2, 4), 'expected': True }, ] for exp in experiments: with patch('rest_framework_extensions.utils.get_rest_framework_version', Mock(return_value=exp['version'])): self.assertEqual( get_rest_framework_features()['django_object_permissions_class'], exp['expected'] )
def test_should_not_use_update_fields_when_related_objects_are_saving(self): data = { 'title': 'goodbye', 'user': { 'id': self.user.pk, 'name': 'oleg' } } serializer = CommentSerializerWithExpandedUsersLiked(instance=self.get_comment(), data=data, partial=True) self.assertTrue(serializer.is_valid()) try: serializer.save() except ValueError: self.fail('If serializer has expanded related serializer, then it should not use update_fields while ' 'saving related object') fresh_instance = self.get_comment() self.assertEqual(fresh_instance.title, 'goodbye') if get_rest_framework_features()['save_related_serializers']: self.assertEqual(fresh_instance.user.name, 'oleg')
def test_should_not_use_update_fields_when_related_objects_are_saving( self): data = { 'title': 'goodbye', 'user': { 'id': self.user.pk, 'name': 'oleg' } } serializer = CommentSerializerWithExpandedUsersLiked( instance=self.get_comment(), data=data, partial=True) self.assertTrue(serializer.is_valid()) try: serializer.save() except ValueError: self.fail( 'If serializer has expanded related serializer, then it should not use update_fields while ' 'saving related object') fresh_instance = self.get_comment() self.assertEqual(fresh_instance.title, 'goodbye') if get_rest_framework_features()['save_related_serializers']: self.assertEqual(fresh_instance.user.name, 'oleg')
def add_trailing_slash_if_needed(regexp_string): # todo: test me if get_rest_framework_features()['router_trailing_slash']: return regexp_string[:-2] + '{trailing_slash}$' else: return regexp_string
class ListUpdateModelMixinTestBehaviour__serializer_fields(APITestCase): urls = urlpatterns def setUp(self): self.user = User.objects.create( id=1, name='Gennady', age=24, last_name='Chibisov', email='*****@*****.**', password='******' ) self.headers = { utils.prepare_header_name(extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME): 'true' } def get_fresh_user(self): return User.objects.get(pk=self.user.pk) def test_simple_response(self): resp = self.client.get('/users/') expected = [ { 'id': 1, 'age': 24, 'name': 'Gennady', 'surname': 'Chibisov' } ] if not get_rest_framework_features()['write_only_fields']: expected[0]['password'] = self.user.password self.assertEqual(resp.data, expected) def test_invalid_for_db_data(self): data = { 'age': 'Not integer value' } try: resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers) except ValueError: self.fail('Errors with invalid for DB data should be caught') else: self.assertEqual(resp.status_code, 400) expected_message = { 'detail': "invalid literal for int() with base 10: 'Not integer value'" } self.assertEqual(resp.data, expected_message) def test_should_use_source_if_it_set_in_serializer(self): data = { 'surname': 'Ivanov' } resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers) self.assertEqual(resp.status_code, 204) self.assertEqual(self.get_fresh_user().last_name, data['surname']) @unittest.skipIf( not get_rest_framework_features()['write_only_fields'], "Current DRF version doesn't support write_only_fields" ) def test_should_update_write_only_fields(self): data = { 'password': '******' } resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers) self.assertEqual(resp.status_code, 204) self.assertEqual(self.get_fresh_user().password, data['password']) def test_should_not_update_read_only_fields(self): data = { 'name': 'Ivan' } resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers) self.assertEqual(resp.status_code, 204) self.assertEqual(self.get_fresh_user().name, self.user.name) def test_should_not_update_hidden_fields(self): data = { 'email': '*****@*****.**' } resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers) self.assertEqual(resp.status_code, 204) self.assertEqual(self.get_fresh_user().email, self.user.email)
# -*- coding: utf-8 -*- import datetime import unittest from django.test import TestCase from rest_framework_extensions.utils import get_rest_framework_features from .urls import urlpatterns from .models import CommentForPaginateByMaxMixin @unittest.skipIf( not get_rest_framework_features()['max_paginate_by'], "Current DRF version doesn't support max_paginate_by parameter" ) class PaginateByMaxMixinTest(TestCase): urls = urlpatterns def setUp(self): for i in range(30): CommentForPaginateByMaxMixin.objects.create( email='*****@*****.**', content='Hello world', created=datetime.datetime.now() ) def test_default_page_size(self): resp = self.client.get('/comments/') self.assertEqual(len(resp.data['results']), 10)
# -*- coding: utf-8 -*- import datetime import unittest from django.test import TestCase from rest_framework_extensions.utils import get_rest_framework_features from .urls import urlpatterns from .models import CommentForPaginateByMaxMixin @unittest.skipIf( not get_rest_framework_features()['max_paginate_by'], "Current DRF version doesn't support max_paginate_by parameter") class PaginateByMaxMixinTest(TestCase): urls = urlpatterns def setUp(self): for i in range(30): CommentForPaginateByMaxMixin.objects.create( email='*****@*****.**', content='Hello world', created=datetime.datetime.now()) def test_default_page_size(self): resp = self.client.get('/comments/') self.assertEqual(len(resp.data['results']), 10) def test_custom_page_size__less_then_maximum(self): resp = self.client.get('/comments/?limit=15')
assign_perm(perms['view'], readers, model) assign_perm(perms['change'], writers, model) assign_perm(perms['delete'], deleters, model) readers.user_set.add(users['fullaccess'], users['readonly']) writers.user_set.add(users['fullaccess'], users['writeonly']) deleters.user_set.add(users['fullaccess'], users['deleteonly']) self.credentials = {} for user in users.values(): self.credentials[user.username] = basic_auth_header(user.username, 'password') @unittest.skipIf( not get_rest_framework_features()['django_object_permissions_class'], "Current DRF version doesn't support DjangoObjectPermissions" ) class ExtendedDjangoObjectPermissionsTest_should_inherit_standard(ExtendedDjangoObjectPermissionTestMixin, APITestCase): urls = urlpatterns # Delete def test_can_delete_permissions(self): response = self.client.delete('/comments/1/', **{'HTTP_AUTHORIZATION': self.credentials['deleteonly']}) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) def test_cannot_delete_permissions(self): response = self.client.delete('/comments/1/', **{'HTTP_AUTHORIZATION': self.credentials['readonly']}) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
concrete_field_names = [] for field in get_model_opts_concrete_fields(opts): if not field.primary_key: concrete_field_names.append(field.name) if field.name != field.attname: concrete_field_names.append(field.attname) update_fields = [] for field_name in partial_fields: if field_name in fields: model_field_name = getattr(fields[field_name], 'source') or field_name if model_field_name in concrete_field_names: update_fields.append(model_field_name) return update_fields if get_rest_framework_features()['single_step_object_creation_in_serializers']: class PartialUpdateSerializerMixin(object): def save(self, **kwargs): self._update_fields = kwargs.get('update_fields', None) return super(PartialUpdateSerializerMixin, self).save(**kwargs) def update(self, instance, validated_attrs): for attr, value in validated_attrs.items(): setattr(instance, attr, value) if self.partial and isinstance(instance, self.Meta.model): instance.save( update_fields=getattr(self, '_update_fields') or get_fields_for_partial_update( opts=self.Meta, init_data=self.get_initial(), fields=self.fields.fields )
class ExtendedDefaultRouterTest(TestCase): def setUp(self): self.router = ExtendedDefaultRouter() def get_routes_names(self, routes): return [i.name for i in routes] def get_dynamic_route_by_def_name(self, def_name, routes): try: return [i for i in routes if def_name in i.mapping.values()][0] except IndexError: return None def test_dynamic_routes_should_be_first_in_order(self): class BasicViewSet(viewsets.ViewSet): def list(self, request, *args, **kwargs): return Response({'method': 'list'}) @action() def action1(self, request, *args, **kwargs): return Response({'method': 'action1'}) @link() def link1(self, request, *args, **kwargs): return Response({'method': 'link1'}) routes = self.router.get_routes(BasicViewSet) expected = [ '{basename}-action1', '{basename}-link1', '{basename}-list', '{basename}-detail' ] msg = '@action and @link methods should come first in routes order' self.assertEqual(self.get_routes_names(routes), expected, msg) def test_action_endpoint(self): class BasicViewSet(viewsets.ViewSet): @action(endpoint='action-one') def action1(self, request, *args, **kwargs): pass routes = self.router.get_routes(BasicViewSet) action1_route = self.get_dynamic_route_by_def_name('action1', routes) msg = '@action with endpoint route should map methods to endpoint if it is specified' self.assertEqual(action1_route.mapping, {'post': 'action1'}, msg) msg = '@action with endpoint route should use url with detail lookup' self.assertEqual( action1_route.url, add_trailing_slash_if_needed(u'^{prefix}/{lookup}/action-one/$'), msg) def test_link_endpoint(self): class BasicViewSet(viewsets.ViewSet): @link(endpoint='link-one') def link1(self, request, *args, **kwargs): pass routes = self.router.get_routes(BasicViewSet) link1_route = self.get_dynamic_route_by_def_name('link1', routes) msg = '@link with endpoint route should map methods to endpoint if it is specified' self.assertEqual(link1_route.mapping, {'get': 'link1'}, msg) msg = '@link with endpoint route should use url with detail lookup' self.assertEqual( link1_route.url, add_trailing_slash_if_needed(u'^{prefix}/{lookup}/link-one/$'), msg) def test_action__for_list(self): class BasicViewSet(viewsets.ViewSet): @action(is_for_list=True) def action1(self, request, *args, **kwargs): pass routes = self.router.get_routes(BasicViewSet) action1_route = self.get_dynamic_route_by_def_name('action1', routes) msg = '@action with is_for_list=True route should map methods to def name' self.assertEqual(action1_route.mapping, {'post': 'action1'}, msg) msg = '@action with is_for_list=True route should use url in list scope' self.assertEqual(action1_route.url, add_trailing_slash_if_needed(u'^{prefix}/action1/$'), msg) def test_action__for_list__and__with_endpoint(self): class BasicViewSet(viewsets.ViewSet): @action(is_for_list=True, endpoint='action-one') def action1(self, request, *args, **kwargs): pass routes = self.router.get_routes(BasicViewSet) action1_route = self.get_dynamic_route_by_def_name('action1', routes) msg = '@action with is_for_list=True and endpoint route should map methods to "endpoint"' self.assertEqual(action1_route.mapping, {'post': 'action1'}, msg) msg = '@action with is_for_list=True and endpoint route should use url in list scope with "endpoint" value' self.assertEqual( action1_route.url, add_trailing_slash_if_needed(u'^{prefix}/action-one/$'), msg) def test_actions__for_list_and_detail_with_exact_names(self): class BasicViewSet(viewsets.ViewSet): @action(is_for_list=True, endpoint='action-one') def action1(self, request, *args, **kwargs): pass @action(endpoint='action-one') def action1_detail(self, request, *args, **kwargs): pass routes = self.router.get_routes(BasicViewSet) action1_list_route = self.get_dynamic_route_by_def_name( 'action1', routes) action1_detail_route = self.get_dynamic_route_by_def_name( 'action1_detail', routes) self.assertEqual(action1_list_route.mapping, {'post': 'action1'}) self.assertEqual( action1_list_route.url, add_trailing_slash_if_needed(u'^{prefix}/action-one/$')) self.assertEqual(action1_detail_route.mapping, {'post': 'action1_detail'}) self.assertEqual( action1_detail_route.url, add_trailing_slash_if_needed(u'^{prefix}/{lookup}/action-one/$')) def test_action_names(self): class BasicViewSet(viewsets.ViewSet): @action(is_for_list=True) def action1(self, request, *args, **kwargs): pass @action() def action2(self, request, *args, **kwargs): pass routes = self.router.get_routes(BasicViewSet) action1_list_route = self.get_dynamic_route_by_def_name( 'action1', routes) action2_detail_route = self.get_dynamic_route_by_def_name( 'action2', routes) self.assertEqual(action1_list_route.name, u'{basename}-action1-list') self.assertEqual(action2_detail_route.name, u'{basename}-action2') def test_action_names__with_endpoints(self): class BasicViewSet(viewsets.ViewSet): @action(is_for_list=True, endpoint='action_one') def action1(self, request, *args, **kwargs): pass @action(endpoint='action-two') def action2(self, request, *args, **kwargs): pass routes = self.router.get_routes(BasicViewSet) action1_list_route = self.get_dynamic_route_by_def_name( 'action1', routes) action2_detail_route = self.get_dynamic_route_by_def_name( 'action2', routes) self.assertEqual(action1_list_route.name, u'{basename}-action-one-list') self.assertEqual(action2_detail_route.name, u'{basename}-action-two') @unittest.skipIf( not get_rest_framework_features()['has_action_and_link_decorators'], "Current DRF version has removed 'action' and 'link' decorators") def test_with_default_controllers(self): class BasicViewSet(viewsets.ViewSet): @link() def link(self, request, *args, **kwargs): pass @decorators.link() def link_default(self, request, *args, **kwargs): pass @action() def action(self, request, *args, **kwargs): pass @decorators.action() def action_default(self, request, *args, **kwargs): pass routes = self.router.get_routes(BasicViewSet) link_route = self.get_dynamic_route_by_def_name('link', routes) link_default_route = self.get_dynamic_route_by_def_name( 'link_default', routes) action_route = self.get_dynamic_route_by_def_name('action', routes) action_default_route = self.get_dynamic_route_by_def_name( 'action_default', routes) self.assertEqual(link_route.name, u'{basename}-link') self.assertEqual(link_default_route.name, u'{basename}-link-default') self.assertEqual(action_route.name, u'{basename}-action') self.assertEqual(action_default_route.name, u'{basename}-action-default')
def get_page_size(self, request): if (get_rest_framework_features()['max_paginate_by'] and self.page_size_query_param and self.max_page_size and request.query_params.get(self.page_size_query_param) == 'max'): return self.max_page_size return super(PaginateByMaxMixin, self).get_page_size(request)
for field in get_model_opts_concrete_fields(opts): if not field.primary_key: concrete_field_names.append(field.name) if field.name != field.attname: concrete_field_names.append(field.attname) update_fields = [] for field_name in partial_fields: if field_name in fields: model_field_name = getattr(fields[field_name], 'source') or field_name if model_field_name in concrete_field_names: update_fields.append(model_field_name) return update_fields if get_rest_framework_features()['single_step_object_creation_in_serializers']: class PartialUpdateSerializerMixin(object): def save(self, **kwargs): self._update_fields = kwargs.get('update_fields', None) return super(PartialUpdateSerializerMixin, self).save(**kwargs) def update(self, instance, validated_attrs): for attr, value in validated_attrs.items(): setattr(instance, attr, value) if self.partial and isinstance(instance, self.Meta.model): instance.save( update_fields=getattr(self, '_update_fields') or get_fields_for_partial_update(opts=self.Meta, init_data=self.get_initial(), fields=self.fields.fields))
router = ExtendedSimpleRouter() router.register(r'router-viewset', RouterViewSet) urls = router.urls lookup_allowed_symbols = get_lookup_allowed_symbols() for exp in ['^router-viewset/$', '^router-viewset/{0}/$'.format(lookup_allowed_symbols), '^router-viewset/list_controller/$', '^router-viewset/{0}/detail_controller/$'.format(lookup_allowed_symbols)]: msg = 'Should find url pattern with regexp %s' % exp self.assertIsNotNone(get_url_pattern_by_regex_pattern(urls, exp), msg=msg) @unittest.skipUnless( get_rest_framework_features()['router_trailing_slash'], "Current DRF version doesn't support Router trailing_slash" ) class TestTrailingSlashRemoved(TestCase): def test_urls_can_have_trailing_slash_removed(self): router = ExtendedSimpleRouter(trailing_slash=False) router.register(r'router-viewset', RouterViewSet) urls = router.urls lookup_allowed_symbols = get_lookup_allowed_symbols( force_dot=get_rest_framework_features()['allow_dot_in_lookup_regex_without_trailing_slash'] ) for exp in ['^router-viewset$', '^router-viewset/{0}$'.format(lookup_allowed_symbols), '^router-viewset/list_controller$',
def add_trailing_slash_if_needed(regexp_string): # todo: test me if get_rest_framework_features()['router_trailing_slash']: return regexp_string[:-2] + '{trailing_slash}$' else: return regexp_string
# -*- coding: utf-8 -*- from rest_framework_extensions.utils import get_rest_framework_features if get_rest_framework_features()['django_object_permissions_class']: from .extended_django_object_permissions import ExtendedDjangoObjectPermissions
class PartialUpdateSerializerMixinTest(TestCase): def setUp(self): self.files = [ File(BytesIO(u'file one'.encode('utf-8')), name='file1.txt'), File(BytesIO(u'file two'.encode('utf-8')), name='file2.txt'), ] self.files[0]._set_size(8) self.files[1]._set_size(8) self.user = UserModel.objects.create(name='gena') self.comment = CommentModel.objects.create(user=self.user, title='hello', text='world', attachment=self.files[0]) def get_comment(self): return CommentModel.objects.get(pk=self.comment.pk) @unittest.skipIf( sys.version_info[0] == 3, "Skipped for python3 because of https://github.com/tomchristie/django-rest-framework/issues/1642" ) def test_should_use_default_saving_without_partial(self): serializer = CommentSerializer(data={ 'user': self.user.id, 'title': 'hola', 'text': 'amigos', }) self.assertTrue( serializer.is_valid()) # bug for python3 comes from here saved_object = serializer.save() self.assertEqual(saved_object.user, self.user) self.assertEqual(saved_object.title, 'hola') self.assertEqual(saved_object.text, 'amigos') def test_should_save_partial(self): serializer = CommentSerializer(instance=self.comment, data={'title': 'hola'}, partial=True) self.assertTrue(serializer.is_valid()) saved_object = serializer.save() self.assertEqual(saved_object.user, self.user) self.assertEqual(saved_object.title, 'hola') self.assertEqual(saved_object.text, 'world') def test_should_save_only_fields_from_data_for_partial_update(self): # it's important to use different instances for Comment, because serializer's save method affects # instance from arguments serializer_one = CommentSerializer(instance=self.get_comment(), data={'title': 'goodbye'}, partial=True) serializer_two = CommentSerializer(instance=self.get_comment(), data={'text': 'moon'}, partial=True) serializer_three_kwargs = { 'instance': self.get_comment(), 'partial': True } if get_rest_framework_features( )['uses_single_request_data_in_serializers']: serializer_three_kwargs['data'] = {'attachment': self.files[1]} else: serializer_three_kwargs.update({ 'data': {}, 'files': { 'attachment': self.files[1] } }) serializer_three = CommentSerializer(**serializer_three_kwargs) self.assertTrue(serializer_one.is_valid()) self.assertTrue(serializer_two.is_valid()) self.assertTrue(serializer_three.is_valid()) # saving three serializers expecting they don't affect each other's saving serializer_one.save() serializer_two.save() serializer_three.save() fresh_instance = self.get_comment() self.assertEqual(fresh_instance.attachment.read(), u'file two'.encode('utf-8')) self.assertEqual(fresh_instance.text, 'moon') self.assertEqual(fresh_instance.title, 'goodbye') def test_should_use_related_field_name_for_update_field_list(self): another_user = UserModel.objects.create(name='vova') data = {'title': 'goodbye', 'user': another_user.pk} serializer = CommentSerializer(instance=self.get_comment(), data=data, partial=True) self.assertTrue(serializer.is_valid()) serializer.save() fresh_instance = self.get_comment() self.assertEqual(fresh_instance.title, 'goodbye') self.assertEqual(fresh_instance.user, another_user) def test_should_use_field_source_value_for_searching_model_concrete_fields( self): data = {'title_from_source': 'goodbye'} serializer = CommentSerializer(instance=self.get_comment(), data=data, partial=True) self.assertTrue(serializer.is_valid()) serializer.save() fresh_instance = self.get_comment() self.assertEqual(fresh_instance.title, 'goodbye') def test_should_not_use_m2m_field_name_for_update_field_list(self): another_user = UserModel.objects.create(name='vova') data = { 'title': 'goodbye', 'users_liked': [self.user.pk, another_user.pk] } serializer = CommentSerializer(instance=self.get_comment(), data=data, partial=True) self.assertTrue(serializer.is_valid()) try: serializer.save() except ValueError: self.fail( 'If m2m field used in partial update then it should not be used in update_fields list' ) fresh_instance = self.get_comment() self.assertEqual(fresh_instance.title, 'goodbye') users_liked = set(fresh_instance.users_liked.all().values_list( 'pk', flat=True)) self.assertEqual(users_liked, set([self.user.pk, another_user.pk])) def test_should_not_use_related_set_field_name_for_update_field_list(self): another_user = UserModel.objects.create(name='vova') another_comment = CommentModel.objects.create( user=another_user, title='goodbye', text='moon', ) data = {'name': 'vova', 'comments': [another_comment.pk]} serializer = UserSerializer(instance=another_user, data=data, partial=True) self.assertTrue(serializer.is_valid()) serializer.save() try: serializer.save() except ValueError: self.fail( 'If related set field used in partial update then it should not be used in update_fields list' ) fresh_comment = CommentModel.objects.get(pk=another_comment.pk) fresh_user = UserModel.objects.get(pk=another_user.pk) self.assertEqual(fresh_comment.user, another_user) self.assertEqual(fresh_user.name, 'vova') def test_should_not_try_to_update_fields_that_are_not_in_model(self): data = {'title': 'goodbye', 'not_existing_field': 'moon'} serializer = CommentSerializer(instance=self.get_comment(), data=data, partial=True) self.assertTrue(serializer.is_valid()) try: serializer.save() except ValueError: msg = 'Should not pass values to update_fields from data, if they are not in model' self.fail(msg) fresh_instance = self.get_comment() self.assertEqual(fresh_instance.title, 'goodbye') self.assertEqual(fresh_instance.text, 'world') def test_should_not_try_to_update_fields_that_are_not_allowed_from_serializer( self): data = {'title': 'goodbye', 'hidden_text': 'do not change me'} serializer = CommentSerializer(instance=self.get_comment(), data=data, partial=True) self.assertTrue(serializer.is_valid()) serializer.save() fresh_instance = self.get_comment() self.assertEqual(fresh_instance.title, 'goodbye') self.assertEqual(fresh_instance.text, 'world') self.assertEqual(fresh_instance.hidden_text, None) def test_should_use_list_of_fields_to_update_from_arguments_if_it_passed( self): data = {'title': 'goodbye', 'text': 'moon'} serializer = CommentSerializer(instance=self.get_comment(), data=data, partial=True) self.assertTrue(serializer.is_valid()) serializer.save(**{'update_fields': ['title']}) fresh_instance = self.get_comment() self.assertEqual(fresh_instance.title, 'goodbye') self.assertEqual(fresh_instance.text, 'world') @unittest.skipUnless( get_rest_framework_features() ['has_auto_writable_nested_serialization'], "This version of DRF doesn't have automatic writable nested serialization" ) def test_should_not_use_update_fields_when_related_objects_are_saving( self): data = { 'title': 'goodbye', 'user': { 'id': self.user.pk, 'name': 'oleg' } } serializer = CommentSerializerWithExpandedUsersLiked( instance=self.get_comment(), data=data, partial=True) self.assertTrue(serializer.is_valid()) try: serializer.save() except ValueError as exc: self.fail( 'If serializer has expanded related serializer, then it should not use update_fields while ' 'saving related object') fresh_instance = self.get_comment() self.assertEqual(fresh_instance.title, 'goodbye') if get_rest_framework_features()['save_related_serializers']: self.assertEqual(fresh_instance.user.name, 'oleg') def test_should_not_use_field_attname_for_update_fields__if_attname_not_allowed_in_serializer_fields( self): another_user = UserModel.objects.create(name='vova') data = {'title': 'goodbye', 'user_id': another_user.id} serializer = CommentSerializer(instance=self.get_comment(), data=data, partial=True) self.assertTrue(serializer.is_valid()) serializer.save() fresh_instance = self.get_comment() self.assertEqual(fresh_instance.user_id, self.user.id) def test_should_use_field_attname_for_update_fields__if_attname_allowed_in_serializer_fields( self): another_user = UserModel.objects.create(name='vova') data = {'title': 'goodbye', 'user_id': another_user.id} serializer = CommentSerializerWithAllowedUserId( instance=self.get_comment(), data=data, partial=True) self.assertTrue(serializer.is_valid()) serializer.save() fresh_instance = self.get_comment() self.assertEqual(fresh_instance.user_id, another_user.id) def test_should_not_use_pk_field_for_update_fields(self): old_pk = self.get_comment().pk data = {'id': old_pk + 1, 'title': 'goodbye'} serializer = CommentSerializer(instance=self.get_comment(), data=data, partial=True) self.assertTrue(serializer.is_valid()) try: serializer.save() except ValueError: self.fail( 'Primary key field should be excluded from update_fields list') fresh_instance = self.get_comment() self.assertEqual(fresh_instance.pk, old_pk) self.assertEqual(fresh_instance.title, u'goodbye')