class TestSASLXMPPInterface(xmltestutils.XMLTestCase): def setUp(self): self.loop = asyncio.get_event_loop() self.xmlstream = XMLStreamMock(self, loop=self.loop) self.sm = sasl.SASLXMPPInterface(self.xmlstream) def _run_test(self, coro, actions=[], stimulus=None): return run_coroutine_with_peer( coro, self.xmlstream.run_test(actions, stimulus=stimulus), loop=self.loop) def test_setup(self): self.assertIsNone(self.sm.timeout) self.assertIs(self.xmlstream, self.sm.xmlstream) def test_initiate_success(self): state, payload = self._run_test( self.sm.initiate("foo", b"bar"), [ XMLStreamMock.Mute(), XMLStreamMock.Send( nonza.SASLAuth(mechanism="foo", payload=b"bar"), response=XMLStreamMock.Receive( nonza.SASLSuccess() ) ), XMLStreamMock.Unmute(), ] ) self.assertEqual(state, "success") self.assertIsNone(payload) def test_initiate_failure(self): with self.assertRaises(aiosasl.SASLFailure) as ctx: self._run_test( self.sm.initiate("foo", b"bar"), [ XMLStreamMock.Mute(), XMLStreamMock.Send( nonza.SASLAuth(mechanism="foo", payload=b"bar"), response=XMLStreamMock.Receive( nonza.SASLFailure( condition=(namespaces.sasl, "not-authorized") ) ) ), XMLStreamMock.Unmute(), ] ) self.assertEqual( "not-authorized", ctx.exception.opaque_error ) def test_initiate_challenge(self): state, payload = self._run_test( self.sm.initiate("foo", b"bar"), [ XMLStreamMock.Mute(), XMLStreamMock.Send( nonza.SASLAuth(mechanism="foo", payload=b"bar"), response=XMLStreamMock.Receive( nonza.SASLChallenge(payload=b"baz") ) ), XMLStreamMock.Unmute(), ] ) self.assertEqual(state, "challenge") self.assertEqual(payload, b"baz") def test_response_success(self): self.sm._state = "challenge" state, payload = self._run_test( self.sm.respond(b"bar"), [ XMLStreamMock.Mute(), XMLStreamMock.Send( nonza.SASLResponse(payload=b"bar"), response=XMLStreamMock.Receive( nonza.SASLSuccess() ) ), XMLStreamMock.Unmute(), ] ) self.assertEqual(state, "success") self.assertIsNone(payload) def test_response_failure(self): self.sm._state = "challenge" with self.assertRaises(aiosasl.SASLFailure) as ctx: self._run_test( self.sm.respond(b"bar"), [ XMLStreamMock.Mute(), XMLStreamMock.Send( nonza.SASLResponse(payload=b"bar"), response=XMLStreamMock.Receive( nonza.SASLFailure( condition=(namespaces.sasl, "credentials-expired") ) ) ), XMLStreamMock.Unmute(), ] ) self.assertEqual( "credentials-expired", ctx.exception.opaque_error ) def test_response_challenge(self): self.sm._state = "challenge" state, payload = self._run_test( self.sm.respond(b"bar"), [ XMLStreamMock.Mute(), XMLStreamMock.Send( nonza.SASLResponse(payload=b"bar"), response=XMLStreamMock.Receive( nonza.SASLChallenge(payload=b"baz") ) ), XMLStreamMock.Unmute(), ] ) self.assertEqual(state, "challenge") self.assertEqual(payload, b"baz") def test_abort_reject_non_failure(self): self.sm._state = "challenge" with self.assertRaisesRegex( aiosasl.SASLFailure, "unexpected non-failure" ) as ctx: self._run_test( self.sm.abort(), [ XMLStreamMock.Send( nonza.SASLAbort(), response=XMLStreamMock.Receive( nonza.SASLSuccess() ) ) ] ) self.assertEqual( "aborted", ctx.exception.opaque_error ) def test_abort_return_on_aborted_error(self): self.sm._state = "challenge" state, payload = self._run_test( self.sm.abort(), [ XMLStreamMock.Send( nonza.SASLAbort(), response=XMLStreamMock.Receive( nonza.SASLFailure( condition=(namespaces.sasl, "aborted") ) ) ) ] ) self.assertEqual(state, "failure") self.assertIsNone(payload) def test_abort_re_raise_other_errors(self): self.sm._state = "challenge" with self.assertRaises(aiosasl.SASLFailure) as ctx: self._run_test( self.sm.abort(), [ XMLStreamMock.Send( nonza.SASLAbort(), response=XMLStreamMock.Receive( nonza.SASLFailure( condition=(namespaces.sasl, "mechanism-too-weak") ) ) ) ] ) self.assertEqual( "mechanism-too-weak", ctx.exception.opaque_error ) def tearDown(self): del self.xmlstream del self.loop
class TestSASLXMPPInterface(xmltestutils.XMLTestCase): def setUp(self): self.loop = asyncio.get_event_loop() self.xmlstream = XMLStreamMock(self, loop=self.loop) self.sm = sasl.SASLXMPPInterface(self.xmlstream) def _run_test(self, coro, actions=[], stimulus=None): return run_coroutine_with_peer(coro, self.xmlstream.run_test( actions, stimulus=stimulus), loop=self.loop) def test_setup(self): self.assertIsNone(self.sm.timeout) self.assertIs(self.xmlstream, self.sm.xmlstream) def test_initiate_success(self): state, payload = self._run_test(self.sm.initiate("foo", b"bar"), [ XMLStreamMock.Send(nonza.SASLAuth(mechanism="foo", payload=b"bar"), response=XMLStreamMock.Receive( nonza.SASLSuccess())) ]) self.assertEqual(state, "success") self.assertIsNone(payload) def test_initiate_failure(self): with self.assertRaises(aiosasl.SASLFailure) as ctx: self._run_test(self.sm.initiate("foo", b"bar"), [ XMLStreamMock.Send( nonza.SASLAuth(mechanism="foo", payload=b"bar"), response=XMLStreamMock.Receive( nonza.SASLFailure(condition=(namespaces.sasl, "not-authorized")))) ]) self.assertEqual("not-authorized", ctx.exception.opaque_error) def test_initiate_challenge(self): state, payload = self._run_test(self.sm.initiate("foo", b"bar"), [ XMLStreamMock.Send(nonza.SASLAuth(mechanism="foo", payload=b"bar"), response=XMLStreamMock.Receive( nonza.SASLChallenge(payload=b"baz"))) ]) self.assertEqual(state, "challenge") self.assertEqual(payload, b"baz") def test_response_success(self): self.sm._state = "challenge" state, payload = self._run_test(self.sm.respond(b"bar"), [ XMLStreamMock.Send(nonza.SASLResponse(payload=b"bar"), response=XMLStreamMock.Receive( nonza.SASLSuccess())) ]) self.assertEqual(state, "success") self.assertIsNone(payload) def test_response_failure(self): self.sm._state = "challenge" with self.assertRaises(aiosasl.SASLFailure) as ctx: self._run_test(self.sm.respond(b"bar"), [ XMLStreamMock.Send( nonza.SASLResponse(payload=b"bar"), response=XMLStreamMock.Receive( nonza.SASLFailure(condition=(namespaces.sasl, "credentials-expired")))) ]) self.assertEqual("credentials-expired", ctx.exception.opaque_error) def test_response_challenge(self): self.sm._state = "challenge" state, payload = self._run_test(self.sm.respond(b"bar"), [ XMLStreamMock.Send(nonza.SASLResponse(payload=b"bar"), response=XMLStreamMock.Receive( nonza.SASLChallenge(payload=b"baz"))) ]) self.assertEqual(state, "challenge") self.assertEqual(payload, b"baz") def test_abort_reject_non_failure(self): self.sm._state = "challenge" with self.assertRaisesRegex(aiosasl.SASLFailure, "unexpected non-failure") as ctx: self._run_test(self.sm.abort(), [ XMLStreamMock.Send(nonza.SASLAbort(), response=XMLStreamMock.Receive( nonza.SASLSuccess())) ]) self.assertEqual("aborted", ctx.exception.opaque_error) def test_abort_return_on_aborted_error(self): self.sm._state = "challenge" state, payload = self._run_test(self.sm.abort(), [ XMLStreamMock.Send( nonza.SASLAbort(), response=XMLStreamMock.Receive( nonza.SASLFailure(condition=(namespaces.sasl, "aborted")))) ]) self.assertEqual(state, "failure") self.assertIsNone(payload) def test_abort_re_raise_other_errors(self): self.sm._state = "challenge" with self.assertRaises(aiosasl.SASLFailure) as ctx: self._run_test(self.sm.abort(), [ XMLStreamMock.Send( nonza.SASLAbort(), response=XMLStreamMock.Receive( nonza.SASLFailure(condition=(namespaces.sasl, "mechanism-too-weak")))) ]) self.assertEqual("mechanism-too-weak", ctx.exception.opaque_error) def tearDown(self): del self.xmlstream del self.loop
class TestXMLStreamMock(XMLTestCase): def setUp(self): class Cls(xso.XSO): TAG = ("uri:foo", "foo") self.Cls = Cls self.loop = asyncio.get_event_loop() self.xmlstream = XMLStreamMock(self, loop=self.loop) def test_register_stanza_handler(self): received = [] def handler(obj): nonlocal received received.append(obj) obj = self.Cls() self.xmlstream.stanza_parser.add_class(self.Cls, handler) run_coroutine( self.xmlstream.run_test([], stimulus=XMLStreamMock.Receive(obj))) self.assertSequenceEqual([obj], received) def test_send_xso(self): obj = self.Cls() def handler(obj): self.xmlstream.send_xso(obj) self.xmlstream.stanza_parser.add_class(self.Cls, handler) run_coroutine( self.xmlstream.run_test([ XMLStreamMock.Send(obj), ], stimulus=XMLStreamMock.Receive(obj))) def test_catch_missing_stanza_handler(self): obj = self.Cls() with self.assertRaisesRegex(AssertionError, "no handler registered"): run_coroutine( self.xmlstream.run_test([], stimulus=XMLStreamMock.Receive(obj))) def test_receive_stream_features_into_future(self): fut = self.xmlstream.features_future() obj = nonza.StreamFeatures() run_coroutine( self.xmlstream.run_test([], stimulus=XMLStreamMock.Receive(obj))) self.assertTrue(fut.done()) self.assertIs(fut.result(), obj) def test_no_termination_on_missing_action(self): obj = self.Cls() with self.assertRaises(asyncio.TimeoutError): run_coroutine(self.xmlstream.run_test([ XMLStreamMock.Send(obj), ], ), timeout=0.05) def test_catch_surplus_send(self): self.xmlstream.send_xso(self.Cls()) with self.assertRaisesRegex( AssertionError, r"unexpected send_xso\(<tests.test_testutils.TestXMLStreamMock" r".setUp.<locals>.Cls object at 0x[a-fA-F0-9]+>\)"): run_coroutine(self.xmlstream.run_test([], )) def test_reset(self): obj = self.Cls() def handler(obj): self.xmlstream.reset() self.xmlstream.stanza_parser.add_class(self.Cls, handler) run_coroutine( self.xmlstream.run_test([ XMLStreamMock.Reset(), ], stimulus=XMLStreamMock.Receive(obj))) def test_catch_surplus_reset(self): self.xmlstream.reset() with self.assertRaisesRegex(AssertionError, "unexpected reset"): run_coroutine(self.xmlstream.run_test([], )) def test_close(self): closing_handler = unittest.mock.Mock() fut = self.xmlstream.error_future() obj = self.Cls() self.xmlstream.on_closing.connect(closing_handler) def handler(obj): self.xmlstream.close() self.xmlstream.stanza_parser.add_class(self.Cls, handler) run_coroutine( self.xmlstream.run_test([ XMLStreamMock.Close(), ], stimulus=XMLStreamMock.Receive(obj))) self.assertSequenceEqual([ unittest.mock.call(None), ], closing_handler.mock_calls) self.assertTrue(fut.done()) self.assertIsInstance(fut.exception(), ConnectionError) def test_catch_surplus_close(self): self.xmlstream.close() with self.assertRaisesRegex(AssertionError, "unexpected close"): run_coroutine(self.xmlstream.run_test([], )) def test_mute_unmute_cycle(self): with self.xmlstream.mute(): run_coroutine(self.xmlstream.run_test([ XMLStreamMock.Mute(), ], )) run_coroutine(self.xmlstream.run_test([ XMLStreamMock.Unmute(), ], )) def test_catch_surplus_mute(self): with self.xmlstream.mute(): with self.assertRaisesRegex(AssertionError, "unexpected mute"): run_coroutine(self.xmlstream.run_test([], )) def test_catch_surplus_unmute(self): with self.xmlstream.mute(): pass with self.assertRaisesRegex(AssertionError, "unexpected unmute"): run_coroutine(self.xmlstream.run_test([ XMLStreamMock.Mute(), ], )) def test_starttls(self): ssl_context = unittest.mock.MagicMock() post_handshake_callback = CoroutineMock() post_handshake_callback.return_value = None self.xmlstream.transport = object() run_coroutine( asyncio.gather( self.xmlstream.starttls(ssl_context, post_handshake_callback), self.xmlstream.run_test([ XMLStreamMock.STARTTLS(ssl_context, post_handshake_callback) ], ))) post_handshake_callback.assert_called_once_with( self.xmlstream.transport) def test_starttls_without_callback(self): ssl_context = unittest.mock.MagicMock() self.xmlstream.transport = object() run_coroutine( asyncio.gather( self.xmlstream.starttls(ssl_context, None), self.xmlstream.run_test( [XMLStreamMock.STARTTLS(ssl_context, None)], ))) def test_starttls_reject_incorrect_arguments(self): ssl_context = unittest.mock.MagicMock() post_handshake_callback = unittest.mock.MagicMock() self.xmlstream.transport = object() with self.assertRaisesRegex(AssertionError, "mismatched starttls argument"): run_coroutine( asyncio.gather( self.xmlstream.starttls(object(), post_handshake_callback), self.xmlstream.run_test([ XMLStreamMock.STARTTLS(ssl_context, post_handshake_callback) ], ))) with self.assertRaisesRegex(AssertionError, "mismatched starttls argument"): run_coroutine( asyncio.gather( self.xmlstream.starttls(ssl_context, object()), self.xmlstream.run_test([ XMLStreamMock.STARTTLS(ssl_context, post_handshake_callback) ], ))) def test_starttls_propagates_exception_from_callback(self): ssl_context = unittest.mock.MagicMock() post_handshake_callback = unittest.mock.MagicMock() self.xmlstream.transport = object() exc = ValueError() post_handshake_callback.side_effect = exc caught_exception, other_result = run_coroutine( asyncio.gather(self.xmlstream.starttls(ssl_context, post_handshake_callback), self.xmlstream.run_test([ XMLStreamMock.STARTTLS(ssl_context, post_handshake_callback) ], ), return_exceptions=True)) self.assertIs(caught_exception, exc) self.assertIs(other_result, None) def test_fail(self): exc = ValueError() fun = unittest.mock.MagicMock() fun.return_value = None ec_future = asyncio.ensure_future(self.xmlstream.error_future()) features_future = self.xmlstream.features_future() self.xmlstream.on_closing.connect(fun) run_coroutine( self.xmlstream.run_test([], stimulus=XMLStreamMock.Fail(exc=exc))) self.assertTrue(ec_future.done()) self.assertIs(exc, ec_future.exception()) self.assertTrue(features_future.done()) self.assertIs(exc, features_future.exception()) fun.assert_called_once_with(exc) with self.assertRaises(ValueError) as ctx: self.xmlstream.reset() self.assertIs(exc, ctx.exception) with self.assertRaises(ValueError) as ctx: run_coroutine(self.xmlstream.starttls(object())) self.assertIs(exc, ctx.exception) with self.assertRaises(ValueError) as ctx: self.xmlstream.send_xso(object()) self.assertIs(exc, ctx.exception) with self.assertRaisesRegex(TypeError, "clear_exception"): run_coroutine(self.xmlstream.run_test([], clear_exception=True)) def test_close_and_wait(self): task = asyncio.ensure_future(self.xmlstream.close_and_wait()) run_coroutine(self.xmlstream.run_test([ XMLStreamMock.Close(), ])) self.assertTrue(task.done()) def test_abort(self): fut = self.xmlstream.error_future() ffut = self.xmlstream.features_future() obj = self.Cls() def handler(obj): self.xmlstream.abort() self.xmlstream.stanza_parser.add_class(self.Cls, handler) run_coroutine( self.xmlstream.run_test([ XMLStreamMock.Abort(), ], stimulus=XMLStreamMock.Receive(obj))) self.assertTrue(fut.done()) self.assertIsInstance(fut.exception(), ConnectionError) self.assertTrue(ffut.done()) self.assertIsInstance(ffut.exception(), ConnectionError) def test_catch_surplus_abort(self): self.xmlstream.abort() with self.assertRaisesRegex(AssertionError, "unexpected abort"): run_coroutine(self.xmlstream.run_test([], )) def tearDown(self): del self.xmlstream del self.loop
class TestXMLStreamMock(XMLTestCase): def setUp(self): class Cls(xso.XSO): TAG = ("uri:foo", "foo") self.Cls = Cls self.loop = asyncio.get_event_loop() self.xmlstream = XMLStreamMock(self, loop=self.loop) def test_register_stanza_handler(self): received = [] def handler(obj): nonlocal received received.append(obj) obj = self.Cls() self.xmlstream.stanza_parser.add_class(self.Cls, handler) run_coroutine(self.xmlstream.run_test( [], stimulus=XMLStreamMock.Receive(obj) )) self.assertSequenceEqual( [ obj ], received ) def test_send_xso(self): obj = self.Cls() def handler(obj): self.xmlstream.send_xso(obj) self.xmlstream.stanza_parser.add_class(self.Cls, handler) run_coroutine(self.xmlstream.run_test( [ XMLStreamMock.Send(obj), ], stimulus=XMLStreamMock.Receive(obj) )) def test_catch_missing_stanza_handler(self): obj = self.Cls() with self.assertRaisesRegex(AssertionError, "no handler registered"): run_coroutine(self.xmlstream.run_test( [ ], stimulus=XMLStreamMock.Receive(obj) )) def test_no_termination_on_missing_action(self): obj = self.Cls() with self.assertRaises(asyncio.TimeoutError): run_coroutine( self.xmlstream.run_test( [ XMLStreamMock.Send(obj), ], ), timeout=0.05) def test_catch_surplus_send(self): self.xmlstream.send_xso(self.Cls()) with self.assertRaisesRegex( AssertionError, r"unexpected send_xso\(<tests.test_testutils.TestXMLStreamMock" r".setUp.<locals>.Cls object at 0x[a-fA-F0-9]+>\)"): run_coroutine(self.xmlstream.run_test( [ ], )) def test_reset(self): obj = self.Cls() def handler(obj): self.xmlstream.reset() self.xmlstream.stanza_parser.add_class(self.Cls, handler) run_coroutine(self.xmlstream.run_test( [ XMLStreamMock.Reset(), ], stimulus=XMLStreamMock.Receive(obj) )) def test_catch_surplus_reset(self): self.xmlstream.reset() with self.assertRaisesRegex(AssertionError, "unexpected reset"): run_coroutine(self.xmlstream.run_test( [ ], )) def test_close(self): closing_handler = unittest.mock.Mock() fut = self.xmlstream.error_future() obj = self.Cls() self.xmlstream.on_closing.connect(closing_handler) def handler(obj): self.xmlstream.close() self.xmlstream.stanza_parser.add_class(self.Cls, handler) run_coroutine(self.xmlstream.run_test( [ XMLStreamMock.Close(), ], stimulus=XMLStreamMock.Receive(obj) )) self.assertSequenceEqual( [ unittest.mock.call(None), ], closing_handler.mock_calls ) self.assertTrue(fut.done()) self.assertIsInstance( fut.exception(), ConnectionError ) def test_catch_surplus_close(self): self.xmlstream.close() with self.assertRaisesRegex(AssertionError, "unexpected close"): run_coroutine(self.xmlstream.run_test( [ ], )) def test_mute_unmute_cycle(self): with self.xmlstream.mute(): run_coroutine(self.xmlstream.run_test( [ XMLStreamMock.Mute(), ], )) run_coroutine(self.xmlstream.run_test( [ XMLStreamMock.Unmute(), ], )) def test_catch_surplus_mute(self): with self.xmlstream.mute(): with self.assertRaisesRegex(AssertionError, "unexpected mute"): run_coroutine(self.xmlstream.run_test( [ ], )) def test_catch_surplus_unmute(self): with self.xmlstream.mute(): pass with self.assertRaisesRegex(AssertionError, "unexpected unmute"): run_coroutine(self.xmlstream.run_test( [ XMLStreamMock.Mute(), ], )) def test_starttls(self): ssl_context = unittest.mock.MagicMock() post_handshake_callback = unittest.mock.MagicMock() self.xmlstream.transport = object() run_coroutine( asyncio.gather( self.xmlstream.starttls(ssl_context, post_handshake_callback), self.xmlstream.run_test( [ XMLStreamMock.STARTTLS( ssl_context, post_handshake_callback) ], ) ) ) post_handshake_callback.assert_called_once_with( self.xmlstream.transport) def test_starttls_without_callback(self): ssl_context = unittest.mock.MagicMock() self.xmlstream.transport = object() run_coroutine( asyncio.gather( self.xmlstream.starttls(ssl_context, None), self.xmlstream.run_test( [ XMLStreamMock.STARTTLS(ssl_context, None) ], ) ) ) def test_starttls_reject_incorrect_arguments(self): ssl_context = unittest.mock.MagicMock() post_handshake_callback = unittest.mock.MagicMock() self.xmlstream.transport = object() with self.assertRaisesRegex(AssertionError, "mismatched starttls argument"): run_coroutine( asyncio.gather( self.xmlstream.starttls(object(), post_handshake_callback), self.xmlstream.run_test( [ XMLStreamMock.STARTTLS( ssl_context, post_handshake_callback) ], ) ) ) with self.assertRaisesRegex(AssertionError, "mismatched starttls argument"): run_coroutine( asyncio.gather( self.xmlstream.starttls(ssl_context, object()), self.xmlstream.run_test( [ XMLStreamMock.STARTTLS( ssl_context, post_handshake_callback) ], ) ) ) def test_starttls_propagates_exception_from_callback(self): ssl_context = unittest.mock.MagicMock() post_handshake_callback = unittest.mock.MagicMock() self.xmlstream.transport = object() exc = ValueError() post_handshake_callback.side_effect = exc caught_exception, other_result = run_coroutine( asyncio.gather( self.xmlstream.starttls(ssl_context, post_handshake_callback), self.xmlstream.run_test( [ XMLStreamMock.STARTTLS( ssl_context, post_handshake_callback) ], ), return_exceptions=True ) ) self.assertIs(caught_exception, exc) self.assertIs(other_result, None) def test_fail(self): exc = ValueError() fun = unittest.mock.MagicMock() fun.return_value = None ec_future = asyncio.ensure_future(self.xmlstream.error_future()) self.xmlstream.on_closing.connect(fun) run_coroutine(self.xmlstream.run_test( [ ], stimulus=XMLStreamMock.Fail(exc=exc) )) self.assertTrue(ec_future.done()) self.assertIs(exc, ec_future.exception()) fun.assert_called_once_with(exc) with self.assertRaises(ValueError) as ctx: self.xmlstream.reset() self.assertIs(exc, ctx.exception) with self.assertRaises(ValueError) as ctx: run_coroutine(self.xmlstream.starttls(object())) self.assertIs(exc, ctx.exception) with self.assertRaises(ValueError) as ctx: self.xmlstream.send_xso(object()) self.assertIs(exc, ctx.exception) with self.assertRaisesRegex(TypeError, "clear_exception"): run_coroutine(self.xmlstream.run_test( [ ], clear_exception=True )) def test_close_and_wait(self): task = asyncio.ensure_future(self.xmlstream.close_and_wait()) run_coroutine(self.xmlstream.run_test( [ XMLStreamMock.Close(), ] )) self.assertTrue(task.done()) def test_abort(self): fut = self.xmlstream.error_future() obj = self.Cls() def handler(obj): self.xmlstream.abort() self.xmlstream.stanza_parser.add_class(self.Cls, handler) run_coroutine(self.xmlstream.run_test( [ XMLStreamMock.Abort(), ], stimulus=XMLStreamMock.Receive(obj) )) self.assertTrue(fut.done()) self.assertIsInstance( fut.exception(), ConnectionError ) def test_catch_surplus_abort(self): self.xmlstream.abort() with self.assertRaisesRegex(AssertionError, "unexpected abort"): run_coroutine(self.xmlstream.run_test( [ ], )) def tearDown(self): del self.xmlstream del self.loop