Esempio n. 1
0
    def __validate_cls_namespace__(name: str, namespace: Dict) -> None:  # noqa C901
        """Validate the class name space in place"""
        annotations = resolve_annotations(
            namespace.get("__annotations__", {}), namespace.get("__module__")
        )
        config = validate_config(namespace.get("Config", BaseODMConfig), name)
        odm_fields: Dict[str, ODMBaseField] = {}
        references: List[str] = []
        bson_serialized_fields: Set[str] = set()
        mutable_fields: Set[str] = set()

        # Make sure all fields are defined with type annotation
        for field_name, value in namespace.items():
            if (
                should_touch_field(value=value)
                and not is_dunder(field_name)
                and field_name not in annotations
            ):
                raise TypeError(
                    f"field {field_name} is defined without type annotation"
                )

        # Validate fields types and substitute bson fields
        for (field_name, field_type) in annotations.items():
            if not is_dunder(field_name) and should_touch_field(type_=field_type):
                substituted_type = validate_type(field_type)
                # Handle BSON serialized fields after substitution to allow some
                # builtin substitution
                bson_serialization_method = getattr(substituted_type, "__bson__", None)
                if bson_serialization_method is not None:
                    bson_serialized_fields.add(field_name)
                annotations[field_name] = substituted_type

        # Validate fields
        for (field_name, field_type) in annotations.items():
            value = namespace.get(field_name, Undefined)

            if is_dunder(field_name) or not should_touch_field(value, field_type):
                continue  # pragma: no cover
                # https://github.com/nedbat/coveragepy/issues/198

            if isinstance(value, PDFieldInfo):
                raise TypeError("please use odmantic.Field instead of pydantic.Field")

            if is_type_mutable(field_type):
                mutable_fields.add(field_name)

            if lenient_issubclass(field_type, EmbeddedModel):
                if isinstance(value, ODMFieldInfo):
                    namespace[field_name] = value.pydantic_field_info
                    key_name = (
                        value.key_name if value.key_name is not None else field_name
                    )
                    primary_field = value.primary_field
                else:
                    key_name = field_name
                    primary_field = False

                odm_fields[field_name] = ODMEmbedded(
                    primary_field=primary_field,
                    model=field_type,
                    key_name=key_name,
                    model_config=config,
                )
            elif lenient_issubclass(field_type, Model):
                if not isinstance(value, ODMReferenceInfo):
                    raise TypeError(
                        "cannot define a reference {field_name} (in {name}) without"
                        " a Reference assigned to it"
                    )
                key_name = value.key_name if value.key_name is not None else field_name
                raise_on_invalid_key_name(key_name)
                odm_fields[field_name] = ODMReference(
                    model=field_type, key_name=key_name, model_config=config
                )
                references.append(field_name)
                del namespace[field_name]  # Remove default ODMReferenceInfo value
            else:
                if isinstance(value, ODMFieldInfo):
                    key_name = (
                        value.key_name if value.key_name is not None else field_name
                    )
                    raise_on_invalid_key_name(key_name)
                    odm_fields[field_name] = ODMField(
                        primary_field=value.primary_field,
                        key_name=key_name,
                        model_config=config,
                    )
                    namespace[field_name] = value.pydantic_field_info

                elif value is Undefined:
                    odm_fields[field_name] = ODMField(
                        primary_field=False, key_name=field_name, model_config=config
                    )

                else:
                    try:
                        parse_obj_as(field_type, value)
                    except ValidationError:
                        raise TypeError(
                            f"Unhandled field definition {name}: {repr(field_type)}"
                            f" = {repr(value)}"
                        )
                    odm_fields[field_name] = ODMField(
                        primary_field=False, key_name=field_name, model_config=config
                    )

        duplicate_key = find_duplicate_key(odm_fields.values())
        if duplicate_key is not None:
            raise TypeError(f"Duplicated key_name: {duplicate_key} in {name}")
        # NOTE: Duplicate key detection make sur that at most one primary key is
        # defined
        namespace["__annotations__"] = annotations
        namespace["__odm_fields__"] = odm_fields
        namespace["__references__"] = tuple(references)
        namespace["__bson_serialized_fields__"] = frozenset(bson_serialized_fields)
        namespace["__mutable_fields__"] = frozenset(mutable_fields)
        namespace["Config"] = config
Esempio n. 2
0
def should_touch_field(value: Any = None, type_: Optional[Type] = None) -> bool:
    return not (
        lenient_issubclass(type_, UNTOUCHED_TYPES)
        or isinstance(value, UNTOUCHED_TYPES)
        or (type_ is not None and is_classvar(type_))
    )
Esempio n. 3
0
    def __init__(
        self,
        path: str,
        endpoint: Callable,
        *,
        response_model: Type[Any] = None,
        status_code: int = 200,
        tags: List[str] = None,
        dependencies: Sequence[params.Depends] = None,
        summary: str = None,
        description: str = None,
        response_description: str = "Successful Response",
        responses: Dict[Union[int, str], Dict[str, Any]] = None,
        deprecated: bool = None,
        name: str = None,
        methods: Optional[Union[Set[str], List[str]]] = None,
        operation_id: str = None,
        response_model_include: Set[str] = None,
        response_model_exclude: Set[str] = set(),
        response_model_by_alias: bool = True,
        response_model_skip_defaults: bool = False,
        include_in_schema: bool = True,
        response_class: Type[Response] = JSONResponse,
        dependency_overrides_provider: Any = None,
    ) -> None:
        assert path.startswith("/"), "Routed paths must always start with '/'"
        self.path = path
        self.endpoint = endpoint
        self.name = get_name(endpoint) if name is None else name
        self.path_regex, self.path_format, self.param_convertors = compile_path(
            path)
        if methods is None:
            methods = ["GET"]
        self.methods = set([method.upper() for method in methods])
        self.unique_id = generate_operation_id_for_path(
            name=self.name, path=self.path_format, method=list(methods)[0])
        self.response_model = response_model
        if self.response_model:
            assert lenient_issubclass(
                response_class, JSONResponse
            ), "To declare a type the response must be a JSON response"
            response_name = "Response_" + self.unique_id
            self.response_field: Optional[Field] = Field(
                name=response_name,
                type_=self.response_model,
                class_validators={},
                default=None,
                required=False,
                model_config=BaseConfig,
                schema=Schema(None),
            )
            # Create a clone of the field, so that a Pydantic submodel is not returned
            # as is just because it's an instance of a subclass of a more limited class
            # e.g. UserInDB (containing hashed_password) could be a subclass of User
            # that doesn't have the hashed_password. But because it's a subclass, it
            # would pass the validation and be returned as is.
            # By being a new field, no inheritance will be passed as is. A new model
            # will be always created.
            self.secure_cloned_response_field: Optional[
                Field] = create_cloned_field(self.response_field)
        else:
            self.response_field = None
            self.secure_cloned_response_field = None
        self.status_code = status_code
        self.tags = tags or []
        if dependencies:
            self.dependencies = list(dependencies)
        else:
            self.dependencies = []
        self.summary = summary
        self.description = description or inspect.cleandoc(
            self.endpoint.__doc__ or "")
        self.response_description = response_description
        self.responses = responses or {}
        response_fields = {}
        for additional_status_code, response in self.responses.items():
            assert isinstance(response,
                              dict), "An additional response must be a dict"
            model = response.get("model")
            if model:
                assert lenient_issubclass(
                    model,
                    BaseModel), "A response model must be a Pydantic model"
                response_name = f"Response_{additional_status_code}_{self.unique_id}"
                response_field = Field(
                    name=response_name,
                    type_=model,
                    class_validators=None,
                    default=None,
                    required=False,
                    model_config=BaseConfig,
                    schema=Schema(None),
                )
                response_fields[additional_status_code] = response_field
        if response_fields:
            self.response_fields: Dict[Union[int, str],
                                       Field] = response_fields
        else:
            self.response_fields = {}
        self.deprecated = deprecated
        self.operation_id = operation_id
        self.response_model_include = response_model_include
        self.response_model_exclude = response_model_exclude
        self.response_model_by_alias = response_model_by_alias
        self.response_model_skip_defaults = response_model_skip_defaults
        self.include_in_schema = include_in_schema
        self.response_class = response_class

        assert inspect.isfunction(endpoint) or inspect.ismethod(
            endpoint), f"An endpoint must be a function or method"
        self.dependant = get_dependant(path=self.path_format,
                                       call=self.endpoint)
        for depends in self.dependencies[::-1]:
            self.dependant.dependencies.insert(
                0,
                get_parameterless_sub_dependant(depends=depends,
                                                path=self.path_format),
            )
        self.body_field = get_body_field(dependant=self.dependant,
                                         name=self.unique_id)
        self.dependency_overrides_provider = dependency_overrides_provider
        self.app = request_response(
            get_app(
                dependant=self.dependant,
                body_field=self.body_field,
                status_code=self.status_code,
                response_class=self.response_class,
                response_field=self.secure_cloned_response_field,
                response_model_include=self.response_model_include,
                response_model_exclude=self.response_model_exclude,
                response_model_by_alias=self.response_model_by_alias,
                response_model_skip_defaults=self.response_model_skip_defaults,
                dependency_overrides_provider=self.
                dependency_overrides_provider,
            ))
Esempio n. 4
0
async def request_body_to_args(
    required_params: List[ModelField],
    received_body: Optional[Union[Dict[str, Any], FormData]],
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
    values = {}
    errors = []
    if required_params:
        field = required_params[0]
        field_info = get_field_info(field)
        embed = getattr(field_info, "embed", None)
        field_alias_omitted = len(required_params) == 1 and not embed
        if field_alias_omitted:
            received_body = {field.alias: received_body}

        for field in required_params:
            loc: Tuple[str, ...]
            if field_alias_omitted:
                loc = ("body",)
            else:
                loc = ("body", field.alias)

            value: Any = None
            if received_body is not None:
                if (
                    field.shape in sequence_shapes or field.type_ in sequence_types
                ) and isinstance(received_body, FormData):
                    value = received_body.getlist(field.alias)
                else:
                    try:
                        value = received_body.get(field.alias)
                    except AttributeError:
                        errors.append(get_missing_field_error(loc))
                        continue
            if (
                value is None
                or (isinstance(field_info, params.Form) and value == "")
                or (
                    isinstance(field_info, params.Form)
                    and field.shape in sequence_shapes
                    and len(value) == 0
                )
            ):
                if field.required:
                    errors.append(get_missing_field_error(loc))
                else:
                    values[field.name] = deepcopy(field.default)
                continue
            if (
                isinstance(field_info, params.File)
                and lenient_issubclass(field.type_, bytes)
                and isinstance(value, UploadFile)
            ):
                value = await value.read()
            elif (
                field.shape in sequence_shapes
                and isinstance(field_info, params.File)
                and lenient_issubclass(field.type_, bytes)
                and isinstance(value, sequence_types)
            ):
                awaitables = [sub_value.read() for sub_value in value]
                contents = await asyncio.gather(*awaitables)
                value = sequence_shape_to_type[field.shape](contents)

            v_, errors_ = field.validate(value, values, loc=loc)

            if isinstance(errors_, ErrorWrapper):
                errors.append(errors_)
            elif isinstance(errors_, list):
                errors.extend(errors_)
            else:
                values[field.name] = v_
    return values, errors
Esempio n. 5
0
def create_cloned_field(field: ModelField) -> ModelField:
    original_type = field.type_
    if is_dataclass(original_type) and hasattr(original_type,
                                               "__pydantic_model__"):
        original_type = original_type.__pydantic_model__  # type: ignore
    use_type = original_type
    if lenient_issubclass(original_type, BaseModel):
        original_type = cast(Type[BaseModel], original_type)
        use_type = create_model(original_type.__name__, __base__=original_type)
        for f in original_type.__fields__.values():
            use_type.__fields__[f.name] = create_cloned_field(f)
    if PYDANTIC_1:
        new_field = ModelField(
            name=field.name,
            type_=use_type,
            class_validators={},
            default=None,
            required=False,
            model_config=BaseConfig,
            field_info=FieldInfo(None),
        )
    else:  # pragma: nocover
        new_field = ModelField(  # type: ignore
            name=field.name,
            type_=use_type,
            class_validators={},
            default=None,
            required=False,
            model_config=BaseConfig,
            schema=FieldInfo(None),
        )
    new_field.has_alias = field.has_alias
    new_field.alias = field.alias
    new_field.class_validators = field.class_validators
    new_field.default = field.default
    new_field.required = field.required
    new_field.model_config = field.model_config
    if PYDANTIC_1:
        new_field.field_info = field.field_info
    else:  # pragma: nocover
        new_field.schema = field.schema  # type: ignore
    new_field.allow_none = field.allow_none
    new_field.validate_always = field.validate_always
    if field.sub_fields:
        new_field.sub_fields = [
            create_cloned_field(sub_field) for sub_field in field.sub_fields
        ]
    if field.key_field:
        new_field.key_field = create_cloned_field(field.key_field)
    new_field.validators = field.validators
    if PYDANTIC_1:
        new_field.pre_validators = field.pre_validators
        new_field.post_validators = field.post_validators
    else:  # pragma: nocover
        new_field.whole_pre_validators = field.whole_pre_validators  # type: ignore
        new_field.whole_post_validators = field.whole_post_validators  # type: ignore
    new_field.parse_json = field.parse_json
    new_field.shape = field.shape
    try:
        new_field.populate_validators()
    except AttributeError:  # pragma: nocover
        # TODO: remove when removing support for Pydantic < 1.0.0
        new_field._populate_validators()  # type: ignore
    return new_field
Esempio n. 6
0
def get_openapi_path(
    *, route: ViewRoute, model_name_map: Dict[Type, str]
) -> Tuple[Dict, Dict]:
    path = {}
    definitions: Dict[str, Any] = {}
    assert route.methods is not None, "Methods must be a list"
    assert route.response_class, "A response class is n" "eeded to generate OpenAPI"
    #
    route_response_media_type = "application/json"
    if route.include_in_schema:
        for method in route.methods:
            operation = get_openapi_operation_metadata(route=route, method=method)
            parameters: List[Dict] = []
            all_route_params = route.dependant.get_flat_params()
            operation_parameters = get_openapi_operation_parameters(
                all_route_params=all_route_params, model_name_map=model_name_map
            )
            parameters.extend(operation_parameters)
            if parameters:
                operation["parameters"] = list(
                    {param["name"]: param for param in parameters}.values()
                )
            if method in METHODS_WITH_BODY:
                request_body_oai = get_openapi_operation_request_body(
                    body_field=route.body_field, model_name_map=model_name_map
                )
                if request_body_oai:
                    operation["requestBody"] = request_body_oai
            status_code = str(route.status_code)
            operation.setdefault("responses", {}).setdefault(status_code, {})[
                "description"
            ] = route.response_description
            if (
                # route_response_media_type
                route.status_code
                not in STATUS_CODES_WITH_NO_BODY
            ):
                response_schema = {"type": "string"}
                if lenient_issubclass(route.response_class, JsonResponse):
                    if route.response_field:
                        response_schema, _, _ = field_schema(
                            route.response_field,
                            model_name_map=model_name_map,
                            ref_prefix=REF_PREFIX,
                        )
                    else:
                        response_schema = {}
                operation.setdefault("responses", {}).setdefault(
                    status_code, {}
                ).setdefault("content", {}).setdefault(route_response_media_type, {})[
                    "schema"
                ] = response_schema
            http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
            #
            if (all_route_params or route.body_field) and not any(
                [
                    status in operation["responses"]
                    for status in [http422, "4XX", "default"]
                ]
            ):
                operation["responses"][http422] = {
                    "description": "Validation Error",
                    "content": {
                        "application/json": {
                            "schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
                        }
                    },
                }
                if "ValidationError" not in definitions:
                    definitions.update(
                        {
                            "ValidationError": validation_error_definition,
                            "HTTPValidationError": validation_error_response_definition,
                        }
                    )
            path[method.lower()] = operation

    return path, definitions
Esempio n. 7
0
    def __init__(
        self,
        path: str,
        endpoint: Callable,
        *,
        response_model: Type[Any] = None,
        status_code: int = 200,
        tags: List[str] = None,
        dependencies: List[params.Depends] = None,
        summary: str = None,
        description: str = None,
        response_description: str = "Successful Response",
        responses: Dict[Union[int, str], Dict[str, Any]] = None,
        deprecated: bool = None,
        name: str = None,
        methods: List[str] = None,
        operation_id: str = None,
        response_model_include: Set[str] = None,
        response_model_exclude: Set[str] = set(),
        response_model_by_alias: bool = True,
        response_model_skip_defaults: bool = False,
        include_in_schema: bool = True,
        response_class: Type[Response] = JSONResponse,
        dependency_overrides_provider: Any = None,
    ) -> None:
        assert path.startswith("/"), "Routed paths must always start with '/'"
        self.path = path
        self.endpoint = endpoint
        self.name = get_name(endpoint) if name is None else name
        self.response_model = response_model
        if self.response_model:
            assert lenient_issubclass(
                response_class, JSONResponse
            ), "To declare a type the response must be a JSON response"
            response_name = "Response_" + self.name
            self.response_field: Optional[Field] = Field(
                name=response_name,
                type_=self.response_model,
                class_validators={},
                default=None,
                required=False,
                model_config=BaseConfig,
                schema=Schema(None),
            )
        else:
            self.response_field = None
        self.status_code = status_code
        self.tags = tags or []
        self.dependencies = dependencies or []
        self.summary = summary
        self.description = description or inspect.cleandoc(
            self.endpoint.__doc__ or "")
        self.response_description = response_description
        self.responses = responses or {}
        response_fields = {}
        for additional_status_code, response in self.responses.items():
            assert isinstance(response,
                              dict), "An additional response must be a dict"
            model = response.get("model")
            if model:
                assert lenient_issubclass(
                    model,
                    BaseModel), "A response model must be a Pydantic model"
                response_name = f"Response_{additional_status_code}_{self.name}"
                response_field = Field(
                    name=response_name,
                    type_=model,
                    class_validators=None,
                    default=None,
                    required=False,
                    model_config=BaseConfig,
                    schema=Schema(None),
                )
                response_fields[additional_status_code] = response_field
        if response_fields:
            self.response_fields: Dict[Union[int, str],
                                       Field] = response_fields
        else:
            self.response_fields = {}
        self.deprecated = deprecated
        if methods is None:
            methods = ["GET"]
        self.methods = methods
        self.operation_id = operation_id
        self.response_model_include = response_model_include
        self.response_model_exclude = response_model_exclude
        self.response_model_by_alias = response_model_by_alias
        self.response_model_skip_defaults = response_model_skip_defaults
        self.include_in_schema = include_in_schema
        self.response_class = response_class

        self.path_regex, self.path_format, self.param_convertors = compile_path(
            path)
        assert inspect.isfunction(endpoint) or inspect.ismethod(
            endpoint), f"An endpoint must be a function or method"
        self.dependant = get_dependant(path=self.path_format,
                                       call=self.endpoint)
        for depends in self.dependencies[::-1]:
            self.dependant.dependencies.insert(
                0,
                get_parameterless_sub_dependant(depends=depends,
                                                path=self.path_format),
            )
        self.body_field = get_body_field(dependant=self.dependant,
                                         name=self.name)
        self.dependency_overrides_provider = dependency_overrides_provider
        self.app = request_response(
            get_app(
                dependant=self.dependant,
                body_field=self.body_field,
                status_code=self.status_code,
                response_class=self.response_class,
                response_field=self.response_field,
                response_model_include=self.response_model_include,
                response_model_exclude=self.response_model_exclude,
                response_model_by_alias=self.response_model_by_alias,
                response_model_skip_defaults=self.response_model_skip_defaults,
                dependency_overrides_provider=self.
                dependency_overrides_provider,
            ))
Esempio n. 8
0
def test_lenient_issubclass_is_lenient():
    assert lenient_issubclass('a', 'a') is False
Esempio n. 9
0
def get_dependant(*,
                  path: str,
                  call: Callable,
                  name: str = None,
                  security_scopes: List[str] = None) -> Dependant:
    path_param_names = get_path_param_names(path)
    endpoint_signature = inspect.signature(call)
    signature_params = endpoint_signature.parameters
    dependant = Dependant(call=call, name=name)
    for param_name in signature_params:
        param = signature_params[param_name]
        if isinstance(param.default, params.Depends):
            sub_dependant = get_param_sub_dependant(
                param=param, path=path, security_scopes=security_scopes)
            dependant.dependencies.append(sub_dependant)
    for param_name in signature_params:
        param = signature_params[param_name]
        if ((param.default == param.empty) or isinstance(
                param.default, params.Path)) and (param_name
                                                  in path_param_names):
            assert (lenient_issubclass(param.annotation, param_supported_types)
                    or param.annotation == param.empty
                    ), f"Path params must be of one of the supported types"
            add_param_to_fields(
                param=param,
                dependant=dependant,
                default_schema=params.Path,
                force_type=params.ParamTypes.path,
            )
        elif (param.default == param.empty or param.default is None
              or isinstance(param.default, param_supported_types)) and (
                  param.annotation == param.empty or lenient_issubclass(
                      param.annotation, param_supported_types)):
            add_param_to_fields(param=param,
                                dependant=dependant,
                                default_schema=params.Query)
        elif isinstance(param.default, params.Param):
            if param.annotation != param.empty:
                origin = getattr(param.annotation, "__origin__", None)
                param_all_types = param_supported_types + (list, tuple, set)
                if isinstance(param.default, (params.Query, params.Header)):
                    assert lenient_issubclass(
                        param.annotation, param_all_types
                    ) or lenient_issubclass(
                        origin, param_all_types
                    ), f"Parameters for Query and Header must be of type str, int, float, bool, list, tuple or set: {param}"
                else:
                    assert lenient_issubclass(
                        param.annotation, param_supported_types
                    ), f"Parameters for Path and Cookies must be of type str, int, float, bool: {param}"
            add_param_to_fields(param=param,
                                dependant=dependant,
                                default_schema=params.Query)
        elif lenient_issubclass(param.annotation, Request):
            dependant.request_param_name = param_name
        elif lenient_issubclass(param.annotation, WebSocket):
            dependant.websocket_param_name = param_name
        elif lenient_issubclass(param.annotation, BackgroundTasks):
            dependant.background_tasks_param_name = param_name
        elif lenient_issubclass(param.annotation, SecurityScopes):
            dependant.security_scopes_param_name = param_name
        elif not isinstance(param.default, params.Depends):
            add_param_to_body_fields(param=param, dependant=dependant)
    return dependant
Esempio n. 10
0
def test_lenient_issubclass_with_generic_aliases():
    from collections.abc import Mapping

    # should not raise an error here:
    assert lenient_issubclass(list[str], Mapping) is False
Esempio n. 11
0
def get_openapi_path(
    *, route: routing.APIRoute, model_name_map: Dict[Type, str]
) -> Tuple[Dict, Dict, Dict]:
    path = {}
    security_schemes: Dict[str, Any] = {}
    definitions: Dict[str, Any] = {}
    assert route.methods is not None, "Methods must be a list"
    if route.include_in_schema:
        for method in route.methods:
            operation = get_openapi_operation_metadata(route=route, method=method)
            parameters: List[Dict] = []
            flat_dependant = get_flat_dependant(route.dependant)
            security_definitions, operation_security = get_openapi_security_definitions(
                flat_dependant=flat_dependant
            )
            if operation_security:
                operation.setdefault("security", []).extend(operation_security)
            if security_definitions:
                security_schemes.update(security_definitions)
            all_route_params = get_openapi_params(route.dependant)
            validation_definitions, operation_parameters = get_openapi_operation_parameters(
                all_route_params=all_route_params
            )
            definitions.update(validation_definitions)
            parameters.extend(operation_parameters)
            if parameters:
                operation["parameters"] = parameters
            if method in METHODS_WITH_BODY:
                request_body_oai = get_openapi_operation_request_body(
                    body_field=route.body_field, model_name_map=model_name_map
                )
                if request_body_oai:
                    operation["requestBody"] = request_body_oai
                    if "ValidationError" not in definitions:
                        definitions["ValidationError"] = validation_error_definition
                        definitions[
                            "HTTPValidationError"
                        ] = validation_error_response_definition
            if route.responses:
                for (additional_status_code, response) in route.responses.items():
                    assert isinstance(
                        response, dict
                    ), "An additional response must be a dict"
                    field = route.response_fields.get(additional_status_code)
                    if field:
                        response_schema, _ = field_schema(
                            field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
                        )
                        response.setdefault("content", {}).setdefault(
                            "application/json", {}
                        )["schema"] = response_schema
                    status_text = http.client.responses.get(int(additional_status_code))
                    response.setdefault(
                        "description", status_text or "Additional Response"
                    )
                    operation.setdefault("responses", {})[
                        str(additional_status_code)
                    ] = response
            status_code = str(route.status_code)
            response_schema = {"type": "string"}
            if lenient_issubclass(route.response_class, JSONResponse):
                if route.response_field:
                    response_schema, _ = field_schema(
                        route.response_field,
                        model_name_map=model_name_map,
                        ref_prefix=REF_PREFIX,
                    )
                else:
                    response_schema = {}
            operation.setdefault("responses", {}).setdefault(status_code, {})[
                "description"
            ] = route.response_description
            operation.setdefault("responses", {}).setdefault(
                status_code, {}
            ).setdefault("content", {}).setdefault(route.response_class.media_type, {})[
                "schema"
            ] = response_schema
            if all_route_params or route.body_field:
                operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
                    "description": "Validation Error",
                    "content": {
                        "application/json": {
                            "schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
                        }
                    },
                }
            path[method.lower()] = operation
    return path, security_schemes, definitions
Esempio n. 12
0
def get_openapi_path(
    *, route: routing.APIRoute, model_name_map: Dict[type, str],
    operation_ids: Set[str]
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
    path = {}
    security_schemes: Dict[str, Any] = {}
    definitions: Dict[str, Any] = {}
    assert route.methods is not None, "Methods must be a list"
    if isinstance(route.response_class, DefaultPlaceholder):
        current_response_class: Type[Response] = route.response_class.value
    else:
        current_response_class = route.response_class
    assert current_response_class, "A response class is needed to generate OpenAPI"
    route_response_media_type: Optional[
        str] = current_response_class.media_type
    if route.include_in_schema:
        for method in route.methods:
            operation = get_openapi_operation_metadata(
                route=route, method=method, operation_ids=operation_ids)
            parameters: List[Dict[str, Any]] = []
            flat_dependant = get_flat_dependant(route.dependant,
                                                skip_repeats=True)
            security_definitions, operation_security = get_openapi_security_definitions(
                flat_dependant=flat_dependant)
            if operation_security:
                operation.setdefault("security", []).extend(operation_security)
            if security_definitions:
                security_schemes.update(security_definitions)
            all_route_params = get_flat_params(route.dependant)
            operation_parameters = get_openapi_operation_parameters(
                all_route_params=all_route_params,
                model_name_map=model_name_map)
            parameters.extend(operation_parameters)
            if parameters:
                operation["parameters"] = list(
                    {param["name"]: param
                     for param in parameters}.values())
            if method in METHODS_WITH_BODY:
                request_body_oai = get_openapi_operation_request_body(
                    body_field=route.body_field, model_name_map=model_name_map)
                if request_body_oai:
                    operation["requestBody"] = request_body_oai
            if route.callbacks:
                callbacks = {}
                for callback in route.callbacks:
                    if isinstance(callback, routing.APIRoute):
                        (
                            cb_path,
                            cb_security_schemes,
                            cb_definitions,
                        ) = get_openapi_path(
                            route=callback,
                            model_name_map=model_name_map,
                            operation_ids=operation_ids,
                        )
                        callbacks[callback.name] = {callback.path: cb_path}
                operation["callbacks"] = callbacks
            if route.status_code is not None:
                status_code = str(route.status_code)
            else:
                # It would probably make more sense for all response classes to have an
                # explicit default status_code, and to extract it from them, instead of
                # doing this inspection tricks, that would probably be in the future
                # TODO: probably make status_code a default class attribute for all
                # responses in Starlette
                response_signature = inspect.signature(
                    current_response_class.__init__)
                status_code_param = response_signature.parameters.get(
                    "status_code")
                if status_code_param is not None:
                    if isinstance(status_code_param.default, int):
                        status_code = str(status_code_param.default)
            operation.setdefault("responses", {}).setdefault(
                status_code, {})["description"] = route.response_description
            if (route_response_media_type
                    and route.status_code not in STATUS_CODES_WITH_NO_BODY):
                response_schema = {"type": "string"}
                if lenient_issubclass(current_response_class, JSONResponse):
                    if route.response_field:
                        response_schema, _, _ = field_schema(
                            route.response_field,
                            model_name_map=model_name_map,
                            ref_prefix=REF_PREFIX,
                        )
                    else:
                        response_schema = {}
                operation.setdefault("responses", {}).setdefault(
                    status_code, {}).setdefault("content", {}).setdefault(
                        route_response_media_type,
                        {})["schema"] = response_schema
            if route.responses:
                operation_responses = operation.setdefault("responses", {})
                for (
                        additional_status_code,
                        additional_response,
                ) in route.responses.items():
                    process_response = additional_response.copy()
                    process_response.pop("model", None)
                    status_code_key = str(additional_status_code).upper()
                    if status_code_key == "DEFAULT":
                        status_code_key = "default"
                    openapi_response = operation_responses.setdefault(
                        status_code_key, {})
                    assert isinstance(
                        process_response,
                        dict), "An additional response must be a dict"
                    field = route.response_fields.get(additional_status_code)
                    additional_field_schema: Optional[Dict[str, Any]] = None
                    if field:
                        additional_field_schema, _, _ = field_schema(
                            field,
                            model_name_map=model_name_map,
                            ref_prefix=REF_PREFIX)
                        media_type = route_response_media_type or "application/json"
                        additional_schema = (process_response.setdefault(
                            "content",
                            {}).setdefault(media_type,
                                           {}).setdefault("schema", {}))
                        deep_dict_update(additional_schema,
                                         additional_field_schema)
                    status_text: Optional[str] = status_code_ranges.get(
                        str(additional_status_code).upper()
                    ) or http.client.responses.get(int(additional_status_code))
                    description = (process_response.get("description")
                                   or openapi_response.get("description")
                                   or status_text or "Additional Response")
                    deep_dict_update(openapi_response, process_response)
                    openapi_response["description"] = description
            http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
            if (all_route_params or route.body_field) and not any([
                    status in operation["responses"]
                    for status in [http422, "4XX", "default"]
            ]):
                operation["responses"][http422] = {
                    "description": "Validation Error",
                    "content": {
                        "application/json": {
                            "schema": {
                                "$ref": REF_PREFIX + "HTTPValidationError"
                            }
                        }
                    },
                }
                if "ValidationError" not in definitions:
                    definitions.update({
                        "ValidationError":
                        validation_error_definition,
                        "HTTPValidationError":
                        validation_error_response_definition,
                    })
            if route.openapi_extra:
                deep_dict_update(operation, route.openapi_extra)
            path[method.lower()] = operation
    return path, security_schemes, definitions
Esempio n. 13
0
    def __init__(
        self,
        path: str,
        endpoint: Callable,
        *,
        response_model: Type[Any] = None,
        status_code: int = 200,
        tags: List[str] = None,
        dependencies: Sequence[params.Depends] = None,
        summary: str = None,
        description: str = None,
        response_description: str = "Successful Response",
        responses: Dict[Union[int, str], Dict[str, Any]] = None,
        deprecated: bool = None,
        name: str = None,
        methods: Optional[Union[Set[str], List[str]]] = None,
        operation_id: str = None,
        response_model_include: Union[SetIntStr, DictIntStrAny] = None,
        response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
        response_model_by_alias: bool = True,
        response_model_exclude_unset: bool = False,
        include_in_schema: bool = True,
        response_class: Optional[Type[Response]] = None,
        dependency_overrides_provider: Any = None,
    ) -> None:
        self.path = path
        self.endpoint = endpoint
        self.name = get_name(endpoint) if name is None else name
        self.path_regex, self.path_format, self.param_convertors = compile_path(
            path)
        if methods is None:
            methods = ["GET"]
        self.methods = set([method.upper() for method in methods])
        self.unique_id = generate_operation_id_for_path(
            name=self.name, path=self.path_format, method=list(methods)[0])
        self.response_model = response_model
        if self.response_model:
            assert (
                status_code not in STATUS_CODES_WITH_NO_BODY
            ), f"Status code {status_code} must not have a response body"
            response_name = "Response_" + self.unique_id
            if PYDANTIC_1:
                self.response_field: Optional[ModelField] = ModelField(
                    name=response_name,
                    type_=self.response_model,
                    class_validators={},
                    default=None,
                    required=False,
                    model_config=BaseConfig,
                    field_info=FieldInfo(None),
                )
            else:
                self.response_field: Optional[
                    ModelField] = ModelField(  # type: ignore  # pragma: nocover
                        name=response_name,
                        type_=self.response_model,
                        class_validators={},
                        default=None,
                        required=False,
                        model_config=BaseConfig,
                        schema=FieldInfo(None),
                    )
            # Create a clone of the field, so that a Pydantic submodel is not returned
            # as is just because it's an instance of a subclass of a more limited class
            # e.g. UserInDB (containing hashed_password) could be a subclass of User
            # that doesn't have the hashed_password. But because it's a subclass, it
            # would pass the validation and be returned as is.
            # By being a new field, no inheritance will be passed as is. A new model
            # will be always created.
            self.secure_cloned_response_field: Optional[
                ModelField] = create_cloned_field(self.response_field)
        else:
            self.response_field = None
            self.secure_cloned_response_field = None
        self.status_code = status_code
        self.tags = tags or []
        if dependencies:
            self.dependencies = list(dependencies)
        else:
            self.dependencies = []
        self.summary = summary
        self.description = description or inspect.cleandoc(
            self.endpoint.__doc__ or "")
        # if a "form feed" character (page break) is found in the description text,
        # truncate description text to the content preceding the first "form feed"
        self.description = self.description.split("\f")[0]
        self.response_description = response_description
        self.responses = responses or {}
        response_fields = {}
        for additional_status_code, response in self.responses.items():
            assert isinstance(response,
                              dict), "An additional response must be a dict"
            model = response.get("model")
            if model:
                assert (
                    additional_status_code not in STATUS_CODES_WITH_NO_BODY
                ), f"Status code {additional_status_code} must not have a response body"
                assert lenient_issubclass(
                    model,
                    BaseModel), "A response model must be a Pydantic model"
                response_name = f"Response_{additional_status_code}_{self.unique_id}"
                if PYDANTIC_1:
                    response_field = ModelField(
                        name=response_name,
                        type_=model,
                        class_validators=None,
                        default=None,
                        required=False,
                        model_config=BaseConfig,
                        field_info=FieldInfo(None),
                    )
                else:
                    response_field = ModelField(  # type: ignore  # pragma: nocover
                        name=response_name,
                        type_=model,
                        class_validators=None,
                        default=None,
                        required=False,
                        model_config=BaseConfig,
                        schema=FieldInfo(None),
                    )
                response_fields[additional_status_code] = response_field
        if response_fields:
            self.response_fields: Dict[Union[int, str],
                                       ModelField] = response_fields
        else:
            self.response_fields = {}
        self.deprecated = deprecated
        self.operation_id = operation_id
        self.response_model_include = response_model_include
        self.response_model_exclude = response_model_exclude
        self.response_model_by_alias = response_model_by_alias
        self.response_model_exclude_unset = response_model_exclude_unset
        self.include_in_schema = include_in_schema
        self.response_class = response_class

        assert inspect.isfunction(endpoint) or inspect.ismethod(
            endpoint), f"An endpoint must be a function or method"
        self.dependant = get_dependant(path=self.path_format,
                                       call=self.endpoint)
        for depends in self.dependencies[::-1]:
            self.dependant.dependencies.insert(
                0,
                get_parameterless_sub_dependant(depends=depends,
                                                path=self.path_format),
            )
        self.body_field = get_body_field(dependant=self.dependant,
                                         name=self.unique_id)
        self.dependency_overrides_provider = dependency_overrides_provider
        self.app = request_response(self.get_route_handler())
Esempio n. 14
0
def is_scalar_field(field: Field) -> bool:
    return (field.shape == Shape.SINGLETON
            and not lenient_issubclass(field.type_, BaseModel)
            and not lenient_issubclass(field.type_, sequence_types + (dict, ))
            and not isinstance(field.schema, params.Body))
Esempio n. 15
0
        def __new__(mcs, name, bases, namespace, **kwargs):
            from pydantic.fields import Undefined
            from pydantic.class_validators import extract_validators, inherit_validators
            from pydantic.types import PyObject
            from pydantic.typing import is_classvar, resolve_annotations
            from pydantic.utils import lenient_issubclass, validate_field_name
            from pydantic.main import inherit_config, prepare_config, UNTOUCHED_TYPES

            fields: Dict[str, ModelField] = {}
            config = BaseConfig
            validators: Dict[str, List[Validator]] = {}

            for base in reversed(bases):
                if issubclass(base, AbstractCheckedSession) and base != AbstractCheckedSession:
                    config = inherit_config(base.__config__, config)
                    fields.update(deepcopy(base.__fields__))
                    validators = inherit_validators(base.__validators__, validators)

            config = inherit_config(namespace.get('Config'), config)
            validators = inherit_validators(extract_validators(namespace), validators)

            # update fields inherited from base classes
            for field in fields.values():
                field.set_config(config)
                extra_validators = validators.get(field.name, [])
                if extra_validators:
                    field.class_validators.update(extra_validators)
                    # re-run prepare to add extra validators
                    field.populate_validators()

            prepare_config(config, name)

            # extract and build fields
            class_vars = set()
            if (namespace.get('__module__'), namespace.get('__qualname__')) != \
                    ('larray.core.checked', 'CheckedSession'):
                untouched_types = UNTOUCHED_TYPES + config.keep_untouched

                # annotation only fields need to come first in fields
                annotations = resolve_annotations(namespace.get('__annotations__', {}),
                                                  namespace.get('__module__', None))
                for ann_name, ann_type in annotations.items():
                    if is_classvar(ann_type):
                        class_vars.add(ann_name)
                    elif not ann_name.startswith('_'):
                        validate_field_name(bases, ann_name)
                        value = namespace.get(ann_name, Undefined)
                        if (isinstance(value, untouched_types) and ann_type != PyObject
                                and not lenient_issubclass(getattr(ann_type, '__origin__', None), Type)):
                            continue
                        fields[ann_name] = ModelField.infer(name=ann_name, value=value, annotation=ann_type,
                                                            class_validators=validators.get(ann_name, []),
                                                            config=config)

                for var_name, value in namespace.items():
                    # 'var_name not in annotations' because namespace.items() contains annotated fields
                    # with default values
                    # 'var_name not in class_vars' to avoid to update a field if it was redeclared (by mistake)
                    if (var_name not in annotations and not var_name.startswith('_')
                            and not isinstance(value, untouched_types) and var_name not in class_vars):
                        validate_field_name(bases, var_name)
                        # the method ModelField.infer() fails to infer the type of Group objects
                        # (which are interpreted as ndarray objects)
                        annotation = type(value) if isinstance(value, Group) else annotations.get(var_name)
                        inferred = ModelField.infer(name=var_name, value=value, annotation=annotation,
                                                    class_validators=validators.get(var_name, []), config=config)
                        if var_name in fields and inferred.type_ != fields[var_name].type_:
                            raise TypeError(f'The type of {name}.{var_name} differs from the new default value; '
                                            f'if you wish to change the type of this field, please use a type '
                                            f'annotation')
                        fields[var_name] = inferred

            new_namespace = {
                '__config__': config,
                '__fields__': fields,
                '__field_defaults__': {n: f.default for n, f in fields.items() if not f.required},
                '__validators__': validators,
                **{n: v for n, v in namespace.items() if n not in fields},
            }
            return super().__new__(mcs, name, bases, new_namespace, **kwargs)
Esempio n. 16
0
def test_lenient_issubclass():
    class A(str):
        pass

    assert lenient_issubclass(A, str) is True
Esempio n. 17
0
async def request_body_to_args(
    required_params: List[ModelField],
    received_body: Optional[Union[Dict[str, Any], FormData]],
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
    values = {}
    errors = []
    if required_params:
        field = required_params[0]
        field_info = field.field_info
        embed = getattr(field_info, "embed", None)
        field_alias_omitted = len(required_params) == 1 and not embed
        if field_alias_omitted:
            received_body = {field.alias: received_body}

        for field in required_params:
            loc: Tuple[str, ...]
            if field_alias_omitted:
                loc = ("body", )
            else:
                loc = ("body", field.alias)

            value: Optional[Any] = None
            if received_body is not None:
                if (field.shape in sequence_shapes
                        or field.type_ in sequence_types) and isinstance(
                            received_body, FormData):
                    value = received_body.getlist(field.alias)
                else:
                    try:
                        value = received_body.get(field.alias)
                    except AttributeError:
                        errors.append(get_missing_field_error(loc))
                        continue
            if (value is None
                    or (isinstance(field_info, params.Form) and value == "") or
                (isinstance(field_info, params.Form)
                 and field.shape in sequence_shapes and len(value) == 0)):
                if field.required:
                    errors.append(get_missing_field_error(loc))
                else:
                    values[field.name] = deepcopy(field.default)
                continue
            if (isinstance(field_info, params.File)
                    and lenient_issubclass(field.type_, bytes)
                    and isinstance(value, UploadFile)):
                value = await value.read()
            elif (field.shape in sequence_shapes
                  and isinstance(field_info, params.File)
                  and lenient_issubclass(field.type_, bytes)
                  and isinstance(value, sequence_types)):
                results: List[Union[bytes, str]] = []

                async def process_fn(
                        fn: Callable[[], Coroutine[Any, Any, Any]]) -> None:
                    result = await fn()
                    results.append(result)

                async with anyio.create_task_group() as tg:
                    for sub_value in value:
                        tg.start_soon(process_fn, sub_value.read)
                value = sequence_shape_to_type[field.shape](results)

            v_, errors_ = field.validate(value, values, loc=loc)

            if isinstance(errors_, ErrorWrapper):
                errors.append(errors_)
            elif isinstance(errors_, list):
                errors.extend(errors_)
            else:
                values[field.name] = v_
    return values, errors
Esempio n. 18
0
def get_openapi_path(
    *, route: routing.APIRoute, model_name_map: Dict[Type, str]
) -> Tuple[Dict, Dict, Dict]:
    path = {}
    security_schemes: Dict[str, Any] = {}
    definitions: Dict[str, Any] = {}
    assert route.methods is not None, "Methods must be a list"
    assert route.response_class, "A response class is needed to generate OpenAPI"
    route_response_media_type: Optional[str] = route.response_class.media_type
    if route.include_in_schema:
        for method in route.methods:
            operation = get_openapi_operation_metadata(route=route, method=method)
            parameters: List[Dict] = []
            flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
            security_definitions, operation_security = get_openapi_security_definitions(
                flat_dependant=flat_dependant
            )
            if operation_security:
                operation.setdefault("security", []).extend(operation_security)
            if security_definitions:
                security_schemes.update(security_definitions)
            all_route_params = get_openapi_params(route.dependant)
            operation_parameters = get_openapi_operation_parameters(all_route_params)
            parameters.extend(operation_parameters)
            if parameters:
                operation["parameters"] = parameters
            if method in METHODS_WITH_BODY:
                request_body_oai = get_openapi_operation_request_body(
                    body_field=route.body_field, model_name_map=model_name_map
                )
                if request_body_oai:
                    operation["requestBody"] = request_body_oai
            if route.callbacks:
                callbacks = {}
                for callback in route.callbacks:
                    cb_path, cb_security_schemes, cb_definitions, = get_openapi_path(
                        route=callback, model_name_map=model_name_map
                    )
                    callbacks[callback.name] = {callback.path: cb_path}
                operation["callbacks"] = callbacks
            if route.responses:
                for (additional_status_code, response) in route.responses.items():
                    assert isinstance(
                        response, dict
                    ), "An additional response must be a dict"
                    field = route.response_fields.get(additional_status_code)
                    if field:
                        response_schema, _, _ = field_schema(
                            field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
                        )
                        response.setdefault("content", {}).setdefault(
                            route_response_media_type or "application/json", {}
                        )["schema"] = response_schema
                    status_text: Optional[str] = status_code_ranges.get(
                        str(additional_status_code).upper()
                    ) or http.client.responses.get(int(additional_status_code))
                    response.setdefault(
                        "description", status_text or "Additional Response"
                    )
                    status_code_key = str(additional_status_code).upper()
                    if status_code_key == "DEFAULT":
                        status_code_key = "default"
                    operation.setdefault("responses", {})[status_code_key] = response
            status_code = str(route.status_code)
            operation.setdefault("responses", {}).setdefault(status_code, {})[
                "description"
            ] = route.response_description
            if (
                route_response_media_type
                and route.status_code not in STATUS_CODES_WITH_NO_BODY
            ):
                response_schema = {"type": "string"}
                if lenient_issubclass(route.response_class, JSONResponse):
                    if route.response_field:
                        response_schema, _, _ = field_schema(
                            route.response_field,
                            model_name_map=model_name_map,
                            ref_prefix=REF_PREFIX,
                        )
                    else:
                        response_schema = {}
                operation.setdefault("responses", {}).setdefault(
                    status_code, {}
                ).setdefault("content", {}).setdefault(route_response_media_type, {})[
                    "schema"
                ] = response_schema

            http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
            if (all_route_params or route.body_field) and not any(
                [
                    status in operation["responses"]
                    for status in [http422, "4XX", "default"]
                ]
            ):
                operation["responses"][http422] = {
                    "description": "Validation Error",
                    "content": {
                        "application/json": {
                            "schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
                        }
                    },
                }
                if "ValidationError" not in definitions:
                    definitions.update(
                        {
                            "ValidationError": validation_error_definition,
                            "HTTPValidationError": validation_error_response_definition,
                        }
                    )
            path[method.lower()] = operation
    return path, security_schemes, definitions
Esempio n. 19
0
    def __init__(
        self,
        path: str,
        endpoint: Callable,
        *,
        response_model: Type[BaseModel] = None,
        status_code: int = 200,
        tags: List[str] = None,
        summary: str = None,
        description: str = None,
        response_description: str = "Successful Response",
        deprecated: bool = None,
        name: str = None,
        methods: List[str] = None,
        operation_id: str = None,
        include_in_schema: bool = True,
        content_type: Type[Response] = JSONResponse,
    ) -> None:
        assert path.startswith("/"), "Routed paths must always start with '/'"
        self.path = path
        self.endpoint = endpoint
        self.name = get_name(endpoint) if name is None else name
        self.response_model = response_model
        if self.response_model:
            assert lenient_issubclass(
                content_type, JSONResponse
            ), "To declare a type the response must be a JSON response"
            response_name = "Response_" + self.name
            self.response_field: Optional[Field] = Field(
                name=response_name,
                type_=self.response_model,
                class_validators=[],
                default=None,
                required=False,
                model_config=UnconstrainedConfig,
                schema=Schema(None),
            )
        else:
            self.response_field = None
        self.status_code = status_code
        self.tags = tags or []
        self.summary = summary
        self.description = description or self.endpoint.__doc__
        self.response_description = response_description
        self.deprecated = deprecated
        if methods is None:
            methods = ["GET"]
        self.methods = methods
        self.operation_id = operation_id
        self.include_in_schema = include_in_schema
        self.content_type = content_type

        self.path_regex, self.path_format, self.param_convertors = self.compile_path(
            path)
        assert inspect.isfunction(endpoint) or inspect.ismethod(
            endpoint), f"An endpoint must be a function or method"
        self.dependant = get_dependant(path=path, call=self.endpoint)
        self.body_field = get_body_field(dependant=self.dependant,
                                         name=self.name)
        self.app = request_response(
            get_app(
                dependant=self.dependant,
                body_field=self.body_field,
                status_code=self.status_code,
                content_type=self.content_type,
                response_field=self.response_field,
            ))
Esempio n. 20
0
async def request_body_to_args(
    required_params: List[Field],
    received_body: Optional[Union[Dict[str, Any], FormData]],
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
    values = {}
    errors = []
    if required_params:
        field = required_params[0]
        embed = getattr(field.schema, "embed", None)
        if len(required_params) == 1 and not embed:
            received_body = {field.alias: received_body}
        for field in required_params:
            value: Any = None
            if received_body is not None:
                if field.shape in sequence_shapes and isinstance(
                    received_body, FormData
                ):
                    value = received_body.getlist(field.alias)
                else:
                    value = received_body.get(field.alias)
            if (
                value is None
                or (isinstance(field.schema, params.Form) and value == "")
                or (
                    isinstance(field.schema, params.Form)
                    and field.shape in sequence_shapes
                    and len(value) == 0
                )
            ):
                if field.required:
                    errors.append(
                        ErrorWrapper(
                            MissingError(), loc=("body", field.alias), config=BaseConfig
                        )
                    )
                else:
                    values[field.name] = deepcopy(field.default)
                continue
            if (
                isinstance(field.schema, params.File)
                and lenient_issubclass(field.type_, bytes)
                and isinstance(value, UploadFile)
            ):
                value = await value.read()
            elif (
                field.shape in sequence_shapes
                and isinstance(field.schema, params.File)
                and lenient_issubclass(field.type_, bytes)
                and isinstance(value, sequence_types)
            ):
                awaitables = [sub_value.read() for sub_value in value]
                contents = await asyncio.gather(*awaitables)
                value = sequence_shape_to_type[field.shape](contents)
            v_, errors_ = field.validate(value, values, loc=("body", field.alias))
            if isinstance(errors_, ErrorWrapper):
                errors.append(errors_)
            elif isinstance(errors_, list):
                errors.extend(errors_)
            else:
                values[field.name] = v_
    return values, errors
Esempio n. 21
0
def get_openapi_path(
    *, route: routing.APIRoute, model_name_map: Dict[Type, str]
) -> Tuple[Dict, Dict, Dict]:
    path = {}
    security_schemes: Dict[str, Any] = {}
    definitions: Dict[str, Any] = {}
    assert route.methods is not None, "Methods must be a list"
    for method in route.methods:
        operation = get_openapi_operation_metadata(route=route, method=method)
        parameters: List[Dict] = []
        flat_dependant = get_flat_dependant(route.dependant)
        security_definitions, operation_security = get_openapi_security_definitions(
            flat_dependant=flat_dependant
        )
        if operation_security:
            operation.setdefault("security", []).extend(operation_security)
        if security_definitions:
            security_schemes.update(security_definitions)
        all_route_params = get_openapi_params(route.dependant)
        validation_definitions, operation_parameters = get_openapi_operation_parameters(
            all_route_params=all_route_params
        )
        definitions.update(validation_definitions)
        parameters.extend(operation_parameters)
        if parameters:
            operation["parameters"] = parameters
        if method in METHODS_WITH_BODY:
            request_body_oai = get_openapi_operation_request_body(
                body_field=route.body_field, model_name_map=model_name_map
            )
            if request_body_oai:
                operation["requestBody"] = request_body_oai
                if "ValidationError" not in definitions:
                    definitions["ValidationError"] = validation_error_definition
                    definitions[
                        "HTTPValidationError"
                    ] = validation_error_response_definition
        status_code = str(route.status_code)
        response_schema = {"type": "string"}
        if lenient_issubclass(route.content_type, JSONResponse):
            if route.response_field:
                response_schema, _ = field_schema(
                    route.response_field,
                    model_name_map=model_name_map,
                    ref_prefix=REF_PREFIX,
                )
            else:
                response_schema = {}
        content = {route.content_type.media_type: {"schema": response_schema}}
        operation["responses"] = {
            status_code: {"description": route.response_description, "content": content}
        }
        if all_route_params or route.body_field:
            operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
                "description": "Validation Error",
                "content": {
                    "application/json": {
                        "schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
                    }
                },
            }
        path[method.lower()] = operation
    return path, security_schemes, definitions