Example #1
0
 async def app(scope, receive, send):
     request = Request(scope, receive)
     body = await request.body()
     chunks = b""
     async for chunk in request.stream():
         chunks += chunk
     response = JSONResponse({
         "body": body.decode(),
         "stream": chunks.decode()
     })
     await response(scope, receive, send)
Example #2
0
 async def __call__(self, scope: Scope, receive: Receive, send: Send):
     req = Request(scope, receive, send)
     path = self._route_tree.get(req.url.path)
     if not path:
         await _not_found(scope, receive, send)
         return
     handler = path.get(req.method)
     if not handler:
         await _not_allowed(scope, receive, send)
         return
     await handler(scope, receive, send)
Example #3
0
 def setUp(self):
     """setup for test"""
     self.mock_scope = {
         "headers": "",
         "path": "",
         "query_string": "",
         "root_path": "",
         "type": "http",
     }
     self.mock_request = Request(scope=self.mock_scope)
     pass
    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return

        request = Request(scope, receive=receive)
        send = functools.partial(self.send, send=send, request=request)

        response = await self.before_request(request) or self.app
        await response(request.scope, request.receive, send)
        await self.after_request(request)
Example #5
0
 async def asgi(receive, send):
     request = Request(scope, receive)
     chunks = b""
     async for chunk in request.stream():
         chunks += chunk
     try:
         body = await request.body()
     except RuntimeError:
         body = b"<stream consumed>"
     response = JSONResponse({"body": body.decode(), "stream": chunks.decode()})
     await response(receive, send)
def starlette_request() -> Request:
    scope = dict(
        method="GET",
        type="http",
        path="test/v1/test",
        headers=[(b"test_header", b"test_val"),
                 (b"test_header_1", b"test_val_2")],
        query_string=b"?b=1&a=2",
    )
    request = Request(scope)
    yield request
Example #7
0
async def test_get_from_cache_vary(cache: Cache) -> None:
    scope: Scope = {
        "type": "http",
        "method": "GET",
        "path": "/path",
        "headers": [[b"accept-encoding", b"gzip, deflate"]],
    }
    request = Request(scope)
    # Response indicates that contents of the response at this URL may *vary*
    # depending on the "Accept-Encoding" header sent in the request.
    response = PlainTextResponse("Hello, world!",
                                 headers={"Vary": "Accept-Encoding"})
    await store_in_cache(response, request=request, cache=cache)

    # Let's use a different "Accept-Encoding" header,
    # and check that no cached response is found.
    other_scope = {**scope, "headers": [[b"accept-encoding", b"identity"]]}
    other_request = Request(other_scope)
    cached_response = await get_from_cache(other_request, cache=cache)
    assert cached_response is None
Example #8
0
    async def __call__(self, scope: Scope, receive: Receive,
                       send: Send) -> None:
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return

        async def call_next(request: Request) -> Response:
            app_exc: typing.Optional[Exception] = None
            send_stream, recv_stream = anyio.create_memory_object_stream()

            async def coro() -> None:
                nonlocal app_exc

                async with send_stream:
                    try:
                        await self.app(scope, request.receive,
                                       send_stream.send)
                    except Exception as exc:
                        app_exc = exc

            task_group.start_soon(coro)

            try:
                message = await recv_stream.receive()
            except anyio.EndOfStream:
                if app_exc is not None:
                    raise app_exc
                raise RuntimeError("No response returned.")

            assert message["type"] == "http.response.start"

            async def body_stream() -> typing.AsyncGenerator[bytes, None]:
                async with recv_stream:
                    async for message in recv_stream:
                        assert message["type"] == "http.response.body"
                        body = message.get("body", b"")
                        if body:
                            yield body
                        if not message.get("more_body", False):
                            break

                if app_exc is not None:
                    raise app_exc

            response = StreamingResponse(status_code=message["status"],
                                         content=body_stream())
            response.raw_headers = message["headers"]
            return response

        async with anyio.create_task_group() as task_group:
            request = Request(scope, receive=receive)
            response = await self.dispatch_func(request, call_next)
            await response(scope, receive, send)
            task_group.cancel_scope.cancel()
Example #9
0
 async def asgi(receive, send):
     request = Request(scope, receive)
     # Read bytes, to force request.stream() to return the already parsed body
     body_bytes = await request.body()
     data = await request.form()
     output = {}
     for key, value in data.items():
         output[key] = value
     await request.close()
     response = JSONResponse(output)
     await response(receive, send)
Example #10
0
    async def my_app(scope: Scope, receive: Receive, send: Send):
        request: Request = Request(scope, receive=receive)
        user: UserModel = await get_current_user(await oauth2_scheme(request))

        if user and has_one_permission(user, permissions):
            return await _app(scope, receive, send)
        raise HTTPException(
            status_code=HTTP_401_UNAUTHORIZED,
            detail="Could not validate credentials",
            headers={"WWW-Authenticate": "Bearer"},
        )
async def test_oauth2_authorize_code_challenge():
    app = AsyncPathMapDispatch({'/token': {'body': get_bearer_token()}})
    oauth = OAuth()
    client = oauth.register(
        'dev',
        client_id='dev',
        api_base_url='https://i.b/api',
        access_token_url='https://i.b/token',
        authorize_url='https://i.b/authorize',
        client_kwargs={
            'code_challenge_method': 'S256',
            'app': app,
        },
    )

    req_scope = {'type': 'http', 'session': {}}
    req = Request(req_scope)

    resp = await client.authorize_redirect(req,
                                           redirect_uri='https://b.com/bar')
    assert resp.status_code == 302

    url = resp.headers.get('Location')
    assert 'code_challenge=' in url
    assert 'code_challenge_method=S256' in url

    state = req.session['_dev_authlib_state_']
    assert state is not None

    verifier = req.session['_dev_authlib_code_verifier_']
    assert verifier is not None

    req_scope.update({
        'path': '/',
        'query_string': 'code=a&state={}'.format(state).encode(),
        'session': req.session,
    })
    req = Request(req_scope)

    token = await client.authorize_access_token(req)
    assert token['access_token'] == 'a'
Example #12
0
def test_get_settings_from_request_true() -> None:
    request = Request(scope={
        "type": "http",
        "headers": {},
    })
    request.cookies["nsfw_popular_all"] = "1"
    request.cookies["nsfw_thumbnails"] = "1"

    res = get_settings_from_request(request)

    assert res.nsfw_popular_all is True
    assert res.nsfw_thumbnails is True
Example #13
0
async def test_non_cachable_status_codes(cache: Cache,
                                         status_code: int) -> None:
    scope: Scope = {
        "type": "http",
        "method": "GET",
        "path": "/path",
        "headers": [],
    }
    request = Request(scope)
    response = PlainTextResponse("Hello, world!", status_code=status_code)
    with pytest.raises(ResponseNotCachable):
        await store_in_cache(response, request=request, cache=cache)
    async def __call__(self, scope: Scope, receive: Receive,
                       send: Send) -> None:
        if scope["type"] not in ["http"]:
            await self.app(scope, receive, send)
            return

        request = Request(scope)

        method = request.method
        path = request.url.path
        begin = time.perf_counter()
        end = None

        # Default status code used when the application does not return a valid response
        # or an unhandled exception occurs.
        status_code = 500

        async def wrapped_send(message: Message) -> None:
            if message['type'] == 'http.response.start':
                nonlocal status_code
                status_code = message['status']

            if message['type'] == 'http.response.body':
                nonlocal end
                end = time.perf_counter()

            await send(message)

        try:
            await self.app(scope, receive, wrapped_send)
        finally:
            if self.filter_unhandled_paths or self.group_paths:
                grouped_path = self._get_router_path(scope)

                # filter_unhandled_paths removes any requests without mapped endpoint from the metrics.
                if self.filter_unhandled_paths and grouped_path is None:
                    return

                # group_paths enables returning the original router path (with url param names)
                # for example, when using this option, requests to /api/product/1 and /api/product/3
                # will both be grouped under /api/product/{product_id}. See the README for more info.
                if self.group_paths and grouped_path is not None:
                    path = grouped_path

            labels = [method, path, status_code, self.app_name]

            # if we were not able to set end when the response body was written,
            # set it now.
            if end is None:
                end = time.perf_counter()

            self.request_count.labels(*labels).inc()
            self.request_time.labels(*labels).observe(end - begin)
Example #15
0
    async def __call__(self, scope, receive, send):
        request = Request(scope, receive=receive)

        if self.assert_func:
            await self.assert_func(request)

        response = Response(
            status_code=self.status_code,
            content=self.body,
            headers=self.headers,
        )
        await response(scope, receive, send)
Example #16
0
    async def app(scope, receive, send):
        # the server is push-enabled
        scope["extensions"]["http.response.push"] = {}

        data = "OK"
        request = Request(scope)
        try:
            await request.send_push_promise("/style.css")
        except RuntimeError:
            data = "Send channel not available"
        response = JSONResponse({"json": data})
        await response(scope, receive, send)
Example #17
0
    async def asgi(self, receive: Receive, send: Send, scope: Scope) -> None:
        global context

        # TODO: Temporary fix for https://github.com/encode/starlette/issues/472
        # --------------------------8<--------------------------
        self._orig_send = send
        # -------------------------->8--------------------------

        with context.enter():
            request = Request(scope, receive=receive)
            response = await self.call_next(request)
            await response(receive, send)
Example #18
0
 async def dispatch(self) -> None:
     request = Request(self.scope, self.receive)
     graphiql = get_graphql_config(request).graphiql
     if "text/html" in request.headers.get("Accept", ""):
         app: ASGIApp
         if graphiql and graphiql.path is None:
             app = GraphiQLEndpoint
         else:
             app = PlainTextResponse("Not Found", 404)
         await app(self.scope, self.receive, self.send)
     else:
         await super().dispatch()
    async def __call__(self, scope: Scope, receive: Receive,
                       send: Send) -> None:

        # handle http requests
        if self.is_http(scope):
            request = Request(scope, receive)

            # handle callback request
            if self.get_url(request) == self.callback_url:
                await Callback(
                    scope,
                    receive,
                    send,
                    kc=self.kc,
                    redirect_uri=self.login_redirect_uri,
                )
                return

            # handle logout request
            elif request.url.path == self.logout_uri:
                await Logout(
                    scope,
                    receive,
                    send,
                    kc=self.kc,
                    redirect_uri=self.logout_redirect_uri,
                )
                return

            # handle logout redirect uri
            elif request.url.path == self.logout_redirect_uri:
                await self.app(scope, receive, send)
                return

            # handle unauthorized requests
            elif "user" not in request.session:
                await Login(
                    scope,
                    receive,
                    send,
                    kc=self.kc,
                    redirect_uri=self.logout_redirect_uri,
                )
                return

            # handle authorized requests
            else:
                await self.app(scope, receive, send)
                return

        # handle non http requests
        await self.app(scope, receive, send)
Example #20
0
def store_current_request(
        request_or_scope: Union[Request, Scope],
        receive: Optional[Receive] = None) -> Optional[Request]:
    if ContextVar is None:
        return None

    if receive is None:
        request = request_or_scope
    else:
        request = Request(request_or_scope, receive)

    _current_request.set(request)
    return request
Example #21
0
    def server_request_hook(self, span, req_data, body):
        """this function is used to capture request attributes"""
        span.update_name(f"{req_data['method']} {span.name}")
        headers = dict(Headers(raw=req_data['headers']))
        request_url = str(Request(req_data).url)
        self.generic_request_handler(headers, body, span)

        block_result = Registry().apply_filters(span, request_url, headers,
                                                body, TYPE_HTTP)
        if block_result:
            logger.debug('should block evaluated to true, aborting with 403')
            return False
        return True
 def _prepare_request(db_session: AsyncSession, user: User,
                      user_session: UserSession):
     jwt, _ = encode_jwt({
         "user_id": user.id,
         "session_id": user_session.public_id
     })
     scope = {
         "type": "http",
         "headers": [(b"authorization", f"Bearer {jwt}".encode("latin-1"))]
     }
     request = Request(scope)
     request.db_session = db_session
     return request
Example #23
0
 async def dispatch(self) -> None:
     """
     Overriding of the existing method.
     """
     request = Request(self.scope, receive=self.receive)
     handler = self.perform_action
     # In case the `perform_action` method is overridden with a sync one.
     is_async = asyncio.iscoroutinefunction(handler)
     if is_async:
         response = await handler(request)
     else:
         response = await run_in_threadpool(handler, request)
     await response(self.scope, self.receive, self.send)
Example #24
0
async def hook(scope, receive, send):
    assert scope['type'] == 'http'
    request = Request(scope, receive)
    sig = hmac.new(key, digestmod='sha1')
    body = await request.body()
    sig.update(body)
    valid = 'X-Hub-Signature' in request.headers and hmac.compare_digest(
        'sha1={}'.format(sig.hexdigest()), request.headers['X-Hub-Signature'])
    response = PlainTextResponse(
        str(valid),
        background=BackgroundTask(rebuild_deploy, body=await request.json())
        if valid else None)
    await response(scope, receive, send)
Example #25
0
 def __call__(self,
              method='GET',
              path='/',
              headers=None,
              client_addr='testclient',
              **extra_scope):
     _headers = {
         'host': 'testserver',
         'user-agent': 'testclient',
         'connection': 'keep-alive'
     }
     if headers:
         _headers.update(headers)
     scope = {
         'type':
         'http',
         'http_version':
         '1.1',
         'method':
         method,
         'path':
         path,
         'root_path':
         '',
         'scheme':
         'http',
         'query_string':
         b'',
         'headers': [(k.lower().encode(), v.encode())
                     for k, v in _headers.items()],
         'client': [
             client_addr,
             50000,
         ],
         'server': [
             'testserver',
             80,
         ],
         'extensions': {
             'http.response.template': {},
         },
         'app':
         app,
         'state': {},
         'session': {},
         'endpoint':
         endpoint,
         'path_params': {},
         **extra_scope,
     }
     return Request(scope, None, None)
Example #26
0
 def __call__(self, scope: Scope) -> ASGIInstance:
     if scope["type"] in ("http", "websocket"):
         request = Request(scope)
         if self.session_cookie in request.cookies:
             data = request.cookies[self.session_cookie].encode("utf-8")
             try:
                 data = self.signer.unsign(data, max_age=self.max_age)
                 scope["session"] = json.loads(b64decode(data))
             except (BadTimeSignature, SignatureExpired):
                 scope["session"] = {}
         else:
             scope["session"] = {}
         return functools.partial(self.asgi, scope=scope)
     return self.app(scope)  # pragma: no cover
Example #27
0
async def application(scope, receive, send):
    if scope["type"] != "http":
        return

    request = Request(scope, receive)
    response = Response()

    if request.method != "GET":
        response.status = 405
    else:
        async with aiohttp.ClientSession() as client:
            await fetch_content(client, request, response)

    await response(scope, receive, send)
Example #28
0
 async def dispatch(self):
     try:
         request = Request(self.scope, self.receive)
         metadata, handler = self.get_handler(request)
         request_payload = await self.get_request_payload(request, metadata)
         self.request = request
         self.metadata = metadata
         response: Response
         response_payload = await handler(request_payload)
         response = self.encode_response(metadata, response_payload)
     except AWSServerException as e:
         response = self.encode_response(metadata,
                                         (HTTP_200_OK, e.as_aws_payload()))
     await response(self.scope, self.receive, self.send)
Example #29
0
    async def asgi(self, receive: Receive, send: Send, scope: Scope) -> None:
        request = Request(scope, receive=receive)
        try:
            auth_result = await self.backend.authenticate(request)
        except AuthenticationError as exc:
            response = PlainTextResponse(str(exc), status_code=400)
            await response(receive, send)
            return

        if auth_result is None:
            auth_result = AuthCredentials(), UnauthenticatedUser()
        scope["auth"], scope["user"] = auth_result
        inner = self.app(scope)
        await inner(receive, send)
Example #30
0
async def test_exception_handler_pydantic_validationerror():
    async def test_receive():
        return {"type": "http.request"}

    request = Request(
        {"type": "http", "method": "GET", "path": "/"}, receive=test_receive
    )
    error = ValidationError([{"hello": "world"}])
    raw_response = await validation_exception_handler(request, error)
    response = json.loads(raw_response.body.decode("utf-8"))

    assert response["code"] == 400
    assert response["detail"] == "Empty body for this request is not valid."
    assert response["fields"] == []