Exemple #1
0
    def modify_signature(
        endpoint: t.Callable,
        model: t.Type[PydanticBaseModel],
        http_method: HTTPMethod,
        allow_pagination: bool = False,
        allow_ordering: bool = False,
    ):
        """
        Modify the endpoint's signature, so FastAPI can correctly extract the
        schema from it. GET endpoints are given more filters.
        """
        parameters = [
            Parameter(
                name="request",
                kind=Parameter.POSITIONAL_OR_KEYWORD,
                annotation=Request,
            ),
        ]

        for field_name, _field in model.__fields__.items():
            type_ = _field.outer_type_
            parameters.append(
                Parameter(
                    name=field_name,
                    kind=Parameter.POSITIONAL_OR_KEYWORD,
                    default=Query(
                        default=None,
                        description=(f"Filter by the `{field_name}` column."),
                    ),
                    annotation=type_,
                ),
            )

            if type_ in (int, float, Decimal):
                parameters.append(
                    Parameter(
                        name=f"{field_name}__operator",
                        kind=Parameter.POSITIONAL_OR_KEYWORD,
                        default=Query(
                            default=None,
                            description=(
                                f"Which operator to use for `{field_name}`. "
                                "The options are `e` (equals - default) `lt`, "
                                "`lte`, `gt`, and `gte`."
                            ),
                        ),
                    )
                )

            if type_ is str:
                parameters.append(
                    Parameter(
                        name=f"{field_name}__match",
                        kind=Parameter.POSITIONAL_OR_KEYWORD,
                        default=Query(
                            default=None,
                            description=(
                                f"Specifies how `{field_name}` should be "
                                "matched - `contains` (default), `exact`, "
                                "`starts`, `ends`."
                            ),
                        ),
                    )
                )

        if http_method == HTTPMethod.get:
            if allow_ordering:
                parameters.extend(
                    [
                        Parameter(
                            name="__order",
                            kind=Parameter.POSITIONAL_OR_KEYWORD,
                            annotation=str,
                            default=Query(
                                default=None,
                                description=(
                                    "Specifies which field to sort the "
                                    "results by. For example `id` to sort by "
                                    "id, and `-id` for descending."
                                ),
                            ),
                        )
                    ]
                )

            if allow_pagination:
                parameters.extend(
                    [
                        Parameter(
                            name="__page_size",
                            kind=Parameter.POSITIONAL_OR_KEYWORD,
                            annotation=int,
                            default=Query(
                                default=None,
                                description=(
                                    "The number of results to return."
                                ),
                            ),
                        ),
                        Parameter(
                            name="__page",
                            kind=Parameter.POSITIONAL_OR_KEYWORD,
                            annotation=int,
                            default=Query(
                                default=None,
                                description=(
                                    "Which page of results to return (default "
                                    "1)."
                                ),
                            ),
                        ),
                        Parameter(
                            name="__visible_fields",
                            kind=Parameter.POSITIONAL_OR_KEYWORD,
                            annotation=str,
                            default=Query(
                                default=None,
                                description=(
                                    "The fields to return. It's a comma "
                                    "separated list - for example "
                                    "'name,address'. By default all fields "
                                    "are returned."
                                ),
                            ),
                        ),
                    ]
                )

            parameters.extend(
                [
                    Parameter(
                        name="__range_header",
                        kind=Parameter.POSITIONAL_OR_KEYWORD,
                        annotation=bool,
                        default=Query(
                            default=False,
                            description=(
                                "Set to 'true' to add the "
                                "Content-Range response header"
                            ),
                        ),
                    )
                ]
            )
            parameters.extend(
                [
                    Parameter(
                        name="__range_header_name",
                        kind=Parameter.POSITIONAL_OR_KEYWORD,
                        annotation=str,
                        default=Query(
                            default=None,
                            description=(
                                "Specify the object name in the Content-Range "
                                "response header (defaults to the table name)."
                            ),
                        ),
                    )
                ]
            )

        endpoint.__signature__ = Signature(  # type: ignore
            parameters=parameters
        )
Exemple #2
0
def update_wrapper(
    wrapper: T.Callable,
    wrapped: T.Callable,
    signature: T.Union[_FullerSig, None, bool] = True,  # not in functools
    docstring: T.Union[str, bool] = True,  # not in functools
    assigned: T.Sequence[str] = WRAPPER_ASSIGNMENTS,
    updated: T.Sequence[str] = WRAPPER_UPDATES,
    # docstring options
    _doc_fmt: T.Optional[dict] = None,  # not in functools
    _doc_style: T.Union[str, T.Callable, None] = None,
):
    """Update a wrapper function to look like the wrapped function.

    Parameters
    ----------
    wrapper : Callable
        the function to be updated
    wrapped : Callable
       the original function
    signature : Signature or None or bool, optional
        signature to impose on `wrapper`.
        None and False default to `wrapped`'s signature.
        True merges `wrapper` and `wrapped` kwdefaults & annotations
    docstring : str or bool, optional
        docstring to impose on `wrapper`.
        False ignores `wrapper`'s docstring, using only `wrapped`'s docstring.
        None (defualt) merges the `wrapper` and `wrapped` docstring
    assigned : tuple, optional
       tuple naming the attributes assigned directly
       from the wrapped function to the wrapper function (defaults to
       ``functools.WRAPPER_ASSIGNMENTS``)
    updated : tuple, optional
       is a tuple naming the attributes of the wrapper that
       are updated with the corresponding attribute from the wrapped
       function (defaults to ``functools.WRAPPER_UPDATES``)
    _doc_fmt : dict, optional
        dictionary to format wrapper docstring
    _doc_style: str or Callable, optional
        the style of the docstring
        if None (default), appends `wrapper` docstring
        if str or Callable, merges the docstring

    Returns
    -------
    wrapper : Callable
        `wrapper` function updated by the `wrapped` function's attributes and
        also the provided `signature` and `docstring`.

    Raises
    ------
    ValueError
        if docstring is True

    """
    # ---------------------------------------
    # preamble

    signature, _update_sig = __parse_sig_for_update_wrapper(signature, wrapped)

    # need to get wrapper properties now
    wrapper_sig = _FullerSig.from_callable(wrapper)

    wrapper_doc = _nspct.getdoc(wrapper) or ""
    wrapper_doc = "\n".join(wrapper_doc.split("\n")[1:])  # drop title

    if _doc_fmt is None:
        _doc_fmt = {}

    # ---------------------------------------
    # update wrapper (same as functools.update_wrapper)

    for attr in assigned:
        try:
            value = getattr(wrapped, attr)
        except AttributeError:
            pass
        else:
            setattr(wrapper, attr, value)

    for attr in updated:  # update whole dictionary
        getattr(wrapper, attr).update(getattr(wrapped, attr, {}))

    # ---------------------------------------

    # deal with signature
    if signature in (None, False):
        pass

    elif _update_sig:  # merge wrapped and wrapper signature

        signature = __update_wrapper_update_sig(
            signature, wrapper_sig, _doc_fmt
        )

        for attr in SIGNATURE_ASSIGNMENTS:
            value = getattr(signature, attr)
            setattr(wrapper, attr, value)

        wrapper.__signature__ = signature.signature

    else:  # a signature object
        for attr in SIGNATURE_ASSIGNMENTS:
            _value = getattr(signature, attr)
            setattr(wrapper, attr, _value)

        # for docstring
        for param in wrapper_sig.parameters.values():
            # can only merge keyword-only
            if param.kind == _nspct.KEYWORD_ONLY:
                _doc_fmt[param.name] = param.default

        wrapper.__signature__ = signature.signature

    # ---------------------------------------
    # docstring

    if _doc_fmt:  # (not empty dict)
        wrapper_doc = _FormatTemplate(wrapper_doc).safe_substitute(**_doc_fmt)

    wrapper.__doc__ = __update_wrapper_docstring(
        wrapped,
        docstring=docstring,
        wrapper_doc=wrapper_doc,
        _doc_style=_doc_style,
    )

    # Issue #17482: set __wrapped__ last so we don't inadvertently copy it
    # from the wrapped function when updating __dict__
    wrapper.__wrapped__ = wrapped
    # Return the wrapper so this can be used as a decorator via partial()
    return wrapper
Exemple #3
0
def inject_loop(func: typing.Callable) -> typing.Callable:
    """
    Add the main event loop to the decorated function.

    Requires a parameter: ``loop`` to be existing in the function. Will ensure
    that this parameter has the asyncio event loop injected into it.

    Args:
        func (Callable): The callable being decorated. It must have a ``loop``
        argument to be decorated.

    Returns:
        Callable: The decorated callable.

    """
    sig = inspect.signature(func)
    sig = sig.replace(
        parameters=[
            value
            if key != 'loop'
            else value.replace(default=None)
            for key, value in sig.parameters.items()
        ]
    )
    func.__signature__ = sig  # type: ignore

    def add_loop(
        args: typing.Tuple[typing.Any, ...],
        kwargs: typing.Dict[str, typing.Any]
    ) -> collections.OrderedDict:
        bargs = sig.bind(*args, **kwargs)
        bargs.apply_defaults()
        if bargs.arguments['loop'] is None:
            bargs.arguments['loop'] = _get_loop()

        return bargs.arguments  # type: ignore

    if inspect.isasyncgenfunction(func):  # type: ignore
        async def async_gen_loop_wrapper(
            *args: typing.Tuple[typing.Any, ...],
            **kwargs: typing.Dict[str, typing.Any]
        ) -> typing.AsyncGenerator:
            async for elem in func(**add_loop(args, kwargs)):
                yield elem
        ret = async_gen_loop_wrapper

    elif inspect.iscoroutinefunction(func):
        async def async_loop_wrapper(
            *args: typing.Tuple[typing.Any, ...],
            **kwargs: typing.Dict[str, typing.Any]
        ) -> typing.Coroutine:
            return await func(**add_loop(args, kwargs))
        ret = async_loop_wrapper  # type: ignore

    elif inspect.isgeneratorfunction(func):
        def gen_loop_wrapper(
            *args: typing.Tuple[typing.Any, ...],
            **kwargs: typing.Dict[str, typing.Any]
        ) -> typing.Generator:
            yield from func(**add_loop(args, kwargs))
        ret = gen_loop_wrapper  # type: ignore

    else:
        def func_loop_wrapper(
            *args: typing.Tuple[typing.Any, ...],
            **kwargs: typing.Dict[str, typing.Any]
        ) -> typing.Any:
            return func(**add_loop(args, kwargs))
        ret = func_loop_wrapper

    ret.__signature__ = sig  # type: ignore

    return functools.wraps(func)(ret)
Exemple #4
0
def logged(func: typing.Callable) -> typing.Callable:
    """
    Wrap callable with logging entries.

    There will be two logging entries:
        * the first will mark when the callable started, along with all
          parameters.
        * the second will mark when the callable ended, along with the return
          value and the time it took to execute the function.

    It will identify the callable as a FUNCTION, GENERATOR, or COROUTINE.

    Warning:
        Use this decorator sparingly, as it can make VERY VERBOSE
        LogLevel.GENERATED entries in the log. and makes it somewhat
        difficult to dig through logs at that level.

    Args:
        func (Callable): The callable to log.

    Returns:
        Callable: The decorated callable.
    """
    sig = inspect.signature(func)
    if 'log' in sig.parameters:
        sig = sig.replace(parameters=[
            value if key != 'log' else value.replace(default=None)
            for key, value in sig.parameters.items()
        ])
        func.__signature__ = sig

    if inspect.isasyncgenfunction(func):

        async def logged_async_gen(*args, **kwargs):
            with _log_wrapper(func,
                              '---- ASYNC GENERATOR ----',
                              sig,
                              args,
                              kwargs,
                              list_=True) as (results, fargs):
                async for entry in func(**fargs):
                    yield entry
                    results.append(entry)

        ret = logged_async_gen

    elif inspect.iscoroutinefunction(func):

        async def logged_coro(*args, **kwargs):
            with _log_wrapper(func,
                              '<<<< COROUTINE >>>>',
                              sig,
                              args,
                              kwargs,
                              list_=False) as (results, fargs):
                results.append(await func(**fargs))
                return results[-1]

        ret = logged_coro

    elif inspect.isgeneratorfunction(func):

        def logged_gen(*args, **kwargs):
            with _log_wrapper(func,
                              '|||| GENERATOR ||||',
                              sig,
                              args,
                              kwargs,
                              list_=True) as (results, fargs):
                for entry in func(**fargs):
                    yield entry
                    results.append(entry)

        ret = logged_gen

    else:

        def logged_func(*args, **kwargs):
            with _log_wrapper(func,
                              ':::: FUNCTION ::::',
                              sig,
                              args,
                              kwargs,
                              list_=False) as (results, fargs):
                results.append(func(**fargs))
                return results[-1]

        ret = logged_func

    ret.__signature__ = sig

    return functools.wraps(func)(ret)