def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: flat_dependant = get_flat_dependant(dependant) if not flat_dependant.body_params: return None first_param = flat_dependant.body_params[0] field_info = get_field_info(first_param) embed = getattr(field_info, "embed", None) if len(flat_dependant.body_params) == 1 and not embed: return get_schema_compatible_field(field=first_param) model_name = "Body_" + name BodyModel = create_model(model_name) for f in flat_dependant.body_params: BodyModel.__fields__[f.name] = get_schema_compatible_field(field=f) required = any(True for f in flat_dependant.body_params if f.required) BodyFieldInfo_kwargs: Dict[str, Any] = dict(default=None) if any( isinstance(get_field_info(f), params.File) for f in flat_dependant.body_params): BodyFieldInfo: Type[params.Body] = params.File elif any( isinstance(get_field_info(f), params.Form) for f in flat_dependant.body_params): BodyFieldInfo = params.Form else: BodyFieldInfo = params.Body body_param_media_types = [ getattr(get_field_info(f), "media_type") for f in flat_dependant.body_params if isinstance(get_field_info(f), params.Body) ] if len(set(body_param_media_types)) == 1: BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] if PYDANTIC_1: field = ModelField( name="body", type_=BodyModel, default=None, required=required, model_config=BaseConfig, class_validators={}, alias="body", field_info=BodyFieldInfo(**BodyFieldInfo_kwargs), ) else: # pragma: nocover field = ModelField( # type: ignore name="body", type_=BodyModel, default=None, required=required, model_config=BaseConfig, class_validators={}, alias="body", schema=BodyFieldInfo(**BodyFieldInfo_kwargs), ) return field
def get_openapi_operation_parameters( *, all_route_params: Sequence[ModelField], model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str] ) -> List[Dict[str, Any]]: parameters = [] for param in all_route_params: field_info = get_field_info(param) field_info = cast(Param, field_info) # ignore mypy error until enum schemas are released parameter = { "name": param.alias, "in": field_info.in_.value, "required": param.required, "schema": field_schema( param, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore )[0], } if field_info.description: parameter["description"] = field_info.description if field_info.deprecated: parameter["deprecated"] = field_info.deprecated parameters.append(parameter) return parameters
async def request_body_to_args( required_params: List[ModelField], received_body: Optional[Union[Dict[str, Any], FormData]], ) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: values = {} errors = [] if required_params: field = required_params[0] field_info = get_field_info(field) embed = getattr(field_info, "embed", None) if len(required_params) == 1 and not embed: received_body = {field.alias: received_body} for field in required_params: value: Any = None if received_body is not None: if field.shape in sequence_shapes and isinstance( received_body, FormData): value = received_body.getlist(field.alias) else: value = received_body.get(field.alias) if (value is None or (isinstance(field_info, params.Form) and value == "") or (isinstance(field_info, params.Form) and field.shape in sequence_shapes and len(value) == 0)): if field.required: if PYDANTIC_1: errors.append( ErrorWrapper(MissingError(), loc=("body", field.alias))) else: # pragma: nocover errors.append( ErrorWrapper( # type: ignore MissingError(), loc=("body", field.alias), config=BaseConfig, )) else: values[field.name] = deepcopy(field.default) continue if (isinstance(field_info, params.File) and lenient_issubclass(field.type_, bytes) and isinstance(value, UploadFile)): value = await value.read() elif (field.shape in sequence_shapes and isinstance(field_info, params.File) and lenient_issubclass(field.type_, bytes) and isinstance(value, sequence_types)): awaitables = [sub_value.read() for sub_value in value] contents = await asyncio.gather(*awaitables) value = sequence_shape_to_type[field.shape](contents) v_, errors_ = field.validate(value, values, loc=("body", field.alias)) if isinstance(errors_, ErrorWrapper): errors.append(errors_) elif isinstance(errors_, list): errors.extend(errors_) else: values[field.name] = v_ return values, errors
def get_dependant( *, path: str, call: Callable, name: str = None, security_scopes: List[str] = None, use_cache: bool = True, ) -> Dependant: path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters if inspect.isgeneratorfunction(call) or inspect.isasyncgenfunction(call): check_dependency_contextmanagers() dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache) for param_name, param in signature_params.items(): if isinstance(param.default, params.Depends): sub_dependant = get_param_sub_dependant( param=param, path=path, security_scopes=security_scopes) dependant.dependencies.append(sub_dependant) for param_name, param in signature_params.items(): if isinstance(param.default, params.Depends): continue if add_non_field_param_to_dependency(param=param, dependant=dependant): continue param_field = get_param_field(param=param, default_field_info=params.Query, param_name=param_name) if param_name in path_param_names: assert is_scalar_field( field=param_field ), f"Path params must be of one of the supported types" if isinstance(param.default, params.Path): ignore_default = False else: ignore_default = True param_field = get_param_field( param=param, param_name=param_name, default_field_info=params.Path, force_type=params.ParamTypes.path, ignore_default=ignore_default, ) add_param_to_fields(field=param_field, dependant=dependant) elif is_scalar_field(field=param_field): add_param_to_fields(field=param_field, dependant=dependant) elif isinstance( param.default, (params.Query, params.Header)) and is_scalar_sequence_field(param_field): add_param_to_fields(field=param_field, dependant=dependant) else: field_info = get_field_info(param_field) assert isinstance( field_info, params.Body ), f"Param: {param_field.name} can only be a request body, using Body(...)" dependant.body_params.append(param_field) return dependant
def is_scalar_field(field: ModelField) -> bool: field_info = get_field_info(field) if not (field.shape == SHAPE_SINGLETON and not lenient_issubclass(field.type_, BaseModel) and not lenient_issubclass(field.type_, sequence_types + (dict, )) and not isinstance(field_info, params.Body)): return False if field.sub_fields: if not all(is_scalar_field(f) for f in field.sub_fields): return False return True
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: field_info = cast(params.Param, get_field_info(field)) if field_info.in_ == params.ParamTypes.path: dependant.path_params.append(field) elif field_info.in_ == params.ParamTypes.query: dependant.query_params.append(field) elif field_info.in_ == params.ParamTypes.header: dependant.header_params.append(field) else: assert ( field_info.in_ == params.ParamTypes.cookie ), f"non-body parameters must be in path, query, header or cookie: {field.name}" dependant.cookie_params.append(field)
def get_openapi_operation_request_body( *, body_field: Optional[ModelField], model_name_map: Dict[Type[BaseModel], str] ) -> Optional[Dict]: if not body_field: return None assert isinstance(body_field, ModelField) body_schema, _, _ = field_schema( body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX ) field_info = cast(Body, get_field_info(body_field)) request_media_type = field_info.media_type required = body_field.required request_body_oai: Dict[str, Any] = {} if required: request_body_oai["required"] = required request_body_oai["content"] = {request_media_type: {"schema": body_schema}} return request_body_oai
def request_params_to_args( required_params: Sequence[ModelField], received_params: Union[Mapping[str, Any], QueryParams, Headers], ) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: values = {} errors = [] for field in required_params: if is_scalar_sequence_field(field) and isinstance( received_params, (QueryParams, Headers) ): value = received_params.getlist(field.alias) or field.default else: value = received_params.get(field.alias) field_info = get_field_info(field) assert isinstance( field_info, params.Param ), "Params must be subclasses of Param" if value is None: if field.required: if PYDANTIC_1: errors.append( ErrorWrapper( MissingError(), loc=(field_info.in_.value, field.alias) ) ) else: # pragma: nocover errors.append( ErrorWrapper( # type: ignore MissingError(), loc=(field_info.in_.value, field.alias), config=BaseConfig, ) ) else: values[field.name] = deepcopy(field.default) continue v_, errors_ = field.validate( value, values, loc=(field_info.in_.value, field.alias) ) if isinstance(errors_, ErrorWrapper): errors.append(errors_) elif isinstance(errors_, list): errors.extend(errors_) else: values[field.name] = v_ return values, errors
def get_schema_compatible_field(*, field: ModelField) -> ModelField: out_field = field if lenient_issubclass(field.type_, UploadFile): use_type: type = bytes if field.shape in sequence_shapes: use_type = List[bytes] out_field = create_response_field( name=field.name, type_=use_type, class_validators=field.class_validators, model_config=field.model_config, default=field.default, required=field.required, alias=field.alias, field_info=get_field_info(field), ) return out_field
def get_openapi_operation_parameters( all_route_params: Sequence[ModelField], ) -> List[Dict[str, Any]]: parameters = [] for param in all_route_params: field_info = get_field_info(param) field_info = cast(Param, field_info) parameter = { "name": param.alias, "in": field_info.in_.value, "required": param.required, "schema": field_schema(param, model_name_map={})[0], } if field_info.description: parameter["description"] = field_info.description if field_info.deprecated: parameter["deprecated"] = field_info.deprecated parameters.append(parameter) return parameters
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: flat_dependant = get_flat_dependant(dependant) if not flat_dependant.body_params: return None first_param = flat_dependant.body_params[0] field_info = get_field_info(first_param) embed = getattr(field_info, "embed", None) body_param_names_set = {param.name for param in flat_dependant.body_params} if len(body_param_names_set) == 1 and not embed: final_field = get_schema_compatible_field(field=first_param) check_file_field(final_field) return final_field # If one field requires to embed, all have to be embedded # in case a sub-dependency is evaluated with a single unique body field # That is combined (embedded) with other body fields for param in flat_dependant.body_params: setattr(get_field_info(param), "embed", True) model_name = "Body_" + name BodyModel = create_model(model_name) for f in flat_dependant.body_params: BodyModel.__fields__[f.name] = get_schema_compatible_field(field=f) required = any(True for f in flat_dependant.body_params if f.required) BodyFieldInfo_kwargs: Dict[str, Any] = dict(default=None) if any( isinstance(get_field_info(f), params.File) for f in flat_dependant.body_params): BodyFieldInfo: Type[params.Body] = params.File elif any( isinstance(get_field_info(f), params.Form) for f in flat_dependant.body_params): BodyFieldInfo = params.Form else: BodyFieldInfo = params.Body body_param_media_types = [ getattr(get_field_info(f), "media_type") for f in flat_dependant.body_params if isinstance(get_field_info(f), params.Body) ] if len(set(body_param_media_types)) == 1: BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] final_field = create_response_field( name="body", type_=BodyModel, required=required, alias="body", field_info=BodyFieldInfo(**BodyFieldInfo_kwargs), ) check_file_field(final_field) return final_field
def check_file_field(field: ModelField) -> None: field_info = get_field_info(field) if isinstance(field_info, params.Form): try: # __version__ is available in both multiparts, and can be mocked from multipart import __version__ assert __version__ try: # parse_options_header is only available in the right multlipart from multipart.multipart import parse_options_header assert parse_options_header except ImportError: logger.error(multipart_incorrect_install_error) raise RuntimeError(multipart_incorrect_install_error) except ImportError: logger.error(multipart_not_installed_error) raise RuntimeError(multipart_not_installed_error)
def get_request_handler( dependant: Dependant, body_field: ModelField = None, status_code: int = 200, response_class: Type[Response] = JSONResponse, response_field: ModelField = None, response_model_include: Union[SetIntStr, DictIntStrAny] = None, response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(), response_model_by_alias: bool = True, response_model_exclude_unset: bool = False, dependency_overrides_provider: Any = None, ) -> Callable: assert dependant.call is not None, "dependant.call must be a function" is_coroutine = asyncio.iscoroutinefunction(dependant.call) is_body_form = body_field and isinstance(get_field_info(body_field), params.Form) async def app(request: Request) -> Response: try: body = None if body_field: if is_body_form: body = await request.form() else: body_bytes = await request.body() if body_bytes and request.headers["Content-Type"] == "application/json": body = await request.json() else: body = body_bytes except Exception as e: logger.error(f"Error getting request body: {e}") raise HTTPException( status_code=400, detail="There was an error parsing the body" ) from e solved_result = await solve_dependencies( request=request, dependant=dependant, body=body, dependency_overrides_provider=dependency_overrides_provider, ) values, errors, background_tasks, sub_response, _ = solved_result if errors: raise RequestValidationError(errors, body=body) else: raw_response = await run_endpoint_function( dependant=dependant, values=values, is_coroutine=is_coroutine ) if isinstance(raw_response, Response): if raw_response.background is None: raw_response.background = background_tasks return raw_response response_data = await serialize_response( field=response_field, response_content=raw_response, include=response_model_include, exclude=response_model_exclude, by_alias=response_model_by_alias, exclude_unset=response_model_exclude_unset, is_coroutine=is_coroutine, ) response = response_class( content=response_data, status_code=status_code, background=background_tasks, ) response.headers.raw.extend(sub_response.headers.raw) if sub_response.status_code: response.status_code = sub_response.status_code return response return app
async def request_body_to_args( required_params: List[ModelField], received_body: Optional[Union[Dict[str, Any], FormData]], ) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: values = {} errors = [] if required_params: field = required_params[0] field_info = get_field_info(field) embed = getattr(field_info, "embed", None) field_alias_omitted = len(required_params) == 1 and not embed if field_alias_omitted: received_body = {field.alias: received_body} for field in required_params: loc: Tuple[str, ...] if field_alias_omitted: loc = ("body",) else: loc = ("body", field.alias) value: Any = None if received_body is not None: if ( field.shape in sequence_shapes or field.type_ in sequence_types ) and isinstance(received_body, FormData): value = received_body.getlist(field.alias) else: try: value = received_body.get(field.alias) except AttributeError: errors.append(get_missing_field_error(loc)) continue if ( value is None or (isinstance(field_info, params.Form) and value == "") or ( isinstance(field_info, params.Form) and field.shape in sequence_shapes and len(value) == 0 ) ): if field.required: errors.append(get_missing_field_error(loc)) else: values[field.name] = deepcopy(field.default) continue if ( isinstance(field_info, params.File) and lenient_issubclass(field.type_, bytes) and isinstance(value, UploadFile) ): value = await value.read() elif ( field.shape in sequence_shapes and isinstance(field_info, params.File) and lenient_issubclass(field.type_, bytes) and isinstance(value, sequence_types) ): awaitables = [sub_value.read() for sub_value in value] contents = await asyncio.gather(*awaitables) value = sequence_shape_to_type[field.shape](contents) v_, errors_ = field.validate(value, values, loc=loc) if isinstance(errors_, ErrorWrapper): errors.append(errors_) elif isinstance(errors_, list): errors.extend(errors_) else: values[field.name] = v_ return values, errors