async def test_remove_handler(): ip, port = ("127.0.0.1", 7777) stream_name = "stream.1" handled_event = asyncio.Event() async def handler(stream: Stream): assert stream.name == stream_name assert stream.ip == ip assert stream.port == port handled_event.set() reader_mock, writer_mock = get_connection_mock(ip, port) with patch("asyncio.open_connection", return_value=(reader_mock, writer_mock)): async with open_multiplexer_context(ip, port) as multiplexer: multiplexer.set_handler(stream_name, handler) multiplexer.remove_handler(stream_name) encoded_message = get_encoded_message(stream_name, MplexFlag.NEW_STREAM, stream_name.encode()) reader_mock.feed_data(encoded_message) with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(handled_event.wait(), timeout=0.05) with pytest.raises(KeyError): multiplexer.remove_handler(stream_name)
async def test_remove_handler(): remote_ip, remote_port = ("127.0.0.2", 7777) stream_name = "stream.1" handled_event = asyncio.Event() async def handler(stream: Stream): assert stream.name == stream_name assert stream.ip == remote_ip assert stream.port == remote_port handled_event.set() start_server_mock, server_mock, remote_conn_mock = get_start_sever_mock() with patch("asyncio.start_server", start_server_mock): async with bind_multiplex_listener_context("127.0.0.1", 7777) as multiplex_listener: multiplex_listener.set_handler(stream_name, handler) multiplex_listener.remove_handler(stream_name) reader_mock, writer_mock = remote_conn_mock.new_mock_connection( remote_ip, remote_port) encoded_message = get_encoded_message(stream_name, MplexFlag.NEW_STREAM, stream_name.encode()) reader_mock.feed_data(encoded_message) with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(handled_event.wait(), timeout=0.05) with pytest.raises(KeyError): multiplex_listener.remove_handler(stream_name)
async def test_handle_new_stream(stream_names): ip, port = ("127.0.0.1", 7777) handled_events = {name: asyncio.Event() for name in stream_names} async def handler(stream_name: StreamName, stream: Stream): assert stream.name == stream_name assert stream.ip == ip assert stream.port == port handled_events[stream_name].set() reader_mock, writer_mock = get_connection_mock(ip, port) with patch("asyncio.open_connection", return_value=(reader_mock, writer_mock)): async with open_multiplexer_context(ip, port) as multiplexer: for stream_name in stream_names: multiplexer.set_handler(stream_name, partial(handler, stream_name)) for stream_name in stream_names: encoded_message = get_encoded_message(stream_name, MplexFlag.NEW_STREAM, stream_name.encode()) reader_mock.feed_data(encoded_message) await asyncio.wait_for(handled_events[stream_name].wait(), timeout=0.05)
async def test_iterate_read(data: bytes): ip, port = ("127.0.0.1", 7777) reader_mock, writer_mock = get_connection_mock(ip, port) with patch("asyncio.open_connection", return_value=(reader_mock, writer_mock)): async with open_multiplexer_context(ip, port) as multiplexer: stream_name = "stream.1" stream = await multiplexer.multiplex(stream_name) line_amount = random.randint(1, len(data)) data_with_lines = b"\n".join([data for _ in range(line_amount)]) chunked_data = data_with_lines.split(b"\n") encoded_message = get_encoded_message(stream_name, MplexFlag.MESSAGE, data_with_lines) reader_mock.feed_data(encoded_message) i = 0 async for line in stream: print(line, i, len(chunked_data)) assert line == chunked_data[i] + b"\n" i += 1 if i > len(chunked_data): break with pytest.raises(RuntimeError): async for line in stream: print(line)
async def test_handle_streams_after_connected(remote_address, stream_names): remote_ip, remote_port = remote_address handled_events = {name: asyncio.Event() for name in stream_names} async def handler(stream_name: StreamName, stream: Stream): assert stream.name == stream_name assert stream.ip == remote_ip assert stream.port == remote_port handled_events[stream_name].set() start_server_mock, server_mock, remote_conn_mock = get_start_sever_mock() with patch("asyncio.start_server", start_server_mock): async with bind_multiplex_listener_context("127.0.0.1", 7777) as multiplex_listener: # First open connection, then set handler for stream_name in stream_names: reader_mock, writer_mock = remote_conn_mock.new_mock_connection( remote_ip, remote_port) await asyncio.sleep(0) for stream_name in stream_names: multiplex_listener.set_handler( stream_name, partial(handler, stream_name)) encoded_message = get_encoded_message(stream_name, MplexFlag.NEW_STREAM, stream_name.encode()) reader_mock.feed_data(encoded_message) await asyncio.wait_for(handled_events[stream_name].wait(), timeout=0.05)
async def test_read_until_close(): ip, port = ("127.0.0.1", 7777) reader_mock, writer_mock = get_connection_mock(ip, port) with patch("asyncio.open_connection", return_value=(reader_mock, writer_mock)): async with open_multiplexer_context(ip, port) as multiplexer: stream_name = "stream.1" stream = await multiplexer.multiplex(stream_name) data = b"data" encoded_message = get_encoded_message(stream_name, MplexFlag.MESSAGE, data) reader_mock.feed_data(encoded_message) encoded_message = get_encoded_message(stream_name, MplexFlag.CLOSE, b"") reader_mock.feed_data(encoded_message) read_data = await asyncio.wait_for(stream.read(), timeout=0.05) assert read_data == data
async def test_write_to_stream(data: bytes): ip, port = ("127.0.0.1", 7777) reader_mock, writer_mock = get_connection_mock(ip, port) with patch("asyncio.open_connection", return_value=(reader_mock, writer_mock)): async with open_multiplexer_context(ip, port) as multiplexer: stream_name = "stream.1" stream = await multiplexer.multiplex(stream_name) await stream.write(data) encoded_message = get_encoded_message(stream_name, MplexFlag.MESSAGE, data) writer_mock.write.assert_called_with(encoded_message)
async def test_read_invalid_flag(): ip, port = ("127.0.0.1", 7777) reader_mock, writer_mock = get_connection_mock(ip, port) with patch("asyncio.open_connection", return_value=(reader_mock, writer_mock)): async with open_multiplexer_context(ip, port) as multiplexer: stream_name = "stream.1" stream = await multiplexer.multiplex(stream_name) encoded_message = get_encoded_message(stream_name, 4, b"data") reader_mock.feed_data(encoded_message) with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(stream.read(1), timeout=0.05)
async def test_read_zero_from_stream(mock: AsyncMock): ip, port = ("127.0.0.1", 7777) reader_mock, writer_mock = get_connection_mock(ip, port) mock.return_value = (reader_mock, writer_mock) async with open_multiplexer_context(ip, port) as multiplexer: stream_name = "stream.1" stream = await multiplexer.multiplex(stream_name) encoded_message = get_encoded_message(stream_name, MplexFlag.MESSAGE, b"data") reader_mock.feed_data(encoded_message) read_data = await asyncio.wait_for(stream.read(0), timeout=0.05) assert read_data == b""
async def test_multiplex_one(stream_name: StreamName): ip, port = ("127.0.0.1", 7777) reader_mock, writer_mock = get_connection_mock(ip, port) with patch("asyncio.open_connection", return_value=(reader_mock, writer_mock)): async with open_multiplexer_context(ip, port) as multiplexer: stream = await multiplexer.multiplex(stream_name) assert isinstance(stream, Stream) assert stream.ip == ip assert stream.port == port assert stream.name == stream_name assert not stream.is_closed() encoded_message = get_encoded_message(stream_name, MplexFlag.NEW_STREAM, stream_name.encode()) writer_mock.write.assert_called_with(encoded_message)
async def test_readline(data: bytes): ip, port = ("127.0.0.1", 7777) reader_mock, writer_mock = get_connection_mock(ip, port) with patch("asyncio.open_connection", return_value=(reader_mock, writer_mock)): async with open_multiplexer_context(ip, port) as multiplexer: stream_name = "stream.1" stream = await multiplexer.multiplex(stream_name) data = data + b"\n" + data encoded_message = get_encoded_message(stream_name, MplexFlag.MESSAGE, data) reader_mock.feed_data(encoded_message) assert await stream.readline() == data.split(b"\n")[0] + b"\n" with pytest.raises(RuntimeError): await stream.readline()
async def test_readexactly(data: bytes): ip, port = ("127.0.0.1", 7777) reader_mock, writer_mock = get_connection_mock(ip, port) with patch("asyncio.open_connection", return_value=(reader_mock, writer_mock)): async with open_multiplexer_context(ip, port) as multiplexer: stream_name = "stream.1" stream = await multiplexer.multiplex(stream_name) read_amount = random.randint(0, len(data) - 1) encoded_message = get_encoded_message(stream_name, MplexFlag.MESSAGE, data) reader_mock.feed_data(encoded_message) assert await stream.readexactly(read_amount) == data[:read_amount] with pytest.raises(RuntimeError): await stream.readexactly(read_amount)
async def test_read_from_one_stream(data: StreamData): ip, port = ("127.0.0.1", 7777) reader_mock, writer_mock = get_connection_mock(ip, port) with patch("asyncio.open_connection", return_value=(reader_mock, writer_mock)): async with open_multiplexer_context(ip, port) as multiplexer: stream_name = "stream.1" stream = await multiplexer.multiplex(stream_name) encoded_message = get_encoded_message(stream_name, MplexFlag.MESSAGE, data) reader_mock.feed_data(encoded_message) read_amount = min(random.randint(1, 10), len(data)) # read every 'read_amount' bytes for i in range(len(data) // read_amount): read_data = await asyncio.wait_for(stream.read(read_amount), timeout=1) assert (read_data == data[i * read_amount:(i * read_amount) + read_amount])
async def test_read_close_flag(): ip, port = ("127.0.0.1", 7777) reader_mock, writer_mock = get_connection_mock(ip, port) with patch("asyncio.open_connection", return_value=(reader_mock, writer_mock)): async with open_multiplexer_context(ip, port) as multiplexer: stream_name = "stream.1" stream = await multiplexer.multiplex(stream_name) encoded_message = get_encoded_message(stream_name, MplexFlag.CLOSE, b"") reader_mock.feed_data(encoded_message) await asyncio.sleep(0.05) # 'read' must be first so that it gives control back to event_loop with pytest.raises(RuntimeError): await stream.close() with pytest.raises(RuntimeError): await stream.write(b"data") with pytest.raises(RuntimeError): await asyncio.wait_for(stream.read(1), timeout=0.05)
async def test_close_stream(mock: AsyncMock): ip, port = ("127.0.0.1", 7777) reader_mock, writer_mock = get_connection_mock(ip, port) mock.return_value = (reader_mock, writer_mock) async with open_multiplexer_context(ip, port) as multiplexer: stream_name = "stream.1" stream = await multiplexer.multiplex(stream_name) await stream.close() encoded_message = get_encoded_message(stream_name, MplexFlag.CLOSE, stream_name.encode()) writer_mock.write.assert_called_with(encoded_message) assert stream.is_closed() with pytest.raises(RuntimeError): await stream.close() with pytest.raises(RuntimeError): await asyncio.wait_for(stream.read(1), timeout=0.05) with pytest.raises(RuntimeError): await stream.write(b"data")
async def test_readuntil(data: bytes): ip, port = ("127.0.0.1", 7777) reader_mock, writer_mock = get_connection_mock(ip, port) with patch("asyncio.open_connection", return_value=(reader_mock, writer_mock)): async with open_multiplexer_context(ip, port) as multiplexer: stream_name = "stream.1" stream = await multiplexer.multiplex(stream_name) separator_index = random.randint(0, len(data) - 1) separator = data[separator_index:separator_index + 1] encoded_message = get_encoded_message(stream_name, MplexFlag.MESSAGE, data) reader_mock.feed_data(encoded_message) assert (await stream.readuntil(separator) == data.split(separator)[0] + separator) with pytest.raises(RuntimeError): await stream.readuntil(separator)
async def test_read_from_multiple_streams(stream_names: List[StreamName], data: StreamData): ip, port = ("127.0.0.1", 7777) reader_mock, writer_mock = get_connection_mock(ip, port) with patch("asyncio.open_connection", return_value=(reader_mock, writer_mock)): async with open_multiplexer_context(ip, port) as multiplexer: streams = { stream_name: await multiplexer.multiplex(stream_name) for stream_name in stream_names } receiver, non_receiver = random.sample(stream_names, k=2) encoded_message = get_encoded_message(receiver, MplexFlag.MESSAGE, data) reader_mock.feed_data(encoded_message) with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(streams[non_receiver].read(1), timeout=0.05) assert await streams[receiver].read(len(data)) == data