Example #1
0
    def _build_implicit_prefetches(
        self,
        model,
        prefetches,
        requirements
    ):
        """Build a prefetch dictionary based on internal requirements."""

        for source, remainder in six.iteritems(requirements):
            if not remainder or isinstance(remainder, six.string_types):
                # no further requirements to prefetch
                continue

            related_field = get_model_field(model, source)
            related_model = get_related_model(related_field)

            queryset = self._build_implicit_queryset(
                related_model,
                remainder
            ) if related_model else None

            prefetches[source] = self._create_prefetch(
                source,
                queryset
            )

        return prefetches
Example #2
0
    def bind(self, *args, **kwargs):
        """Bind to the parent serializer."""
        if self.bound:  # Prevent double-binding
            return
        super(DynamicRelationField, self).bind(*args, **kwargs)
        self.bound = True
        parent_model = getattr(self.parent.Meta, 'model', None)

        remote = is_field_remote(parent_model, self.source)

        try:
            model_field = get_model_field(parent_model, self.source)
        except:
            # model field may not be available for m2o fields with no
            # related_name
            model_field = None

        # Infer `required` and `allow_null`
        if 'required' not in self.kwargs and (
                remote or (model_field and
                           (model_field.has_default() or model_field.null))):
            self.required = False
        if 'allow_null' not in self.kwargs and getattr(model_field, 'null',
                                                       False):
            self.allow_null = True

        self.model_field = model_field
Example #3
0
def nested_update(instance, key, value, objects=None):
    objects = objects or []
    nested = getattr(instance, key, None)

    def fix(x):
        s = str(x).lower()
        if s == "true":
            return "True"
        if s == "false":
            return "False"
        return x

    value = {k: fix(v) for k, v in value.items()}
    if not nested:
        # object does not exist, try to create it
        try:
            field = get_model_field(instance, key)
            related_model = get_related_model(field)
        except:
            raise exceptions.ValidationError('Invalid relationship: %s' % key)
        else:
            nested = related_model.objects.create(**value)
            setattr(instance, key, nested)
    else:
        # object exists, perform a nested update
        for k, v in six.iteritems(value):
            if isinstance(v, dict):
                nested_update(nested, k, v, objects)
            else:
                setattr(nested, k, v)
        objects.append(nested)
    return objects
Example #4
0
 def model_field(self):
     if not hasattr(self, '_model_field'):
         try:
             source = self.source or self.field_name
             self._model_field = get_model_field(self.parent_model, source)
         except:
             self._model_field = None
     return self._model_field
Example #5
0
 def sort_field(self):
     if self.sort_by:
         if not hasattr(self, '_sort_field'):
             source = self.sort_by
             try:
                 self._sort_field = get_model_field(self.parent_model,
                                                    source)
             except AttributeError:
                 self._sort_field = None
         return self._sort_field
     else:
         return self.model_field
Example #6
0
    def get_field(self, field_name):
        # it might be deferred
        fields = self.get_all_fields()
        if field_name == 'pk':
            meta = self.get_meta()
            if hasattr(meta, '_pk'):
                return meta._pk

            field = None
            model = self.get_model()
            primary_key = getattr(meta, 'primary_key', None)

            if primary_key:
                field = fields.get(primary_key)
            else:
                for n, f in fields.items():
                    # try to use model fields
                    try:
                        if getattr(field, 'primary_key', False):
                            field = f
                            break

                        model_field = get_model_field(model, f.source or n)

                        if model_field.primary_key:
                            field = f
                            break
                    except:
                        pass

            if not field:
                # fall back to a field called ID
                if 'id' in fields:
                    field = fields['id']

            if field:
                meta._pk = field
                return field
        else:
            if field_name in fields:
                field = fields[field_name]
                return field

        raise ValidationError(
            {field_name: '"%s" is not an API field' % field_name})
Example #7
0
    def generate_query_key(self, serializer):
        """Get the key that can be passed to Django's filter method.

        To account for serialier field name rewrites, this method
        translates serializer field names to model field names
        by inspecting `serializer`.

        For example, a query like `filter{users.events}` would be
        returned as `users__events`.

        Arguments:
            serializer: A DRF serializer

        Returns:
            A filter key.
        """
        rewritten = []
        last = len(self.field) - 1
        s = serializer
        field = None
        for i, field_name in enumerate(self.field):
            # Note: .fields can be empty for related serializers that aren't
            # sideloaded. Fields that are deferred also won't be present.
            # If field name isn't in serializer.fields, get full list from
            # get_all_fields() method. This is somewhat expensive, so only do
            # this if we have to.
            fields = s.fields
            if field_name not in fields:
                fields = getattr(s, 'get_all_fields', lambda: {})()

            if field_name == 'pk':
                rewritten.append('pk')
                continue

            if field_name not in fields:
                raise ValidationError("Invalid filter field: %s" % field_name)

            field = fields[field_name]

            # For remote fields, strip off '_set' for filtering. This is a
            # weird Django inconsistency.
            model_field_name = field.source or field_name
            model_field = get_model_field(s.get_model(), model_field_name)
            if isinstance(model_field, RelatedObject):
                model_field_name = model_field.field.related_query_name()

            # If get_all_fields() was used above, field could be unbound,
            # and field.source would be None
            rewritten.append(model_field_name)

            if i == last:
                break

            # Recurse into nested field
            s = getattr(field, 'serializer', None)
            if isinstance(s, serializers.ListSerializer):
                s = s.child
            if not s:
                raise ValidationError("Invalid nested filter field: %s" %
                                      field_name)

        if self.operator:
            rewritten.append(self.operator)

        return ('__'.join(rewritten), field)
Example #8
0
    def generate_query_key(self, serializer):
        """Get the key that can be passed to Django's filter method.

        To account for serialier field name rewrites, this method
        translates serializer field names to model field names
        by inspecting `serializer`.

        For example, a query like `filter{users.events}` would be
        returned as `users__events`.

        Arguments:
            serializer: A DRF serializer

        Returns:
            A filter key.
        """
        rewritten = []
        last = len(self.field) - 1
        s = serializer
        for i, field_name in enumerate(self.field):
            # Note: .fields can be empty for related serializers that aren't
            # sideloaded. Fields that are deferred also won't be present.
            # If field name isn't in serializer.fields, get full list from
            # get_all_fields() method. This is somewhat expensive, so only do
            # this if we have to.
            fields = s.fields
            if field_name not in fields:
                fields = getattr(s, 'get_all_fields', lambda: {})()

            if field_name == 'pk':
                rewritten.append('pk')
                continue

            if field_name not in fields:
                raise ValidationError(
                    "Invalid filter field: %s" % field_name
                )

            field = fields[field_name]

            # For remote fields, strip off '_set' for filtering. This is a
            # weird Django inconsistency.
            model_field_name = field.source or field_name
            model_field = get_model_field(s.get_model(), model_field_name)
            if isinstance(model_field, RelatedObject):
                model_field_name = model_field.field.related_query_name()

            # If get_all_fields() was used above, field could be unbound,
            # and field.source would be None
            rewritten.append(model_field_name)

            if i == last:
                break

            # Recurse into nested field
            s = getattr(field, 'serializer', None)
            if isinstance(s, serializers.ListSerializer):
                s = s.child
            if not s:
                raise ValidationError(
                    "Invalid nested filter field: %s" % field_name
                )

        if self.operator:
            rewritten.append(self.operator)

        return '__'.join(rewritten)