コード例 #1
0
class TestTransportMock(unittest.TestCase):
    def setUp(self):
        self.protocol = make_protocol_mock()
        self.loop = asyncio.get_event_loop()
        self.t = TransportMock(self, self.protocol, loop=self.loop)

    def _run_test(self, t, *args, **kwargs):
        return run_coroutine(t.run_test(*args, **kwargs), loop=self.loop)

    def test_run_test(self):
        self._run_test(self.t, [])
        self.assertSequenceEqual(self.protocol.mock_calls, [
            unittest.mock.call.connection_made(self.t),
            unittest.mock.call.connection_lost(None),
        ])

    def test_stimulus(self):
        self._run_test(self.t, [], stimulus=b"foo")
        self.assertSequenceEqual(self.protocol.mock_calls, [
            unittest.mock.call.connection_made(self.t),
            unittest.mock.call.data_received(b"foo"),
            unittest.mock.call.connection_lost(None),
        ])

    def test_request_response(self):
        def data_received(data):
            assert data in {b"foo", b"baz"}
            if data == b"foo":
                self.t.write(b"bar")
            elif data == b"baz":
                self.t.close()

        self.protocol.data_received = data_received
        self._run_test(self.t, [
            TransportMock.Write(b"bar",
                                response=TransportMock.Receive(b"baz")),
            TransportMock.Close()
        ],
                       stimulus=b"foo")

    def test_request_multiresponse(self):
        def data_received(data):
            assert data in {b"foo", b"bar", b"baz"}
            if data == b"foo":
                self.t.write(b"bar")
            elif data == b"bar":
                self.t.write(b"baric")
            elif data == b"baz":
                self.t.close()

        self.protocol.data_received = data_received
        self._run_test(self.t, [
            TransportMock.Write(b"bar",
                                response=[
                                    TransportMock.Receive(b"bar"),
                                    TransportMock.Receive(b"baz")
                                ]),
            TransportMock.Write(b"baric"),
            TransportMock.Close()
        ],
                       stimulus=b"foo")

    def test_catch_asynchronous_invalid_action(self):
        def connection_made(transport):
            self.loop.call_soon(transport.close)

        self.protocol.connection_made = connection_made
        with self.assertRaises(AssertionError):
            self._run_test(self.t, [TransportMock.Write(b"foo")])

    def test_catch_invalid_write(self):
        def connection_made(transport):
            transport.write(b"fnord")

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(AssertionError,
                                    "mismatch of expected and written data"):
            self._run_test(self.t, [TransportMock.Write(b"foo")])

    def test_catch_surplus_write(self):
        def connection_made(transport):
            transport.write(b"fnord")

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(AssertionError, "unexpected write"):
            self._run_test(self.t, [])

    def test_catch_unexpected_close(self):
        def connection_made(transport):
            transport.close()

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(AssertionError, "unexpected close"):
            self._run_test(self.t, [TransportMock.Write(b"foo")])

    def test_catch_surplus_close(self):
        def connection_made(transport):
            transport.close()

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(AssertionError, "unexpected close"):
            self._run_test(self.t, [])

    def test_allow_asynchronous_partial_write(self):
        def connection_made(transport):
            self.loop.call_soon(transport.write, b"f")
            self.loop.call_soon(transport.write, b"o")
            self.loop.call_soon(transport.write, b"o")

        self.protocol.connection_made = connection_made
        self._run_test(self.t, [TransportMock.Write(b"foo")])

    def test_asynchronous_request_response(self):
        def data_received(data):
            self.assertIn(data, {b"foo", b"baz"})
            if data == b"foo":
                self.loop.call_soon(self.t.write, b"bar")
            elif data == b"baz":
                self.loop.call_soon(self.t.close)

        self.protocol.data_received = data_received
        self._run_test(self.t, [
            TransportMock.Write(b"bar",
                                response=TransportMock.Receive(b"baz")),
            TransportMock.Close()
        ],
                       stimulus=b"foo")

    def test_response_eof_received(self):
        def connection_made(transport):
            transport.close()

        self.protocol.connection_made = connection_made
        self._run_test(
            self.t, [TransportMock.Close(response=TransportMock.ReceiveEof())])
        self.assertSequenceEqual(self.protocol.mock_calls, [
            unittest.mock.call.eof_received(),
            unittest.mock.call.connection_lost(None)
        ])

    def test_response_lose_connection(self):
        def connection_made(transport):
            transport.close()

        obj = object()

        self.protocol.connection_made = connection_made
        self._run_test(
            self.t,
            [TransportMock.Close(response=TransportMock.LoseConnection(obj))])
        self.assertSequenceEqual(self.protocol.mock_calls,
                                 [unittest.mock.call.connection_lost(obj)])

    def test_writelines(self):
        def connection_made(transport):
            transport.writelines([b"foo", b"bar"])

        self.protocol.connection_made = connection_made

        self._run_test(self.t, [TransportMock.Write(b"foobar")])

    def test_can_write_eof(self):
        self.assertTrue(self.t.can_write_eof())

    def test_abort(self):
        def connection_made(transport):
            transport.abort()

        self.protocol.connection_made = connection_made

        self._run_test(self.t, [TransportMock.Abort()])

    def test_write_eof(self):
        def connection_made(transport):
            transport.write_eof()

        self.protocol.connection_made = connection_made

        self._run_test(self.t, [TransportMock.WriteEof()])

    def test_catch_surplus_write_eof(self):
        def connection_made(transport):
            transport.write_eof()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(AssertionError, "unexpected write_eof"):
            self._run_test(self.t, [])

    def test_catch_unexpected_write_eof(self):
        def connection_made(transport):
            transport.write_eof()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(AssertionError, "unexpected write_eof"):
            self._run_test(self.t, [TransportMock.Abort()])

    def test_catch_surplus_abort(self):
        def connection_made(transport):
            transport.abort()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(AssertionError, "unexpected abort"):
            self._run_test(self.t, [])

    def test_catch_unexpected_abort(self):
        def connection_made(transport):
            transport.abort()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(AssertionError, "unexpected abort"):
            self._run_test(self.t, [TransportMock.WriteEof()])

    def test_invalid_response(self):
        def connection_made(transport):
            transport.write(b"foo")

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(RuntimeError,
                                    "test specification incorrect"):
            self._run_test(self.t, [TransportMock.Write(b"foo", response=1)])

    def test_response_sequence(self):
        def connection_made(transport):
            transport.write(b"foo")

        self.protocol.connection_made = connection_made

        self._run_test(self.t, [
            TransportMock.Write(b"foo",
                                response=[
                                    TransportMock.Receive(b"foo"),
                                    TransportMock.ReceiveEof()
                                ])
        ])

        self.assertSequenceEqual(self.protocol.mock_calls, [
            unittest.mock.call.data_received(b"foo"),
            unittest.mock.call.eof_received(),
            unittest.mock.call.connection_lost(None),
        ])

    def test_clear_error_message(self):
        def connection_made(transport):
            transport.write(b"foo")
            transport.write(b"bar")

        self.protocol.connection_made = connection_made

        with self.assertRaises(AssertionError):
            self._run_test(self.t, [TransportMock.Write(b"baz")])

    def test_detached_response(self):
        data = []

        def data_received(blob):
            data.append(blob)

        def connection_made(transport):
            transport.write(b"foo")
            self.assertFalse(data)

        self.protocol.connection_made = connection_made
        self.protocol.data_received = data_received

        self._run_test(self.t, [
            TransportMock.Write(b"foo", response=TransportMock.Receive(b"bar"))
        ])

    def test_no_response_conflict(self):
        data = []

        def data_received(blob):
            data.append(blob)

        def connection_made(transport):
            transport.write(b"foo")
            self.assertFalse(data)
            transport.write(b"bar")

        self.protocol.connection_made = connection_made
        self.protocol.data_received = data_received

        self._run_test(self.t, [
            TransportMock.Write(
                b"foo",
                response=TransportMock.Receive(b"baz"),
            ),
            TransportMock.Write(b"bar", response=TransportMock.Receive(b"baz"))
        ])

    def test_partial(self):
        def connection_made(transport):
            transport.write(b"foo")

        self.protocol.connection_made = connection_made

        self._run_test(self.t, [
            TransportMock.Write(b"foo", ),
        ], partial=True)

        self.t.write_eof()
        self.t.close()

        self._run_test(self.t,
                       [TransportMock.WriteEof(),
                        TransportMock.Close()])

    def test_no_starttls_by_default(self):
        self.assertFalse(self.t.can_starttls())
        with self.assertRaises(RuntimeError):
            run_coroutine(self.t.starttls())

    def test_starttls(self):
        self.t = TransportMock(self,
                               self.protocol,
                               with_starttls=True,
                               loop=self.loop)
        self.assertTrue(self.t.can_starttls())

        fut = asyncio.Future()

        def connection_made(transport):
            fut.set_result(None)

        self.protocol.connection_made = connection_made

        ssl_context = unittest.mock.Mock()
        post_handshake_callback = CoroutineMock()
        post_handshake_callback.return_value = None

        async def late_starttls():
            await fut
            await self.t.starttls(ssl_context, post_handshake_callback)

        run_coroutine_with_peer(
            late_starttls(),
            self.t.run_test(
                [TransportMock.STARTTLS(ssl_context,
                                        post_handshake_callback)]))

        post_handshake_callback.assert_called_once_with(self.t)

    def test_starttls_without_callback(self):
        self.t = TransportMock(self,
                               self.protocol,
                               with_starttls=True,
                               loop=self.loop)
        self.assertTrue(self.t.can_starttls())

        fut = asyncio.Future()

        def connection_made(transport):
            fut.set_result(None)

        self.protocol.connection_made = connection_made

        ssl_context = unittest.mock.Mock()

        async def late_starttls():
            await fut
            await self.t.starttls(ssl_context)

        run_coroutine_with_peer(
            late_starttls(),
            self.t.run_test([TransportMock.STARTTLS(ssl_context, None)]))

    def test_exception_from_stimulus_bubbles_up(self):
        exc = ConnectionError("foobar")

        def data_received(data):
            raise exc

        self.protocol.data_received = data_received

        with self.assertRaises(ConnectionError) as ctx:
            run_coroutine(
                self.t.run_test([], stimulus=TransportMock.Receive(b"foobar")))

        self.assertIs(exc, ctx.exception)

    def tearDown(self):
        del self.t
        del self.loop
        del self.protocol
コード例 #2
0
ファイル: test_testutils.py プロジェクト: horazont/aioxmpp
class TestTransportMock(unittest.TestCase):
    def setUp(self):
        self.protocol = make_protocol_mock()
        self.loop = asyncio.get_event_loop()
        self.t = TransportMock(self, self.protocol, loop=self.loop)

    def _run_test(self, t, *args, **kwargs):
        return run_coroutine(t.run_test(*args, **kwargs), loop=self.loop)

    def test_run_test(self):
        self._run_test(self.t, [])
        self.assertSequenceEqual(
            self.protocol.mock_calls,
            [
                unittest.mock.call.connection_made(self.t),
                unittest.mock.call.connection_lost(None),
            ])

    def test_stimulus(self):
        self._run_test(self.t, [], stimulus=b"foo")
        self.assertSequenceEqual(
            self.protocol.mock_calls,
            [
                unittest.mock.call.connection_made(self.t),
                unittest.mock.call.data_received(b"foo"),
                unittest.mock.call.connection_lost(None),
            ])

    def test_request_response(self):
        def data_received(data):
            assert data in {b"foo", b"baz"}
            if data == b"foo":
                self.t.write(b"bar")
            elif data == b"baz":
                self.t.close()

        self.protocol.data_received = data_received
        self._run_test(
            self.t,
            [
                TransportMock.Write(
                    b"bar",
                    response=TransportMock.Receive(b"baz")),
                TransportMock.Close()
            ],
            stimulus=b"foo")

    def test_request_multiresponse(self):
        def data_received(data):
            assert data in {b"foo", b"bar", b"baz"}
            if data == b"foo":
                self.t.write(b"bar")
            elif data == b"bar":
                self.t.write(b"baric")
            elif data == b"baz":
                self.t.close()

        self.protocol.data_received = data_received
        self._run_test(
            self.t,
            [
                TransportMock.Write(
                    b"bar",
                    response=[
                        TransportMock.Receive(b"bar"),
                        TransportMock.Receive(b"baz")
                    ]),
                TransportMock.Write(b"baric"),
                TransportMock.Close()
            ],
            stimulus=b"foo")

    def test_catch_asynchronous_invalid_action(self):
        def connection_made(transport):
            self.loop.call_soon(transport.close)

        self.protocol.connection_made = connection_made
        with self.assertRaises(AssertionError):
            self._run_test(
                self.t,
                [
                    TransportMock.Write(b"foo")
                ])

    def test_catch_invalid_write(self):
        def connection_made(transport):
            transport.write(b"fnord")

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(
                AssertionError,
                "mismatch of expected and written data"):
            self._run_test(
                self.t,
                [
                    TransportMock.Write(b"foo")
                ])

    def test_catch_surplus_write(self):
        def connection_made(transport):
            transport.write(b"fnord")

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(AssertionError, "unexpected write"):
            self._run_test(
                self.t,
                [
                ])

    def test_catch_unexpected_close(self):
        def connection_made(transport):
            transport.close()

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(AssertionError, "unexpected close"):
            self._run_test(
                self.t,
                [
                    TransportMock.Write(b"foo")
                ])

    def test_catch_surplus_close(self):
        def connection_made(transport):
            transport.close()

        self.protocol.connection_made = connection_made
        with self.assertRaisesRegex(AssertionError, "unexpected close"):
            self._run_test(
                self.t,
                [
                ])

    def test_allow_asynchronous_partial_write(self):
        def connection_made(transport):
            self.loop.call_soon(transport.write, b"f")
            self.loop.call_soon(transport.write, b"o")
            self.loop.call_soon(transport.write, b"o")

        self.protocol.connection_made = connection_made
        self._run_test(
            self.t,
            [
                TransportMock.Write(b"foo")
            ])

    def test_asynchronous_request_response(self):
        def data_received(data):
            self.assertIn(data, {b"foo", b"baz"})
            if data == b"foo":
                self.loop.call_soon(self.t.write, b"bar")
            elif data == b"baz":
                self.loop.call_soon(self.t.close)

        self.protocol.data_received = data_received
        self._run_test(
            self.t,
            [
                TransportMock.Write(
                    b"bar",
                    response=TransportMock.Receive(b"baz")),
                TransportMock.Close()
            ],
            stimulus=b"foo")

    def test_response_eof_received(self):
        def connection_made(transport):
            transport.close()

        self.protocol.connection_made = connection_made
        self._run_test(
            self.t,
            [
                TransportMock.Close(
                    response=TransportMock.ReceiveEof()
                )
            ])
        self.assertSequenceEqual(
            self.protocol.mock_calls,
            [
                unittest.mock.call.eof_received(),
                unittest.mock.call.connection_lost(None)
            ])

    def test_response_lose_connection(self):
        def connection_made(transport):
            transport.close()

        obj = object()

        self.protocol.connection_made = connection_made
        self._run_test(
            self.t,
            [
                TransportMock.Close(
                    response=TransportMock.LoseConnection(obj)
                )
            ])
        self.assertSequenceEqual(
            self.protocol.mock_calls,
            [
                unittest.mock.call.connection_lost(obj)
            ])

    def test_writelines(self):
        def connection_made(transport):
            transport.writelines([b"foo", b"bar"])

        self.protocol.connection_made = connection_made

        self._run_test(
            self.t,
            [
                TransportMock.Write(b"foobar")
            ])

    def test_can_write_eof(self):
        self.assertTrue(self.t.can_write_eof())

    def test_abort(self):
        def connection_made(transport):
            transport.abort()

        self.protocol.connection_made = connection_made

        self._run_test(
            self.t,
            [
                TransportMock.Abort()
            ])

    def test_write_eof(self):
        def connection_made(transport):
            transport.write_eof()

        self.protocol.connection_made = connection_made

        self._run_test(
            self.t,
            [
                TransportMock.WriteEof()
            ])

    def test_catch_surplus_write_eof(self):
        def connection_made(transport):
            transport.write_eof()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(
                AssertionError,
                "unexpected write_eof"):
            self._run_test(
                self.t,
                [
                ])

    def test_catch_unexpected_write_eof(self):
        def connection_made(transport):
            transport.write_eof()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(
                AssertionError,
                "unexpected write_eof"):
            self._run_test(
                self.t,
                [
                    TransportMock.Abort()
                ])

    def test_catch_surplus_abort(self):
        def connection_made(transport):
            transport.abort()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(
                AssertionError,
                "unexpected abort"):
            self._run_test(
                self.t,
                [
                ])

    def test_catch_unexpected_abort(self):
        def connection_made(transport):
            transport.abort()

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(
                AssertionError,
                "unexpected abort"):
            self._run_test(
                self.t,
                [
                    TransportMock.WriteEof()
                ])

    def test_invalid_response(self):
        def connection_made(transport):
            transport.write(b"foo")

        self.protocol.connection_made = connection_made

        with self.assertRaisesRegex(
                RuntimeError,
                "test specification incorrect"):
            self._run_test(
                self.t,
                [
                    TransportMock.Write(
                        b"foo",
                        response=1)
                ])

    def test_response_sequence(self):
        def connection_made(transport):
            transport.write(b"foo")

        self.protocol.connection_made = connection_made

        self._run_test(
            self.t,
            [
                TransportMock.Write(
                    b"foo",
                    response=[
                        TransportMock.Receive(b"foo"),
                        TransportMock.ReceiveEof()
                    ])
            ])

        self.assertSequenceEqual(
            self.protocol.mock_calls,
            [
                unittest.mock.call.data_received(b"foo"),
                unittest.mock.call.eof_received(),
                unittest.mock.call.connection_lost(None),
            ])

    def test_clear_error_message(self):
        def connection_made(transport):
            transport.write(b"foo")
            transport.write(b"bar")

        self.protocol.connection_made = connection_made

        with self.assertRaises(AssertionError):
            self._run_test(
                self.t,
                [
                    TransportMock.Write(b"baz")
                ])

    def test_detached_response(self):
        data = []

        def data_received(blob):
            data.append(blob)

        def connection_made(transport):
            transport.write(b"foo")
            self.assertFalse(data)

        self.protocol.connection_made = connection_made
        self.protocol.data_received = data_received

        self._run_test(
            self.t,
            [
                TransportMock.Write(
                    b"foo",
                    response=TransportMock.Receive(b"bar")
                )
            ])

    def test_no_response_conflict(self):
        data = []

        def data_received(blob):
            data.append(blob)

        def connection_made(transport):
            transport.write(b"foo")
            self.assertFalse(data)
            transport.write(b"bar")

        self.protocol.connection_made = connection_made
        self.protocol.data_received = data_received

        self._run_test(
            self.t,
            [
                TransportMock.Write(
                    b"foo",
                    response=TransportMock.Receive(b"baz"),
                ),
                TransportMock.Write(
                    b"bar",
                    response=TransportMock.Receive(b"baz")
                )
            ])

    def test_partial(self):
        def connection_made(transport):
            transport.write(b"foo")

        self.protocol.connection_made = connection_made

        self._run_test(
            self.t,
            [
                TransportMock.Write(
                    b"foo",
                ),
            ],
            partial=True
        )

        self.t.write_eof()
        self.t.close()

        self._run_test(
            self.t,
            [
                TransportMock.WriteEof(),
                TransportMock.Close()
            ]
        )

    def test_no_starttls_by_default(self):
        self.assertFalse(self.t.can_starttls())
        with self.assertRaises(RuntimeError):
            run_coroutine(self.t.starttls())

    def test_starttls(self):
        self.t = TransportMock(self, self.protocol,
                               with_starttls=True,
                               loop=self.loop)
        self.assertTrue(self.t.can_starttls())

        fut = asyncio.Future()
        def connection_made(transport):
            fut.set_result(None)

        self.protocol.connection_made = connection_made

        ssl_context = unittest.mock.Mock()
        post_handshake_callback = unittest.mock.Mock()
        post_handshake_callback.return_value = []

        @asyncio.coroutine
        def late_starttls():
            yield from fut
            yield from self.t.starttls(ssl_context,
                                       post_handshake_callback)

        run_coroutine_with_peer(
            late_starttls(),
            self.t.run_test(
                [
                    TransportMock.STARTTLS(ssl_context,
                                           post_handshake_callback)
                ]
            )
        )

        post_handshake_callback.assert_called_once_with(self.t)

    def test_starttls_without_callback(self):
        self.t = TransportMock(self, self.protocol,
                               with_starttls=True,
                               loop=self.loop)
        self.assertTrue(self.t.can_starttls())

        fut = asyncio.Future()
        def connection_made(transport):
            fut.set_result(None)

        self.protocol.connection_made = connection_made

        ssl_context = unittest.mock.Mock()

        @asyncio.coroutine
        def late_starttls():
            yield from fut
            yield from self.t.starttls(ssl_context)

        run_coroutine_with_peer(
            late_starttls(),
            self.t.run_test(
                [
                    TransportMock.STARTTLS(ssl_context, None)
                ]
            )
        )

    def test_exception_from_stimulus_bubbles_up(self):
        exc = ConnectionError("foobar")

        def data_received(data):
            raise exc

        self.protocol.data_received = data_received

        with self.assertRaises(ConnectionError) as ctx:
            run_coroutine(
                self.t.run_test(
                    [
                    ],
                    stimulus=TransportMock.Receive(b"foobar")
                )
            )

        self.assertIs(
            exc,
            ctx.exception
        )

    def tearDown(self):
        del self.t
        del self.loop
        del self.protocol