Exemplo n.º 1
0
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
    flat_dependant = get_flat_dependant(dependant)
    if not flat_dependant.body_params:
        return None
    first_param = flat_dependant.body_params[0]
    field_info = get_field_info(first_param)
    embed = getattr(field_info, "embed", None)
    if len(flat_dependant.body_params) == 1 and not embed:
        return get_schema_compatible_field(field=first_param)
    model_name = "Body_" + name
    BodyModel = create_model(model_name)
    for f in flat_dependant.body_params:
        BodyModel.__fields__[f.name] = get_schema_compatible_field(field=f)
    required = any(True for f in flat_dependant.body_params if f.required)

    BodyFieldInfo_kwargs: Dict[str, Any] = dict(default=None)
    if any(
            isinstance(get_field_info(f), params.File)
            for f in flat_dependant.body_params):
        BodyFieldInfo: Type[params.Body] = params.File
    elif any(
            isinstance(get_field_info(f), params.Form)
            for f in flat_dependant.body_params):
        BodyFieldInfo = params.Form
    else:
        BodyFieldInfo = params.Body

        body_param_media_types = [
            getattr(get_field_info(f), "media_type")
            for f in flat_dependant.body_params
            if isinstance(get_field_info(f), params.Body)
        ]
        if len(set(body_param_media_types)) == 1:
            BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
    if PYDANTIC_1:
        field = ModelField(
            name="body",
            type_=BodyModel,
            default=None,
            required=required,
            model_config=BaseConfig,
            class_validators={},
            alias="body",
            field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
        )
    else:  # pragma: nocover
        field = ModelField(  # type: ignore
            name="body",
            type_=BodyModel,
            default=None,
            required=required,
            model_config=BaseConfig,
            class_validators={},
            alias="body",
            schema=BodyFieldInfo(**BodyFieldInfo_kwargs),
        )
    return field
Exemplo n.º 2
0
def get_openapi_operation_parameters(
    *, all_route_params: Sequence[ModelField],
    model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str]
) -> List[Dict[str, Any]]:
    parameters = []
    for param in all_route_params:
        field_info = get_field_info(param)
        field_info = cast(Param, field_info)
        # ignore mypy error until enum schemas are released
        parameter = {
            "name":
            param.alias,
            "in":
            field_info.in_.value,
            "required":
            param.required,
            "schema":
            field_schema(
                param,
                model_name_map=model_name_map,
                ref_prefix=REF_PREFIX  # type: ignore
            )[0],
        }
        if field_info.description:
            parameter["description"] = field_info.description
        if field_info.deprecated:
            parameter["deprecated"] = field_info.deprecated
        parameters.append(parameter)
    return parameters
Exemplo n.º 3
0
async def request_body_to_args(
    required_params: List[ModelField],
    received_body: Optional[Union[Dict[str, Any], FormData]],
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
    values = {}
    errors = []
    if required_params:
        field = required_params[0]
        field_info = get_field_info(field)
        embed = getattr(field_info, "embed", None)
        if len(required_params) == 1 and not embed:
            received_body = {field.alias: received_body}
        for field in required_params:
            value: Any = None
            if received_body is not None:
                if field.shape in sequence_shapes and isinstance(
                        received_body, FormData):
                    value = received_body.getlist(field.alias)
                else:
                    value = received_body.get(field.alias)
            if (value is None
                    or (isinstance(field_info, params.Form) and value == "") or
                (isinstance(field_info, params.Form)
                 and field.shape in sequence_shapes and len(value) == 0)):
                if field.required:
                    if PYDANTIC_1:
                        errors.append(
                            ErrorWrapper(MissingError(),
                                         loc=("body", field.alias)))
                    else:  # pragma: nocover
                        errors.append(
                            ErrorWrapper(  # type: ignore
                                MissingError(),
                                loc=("body", field.alias),
                                config=BaseConfig,
                            ))
                else:
                    values[field.name] = deepcopy(field.default)
                continue
            if (isinstance(field_info, params.File)
                    and lenient_issubclass(field.type_, bytes)
                    and isinstance(value, UploadFile)):
                value = await value.read()
            elif (field.shape in sequence_shapes
                  and isinstance(field_info, params.File)
                  and lenient_issubclass(field.type_, bytes)
                  and isinstance(value, sequence_types)):
                awaitables = [sub_value.read() for sub_value in value]
                contents = await asyncio.gather(*awaitables)
                value = sequence_shape_to_type[field.shape](contents)
            v_, errors_ = field.validate(value,
                                         values,
                                         loc=("body", field.alias))
            if isinstance(errors_, ErrorWrapper):
                errors.append(errors_)
            elif isinstance(errors_, list):
                errors.extend(errors_)
            else:
                values[field.name] = v_
    return values, errors
Exemplo n.º 4
0
def get_dependant(
    *,
    path: str,
    call: Callable,
    name: str = None,
    security_scopes: List[str] = None,
    use_cache: bool = True,
) -> Dependant:
    path_param_names = get_path_param_names(path)
    endpoint_signature = get_typed_signature(call)
    signature_params = endpoint_signature.parameters
    if inspect.isgeneratorfunction(call) or inspect.isasyncgenfunction(call):
        check_dependency_contextmanagers()
    dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
    for param_name, param in signature_params.items():
        if isinstance(param.default, params.Depends):
            sub_dependant = get_param_sub_dependant(
                param=param, path=path, security_scopes=security_scopes)
            dependant.dependencies.append(sub_dependant)
    for param_name, param in signature_params.items():
        if isinstance(param.default, params.Depends):
            continue
        if add_non_field_param_to_dependency(param=param, dependant=dependant):
            continue
        param_field = get_param_field(param=param,
                                      default_field_info=params.Query,
                                      param_name=param_name)
        if param_name in path_param_names:
            assert is_scalar_field(
                field=param_field
            ), f"Path params must be of one of the supported types"
            if isinstance(param.default, params.Path):
                ignore_default = False
            else:
                ignore_default = True
            param_field = get_param_field(
                param=param,
                param_name=param_name,
                default_field_info=params.Path,
                force_type=params.ParamTypes.path,
                ignore_default=ignore_default,
            )
            add_param_to_fields(field=param_field, dependant=dependant)
        elif is_scalar_field(field=param_field):
            add_param_to_fields(field=param_field, dependant=dependant)
        elif isinstance(
                param.default,
            (params.Query,
             params.Header)) and is_scalar_sequence_field(param_field):
            add_param_to_fields(field=param_field, dependant=dependant)
        else:
            field_info = get_field_info(param_field)
            assert isinstance(
                field_info, params.Body
            ), f"Param: {param_field.name} can only be a request body, using Body(...)"
            dependant.body_params.append(param_field)
    return dependant
Exemplo n.º 5
0
def is_scalar_field(field: ModelField) -> bool:
    field_info = get_field_info(field)
    if not (field.shape == SHAPE_SINGLETON
            and not lenient_issubclass(field.type_, BaseModel)
            and not lenient_issubclass(field.type_, sequence_types + (dict, ))
            and not isinstance(field_info, params.Body)):
        return False
    if field.sub_fields:
        if not all(is_scalar_field(f) for f in field.sub_fields):
            return False
    return True
Exemplo n.º 6
0
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
    field_info = cast(params.Param, get_field_info(field))
    if field_info.in_ == params.ParamTypes.path:
        dependant.path_params.append(field)
    elif field_info.in_ == params.ParamTypes.query:
        dependant.query_params.append(field)
    elif field_info.in_ == params.ParamTypes.header:
        dependant.header_params.append(field)
    else:
        assert (
            field_info.in_ == params.ParamTypes.cookie
        ), f"non-body parameters must be in path, query, header or cookie: {field.name}"
        dependant.cookie_params.append(field)
Exemplo n.º 7
0
def get_openapi_operation_request_body(
    *, body_field: Optional[ModelField], model_name_map: Dict[Type[BaseModel], str]
) -> Optional[Dict]:
    if not body_field:
        return None
    assert isinstance(body_field, ModelField)
    body_schema, _, _ = field_schema(
        body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
    )
    field_info = cast(Body, get_field_info(body_field))
    request_media_type = field_info.media_type
    required = body_field.required
    request_body_oai: Dict[str, Any] = {}
    if required:
        request_body_oai["required"] = required
    request_body_oai["content"] = {request_media_type: {"schema": body_schema}}
    return request_body_oai
Exemplo n.º 8
0
def request_params_to_args(
    required_params: Sequence[ModelField],
    received_params: Union[Mapping[str, Any], QueryParams, Headers],
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
    values = {}
    errors = []
    for field in required_params:
        if is_scalar_sequence_field(field) and isinstance(
            received_params, (QueryParams, Headers)
        ):
            value = received_params.getlist(field.alias) or field.default
        else:
            value = received_params.get(field.alias)
        field_info = get_field_info(field)
        assert isinstance(
            field_info, params.Param
        ), "Params must be subclasses of Param"
        if value is None:
            if field.required:
                if PYDANTIC_1:
                    errors.append(
                        ErrorWrapper(
                            MissingError(), loc=(field_info.in_.value, field.alias)
                        )
                    )
                else:  # pragma: nocover
                    errors.append(
                        ErrorWrapper(  # type: ignore
                            MissingError(),
                            loc=(field_info.in_.value, field.alias),
                            config=BaseConfig,
                        )
                    )
            else:
                values[field.name] = deepcopy(field.default)
            continue
        v_, errors_ = field.validate(
            value, values, loc=(field_info.in_.value, field.alias)
        )
        if isinstance(errors_, ErrorWrapper):
            errors.append(errors_)
        elif isinstance(errors_, list):
            errors.extend(errors_)
        else:
            values[field.name] = v_
    return values, errors
Exemplo n.º 9
0
def get_schema_compatible_field(*, field: ModelField) -> ModelField:
    out_field = field
    if lenient_issubclass(field.type_, UploadFile):
        use_type: type = bytes
        if field.shape in sequence_shapes:
            use_type = List[bytes]
        out_field = create_response_field(
            name=field.name,
            type_=use_type,
            class_validators=field.class_validators,
            model_config=field.model_config,
            default=field.default,
            required=field.required,
            alias=field.alias,
            field_info=get_field_info(field),
        )
    return out_field
Exemplo n.º 10
0
def get_openapi_operation_parameters(
    all_route_params: Sequence[ModelField], ) -> List[Dict[str, Any]]:
    parameters = []
    for param in all_route_params:
        field_info = get_field_info(param)
        field_info = cast(Param, field_info)
        parameter = {
            "name": param.alias,
            "in": field_info.in_.value,
            "required": param.required,
            "schema": field_schema(param, model_name_map={})[0],
        }
        if field_info.description:
            parameter["description"] = field_info.description
        if field_info.deprecated:
            parameter["deprecated"] = field_info.deprecated
        parameters.append(parameter)
    return parameters
Exemplo n.º 11
0
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
    flat_dependant = get_flat_dependant(dependant)
    if not flat_dependant.body_params:
        return None
    first_param = flat_dependant.body_params[0]
    field_info = get_field_info(first_param)
    embed = getattr(field_info, "embed", None)
    body_param_names_set = {param.name for param in flat_dependant.body_params}
    if len(body_param_names_set) == 1 and not embed:
        final_field = get_schema_compatible_field(field=first_param)
        check_file_field(final_field)
        return final_field
    # If one field requires to embed, all have to be embedded
    # in case a sub-dependency is evaluated with a single unique body field
    # That is combined (embedded) with other body fields
    for param in flat_dependant.body_params:
        setattr(get_field_info(param), "embed", True)
    model_name = "Body_" + name
    BodyModel = create_model(model_name)
    for f in flat_dependant.body_params:
        BodyModel.__fields__[f.name] = get_schema_compatible_field(field=f)
    required = any(True for f in flat_dependant.body_params if f.required)

    BodyFieldInfo_kwargs: Dict[str, Any] = dict(default=None)
    if any(
            isinstance(get_field_info(f), params.File)
            for f in flat_dependant.body_params):
        BodyFieldInfo: Type[params.Body] = params.File
    elif any(
            isinstance(get_field_info(f), params.Form)
            for f in flat_dependant.body_params):
        BodyFieldInfo = params.Form
    else:
        BodyFieldInfo = params.Body

        body_param_media_types = [
            getattr(get_field_info(f), "media_type")
            for f in flat_dependant.body_params
            if isinstance(get_field_info(f), params.Body)
        ]
        if len(set(body_param_media_types)) == 1:
            BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
    final_field = create_response_field(
        name="body",
        type_=BodyModel,
        required=required,
        alias="body",
        field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
    )
    check_file_field(final_field)
    return final_field
Exemplo n.º 12
0
def check_file_field(field: ModelField) -> None:
    field_info = get_field_info(field)
    if isinstance(field_info, params.Form):
        try:
            # __version__ is available in both multiparts, and can be mocked
            from multipart import __version__

            assert __version__
            try:
                # parse_options_header is only available in the right multlipart
                from multipart.multipart import parse_options_header

                assert parse_options_header
            except ImportError:
                logger.error(multipart_incorrect_install_error)
                raise RuntimeError(multipart_incorrect_install_error)
        except ImportError:
            logger.error(multipart_not_installed_error)
            raise RuntimeError(multipart_not_installed_error)
Exemplo n.º 13
0
def get_request_handler(
    dependant: Dependant,
    body_field: ModelField = None,
    status_code: int = 200,
    response_class: Type[Response] = JSONResponse,
    response_field: ModelField = None,
    response_model_include: Union[SetIntStr, DictIntStrAny] = None,
    response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
    response_model_by_alias: bool = True,
    response_model_exclude_unset: bool = False,
    dependency_overrides_provider: Any = None,
) -> Callable:
    assert dependant.call is not None, "dependant.call must be a function"
    is_coroutine = asyncio.iscoroutinefunction(dependant.call)
    is_body_form = body_field and isinstance(get_field_info(body_field), params.Form)

    async def app(request: Request) -> Response:
        try:
            body = None
            if body_field:
                if is_body_form:
                    body = await request.form()
                else:
                    body_bytes = await request.body()
                    if body_bytes and request.headers["Content-Type"] == "application/json":
                        body = await request.json()
                    else:
                        body = body_bytes
        except Exception as e:
            logger.error(f"Error getting request body: {e}")
            raise HTTPException(
                status_code=400, detail="There was an error parsing the body"
            ) from e
        solved_result = await solve_dependencies(
            request=request,
            dependant=dependant,
            body=body,
            dependency_overrides_provider=dependency_overrides_provider,
        )
        values, errors, background_tasks, sub_response, _ = solved_result
        if errors:
            raise RequestValidationError(errors, body=body)
        else:
            raw_response = await run_endpoint_function(
                dependant=dependant, values=values, is_coroutine=is_coroutine
            )

            if isinstance(raw_response, Response):
                if raw_response.background is None:
                    raw_response.background = background_tasks
                return raw_response
            response_data = await serialize_response(
                field=response_field,
                response_content=raw_response,
                include=response_model_include,
                exclude=response_model_exclude,
                by_alias=response_model_by_alias,
                exclude_unset=response_model_exclude_unset,
                is_coroutine=is_coroutine,
            )
            response = response_class(
                content=response_data,
                status_code=status_code,
                background=background_tasks,
            )
            response.headers.raw.extend(sub_response.headers.raw)
            if sub_response.status_code:
                response.status_code = sub_response.status_code
            return response

    return app
Exemplo n.º 14
0
async def request_body_to_args(
    required_params: List[ModelField],
    received_body: Optional[Union[Dict[str, Any], FormData]],
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
    values = {}
    errors = []
    if required_params:
        field = required_params[0]
        field_info = get_field_info(field)
        embed = getattr(field_info, "embed", None)
        field_alias_omitted = len(required_params) == 1 and not embed
        if field_alias_omitted:
            received_body = {field.alias: received_body}

        for field in required_params:
            loc: Tuple[str, ...]
            if field_alias_omitted:
                loc = ("body",)
            else:
                loc = ("body", field.alias)

            value: Any = None
            if received_body is not None:
                if (
                    field.shape in sequence_shapes or field.type_ in sequence_types
                ) and isinstance(received_body, FormData):
                    value = received_body.getlist(field.alias)
                else:
                    try:
                        value = received_body.get(field.alias)
                    except AttributeError:
                        errors.append(get_missing_field_error(loc))
                        continue
            if (
                value is None
                or (isinstance(field_info, params.Form) and value == "")
                or (
                    isinstance(field_info, params.Form)
                    and field.shape in sequence_shapes
                    and len(value) == 0
                )
            ):
                if field.required:
                    errors.append(get_missing_field_error(loc))
                else:
                    values[field.name] = deepcopy(field.default)
                continue
            if (
                isinstance(field_info, params.File)
                and lenient_issubclass(field.type_, bytes)
                and isinstance(value, UploadFile)
            ):
                value = await value.read()
            elif (
                field.shape in sequence_shapes
                and isinstance(field_info, params.File)
                and lenient_issubclass(field.type_, bytes)
                and isinstance(value, sequence_types)
            ):
                awaitables = [sub_value.read() for sub_value in value]
                contents = await asyncio.gather(*awaitables)
                value = sequence_shape_to_type[field.shape](contents)

            v_, errors_ = field.validate(value, values, loc=loc)

            if isinstance(errors_, ErrorWrapper):
                errors.append(errors_)
            elif isinstance(errors_, list):
                errors.extend(errors_)
            else:
                values[field.name] = v_
    return values, errors