示例#1
0
 def setUp(self):
     super().setUp()
     self.loop = asyncio.new_event_loop()
     asyncio.set_event_loop(self.loop)
     self.protocol = WebSocketCommonProtocol(debug=True)
     self.transport = TransportMock()
     self.transport.connect(self.loop, self.protocol)
示例#2
0
 def setUp(self):
     super().setUp()
     self.loop = asyncio.new_event_loop()
     asyncio.set_event_loop(self.loop)
     # Disable pings to make it easier to test what frames are sent exactly.
     self.protocol = WebSocketCommonProtocol(ping_interval=None)
     self.transport = TransportMock()
     self.transport.setup_mock(self.loop, self.protocol)
示例#3
0
 def restart_protocol_with_keepalive_ping(self,
                                          ping_interval=3 * MS,
                                          ping_timeout=3 * MS):
     initial_protocol = self.protocol
     # copied from tearDown
     self.transport.close()
     self.loop.run_until_complete(self.protocol.close())
     # copied from setUp, but enables keepalive pings
     self.protocol = WebSocketCommonProtocol(ping_interval=ping_interval,
                                             ping_timeout=ping_timeout)
     self.transport = TransportMock()
     self.transport.setup_mock(self.loop, self.protocol)
     self.protocol.is_client = initial_protocol.is_client
     self.protocol.side = initial_protocol.side
示例#4
0
async def websocket(request: Request, ws: WebSocket) -> None:
    logger.info('WebSocket connected')
    app = request.app.app

    # send chat history
    chat_log = await app.redis.lrange(CHAT_LOG, 0, MAX_CHAT_LOG) or []
    chat_log = list(map(lambda x: json.loads(x), chat_log))

    await ws.send(
        make_server_action(WSServerActionType.REPLACE_CHAT_LOG,
                           {'log': chat_log}))

    async with app.subscribe_chat() as (pubsub_channel, pubsub_get):
        ws_receiver = asyncio.ensure_future(ws.recv())
        pubsub_receiver = asyncio.ensure_future(pubsub_get())

        try:
            while True:
                dones, pendings = await asyncio.wait(
                    [ws_receiver, pubsub_receiver],
                    timeout=WAIT_TIMEOUT,
                    return_when=asyncio.FIRST_COMPLETED)

                for done in dones:
                    if done is ws_receiver:
                        ws_data, ws_receiver = ws_receiver.result(
                        ), asyncio.ensure_future(ws.recv())
                        await app.redis.publish(
                            pubsub_channel.name, await
                            receive_ws_message(app, ws, json.loads(ws_data)))
                    elif done is pubsub_receiver:
                        pubsub_data, pubsub_receiver = pubsub_receiver.result(
                        ), asyncio.ensure_future(pubsub_get())
                        logger.debug('pubsub: %r', pubsub_data)
                        await ws.send(
                            make_server_action(
                                WSServerActionType.APPEND_CHAT_EVENT,
                                json.loads(pubsub_data)))

        finally:
            ws_receiver.cancel()
            pubsub_receiver.cancel()
示例#5
0
class CommonTests:
    """
    Mixin that defines most tests but doesn't inherit unittest.TestCase.

    Tests are run by the ServerTests and ClientTests subclasses.

    """
    def setUp(self):
        super().setUp()
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)
        # Disable pings to make it easier to test what frames are sent exactly.
        self.protocol = WebSocketCommonProtocol(ping_interval=None)
        self.transport = TransportMock()
        self.transport.setup_mock(self.loop, self.protocol)

    def tearDown(self):
        self.transport.close()
        self.loop.run_until_complete(self.protocol.close())
        self.loop.close()
        super().tearDown()

    # Utilities for writing tests.

    def run_loop_once(self):
        # Process callbacks scheduled with call_soon by appending a callback
        # to stop the event loop then running it until it hits that callback.
        self.loop.call_soon(self.loop.stop)
        self.loop.run_forever()

    def make_drain_slow(self, delay=MS):
        # Process connection_made in order to initialize self.protocol.writer.
        self.run_loop_once()

        original_drain = self.protocol.writer.drain

        @asyncio.coroutine
        def delayed_drain():
            yield from asyncio.sleep(delay, loop=self.loop)
            yield from original_drain()

        self.protocol.writer.drain = delayed_drain

    close_frame = Frame(True, OP_CLOSE, serialize_close(1000, 'close'))
    local_close = Frame(True, OP_CLOSE, serialize_close(1000, 'local'))
    remote_close = Frame(True, OP_CLOSE, serialize_close(1000, 'remote'))

    @property
    def ensure_future(self):
        return functools.partial(asyncio_ensure_future, loop=self.loop)

    def receive_frame(self, frame):
        """
        Make the protocol receive a frame.

        """
        writer = self.protocol.data_received
        mask = not self.protocol.is_client
        frame.write(writer, mask=mask)

    def receive_eof(self):
        """
        Make the protocol receive the end of the data stream.

        Since ``WebSocketCommonProtocol.eof_received`` returns ``None``, an
        actual transport would close itself after calling it. This function
        emulates that behavior.

        """
        self.protocol.eof_received()
        self.loop.call_soon(self.transport.close)

    def receive_eof_if_client(self):
        """
        Like receive_eof, but only if this is the client side.

        Since the server is supposed to initiate the termination of the TCP
        connection, this method helps making tests work for both sides.

        """
        if self.protocol.is_client:
            self.receive_eof()

    def close_connection(self, code=1000, reason='close'):
        """
        Execute a closing handshake.

        This puts the connection in the CLOSED state.

        """
        close_frame_data = serialize_close(code, reason)
        # Prepare the response to the closing handshake from the remote side.
        self.receive_frame(Frame(True, OP_CLOSE, close_frame_data))
        self.receive_eof_if_client()
        # Trigger the closing handshake from the local side and complete it.
        self.loop.run_until_complete(self.protocol.close(code, reason))
        # Empty the outgoing data stream so we can make assertions later on.
        self.assertOneFrameSent(True, OP_CLOSE, close_frame_data)

        assert self.protocol.state is State.CLOSED

    def half_close_connection_local(self, code=1000, reason='close'):
        """
        Start a closing handshake but do not complete it.

        The main difference with `close_connection` is that the connection is
        left in the CLOSING state until the event loop runs again.

        The current implementation returns a task that must be awaited or
        canceled, else asyncio complains about destroying a pending task.

        """
        close_frame_data = serialize_close(code, reason)
        # Trigger the closing handshake from the local endpoint.
        close_task = self.ensure_future(self.protocol.close(code, reason))
        self.run_loop_once()  # wait_for executes
        self.run_loop_once()  # write_frame executes
        # Empty the outgoing data stream so we can make assertions later on.
        self.assertOneFrameSent(True, OP_CLOSE, close_frame_data)

        assert self.protocol.state is State.CLOSING

        # Complete the closing sequence at 1ms intervals so the test can run
        # at each point even it goes back to the event loop several times.
        self.loop.call_later(MS, self.receive_frame,
                             Frame(True, OP_CLOSE, close_frame_data))
        self.loop.call_later(2 * MS, self.receive_eof_if_client)

        # This task must be awaited or canceled by the caller.
        return close_task

    def half_close_connection_remote(self, code=1000, reason='close'):
        """
        Receive a closing handshake but do not complete it.

        The main difference with `close_connection` is that the connection is
        left in the CLOSING state until the event loop runs again.

        """
        # On the server side, websockets completes the closing handshake and
        # closes the TCP connection immediately. Yield to the event loop after
        # sending the close frame to run the test while the connection is in
        # the CLOSING state.
        if not self.protocol.is_client:
            self.make_drain_slow()

        close_frame_data = serialize_close(code, reason)
        # Trigger the closing handshake from the remote endpoint.
        self.receive_frame(Frame(True, OP_CLOSE, close_frame_data))
        self.run_loop_once()  # read_frame executes
        # Empty the outgoing data stream so we can make assertions later on.
        self.assertOneFrameSent(True, OP_CLOSE, close_frame_data)

        assert self.protocol.state is State.CLOSING

        # Complete the closing sequence at 1ms intervals so the test can run
        # at each point even it goes back to the event loop several times.
        self.loop.call_later(2 * MS, self.receive_eof_if_client)

    def process_invalid_frames(self):
        """
        Make the protocol fail quickly after simulating invalid data.

        To achieve this, this function triggers the protocol's eof_received,
        which interrupts pending reads waiting for more data.

        """
        self.run_loop_once()
        self.receive_eof()
        self.loop.run_until_complete(self.protocol.close_connection_task)

    def last_sent_frame(self):
        """
        Read the last frame sent to the transport.

        This method assumes that at most one frame was sent. It raises an
        AssertionError otherwise.

        """
        stream = asyncio.StreamReader(loop=self.loop)

        for (data, ), kw in self.transport.write.call_args_list:
            stream.feed_data(data)
        self.transport.write.call_args_list = []
        stream.feed_eof()

        if stream.at_eof():
            frame = None
        else:
            frame = self.loop.run_until_complete(
                Frame.read(stream.readexactly, mask=self.protocol.is_client))

        if not stream.at_eof():  # pragma: no cover
            data = self.loop.run_until_complete(stream.read())
            raise AssertionError("Trailing data found: {!r}".format(data))

        return frame

    def assertOneFrameSent(self, *args):
        self.assertEqual(self.last_sent_frame(), Frame(*args))

    def assertNoFrameSent(self):
        self.assertIsNone(self.last_sent_frame())

    def assertConnectionClosed(self, code, message):
        # The following line guarantees that connection_lost was called.
        self.assertEqual(self.protocol.state, State.CLOSED)
        # A close frame was received.
        self.assertEqual(self.protocol.close_code, code)
        self.assertEqual(self.protocol.close_reason, message)

    def assertConnectionFailed(self, code, message):
        # The following line guarantees that connection_lost was called.
        self.assertEqual(self.protocol.state, State.CLOSED)
        # No close frame was received.
        self.assertEqual(self.protocol.close_code, 1006)
        self.assertEqual(self.protocol.close_reason, '')
        # A close frame was sent -- unless the connection was already lost.
        if code == 1006:
            self.assertNoFrameSent()
        else:
            self.assertOneFrameSent(True, OP_CLOSE,
                                    serialize_close(code, message))

    @contextlib.contextmanager
    def assertCompletesWithin(self, min_time, max_time):
        t0 = self.loop.time()
        yield
        t1 = self.loop.time()
        dt = t1 - t0
        self.assertGreaterEqual(dt, min_time,
                                "Too fast: {} < {}".format(dt, min_time))
        self.assertLess(dt, max_time,
                        "Too slow: {} >= {}".format(dt, max_time))

    # Test public attributes.

    def test_local_address(self):
        get_extra_info = unittest.mock.Mock(return_value=('host', 4312))
        self.transport.get_extra_info = get_extra_info

        self.assertEqual(self.protocol.local_address, ('host', 4312))
        get_extra_info.assert_called_with('sockname', None)

    def test_local_address_before_connection(self):
        # Emulate the situation before connection_open() runs.
        self.protocol.writer, _writer = None, self.protocol.writer

        try:
            self.assertEqual(self.protocol.local_address, None)
        finally:
            self.protocol.writer = _writer

    def test_remote_address(self):
        get_extra_info = unittest.mock.Mock(return_value=('host', 4312))
        self.transport.get_extra_info = get_extra_info

        self.assertEqual(self.protocol.remote_address, ('host', 4312))
        get_extra_info.assert_called_with('peername', None)

    def test_remote_address_before_connection(self):
        # Emulate the situation before connection_open() runs.
        self.protocol.writer, _writer = None, self.protocol.writer

        try:
            self.assertEqual(self.protocol.remote_address, None)
        finally:
            self.protocol.writer = _writer

    def test_open(self):
        self.assertTrue(self.protocol.open)
        self.close_connection()
        self.assertFalse(self.protocol.open)

    def test_closed(self):
        self.assertFalse(self.protocol.closed)
        self.close_connection()
        self.assertTrue(self.protocol.closed)

    # Test the recv coroutine.

    def test_recv_text(self):
        self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8')))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, 'café')

    def test_recv_binary(self):
        self.receive_frame(Frame(True, OP_BINARY, b'tea'))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, b'tea')

    def test_recv_on_closing_connection_local(self):
        close_task = self.half_close_connection_local()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.recv())

        self.loop.run_until_complete(close_task)  # cleanup

    def test_recv_on_closing_connection_remote(self):
        self.half_close_connection_remote()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.recv())

    def test_recv_on_closed_connection(self):
        self.close_connection()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.recv())

    def test_recv_protocol_error(self):
        self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8')))
        self.process_invalid_frames()
        self.assertConnectionFailed(1002, '')

    def test_recv_unicode_error(self):
        self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('latin-1')))
        self.process_invalid_frames()
        self.assertConnectionFailed(1007, '')

    def test_recv_text_payload_too_big(self):
        self.protocol.max_size = 1024
        self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205))
        self.process_invalid_frames()
        self.assertConnectionFailed(1009, '')

    def test_recv_binary_payload_too_big(self):
        self.protocol.max_size = 1024
        self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342))
        self.process_invalid_frames()
        self.assertConnectionFailed(1009, '')

    def test_recv_text_no_max_size(self):
        self.protocol.max_size = None  # for test coverage
        self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, 'café' * 205)

    def test_recv_binary_no_max_size(self):
        self.protocol.max_size = None  # for test coverage
        self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, b'tea' * 342)

    def test_recv_other_error(self):
        @asyncio.coroutine
        def read_message():
            raise Exception("BOOM")

        self.protocol.read_message = read_message
        self.process_invalid_frames()
        self.assertConnectionFailed(1011, '')

    def test_recv_canceled(self):
        recv = self.ensure_future(self.protocol.recv())
        self.loop.call_soon(recv.cancel)
        with self.assertRaises(asyncio.CancelledError):
            self.loop.run_until_complete(recv)

        # The next frame doesn't disappear in a vacuum (it used to).
        self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8')))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, 'café')

    # Test the send coroutine.

    def test_send_text(self):
        self.loop.run_until_complete(self.protocol.send('café'))
        self.assertOneFrameSent(True, OP_TEXT, 'café'.encode('utf-8'))

    def test_send_binary(self):
        self.loop.run_until_complete(self.protocol.send(b'tea'))
        self.assertOneFrameSent(True, OP_BINARY, b'tea')

    def test_send_type_error(self):
        with self.assertRaises(TypeError):
            self.loop.run_until_complete(self.protocol.send(42))
        self.assertNoFrameSent()

    def test_send_on_closing_connection_local(self):
        close_task = self.half_close_connection_local()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.send('foobar'))

        self.assertNoFrameSent()

        self.loop.run_until_complete(close_task)  # cleanup

    def test_send_on_closing_connection_remote(self):
        self.half_close_connection_remote()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.send('foobar'))

        self.assertNoFrameSent()

    def test_send_on_closed_connection(self):
        self.close_connection()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.send('foobar'))

        self.assertNoFrameSent()

    # Test the ping coroutine.

    def test_ping_default(self):
        self.loop.run_until_complete(self.protocol.ping())
        # With our testing tools, it's more convenient to extract the expected
        # ping data from the library's internals than from the frame sent.
        ping_data = next(iter(self.protocol.pings))
        self.assertIsInstance(ping_data, bytes)
        self.assertEqual(len(ping_data), 4)
        self.assertOneFrameSent(True, OP_PING, ping_data)

    def test_ping_text(self):
        self.loop.run_until_complete(self.protocol.ping('café'))
        self.assertOneFrameSent(True, OP_PING, 'café'.encode('utf-8'))

    def test_ping_binary(self):
        self.loop.run_until_complete(self.protocol.ping(b'tea'))
        self.assertOneFrameSent(True, OP_PING, b'tea')

    def test_ping_type_error(self):
        with self.assertRaises(TypeError):
            self.loop.run_until_complete(self.protocol.ping(42))
        self.assertNoFrameSent()

    def test_ping_on_closing_connection_local(self):
        close_task = self.half_close_connection_local()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.ping())

        self.assertNoFrameSent()

        self.loop.run_until_complete(close_task)  # cleanup

    def test_ping_on_closing_connection_remote(self):
        self.half_close_connection_remote()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.ping())

        self.assertNoFrameSent()

    def test_ping_on_closed_connection(self):
        self.close_connection()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.ping())

        self.assertNoFrameSent()

    # Test the pong coroutine.

    def test_pong_default(self):
        self.loop.run_until_complete(self.protocol.pong())
        self.assertOneFrameSent(True, OP_PONG, b'')

    def test_pong_text(self):
        self.loop.run_until_complete(self.protocol.pong('café'))
        self.assertOneFrameSent(True, OP_PONG, 'café'.encode('utf-8'))

    def test_pong_binary(self):
        self.loop.run_until_complete(self.protocol.pong(b'tea'))
        self.assertOneFrameSent(True, OP_PONG, b'tea')

    def test_pong_type_error(self):
        with self.assertRaises(TypeError):
            self.loop.run_until_complete(self.protocol.pong(42))
        self.assertNoFrameSent()

    def test_pong_on_closing_connection_local(self):
        close_task = self.half_close_connection_local()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.pong())

        self.assertNoFrameSent()

        self.loop.run_until_complete(close_task)  # cleanup

    def test_pong_on_closing_connection_remote(self):
        self.half_close_connection_remote()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.pong())

        self.assertNoFrameSent()

    def test_pong_on_closed_connection(self):
        self.close_connection()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.pong())

        self.assertNoFrameSent()

    # Test the protocol's logic for acknowledging pings with pongs.

    def test_answer_ping(self):
        self.receive_frame(Frame(True, OP_PING, b'test'))
        self.run_loop_once()
        self.assertOneFrameSent(True, OP_PONG, b'test')

    def test_ignore_pong(self):
        self.receive_frame(Frame(True, OP_PONG, b'test'))
        self.run_loop_once()
        self.assertNoFrameSent()

    def test_acknowledge_ping(self):
        ping = self.loop.run_until_complete(self.protocol.ping())
        self.assertFalse(ping.done())
        ping_frame = self.last_sent_frame()
        pong_frame = Frame(True, OP_PONG, ping_frame.data)
        self.receive_frame(pong_frame)
        self.run_loop_once()
        self.run_loop_once()
        self.assertTrue(ping.done())

    def test_cancel_ping(self):
        ping = self.loop.run_until_complete(self.protocol.ping())
        # Remove the frame from the buffer, else close_connection() complains.
        self.last_sent_frame()
        self.assertFalse(ping.cancelled())
        self.close_connection()
        self.assertTrue(ping.cancelled())

    def test_acknowledge_previous_pings(self):
        pings = [(self.loop.run_until_complete(self.protocol.ping()),
                  self.last_sent_frame()) for i in range(3)]
        # Unsolicited pong doesn't acknowledge pings
        self.receive_frame(Frame(True, OP_PONG, b''))
        self.run_loop_once()
        self.run_loop_once()
        self.assertFalse(pings[0][0].done())
        self.assertFalse(pings[1][0].done())
        self.assertFalse(pings[2][0].done())
        # Pong acknowledges all previous pings
        self.receive_frame(Frame(True, OP_PONG, pings[1][1].data))
        self.run_loop_once()
        self.run_loop_once()
        self.assertTrue(pings[0][0].done())
        self.assertTrue(pings[1][0].done())
        self.assertFalse(pings[2][0].done())

    def test_canceled_ping(self):
        ping = self.loop.run_until_complete(self.protocol.ping())
        ping_frame = self.last_sent_frame()
        ping.cancel()
        pong_frame = Frame(True, OP_PONG, ping_frame.data)
        self.receive_frame(pong_frame)
        self.run_loop_once()
        self.run_loop_once()
        self.assertTrue(ping.cancelled())

    def test_duplicate_ping(self):
        self.loop.run_until_complete(self.protocol.ping(b'foobar'))
        self.assertOneFrameSent(True, OP_PING, b'foobar')
        with self.assertRaises(ValueError):
            self.loop.run_until_complete(self.protocol.ping(b'foobar'))
        self.assertNoFrameSent()

    # Test the protocol's logic for rebuilding fragmented messages.

    def test_fragmented_text(self):
        self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
        self.receive_frame(Frame(True, OP_CONT, 'fé'.encode('utf-8')))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, 'café')

    def test_fragmented_binary(self):
        self.receive_frame(Frame(False, OP_BINARY, b't'))
        self.receive_frame(Frame(False, OP_CONT, b'e'))
        self.receive_frame(Frame(True, OP_CONT, b'a'))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, b'tea')

    def test_fragmented_text_payload_too_big(self):
        self.protocol.max_size = 1024
        self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100))
        self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105))
        self.process_invalid_frames()
        self.assertConnectionFailed(1009, '')

    def test_fragmented_binary_payload_too_big(self):
        self.protocol.max_size = 1024
        self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171))
        self.receive_frame(Frame(True, OP_CONT, b'tea' * 171))
        self.process_invalid_frames()
        self.assertConnectionFailed(1009, '')

    def test_fragmented_text_no_max_size(self):
        self.protocol.max_size = None  # for test coverage
        self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100))
        self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, 'café' * 205)

    def test_fragmented_binary_no_max_size(self):
        self.protocol.max_size = None  # for test coverage
        self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171))
        self.receive_frame(Frame(True, OP_CONT, b'tea' * 171))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, b'tea' * 342)

    def test_control_frame_within_fragmented_text(self):
        self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
        self.receive_frame(Frame(True, OP_PING, b''))
        self.receive_frame(Frame(True, OP_CONT, 'fé'.encode('utf-8')))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, 'café')
        self.assertOneFrameSent(True, OP_PONG, b'')

    def test_unterminated_fragmented_text(self):
        self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
        # Missing the second part of the fragmented frame.
        self.receive_frame(Frame(True, OP_BINARY, b'tea'))
        self.process_invalid_frames()
        self.assertConnectionFailed(1002, '')

    def test_close_handshake_in_fragmented_text(self):
        self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
        self.receive_frame(Frame(True, OP_CLOSE, b''))
        self.process_invalid_frames()
        # The RFC may have overlooked this case: it says that control frames
        # can be interjected in the middle of a fragmented message and that a
        # close frame must be echoed. Even though there's an unterminated
        # message, technically, the closing handshake was successful.
        self.assertConnectionClosed(1005, '')

    def test_connection_close_in_fragmented_text(self):
        self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
        self.process_invalid_frames()
        self.assertConnectionFailed(1006, '')

    # Test miscellaneous code paths to ensure full coverage.

    def test_connection_lost(self):
        # Test calling connection_lost without going through close_connection.
        self.protocol.connection_lost(None)

        self.assertConnectionFailed(1006, '')

    def test_ensure_open_before_opening_handshake(self):
        # Simulate a bug by forcibly reverting the protocol state.
        self.protocol.state = State.CONNECTING

        with self.assertRaises(InvalidState):
            self.loop.run_until_complete(self.protocol.ensure_open())

    def test_ensure_open_during_unclean_close(self):
        # Process connection_made in order to start transfer_data_task.
        self.run_loop_once()

        # Ensure the test terminates quickly.
        self.loop.call_later(MS, self.receive_eof_if_client)

        # Simulate the case when close() times out sending a close frame.
        self.protocol.fail_connection()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.ensure_open())

    def test_legacy_recv(self):
        # By default legacy_recv in disabled.
        self.assertEqual(self.protocol.legacy_recv, False)

        self.close_connection()

        # Enable legacy_recv.
        self.protocol.legacy_recv = True

        # Now recv() returns None instead of raising ConnectionClosed.
        self.assertIsNone(self.loop.run_until_complete(self.protocol.recv()))

    def test_connection_closed_attributes(self):
        self.close_connection()

        with self.assertRaises(ConnectionClosed) as context:
            self.loop.run_until_complete(self.protocol.recv())

        connection_closed_exc = context.exception
        self.assertEqual(connection_closed_exc.code, 1000)
        self.assertEqual(connection_closed_exc.reason, 'close')

    # Test the protocol logic for sending keepalive pings.

    def restart_protocol_with_keepalive_ping(self,
                                             ping_interval=3 * MS,
                                             ping_timeout=3 * MS):
        initial_protocol = self.protocol
        # copied from tearDown
        self.transport.close()
        self.loop.run_until_complete(self.protocol.close())
        # copied from setUp, but enables keepalive pings
        self.protocol = WebSocketCommonProtocol(ping_interval=ping_interval,
                                                ping_timeout=ping_timeout)
        self.transport = TransportMock()
        self.transport.setup_mock(self.loop, self.protocol)
        self.protocol.is_client = initial_protocol.is_client
        self.protocol.side = initial_protocol.side

    def test_keepalive_ping(self):
        self.restart_protocol_with_keepalive_ping()

        # Ping is sent at 3ms and acknowledged at 4ms.
        self.loop.run_until_complete(asyncio.sleep(4 * MS))
        ping_1, = tuple(self.protocol.pings)
        self.assertOneFrameSent(True, OP_PING, ping_1)
        self.receive_frame(Frame(True, OP_PONG, ping_1))

        # Next ping is sent at 7ms.
        self.loop.run_until_complete(asyncio.sleep(4 * MS))
        ping_2, = tuple(self.protocol.pings)
        self.assertOneFrameSent(True, OP_PING, ping_2)

        # The keepalive ping task goes on.
        self.assertFalse(self.protocol.keepalive_ping_task.done())

    def test_keepalive_ping_not_acknowledged_closes_connection(self):
        self.restart_protocol_with_keepalive_ping()

        # Ping is sent at 3ms and not acknowleged.
        self.loop.run_until_complete(asyncio.sleep(4 * MS))
        ping_1, = tuple(self.protocol.pings)
        self.assertOneFrameSent(True, OP_PING, ping_1)

        # Connection is closed at 6ms.
        self.loop.run_until_complete(asyncio.sleep(4 * MS))
        self.assertOneFrameSent(True, OP_CLOSE, serialize_close(1011, ''))

        # The keepalive ping task is complete.
        self.assertEqual(self.protocol.keepalive_ping_task.result(), None)

    def test_keepalive_ping_stops_when_connection_closing(self):
        self.restart_protocol_with_keepalive_ping()
        close_task = self.half_close_connection_local()

        # No ping sent at 3ms because the closing handshake is in progress.
        self.loop.run_until_complete(asyncio.sleep(4 * MS))
        self.assertNoFrameSent()

        # The keepalive ping task terminated.
        self.assertTrue(self.protocol.keepalive_ping_task.cancelled())

        self.loop.run_until_complete(close_task)  # cleanup

    def test_keepalive_ping_stops_when_connection_closed(self):
        self.restart_protocol_with_keepalive_ping()
        self.close_connection()

        # The keepalive ping task terminated.
        self.assertTrue(self.protocol.keepalive_ping_task.cancelled())

    def test_keepalive_ping_with_no_ping_interval(self):
        self.restart_protocol_with_keepalive_ping(ping_interval=None)

        # No ping is sent at 3ms.
        self.loop.run_until_complete(asyncio.sleep(4 * MS))
        self.assertNoFrameSent()

    def test_keepalive_ping_with_no_ping_timeout(self):
        self.restart_protocol_with_keepalive_ping(ping_timeout=None)

        # Ping is sent at 3ms and not acknowleged.
        self.loop.run_until_complete(asyncio.sleep(4 * MS))
        ping_1, = tuple(self.protocol.pings)
        self.assertOneFrameSent(True, OP_PING, ping_1)

        # Next ping is sent at 7ms anyway.
        self.loop.run_until_complete(asyncio.sleep(4 * MS))
        ping_1_again, ping_2 = tuple(self.protocol.pings)
        self.assertEqual(ping_1, ping_1_again)
        self.assertOneFrameSent(True, OP_PING, ping_2)

        # The keepalive ping task goes on.
        self.assertFalse(self.protocol.keepalive_ping_task.done())

    def test_keepalive_ping_unexpected_error(self):
        self.restart_protocol_with_keepalive_ping()

        @asyncio.coroutine
        def ping():
            raise Exception("BOOM")

        self.protocol.ping = ping

        # The keepalive ping task fails when sending a ping at 3ms.
        self.loop.run_until_complete(asyncio.sleep(4 * MS))

        # The keepalive ping task is complete.
        # It logs and swallows the exception.
        self.assertEqual(self.protocol.keepalive_ping_task.result(), None)

    # Test the protocol logic for closing the connection.

    def test_local_close(self):
        # Emulate how the remote endpoint answers the closing handshake.
        self.loop.call_later(MS, self.receive_frame, self.close_frame)
        self.loop.call_later(MS, self.receive_eof_if_client)

        # Run the closing handshake.
        self.loop.run_until_complete(self.protocol.close(reason='close'))

        self.assertConnectionClosed(1000, 'close')
        self.assertOneFrameSent(*self.close_frame)

        # Closing the connection again is a no-op.
        self.loop.run_until_complete(self.protocol.close(reason='oh noes!'))

        self.assertConnectionClosed(1000, 'close')
        self.assertNoFrameSent()

    def test_remote_close(self):
        # Emulate how the remote endpoint initiates the closing handshake.
        self.loop.call_later(MS, self.receive_frame, self.close_frame)
        self.loop.call_later(MS, self.receive_eof_if_client)

        # Wait for some data in order to process the handshake.
        # After recv() raises ConnectionClosed, the connection is closed.
        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.recv())

        self.assertConnectionClosed(1000, 'close')
        self.assertOneFrameSent(*self.close_frame)

        # Closing the connection again is a no-op.
        self.loop.run_until_complete(self.protocol.close(reason='oh noes!'))

        self.assertConnectionClosed(1000, 'close')
        self.assertNoFrameSent()

    def test_simultaneous_close(self):
        # Receive the incoming close frame right after self.protocol.close()
        # starts executing. This reproduces the error described in:
        # https://github.com/aaugustin/websockets/issues/339
        self.loop.call_soon(self.receive_frame, self.remote_close)
        self.loop.call_soon(self.receive_eof_if_client)

        self.loop.run_until_complete(self.protocol.close(reason='local'))

        self.assertConnectionClosed(1000, 'remote')
        # The current implementation sends a close frame in response to the
        # close frame received from the remote end. It skips the close frame
        # that should be sent as a result of calling close().
        self.assertOneFrameSent(*self.remote_close)

    def test_close_preserves_incoming_frames(self):
        self.receive_frame(Frame(True, OP_TEXT, b'hello'))

        self.loop.call_later(MS, self.receive_frame, self.close_frame)
        self.loop.call_later(MS, self.receive_eof_if_client)
        self.loop.run_until_complete(self.protocol.close(reason='close'))

        self.assertConnectionClosed(1000, 'close')
        self.assertOneFrameSent(*self.close_frame)

        next_message = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(next_message, 'hello')

    def test_close_protocol_error(self):
        invalid_close_frame = Frame(True, OP_CLOSE, b'\x00')
        self.receive_frame(invalid_close_frame)
        self.receive_eof_if_client()
        self.run_loop_once()
        self.loop.run_until_complete(self.protocol.close(reason='close'))

        self.assertConnectionFailed(1002, '')

    def test_close_connection_lost(self):
        self.receive_eof()
        self.run_loop_once()
        self.loop.run_until_complete(self.protocol.close(reason='close'))

        self.assertConnectionFailed(1006, '')

    def test_local_close_during_recv(self):
        recv = self.ensure_future(self.protocol.recv())

        self.loop.call_later(MS, self.receive_frame, self.close_frame)
        self.loop.call_later(MS, self.receive_eof_if_client)

        self.loop.run_until_complete(self.protocol.close(reason='close'))

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(recv)

        self.assertConnectionClosed(1000, 'close')

    # There is no test_remote_close_during_recv because it would be identical
    # to test_remote_close.

    def test_remote_close_during_send(self):
        self.make_drain_slow()
        send = self.ensure_future(self.protocol.send('hello'))

        self.receive_frame(self.close_frame)
        self.receive_eof()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(send)

        self.assertConnectionClosed(1000, 'close')
示例#6
0
class CommonTests:
    """
    Mixin that defines most tests but doesn't inherit unittest.TestCase.

    Tests are run by the ServerTests and ClientTests subclasses.

    """
    def setUp(self):
        super().setUp()
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)
        self.protocol = WebSocketCommonProtocol(debug=True)
        self.transport = TransportMock()
        self.transport.connect(self.loop, self.protocol)

    def tearDown(self):
        self.loop.run_until_complete(
            self.protocol.close_connection(force=True))
        self.loop.close()
        super().tearDown()

    # Utilities for writing tests.

    def run_loop_once(self):
        # Process callbacks scheduled with call_soon by appending a callback
        # to stop the event loop then running it until it hits that callback.
        self.loop.call_soon(self.loop.stop)
        self.loop.run_forever()

    def make_drain_slow(self):
        # Process connection_made in order to initialize self.protocol.writer.
        self.run_loop_once()

        original_drain = self.protocol.writer.drain

        @asyncio.coroutine
        def delayed_drain():
            yield from asyncio.sleep(3 * MS, loop=self.loop)
            yield from original_drain()

        self.protocol.writer.drain = delayed_drain

    close_frame = Frame(True, OP_CLOSE, serialize_close(1000, 'close'))
    local_close = Frame(True, OP_CLOSE, serialize_close(1000, 'local'))
    remote_close = Frame(True, OP_CLOSE, serialize_close(1000, 'remote'))

    @property
    def async (self):
        return functools.partial(asyncio_ensure_future, loop=self.loop)

    def receive_frame(self, frame):
        """
        Make the protocol receive a frame.

        """
        writer = self.protocol.data_received
        mask = not self.protocol.is_client
        self.loop.call_soon(write_frame, frame, writer, mask)

    def receive_eof(self):
        """
        Make the protocol receive the end of stream.

        WebSocketCommonProtocol.eof_received returns None — it is inherited
        from StreamReaderProtocol. (Returning True wouldn't work on secure
        connections anyway.) As a consequence, actual transports close
        themselves after calling it.

        To emulate this behavior, this function closes the transport just
        after calling the protocol's eof_received. Closing the transport has
        the side-effect calling the protocol's connection_lost.

        """
        self.loop.call_soon(self.protocol.eof_received)
        self.loop.call_soon(self.loop.call_soon, self.transport.close)

    def receive_eof_if_client(self):
        """
        Like receive_eof, but only if this is the client side.

        Since the server is supposed to initiate the termination of the TCP
        connection, this method helps making tests work for both sides.

        """
        if self.protocol.is_client:
            self.receive_eof()

    def close_connection(self, code=1000, reason='close'):
        """
        Close the connection with a standard closing handshake.

        This puts the connection in the CLOSED state.

        """
        close_frame_data = serialize_close(code, reason)
        # Prepare the response to the closing handshake from the remote side.
        self.receive_frame(Frame(True, OP_CLOSE, close_frame_data))
        self.receive_eof_if_client()
        # Trigger the closing handshake from the local side and complete it.
        self.loop.run_until_complete(self.protocol.close(code, reason))
        # Empty the outgoing data stream so we can make assertions later on.
        self.assertOneFrameSent(True, OP_CLOSE, close_frame_data)

    def close_connection_partial(self, code=1000, reason='close'):
        """
        Initiate a standard closing handshake but do not complete it.

        The main difference with `close_connection` is that the connection is
        left in the CLOSING state until the event loop runs again.

        """
        close_frame_data = serialize_close(code, reason)
        # Trigger the closing handshake from the local side.
        self. async (self.protocol.close(code, reason))
        self.run_loop_once()
        # Empty the outgoing data stream so we can make assertions later on.
        self.assertOneFrameSent(True, OP_CLOSE, close_frame_data)
        # Prepare the response to the closing handshake from the remote side.
        self.receive_frame(Frame(True, OP_CLOSE, close_frame_data))
        self.receive_eof_if_client()

    def process_invalid_frames(self):
        """
        Make the protocol fail quickly after simulating invalid data.

        To achieve this, this function triggers the protocol's eof_received,
        which interrupts pending reads waiting for more data. It delays this
        operation with call_later because the protocol must start processing
        frames first. Otherwise it will see a closed connection and no data.

        """
        self.loop.call_later(MS, self.receive_eof)
        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.recv())

    def process_control_frames(self):
        """
        Process control frames received by the protocol.

        To ensure that recv completes quickly, receive an additional dummy
        frame, which recv() will drop.

        """
        self.receive_frame(Frame(True, OP_TEXT, b''))
        next_message = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(next_message, '')

    def last_sent_frame(self):
        """
        Read the last frame sent to the transport.

        This method assumes that at most one frame was sent. It raises an
        AssertionError otherwise.

        """
        stream = asyncio.StreamReader(loop=self.loop)

        for (data, ), kw in self.transport.write.call_args_list:
            stream.feed_data(data)
        self.transport.write.call_args_list = []
        stream.feed_eof()

        if stream.at_eof():
            frame = None
        else:
            frame = self.loop.run_until_complete(
                read_frame(stream.readexactly, self.protocol.is_client))

        if not stream.at_eof():  # pragma: no cover
            data = self.loop.run_until_complete(stream.read())
            raise AssertionError("Trailing data found: {!r}".format(data))

        return frame

    def assertOneFrameSent(self, fin, opcode, data):
        self.assertEqual(self.last_sent_frame(), Frame(fin, opcode, data))

    def assertNoFrameSent(self):
        self.assertIsNone(self.last_sent_frame())

    def assertConnectionClosed(self, code, message):
        # The following line guarantees that connection_lost was called.
        self.assertEqual(self.protocol.state, CLOSED)
        self.assertEqual(self.protocol.close_code, code)
        self.assertEqual(self.protocol.close_reason, message)

    @contextlib.contextmanager
    def assertCompletesWithin(self, min_time, max_time):
        t0 = self.loop.time()
        yield
        t1 = self.loop.time()
        dt = t1 - t0
        self.assertGreaterEqual(dt, min_time,
                                "Too fast: {} < {}".format(dt, min_time))
        self.assertLess(dt, max_time,
                        "Too slow: {} >= {}".format(dt, max_time))

    # Test public attributes.

    def test_local_address(self):
        get_extra_info = unittest.mock.Mock(return_value=('host', 4312))
        self.transport.get_extra_info = get_extra_info
        # The connection isn't established yet.
        self.assertEqual(self.protocol.local_address, None)
        self.run_loop_once()
        # The connection is established.
        self.assertEqual(self.protocol.local_address, ('host', 4312))
        get_extra_info.assert_called_with('sockname', None)

    def test_remote_address(self):
        get_extra_info = unittest.mock.Mock(return_value=('host', 4312))
        self.transport.get_extra_info = get_extra_info
        # The connection isn't established yet.
        self.assertEqual(self.protocol.remote_address, None)
        self.run_loop_once()
        # The connection is established.
        self.assertEqual(self.protocol.remote_address, ('host', 4312))
        get_extra_info.assert_called_with('peername', None)

    def test_open(self):
        self.assertTrue(self.protocol.open)
        self.close_connection()
        self.assertFalse(self.protocol.open)

    def test_state_name(self):
        self.assertEqual(self.protocol.state_name, 'OPEN')
        self.close_connection()
        self.assertEqual(self.protocol.state_name, 'CLOSED')

    # Test the recv coroutine.

    def test_recv_text(self):
        self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8')))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, 'café')

    def test_recv_binary(self):
        self.receive_frame(Frame(True, OP_BINARY, b'tea'))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, b'tea')

    def test_recv_on_closing_connection(self):
        self.close_connection_partial()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.recv())

    def test_recv_on_closed_connection(self):
        self.close_connection()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.recv())

    def test_recv_protocol_error(self):
        self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8')))
        self.process_invalid_frames()
        self.assertConnectionClosed(1002, '')

    def test_recv_unicode_error(self):
        self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('latin-1')))
        self.process_invalid_frames()
        self.assertConnectionClosed(1007, '')

    def test_recv_text_payload_too_big(self):
        self.protocol.max_size = 1024
        self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205))
        self.process_invalid_frames()
        self.assertConnectionClosed(1009, '')

    def test_recv_binary_payload_too_big(self):
        self.protocol.max_size = 1024
        self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342))
        self.process_invalid_frames()
        self.assertConnectionClosed(1009, '')

    def test_recv_text_no_max_size(self):
        self.protocol.max_size = None  # for test coverage
        self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, 'café' * 205)

    def test_recv_binary_no_max_size(self):
        self.protocol.max_size = None  # for test coverage
        self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, b'tea' * 342)

    def test_recv_other_error(self):
        @asyncio.coroutine
        def read_message():
            raise Exception("BOOM")

        self.protocol.read_message = read_message
        self.process_invalid_frames()
        with self.assertRaises(Exception):
            self.loop.run_until_complete(self.protocol.worker_task)
        self.assertConnectionClosed(1011, '')

    def test_recv_cancelled(self):
        recv = self. async (self.protocol.recv())
        self.loop.call_soon(recv.cancel)
        with self.assertRaises(asyncio.CancelledError):
            self.loop.run_until_complete(recv)

        # The next frame doesn't disappear in a vacuum (it used to).
        self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8')))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, 'café')

    # Test the send coroutine.

    def test_send_text(self):
        self.loop.run_until_complete(self.protocol.send('café'))
        self.assertOneFrameSent(True, OP_TEXT, 'café'.encode('utf-8'))

    def test_send_binary(self):
        self.loop.run_until_complete(self.protocol.send(b'tea'))
        self.assertOneFrameSent(True, OP_BINARY, b'tea')

    def test_send_type_error(self):
        with self.assertRaises(TypeError):
            self.loop.run_until_complete(self.protocol.send(42))
        self.assertNoFrameSent()

    def test_send_on_closing_connection(self):
        self.close_connection_partial()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.send('foobar'))
        self.assertNoFrameSent()

    def test_send_on_closed_connection(self):
        self.close_connection()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.send('foobar'))
        self.assertNoFrameSent()

    # Test the ping coroutine.

    def test_ping_default(self):
        self.loop.run_until_complete(self.protocol.ping())
        # With our testing tools, it's more convenient to extract the expected
        # ping data from the library's internals than from the frame sent.
        ping_data = next(iter(self.protocol.pings))
        self.assertIsInstance(ping_data, bytes)
        self.assertEqual(len(ping_data), 4)
        self.assertOneFrameSent(True, OP_PING, ping_data)

    def test_ping_text(self):
        self.loop.run_until_complete(self.protocol.ping('café'))
        self.assertOneFrameSent(True, OP_PING, 'café'.encode('utf-8'))

    def test_ping_binary(self):
        self.loop.run_until_complete(self.protocol.ping(b'tea'))
        self.assertOneFrameSent(True, OP_PING, b'tea')

    def test_ping_type_error(self):
        with self.assertRaises(TypeError):
            self.loop.run_until_complete(self.protocol.ping(42))
        self.assertNoFrameSent()

    def test_ping_on_closing_connection(self):
        self.close_connection_partial()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.ping())
        self.assertNoFrameSent()

    def test_ping_on_closed_connection(self):
        self.close_connection()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.ping())
        self.assertNoFrameSent()

    # Test the pong coroutine.

    def test_pong_default(self):
        self.loop.run_until_complete(self.protocol.pong())
        self.assertOneFrameSent(True, OP_PONG, b'')

    def test_pong_text(self):
        self.loop.run_until_complete(self.protocol.pong('café'))
        self.assertOneFrameSent(True, OP_PONG, 'café'.encode('utf-8'))

    def test_pong_binary(self):
        self.loop.run_until_complete(self.protocol.pong(b'tea'))
        self.assertOneFrameSent(True, OP_PONG, b'tea')

    def test_pong_type_error(self):
        with self.assertRaises(TypeError):
            self.loop.run_until_complete(self.protocol.pong(42))
        self.assertNoFrameSent()

    def test_pong_on_closing_connection(self):
        self.close_connection_partial()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.pong())
        self.assertNoFrameSent()

    def test_pong_on_closed_connection(self):
        self.close_connection()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.pong())
        self.assertNoFrameSent()

    # Test the protocol's logic for acknowledging pings with pongs.

    def test_answer_ping(self):
        self.receive_frame(Frame(True, OP_PING, b'test'))
        self.process_control_frames()
        self.assertOneFrameSent(True, OP_PONG, b'test')

    def test_ignore_pong(self):
        self.receive_frame(Frame(True, OP_PONG, b'test'))
        self.process_control_frames()
        self.assertNoFrameSent()

    def test_acknowledge_ping(self):
        ping = self.loop.run_until_complete(self.protocol.ping())
        self.assertFalse(ping.done())
        ping_frame = self.last_sent_frame()
        pong_frame = Frame(True, OP_PONG, ping_frame.data)
        self.receive_frame(pong_frame)
        self.process_control_frames()
        self.assertTrue(ping.done())

    def test_acknowledge_previous_pings(self):
        pings = [(
            self.loop.run_until_complete(self.protocol.ping()),
            self.last_sent_frame(),
        ) for i in range(3)]
        # Unsolicited pong doesn't acknowledge pings
        self.receive_frame(Frame(True, OP_PONG, b''))
        self.process_control_frames()
        self.assertFalse(pings[0][0].done())
        self.assertFalse(pings[1][0].done())
        self.assertFalse(pings[2][0].done())
        # Pong acknowledges all previous pings
        self.receive_frame(Frame(True, OP_PONG, pings[1][1].data))
        self.process_control_frames()
        self.assertTrue(pings[0][0].done())
        self.assertTrue(pings[1][0].done())
        self.assertFalse(pings[2][0].done())

    def test_cancel_ping(self):
        ping = self.loop.run_until_complete(self.protocol.ping())
        ping_frame = self.last_sent_frame()
        ping.cancel()
        pong_frame = Frame(True, OP_PONG, ping_frame.data)
        self.receive_frame(pong_frame)
        self.process_control_frames()
        self.assertTrue(ping.cancelled())

    def test_duplicate_ping(self):
        self.loop.run_until_complete(self.protocol.ping(b'foobar'))
        self.assertOneFrameSent(True, OP_PING, b'foobar')
        with self.assertRaises(ValueError):
            self.loop.run_until_complete(self.protocol.ping(b'foobar'))
        self.assertNoFrameSent()

    # Test the protocol's logic for rebuilding fragmented messages.

    def test_fragmented_text(self):
        self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
        self.receive_frame(Frame(True, OP_CONT, 'fé'.encode('utf-8')))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, 'café')

    def test_fragmented_binary(self):
        self.receive_frame(Frame(False, OP_BINARY, b't'))
        self.receive_frame(Frame(False, OP_CONT, b'e'))
        self.receive_frame(Frame(True, OP_CONT, b'a'))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, b'tea')

    def test_fragmented_text_payload_too_big(self):
        self.protocol.max_size = 1024
        self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100))
        self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105))
        self.process_invalid_frames()
        self.assertConnectionClosed(1009, '')

    def test_fragmented_binary_payload_too_big(self):
        self.protocol.max_size = 1024
        self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171))
        self.receive_frame(Frame(True, OP_CONT, b'tea' * 171))
        self.process_invalid_frames()
        self.assertConnectionClosed(1009, '')

    def test_fragmented_text_no_max_size(self):
        self.protocol.max_size = None  # for test coverage
        self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100))
        self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, 'café' * 205)

    def test_fragmented_binary_no_max_size(self):
        self.protocol.max_size = None  # for test coverage
        self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171))
        self.receive_frame(Frame(True, OP_CONT, b'tea' * 171))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, b'tea' * 342)

    def test_control_frame_within_fragmented_text(self):
        self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
        self.receive_frame(Frame(True, OP_PING, b''))
        self.receive_frame(Frame(True, OP_CONT, 'fé'.encode('utf-8')))
        data = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(data, 'café')
        self.assertOneFrameSent(True, OP_PONG, b'')

    def test_unterminated_fragmented_text(self):
        self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
        # Missing the second part of the fragmented frame.
        self.receive_frame(Frame(True, OP_BINARY, b'tea'))
        self.process_invalid_frames()
        self.assertConnectionClosed(1002, '')

    def test_close_handshake_in_fragmented_text(self):
        self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
        self.receive_frame(Frame(True, OP_CLOSE, b''))
        self.process_invalid_frames()
        self.assertConnectionClosed(1005, '')

    def test_connection_close_in_fragmented_text(self):
        self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8')))
        self.process_invalid_frames()
        self.assertConnectionClosed(1006, '')

    # Test miscellaneous code paths to ensure full coverage.

    def test_connection_lost(self):
        # Test calling connection_lost without going through close_connection.
        self.protocol.connection_lost(None)

        self.assertConnectionClosed(1006, '')

    def test_ensure_connection_before_opening_handshake(self):
        self.protocol.state = CONNECTING

        with self.assertRaises(InvalidState):
            self.loop.run_until_complete(self.protocol.ensure_open())

    def test_legacy_recv(self):
        # By default legacy_recv in disabled.
        self.assertEqual(self.protocol.legacy_recv, False)

        self.close_connection()

        # Enable legacy_recv.
        self.protocol.legacy_recv = True

        # Now recv() returns None instead of raising ConnectionClosed.
        self.assertIsNone(self.loop.run_until_complete(self.protocol.recv()))

    def test_connection_closed_attributes(self):
        self.close_connection()

        with self.assertRaises(ConnectionClosed) as context:
            self.loop.run_until_complete(self.protocol.recv())

        connection_closed = context.exception
        self.assertEqual(connection_closed.code, 1000)
        self.assertEqual(connection_closed.reason, 'close')

    # Test the protocol logic for closing the connection.

    def test_local_close(self):
        # Emulate how the remote endpoint answers the closing handshake.
        self.receive_frame(self.close_frame)
        self.receive_eof_if_client()

        # Run the closing handshake.
        self.loop.run_until_complete(self.protocol.close(reason='close'))

        self.assertConnectionClosed(1000, 'close')
        self.assertOneFrameSent(*self.close_frame)

        # Closing the connection again is a no-op.
        self.loop.run_until_complete(self.protocol.close(reason='oh noes!'))

        self.assertConnectionClosed(1000, 'close')
        self.assertNoFrameSent()

    def test_remote_close(self):
        # Emulate how the remote endpoint initiates the closing handshake.
        self.receive_frame(self.close_frame)
        self.receive_eof_if_client()

        # Wait for some data in order to process the handshake.
        # After recv() raises ConnectionClosed, the connection is closed.
        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.recv())

        self.assertConnectionClosed(1000, 'close')
        self.assertOneFrameSent(*self.close_frame)

        # Closing the connection again is a no-op.
        self.loop.run_until_complete(self.protocol.close(reason='oh noes!'))

        self.assertConnectionClosed(1000, 'close')
        self.assertNoFrameSent()

    def test_simultaneous_close(self):
        self.receive_frame(self.remote_close)
        self.receive_eof_if_client()
        self.loop.run_until_complete(self.protocol.close(reason='local'))

        # The close code and reason are taken from the remote side because
        # that's presumably more useful that the values from the local side.
        self.assertConnectionClosed(1000, 'remote')
        self.assertOneFrameSent(*self.local_close)

    def test_close_preserves_incoming_frames(self):
        self.receive_frame(Frame(True, OP_TEXT, b'hello'))
        self.receive_frame(self.close_frame)
        self.receive_eof_if_client()
        self.loop.run_until_complete(self.protocol.close(reason='close'))

        self.assertConnectionClosed(1000, 'close')
        self.assertOneFrameSent(*self.close_frame)

        next_message = self.loop.run_until_complete(self.protocol.recv())
        self.assertEqual(next_message, 'hello')

    def test_close_protocol_error(self):
        invalid_close_frame = Frame(True, OP_CLOSE, b'\x00')
        self.receive_frame(invalid_close_frame)
        self.receive_eof_if_client()
        self.loop.run_until_complete(self.protocol.close(reason='close'))

        self.assertConnectionClosed(1002, '')

    def test_close_connection_lost(self):
        self.receive_eof()
        self.loop.run_until_complete(self.protocol.close(reason='close'))

        self.assertConnectionClosed(1006, '')

    def test_remote_close_race_with_failing_connection(self):
        self.make_drain_slow()

        # Fail the connection while answering a close frame from the client.
        self.loop.call_soon(self.receive_frame, self.remote_close)
        self.loop.call_later(MS, self. async, self.protocol.fail_connection())
        # The client expects the server to close the connection.
        # Simulate it instead of waiting for the connection timeout.
        self.loop.call_later(MS, self.receive_eof_if_client)

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.protocol.recv())

        # The closing handshake was completed by fail_connection.
        self.assertConnectionClosed(1011, '')
        self.assertOneFrameSent(*self.remote_close)

    def test_local_close_during_recv(self):
        recv = self. async (self.protocol.recv())

        self.receive_frame(self.close_frame)
        self.receive_eof_if_client()

        self.loop.run_until_complete(self.protocol.close(reason='close'))

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(recv)

        self.assertConnectionClosed(1000, 'close')

    # There is no test_remote_close_during_recv because it would be identical
    # to test_remote_close.

    def test_remote_close_during_send(self):
        self.make_drain_slow()
        send = self. async (self.protocol.send('hello'))

        self.receive_frame(self.close_frame)
        self.receive_eof()

        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(send)

        self.assertConnectionClosed(1006, '')