Exemplo n.º 1
0
def test_headers():
    h = Headers(raw=[(b"a", b"123"), (b"a", b"456"), (b"b", b"789")])
    assert "a" in h
    assert "A" in h
    assert "b" in h
    assert "B" in h
    assert "c" not in h
    assert h["a"] == "123"
    assert h.get("a") == "123"
    assert h.get("nope", default=None) is None
    assert h.getlist("a") == ["123", "456"]
    assert h.keys() == ["a", "a", "b"]
    assert h.values() == ["123", "456", "789"]
    assert h.items() == [("a", "123"), ("a", "456"), ("b", "789")]
    assert list(h) == ["a", "a", "b"]
    assert dict(h) == {"a": "123", "b": "789"}
    assert repr(
        h) == "Headers(raw=[(b'a', b'123'), (b'a', b'456'), (b'b', b'789')])"
    assert h == Headers(raw=[(b"a", b"123"), (b"b", b"789"), (b"a", b"456")])
    assert h != [(b"a", b"123"), (b"A", b"456"), (b"b", b"789")]

    h = Headers({"a": "123", "b": "789"})
    assert h["A"] == "123"
    assert h["B"] == "789"
    assert h.raw == [(b"a", b"123"), (b"b", b"789")]
    assert repr(h) == "Headers({'a': '123', 'b': '789'})"
Exemplo n.º 2
0
    async def __call__(
            self, scope, receive, send
    ) -> None:
        if scope["type"] != "http":  # pragma: no cover
            handler = await self.app(scope, receive, send)
            await handler.__call__(receive, send)
            return

        method = scope["method"]
        headers = Headers(scope=scope)
        origin = headers.get("origin")

        if origin is None:
            handler = await self.app(scope, receive, send)
            await handler.__call__(receive, send)
            return

        if method == "OPTIONS" and "access-control-request-method" in headers:
            response = self.preflight_response(request_headers=headers)
            await response(scope, receive, send)
            return

        await self.simple_response(
            scope, receive, send, request_headers=headers
        )
Exemplo n.º 3
0
    def client_response_hook(self, span, resp_data):
        """used to capture the response data
        this function is called twice, once during each resp_phase"""
        resp_phase = resp_data['type']
        if resp_phase == "http.response.start":
            status_code = resp_data["status"]
            set_status_code(span, status_code)
            headers = dict(Headers(raw=resp_data['headers']))
            should_capture_body = self._capture_headers(
                self._process_response_headers,
                self.HTTP_RESPONSE_HEADER_PREFIX, span, headers,
                self._process_response_body)
            span.set_attribute('hypertrace.capture', should_capture_body)

        elif resp_phase == 'http.response.body':
            should_capture = span.attributes.get('hypertrace.capture')
            if should_capture:
                body_data = resp_data['body']
                body_str = None
                if isinstance(body_data, bytes):
                    body_str = body_data.decode('UTF8', 'backslashreplace')
                else:
                    body_str = body_data

                resp_body_str = self.grab_first_n_bytes(body_str)
                span.set_attribute('http.response.body', resp_body_str)
Exemplo n.º 4
0
async def get_sequence(checksum: str,
                       start: int = 0,
                       end: int = 0,
                       accept: str = ""):
    """
    Return Refget sequence based on checksum value.
    str start: Start point of the sequence defined in checksum.
    str end: End point of the sequence defined in checksum.
    """
    headers = Headers()
    params = {"accept": accept}

    if start < end:
        params["start"] = start
        params["end"] = end
    url_path = "sequence/" + checksum
    try:
        result = await create_request_coroutine(
            url_list=metadata_url_list(checksum),
            url_path=url_path,
            headers=headers,
            params=params,
        )
        if result == "":
            return HTTPException(status_code=HTTP_404_NOT_FOUND,
                                 detail="Not Found")

        return result
    except Exception as e:
        logger.log("DEBUG", "Unhandled exception in get_sequence " + str(e))
Exemplo n.º 5
0
    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        if self.allow_any or scope["type"] not in (
            "http",
            "websocket",
        ):  # pragma: no cover
            await self.app(scope, receive, send)
            return

        headers = Headers(scope=scope)
        host = headers.get("host", "").split(":")[0]
        is_valid_host = False
        found_www_redirect = False
        for pattern in self.allowed_hosts:
            if host == pattern or (
                pattern.startswith("*") and host.endswith(pattern[1:])
            ):
                is_valid_host = True
                break
            elif "www." + host == pattern:
                found_www_redirect = True

        if is_valid_host:
            await self.app(scope, receive, send)
        else:
            response: Response
            if found_www_redirect and self.www_redirect:
                url = URL(scope=scope)
                redirect_url = url.replace(netloc="www." + url.netloc)
                response = RedirectResponse(url=str(redirect_url))
            else:
                response = PlainTextResponse("Invalid host header", status_code=400)
            await response(scope, receive, send)
Exemplo n.º 6
0
 async def __call__(self, scope: Scope, receive: Receive,
                    send: Send) -> None:
     headers = Headers(scope=scope)
     url = URL(scope=scope)
     log_debug(message=f'[API Server] {url}',
               data={'headers': f'{headers}'})
     await self.app(scope, receive, send)
Exemplo n.º 7
0
def test_player_json(service, steam_id, mod_date):
    resp = service.get("/player/{0}.json".format(steam_id))

    obj_defacto = resp.json()
    obj_expected = read_json_sample("player_{}".format(steam_id))
    assert obj_defacto == obj_expected

    resp = service.get("/player/{0}".format(steam_id))
    assert resp.template.name == "player_stats.html"
    context = resp.context
    assert "request" in context
    assert "steam_id" in context
    assert context["steam_id"] == steam_id

    del context["request"]
    del context["steam_id"]
    obj_defacto = context
    assert obj_defacto == obj_expected

    assert resp.headers["last-modified"] == mod_date

    service.get(
        "/player/{0}.json".format(steam_id),
        304,
        headers=Headers({"If-Modified-Since": mod_date}),
    )
Exemplo n.º 8
0
 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
     if scope["type"] == "http":
         headers = Headers(scope=scope)
         if "gzip" in headers.get("Accept-Encoding", ""):
             responder = GZipResponder(self.app, self.minimum_size)
             await responder(scope, receive, send)
             return
     await self.app(scope, receive, send)
Exemplo n.º 9
0
 async def get_response(self, path: str, scope: Scope) -> Response:
     response = await super().get_response(path, scope)
     if isinstance(response, FileResponse):
         etag = await _hash_file(response.path)
         response.headers["etag"] = etag
         if self.is_not_modified(response.headers, Headers(scope=scope)):
             return NotModifiedResponse(response.headers)
     return response
Exemplo n.º 10
0
 def __init__(self, function: Callable[..., Any],
              callback: Union[Dict[str, Any], bool]) -> None:
     self.__function = function
     self.__attribute_finders = self.__prepare_and_validate_finders(
         callback)
     self.__headers = Headers()
     self.__params: Dict[str, str] = {}
     self.__invalid_callback_object = ""
Exemplo n.º 11
0
 def __init__(self, request: HTTPConnection):
     super().__init__(
         request.scope["method"],
         str(URL(scope=request.scope)),
         None,
         Headers(scope=request.scope),
     )
     self._orig_request = request
Exemplo n.º 12
0
    def __call__(self, scope: Scope) -> ASGIInstance:
        if scope["type"] in ("http", "websocket") and not self.allow_any:
            headers = Headers(scope=scope)
            host = headers.get("host")
            if host not in self.allowed_hosts:
                return PlainTextResponse("Invalid host header",
                                         status_code=400)

        return self.app(scope)
Exemplo n.º 13
0
def test_partial():
    from starlette.datastructures import Headers

    class Header(Model):
        token: str

    header = Header(**Headers({"token": "123", "test": "12345g"}))
    assert header.token == "123"
    assert header.dict() == {"token": "123"}
Exemplo n.º 14
0
    async def asgi(self, receive: Receive, send: Send, scope: Scope, path: str) -> None:
        if not self.config_checked:
            await self.check_config()
            self.config_checked = True

        method = scope["method"]
        headers = Headers(scope=scope)
        response = await self.get_response(path, method, headers)
        await response(receive, send)
Exemplo n.º 15
0
    async def __call__(self, scope: Scope, receive: Receive, send: Send):
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return

        method = scope["method"]
        headers = Headers(scope=scope)

        await self.response(scope, receive, send, request_headers=headers)
Exemplo n.º 16
0
 def is_not_modified(self, stat_headers: typing.Dict[str, str]) -> bool:
     etag = stat_headers["etag"]
     last_modified = stat_headers["last-modified"]
     req_headers = Headers(scope=self.scope)
     if etag == req_headers.get("if-none-match"):
         return True
     if "if-modified-since" not in req_headers:
         return False
     last_req_time = req_headers["if-modified-since"]
     return parsedate(last_req_time) >= parsedate(last_modified)  # type: ignore
def test_check_auth_error_auth_error():
    """Test: check_auth(request: Request) -> None
    Error: Auth error."""

    scope = {'type': 'http'}
    request = Request(scope)
    request._headers = Headers(headers={'X-Key': 'fake-key'})
    with pytest.raises(HTTPException) as ex:
        check_auth(request)
    assert ex.value.status_code == 400
    assert ex.value.detail == 'Auth error'
Exemplo n.º 18
0
    async def __call__(self, scope: Scope, receive: Receive, send: Send):
        if scope["type"] not in ("http", "websocket"):
            return await self.app(scope, receive, send)

        headers = Headers(scope=scope)
        for header in headers:
            if header in self.required_headers:
                return await self.app(scope, receive, send)

        logging.error("No App Engine task header found.  Headers: %s", headers)
        response = Response(None, HTTPStatus.FORBIDDEN)
        return await response(scope, receive, send)
Exemplo n.º 19
0
def test_deprecated_player_json(service, steam_id, mod_date):
    resp = service.get("/deprecated/player/{0}.json".format(steam_id))
    assert resp.json() == read_json_sample(
        "deprecated_player_{0}".format(steam_id))

    assert resp.headers["last-modified"] == mod_date

    service.get(
        "/player/{0}.json".format(steam_id),
        304,
        headers=Headers({"If-Modified-Since": mod_date}),
    )
Exemplo n.º 20
0
        async def send_with_tracing(message: Message) -> None:
            span = self.tracer.current_span()

            if span and message.get("type") == "http.response.start":
                if "status" in message:
                    status_code: int = message["status"]
                    span.set_tag(http_tags.STATUS_CODE, str(status_code))
                if "headers" in message:
                    response_headers = Headers(raw=message["headers"])
                    store_response_headers(response_headers, span, config.asgi)

            await send(message)
Exemplo n.º 21
0
 def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
     headers = Headers(scope=scope)
     host = headers.get("host", "").split(":")[0]
     match = self.host_regex.match(host)
     if match:
         matched_params = match.groupdict()
         for key, value in matched_params.items():
             matched_params[key] = self.param_convertors[key].convert(value)
         path_params = dict(scope.get("path_params", {}))
         path_params.update(matched_params)
         child_scope = {"path_params": path_params, "endpoint": self.app}
         return Match.FULL, child_scope
     return Match.NONE, {}
Exemplo n.º 22
0
    def server_request_hook(self, span, req_data, body):
        """this function is used to capture request attributes"""
        span.update_name(f"{req_data['method']} {span.name}")
        headers = dict(Headers(raw=req_data['headers']))
        request_url = str(Request(req_data).url)
        self.generic_request_handler(headers, body, span)

        block_result = Registry().apply_filters(span, request_url, headers,
                                                body, TYPE_HTTP)
        if block_result:
            logger.debug('should block evaluated to true, aborting with 403')
            return False
        return True
Exemplo n.º 23
0
 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
     headers = Headers(scope=scope)
     self.should_decode_from_msgpack_to_json = (
         "application/x-msgpack" in headers.get("content-type", "")
     )
     # Take an initial guess, although we eventually may not
     # be able to do the conversion.
     self.should_encode_from_json_to_msgpack = (
         "application/x-msgpack" in headers.getlist("accept")
     )
     self.receive = receive
     self.send = send
     await self.app(scope, self.receive_with_msgpack, self.send_with_msgpack)
Exemplo n.º 24
0
    def __call__(self, scope: Scope) -> ASGIInstance:
        if scope["type"] == "http":
            method = scope["method"]
            headers = Headers(scope=scope)
            origin = headers.get("origin")

            if origin is not None:
                if method == "OPTIONS" and "access-control-request-method" in headers:
                    return self.preflight_response(request_headers=headers)
                else:
                    return functools.partial(self.simple_response,
                                             scope=scope,
                                             request_headers=headers)

        return self.app(scope)
Exemplo n.º 25
0
    def __call__(self, scope: Scope) -> ASGIInstance:
        if scope["type"] in ("http", "websocket") and not self.allow_any:
            headers = Headers(scope=scope)
            host = headers.get("host", "").split(":")[0]
            for pattern in self.allowed_hosts:
                if (
                    host == pattern
                    or pattern.startswith("*")
                    and host.endswith(pattern[1:])
                ):
                    break
            else:
                return PlainTextResponse("Invalid host header", status_code=400)

        return self.app(scope)
Exemplo n.º 26
0
        async def send_with_tracing(message: Message) -> None:
            span = self.tracer.current_span()

            if not span:
                # Unexpected.
                await send(message)
                return

            if message["type"] == "http.response.start":
                status_code: int = message["status"]
                response_headers = Headers(raw=message["headers"])
                store_response_headers(response_headers, span, config.asgi)
                span.set_tag(http_tags.STATUS_CODE, str(status_code))

            await send(message)
Exemplo n.º 27
0
 async def __call__(self, scope: Scope, receive: Receive,
                    send: Send) -> None:
     if scope["type"] == "http":
         headers = Headers(scope=scope)
         accepted = {
             item.strip()
             for item in headers.get("accept-encoding", "").split(",")
             if item
         }
         responder = CompressionResponder(self.app, self.minimum_size,
                                          accepted,
                                          self.compression_registry)
         await responder(scope, receive, send)
         return
     await self.app(scope, receive, send)
Exemplo n.º 28
0
    def file_response(
        self,
        full_path: str,
        stat_result: os.stat_result,
        scope: Scope,
        status_code: int = 200,
    ) -> Response:
        method = scope["method"]
        request_headers = Headers(scope=scope)

        response = FileResponse(
            full_path, status_code=status_code, stat_result=stat_result, method=method
        )
        if self.is_not_modified(response.headers, request_headers):
            return NotModifiedResponse(response.headers)
        return response
Exemplo n.º 29
0
 async def __call__(self, scope: Scope, receive: Receive,
                    send: Send) -> None:
     if scope["type"] == "http":
         headers = Headers(scope=scope)
         if "br" in headers.get("Accept-Encoding", ""):
             responder = BrotliResponder(
                 self.app,
                 self.quality,
                 self.mode,
                 self.lgwin,
                 self.lgblock,
                 self.minimum_size,
             )
             await responder(scope, receive, send)
             return
     await self.app(scope, receive, send)
Exemplo n.º 30
0
    def __init__(
        self,
        slices: Sequence[slice],
        *,
        headers: dict = None,
        content: Any = None,
        media_type: str = None,
    ) -> None:
        """Constructs a ByteRangesResponse.

        Args:
            slices: The byte ranges of content to send. The stop attribute of
                each slice must not be negative. If Content-Length is not
                supplied, then the stop attribute of each slice must not be
                None.
            headers: The headers that would have been sent in a 200 (OK)
                response.
            content: The overall content to be split into parts. Must be
                convertible to bytes.
            media_type: The Content-Type of the overall content to be split
                into parts. Only required if Content-Type does not occur in
                headers.
        """
        super().__init__(
            status_code=206,
            headers=headers,
            content=content,
            media_type=media_type,
        )
        assert slices
        self.slices = slices
        self.headers_200 = Headers(self.headers)
        self.boundary = str(uuid.uuid4())
        self.raw_boundary = self.boundary.encode("latin-1")
        if len(self.slices) == 1:
            s = _normalize(self.headers_200, self.slices[0])
            self.headers["content-range"] = _get_content_range(
                self.headers_200, s)
            self.headers["content-length"] = str(s.stop - s.start)
        else:
            self.headers[
                "content-type"] = f'multipart/byteranges; boundary="{self.boundary}"'
            self.headers["content-length"] = str(
                _get_multipart_length(self.headers_200, self.raw_boundary,
                                      self.slices))