Example #1
0
    async def build_response(
        self,
        message: Any,
        response: Response,
        cache_backend: Union[BaseCacheBackend, BaseAsyncCacheBackend],
    ) -> Any:
        """
        Builds the response object to return to the caller.
        :param message: The response from the ASGI application
        :param response: the response object
        :param cache_backend: The cache backend used to store responses.
        :return: The message returned from the ASGI application.
        """
        headers = MutableHeaders()

        headers["Cache-Control"] = str(self.ttl)

        response.headers.update(dict(headers.items()))

        try:
            should_update_cache = self.request.state.update_cache
        except AttributeError:
            return message

        if not should_update_cache:
            return message

        cache_key = self.key_func(self.request)

        if isinstance(cache_backend, BaseAsyncCacheBackend):
            await cache_backend.set(cache_key, message, self.ttl)
        else:
            cache_backend.set(cache_key, message, self.ttl)

        return message
Example #2
0
def test_mutable_headers_merge():
    h = MutableHeaders()
    h = h | MutableHeaders({"a": "1"})
    assert isinstance(h, MutableHeaders)
    assert dict(h) == {"a": "1"}
    assert h.items() == [("a", "1")]
    assert h.raw == [(b"a", b"1")]
Example #3
0
def test_mutable_headers_update_dict():
    h = MutableHeaders()
    h |= {"a": "1"}
    assert isinstance(h, MutableHeaders)
    assert dict(h) == {"a": "1"}
    assert h.items() == [("a", "1")]
    assert h.raw == [(b"a", b"1")]
Example #4
0
        def send_wrapper(message: Message):
            if message["type"] == "http.response.start":
                counter = traced_counter.get()
                if counter and isinstance(counter, Counter):
                    headers = MutableHeaders(scope=message)
                    headers.append("x-dagster-call-counts", json.dumps(counter.counts()))

            return send(message)
 async def send_wrapper(message: Message) -> None:
     if message['type'] == 'http.response.start':
         if scope['session']:
             _data = base64.b64encode(
                 json.dumps(scope['session']).encode('utf-8'))
             _data = self.signer.sign(_data)
             headers = MutableHeaders(scope=message)
             headers.append('session', _data.decode('utf-8'))
     await send(message)
 async def send(self, message, send, request):
     if message["type"] != "http.response.start":
         await send(message)
         return
     headers = MutableHeaders(scope=message)
     req_headers = {k.lower(): v for k, v in dict(request.headers).items()}
     headers.append("x-b3-traceid", req_headers.get("x-b3-traceid", ""))
     headers.append("x-b3-sampled", req_headers.get("x-b3-sampled", ""))
     await send(message)
Example #7
0
 async def __call__(self, scope: Scope, receive: Receive,
                    send: Send) -> None:
     if scope["type"] in ("http", "websocket"):
         connection = HTTPConnection(scope)
         if "user" in connection.cookies:
             user_auth = connection.cookies["user"]
             headers = MutableHeaders(scope=scope)
             headers.setdefault("Authorization", user_auth)
     await self.app(scope, receive, send)
Example #8
0
def headers():
    h = MutableHeaders()
    h.update({
        HeaderKeys.correlation_id: dummy_correlation_id,
        HeaderKeys.request_id: dummy_request_id,
        HeaderKeys.date: dummy_date,
        HeaderKeys.user_agent: dummy_user_agent,
        HeaderKeys.forwarded_for: dummy_forwarded_for,
    })
    return h
Example #9
0
    async def send(self, message: Message, send: Send, request_headers: Headers):
        if message["type"] != "http.response.start":
            await send(message)
            return

        message.setdefault("headers", [])
        headers = MutableHeaders(scope=message)
        headers.update(self.__headers)

        await send(message)
Example #10
0
    async def enrich_response(self, arg) -> None:
        value = str(context.get(self.key))

        # for ContextMiddleware
        if isinstance(arg, Response):
            arg.headers[self.key] = value
        # for ContextPureMiddleware
        else:
            if arg["type"] == "http.response.start":
                headers = MutableHeaders(scope=arg)
                headers.append(self.key, value)
Example #11
0
 async def send_cached(message):
     # Borrowing from:
     #   https://github.com/encode/starlette/blob/master/starlette/middleware/cors.py
     if message["type"] != "http.response.start":
         return await send(message)
     message.setdefault("headers", [])
     headers = MutableHeaders(scope=message)
     headers.update({
         'Cache-Control': f'max-age={max_age}',
         'Expires': http_date(time.time() + max_age),
     })
     await send(message)
Example #12
0
 async def send_wrapper(message: Message) -> None:
     if message["type"] == "http.response.start":
         path = scope.get("root_path", "") or "/"
         if scope["session"]:
             # We have session data to persist.
             data = b64encode(
                 json.dumps(scope["session"]).encode("utf-8"))
             data = self.signer.sign(data)
             headers = MutableHeaders(scope=message)
             header_value = "%s=%s; path=%s; Max-Age=%d; %s" % (
                 self.session_cookie,
                 data.decode("utf-8"),
                 path,
                 self.max_age,
                 self.security_flags,
             )
             headers.append("Set-Cookie", header_value)
         elif not initial_session_was_empty:
             # The session has been cleared.
             headers = MutableHeaders(scope=message)
             header_value = "{}={}; {}".format(
                 self.session_cookie,
                 f"null; path={path}; expires=Thu, 01 Jan 1970 00:00:00 GMT;",
                 self.security_flags,
             )
             headers.append("Set-Cookie", header_value)
     await send(message)
Example #13
0
    def __init__(self) -> None:
        """Do not use manually."""
        self._redirect_to = None
        self._starlette_resp = StarletteResponse
        self._body = None
        self._text = None
        self._content = None
        self._json = None

        self.headers = MutableHeaders()
        self.cookies = SimpleCookie()
        self.status_code = HTTPStatus.OK
        self.streaming: Optional[AsyncGenerator] = None
        self.reraise = False
Example #14
0
 async def sender(message: Message) -> None:
     if message["type"] == "http.response.start":
         if scope["session"]:
             # We have session data to persist.
             data = b64encode(
                 json.dumps(scope["session"]).encode("utf-8"))
             data = self.signer.sign(data)
             headers = MutableHeaders(scope=message)
             header_value = "%s=%s; path=/; Max-Age=%d; %s" % (
                 self.session_cookie,
                 data.decode("utf-8"),
                 self.max_age,
                 self.security_flags,
             )
             headers.append("Set-Cookie", header_value)
         elif not was_empty_session:
             # The session has been cleared.
             headers = MutableHeaders(scope=message)
             header_value = "%s=%s; %s" % (
                 self.session_cookie,
                 "null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT;",
                 self.security_flags,
             )
             headers.append("Set-Cookie", header_value)
     await send(message)
Example #15
0
        async def send_wrapper(message: Message) -> None:
            if message["type"] == "http.response.start":
                if scope["session"]:
                    if "exp" not in scope["session"]:
                        scope["session"]["exp"] = int(
                            time.time()) + self.max_age
                    data = jwt.encode(self.jwt_header, scope["session"],
                                      str(self.jwt_secret.encode))

                    headers = MutableHeaders(scope=message)
                    header_value = "%s=%s; path=/; Max-Age=%d; %s" % (
                        self.session_cookie,
                        data.decode("utf-8"),
                        self.max_age,
                        self.security_flags,
                    )
                    if self.domain:  # pragma: no cover
                        header_value += f"; domain={self.domain}"
                    headers.append("Set-Cookie", header_value)
                elif not initial_session_was_empty:
                    # The session has been cleared.
                    headers = MutableHeaders(scope=message)
                    header_value = "%s=%s; %s" % (
                        self.session_cookie,
                        "null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT;",
                        self.security_flags,
                    )
                    if self.domain:  # pragma: no cover
                        header_value += f"; domain={self.domain}"
                    headers.append("Set-Cookie", header_value)
            await send(message)
Example #16
0
        async def send_wrapper(message: Message) -> None:
            if message["type"] == "http.response.start":
                session: Session = scope["session"]
                if session.is_modified and not session.is_empty:
                    # We have session data to persist (data was changed, cleared, etc).
                    nonlocal session_id
                    session_id = await scope["session"].persist()

                    headers = MutableHeaders(scope=message)
                    header_value = "%s=%s; path=/; Max-Age=%d; %s" % (
                        self.session_cookie,
                        session_id,
                        self.max_age,
                        self.security_flags,
                    )
                    headers.append("Set-Cookie", header_value)
                elif session.is_loaded and session.is_empty:
                    # no interactions to session were done
                    headers = MutableHeaders(scope=message)
                    header_value = "%s=%s; %s" % (
                        self.session_cookie,
                        "null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT;",
                        self.security_flags,
                    )
                    headers.append("Set-Cookie", header_value)
            await send(message)
Example #17
0
 async def send_wrapper(message: Message) -> None:
     if message["type"] == "http.response.start":
         if scope["session"]:
             # We have session data to persist.
             data = b64encode(
                 json.dumps(scope["session"]).encode("utf-8"))
             data = self.signer.sign(data)
             headers = MutableHeaders(scope=message)
             header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(  # noqa E501
                 session_cookie=self.session_cookie,
                 data=data.decode("utf-8"),
                 path=self.path,
                 max_age=f"Max-Age={self.max_age}; "
                 if self.max_age else "",
                 security_flags=self.security_flags,
             )
             headers.append("Set-Cookie", header_value)
         elif not initial_session_was_empty:
             # The session has been cleared.
             headers = MutableHeaders(scope=message)
             header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(  # noqa E501
                 session_cookie=self.session_cookie,
                 data="null",
                 path=self.path,
                 expires="expires=Thu, 01 Jan 1970 00:00:00 GMT; ",
                 security_flags=self.security_flags,
             )
             headers.append("Set-Cookie", header_value)
     await send(message)
Example #18
0
    async def send(self, message, send=None, origin=None):
        if message["type"] != "http.response.start":
            await send(message)
            return

        message.setdefault("headers", [])
        headers = MutableHeaders(message["headers"])

        # If we only allow specific origins, then we have to mirror back
        # the Origin header in the response.
        if not self.allow_all_origins and self.is_allowed_origin(
                origin=origin):
            headers["Access-Control-Allow-Origin"] = origin
        headers.update(self.simple_headers)
        await send(message)
Example #19
0
                async def wrapped_send_bfyaml(response):
                    nonlocal body, agent_response

                    try:
                        if agent_response:  # send response if signature is validated
                            if response.get("type") == "http.response.start":
                                _add_header(
                                    response, agent_response[0],
                                    agent_response[1]
                                )

                                # We can add headers as many as we want but it was
                                # not possible to mutate an existing header without
                                # using below approach.
                                # override the Content-Length received from the original
                                # Response. Note: MutableHeaders is present in the minimum
                                # Starlette version used in minimum FastAPI version (0.51.0)
                                from starlette.datastructures import MutableHeaders
                                headers = MutableHeaders(
                                    raw=response["headers"]
                                )
                                headers['Content-Length'] = str(len(body))
                            elif response.get("type") == "http.response.body":
                                response["body"] = body

                        await send(response)
                    except Exception as e:
                        log.exception(e)
Example #20
0
    async def send_with_gzip(self, message: Message) -> None:
        message_type = message["type"]
        if message_type == "http.response.start":
            # Don't send the initial message until we've determined how to
            # modify the ougoging headers correctly.
            self.initial_message = message
        elif message_type == "http.response.body" and not self.started:
            self.started = True
            body = message.get("body", b"")
            more_body = message.get("more_body", False)
            if len(body) < self.minimum_size and not more_body:
                # Don't apply GZip to small outgoing responses.
                await self.send(self.initial_message)
                await self.send(message)
            elif not more_body:
                # Standard GZip response.
                self.gzip_file.write(body)
                self.gzip_file.close()
                body = self.gzip_buffer.getvalue()

                headers = MutableHeaders(raw=self.initial_message["headers"])
                headers["Content-Encoding"] = "gzip"
                headers["Content-Length"] = str(len(body))
                headers.add_vary_header("Accept-Encoding")
                message["body"] = body

                await self.send(self.initial_message)
                await self.send(message)
            else:
                # Initial body in streaming GZip response.
                headers = MutableHeaders(raw=self.initial_message["headers"])
                headers["Content-Encoding"] = "gzip"
                headers.add_vary_header("Accept-Encoding")
                del headers["Content-Length"]

                self.gzip_file.write(body)
                message["body"] = self.gzip_buffer.getvalue()
                self.gzip_buffer.seek(0)
                self.gzip_buffer.truncate()

                await self.send(self.initial_message)
                await self.send(message)

        elif message_type == "http.response.body":
            # Remaining body in streaming GZip response.
            body = message.get("body", b"")
            more_body = message.get("more_body", False)

            self.gzip_file.write(body)
            if not more_body:
                self.gzip_file.close()

            message["body"] = self.gzip_buffer.getvalue()
            self.gzip_buffer.seek(0)
            self.gzip_buffer.truncate()

            await self.send(message)
 async def send_wrapper(self, msg: Message):
     if msg["type"] == "http.response.start":
         result = self.rl_res
         headers = MutableHeaders(scope=msg)
         headers.append("X-Rate-Limit-Limit", str(result.consumed_points))
         headers.append("X-Rate-Limit-Remaining",
                        str(result.remaining_points))
         headers.append("X-Rate-Limit-Reset", str(result.ms_before_next))
     await self.send(msg)
Example #22
0
    async def send_with_brotli(self, message: Message) -> None:
        """Apply compression using brotli."""
        message_type = message["type"]
        if message_type == "http.response.start":
            # Don't send the initial message until we've determined how to
            # modify the outgoing headers correctly.
            self.initial_message = message
        elif message_type == "http.response.body" and not self.started:
            self.started = True
            body = message.get("body", b"")
            more_body = message.get("more_body", False)
            if len(body) < self.minimum_size and not more_body:
                # Don't apply Brotli to small outgoing responses.
                await self.send(self.initial_message)
                await self.send(message)
            elif not more_body:
                # Standard Brotli response.
                body = self.br_file.process(body) + self.br_file.finish()
                headers = MutableHeaders(raw=self.initial_message["headers"])
                headers["Content-Encoding"] = "br"
                headers["Content-Length"] = str(len(body))
                headers.add_vary_header("Accept-Encoding")
                message["body"] = body
                await self.send(self.initial_message)
                await self.send(message)
            else:
                # Initial body in streaming Brotli response.
                headers = MutableHeaders(raw=self.initial_message["headers"])
                headers["Content-Encoding"] = "br"
                headers.add_vary_header("Accept-Encoding")
                del headers["Content-Length"]
                self.br_buffer.write(
                    self.br_file.process(body) + self.br_file.flush())

                message["body"] = self.br_buffer.getvalue()
                self.br_buffer.seek(0)
                self.br_buffer.truncate()
                await self.send(self.initial_message)
                await self.send(message)

        elif message_type == "http.response.body":
            # Remaining body in streaming Brotli response.
            body = message.get("body", b"")
            more_body = message.get("more_body", False)
            self.br_buffer.write(
                self.br_file.process(body) + self.br_file.flush())
            if not more_body:
                self.br_buffer.write(self.br_file.finish())
                message["body"] = self.br_buffer.getvalue()
                self.br_buffer.close()
                await self.send(message)
                return
            message["body"] = self.br_buffer.getvalue()
            self.br_buffer.seek(0)
            self.br_buffer.truncate()
            await self.send(message)
Example #23
0
    async def send_with_caching(self, message: Message) -> None:
        if message["type"] == "http.response.start":
            logger.trace(f"patch_cache_control {kvformat(**self.kwargs)}")
            headers = MutableHeaders(raw=list(message["headers"]))
            patch_cache_control(headers, **self.kwargs)
            message["headers"] = headers.raw

        await self.send(message)
Example #24
0
 async def body(self) -> bytes:
     if not hasattr(self, "_body"):
         body = await super().body()
         try:
             self._body = zlib.decompress(body)
             headers = MutableHeaders(raw=self.scope["headers"])
             headers["Content-Length"] = str(len(self._body))
         except zlib.error:
             self._body = body
     return self._body
Example #25
0
    async def __call__(self, scope: Scope, receive: Receive,
                       send: Send) -> None:
        headers = MutableHeaders(scope=scope)
        self.should_decode_from_msgpack_to_json = ("application/x-msgpack"
                                                   in headers.get(
                                                       "content-type", ""))
        # Take an initial guess, although we eventually may not
        # be able to do the conversion.
        self.should_encode_from_json_to_msgpack = (
            "application/x-msgpack" in headers.getlist("accept"))
        self.receive = receive
        self.send = send

        if self.should_decode_from_msgpack_to_json:
            # We're going to present JSON content to the application,
            # so rewrite `Content-Type` for consistency and compliance
            # with possible downstream security checks in some frameworks.
            # See: https://github.com/florimondmanca/msgpack-asgi/issues/23
            headers["content-type"] = "application/json"

        await self.app(scope, self.receive_with_msgpack,
                       self.send_with_msgpack)
Example #26
0
    async def send(self, message: Message, send: Send,
                   request_headers: Headers) -> None:
        if message["type"] != "http.response.start":
            await send(message)
            return

        message.setdefault("headers", [])
        headers = MutableHeaders(scope=message)
        headers.update(self.simple_headers)
        origin = request_headers["Origin"]
        has_cookie = "cookie" in request_headers

        # If request includes any cookie headers, then we must respond
        # with the specific origin instead of '*'.
        if self.allow_all_origins and has_cookie:
            self.allow_explicit_origin(headers, origin)

        # If we only allow specific origins, then we have to mirror back
        # the Origin header in the response.
        elif not self.allow_all_origins and self.is_allowed_origin(
                origin=origin):
            self.allow_explicit_origin(headers, origin)

        await send(message)
Example #27
0
    async def send(self, message, send=None, request_headers=None):
        if message["type"] != "http.response.start":
            await send(message)
            return

        message.setdefault("headers", [])
        headers = MutableHeaders(message["headers"])
        origin = request_headers["Origin"]
        has_cookie = "cookie" in request_headers

        # If request includes any cookie headers, then we must respond
        # with the specific origin instead of '*'.
        if self.allow_all_origins and has_cookie:
            self.simple_headers["Access-Control-Allow-Origin"] = origin

        # If we only allow specific origins, then we have to mirror back
        # the Origin header in the response.
        elif not self.allow_all_origins and self.is_allowed_origin(
                origin=origin):
            headers["Access-Control-Allow-Origin"] = origin
            if "vary" in headers:
                self.simple_headers["Vary"] = f"{headers.get('vary')}, Origin"
        headers.update(self.simple_headers)
        await send(message)
Example #28
0
def patch_cache_control(headers: MutableHeaders, **kwargs: typing.Any) -> None:
    """
    Patch headers with an extended version of the initial Cache-Control header by adding
    all keyword arguments to it.
    """
    cache_control: typing.Dict[str, typing.Any] = {}
    for field in parse_http_list(headers.get("Cache-Control", "")):
        try:
            key, value = field.split("=")
        except ValueError:
            cache_control[field] = True
        else:
            cache_control[key] = value

    if "max-age" in cache_control and "max_age" in kwargs:
        kwargs["max_age"] = min(int(cache_control["max-age"]),
                                kwargs["max_age"])

    if "public" in kwargs:
        raise NotImplementedError(
            "The 'public' cache control directive isn't supported yet.")

    if "private" in kwargs:
        raise NotImplementedError(
            "The 'private' cache control directive isn't supported yet.")

    for key, value in kwargs.items():
        key = key.replace("_", "-")
        cache_control[key] = value

    directives: typing.List[str] = []
    for key, value in cache_control.items():
        if value is False:
            continue
        if value is True:
            directives.append(key)
        else:
            directives.append(f"{key}={value}")

    patched_cache_control = ", ".join(directives)

    if patched_cache_control:
        headers["Cache-Control"] = patched_cache_control
    else:
        del headers["Cache-Control"]
Example #29
0
def test_mutable_headers():
    h = MutableHeaders()
    assert dict(h) == {}
    h["a"] = "1"
    assert dict(h) == {"a": "1"}
    h["a"] = "2"
    assert dict(h) == {"a": "2"}
    h.setdefault("a", "3")
    assert dict(h) == {"a": "2"}
    h.setdefault("b", "4")
    assert dict(h) == {"a": "2", "b": "4"}
    del h["a"]
    assert dict(h) == {"b": "4"}
 async def send_wrapper(message: Message) -> None:
     if message['type'] == 'http.response.start':
         if scope['session']:
             # We have session data to persist.
             headers = MutableHeaders(scope=message)
             header_value = '%s=%s; path=/; Max-Age=%d; %s' % (
                 self.session_cookie,
                 self.backend.encode(scope['session']),
                 self.max_age,
                 self.security_flags,
             )
             headers.append('Set-Cookie', header_value)
         elif not initial_session_was_empty:
             # The session has been cleared.
             headers = MutableHeaders(scope=message)
             header_value = '%s=%s; %s' % (
                 self.session_cookie,
                 'null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT;',
                 self.security_flags,
             )
             headers.append('Set-Cookie', header_value)
     await send(message)