def test_4605_regression(self):
     generator = SchemaGenerator()
     prefix = generator.determine_path_prefix([
         '/api/v1/items/',
         '/auth/convert-token/'
     ])
     assert prefix == '/'
예제 #2
0
 def test_4605_regression(self):
     generator = SchemaGenerator()
     prefix = generator.determine_path_prefix([
         '/api/v1/items/',
         '/auth/convert-token/'
     ])
     assert prefix == '/'
예제 #3
0
class OpenAPISchemaGenerator(object):
    """
    This class iterates over all registered API endpoints and returns an appropriate OpenAPI 2.0 compliant schema.
    Method implementations shamelessly stolen and adapted from rest-framework ``SchemaGenerator``.
    """
    endpoint_enumerator_class = EndpointEnumerator
    reference_resolver_class = ReferenceResolver

    def __init__(self,
                 info,
                 version='',
                 url=None,
                 patterns=None,
                 urlconf=None,
                 swagger_settings=_swagger_settings):
        """

        :param openapi.Info info: information about the API
        :param str version: API version string; if omitted, `info.default_version` will be used
        :param str url: API scheme, host and port; if ``None`` is passed and ``DEFAULT_API_URL`` is not set, the url
            will be inferred from the request made against the schema view, so you should generally not need to set
            this parameter explicitly; if the empty string is passed, no host and scheme will be emitted

            If `url` is not ``None`` or the empty string, it must be a scheme-absolute uri (i.e. starting with http://
            or https://), and any path component is ignored;

            See also: :ref:`documentation on base URL construction <custom-spec-base-url>`
        :param patterns: if given, only these patterns will be enumerated for inclusion in the API spec
        :param urlconf: if patterns is not given, use this urlconf to enumerate patterns;
            if not given, the default urlconf is used
        :param swagger_settings: if given global swagger_settings are overridden with local settings
        """
        self._gen = SchemaGenerator(info.title, url,
                                    info.get('description',
                                             ''), patterns, urlconf)
        self.info = info
        self.swagger_settings = swagger_settings
        self.version = version
        self.consumes = []
        self.produces = []

        if url is None and swagger_settings.DEFAULT_API_URL is not None:
            url = swagger_settings.DEFAULT_API_URL

        if url:
            parsed_url = urlparse.urlparse(url)
            if parsed_url.scheme not in ('http',
                                         'https') or not parsed_url.netloc:
                raise SwaggerGenerationError(
                    "`url` must be an absolute HTTP(S) url")
            if parsed_url.path:
                logger.warning(
                    "path component of api base URL %s is ignored; use FORCE_SCRIPT_NAME instead"
                    % url)

    @property
    def url(self):
        return self._gen.url

    def get_security_definitions(self):
        """Get the security schemes for this API. This determines what is usable in security requirements,
        and helps clients configure their authorization credentials.

        :return: the security schemes usable with this API
        :rtype: dict[str,dict] or None
        """
        security_definitions = self.swagger_settings.SECURITY_DEFINITIONS
        if security_definitions is not None:
            security_definitions = SwaggerDict._as_odict(
                security_definitions, {})

        return security_definitions

    def get_security_requirements(self, security_definitions):
        """Get the base (global) security requirements of the API. This is never called if
        :meth:`.get_security_definitions` returns `None`.

        :param security_definitions: security definitions as returned by :meth:`.get_security_definitions`
        :return: the security schemes accepted by default
        :rtype: list[dict[str,list[str]]] or None
        """
        security_requirements = self.swagger_settings.SECURITY_REQUIREMENTS
        if security_requirements is None:
            security_requirements = [{
                security_scheme: []
            } for security_scheme in security_definitions]

        security_requirements = [
            SwaggerDict._as_odict(sr, {}) for sr in security_requirements
        ]
        security_requirements = sorted(security_requirements, key=list)
        return security_requirements

    def get_schema(self, request=None, public=False):
        """Generate a :class:`.Swagger` object representing the API schema.

        :param request: the request used for filtering accessible endpoints and finding the spec URI
        :type request: rest_framework.request.Request or None
        :param bool public: if True, all endpoints are included regardless of access through `request`

        :return: the generated Swagger specification
        :rtype: openapi.Swagger
        """
        endpoints = self.get_endpoints(request)
        components = self.reference_resolver_class(openapi.SCHEMA_DEFINITIONS,
                                                   force_init=True)
        self.consumes = get_consumes(api_settings.DEFAULT_PARSER_CLASSES)
        self.produces = get_produces(api_settings.DEFAULT_RENDERER_CLASSES,
                                     self.swagger_settings)
        paths, prefix = self.get_paths(endpoints, components, request, public)

        security_definitions = self.get_security_definitions()
        if security_definitions:
            security_requirements = self.get_security_requirements(
                security_definitions)
        else:
            security_requirements = None

        url = self.url
        if url is None and request is not None:
            url = request.build_absolute_uri()

        return openapi.Swagger(info=self.info,
                               paths=paths,
                               consumes=self.consumes or None,
                               produces=self.produces or None,
                               security_definitions=security_definitions,
                               security=security_requirements,
                               _url=url,
                               _prefix=prefix,
                               _version=self.version,
                               **dict(components))

    def create_view(self, callback, method, request=None):
        """Create a view instance from a view callback as registered in urlpatterns.

        :param callback: view callback registered in urlpatterns
        :param str method: HTTP method
        :param request: request to bind to the view
        :type request: rest_framework.request.Request or None
        :return: the view instance
        """
        view = self._gen.create_view(callback, method, request)
        overrides = getattr(callback, '_swagger_auto_schema', None)
        if overrides is not None:
            # decorated function based view must have its decorator information passed on to the re-instantiated view
            for method, _ in overrides.items():
                view_method = getattr(view, method, None)
                if view_method is not None:  # pragma: no cover
                    setattr(view_method.__func__, '_swagger_auto_schema',
                            overrides)

        setattr(view, 'swagger_fake_view', True)
        return view

    def coerce_path(self, path, view):
        """Coerce {pk} path arguments into the name of the model field, where possible. This is cleaner for an
        external representation (i.e. "this is an identifier", not "this is a database primary key").

        :param str path: the path
        :param rest_framework.views.APIView view: associated view
        :rtype: str
        """
        if '{pk}' not in path:
            return path

        model = getattr(get_queryset_from_view(view), 'model', None)
        if model:
            field_name = get_pk_name(model)
        else:
            field_name = 'id'
        return path.replace('{pk}', '{%s}' % field_name)

    def get_endpoints(self, request):
        """Iterate over all the registered endpoints in the API and return a fake view with the right parameters.

        :param request: request to bind to the endpoint views
        :type request: rest_framework.request.Request or None
        :return: {path: (view_class, list[(http_method, view_instance)])
        :rtype: dict[str,(type,list[(str,rest_framework.views.APIView)])]
        """
        enumerator = self.endpoint_enumerator_class(self._gen.patterns,
                                                    self._gen.urlconf,
                                                    request=request)
        endpoints = enumerator.get_api_endpoints()

        view_paths = defaultdict(list)
        view_cls = {}
        for path, method, callback in endpoints:
            view = self.create_view(callback, method, request)
            path = self.coerce_path(path, view)
            view_paths[path].append((method, view))
            view_cls[path] = callback.cls
        return {
            path: (view_cls[path], methods)
            for path, methods in view_paths.items()
        }

    def get_operation_keys(self, subpath, method, view):
        """Return a list of keys that should be used to group an operation within the specification. ::

          /users/                   ("users", "list"), ("users", "create")
          /users/{pk}/              ("users", "read"), ("users", "update"), ("users", "delete")
          /users/enabled/           ("users", "enabled")  # custom viewset list action
          /users/{pk}/star/         ("users", "star")     # custom viewset detail action
          /users/{pk}/groups/       ("users", "groups", "list"), ("users", "groups", "create")
          /users/{pk}/groups/{pk}/  ("users", "groups", "read"), ("users", "groups", "update")

        :param str subpath: path to the operation with any common prefix/base path removed
        :param str method: HTTP method
        :param view: the view associated with the operation
        :rtype: list[str]
        """
        return self._gen.get_keys(subpath, method, view)

    def determine_path_prefix(self, paths):
        """
        Given a list of all paths, return the common prefix which should be
        discounted when generating a schema structure.

        This will be the longest common string that does not include that last
        component of the URL, or the last component before a path parameter.

        For example: ::

            /api/v1/users/
            /api/v1/users/{pk}/

        The path prefix is ``/api/v1/``.

        :param list[str] paths: list of paths
        :rtype: str
        """
        return self._gen.determine_path_prefix(paths)

    def should_include_endpoint(self, path, method, view, public):
        """Check if a given endpoint should be included in the resulting schema.

        :param str path: request path
        :param str method: http request method
        :param view: instantiated view callback
        :param bool public: if True, all endpoints are included regardless of access through `request`
        :returns: true if the view should be excluded
        :rtype: bool
        """
        return public or self._gen.has_view_permissions(path, method, view)

    def get_paths_object(self, paths):
        """Construct the Swagger Paths object.

        :param OrderedDict[str,openapi.PathItem] paths: mapping of paths to :class:`.PathItem` objects
        :returns: the :class:`.Paths` object
        :rtype: openapi.Paths
        """
        return openapi.Paths(paths=paths)

    def get_paths(self, endpoints, components, request, public):
        """Generate the Swagger Paths for the API from the given endpoints.

        :param dict endpoints: endpoints as returned by get_endpoints
        :param ReferenceResolver components: resolver/container for Swagger References
        :param Request request: the request made against the schema view; can be None
        :param bool public: if True, all endpoints are included regardless of access through `request`
        :returns: the :class:`.Paths` object and the longest common path prefix, as a 2-tuple
        :rtype: tuple[openapi.Paths,str]
        """
        if not endpoints:
            return openapi.Paths(paths={}), ''

        prefix = self.determine_path_prefix(list(endpoints.keys())) or ''
        assert '{' not in prefix, "base path cannot be templated in swagger 2.0"

        paths = OrderedDict()
        for path, (view_cls, methods) in sorted(endpoints.items()):
            operations = {}
            for method, view in methods:
                if not self.should_include_endpoint(path, method, view,
                                                    public):
                    continue

                operation = self.get_operation(view, path, prefix, method,
                                               components, request)
                if operation is not None:
                    operations[method.lower()] = operation

            if operations:
                # since the common prefix is used as the API basePath, it must be stripped
                # from individual paths when writing them into the swagger document
                path_suffix = path[len(prefix):]
                if not path_suffix.startswith('/'):
                    path_suffix = '/' + path_suffix
                paths[path_suffix] = self.get_path_item(
                    path, view_cls, operations)

        return self.get_paths_object(paths), prefix

    def get_operation(self, view, path, prefix, method, components, request):
        """Get an :class:`.Operation` for the given API endpoint (path, method). This method delegates to
        :meth:`~.inspectors.ViewInspector.get_operation` of a :class:`~.inspectors.ViewInspector` determined
        according to settings and :func:`@swagger_auto_schema <.swagger_auto_schema>` overrides.

        :param view: the view associated with this endpoint
        :param str path: the path component of the operation URL
        :param str prefix: common path prefix among all endpoints
        :param str method: the http method of the operation
        :param openapi.ReferenceResolver components: referenceable components
        :param Request request: the request made against the schema view; can be None
        :rtype: openapi.Operation
        """
        operation_keys = self.get_operation_keys(path[len(prefix):], method,
                                                 view)
        overrides = self.get_overrides(view, method)

        # the inspector class can be specified, in decreasing order of priorty,
        #   1. globaly via DEFAULT_AUTO_SCHEMA_CLASS
        view_inspector_cls = self.swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS
        #   2. on the view/viewset class
        view_inspector_cls = getattr(view, 'swagger_schema',
                                     view_inspector_cls)
        #   3. on the swagger_auto_schema decorator
        view_inspector_cls = overrides.get('auto_schema', view_inspector_cls)

        if view_inspector_cls is None:
            return None

        view_inspector = view_inspector_cls(view, path, method, components,
                                            request, overrides, operation_keys,
                                            self.swagger_settings)
        operation = view_inspector.get_operation(operation_keys)
        if operation is None:
            return None

        if 'consumes' in operation and set(operation.consumes) == set(
                self.consumes):
            del operation.consumes
        if 'produces' in operation and set(operation.produces) == set(
                self.produces):
            del operation.produces
        return operation

    def get_path_item(self, path, view_cls, operations):
        """Get a :class:`.PathItem` object that describes the parameters and operations related to a single path in the
        API.

        :param str path: the path
        :param type view_cls: the view that was bound to this path in urlpatterns
        :param dict[str,openapi.Operation] operations: operations defined on this path, keyed by lowercase HTTP method
        :rtype: openapi.PathItem
        """
        path_parameters = self.get_path_parameters(path, view_cls)
        return openapi.PathItem(parameters=path_parameters, **operations)

    def get_overrides(self, view, method):
        """Get overrides specified for a given operation.

        :param view: the view associated with the operation
        :param str method: HTTP method
        :return: a dictionary containing any overrides set by :func:`@swagger_auto_schema <.swagger_auto_schema>`
        :rtype: dict
        """
        method = method.lower()
        action = getattr(view, 'action', method)
        action_method = getattr(view, action, None)
        overrides = getattr(action_method, '_swagger_auto_schema', {})
        if method in overrides:
            overrides = overrides[method]

        return copy.deepcopy(overrides)

    def get_path_parameters(self, path, view_cls):
        """Return a list of Parameter instances corresponding to any templated path variables.

        :param str path: templated request path
        :param type view_cls: the view class associated with the path
        :return: path parameters
        :rtype: list[openapi.Parameter]
        """
        parameters = []
        queryset = get_queryset_from_view(view_cls)

        for variable in sorted(uritemplate.variables(path)):
            model, model_field = get_queryset_field(queryset, variable)
            attrs = get_basic_type_info(model_field) or {
                'type': openapi.TYPE_STRING
            }
            if getattr(
                    view_cls, 'lookup_field',
                    None) == variable and attrs['type'] == openapi.TYPE_STRING:
                attrs['pattern'] = getattr(view_cls, 'lookup_value_regex',
                                           attrs.get('pattern', None))

            if model_field and getattr(model_field, 'help_text', False):
                description = model_field.help_text
            elif model_field and getattr(model_field, 'primary_key', False):
                description = get_pk_description(model, model_field)
            else:
                description = None

            field = openapi.Parameter(name=variable,
                                      description=force_real_str(description),
                                      required=True,
                                      in_=openapi.IN_PATH,
                                      **attrs)
            parameters.append(field)

        return parameters
예제 #4
0
class OpenAPISchemaGenerator(object):
    endpoint_enumerator_class = EndpointEnumerator

    def __init__(self, name='API', version=''):
        self._gen = SchemaGenerator(name, None, '', None, None)
        self.version = version
        self.consumes = []
        self.produces = []

    @property
    def url(self):
        return self._gen.url

    def create_view(self, callback, method, request=None):
        """Create a view instance from a view callback as registered in urlpatterns.
        :param callback: view callback registered in urlpatterns
        :param str method: HTTP method
        :param request: request to bind to the view
        :type request: rest_framework.request.Request or None
        :return: the view instance
        """
        view = self._gen.create_view(callback, method, request)
        overrides = getattr(callback, '_swagger_auto_schema', None)
        if overrides is not None:
            # decorated function based view must have its decorator information passed on to the re-instantiated view
            for method, _ in overrides.items():
                view_method = getattr(view, method, None)
                if view_method is not None:  # pragma: no cover
                    setattr(view_method.__func__, '_swagger_auto_schema', overrides)

        setattr(view, 'swagger_fake_view', True)
        return view

    def get_schema(self, request=None, public=False):
        endpoints = self.get_endpoints(request)
        # return self.get_paths(endpoints, None, request, public=False)
        # import pdb; pdb.set_trace()
        register_api = self.get_register_api(endpoints, None, request, public=False)
        return self.handle_api_path(register_api)

    def get_endpoints(self, request):
        enumerator = self.endpoint_enumerator_class(self._gen.patterns, self._gen.urlconf, request=request)
        endpoints = []
        # import pdb; pdb.set_trace()
        enumerator.get_api_endpoints(final_arrays=endpoints)
        ret=[]
        for group in endpoints:
            #print(group)
            group_points=group['points']
            view_paths = defaultdict(list)
            view_cls = {}
            for path, method, callback in group_points:
                view = self.create_view(callback, method, request)
                # path = self.coerce_path(path, view)
                view_paths[path].append((method, view))
                view_cls[path] = callback.cls
            #返回的是一个json-dict
            ret.append({"prefix":group['prefix'],"points":{path: (view_cls[path], methods) for path, methods in view_paths.items()}})
        return ret

    def determine_path_prefix(self, paths):
        """
        Given a list of all paths, return the common prefix which should be
        discounted when generating a schema structure.
        This will be the longest common string that does not include that last
        component of the URL, or the last component before a path parameter.
        For example: ::
            /api/v1/users/
            /api/v1/users/{pk}/
        The path prefix is ``/api/v1/``.
        :param list[str] paths: list of paths
        :rtype: str
        """
        return self._gen.determine_path_prefix(paths)

    def _get_description_section(self, view, header, description):
        lines = [line for line in description.splitlines()]
        current_section = ''
        sections = {'': ''}

        for line in lines:
            if header_regex.match(line):
                current_section, seperator, lead = line.partition(':')
                sections[current_section] = lead.strip()
            else:
                sections[current_section] += '\n' + line

        # TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
        coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
        if header in sections:
            return sections[header].strip()
        if header in coerce_method_names:
            if coerce_method_names[header] in sections:
                return sections[coerce_method_names[header]].strip()
        return sections[''].strip()

    def split_summary_from_description(self, description):
        """Decide if and how to split a summary out of the given description. The default implementation
        uses the first paragraph of the description as a summary if it is less than 120 characters long.
        :param description: the full description to be analyzed
        :return: summary and description
        :rtype: (str,str)
        """
        # https://www.python.org/dev/peps/pep-0257/#multi-line-docstrings
        summary = None
        summary_max_len = 120  # OpenAPI 2.0 spec says summary should be under 120 characters
        sections = description.split('\n', 1)
        if len(sections) == 2:
            sections[0] = sections[0].strip()
            if len(sections[0]) < summary_max_len:
                summary, description = sections
                description = description.strip()

        return summary, description

    def handle_api_path(self, api_list):
        '''
        处理正则数据
        '''
        # path_pamas_pattern
        result = []
        pattern_param = re.compile('\{((?!\/).)*\}')
        pattern_str = '[^/]*'
        g = lambda pathregx: True if pattern_str in pathregx else False
        for api in api_list:
            origin_path = api['origin_path']
            regex_path = pattern_param.sub(pattern_str, origin_path)
            is_regex = g(regex_path)
            if is_regex:
                regex_path = "^" + regex_path + "$"

            method = api['method']
            desc = api["desc"]
            name = api["name"]

            regex_api = {
                "path": regex_path,
                "origin_path": origin_path,
                "method": method,
                "desc": desc[:500],
                "api_name": name[:30],
                "is_regex": is_regex,
            }
            result.append(regex_api)
            logger.info("path=%s, method=%s, regx=%s" % (origin_path, method, regex_path))
        return result

    def get_register_api(self, top, components, request, public):
        '''
        按照注册中心的数据格式, 重新构建数据结构
        '''
        if not top:
            return None
        result = []
        for category in top:
            # import pdb; pdb.set_trace()
            endpoints=category['points']
            if endpoints is None or len(endpoints)==0:
                continue
            for path, (view_cls, methods) in sorted(endpoints.items()):
                for method, view in methods:
                    method_lower=method.lower()
                    method_name = getattr(view, 'action', method_lower)
                    method_docstring = getattr(view, method_name, None).__doc__
                    if method_docstring:
                        method_docstring = self._get_description_section(view, method.lower(),
                                                                         formatting.dedent(smart_text(method_docstring)))
                    else:
                        method_docstring = self._get_description_section(view, getattr(view, 'action', method.lower()),
                                                                         view.get_view_description())
                    if not method_docstring:
                        method_docstring = "_"
                    (name, desc) = self.split_summary_from_description(method_docstring)
                    if name is None:
                            name = desc
                    _path = {
                        "name": name,
                        # "path": path,
                        "method": method_lower,
                        "origin_path": path,
                        # "is_regex": False,
                        "desc": desc,
                        # "version": ""
                    }
                    result.append(_path)
        return result

    def get_paths(self, top, components, request, public):
        """
        Generate the Swagger Paths for the API from the given endpoints.
        """
        if not top:
            return None

        ret=[]
        # import pdb; pdb.set_trace()
        for category in top:
            # print('~~~~~~~~~')
            endpoints=category['points']
            if endpoints is None or len(endpoints)==0:
                continue
            prefix = self.determine_path_prefix(list(endpoints.keys())) or ''

            assert '{' not in prefix, "base path cannot be templated in swagger 2.0"

            paths = {}
            for path, (view_cls, methods) in sorted(endpoints.items()):
                _path={}
                for method, view in methods:
                    # print(method)
                    # if not self.should_include_endpoint(path, method, view, public):
                    #     continue
                    method_lower=method.lower()
                    method_name = getattr(view, 'action', method_lower)
                    method_docstring = getattr(view, method_name, None).__doc__
                    if method_docstring:
                        method_docstring = self._get_description_section(view, method.lower(),
                                                                         formatting.dedent(smart_text(method_docstring)))
                    else:
                        method_docstring = self._get_description_section(view, getattr(view, 'action', method.lower()),
                                                                         view.get_view_description())
                    if not method_docstring:
                        method_docstring = "_"

                    (_name, _desc) = self.split_summary_from_description(method_docstring)
                    if _name is None:
                        _name = _desc
                    _path[method_lower]={"name": _name, "description": _desc}

                path_suffix = path[len(prefix):]
                if not path_suffix.startswith('/'):
                    path_suffix = '/' + path_suffix
                paths[path_suffix]=_path

            # #检查是否有basePath可以融合在一起
            # _current_len=len(ret)
            # for i in range(_current_len):
            #     _current_prefix=ret[i]['basePath']
            #     if prefix.find(_current_prefix)>=0:
            #         ret[i]['paths'].update(paths)
            #
            #         paths=None
            #         break
            #     elif _current_prefix.find(prefix)>=0:
            #         ret[i]['basePath']=prefix
            #         ret[i]['paths'].update(paths)
            #         paths=None
            #         break
            if not paths is None:
                ret.append({"basePath":prefix,"basePathArray":prefix.strip("/").split("/"),"paths":paths})

        return ret