Ejemplo n.º 1
0
    def get_id_fields(self):
        """
        Called to return a list of fields consisting of, at minimum,
        the PK field name. The output of this method is used to
        construct a Prefetch object with a .only() queryset
        when this field is not being sideloaded but we need to
        return a list of IDs.
        """
        model = self.get_model()
        meta = Meta(model)

        out = [meta.get_pk_field().attname]

        # If this is being called, it means it
        # is a many-relation  to its parent.
        # Django wants the FK to the parent,
        # but since accurately inferring the FK
        # pointing back to the parent is less than trivial,
        # we will just pull all ID fields.
        # TODO: We also might need to return all non-nullable fields,
        #    or else it is possible Django will issue another request.
        for field in meta.get_fields():
            if isinstance(field, models.ForeignKey):
                out.append(field.attname)

        return out
Ejemplo n.º 2
0
    def update(self, instance, validated_data):
        # support nested writes if possible
        meta = Meta(instance)
        to_save = [instance]
        # Simply set each attribute on the instance, and then save it.
        # Note that unlike `.create()` we don't need to treat many-to-many
        # relationships as being a special case. During updates we already
        # have an instance pk for the relationships to be associated with.
        try:

            with transaction.atomic():
                for attr, value in validated_data.items():
                    try:
                        field = meta.get_field(attr)
                        if field.related_model:
                            if isinstance(value, dict):
                                # nested dictionary on a has-one
                                # relationship, we should take the current
                                # related value and apply updates to it
                                to_save.extend(
                                    nested_update(instance, attr, value))
                            else:
                                # normal relationship update
                                setattr(instance, attr, value)
                        else:
                            setattr(instance, attr, value)
                    except AttributeError:
                        setattr(instance, attr, value)

                for s in to_save:
                    s.save()
        except Exception as e:
            raise exceptions.ValidationError(e)

        return instance
Ejemplo n.º 3
0
    def _build_requested_prefetches(self, prefetches, requirements, model,
                                    fields, filters, is_root_level):
        """Build a prefetch dictionary based on request requirements."""
        meta = Meta(model)
        for name, field in six.iteritems(fields):
            original_field = field
            if isinstance(field, dfields.DynamicRelationField):
                field = field.serializer
            if isinstance(field, serializers.ListSerializer):
                field = field.child
            if not isinstance(field, serializers.ModelSerializer):
                continue

            source = field.source or name
            if '.' in source:
                raise ValidationError('Nested relationship values '
                                      'are not supported')
            if source == '*':
                # ignore custom getter/setter
                continue

            if source in prefetches:
                # ignore duplicated sources
                continue

            related_queryset = getattr(original_field, 'queryset', None)
            if callable(related_queryset):
                related_queryset = related_queryset(field)

            is_id_only = getattr(field, 'id_only', lambda: False)()
            is_remote = meta.is_field_remote(source)
            is_gui_root = self.view.get_format() == 'admin' and is_root_level
            if (related_queryset is None and is_id_only and not is_remote
                    and not is_gui_root):
                # full representation and remote fields
                # should all trigger prefetching
                continue

            # Popping the source here (during explicit prefetch construction)
            # guarantees that implicitly required prefetches that follow will
            # not conflict.
            required = requirements.pop(source, None)

            query_name = Meta.get_query_name(original_field.model_field)
            prefetch_queryset = self._build_queryset(serializer=field,
                                                     filters=filters.get(
                                                         query_name, {}),
                                                     queryset=related_queryset,
                                                     requirements=required)

            # There can only be one prefetch per source, even
            # though there can be multiple fields pointing to
            # the same source. This could break in some cases,
            # but is mostly an issue on writes when we use all
            # fields by default.
            prefetches[source] = Prefetch(source, queryset=prefetch_queryset)

        return prefetches
Ejemplo n.º 4
0
 def prepare_value(self, instance):
     model = self.parent_model
     source = self.source or self.field_name
     choices = Meta(model).get_field(source).choices
     value = getattr(instance, source)
     choice = dict(choices).get(value)
     return choice
Ejemplo n.º 5
0
    def resolve(self, serializer, query, view=None):
        """Resolve an ordering.

        Arguments:
            query: a string representing an API field
                e.g: "location.name"
            serializer: a serializer instance
                e.g. UserSerializer
            view: a view instance (optional)
                e.g. UserViewSet

        Returns:
            Double-underscore-separated list of strings,
            representing a model field.
                e.g. "location__real_name"

        Raises:
            ValidationError if the query cannot be rewritten
        """
        if not self._is_allowed_query(query, view):
            raise ValidationError('Invalid sort option: %s' % query)

        model_fields, _ = serializer.resolve(query, sort=True)
        resolved = '__'.join([Meta.get_query_name(f) for f in model_fields])
        return resolved
Ejemplo n.º 6
0
    def _build_implicit_prefetches(self, model, prefetches, requirements):
        """Build a prefetch dictionary based on internal requirements."""

        meta = Meta(model)
        for source, remainder in six.iteritems(requirements):
            if not remainder or isinstance(remainder, six.string_types):
                # no further requirements to prefetch
                continue

            related_field = meta.get_field(source)
            related_model = get_related_model(related_field)

            queryset = self._build_implicit_queryset(
                related_model, remainder) if related_model else None

            prefetches[source] = Prefetch(source, queryset=queryset)

        return prefetches
Ejemplo n.º 7
0
    def test_delete(self):
        view = self.view

        if view is None:
            return

        if 'delete' not in view.http_method_names:
            return

        for renderer in view.get_renderers():
            instance = self.create_instance()
            url = '%s?format=%s' % (self.get_url(instance.pk), renderer.format)
            response = self.api_client.delete(url)
            self.assertTrue(
                response.status_code < 400, 'DELETE %s failed with %d: %s' %
                (url, response.status_code, response.content.decode('utf-8')))
            model = self.get_model()
            model_name = Meta(model).get_name()
            pk = instance.pk
            self.assertFalse(
                model.objects.filter(pk=pk).exists(),
                'DELETE %s succeeded but instance "%s.%s" still exists' %
                (url, model_name, pk))
Ejemplo n.º 8
0
    def test_create(self):
        view = self.view
        if view is None:
            return

        if 'post' not in view.http_method_names:
            return

        model = self.get_model()
        for renderer in view.get_renderers():

            format = renderer.format
            url = '%s?format=%s' % (self.get_url(), format)
            data = self.get_post_params()
            response = self.api_client.post(url,
                                            content_type='application/json',
                                            data=json.dumps(data))
            self.assertTrue(
                response.status_code < 400, 'POST %s failed with %d:\n%s' %
                (url, response.status_code, response.content.decode('utf-8')))
            content = response.content.decode('utf-8')
            if format == 'json':
                content = json.loads(content)
                model = self.get_model()
                model_name = Meta(model).get_name()
                serializer = self.serializer_class()
                name = serializer.get_name()
                pk_field = serializer.get_field('pk')
                if pk_field:
                    pk_field = pk_field.field_name
                    self.assertTrue(name in content, content)
                    pk = content[name][pk_field]
                    self.assertTrue(
                        model.objects.filter(pk=pk).exists(),
                        'POST %s succeeded but instance '
                        '"%s.%s" does not exist' % (url, model_name, pk))
Ejemplo n.º 9
0
    def resolve(self, query):
        """Resolves a query into model and serializer fields.

        Arguments:
            query: an API field path, in dot-nation
                e.g: "creator.location_name"

        Returns:
            (model_fields, api_fields)
                e.g:
                    [
                        Blog._meta.fields.user,
                        User._meta.fields.location,
                        Location._meta.fields.name
                    ],
                    [
                        DynamicRelationField(source="user"),
                        DynamicCharField(source="location.name")
                    ]

        Raises:
            ValidationError if the query is invalid,
                e.g. references a method field or an undefined field
        ```

        Note that the lists do not necessarily contain the
        same number of elements because API fields can reference nested model fields.
        """  # noqa
        if not isinstance(query, six.string_types):
            parts = query
            query = '.'.join(query)
        else:
            parts = query.split('.')

        model_fields = []
        api_fields = []

        serializer = self

        model = serializer.get_model()
        resource_name = serializer.get_name()
        meta = Meta(model)
        api_name = parts[0]
        other = parts[1:]

        try:
            api_field = serializer.get_field(api_name)
        except:
            api_field = None

        if other:
            if not (api_field and isinstance(api_field, DynamicRelationField)):
                raise ValidationError({
                    api_name:
                    'Could not resolve "%s": '
                    '"%s.%s" is not an API relation' %
                    (query, resource_name, api_name)
                })

            source = api_field.source or api_name
            related = api_field.serializer_class()
            other = '.'.join(other)
            model_fields, api_fields = related.resolve(other)

            try:
                model_field = meta.get_field(source)
            except AttributeError:
                raise ValidationError({
                    api_name:
                    'Could not resolve "%s": '
                    '"%s.%s" is not a model relation' %
                    (query, meta.get_name(), source)
                })

            model_fields.insert(0, model_field)
            api_fields.insert(0, api_field)
        else:
            if api_name == 'pk':
                # pk is an alias for the id field
                model_field = meta.get_pk_field()
                model_fields.append(model_field)
                if api_field:
                    # the pk field may not exist
                    # on the serializer
                    api_fields.append(api_field)
            else:
                if not api_field:
                    raise ValidationError({
                        api_name:
                        'Could not resolve "%s": '
                        '"%s.%s" is not an API field' %
                        (query, resource_name, api_name)
                    })

                api_fields.append(api_field)

                if api_field.source == '*':
                    # a method field was requested, model field is unknown
                    return (model_fields, api_fields)

                source = api_field.source or api_name
                if '.' in source:
                    fields = source.split('.')
                    for field in fields[:-1]:
                        related_model = None
                        try:
                            model_field = meta.get_field(field)
                            related_model = model_field.related_model
                        except:
                            pass

                        if not related_model:
                            raise ValidationError({
                                api_name:
                                'Could not resolve "%s": '
                                '"%s.%s" is not a model relation' %
                                (query, meta.get_name(), field)
                            })
                        model = related_model
                        meta = Meta(model)
                        model_fields.append(model_field)
                    field = fields[-1]
                    try:
                        model_field = meta.get_field(field)
                    except:
                        raise ValidationError({
                            api_name:
                            'Could not resolve: "%s", '
                            '"%s.%s" is not a model field' %
                            (query, meta.get_name(), field)
                        })
                    model_fields.append(model_field)
                else:
                    try:
                        model_field = meta.get_field(source)
                    except:
                        raise ValidationError({
                            api_name:
                            'Could not resolve "%s": '
                            '"%s.%s" is not a model field' %
                            (query, meta.get_name(), source)
                        })
                    model_fields.append(model_field)

        return (model_fields, api_fields)
Ejemplo n.º 10
0
    def _build_queryset(self,
                        serializer=None,
                        filters=None,
                        queryset=None,
                        requirements=None):
        """Build a queryset that pulls in all data required by this request.

        Handles nested prefetching of related data and deferring fields
        at the queryset level.

        Arguments:
            serializer: An optional serializer to use a base for the queryset.
                If no serializer is passed, the `get_serializer` method will
                be used to initialize the base serializer for the viewset.
            filters: An optional TreeMap of nested filters.
            queryset: An optional base queryset.
            requirements: An optional TreeMap of nested requirements.
        """

        is_root_level = False
        if serializer:
            if queryset is None:
                queryset = serializer.Meta.model.objects
        else:
            serializer = self.view.get_serializer()
            is_root_level = True

        model = serializer.get_model()

        if not model:
            return queryset

        meta = Meta(model)

        prefetches = {}

        # build a nested Prefetch queryset
        # based on request parameters and serializer fields
        fields = serializer.fields

        if requirements is None:
            requirements = TreeMap()

        self._get_implicit_requirements(fields, requirements)

        if filters is None:
            filters = self._get_requested_filters()

        # build nested Prefetch queryset
        self._build_requested_prefetches(prefetches, requirements, model,
                                         fields, filters, is_root_level)

        # build remaining prefetches out of internal requirements
        # that are not already covered by request requirements
        self._build_implicit_prefetches(model, prefetches, requirements)

        # use requirements at this level to limit fields selected
        # only do this for GET requests where we are not requesting the
        # entire fieldset
        is_gui = self.view.get_format() == 'admin'
        if ('*' not in requirements and not self.view.is_update()
                and not self.view.is_delete() and not is_gui):
            id_fields = getattr(serializer, 'get_id_fields', lambda: [])()
            # only include local model fields
            only = [
                field for field in set(id_fields + list(requirements.keys()))
                if meta.is_field(field) and not meta.is_field_remote(field)
            ]
            queryset = queryset.only(*only)

        # add request filters
        query = self._filters_to_query(filters)

        if query:
            # Convert internal django ValidationError to
            # APIException-based one in order to resolve validation error
            # from 500 status code to 400.
            try:
                queryset = queryset.filter(query)
            except InternalValidationError as e:
                raise ValidationError(
                    dict(e) if hasattr(e, 'error_dict') else list(e))
            except Exception as e:
                # Some other Django error in parsing the filter.
                # Very likely a bad query, so throw a ValidationError.
                err_msg = getattr(e, 'message', '')
                raise ValidationError(err_msg)

        # A serializer can have this optional function
        # to dynamically apply additional filters on
        # any queries that will use that serializer
        # You could use this to have (for example) different
        # serializers for different subsets of a model or to
        # implement permissions which work even in sideloads
        if hasattr(serializer, 'filter_queryset'):
            queryset = serializer.filter_queryset(queryset)

        # add prefetches and remove duplicates if necessary
        prefetch = prefetches.values()
        queryset = queryset.prefetch_related(*prefetch)
        if has_joins(queryset) or not is_root_level:
            queryset = queryset.distinct()

        if self.DEBUG:
            queryset._using_prefetches = prefetches
        return queryset
Ejemplo n.º 11
0
    def _get_requested_filters(self, **kwargs):
        """
        Convert 'filters' query params into a dict that can be passed
        to Q. Returns a dict with two fields, 'include' and 'exclude',
        which can be used like:

            result = self._get_requested_filters()
            q = Q(**result['_include'] & ~Q(**result['_exclude'])

        """

        filters_map = kwargs.get('filters_map')

        view = getattr(self, 'view', None)
        if view:
            serializer_class = view.get_serializer_class()
            serializer = serializer_class()
            if not filters_map:
                filters_map = view.get_request_feature(view.FILTER)
        else:
            serializer = None

        out = TreeMap()

        for key, value in six.iteritems(filters_map):

            # Inclusion or exclusion?
            if key[0] == '-':
                key = key[1:]
                category = '_exclude'
            else:
                category = '_include'

            # for relational filters, separate out relation path part
            if '|' in key:
                rel, key = key.split('|')
                rel = rel.split('.')
            else:
                rel = None

            terms = key.split('.')
            # Last part could be operator, e.g. "events.capacity.gte"
            if len(terms) > 1 and terms[-1] in self.VALID_FILTER_OPERATORS:
                operator = terms.pop()
            else:
                operator = None

            # All operators except 'range' and 'in' should have one value
            if operator == 'range':
                value = value[:2]
                if value[0] == '':
                    operator = 'lte'
                    value = value[1]
                elif value[1] == '':
                    operator = 'gte'
                    value = value[0]
            elif operator == 'in':
                # no-op: i.e. accept `value` as an arbitrarily long list
                pass
            elif operator in self.VALID_FILTER_OPERATORS:
                value = value[0]
                if (operator == 'isnull'
                        and isinstance(value, six.string_types)):
                    value = is_truthy(value)
                elif operator == 'eq':
                    operator = None

            if serializer:
                s = serializer

                if rel:
                    # get related serializer
                    model_fields, serializer_fields = serializer.resolve(rel)
                    s = serializer_fields[-1]
                    s = getattr(s, 'serializer', s)
                    rel = [Meta.get_query_name(f) for f in model_fields]

                # perform model-field resolution
                model_fields, serializer_fields = s.resolve(terms)
                field = serializer_fields[-1] if serializer_fields else None
                # if the field is a boolean,
                # coerce the value
                if field and isinstance(
                        field,
                    (serializers.BooleanField, serializers.NullBooleanField)):
                    value = is_truthy(value)
                key = '__'.join([Meta.get_query_name(f) for f in model_fields])

            else:
                key = '__'.join(terms)

            if operator:
                key += '__%s' % operator

            # insert into output tree
            path = rel if rel else []
            path += [category, key]
            out.insert(path, value)
        return out