예제 #1
0
def test_model_name_maps():
    create_testing_submodules()
    from pydantic_schema_test.modulea.modela import Model as ModelA
    from pydantic_schema_test.moduleb.modelb import Model as ModelB
    from pydantic_schema_test.modulec.modelc import Model as ModelC
    from pydantic_schema_test.moduled.modeld import Model as ModelD

    class Foo(BaseModel):
        a: str

    class Bar(BaseModel):
        b: Foo

    class Baz(BaseModel):
        c: Bar

    flat_models = get_flat_models_from_models([Baz, ModelA, ModelB, ModelC, ModelD])
    model_name_map = get_model_name_map(flat_models)
    assert model_name_map == {
        Foo: 'Foo',
        Bar: 'Bar',
        Baz: 'Baz',
        ModelA: 'pydantic_schema_test__modulea__modela__Model',
        ModelB: 'pydantic_schema_test__moduleb__modelb__Model',
        ModelC: 'pydantic_schema_test__modulec__modelc__Model',
    }
예제 #2
0
def get_model_mapper(model, stoppage=None, full=True):
    """Get a dictionary of name: class for all the objects in model."""
    model = get_model(model)
    flat_models = get_flat_models_from_model(model)

    # this is the list of all the referenced objects
    model_name_map = get_model_name_map(flat_models)
    # flip the dictionary so I can access each class by name
    model_name_map = {v: k for k, v in model_name_map.items()}

    if full:
        if not stoppage:
            stoppage = set(
                ['NoExtraBaseModel', 'ModelMetaclass', 'BaseModel', 'object'])
        # Pydantic does not necessarily add all the baseclasses to the OpenAPI
        # documentation. We check all of them and them to the list if they are not
        # already added
        models = list(model_name_map.values())
        for model in models:
            for cls in type.mro(model):
                if cls.__name__ in stoppage:
                    break
                if cls.__name__ not in model_name_map:
                    model_name_map[cls.__name__] = cls

    return model_name_map
예제 #3
0
파일: utils.py 프로젝트: xaviml/fastapi
def get_openapi(
    *,
    title: str,
    version: str,
    openapi_version: str = "3.0.2",
    description: Optional[str] = None,
    routes: Sequence[BaseRoute],
    tags: Optional[List[Dict[str, Any]]] = None,
    servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
    terms_of_service: Optional[str] = None,
    contact: Optional[Dict[str, Union[str, Any]]] = None,
    license_info: Optional[Dict[str, Union[str, Any]]] = None,
) -> Dict[str, Any]:
    info: Dict[str, Any] = {"title": title, "version": version}
    if description:
        info["description"] = description
    if terms_of_service:
        info["termsOfService"] = terms_of_service
    if contact:
        info["contact"] = contact
    if license_info:
        info["license"] = license_info
    output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
    if servers:
        output["servers"] = servers
    components: Dict[str, Dict[str, Any]] = {}
    paths: Dict[str, Dict[str, Any]] = {}
    operation_ids: Set[str] = set()
    flat_models = get_flat_models_from_routes(routes)
    model_name_map = get_model_name_map(flat_models)
    definitions = get_model_definitions(flat_models=flat_models,
                                        model_name_map=model_name_map)
    for route in routes:
        if isinstance(route, routing.APIRoute):
            result = get_openapi_path(route=route,
                                      model_name_map=model_name_map,
                                      operation_ids=operation_ids)
            if result:
                path, security_schemes, path_definitions = result
                if path:
                    paths.setdefault(route.path_format, {}).update(path)
                if security_schemes:
                    components.setdefault("securitySchemes",
                                          {}).update(security_schemes)
                if path_definitions:
                    definitions.update(path_definitions)
    if definitions:
        components["schemas"] = {
            k: definitions[k]
            for k in sorted(definitions)
        }
    if components:
        output["components"] = components
    output["paths"] = paths
    if tags:
        output["tags"] = tags
    return jsonable_encoder(OpenAPI(**output),
                            by_alias=True,
                            exclude_none=True)  # type: ignore
예제 #4
0
def get_openapi(
    *,
    title: str,
    version: str,
    openapi_version: str = "3.0.2",
    description: Optional[str] = None,
    routes: Sequence[BaseRoute],
    tags: Optional[List[Dict[str, Any]]] = None,
    servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
) -> Dict:
    info = {"title": title, "version": version}
    if description:
        info["description"] = description
    output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
    if servers:
        output["servers"] = servers
    components: Dict[str, Dict] = {}
    paths: Dict[str, Dict] = {}
    flat_models = get_flat_models_from_routes(routes)
    # ignore mypy error until enum schemas are released
    model_name_map = get_model_name_map(flat_models)  # type: ignore
    # ignore mypy error until enum schemas are released
    definitions = get_model_definitions(
        flat_models=flat_models,
        model_name_map=model_name_map  # type: ignore
    )
    for route in routes:
        if isinstance(route, routing.APIRoute):
            result = get_openapi_path(route=route,
                                      model_name_map=model_name_map)
            if result:
                path, security_schemes, path_definitions = result
                if path:
                    old_path = paths.get(route.path_format, {})  # New
                    new_path = conservative_merger.merge(old_path, path)  # New
                    paths[route.path_format] = new_path  # New
                    # paths.setdefault(route.path_format, {}).update(path)  # Old
                if security_schemes:
                    components.setdefault("securitySchemes",
                                          {}).update(security_schemes)
                if path_definitions:
                    definitions.update(path_definitions)
    if definitions:
        definitions.update(geojson)
        components["schemas"] = {
            k: definitions[k]
            for k in sorted(definitions)
        }
    if components:
        output["components"] = components
    output["paths"] = paths
    if tags:
        output["tags"] = tags
    return jsonable_encoder(OpenAPI(**output),
                            by_alias=True,
                            exclude_none=True)
예제 #5
0
def get_openapi(*,
                title: str,
                version: str,
                openapi_version: str = "3.0.2",
                description: str = None,
                routes: Sequence[BaseRoute],
                openapi_prefix: str = "") -> Dict:
    info = {"title": title, "version": version}
    if description:
        info["description"] = description
    output = {"openapi": openapi_version, "info": info}
    components: Dict[str, Dict] = {}
    paths: Dict[str, Dict] = {}
    flat_models = get_flat_models_from_routes(routes)
    model_name_map = get_model_name_map(flat_models)
    definitions = get_model_definitions(flat_models=flat_models,
                                        model_name_map=model_name_map)
    for route in routes:
        if isinstance(route, routing.APIRoute):
            # 将path和query中的ObjectId显示为str
            for field in chain(route.dependant.path_params,
                               route.dependant.query_params):
                replace_field_type(field)
            result = get_openapi_path(route=route,
                                      model_name_map=model_name_map)
            if result:
                path, security_schemes, path_definitions = result
                if path:
                    # 在schema中删除auth_token,authorization头
                    for _, op in path.items():
                        if op.get('parameters'):
                            op['parameters'] = list(
                                filter(
                                    lambda p: p['name'] not in
                                    ['auth_token', 'authorization'],
                                    op['parameters']))
                    paths.setdefault(openapi_prefix + route.path_format,
                                     {}).update(path)
                if security_schemes:
                    components.setdefault("securitySchemes",
                                          {}).update(security_schemes)
                if path_definitions:
                    definitions.update(path_definitions)
    if definitions:
        components["schemas"] = {
            k: definitions[k]
            for k in sorted(definitions)
        }
    if components:
        output["components"] = components
    output["paths"] = paths
    # return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False)
    return output
예제 #6
0
def get_openapi(*,
                title: str,
                version: str,
                openapi_version: str = "3.0.2",
                description: str = None,
                routes: Sequence[BaseRoute],
                openapi_prefix: str = "") -> Dict:
    info = {"title": title, "version": version}
    if description:
        info["description"] = description
    output = {"openapi": openapi_version, "info": info}
    components: Dict[str, Dict] = {}
    paths: Dict[str, Dict] = {}
    definitions: Dict[str, Any] = {}
    for route in routes:
        if isinstance(route, routing.APIRoute):
            flat_models = get_flat_models_from_routes([route])
            model_name_map = get_model_name_map(flat_models)
            definitions.update(
                get_model_definitions(
                    by_alias=route.response_model_by_alias,
                    flat_models=flat_models,
                    model_name_map=model_name_map,
                ))
            result = get_openapi_path(route=route,
                                      model_name_map=model_name_map)
            if result:
                path, security_schemes, path_definitions = result
                if path:
                    paths.setdefault(openapi_prefix + route.path_format,
                                     {}).update(path)
                if security_schemes:
                    components.setdefault("securitySchemes",
                                          {}).update(security_schemes)
                if path_definitions:
                    definitions.update(path_definitions)
    if definitions:
        components["schemas"] = {
            k: definitions[k]
            for k in sorted(definitions)
        }
    if components:
        output["components"] = components
    output["paths"] = paths
    return jsonable_encoder(OpenAPI(**output),
                            by_alias=True,
                            exclude_none=True)
예제 #7
0
def get_openapi(
    *,
    title: str,
    version: str,
    openapi_version: str = "3.0.2",
    description: Optional[str] = None,
    routes: Sequence,
    tags: Optional[List[Dict[str, Any]]] = None,
) -> Dict:
    info = {"title": title, "version": version}
    if description:
        info["description"] = description
    output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
    components: Dict[str, Dict] = {}
    paths: Dict[str, Dict] = {}
    flat_models = get_flat_models_from_routes(routes)
    # ignore mypy error until enum schemas are released
    model_name_map = get_model_name_map(flat_models)  # type: ignore
    # ignore mypy error until enum schemas are released
    definitions = get_model_definitions(
        flat_models=flat_models, model_name_map=model_name_map  # type: ignore
    )
    # todo: security
    for route in routes:
        if isinstance(route, ViewRoute):
            result = get_openapi_path(route=route, model_name_map=model_name_map)
            if result:
                path, path_definitions = result
                if path:
                    paths.setdefault(route.path_format, {}).update(path)
                if path_definitions:
                    definitions.update(path_definitions)
    if definitions:
        components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
    if components:
        output["components"] = components
    output["paths"] = paths
    if tags:
        output["tags"] = tags
    return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True)
예제 #8
0
    def get_models_name(self):
        models = set()
        for route in self.app.routes:
            endpoint = route.endpoint
            if not self.is_pydantic_endpoint(endpoint):
                continue

            for method in OpenApiPath.operations:
                if not hasattr(endpoint, method):
                    continue

                handler = getattr(endpoint, method)
                method_annotations = handler.__annotations__
                for ann_name, ann in method_annotations.items():
                    if isinstance(ann, type):
                        if issubclass(ann, BaseModel):
                            models.add(ann)
                        if issubclass(ann, BaseForm):
                            models.add(ann.model_cls)

        model_names = get_model_name_map(models)
        return model_names