class OpenAPISchemaGenerator(object): """ This class iterates over all registered API endpoints and returns an appropriate OpenAPI 2.0 compliant schema. Method implementations shamelessly stolen and adapted from rest-framework ``SchemaGenerator``. """ endpoint_enumerator_class = EndpointEnumerator reference_resolver_class = ReferenceResolver def __init__(self, info, version='', url=None, patterns=None, urlconf=None): """ :param openapi.Info info: information about the API :param str version: API version string; if omitted, `info.default_version` will be used :param str url: API scheme, host and port; if ``None`` is passed and ``DEFAULT_API_URL`` is not set, the url will be inferred from the request made against the schema view, so you should generally not need to set this parameter explicitly; if the empty string is passed, no host and scheme will be emitted If `url` is not ``None`` or the empty string, it must be a scheme-absolute uri (i.e. starting with http:// or https://), and any path component is ignored; See also: :ref:`documentation on base URL construction <custom-spec-base-url>` :param patterns: if given, only these patterns will be enumerated for inclusion in the API spec :param urlconf: if patterns is not given, use this urlconf to enumerate patterns; if not given, the default urlconf is used """ self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf) self.info = info self.version = version self.consumes = [] self.produces = [] if url is None and swagger_settings.DEFAULT_API_URL is not None: url = swagger_settings.DEFAULT_API_URL if url: parsed_url = urlparse.urlparse(url) if parsed_url.scheme not in ('http', 'https') or not parsed_url.netloc: raise SwaggerGenerationError("`url` must be an absolute HTTP(S) url") if parsed_url.path: logger.warning("path component of api base URL %s is ignored; use FORCE_SCRIPT_NAME instead" % url) @property def url(self): return self._gen.url def get_security_definitions(self): """Get the security schemes for this API. This determines what is usable in security requirements, and helps clients configure their authorization credentials. :return: the security schemes usable with this API :rtype: dict[str,dict] or None """ security_definitions = swagger_settings.SECURITY_DEFINITIONS if security_definitions is not None: security_definitions = SwaggerDict._as_odict(security_definitions, {}) return security_definitions def get_security_requirements(self, security_definitions): """Get the base (global) security requirements of the API. This is never called if :meth:`.get_security_definitions` returns `None`. :param security_definitions: security definitions as returned by :meth:`.get_security_definitions` :return: the security schemes accepted by default :rtype: list[dict[str,list[str]]] or None """ security_requirements = swagger_settings.SECURITY_REQUIREMENTS if security_requirements is None: security_requirements = [{security_scheme: []} for security_scheme in security_definitions] security_requirements = [SwaggerDict._as_odict(sr, {}) for sr in security_requirements] security_requirements = sorted(security_requirements, key=list) return security_requirements def get_schema(self, request=None, public=False): """Generate a :class:`.Swagger` object representing the API schema. :param request: the request used for filtering accessible endpoints and finding the spec URI :type request: rest_framework.request.Request or None :param bool public: if True, all endpoints are included regardless of access through `request` :return: the generated Swagger specification :rtype: openapi.Swagger """ endpoints = self.get_endpoints(request) components = self.reference_resolver_class(openapi.SCHEMA_DEFINITIONS, force_init=True) self.consumes = get_consumes(api_settings.DEFAULT_PARSER_CLASSES) self.produces = get_produces(api_settings.DEFAULT_RENDERER_CLASSES) paths, prefix = self.get_paths(endpoints, components, request, public) security_definitions = self.get_security_definitions() if security_definitions: security_requirements = self.get_security_requirements(security_definitions) else: security_requirements = None url = self.url if url is None and request is not None: url = request.build_absolute_uri() return openapi.Swagger( info=self.info, paths=paths, consumes=self.consumes or None, produces=self.produces or None, security_definitions=security_definitions, security=security_requirements, _url=url, _prefix=prefix, _version=self.version, **dict(components) ) def create_view(self, callback, method, request=None): """Create a view instance from a view callback as registered in urlpatterns. :param callback: view callback registered in urlpatterns :param str method: HTTP method :param request: request to bind to the view :type request: rest_framework.request.Request or None :return: the view instance """ view = self._gen.create_view(callback, method, request) overrides = getattr(callback, '_swagger_auto_schema', None) if overrides is not None: # decorated function based view must have its decorator information passed on to the re-instantiated view for method, _ in overrides.items(): view_method = getattr(view, method, None) if view_method is not None: # pragma: no cover setattr(view_method.__func__, '_swagger_auto_schema', overrides) setattr(view, 'swagger_fake_view', True) return view def coerce_path(self, path, view): """Coerce {pk} path arguments into the name of the model field, where possible. This is cleaner for an external representation (i.e. "this is an identifier", not "this is a database primary key"). :param str path: the path :param rest_framework.views.APIView view: associated view :rtype: str """ if '{pk}' not in path: return path model = getattr(get_queryset_from_view(view), 'model', None) if model: field_name = get_pk_name(model) else: field_name = 'id' return path.replace('{pk}', '{%s}' % field_name) def get_endpoints(self, request): """Iterate over all the registered endpoints in the API and return a fake view with the right parameters. :param request: request to bind to the endpoint views :type request: rest_framework.request.Request or None :return: {path: (view_class, list[(http_method, view_instance)]) :rtype: dict[str,(type,list[(str,rest_framework.views.APIView)])] """ enumerator = self.endpoint_enumerator_class(self._gen.patterns, self._gen.urlconf, request=request) endpoints = enumerator.get_api_endpoints() view_paths = defaultdict(list) view_cls = {} for path, method, callback in endpoints: view = self.create_view(callback, method, request) path = self.coerce_path(path, view) view_paths[path].append((method, view)) view_cls[path] = callback.cls return {path: (view_cls[path], methods) for path, methods in view_paths.items()} def get_operation_keys(self, subpath, method, view): """Return a list of keys that should be used to group an operation within the specification. :: /users/ ("users", "list"), ("users", "create") /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") /users/enabled/ ("users", "enabled") # custom viewset list action /users/{pk}/star/ ("users", "star") # custom viewset detail action /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update") :param str subpath: path to the operation with any common prefix/base path removed :param str method: HTTP method :param view: the view associated with the operation :rtype: list[str] """ return self._gen.get_keys(subpath, method, view) def determine_path_prefix(self, paths): """ Given a list of all paths, return the common prefix which should be discounted when generating a schema structure. This will be the longest common string that does not include that last component of the URL, or the last component before a path parameter. For example: :: /api/v1/users/ /api/v1/users/{pk}/ The path prefix is ``/api/v1/``. :param list[str] paths: list of paths :rtype: str """ return self._gen.determine_path_prefix(paths) def should_include_endpoint(self, path, method, view, public): """Check if a given endpoint should be included in the resulting schema. :param str path: request path :param str method: http request method :param view: instantiated view callback :param bool public: if True, all endpoints are included regardless of access through `request` :returns: true if the view should be excluded :rtype: bool """ return public or self._gen.has_view_permissions(path, method, view) def get_paths_object(self, paths): """Construct the Swagger Paths object. :param OrderedDict[str,openapi.PathItem] paths: mapping of paths to :class:`.PathItem` objects :returns: the :class:`.Paths` object :rtype: openapi.Paths """ return openapi.Paths(paths=paths) def get_paths(self, endpoints, components, request, public): """Generate the Swagger Paths for the API from the given endpoints. :param dict endpoints: endpoints as returned by get_endpoints :param ReferenceResolver components: resolver/container for Swagger References :param Request request: the request made against the schema view; can be None :param bool public: if True, all endpoints are included regardless of access through `request` :returns: the :class:`.Paths` object and the longest common path prefix, as a 2-tuple :rtype: tuple[openapi.Paths,str] """ if not endpoints: return openapi.Paths(paths={}), '' prefix = self.determine_path_prefix(list(endpoints.keys())) or '' assert '{' not in prefix, "base path cannot be templated in swagger 2.0" paths = OrderedDict() for path, (view_cls, methods) in sorted(endpoints.items()): operations = {} for method, view in methods: if not self.should_include_endpoint(path, method, view, public): continue operation = self.get_operation(view, path, prefix, method, components, request) if operation is not None: operations[method.lower()] = operation if operations: # since the common prefix is used as the API basePath, it must be stripped # from individual paths when writing them into the swagger document path_suffix = path[len(prefix):] if not path_suffix.startswith('/'): path_suffix = '/' + path_suffix paths[path_suffix] = self.get_path_item(path, view_cls, operations) return self.get_paths_object(paths), prefix def get_operation(self, view, path, prefix, method, components, request): """Get an :class:`.Operation` for the given API endpoint (path, method). This method delegates to :meth:`~.inspectors.ViewInspector.get_operation` of a :class:`~.inspectors.ViewInspector` determined according to settings and :func:`@swagger_auto_schema <.swagger_auto_schema>` overrides. :param view: the view associated with this endpoint :param str path: the path component of the operation URL :param str prefix: common path prefix among all endpoints :param str method: the http method of the operation :param openapi.ReferenceResolver components: referenceable components :param Request request: the request made against the schema view; can be None :rtype: openapi.Operation """ operation_keys = self.get_operation_keys(path[len(prefix):], method, view) overrides = self.get_overrides(view, method) # the inspector class can be specified, in decreasing order of priorty, # 1. globaly via DEFAULT_AUTO_SCHEMA_CLASS view_inspector_cls = swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS # 2. on the view/viewset class view_inspector_cls = getattr(view, 'swagger_schema', view_inspector_cls) # 3. on the swagger_auto_schema decorator view_inspector_cls = overrides.get('auto_schema', view_inspector_cls) if view_inspector_cls is None: return None view_inspector = view_inspector_cls(view, path, method, components, request, overrides, operation_keys) operation = view_inspector.get_operation(operation_keys) if operation is None: return None if 'consumes' in operation and set(operation.consumes) == set(self.consumes): del operation.consumes if 'produces' in operation and set(operation.produces) == set(self.produces): del operation.produces return operation def get_path_item(self, path, view_cls, operations): """Get a :class:`.PathItem` object that describes the parameters and operations related to a single path in the API. :param str path: the path :param type view_cls: the view that was bound to this path in urlpatterns :param dict[str,openapi.Operation] operations: operations defined on this path, keyed by lowercase HTTP method :rtype: openapi.PathItem """ path_parameters = self.get_path_parameters(path, view_cls) return openapi.PathItem(parameters=path_parameters, **operations) def get_overrides(self, view, method): """Get overrides specified for a given operation. :param view: the view associated with the operation :param str method: HTTP method :return: a dictionary containing any overrides set by :func:`@swagger_auto_schema <.swagger_auto_schema>` :rtype: dict """ method = method.lower() action = getattr(view, 'action', method) action_method = getattr(view, action, None) overrides = getattr(action_method, '_swagger_auto_schema', {}) if method in overrides: overrides = overrides[method] return copy.deepcopy(overrides) def get_path_parameters(self, path, view_cls): """Return a list of Parameter instances corresponding to any templated path variables. :param str path: templated request path :param type view_cls: the view class associated with the path :return: path parameters :rtype: list[openapi.Parameter] """ parameters = [] queryset = get_queryset_from_view(view_cls) for variable in sorted(uritemplate.variables(path)): model, model_field = get_queryset_field(queryset, variable) attrs = get_basic_type_info(model_field) or {'type': openapi.TYPE_STRING} if getattr(view_cls, 'lookup_field', None) == variable and attrs['type'] == openapi.TYPE_STRING: attrs['pattern'] = getattr(view_cls, 'lookup_value_regex', attrs.get('pattern', None)) if model_field and getattr(model_field, 'help_text', False): description = model_field.help_text elif model_field and getattr(model_field, 'primary_key', False): description = get_pk_description(model, model_field) else: description = None field = openapi.Parameter( name=variable, description=force_real_str(description), required=True, in_=openapi.IN_PATH, **attrs ) parameters.append(field) return parameters
class OpenAPISchemaGenerator(object): """ This class iterates over all registered API endpoints and returns an appropriate OpenAPI 2.0 compliant schema. Method implementations shamelessly stolen and adapted from rest-framework ``SchemaGenerator``. """ endpoint_enumerator_class = EndpointEnumerator def __init__(self, info, version, url=None, patterns=None, urlconf=None): """ :param .Info info: information about the API :param str version: API version string, takes preedence over the version in `info` :param str url: API :param patterns: if given, only these patterns will be enumerated for inclusion in the API spec :param urlconf: if patterns is not given, use this urlconf to enumerate patterns; if not given, the default urlconf is used """ self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf) self.info = info self.version = version @property def url(self): return self._gen.url def get_schema(self, request=None, public=False): """Generate a :class:`.Swagger` object representing the API schema. :param Request request: the request used for filtering accesible endpoints and finding the spec URI :param bool public: if True, all endpoints are included regardless of access through `request` :return: the generated Swagger specification :rtype: openapi.Swagger """ endpoints = self.get_endpoints(request) endpoints = self.replace_version(endpoints, request) components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS) paths = self.get_paths(endpoints, components, request, public) url = self.url if not url and request is not None: url = request.build_absolute_uri() return openapi.Swagger(info=self.info, paths=paths, _url=url, _version=self.version, **dict(components)) def create_view(self, callback, method, request=None): """Create a view instance from a view callback as registered in urlpatterns. :param callable callback: view callback registered in urlpatterns :param str method: HTTP method :param rest_framework.request.Request request: request to bind to the view :return: the view instance """ view = self._gen.create_view(callback, method, request) overrides = getattr(callback, '_swagger_auto_schema', None) if overrides is not None: # decorated function based view must have its decorator information passed on to the re-instantiated view for method, _ in overrides.items(): view_method = getattr(view, method, None) if view_method is not None: # pragma: no cover setattr(view_method.__func__, '_swagger_auto_schema', overrides) return view def replace_version(self, endpoints, request): """If ``request.version`` is not ``None``, replace the version parameter in the path of any endpoints using ``URLPathVersioning`` as a versioning class. :param dict endpoints: endpoints as returned by :meth:`.get_endpoints` :param Request request: the request made against the schema view :return: endpoints with modified paths """ version = getattr(request, 'version', None) if version is None: return endpoints new_endpoints = {} for path, endpoint in endpoints.items(): view_cls = endpoint[0] versioning_class = getattr(view_cls, 'versioning_class', None) version_param = getattr(versioning_class, 'version_param', 'version') if versioning_class is not None and issubclass( versioning_class, versioning.URLPathVersioning): path = path.replace('{%s}' % version_param, version) new_endpoints[path] = endpoint return new_endpoints def get_endpoints(self, request): """Iterate over all the registered endpoints in the API and return a fake view with the right parameters. :param rest_framework.request.Request request: request to bind to the endpoint views :return: {path: (view_class, list[(http_method, view_instance)]) :rtype: dict """ enumerator = self.endpoint_enumerator_class(self._gen.patterns, self._gen.urlconf) endpoints = enumerator.get_api_endpoints() view_paths = defaultdict(list) view_cls = {} for path, method, callback in endpoints: view = self.create_view(callback, method, request) path = self._gen.coerce_path(path, method, view) view_paths[path].append((method, view)) view_cls[path] = callback.cls return { path: (view_cls[path], methods) for path, methods in view_paths.items() } def get_operation_keys(self, subpath, method, view): """Return a list of keys that should be used to group an operation within the specification. :: /users/ ("users", "list"), ("users", "create") /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") /users/enabled/ ("users", "enabled") # custom viewset list action /users/{pk}/star/ ("users", "star") # custom viewset detail action /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update") :param str subpath: path to the operation with any common prefix/base path removed :param str method: HTTP method :param view: the view associated with the operation :rtype: tuple """ return self._gen.get_keys(subpath, method, view) def determine_path_prefix(self, paths): """ Given a list of all paths, return the common prefix which should be discounted when generating a schema structure. This will be the longest common string that does not include that last component of the URL, or the last component before a path parameter. For example: :: /api/v1/users/ /api/v1/users/{pk}/ The path prefix is ``/api/v1/``. :param list[str] paths: list of paths :rtype: str """ return self._gen.determine_path_prefix(paths) def get_paths(self, endpoints, components, request, public): """Generate the Swagger Paths for the API from the given endpoints. :param dict endpoints: endpoints as returned by get_endpoints :param ReferenceResolver components: resolver/container for Swagger References :param Request request: the request made against the schema view; can be None :param bool public: if True, all endpoints are included regardless of access through `request` :rtype: openapi.Paths """ if not endpoints: return openapi.Paths(paths={}) prefix = self.determine_path_prefix(list(endpoints.keys())) paths = OrderedDict() for path, (view_cls, methods) in sorted(endpoints.items()): operations = {} for method, view in methods: if not public and not self._gen.has_view_permissions( path, method, view): continue operations[method.lower()] = self.get_operation( view, path, prefix, method, components, request) if operations: paths[path] = self.get_path_item(path, view_cls, operations) return openapi.Paths(paths=paths) def get_operation(self, view, path, prefix, method, components, request): """Get an :class:`.Operation` for the given API endpoint (path, method). This method delegates to :meth:`~.inspectors.ViewInspector.get_operation` of a :class:`~.inspectors.ViewInspector` determined according to settings and :func:`@swagger_auto_schema <.swagger_auto_schema>` overrides. :param view: the view associated with this endpoint :param str path: the path component of the operation URL :param str prefix: common path prefix among all endpoints :param str method: the http method of the operation :param openapi.ReferenceResolver components: referenceable components :param Request request: the request made against the schema view; can be None :rtype: openapi.Operation """ operation_keys = self.get_operation_keys(path[len(prefix):], method, view) overrides = self.get_overrides(view, method) # the inspector class can be specified, in decreasing order of priorty, # 1. globaly via DEFAULT_AUTO_SCHEMA_CLASS view_inspector_cls = swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS # 2. on the view/viewset class view_inspector_cls = getattr(view, 'swagger_schema', view_inspector_cls) # 3. on the swagger_auto_schema decorator view_inspector_cls = overrides.get('auto_schema', view_inspector_cls) view_inspector = view_inspector_cls(view, path, method, components, request, overrides) return view_inspector.get_operation(operation_keys) def get_path_item(self, path, view_cls, operations): """Get a :class:`.PathItem` object that describes the parameters and operations related to a single path in the API. :param str path: the path :param type view_cls: the view that was bound to this path in urlpatterns :param dict[str,openapi.Operation] operations: operations defined on this path, keyed by lowercase HTTP method :rtype: openapi.PathItem """ path_parameters = self.get_path_parameters(path, view_cls) return openapi.PathItem(parameters=path_parameters, **operations) def get_overrides(self, view, method): """Get overrides specified for a given operation. :param view: the view associated with the operation :param str method: HTTP method :return: a dictionary containing any overrides set by :func:`@swagger_auto_schema <.swagger_auto_schema>` :rtype: dict """ method = method.lower() action = getattr(view, 'action', method) action_method = getattr(view, action, None) overrides = getattr(action_method, '_swagger_auto_schema', {}) if method in overrides: overrides = overrides[method] return overrides def get_path_parameters(self, path, view_cls): """Return a list of Parameter instances corresponding to any templated path variables. :param str path: templated request path :param type view_cls: the view class associated with the path :return: path parameters :rtype: list[openapi.Parameter] """ parameters = [] queryset = getattr(view_cls, 'queryset', None) model = getattr(getattr(view_cls, 'queryset', None), 'model', None) for variable in uritemplate.variables(path): model, model_field = get_queryset_field(queryset, variable) attrs = get_basic_type_info(model_field) or { 'type': openapi.TYPE_STRING } if hasattr(view_cls, 'lookup_value_regex') and getattr( view_cls, 'lookup_field', None) == variable: attrs['pattern'] = view_cls.lookup_value_regex if model_field and model_field.help_text: description = force_text(model_field.help_text) elif model_field and model_field.primary_key: description = get_pk_description(model, model_field) else: description = None field = openapi.Parameter(name=variable, description=description, required=True, in_=openapi.IN_PATH, **attrs) parameters.append(field) return parameters
class OpenAPISchemaGenerator(object): """ This class iterates over all registered API endpoints and returns an appropriate OpenAPI 2.0 compliant schema. Method implementations shamelessly stolen and adapted from rest_framework SchemaGenerator. """ def __init__(self, info, version, url=None, patterns=None, urlconf=None): """ :param .Info info: information about the API :param str version: API version string, takes preedence over the version in `info` :param str url: API :param patterns: if given, only these patterns will be enumerated for inclusion in the API spec :param urlconf: if patterns is not given, use this urlconf to enumerate patterns; if not given, the default urlconf is used """ self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf) self.info = info self.version = version def get_schema(self, request=None, public=False): """Generate an :class:`.Swagger` representing the API schema. :param rest_framework.request.Request request: the request used for filtering accesible endpoints and finding the spec URI :param bool public: if True, all endpoints are included regardless of access through `request` :return: the generated Swagger specification :rtype: openapi.Swagger """ endpoints = self.get_endpoints(None if public else request) components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS) paths = self.get_paths(endpoints, components) url = self._gen.url if not url and request is not None: url = request.build_absolute_uri() return openapi.Swagger(info=self.info, paths=paths, _url=url, _version=self.version, **dict(components)) def create_view(self, callback, method, request=None): """Create a view instance from a view callback as registered in urlpatterns. :param callable callback: view callback registered in urlpatterns :param str method: HTTP method :param rest_framework.request.Request request: request to bind to the view :return: the view instance """ view = self._gen.create_view(callback, method, request) overrides = getattr(callback, 'swagger_auto_schema', None) if overrides is not None: # decorated function based view must have its decorator information passed on to the re-instantiated view for method, _ in overrides.items(): view_method = getattr(view, method, None) if view_method is not None: # pragma: no cover setattr(view_method.__func__, 'swagger_auto_schema', overrides) return view def get_endpoints(self, request=None): """Iterate over all the registered endpoints in the API. :param rest_framework.request.Request request: used for returning only endpoints available to the given request :return: {path: (view_class, list[(http_method, view_instance)]) :rtype: dict """ inspector = self._gen.endpoint_inspector_cls(self._gen.patterns, self._gen.urlconf) endpoints = inspector.get_api_endpoints() view_paths = defaultdict(list) view_cls = {} for path, method, callback in endpoints: view = self.create_view(callback, method, request) path = self._gen.coerce_path(path, method, view) view_paths[path].append((method, view)) view_cls[path] = callback.cls return { path: (view_cls[path], methods) for path, methods in view_paths.items() } def get_operation_keys(self, subpath, method, view): """Return a list of keys that should be used to group an operation within the specification. :: /users/ ("users", "list"), ("users", "create") /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") /users/enabled/ ("users", "enabled") # custom viewset list action /users/{pk}/star/ ("users", "star") # custom viewset detail action /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update") :param str subpath: path to the operation with any common prefix/base path removed :param str method: HTTP method :param view: the view associated with the operation :rtype: tuple """ return self._gen.get_keys(subpath, method, view) def get_paths(self, endpoints, components): """Generate the Swagger Paths for the API from the given endpoints. :param dict endpoints: endpoints as returned by get_endpoints :param ReferenceResolver components: resolver/container for Swagger References :rtype: openapi.Paths """ if not endpoints: return openapi.Paths(paths={}) prefix = self._gen.determine_path_prefix(endpoints.keys()) paths = OrderedDict() default_schema_cls = SwaggerAutoSchema for path, (view_cls, methods) in sorted(endpoints.items()): path_parameters = self.get_path_parameters(path, view_cls) operations = {} for method, view in methods: if not self._gen.has_view_permissions(path, method, view): continue operation_keys = self.get_operation_keys( path[len(prefix):], method, view) overrides = self.get_overrides(view, method) auto_schema_cls = overrides.get('auto_schema', default_schema_cls) schema = auto_schema_cls(view, path, method, overrides, components) operations[method.lower()] = schema.get_operation( operation_keys) if operations: paths[path] = openapi.PathItem(parameters=path_parameters, **operations) return openapi.Paths(paths=paths) def get_overrides(self, view, method): """Get overrides specified for a given operation. :param view: the view associated with the operation :param str method: HTTP method :return: a dictionary containing any overrides set by :func:`@swagger_auto_schema <.swagger_auto_schema>` :rtype: dict """ method = method.lower() action = getattr(view, 'action', method) action_method = getattr(view, action, None) overrides = getattr(action_method, 'swagger_auto_schema', {}) if method in overrides: overrides = overrides[method] return overrides def get_path_parameters(self, path, view_cls): """Return a list of Parameter instances corresponding to any templated path variables. :param str path: templated request path :param type view_cls: the view class associated with the path :return: path parameters :rtype: list[openapi.Parameter] """ parameters = [] model = getattr(getattr(view_cls, 'queryset', None), 'model', None) for variable in uritemplate.variables(path): pattern = None type = openapi.TYPE_STRING description = None if model is not None: # Attempt to infer a field description if possible. try: model_field = model._meta.get_field(variable) except Exception: # pragma: no cover model_field = None if model_field is not None and model_field.help_text: description = force_text(model_field.help_text) elif model_field is not None and model_field.primary_key: description = get_pk_description(model, model_field) if hasattr(view_cls, 'lookup_value_regex') and getattr( view_cls, 'lookup_field', None) == variable: pattern = view_cls.lookup_value_regex elif isinstance(model_field, django.db.models.AutoField): type = openapi.TYPE_INTEGER field = openapi.Parameter( name=variable, required=True, in_=openapi.IN_PATH, type=type, pattern=pattern, description=description, ) parameters.append(field) return parameters
class OpenAPISchemaGenerator(object): endpoint_enumerator_class = EndpointEnumerator def __init__(self, name='API', version=''): self._gen = SchemaGenerator(name, None, '', None, None) self.version = version self.consumes = [] self.produces = [] @property def url(self): return self._gen.url def create_view(self, callback, method, request=None): """Create a view instance from a view callback as registered in urlpatterns. :param callback: view callback registered in urlpatterns :param str method: HTTP method :param request: request to bind to the view :type request: rest_framework.request.Request or None :return: the view instance """ view = self._gen.create_view(callback, method, request) overrides = getattr(callback, '_swagger_auto_schema', None) if overrides is not None: # decorated function based view must have its decorator information passed on to the re-instantiated view for method, _ in overrides.items(): view_method = getattr(view, method, None) if view_method is not None: # pragma: no cover setattr(view_method.__func__, '_swagger_auto_schema', overrides) setattr(view, 'swagger_fake_view', True) return view def get_schema(self, request=None, public=False): endpoints = self.get_endpoints(request) return self.get_paths(endpoints, None, request, public=False) def get_endpoints(self, request): enumerator = self.endpoint_enumerator_class(self._gen.patterns, self._gen.urlconf, request=request) endpoints = [] enumerator.get_api_endpoints(final_arrays=endpoints) ret = [] for group in endpoints: #print(group) group_points = group['points'] view_paths = defaultdict(list) view_cls = {} for path, method, callback in group_points: view = self.create_view(callback, method, request) # path = self.coerce_path(path, view) view_paths[path].append((method, view)) view_cls[path] = callback.cls #返回的是一个json-dict ret.append({ "prefix": group['prefix'], "points": { path: (view_cls[path], methods) for path, methods in view_paths.items() } }) return ret def determine_path_prefix(self, paths): """ Given a list of all paths, return the common prefix which should be discounted when generating a schema structure. This will be the longest common string that does not include that last component of the URL, or the last component before a path parameter. For example: :: /api/v1/users/ /api/v1/users/{pk}/ The path prefix is ``/api/v1/``. :param list[str] paths: list of paths :rtype: str """ return self._gen.determine_path_prefix(paths) def _get_description_section(self, view, header, description): lines = [line for line in description.splitlines()] current_section = '' sections = {'': ''} for line in lines: if header_regex.match(line): current_section, seperator, lead = line.partition(':') sections[current_section] = lead.strip() else: sections[current_section] += '\n' + line # TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys` coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES if header in sections: return sections[header].strip() if header in coerce_method_names: if coerce_method_names[header] in sections: return sections[coerce_method_names[header]].strip() return sections[''].strip() def split_summary_from_description(self, description): """Decide if and how to split a summary out of the given description. The default implementation uses the first paragraph of the description as a summary if it is less than 120 characters long. :param description: the full description to be analyzed :return: summary and description :rtype: (str,str) """ # https://www.python.org/dev/peps/pep-0257/#multi-line-docstrings summary = None summary_max_len = 120 # OpenAPI 2.0 spec says summary should be under 120 characters sections = description.split('\n', 1) if len(sections) == 2: sections[0] = sections[0].strip() if len(sections[0]) < summary_max_len: summary, description = sections description = description.strip() return summary, description def get_paths(self, top, components, request, public): """ Generate the Swagger Paths for the API from the given endpoints. """ if not top: return None ret = [] for category in top: print('~~~~~~~~~') endpoints = category['points'] if endpoints is None or len(endpoints) == 0: continue prefix = self.determine_path_prefix(list(endpoints.keys())) or '' assert '{' not in prefix, "base path cannot be templated in swagger 2.0" paths = {} for path, (view_cls, methods) in sorted(endpoints.items()): _path = {} for method, view in methods: # print(method) # if not self.should_include_endpoint(path, method, view, public): # continue method_lower = method.lower() method_name = getattr(view, 'action', method_lower) method_docstring = getattr(view, method_name, None).__doc__ if method_docstring: method_docstring = self._get_description_section( view, method.lower(), formatting.dedent(smart_text(method_docstring))) else: method_docstring = self._get_description_section( view, getattr(view, 'action', method.lower()), view.get_view_description()) if not method_docstring: method_docstring = "_" (_name, _desc ) = self.split_summary_from_description(method_docstring) if _name is None: _name = _desc _path[method_lower] = {"name": _name, "description": _desc} path_suffix = path[len(prefix):] if not path_suffix.startswith('/'): path_suffix = '/' + path_suffix paths[path_suffix] = _path # #检查是否有basePath可以融合在一起 # _current_len=len(ret) # for i in range(_current_len): # _current_prefix=ret[i]['basePath'] # if prefix.find(_current_prefix)>=0: # ret[i]['paths'].update(paths) # # paths=None # break # elif _current_prefix.find(prefix)>=0: # ret[i]['basePath']=prefix # ret[i]['paths'].update(paths) # paths=None # break if not paths is None: ret.append({ "basePath": prefix, "basePathArray": prefix.strip("/").split("/"), "paths": paths }) return ret