class AsgiTestBase(SpanTestBase):
    def setUp(self):
        super().setUp()

        self.scope = {}
        setup_testing_defaults(self.scope)
        self.communicator = None

    def tearDown(self):
        if self.communicator:
            asyncio.get_event_loop().run_until_complete(
                self.communicator.wait())

    def seed_app(self, app):
        self.communicator = ApplicationCommunicator(app, self.scope)

    def send_input(self, message):
        asyncio.get_event_loop().run_until_complete(
            self.communicator.send_input(message))

    def send_default_request(self):
        self.send_input({"type": "http.request", "body": b""})

    def get_output(self):
        output = asyncio.get_event_loop().run_until_complete(
            self.communicator.receive_output(0))
        return output

    def get_all_output(self):
        outputs = []
        while True:
            try:
                outputs.append(self.get_output())
            except asyncio.TimeoutError:
                break
        return outputs
Exemple #2
0
class ASGITest(TestCase):
    @skipUnless(HAVE_ASYNCIO_AND_ASGI, "Don't have asyncio/asgi installed.")
    def setUp(self):
        self.registry = CollectorRegistry()
        self.captured_status = None
        self.captured_headers = None
        # Setup ASGI scope
        self.scope = {}
        setup_testing_defaults(self.scope)
        self.communicator = None

    def tearDown(self):
        if self.communicator:
            asyncio.get_event_loop().run_until_complete(
                self.communicator.wait()
            )
            
    def seed_app(self, app):
        self.communicator = ApplicationCommunicator(app, self.scope)

    def send_input(self, payload):
        asyncio.get_event_loop().run_until_complete(
            self.communicator.send_input(payload)
        )

    def send_default_request(self):
        self.send_input({"type": "http.request", "body": b""})

    def get_output(self):
        output = asyncio.get_event_loop().run_until_complete(
            self.communicator.receive_output(0)
        )
        return output

    def get_all_output(self):
        outputs = []
        while True:
            try:
                outputs.append(self.get_output())
            except asyncio.TimeoutError:
                break
        return outputs

    def validate_metrics(self, metric_name, help_text, increments):
        """
        ASGI app serves the metrics from the provided registry.
        """
        c = Counter(metric_name, help_text, registry=self.registry)
        for _ in range(increments):
            c.inc()
        # Create and run ASGI app
        app = make_asgi_app(self.registry)
        self.seed_app(app)
        self.send_default_request()
        # Assert outputs
        outputs = self.get_all_output()
        # Assert outputs
        self.assertEqual(len(outputs), 2)
        response_start = outputs[0]
        self.assertEqual(response_start['type'], 'http.response.start')
        response_body = outputs[1]
        self.assertEqual(response_body['type'], 'http.response.body')
        # Status code
        self.assertEqual(response_start['status'], 200)
        # Headers
        self.assertEqual(len(response_start['headers']), 1)
        self.assertEqual(response_start['headers'][0], (b"Content-Type", CONTENT_TYPE_LATEST.encode('utf8')))
        # Body
        output = response_body['body'].decode('utf8')
        self.assertIn("# HELP " + metric_name + "_total " + help_text + "\n", output)
        self.assertIn("# TYPE " + metric_name + "_total counter\n", output)
        self.assertIn(metric_name + "_total " + str(increments) + ".0\n", output)

    def test_report_metrics_1(self):
        self.validate_metrics("counter", "A counter", 2)

    def test_report_metrics_2(self):
        self.validate_metrics("counter", "Another counter", 3)

    def test_report_metrics_3(self):
        self.validate_metrics("requests", "Number of requests", 5)

    def test_report_metrics_4(self):
        self.validate_metrics("failed_requests", "Number of failed requests", 7)
Exemple #3
0
class _BaseWebSocket:
    _app_ref: AppRef
    host: str
    path: str
    headers: MutableHeaders
    queries: QueryParams
    timeout: Optional[float]

    def __init__(
        self,
        http: _BaseClient,
        path: str = "",
        headers: Optional[MutableHeaders] = None,
        queries: Optional[Params] = None,
        cookies: Optional[Mapping] = None,
        timeout: Optional[float] = None,
    ):
        self._app_ref = AppRef(app=http._app_ref["app"])
        self.host = http.host
        self._client = http._client
        self.path = path
        self.headers = headers or MutableHeaders()
        cookie_headers = SimpleCookie(cookies).output().splitlines()
        for c in cookie_headers:
            k, v = c.split(":", 1)
            self.headers.append(k, v)
        self.queries = QueryParams(queries or [])
        self.timeout = timeout

    async def _connect(self, path: str = None):
        self.path = path or self.path
        queries = self.queries
        if isinstance(queries, Mapping):
            qsl = [f"{quote_plus(k)}={quote_plus(v)}" for k, v in queries.items()]
        else:
            qsl = [f"{quote_plus(k)}={quote_plus(v)}" for k, v in queries]
        qs = "&".join(qsl).encode("ascii")
        headers = [
            (k.encode("latin-1"), v.encode("latin-1")) for k, v in self.headers.items()
        ]
        scope = {
            "type": "websocket",
            "asgi": {"spec_version": "2.1"},
            "scheme": "ws",
            "http_version": "1.1",
            "path": self.path,
            "raw_path": quote_plus(self.path).encode("ascii"),
            "query_string": qs,
            "root_path": "",
            "headers": headers,
            "client": self._client,
            "subprotocols": [
                x
                for x in self.headers.get("sec-websocket-protocol", "").split(", ")
                if x
            ],
        }

        self._connection = ApplicationCommunicator(self._app_ref["app"], scope)
        await self._connection.send_input({"type": "websocket.connect"})
        message = await (self._connection.receive_output(self.timeout))
        if message["type"] != "websocket.accept":
            raise RuntimeError("Connection refused.")

    async def _close(self, status_code=1000):
        message = {"type": "websocket.disconnect", "code": status_code}
        await self._connection.send_input(message)
        await self._connection.receive_nothing()
        del self._connection

    async def _receive(self, mode: type[AnyStr]) -> AnyStr:
        message = await self._connection.receive_output(self.timeout)
        if mode is str:
            type_key = "text"
        elif mode is bytes:
            type_key = "bytes"
        else:
            raise TypeError(f"`str` or `bytes` are allowed, not `{mode}` .")
        result = message.get(type_key, None)
        if result is None:
            if message["type"] == "websocket.close":
                raise RuntimeError("Connection already closed.")
            raise TypeError(f"Server did not send `{type_key}` content.")
        return result

    async def _send(self, data: AnyStr):
        message: dict = {"type": "websocket.receive"}
        if isinstance(data, str):
            type_key = "text"
        elif isinstance(data, bytes):
            type_key = "bytes"
        else:
            raise TypeError(
                f"use `str` or `bytes`. data type `{type(data)}` is not acceptable."
            )
        message[type_key] = data
        await self._connection.send_input(message)

    async def __aenter__(self) -> "_BaseWebSocket":
        await self._connect()
        return self

    async def __aexit__(self, exc_type, exc_value, traceback) -> None:
        await self._close()