Ejemplo n.º 1
0
    def _get_request_body(self):
        # only unsafe methods can have a body
        if self.method not in ('PUT', 'PATCH', 'POST'):
            return None

        serializer = force_instance(self.get_request_serializer())

        request_body_required = False
        if is_list_serializer(serializer):
            if is_serializer(serializer.child):
                component = self.resolve_serializer(serializer.child,
                                                    'request')
                schema = build_array_type(component.ref)
            else:
                schema = build_array_type(
                    self._map_serializer_field(serializer.child, 'request'))
            request_body_required = True
        elif is_serializer(serializer):
            if self.method == 'PATCH':
                serializer.partial = True
            component = self.resolve_serializer(serializer, 'request')
            if not component.schema:
                # serializer is empty so skip content enumeration
                return None
            schema = component.ref
            # request body is only required if any required property is not read-only
            readonly_props = [
                p for p, s in component.schema.get('properties', {}).items()
                if s.get('readOnly')
            ]
            required_props = component.schema.get('required', [])
            request_body_required = any(req not in readonly_props
                                        for req in required_props)
        elif is_basic_type(serializer):
            schema = build_basic_type(serializer)
            if not schema:
                return None
        else:
            warn(
                f'could not resolve request body for {self.method} {self.path}. defaulting to generic '
                'free-form object. (maybe annotate a Serializer class?)')
            schema = build_object_type(
                additionalProperties={},
                description='Unspecified request body',
            )

        request_body = {
            'content': {
                media_type: build_media_type_object(
                    schema,
                    self._get_examples(serializer, 'request', media_type))
                for media_type in self.map_parsers()
            }
        }

        if request_body_required:
            request_body['required'] = request_body_required

        return request_body
Ejemplo n.º 2
0
    def resolve_filter_field(self, auto_schema, model, filterset_class,
                             field_name, filter_field):
        from django_filters.rest_framework import filters

        if isinstance(filter_field, filters.OrderingFilter):
            # only here filter_field.field_name is not the model field name/path
            schema = build_basic_type(OpenApiTypes.STR)
        elif filter_field.method:
            if callable(filter_field.method):
                filter_method = filter_field.method
            else:
                filter_method = getattr(filterset_class, filter_field.method)

            try:
                filter_method_hints = typing.get_type_hints(filter_method)
            except:  # noqa: E722
                filter_method_hints = {}

            if 'value' in filter_method_hints and is_basic_type(
                    filter_method_hints['value']):
                schema = build_basic_type(filter_method_hints['value'])
            else:
                schema = self.map_filter_field(filter_field)
        else:
            path = filter_field.field_name.split('__')
            model_field = follow_field_source(model, path)

            if isinstance(model_field, models.Field):
                schema = auto_schema._map_model_field(model_field,
                                                      direction=None)
            else:
                schema = self.map_filter_field(filter_field)

        enum = schema.pop('enum', None)
        if 'choices' in filter_field.extra:
            enum = [c for c, _ in filter_field.extra['choices']]
        if enum:
            schema['enum'] = sorted(enum, key=str)

        description = schema.pop('description', None)
        if filter_field.extra.get('help_text', None):
            description = filter_field.extra['help_text']
        elif filter_field.label is not None:
            description = filter_field.label

        if isinstance(filter_field, filters.BaseCSVFilter):
            schema = build_array_type(schema)
            explode = False
            style = 'form'
        else:
            explode = None
            style = None

        return build_parameter_type(name=field_name,
                                    required=filter_field.extra['required'],
                                    location=OpenApiParameter.QUERY,
                                    description=description,
                                    schema=schema,
                                    explode=explode,
                                    style=style)
Ejemplo n.º 3
0
    def _get_response_for_code(self, serializer):
        serializer = force_instance(serializer)

        if not serializer:
            return {'description': _('No response body')}
        elif isinstance(serializer, serializers.ListSerializer):
            if is_serializer(serializer.child):
                schema = self.resolve_serializer(serializer.child,
                                                 'response').ref
            else:
                schema = self._map_serializer_field(serializer.child,
                                                    'response')
        elif is_serializer(serializer):
            component = self.resolve_serializer(serializer, 'response')
            if not component.schema:
                return {'description': _('No response body')}
            schema = component.ref
        elif is_basic_type(serializer):
            schema = build_basic_type(serializer)
        elif isinstance(serializer, dict):
            # bypass processing and use given schema directly
            schema = serializer
        else:
            warn(
                f'could not resolve "{serializer}" for {self.method} {self.path}. Expected either '
                f'a serializer or some supported override mechanism. defaulting to '
                f'generic free-form object.')
            schema = build_basic_type(OpenApiTypes.OBJECT)
            schema['description'] = _('Unspecified response body')

        if self._is_list_view(serializer) and not get_override(
                serializer, 'many') is False:
            schema = build_array_type(schema)
            paginator = self._get_paginator()

            if paginator and is_serializer(serializer):
                paginated_name = f'Paginated{self._get_serializer_name(serializer, "response")}List'
                component = ResolvedComponent(
                    name=paginated_name,
                    type=ResolvedComponent.SCHEMA,
                    schema=paginator.get_paginated_response_schema(schema),
                    object=paginated_name,
                )
                self.registry.register(component)
                schema = component.ref
            elif paginator:
                schema = paginator.get_paginated_response_schema(schema)

        return {
            'content': {
                mt: {
                    'schema': schema
                }
                for mt in self.map_renderers('media_type')
            },
            # Description is required by spec, but descriptions for each response code don't really
            # fit into our model. Description is therefore put into the higher level slots.
            # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.3.md#responseObject
            'description': ''
        }
Ejemplo n.º 4
0
    def _get_response_for_code(self, path, method, serializer):
        serializer = force_instance(serializer)

        if not serializer:
            return {'description': 'No response body'}
        elif isinstance(serializer, serializers.ListSerializer):
            schema = self.resolve_serializer(method, serializer.child).ref
        elif is_serializer(serializer):
            component = self.resolve_serializer(method, serializer)
            if not component:
                return {'description': 'No response body'}
            schema = component.ref
        elif is_basic_type(serializer):
            schema = build_basic_type(serializer)
        elif isinstance(serializer, dict):
            # bypass processing and use given schema directly
            schema = serializer
        else:
            warn(
                f'could not resolve "{serializer}" for {method} {path}. Expected either '
                f'a serializer or some supported override mechanism. defaulting to '
                f'generic free-form object.')
            schema = build_basic_type(OpenApiTypes.OBJECT)
            schema['description'] = 'Unspecified response body'

        if isinstance(serializer, serializers.ListSerializer) or is_list_view(
                path, method, self.view):
            # TODO i fear is_list_view is not covering all the cases
            schema = build_array_type(schema)
            paginator = self._get_paginator()
            if paginator:
                schema = paginator.get_paginated_response_schema(schema)

        return {
            'content': {
                mt: {
                    'schema': schema
                }
                for mt in self.map_renderers(path, method)
            },
            # Description is required by spec, but descriptions for each response code don't really
            # fit into our model. Description is therefore put into the higher level slots.
            # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.3.md#responseObject
            'description': ''
        }
Ejemplo n.º 5
0
    def _get_request_body(self):
        # only unsafe methods can have a body
        if self.method not in ('PUT', 'PATCH', 'POST'):
            return None

        serializer = force_instance(self.get_request_serializer())

        request_body_required = False
        if isinstance(serializer, serializers.ListSerializer):
            component = self.resolve_serializer(serializer.child, 'request')
            schema = build_array_type(component.ref)
            request_body_required = True
        elif is_serializer(serializer):
            component = self.resolve_serializer(serializer, 'request')
            if not component:
                # serializer is empty so skip content enumeration
                return None
            schema = component.ref
            if component.schema.get('required', []):
                request_body_required = True
        elif is_basic_type(serializer):
            schema = build_basic_type(serializer)
            if not schema:
                return None
        else:
            warn(
                f'could not resolve request body for {self.method} {self.path}. defaulting to generic '
                'free-form object. (maybe annotate a Serializer class?)'
            )
            schema = {
                'type': 'object',
                'additionalProperties': {},  # https://github.com/swagger-api/swagger-codegen/issues/1318
                'description': 'Unspecified request body',
            }

        request_body = {
            'content': {
                request_media_types: {'schema': schema} for request_media_types in self.map_parsers()
            }
        }
        if request_body_required:
            request_body['required'] = request_body_required

        return request_body
Ejemplo n.º 6
0
class GroupsView(APIView):
    permission_classes = (IsAuthenticated,)

    @extend_schema(
        tags=["Groups"],
        operation_id="list_groups",
        description=(
            "Lists all the groups of the authorized user. A group can contain "
            "multiple applications like a database. Multiple users can have "
            "access to a group. For example each company could have their own group "
            "containing databases related to that company. The order of the groups "
            "are custom for each user. The order is configurable via the "
            "**order_groups** endpoint."
        ),
        responses={200: build_array_type(group_user_schema)},
    )
    def get(self, request):
        """Responds with a list of serialized groups where the user is part of."""

        groups = GroupUser.objects.filter(user=request.user).select_related("group")
        serializer = GroupUserGroupSerializer(groups, many=True)
        return Response(serializer.data)

    @extend_schema(
        tags=["Groups"],
        operation_id="create_group",
        description=(
            "Creates a new group where only the authorized user has access to. No "
            "initial data like database applications are added, they have to be "
            "created via other endpoints."
        ),
        request=GroupSerializer,
        responses={200: group_user_schema},
    )
    @transaction.atomic
    @validate_body(GroupSerializer)
    def post(self, request, data):
        """Creates a new group for a user."""

        group_user = CoreHandler().create_group(request.user, name=data["name"])
        return Response(GroupUserGroupSerializer(group_user).data)
Ejemplo n.º 7
0
    def _map_serializer_field(self, field, direction):
        if has_override(field, 'field'):
            override = get_override(field, 'field')
            if is_basic_type(override):
                return build_basic_type(override)
            else:
                return self._map_serializer_field(override, direction)

        meta = self._get_serializer_field_meta(field)

        serializer_field_extension = OpenApiSerializerFieldExtension.get_match(
            field)
        if serializer_field_extension:
            schema = serializer_field_extension.map_serializer_field(
                self, direction)
            return append_meta(schema, meta)

        # nested serializer
        if isinstance(field, serializers.Serializer):
            component = self.resolve_serializer(field, direction)
            return append_meta(component.ref, meta) if component else None

        # nested serializer with many=True gets automatically replaced with ListSerializer
        if isinstance(field, serializers.ListSerializer):
            if is_serializer(field.child):
                component = self.resolve_serializer(field.child, direction)
                return append_meta(build_array_type(component.ref),
                                   meta) if component else None
            else:
                schema = self._map_serializer_field(field.child, direction)
                return append_meta(build_array_type(schema), meta)

        # Related fields.
        if isinstance(field, serializers.ManyRelatedField):
            schema = self._map_serializer_field(field.child_relation,
                                                direction)
            # remove hand-over initkwargs applying only to outer scope
            schema.pop('description', None)
            schema.pop('readOnly', None)
            return append_meta(build_array_type(schema), meta)

        if isinstance(field, serializers.PrimaryKeyRelatedField):
            # read_only fields do not have a Manager by design. go around and get field
            # from parent. also avoid calling Manager. __bool__ as it might be customized
            # to hit the database.
            if getattr(field, 'queryset', None) is not None:
                model_field = field.queryset.model._meta.pk
            else:
                if isinstance(field.parent, serializers.ManyRelatedField):
                    model_field = field.parent.parent.Meta.model._meta.pk
                else:
                    model_field = field.parent.Meta.model._meta.pk

            # primary keys are usually non-editable (readOnly=True) and map_model_field correctly
            # signals that attribute. however this does not apply in the context of relations.
            schema = self._map_model_field(model_field, direction)
            schema.pop('readOnly', None)
            return append_meta(schema, meta)

        if isinstance(field, serializers.StringRelatedField):
            return append_meta(build_basic_type(OpenApiTypes.STR), meta)

        if isinstance(field, serializers.SlugRelatedField):
            return append_meta(build_basic_type(OpenApiTypes.STR), meta)

        if isinstance(field, serializers.HyperlinkedIdentityField):
            return append_meta(build_basic_type(OpenApiTypes.URI), meta)

        if isinstance(field, serializers.HyperlinkedRelatedField):
            return append_meta(build_basic_type(OpenApiTypes.URI), meta)

        if isinstance(field, serializers.MultipleChoiceField):
            return append_meta(
                build_array_type(build_choice_field(field.choices)), meta)

        if isinstance(field, serializers.ChoiceField):
            return append_meta(build_choice_field(field.choices), meta)

        if isinstance(field, serializers.ListField):
            schema = build_array_type({})
            # TODO check this
            if not isinstance(field.child, _UnvalidatedField):
                map_field = self._map_serializer_field(field.child, direction)
                items = {"type": map_field.get('type')}
                if 'format' in map_field:
                    items['format'] = map_field.get('format')
                schema['items'] = items
            return append_meta(schema, meta)

        # DateField and DateTimeField type is string
        if isinstance(field, serializers.DateField):
            return append_meta(build_basic_type(OpenApiTypes.DATE), meta)

        if isinstance(field, serializers.DateTimeField):
            return append_meta(build_basic_type(OpenApiTypes.DATETIME), meta)

        if isinstance(field, serializers.TimeField):
            return append_meta(build_basic_type(OpenApiTypes.TIME), meta)

        if isinstance(field, serializers.EmailField):
            return append_meta(build_basic_type(OpenApiTypes.EMAIL), meta)

        if isinstance(field, serializers.URLField):
            return append_meta(build_basic_type(OpenApiTypes.URI), meta)

        if isinstance(field, serializers.UUIDField):
            return append_meta(build_basic_type(OpenApiTypes.UUID), meta)

        if isinstance(field, serializers.DurationField):
            return append_meta(build_basic_type(OpenApiTypes.STR), meta)

        if isinstance(field, serializers.IPAddressField):
            # TODO this might be a DRF bug. protocol is not propagated to serializer although it
            #  should have been. results in always 'both' (thus no format)
            if 'ipv4' == field.protocol.lower():
                schema = build_basic_type(OpenApiTypes.IP4)
            elif 'ipv6' == field.protocol.lower():
                schema = build_basic_type(OpenApiTypes.IP6)
            else:
                schema = build_basic_type(OpenApiTypes.STR)
            return append_meta(schema, meta)

        # DecimalField has multipleOf based on decimal_places
        if isinstance(field, serializers.DecimalField):
            if getattr(field, 'coerce_to_string',
                       api_settings.COERCE_DECIMAL_TO_STRING):
                content = {
                    **build_basic_type(OpenApiTypes.STR), 'format': 'decimal'
                }
            else:
                content = build_basic_type(OpenApiTypes.DECIMAL)

            if field.max_whole_digits:
                content['maximum'] = int(field.max_whole_digits * '9') + 1
                content['minimum'] = -content['maximum']
            self._map_min_max(field, content)
            return append_meta(content, meta)

        if isinstance(field, serializers.FloatField):
            content = build_basic_type(OpenApiTypes.FLOAT)
            self._map_min_max(field, content)
            return append_meta(content, meta)

        if isinstance(field, serializers.IntegerField):
            content = build_basic_type(OpenApiTypes.INT)
            self._map_min_max(field, content)
            # 2147483647 is max for int32_size, so we use int64 for format
            if int(content.get('maximum', 0)) > 2147483647 or int(
                    content.get('minimum', 0)) > 2147483647:
                content['format'] = 'int64'
            return append_meta(content, meta)

        if isinstance(field, serializers.FileField):
            if spectacular_settings.COMPONENT_SPLIT_REQUEST and direction == 'request':
                content = build_basic_type(OpenApiTypes.BINARY)
            else:
                use_url = getattr(field, 'use_url',
                                  api_settings.UPLOADED_FILES_USE_URL)
                content = build_basic_type(
                    OpenApiTypes.URI if use_url else OpenApiTypes.STR)
            return append_meta(content, meta)

        if isinstance(field, serializers.SerializerMethodField):
            method = getattr(field.parent, field.method_name)
            return append_meta(self._map_response_type_hint(method), meta)

        if anyisinstance(
                field,
            [serializers.BooleanField, serializers.NullBooleanField]):
            return append_meta(build_basic_type(OpenApiTypes.BOOL), meta)

        if isinstance(field, serializers.JSONField):
            return append_meta(build_basic_type(OpenApiTypes.OBJECT), meta)

        if anyisinstance(field,
                         [serializers.DictField, serializers.HStoreField]):
            content = build_basic_type(OpenApiTypes.OBJECT)
            if not isinstance(field.child, _UnvalidatedField):
                content['additionalProperties'] = self._map_serializer_field(
                    field.child, direction)
            return append_meta(content, meta)

        if isinstance(field, serializers.CharField):
            return append_meta(build_basic_type(OpenApiTypes.STR), meta)

        if isinstance(field, serializers.ReadOnlyField):
            # direct source from the serializer
            assert field.source_attrs, f'ReadOnlyField "{field}" needs a proper source'
            target = follow_field_source(field.parent.Meta.model,
                                         field.source_attrs)

            if callable(target):
                schema = self._map_response_type_hint(target)
            elif isinstance(target, models.Field):
                schema = self._map_model_field(target, direction)
            else:
                assert False, f'ReadOnlyField target "{field}" must be property or model field'
            return append_meta(schema, meta)

        # DRF was not able to match the model field to an explicit SerializerField and therefore
        # used its generic fallback serializer field that simply wraps the model field.
        if isinstance(field, serializers.ModelField):
            schema = self._map_model_field(field.model_field, direction)
            return append_meta(schema, meta)

        warn(
            f'could not resolve serializer field "{field}". defaulting to "string"'
        )
        return append_meta(build_basic_type(OpenApiTypes.STR), meta)
Ejemplo n.º 8
0
    def _map_serializer_field(self, method, field):
        if hasattr(field, '_spectacular_annotation'):
            if is_basic_type(field._spectacular_annotation):
                return build_basic_type(field._spectacular_annotation)
            else:
                return self._map_serializer_field(
                    method, field._spectacular_annotation)

        # nested serializer
        if isinstance(field, serializers.Serializer):
            return self.resolve_serializer(method, field).ref

        # nested serializer with many=True gets automatically replaced with ListSerializer
        if isinstance(field, serializers.ListSerializer):
            return build_array_type(
                self.resolve_serializer(method, field.child).ref)

        # Related fields.
        if isinstance(field, serializers.ManyRelatedField):
            return build_array_type(
                self._map_serializer_field(method, field.child_relation))

        if isinstance(field, serializers.PrimaryKeyRelatedField):
            # read_only fields do not have a Manager by design. go around and get field
            # from parent. also avoid calling Manager. __bool__ as it might be customized
            # to hit the database.
            if getattr(field, 'queryset', None) is not None:
                return self._map_model_field(field.queryset.model._meta.pk)
            else:
                model = field.parent.Meta.model
                return self._map_model_field(
                    get_field_from_model(model, model.id))

        if isinstance(field, serializers.StringRelatedField):
            return build_basic_type(OpenApiTypes.STR)

        if isinstance(field, serializers.SlugRelatedField):
            return build_basic_type(OpenApiTypes.STR)

        if isinstance(field, serializers.HyperlinkedIdentityField):
            return build_basic_type(OpenApiTypes.URI)

        if isinstance(field, serializers.HyperlinkedRelatedField):
            return build_basic_type(OpenApiTypes.URI)

        # ChoiceFields (single and multiple).
        # Q:
        # - Is 'type' required?
        # - can we determine the TYPE of a choicefield?
        if isinstance(field, serializers.MultipleChoiceField):
            return build_array_type(self._map_choicefield(field))

        if isinstance(field, serializers.ChoiceField):
            return self._map_choicefield(field)

        if isinstance(field, serializers.ListField):
            schema = build_array_type({})
            # TODO check this
            if not isinstance(field.child, _UnvalidatedField):
                map_field = self._map_serializer_field(method, field.child)
                items = {"type": map_field.get('type')}
                if 'format' in map_field:
                    items['format'] = map_field.get('format')
                schema['items'] = items
            return schema

        # DateField and DateTimeField type is string
        if isinstance(field, serializers.DateField):
            return build_basic_type(OpenApiTypes.DATE)

        if isinstance(field, serializers.DateTimeField):
            return build_basic_type(OpenApiTypes.DATETIME)

        if isinstance(field, serializers.EmailField):
            return build_basic_type(OpenApiTypes.EMAIL)

        if isinstance(field, serializers.URLField):
            return build_basic_type(OpenApiTypes.URI)

        if isinstance(field, serializers.UUIDField):
            return build_basic_type(OpenApiTypes.UUID)

        if isinstance(field, serializers.IPAddressField):
            # TODO this might be a DRF bug. protocol is not propagated to serializer although it
            #  should have been. results in always 'both' (thus no format)
            if 'ipv4' == field.protocol.lower():
                return build_basic_type(OpenApiTypes.IP4)
            elif 'ipv6' == field.protocol.lower():
                return build_basic_type(OpenApiTypes.IP6)
            else:
                return build_basic_type(OpenApiTypes.STR)

        # DecimalField has multipleOf based on decimal_places
        if isinstance(field, serializers.DecimalField):
            content = {'type': 'number'}
            if field.decimal_places:
                content['multipleOf'] = float('.' +
                                              (field.decimal_places - 1) *
                                              '0' + '1')
            if field.max_whole_digits:
                content['maximum'] = int(field.max_whole_digits * '9') + 1
                content['minimum'] = -content['maximum']
            self._map_min_max(field, content)
            return content

        if isinstance(field, serializers.FloatField):
            content = build_basic_type(OpenApiTypes.FLOAT)
            self._map_min_max(field, content)
            return content

        if isinstance(field, serializers.IntegerField):
            content = build_basic_type(OpenApiTypes.INT)
            self._map_min_max(field, content)
            # 2147483647 is max for int32_size, so we use int64 for format
            if int(content.get('maximum', 0)) > 2147483647 or int(
                    content.get('minimum', 0)) > 2147483647:
                content['format'] = 'int64'
            return content

        if isinstance(field, serializers.FileField):
            # TODO returns filename. but does it accept binary data on upload?
            return build_basic_type(OpenApiTypes.STR)

        if isinstance(field, serializers.SerializerMethodField):
            method = getattr(field.parent, field.method_name)
            return self._map_type_hint(method)

        if isinstance(field, serializers.BooleanField):
            return build_basic_type(OpenApiTypes.BOOL)

        if anyisinstance(field, [
                serializers.JSONField, serializers.DictField,
                serializers.HStoreField
        ]):
            return build_basic_type(OpenApiTypes.OBJECT)

        if isinstance(field, serializers.CharField):
            return build_basic_type(OpenApiTypes.STR)

        if isinstance(field, serializers.ReadOnlyField):
            # direct source from the serializer
            assert field.source_attrs, 'ReadOnlyField needs a proper source'
            target = follow_field_source(field.parent.Meta.model,
                                         field.source_attrs)

            if callable(target):
                return self._map_type_hint(target)
            elif isinstance(target, models.Field):
                return self._map_model_field(target)

        warn(
            f'could not resolve serializer field {field}. defaulting to "string"'
        )
        return build_basic_type(OpenApiTypes.STR)
Ejemplo n.º 9
0
    def resolve_filter_field(self, auto_schema, model, filterset_class,
                             field_name, filter_field):
        from django_filters.rest_framework import filters

        unambiguous_mapping = {
            filters.CharFilter: OpenApiTypes.STR,
            filters.BooleanFilter: OpenApiTypes.BOOL,
            filters.DateFilter: OpenApiTypes.DATE,
            filters.DateTimeFilter: OpenApiTypes.DATETIME,
            filters.IsoDateTimeFilter: OpenApiTypes.DATETIME,
            filters.TimeFilter: OpenApiTypes.TIME,
            filters.UUIDFilter: OpenApiTypes.UUID,
            filters.DurationFilter: OpenApiTypes.DURATION,
            filters.OrderingFilter: OpenApiTypes.STR,
            filters.TimeRangeFilter: OpenApiTypes.TIME,
            filters.DateFromToRangeFilter: OpenApiTypes.DATE,
            filters.IsoDateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
            filters.DateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
        }
        if isinstance(filter_field, tuple(unambiguous_mapping)
                      ) and filter_field.__class__ in unambiguous_mapping:
            schema = build_basic_type(
                unambiguous_mapping[filter_field.__class__])
        elif isinstance(filter_field,
                        (filters.NumberFilter, filters.NumericRangeFilter)):
            # NumberField is underspecified by itself. try to find the
            # type that makes the most sense or default to generic NUMBER
            if filter_field.method:
                schema = self._build_filter_method_type(
                    filterset_class, filter_field)
                if schema['type'] not in ['integer', 'number']:
                    schema = build_basic_type(OpenApiTypes.NUMBER)
            else:
                model_field = self._get_model_field(filter_field, model)
                if isinstance(model_field,
                              (models.IntegerField, models.AutoField)):
                    schema = build_basic_type(OpenApiTypes.INT)
                elif isinstance(model_field, models.FloatField):
                    schema = build_basic_type(OpenApiTypes.FLOAT)
                elif isinstance(model_field, models.DecimalField):
                    schema = build_basic_type(
                        OpenApiTypes.NUMBER)  # TODO may be improved
                else:
                    schema = build_basic_type(OpenApiTypes.NUMBER)
        elif filter_field.method:
            # try to make best effort on the given method
            schema = self._build_filter_method_type(filterset_class,
                                                    filter_field)
        else:
            # last resort is to lookup the type via the model field.
            model_field = self._get_model_field(filter_field, model)
            if isinstance(model_field, models.Field):
                schema = auto_schema._map_model_field(model_field,
                                                      direction=None)
            else:
                # default to string if nothing else works
                schema = build_basic_type(OpenApiTypes.STR)

        # enrich schema with additional info from filter_field
        enum = schema.pop('enum', None)
        if 'choices' in filter_field.extra:
            enum = [c for c, _ in filter_field.extra['choices']]
        if enum:
            schema['enum'] = sorted(enum, key=str)

        description = schema.pop('description', None)
        if filter_field.extra.get('help_text', None):
            description = filter_field.extra['help_text']
        elif filter_field.label is not None:
            description = filter_field.label

        # parameter style variations based on filter base class
        if isinstance(filter_field, filters.BaseCSVFilter):
            schema = build_array_type(schema)
            field_names = [field_name]
            explode = False
            style = 'form'
        elif isinstance(filter_field, filters.MultipleChoiceFilter):
            schema = build_array_type(schema)
            field_names = [field_name]
            explode = True
            style = 'form'
        elif isinstance(filter_field,
                        (filters.RangeFilter, filters.NumericRangeFilter)):
            field_names = [f'{field_name}_min', f'{field_name}_max']
            explode = None
            style = None
        else:
            field_names = [field_name]
            explode = None
            style = None

        return [
            build_parameter_type(name=field_name,
                                 required=filter_field.extra['required'],
                                 location=OpenApiParameter.QUERY,
                                 description=description,
                                 schema=schema,
                                 explode=explode,
                                 style=style) for field_name in field_names
        ]
    def resolve_filter_field(self, auto_schema, model, filterset_class,
                             field_name, filter_field):
        from django_filters.rest_framework import filters

        unambiguous_mapping = {
            filters.CharFilter: OpenApiTypes.STR,
            filters.BooleanFilter: OpenApiTypes.BOOL,
            filters.DateFilter: OpenApiTypes.DATE,
            filters.DateTimeFilter: OpenApiTypes.DATETIME,
            filters.IsoDateTimeFilter: OpenApiTypes.DATETIME,
            filters.TimeFilter: OpenApiTypes.TIME,
            filters.UUIDFilter: OpenApiTypes.UUID,
            filters.DurationFilter: OpenApiTypes.DURATION,
            filters.OrderingFilter: OpenApiTypes.STR,
            filters.TimeRangeFilter: OpenApiTypes.TIME,
            filters.DateFromToRangeFilter: OpenApiTypes.DATE,
            filters.IsoDateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
            filters.DateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
        }
        if isinstance(filter_field, tuple(unambiguous_mapping)):
            for cls in filter_field.__class__.__mro__:
                if cls in unambiguous_mapping:
                    schema = build_basic_type(unambiguous_mapping[cls])
                    break
        elif isinstance(filter_field,
                        (filters.NumberFilter, filters.NumericRangeFilter)):
            # NumberField is underspecified by itself. try to find the
            # type that makes the most sense or default to generic NUMBER
            if filter_field.method:
                schema = self._build_filter_method_type(
                    filterset_class, filter_field)
                if schema['type'] not in ['integer', 'number']:
                    schema = build_basic_type(OpenApiTypes.NUMBER)
            else:
                model_field = self._get_model_field(filter_field, model)
                if isinstance(model_field,
                              (models.IntegerField, models.AutoField)):
                    schema = build_basic_type(OpenApiTypes.INT)
                elif isinstance(model_field, models.FloatField):
                    schema = build_basic_type(OpenApiTypes.FLOAT)
                elif isinstance(model_field, models.DecimalField):
                    schema = build_basic_type(
                        OpenApiTypes.NUMBER)  # TODO may be improved
                else:
                    schema = build_basic_type(OpenApiTypes.NUMBER)
        elif filter_field.method:
            # try to make best effort on the given method
            schema = self._build_filter_method_type(filterset_class,
                                                    filter_field)
        else:
            # last resort is to lookup the type via the model field.
            model_field = self._get_model_field(filter_field, model)
            if isinstance(model_field, models.Field):
                try:
                    schema = auto_schema._map_model_field(model_field,
                                                          direction=None)
                except Exception as exc:
                    warn(
                        f'Exception raised while trying resolve model field for django-filter '
                        f'field "{field_name}". Defaulting to string (Exception: {exc})'
                    )
                    schema = build_basic_type(OpenApiTypes.STR)
            else:
                # default to string if nothing else works
                schema = build_basic_type(OpenApiTypes.STR)

        # primary keys are usually non-editable (readOnly=True) and map_model_field correctly
        # signals that attribute. however this does not apply in this context.
        schema.pop('readOnly', None)
        # enrich schema with additional info from filter_field
        enum = schema.pop('enum', None)
        if 'choices' in filter_field.extra:
            enum = [c for c, _ in filter_field.extra['choices']]
        if enum:
            schema['enum'] = sorted(enum, key=str)

        description = schema.pop('description', None)
        if filter_field.extra.get('help_text', None):
            description = filter_field.extra['help_text']
        elif filter_field.label is not None:
            description = filter_field.label

        # parameter style variations based on filter base class
        if isinstance(filter_field, filters.BaseCSVFilter):
            schema = build_array_type(schema)
            field_names = [field_name]
            explode = False
            style = 'form'
        elif isinstance(filter_field, filters.MultipleChoiceFilter):
            schema = build_array_type(schema)
            field_names = [field_name]
            explode = True
            style = 'form'
        elif isinstance(filter_field,
                        (filters.RangeFilter, filters.NumericRangeFilter)):
            try:
                suffixes = filter_field.field_class.widget.suffixes
            except AttributeError:
                suffixes = ['min', 'max']
            field_names = [
                f'{field_name}_{suffix}' if suffix else field_name
                for suffix in suffixes
            ]
            explode = None
            style = None
        else:
            field_names = [field_name]
            explode = None
            style = None

        return [
            build_parameter_type(name=field_name,
                                 required=filter_field.extra['required'],
                                 location=OpenApiParameter.QUERY,
                                 description=description,
                                 schema=schema,
                                 explode=explode,
                                 style=style) for field_name in field_names
        ]
Ejemplo n.º 11
0
    return schema


GENERIC_ERROR = build_object_type(
    description=_("Generic API Error"),
    properties={
        "detail": build_standard_type(OpenApiTypes.STR),
        "code": build_standard_type(OpenApiTypes.STR),
    },
    required=["detail"],
)
VALIDATION_ERROR = build_object_type(
    description=_("Validation Error"),
    properties={
        "non_field_errors":
        build_array_type(build_standard_type(OpenApiTypes.STR)),
        "code": build_standard_type(OpenApiTypes.STR),
    },
    required=[],
    additionalProperties={},
)


def postprocess_schema_responses(result, generator, **kwargs):  # noqa: W0613
    """Workaround to set a default response for endpoints.
    Workaround suggested at
    <https://github.com/tfranzel/drf-spectacular/issues/119#issuecomment-656970357>
    for the missing drf-spectacular feature discussed in
    <https://github.com/tfranzel/drf-spectacular/issues/101>.
    """
    def create_component(name, schema, type_=ResolvedComponent.SCHEMA):
Ejemplo n.º 12
0
def resolve_type_hint(hint) -> Any:
    """drf-spectacular library method modified as described above"""
    origin, args = _get_type_hint_origin(hint)
    excluded_fields = get_override(hint, "exclude_fields", [])

    if origin is None and is_basic_type(hint, allow_none=False):
        return build_basic_type(hint)
    elif origin is None and inspect.isclass(hint) and issubclass(hint, tuple):
        # a convoluted way to catch NamedTuple. suggestions welcome.
        if get_type_hints(hint):
            properties = {
                k: resolve_type_hint(v)
                for k, v in get_type_hints(hint).items()
            }
        else:
            properties = {
                k: build_basic_type(OpenApiTypes.ANY)
                for k in hint._fields
            }
        return build_object_type(properties=properties,
                                 required=properties.keys())
    elif origin is list or hint is list:
        return build_array_type(
            resolve_type_hint(args[0])
            if args else build_basic_type(OpenApiTypes.ANY))
    elif origin is tuple:
        return build_array_type(
            schema=build_basic_type(args[0]),
            max_length=len(args),
            min_length=len(args),
        )
    elif origin is dict or origin is defaultdict or origin is OrderedDict:
        schema = build_basic_type(OpenApiTypes.OBJECT)
        if args and args[1] is not typing.Any:
            schema["additionalProperties"] = resolve_type_hint(args[1])
        return schema
    elif origin is set:
        return build_array_type(resolve_type_hint(args[0]))
    elif origin is frozenset:
        return build_array_type(resolve_type_hint(args[0]))
    elif origin is Literal:
        # Literal only works for python >= 3.8 despite typing_extensions, because it
        # behaves slightly different w.r.t. __origin__
        schema = {"enum": list(args)}
        if all(type(args[0]) is type(choice) for choice in args):
            schema.update(build_basic_type(type(args[0])))
        return schema
    elif inspect.isclass(hint) and issubclass(hint, Enum):
        schema = {"enum": [item.value for item in hint]}
        mixin_base_types = [t for t in hint.__mro__ if is_basic_type(t)]
        if mixin_base_types:
            schema.update(build_basic_type(mixin_base_types[0]))
        return schema
    elif isinstance(hint, _TypedDictMeta):
        return build_object_type(
            properties={
                k: resolve_type_hint(v)
                for k, v in get_type_hints(hint).items()
                if k not in excluded_fields
            },
            description=inspect.cleandoc(hint.__doc__ or ""),
            required=[
                h for h in hint.__required_keys__ if h not in excluded_fields
            ],
        )
    elif origin is Union:
        type_args = [arg for arg in args
                     if arg is not type(None)]  # noqa: E721
        if len(type_args) > 1:
            schema = {"oneOf": [resolve_type_hint(arg) for arg in type_args]}
        else:
            schema = resolve_type_hint(type_args[0])
        if type(None) in args:
            schema["nullable"] = True
        return schema
    elif origin is collections.abc.Iterable:
        return build_array_type(resolve_type_hint(args[0]))
    elif isinstance(hint, typing._TypedDictMeta):
        raise UnableToProceedError(
            "Wrong TypedDict class, please use typing_extensions.TypedDict")
    else:
        raise UnableToProceedError()
Ejemplo n.º 13
0
    schema.update(kwargs)
    return schema


GENERIC_ERROR = build_object_type(
    description=_("Generic API Error"),
    properties={
        "detail": build_standard_type(OpenApiTypes.STR),
        "code": build_standard_type(OpenApiTypes.STR),
    },
    required=["detail"],
)
VALIDATION_ERROR = build_object_type(
    description=_("Validation Error"),
    properties={
        "non_field_errors": build_array_type(build_standard_type(OpenApiTypes.STR)),
        "code": build_standard_type(OpenApiTypes.STR),
    },
    required=["detail"],
    additionalProperties={},
)


def postprocess_schema_responses(result, generator, **kwargs):  # noqa: W0613
    """Workaround to set a default response for endpoints.
    Workaround suggested at
    <https://github.com/tfranzel/drf-spectacular/issues/119#issuecomment-656970357>
    for the missing drf-spectacular feature discussed in
    <https://github.com/tfranzel/drf-spectacular/issues/101>.
    """
Ejemplo n.º 14
0
        },
        additionalProperties=build_object_type(
            properties={
                "meta": build_object_type(
                    properties={
                        "param": build_basic_type(str),
                        "missing": build_basic_type(str),
                        "missing_param": build_basic_type(str),
                    }
                ),
                "doc_count_error_upper_bound": build_basic_type(int),
                "sum_other_doc_count": build_basic_type(int),
                "buckets": build_array_type(
                    build_object_type(
                        properties={
                            "key": build_basic_type(str),
                            "doc_count": build_basic_type(int),
                        },
                    ),
                ),
            },
        ),
    ),
)


class ToolhubFilterExtension(OpenApiFilterExtension):
    """Describe django_elasticsearch_dsl_drf filters."""

    target_class = "toolhub.apps.search.views.QueryStringFilterBackend"

    def get_schema_operation_parameters(self, auto_schema, *args, **kwargs):
Ejemplo n.º 15
0
    schema = build_basic_type(obj)
    schema.update(kwargs)
    return schema


GENERIC_ERROR = build_object_type(
    description=_("API error"),
    properties={
        "code": build_standard_type(OpenApiTypes.NUMBER),
        "message": build_standard_type(OpenApiTypes.STR),
        "status_code": build_standard_type(OpenApiTypes.NUMBER),
        "errors": build_array_type(
            build_object_type(
                properties={
                    "code": build_standard_type(OpenApiTypes.NUMBER),
                    "field": build_standard_type(OpenApiTypes.STR),
                    "message": build_standard_type(OpenApiTypes.STR),
                },
                required=["field", "message"],
            ),
        ),
    },
    required=["code", "message", "errors"],
)


def postprocess_schema_responses(result, generator, **kwargs):  # noqa: W0613
    """Workaround to set a default response for endpoints.

    Workaround suggested at
    <https://github.com/tfranzel/drf-spectacular/issues/119#issuecomment-656970357>
    for the missing drf-spectacular feature discussed in