def get_id_fields(self): """ Called to return a list of fields consisting of, at minimum, the PK field name. The output of this method is used to construct a Prefetch object with a .only() queryset when this field is not being sideloaded but we need to return a list of IDs. """ model = self.get_model() meta = Meta(model) out = [meta.get_pk_field().attname] # If this is being called, it means it # is a many-relation to its parent. # Django wants the FK to the parent, # but since accurately inferring the FK # pointing back to the parent is less than trivial, # we will just pull all ID fields. # TODO: We also might need to return all non-nullable fields, # or else it is possible Django will issue another request. for field in meta.get_fields(): if isinstance(field, models.ForeignKey): out.append(field.attname) return out
def update(self, instance, validated_data): # support nested writes if possible meta = Meta(instance) to_save = [instance] # Simply set each attribute on the instance, and then save it. # Note that unlike `.create()` we don't need to treat many-to-many # relationships as being a special case. During updates we already # have an instance pk for the relationships to be associated with. try: with transaction.atomic(): for attr, value in validated_data.items(): try: field = meta.get_field(attr) if field.related_model: if isinstance(value, dict): # nested dictionary on a has-one # relationship, we should take the current # related value and apply updates to it to_save.extend( nested_update(instance, attr, value)) else: # normal relationship update setattr(instance, attr, value) else: setattr(instance, attr, value) except AttributeError: setattr(instance, attr, value) for s in to_save: s.save() except Exception as e: raise exceptions.ValidationError(e) return instance
def _build_requested_prefetches(self, prefetches, requirements, model, fields, filters, is_root_level): """Build a prefetch dictionary based on request requirements.""" meta = Meta(model) for name, field in six.iteritems(fields): original_field = field if isinstance(field, dfields.DynamicRelationField): field = field.serializer if isinstance(field, serializers.ListSerializer): field = field.child if not isinstance(field, serializers.ModelSerializer): continue source = field.source or name if '.' in source: raise ValidationError('Nested relationship values ' 'are not supported') if source == '*': # ignore custom getter/setter continue if source in prefetches: # ignore duplicated sources continue related_queryset = getattr(original_field, 'queryset', None) if callable(related_queryset): related_queryset = related_queryset(field) is_id_only = getattr(field, 'id_only', lambda: False)() is_remote = meta.is_field_remote(source) is_gui_root = self.view.get_format() == 'admin' and is_root_level if (related_queryset is None and is_id_only and not is_remote and not is_gui_root): # full representation and remote fields # should all trigger prefetching continue # Popping the source here (during explicit prefetch construction) # guarantees that implicitly required prefetches that follow will # not conflict. required = requirements.pop(source, None) query_name = Meta.get_query_name(original_field.model_field) prefetch_queryset = self._build_queryset(serializer=field, filters=filters.get( query_name, {}), queryset=related_queryset, requirements=required) # There can only be one prefetch per source, even # though there can be multiple fields pointing to # the same source. This could break in some cases, # but is mostly an issue on writes when we use all # fields by default. prefetches[source] = Prefetch(source, queryset=prefetch_queryset) return prefetches
def prepare_value(self, instance): model = self.parent_model source = self.source or self.field_name choices = Meta(model).get_field(source).choices value = getattr(instance, source) choice = dict(choices).get(value) return choice
def resolve(self, serializer, query, view=None): """Resolve an ordering. Arguments: query: a string representing an API field e.g: "location.name" serializer: a serializer instance e.g. UserSerializer view: a view instance (optional) e.g. UserViewSet Returns: Double-underscore-separated list of strings, representing a model field. e.g. "location__real_name" Raises: ValidationError if the query cannot be rewritten """ if not self._is_allowed_query(query, view): raise ValidationError('Invalid sort option: %s' % query) model_fields, _ = serializer.resolve(query, sort=True) resolved = '__'.join([Meta.get_query_name(f) for f in model_fields]) return resolved
def _build_implicit_prefetches(self, model, prefetches, requirements): """Build a prefetch dictionary based on internal requirements.""" meta = Meta(model) for source, remainder in six.iteritems(requirements): if not remainder or isinstance(remainder, six.string_types): # no further requirements to prefetch continue related_field = meta.get_field(source) related_model = get_related_model(related_field) queryset = self._build_implicit_queryset( related_model, remainder) if related_model else None prefetches[source] = Prefetch(source, queryset=queryset) return prefetches
def test_delete(self): view = self.view if view is None: return if 'delete' not in view.http_method_names: return for renderer in view.get_renderers(): instance = self.create_instance() url = '%s?format=%s' % (self.get_url(instance.pk), renderer.format) response = self.api_client.delete(url) self.assertTrue( response.status_code < 400, 'DELETE %s failed with %d: %s' % (url, response.status_code, response.content.decode('utf-8'))) model = self.get_model() model_name = Meta(model).get_name() pk = instance.pk self.assertFalse( model.objects.filter(pk=pk).exists(), 'DELETE %s succeeded but instance "%s.%s" still exists' % (url, model_name, pk))
def test_create(self): view = self.view if view is None: return if 'post' not in view.http_method_names: return model = self.get_model() for renderer in view.get_renderers(): format = renderer.format url = '%s?format=%s' % (self.get_url(), format) data = self.get_post_params() response = self.api_client.post(url, content_type='application/json', data=json.dumps(data)) self.assertTrue( response.status_code < 400, 'POST %s failed with %d:\n%s' % (url, response.status_code, response.content.decode('utf-8'))) content = response.content.decode('utf-8') if format == 'json': content = json.loads(content) model = self.get_model() model_name = Meta(model).get_name() serializer = self.serializer_class() name = serializer.get_name() pk_field = serializer.get_field('pk') if pk_field: pk_field = pk_field.field_name self.assertTrue(name in content, content) pk = content[name][pk_field] self.assertTrue( model.objects.filter(pk=pk).exists(), 'POST %s succeeded but instance ' '"%s.%s" does not exist' % (url, model_name, pk))
def resolve(self, query): """Resolves a query into model and serializer fields. Arguments: query: an API field path, in dot-nation e.g: "creator.location_name" Returns: (model_fields, api_fields) e.g: [ Blog._meta.fields.user, User._meta.fields.location, Location._meta.fields.name ], [ DynamicRelationField(source="user"), DynamicCharField(source="location.name") ] Raises: ValidationError if the query is invalid, e.g. references a method field or an undefined field ``` Note that the lists do not necessarily contain the same number of elements because API fields can reference nested model fields. """ # noqa if not isinstance(query, six.string_types): parts = query query = '.'.join(query) else: parts = query.split('.') model_fields = [] api_fields = [] serializer = self model = serializer.get_model() resource_name = serializer.get_name() meta = Meta(model) api_name = parts[0] other = parts[1:] try: api_field = serializer.get_field(api_name) except: api_field = None if other: if not (api_field and isinstance(api_field, DynamicRelationField)): raise ValidationError({ api_name: 'Could not resolve "%s": ' '"%s.%s" is not an API relation' % (query, resource_name, api_name) }) source = api_field.source or api_name related = api_field.serializer_class() other = '.'.join(other) model_fields, api_fields = related.resolve(other) try: model_field = meta.get_field(source) except AttributeError: raise ValidationError({ api_name: 'Could not resolve "%s": ' '"%s.%s" is not a model relation' % (query, meta.get_name(), source) }) model_fields.insert(0, model_field) api_fields.insert(0, api_field) else: if api_name == 'pk': # pk is an alias for the id field model_field = meta.get_pk_field() model_fields.append(model_field) if api_field: # the pk field may not exist # on the serializer api_fields.append(api_field) else: if not api_field: raise ValidationError({ api_name: 'Could not resolve "%s": ' '"%s.%s" is not an API field' % (query, resource_name, api_name) }) api_fields.append(api_field) if api_field.source == '*': # a method field was requested, model field is unknown return (model_fields, api_fields) source = api_field.source or api_name if '.' in source: fields = source.split('.') for field in fields[:-1]: related_model = None try: model_field = meta.get_field(field) related_model = model_field.related_model except: pass if not related_model: raise ValidationError({ api_name: 'Could not resolve "%s": ' '"%s.%s" is not a model relation' % (query, meta.get_name(), field) }) model = related_model meta = Meta(model) model_fields.append(model_field) field = fields[-1] try: model_field = meta.get_field(field) except: raise ValidationError({ api_name: 'Could not resolve: "%s", ' '"%s.%s" is not a model field' % (query, meta.get_name(), field) }) model_fields.append(model_field) else: try: model_field = meta.get_field(source) except: raise ValidationError({ api_name: 'Could not resolve "%s": ' '"%s.%s" is not a model field' % (query, meta.get_name(), source) }) model_fields.append(model_field) return (model_fields, api_fields)
def _build_queryset(self, serializer=None, filters=None, queryset=None, requirements=None): """Build a queryset that pulls in all data required by this request. Handles nested prefetching of related data and deferring fields at the queryset level. Arguments: serializer: An optional serializer to use a base for the queryset. If no serializer is passed, the `get_serializer` method will be used to initialize the base serializer for the viewset. filters: An optional TreeMap of nested filters. queryset: An optional base queryset. requirements: An optional TreeMap of nested requirements. """ is_root_level = False if serializer: if queryset is None: queryset = serializer.Meta.model.objects else: serializer = self.view.get_serializer() is_root_level = True model = serializer.get_model() if not model: return queryset meta = Meta(model) prefetches = {} # build a nested Prefetch queryset # based on request parameters and serializer fields fields = serializer.fields if requirements is None: requirements = TreeMap() self._get_implicit_requirements(fields, requirements) if filters is None: filters = self._get_requested_filters() # build nested Prefetch queryset self._build_requested_prefetches(prefetches, requirements, model, fields, filters, is_root_level) # build remaining prefetches out of internal requirements # that are not already covered by request requirements self._build_implicit_prefetches(model, prefetches, requirements) # use requirements at this level to limit fields selected # only do this for GET requests where we are not requesting the # entire fieldset is_gui = self.view.get_format() == 'admin' if ('*' not in requirements and not self.view.is_update() and not self.view.is_delete() and not is_gui): id_fields = getattr(serializer, 'get_id_fields', lambda: [])() # only include local model fields only = [ field for field in set(id_fields + list(requirements.keys())) if meta.is_field(field) and not meta.is_field_remote(field) ] queryset = queryset.only(*only) # add request filters query = self._filters_to_query(filters) if query: # Convert internal django ValidationError to # APIException-based one in order to resolve validation error # from 500 status code to 400. try: queryset = queryset.filter(query) except InternalValidationError as e: raise ValidationError( dict(e) if hasattr(e, 'error_dict') else list(e)) except Exception as e: # Some other Django error in parsing the filter. # Very likely a bad query, so throw a ValidationError. err_msg = getattr(e, 'message', '') raise ValidationError(err_msg) # A serializer can have this optional function # to dynamically apply additional filters on # any queries that will use that serializer # You could use this to have (for example) different # serializers for different subsets of a model or to # implement permissions which work even in sideloads if hasattr(serializer, 'filter_queryset'): queryset = serializer.filter_queryset(queryset) # add prefetches and remove duplicates if necessary prefetch = prefetches.values() queryset = queryset.prefetch_related(*prefetch) if has_joins(queryset) or not is_root_level: queryset = queryset.distinct() if self.DEBUG: queryset._using_prefetches = prefetches return queryset
def _get_requested_filters(self, **kwargs): """ Convert 'filters' query params into a dict that can be passed to Q. Returns a dict with two fields, 'include' and 'exclude', which can be used like: result = self._get_requested_filters() q = Q(**result['_include'] & ~Q(**result['_exclude']) """ filters_map = kwargs.get('filters_map') view = getattr(self, 'view', None) if view: serializer_class = view.get_serializer_class() serializer = serializer_class() if not filters_map: filters_map = view.get_request_feature(view.FILTER) else: serializer = None out = TreeMap() for key, value in six.iteritems(filters_map): # Inclusion or exclusion? if key[0] == '-': key = key[1:] category = '_exclude' else: category = '_include' # for relational filters, separate out relation path part if '|' in key: rel, key = key.split('|') rel = rel.split('.') else: rel = None terms = key.split('.') # Last part could be operator, e.g. "events.capacity.gte" if len(terms) > 1 and terms[-1] in self.VALID_FILTER_OPERATORS: operator = terms.pop() else: operator = None # All operators except 'range' and 'in' should have one value if operator == 'range': value = value[:2] if value[0] == '': operator = 'lte' value = value[1] elif value[1] == '': operator = 'gte' value = value[0] elif operator == 'in': # no-op: i.e. accept `value` as an arbitrarily long list pass elif operator in self.VALID_FILTER_OPERATORS: value = value[0] if (operator == 'isnull' and isinstance(value, six.string_types)): value = is_truthy(value) elif operator == 'eq': operator = None if serializer: s = serializer if rel: # get related serializer model_fields, serializer_fields = serializer.resolve(rel) s = serializer_fields[-1] s = getattr(s, 'serializer', s) rel = [Meta.get_query_name(f) for f in model_fields] # perform model-field resolution model_fields, serializer_fields = s.resolve(terms) field = serializer_fields[-1] if serializer_fields else None # if the field is a boolean, # coerce the value if field and isinstance( field, (serializers.BooleanField, serializers.NullBooleanField)): value = is_truthy(value) key = '__'.join([Meta.get_query_name(f) for f in model_fields]) else: key = '__'.join(terms) if operator: key += '__%s' % operator # insert into output tree path = rel if rel else [] path += [category, key] out.insert(path, value) return out