Example #1
0
 def test_async_context_manager_success(self, anyio_backend_name,
                                        anyio_backend_options):
     with start_blocking_portal(anyio_backend_name,
                                anyio_backend_options) as portal:
         with portal.wrap_async_context_manager(
                 TestBlockingPortal.AsyncCM(False)) as cm:
             assert cm == 'test'
Example #2
0
    def test_start_crash_before_started_call(self, anyio_backend_name, anyio_backend_options):
        def taskfunc(*, task_status):
            raise Exception('foo')

        with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
            with pytest.raises(Exception, match='foo'):
                portal.start_task(taskfunc)
Example #3
0
    def __enter__(self):
        if not self._portal:
            self._portal = start_blocking_portal()
            self._shutdown_portal = True

        self._portal.call(self._scheduler.__aenter__)
        return self
Example #4
0
    def test_start_no_started_call(self, anyio_backend_name, anyio_backend_options):
        def taskfunc(*, task_status):
            pass

        with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
            with pytest.raises(RuntimeError, match='Task exited'):
                portal.start_task(taskfunc)
Example #5
0
 def test_call_stopped_portal(self, anyio_backend_name,
                              anyio_backend_options):
     portal = start_blocking_portal(anyio_backend_name,
                                    anyio_backend_options)
     portal.call(portal.stop)
     pytest.raises(RuntimeError, portal.call, threading.get_ident).\
         match('This portal is not running')
Example #6
0
    def test_start_with_name(self, anyio_backend_name, anyio_backend_options):
        def taskfunc(*, task_status):
            task_status.started(get_current_task().name)

        with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
            future, start_value = portal.start_task(taskfunc, name='testname')
            assert start_value == 'testname'
Example #7
0
    def test_start_with_new_event_loop(self, anyio_backend_name, anyio_backend_options,
                                       use_contextmanager):
        async def async_get_thread_id():
            return threading.get_ident()

        if use_contextmanager:
            with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
                thread_id = portal.call(async_get_thread_id)
        else:
            portal = start_blocking_portal(anyio_backend_name, anyio_backend_options)
            try:
                thread_id = portal.call(async_get_thread_id)
            finally:
                portal.call(portal.stop)

        assert isinstance(thread_id, int)
        assert thread_id != threading.get_ident()
Example #8
0
    def test_async_context_manager_error(self, anyio_backend_name, anyio_backend_options):
        with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
            with pytest.raises(Exception) as exc:
                with portal.wrap_async_context_manager(TestBlockingPortal.AsyncCM(False)) as cm:
                    assert cm == 'test'
                    raise Exception('should NOT be ignored')

                exc.match('should NOT be ignored')
Example #9
0
    def test_start_with_value(self, anyio_backend_name, anyio_backend_options):
        def taskfunc(*, task_status):
            task_status.started('foo')

        with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
            future, value = portal.start_task(taskfunc)
            assert value == 'foo'
            assert future.result() is None
Example #10
0
 def test_async_context_manager_error_ignore(self, anyio_backend_name,
                                             anyio_backend_options):
     with start_blocking_portal(anyio_backend_name,
                                anyio_backend_options) as portal:
         with portal.wrap_async_context_manager(
                 TestBlockingPortal.AsyncCM(True)) as cm:
             assert cm == 'test'
             raise Exception('should be ignored')
Example #11
0
    def test_start_with_new_event_loop(self, anyio_backend_name, anyio_backend_options):
        async def async_get_thread_id():
            return threading.get_ident()

        with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
            thread_id = portal.call(async_get_thread_id)

        assert isinstance(thread_id, int)
        assert thread_id != threading.get_ident()
Example #12
0
    def test_start_crash_after_started_call(self, anyio_backend_name, anyio_backend_options):
        def taskfunc(*, task_status):
            task_status.started(2)
            raise Exception('foo')

        with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
            future, value = portal.start_task(taskfunc)
            assert value == 2
            with pytest.raises(Exception, match='foo'):
                future.result()
Example #13
0
    def test_spawn_task_with_name(self, anyio_backend_name, anyio_backend_options):
        async def taskfunc():
            nonlocal task_name
            task_name = get_current_task().name

        task_name = None
        with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
            portal.spawn_task(taskfunc, name='testname')

        assert task_name == 'testname'
Example #14
0
    def test_spawn_task_cancel_later(self, anyio_backend_name, anyio_backend_options):
        async def noop():
            await sleep(2)

        with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
            future = portal.spawn_task(noop)
            portal.call(wait_all_tasks_blocked)
            future.cancel()

        assert future.cancelled()
Example #15
0
    def connect(self) -> None:
        """Connect to the SMTP server."""
        portal_cm = start_blocking_portal(self._async_backend, self._async_backend_options)
        portal = portal_cm.__enter__()
        try:
            portal.call(self._async_client.connect)
        except BaseException:
            portal_cm.__exit__(*sys.exc_info())
            raise

        self._portal_cm = portal_cm
        self._portal = portal
Example #16
0
    def test_spawn_task(self, anyio_backend_name, anyio_backend_options):
        async def event_waiter():
            await event1.wait()
            event2.set()
            return 'test'

        with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
            event1 = portal.call(create_event)
            event2 = portal.call(create_event)
            future = portal.spawn_task(event_waiter)
            portal.call(event1.set)
            portal.call(event2.wait)
            assert future.result() == 'test'
Example #17
0
    def test_spawn_task_cancel_immediately(self, anyio_backend_name, anyio_backend_options):
        async def event_waiter():
            nonlocal cancelled
            try:
                await sleep(3)
            except get_cancelled_exc_class():
                cancelled = True

        cancelled = False
        with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
            future = portal.spawn_task(event_waiter)
            future.cancel()

        assert cancelled
Example #18
0
    def __enter__(self) -> "WebSocketTestSession":
        self.exit_stack = contextlib.ExitStack()
        self.portal = self.exit_stack.enter_context(
            anyio.start_blocking_portal(**self.async_backend)
        )

        try:
            _: "Future[None]" = self.portal.start_task_soon(self._run)
            self.send({"type": "websocket.connect"})
            message = self.receive()
            self._raise_on_close(message)
        except Exception:
            self.exit_stack.close()
            raise
        self.accepted_subprotocol = message.get("subprotocol", None)
        return self
Example #19
0
 def __enter__(self) -> "TestClient":
     self.exit_stack = contextlib.ExitStack()
     self.portal = self.exit_stack.enter_context(
         anyio.start_blocking_portal(**self.async_backend)
     )
     self.stream_send = StapledObjectStream(
         *anyio.create_memory_object_stream(math.inf)
     )
     self.stream_receive = StapledObjectStream(
         *anyio.create_memory_object_stream(math.inf)
     )
     try:
         self.task = self.portal.start_task_soon(self.lifespan)
         self.portal.call(self.wait_startup)
     except Exception:
         self.exit_stack.close()
         raise
     return self
Example #20
0
 def __enter__(self):
     self._exit_stack.__enter__()
     if not self.portal:
         portal_cm = start_blocking_portal()
         self.portal = self._exit_stack.enter_context(portal_cm)
         self._async_event_source.subscribe(self._forward_async_event)
Example #21
0
def portal():
    with start_blocking_portal() as portal:
        yield portal
Example #22
0
    def send(
        self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any
    ) -> requests.Response:
        scheme, netloc, path, query, fragment = (
            str(item) for item in urlsplit(request.url)
        )

        default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]

        if ":" in netloc:
            host, port_string = netloc.split(":", 1)
            port = int(port_string)
        else:
            host = netloc
            port = default_port

        # Include the 'host' header.
        if "host" in request.headers:
            headers: typing.List[typing.Tuple[bytes, bytes]] = []
        elif port == default_port:
            headers = [(b"host", host.encode())]
        else:
            headers = [(b"host", (f"{host}:{port}").encode())]

        # Include other request headers.
        headers += [
            (key.lower().encode(), value.encode())
            for key, value in request.headers.items()
        ]

        if scheme in {"ws", "wss"}:
            subprotocol = request.headers.get("sec-websocket-protocol", None)
            if subprotocol is None:
                subprotocols: typing.Sequence[str] = []
            else:
                subprotocols = [value.strip() for value in subprotocol.split(",")]
            scope = {
                "type": "websocket",
                "path": unquote(path),
                "root_path": self.root_path,
                "scheme": scheme,
                "query_string": query.encode(),
                "headers": headers,
                "client": ["testclient", 50000],
                "server": [host, port],
                "subprotocols": subprotocols,
            }
            session = WebSocketTestSession(self.app, scope, self.async_backend)
            raise _Upgrade(session)

        scope = {
            "type": "http",
            "http_version": "1.1",
            "method": request.method,
            "path": unquote(path),
            "root_path": self.root_path,
            "scheme": scheme,
            "query_string": query.encode(),
            "headers": headers,
            "client": ["testclient", 50000],
            "server": [host, port],
            "extensions": {"http.response.template": {}},
        }

        request_complete = False
        response_started = False
        response_complete: anyio.Event
        raw_kwargs: typing.Dict[str, typing.Any] = {"body": io.BytesIO()}
        template = None
        context = None

        async def receive() -> Message:
            nonlocal request_complete

            if request_complete:
                if not response_complete.is_set():
                    await response_complete.wait()
                return {"type": "http.disconnect"}

            body = request.body
            if isinstance(body, str):
                body_bytes: bytes = body.encode("utf-8")
            elif body is None:
                body_bytes = b""
            elif isinstance(body, types.GeneratorType):
                try:
                    chunk = body.send(None)
                    if isinstance(chunk, str):
                        chunk = chunk.encode("utf-8")
                    return {"type": "http.request", "body": chunk, "more_body": True}
                except StopIteration:
                    request_complete = True
                    return {"type": "http.request", "body": b""}
            else:
                body_bytes = body

            request_complete = True
            return {"type": "http.request", "body": body_bytes}

        async def send(message: Message) -> None:
            nonlocal raw_kwargs, response_started, template, context

            if message["type"] == "http.response.start":
                assert (
                    not response_started
                ), 'Received multiple "http.response.start" messages.'
                raw_kwargs["version"] = 11
                raw_kwargs["status"] = message["status"]
                raw_kwargs["reason"] = _get_reason_phrase(message["status"])
                raw_kwargs["headers"] = [
                    (key.decode(), value.decode())
                    for key, value in message.get("headers", [])
                ]
                raw_kwargs["preload_content"] = False
                raw_kwargs["original_response"] = _MockOriginalResponse(
                    raw_kwargs["headers"]
                )
                response_started = True
            elif message["type"] == "http.response.body":
                assert (
                    response_started
                ), 'Received "http.response.body" without "http.response.start".'
                assert (
                    not response_complete.is_set()
                ), 'Received "http.response.body" after response completed.'
                body = message.get("body", b"")
                more_body = message.get("more_body", False)
                if request.method != "HEAD":
                    raw_kwargs["body"].write(body)
                if not more_body:
                    raw_kwargs["body"].seek(0)
                    response_complete.set()
            elif message["type"] == "http.response.template":
                template = message["template"]
                context = message["context"]

        try:
            with anyio.start_blocking_portal(**self.async_backend) as portal:
                response_complete = portal.call(anyio.Event)
                portal.call(self.app, scope, receive, send)
        except BaseException as exc:
            if self.raise_server_exceptions:
                raise exc

        if self.raise_server_exceptions:
            assert response_started, "TestClient did not receive any response."
        elif not response_started:
            raw_kwargs = {
                "version": 11,
                "status": 500,
                "reason": "Internal Server Error",
                "headers": [],
                "preload_content": False,
                "original_response": _MockOriginalResponse([]),
                "body": io.BytesIO(),
            }

        raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
        response = self.build_response(request, raw)
        if template is not None:
            response.template = template
            response.context = context
        return response
Example #23
0
    def test_start_with_nonexistent_backend(self):
        with pytest.raises(LookupError) as exc:
            with start_blocking_portal('foo'):
                pass

        exc.match('No such backend: foo')