def test_starttls_without_callback(self): self.t = TransportMock(self, self.protocol, with_starttls=True, loop=self.loop) self.assertTrue(self.t.can_starttls()) fut = asyncio.Future() def connection_made(transport): fut.set_result(None) self.protocol.connection_made = connection_made ssl_context = unittest.mock.Mock() @asyncio.coroutine def late_starttls(): yield from fut yield from self.t.starttls(ssl_context) run_coroutine_with_peer( late_starttls(), self.t.run_test( [ TransportMock.STARTTLS(ssl_context, None) ] ) )
def test_starttls(self): self.t = TransportMock(self, self.protocol, with_starttls=True, loop=self.loop) self.assertTrue(self.t.can_starttls()) fut = asyncio.Future() def connection_made(transport): fut.set_result(None) self.protocol.connection_made = connection_made ssl_context = unittest.mock.Mock() post_handshake_callback = CoroutineMock() post_handshake_callback.return_value = None async def late_starttls(): await fut await self.t.starttls(ssl_context, post_handshake_callback) run_coroutine_with_peer( late_starttls(), self.t.run_test( [TransportMock.STARTTLS(ssl_context, post_handshake_callback)])) post_handshake_callback.assert_called_once_with(self.t)
def test_partial(self): def connection_made(transport): transport.write(b"foo") self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.Write( b"foo", ), ], partial=True ) self.t.write_eof() self.t.close() self._run_test( self.t, [ TransportMock.WriteEof(), TransportMock.Close() ] )
def test_no_response_conflict(self): data = [] def data_received(blob): data.append(blob) def connection_made(transport): transport.write(b"foo") self.assertFalse(data) transport.write(b"bar") self.protocol.connection_made = connection_made self.protocol.data_received = data_received self._run_test( self.t, [ TransportMock.Write( b"foo", response=TransportMock.Receive(b"baz"), ), TransportMock.Write( b"bar", response=TransportMock.Receive(b"baz") ) ])
def test_iq_results_are_not_replied_to(self): import aioxmpp.protocol import aioxmpp.stream version = (1, 0) fut = asyncio.Future() p = aioxmpp.protocol.XMLStream( to=TEST_PEER, sorted_attributes=True, features_future=fut) t = TransportMock(self, p) s = aioxmpp.stream.StanzaStream(TEST_FROM.bare()) run_coroutine(t.run_test( [ TransportMock.Write( STREAM_HEADER, response=[ TransportMock.Receive( PEER_STREAM_HEADER_TEMPLATE.format( minor=version[1], major=version[0]).encode("utf-8")), ] ), ], partial=True )) self.assertEqual(p.state, aioxmpp.protocol.State.OPEN) s.start(p) run_coroutine( t.run_test( [ ], stimulus=[ TransportMock.Receive( b'<iq type="result" id="foo">' b'<payload xmlns="fnord"/>' b'</iq>' ) ], partial=True, ) ) s.flush_incoming() run_coroutine(asyncio.sleep(0)) run_coroutine( t.run_test( [ ], ) ) s.stop()
def test_response_eof_received(self): def connection_made(transport): transport.close() self.protocol.connection_made = connection_made self._run_test( self.t, [TransportMock.Close(response=TransportMock.ReceiveEof())]) self.assertSequenceEqual(self.protocol.mock_calls, [ unittest.mock.call.eof_received(), unittest.mock.call.connection_lost(None) ])
def test_response_lose_connection(self): def connection_made(transport): transport.close() obj = object() self.protocol.connection_made = connection_made self._run_test( self.t, [TransportMock.Close(response=TransportMock.LoseConnection(obj))]) self.assertSequenceEqual(self.protocol.mock_calls, [unittest.mock.call.connection_lost(obj)])
def test_request_response(self): def data_received(data): assert data in {b"foo", b"baz"} if data == b"foo": self.t.write(b"bar") elif data == b"baz": self.t.close() self.protocol.data_received = data_received self._run_test(self.t, [ TransportMock.Write(b"bar", response=TransportMock.Receive(b"baz")), TransportMock.Close() ], stimulus=b"foo")
def test_asynchronous_request_response(self): def data_received(data): self.assertIn(data, {b"foo", b"baz"}) if data == b"foo": self.loop.call_soon(self.t.write, b"bar") elif data == b"baz": self.loop.call_soon(self.t.close) self.protocol.data_received = data_received self._run_test(self.t, [ TransportMock.Write(b"bar", response=TransportMock.Receive(b"baz")), TransportMock.Close() ], stimulus=b"foo")
def test_catch_unexpected_close(self): def connection_made(transport): transport.close() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected close"): self._run_test(self.t, [TransportMock.Write(b"foo")])
def test_catch_asynchronous_invalid_action(self): def connection_made(transport): self.loop.call_soon(transport.close) self.protocol.connection_made = connection_made with self.assertRaises(AssertionError): self._run_test(self.t, [TransportMock.Write(b"foo")])
def test_abort(self): def connection_made(transport): transport.abort() self.protocol.connection_made = connection_made self._run_test(self.t, [TransportMock.Abort()])
def test_write_eof(self): def connection_made(transport): transport.write_eof() self.protocol.connection_made = connection_made self._run_test(self.t, [TransportMock.WriteEof()])
def test_writelines(self): def connection_made(transport): transport.writelines([b"foo", b"bar"]) self.protocol.connection_made = connection_made self._run_test(self.t, [TransportMock.Write(b"foobar")])
def test_catch_invalid_write(self): def connection_made(transport): transport.write(b"fnord") self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "mismatch of expected and written data"): self._run_test(self.t, [TransportMock.Write(b"foo")])
def test_catch_unexpected_abort(self): def connection_made(transport): transport.abort() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected abort"): self._run_test(self.t, [TransportMock.WriteEof()])
def test_allow_asynchronous_partial_write(self): def connection_made(transport): self.loop.call_soon(transport.write, b"f") self.loop.call_soon(transport.write, b"o") self.loop.call_soon(transport.write, b"o") self.protocol.connection_made = connection_made self._run_test(self.t, [TransportMock.Write(b"foo")])
def test_clear_error_message(self): def connection_made(transport): transport.write(b"foo") transport.write(b"bar") self.protocol.connection_made = connection_made with self.assertRaises(AssertionError): self._run_test(self.t, [TransportMock.Write(b"baz")])
def test_invalid_response(self): def connection_made(transport): transport.write(b"foo") self.protocol.connection_made = connection_made with self.assertRaisesRegex(RuntimeError, "test specification incorrect"): self._run_test(self.t, [TransportMock.Write(b"foo", response=1)])
def test_response_sequence(self): def connection_made(transport): transport.write(b"foo") self.protocol.connection_made = connection_made self._run_test(self.t, [ TransportMock.Write(b"foo", response=[ TransportMock.Receive(b"foo"), TransportMock.ReceiveEof() ]) ]) self.assertSequenceEqual(self.protocol.mock_calls, [ unittest.mock.call.data_received(b"foo"), unittest.mock.call.eof_received(), unittest.mock.call.connection_lost(None), ])
def test_exception_from_stimulus_bubbles_up(self): exc = ConnectionError("foobar") def data_received(data): raise exc self.protocol.data_received = data_received with self.assertRaises(ConnectionError) as ctx: run_coroutine( self.t.run_test([], stimulus=TransportMock.Receive(b"foobar"))) self.assertIs(exc, ctx.exception)
def test_iq_results_are_not_replied_to(self): import aioxmpp.protocol import aioxmpp.stream version = (1, 0) fut = asyncio.Future() p = aioxmpp.protocol.XMLStream(to=TEST_PEER, sorted_attributes=True, features_future=fut) t = TransportMock(self, p) s = aioxmpp.stream.StanzaStream(TEST_FROM.bare()) run_coroutine( t.run_test([ TransportMock.Write( STREAM_HEADER, response=[ TransportMock.Receive( PEER_STREAM_HEADER_TEMPLATE.format( minor=version[1], major=version[0]).encode("utf-8")), ]), ], partial=True)) self.assertEqual(p.state, aioxmpp.protocol.State.OPEN) s.start(p) run_coroutine( t.run_test( [], stimulus=[ TransportMock.Receive(b'<iq type="result" id="foo">' b'<payload xmlns="fnord"/>' b'</iq>') ], partial=True, )) s.flush_incoming() run_coroutine(asyncio.sleep(0)) run_coroutine(t.run_test([], )) s.stop()
def test_sm_bootstrap_race(self): import aioxmpp.protocol import aioxmpp.stream version = (1, 0) fut = asyncio.Future() p = aioxmpp.protocol.XMLStream(to=TEST_PEER, sorted_attributes=True, features_future=fut) t = TransportMock(self, p) s = aioxmpp.stream.StanzaStream(TEST_FROM.bare()) s.soft_timeout = timedelta(seconds=0.25) run_coroutine( t.run_test([ TransportMock.Write( STREAM_HEADER, response=[ TransportMock.Receive( PEER_STREAM_HEADER_TEMPLATE.format( minor=version[1], major=version[0]).encode("utf-8")), TransportMock.Receive( b"<stream:features><sm xmlns='urn:xmpp:sm:3'/>" b"</stream:features>") ]), ], partial=True)) self.assertEqual(p.state, aioxmpp.protocol.State.OPEN) self.assertTrue(fut.done()) s.start(p) run_coroutine_with_peer( s.start_sm(), t.run_test([ TransportMock.Write( b'<enable xmlns="urn:xmpp:sm:3" resume="true"/>', response=[ TransportMock.Receive( b'<enabled xmlns="urn:xmpp:sm:3" ' b'resume="true" id="foo"/>' b'<r xmlns="urn:xmpp:sm:3"/>') ]) ], partial=True)) self.assertTrue(s.sm_enabled) self.assertEqual(s.sm_id, "foo") self.assertTrue(s.sm_resumable) run_coroutine( t.run_test([ TransportMock.Write( b'<a xmlns="urn:xmpp:sm:3" h="0"/>' b'<r xmlns="urn:xmpp:sm:3"/>', response=[ TransportMock.Receive( b'<a xmlns="urn:xmpp:sm:3" h="0"/>', ), TransportMock.Receive(b'<r xmlns="urn:xmpp:sm:3"/>', ) ]), TransportMock.Write(b'<a xmlns="urn:xmpp:sm:3" h="0"/>') ], partial=True))
class TestTransportMock(unittest.TestCase): def setUp(self): self.protocol = make_protocol_mock() self.loop = asyncio.get_event_loop() self.t = TransportMock(self, self.protocol, loop=self.loop) def _run_test(self, t, *args, **kwargs): return run_coroutine(t.run_test(*args, **kwargs), loop=self.loop) def test_run_test(self): self._run_test(self.t, []) self.assertSequenceEqual( self.protocol.mock_calls, [ unittest.mock.call.connection_made(self.t), unittest.mock.call.connection_lost(None), ]) def test_stimulus(self): self._run_test(self.t, [], stimulus=b"foo") self.assertSequenceEqual( self.protocol.mock_calls, [ unittest.mock.call.connection_made(self.t), unittest.mock.call.data_received(b"foo"), unittest.mock.call.connection_lost(None), ]) def test_request_response(self): def data_received(data): assert data in {b"foo", b"baz"} if data == b"foo": self.t.write(b"bar") elif data == b"baz": self.t.close() self.protocol.data_received = data_received self._run_test( self.t, [ TransportMock.Write( b"bar", response=TransportMock.Receive(b"baz")), TransportMock.Close() ], stimulus=b"foo") def test_request_multiresponse(self): def data_received(data): assert data in {b"foo", b"bar", b"baz"} if data == b"foo": self.t.write(b"bar") elif data == b"bar": self.t.write(b"baric") elif data == b"baz": self.t.close() self.protocol.data_received = data_received self._run_test( self.t, [ TransportMock.Write( b"bar", response=[ TransportMock.Receive(b"bar"), TransportMock.Receive(b"baz") ]), TransportMock.Write(b"baric"), TransportMock.Close() ], stimulus=b"foo") def test_catch_asynchronous_invalid_action(self): def connection_made(transport): self.loop.call_soon(transport.close) self.protocol.connection_made = connection_made with self.assertRaises(AssertionError): self._run_test( self.t, [ TransportMock.Write(b"foo") ]) def test_catch_invalid_write(self): def connection_made(transport): transport.write(b"fnord") self.protocol.connection_made = connection_made with self.assertRaisesRegex( AssertionError, "mismatch of expected and written data"): self._run_test( self.t, [ TransportMock.Write(b"foo") ]) def test_catch_surplus_write(self): def connection_made(transport): transport.write(b"fnord") self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected write"): self._run_test( self.t, [ ]) def test_catch_unexpected_close(self): def connection_made(transport): transport.close() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected close"): self._run_test( self.t, [ TransportMock.Write(b"foo") ]) def test_catch_surplus_close(self): def connection_made(transport): transport.close() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected close"): self._run_test( self.t, [ ]) def test_allow_asynchronous_partial_write(self): def connection_made(transport): self.loop.call_soon(transport.write, b"f") self.loop.call_soon(transport.write, b"o") self.loop.call_soon(transport.write, b"o") self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.Write(b"foo") ]) def test_asynchronous_request_response(self): def data_received(data): self.assertIn(data, {b"foo", b"baz"}) if data == b"foo": self.loop.call_soon(self.t.write, b"bar") elif data == b"baz": self.loop.call_soon(self.t.close) self.protocol.data_received = data_received self._run_test( self.t, [ TransportMock.Write( b"bar", response=TransportMock.Receive(b"baz")), TransportMock.Close() ], stimulus=b"foo") def test_response_eof_received(self): def connection_made(transport): transport.close() self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.Close( response=TransportMock.ReceiveEof() ) ]) self.assertSequenceEqual( self.protocol.mock_calls, [ unittest.mock.call.eof_received(), unittest.mock.call.connection_lost(None) ]) def test_response_lose_connection(self): def connection_made(transport): transport.close() obj = object() self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.Close( response=TransportMock.LoseConnection(obj) ) ]) self.assertSequenceEqual( self.protocol.mock_calls, [ unittest.mock.call.connection_lost(obj) ]) def test_writelines(self): def connection_made(transport): transport.writelines([b"foo", b"bar"]) self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.Write(b"foobar") ]) def test_can_write_eof(self): self.assertTrue(self.t.can_write_eof()) def test_abort(self): def connection_made(transport): transport.abort() self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.Abort() ]) def test_write_eof(self): def connection_made(transport): transport.write_eof() self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.WriteEof() ]) def test_catch_surplus_write_eof(self): def connection_made(transport): transport.write_eof() self.protocol.connection_made = connection_made with self.assertRaisesRegex( AssertionError, "unexpected write_eof"): self._run_test( self.t, [ ]) def test_catch_unexpected_write_eof(self): def connection_made(transport): transport.write_eof() self.protocol.connection_made = connection_made with self.assertRaisesRegex( AssertionError, "unexpected write_eof"): self._run_test( self.t, [ TransportMock.Abort() ]) def test_catch_surplus_abort(self): def connection_made(transport): transport.abort() self.protocol.connection_made = connection_made with self.assertRaisesRegex( AssertionError, "unexpected abort"): self._run_test( self.t, [ ]) def test_catch_unexpected_abort(self): def connection_made(transport): transport.abort() self.protocol.connection_made = connection_made with self.assertRaisesRegex( AssertionError, "unexpected abort"): self._run_test( self.t, [ TransportMock.WriteEof() ]) def test_invalid_response(self): def connection_made(transport): transport.write(b"foo") self.protocol.connection_made = connection_made with self.assertRaisesRegex( RuntimeError, "test specification incorrect"): self._run_test( self.t, [ TransportMock.Write( b"foo", response=1) ]) def test_response_sequence(self): def connection_made(transport): transport.write(b"foo") self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.Write( b"foo", response=[ TransportMock.Receive(b"foo"), TransportMock.ReceiveEof() ]) ]) self.assertSequenceEqual( self.protocol.mock_calls, [ unittest.mock.call.data_received(b"foo"), unittest.mock.call.eof_received(), unittest.mock.call.connection_lost(None), ]) def test_clear_error_message(self): def connection_made(transport): transport.write(b"foo") transport.write(b"bar") self.protocol.connection_made = connection_made with self.assertRaises(AssertionError): self._run_test( self.t, [ TransportMock.Write(b"baz") ]) def test_detached_response(self): data = [] def data_received(blob): data.append(blob) def connection_made(transport): transport.write(b"foo") self.assertFalse(data) self.protocol.connection_made = connection_made self.protocol.data_received = data_received self._run_test( self.t, [ TransportMock.Write( b"foo", response=TransportMock.Receive(b"bar") ) ]) def test_no_response_conflict(self): data = [] def data_received(blob): data.append(blob) def connection_made(transport): transport.write(b"foo") self.assertFalse(data) transport.write(b"bar") self.protocol.connection_made = connection_made self.protocol.data_received = data_received self._run_test( self.t, [ TransportMock.Write( b"foo", response=TransportMock.Receive(b"baz"), ), TransportMock.Write( b"bar", response=TransportMock.Receive(b"baz") ) ]) def test_partial(self): def connection_made(transport): transport.write(b"foo") self.protocol.connection_made = connection_made self._run_test( self.t, [ TransportMock.Write( b"foo", ), ], partial=True ) self.t.write_eof() self.t.close() self._run_test( self.t, [ TransportMock.WriteEof(), TransportMock.Close() ] ) def test_no_starttls_by_default(self): self.assertFalse(self.t.can_starttls()) with self.assertRaises(RuntimeError): run_coroutine(self.t.starttls()) def test_starttls(self): self.t = TransportMock(self, self.protocol, with_starttls=True, loop=self.loop) self.assertTrue(self.t.can_starttls()) fut = asyncio.Future() def connection_made(transport): fut.set_result(None) self.protocol.connection_made = connection_made ssl_context = unittest.mock.Mock() post_handshake_callback = unittest.mock.Mock() post_handshake_callback.return_value = [] @asyncio.coroutine def late_starttls(): yield from fut yield from self.t.starttls(ssl_context, post_handshake_callback) run_coroutine_with_peer( late_starttls(), self.t.run_test( [ TransportMock.STARTTLS(ssl_context, post_handshake_callback) ] ) ) post_handshake_callback.assert_called_once_with(self.t) def test_starttls_without_callback(self): self.t = TransportMock(self, self.protocol, with_starttls=True, loop=self.loop) self.assertTrue(self.t.can_starttls()) fut = asyncio.Future() def connection_made(transport): fut.set_result(None) self.protocol.connection_made = connection_made ssl_context = unittest.mock.Mock() @asyncio.coroutine def late_starttls(): yield from fut yield from self.t.starttls(ssl_context) run_coroutine_with_peer( late_starttls(), self.t.run_test( [ TransportMock.STARTTLS(ssl_context, None) ] ) ) def test_exception_from_stimulus_bubbles_up(self): exc = ConnectionError("foobar") def data_received(data): raise exc self.protocol.data_received = data_received with self.assertRaises(ConnectionError) as ctx: run_coroutine( self.t.run_test( [ ], stimulus=TransportMock.Receive(b"foobar") ) ) self.assertIs( exc, ctx.exception ) def tearDown(self): del self.t del self.loop del self.protocol
def setUp(self): self.protocol = make_protocol_mock() self.loop = asyncio.get_event_loop() self.t = TransportMock(self, self.protocol, loop=self.loop)
class TestTransportMock(unittest.TestCase): def setUp(self): self.protocol = make_protocol_mock() self.loop = asyncio.get_event_loop() self.t = TransportMock(self, self.protocol, loop=self.loop) def _run_test(self, t, *args, **kwargs): return run_coroutine(t.run_test(*args, **kwargs), loop=self.loop) def test_run_test(self): self._run_test(self.t, []) self.assertSequenceEqual(self.protocol.mock_calls, [ unittest.mock.call.connection_made(self.t), unittest.mock.call.connection_lost(None), ]) def test_stimulus(self): self._run_test(self.t, [], stimulus=b"foo") self.assertSequenceEqual(self.protocol.mock_calls, [ unittest.mock.call.connection_made(self.t), unittest.mock.call.data_received(b"foo"), unittest.mock.call.connection_lost(None), ]) def test_request_response(self): def data_received(data): assert data in {b"foo", b"baz"} if data == b"foo": self.t.write(b"bar") elif data == b"baz": self.t.close() self.protocol.data_received = data_received self._run_test(self.t, [ TransportMock.Write(b"bar", response=TransportMock.Receive(b"baz")), TransportMock.Close() ], stimulus=b"foo") def test_request_multiresponse(self): def data_received(data): assert data in {b"foo", b"bar", b"baz"} if data == b"foo": self.t.write(b"bar") elif data == b"bar": self.t.write(b"baric") elif data == b"baz": self.t.close() self.protocol.data_received = data_received self._run_test(self.t, [ TransportMock.Write(b"bar", response=[ TransportMock.Receive(b"bar"), TransportMock.Receive(b"baz") ]), TransportMock.Write(b"baric"), TransportMock.Close() ], stimulus=b"foo") def test_catch_asynchronous_invalid_action(self): def connection_made(transport): self.loop.call_soon(transport.close) self.protocol.connection_made = connection_made with self.assertRaises(AssertionError): self._run_test(self.t, [TransportMock.Write(b"foo")]) def test_catch_invalid_write(self): def connection_made(transport): transport.write(b"fnord") self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "mismatch of expected and written data"): self._run_test(self.t, [TransportMock.Write(b"foo")]) def test_catch_surplus_write(self): def connection_made(transport): transport.write(b"fnord") self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected write"): self._run_test(self.t, []) def test_catch_unexpected_close(self): def connection_made(transport): transport.close() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected close"): self._run_test(self.t, [TransportMock.Write(b"foo")]) def test_catch_surplus_close(self): def connection_made(transport): transport.close() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected close"): self._run_test(self.t, []) def test_allow_asynchronous_partial_write(self): def connection_made(transport): self.loop.call_soon(transport.write, b"f") self.loop.call_soon(transport.write, b"o") self.loop.call_soon(transport.write, b"o") self.protocol.connection_made = connection_made self._run_test(self.t, [TransportMock.Write(b"foo")]) def test_asynchronous_request_response(self): def data_received(data): self.assertIn(data, {b"foo", b"baz"}) if data == b"foo": self.loop.call_soon(self.t.write, b"bar") elif data == b"baz": self.loop.call_soon(self.t.close) self.protocol.data_received = data_received self._run_test(self.t, [ TransportMock.Write(b"bar", response=TransportMock.Receive(b"baz")), TransportMock.Close() ], stimulus=b"foo") def test_response_eof_received(self): def connection_made(transport): transport.close() self.protocol.connection_made = connection_made self._run_test( self.t, [TransportMock.Close(response=TransportMock.ReceiveEof())]) self.assertSequenceEqual(self.protocol.mock_calls, [ unittest.mock.call.eof_received(), unittest.mock.call.connection_lost(None) ]) def test_response_lose_connection(self): def connection_made(transport): transport.close() obj = object() self.protocol.connection_made = connection_made self._run_test( self.t, [TransportMock.Close(response=TransportMock.LoseConnection(obj))]) self.assertSequenceEqual(self.protocol.mock_calls, [unittest.mock.call.connection_lost(obj)]) def test_writelines(self): def connection_made(transport): transport.writelines([b"foo", b"bar"]) self.protocol.connection_made = connection_made self._run_test(self.t, [TransportMock.Write(b"foobar")]) def test_can_write_eof(self): self.assertTrue(self.t.can_write_eof()) def test_abort(self): def connection_made(transport): transport.abort() self.protocol.connection_made = connection_made self._run_test(self.t, [TransportMock.Abort()]) def test_write_eof(self): def connection_made(transport): transport.write_eof() self.protocol.connection_made = connection_made self._run_test(self.t, [TransportMock.WriteEof()]) def test_catch_surplus_write_eof(self): def connection_made(transport): transport.write_eof() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected write_eof"): self._run_test(self.t, []) def test_catch_unexpected_write_eof(self): def connection_made(transport): transport.write_eof() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected write_eof"): self._run_test(self.t, [TransportMock.Abort()]) def test_catch_surplus_abort(self): def connection_made(transport): transport.abort() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected abort"): self._run_test(self.t, []) def test_catch_unexpected_abort(self): def connection_made(transport): transport.abort() self.protocol.connection_made = connection_made with self.assertRaisesRegex(AssertionError, "unexpected abort"): self._run_test(self.t, [TransportMock.WriteEof()]) def test_invalid_response(self): def connection_made(transport): transport.write(b"foo") self.protocol.connection_made = connection_made with self.assertRaisesRegex(RuntimeError, "test specification incorrect"): self._run_test(self.t, [TransportMock.Write(b"foo", response=1)]) def test_response_sequence(self): def connection_made(transport): transport.write(b"foo") self.protocol.connection_made = connection_made self._run_test(self.t, [ TransportMock.Write(b"foo", response=[ TransportMock.Receive(b"foo"), TransportMock.ReceiveEof() ]) ]) self.assertSequenceEqual(self.protocol.mock_calls, [ unittest.mock.call.data_received(b"foo"), unittest.mock.call.eof_received(), unittest.mock.call.connection_lost(None), ]) def test_clear_error_message(self): def connection_made(transport): transport.write(b"foo") transport.write(b"bar") self.protocol.connection_made = connection_made with self.assertRaises(AssertionError): self._run_test(self.t, [TransportMock.Write(b"baz")]) def test_detached_response(self): data = [] def data_received(blob): data.append(blob) def connection_made(transport): transport.write(b"foo") self.assertFalse(data) self.protocol.connection_made = connection_made self.protocol.data_received = data_received self._run_test(self.t, [ TransportMock.Write(b"foo", response=TransportMock.Receive(b"bar")) ]) def test_no_response_conflict(self): data = [] def data_received(blob): data.append(blob) def connection_made(transport): transport.write(b"foo") self.assertFalse(data) transport.write(b"bar") self.protocol.connection_made = connection_made self.protocol.data_received = data_received self._run_test(self.t, [ TransportMock.Write( b"foo", response=TransportMock.Receive(b"baz"), ), TransportMock.Write(b"bar", response=TransportMock.Receive(b"baz")) ]) def test_partial(self): def connection_made(transport): transport.write(b"foo") self.protocol.connection_made = connection_made self._run_test(self.t, [ TransportMock.Write(b"foo", ), ], partial=True) self.t.write_eof() self.t.close() self._run_test(self.t, [TransportMock.WriteEof(), TransportMock.Close()]) def test_no_starttls_by_default(self): self.assertFalse(self.t.can_starttls()) with self.assertRaises(RuntimeError): run_coroutine(self.t.starttls()) def test_starttls(self): self.t = TransportMock(self, self.protocol, with_starttls=True, loop=self.loop) self.assertTrue(self.t.can_starttls()) fut = asyncio.Future() def connection_made(transport): fut.set_result(None) self.protocol.connection_made = connection_made ssl_context = unittest.mock.Mock() post_handshake_callback = CoroutineMock() post_handshake_callback.return_value = None async def late_starttls(): await fut await self.t.starttls(ssl_context, post_handshake_callback) run_coroutine_with_peer( late_starttls(), self.t.run_test( [TransportMock.STARTTLS(ssl_context, post_handshake_callback)])) post_handshake_callback.assert_called_once_with(self.t) def test_starttls_without_callback(self): self.t = TransportMock(self, self.protocol, with_starttls=True, loop=self.loop) self.assertTrue(self.t.can_starttls()) fut = asyncio.Future() def connection_made(transport): fut.set_result(None) self.protocol.connection_made = connection_made ssl_context = unittest.mock.Mock() async def late_starttls(): await fut await self.t.starttls(ssl_context) run_coroutine_with_peer( late_starttls(), self.t.run_test([TransportMock.STARTTLS(ssl_context, None)])) def test_exception_from_stimulus_bubbles_up(self): exc = ConnectionError("foobar") def data_received(data): raise exc self.protocol.data_received = data_received with self.assertRaises(ConnectionError) as ctx: run_coroutine( self.t.run_test([], stimulus=TransportMock.Receive(b"foobar"))) self.assertIs(exc, ctx.exception) def tearDown(self): del self.t del self.loop del self.protocol
def test_sm_works_correctly_with_invalid_payload(self): import aioxmpp.protocol import aioxmpp.stream version = (1, 0) fut = asyncio.Future() p = aioxmpp.protocol.XMLStream( to=TEST_PEER, sorted_attributes=True, features_future=fut) t = TransportMock(self, p) s = aioxmpp.stream.StanzaStream(TEST_FROM.bare()) run_coroutine(t.run_test( [ TransportMock.Write( STREAM_HEADER, response=[ TransportMock.Receive( PEER_STREAM_HEADER_TEMPLATE.format( minor=version[1], major=version[0]).encode("utf-8")), TransportMock.Receive( b"<stream:features><sm xmlns='urn:xmpp:sm:3'/>" b"</stream:features>" ) ] ), ], partial=True )) self.assertEqual(p.state, aioxmpp.protocol.State.OPEN) self.assertTrue(fut.done()) s.ping_interval = timedelta(seconds=0.25) s.ping_opportunistic_interval = timedelta(seconds=0.25) s.start(p) run_coroutine_with_peer( s.start_sm(), t.run_test( [ TransportMock.Write( b'<enable xmlns="urn:xmpp:sm:3" resume="true"/>', response=[ TransportMock.Receive( b'<enabled xmlns="urn:xmpp:sm:3" ' b'resume="true" id="foo"/>' ) ] ) ], partial=True ) ) self.assertTrue(s.sm_enabled) self.assertEqual(s.sm_id, "foo") self.assertTrue(s.sm_resumable) run_coroutine( t.run_test( [ TransportMock.Write( b'<r xmlns="urn:xmpp:sm:3"/>', response=[ TransportMock.Receive( b'<a xmlns="urn:xmpp:sm:3" h="0"/>', ), TransportMock.Receive( b'<r xmlns="urn:xmpp:sm:3"/>', ) ] ), TransportMock.Write( b'<a xmlns="urn:xmpp:sm:3" h="0"/>' ) ], partial=True ) ) run_coroutine( t.run_test( [ TransportMock.Write( b'<iq id="foo" type="error"><error type="cancel">' b'<feature-not-implemented' b' xmlns="urn:ietf:params:xml:ns:xmpp-stanzas"/>' b'</error></iq>' ), TransportMock.Write( b'<r xmlns="urn:xmpp:sm:3"/>', response=[ TransportMock.Receive( b'<a xmlns="urn:xmpp:sm:3" h="1"/>', ), TransportMock.Receive( b'<r xmlns="urn:xmpp:sm:3"/>', ) ] ), TransportMock.Write( b'<a xmlns="urn:xmpp:sm:3" h="1"/>' ) ], stimulus=[ TransportMock.Receive( b'<iq type="get" id="foo">' b'<payload xmlns="fnord"/>' b'</iq>' ) ], partial=True ) )
def test_hard_deadtime_kills_stream(self): import aioxmpp.protocol import aioxmpp.stream version = (1, 0) fut = asyncio.Future() p = aioxmpp.protocol.XMLStream(to=TEST_PEER, sorted_attributes=True, features_future=fut) t = TransportMock(self, p) s = aioxmpp.stream.StanzaStream(TEST_FROM.bare()) s.soft_timeout = timedelta(seconds=0.1) s.round_trip_time = timedelta(seconds=0.1) failure_fut = s.on_failure.future() run_coroutine( t.run_test([ TransportMock.Write( STREAM_HEADER, response=[ TransportMock.Receive( PEER_STREAM_HEADER_TEMPLATE.format( minor=version[1], major=version[0]).encode("utf-8")), ]), ], partial=True)) self.assertEqual(p.state, aioxmpp.protocol.State.OPEN) s.start(p) IQ_bak = aioxmpp.IQ def fake_iq_constructor(*args, **kwargs): iq = IQ_bak(*args, **kwargs) iq.id_ = "ping" return iq with unittest.mock.patch("aioxmpp.stanza.IQ") as IQ: IQ.side_effect = fake_iq_constructor run_coroutine( t.run_test([ TransportMock.Write(b'<iq id="ping" type="get">' b'<ping xmlns="urn:xmpp:ping"/></iq>'), ], partial=True)) run_coroutine(t.run_test([ TransportMock.Abort(), ], )) run_coroutine(asyncio.sleep(0)) self.assertFalse(s.running) self.assertTrue(failure_fut.done()) self.assertIsInstance(failure_fut.exception(), ConnectionError) self.assertIn("timeout", str(failure_fut.exception()))
def test_sm_works_correctly_with_invalid_payload(self): import aioxmpp.protocol import aioxmpp.stream version = (1, 0) fut = asyncio.Future() p = aioxmpp.protocol.XMLStream(to=TEST_PEER, sorted_attributes=True, features_future=fut) t = TransportMock(self, p) s = aioxmpp.stream.StanzaStream(TEST_FROM.bare()) s.soft_timeout = timedelta(seconds=0.25) run_coroutine( t.run_test([ TransportMock.Write( STREAM_HEADER, response=[ TransportMock.Receive( PEER_STREAM_HEADER_TEMPLATE.format( minor=version[1], major=version[0]).encode("utf-8")), TransportMock.Receive( b"<stream:features><sm xmlns='urn:xmpp:sm:3'/>" b"</stream:features>") ]), ], partial=True)) self.assertEqual(p.state, aioxmpp.protocol.State.OPEN) self.assertTrue(fut.done()) s.start(p) run_coroutine_with_peer( s.start_sm(), t.run_test([ TransportMock.Write( b'<enable xmlns="urn:xmpp:sm:3" resume="true"/>', response=[ TransportMock.Receive( b'<enabled xmlns="urn:xmpp:sm:3" ' b'resume="true" id="foo"/>') ]) ], partial=True)) self.assertTrue(s.sm_enabled) self.assertEqual(s.sm_id, "foo") self.assertTrue(s.sm_resumable) run_coroutine( t.run_test([ TransportMock.Write( b'<r xmlns="urn:xmpp:sm:3"/>', response=[ TransportMock.Receive( b'<a xmlns="urn:xmpp:sm:3" h="0"/>', ), TransportMock.Receive(b'<r xmlns="urn:xmpp:sm:3"/>', ) ]), TransportMock.Write(b'<a xmlns="urn:xmpp:sm:3" h="0"/>') ], partial=True)) run_coroutine( t.run_test([ TransportMock.Write( b'<iq id="foo" type="error"><error type="cancel">' b'<service-unavailable' b' xmlns="urn:ietf:params:xml:ns:xmpp-stanzas"/>' b'</error></iq>'), TransportMock.Write( b'<r xmlns="urn:xmpp:sm:3"/>', response=[ TransportMock.Receive( b'<a xmlns="urn:xmpp:sm:3" h="1"/>', ), TransportMock.Receive(b'<r xmlns="urn:xmpp:sm:3"/>', ) ]), TransportMock.Write(b'<a xmlns="urn:xmpp:sm:3" h="1"/>') ], stimulus=[ TransportMock.Receive(b'<iq type="get" id="foo">' b'<payload xmlns="fnord"/>' b'</iq>') ], partial=True))