Пример #1
0
    def test_starttls_without_callback(self):
        self.t = TransportMock(self, self.protocol,
                               with_starttls=True,
                               loop=self.loop)
        self.assertTrue(self.t.can_starttls())

        fut = asyncio.Future()
        def connection_made(transport):
            fut.set_result(None)

        self.protocol.connection_made = connection_made

        ssl_context = unittest.mock.Mock()

        @asyncio.coroutine
        def late_starttls():
            yield from fut
            yield from self.t.starttls(ssl_context)

        run_coroutine_with_peer(
            late_starttls(),
            self.t.run_test(
                [
                    TransportMock.STARTTLS(ssl_context, None)
                ]
            )
        )
Пример #2
0
    def test_starttls(self):
        self.t = TransportMock(self,
                               self.protocol,
                               with_starttls=True,
                               loop=self.loop)
        self.assertTrue(self.t.can_starttls())

        fut = asyncio.Future()

        def connection_made(transport):
            fut.set_result(None)

        self.protocol.connection_made = connection_made

        ssl_context = unittest.mock.Mock()
        post_handshake_callback = CoroutineMock()
        post_handshake_callback.return_value = None

        async def late_starttls():
            await fut
            await self.t.starttls(ssl_context, post_handshake_callback)

        run_coroutine_with_peer(
            late_starttls(),
            self.t.run_test(
                [TransportMock.STARTTLS(ssl_context,
                                        post_handshake_callback)]))

        post_handshake_callback.assert_called_once_with(self.t)
Пример #3
0
    def test_partial(self):
        def connection_made(transport):
            transport.write(b"foo")

        self.protocol.connection_made = connection_made

        self._run_test(
            self.t,
            [
                TransportMock.Write(
                    b"foo",
                ),
            ],
            partial=True
        )

        self.t.write_eof()
        self.t.close()

        self._run_test(
            self.t,
            [
                TransportMock.WriteEof(),
                TransportMock.Close()
            ]
        )
Пример #4
0
    def test_no_response_conflict(self):
        data = []

        def data_received(blob):
            data.append(blob)

        def connection_made(transport):
            transport.write(b"foo")
            self.assertFalse(data)
            transport.write(b"bar")

        self.protocol.connection_made = connection_made
        self.protocol.data_received = data_received

        self._run_test(
            self.t,
            [
                TransportMock.Write(
                    b"foo",
                    response=TransportMock.Receive(b"baz"),
                ),
                TransportMock.Write(
                    b"bar",
                    response=TransportMock.Receive(b"baz")
                )
            ])
Пример #5
0
    def test_iq_results_are_not_replied_to(self):
        import aioxmpp.protocol
        import aioxmpp.stream

        version = (1, 0)

        fut = asyncio.Future()
        p = aioxmpp.protocol.XMLStream(
            to=TEST_PEER,
            sorted_attributes=True,
            features_future=fut)
        t = TransportMock(self, p)
        s = aioxmpp.stream.StanzaStream(TEST_FROM.bare())

        run_coroutine(t.run_test(
            [
                TransportMock.Write(
                    STREAM_HEADER,
                    response=[
                        TransportMock.Receive(
                            PEER_STREAM_HEADER_TEMPLATE.format(
                                minor=version[1],
                                major=version[0]).encode("utf-8")),
                    ]
                ),
            ],
            partial=True
        ))

        self.assertEqual(p.state, aioxmpp.protocol.State.OPEN)

        s.start(p)

        run_coroutine(
            t.run_test(
                [
                ],
                stimulus=[
                    TransportMock.Receive(
                        b'<iq type="result" id="foo">'
                        b'<payload xmlns="fnord"/>'
                        b'</iq>'
                    )
                ],
                partial=True,
            )
        )

        s.flush_incoming()

        run_coroutine(asyncio.sleep(0))

        run_coroutine(
            t.run_test(
                [
                ],
            )
        )

        s.stop()
Пример #6
0
    def test_response_eof_received(self):
        def connection_made(transport):
            transport.close()

        self.protocol.connection_made = connection_made
        self._run_test(
            self.t, [TransportMock.Close(response=TransportMock.ReceiveEof())])
        self.assertSequenceEqual(self.protocol.mock_calls, [
            unittest.mock.call.eof_received(),
            unittest.mock.call.connection_lost(None)
        ])
Пример #7
0
    def test_response_lose_connection(self):
        def connection_made(transport):
            transport.close()

        obj = object()

        self.protocol.connection_made = connection_made
        self._run_test(
            self.t,
            [TransportMock.Close(response=TransportMock.LoseConnection(obj))])
        self.assertSequenceEqual(self.protocol.mock_calls,
                                 [unittest.mock.call.connection_lost(obj)])
Пример #8
0
    def test_request_response(self):
        def data_received(data):
            assert data in {b"foo", b"baz"}
            if data == b"foo":
                self.t.write(b"bar")
            elif data == b"baz":
                self.t.close()

        self.protocol.data_received = data_received
        self._run_test(self.t, [
            TransportMock.Write(b"bar",
                                response=TransportMock.Receive(b"baz")),
            TransportMock.Close()
        ],
                       stimulus=b"foo")
Пример #9
0
    def test_asynchronous_request_response(self):
        def data_received(data):
            self.assertIn(data, {b"foo", b"baz"})
            if data == b"foo":
                self.loop.call_soon(self.t.write, b"bar")
            elif data == b"baz":
                self.loop.call_soon(self.t.close)

        self.protocol.data_received = data_received
        self._run_test(self.t, [
            TransportMock.Write(b"bar",
                                response=TransportMock.Receive(b"baz")),
            TransportMock.Close()
        ],
                       stimulus=b"foo")
Пример #10
0
    def test_catch_unexpected_close(self):
        def connection_made(transport):
            transport.close()

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(AssertionError, "unexpected close"):
            self._run_test(self.t, [TransportMock.Write(b"foo")])
Пример #11
0
    def test_catch_asynchronous_invalid_action(self):
        def connection_made(transport):
            self.loop.call_soon(transport.close)

        self.protocol.connection_made = connection_made
        with self.assertRaises(AssertionError):
            self._run_test(self.t, [TransportMock.Write(b"foo")])
Пример #12
0
    def test_abort(self):
        def connection_made(transport):
            transport.abort()

        self.protocol.connection_made = connection_made

        self._run_test(self.t, [TransportMock.Abort()])
Пример #13
0
    def test_write_eof(self):
        def connection_made(transport):
            transport.write_eof()

        self.protocol.connection_made = connection_made

        self._run_test(self.t, [TransportMock.WriteEof()])
Пример #14
0
    def test_writelines(self):
        def connection_made(transport):
            transport.writelines([b"foo", b"bar"])

        self.protocol.connection_made = connection_made

        self._run_test(self.t, [TransportMock.Write(b"foobar")])
Пример #15
0
    def test_catch_invalid_write(self):
        def connection_made(transport):
            transport.write(b"fnord")

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(AssertionError,
                                    "mismatch of expected and written data"):
            self._run_test(self.t, [TransportMock.Write(b"foo")])
Пример #16
0
    def test_catch_unexpected_abort(self):
        def connection_made(transport):
            transport.abort()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(AssertionError, "unexpected abort"):
            self._run_test(self.t, [TransportMock.WriteEof()])
Пример #17
0
    def test_allow_asynchronous_partial_write(self):
        def connection_made(transport):
            self.loop.call_soon(transport.write, b"f")
            self.loop.call_soon(transport.write, b"o")
            self.loop.call_soon(transport.write, b"o")

        self.protocol.connection_made = connection_made
        self._run_test(self.t, [TransportMock.Write(b"foo")])
Пример #18
0
    def test_clear_error_message(self):
        def connection_made(transport):
            transport.write(b"foo")
            transport.write(b"bar")

        self.protocol.connection_made = connection_made

        with self.assertRaises(AssertionError):
            self._run_test(self.t, [TransportMock.Write(b"baz")])
Пример #19
0
    def test_invalid_response(self):
        def connection_made(transport):
            transport.write(b"foo")

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(RuntimeError,
                                    "test specification incorrect"):
            self._run_test(self.t, [TransportMock.Write(b"foo", response=1)])
Пример #20
0
    def test_response_sequence(self):
        def connection_made(transport):
            transport.write(b"foo")

        self.protocol.connection_made = connection_made

        self._run_test(self.t, [
            TransportMock.Write(b"foo",
                                response=[
                                    TransportMock.Receive(b"foo"),
                                    TransportMock.ReceiveEof()
                                ])
        ])

        self.assertSequenceEqual(self.protocol.mock_calls, [
            unittest.mock.call.data_received(b"foo"),
            unittest.mock.call.eof_received(),
            unittest.mock.call.connection_lost(None),
        ])
Пример #21
0
    def test_exception_from_stimulus_bubbles_up(self):
        exc = ConnectionError("foobar")

        def data_received(data):
            raise exc

        self.protocol.data_received = data_received

        with self.assertRaises(ConnectionError) as ctx:
            run_coroutine(
                self.t.run_test([], stimulus=TransportMock.Receive(b"foobar")))

        self.assertIs(exc, ctx.exception)
Пример #22
0
    def test_iq_results_are_not_replied_to(self):
        import aioxmpp.protocol
        import aioxmpp.stream

        version = (1, 0)

        fut = asyncio.Future()
        p = aioxmpp.protocol.XMLStream(to=TEST_PEER,
                                       sorted_attributes=True,
                                       features_future=fut)
        t = TransportMock(self, p)
        s = aioxmpp.stream.StanzaStream(TEST_FROM.bare())

        run_coroutine(
            t.run_test([
                TransportMock.Write(
                    STREAM_HEADER,
                    response=[
                        TransportMock.Receive(
                            PEER_STREAM_HEADER_TEMPLATE.format(
                                minor=version[1],
                                major=version[0]).encode("utf-8")),
                    ]),
            ],
                       partial=True))

        self.assertEqual(p.state, aioxmpp.protocol.State.OPEN)

        s.start(p)

        run_coroutine(
            t.run_test(
                [],
                stimulus=[
                    TransportMock.Receive(b'<iq type="result" id="foo">'
                                          b'<payload xmlns="fnord"/>'
                                          b'</iq>')
                ],
                partial=True,
            ))

        s.flush_incoming()

        run_coroutine(asyncio.sleep(0))

        run_coroutine(t.run_test([], ))

        s.stop()
Пример #23
0
    def test_sm_bootstrap_race(self):
        import aioxmpp.protocol
        import aioxmpp.stream

        version = (1, 0)

        fut = asyncio.Future()
        p = aioxmpp.protocol.XMLStream(to=TEST_PEER,
                                       sorted_attributes=True,
                                       features_future=fut)
        t = TransportMock(self, p)
        s = aioxmpp.stream.StanzaStream(TEST_FROM.bare())
        s.soft_timeout = timedelta(seconds=0.25)

        run_coroutine(
            t.run_test([
                TransportMock.Write(
                    STREAM_HEADER,
                    response=[
                        TransportMock.Receive(
                            PEER_STREAM_HEADER_TEMPLATE.format(
                                minor=version[1],
                                major=version[0]).encode("utf-8")),
                        TransportMock.Receive(
                            b"<stream:features><sm xmlns='urn:xmpp:sm:3'/>"
                            b"</stream:features>")
                    ]),
            ],
                       partial=True))

        self.assertEqual(p.state, aioxmpp.protocol.State.OPEN)

        self.assertTrue(fut.done())

        s.start(p)
        run_coroutine_with_peer(
            s.start_sm(),
            t.run_test([
                TransportMock.Write(
                    b'<enable xmlns="urn:xmpp:sm:3" resume="true"/>',
                    response=[
                        TransportMock.Receive(
                            b'<enabled xmlns="urn:xmpp:sm:3" '
                            b'resume="true" id="foo"/>'
                            b'<r xmlns="urn:xmpp:sm:3"/>')
                    ])
            ],
                       partial=True))

        self.assertTrue(s.sm_enabled)
        self.assertEqual(s.sm_id, "foo")
        self.assertTrue(s.sm_resumable)

        run_coroutine(
            t.run_test([
                TransportMock.Write(
                    b'<a xmlns="urn:xmpp:sm:3" h="0"/>'
                    b'<r xmlns="urn:xmpp:sm:3"/>',
                    response=[
                        TransportMock.Receive(
                            b'<a xmlns="urn:xmpp:sm:3" h="0"/>', ),
                        TransportMock.Receive(b'<r xmlns="urn:xmpp:sm:3"/>', )
                    ]),
                TransportMock.Write(b'<a xmlns="urn:xmpp:sm:3" h="0"/>')
            ],
                       partial=True))
Пример #24
0
class TestTransportMock(unittest.TestCase):
    def setUp(self):
        self.protocol = make_protocol_mock()
        self.loop = asyncio.get_event_loop()
        self.t = TransportMock(self, self.protocol, loop=self.loop)

    def _run_test(self, t, *args, **kwargs):
        return run_coroutine(t.run_test(*args, **kwargs), loop=self.loop)

    def test_run_test(self):
        self._run_test(self.t, [])
        self.assertSequenceEqual(
            self.protocol.mock_calls,
            [
                unittest.mock.call.connection_made(self.t),
                unittest.mock.call.connection_lost(None),
            ])

    def test_stimulus(self):
        self._run_test(self.t, [], stimulus=b"foo")
        self.assertSequenceEqual(
            self.protocol.mock_calls,
            [
                unittest.mock.call.connection_made(self.t),
                unittest.mock.call.data_received(b"foo"),
                unittest.mock.call.connection_lost(None),
            ])

    def test_request_response(self):
        def data_received(data):
            assert data in {b"foo", b"baz"}
            if data == b"foo":
                self.t.write(b"bar")
            elif data == b"baz":
                self.t.close()

        self.protocol.data_received = data_received
        self._run_test(
            self.t,
            [
                TransportMock.Write(
                    b"bar",
                    response=TransportMock.Receive(b"baz")),
                TransportMock.Close()
            ],
            stimulus=b"foo")

    def test_request_multiresponse(self):
        def data_received(data):
            assert data in {b"foo", b"bar", b"baz"}
            if data == b"foo":
                self.t.write(b"bar")
            elif data == b"bar":
                self.t.write(b"baric")
            elif data == b"baz":
                self.t.close()

        self.protocol.data_received = data_received
        self._run_test(
            self.t,
            [
                TransportMock.Write(
                    b"bar",
                    response=[
                        TransportMock.Receive(b"bar"),
                        TransportMock.Receive(b"baz")
                    ]),
                TransportMock.Write(b"baric"),
                TransportMock.Close()
            ],
            stimulus=b"foo")

    def test_catch_asynchronous_invalid_action(self):
        def connection_made(transport):
            self.loop.call_soon(transport.close)

        self.protocol.connection_made = connection_made
        with self.assertRaises(AssertionError):
            self._run_test(
                self.t,
                [
                    TransportMock.Write(b"foo")
                ])

    def test_catch_invalid_write(self):
        def connection_made(transport):
            transport.write(b"fnord")

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(
                AssertionError,
                "mismatch of expected and written data"):
            self._run_test(
                self.t,
                [
                    TransportMock.Write(b"foo")
                ])

    def test_catch_surplus_write(self):
        def connection_made(transport):
            transport.write(b"fnord")

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(AssertionError, "unexpected write"):
            self._run_test(
                self.t,
                [
                ])

    def test_catch_unexpected_close(self):
        def connection_made(transport):
            transport.close()

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(AssertionError, "unexpected close"):
            self._run_test(
                self.t,
                [
                    TransportMock.Write(b"foo")
                ])

    def test_catch_surplus_close(self):
        def connection_made(transport):
            transport.close()

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(AssertionError, "unexpected close"):
            self._run_test(
                self.t,
                [
                ])

    def test_allow_asynchronous_partial_write(self):
        def connection_made(transport):
            self.loop.call_soon(transport.write, b"f")
            self.loop.call_soon(transport.write, b"o")
            self.loop.call_soon(transport.write, b"o")

        self.protocol.connection_made = connection_made
        self._run_test(
            self.t,
            [
                TransportMock.Write(b"foo")
            ])

    def test_asynchronous_request_response(self):
        def data_received(data):
            self.assertIn(data, {b"foo", b"baz"})
            if data == b"foo":
                self.loop.call_soon(self.t.write, b"bar")
            elif data == b"baz":
                self.loop.call_soon(self.t.close)

        self.protocol.data_received = data_received
        self._run_test(
            self.t,
            [
                TransportMock.Write(
                    b"bar",
                    response=TransportMock.Receive(b"baz")),
                TransportMock.Close()
            ],
            stimulus=b"foo")

    def test_response_eof_received(self):
        def connection_made(transport):
            transport.close()

        self.protocol.connection_made = connection_made
        self._run_test(
            self.t,
            [
                TransportMock.Close(
                    response=TransportMock.ReceiveEof()
                )
            ])
        self.assertSequenceEqual(
            self.protocol.mock_calls,
            [
                unittest.mock.call.eof_received(),
                unittest.mock.call.connection_lost(None)
            ])

    def test_response_lose_connection(self):
        def connection_made(transport):
            transport.close()

        obj = object()

        self.protocol.connection_made = connection_made
        self._run_test(
            self.t,
            [
                TransportMock.Close(
                    response=TransportMock.LoseConnection(obj)
                )
            ])
        self.assertSequenceEqual(
            self.protocol.mock_calls,
            [
                unittest.mock.call.connection_lost(obj)
            ])

    def test_writelines(self):
        def connection_made(transport):
            transport.writelines([b"foo", b"bar"])

        self.protocol.connection_made = connection_made

        self._run_test(
            self.t,
            [
                TransportMock.Write(b"foobar")
            ])

    def test_can_write_eof(self):
        self.assertTrue(self.t.can_write_eof())

    def test_abort(self):
        def connection_made(transport):
            transport.abort()

        self.protocol.connection_made = connection_made

        self._run_test(
            self.t,
            [
                TransportMock.Abort()
            ])

    def test_write_eof(self):
        def connection_made(transport):
            transport.write_eof()

        self.protocol.connection_made = connection_made

        self._run_test(
            self.t,
            [
                TransportMock.WriteEof()
            ])

    def test_catch_surplus_write_eof(self):
        def connection_made(transport):
            transport.write_eof()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(
                AssertionError,
                "unexpected write_eof"):
            self._run_test(
                self.t,
                [
                ])

    def test_catch_unexpected_write_eof(self):
        def connection_made(transport):
            transport.write_eof()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(
                AssertionError,
                "unexpected write_eof"):
            self._run_test(
                self.t,
                [
                    TransportMock.Abort()
                ])

    def test_catch_surplus_abort(self):
        def connection_made(transport):
            transport.abort()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(
                AssertionError,
                "unexpected abort"):
            self._run_test(
                self.t,
                [
                ])

    def test_catch_unexpected_abort(self):
        def connection_made(transport):
            transport.abort()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(
                AssertionError,
                "unexpected abort"):
            self._run_test(
                self.t,
                [
                    TransportMock.WriteEof()
                ])

    def test_invalid_response(self):
        def connection_made(transport):
            transport.write(b"foo")

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(
                RuntimeError,
                "test specification incorrect"):
            self._run_test(
                self.t,
                [
                    TransportMock.Write(
                        b"foo",
                        response=1)
                ])

    def test_response_sequence(self):
        def connection_made(transport):
            transport.write(b"foo")

        self.protocol.connection_made = connection_made

        self._run_test(
            self.t,
            [
                TransportMock.Write(
                    b"foo",
                    response=[
                        TransportMock.Receive(b"foo"),
                        TransportMock.ReceiveEof()
                    ])
            ])

        self.assertSequenceEqual(
            self.protocol.mock_calls,
            [
                unittest.mock.call.data_received(b"foo"),
                unittest.mock.call.eof_received(),
                unittest.mock.call.connection_lost(None),
            ])

    def test_clear_error_message(self):
        def connection_made(transport):
            transport.write(b"foo")
            transport.write(b"bar")

        self.protocol.connection_made = connection_made

        with self.assertRaises(AssertionError):
            self._run_test(
                self.t,
                [
                    TransportMock.Write(b"baz")
                ])

    def test_detached_response(self):
        data = []

        def data_received(blob):
            data.append(blob)

        def connection_made(transport):
            transport.write(b"foo")
            self.assertFalse(data)

        self.protocol.connection_made = connection_made
        self.protocol.data_received = data_received

        self._run_test(
            self.t,
            [
                TransportMock.Write(
                    b"foo",
                    response=TransportMock.Receive(b"bar")
                )
            ])

    def test_no_response_conflict(self):
        data = []

        def data_received(blob):
            data.append(blob)

        def connection_made(transport):
            transport.write(b"foo")
            self.assertFalse(data)
            transport.write(b"bar")

        self.protocol.connection_made = connection_made
        self.protocol.data_received = data_received

        self._run_test(
            self.t,
            [
                TransportMock.Write(
                    b"foo",
                    response=TransportMock.Receive(b"baz"),
                ),
                TransportMock.Write(
                    b"bar",
                    response=TransportMock.Receive(b"baz")
                )
            ])

    def test_partial(self):
        def connection_made(transport):
            transport.write(b"foo")

        self.protocol.connection_made = connection_made

        self._run_test(
            self.t,
            [
                TransportMock.Write(
                    b"foo",
                ),
            ],
            partial=True
        )

        self.t.write_eof()
        self.t.close()

        self._run_test(
            self.t,
            [
                TransportMock.WriteEof(),
                TransportMock.Close()
            ]
        )

    def test_no_starttls_by_default(self):
        self.assertFalse(self.t.can_starttls())
        with self.assertRaises(RuntimeError):
            run_coroutine(self.t.starttls())

    def test_starttls(self):
        self.t = TransportMock(self, self.protocol,
                               with_starttls=True,
                               loop=self.loop)
        self.assertTrue(self.t.can_starttls())

        fut = asyncio.Future()
        def connection_made(transport):
            fut.set_result(None)

        self.protocol.connection_made = connection_made

        ssl_context = unittest.mock.Mock()
        post_handshake_callback = unittest.mock.Mock()
        post_handshake_callback.return_value = []

        @asyncio.coroutine
        def late_starttls():
            yield from fut
            yield from self.t.starttls(ssl_context,
                                       post_handshake_callback)

        run_coroutine_with_peer(
            late_starttls(),
            self.t.run_test(
                [
                    TransportMock.STARTTLS(ssl_context,
                                           post_handshake_callback)
                ]
            )
        )

        post_handshake_callback.assert_called_once_with(self.t)

    def test_starttls_without_callback(self):
        self.t = TransportMock(self, self.protocol,
                               with_starttls=True,
                               loop=self.loop)
        self.assertTrue(self.t.can_starttls())

        fut = asyncio.Future()
        def connection_made(transport):
            fut.set_result(None)

        self.protocol.connection_made = connection_made

        ssl_context = unittest.mock.Mock()

        @asyncio.coroutine
        def late_starttls():
            yield from fut
            yield from self.t.starttls(ssl_context)

        run_coroutine_with_peer(
            late_starttls(),
            self.t.run_test(
                [
                    TransportMock.STARTTLS(ssl_context, None)
                ]
            )
        )

    def test_exception_from_stimulus_bubbles_up(self):
        exc = ConnectionError("foobar")

        def data_received(data):
            raise exc

        self.protocol.data_received = data_received

        with self.assertRaises(ConnectionError) as ctx:
            run_coroutine(
                self.t.run_test(
                    [
                    ],
                    stimulus=TransportMock.Receive(b"foobar")
                )
            )

        self.assertIs(
            exc,
            ctx.exception
        )

    def tearDown(self):
        del self.t
        del self.loop
        del self.protocol
Пример #25
0
 def setUp(self):
     self.protocol = make_protocol_mock()
     self.loop = asyncio.get_event_loop()
     self.t = TransportMock(self, self.protocol, loop=self.loop)
Пример #26
0
class TestTransportMock(unittest.TestCase):
    def setUp(self):
        self.protocol = make_protocol_mock()
        self.loop = asyncio.get_event_loop()
        self.t = TransportMock(self, self.protocol, loop=self.loop)

    def _run_test(self, t, *args, **kwargs):
        return run_coroutine(t.run_test(*args, **kwargs), loop=self.loop)

    def test_run_test(self):
        self._run_test(self.t, [])
        self.assertSequenceEqual(self.protocol.mock_calls, [
            unittest.mock.call.connection_made(self.t),
            unittest.mock.call.connection_lost(None),
        ])

    def test_stimulus(self):
        self._run_test(self.t, [], stimulus=b"foo")
        self.assertSequenceEqual(self.protocol.mock_calls, [
            unittest.mock.call.connection_made(self.t),
            unittest.mock.call.data_received(b"foo"),
            unittest.mock.call.connection_lost(None),
        ])

    def test_request_response(self):
        def data_received(data):
            assert data in {b"foo", b"baz"}
            if data == b"foo":
                self.t.write(b"bar")
            elif data == b"baz":
                self.t.close()

        self.protocol.data_received = data_received
        self._run_test(self.t, [
            TransportMock.Write(b"bar",
                                response=TransportMock.Receive(b"baz")),
            TransportMock.Close()
        ],
                       stimulus=b"foo")

    def test_request_multiresponse(self):
        def data_received(data):
            assert data in {b"foo", b"bar", b"baz"}
            if data == b"foo":
                self.t.write(b"bar")
            elif data == b"bar":
                self.t.write(b"baric")
            elif data == b"baz":
                self.t.close()

        self.protocol.data_received = data_received
        self._run_test(self.t, [
            TransportMock.Write(b"bar",
                                response=[
                                    TransportMock.Receive(b"bar"),
                                    TransportMock.Receive(b"baz")
                                ]),
            TransportMock.Write(b"baric"),
            TransportMock.Close()
        ],
                       stimulus=b"foo")

    def test_catch_asynchronous_invalid_action(self):
        def connection_made(transport):
            self.loop.call_soon(transport.close)

        self.protocol.connection_made = connection_made
        with self.assertRaises(AssertionError):
            self._run_test(self.t, [TransportMock.Write(b"foo")])

    def test_catch_invalid_write(self):
        def connection_made(transport):
            transport.write(b"fnord")

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(AssertionError,
                                    "mismatch of expected and written data"):
            self._run_test(self.t, [TransportMock.Write(b"foo")])

    def test_catch_surplus_write(self):
        def connection_made(transport):
            transport.write(b"fnord")

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(AssertionError, "unexpected write"):
            self._run_test(self.t, [])

    def test_catch_unexpected_close(self):
        def connection_made(transport):
            transport.close()

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(AssertionError, "unexpected close"):
            self._run_test(self.t, [TransportMock.Write(b"foo")])

    def test_catch_surplus_close(self):
        def connection_made(transport):
            transport.close()

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(AssertionError, "unexpected close"):
            self._run_test(self.t, [])

    def test_allow_asynchronous_partial_write(self):
        def connection_made(transport):
            self.loop.call_soon(transport.write, b"f")
            self.loop.call_soon(transport.write, b"o")
            self.loop.call_soon(transport.write, b"o")

        self.protocol.connection_made = connection_made
        self._run_test(self.t, [TransportMock.Write(b"foo")])

    def test_asynchronous_request_response(self):
        def data_received(data):
            self.assertIn(data, {b"foo", b"baz"})
            if data == b"foo":
                self.loop.call_soon(self.t.write, b"bar")
            elif data == b"baz":
                self.loop.call_soon(self.t.close)

        self.protocol.data_received = data_received
        self._run_test(self.t, [
            TransportMock.Write(b"bar",
                                response=TransportMock.Receive(b"baz")),
            TransportMock.Close()
        ],
                       stimulus=b"foo")

    def test_response_eof_received(self):
        def connection_made(transport):
            transport.close()

        self.protocol.connection_made = connection_made
        self._run_test(
            self.t, [TransportMock.Close(response=TransportMock.ReceiveEof())])
        self.assertSequenceEqual(self.protocol.mock_calls, [
            unittest.mock.call.eof_received(),
            unittest.mock.call.connection_lost(None)
        ])

    def test_response_lose_connection(self):
        def connection_made(transport):
            transport.close()

        obj = object()

        self.protocol.connection_made = connection_made
        self._run_test(
            self.t,
            [TransportMock.Close(response=TransportMock.LoseConnection(obj))])
        self.assertSequenceEqual(self.protocol.mock_calls,
                                 [unittest.mock.call.connection_lost(obj)])

    def test_writelines(self):
        def connection_made(transport):
            transport.writelines([b"foo", b"bar"])

        self.protocol.connection_made = connection_made

        self._run_test(self.t, [TransportMock.Write(b"foobar")])

    def test_can_write_eof(self):
        self.assertTrue(self.t.can_write_eof())

    def test_abort(self):
        def connection_made(transport):
            transport.abort()

        self.protocol.connection_made = connection_made

        self._run_test(self.t, [TransportMock.Abort()])

    def test_write_eof(self):
        def connection_made(transport):
            transport.write_eof()

        self.protocol.connection_made = connection_made

        self._run_test(self.t, [TransportMock.WriteEof()])

    def test_catch_surplus_write_eof(self):
        def connection_made(transport):
            transport.write_eof()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(AssertionError, "unexpected write_eof"):
            self._run_test(self.t, [])

    def test_catch_unexpected_write_eof(self):
        def connection_made(transport):
            transport.write_eof()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(AssertionError, "unexpected write_eof"):
            self._run_test(self.t, [TransportMock.Abort()])

    def test_catch_surplus_abort(self):
        def connection_made(transport):
            transport.abort()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(AssertionError, "unexpected abort"):
            self._run_test(self.t, [])

    def test_catch_unexpected_abort(self):
        def connection_made(transport):
            transport.abort()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(AssertionError, "unexpected abort"):
            self._run_test(self.t, [TransportMock.WriteEof()])

    def test_invalid_response(self):
        def connection_made(transport):
            transport.write(b"foo")

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(RuntimeError,
                                    "test specification incorrect"):
            self._run_test(self.t, [TransportMock.Write(b"foo", response=1)])

    def test_response_sequence(self):
        def connection_made(transport):
            transport.write(b"foo")

        self.protocol.connection_made = connection_made

        self._run_test(self.t, [
            TransportMock.Write(b"foo",
                                response=[
                                    TransportMock.Receive(b"foo"),
                                    TransportMock.ReceiveEof()
                                ])
        ])

        self.assertSequenceEqual(self.protocol.mock_calls, [
            unittest.mock.call.data_received(b"foo"),
            unittest.mock.call.eof_received(),
            unittest.mock.call.connection_lost(None),
        ])

    def test_clear_error_message(self):
        def connection_made(transport):
            transport.write(b"foo")
            transport.write(b"bar")

        self.protocol.connection_made = connection_made

        with self.assertRaises(AssertionError):
            self._run_test(self.t, [TransportMock.Write(b"baz")])

    def test_detached_response(self):
        data = []

        def data_received(blob):
            data.append(blob)

        def connection_made(transport):
            transport.write(b"foo")
            self.assertFalse(data)

        self.protocol.connection_made = connection_made
        self.protocol.data_received = data_received

        self._run_test(self.t, [
            TransportMock.Write(b"foo", response=TransportMock.Receive(b"bar"))
        ])

    def test_no_response_conflict(self):
        data = []

        def data_received(blob):
            data.append(blob)

        def connection_made(transport):
            transport.write(b"foo")
            self.assertFalse(data)
            transport.write(b"bar")

        self.protocol.connection_made = connection_made
        self.protocol.data_received = data_received

        self._run_test(self.t, [
            TransportMock.Write(
                b"foo",
                response=TransportMock.Receive(b"baz"),
            ),
            TransportMock.Write(b"bar", response=TransportMock.Receive(b"baz"))
        ])

    def test_partial(self):
        def connection_made(transport):
            transport.write(b"foo")

        self.protocol.connection_made = connection_made

        self._run_test(self.t, [
            TransportMock.Write(b"foo", ),
        ], partial=True)

        self.t.write_eof()
        self.t.close()

        self._run_test(self.t,
                       [TransportMock.WriteEof(),
                        TransportMock.Close()])

    def test_no_starttls_by_default(self):
        self.assertFalse(self.t.can_starttls())
        with self.assertRaises(RuntimeError):
            run_coroutine(self.t.starttls())

    def test_starttls(self):
        self.t = TransportMock(self,
                               self.protocol,
                               with_starttls=True,
                               loop=self.loop)
        self.assertTrue(self.t.can_starttls())

        fut = asyncio.Future()

        def connection_made(transport):
            fut.set_result(None)

        self.protocol.connection_made = connection_made

        ssl_context = unittest.mock.Mock()
        post_handshake_callback = CoroutineMock()
        post_handshake_callback.return_value = None

        async def late_starttls():
            await fut
            await self.t.starttls(ssl_context, post_handshake_callback)

        run_coroutine_with_peer(
            late_starttls(),
            self.t.run_test(
                [TransportMock.STARTTLS(ssl_context,
                                        post_handshake_callback)]))

        post_handshake_callback.assert_called_once_with(self.t)

    def test_starttls_without_callback(self):
        self.t = TransportMock(self,
                               self.protocol,
                               with_starttls=True,
                               loop=self.loop)
        self.assertTrue(self.t.can_starttls())

        fut = asyncio.Future()

        def connection_made(transport):
            fut.set_result(None)

        self.protocol.connection_made = connection_made

        ssl_context = unittest.mock.Mock()

        async def late_starttls():
            await fut
            await self.t.starttls(ssl_context)

        run_coroutine_with_peer(
            late_starttls(),
            self.t.run_test([TransportMock.STARTTLS(ssl_context, None)]))

    def test_exception_from_stimulus_bubbles_up(self):
        exc = ConnectionError("foobar")

        def data_received(data):
            raise exc

        self.protocol.data_received = data_received

        with self.assertRaises(ConnectionError) as ctx:
            run_coroutine(
                self.t.run_test([], stimulus=TransportMock.Receive(b"foobar")))

        self.assertIs(exc, ctx.exception)

    def tearDown(self):
        del self.t
        del self.loop
        del self.protocol
Пример #27
0
    def test_sm_works_correctly_with_invalid_payload(self):
        import aioxmpp.protocol
        import aioxmpp.stream

        version = (1, 0)

        fut = asyncio.Future()
        p = aioxmpp.protocol.XMLStream(
            to=TEST_PEER,
            sorted_attributes=True,
            features_future=fut)
        t = TransportMock(self, p)
        s = aioxmpp.stream.StanzaStream(TEST_FROM.bare())

        run_coroutine(t.run_test(
            [
                TransportMock.Write(
                    STREAM_HEADER,
                    response=[
                        TransportMock.Receive(
                            PEER_STREAM_HEADER_TEMPLATE.format(
                                minor=version[1],
                                major=version[0]).encode("utf-8")),
                        TransportMock.Receive(
                            b"<stream:features><sm xmlns='urn:xmpp:sm:3'/>"
                            b"</stream:features>"
                        )
                    ]
                ),
            ],
            partial=True
        ))

        self.assertEqual(p.state, aioxmpp.protocol.State.OPEN)

        self.assertTrue(fut.done())

        s.ping_interval = timedelta(seconds=0.25)
        s.ping_opportunistic_interval = timedelta(seconds=0.25)

        s.start(p)
        run_coroutine_with_peer(
            s.start_sm(),
            t.run_test(
                [
                    TransportMock.Write(
                        b'<enable xmlns="urn:xmpp:sm:3" resume="true"/>',
                        response=[
                            TransportMock.Receive(
                                b'<enabled xmlns="urn:xmpp:sm:3" '
                                b'resume="true" id="foo"/>'
                            )
                        ]
                    )
                ],
                partial=True
            )
        )

        self.assertTrue(s.sm_enabled)
        self.assertEqual(s.sm_id, "foo")
        self.assertTrue(s.sm_resumable)

        run_coroutine(
            t.run_test(
                [
                    TransportMock.Write(
                        b'<r xmlns="urn:xmpp:sm:3"/>',
                        response=[
                            TransportMock.Receive(
                                b'<a xmlns="urn:xmpp:sm:3" h="0"/>',
                            ),
                            TransportMock.Receive(
                                b'<r xmlns="urn:xmpp:sm:3"/>',
                            )
                        ]
                    ),
                    TransportMock.Write(
                        b'<a xmlns="urn:xmpp:sm:3" h="0"/>'
                    )
                ],
                partial=True
            )
        )

        run_coroutine(
            t.run_test(
                [
                    TransportMock.Write(
                        b'<iq id="foo" type="error"><error type="cancel">'
                        b'<feature-not-implemented'
                        b' xmlns="urn:ietf:params:xml:ns:xmpp-stanzas"/>'
                        b'</error></iq>'
                    ),
                    TransportMock.Write(
                        b'<r xmlns="urn:xmpp:sm:3"/>',
                        response=[
                            TransportMock.Receive(
                                b'<a xmlns="urn:xmpp:sm:3" h="1"/>',
                            ),
                            TransportMock.Receive(
                                b'<r xmlns="urn:xmpp:sm:3"/>',
                            )
                        ]
                    ),
                    TransportMock.Write(
                        b'<a xmlns="urn:xmpp:sm:3" h="1"/>'
                    )
                ],
                stimulus=[
                    TransportMock.Receive(
                        b'<iq type="get" id="foo">'
                        b'<payload xmlns="fnord"/>'
                        b'</iq>'
                    )
                ],
                partial=True
            )
        )
Пример #28
0
    def test_hard_deadtime_kills_stream(self):
        import aioxmpp.protocol
        import aioxmpp.stream

        version = (1, 0)

        fut = asyncio.Future()
        p = aioxmpp.protocol.XMLStream(to=TEST_PEER,
                                       sorted_attributes=True,
                                       features_future=fut)
        t = TransportMock(self, p)
        s = aioxmpp.stream.StanzaStream(TEST_FROM.bare())
        s.soft_timeout = timedelta(seconds=0.1)
        s.round_trip_time = timedelta(seconds=0.1)

        failure_fut = s.on_failure.future()

        run_coroutine(
            t.run_test([
                TransportMock.Write(
                    STREAM_HEADER,
                    response=[
                        TransportMock.Receive(
                            PEER_STREAM_HEADER_TEMPLATE.format(
                                minor=version[1],
                                major=version[0]).encode("utf-8")),
                    ]),
            ],
                       partial=True))

        self.assertEqual(p.state, aioxmpp.protocol.State.OPEN)

        s.start(p)

        IQ_bak = aioxmpp.IQ

        def fake_iq_constructor(*args, **kwargs):
            iq = IQ_bak(*args, **kwargs)
            iq.id_ = "ping"
            return iq

        with unittest.mock.patch("aioxmpp.stanza.IQ") as IQ:
            IQ.side_effect = fake_iq_constructor

            run_coroutine(
                t.run_test([
                    TransportMock.Write(b'<iq id="ping" type="get">'
                                        b'<ping xmlns="urn:xmpp:ping"/></iq>'),
                ],
                           partial=True))

        run_coroutine(t.run_test([
            TransportMock.Abort(),
        ], ))

        run_coroutine(asyncio.sleep(0))

        self.assertFalse(s.running)

        self.assertTrue(failure_fut.done())
        self.assertIsInstance(failure_fut.exception(), ConnectionError)
        self.assertIn("timeout", str(failure_fut.exception()))
Пример #29
0
    def test_sm_works_correctly_with_invalid_payload(self):
        import aioxmpp.protocol
        import aioxmpp.stream

        version = (1, 0)

        fut = asyncio.Future()
        p = aioxmpp.protocol.XMLStream(to=TEST_PEER,
                                       sorted_attributes=True,
                                       features_future=fut)
        t = TransportMock(self, p)
        s = aioxmpp.stream.StanzaStream(TEST_FROM.bare())
        s.soft_timeout = timedelta(seconds=0.25)

        run_coroutine(
            t.run_test([
                TransportMock.Write(
                    STREAM_HEADER,
                    response=[
                        TransportMock.Receive(
                            PEER_STREAM_HEADER_TEMPLATE.format(
                                minor=version[1],
                                major=version[0]).encode("utf-8")),
                        TransportMock.Receive(
                            b"<stream:features><sm xmlns='urn:xmpp:sm:3'/>"
                            b"</stream:features>")
                    ]),
            ],
                       partial=True))

        self.assertEqual(p.state, aioxmpp.protocol.State.OPEN)

        self.assertTrue(fut.done())

        s.start(p)
        run_coroutine_with_peer(
            s.start_sm(),
            t.run_test([
                TransportMock.Write(
                    b'<enable xmlns="urn:xmpp:sm:3" resume="true"/>',
                    response=[
                        TransportMock.Receive(
                            b'<enabled xmlns="urn:xmpp:sm:3" '
                            b'resume="true" id="foo"/>')
                    ])
            ],
                       partial=True))

        self.assertTrue(s.sm_enabled)
        self.assertEqual(s.sm_id, "foo")
        self.assertTrue(s.sm_resumable)

        run_coroutine(
            t.run_test([
                TransportMock.Write(
                    b'<r xmlns="urn:xmpp:sm:3"/>',
                    response=[
                        TransportMock.Receive(
                            b'<a xmlns="urn:xmpp:sm:3" h="0"/>', ),
                        TransportMock.Receive(b'<r xmlns="urn:xmpp:sm:3"/>', )
                    ]),
                TransportMock.Write(b'<a xmlns="urn:xmpp:sm:3" h="0"/>')
            ],
                       partial=True))

        run_coroutine(
            t.run_test([
                TransportMock.Write(
                    b'<iq id="foo" type="error"><error type="cancel">'
                    b'<service-unavailable'
                    b' xmlns="urn:ietf:params:xml:ns:xmpp-stanzas"/>'
                    b'</error></iq>'),
                TransportMock.Write(
                    b'<r xmlns="urn:xmpp:sm:3"/>',
                    response=[
                        TransportMock.Receive(
                            b'<a xmlns="urn:xmpp:sm:3" h="1"/>', ),
                        TransportMock.Receive(b'<r xmlns="urn:xmpp:sm:3"/>', )
                    ]),
                TransportMock.Write(b'<a xmlns="urn:xmpp:sm:3" h="1"/>')
            ],
                       stimulus=[
                           TransportMock.Receive(b'<iq type="get" id="foo">'
                                                 b'<payload xmlns="fnord"/>'
                                                 b'</iq>')
                       ],
                       partial=True))
Пример #30
0
 def setUp(self):
     self.protocol = make_protocol_mock()
     self.loop = asyncio.get_event_loop()
     self.t = TransportMock(self, self.protocol, loop=self.loop)