Example #1
0
def test_async_func():
    async def async_func():
        ...  # pragma: no cover

    def func():
        ...  # pragma: no cover

    assert is_async_callable(async_func)
    assert not is_async_callable(func)
Example #2
0
def test_async_method():
    class Async:
        async def method(self):
            ...  # pragma: no cover

    class Sync:
        def method(self):
            ...  # pragma: no cover

    assert is_async_callable(Async().method)
    assert not is_async_callable(Sync().method)
Example #3
0
def test_async_object_call():
    class Async:
        async def __call__(self):
            ...  # pragma: no cover

    class Sync:
        def __call__(self):
            ...  # pragma: no cover

    assert is_async_callable(Async())
    assert not is_async_callable(Sync())
Example #4
0
def test_async_partial():
    async def async_func(a, b):
        ...  # pragma: no cover

    def func(a, b):
        ...  # pragma: no cover

    partial = functools.partial(async_func, 1)
    assert is_async_callable(partial)

    partial = functools.partial(func, 1)
    assert not is_async_callable(partial)
Example #5
0
def test_async_partial_object_call():
    class Async:
        async def __call__(self, a, b):
            ...  # pragma: no cover

    class Sync:
        def __call__(self, a, b):
            ...  # pragma: no cover

    partial = functools.partial(Async(), 1)
    assert is_async_callable(partial)

    partial = functools.partial(Sync(), 1)
    assert not is_async_callable(partial)
Example #6
0
def test_async_nested_partial():
    async def async_func(a, b):
        ...  # pragma: no cover

    partial = functools.partial(async_func, b=2)
    nested_partial = functools.partial(partial, a=1)
    assert is_async_callable(nested_partial)
Example #7
0
 async def shutdown(self) -> None:
     """
     Run any `.on_shutdown` event handlers.
     """
     for handler in self.on_shutdown:
         if is_async_callable(handler):
             await handler()
         else:
             handler()
Example #8
0
 async def startup(self) -> None:
     """
     Run any `.on_startup` event handlers.
     """
     for handler in self.on_startup:
         if is_async_callable(handler):
             await handler()
         else:
             handler()
Example #9
0
    async def dispatch(self) -> None:
        request = Request(self.scope, receive=self.receive)
        handler_name = ("get" if request.method == "HEAD"
                        and not hasattr(self, "head") else
                        request.method.lower())

        handler: typing.Callable[[Request],
                                 typing.Any] = getattr(self, handler_name,
                                                       self.method_not_allowed)
        is_async = is_async_callable(handler)
        if is_async:
            response = await handler(request)
        else:
            response = await run_in_threadpool(handler, request)
        await response(self.scope, self.receive, self.send)
Example #10
0
def request_response(func: typing.Callable) -> ASGIApp:
    """
    Takes a function or coroutine `func(request) -> response`,
    and returns an ASGI application.
    """
    is_coroutine = is_async_callable(func)

    async def app(scope: Scope, receive: Receive, send: Send) -> None:
        request = Request(scope, receive=receive, send=send)
        if is_coroutine:
            response = await func(request)
        else:
            response = await run_in_threadpool(func, request)
        await response(scope, receive, send)

    return app
Example #11
0
    async def __call__(self, scope: Scope, receive: Receive,
                       send: Send) -> None:
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return

        response_started = False

        async def _send(message: Message) -> None:
            nonlocal response_started, send

            if message["type"] == "http.response.start":
                response_started = True
            await send(message)

        try:
            await self.app(scope, receive, _send)
        except Exception as exc:
            request = Request(scope)
            if self.debug:
                # In debug mode, return traceback responses.
                response = self.debug_response(request, exc)
            elif self.handler is None:
                # Use our default 500 error handler.
                response = self.error_response(request, exc)
            else:
                # Use an installed 500 error handler.
                if is_async_callable(self.handler):
                    response = await self.handler(request, exc)
                else:
                    response = await run_in_threadpool(self.handler, request,
                                                       exc)

            if not response_started:
                await response(scope, receive, send)

            # We always continue to raise the exception.
            # This allows servers to log the error, or allows test clients
            # to optionally raise the error within the test case.
            raise exc
Example #12
0
    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return

        response_started = False

        async def sender(message: Message) -> None:
            nonlocal response_started

            if message["type"] == "http.response.start":
                response_started = True
            await send(message)

        try:
            await self.app(scope, receive, sender)
        except Exception as exc:
            handler = None

            if isinstance(exc, HTTPException):
                handler = self._status_handlers.get(exc.status_code)

            if handler is None:
                handler = self._lookup_exception_handler(exc)

            if handler is None:
                raise exc

            if response_started:
                msg = "Caught handled exception, but response already started."
                raise RuntimeError(msg) from exc

            request = Request(scope, receive=receive)
            if is_async_callable(handler):
                response = await handler(request, exc)
            else:
                response = await run_in_threadpool(handler, request, exc)
            await response(scope, receive, sender)
Example #13
0
    def decorator(func: typing.Callable) -> typing.Callable:
        sig = inspect.signature(func)
        for idx, parameter in enumerate(sig.parameters.values()):
            if parameter.name == "request" or parameter.name == "websocket":
                type_ = parameter.name
                break
        else:
            raise Exception(
                f'No "request" or "websocket" argument on function "{func}"')

        if type_ == "websocket":
            # Handle websocket functions. (Always async)
            @functools.wraps(func)
            async def websocket_wrapper(*args: typing.Any,
                                        **kwargs: typing.Any) -> None:
                websocket = kwargs.get("websocket",
                                       args[idx] if idx < len(args) else None)
                assert isinstance(websocket, WebSocket)

                if not has_required_scope(websocket, scopes_list):
                    await websocket.close()
                else:
                    await func(*args, **kwargs)

            return websocket_wrapper

        elif is_async_callable(func):
            # Handle async request/response functions.
            @functools.wraps(func)
            async def async_wrapper(*args: typing.Any,
                                    **kwargs: typing.Any) -> Response:
                request = kwargs.get("request",
                                     args[idx] if idx < len(args) else None)
                assert isinstance(request, Request)

                if not has_required_scope(request, scopes_list):
                    if redirect is not None:
                        orig_request_qparam = urlencode(
                            {"next": str(request.url)})
                        next_url = "{redirect_path}?{orig_request}".format(
                            redirect_path=request.url_for(redirect),
                            orig_request=orig_request_qparam,
                        )
                        return RedirectResponse(url=next_url, status_code=303)
                    raise HTTPException(status_code=status_code)
                return await func(*args, **kwargs)

            return async_wrapper

        else:
            # Handle sync request/response functions.
            @functools.wraps(func)
            def sync_wrapper(*args: typing.Any,
                             **kwargs: typing.Any) -> Response:
                request = kwargs.get("request",
                                     args[idx] if idx < len(args) else None)
                assert isinstance(request, Request)

                if not has_required_scope(request, scopes_list):
                    if redirect is not None:
                        orig_request_qparam = urlencode(
                            {"next": str(request.url)})
                        next_url = "{redirect_path}?{orig_request}".format(
                            redirect_path=request.url_for(redirect),
                            orig_request=orig_request_qparam,
                        )
                        return RedirectResponse(url=next_url, status_code=303)
                    raise HTTPException(status_code=status_code)
                return func(*args, **kwargs)

            return sync_wrapper
Example #14
0
 def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args,
              **kwargs: P.kwargs) -> None:
     self.func = func
     self.args = args
     self.kwargs = kwargs
     self.is_async = is_async_callable(func)
Example #15
0
def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
    if inspect.isclass(app):
        return hasattr(app, "__await__")
    return is_async_callable(app)