예제 #1
0
def test_mutable_headers_update_dict():
    h = MutableHeaders()
    h |= {"a": "1"}
    assert isinstance(h, MutableHeaders)
    assert dict(h) == {"a": "1"}
    assert h.items() == [("a", "1")]
    assert h.raw == [(b"a", b"1")]
예제 #2
0
def test_mutable_headers_merge():
    h = MutableHeaders()
    h = h | MutableHeaders({"a": "1"})
    assert isinstance(h, MutableHeaders)
    assert dict(h) == {"a": "1"}
    assert h.items() == [("a", "1")]
    assert h.raw == [(b"a", b"1")]
예제 #3
0
    async def build_response(
        self,
        message: Any,
        response: Response,
        cache_backend: Union[BaseCacheBackend, BaseAsyncCacheBackend],
    ) -> Any:
        """
        Builds the response object to return to the caller.
        :param message: The response from the ASGI application
        :param response: the response object
        :param cache_backend: The cache backend used to store responses.
        :return: The message returned from the ASGI application.
        """
        headers = MutableHeaders()

        headers["Cache-Control"] = str(self.ttl)

        response.headers.update(dict(headers.items()))

        try:
            should_update_cache = self.request.state.update_cache
        except AttributeError:
            return message

        if not should_update_cache:
            return message

        cache_key = self.key_func(self.request)

        if isinstance(cache_backend, BaseAsyncCacheBackend):
            await cache_backend.set(cache_key, message, self.ttl)
        else:
            cache_backend.set(cache_key, message, self.ttl)

        return message
예제 #4
0
    async def _request(
        self,
        method: str,
        path: str,
        queries: Optional[Params] = None,
        headers: Optional[Headers] = None,
        cookies: Optional[dict[str, str]] = None,
        json: Optional[Mapping] = None,
        files: Optional[Mapping] = None,
        form: Optional[Mapping] = None,
        content: Optional[bytes] = None,
        timeout: Optional[float] = None,
        allow_redirects=True,
    ) -> HttpTestResponse:
        # TODO: max_redirects?
        if isinstance(headers, Mapping):
            ci_headers = MutableHeaders(headers=headers)
        else:
            ci_headers = MutableHeaders(
                raw=[
                    (k.encode("latin-1"), v.encode("latin-1")) for k, v in headers or []
                ]
            )
        if files:
            file_list = [
                RequestField.from_tuples(key, value) for key, value in files.items()
            ]
            data, content_type = encode_multipart_formdata(file_list)
            ci_headers["content-type"] = content_type
        elif form:
            data = form
        else:
            data = content
        timeout = timeout or self.timeout
        if cookies is not None:
            self._client.cookies = Cookies()
        query_params = QueryParams(queries or [])
        if timeout is not None:
            _headers = [(k, v) for k, v in ci_headers.items()]
            async with timeout_ctx(timeout):
                response = await self._client.request(
                    method.upper(),
                    path,
                    headers=_headers,
                    data=data,
                    params=query_params,
                    json=json,
                    cookies=cookies,
                    follow_redirects=False,
                    timeout=timeout,
                )
                if allow_redirects and 300 <= response.status_code < 400:
                    path = response.headers["location"]
                    response = await self._client.request(
                        method.upper(),
                        path,
                        headers=_headers,
                        data=data,
                        params=query_params,
                        json=json,
                        cookies=cookies,
                        follow_redirects=allow_redirects,
                        timeout=timeout,
                    )

        else:
            _headers = [(k, v) for k, v in ci_headers.items()]
            response = await self._client.request(
                method.upper(),
                path,
                headers=_headers,
                data=data,
                params=query_params,
                json=json,
                cookies=cookies,
                follow_redirects=False,
                timeout=timeout,
            )
            if allow_redirects and 300 <= response.status_code < 400:
                path = response.headers["location"]
                response = await self._client.request(
                    method.upper(),
                    path,
                    headers=_headers,
                    data=data,
                    params=query_params,
                    json=json,
                    cookies=cookies,
                    follow_redirects=allow_redirects,
                    timeout=timeout,
                )
        return HttpTestResponse(response)