async def test_handle_new_stream(stream_names):
    ip, port = ("127.0.0.1", 7777)
    handled_events = {name: asyncio.Event() for name in stream_names}

    async def handler(stream_name: StreamName, stream: Stream):
        assert stream.name == stream_name
        assert stream.ip == ip
        assert stream.port == port
        handled_events[stream_name].set()

    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            for stream_name in stream_names:
                multiplexer.set_handler(stream_name,
                                        partial(handler, stream_name))

            for stream_name in stream_names:
                encoded_message = get_encoded_message(stream_name,
                                                      MplexFlag.NEW_STREAM,
                                                      stream_name.encode())
                reader_mock.feed_data(encoded_message)

                await asyncio.wait_for(handled_events[stream_name].wait(),
                                       timeout=0.05)
async def test_remove_handler():
    ip, port = ("127.0.0.1", 7777)
    stream_name = "stream.1"
    handled_event = asyncio.Event()

    async def handler(stream: Stream):
        assert stream.name == stream_name
        assert stream.ip == ip
        assert stream.port == port
        handled_event.set()

    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            multiplexer.set_handler(stream_name, handler)
            multiplexer.remove_handler(stream_name)

            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.NEW_STREAM,
                                                  stream_name.encode())
            reader_mock.feed_data(encoded_message)

            with pytest.raises(asyncio.TimeoutError):
                await asyncio.wait_for(handled_event.wait(), timeout=0.05)
            with pytest.raises(KeyError):
                multiplexer.remove_handler(stream_name)
async def test_iterate_read(data: bytes):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)

            line_amount = random.randint(1, len(data))
            data_with_lines = b"\n".join([data for _ in range(line_amount)])
            chunked_data = data_with_lines.split(b"\n")

            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.MESSAGE,
                                                  data_with_lines)
            reader_mock.feed_data(encoded_message)

            i = 0
            async for line in stream:
                print(line, i, len(chunked_data))
                assert line == chunked_data[i] + b"\n"
                i += 1
                if i > len(chunked_data):
                    break

        with pytest.raises(RuntimeError):
            async for line in stream:
                print(line)
async def test_multiplex_empty_name():
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            with pytest.raises(ValueError):
                stream = await multiplexer.multiplex("")
async def test_close_multiplexer(mock: AsyncMock):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    mock.return_value = (reader_mock, writer_mock)

    multiplexer = await open_multiplexer(ip, port)
    await multiplexer.close()
    writer_mock.write_eof.assert_called()
    writer_mock.wait_closed.assert_awaited()
async def test_open_multiplexer(mock: AsyncMock):
    ip, port = ("127.0.0.1", 7777)
    mock.return_value = get_connection_mock(ip, port)

    multiplexer = await open_multiplexer(ip, port)
    mock.assert_awaited_with(ip, port)
    assert isinstance(multiplexer, Multiplexer)
    assert multiplexer.ip == ip
    assert multiplexer.port == port
async def test_multiplexer_contextmanager(mock: AsyncMock):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    mock.return_value = (reader_mock, writer_mock)

    async with open_multiplexer_context(ip, port) as multiplexer:
        mock.assert_awaited_with(ip, port)
        assert isinstance(multiplexer, Multiplexer)
    writer_mock.write_eof.assert_called()
    writer_mock.wait_closed.assert_awaited()
async def test_open_multiplex_twice(mock: AsyncMock):
    ip, port = ("127.0.0.1", 7777)
    mock.return_value = get_connection_mock(ip, port)
    async with open_multiplexer_context(ip, port) as multiplexer:
        first_stream = await multiplexer.multiplex("stream.1")
        with pytest.raises(ValueError):
            second_stream = await multiplexer.multiplex("stream.1")

        await first_stream.close()
        second_stream = await multiplexer.multiplex("stream.1")
async def test_errors_after_close_multiplexer(mock: AsyncMock):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    mock.return_value = (reader_mock, writer_mock)

    multiplexer = await open_multiplexer(ip, port)
    await multiplexer.close()
    with pytest.raises(RuntimeError):
        await multiplexer.multiplex("stream.1")
    with pytest.raises(RuntimeError):
        await multiplexer.close()
Beispiel #10
0
async def test_read_uvarint_overflow():
    uvarint_overflow = bytearray(
        [0b10000000 for _ in range(UVARINT_MAX_BYTES + 1)])
    reader_mock, writer_mock = get_connection_mock("127.0.0.1", 7777)
    stream_id, flag = 12, MplexFlag.NEW_STREAM
    encoded_message = uvarint.encode(stream_id << 3 | flag) + uvarint_overflow
    reader_mock.feed_data(encoded_message)

    mplex_protocol = MplexProtocol(reader_mock, writer_mock)
    with pytest.raises(OverflowError):
        message = await mplex_protocol.read_message()
async def test_close_from_multiplexer():
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        multiplexer = await open_multiplexer(ip, port)
        stream_name = "stream.1"
        stream = await multiplexer.multiplex(stream_name)

        await multiplexer.close()
        with pytest.raises(RuntimeError):
            await stream.close()
async def test_write_to_stream(data: bytes):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)
            await stream.write(data)

            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.MESSAGE, data)
            writer_mock.write.assert_called_with(encoded_message)
Beispiel #13
0
async def test_write_message(fragmented_message: Tuple[StreamID, MplexFlag,
                                                       StreamData]):
    reader_mock, writer_mock = get_connection_mock("127.0.0.1", 7777)
    stream_id, flag, data = fragmented_message

    mplex_protocol = MplexProtocol(reader_mock, writer_mock)
    await mplex_protocol.write_message(
        MplexMessage(stream_id=stream_id, flag=flag, data=data))
    encoded_message = (uvarint.encode(stream_id << 3 | flag) +
                       uvarint.encode(len(data)) + data)

    writer_mock.write.assert_called_with(encoded_message)
    writer_mock.drain.assert_awaited()
async def test_read_invalid_flag():
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)

            encoded_message = get_encoded_message(stream_name, 4, b"data")
            reader_mock.feed_data(encoded_message)

            with pytest.raises(asyncio.TimeoutError):
                await asyncio.wait_for(stream.read(1), timeout=0.05)
async def test_read_zero_from_stream(mock: AsyncMock):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    mock.return_value = (reader_mock, writer_mock)
    async with open_multiplexer_context(ip, port) as multiplexer:
        stream_name = "stream.1"
        stream = await multiplexer.multiplex(stream_name)

        encoded_message = get_encoded_message(stream_name, MplexFlag.MESSAGE,
                                              b"data")
        reader_mock.feed_data(encoded_message)

        read_data = await asyncio.wait_for(stream.read(0), timeout=0.05)
        assert read_data == b""
Beispiel #16
0
async def test_read_message(fragmented_message: Tuple[StreamID, MplexFlag,
                                                      StreamData]):
    reader_mock, writer_mock = get_connection_mock("127.0.0.1", 7777)
    stream_id, flag, data = fragmented_message

    mplex_protocol = MplexProtocol(reader_mock, writer_mock)
    encoded_message = (uvarint.encode(stream_id << 3 | flag) +
                       uvarint.encode(len(data)) + data)
    reader_mock.feed_data(encoded_message)
    message = await mplex_protocol.read_message()
    assert isinstance(message, MplexMessage)
    assert message.stream_id == stream_id
    assert message.flag == flag
    assert message.data == data
async def test_multiplex_one(stream_name: StreamName):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream = await multiplexer.multiplex(stream_name)
            assert isinstance(stream, Stream)
            assert stream.ip == ip
            assert stream.port == port
            assert stream.name == stream_name
            assert not stream.is_closed()

            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.NEW_STREAM,
                                                  stream_name.encode())

            writer_mock.write.assert_called_with(encoded_message)
async def test_readline(data: bytes):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)

            data = data + b"\n" + data
            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.MESSAGE, data)
            reader_mock.feed_data(encoded_message)

            assert await stream.readline() == data.split(b"\n")[0] + b"\n"

        with pytest.raises(RuntimeError):
            await stream.readline()
async def test_readexactly(data: bytes):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)

            read_amount = random.randint(0, len(data) - 1)

            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.MESSAGE, data)
            reader_mock.feed_data(encoded_message)

            assert await stream.readexactly(read_amount) == data[:read_amount]

        with pytest.raises(RuntimeError):
            await stream.readexactly(read_amount)
async def test_read_until_close():
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)

            data = b"data"
            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.MESSAGE, data)
            reader_mock.feed_data(encoded_message)
            encoded_message = get_encoded_message(stream_name, MplexFlag.CLOSE,
                                                  b"")
            reader_mock.feed_data(encoded_message)

            read_data = await asyncio.wait_for(stream.read(), timeout=0.05)
            assert read_data == data
async def test_read_from_one_stream(data: StreamData):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)

            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.MESSAGE, data)
            reader_mock.feed_data(encoded_message)

            read_amount = min(random.randint(1, 10),
                              len(data))  # read every 'read_amount' bytes
            for i in range(len(data) // read_amount):
                read_data = await asyncio.wait_for(stream.read(read_amount),
                                                   timeout=1)
                assert (read_data == data[i * read_amount:(i * read_amount) +
                                          read_amount])
async def test_readuntil(data: bytes):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)

            separator_index = random.randint(0, len(data) - 1)
            separator = data[separator_index:separator_index + 1]
            encoded_message = get_encoded_message(stream_name,
                                                  MplexFlag.MESSAGE, data)
            reader_mock.feed_data(encoded_message)

            assert (await
                    stream.readuntil(separator) == data.split(separator)[0] +
                    separator)

        with pytest.raises(RuntimeError):
            await stream.readuntil(separator)
async def test_close_stream(mock: AsyncMock):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    mock.return_value = (reader_mock, writer_mock)

    async with open_multiplexer_context(ip, port) as multiplexer:
        stream_name = "stream.1"
        stream = await multiplexer.multiplex(stream_name)
        await stream.close()

        encoded_message = get_encoded_message(stream_name, MplexFlag.CLOSE,
                                              stream_name.encode())
        writer_mock.write.assert_called_with(encoded_message)
        assert stream.is_closed()

        with pytest.raises(RuntimeError):
            await stream.close()
        with pytest.raises(RuntimeError):
            await asyncio.wait_for(stream.read(1), timeout=0.05)
        with pytest.raises(RuntimeError):
            await stream.write(b"data")
async def test_read_close_flag():
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            stream_name = "stream.1"
            stream = await multiplexer.multiplex(stream_name)

            encoded_message = get_encoded_message(stream_name, MplexFlag.CLOSE,
                                                  b"")
            reader_mock.feed_data(encoded_message)

            await asyncio.sleep(0.05)
            # 'read' must be first so that it gives control back to event_loop
            with pytest.raises(RuntimeError):
                await stream.close()
            with pytest.raises(RuntimeError):
                await stream.write(b"data")
            with pytest.raises(RuntimeError):
                await asyncio.wait_for(stream.read(1), timeout=0.05)
async def test_read_from_multiple_streams(stream_names: List[StreamName],
                                          data: StreamData):
    ip, port = ("127.0.0.1", 7777)
    reader_mock, writer_mock = get_connection_mock(ip, port)
    with patch("asyncio.open_connection",
               return_value=(reader_mock, writer_mock)):
        async with open_multiplexer_context(ip, port) as multiplexer:
            streams = {
                stream_name: await multiplexer.multiplex(stream_name)
                for stream_name in stream_names
            }
            receiver, non_receiver = random.sample(stream_names, k=2)
            encoded_message = get_encoded_message(receiver, MplexFlag.MESSAGE,
                                                  data)
            reader_mock.feed_data(encoded_message)

            with pytest.raises(asyncio.TimeoutError):
                await asyncio.wait_for(streams[non_receiver].read(1),
                                       timeout=0.05)

            assert await streams[receiver].read(len(data)) == data
Beispiel #26
0
 def new_mock_connection(self, remote_ip, remote_port):
     reader_mock, writer_mock = get_connection_mock(remote_ip, remote_port)
     asyncio.create_task(self.callback(reader_mock, writer_mock))
     return reader_mock, writer_mock
Beispiel #27
0
def test_create_mplex_protocol():
    mplex_protocol = MplexProtocol(*get_connection_mock("127.0.0.1", 7777))
    assert isinstance(mplex_protocol, MplexProtocol)