예제 #1
0
 def trace(
     self,
     path: str,
     *,
     response_model: Optional[Type[Any]] = None,
     status_code: Optional[int] = None,
     tags: Optional[List[Union[str, Enum]]] = None,
     dependencies: Optional[Sequence[Depends]] = None,
     summary: Optional[str] = None,
     description: Optional[str] = None,
     response_description: str = "Successful Response",
     responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
     deprecated: Optional[bool] = None,
     operation_id: Optional[str] = None,
     response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
     response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
     response_model_by_alias: bool = True,
     response_model_exclude_unset: bool = False,
     response_model_exclude_defaults: bool = False,
     response_model_exclude_none: bool = False,
     include_in_schema: bool = True,
     response_class: Type[Response] = Default(JSONResponse),
     name: Optional[str] = None,
     callbacks: Optional[List[BaseRoute]] = None,
     openapi_extra: Optional[Dict[str, Any]] = None,
 ) -> Callable[[DecoratedCallable], DecoratedCallable]:
     return self.router.trace(
         path,
         response_model=response_model,
         status_code=status_code,
         tags=tags,
         dependencies=dependencies,
         summary=summary,
         description=description,
         response_description=response_description,
         responses=responses,
         deprecated=deprecated,
         operation_id=operation_id,
         response_model_include=response_model_include,
         response_model_exclude=response_model_exclude,
         response_model_by_alias=response_model_by_alias,
         response_model_exclude_unset=response_model_exclude_unset,
         response_model_exclude_defaults=response_model_exclude_defaults,
         response_model_exclude_none=response_model_exclude_none,
         include_in_schema=include_in_schema,
         response_class=response_class,
         name=name,
         callbacks=callbacks,
         openapi_extra=openapi_extra,
     )
예제 #2
0
 def patch(
     self,
     path: str,
     *,
     response_model: Optional[Type[Any]] = None,
     status_code: int = 200,
     tags: Optional[List[str]] = None,
     dependencies: Optional[Sequence[params.Depends]] = None,
     summary: Optional[str] = None,
     description: Optional[str] = None,
     response_description: str = "Successful Response",
     responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
     deprecated: Optional[bool] = None,
     operation_id: Optional[str] = None,
     response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
     response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
     response_model_by_alias: bool = True,
     response_model_exclude_unset: bool = False,
     response_model_exclude_defaults: bool = False,
     response_model_exclude_none: bool = False,
     include_in_schema: bool = True,
     response_class: Type[Response] = Default(JSONResponse),
     name: Optional[str] = None,
     callbacks: Optional[List[APIRoute]] = None,
 ) -> Callable:
     return self.api_route(
         path=path,
         response_model=response_model,
         status_code=status_code,
         tags=tags,
         dependencies=dependencies,
         summary=summary,
         description=description,
         response_description=response_description,
         responses=responses,
         deprecated=deprecated,
         methods=["PATCH"],
         operation_id=operation_id,
         response_model_include=response_model_include,
         response_model_exclude=response_model_exclude,
         response_model_by_alias=response_model_by_alias,
         response_model_exclude_unset=response_model_exclude_unset,
         response_model_exclude_defaults=response_model_exclude_defaults,
         response_model_exclude_none=response_model_exclude_none,
         include_in_schema=include_in_schema,
         response_class=response_class,
         name=name,
         callbacks=callbacks,
     )
예제 #3
0
def route(
    path: str = "",
    *,
    response_model: Optional[Type[Any]] = None,
    status_code: int = 200,
    tags: Optional[List[str]] = None,
    dependencies: Optional[Sequence[params.Depends]] = None,
    summary: Optional[str] = None,
    description: Optional[str] = None,
    response_description: str = "Successful Response",
    responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
    deprecated: Optional[bool] = None,
    methods: Optional[List[str]] = None,
    operation_id: Optional[str] = None,
    response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
    response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
    response_model_by_alias: bool = True,
    response_model_exclude_unset: bool = False,
    response_model_exclude_defaults: bool = False,
    response_model_exclude_none: bool = False,
    include_in_schema: bool = True,
    response_class: Type[Response] = Default(JSONResponse),
    name: Optional[str] = None,
    callbacks: Optional[List[BaseRoute]] = None,
    response_model_infer: bool = True,
) -> Callable[[DecoratedMember], DecoratedMember]:
    args = dict(locals())

    def decorator(member: MemberType) -> MemberType:
        if not inspect.isroutine(member):
            raise TypeError("Decorator should be applied to routine")
        # infer response_model from function return type hint
        infer = args.pop("response_model_infer")
        if infer and args["response_model"] is None:
            return_type = get_type_hints(desc_unwrap(member)).get("return")
            if isinstance(return_type, type) and issubclass(
                    return_type, Response):
                # skip response class fail
                pass
            else:
                args["response_model"] = return_type
        RouteEntryManager.add(member, APIRouteEntry(args))
        return member

    return decorator
예제 #4
0
 def __init__(
     self,
     *,
     prefix: str = "",
     tags: Optional[List[str]] = None,
     dependencies: Optional[Sequence[params.Depends]] = None,
     default_response_class: Type[Response] = Default(JSONResponse),
     responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
     callbacks: Optional[List[APIRoute]] = None,
     routes: Optional[List[routing.BaseRoute]] = None,
     redirect_slashes: bool = True,
     default: Optional[ASGIApp] = None,
     dependency_overrides_provider: Optional[Any] = None,
     route_class: Type[APIRoute] = APIRoute,
     on_startup: Optional[Sequence[Callable]] = None,
     on_shutdown: Optional[Sequence[Callable]] = None,
     deprecated: bool = None,
     include_in_schema: bool = True,
 ) -> None:
     super().__init__(
         routes=routes,
         redirect_slashes=redirect_slashes,
         default=default,
         on_startup=on_startup,
         on_shutdown=on_shutdown,
     )
     if prefix:
         assert prefix.startswith("/"), "A path prefix must start with '/'"
         assert not prefix.endswith(
             "/"
         ), "A path prefix must not end with '/', as the routes will start with '/'"
     self.prefix = prefix
     self.tags: List[str] = tags or []
     self.dependencies = list(dependencies or []) or []
     self.deprecated = deprecated
     self.include_in_schema = include_in_schema
     self.responses = responses or {}
     self.callbacks = callbacks or []
     self.dependency_overrides_provider = dependency_overrides_provider
     self.route_class = route_class
     self.default_response_class = default_response_class
예제 #5
0
 def include_router(
     self,
     router: routing.APIRouter,
     *,
     prefix: str = "",
     tags: Optional[List[str]] = None,
     dependencies: Optional[Sequence[Depends]] = None,
     responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
     deprecated: bool = None,
     include_in_schema: bool = True,
     default_response_class: Type[Response] = Default(JSONResponse),
     callbacks: Optional[List[routing.APIRoute]] = None,
 ) -> None:
     self.router.include_router(
         router,
         prefix=prefix,
         tags=tags,
         dependencies=dependencies,
         responses=responses,
         deprecated=deprecated,
         include_in_schema=include_in_schema,
         default_response_class=default_response_class,
         callbacks=callbacks,
     )
예제 #6
0
def api(
    prefix: str = "",
    *,
    tags: Optional[List[str]] = None,
    dependencies: Optional[Sequence[params.Depends]] = None,
    default_response_class: Type[Response] = Default(JSONResponse),
    responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
    callbacks: Optional[List[BaseRoute]] = None,
    routes: Optional[List[routing.BaseRoute]] = None,
    redirect_slashes: bool = True,
    default: Optional[ASGIApp] = None,
    dependency_overrides_provider: Optional[Any] = None,
    route_class: Type[APIRoute] = APIRoute,
    on_startup: Optional[Sequence[Callable[[], Any]]] = None,
    on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
    deprecated: Optional[bool] = None,
    include_in_schema: bool = True,
):
    _ = locals()

    def decorator(_):
        raise AssertionError("api decorator not wrapped")

    return decorator
예제 #7
0
 def include_router(
     self,
     router: "APIRouter",
     *,
     prefix: str = "",
     tags: Optional[List[str]] = None,
     dependencies: Optional[Sequence[params.Depends]] = None,
     default_response_class: Type[Response] = Default(JSONResponse),
     responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
     callbacks: Optional[List[APIRoute]] = None,
     deprecated: bool = None,
     include_in_schema: bool = True,
 ) -> None:
     if prefix:
         assert prefix.startswith("/"), "A path prefix must start with '/'"
         assert not prefix.endswith(
             "/"
         ), "A path prefix must not end with '/', as the routes will start with '/'"
     else:
         for r in router.routes:
             path = getattr(r, "path")
             name = getattr(r, "name", "unknown")
             if path is not None and not path:
                 raise Exception(
                     f"Prefix and path cannot be both empty (path operation: {name})"
                 )
     if responses is None:
         responses = {}
     for route in router.routes:
         if isinstance(route, APIRoute):
             combined_responses = {**responses, **route.responses}
             use_response_class = get_value_or_default(
                 route.response_class,
                 router.default_response_class,
                 default_response_class,
                 self.default_response_class,
             )
             current_tags = []
             if tags:
                 current_tags.extend(tags)
             if route.tags:
                 current_tags.extend(route.tags)
             current_dependencies: List[params.Depends] = []
             if dependencies:
                 current_dependencies.extend(dependencies)
             if route.dependencies:
                 current_dependencies.extend(route.dependencies)
             current_callbacks = []
             if callbacks:
                 current_callbacks.extend(callbacks)
             if route.callbacks:
                 current_callbacks.extend(route.callbacks)
             self.add_api_route(
                 prefix + route.path,
                 route.endpoint,
                 response_model=route.response_model,
                 status_code=route.status_code,
                 tags=current_tags,
                 dependencies=current_dependencies,
                 summary=route.summary,
                 description=route.description,
                 response_description=route.response_description,
                 responses=combined_responses,
                 deprecated=route.deprecated or deprecated or self.deprecated,
                 methods=route.methods,
                 operation_id=route.operation_id,
                 response_model_include=route.response_model_include,
                 response_model_exclude=route.response_model_exclude,
                 response_model_by_alias=route.response_model_by_alias,
                 response_model_exclude_unset=route.response_model_exclude_unset,
                 response_model_exclude_defaults=route.response_model_exclude_defaults,
                 response_model_exclude_none=route.response_model_exclude_none,
                 include_in_schema=route.include_in_schema
                 and self.include_in_schema
                 and include_in_schema,
                 response_class=use_response_class,
                 name=route.name,
                 route_class_override=type(route),
                 callbacks=current_callbacks,
             )
         elif isinstance(route, routing.Route):
             self.add_route(
                 prefix + route.path,
                 route.endpoint,
                 methods=list(route.methods or []),
                 include_in_schema=route.include_in_schema,
                 name=route.name,
             )
         elif isinstance(route, APIWebSocketRoute):
             self.add_api_websocket_route(
                 prefix + route.path, route.endpoint, name=route.name
             )
         elif isinstance(route, routing.WebSocketRoute):
             self.add_websocket_route(
                 prefix + route.path, route.endpoint, name=route.name
             )
     for handler in router.on_startup:
         self.add_event_handler("startup", handler)
     for handler in router.on_shutdown:
         self.add_event_handler("shutdown", handler)
예제 #8
0
 def add_api_route(
     self,
     path: str,
     endpoint: Callable,
     *,
     response_model: Optional[Type[Any]] = None,
     status_code: int = 200,
     tags: Optional[List[str]] = None,
     dependencies: Optional[Sequence[params.Depends]] = None,
     summary: Optional[str] = None,
     description: Optional[str] = None,
     response_description: str = "Successful Response",
     responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
     deprecated: Optional[bool] = None,
     methods: Optional[Union[Set[str], List[str]]] = None,
     operation_id: Optional[str] = None,
     response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
     response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
     response_model_by_alias: bool = True,
     response_model_exclude_unset: bool = False,
     response_model_exclude_defaults: bool = False,
     response_model_exclude_none: bool = False,
     include_in_schema: bool = True,
     response_class: Union[Type[Response], DefaultPlaceholder] = Default(
         JSONResponse
     ),
     name: Optional[str] = None,
     route_class_override: Optional[Type[APIRoute]] = None,
     callbacks: Optional[List[APIRoute]] = None,
 ) -> None:
     route_class = route_class_override or self.route_class
     responses = responses or {}
     combined_responses = {**self.responses, **responses}
     current_response_class = get_value_or_default(
         response_class, self.default_response_class
     )
     current_tags = self.tags.copy()
     if tags:
         current_tags.extend(tags)
     current_dependencies = self.dependencies.copy()
     if dependencies:
         current_dependencies.extend(dependencies)
     current_callbacks = self.callbacks.copy()
     if callbacks:
         current_callbacks.extend(callbacks)
     route = route_class(
         self.prefix + path,
         endpoint=endpoint,
         response_model=response_model,
         status_code=status_code,
         tags=current_tags,
         dependencies=current_dependencies,
         summary=summary,
         description=description,
         response_description=response_description,
         responses=combined_responses,
         deprecated=deprecated or self.deprecated,
         methods=methods,
         operation_id=operation_id,
         response_model_include=response_model_include,
         response_model_exclude=response_model_exclude,
         response_model_by_alias=response_model_by_alias,
         response_model_exclude_unset=response_model_exclude_unset,
         response_model_exclude_defaults=response_model_exclude_defaults,
         response_model_exclude_none=response_model_exclude_none,
         include_in_schema=include_in_schema and self.include_in_schema,
         response_class=current_response_class,
         name=name,
         dependency_overrides_provider=self.dependency_overrides_provider,
         callbacks=current_callbacks,
     )
     self.routes.append(route)
예제 #9
0
    def __init__(
        self,
        path: str,
        endpoint: Callable,
        *,
        response_model: Optional[Type[Any]] = None,
        status_code: int = 200,
        tags: Optional[List[str]] = None,
        dependencies: Optional[Sequence[params.Depends]] = None,
        summary: Optional[str] = None,
        description: Optional[str] = None,
        response_description: str = "Successful Response",
        responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
        deprecated: Optional[bool] = None,
        name: Optional[str] = None,
        methods: Optional[Union[Set[str], List[str]]] = None,
        operation_id: Optional[str] = None,
        response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
        response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
        response_model_by_alias: bool = True,
        response_model_exclude_unset: bool = False,
        response_model_exclude_defaults: bool = False,
        response_model_exclude_none: bool = False,
        include_in_schema: bool = True,
        response_class: Union[Type[Response], DefaultPlaceholder] = Default(
            JSONResponse
        ),
        dependency_overrides_provider: Optional[Any] = None,
        callbacks: Optional[List["APIRoute"]] = None,
    ) -> None:
        # normalise enums e.g. http.HTTPStatus
        if isinstance(status_code, enum.IntEnum):
            status_code = int(status_code)
        self.path = path
        self.endpoint = endpoint
        self.name = get_name(endpoint) if name is None else name
        self.path_regex, self.path_format, self.param_convertors = compile_path(path)
        if methods is None:
            methods = ["GET"]
        self.methods = set([method.upper() for method in methods])
        self.unique_id = generate_operation_id_for_path(
            name=self.name, path=self.path_format, method=list(methods)[0]
        )
        self.response_model = response_model
        if self.response_model:
            assert (
                status_code not in STATUS_CODES_WITH_NO_BODY
            ), f"Status code {status_code} must not have a response body"
            response_name = "Response_" + self.unique_id
            self.response_field = create_response_field(
                name=response_name, type_=self.response_model
            )
            # Create a clone of the field, so that a Pydantic submodel is not returned
            # as is just because it's an instance of a subclass of a more limited class
            # e.g. UserInDB (containing hashed_password) could be a subclass of User
            # that doesn't have the hashed_password. But because it's a subclass, it
            # would pass the validation and be returned as is.
            # By being a new field, no inheritance will be passed as is. A new model
            # will be always created.
            self.secure_cloned_response_field: Optional[
                ModelField
            ] = create_cloned_field(self.response_field)
        else:
            self.response_field = None  # type: ignore
            self.secure_cloned_response_field = None
        self.status_code = status_code
        self.tags = tags or []
        if dependencies:
            self.dependencies = list(dependencies)
        else:
            self.dependencies = []
        self.summary = summary
        self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
        # if a "form feed" character (page break) is found in the description text,
        # truncate description text to the content preceding the first "form feed"
        self.description = self.description.split("\f")[0]
        self.response_description = response_description
        self.responses = responses or {}
        response_fields = {}
        for additional_status_code, response in self.responses.items():
            assert isinstance(response, dict), "An additional response must be a dict"
            model = response.get("model")
            if model:
                assert (
                    additional_status_code not in STATUS_CODES_WITH_NO_BODY
                ), f"Status code {additional_status_code} must not have a response body"
                response_name = f"Response_{additional_status_code}_{self.unique_id}"
                response_field = create_response_field(name=response_name, type_=model)
                response_fields[additional_status_code] = response_field
        if response_fields:
            self.response_fields: Dict[Union[int, str], ModelField] = response_fields
        else:
            self.response_fields = {}
        self.deprecated = deprecated
        self.operation_id = operation_id
        self.response_model_include = response_model_include
        self.response_model_exclude = response_model_exclude
        self.response_model_by_alias = response_model_by_alias
        self.response_model_exclude_unset = response_model_exclude_unset
        self.response_model_exclude_defaults = response_model_exclude_defaults
        self.response_model_exclude_none = response_model_exclude_none
        self.include_in_schema = include_in_schema
        self.response_class = response_class

        assert callable(endpoint), "An endpoint must be a callable"
        self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
        for depends in self.dependencies[::-1]:
            self.dependant.dependencies.insert(
                0,
                get_parameterless_sub_dependant(depends=depends, path=self.path_format),
            )
        self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
        self.dependency_overrides_provider = dependency_overrides_provider
        self.callbacks = callbacks
        self.app = request_response(self.get_route_handler())
예제 #10
0
def get_request_handler(
    dependant: Dependant,
    body_field: Optional[ModelField] = None,
    status_code: int = 200,
    response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse),
    response_field: Optional[ModelField] = None,
    response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
    response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
    response_model_by_alias: bool = True,
    response_model_exclude_unset: bool = False,
    response_model_exclude_defaults: bool = False,
    response_model_exclude_none: bool = False,
    dependency_overrides_provider: Optional[Any] = None,
) -> Callable:
    assert dependant.call is not None, "dependant.call must be a function"
    is_coroutine = asyncio.iscoroutinefunction(dependant.call)
    is_body_form = body_field and isinstance(body_field.field_info, params.Form)
    if isinstance(response_class, DefaultPlaceholder):
        actual_response_class: Type[Response] = response_class.value
    else:
        actual_response_class = response_class

    async def app(request: Request) -> Response:
        try:
            body = None
            if body_field:
                if is_body_form:
                    body = await request.form()
                else:
                    body_bytes = await request.body()
                    if body_bytes:
                        body = await request.json()
        except json.JSONDecodeError as e:
            raise RequestValidationError([ErrorWrapper(e, ("body", e.pos))], body=e.doc)
        except Exception as e:
            raise HTTPException(
                status_code=400, detail="There was an error parsing the body"
            ) from e
        solved_result = await solve_dependencies(
            request=request,
            dependant=dependant,
            body=body,
            dependency_overrides_provider=dependency_overrides_provider,
        )
        values, errors, background_tasks, sub_response, _ = solved_result
        if errors:
            raise RequestValidationError(errors, body=body)
        else:
            raw_response = await run_endpoint_function(
                dependant=dependant, values=values, is_coroutine=is_coroutine
            )

            if isinstance(raw_response, Response):
                if raw_response.background is None:
                    raw_response.background = background_tasks
                return raw_response
            response_data = await serialize_response(
                field=response_field,
                response_content=raw_response,
                include=response_model_include,
                exclude=response_model_exclude,
                by_alias=response_model_by_alias,
                exclude_unset=response_model_exclude_unset,
                exclude_defaults=response_model_exclude_defaults,
                exclude_none=response_model_exclude_none,
                is_coroutine=is_coroutine,
            )
            response = actual_response_class(
                content=response_data,
                status_code=status_code,
                background=background_tasks,
            )
            response.headers.raw.extend(sub_response.headers.raw)
            if sub_response.status_code:
                response.status_code = sub_response.status_code
            return response

    return app
예제 #11
0
    def __init__(
        self,
        *,
        debug: bool = False,
        routes: Optional[List[BaseRoute]] = None,
        title: str = "FastAPI",
        description: str = "",
        version: str = "0.1.0",
        openapi_url: Optional[str] = "/openapi.json",
        openapi_tags: Optional[List[Dict[str, Any]]] = None,
        servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
        dependencies: Optional[Sequence[Depends]] = None,
        default_response_class: Type[Response] = Default(JSONResponse),
        docs_url: Optional[str] = "/docs",
        redoc_url: Optional[str] = "/redoc",
        swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect",
        swagger_ui_init_oauth: Optional[dict] = None,
        middleware: Optional[Sequence[Middleware]] = None,
        exception_handlers: Optional[
            Dict[Union[int, Type[Exception]], Callable]
        ] = None,
        on_startup: Optional[Sequence[Callable]] = None,
        on_shutdown: Optional[Sequence[Callable]] = None,
        openapi_prefix: str = "",
        root_path: str = "",
        root_path_in_servers: bool = True,
        responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
        callbacks: Optional[List[routing.APIRoute]] = None,
        deprecated: bool = None,
        include_in_schema: bool = True,
        **extra: Any,
    ) -> None:
        self._debug = debug
        self.state = State()
        self.router: routing.APIRouter = routing.APIRouter(
            routes=routes,
            dependency_overrides_provider=self,
            on_startup=on_startup,
            on_shutdown=on_shutdown,
            default_response_class=default_response_class,
            dependencies=dependencies,
            callbacks=callbacks,
            deprecated=deprecated,
            include_in_schema=include_in_schema,
            responses=responses,
        )
        self.exception_handlers = (
            {} if exception_handlers is None else dict(exception_handlers)
        )
        self.exception_handlers.setdefault(HTTPException, http_exception_handler)
        self.exception_handlers.setdefault(
            RequestValidationError, request_validation_exception_handler
        )

        self.user_middleware = [] if middleware is None else list(middleware)
        self.middleware_stack = self.build_middleware_stack()

        self.title = title
        self.description = description
        self.version = version
        self.servers = servers or []
        self.openapi_url = openapi_url
        self.openapi_tags = openapi_tags
        # TODO: remove when discarding the openapi_prefix parameter
        if openapi_prefix:
            logger.warning(
                '"openapi_prefix" has been deprecated in favor of "root_path", which '
                "follows more closely the ASGI standard, is simpler, and more "
                "automatic. Check the docs at "
                "https://fastapi.tiangolo.com/advanced/sub-applications/"
            )
        self.root_path = root_path or openapi_prefix
        self.root_path_in_servers = root_path_in_servers
        self.docs_url = docs_url
        self.redoc_url = redoc_url
        self.swagger_ui_oauth2_redirect_url = swagger_ui_oauth2_redirect_url
        self.swagger_ui_init_oauth = swagger_ui_init_oauth
        self.extra = extra
        self.dependency_overrides: Dict[Callable, Callable] = {}

        self.openapi_version = "3.0.2"

        if self.openapi_url:
            assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'"
            assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'"
        self.openapi_schema: Optional[Dict[str, Any]] = None
        self.setup()
예제 #12
0
def get_request_handler(
    dependant: Dependant,
    body_field: Optional[ModelField] = None,
    status_code: Optional[int] = None,
    response_class: Union[Type[Response],
                          DefaultPlaceholder] = Default(JSONResponse),
    response_field: Optional[ModelField] = None,
    response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
    response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
    response_model_by_alias: bool = True,
    response_model_exclude_unset: bool = False,
    response_model_exclude_defaults: bool = False,
    response_model_exclude_none: bool = False,
    dependency_overrides_provider: Optional[Any] = None,
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
    assert dependant.call is not None, "dependant.call must be a function"
    is_coroutine = asyncio.iscoroutinefunction(dependant.call)
    is_body_form = body_field and isinstance(body_field.field_info,
                                             params.Form)
    if isinstance(response_class, DefaultPlaceholder):
        actual_response_class: Type[Response] = response_class.value
    else:
        actual_response_class = response_class

    async def app(request: Request) -> Response:
        try:
            body: Any = None
            if body_field:
                if is_body_form:
                    body = await request.form()
                else:
                    body_bytes = await request.body()
                    if body_bytes:
                        json_body: Any = Undefined
                        content_type_value = request.headers.get(
                            "content-type")
                        if not content_type_value:
                            json_body = await request.json()
                        else:
                            message = email.message.Message()
                            message["content-type"] = content_type_value
                            if message.get_content_maintype() == "application":
                                subtype = message.get_content_subtype()
                                if subtype == "json" or subtype.endswith(
                                        "+json"):
                                    json_body = await request.json()
                        if json_body != Undefined:
                            body = json_body
                        else:
                            body = body_bytes
        except json.JSONDecodeError as e:
            raise RequestValidationError([ErrorWrapper(e, ("body", e.pos))],
                                         body=e.doc)
        except Exception as e:
            raise HTTPException(
                status_code=400,
                detail="There was an error parsing the body") from e
        solved_result = await solve_dependencies(
            request=request,
            dependant=dependant,
            body=body,
            dependency_overrides_provider=dependency_overrides_provider,
        )
        values, errors, background_tasks, sub_response, _ = solved_result
        if errors:
            raise RequestValidationError(errors, body=body)
        else:
            raw_response = await run_endpoint_function(
                dependant=dependant, values=values, is_coroutine=is_coroutine)

            if isinstance(raw_response, Response):
                if raw_response.background is None:
                    raw_response.background = background_tasks
                return raw_response
            response_data = await serialize_response(
                field=response_field,
                response_content=raw_response,
                include=response_model_include,
                exclude=response_model_exclude,
                by_alias=response_model_by_alias,
                exclude_unset=response_model_exclude_unset,
                exclude_defaults=response_model_exclude_defaults,
                exclude_none=response_model_exclude_none,
                is_coroutine=is_coroutine,
            )
            response_args: Dict[str, Any] = {"background": background_tasks}
            # If status_code was set, use it, otherwise use the default from the
            # response class, in the case of redirect it's 307
            if status_code is not None:
                response_args["status_code"] = status_code
            response = actual_response_class(response_data, **response_args)
            response.headers.raw.extend(sub_response.headers.raw)
            if sub_response.status_code:
                response.status_code = sub_response.status_code
            return response

    return app
예제 #13
0
def test_default_placeholder_bool():
    placeholder_a = Default("a")
    placeholder_b = Default("")
    assert placeholder_a
    assert not placeholder_b
예제 #14
0
def test_default_placeholder_equals():
    placeholder_1 = Default("a")
    placeholder_2 = Default("a")
    assert placeholder_1 == placeholder_2
    assert placeholder_1.value == placeholder_2.value
예제 #15
0
def get_request_handler(
    dependant: Dependant,
    content_type_mappings: Dict[str, Type[AbstractContentType]],
    body_field: Optional[ModelField] = None,
    status_code: int = 200,
    response_class: Union[Type[Response],
                          DefaultPlaceholder] = Default(JSONResponse),
    response_field: Optional[ModelField] = None,
    response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
    response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
    response_model_by_alias: bool = True,
    response_model_exclude_unset: bool = False,
    response_model_exclude_defaults: bool = False,
    response_model_exclude_none: bool = False,
    dependency_overrides_provider: Optional[Any] = None,
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
    assert dependant.call is not None, "dependant.call must be a function"
    is_coroutine = asyncio.iscoroutinefunction(dependant.call)
    is_body_form = body_field and isinstance(body_field.field_info,
                                             params.Form)
    if isinstance(response_class, DefaultPlaceholder):
        actual_response_class: Type[Response] = response_class.value
    else:
        actual_response_class = response_class

    async def app(request: Request) -> Response:
        try:
            body = None
            if body_field:
                if is_body_form:
                    body = await request.form()
                else:
                    body_bytes = await request.body()
                    if body_bytes:
                        body = await request.json()
        except json.JSONDecodeError as e:
            raise RequestValidationError([ErrorWrapper(e, ("body", e.pos))],
                                         body=e.doc)
        except Exception as e:
            raise HTTPException(
                status_code=400,
                detail="There was an error parsing the body") from e
        solved_result = await solve_dependencies(
            request=request,
            dependant=dependant,
            body=body,
            dependency_overrides_provider=dependency_overrides_provider,
        )
        values, errors, background_tasks, sub_response, _ = solved_result
        if errors:
            raise RequestValidationError(errors, body=body)
        else:
            raw_response = await run_endpoint_function(
                dependant=dependant, values=values, is_coroutine=is_coroutine)

            if isinstance(raw_response, Response):
                if raw_response.background is None:
                    raw_response.background = background_tasks
                return raw_response

            accept = request.headers['accept']

            # Default functionality, application/json via serialize_response
            if accept == 'application/json':
                response_data = await serialize_response(
                    field=response_field,
                    response_content=raw_response,
                    include=response_model_include,
                    exclude=response_model_exclude,
                    by_alias=response_model_by_alias,
                    exclude_unset=response_model_exclude_unset,
                    exclude_defaults=response_model_exclude_defaults,
                    exclude_none=response_model_exclude_none,
                    is_coroutine=is_coroutine,
                )
                response = actual_response_class(
                    content=response_data,
                    status_code=status_code,
                    background=background_tasks,  # type: ignore # in Starlette
                )
                response.headers.raw.extend(sub_response.headers.raw)
                if sub_response.status_code:
                    response.status_code = sub_response.status_code
                return response

            try:
                content_type_mapping: Type[
                    AbstractContentType] = content_type_mappings[accept]
                if isinstance(raw_response, content_type_mapping):
                    return Response(media_type=accept,
                                    content=content_type_mapping._to_method()(
                                        raw_response),
                                    status_code=status_code)
                raise HTTPException(
                    status_code=406,
                    detail=f"Unable to format content for Accept: {accept}")

            except KeyError:
                raise HTTPException(
                    status_code=406,
                    detail=f"Unable to format content for Accept: {accept}")

    return app
예제 #16
0
    def add_api_route(
        self,
        path: str,
        endpoint: Callable[..., Any],
        *,
        response_model: Optional[Type[Any]] = None,
        status_code: Optional[int] = None,
        response_description: Optional[str] = None,
        summary: Optional[str] = None,
        tags: Optional[List[str]] = None,
        dependencies: Optional[Sequence[params.Depends]] = None,
        description: Optional[str] = None,
        responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
        deprecated: Optional[bool] = None,
        methods: Optional[Union[Set[str], List[str]]] = None,
        operation_id: Optional[str] = None,
        response_model_include: Optional[Union[SetIntStr,
                                               DictIntStrAny]] = None,  # noqa
        response_model_exclude: Optional[Union[SetIntStr,
                                               DictIntStrAny]] = None,  # noqa
        response_model_by_alias: bool = True,
        response_model_exclude_unset: bool = False,
        response_model_exclude_defaults: bool = False,
        response_model_exclude_none: bool = False,
        include_in_schema: bool = True,
        response_class: Union[Type[Response],
                              DefaultPlaceholder] = Default(JSONResponse),
        name: Optional[str] = None,
        route_class_override: Optional[Type[APIRoute]] = None,
        callbacks: Optional[List[BaseRoute]] = None,
    ) -> None:
        if path in self.routes:
            self.remove_api_route(path, methods)

        if not summary:
            _endpoint_name = endpoint.__name__.strip('_').\
                replace('_', ' ').capitalize()
            if self.model.__name__.lower() in _endpoint_name:
                summary = _endpoint_name
            else:
                summary = _endpoint_name + ' ' + self.model.__name__

        if not response_description:
            response_description = f"{self.model.__name__} Successful Response"

        if not status_code:
            status_code = 200
            if methods:
                if len(methods) == 1 and methods[0].upper() == 'POST':
                    status_code = 201

        return super().add_api_route(
            path,
            endpoint,
            response_model=response_model,
            status_code=status_code,
            tags=tags,
            dependencies=dependencies,
            summary=summary,
            description=description,
            response_description=response_description,
            responses=responses,
            deprecated=deprecated,
            methods=methods,
            operation_id=operation_id,
            response_model_include=response_model_include,
            response_model_exclude=response_model_exclude,
            response_model_by_alias=response_model_by_alias,
            response_model_exclude_unset=response_model_exclude_unset,
            response_model_exclude_defaults=response_model_exclude_defaults,
            response_model_exclude_none=response_model_exclude_none,
            include_in_schema=include_in_schema,
            response_class=response_class,
            name=name,
            route_class_override=route_class_override,
            callbacks=callbacks)