Exemple #1
0
    def test_close(self):
        closing_handler = unittest.mock.Mock()
        fut = self.xmlstream.error_future()

        obj = self.Cls()

        self.xmlstream.on_closing.connect(closing_handler)

        def handler(obj):
            self.xmlstream.close()

        self.xmlstream.stanza_parser.add_class(self.Cls, handler)
        run_coroutine(self.xmlstream.run_test(
            [
                XMLStreamMock.Close(),
            ],
            stimulus=XMLStreamMock.Receive(obj)
        ))

        self.assertSequenceEqual(
            [
                unittest.mock.call(None),
            ],
            closing_handler.mock_calls
        )

        self.assertTrue(fut.done())
        self.assertIsInstance(
            fut.exception(),
            ConnectionError
        )
Exemple #2
0
    def test_starttls_reject_incorrect_arguments(self):
        ssl_context = unittest.mock.MagicMock()
        post_handshake_callback = unittest.mock.MagicMock()

        self.xmlstream.transport = object()

        with self.assertRaisesRegex(AssertionError,
                                    "mismatched starttls argument"):
            run_coroutine(
                asyncio.gather(
                    self.xmlstream.starttls(object(), post_handshake_callback),
                    self.xmlstream.run_test([
                        XMLStreamMock.STARTTLS(ssl_context,
                                               post_handshake_callback)
                    ], )))

        with self.assertRaisesRegex(AssertionError,
                                    "mismatched starttls argument"):
            run_coroutine(
                asyncio.gather(
                    self.xmlstream.starttls(ssl_context, object()),
                    self.xmlstream.run_test([
                        XMLStreamMock.STARTTLS(ssl_context,
                                               post_handshake_callback)
                    ], )))
Exemple #3
0
    def setUp(self):
        class Cls(xso.XSO):
            TAG = ("uri:foo", "foo")

        self.Cls = Cls
        self.loop = asyncio.get_event_loop()
        self.xmlstream = XMLStreamMock(self, loop=self.loop)
Exemple #4
0
 def test_initiate_challenge(self):
     state, payload = self._run_test(self.sm.initiate("foo", b"bar"), [
         XMLStreamMock.Send(nonza.SASLAuth(mechanism="foo", payload=b"bar"),
                            response=XMLStreamMock.Receive(
                                nonza.SASLChallenge(payload=b"baz")))
     ])
     self.assertEqual(state, "challenge")
     self.assertEqual(payload, b"baz")
Exemple #5
0
 def test_initiate_success(self):
     state, payload = self._run_test(self.sm.initiate("foo", b"bar"), [
         XMLStreamMock.Send(nonza.SASLAuth(mechanism="foo", payload=b"bar"),
                            response=XMLStreamMock.Receive(
                                nonza.SASLSuccess()))
     ])
     self.assertEqual(state, "success")
     self.assertIsNone(payload)
Exemple #6
0
    def test_mute_unmute_cycle(self):
        with self.xmlstream.mute():
            run_coroutine(self.xmlstream.run_test([
                XMLStreamMock.Mute(),
            ], ))

        run_coroutine(self.xmlstream.run_test([
            XMLStreamMock.Unmute(),
        ], ))
Exemple #7
0
    def test_response_success(self):
        self.sm._state = "challenge"

        state, payload = self._run_test(self.sm.respond(b"bar"), [
            XMLStreamMock.Send(nonza.SASLResponse(payload=b"bar"),
                               response=XMLStreamMock.Receive(
                                   nonza.SASLSuccess()))
        ])
        self.assertEqual(state, "success")
        self.assertIsNone(payload)
Exemple #8
0
    def test_response_challenge(self):
        self.sm._state = "challenge"

        state, payload = self._run_test(self.sm.respond(b"bar"), [
            XMLStreamMock.Send(nonza.SASLResponse(payload=b"bar"),
                               response=XMLStreamMock.Receive(
                                   nonza.SASLChallenge(payload=b"baz")))
        ])
        self.assertEqual(state, "challenge")
        self.assertEqual(payload, b"baz")
Exemple #9
0
    def test_initiate_failure(self):
        with self.assertRaises(aiosasl.SASLFailure) as ctx:
            self._run_test(self.sm.initiate("foo", b"bar"), [
                XMLStreamMock.Send(
                    nonza.SASLAuth(mechanism="foo", payload=b"bar"),
                    response=XMLStreamMock.Receive(
                        nonza.SASLFailure(condition=(namespaces.sasl,
                                                     "not-authorized"))))
            ])

        self.assertEqual("not-authorized", ctx.exception.opaque_error)
Exemple #10
0
    def test_abort_reject_non_failure(self):
        self.sm._state = "challenge"

        with self.assertRaisesRegex(aiosasl.SASLFailure,
                                    "unexpected non-failure") as ctx:
            self._run_test(self.sm.abort(), [
                XMLStreamMock.Send(nonza.SASLAbort(),
                                   response=XMLStreamMock.Receive(
                                       nonza.SASLSuccess()))
            ])

        self.assertEqual("aborted", ctx.exception.opaque_error)
Exemple #11
0
    def test_reset(self):
        obj = self.Cls()

        def handler(obj):
            self.xmlstream.reset()

        self.xmlstream.stanza_parser.add_class(self.Cls, handler)
        run_coroutine(
            self.xmlstream.run_test([
                XMLStreamMock.Reset(),
            ],
                                    stimulus=XMLStreamMock.Receive(obj)))
Exemple #12
0
    def test_abort_return_on_aborted_error(self):
        self.sm._state = "challenge"

        state, payload = self._run_test(self.sm.abort(), [
            XMLStreamMock.Send(
                nonza.SASLAbort(),
                response=XMLStreamMock.Receive(
                    nonza.SASLFailure(condition=(namespaces.sasl, "aborted"))))
        ])

        self.assertEqual(state, "failure")
        self.assertIsNone(payload)
Exemple #13
0
    def test_response_failure(self):
        self.sm._state = "challenge"

        with self.assertRaises(aiosasl.SASLFailure) as ctx:
            self._run_test(self.sm.respond(b"bar"), [
                XMLStreamMock.Send(
                    nonza.SASLResponse(payload=b"bar"),
                    response=XMLStreamMock.Receive(
                        nonza.SASLFailure(condition=(namespaces.sasl,
                                                     "credentials-expired"))))
            ])

        self.assertEqual("credentials-expired", ctx.exception.opaque_error)
Exemple #14
0
    def test_abort_re_raise_other_errors(self):
        self.sm._state = "challenge"

        with self.assertRaises(aiosasl.SASLFailure) as ctx:
            self._run_test(self.sm.abort(), [
                XMLStreamMock.Send(
                    nonza.SASLAbort(),
                    response=XMLStreamMock.Receive(
                        nonza.SASLFailure(condition=(namespaces.sasl,
                                                     "mechanism-too-weak"))))
            ])

        self.assertEqual("mechanism-too-weak", ctx.exception.opaque_error)
Exemple #15
0
    def test_fail(self):
        exc = ValueError()
        fun = unittest.mock.MagicMock()
        fun.return_value = None

        ec_future = asyncio.ensure_future(self.xmlstream.error_future())
        features_future = self.xmlstream.features_future()

        self.xmlstream.on_closing.connect(fun)

        run_coroutine(
            self.xmlstream.run_test([], stimulus=XMLStreamMock.Fail(exc=exc)))

        self.assertTrue(ec_future.done())
        self.assertIs(exc, ec_future.exception())
        self.assertTrue(features_future.done())
        self.assertIs(exc, features_future.exception())

        fun.assert_called_once_with(exc)

        with self.assertRaises(ValueError) as ctx:
            self.xmlstream.reset()
        self.assertIs(exc, ctx.exception)
        with self.assertRaises(ValueError) as ctx:
            run_coroutine(self.xmlstream.starttls(object()))
        self.assertIs(exc, ctx.exception)
        with self.assertRaises(ValueError) as ctx:
            self.xmlstream.send_xso(object())
        self.assertIs(exc, ctx.exception)

        with self.assertRaisesRegex(TypeError, "clear_exception"):
            run_coroutine(self.xmlstream.run_test([], clear_exception=True))
Exemple #16
0
    def test_catch_missing_stanza_handler(self):
        obj = self.Cls()

        with self.assertRaisesRegex(AssertionError, "no handler registered"):
            run_coroutine(
                self.xmlstream.run_test([],
                                        stimulus=XMLStreamMock.Receive(obj)))
Exemple #17
0
    def test_starttls_propagates_exception_from_callback(self):
        ssl_context = unittest.mock.MagicMock()
        post_handshake_callback = unittest.mock.MagicMock()

        self.xmlstream.transport = object()

        exc = ValueError()
        post_handshake_callback.side_effect = exc

        caught_exception, other_result = run_coroutine(
            asyncio.gather(
                self.xmlstream.starttls(ssl_context, post_handshake_callback),
                self.xmlstream.run_test(
                    [
                        XMLStreamMock.STARTTLS(
                            ssl_context,
                            post_handshake_callback)
                    ],
                ),
                return_exceptions=True
            )
        )

        self.assertIs(caught_exception, exc)
        self.assertIs(other_result, None)
Exemple #18
0
    def setUp(self):
        class Cls(xso.XSO):
            TAG = ("uri:foo", "foo")

        self.Cls = Cls
        self.loop = asyncio.get_event_loop()
        self.xmlstream = XMLStreamMock(self, loop=self.loop)
Exemple #19
0
    def test_abort(self):
        fut = self.xmlstream.error_future()

        obj = self.Cls()

        def handler(obj):
            self.xmlstream.abort()

        self.xmlstream.stanza_parser.add_class(self.Cls, handler)
        run_coroutine(
            self.xmlstream.run_test([
                XMLStreamMock.Abort(),
            ],
                                    stimulus=XMLStreamMock.Receive(obj)))

        self.assertTrue(fut.done())
        self.assertIsInstance(fut.exception(), ConnectionError)
Exemple #20
0
    def test_no_termination_on_missing_action(self):
        obj = self.Cls()

        with self.assertRaises(asyncio.TimeoutError):
            run_coroutine(self.xmlstream.run_test([
                XMLStreamMock.Send(obj),
            ], ),
                          timeout=0.05)
Exemple #21
0
    def test_close_and_wait(self):
        task = asyncio.ensure_future(self.xmlstream.close_and_wait())

        run_coroutine(self.xmlstream.run_test([
            XMLStreamMock.Close(),
        ]))

        self.assertTrue(task.done())
Exemple #22
0
    def test_catch_surplus_unmute(self):
        with self.xmlstream.mute():
            pass

        with self.assertRaisesRegex(AssertionError, "unexpected unmute"):
            run_coroutine(self.xmlstream.run_test([
                XMLStreamMock.Mute(),
            ], ))
Exemple #23
0
    def test_receive_stream_features_into_future(self):
        fut = self.xmlstream.features_future()
        obj = nonza.StreamFeatures()

        run_coroutine(
            self.xmlstream.run_test([], stimulus=XMLStreamMock.Receive(obj)))

        self.assertTrue(fut.done())
        self.assertIs(fut.result(), obj)
Exemple #24
0
    def test_starttls_without_callback(self):
        ssl_context = unittest.mock.MagicMock()

        self.xmlstream.transport = object()

        run_coroutine(
            asyncio.gather(
                self.xmlstream.starttls(ssl_context, None),
                self.xmlstream.run_test(
                    [XMLStreamMock.STARTTLS(ssl_context, None)], )))
Exemple #25
0
    def test_register_stanza_handler(self):
        received = []

        def handler(obj):
            nonlocal received
            received.append(obj)

        obj = self.Cls()

        self.xmlstream.stanza_parser.add_class(self.Cls, handler)

        run_coroutine(
            self.xmlstream.run_test([], stimulus=XMLStreamMock.Receive(obj)))

        self.assertSequenceEqual([obj], received)
Exemple #26
0
    def test_starttls(self):
        ssl_context = unittest.mock.MagicMock()
        post_handshake_callback = unittest.mock.MagicMock()

        self.xmlstream.transport = object()

        run_coroutine(
            asyncio.gather(
                self.xmlstream.starttls(ssl_context, post_handshake_callback),
                self.xmlstream.run_test([
                    XMLStreamMock.STARTTLS(ssl_context,
                                           post_handshake_callback)
                ], )))

        post_handshake_callback.assert_called_once_with(
            self.xmlstream.transport)
Exemple #27
0
 def setUp(self):
     self.loop = asyncio.get_event_loop()
     self.xmlstream = XMLStreamMock(self, loop=self.loop)
     self.sm = sasl.SASLXMPPInterface(self.xmlstream)
Exemple #28
0
class TestSASLXMPPInterface(xmltestutils.XMLTestCase):
    def setUp(self):
        self.loop = asyncio.get_event_loop()
        self.xmlstream = XMLStreamMock(self, loop=self.loop)
        self.sm = sasl.SASLXMPPInterface(self.xmlstream)

    def _run_test(self, coro, actions=[], stimulus=None):
        return run_coroutine_with_peer(
            coro,
            self.xmlstream.run_test(actions, stimulus=stimulus),
            loop=self.loop)

    def test_setup(self):
        self.assertIsNone(self.sm.timeout)
        self.assertIs(self.xmlstream, self.sm.xmlstream)

    def test_initiate_success(self):
        state, payload = self._run_test(
            self.sm.initiate("foo", b"bar"),
            [
                XMLStreamMock.Mute(),
                XMLStreamMock.Send(
                    nonza.SASLAuth(mechanism="foo",
                                   payload=b"bar"),
                    response=XMLStreamMock.Receive(
                        nonza.SASLSuccess()
                    )
                ),
                XMLStreamMock.Unmute(),
            ]
        )
        self.assertEqual(state, "success")
        self.assertIsNone(payload)

    def test_initiate_failure(self):
        with self.assertRaises(aiosasl.SASLFailure) as ctx:
            self._run_test(
                self.sm.initiate("foo", b"bar"),
                [
                    XMLStreamMock.Mute(),
                    XMLStreamMock.Send(
                        nonza.SASLAuth(mechanism="foo",
                                       payload=b"bar"),
                        response=XMLStreamMock.Receive(
                            nonza.SASLFailure(
                                condition=(namespaces.sasl, "not-authorized")
                            )
                        )
                    ),
                    XMLStreamMock.Unmute(),
                ]
            )

        self.assertEqual(
            "not-authorized",
            ctx.exception.opaque_error
        )

    def test_initiate_challenge(self):
        state, payload = self._run_test(
            self.sm.initiate("foo", b"bar"),
            [
                XMLStreamMock.Mute(),
                XMLStreamMock.Send(
                    nonza.SASLAuth(mechanism="foo",
                                   payload=b"bar"),
                    response=XMLStreamMock.Receive(
                        nonza.SASLChallenge(payload=b"baz")
                    )
                ),
                XMLStreamMock.Unmute(),
            ]
        )
        self.assertEqual(state, "challenge")
        self.assertEqual(payload, b"baz")

    def test_response_success(self):
        self.sm._state = "challenge"

        state, payload = self._run_test(
            self.sm.respond(b"bar"),
            [
                XMLStreamMock.Mute(),
                XMLStreamMock.Send(
                    nonza.SASLResponse(payload=b"bar"),
                    response=XMLStreamMock.Receive(
                        nonza.SASLSuccess()
                    )
                ),
                XMLStreamMock.Unmute(),
            ]
        )
        self.assertEqual(state, "success")
        self.assertIsNone(payload)

    def test_response_failure(self):
        self.sm._state = "challenge"

        with self.assertRaises(aiosasl.SASLFailure) as ctx:
            self._run_test(
                self.sm.respond(b"bar"),
                [
                    XMLStreamMock.Mute(),
                    XMLStreamMock.Send(
                        nonza.SASLResponse(payload=b"bar"),
                        response=XMLStreamMock.Receive(
                            nonza.SASLFailure(
                                condition=(namespaces.sasl,
                                           "credentials-expired")
                            )
                        )
                    ),
                    XMLStreamMock.Unmute(),
                ]
            )

        self.assertEqual(
            "credentials-expired",
            ctx.exception.opaque_error
        )

    def test_response_challenge(self):
        self.sm._state = "challenge"

        state, payload = self._run_test(
            self.sm.respond(b"bar"),
            [
                XMLStreamMock.Mute(),
                XMLStreamMock.Send(
                    nonza.SASLResponse(payload=b"bar"),
                    response=XMLStreamMock.Receive(
                        nonza.SASLChallenge(payload=b"baz")
                    )
                ),
                XMLStreamMock.Unmute(),
            ]
        )
        self.assertEqual(state, "challenge")
        self.assertEqual(payload, b"baz")

    def test_abort_reject_non_failure(self):
        self.sm._state = "challenge"

        with self.assertRaisesRegex(
            aiosasl.SASLFailure,
            "unexpected non-failure"
        ) as ctx:
            self._run_test(
                self.sm.abort(),
                [
                    XMLStreamMock.Send(
                        nonza.SASLAbort(),
                        response=XMLStreamMock.Receive(
                            nonza.SASLSuccess()
                        )
                    )
                ]
            )

        self.assertEqual(
            "aborted",
            ctx.exception.opaque_error
        )

    def test_abort_return_on_aborted_error(self):
        self.sm._state = "challenge"

        state, payload = self._run_test(
            self.sm.abort(),
            [
                XMLStreamMock.Send(
                    nonza.SASLAbort(),
                    response=XMLStreamMock.Receive(
                        nonza.SASLFailure(
                            condition=(namespaces.sasl, "aborted")
                        )
                    )
                )
            ]
        )

        self.assertEqual(state, "failure")
        self.assertIsNone(payload)

    def test_abort_re_raise_other_errors(self):
        self.sm._state = "challenge"

        with self.assertRaises(aiosasl.SASLFailure) as ctx:
            self._run_test(
                self.sm.abort(),
                [
                    XMLStreamMock.Send(
                        nonza.SASLAbort(),
                        response=XMLStreamMock.Receive(
                            nonza.SASLFailure(
                                condition=(namespaces.sasl,
                                           "mechanism-too-weak")
                            )
                        )
                    )
                ]
            )

        self.assertEqual(
            "mechanism-too-weak",
            ctx.exception.opaque_error
        )

    def tearDown(self):
        del self.xmlstream
        del self.loop
Exemple #29
0
class TestXMLStreamMock(XMLTestCase):
    def setUp(self):
        class Cls(xso.XSO):
            TAG = ("uri:foo", "foo")

        self.Cls = Cls
        self.loop = asyncio.get_event_loop()
        self.xmlstream = XMLStreamMock(self, loop=self.loop)

    def test_register_stanza_handler(self):
        received = []

        def handler(obj):
            nonlocal received
            received.append(obj)

        obj = self.Cls()

        self.xmlstream.stanza_parser.add_class(self.Cls, handler)

        run_coroutine(self.xmlstream.run_test(
            [],
            stimulus=XMLStreamMock.Receive(obj)
        ))

        self.assertSequenceEqual(
            [
                obj
            ],
            received
        )

    def test_send_xso(self):
        obj = self.Cls()

        def handler(obj):
            self.xmlstream.send_xso(obj)

        self.xmlstream.stanza_parser.add_class(self.Cls, handler)
        run_coroutine(self.xmlstream.run_test(
            [
                XMLStreamMock.Send(obj),
            ],
            stimulus=XMLStreamMock.Receive(obj)
        ))

    def test_catch_missing_stanza_handler(self):
        obj = self.Cls()

        with self.assertRaisesRegex(AssertionError, "no handler registered"):
            run_coroutine(self.xmlstream.run_test(
                [
                ],
                stimulus=XMLStreamMock.Receive(obj)
            ))

    def test_no_termination_on_missing_action(self):
        obj = self.Cls()

        with self.assertRaises(asyncio.TimeoutError):
            run_coroutine(
                self.xmlstream.run_test(
                    [
                        XMLStreamMock.Send(obj),
                    ],
                ),
                timeout=0.05)

    def test_catch_surplus_send(self):
        self.xmlstream.send_xso(self.Cls())

        with self.assertRaisesRegex(
                AssertionError,
                r"unexpected send_xso\(<tests.test_testutils.TestXMLStreamMock"
                r".setUp.<locals>.Cls object at 0x[a-fA-F0-9]+>\)"):
            run_coroutine(self.xmlstream.run_test(
                [
                ],
            ))

    def test_reset(self):
        obj = self.Cls()

        def handler(obj):
            self.xmlstream.reset()

        self.xmlstream.stanza_parser.add_class(self.Cls, handler)
        run_coroutine(self.xmlstream.run_test(
            [
                XMLStreamMock.Reset(),
            ],
            stimulus=XMLStreamMock.Receive(obj)
        ))

    def test_catch_surplus_reset(self):
        self.xmlstream.reset()

        with self.assertRaisesRegex(AssertionError,
                                    "unexpected reset"):
            run_coroutine(self.xmlstream.run_test(
                [
                ],
            ))

    def test_close(self):
        closing_handler = unittest.mock.Mock()
        fut = self.xmlstream.error_future()

        obj = self.Cls()

        self.xmlstream.on_closing.connect(closing_handler)

        def handler(obj):
            self.xmlstream.close()

        self.xmlstream.stanza_parser.add_class(self.Cls, handler)
        run_coroutine(self.xmlstream.run_test(
            [
                XMLStreamMock.Close(),
            ],
            stimulus=XMLStreamMock.Receive(obj)
        ))

        self.assertSequenceEqual(
            [
                unittest.mock.call(None),
            ],
            closing_handler.mock_calls
        )

        self.assertTrue(fut.done())
        self.assertIsInstance(
            fut.exception(),
            ConnectionError
        )

    def test_catch_surplus_close(self):
        self.xmlstream.close()

        with self.assertRaisesRegex(AssertionError,
                                    "unexpected close"):
            run_coroutine(self.xmlstream.run_test(
                [
                ],
            ))

    def test_mute_unmute_cycle(self):
        with self.xmlstream.mute():
            run_coroutine(self.xmlstream.run_test(
                [
                    XMLStreamMock.Mute(),
                ],
            ))

        run_coroutine(self.xmlstream.run_test(
            [
                XMLStreamMock.Unmute(),
            ],
        ))

    def test_catch_surplus_mute(self):
        with self.xmlstream.mute():
            with self.assertRaisesRegex(AssertionError,
                                        "unexpected mute"):
                run_coroutine(self.xmlstream.run_test(
                    [
                    ],
                ))

    def test_catch_surplus_unmute(self):
        with self.xmlstream.mute():
            pass

        with self.assertRaisesRegex(AssertionError,
                                    "unexpected unmute"):
            run_coroutine(self.xmlstream.run_test(
                [
                    XMLStreamMock.Mute(),
                ],
            ))

    def test_starttls(self):
        ssl_context = unittest.mock.MagicMock()
        post_handshake_callback = unittest.mock.MagicMock()

        self.xmlstream.transport = object()

        run_coroutine(
            asyncio.gather(
                self.xmlstream.starttls(ssl_context, post_handshake_callback),
                self.xmlstream.run_test(
                    [
                        XMLStreamMock.STARTTLS(
                            ssl_context,
                            post_handshake_callback)
                    ],
                )
            )
        )

        post_handshake_callback.assert_called_once_with(
            self.xmlstream.transport)

    def test_starttls_without_callback(self):
        ssl_context = unittest.mock.MagicMock()

        self.xmlstream.transport = object()

        run_coroutine(
            asyncio.gather(
                self.xmlstream.starttls(ssl_context, None),
                self.xmlstream.run_test(
                    [
                        XMLStreamMock.STARTTLS(ssl_context, None)
                    ],
                )
            )
        )

    def test_starttls_reject_incorrect_arguments(self):
        ssl_context = unittest.mock.MagicMock()
        post_handshake_callback = unittest.mock.MagicMock()

        self.xmlstream.transport = object()

        with self.assertRaisesRegex(AssertionError,
                                    "mismatched starttls argument"):
            run_coroutine(
                asyncio.gather(
                    self.xmlstream.starttls(object(), post_handshake_callback),
                    self.xmlstream.run_test(
                        [
                            XMLStreamMock.STARTTLS(
                                ssl_context,
                                post_handshake_callback)
                        ],
                    )
                )
            )

        with self.assertRaisesRegex(AssertionError,
                                    "mismatched starttls argument"):
            run_coroutine(
                asyncio.gather(
                    self.xmlstream.starttls(ssl_context, object()),
                    self.xmlstream.run_test(
                        [
                            XMLStreamMock.STARTTLS(
                                ssl_context,
                                post_handshake_callback)
                        ],
                    )
                )
            )

    def test_starttls_propagates_exception_from_callback(self):
        ssl_context = unittest.mock.MagicMock()
        post_handshake_callback = unittest.mock.MagicMock()

        self.xmlstream.transport = object()

        exc = ValueError()
        post_handshake_callback.side_effect = exc

        caught_exception, other_result = run_coroutine(
            asyncio.gather(
                self.xmlstream.starttls(ssl_context, post_handshake_callback),
                self.xmlstream.run_test(
                    [
                        XMLStreamMock.STARTTLS(
                            ssl_context,
                            post_handshake_callback)
                    ],
                ),
                return_exceptions=True
            )
        )

        self.assertIs(caught_exception, exc)
        self.assertIs(other_result, None)

    def test_fail(self):
        exc = ValueError()
        fun = unittest.mock.MagicMock()
        fun.return_value = None

        ec_future = asyncio.ensure_future(self.xmlstream.error_future())

        self.xmlstream.on_closing.connect(fun)

        run_coroutine(self.xmlstream.run_test(
            [
            ],
            stimulus=XMLStreamMock.Fail(exc=exc)
        ))

        self.assertTrue(ec_future.done())
        self.assertIs(exc, ec_future.exception())

        fun.assert_called_once_with(exc)

        with self.assertRaises(ValueError) as ctx:
            self.xmlstream.reset()
        self.assertIs(exc, ctx.exception)
        with self.assertRaises(ValueError) as ctx:
            run_coroutine(self.xmlstream.starttls(object()))
        self.assertIs(exc, ctx.exception)
        with self.assertRaises(ValueError) as ctx:
            self.xmlstream.send_xso(object())
        self.assertIs(exc, ctx.exception)

        with self.assertRaisesRegex(TypeError,
                                    "clear_exception"):
            run_coroutine(self.xmlstream.run_test(
                [
                ],
                clear_exception=True
            ))

    def test_close_and_wait(self):
        task = asyncio.ensure_future(self.xmlstream.close_and_wait())

        run_coroutine(self.xmlstream.run_test(
            [
                XMLStreamMock.Close(),
            ]
        ))

        self.assertTrue(task.done())

    def test_abort(self):
        fut = self.xmlstream.error_future()

        obj = self.Cls()

        def handler(obj):
            self.xmlstream.abort()

        self.xmlstream.stanza_parser.add_class(self.Cls, handler)
        run_coroutine(self.xmlstream.run_test(
            [
                XMLStreamMock.Abort(),
            ],
            stimulus=XMLStreamMock.Receive(obj)
        ))

        self.assertTrue(fut.done())
        self.assertIsInstance(
            fut.exception(),
            ConnectionError
        )

    def test_catch_surplus_abort(self):
        self.xmlstream.abort()

        with self.assertRaisesRegex(AssertionError,
                                    "unexpected abort"):
            run_coroutine(self.xmlstream.run_test(
                [
                ],
            ))

    def tearDown(self):
        del self.xmlstream
        del self.loop
Exemple #30
0
class TestSASLXMPPInterface(xmltestutils.XMLTestCase):
    def setUp(self):
        self.loop = asyncio.get_event_loop()
        self.xmlstream = XMLStreamMock(self, loop=self.loop)
        self.sm = sasl.SASLXMPPInterface(self.xmlstream)

    def _run_test(self, coro, actions=[], stimulus=None):
        return run_coroutine_with_peer(coro,
                                       self.xmlstream.run_test(
                                           actions, stimulus=stimulus),
                                       loop=self.loop)

    def test_setup(self):
        self.assertIsNone(self.sm.timeout)
        self.assertIs(self.xmlstream, self.sm.xmlstream)

    def test_initiate_success(self):
        state, payload = self._run_test(self.sm.initiate("foo", b"bar"), [
            XMLStreamMock.Send(nonza.SASLAuth(mechanism="foo", payload=b"bar"),
                               response=XMLStreamMock.Receive(
                                   nonza.SASLSuccess()))
        ])
        self.assertEqual(state, "success")
        self.assertIsNone(payload)

    def test_initiate_failure(self):
        with self.assertRaises(aiosasl.SASLFailure) as ctx:
            self._run_test(self.sm.initiate("foo", b"bar"), [
                XMLStreamMock.Send(
                    nonza.SASLAuth(mechanism="foo", payload=b"bar"),
                    response=XMLStreamMock.Receive(
                        nonza.SASLFailure(condition=(namespaces.sasl,
                                                     "not-authorized"))))
            ])

        self.assertEqual("not-authorized", ctx.exception.opaque_error)

    def test_initiate_challenge(self):
        state, payload = self._run_test(self.sm.initiate("foo", b"bar"), [
            XMLStreamMock.Send(nonza.SASLAuth(mechanism="foo", payload=b"bar"),
                               response=XMLStreamMock.Receive(
                                   nonza.SASLChallenge(payload=b"baz")))
        ])
        self.assertEqual(state, "challenge")
        self.assertEqual(payload, b"baz")

    def test_response_success(self):
        self.sm._state = "challenge"

        state, payload = self._run_test(self.sm.respond(b"bar"), [
            XMLStreamMock.Send(nonza.SASLResponse(payload=b"bar"),
                               response=XMLStreamMock.Receive(
                                   nonza.SASLSuccess()))
        ])
        self.assertEqual(state, "success")
        self.assertIsNone(payload)

    def test_response_failure(self):
        self.sm._state = "challenge"

        with self.assertRaises(aiosasl.SASLFailure) as ctx:
            self._run_test(self.sm.respond(b"bar"), [
                XMLStreamMock.Send(
                    nonza.SASLResponse(payload=b"bar"),
                    response=XMLStreamMock.Receive(
                        nonza.SASLFailure(condition=(namespaces.sasl,
                                                     "credentials-expired"))))
            ])

        self.assertEqual("credentials-expired", ctx.exception.opaque_error)

    def test_response_challenge(self):
        self.sm._state = "challenge"

        state, payload = self._run_test(self.sm.respond(b"bar"), [
            XMLStreamMock.Send(nonza.SASLResponse(payload=b"bar"),
                               response=XMLStreamMock.Receive(
                                   nonza.SASLChallenge(payload=b"baz")))
        ])
        self.assertEqual(state, "challenge")
        self.assertEqual(payload, b"baz")

    def test_abort_reject_non_failure(self):
        self.sm._state = "challenge"

        with self.assertRaisesRegex(aiosasl.SASLFailure,
                                    "unexpected non-failure") as ctx:
            self._run_test(self.sm.abort(), [
                XMLStreamMock.Send(nonza.SASLAbort(),
                                   response=XMLStreamMock.Receive(
                                       nonza.SASLSuccess()))
            ])

        self.assertEqual("aborted", ctx.exception.opaque_error)

    def test_abort_return_on_aborted_error(self):
        self.sm._state = "challenge"

        state, payload = self._run_test(self.sm.abort(), [
            XMLStreamMock.Send(
                nonza.SASLAbort(),
                response=XMLStreamMock.Receive(
                    nonza.SASLFailure(condition=(namespaces.sasl, "aborted"))))
        ])

        self.assertEqual(state, "failure")
        self.assertIsNone(payload)

    def test_abort_re_raise_other_errors(self):
        self.sm._state = "challenge"

        with self.assertRaises(aiosasl.SASLFailure) as ctx:
            self._run_test(self.sm.abort(), [
                XMLStreamMock.Send(
                    nonza.SASLAbort(),
                    response=XMLStreamMock.Receive(
                        nonza.SASLFailure(condition=(namespaces.sasl,
                                                     "mechanism-too-weak"))))
            ])

        self.assertEqual("mechanism-too-weak", ctx.exception.opaque_error)

    def tearDown(self):
        del self.xmlstream
        del self.loop
Exemple #31
0
 def setUp(self):
     self.loop = asyncio.get_event_loop()
     self.xmlstream = XMLStreamMock(self, loop=self.loop)
     self.sm = sasl.SASLXMPPInterface(self.xmlstream)
Exemple #32
0
class TestXMLStreamMock(XMLTestCase):
    def setUp(self):
        class Cls(xso.XSO):
            TAG = ("uri:foo", "foo")

        self.Cls = Cls
        self.loop = asyncio.get_event_loop()
        self.xmlstream = XMLStreamMock(self, loop=self.loop)

    def test_register_stanza_handler(self):
        received = []

        def handler(obj):
            nonlocal received
            received.append(obj)

        obj = self.Cls()

        self.xmlstream.stanza_parser.add_class(self.Cls, handler)

        run_coroutine(
            self.xmlstream.run_test([], stimulus=XMLStreamMock.Receive(obj)))

        self.assertSequenceEqual([obj], received)

    def test_send_xso(self):
        obj = self.Cls()

        def handler(obj):
            self.xmlstream.send_xso(obj)

        self.xmlstream.stanza_parser.add_class(self.Cls, handler)
        run_coroutine(
            self.xmlstream.run_test([
                XMLStreamMock.Send(obj),
            ],
                                    stimulus=XMLStreamMock.Receive(obj)))

    def test_catch_missing_stanza_handler(self):
        obj = self.Cls()

        with self.assertRaisesRegex(AssertionError, "no handler registered"):
            run_coroutine(
                self.xmlstream.run_test([],
                                        stimulus=XMLStreamMock.Receive(obj)))

    def test_receive_stream_features_into_future(self):
        fut = self.xmlstream.features_future()
        obj = nonza.StreamFeatures()

        run_coroutine(
            self.xmlstream.run_test([], stimulus=XMLStreamMock.Receive(obj)))

        self.assertTrue(fut.done())
        self.assertIs(fut.result(), obj)

    def test_no_termination_on_missing_action(self):
        obj = self.Cls()

        with self.assertRaises(asyncio.TimeoutError):
            run_coroutine(self.xmlstream.run_test([
                XMLStreamMock.Send(obj),
            ], ),
                          timeout=0.05)

    def test_catch_surplus_send(self):
        self.xmlstream.send_xso(self.Cls())

        with self.assertRaisesRegex(
                AssertionError,
                r"unexpected send_xso\(<tests.test_testutils.TestXMLStreamMock"
                r".setUp.<locals>.Cls object at 0x[a-fA-F0-9]+>\)"):
            run_coroutine(self.xmlstream.run_test([], ))

    def test_reset(self):
        obj = self.Cls()

        def handler(obj):
            self.xmlstream.reset()

        self.xmlstream.stanza_parser.add_class(self.Cls, handler)
        run_coroutine(
            self.xmlstream.run_test([
                XMLStreamMock.Reset(),
            ],
                                    stimulus=XMLStreamMock.Receive(obj)))

    def test_catch_surplus_reset(self):
        self.xmlstream.reset()

        with self.assertRaisesRegex(AssertionError, "unexpected reset"):
            run_coroutine(self.xmlstream.run_test([], ))

    def test_close(self):
        closing_handler = unittest.mock.Mock()
        fut = self.xmlstream.error_future()

        obj = self.Cls()

        self.xmlstream.on_closing.connect(closing_handler)

        def handler(obj):
            self.xmlstream.close()

        self.xmlstream.stanza_parser.add_class(self.Cls, handler)
        run_coroutine(
            self.xmlstream.run_test([
                XMLStreamMock.Close(),
            ],
                                    stimulus=XMLStreamMock.Receive(obj)))

        self.assertSequenceEqual([
            unittest.mock.call(None),
        ], closing_handler.mock_calls)

        self.assertTrue(fut.done())
        self.assertIsInstance(fut.exception(), ConnectionError)

    def test_catch_surplus_close(self):
        self.xmlstream.close()

        with self.assertRaisesRegex(AssertionError, "unexpected close"):
            run_coroutine(self.xmlstream.run_test([], ))

    def test_mute_unmute_cycle(self):
        with self.xmlstream.mute():
            run_coroutine(self.xmlstream.run_test([
                XMLStreamMock.Mute(),
            ], ))

        run_coroutine(self.xmlstream.run_test([
            XMLStreamMock.Unmute(),
        ], ))

    def test_catch_surplus_mute(self):
        with self.xmlstream.mute():
            with self.assertRaisesRegex(AssertionError, "unexpected mute"):
                run_coroutine(self.xmlstream.run_test([], ))

    def test_catch_surplus_unmute(self):
        with self.xmlstream.mute():
            pass

        with self.assertRaisesRegex(AssertionError, "unexpected unmute"):
            run_coroutine(self.xmlstream.run_test([
                XMLStreamMock.Mute(),
            ], ))

    def test_starttls(self):
        ssl_context = unittest.mock.MagicMock()
        post_handshake_callback = CoroutineMock()
        post_handshake_callback.return_value = None

        self.xmlstream.transport = object()

        run_coroutine(
            asyncio.gather(
                self.xmlstream.starttls(ssl_context, post_handshake_callback),
                self.xmlstream.run_test([
                    XMLStreamMock.STARTTLS(ssl_context,
                                           post_handshake_callback)
                ], )))

        post_handshake_callback.assert_called_once_with(
            self.xmlstream.transport)

    def test_starttls_without_callback(self):
        ssl_context = unittest.mock.MagicMock()

        self.xmlstream.transport = object()

        run_coroutine(
            asyncio.gather(
                self.xmlstream.starttls(ssl_context, None),
                self.xmlstream.run_test(
                    [XMLStreamMock.STARTTLS(ssl_context, None)], )))

    def test_starttls_reject_incorrect_arguments(self):
        ssl_context = unittest.mock.MagicMock()
        post_handshake_callback = unittest.mock.MagicMock()

        self.xmlstream.transport = object()

        with self.assertRaisesRegex(AssertionError,
                                    "mismatched starttls argument"):
            run_coroutine(
                asyncio.gather(
                    self.xmlstream.starttls(object(), post_handshake_callback),
                    self.xmlstream.run_test([
                        XMLStreamMock.STARTTLS(ssl_context,
                                               post_handshake_callback)
                    ], )))

        with self.assertRaisesRegex(AssertionError,
                                    "mismatched starttls argument"):
            run_coroutine(
                asyncio.gather(
                    self.xmlstream.starttls(ssl_context, object()),
                    self.xmlstream.run_test([
                        XMLStreamMock.STARTTLS(ssl_context,
                                               post_handshake_callback)
                    ], )))

    def test_starttls_propagates_exception_from_callback(self):
        ssl_context = unittest.mock.MagicMock()
        post_handshake_callback = unittest.mock.MagicMock()

        self.xmlstream.transport = object()

        exc = ValueError()
        post_handshake_callback.side_effect = exc

        caught_exception, other_result = run_coroutine(
            asyncio.gather(self.xmlstream.starttls(ssl_context,
                                                   post_handshake_callback),
                           self.xmlstream.run_test([
                               XMLStreamMock.STARTTLS(ssl_context,
                                                      post_handshake_callback)
                           ], ),
                           return_exceptions=True))

        self.assertIs(caught_exception, exc)
        self.assertIs(other_result, None)

    def test_fail(self):
        exc = ValueError()
        fun = unittest.mock.MagicMock()
        fun.return_value = None

        ec_future = asyncio.ensure_future(self.xmlstream.error_future())
        features_future = self.xmlstream.features_future()

        self.xmlstream.on_closing.connect(fun)

        run_coroutine(
            self.xmlstream.run_test([], stimulus=XMLStreamMock.Fail(exc=exc)))

        self.assertTrue(ec_future.done())
        self.assertIs(exc, ec_future.exception())
        self.assertTrue(features_future.done())
        self.assertIs(exc, features_future.exception())

        fun.assert_called_once_with(exc)

        with self.assertRaises(ValueError) as ctx:
            self.xmlstream.reset()
        self.assertIs(exc, ctx.exception)
        with self.assertRaises(ValueError) as ctx:
            run_coroutine(self.xmlstream.starttls(object()))
        self.assertIs(exc, ctx.exception)
        with self.assertRaises(ValueError) as ctx:
            self.xmlstream.send_xso(object())
        self.assertIs(exc, ctx.exception)

        with self.assertRaisesRegex(TypeError, "clear_exception"):
            run_coroutine(self.xmlstream.run_test([], clear_exception=True))

    def test_close_and_wait(self):
        task = asyncio.ensure_future(self.xmlstream.close_and_wait())

        run_coroutine(self.xmlstream.run_test([
            XMLStreamMock.Close(),
        ]))

        self.assertTrue(task.done())

    def test_abort(self):
        fut = self.xmlstream.error_future()
        ffut = self.xmlstream.features_future()

        obj = self.Cls()

        def handler(obj):
            self.xmlstream.abort()

        self.xmlstream.stanza_parser.add_class(self.Cls, handler)
        run_coroutine(
            self.xmlstream.run_test([
                XMLStreamMock.Abort(),
            ],
                                    stimulus=XMLStreamMock.Receive(obj)))

        self.assertTrue(fut.done())
        self.assertIsInstance(fut.exception(), ConnectionError)

        self.assertTrue(ffut.done())
        self.assertIsInstance(ffut.exception(), ConnectionError)

    def test_catch_surplus_abort(self):
        self.xmlstream.abort()

        with self.assertRaisesRegex(AssertionError, "unexpected abort"):
            run_coroutine(self.xmlstream.run_test([], ))

    def tearDown(self):
        del self.xmlstream
        del self.loop