async def test_remove_handler():
    ip, port = ("127.0.0.1", 7777)
    stream_name = "stream.1"
    handled_event = asyncio.Event()

    async def handler(stream: Stream):
        assert stream.name == stream_name
        assert stream.ip == ip
        assert stream.port == port
        handled_event.set()

    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            multiplexer.set_handler(stream_name, handler)
            multiplexer.remove_handler(stream_name)

            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.NEW_STREAM,
                                                  stream_name.encode())
            reader_mock.feed_data(encoded_message)

            with pytest.raises(asyncio.TimeoutError):
                await asyncio.wait_for(handled_event.wait(), timeout=0.05)
            with pytest.raises(KeyError):
                multiplexer.remove_handler(stream_name)
Пример #2
0
async def test_remove_handler():
    remote_ip, remote_port = ("127.0.0.2", 7777)
    stream_name = "stream.1"
    handled_event = asyncio.Event()

    async def handler(stream: Stream):
        assert stream.name == stream_name
        assert stream.ip == remote_ip
        assert stream.port == remote_port
        handled_event.set()

    start_server_mock, server_mock, remote_conn_mock = get_start_sever_mock()
    with patch("asyncio.start_server", start_server_mock):
        async with bind_multiplex_listener_context("127.0.0.1",
                                                   7777) as multiplex_listener:
            multiplex_listener.set_handler(stream_name, handler)
            multiplex_listener.remove_handler(stream_name)

            reader_mock, writer_mock = remote_conn_mock.new_mock_connection(
                remote_ip, remote_port)
            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.NEW_STREAM,
                                                  stream_name.encode())
            reader_mock.feed_data(encoded_message)

            with pytest.raises(asyncio.TimeoutError):
                await asyncio.wait_for(handled_event.wait(), timeout=0.05)
            with pytest.raises(KeyError):
                multiplex_listener.remove_handler(stream_name)
async def test_handle_new_stream(stream_names):
    ip, port = ("127.0.0.1", 7777)
    handled_events = {name: asyncio.Event() for name in stream_names}

    async def handler(stream_name: StreamName, stream: Stream):
        assert stream.name == stream_name
        assert stream.ip == ip
        assert stream.port == port
        handled_events[stream_name].set()

    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            for stream_name in stream_names:
                multiplexer.set_handler(stream_name,
                                        partial(handler, stream_name))

            for stream_name in stream_names:
                encoded_message = get_encoded_message(stream_name,
                                                      MplexFlag.NEW_STREAM,
                                                      stream_name.encode())
                reader_mock.feed_data(encoded_message)

                await asyncio.wait_for(handled_events[stream_name].wait(),
                                       timeout=0.05)
async def test_iterate_read(data: bytes):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)

            line_amount = random.randint(1, len(data))
            data_with_lines = b"\n".join([data for _ in range(line_amount)])
            chunked_data = data_with_lines.split(b"\n")

            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.MESSAGE,
                                                  data_with_lines)
            reader_mock.feed_data(encoded_message)

            i = 0
            async for line in stream:
                print(line, i, len(chunked_data))
                assert line == chunked_data[i] + b"\n"
                i += 1
                if i > len(chunked_data):
                    break

        with pytest.raises(RuntimeError):
            async for line in stream:
                print(line)
Пример #5
0
async def test_handle_streams_after_connected(remote_address, stream_names):
    remote_ip, remote_port = remote_address
    handled_events = {name: asyncio.Event() for name in stream_names}

    async def handler(stream_name: StreamName, stream: Stream):
        assert stream.name == stream_name
        assert stream.ip == remote_ip
        assert stream.port == remote_port
        handled_events[stream_name].set()

    start_server_mock, server_mock, remote_conn_mock = get_start_sever_mock()
    with patch("asyncio.start_server", start_server_mock):
        async with bind_multiplex_listener_context("127.0.0.1",
                                                   7777) as multiplex_listener:
            # First open connection, then set handler
            for stream_name in stream_names:
                reader_mock, writer_mock = remote_conn_mock.new_mock_connection(
                    remote_ip, remote_port)
                await asyncio.sleep(0)
                for stream_name in stream_names:
                    multiplex_listener.set_handler(
                        stream_name, partial(handler, stream_name))
                encoded_message = get_encoded_message(stream_name,
                                                      MplexFlag.NEW_STREAM,
                                                      stream_name.encode())
                reader_mock.feed_data(encoded_message)
                await asyncio.wait_for(handled_events[stream_name].wait(),
                                       timeout=0.05)
async def test_read_until_close():
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)

            data = b"data"
            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.MESSAGE, data)
            reader_mock.feed_data(encoded_message)
            encoded_message = get_encoded_message(stream_name, MplexFlag.CLOSE,
                                                  b"")
            reader_mock.feed_data(encoded_message)

            read_data = await asyncio.wait_for(stream.read(), timeout=0.05)
            assert read_data == data
async def test_write_to_stream(data: bytes):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)
            await stream.write(data)

            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.MESSAGE, data)
            writer_mock.write.assert_called_with(encoded_message)
async def test_read_invalid_flag():
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)

            encoded_message = get_encoded_message(stream_name, 4, b"data")
            reader_mock.feed_data(encoded_message)

            with pytest.raises(asyncio.TimeoutError):
                await asyncio.wait_for(stream.read(1), timeout=0.05)
async def test_read_zero_from_stream(mock: AsyncMock):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    mock.return_value = (reader_mock, writer_mock)
    async with open_multiplexer_context(ip, port) as multiplexer:
        stream_name = "stream.1"
        stream = await multiplexer.multiplex(stream_name)

        encoded_message = get_encoded_message(stream_name, MplexFlag.MESSAGE,
                                              b"data")
        reader_mock.feed_data(encoded_message)

        read_data = await asyncio.wait_for(stream.read(0), timeout=0.05)
        assert read_data == b""
async def test_multiplex_one(stream_name: StreamName):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream = await multiplexer.multiplex(stream_name)
            assert isinstance(stream, Stream)
            assert stream.ip == ip
            assert stream.port == port
            assert stream.name == stream_name
            assert not stream.is_closed()

            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.NEW_STREAM,
                                                  stream_name.encode())

            writer_mock.write.assert_called_with(encoded_message)
async def test_readline(data: bytes):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)

            data = data + b"\n" + data
            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.MESSAGE, data)
            reader_mock.feed_data(encoded_message)

            assert await stream.readline() == data.split(b"\n")[0] + b"\n"

        with pytest.raises(RuntimeError):
            await stream.readline()
async def test_readexactly(data: bytes):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)

            read_amount = random.randint(0, len(data) - 1)

            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.MESSAGE, data)
            reader_mock.feed_data(encoded_message)

            assert await stream.readexactly(read_amount) == data[:read_amount]

        with pytest.raises(RuntimeError):
            await stream.readexactly(read_amount)
async def test_read_from_one_stream(data: StreamData):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)

            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.MESSAGE, data)
            reader_mock.feed_data(encoded_message)

            read_amount = min(random.randint(1, 10),
                              len(data))  # read every 'read_amount' bytes
            for i in range(len(data) // read_amount):
                read_data = await asyncio.wait_for(stream.read(read_amount),
                                                   timeout=1)
                assert (read_data == data[i * read_amount:(i * read_amount) +
                                          read_amount])
async def test_read_close_flag():
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)

            encoded_message = get_encoded_message(stream_name, MplexFlag.CLOSE,
                                                  b"")
            reader_mock.feed_data(encoded_message)

            await asyncio.sleep(0.05)
            # 'read' must be first so that it gives control back to event_loop
            with pytest.raises(RuntimeError):
                await stream.close()
            with pytest.raises(RuntimeError):
                await stream.write(b"data")
            with pytest.raises(RuntimeError):
                await asyncio.wait_for(stream.read(1), timeout=0.05)
async def test_close_stream(mock: AsyncMock):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    mock.return_value = (reader_mock, writer_mock)

    async with open_multiplexer_context(ip, port) as multiplexer:
        stream_name = "stream.1"
        stream = await multiplexer.multiplex(stream_name)
        await stream.close()

        encoded_message = get_encoded_message(stream_name, MplexFlag.CLOSE,
                                              stream_name.encode())
        writer_mock.write.assert_called_with(encoded_message)
        assert stream.is_closed()

        with pytest.raises(RuntimeError):
            await stream.close()
        with pytest.raises(RuntimeError):
            await asyncio.wait_for(stream.read(1), timeout=0.05)
        with pytest.raises(RuntimeError):
            await stream.write(b"data")
async def test_readuntil(data: bytes):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)

            separator_index = random.randint(0, len(data) - 1)
            separator = data[separator_index:separator_index + 1]
            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.MESSAGE, data)
            reader_mock.feed_data(encoded_message)

            assert (await
                    stream.readuntil(separator) == data.split(separator)[0] +
                    separator)

        with pytest.raises(RuntimeError):
            await stream.readuntil(separator)
async def test_read_from_multiple_streams(stream_names: List[StreamName],
                                          data: StreamData):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            streams = {
                stream_name: await multiplexer.multiplex(stream_name)
                for stream_name in stream_names
            }
            receiver, non_receiver = random.sample(stream_names, k=2)
            encoded_message = get_encoded_message(receiver, MplexFlag.MESSAGE,
                                                  data)
            reader_mock.feed_data(encoded_message)

            with pytest.raises(asyncio.TimeoutError):
                await asyncio.wait_for(streams[non_receiver].read(1),
                                       timeout=0.05)

            assert await streams[receiver].read(len(data)) == data