class PermissionsTestCase(APITestCase):
    def setUp(self):
        self.factory = APIRequestFactory()
        self.perm = None

    def check_model_perms(self, checks, user):
        for method, check in checks.items():
            request = self.factory.generic(method, '/url/')
            request.user = user
            self.assertEqual(
                self.perm.has_permission(request, 'view'),
                check,
                msg="Invalid perms for user {} and method {}".format(
                    user, method))

    def check_obj_perms(self, checks, user, obj):
        for method, check in checks.items():
            request = self.factory.generic(method, '/url/')
            request.user = user
            fail_msg = "Invalid obj perms for user {} and method {}".format(
                user, method)
            if type(check) == bool:
                self.assertEqual(self.perm.has_object_permission(
                    request, 'view', obj),
                                 check,
                                 msg=fail_msg)
            else:
                with self.assertRaises(check, msg=fail_msg):
                    self.perm.has_object_permission(request, 'view', obj)

    def check_box_obj_perms(self,
                            checks,
                            user,
                            box_owner=None,
                            share_with=None,
                            share_perm=None):
        if box_owner is None:
            box_owner = UserFactory()
        for visibility in [Box.PRIVATE, Box.PUBLIC]:
            box = BoxFactory(visibility=visibility, owner=box_owner)
            if share_with:
                box.share_with(share_with, share_perm)
            self.check_obj_perms(checks, user, box)
示例#2
0
class TestSerialize:
    def setup(self):
        from rest_framework.test import APIRequestFactory

        class DummyJSONSerializer(serializers.Serializer):
            json_field = fields.JSONField()

        self.factory = APIRequestFactory()
        self.Serializer = DummyJSONSerializer

    def test_request_POST_with_form_content_JSONField(self):
        data = b'-----------------------------20775962551482475149231161538\r\nContent-Disposition: form-data; name="json_field"\r\n\r\n{"a": true}\r\n-----------------------------20775962551482475149231161538--\r\n'
        content_type = 'multipart/form-data; boundary=---------------------------20775962551482475149231161538'
        wsgi_request = self.factory.generic('POST', '/', data, content_type)
        request = Request(wsgi_request)
        request.parsers = (FormParser(), MultiPartParser())
        serializer = self.Serializer(data=request.data)
        assert serializer.is_valid(raise_exception=True) == True
        assert dict(serializer.data)['json_field'] == {'a': True}
    def test_csv_with_excel_content_type(self):
        """
        Often on Windows a csv file comes with an excel content-type (e.g: 'application/vnd.ms-excel')
        Test that we handle the case.
        """
        view = InferDatasetView.as_view()
        columns = ['Name', 'Age', 'Weight', 'Comments']
        rows = [
            columns, ['Frederic', '56', '80.5', 'a comment'],
            ['Hilda', '24', '56', '']
        ]
        file_ = helpers.rows_to_csv_file(rows)
        factory = APIRequestFactory()
        with open(file_, 'rb') as fp:
            payload = {
                'file': fp,
            }
            # In order to hack the Content-Type of the multipart form data we need to use the APIRequestFactory and work
            # with the view directly. Can't use the classic API client.
            # hack the content-type of the request.
            data, content_type = factory._encode_data(payload,
                                                      format='multipart')
            if six.PY3:
                data = data.decode('utf-8')
            data = data.replace('Content-Type: text/csv',
                                'Content-Type: application/vnd.ms-excel')
            if six.PY3:
                data = data.encode('utf-8')
            request = factory.generic('POST',
                                      self.url,
                                      data,
                                      content_type=content_type)
            user = self.data_engineer_1_user
            token, _ = Token.objects.get_or_create(user=user)
            force_authenticate(request,
                               user=self.data_engineer_1_user,
                               token=token)
            resp = view(request).render()
            self.assertEqual(status.HTTP_200_OK, resp.status_code)
            # should be json
            self.assertEqual(resp.get('content-type'), 'application/json')
            if six.PY3:
                content = resp.content.decode('utf-8')
            else:
                content = resp.content
            received = json.loads(content)

            # name should be set with the file name
            self.assertIn('name', received)
            file_name = path.splitext(path.basename(fp.name))[0]
            self.assertEqual(file_name, received.get('name'))
            # type should be 'generic'
            self.assertIn('type', received)
            self.assertEqual('generic', received.get('type'))

            # data_package verification
            self.assertIn('data_package', received)
            self.verify_inferred_data(received)

            # verify schema
            schema_descriptor = Package(
                received.get('data_package')).resources[0].descriptor['schema']
            schema = utils_data_package.GenericSchema(schema_descriptor)
            self.assertEqual(len(schema.fields), len(columns))
            self.assertEqual(schema.field_names, columns)

            field = schema.get_field_by_name('Name')
            self.assertEqual(field.type, 'string')
            self.assertFalse(field.required)
            self.assertEqual(field.format, 'default')

            field = schema.get_field_by_name('Age')
            self.assertEqual(field.type, 'integer')
            self.assertFalse(field.required)
            self.assertEqual(field.format, 'default')

            field = schema.get_field_by_name('Weight')
            self.assertEqual(field.type, 'number')
            self.assertFalse(field.required)
            self.assertEqual(field.format, 'default')

            field = schema.get_field_by_name('Comments')
            self.assertEqual(field.type, 'string')
            self.assertFalse(field.required)
            self.assertEqual(field.format, 'default')
示例#4
0
class RecordViewSetsPermissions(TestCase):
    views = [
        # template
        (RecordTemplateViewSet, 'create', 'partial_update', 'update',
         'destroy', 'list', 'retrieve'),
        # fields
        (RecordEncryptedStandardFieldViewSet, 'create', 'partial_update',
         'update', 'destroy'),
        (RecordStandardFieldViewSet, 'create', 'partial_update', 'update',
         'destroy'),
        (RecordEncryptedFileFieldViewSet, 'create', 'partial_update', 'update',
         'destroy'),
        (RecordSelectFieldViewSet, 'create', 'partial_update', 'update',
         'destroy'),
        (RecordEncryptedSelectFieldViewSet, 'create', 'partial_update',
         'update', 'destroy'),
        (RecordStateFieldViewSet, 'create', 'partial_update', 'update',
         'destroy'),
        (RecordUsersFieldViewSet, 'create', 'partial_update', 'update',
         'destroy'),
        # record
        (RecordViewSet, 'create', 'retrieve', 'destroy', 'list'),
        # entry
        (RecordEncryptedSelectEntryViewSet, 'create', 'partial_update',
         'update', 'destroy'),
        (RecordSelectEntryViewSet, 'create', 'partial_update', 'update',
         'destroy'),
        (RecordEncryptedFileEntryViewSet, 'create', 'partial_update', 'update',
         'destroy'),
        (RecordStandardEntryViewSet, 'create', 'partial_update', 'update',
         'destroy'),
        (RecordEncryptedStandardEntryViewSet, 'create', 'partial_update',
         'update', 'destroy'),
        (RecordUsersEntryViewSet, 'create', 'partial_update', 'update',
         'destroy'),
        (RecordStateEntryViewSet, 'create', 'partial_update', 'update',
         'destroy')
    ]

    action_mapper = {
        'create': 'POST',
        'update': 'PUT',
        'partial_update': 'PATCH',
        'destroy': 'DELETE',
        'list': 'GET',
        'retrieve': 'GET'
    }

    def setUp(self):
        self.factory = APIRequestFactory()
        self.rlc = Rlc.objects.create(name="Test RLC")
        self.user = UserProfile.objects.create(email='*****@*****.**',
                                               name='Dummy 1',
                                               rlc=self.rlc)
        self.user.set_password(settings.DUMMY_USER_PASSWORD)
        self.user.save()
        self.rlc_user = RlcUser.objects.create(user=self.user,
                                               email_confirmed=True,
                                               accepted=True)

    def check_forbidden(self, view_class, actions):
        for action in actions:
            view = view_class.as_view(
                actions={self.action_mapper[action]: action})
            request = self.factory.generic(self.action_mapper[action], '')
            response = view(request)
            self.assertEqual(response.status_code, 401)

    def check_not_existent(self, view_class, actions):
        for action in actions:
            view = view_class.as_view(
                actions={self.action_mapper[action]: action})
            request = self.factory.generic(self.action_mapper[action], '')
            with self.assertRaises(AttributeError):
                view(request)

    def test_permissions(self):
        for view_class in self.views:
            self.check_forbidden(view_class[0], view_class[1:])
            non_existent = list(self.action_mapper.keys())
            [non_existent.remove(x) for x in view_class[1:]]
            self.check_not_existent(view_class[0], non_existent)
示例#5
0
class UserPermissionsTestCase(APITestCase):

    def setUp(self):
        self.factory = APIRequestFactory()
        self.perm = UserPermissions()
        self.user_obj = UserFactory()

    def check_perms(self, checks, user):
        for method, check in checks.items():
            request = self.factory.generic(method, '/url/')
            request.user = user
            self.assertEqual(
                self.perm.has_permission(request, 'view'),
                check,
                msg="Invalid perms for user {} and method {}"
                    .format(user, method))

    def check_obj_perms(self, checks, user):
        for method, check in checks.items():
            request = self.factory.generic(method, '/url/')
            request.user = user
            self.assertEqual(
                self.perm.has_object_permission(request, 'view', self.user_obj),
                check,
                msg="Invalid obj perms for user {} and method {}"
                    .format(user, method))

    def test_staff_has_perms(self):
        user = StaffFactory()
        checks = {
            'GET': True,
            'PATCH': True,
            'POST': True,
            'DELETE': True,
        }
        self.check_perms(checks, user)
        self.check_obj_perms(checks, user)

    def test_authenticated_has_perms(self):
        user = UserFactory()
        checks = {
            'GET': True,
            'PATCH': True,
            'POST': False,
            'DELETE': False,
        }
        self.check_perms(checks, user)
        checks = {
            'GET': True,
            'PATCH': False,
            'POST': False,
            'DELETE': False,
        }
        self.check_obj_perms(checks, user)

    def test_user_owner_has_perms(self):
        user = self.user_obj
        checks = {
            'GET': True,
            'PATCH': True,
            'POST': False,
            'DELETE': False,
        }
        self.check_perms(checks, user)
        checks = {
            'GET': True,
            'PATCH': True,
            'POST': False,
            'DELETE': False,
        }
        self.check_obj_perms(checks, user)

    def test_anonymous_doesnt_has_perms(self):
        user = AnonymousUser()

        request = self.factory.post('/url/')
        request.user = user
        with self.assertRaises(Http404):
            self.perm.has_permission(request, 'view')

        request = self.factory.patch('/url/')
        request.user = user
        with self.assertRaises(Http404):
            self.perm.has_permission(request, 'view')

        request = self.factory.get('/url/')
        request.user = user
        with self.assertRaises(Http404):
            self.perm.has_permission(request, 'view')