Beispiel #1
0
 def __init__(
     self,
     subscription_manager: SubscriptionManager[abc.MessageT],
 ) -> None:
     self._subscription_manager = subscription_manager
     self._send_stream, self._receive_stream = anyio.create_memory_object_stream(
         math.inf, subscription_manager.topic_type)
Beispiel #2
0
        async def call_next(request: Request) -> Response:
            app_exc: typing.Optional[Exception] = None
            send_stream, recv_stream = anyio.create_memory_object_stream()

            async def coro() -> None:
                nonlocal app_exc

                async with send_stream:
                    try:
                        await self.app(scope, request.receive,
                                       send_stream.send)
                    except Exception as exc:
                        app_exc = exc

            task_group.start_soon(coro)

            try:
                message = await recv_stream.receive()
            except anyio.EndOfStream:
                if app_exc is not None:
                    raise app_exc
                raise RuntimeError("No response returned.")

            assert message["type"] == "http.response.start"

            async def body_stream() -> typing.AsyncGenerator[bytes, None]:
                async with recv_stream:
                    async for message in recv_stream:
                        assert message["type"] == "http.response.body"
                        yield message.get("body", b"")

            response = StreamingResponse(status_code=message["status"],
                                         content=body_stream())
            response.raw_headers = message["headers"]
            return response
Beispiel #3
0
    def __init__(self,
                 tg: anyio.abc.TaskGroup,
                 config=None,
                 plugin_namespace=None):
        self.logger = logging.getLogger(__name__)
        self.config = deepcopy(_defaults)
        if config is not None:
            self.config.update(config)
        self._build_listeners_config(self.config)

        self._servers = dict()
        self._init_states()
        self._sessions = dict()
        self._subscriptions = dict()

        self._broadcast_queue_s, self._broadcast_queue_r = anyio.create_memory_object_stream(
            100)
        self._tg = tg
        self._do_retain = self.config.get("retain", True)
        if self._do_retain:
            self._retained_messages = dict()

        # Init plugins manager
        context = BrokerContext(self, self.config)
        if plugin_namespace:
            namespace = plugin_namespace
        else:
            namespace = "distmqtt.broker.plugins"
        self.plugins_manager = PluginManager(tg, namespace, context)
Beispiel #4
0
def test_websocket_concurrency_pattern(test_client_factory):
    stream_send, stream_receive = anyio.create_memory_object_stream()

    async def reader(websocket):
        async with stream_send:
            async for data in websocket.iter_json():
                await stream_send.send(data)

    async def writer(websocket):
        async with stream_receive:
            async for message in stream_receive:
                await websocket.send_json(message)

    async def app(scope: Scope, receive: Receive, send: Send) -> None:
        websocket = WebSocket(scope, receive=receive, send=send)
        await websocket.accept()
        async with anyio.create_task_group() as task_group:
            task_group.start_soon(reader, websocket)
            await writer(websocket)
        await websocket.close()

    client = test_client_factory(app)
    with client.websocket_connect("/") as websocket:
        websocket.send_json({"hello": "world"})
        data = websocket.receive_json()
        assert data == {"hello": "world"}
Beispiel #5
0
async def test_receive_exactly_incomplete():
    send_stream, receive_stream = create_memory_object_stream(1)
    buffered_stream = BufferedByteReceiveStream(receive_stream)
    await send_stream.send(b'abcd')
    await send_stream.aclose()
    with pytest.raises(IncompleteRead):
        await buffered_stream.receive_exactly(8)
Beispiel #6
0
    async def __run(self, ws):
        """Drains all messages from a WebSocket, sending them to the client's
        listeners.

        :param ws: WebSocket to drain.
        """

        send_stream, receive_stream = anyio.create_memory_object_stream()

        await self.taskgroup.spawn(self._check_runtime, receive_stream)

        async for msg in ws:
            if isinstance(msg, CloseConnection):
                break
            elif not isinstance(msg, TextMessage):
                log.warning("Unknown JSON message type: %s", repr(msg))
                continue  # ignore
            msg_json = json.loads(msg.data)
            if not isinstance(msg_json, dict) or 'type' not in msg_json:
                log.error("Invalid event: %s", msg)
                continue
            try:
                await send_stream.send(msg_json)
                await self.process_ws(msg_json)
            finally:
                await send_stream.send(None)
        await send_stream.send(False)
Beispiel #7
0
async def test_close_send_while_receiving():
    send, receive = create_memory_object_stream(1)
    with pytest.raises(EndOfStream):
        async with create_task_group() as tg:
            tg.spawn(receive.receive)
            await wait_all_tasks_blocked()
            await send.aclose()
Beispiel #8
0
 async def __aenter__(self):
     await self.channel.basic_consume(self._data,
                                      consumer_tag=self.consumer_tag,
                                      **self.kwargs)
     self._q_w, self._q_r = anyio.create_memory_object_stream(
         30)  # TODO: 2 + possible prefetch
     return self
Beispiel #9
0
async def test_close_receive_while_sending():
    send, receive = create_memory_object_stream(0)
    with pytest.raises(BrokenResourceError):
        async with create_task_group() as tg:
            tg.spawn(send.send, 'hello')
            await wait_all_tasks_blocked()
            await receive.aclose()
Beispiel #10
0
async def test_cancel_during_receive() -> None:
    """
    Test that cancelling a pending receive() operation does not cause an item in the stream to be
    lost.

    """
    receiver_scope = None

    async def scoped_receiver() -> None:
        nonlocal receiver_scope
        with CancelScope() as receiver_scope:
            received.append(await receive.receive())

        assert receiver_scope.cancel_called

    received: List[str] = []
    send, receive = create_memory_object_stream()
    async with create_task_group() as tg:
        tg.start_soon(scoped_receiver)
        await wait_all_tasks_blocked()
        send.send_nowait("hello")
        assert receiver_scope is not None
        receiver_scope.cancel()

    assert received == ["hello"]
Beispiel #11
0
    def __init__(self, watermark: int, direction="lt"):
        self.watermark = watermark
        self.direction = direction
        self._event = anyio.create_event()

        s, r = anyio.create_memory_object_stream(max_buffer_size=math.inf)
        self._send_stream, self._receive_stream = s, r
Beispiel #12
0
 def __init__(self, conn, command, params, seq, expect_body):
     self._conn = conn
     self._command = command
     self._params = params
     self.seq = seq
     self._q_w, self._q_r = anyio.create_memory_object_stream(10000)
     self.expect_body = -expect_body
Beispiel #13
0
 async def read_with_attachments(self, id, **opts):
     send_stream, receive_stream = anyio.create_memory_object_stream()
     async with self._start_read(id, **opts) as args:
         async with anyio.create_task_group() as tg:
             tg.start_soon(self._read_to_stream, send_stream, *args, tg)
             async with receive_stream:
                 yield receive_stream
Beispiel #14
0
    async def handle_tcpros(
        self,
        protocol: str,
        header: abc.Header,
        client: SocketStream,
    ) -> None:
        """Handle topic subscription from external. We are publisher, client is
        subscriber."""
        require_fields(header, "topic", "md5sum", "callerid")

        check_md5sum(header, getattr(self.topic_type, "_md5sum"))

        await client.send(encode_header(self.header))

        # TODO fix capacity
        send_stream, receive_stream = anyio.create_memory_object_stream(math.inf, bytes)

        async with receive_stream:
            subscriber = ConnectedSubscriber(
                protocol,
                header["callerid"],
                send_stream,
            )
            self._subscribers.add(subscriber)
            await self._on_new_subscriber(subscriber)
            async for data in receive_stream:
                try:
                    await client.send(data)
                except anyio.BrokenResourceError:
                    break
            self._subscribers.discard(subscriber)
    async def _handle_lifespan(
        self,
        scope: asgitypes.Scope,
        receive: asgitypes.ASGIReceiveCallable,
        send: asgitypes.ASGISendCallable,
    ) -> None:
        self.app_queues = {
            path: anyio.create_memory_object_stream(MAX_QUEUE_SIZE)
            for path in self.mounts
        }
        self.startup_complete = {path: False for path in self.mounts}
        self.shutdown_complete = {path: False for path in self.mounts}

        async with anyio.create_task_group() as tg:
            for path, app in self.mounts.items():
                await tg.spawn(
                    _invoke_asgi,
                    app,
                    scope,
                    self.app_queues[path][1].receive,
                    functools.partial(self.send, path, send),
                )

            while True:
                message = await receive()
                for channels in self.app_queues.values():
                    await channels[0].send(message)
                if message["type"] == "lifespan.shutdown":
                    break
Beispiel #16
0
async def test_clone():
    send1, receive1 = create_memory_object_stream(1)
    send2 = send1.clone()
    receive2 = receive1.clone()
    await send1.aclose()
    await receive1.aclose()
    send2.send_nowait('hello')
    assert receive2.receive_nowait() == 'hello'
Beispiel #17
0
async def test_receive_send_closed_send_stream():
    send, receive = create_memory_object_stream()
    await send.aclose()
    with pytest.raises(EndOfStream):
        receive.receive_nowait()

    with pytest.raises(ClosedResourceError):
        await send.send(None)
Beispiel #18
0
async def test_receive_exactly() -> None:
    send_stream, receive_stream = create_memory_object_stream(2)
    buffered_stream = BufferedByteReceiveStream(receive_stream)
    await send_stream.send(b"abcd")
    await send_stream.send(b"efgh")
    result = await buffered_stream.receive_exactly(8)
    assert result == b"abcdefgh"
    assert isinstance(result, bytes)
Beispiel #19
0
async def test_receive_exactly():
    send_stream, receive_stream = create_memory_object_stream(2)
    buffered_stream = BufferedByteReceiveStream(receive_stream)
    await send_stream.send(b'abcd')
    await send_stream.send(b'efgh')
    result = await buffered_stream.receive_exactly(8)
    assert result == b'abcdefgh'
    assert isinstance(result, bytes)
Beispiel #20
0
async def test_cancel_send():
    send, receive = create_memory_object_stream()
    async with create_task_group() as tg:
        tg.spawn(send.send, 'hello')
        await wait_all_tasks_blocked()
        tg.cancel_scope.cancel()

    with pytest.raises(WouldBlock):
        receive.receive_nowait()
Beispiel #21
0
 def __init__(self, app: typing.Callable, scope: Scope) -> None:
     self.app = app
     self.scope = scope
     self.status = None
     self.response_headers = None
     self.stream_send, self.stream_receive = anyio.create_memory_object_stream(
         math.inf)
     self.response_started = False
     self.exc_info: typing.Any = None
Beispiel #22
0
async def test_receive() -> None:
    send_stream, receive_stream = create_memory_object_stream(1)
    text_stream = TextReceiveStream(receive_stream)
    await send_stream.send(b"\xc3\xa5\xc3\xa4\xc3")  # ends with half of the "ö" letter
    assert await text_stream.receive() == "åä"

    # Send the missing byte for "ö"
    await send_stream.send(b"\xb6")
    assert await text_stream.receive() == "ö"
Beispiel #23
0
 def __enter__(self) -> "TestClient":
     self.exit_stack = contextlib.ExitStack()
     self.portal = self.exit_stack.enter_context(
         anyio.start_blocking_portal(**self.async_backend)
     )
     self.stream_send = StapledObjectStream(
         *anyio.create_memory_object_stream(math.inf)
     )
     self.stream_receive = StapledObjectStream(
         *anyio.create_memory_object_stream(math.inf)
     )
     try:
         self.task = self.portal.start_task_soon(self.lifespan)
         self.portal.call(self.wait_startup)
     except Exception:
         self.exit_stack.close()
         raise
     return self
Beispiel #24
0
async def test_cancel_receive() -> None:
    send, receive = create_memory_object_stream()
    async with create_task_group() as tg:
        tg.start_soon(receive.receive)
        await wait_all_tasks_blocked()
        tg.cancel_scope.cancel()

    with pytest.raises(WouldBlock):
        send.send_nowait("hello")
Beispiel #25
0
async def test_receive_until_incomplete():
    send_stream, receive_stream = create_memory_object_stream(1)
    buffered_stream = BufferedByteReceiveStream(receive_stream)
    await send_stream.send(b'abcd')
    await send_stream.aclose()
    with pytest.raises(IncompleteRead):
        assert await buffered_stream.receive_until(b'de', 10)

    assert buffered_stream.buffer == b'abcd'
Beispiel #26
0
async def test_receive():
    send_stream, receive_stream = create_memory_object_stream(1)
    text_stream = TextReceiveStream(receive_stream)
    await send_stream.send(b'\xc3\xa5\xc3\xa4\xc3')  # ends with half of the "ö" letter
    assert await text_stream.receive() == 'åä'

    # Send the missing byte for "ö"
    await send_stream.send(b'\xb6')
    assert await text_stream.receive() == 'ö'
Beispiel #27
0
async def test_sync_close() -> None:
    send_stream, receive_stream = create_memory_object_stream(1)
    with send_stream, receive_stream:
        pass

    with pytest.raises(ClosedResourceError):
        send_stream.send_nowait(None)

    with pytest.raises(ClosedResourceError):
        receive_stream.receive_nowait()
Beispiel #28
0
async def test_bidirectional_stream():
    send_stream, receive_stream = create_memory_object_stream(1)
    stapled_stream = StapledObjectStream(send_stream, receive_stream)
    text_stream = TextStream(stapled_stream)

    await text_stream.send('åäö')
    assert await receive_stream.receive() == b'\xc3\xa5\xc3\xa4\xc3\xb6'

    await send_stream.send(b'\xc3\xa6\xc3\xb8')
    assert await text_stream.receive() == 'æø'
Beispiel #29
0
async def test_send_is_unblocked_after_receive_nowait():
    send, receive = create_memory_object_stream(1)
    send.send_nowait('hello')

    with fail_after(1):
        async with create_task_group() as tg:
            tg.spawn(send.send, 'anyio')
            await wait_all_tasks_blocked()
            assert receive.receive_nowait() == 'hello'

    assert receive.receive_nowait() == 'anyio'
Beispiel #30
0
async def test_send_is_unblocked_after_receive_nowait() -> None:
    send, receive = create_memory_object_stream(1)
    send.send_nowait("hello")

    with fail_after(1):
        async with create_task_group() as tg:
            tg.start_soon(send.send, "anyio")
            await wait_all_tasks_blocked()
            assert receive.receive_nowait() == "hello"

    assert receive.receive_nowait() == "anyio"