コード例 #1
0
class OpenAPISchemaGenerator(object):
    """
    This class iterates over all registered API endpoints and returns an appropriate OpenAPI 2.0 compliant schema.
    Method implementations shamelessly stolen and adapted from rest-framework ``SchemaGenerator``.
    """
    endpoint_enumerator_class = EndpointEnumerator
    reference_resolver_class = ReferenceResolver

    def __init__(self, info, version='', url=None, patterns=None, urlconf=None):
        """

        :param openapi.Info info: information about the API
        :param str version: API version string; if omitted, `info.default_version` will be used
        :param str url: API scheme, host and port; if ``None`` is passed and ``DEFAULT_API_URL`` is not set, the url
            will be inferred from the request made against the schema view, so you should generally not need to set
            this parameter explicitly; if the empty string is passed, no host and scheme will be emitted

            If `url` is not ``None`` or the empty string, it must be a scheme-absolute uri (i.e. starting with http://
            or https://), and any path component is ignored;

            See also: :ref:`documentation on base URL construction <custom-spec-base-url>`
        :param patterns: if given, only these patterns will be enumerated for inclusion in the API spec
        :param urlconf: if patterns is not given, use this urlconf to enumerate patterns;
            if not given, the default urlconf is used
        """
        self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf)
        self.info = info
        self.version = version
        self.consumes = []
        self.produces = []

        if url is None and swagger_settings.DEFAULT_API_URL is not None:
            url = swagger_settings.DEFAULT_API_URL

        if url:
            parsed_url = urlparse.urlparse(url)
            if parsed_url.scheme not in ('http', 'https') or not parsed_url.netloc:
                raise SwaggerGenerationError("`url` must be an absolute HTTP(S) url")
            if parsed_url.path:
                logger.warning("path component of api base URL %s is ignored; use FORCE_SCRIPT_NAME instead" % url)

    @property
    def url(self):
        return self._gen.url

    def get_security_definitions(self):
        """Get the security schemes for this API. This determines what is usable in security requirements,
        and helps clients configure their authorization credentials.

        :return: the security schemes usable with this API
        :rtype: dict[str,dict] or None
        """
        security_definitions = swagger_settings.SECURITY_DEFINITIONS
        if security_definitions is not None:
            security_definitions = SwaggerDict._as_odict(security_definitions, {})

        return security_definitions

    def get_security_requirements(self, security_definitions):
        """Get the base (global) security requirements of the API. This is never called if
        :meth:`.get_security_definitions` returns `None`.

        :param security_definitions: security definitions as returned by :meth:`.get_security_definitions`
        :return: the security schemes accepted by default
        :rtype: list[dict[str,list[str]]] or None
        """
        security_requirements = swagger_settings.SECURITY_REQUIREMENTS
        if security_requirements is None:
            security_requirements = [{security_scheme: []} for security_scheme in security_definitions]

        security_requirements = [SwaggerDict._as_odict(sr, {}) for sr in security_requirements]
        security_requirements = sorted(security_requirements, key=list)
        return security_requirements

    def get_schema(self, request=None, public=False):
        """Generate a :class:`.Swagger` object representing the API schema.

        :param request: the request used for filtering accessible endpoints and finding the spec URI
        :type request: rest_framework.request.Request or None
        :param bool public: if True, all endpoints are included regardless of access through `request`

        :return: the generated Swagger specification
        :rtype: openapi.Swagger
        """
        endpoints = self.get_endpoints(request)
        components = self.reference_resolver_class(openapi.SCHEMA_DEFINITIONS, force_init=True)
        self.consumes = get_consumes(api_settings.DEFAULT_PARSER_CLASSES)
        self.produces = get_produces(api_settings.DEFAULT_RENDERER_CLASSES)
        paths, prefix = self.get_paths(endpoints, components, request, public)

        security_definitions = self.get_security_definitions()
        if security_definitions:
            security_requirements = self.get_security_requirements(security_definitions)
        else:
            security_requirements = None

        url = self.url
        if url is None and request is not None:
            url = request.build_absolute_uri()

        return openapi.Swagger(
            info=self.info, paths=paths, consumes=self.consumes or None, produces=self.produces or None,
            security_definitions=security_definitions, security=security_requirements,
            _url=url, _prefix=prefix, _version=self.version, **dict(components)
        )

    def create_view(self, callback, method, request=None):
        """Create a view instance from a view callback as registered in urlpatterns.

        :param callback: view callback registered in urlpatterns
        :param str method: HTTP method
        :param request: request to bind to the view
        :type request: rest_framework.request.Request or None
        :return: the view instance
        """
        view = self._gen.create_view(callback, method, request)
        overrides = getattr(callback, '_swagger_auto_schema', None)
        if overrides is not None:
            # decorated function based view must have its decorator information passed on to the re-instantiated view
            for method, _ in overrides.items():
                view_method = getattr(view, method, None)
                if view_method is not None:  # pragma: no cover
                    setattr(view_method.__func__, '_swagger_auto_schema', overrides)

        setattr(view, 'swagger_fake_view', True)
        return view

    def coerce_path(self, path, view):
        """Coerce {pk} path arguments into the name of the model field, where possible. This is cleaner for an
        external representation (i.e. "this is an identifier", not "this is a database primary key").

        :param str path: the path
        :param rest_framework.views.APIView view: associated view
        :rtype: str
        """
        if '{pk}' not in path:
            return path

        model = getattr(get_queryset_from_view(view), 'model', None)
        if model:
            field_name = get_pk_name(model)
        else:
            field_name = 'id'
        return path.replace('{pk}', '{%s}' % field_name)

    def get_endpoints(self, request):
        """Iterate over all the registered endpoints in the API and return a fake view with the right parameters.

        :param request: request to bind to the endpoint views
        :type request: rest_framework.request.Request or None
        :return: {path: (view_class, list[(http_method, view_instance)])
        :rtype: dict[str,(type,list[(str,rest_framework.views.APIView)])]
        """
        enumerator = self.endpoint_enumerator_class(self._gen.patterns, self._gen.urlconf, request=request)
        endpoints = enumerator.get_api_endpoints()

        view_paths = defaultdict(list)
        view_cls = {}
        for path, method, callback in endpoints:
            view = self.create_view(callback, method, request)
            path = self.coerce_path(path, view)
            view_paths[path].append((method, view))
            view_cls[path] = callback.cls
        return {path: (view_cls[path], methods) for path, methods in view_paths.items()}

    def get_operation_keys(self, subpath, method, view):
        """Return a list of keys that should be used to group an operation within the specification. ::

          /users/                   ("users", "list"), ("users", "create")
          /users/{pk}/              ("users", "read"), ("users", "update"), ("users", "delete")
          /users/enabled/           ("users", "enabled")  # custom viewset list action
          /users/{pk}/star/         ("users", "star")     # custom viewset detail action
          /users/{pk}/groups/       ("users", "groups", "list"), ("users", "groups", "create")
          /users/{pk}/groups/{pk}/  ("users", "groups", "read"), ("users", "groups", "update")

        :param str subpath: path to the operation with any common prefix/base path removed
        :param str method: HTTP method
        :param view: the view associated with the operation
        :rtype: list[str]
        """
        return self._gen.get_keys(subpath, method, view)

    def determine_path_prefix(self, paths):
        """
        Given a list of all paths, return the common prefix which should be
        discounted when generating a schema structure.

        This will be the longest common string that does not include that last
        component of the URL, or the last component before a path parameter.

        For example: ::

            /api/v1/users/
            /api/v1/users/{pk}/

        The path prefix is ``/api/v1/``.

        :param list[str] paths: list of paths
        :rtype: str
        """
        return self._gen.determine_path_prefix(paths)

    def should_include_endpoint(self, path, method, view, public):
        """Check if a given endpoint should be included in the resulting schema.

        :param str path: request path
        :param str method: http request method
        :param view: instantiated view callback
        :param bool public: if True, all endpoints are included regardless of access through `request`
        :returns: true if the view should be excluded
        :rtype: bool
        """
        return public or self._gen.has_view_permissions(path, method, view)

    def get_paths_object(self, paths):
        """Construct the Swagger Paths object.

        :param OrderedDict[str,openapi.PathItem] paths: mapping of paths to :class:`.PathItem` objects
        :returns: the :class:`.Paths` object
        :rtype: openapi.Paths
        """
        return openapi.Paths(paths=paths)

    def get_paths(self, endpoints, components, request, public):
        """Generate the Swagger Paths for the API from the given endpoints.

        :param dict endpoints: endpoints as returned by get_endpoints
        :param ReferenceResolver components: resolver/container for Swagger References
        :param Request request: the request made against the schema view; can be None
        :param bool public: if True, all endpoints are included regardless of access through `request`
        :returns: the :class:`.Paths` object and the longest common path prefix, as a 2-tuple
        :rtype: tuple[openapi.Paths,str]
        """
        if not endpoints:
            return openapi.Paths(paths={}), ''

        prefix = self.determine_path_prefix(list(endpoints.keys())) or ''
        assert '{' not in prefix, "base path cannot be templated in swagger 2.0"

        paths = OrderedDict()
        for path, (view_cls, methods) in sorted(endpoints.items()):
            operations = {}
            for method, view in methods:
                if not self.should_include_endpoint(path, method, view, public):
                    continue

                operation = self.get_operation(view, path, prefix, method, components, request)
                if operation is not None:
                    operations[method.lower()] = operation

            if operations:
                # since the common prefix is used as the API basePath, it must be stripped
                # from individual paths when writing them into the swagger document
                path_suffix = path[len(prefix):]
                if not path_suffix.startswith('/'):
                    path_suffix = '/' + path_suffix
                paths[path_suffix] = self.get_path_item(path, view_cls, operations)

        return self.get_paths_object(paths), prefix

    def get_operation(self, view, path, prefix, method, components, request):
        """Get an :class:`.Operation` for the given API endpoint (path, method). This method delegates to
        :meth:`~.inspectors.ViewInspector.get_operation` of a :class:`~.inspectors.ViewInspector` determined
        according to settings and :func:`@swagger_auto_schema <.swagger_auto_schema>` overrides.

        :param view: the view associated with this endpoint
        :param str path: the path component of the operation URL
        :param str prefix: common path prefix among all endpoints
        :param str method: the http method of the operation
        :param openapi.ReferenceResolver components: referenceable components
        :param Request request: the request made against the schema view; can be None
        :rtype: openapi.Operation
        """
        operation_keys = self.get_operation_keys(path[len(prefix):], method, view)
        overrides = self.get_overrides(view, method)

        # the inspector class can be specified, in decreasing order of priorty,
        #   1. globaly via DEFAULT_AUTO_SCHEMA_CLASS
        view_inspector_cls = swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS
        #   2. on the view/viewset class
        view_inspector_cls = getattr(view, 'swagger_schema', view_inspector_cls)
        #   3. on the swagger_auto_schema decorator
        view_inspector_cls = overrides.get('auto_schema', view_inspector_cls)

        if view_inspector_cls is None:
            return None

        view_inspector = view_inspector_cls(view, path, method, components, request, overrides, operation_keys)
        operation = view_inspector.get_operation(operation_keys)
        if operation is None:
            return None

        if 'consumes' in operation and set(operation.consumes) == set(self.consumes):
            del operation.consumes
        if 'produces' in operation and set(operation.produces) == set(self.produces):
            del operation.produces
        return operation

    def get_path_item(self, path, view_cls, operations):
        """Get a :class:`.PathItem` object that describes the parameters and operations related to a single path in the
        API.

        :param str path: the path
        :param type view_cls: the view that was bound to this path in urlpatterns
        :param dict[str,openapi.Operation] operations: operations defined on this path, keyed by lowercase HTTP method
        :rtype: openapi.PathItem
        """
        path_parameters = self.get_path_parameters(path, view_cls)
        return openapi.PathItem(parameters=path_parameters, **operations)

    def get_overrides(self, view, method):
        """Get overrides specified for a given operation.

        :param view: the view associated with the operation
        :param str method: HTTP method
        :return: a dictionary containing any overrides set by :func:`@swagger_auto_schema <.swagger_auto_schema>`
        :rtype: dict
        """
        method = method.lower()
        action = getattr(view, 'action', method)
        action_method = getattr(view, action, None)
        overrides = getattr(action_method, '_swagger_auto_schema', {})
        if method in overrides:
            overrides = overrides[method]

        return copy.deepcopy(overrides)

    def get_path_parameters(self, path, view_cls):
        """Return a list of Parameter instances corresponding to any templated path variables.

        :param str path: templated request path
        :param type view_cls: the view class associated with the path
        :return: path parameters
        :rtype: list[openapi.Parameter]
        """
        parameters = []
        queryset = get_queryset_from_view(view_cls)

        for variable in sorted(uritemplate.variables(path)):
            model, model_field = get_queryset_field(queryset, variable)
            attrs = get_basic_type_info(model_field) or {'type': openapi.TYPE_STRING}
            if getattr(view_cls, 'lookup_field', None) == variable and attrs['type'] == openapi.TYPE_STRING:
                attrs['pattern'] = getattr(view_cls, 'lookup_value_regex', attrs.get('pattern', None))

            if model_field and getattr(model_field, 'help_text', False):
                description = model_field.help_text
            elif model_field and getattr(model_field, 'primary_key', False):
                description = get_pk_description(model, model_field)
            else:
                description = None

            field = openapi.Parameter(
                name=variable,
                description=force_real_str(description),
                required=True,
                in_=openapi.IN_PATH,
                **attrs
            )
            parameters.append(field)

        return parameters
コード例 #2
0
class OpenAPISchemaGenerator(object):
    """
    This class iterates over all registered API endpoints and returns an appropriate OpenAPI 2.0 compliant schema.
    Method implementations shamelessly stolen and adapted from rest-framework ``SchemaGenerator``.
    """
    endpoint_enumerator_class = EndpointEnumerator

    def __init__(self, info, version, url=None, patterns=None, urlconf=None):
        """

        :param .Info info: information about the API
        :param str version: API version string, takes preedence over the version in `info`
        :param str url: API
        :param patterns: if given, only these patterns will be enumerated for inclusion in the API spec
        :param urlconf: if patterns is not given, use this urlconf to enumerate patterns;
            if not given, the default urlconf is used
        """
        self._gen = SchemaGenerator(info.title, url,
                                    info.get('description',
                                             ''), patterns, urlconf)
        self.info = info
        self.version = version

    @property
    def url(self):
        return self._gen.url

    def get_schema(self, request=None, public=False):
        """Generate a :class:`.Swagger` object representing the API schema.

        :param Request request: the request used for filtering
            accesible endpoints and finding the spec URI
        :param bool public: if True, all endpoints are included regardless of access through `request`

        :return: the generated Swagger specification
        :rtype: openapi.Swagger
        """
        endpoints = self.get_endpoints(request)
        endpoints = self.replace_version(endpoints, request)
        components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS)
        paths = self.get_paths(endpoints, components, request, public)

        url = self.url
        if not url and request is not None:
            url = request.build_absolute_uri()

        return openapi.Swagger(info=self.info,
                               paths=paths,
                               _url=url,
                               _version=self.version,
                               **dict(components))

    def create_view(self, callback, method, request=None):
        """Create a view instance from a view callback as registered in urlpatterns.

        :param callable callback: view callback registered in urlpatterns
        :param str method: HTTP method
        :param rest_framework.request.Request request: request to bind to the view
        :return: the view instance
        """
        view = self._gen.create_view(callback, method, request)
        overrides = getattr(callback, '_swagger_auto_schema', None)
        if overrides is not None:
            # decorated function based view must have its decorator information passed on to the re-instantiated view
            for method, _ in overrides.items():
                view_method = getattr(view, method, None)
                if view_method is not None:  # pragma: no cover
                    setattr(view_method.__func__, '_swagger_auto_schema',
                            overrides)
        return view

    def replace_version(self, endpoints, request):
        """If ``request.version`` is not ``None``, replace the version parameter in the path of any endpoints using
        ``URLPathVersioning`` as a versioning class.

        :param dict endpoints: endpoints as returned by :meth:`.get_endpoints`
        :param Request request: the request made against the schema view
        :return: endpoints with modified paths
        """
        version = getattr(request, 'version', None)
        if version is None:
            return endpoints

        new_endpoints = {}
        for path, endpoint in endpoints.items():
            view_cls = endpoint[0]
            versioning_class = getattr(view_cls, 'versioning_class', None)
            version_param = getattr(versioning_class, 'version_param',
                                    'version')
            if versioning_class is not None and issubclass(
                    versioning_class, versioning.URLPathVersioning):
                path = path.replace('{%s}' % version_param, version)

            new_endpoints[path] = endpoint

        return new_endpoints

    def get_endpoints(self, request):
        """Iterate over all the registered endpoints in the API and return a fake view with the right parameters.

        :param rest_framework.request.Request request: request to bind to the endpoint views
        :return: {path: (view_class, list[(http_method, view_instance)])
        :rtype: dict
        """
        enumerator = self.endpoint_enumerator_class(self._gen.patterns,
                                                    self._gen.urlconf)
        endpoints = enumerator.get_api_endpoints()

        view_paths = defaultdict(list)
        view_cls = {}
        for path, method, callback in endpoints:
            view = self.create_view(callback, method, request)
            path = self._gen.coerce_path(path, method, view)
            view_paths[path].append((method, view))
            view_cls[path] = callback.cls
        return {
            path: (view_cls[path], methods)
            for path, methods in view_paths.items()
        }

    def get_operation_keys(self, subpath, method, view):
        """Return a list of keys that should be used to group an operation within the specification. ::

          /users/                   ("users", "list"), ("users", "create")
          /users/{pk}/              ("users", "read"), ("users", "update"), ("users", "delete")
          /users/enabled/           ("users", "enabled")  # custom viewset list action
          /users/{pk}/star/         ("users", "star")     # custom viewset detail action
          /users/{pk}/groups/       ("users", "groups", "list"), ("users", "groups", "create")
          /users/{pk}/groups/{pk}/  ("users", "groups", "read"), ("users", "groups", "update")

        :param str subpath: path to the operation with any common prefix/base path removed
        :param str method: HTTP method
        :param view: the view associated with the operation
        :rtype: tuple
        """
        return self._gen.get_keys(subpath, method, view)

    def determine_path_prefix(self, paths):
        """
        Given a list of all paths, return the common prefix which should be
        discounted when generating a schema structure.

        This will be the longest common string that does not include that last
        component of the URL, or the last component before a path parameter.

        For example: ::

            /api/v1/users/
            /api/v1/users/{pk}/

        The path prefix is ``/api/v1/``.

        :param list[str] paths: list of paths
        :rtype: str
        """
        return self._gen.determine_path_prefix(paths)

    def get_paths(self, endpoints, components, request, public):
        """Generate the Swagger Paths for the API from the given endpoints.

        :param dict endpoints: endpoints as returned by get_endpoints
        :param ReferenceResolver components: resolver/container for Swagger References
        :param Request request: the request made against the schema view; can be None
        :param bool public: if True, all endpoints are included regardless of access through `request`
        :rtype: openapi.Paths
        """
        if not endpoints:
            return openapi.Paths(paths={})

        prefix = self.determine_path_prefix(list(endpoints.keys()))
        paths = OrderedDict()

        for path, (view_cls, methods) in sorted(endpoints.items()):
            operations = {}
            for method, view in methods:
                if not public and not self._gen.has_view_permissions(
                        path, method, view):
                    continue

                operations[method.lower()] = self.get_operation(
                    view, path, prefix, method, components, request)

            if operations:
                paths[path] = self.get_path_item(path, view_cls, operations)

        return openapi.Paths(paths=paths)

    def get_operation(self, view, path, prefix, method, components, request):
        """Get an :class:`.Operation` for the given API endpoint (path, method). This method delegates to
        :meth:`~.inspectors.ViewInspector.get_operation` of a :class:`~.inspectors.ViewInspector` determined
        according to settings and :func:`@swagger_auto_schema <.swagger_auto_schema>` overrides.

        :param view: the view associated with this endpoint
        :param str path: the path component of the operation URL
        :param str prefix: common path prefix among all endpoints
        :param str method: the http method of the operation
        :param openapi.ReferenceResolver components: referenceable components
        :param Request request: the request made against the schema view; can be None
        :rtype: openapi.Operation
        """

        operation_keys = self.get_operation_keys(path[len(prefix):], method,
                                                 view)
        overrides = self.get_overrides(view, method)

        # the inspector class can be specified, in decreasing order of priorty,
        #   1. globaly via DEFAULT_AUTO_SCHEMA_CLASS
        view_inspector_cls = swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS
        #   2. on the view/viewset class
        view_inspector_cls = getattr(view, 'swagger_schema',
                                     view_inspector_cls)
        #   3. on the swagger_auto_schema decorator
        view_inspector_cls = overrides.get('auto_schema', view_inspector_cls)

        view_inspector = view_inspector_cls(view, path, method, components,
                                            request, overrides)
        return view_inspector.get_operation(operation_keys)

    def get_path_item(self, path, view_cls, operations):
        """Get a :class:`.PathItem` object that describes the parameters and operations related to a single path in the
        API.

        :param str path: the path
        :param type view_cls: the view that was bound to this path in urlpatterns
        :param dict[str,openapi.Operation] operations: operations defined on this path, keyed by lowercase HTTP method
        :rtype: openapi.PathItem
        """
        path_parameters = self.get_path_parameters(path, view_cls)
        return openapi.PathItem(parameters=path_parameters, **operations)

    def get_overrides(self, view, method):
        """Get overrides specified for a given operation.

        :param view: the view associated with the operation
        :param str method: HTTP method
        :return: a dictionary containing any overrides set by :func:`@swagger_auto_schema <.swagger_auto_schema>`
        :rtype: dict
        """
        method = method.lower()
        action = getattr(view, 'action', method)
        action_method = getattr(view, action, None)
        overrides = getattr(action_method, '_swagger_auto_schema', {})
        if method in overrides:
            overrides = overrides[method]

        return overrides

    def get_path_parameters(self, path, view_cls):
        """Return a list of Parameter instances corresponding to any templated path variables.

        :param str path: templated request path
        :param type view_cls: the view class associated with the path
        :return: path parameters
        :rtype: list[openapi.Parameter]
        """
        parameters = []
        queryset = getattr(view_cls, 'queryset', None)
        model = getattr(getattr(view_cls, 'queryset', None), 'model', None)

        for variable in uritemplate.variables(path):
            model, model_field = get_queryset_field(queryset, variable)
            attrs = get_basic_type_info(model_field) or {
                'type': openapi.TYPE_STRING
            }
            if hasattr(view_cls, 'lookup_value_regex') and getattr(
                    view_cls, 'lookup_field', None) == variable:
                attrs['pattern'] = view_cls.lookup_value_regex

            if model_field and model_field.help_text:
                description = force_text(model_field.help_text)
            elif model_field and model_field.primary_key:
                description = get_pk_description(model, model_field)
            else:
                description = None

            field = openapi.Parameter(name=variable,
                                      description=description,
                                      required=True,
                                      in_=openapi.IN_PATH,
                                      **attrs)
            parameters.append(field)

        return parameters
コード例 #3
0
class OpenAPISchemaGenerator(object):
    """
    This class iterates over all registered API endpoints and returns an appropriate OpenAPI 2.0 compliant schema.
    Method implementations shamelessly stolen and adapted from rest_framework SchemaGenerator.
    """
    def __init__(self, info, version, url=None, patterns=None, urlconf=None):
        """

        :param .Info info: information about the API
        :param str version: API version string, takes preedence over the version in `info`
        :param str url: API
        :param patterns: if given, only these patterns will be enumerated for inclusion in the API spec
        :param urlconf: if patterns is not given, use this urlconf to enumerate patterns;
            if not given, the default urlconf is used
        """
        self._gen = SchemaGenerator(info.title, url,
                                    info.get('description',
                                             ''), patterns, urlconf)
        self.info = info
        self.version = version

    def get_schema(self, request=None, public=False):
        """Generate an :class:`.Swagger` representing the API schema.

        :param rest_framework.request.Request request: the request used for filtering
            accesible endpoints and finding the spec URI
        :param bool public: if True, all endpoints are included regardless of access through `request`

        :return: the generated Swagger specification
        :rtype: openapi.Swagger
        """
        endpoints = self.get_endpoints(None if public else request)
        components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS)
        paths = self.get_paths(endpoints, components)

        url = self._gen.url
        if not url and request is not None:
            url = request.build_absolute_uri()

        return openapi.Swagger(info=self.info,
                               paths=paths,
                               _url=url,
                               _version=self.version,
                               **dict(components))

    def create_view(self, callback, method, request=None):
        """Create a view instance from a view callback as registered in urlpatterns.

        :param callable callback: view callback registered in urlpatterns
        :param str method: HTTP method
        :param rest_framework.request.Request request: request to bind to the view
        :return: the view instance
        """
        view = self._gen.create_view(callback, method, request)
        overrides = getattr(callback, 'swagger_auto_schema', None)
        if overrides is not None:
            # decorated function based view must have its decorator information passed on to the re-instantiated view
            for method, _ in overrides.items():
                view_method = getattr(view, method, None)
                if view_method is not None:  # pragma: no cover
                    setattr(view_method.__func__, 'swagger_auto_schema',
                            overrides)
        return view

    def get_endpoints(self, request=None):
        """Iterate over all the registered endpoints in the API.

        :param rest_framework.request.Request request: used for returning only endpoints available to the given request
        :return: {path: (view_class, list[(http_method, view_instance)])
        :rtype: dict
        """
        inspector = self._gen.endpoint_inspector_cls(self._gen.patterns,
                                                     self._gen.urlconf)
        endpoints = inspector.get_api_endpoints()

        view_paths = defaultdict(list)
        view_cls = {}
        for path, method, callback in endpoints:
            view = self.create_view(callback, method, request)
            path = self._gen.coerce_path(path, method, view)
            view_paths[path].append((method, view))
            view_cls[path] = callback.cls
        return {
            path: (view_cls[path], methods)
            for path, methods in view_paths.items()
        }

    def get_operation_keys(self, subpath, method, view):
        """Return a list of keys that should be used to group an operation within the specification.

        ::

          /users/                   ("users", "list"), ("users", "create")
          /users/{pk}/              ("users", "read"), ("users", "update"), ("users", "delete")
          /users/enabled/           ("users", "enabled")  # custom viewset list action
          /users/{pk}/star/         ("users", "star")     # custom viewset detail action
          /users/{pk}/groups/       ("users", "groups", "list"), ("users", "groups", "create")
          /users/{pk}/groups/{pk}/  ("users", "groups", "read"), ("users", "groups", "update")

        :param str subpath: path to the operation with any common prefix/base path removed
        :param str method: HTTP method
        :param view: the view associated with the operation
        :rtype: tuple
        """
        return self._gen.get_keys(subpath, method, view)

    def get_paths(self, endpoints, components):
        """Generate the Swagger Paths for the API from the given endpoints.

        :param dict endpoints: endpoints as returned by get_endpoints
        :param ReferenceResolver components: resolver/container for Swagger References
        :rtype: openapi.Paths
        """
        if not endpoints:
            return openapi.Paths(paths={})

        prefix = self._gen.determine_path_prefix(endpoints.keys())
        paths = OrderedDict()

        default_schema_cls = SwaggerAutoSchema
        for path, (view_cls, methods) in sorted(endpoints.items()):
            path_parameters = self.get_path_parameters(path, view_cls)
            operations = {}
            for method, view in methods:
                if not self._gen.has_view_permissions(path, method, view):
                    continue

                operation_keys = self.get_operation_keys(
                    path[len(prefix):], method, view)
                overrides = self.get_overrides(view, method)
                auto_schema_cls = overrides.get('auto_schema',
                                                default_schema_cls)
                schema = auto_schema_cls(view, path, method, overrides,
                                         components)
                operations[method.lower()] = schema.get_operation(
                    operation_keys)

            if operations:
                paths[path] = openapi.PathItem(parameters=path_parameters,
                                               **operations)

        return openapi.Paths(paths=paths)

    def get_overrides(self, view, method):
        """Get overrides specified for a given operation.

        :param view: the view associated with the operation
        :param str method: HTTP method
        :return: a dictionary containing any overrides set by :func:`@swagger_auto_schema <.swagger_auto_schema>`
        :rtype: dict
        """
        method = method.lower()
        action = getattr(view, 'action', method)
        action_method = getattr(view, action, None)
        overrides = getattr(action_method, 'swagger_auto_schema', {})
        if method in overrides:
            overrides = overrides[method]

        return overrides

    def get_path_parameters(self, path, view_cls):
        """Return a list of Parameter instances corresponding to any templated path variables.

        :param str path: templated request path
        :param type view_cls: the view class associated with the path
        :return: path parameters
        :rtype: list[openapi.Parameter]
        """
        parameters = []
        model = getattr(getattr(view_cls, 'queryset', None), 'model', None)

        for variable in uritemplate.variables(path):
            pattern = None
            type = openapi.TYPE_STRING
            description = None
            if model is not None:
                # Attempt to infer a field description if possible.
                try:
                    model_field = model._meta.get_field(variable)
                except Exception:  # pragma: no cover
                    model_field = None

                if model_field is not None and model_field.help_text:
                    description = force_text(model_field.help_text)
                elif model_field is not None and model_field.primary_key:
                    description = get_pk_description(model, model_field)

                if hasattr(view_cls, 'lookup_value_regex') and getattr(
                        view_cls, 'lookup_field', None) == variable:
                    pattern = view_cls.lookup_value_regex
                elif isinstance(model_field, django.db.models.AutoField):
                    type = openapi.TYPE_INTEGER

            field = openapi.Parameter(
                name=variable,
                required=True,
                in_=openapi.IN_PATH,
                type=type,
                pattern=pattern,
                description=description,
            )
            parameters.append(field)

        return parameters
コード例 #4
0
class OpenAPISchemaGenerator(object):
    endpoint_enumerator_class = EndpointEnumerator

    def __init__(self, name='API', version=''):
        self._gen = SchemaGenerator(name, None, '', None, None)
        self.version = version
        self.consumes = []
        self.produces = []

    @property
    def url(self):
        return self._gen.url

    def create_view(self, callback, method, request=None):
        """Create a view instance from a view callback as registered in urlpatterns.

        :param callback: view callback registered in urlpatterns
        :param str method: HTTP method
        :param request: request to bind to the view
        :type request: rest_framework.request.Request or None
        :return: the view instance
        """
        view = self._gen.create_view(callback, method, request)
        overrides = getattr(callback, '_swagger_auto_schema', None)
        if overrides is not None:
            # decorated function based view must have its decorator information passed on to the re-instantiated view
            for method, _ in overrides.items():
                view_method = getattr(view, method, None)
                if view_method is not None:  # pragma: no cover
                    setattr(view_method.__func__, '_swagger_auto_schema',
                            overrides)

        setattr(view, 'swagger_fake_view', True)
        return view

    def get_schema(self, request=None, public=False):
        endpoints = self.get_endpoints(request)
        return self.get_paths(endpoints, None, request, public=False)

    def get_endpoints(self, request):
        enumerator = self.endpoint_enumerator_class(self._gen.patterns,
                                                    self._gen.urlconf,
                                                    request=request)
        endpoints = []
        enumerator.get_api_endpoints(final_arrays=endpoints)
        ret = []
        for group in endpoints:
            #print(group)
            group_points = group['points']
            view_paths = defaultdict(list)
            view_cls = {}
            for path, method, callback in group_points:
                view = self.create_view(callback, method, request)
                # path = self.coerce_path(path, view)
                view_paths[path].append((method, view))
                view_cls[path] = callback.cls
            #返回的是一个json-dict
            ret.append({
                "prefix": group['prefix'],
                "points": {
                    path: (view_cls[path], methods)
                    for path, methods in view_paths.items()
                }
            })
        return ret

    def determine_path_prefix(self, paths):
        """
        Given a list of all paths, return the common prefix which should be
        discounted when generating a schema structure.

        This will be the longest common string that does not include that last
        component of the URL, or the last component before a path parameter.

        For example: ::

            /api/v1/users/
            /api/v1/users/{pk}/

        The path prefix is ``/api/v1/``.

        :param list[str] paths: list of paths
        :rtype: str
        """
        return self._gen.determine_path_prefix(paths)

    def _get_description_section(self, view, header, description):
        lines = [line for line in description.splitlines()]
        current_section = ''
        sections = {'': ''}

        for line in lines:
            if header_regex.match(line):
                current_section, seperator, lead = line.partition(':')
                sections[current_section] = lead.strip()
            else:
                sections[current_section] += '\n' + line

        # TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
        coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
        if header in sections:
            return sections[header].strip()
        if header in coerce_method_names:
            if coerce_method_names[header] in sections:
                return sections[coerce_method_names[header]].strip()
        return sections[''].strip()

    def split_summary_from_description(self, description):
        """Decide if and how to split a summary out of the given description. The default implementation
        uses the first paragraph of the description as a summary if it is less than 120 characters long.

        :param description: the full description to be analyzed
        :return: summary and description
        :rtype: (str,str)
        """
        # https://www.python.org/dev/peps/pep-0257/#multi-line-docstrings
        summary = None
        summary_max_len = 120  # OpenAPI 2.0 spec says summary should be under 120 characters
        sections = description.split('\n', 1)
        if len(sections) == 2:
            sections[0] = sections[0].strip()
            if len(sections[0]) < summary_max_len:
                summary, description = sections
                description = description.strip()

        return summary, description

    def get_paths(self, top, components, request, public):
        """
        Generate the Swagger Paths for the API from the given endpoints.
        """
        if not top:
            return None

        ret = []
        for category in top:
            print('~~~~~~~~~')
            endpoints = category['points']
            if endpoints is None or len(endpoints) == 0:
                continue
            prefix = self.determine_path_prefix(list(endpoints.keys())) or ''

            assert '{' not in prefix, "base path cannot be templated in swagger 2.0"

            paths = {}
            for path, (view_cls, methods) in sorted(endpoints.items()):
                _path = {}
                for method, view in methods:
                    # print(method)
                    # if not self.should_include_endpoint(path, method, view, public):
                    #     continue
                    method_lower = method.lower()
                    method_name = getattr(view, 'action', method_lower)
                    method_docstring = getattr(view, method_name, None).__doc__
                    if method_docstring:
                        method_docstring = self._get_description_section(
                            view, method.lower(),
                            formatting.dedent(smart_text(method_docstring)))
                    else:
                        method_docstring = self._get_description_section(
                            view, getattr(view, 'action', method.lower()),
                            view.get_view_description())
                    if not method_docstring:
                        method_docstring = "_"

                    (_name, _desc
                     ) = self.split_summary_from_description(method_docstring)
                    if _name is None:
                        _name = _desc
                    _path[method_lower] = {"name": _name, "description": _desc}

                path_suffix = path[len(prefix):]
                if not path_suffix.startswith('/'):
                    path_suffix = '/' + path_suffix
                paths[path_suffix] = _path

            # #检查是否有basePath可以融合在一起
            # _current_len=len(ret)
            # for i in range(_current_len):
            #     _current_prefix=ret[i]['basePath']
            #     if prefix.find(_current_prefix)>=0:
            #         ret[i]['paths'].update(paths)
            #
            #         paths=None
            #         break
            #     elif _current_prefix.find(prefix)>=0:
            #         ret[i]['basePath']=prefix
            #         ret[i]['paths'].update(paths)
            #         paths=None
            #         break
            if not paths is None:
                ret.append({
                    "basePath": prefix,
                    "basePathArray": prefix.strip("/").split("/"),
                    "paths": paths
                })

        return ret