示例#1
0
def test_follow_field_source_forward_reverse(no_warnings):
    class FFS1(models.Model):
        id = models.UUIDField(primary_key=True)
        field_bool = models.BooleanField()

    class FFS2(models.Model):
        ffs1 = models.ForeignKey(FFS1, on_delete=models.PROTECT)

    class FFS3(models.Model):
        id = models.CharField(primary_key=True, max_length=3)
        ffs2 = models.ForeignKey(FFS2, on_delete=models.PROTECT)
        field_float = models.FloatField()

    forward_field = follow_field_source(FFS3, ['ffs2', 'ffs1', 'field_bool'])
    reverse_field = follow_field_source(FFS1, ['ffs2', 'ffs3', 'field_float'])
    forward_model = follow_field_source(FFS3, ['ffs2', 'ffs1'])
    reverse_model = follow_field_source(FFS1, ['ffs2', 'ffs3'])

    assert isinstance(forward_field, models.BooleanField)
    assert isinstance(reverse_field, models.FloatField)
    assert isinstance(forward_model, models.UUIDField)
    assert isinstance(reverse_model, models.CharField)

    auto_schema = AutoSchema()
    assert auto_schema._map_model_field(forward_field,
                                        None)['type'] == 'boolean'
    assert auto_schema._map_model_field(reverse_field,
                                        None)['type'] == 'number'
    assert auto_schema._map_model_field(forward_model,
                                        None)['type'] == 'string'
    assert auto_schema._map_model_field(reverse_model,
                                        None)['type'] == 'string'
示例#2
0
    def resolve_filter_field(self, auto_schema, model, filterset_class,
                             filter_field):
        if filter_field.method:
            filter_method = getattr(filterset_class, filter_field.method)
            filter_method_hints = typing.get_type_hints(filter_method)

            if 'value' in filter_method_hints and is_basic_type(
                    filter_method_hints['value']):
                schema = build_basic_type(filter_method_hints['value'])
            else:
                schema = self.map_filter_field(filter_field)
        else:
            path = filter_field.field_name.split('__')
            model_field = follow_field_source(model, path)

            if isinstance(model_field, models.Field):
                schema = auto_schema._map_model_field(model_field,
                                                      direction=None)
            else:
                schema = self.map_filter_field(filter_field)

        enum = schema.pop('enum', None)

        if 'choices' in filter_field.extra:
            enum = [c for c, _ in filter_field.extra['choices']]

        description = schema.pop('description', None)

        if filter_field.extra.get('help_text', None):
            description = filter_field.extra['help_text']
        elif filter_field.label is not None:
            description = filter_field.label

        return schema, description, enum
示例#3
0
    def resolve_filter_field(self, auto_schema, model, filterset_class,
                             field_name, filter_field):
        from django_filters.rest_framework import filters

        if isinstance(filter_field, filters.OrderingFilter):
            # only here filter_field.field_name is not the model field name/path
            schema = build_basic_type(OpenApiTypes.STR)
        elif filter_field.method:
            if callable(filter_field.method):
                filter_method = filter_field.method
            else:
                filter_method = getattr(filterset_class, filter_field.method)

            try:
                filter_method_hints = typing.get_type_hints(filter_method)
            except:  # noqa: E722
                filter_method_hints = {}

            if 'value' in filter_method_hints and is_basic_type(
                    filter_method_hints['value']):
                schema = build_basic_type(filter_method_hints['value'])
            else:
                schema = self.map_filter_field(filter_field)
        else:
            path = filter_field.field_name.split('__')
            model_field = follow_field_source(model, path)

            if isinstance(model_field, models.Field):
                schema = auto_schema._map_model_field(model_field,
                                                      direction=None)
            else:
                schema = self.map_filter_field(filter_field)

        enum = schema.pop('enum', None)
        if 'choices' in filter_field.extra:
            enum = [c for c, _ in filter_field.extra['choices']]
        if enum:
            schema['enum'] = sorted(enum, key=str)

        description = schema.pop('description', None)
        if filter_field.extra.get('help_text', None):
            description = filter_field.extra['help_text']
        elif filter_field.label is not None:
            description = filter_field.label

        if isinstance(filter_field, filters.BaseCSVFilter):
            schema = build_array_type(schema)
            explode = False
            style = 'form'
        else:
            explode = None
            style = None

        return build_parameter_type(name=field_name,
                                    required=filter_field.extra['required'],
                                    location=OpenApiParameter.QUERY,
                                    description=description,
                                    schema=schema,
                                    explode=explode,
                                    style=style)
示例#4
0
    def get_schema_operation_parameters(self, auto_schema, *args, **kwargs):
        if issubclass(self.target_class, SpectacularDjangoFilterBackendMixin):
            warn(
                'DEPRECATED - Spectacular\'s DjangoFilterBackend is superseded by extension. you '
                'can simply restore this to the original class, extensions will take care of the '
                'rest.')

        model = get_view_model(auto_schema.view)
        if not model:
            return []

        filterset_class = self.target.get_filterset_class(
            auto_schema.view, model.objects.none())
        if not filterset_class:
            return []

        parameters = []
        for field_name, field in filterset_class.base_filters.items():
            path = field.field_name.split('__')
            model_field = follow_field_source(model, path)

            parameters.append(
                build_parameter_type(
                    name=field_name,
                    required=field.extra['required'],
                    location=OpenApiParameter.QUERY,
                    description=field.label
                    if field.label is not None else field_name,
                    schema=auto_schema._map_model_field(model_field,
                                                        direction=None),
                    enum=[c for c, _ in field.extra.get('choices', [])],
                ))

        return parameters
示例#5
0
    def get_schema_operation_parameters(self, view):
        model = get_view_model(view)
        if not model:
            return []

        filterset_class = self.get_filterset_class(view, model.objects.none())
        if not filterset_class:
            return []

        parameters = []
        for field_name, field in filterset_class.base_filters.items():
            path = field.field_name.split('__')
            model_field = follow_field_source(model, path)

            parameters.append(
                build_parameter_type(
                    name=field_name,
                    required=field.extra['required'],
                    location=OpenApiParameter.QUERY,
                    description=field.label
                    if field.label is not None else field_name,
                    schema=view.schema._map_model_field(model_field,
                                                        direction=None),
                    enum=[c for c, _ in field.extra.get('choices', [])],
                ))

        return parameters
示例#6
0
    def _map_serializer_field(self, field, direction):
        if has_override(field, 'field'):
            override = get_override(field, 'field')
            if is_basic_type(override):
                return build_basic_type(override)
            else:
                return self._map_serializer_field(override, direction)

        meta = self._get_serializer_field_meta(field)

        serializer_field_extension = OpenApiSerializerFieldExtension.get_match(
            field)
        if serializer_field_extension:
            schema = serializer_field_extension.map_serializer_field(
                self, direction)
            return append_meta(schema, meta)

        # nested serializer
        if isinstance(field, serializers.Serializer):
            component = self.resolve_serializer(field, direction)
            return append_meta(component.ref, meta) if component else None

        # nested serializer with many=True gets automatically replaced with ListSerializer
        if isinstance(field, serializers.ListSerializer):
            if is_serializer(field.child):
                component = self.resolve_serializer(field.child, direction)
                return append_meta(build_array_type(component.ref),
                                   meta) if component else None
            else:
                schema = self._map_serializer_field(field.child, direction)
                return append_meta(build_array_type(schema), meta)

        # Related fields.
        if isinstance(field, serializers.ManyRelatedField):
            schema = self._map_serializer_field(field.child_relation,
                                                direction)
            # remove hand-over initkwargs applying only to outer scope
            schema.pop('description', None)
            schema.pop('readOnly', None)
            return append_meta(build_array_type(schema), meta)

        if isinstance(field, serializers.PrimaryKeyRelatedField):
            # read_only fields do not have a Manager by design. go around and get field
            # from parent. also avoid calling Manager. __bool__ as it might be customized
            # to hit the database.
            if getattr(field, 'queryset', None) is not None:
                model_field = field.queryset.model._meta.pk
            else:
                if isinstance(field.parent, serializers.ManyRelatedField):
                    model_field = field.parent.parent.Meta.model._meta.pk
                else:
                    model_field = field.parent.Meta.model._meta.pk

            # primary keys are usually non-editable (readOnly=True) and map_model_field correctly
            # signals that attribute. however this does not apply in the context of relations.
            schema = self._map_model_field(model_field, direction)
            schema.pop('readOnly', None)
            return append_meta(schema, meta)

        if isinstance(field, serializers.StringRelatedField):
            return append_meta(build_basic_type(OpenApiTypes.STR), meta)

        if isinstance(field, serializers.SlugRelatedField):
            return append_meta(build_basic_type(OpenApiTypes.STR), meta)

        if isinstance(field, serializers.HyperlinkedIdentityField):
            return append_meta(build_basic_type(OpenApiTypes.URI), meta)

        if isinstance(field, serializers.HyperlinkedRelatedField):
            return append_meta(build_basic_type(OpenApiTypes.URI), meta)

        if isinstance(field, serializers.MultipleChoiceField):
            return append_meta(
                build_array_type(build_choice_field(field.choices)), meta)

        if isinstance(field, serializers.ChoiceField):
            return append_meta(build_choice_field(field.choices), meta)

        if isinstance(field, serializers.ListField):
            schema = build_array_type({})
            # TODO check this
            if not isinstance(field.child, _UnvalidatedField):
                map_field = self._map_serializer_field(field.child, direction)
                items = {"type": map_field.get('type')}
                if 'format' in map_field:
                    items['format'] = map_field.get('format')
                schema['items'] = items
            return append_meta(schema, meta)

        # DateField and DateTimeField type is string
        if isinstance(field, serializers.DateField):
            return append_meta(build_basic_type(OpenApiTypes.DATE), meta)

        if isinstance(field, serializers.DateTimeField):
            return append_meta(build_basic_type(OpenApiTypes.DATETIME), meta)

        if isinstance(field, serializers.TimeField):
            return append_meta(build_basic_type(OpenApiTypes.TIME), meta)

        if isinstance(field, serializers.EmailField):
            return append_meta(build_basic_type(OpenApiTypes.EMAIL), meta)

        if isinstance(field, serializers.URLField):
            return append_meta(build_basic_type(OpenApiTypes.URI), meta)

        if isinstance(field, serializers.UUIDField):
            return append_meta(build_basic_type(OpenApiTypes.UUID), meta)

        if isinstance(field, serializers.DurationField):
            return append_meta(build_basic_type(OpenApiTypes.STR), meta)

        if isinstance(field, serializers.IPAddressField):
            # TODO this might be a DRF bug. protocol is not propagated to serializer although it
            #  should have been. results in always 'both' (thus no format)
            if 'ipv4' == field.protocol.lower():
                schema = build_basic_type(OpenApiTypes.IP4)
            elif 'ipv6' == field.protocol.lower():
                schema = build_basic_type(OpenApiTypes.IP6)
            else:
                schema = build_basic_type(OpenApiTypes.STR)
            return append_meta(schema, meta)

        # DecimalField has multipleOf based on decimal_places
        if isinstance(field, serializers.DecimalField):
            if getattr(field, 'coerce_to_string',
                       api_settings.COERCE_DECIMAL_TO_STRING):
                content = {
                    **build_basic_type(OpenApiTypes.STR), 'format': 'decimal'
                }
            else:
                content = build_basic_type(OpenApiTypes.DECIMAL)

            if field.max_whole_digits:
                content['maximum'] = int(field.max_whole_digits * '9') + 1
                content['minimum'] = -content['maximum']
            self._map_min_max(field, content)
            return append_meta(content, meta)

        if isinstance(field, serializers.FloatField):
            content = build_basic_type(OpenApiTypes.FLOAT)
            self._map_min_max(field, content)
            return append_meta(content, meta)

        if isinstance(field, serializers.IntegerField):
            content = build_basic_type(OpenApiTypes.INT)
            self._map_min_max(field, content)
            # 2147483647 is max for int32_size, so we use int64 for format
            if int(content.get('maximum', 0)) > 2147483647 or int(
                    content.get('minimum', 0)) > 2147483647:
                content['format'] = 'int64'
            return append_meta(content, meta)

        if isinstance(field, serializers.FileField):
            if spectacular_settings.COMPONENT_SPLIT_REQUEST and direction == 'request':
                content = build_basic_type(OpenApiTypes.BINARY)
            else:
                use_url = getattr(field, 'use_url',
                                  api_settings.UPLOADED_FILES_USE_URL)
                content = build_basic_type(
                    OpenApiTypes.URI if use_url else OpenApiTypes.STR)
            return append_meta(content, meta)

        if isinstance(field, serializers.SerializerMethodField):
            method = getattr(field.parent, field.method_name)
            return append_meta(self._map_response_type_hint(method), meta)

        if anyisinstance(
                field,
            [serializers.BooleanField, serializers.NullBooleanField]):
            return append_meta(build_basic_type(OpenApiTypes.BOOL), meta)

        if isinstance(field, serializers.JSONField):
            return append_meta(build_basic_type(OpenApiTypes.OBJECT), meta)

        if anyisinstance(field,
                         [serializers.DictField, serializers.HStoreField]):
            content = build_basic_type(OpenApiTypes.OBJECT)
            if not isinstance(field.child, _UnvalidatedField):
                content['additionalProperties'] = self._map_serializer_field(
                    field.child, direction)
            return append_meta(content, meta)

        if isinstance(field, serializers.CharField):
            return append_meta(build_basic_type(OpenApiTypes.STR), meta)

        if isinstance(field, serializers.ReadOnlyField):
            # direct source from the serializer
            assert field.source_attrs, f'ReadOnlyField "{field}" needs a proper source'
            target = follow_field_source(field.parent.Meta.model,
                                         field.source_attrs)

            if callable(target):
                schema = self._map_response_type_hint(target)
            elif isinstance(target, models.Field):
                schema = self._map_model_field(target, direction)
            else:
                assert False, f'ReadOnlyField target "{field}" must be property or model field'
            return append_meta(schema, meta)

        # DRF was not able to match the model field to an explicit SerializerField and therefore
        # used its generic fallback serializer field that simply wraps the model field.
        if isinstance(field, serializers.ModelField):
            schema = self._map_model_field(field.model_field, direction)
            return append_meta(schema, meta)

        warn(
            f'could not resolve serializer field "{field}". defaulting to "string"'
        )
        return append_meta(build_basic_type(OpenApiTypes.STR), meta)
示例#7
0
    def _map_serializer_field(self, method, field):
        if hasattr(field, '_spectacular_annotation'):
            if is_basic_type(field._spectacular_annotation):
                return build_basic_type(field._spectacular_annotation)
            else:
                return self._map_serializer_field(
                    method, field._spectacular_annotation)

        # nested serializer
        if isinstance(field, serializers.Serializer):
            return self.resolve_serializer(method, field).ref

        # nested serializer with many=True gets automatically replaced with ListSerializer
        if isinstance(field, serializers.ListSerializer):
            return build_array_type(
                self.resolve_serializer(method, field.child).ref)

        # Related fields.
        if isinstance(field, serializers.ManyRelatedField):
            return build_array_type(
                self._map_serializer_field(method, field.child_relation))

        if isinstance(field, serializers.PrimaryKeyRelatedField):
            # read_only fields do not have a Manager by design. go around and get field
            # from parent. also avoid calling Manager. __bool__ as it might be customized
            # to hit the database.
            if getattr(field, 'queryset', None) is not None:
                return self._map_model_field(field.queryset.model._meta.pk)
            else:
                model = field.parent.Meta.model
                return self._map_model_field(
                    get_field_from_model(model, model.id))

        if isinstance(field, serializers.StringRelatedField):
            return build_basic_type(OpenApiTypes.STR)

        if isinstance(field, serializers.SlugRelatedField):
            return build_basic_type(OpenApiTypes.STR)

        if isinstance(field, serializers.HyperlinkedIdentityField):
            return build_basic_type(OpenApiTypes.URI)

        if isinstance(field, serializers.HyperlinkedRelatedField):
            return build_basic_type(OpenApiTypes.URI)

        # ChoiceFields (single and multiple).
        # Q:
        # - Is 'type' required?
        # - can we determine the TYPE of a choicefield?
        if isinstance(field, serializers.MultipleChoiceField):
            return build_array_type(self._map_choicefield(field))

        if isinstance(field, serializers.ChoiceField):
            return self._map_choicefield(field)

        if isinstance(field, serializers.ListField):
            schema = build_array_type({})
            # TODO check this
            if not isinstance(field.child, _UnvalidatedField):
                map_field = self._map_serializer_field(method, field.child)
                items = {"type": map_field.get('type')}
                if 'format' in map_field:
                    items['format'] = map_field.get('format')
                schema['items'] = items
            return schema

        # DateField and DateTimeField type is string
        if isinstance(field, serializers.DateField):
            return build_basic_type(OpenApiTypes.DATE)

        if isinstance(field, serializers.DateTimeField):
            return build_basic_type(OpenApiTypes.DATETIME)

        if isinstance(field, serializers.EmailField):
            return build_basic_type(OpenApiTypes.EMAIL)

        if isinstance(field, serializers.URLField):
            return build_basic_type(OpenApiTypes.URI)

        if isinstance(field, serializers.UUIDField):
            return build_basic_type(OpenApiTypes.UUID)

        if isinstance(field, serializers.IPAddressField):
            # TODO this might be a DRF bug. protocol is not propagated to serializer although it
            #  should have been. results in always 'both' (thus no format)
            if 'ipv4' == field.protocol.lower():
                return build_basic_type(OpenApiTypes.IP4)
            elif 'ipv6' == field.protocol.lower():
                return build_basic_type(OpenApiTypes.IP6)
            else:
                return build_basic_type(OpenApiTypes.STR)

        # DecimalField has multipleOf based on decimal_places
        if isinstance(field, serializers.DecimalField):
            content = {'type': 'number'}
            if field.decimal_places:
                content['multipleOf'] = float('.' +
                                              (field.decimal_places - 1) *
                                              '0' + '1')
            if field.max_whole_digits:
                content['maximum'] = int(field.max_whole_digits * '9') + 1
                content['minimum'] = -content['maximum']
            self._map_min_max(field, content)
            return content

        if isinstance(field, serializers.FloatField):
            content = build_basic_type(OpenApiTypes.FLOAT)
            self._map_min_max(field, content)
            return content

        if isinstance(field, serializers.IntegerField):
            content = build_basic_type(OpenApiTypes.INT)
            self._map_min_max(field, content)
            # 2147483647 is max for int32_size, so we use int64 for format
            if int(content.get('maximum', 0)) > 2147483647 or int(
                    content.get('minimum', 0)) > 2147483647:
                content['format'] = 'int64'
            return content

        if isinstance(field, serializers.FileField):
            # TODO returns filename. but does it accept binary data on upload?
            return build_basic_type(OpenApiTypes.STR)

        if isinstance(field, serializers.SerializerMethodField):
            method = getattr(field.parent, field.method_name)
            return self._map_type_hint(method)

        if isinstance(field, serializers.BooleanField):
            return build_basic_type(OpenApiTypes.BOOL)

        if anyisinstance(field, [
                serializers.JSONField, serializers.DictField,
                serializers.HStoreField
        ]):
            return build_basic_type(OpenApiTypes.OBJECT)

        if isinstance(field, serializers.CharField):
            return build_basic_type(OpenApiTypes.STR)

        if isinstance(field, serializers.ReadOnlyField):
            # direct source from the serializer
            assert field.source_attrs, 'ReadOnlyField needs a proper source'
            target = follow_field_source(field.parent.Meta.model,
                                         field.source_attrs)

            if callable(target):
                return self._map_type_hint(target)
            elif isinstance(target, models.Field):
                return self._map_model_field(target)

        warn(
            f'could not resolve serializer field {field}. defaulting to "string"'
        )
        return build_basic_type(OpenApiTypes.STR)
示例#8
0
 def _get_model_field(self, filter_field, model):
     path = filter_field.field_name.split('__')
     return follow_field_source(model, path)
示例#9
0
 def _get_model_field(self, filter_field, model):
     if not filter_field.field_name:
         return None
     path = filter_field.field_name.split('__')
     return follow_field_source(model, path, emit_warnings=False)