class TestTransportMock(unittest.TestCase): def setUp(self): self.protocol = make_protocol_mock() self.loop = asyncio.get_event_loop() self.t = TransportMock(self, self.protocol, loop=self.loop) def _run_test(self, t, *args, **kwargs): return run_coroutine(t.run_test(*args, **kwargs), loop=self.loop) def test_run_test(self): self._run_test(self.t, []) self.assertSequenceEqual(self.protocol.mock_calls, [ unittest.mock.call.connection_made(self.t), unittest.mock.call.connection_lost(None), ]) def test_stimulus(self): self._run_test(self.t, [], stimulus=b"foo") self.assertSequenceEqual(self.protocol.mock_calls, [ unittest.mock.call.connection_made(self.t), unittest.mock.call.data_received(b"foo"), unittest.mock.call.connection_lost(None), ]) def test_request_response(self): def data_received(data): assert data in {b"foo", b"baz"} if data == b"foo": self.t.write(b"bar") elif data == b"baz": self.t.close() self.protocol.data_received = data_received self._run_test(self.t, [ TransportMock.Write(b"bar", response=TransportMock.Receive(b"baz")), TransportMock.Close() ], stimulus=b"foo") def test_request_multiresponse(self): def data_received(data): assert data in {b"foo", b"bar", b"baz"} if data == b"foo": self.t.write(b"bar") elif data == b"bar": self.t.write(b"baric") elif data == b"baz": self.t.close() self.protocol.data_received = data_received self._run_test(self.t, [ TransportMock.Write(b"bar", response=[ TransportMock.Receive(b"bar"), TransportMock.Receive(b"baz") ]), TransportMock.Write(b"baric"), TransportMock.Close() ], stimulus=b"foo") def test_catch_asynchronous_invalid_action(self): def connection_made(transport): self.loop.call_soon(transport.close) self.protocol.connection_made = connection_made with self.assertRaises(AssertionError): self._run_test(self.t, [TransportMock.Write(b"foo")]) def test_catch_invalid_write(self): def connection_made(transport): transport.write(b"fnord") self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "mismatch of expected and written data"): self._run_test(self.t, [TransportMock.Write(b"foo")]) def test_catch_surplus_write(self): def connection_made(transport): transport.write(b"fnord") self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected write"): self._run_test(self.t, []) def test_catch_unexpected_close(self): def connection_made(transport): transport.close() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected close"): self._run_test(self.t, [TransportMock.Write(b"foo")]) def test_catch_surplus_close(self): def connection_made(transport): transport.close() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected close"): self._run_test(self.t, []) def test_allow_asynchronous_partial_write(self): def connection_made(transport): self.loop.call_soon(transport.write, b"f") self.loop.call_soon(transport.write, b"o") self.loop.call_soon(transport.write, b"o") self.protocol.connection_made = connection_made self._run_test(self.t, [TransportMock.Write(b"foo")]) def test_asynchronous_request_response(self): def data_received(data): self.assertIn(data, {b"foo", b"baz"}) if data == b"foo": self.loop.call_soon(self.t.write, b"bar") elif data == b"baz": self.loop.call_soon(self.t.close) self.protocol.data_received = data_received self._run_test(self.t, [ TransportMock.Write(b"bar", response=TransportMock.Receive(b"baz")), TransportMock.Close() ], stimulus=b"foo") def test_response_eof_received(self): def connection_made(transport): transport.close() self.protocol.connection_made = connection_made self._run_test( self.t, [TransportMock.Close(response=TransportMock.ReceiveEof())]) self.assertSequenceEqual(self.protocol.mock_calls, [ unittest.mock.call.eof_received(), unittest.mock.call.connection_lost(None) ]) def test_response_lose_connection(self): def connection_made(transport): transport.close() obj = object() self.protocol.connection_made = connection_made self._run_test( self.t, [TransportMock.Close(response=TransportMock.LoseConnection(obj))]) self.assertSequenceEqual(self.protocol.mock_calls, [unittest.mock.call.connection_lost(obj)]) def test_writelines(self): def connection_made(transport): transport.writelines([b"foo", b"bar"]) self.protocol.connection_made = connection_made self._run_test(self.t, [TransportMock.Write(b"foobar")]) def test_can_write_eof(self): self.assertTrue(self.t.can_write_eof()) def test_abort(self): def connection_made(transport): transport.abort() self.protocol.connection_made = connection_made self._run_test(self.t, [TransportMock.Abort()]) def test_write_eof(self): def connection_made(transport): transport.write_eof() self.protocol.connection_made = connection_made self._run_test(self.t, [TransportMock.WriteEof()]) def test_catch_surplus_write_eof(self): def connection_made(transport): transport.write_eof() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected write_eof"): self._run_test(self.t, []) def test_catch_unexpected_write_eof(self): def connection_made(transport): transport.write_eof() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected write_eof"): self._run_test(self.t, [TransportMock.Abort()]) def test_catch_surplus_abort(self): def connection_made(transport): transport.abort() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected abort"): self._run_test(self.t, []) def test_catch_unexpected_abort(self): def connection_made(transport): transport.abort() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected abort"): self._run_test(self.t, [TransportMock.WriteEof()]) def test_invalid_response(self): def connection_made(transport): transport.write(b"foo") self.protocol.connection_made = connection_made with self.assertRaisesRegex(RuntimeError, "test specification incorrect"): self._run_test(self.t, [TransportMock.Write(b"foo", response=1)]) def test_response_sequence(self): def connection_made(transport): transport.write(b"foo") self.protocol.connection_made = connection_made self._run_test(self.t, [ TransportMock.Write(b"foo", response=[ TransportMock.Receive(b"foo"), TransportMock.ReceiveEof() ]) ]) self.assertSequenceEqual(self.protocol.mock_calls, [ unittest.mock.call.data_received(b"foo"), unittest.mock.call.eof_received(), unittest.mock.call.connection_lost(None), ]) def test_clear_error_message(self): def connection_made(transport): transport.write(b"foo") transport.write(b"bar") self.protocol.connection_made = connection_made with self.assertRaises(AssertionError): self._run_test(self.t, [TransportMock.Write(b"baz")]) def test_detached_response(self): data = [] def data_received(blob): data.append(blob) def connection_made(transport): transport.write(b"foo") self.assertFalse(data) self.protocol.connection_made = connection_made self.protocol.data_received = data_received self._run_test(self.t, [ TransportMock.Write(b"foo", response=TransportMock.Receive(b"bar")) ]) def test_no_response_conflict(self): data = [] def data_received(blob): data.append(blob) def connection_made(transport): transport.write(b"foo") self.assertFalse(data) transport.write(b"bar") self.protocol.connection_made = connection_made self.protocol.data_received = data_received self._run_test(self.t, [ TransportMock.Write( b"foo", response=TransportMock.Receive(b"baz"), ), TransportMock.Write(b"bar", response=TransportMock.Receive(b"baz")) ]) def test_partial(self): def connection_made(transport): transport.write(b"foo") self.protocol.connection_made = connection_made self._run_test(self.t, [ TransportMock.Write(b"foo", ), ], partial=True) self.t.write_eof() self.t.close() self._run_test(self.t, [TransportMock.WriteEof(), TransportMock.Close()]) def test_no_starttls_by_default(self): self.assertFalse(self.t.can_starttls()) with self.assertRaises(RuntimeError): run_coroutine(self.t.starttls()) def test_starttls(self): self.t = TransportMock(self, self.protocol, with_starttls=True, loop=self.loop) self.assertTrue(self.t.can_starttls()) fut = asyncio.Future() def connection_made(transport): fut.set_result(None) self.protocol.connection_made = connection_made ssl_context = unittest.mock.Mock() post_handshake_callback = CoroutineMock() post_handshake_callback.return_value = None async def late_starttls(): await fut await self.t.starttls(ssl_context, post_handshake_callback) run_coroutine_with_peer( late_starttls(), self.t.run_test( [TransportMock.STARTTLS(ssl_context, post_handshake_callback)])) post_handshake_callback.assert_called_once_with(self.t) def test_starttls_without_callback(self): self.t = TransportMock(self, self.protocol, with_starttls=True, loop=self.loop) self.assertTrue(self.t.can_starttls()) fut = asyncio.Future() def connection_made(transport): fut.set_result(None) self.protocol.connection_made = connection_made ssl_context = unittest.mock.Mock() async def late_starttls(): await fut await self.t.starttls(ssl_context) run_coroutine_with_peer( late_starttls(), self.t.run_test([TransportMock.STARTTLS(ssl_context, None)])) def test_exception_from_stimulus_bubbles_up(self): exc = ConnectionError("foobar") def data_received(data): raise exc self.protocol.data_received = data_received with self.assertRaises(ConnectionError) as ctx: run_coroutine( self.t.run_test([], stimulus=TransportMock.Receive(b"foobar"))) self.assertIs(exc, ctx.exception) def tearDown(self): del self.t del self.loop del self.protocol
class TestTransportMock(unittest.TestCase): def setUp(self): self.protocol = make_protocol_mock() self.loop = asyncio.get_event_loop() self.t = TransportMock(self, self.protocol, loop=self.loop) def _run_test(self, t, *args, **kwargs): return run_coroutine(t.run_test(*args, **kwargs), loop=self.loop) def test_run_test(self): self._run_test(self.t, []) self.assertSequenceEqual( self.protocol.mock_calls, [ unittest.mock.call.connection_made(self.t), unittest.mock.call.connection_lost(None), ]) def test_stimulus(self): self._run_test(self.t, [], stimulus=b"foo") self.assertSequenceEqual( self.protocol.mock_calls, [ unittest.mock.call.connection_made(self.t), unittest.mock.call.data_received(b"foo"), unittest.mock.call.connection_lost(None), ]) def test_request_response(self): def data_received(data): assert data in {b"foo", b"baz"} if data == b"foo": self.t.write(b"bar") elif data == b"baz": self.t.close() self.protocol.data_received = data_received self._run_test( self.t, [ TransportMock.Write( b"bar", response=TransportMock.Receive(b"baz")), TransportMock.Close() ], stimulus=b"foo") def test_request_multiresponse(self): def data_received(data): assert data in {b"foo", b"bar", b"baz"} if data == b"foo": self.t.write(b"bar") elif data == b"bar": self.t.write(b"baric") elif data == b"baz": self.t.close() self.protocol.data_received = data_received self._run_test( self.t, [ TransportMock.Write( b"bar", response=[ TransportMock.Receive(b"bar"), TransportMock.Receive(b"baz") ]), TransportMock.Write(b"baric"), TransportMock.Close() ], stimulus=b"foo") def test_catch_asynchronous_invalid_action(self): def connection_made(transport): self.loop.call_soon(transport.close) self.protocol.connection_made = connection_made with self.assertRaises(AssertionError): self._run_test( self.t, [ TransportMock.Write(b"foo") ]) def test_catch_invalid_write(self): def connection_made(transport): transport.write(b"fnord") self.protocol.connection_made = connection_made with self.assertRaisesRegex( AssertionError, "mismatch of expected and written data"): self._run_test( self.t, [ TransportMock.Write(b"foo") ]) def test_catch_surplus_write(self): def connection_made(transport): transport.write(b"fnord") self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected write"): self._run_test( self.t, [ ]) def test_catch_unexpected_close(self): def connection_made(transport): transport.close() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected close"): self._run_test( self.t, [ TransportMock.Write(b"foo") ]) def test_catch_surplus_close(self): def connection_made(transport): transport.close() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected close"): self._run_test( self.t, [ ]) def test_allow_asynchronous_partial_write(self): def connection_made(transport): self.loop.call_soon(transport.write, b"f") self.loop.call_soon(transport.write, b"o") self.loop.call_soon(transport.write, b"o") self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.Write(b"foo") ]) def test_asynchronous_request_response(self): def data_received(data): self.assertIn(data, {b"foo", b"baz"}) if data == b"foo": self.loop.call_soon(self.t.write, b"bar") elif data == b"baz": self.loop.call_soon(self.t.close) self.protocol.data_received = data_received self._run_test( self.t, [ TransportMock.Write( b"bar", response=TransportMock.Receive(b"baz")), TransportMock.Close() ], stimulus=b"foo") def test_response_eof_received(self): def connection_made(transport): transport.close() self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.Close( response=TransportMock.ReceiveEof() ) ]) self.assertSequenceEqual( self.protocol.mock_calls, [ unittest.mock.call.eof_received(), unittest.mock.call.connection_lost(None) ]) def test_response_lose_connection(self): def connection_made(transport): transport.close() obj = object() self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.Close( response=TransportMock.LoseConnection(obj) ) ]) self.assertSequenceEqual( self.protocol.mock_calls, [ unittest.mock.call.connection_lost(obj) ]) def test_writelines(self): def connection_made(transport): transport.writelines([b"foo", b"bar"]) self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.Write(b"foobar") ]) def test_can_write_eof(self): self.assertTrue(self.t.can_write_eof()) def test_abort(self): def connection_made(transport): transport.abort() self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.Abort() ]) def test_write_eof(self): def connection_made(transport): transport.write_eof() self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.WriteEof() ]) def test_catch_surplus_write_eof(self): def connection_made(transport): transport.write_eof() self.protocol.connection_made = connection_made with self.assertRaisesRegex( AssertionError, "unexpected write_eof"): self._run_test( self.t, [ ]) def test_catch_unexpected_write_eof(self): def connection_made(transport): transport.write_eof() self.protocol.connection_made = connection_made with self.assertRaisesRegex( AssertionError, "unexpected write_eof"): self._run_test( self.t, [ TransportMock.Abort() ]) def test_catch_surplus_abort(self): def connection_made(transport): transport.abort() self.protocol.connection_made = connection_made with self.assertRaisesRegex( AssertionError, "unexpected abort"): self._run_test( self.t, [ ]) def test_catch_unexpected_abort(self): def connection_made(transport): transport.abort() self.protocol.connection_made = connection_made with self.assertRaisesRegex( AssertionError, "unexpected abort"): self._run_test( self.t, [ TransportMock.WriteEof() ]) def test_invalid_response(self): def connection_made(transport): transport.write(b"foo") self.protocol.connection_made = connection_made with self.assertRaisesRegex( RuntimeError, "test specification incorrect"): self._run_test( self.t, [ TransportMock.Write( b"foo", response=1) ]) def test_response_sequence(self): def connection_made(transport): transport.write(b"foo") self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.Write( b"foo", response=[ TransportMock.Receive(b"foo"), TransportMock.ReceiveEof() ]) ]) self.assertSequenceEqual( self.protocol.mock_calls, [ unittest.mock.call.data_received(b"foo"), unittest.mock.call.eof_received(), unittest.mock.call.connection_lost(None), ]) def test_clear_error_message(self): def connection_made(transport): transport.write(b"foo") transport.write(b"bar") self.protocol.connection_made = connection_made with self.assertRaises(AssertionError): self._run_test( self.t, [ TransportMock.Write(b"baz") ]) def test_detached_response(self): data = [] def data_received(blob): data.append(blob) def connection_made(transport): transport.write(b"foo") self.assertFalse(data) self.protocol.connection_made = connection_made self.protocol.data_received = data_received self._run_test( self.t, [ TransportMock.Write( b"foo", response=TransportMock.Receive(b"bar") ) ]) def test_no_response_conflict(self): data = [] def data_received(blob): data.append(blob) def connection_made(transport): transport.write(b"foo") self.assertFalse(data) transport.write(b"bar") self.protocol.connection_made = connection_made self.protocol.data_received = data_received self._run_test( self.t, [ TransportMock.Write( b"foo", response=TransportMock.Receive(b"baz"), ), TransportMock.Write( b"bar", response=TransportMock.Receive(b"baz") ) ]) def test_partial(self): def connection_made(transport): transport.write(b"foo") self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.Write( b"foo", ), ], partial=True ) self.t.write_eof() self.t.close() self._run_test( self.t, [ TransportMock.WriteEof(), TransportMock.Close() ] ) def test_no_starttls_by_default(self): self.assertFalse(self.t.can_starttls()) with self.assertRaises(RuntimeError): run_coroutine(self.t.starttls()) def test_starttls(self): self.t = TransportMock(self, self.protocol, with_starttls=True, loop=self.loop) self.assertTrue(self.t.can_starttls()) fut = asyncio.Future() def connection_made(transport): fut.set_result(None) self.protocol.connection_made = connection_made ssl_context = unittest.mock.Mock() post_handshake_callback = unittest.mock.Mock() post_handshake_callback.return_value = [] @asyncio.coroutine def late_starttls(): yield from fut yield from self.t.starttls(ssl_context, post_handshake_callback) run_coroutine_with_peer( late_starttls(), self.t.run_test( [ TransportMock.STARTTLS(ssl_context, post_handshake_callback) ] ) ) post_handshake_callback.assert_called_once_with(self.t) def test_starttls_without_callback(self): self.t = TransportMock(self, self.protocol, with_starttls=True, loop=self.loop) self.assertTrue(self.t.can_starttls()) fut = asyncio.Future() def connection_made(transport): fut.set_result(None) self.protocol.connection_made = connection_made ssl_context = unittest.mock.Mock() @asyncio.coroutine def late_starttls(): yield from fut yield from self.t.starttls(ssl_context) run_coroutine_with_peer( late_starttls(), self.t.run_test( [ TransportMock.STARTTLS(ssl_context, None) ] ) ) def test_exception_from_stimulus_bubbles_up(self): exc = ConnectionError("foobar") def data_received(data): raise exc self.protocol.data_received = data_received with self.assertRaises(ConnectionError) as ctx: run_coroutine( self.t.run_test( [ ], stimulus=TransportMock.Receive(b"foobar") ) ) self.assertIs( exc, ctx.exception ) def tearDown(self): del self.t del self.loop del self.protocol