def assertInvalidResponseHeaders(self, exc_type, key='CSIRmL8dWYxeAdr/XpEHRw=='): """ Provide response headers for modification. Assert that the transformation made them invalid. """ headers = Headers() build_response(headers, key) yield headers assert issubclass(exc_type, InvalidHandshake) with self.assertRaises(exc_type): check_response(headers, key)
async def connect(cls, client): """ :param client: An instance of tclib.TinychatClient. :type client: tclib.TinychatClient :return: :rtype: """ ws = None log.info('requesting gateway') gateway = await connect_details(client.room) log.debug(f'gateway: {gateway}') if len(gateway) == 2: try: ws = await asyncio.wait_for(websockets.connect( uri=gateway['endpoint'], origin='https://tinychat.com', subprotocols=['tc'], extra_headers=Headers(tc_headers()), loop=client.loop, klass=cls), timeout=60, loop=client.loop) # check for connect timeout error. except asyncio.TimeoutError: log.warning('timed out waiting for client connect.') # return await cls.from_client(resume=resume) # send the join message. await ws.join(client, gateway['token']) # assign the dispatcher. # Sync From Async ws._dispatch = client.dispatch try: # make sure the websocket is in a open state. await ws.ensure_open() except websockets.exceptions.ConnectionClosed as e: log.warning(f'connection was closed: {e.code} {e.reason}') else: log.info(f'websocket connected to: {gateway["endpoint"]}') else: log.warning(f'failed to get gateway: {gateway}') return ws
async def _listen_forever(self): while True: # outer loop restarted every time the connection fails try: headers = Headers() headers['Authorization'] = f'Token {self.settings.session_key}' ssl_args = dict() if not self.settings.verify_certificate and 'wss:' in self.uri: ssl_args['ssl'] = ssl.SSLContext() async with websockets.connect(self.uri, extra_headers=headers, **ssl_args) as ws: self.logger.info(f'Websocket established with {self.uri}') self.logger.debug( f'Ping response: {json.dumps(self.ping_data)}') await ws.send(str(AgoraWebsocketMessage(self.ping_data))) self.logger.info( f'App is running --> listening for server messages...') while True: # listener loop try: msg = await ws.recv() except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed) as e: try: self.logger.warning( f'Connection probably closed (sending ping): {str(e)}' ) pong = await ws.ping() await asyncio.wait_for( pong, timeout=self.ping_timeout) continue except: self.logger.warning( f'Connection closed (trying again in {self.sleep_time}s): {str(e)}' ) await asyncio.sleep(self.sleep_time) break # inner loop await self._process_message(msg, ws) except socket.gaierror as e: self.logger.warning(f'Connection closed: {str(e)}') self.logger.warning(f'Trying to reconnect') continue except ConnectionRefusedError as e: self.logger.warning(f'Connection refused: {str(e)}') self.logger.warning(f'Trying to reconnect') continue
def __init__(self, addr: str = None, key: str = None, close_timeout: float = IO_TIMEOUT): """ Create a web socket client to connect platform. Parameters ---------- addr: str address for a service, like ws://host:port key: str key represents API key, which is used for verifying identity. """ self.__addr = addr if addr is not None else self.ADDRESS self.__key = key self.header = Headers(Authorization="Bearer " + key) if key is not None else None self.close_timeout = close_timeout self.__builder = flatbuffers.Builder(1400) self.__data_handler = None self.__queue = asyncio.Queue(10)
class ClientServerTests(unittest.TestCase): secure = False def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) def tearDown(self): self.loop.close() def run_loop_once(self): # Process callbacks scheduled with call_soon by appending a callback # to stop the event loop then running it until it hits that callback. self.loop.call_soon(self.loop.stop) self.loop.run_forever() def start_server(self, **kwds): # Disable compression by default in tests. kwds.setdefault('compression', None) # Disable pings by default in tests. kwds.setdefault('ping_interval', None) start_server = serve(handler, 'localhost', 0, **kwds) self.server = self.loop.run_until_complete(start_server) def start_client(self, resource_name='/', user_info=None, **kwds): # Disable compression by default in tests. kwds.setdefault('compression', None) # Disable pings by default in tests. kwds.setdefault('ping_interval', None) secure = kwds.get('ssl') is not None server_uri = get_server_uri(self.server, secure, resource_name, user_info) start_client = connect(server_uri, **kwds) self.client = self.loop.run_until_complete(start_client) def stop_client(self): try: self.loop.run_until_complete( asyncio.wait_for(self.client.close_connection_task, timeout=1)) except asyncio.TimeoutError: # pragma: no cover self.fail("Client failed to stop") def stop_server(self): self.server.close() try: self.loop.run_until_complete( asyncio.wait_for(self.server.wait_closed(), timeout=1)) except asyncio.TimeoutError: # pragma: no cover self.fail("Server failed to stop") @contextlib.contextmanager def temp_server(self, **kwds): with temp_test_server(self, **kwds): yield @contextlib.contextmanager def temp_client(self, *args, **kwds): with temp_test_client(self, *args, **kwds): yield @with_server() @with_client() def test_basic(self): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") def test_server_close_while_client_connected(self): with self.temp_server(loop=self.loop): # This endpoint waits just a bit when the connection is canceled # in order to test that wait_closed() really waits for completion. self.start_client('/slow_stop') with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.client.recv()) # Connection ends with 1001 going away. self.assertEqual(self.client.close_code, 1001) def test_explicit_event_loop(self): with self.temp_server(loop=self.loop): with self.temp_client(loop=self.loop): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") # The way the legacy SSL implementation wraps sockets makes it extremely # hard to write a test for Python 3.4. @unittest.skipIf(sys.version_info[:2] <= (3, 4), 'this test requires Python 3.5+') @with_server() def test_explicit_socket(self): class TrackedSocket(socket.socket): def __init__(self, *args, **kwargs): self.used_for_read = False self.used_for_write = False super().__init__(*args, **kwargs) def recv(self, *args, **kwargs): self.used_for_read = True return super().recv(*args, **kwargs) def send(self, *args, **kwargs): self.used_for_write = True return super().send(*args, **kwargs) server_socket = [ sock for sock in self.server.sockets if sock.family == socket.AF_INET ][0] client_socket = TrackedSocket(socket.AF_INET, socket.SOCK_STREAM) client_socket.connect(server_socket.getsockname()) try: self.assertFalse(client_socket.used_for_read) self.assertFalse(client_socket.used_for_write) with self.temp_client( sock=client_socket, # "You must set server_hostname when using ssl without a host" server_hostname='localhost' if self.secure else None, ): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") self.assertTrue(client_socket.used_for_read) self.assertTrue(client_socket.used_for_write) finally: client_socket.close() @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'this test requires Unix sockets') def test_unix_socket(self): with tempfile.TemporaryDirectory() as temp_dir: path = bytes(pathlib.Path(temp_dir) / 'websockets') # Like self.start_server() but with unix_serve(). unix_server = unix_serve(handler, path) self.server = self.loop.run_until_complete(unix_server) client_socket = socket.socket(socket.AF_UNIX) client_socket.connect(path) try: with self.temp_client(sock=client_socket): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") finally: client_socket.close() self.stop_server() @with_server() @with_client('/attributes') def test_protocol_attributes(self): # The test could be connecting with IPv6 or IPv4. expected_client_attrs = [ server_socket.getsockname()[:2] + (self.secure, ) for server_socket in self.server.sockets ] client_attrs = (self.client.host, self.client.port, self.client.secure) self.assertIn(client_attrs, expected_client_attrs) expected_server_attrs = ('localhost', 0, self.secure) server_attrs = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_attrs, repr(expected_server_attrs)) @with_server() @with_client('/path') def test_protocol_path(self): client_path = self.client.path self.assertEqual(client_path, '/path') server_path = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_path, '/path') @with_server() @with_client('/headers', user_info=('user', 'pass')) def test_protocol_basic_auth(self): self.assertEqual(self.client.request_headers['Authorization'], 'Basic dXNlcjpwYXNz') @with_server() @with_client('/headers') def test_protocol_headers(self): client_req = self.client.request_headers client_resp = self.client.response_headers self.assertEqual(client_req['User-Agent'], USER_AGENT) self.assertEqual(client_resp['Server'], USER_AGENT) server_req = self.loop.run_until_complete(self.client.recv()) server_resp = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_req, repr(client_req)) self.assertEqual(server_resp, repr(client_resp)) @with_server() @with_client('/headers', extra_headers=Headers({'X-Spam': 'Eggs'})) def test_protocol_custom_request_headers(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() @with_client('/headers', extra_headers={'X-Spam': 'Eggs'}) def test_protocol_custom_request_headers_dict(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() @with_client('/headers', extra_headers=[('X-Spam', 'Eggs')]) def test_protocol_custom_request_headers_list(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() @with_client('/headers', extra_headers=[('User-Agent', 'Eggs')]) def test_protocol_custom_request_user_agent(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertEqual(req_headers.count("User-Agent"), 1) self.assertIn("('User-Agent', 'Eggs')", req_headers) @with_server(extra_headers=lambda p, r: Headers({'X-Spam': 'Eggs'})) @with_client('/headers') def test_protocol_custom_response_headers_callable(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=lambda p, r: {'X-Spam': 'Eggs'}) @with_client('/headers') def test_protocol_custom_response_headers_callable_dict(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=lambda p, r: [('X-Spam', 'Eggs')]) @with_client('/headers') def test_protocol_custom_response_headers_callable_list(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=Headers({'X-Spam': 'Eggs'})) @with_client('/headers') def test_protocol_custom_response_headers(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers={'X-Spam': 'Eggs'}) @with_client('/headers') def test_protocol_custom_response_headers_dict(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=[('X-Spam', 'Eggs')]) @with_client('/headers') def test_protocol_custom_response_headers_list(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=[('Server', 'Eggs')]) @with_client('/headers') def test_protocol_custom_response_user_agent(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertEqual(resp_headers.count("Server"), 1) self.assertIn("('Server', 'Eggs')", resp_headers) def make_http_request(self, path='/'): # Set url to 'https?://<host>:<port><path>'. url = get_server_uri(self.server, resource_name=path, secure=self.secure) url = url.replace('ws', 'http') if self.secure: open_health_check = functools.partial(urllib.request.urlopen, url, context=self.client_context) else: open_health_check = functools.partial(urllib.request.urlopen, url) return self.loop.run_in_executor(None, open_health_check) @with_server(create_protocol=HealthCheckServerProtocol) def test_http_request_http_endpoint(self): # Making a HTTP request to a HTTP endpoint succeeds. response = self.loop.run_until_complete( self.make_http_request('/__health__/')) with contextlib.closing(response): self.assertEqual(response.code, 200) self.assertEqual(response.read(), b'status = green\n') @with_server(create_protocol=HealthCheckServerProtocol) def test_http_request_ws_endpoint(self): # Making a HTTP request to a WS endpoint fails. with self.assertRaises(urllib.error.HTTPError) as raised: self.loop.run_until_complete(self.make_http_request()) self.assertEqual(raised.exception.code, 426) self.assertEqual(raised.exception.headers['Upgrade'], 'websocket') @with_server(create_protocol=HealthCheckServerProtocol) def test_ws_connection_http_endpoint(self): # Making a WS connection to a HTTP endpoint fails. with self.assertRaises(InvalidStatusCode) as raised: self.start_client('/__health__/') self.assertEqual(raised.exception.status_code, 200) @with_server(create_protocol=HealthCheckServerProtocol) def test_ws_connection_ws_endpoint(self): # Making a WS connection to a WS endpoint succeeds. self.start_client() self.loop.run_until_complete(self.client.send("Hello!")) self.loop.run_until_complete(self.client.recv()) self.stop_client() def assert_client_raises_code(self, status_code): with self.assertRaises(InvalidStatusCode) as raised: self.start_client() self.assertEqual(raised.exception.status_code, status_code) @with_server(create_protocol=UnauthorizedServerProtocol) def test_server_create_protocol(self): self.assert_client_raises_code(401) @with_server(create_protocol=( lambda *args, **kwargs: UnauthorizedServerProtocol(*args, **kwargs))) def test_server_create_protocol_function(self): self.assert_client_raises_code(401) @with_server(klass=UnauthorizedServerProtocol) def test_server_klass_backwards_compatibility(self): self.assert_client_raises_code(401) @with_server(create_protocol=ForbiddenServerProtocol, klass=UnauthorizedServerProtocol) def test_server_create_protocol_over_klass(self): self.assert_client_raises_code(403) @with_server() @with_client('/path', create_protocol=FooClientProtocol) def test_client_create_protocol(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() @with_client( '/path', create_protocol=( lambda *args, **kwargs: FooClientProtocol(*args, **kwargs)), ) def test_client_create_protocol_function(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() @with_client('/path', klass=FooClientProtocol) def test_client_klass(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() @with_client('/path', create_protocol=BarClientProtocol, klass=FooClientProtocol) def test_client_create_protocol_over_klass(self): self.assertIsInstance(self.client, BarClientProtocol) @with_server(close_timeout=7) @with_client('/close_timeout') def test_server_close_timeout(self): close_timeout = self.loop.run_until_complete(self.client.recv()) self.assertEqual(eval(close_timeout), 7) @with_server(timeout=6) @with_client('/close_timeout') def test_server_timeout_backwards_compatibility(self): close_timeout = self.loop.run_until_complete(self.client.recv()) self.assertEqual(eval(close_timeout), 6) @with_server(close_timeout=7, timeout=6) @with_client('/close_timeout') def test_server_close_timeout_over_timeout(self): close_timeout = self.loop.run_until_complete(self.client.recv()) self.assertEqual(eval(close_timeout), 7) @with_server() @with_client('/close_timeout', close_timeout=7) def test_client_close_timeout(self): self.assertEqual(self.client.close_timeout, 7) @with_server() @with_client('/close_timeout', timeout=6) def test_client_timeout_backwards_compatibility(self): self.assertEqual(self.client.close_timeout, 6) @with_server() @with_client('/close_timeout', close_timeout=7, timeout=6) def test_client_close_timeout_over_timeout(self): self.assertEqual(self.client.close_timeout, 7) @with_server() @with_client('/extensions') def test_no_extension(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) @with_server(extensions=[ServerNoOpExtensionFactory()]) @with_client('/extensions', extensions=[ClientNoOpExtensionFactory()]) def test_extension(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([NoOpExtension()])) self.assertEqual(repr(self.client.extensions), repr([NoOpExtension()])) @with_server() @with_client('/extensions', extensions=[ClientNoOpExtensionFactory()]) def test_extension_not_accepted(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) @with_server(extensions=[ServerNoOpExtensionFactory()]) @with_client('/extensions') def test_extension_not_requested(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) @with_server(extensions=[ServerNoOpExtensionFactory([('foo', None)])]) def test_extension_client_rejection(self): with self.assertRaises(NegotiationError): self.start_client('/extensions', extensions=[ClientNoOpExtensionFactory()]) @with_server(extensions=[ # No match because the client doesn't send client_max_window_bits. ServerPerMessageDeflateFactory(client_max_window_bits=10), ServerPerMessageDeflateFactory(), ]) @with_client('/extensions', extensions=[ClientPerMessageDeflateFactory()]) def test_extension_no_match_then_match(self): # The order requested by the client has priority. server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([PerMessageDeflate(False, False, 15, 15)])) self.assertEqual( repr(self.client.extensions), repr([PerMessageDeflate(False, False, 15, 15)]), ) @with_server(extensions=[ServerPerMessageDeflateFactory()]) @with_client('/extensions', extensions=[ClientNoOpExtensionFactory()]) def test_extension_mismatch(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) @with_server(extensions=[ ServerNoOpExtensionFactory(), ServerPerMessageDeflateFactory() ]) @with_client( '/extensions', extensions=[ ClientPerMessageDeflateFactory(), ClientNoOpExtensionFactory() ], ) def test_extension_order(self): # The order requested by the client has priority. server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual( server_extensions, repr([PerMessageDeflate(False, False, 15, 15), NoOpExtension()]), ) self.assertEqual( repr(self.client.extensions), repr([PerMessageDeflate(False, False, 15, 15), NoOpExtension()]), ) @with_server(extensions=[ServerNoOpExtensionFactory()]) @unittest.mock.patch.object(WebSocketServerProtocol, 'process_extensions') def test_extensions_error(self, _process_extensions): _process_extensions.return_value = 'x-no-op', [NoOpExtension()] with self.assertRaises(NegotiationError): self.start_client('/extensions', extensions=[ClientPerMessageDeflateFactory()]) @with_server(extensions=[ServerNoOpExtensionFactory()]) @unittest.mock.patch.object(WebSocketServerProtocol, 'process_extensions') def test_extensions_error_no_extensions(self, _process_extensions): _process_extensions.return_value = 'x-no-op', [NoOpExtension()] with self.assertRaises(InvalidHandshake): self.start_client('/extensions') @with_server(compression='deflate') @with_client('/extensions', compression='deflate') def test_compression_deflate(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([PerMessageDeflate(False, False, 15, 15)])) self.assertEqual( repr(self.client.extensions), repr([PerMessageDeflate(False, False, 15, 15)]), ) @with_server( extensions=[ ServerPerMessageDeflateFactory(client_no_context_takeover=True, server_max_window_bits=10) ], compression='deflate', # overridden by explicit config ) @with_client( '/extensions', extensions=[ ClientPerMessageDeflateFactory(server_no_context_takeover=True, client_max_window_bits=12) ], compression='deflate', # overridden by explicit config ) def test_compression_deflate_and_explicit_config(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([PerMessageDeflate(True, True, 12, 10)])) self.assertEqual(repr(self.client.extensions), repr([PerMessageDeflate(True, True, 10, 12)])) def test_compression_unsupported_server(self): with self.assertRaises(ValueError): self.loop.run_until_complete(self.start_server(compression='xz')) @with_server() def test_compression_unsupported_client(self): with self.assertRaises(ValueError): self.loop.run_until_complete(self.start_client(compression='xz')) @with_server() @with_client('/subprotocol') def test_no_subprotocol(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @with_server(subprotocols=['superchat', 'chat']) @with_client('/subprotocol', subprotocols=['otherchat', 'chat']) def test_subprotocol(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr('chat')) self.assertEqual(self.client.subprotocol, 'chat') @with_server(subprotocols=['superchat']) @with_client('/subprotocol', subprotocols=['otherchat']) def test_subprotocol_not_accepted(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @with_server() @with_client('/subprotocol', subprotocols=['otherchat', 'chat']) def test_subprotocol_not_offered(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @with_server(subprotocols=['superchat', 'chat']) @with_client('/subprotocol') def test_subprotocol_not_requested(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @with_server(subprotocols=['superchat']) @unittest.mock.patch.object(WebSocketServerProtocol, 'process_subprotocol') def test_subprotocol_error(self, _process_subprotocol): _process_subprotocol.return_value = 'superchat' with self.assertRaises(NegotiationError): self.start_client('/subprotocol', subprotocols=['otherchat']) self.run_loop_once() @with_server(subprotocols=['superchat']) @unittest.mock.patch.object(WebSocketServerProtocol, 'process_subprotocol') def test_subprotocol_error_no_subprotocols(self, _process_subprotocol): _process_subprotocol.return_value = 'superchat' with self.assertRaises(InvalidHandshake): self.start_client('/subprotocol') self.run_loop_once() @with_server(subprotocols=['superchat', 'chat']) @unittest.mock.patch.object(WebSocketServerProtocol, 'process_subprotocol') def test_subprotocol_error_two_subprotocols(self, _process_subprotocol): _process_subprotocol.return_value = 'superchat, chat' with self.assertRaises(InvalidHandshake): self.start_client('/subprotocol', subprotocols=['superchat', 'chat']) self.run_loop_once() @with_server() @unittest.mock.patch('websockets.server.read_request') def test_server_receives_malformed_request(self, _read_request): _read_request.side_effect = ValueError("read_request failed") with self.assertRaises(InvalidHandshake): self.start_client() @with_server() @unittest.mock.patch('websockets.client.read_response') def test_client_receives_malformed_response(self, _read_response): _read_response.side_effect = ValueError("read_response failed") with self.assertRaises(InvalidHandshake): self.start_client() self.run_loop_once() @with_server() @unittest.mock.patch('websockets.client.build_request') def test_client_sends_invalid_handshake_request(self, _build_request): def wrong_build_request(headers): return '42' _build_request.side_effect = wrong_build_request with self.assertRaises(InvalidHandshake): self.start_client() @with_server() @unittest.mock.patch('websockets.server.build_response') def test_server_sends_invalid_handshake_response(self, _build_response): def wrong_build_response(headers, key): return build_response(headers, '42') _build_response.side_effect = wrong_build_response with self.assertRaises(InvalidHandshake): self.start_client() @with_server() @unittest.mock.patch('websockets.client.read_response') def test_server_does_not_switch_protocols(self, _read_response): @asyncio.coroutine def wrong_read_response(stream): status_code, headers = yield from read_response(stream) return 400, headers _read_response.side_effect = wrong_read_response with self.assertRaises(InvalidStatusCode): self.start_client() self.run_loop_once() @with_server() @unittest.mock.patch( 'websockets.server.WebSocketServerProtocol.process_request') def test_server_error_in_handshake(self, _process_request): _process_request.side_effect = Exception("process_request crashed") with self.assertRaises(InvalidHandshake): self.start_client() @with_server() @unittest.mock.patch('websockets.server.WebSocketServerProtocol.send') def test_server_handler_crashes(self, send): send.side_effect = ValueError("send failed") with self.temp_client(): self.loop.run_until_complete(self.client.send("Hello!")) with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.client.recv()) # Connection ends with an unexpected error. self.assertEqual(self.client.close_code, 1011) @with_server() @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') def test_server_close_crashes(self, close): close.side_effect = ValueError("close failed") with self.temp_client(): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") # Connection ends with an abnormal closure. self.assertEqual(self.client.close_code, 1006) @with_server() @with_client() @unittest.mock.patch.object(WebSocketClientProtocol, 'handshake') def test_client_closes_connection_before_handshake(self, handshake): # We have mocked the handshake() method to prevent the client from # performing the opening handshake. Force it to close the connection. self.client.writer.close() # The server should stop properly anyway. It used to hang because the # task handling the connection was waiting for the opening handshake. @with_server() @unittest.mock.patch('websockets.server.read_request') def test_server_shuts_down_during_opening_handshake(self, _read_request): _read_request.side_effect = asyncio.CancelledError self.server.closing = True with self.assertRaises(InvalidHandshake) as raised: self.start_client() # Opening handshake fails with 503 Service Unavailable self.assertEqual(str(raised.exception), "Status code not 101: 503") @with_server() def test_server_shuts_down_during_connection_handling(self): with self.temp_client(): self.server.close() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.client.recv()) # Websocket connection terminates with 1001 Going Away. self.assertEqual(self.client.close_code, 1001) @with_server() @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') def test_server_shuts_down_during_connection_close(self, _close): _close.side_effect = asyncio.CancelledError self.server.closing = True with self.temp_client(): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") # Websocket connection terminates abnormally. self.assertEqual(self.client.close_code, 1006) @with_server(create_protocol=ForbiddenServerProtocol) def test_invalid_status_error_during_client_connect(self): with self.assertRaises(InvalidStatusCode) as raised: self.start_client() exception = raised.exception self.assertEqual(str(exception), "Status code not 101: 403") self.assertEqual(exception.status_code, 403) @with_server() @unittest.mock.patch( 'websockets.server.WebSocketServerProtocol.write_http_response') @unittest.mock.patch( 'websockets.server.WebSocketServerProtocol.read_http_request') def test_connection_error_during_opening_handshake(self, _read_http_request, _write_http_response): _read_http_request.side_effect = ConnectionError # This exception is currently platform-dependent. It was observed to # be ConnectionResetError on Linux in the non-SSL case, and # InvalidMessage otherwise (including both Linux and macOS). This # doesn't matter though since this test is primarily for testing a # code path on the server side. with self.assertRaises(Exception): self.start_client() # No response must not be written if the network connection is broken. _write_http_response.assert_not_called() @with_server() @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') def test_connection_error_during_closing_handshake(self, close): close.side_effect = ConnectionError with self.temp_client(): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") # Connection ends with an abnormal closure. self.assertEqual(self.client.close_code, 1006)
def process_request(self, path, request_headers): # Test returning headers as a Headers instance (1/3) return UNAUTHORIZED, Headers([('X-Access', 'denied')]), b''
async def handshake( self, wsuri: websockets.WebSocketURI, origin: Optional[websockets.typing.Origin] = None, available_extensions: Optional[ Sequence[websockets.extensions.base.ClientExtensionFactory] ] = None, available_subprotocols: Optional[ Sequence[websockets.typing.Subprotocol] ] = None, extra_headers: Optional[websockets.http.HeadersLike] = None, ): """ Perform the client side of the opening handshake. :param origin: sets the Origin HTTP header :param available_extensions: list of supported extensions in the order in which they should be used :param available_subprotocols: list of supported subprotocols in order of decreasing preference :param extra_headers: sets additional HTTP request headers; it must be a :class:`~websockets.http.Headers` instance, a :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` pairs :raises ~websockets.exceptions.InvalidHandshake: if the handshake fails """ request_headers = Headers() if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover request_headers["Host"] = wsuri.host else: request_headers["Host"] = f"{wsuri.host}:{wsuri.port}" # if wsuri.user_info: # request_headers["Authorization"] = build_authorization_basic( # *wsuri.user_info # ) if origin is not None: request_headers["Origin"] = origin key = build_request(request_headers) if available_extensions is not None: extensions_header = websockets.headers.build_extension( [ (extension_factory.name, extension_factory.get_request_params()) for extension_factory in available_extensions ] ) request_headers["Sec-WebSocket-Extensions"] = extensions_header if available_subprotocols is not None: protocol_header = websockets.headers.build_subprotocol( available_subprotocols ) request_headers["Sec-WebSocket-Protocol"] = protocol_header if extra_headers is not None: if isinstance(extra_headers, Headers): extra_headers = extra_headers.raw_items() elif isinstance(extra_headers, collections.abc.Mapping): extra_headers = extra_headers.items() for name, value in extra_headers: request_headers[name] = value request_headers.setdefault("User-Agent", websockets.http.USER_AGENT) req = Request("GET", f"https://{wsuri.host}/", data="", headers=request_headers) prepared_request = req.prepare() encrypted_request = aws_auth(prepared_request) self.write_http_request(wsuri.resource_name, Headers(encrypted_request.headers)) status_code, response_headers = await self.read_http_response() if status_code in (301, 302, 303, 307, 308): if "Location" not in response_headers: raise websockets.exceptions.InvalidHeader("Location") raise websockets.exceptions.RedirectHandshake(response_headers["Location"]) elif status_code != 101: raise websockets.exceptions.InvalidStatusCode(status_code) websockets.handshake.check_response(response_headers, key) self.extensions = self.process_extensions( response_headers, available_extensions ) self.subprotocol = self.process_subprotocol( response_headers, available_subprotocols ) self.connection_open()
async def handleHTTP( #type: ignore self, path: str, request_headers: Headers) -> Tuple[HTTPStatus, Union[ Headers, Mapping[str, str], Iterable[Tuple[str, str]]], bytes]: status: HTTPStatus if path == "/token": try: token = request_headers['Token'] except KeyError: status = HTTPStatus.BAD_REQUEST headers = Headers() headers['error'] = 'No token header present' return (status, headers, b'') try: await self.database.tokenSignIn(token=token) except database.InvalidCredentials: status = HTTPStatus.UNAUTHORIZED headers = Headers() headers['error'] = 'Invalid token!' return (status, headers, b'') elif path == "/signIn": try: email = request_headers['Email'] except KeyError: status = HTTPStatus.BAD_REQUEST headers = Headers() headers['error'] = 'No email header present' return (status, headers, b'') try: password = request_headers['Password'] except KeyError: status = HTTPStatus.BAD_REQUEST headers = Headers() headers['error'] = 'No password header present' return (status, headers, b'') try: await self.database.signIn(email=email, password=password) except database.InvalidCredentials: status = HTTPStatus.UNAUTHORIZED headers = Headers() headers['error'] = 'Invalid token!' return (status, headers, b'') elif path == "/signUp": try: password = request_headers['Password'] except KeyError: status = HTTPStatus.BAD_REQUEST headers = Headers() headers['error'] = 'No password header present' return (status, headers, b'') try: email = request_headers['Email'] except KeyError: status = HTTPStatus.BAD_REQUEST headers = Headers() headers['error'] = 'No email header present' return (status, headers, b'') try: username = request_headers['Username'] except KeyError: status = HTTPStatus.BAD_REQUEST headers = Headers() headers['error'] = 'No username header present' return (status, headers, b'') # try: # response = request_headers['Response'] # except KeyError: # status = HTTPStatus.BAD_REQUEST # headers = Headers() # headers['error'] = 'No response header present' # return (status, headers, b'') # if not (await self.recaptcha.verify(response)): # status = HTTPStatus.UNAUTHORIZED # headers = Headers() # headers['error'] = 'Invalid response token.' # return (status, headers, b'') try: await self.database.signUp(username=username, password=password, email=email) except database.UserExists: status = HTTPStatus.FORBIDDEN headers = Headers() headers['error'] = 'Email is taken.' return (status, headers, b'') else: status = HTTPStatus.NOT_FOUND return (status, Headers(), b'')
class WebSocketBase(AbstractWebSocketClient): """Client for the Andesite WebSocket handler. Args: ws_uri: Websocket endpoint to connect to. user_id: Bot's user id. If at the time of creation this is unknown, you may pass `None`, but then it needs to be set before connecting for the first time. password: Authorization for the Andesite node. Set to `None` if the node doesn't have a password. state: State handler to use. If `False` state handling is disabled. `None` to use the default state handler (`State`). max_connect_attempts: See the `max_connect_attempts` attribute. The client automatically keeps track of the current connection id and resumes the previous connection when calling `connect`, if there is any. You can delete the `connection_id` property to disable this. See Also: `AbstractWebSocketClient` for more details including a list of events that are dispatched. Attributes: max_connect_attempts (Optional[int]): Max amount of connection attempts to start before giving up. If `None`, there is no upper limit. This value can be overwritten when calling `connect` manually. web_socket_client (Optional[WebSocketClientProtocol]): Web socket client which is used. This attribute will be set once `connect` is called. Don't use the presence of this attribute to check whether the client is connected, use the `connected` property. """ max_connect_attempts: Optional[int] web_socket_client: Optional[WebSocketClientProtocol] __closed: bool __ws_uri: str __headers: Headers __last_connection_id: Optional[str] __connect_lock: Optional[asyncio.Lock] __read_loop: Optional[asyncio.Future] _json_encoder: JSONEncoder _json_decoder: JSONDecoder def __init__(self, ws_uri: Union[str, URL], user_id: Optional[int], password: Optional[str], *, state: andesite.StateArgumentType = False, max_connect_attempts: int = None) -> None: self.__ws_uri = str(ws_uri) self.__headers = Headers() if password is not None: self.__headers["Authorization"] = password if user_id is not None: self.user_id = user_id self.__last_connection_id = None self.max_connect_attempts = max_connect_attempts self.web_socket_client = None # can't create the lock here, because if the user uses # asyncio.run and creates the client outside of it, the loop # within the lock will not be the same as the loop used by # asyncio.run (as it creates a new loop every time) self.__connect_lock = None self.__closed = False self.__read_loop = None self._json_encoder = JSONEncoder() self._json_decoder = JSONDecoder() self.state = state def __repr__(self) -> str: return f"{type(self).__name__}(ws_uri={self.__ws_uri!r}, user_id={self.user_id!r}, " \ f"password=[HIDDEN], state={self.state!r}, max_connect_attempts={self.max_connect_attempts!r})" def __str__(self) -> str: return f"{type(self).__name__}({self.__ws_uri})" @property def user_id(self) -> Optional[int]: """User id. This is only `None` if it wasn't passed to the constructor. You can set this property to a new user id. """ return self.__headers.get("User-Id") @user_id.setter def user_id(self, user_id: int) -> None: self.__headers["User-Id"] = str(user_id) @property def closed(self) -> bool: return self.__closed @property def connected(self) -> bool: if self.web_socket_client: return self.web_socket_client.open else: return False @property def connection_id(self) -> Optional[str]: return self.__last_connection_id @connection_id.deleter def connection_id(self) -> None: self.__last_connection_id = None @property def node_region(self) -> Optional[str]: client = self.web_socket_client if client: return client.response_headers.get("Andesite-Node-Region") return None @property def node_id(self) -> Optional[str]: client = self.web_socket_client if client: return client.response_headers.get("Andesite-Node-Id") return None def _get_connect_lock(self, *, loop: asyncio.AbstractEventLoop = None ) -> asyncio.Lock: """Get the connect lock. The connect lock is only created once. Subsequent calls always return the same lock. The reason for the delayed creating is that the lock is bound to an event loop, which can change between __init__ and connect. """ if self.__connect_lock is None: self.__connect_lock = asyncio.Lock(loop=loop) return self.__connect_lock async def __connect(self, max_attempts: int = None) -> None: """Internal connect method. Args: max_attempts: Max amount of connection attempts to perform before aborting. This overwrites the instance attribute `max_connect_attempts`. Raises: ValueError: If client is already connected Notes: If `max_attempts` is exceeded and the client gives up on connecting it is closed! """ if self.connected: raise ValueError("Already connected!") headers = self.__headers if "User-Id" not in headers: raise KeyError("Trying to connect but user id unknown.\n" "This is most likely the case because you didn't\n" "set the user_id in the constructor and forgot to\n" "set it before connecting!") # inject the connection id to resume previous connection if self.__last_connection_id is not None: headers["Andesite-Resume-Id"] = self.__last_connection_id else: with suppress(KeyError): del headers["Andesite-Resume-Id"] attempt: int = 1 max_attempts = max_attempts or self.max_connect_attempts while max_attempts is None or attempt <= max_attempts: client = await try_connect(self.__ws_uri, extra_headers=headers) if client: break timeout = int(math.pow(attempt, 1.5)) log.info( f"Connection unsuccessful, trying again in {timeout} seconds") await asyncio.sleep(timeout) attempt += 1 else: self.__closed = True raise ConnectionError( f"Couldn't connect to {self.__ws_uri} after {attempt} attempts" ) log.info("%s: connected", self) self.web_socket_client = client self.__start_read_loop() _ = self.event_target.emit(WebSocketConnectEvent(self)) async def connect(self, *, max_attempts: int = None) -> None: if self.closed: raise ValueError("Client is closed and cannot be reused.") async with self._get_connect_lock(): if not self.connected: await self.__connect(max_attempts) async def disconnect(self) -> None: async with self._get_connect_lock(): self.__stop_read_loop() self.__last_connection_id = None if self.connected: await self.web_socket_client.close(reason="disconnect") _ = self.event_target.emit(WebSocketDisconnectEvent( self, True)) async def reset(self) -> None: await self.disconnect() del self.connection_id self.__closed = False async def close(self) -> None: await self.disconnect() self.__closed = True async def __web_socket_reader(self) -> None: """Internal web socket read loop. This method should never be called manually, see the following methods for controlling the reader. See Also: `WebSocket._start_read_loop` to start the read loop. `WebSocket._stop_read_loop` to stop the read loop. Notes: The read loop is automatically managed by the `WebSocket.connect` and `WebSocket.disconnect` methods. """ loop = asyncio.get_event_loop() def handle_msg(raw_msg: str) -> None: try: data: Dict[str, Any] = self._json_decoder.decode(raw_msg) except JSONDecodeError as e: log.error( f"Couldn't parse received JSON data in {self}: {e}\nmsg: {raw_msg}" ) return if not isinstance(data, dict): log.warning( f"Received invalid message type in {self}. " f"Expecting object, received type {type(data).__name__}: {data}" ) return _ = self.event_target.emit(RawMsgReceiveEvent(self, data)) try: op = data.pop("op") except KeyError: log.info(f"Ignoring message without op code in {self}: {data}") return event_type = data.get("type") cls = andesite.get_update_model(op, event_type) if cls is None: log.warning( f"Ignoring message with unknown op \"{op}\" in {self}: {data}" ) return try: message: andesite.ReceiveOperation = build_from_raw(cls, data) except Exception: log.exception( f"Couldn't parse message in {self} from Andesite node to {cls}: {data}" ) return message.client = self if isinstance(message, andesite.ConnectionUpdate): log.info( f"received connection update, setting last connection id in {self}." ) self.__last_connection_id = message.id _ = self.event_target.emit(message) if self.state is not None: loop.create_task(self.state._handle_andesite_message(message)) while True: try: raw_msg = await self.web_socket_client.recv() except asyncio.CancelledError: break except ConnectionClosed: _ = self.event_target.emit( WebSocketDisconnectEvent(self, False)) log.error( f"Disconnected from websocket, trying to reconnect {self}!" ) await self.connect() continue if log.isEnabledFor(logging.DEBUG): log.debug(f"Received message in {self}: {raw_msg}") try: handle_msg(raw_msg) except Exception: log.exception("Exception in %s while handling message %s.", self, raw_msg) def __start_read_loop(self, *, loop: asyncio.AbstractEventLoop = None) -> None: """Start the web socket reader. If the reader is already running, this is a no-op. """ if self.__read_loop and not self.__read_loop.done(): return if loop is None: loop = asyncio.get_event_loop() self.__read_loop = loop.create_task(self.__web_socket_reader()) def __stop_read_loop(self) -> None: """Stop the web socket reader. If the reader is already stopped, this is a no-op. """ if not self.__read_loop: return self.__read_loop.cancel() async def send(self, guild_id: int, op: str, payload: Dict[str, Any]) -> None: if self.web_socket_client is None or self.web_socket_client.open: log.info("%s: Not connected, connecting.", self) await self.connect() payload.update(guildId=str(guild_id), op=op) log.debug("%s: sending payload: %s", self, payload) _ = self.event_target.emit(RawMsgSendEvent(self, guild_id, op, payload)) data = self._json_encoder.encode(payload) try: await self.web_socket_client.send(data) except ConnectionClosed: # let the websocket reader handle this log.warning( "%s: couldn't send message because the connection is closed: %s", self, payload) else: state = self.state if state: loop = asyncio.get_event_loop() loop.create_task( state._handle_sent_message(guild_id, op, payload))
def test_str(self): for exception, exception_str in [ # fmt: off ( InvalidHandshake("Invalid request"), "Invalid request", ), ( AbortHandshake(200, Headers(), b'OK\n'), "HTTP 200, 0 headers, 3 bytes", ), ( InvalidMessage("Malformed HTTP message"), "Malformed HTTP message", ), ( InvalidHeader('Name'), "Missing Name header", ), ( InvalidHeader('Name', None), "Missing Name header", ), ( InvalidHeader('Name', ''), "Empty Name header", ), ( InvalidHeader('Name', 'Value'), "Invalid Name header: Value", ), ( InvalidHeaderFormat('Sec-WebSocket-Protocol', "expected token", 'a=|', 3), "Invalid Sec-WebSocket-Protocol header: " "expected token at 3 in a=|", ), ( InvalidHeaderValue('Sec-WebSocket-Version', '42'), "Invalid Sec-WebSocket-Version header: 42", ), ( InvalidUpgrade('Upgrade'), "Missing Upgrade header", ), ( InvalidUpgrade('Connection', 'websocket'), "Invalid Connection header: websocket", ), ( InvalidOrigin('http://bad.origin'), 'Invalid Origin header: http://bad.origin', ), ( InvalidStatusCode(403), "Status code not 101: 403", ), ( NegotiationError("Unsupported subprotocol: spam"), "Unsupported subprotocol: spam", ), ( InvalidParameterName('|'), "Invalid parameter name: |", ), ( InvalidParameterValue('a', '|'), "Invalid value for parameter a: |", ), ( DuplicateParameter('a'), "Duplicate parameter: a", ), ( InvalidState("WebSocket connection isn't established yet"), "WebSocket connection isn't established yet", ), ( ConnectionClosed(1000, ''), "WebSocket connection is closed: code = 1000 " "(OK), no reason", ), ( ConnectionClosed(1001, 'bye'), "WebSocket connection is closed: code = 1001 " "(going away), reason = bye", ), (ConnectionClosed(1006, None), "WebSocket connection is closed: code = 1006 " "(connection closed abnormally [internal]), no reason"), (ConnectionClosed(1016, None), "WebSocket connection is closed: code = 1016 " "(unknown), no reason"), (ConnectionClosed(3000, None), "WebSocket connection is closed: code = 3000 " "(registered), no reason"), (ConnectionClosed(4000, None), "WebSocket connection is closed: code = 4000 " "(private use), no reason"), ( InvalidURI("|"), "| isn't a valid URI", ), ( PayloadTooBig("Payload length exceeds limit: 2 > 1 bytes"), "Payload length exceeds limit: 2 > 1 bytes", ), ( WebSocketProtocolError("Invalid opcode: 7"), "Invalid opcode: 7", ), # fmt: on ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str)
def test_str(self): for exception, exception_str in [ # fmt: off ( InvalidHandshake("invalid request"), "invalid request", ), ( AbortHandshake(200, Headers(), b"OK\n"), "HTTP 200, 0 headers, 3 bytes", ), ( SecurityError("redirect from WSS to WS"), "redirect from WSS to WS", ), ( RedirectHandshake("wss://example.com"), "redirect to wss://example.com", ), ( InvalidMessage("malformed HTTP message"), "malformed HTTP message", ), ( InvalidHeader("Name"), "missing Name header", ), ( InvalidHeader("Name", None), "missing Name header", ), ( InvalidHeader("Name", ""), "empty Name header", ), ( InvalidHeader("Name", "Value"), "invalid Name header: Value", ), ( InvalidHeaderFormat("Sec-WebSocket-Protocol", "expected token", "a=|", 3), "invalid Sec-WebSocket-Protocol header: " "expected token at 3 in a=|", ), ( InvalidHeaderValue("Sec-WebSocket-Version", "42"), "invalid Sec-WebSocket-Version header: 42", ), ( InvalidUpgrade("Upgrade"), "missing Upgrade header", ), ( InvalidUpgrade("Connection", "websocket"), "invalid Connection header: websocket", ), ( InvalidOrigin("http://bad.origin"), "invalid Origin header: http://bad.origin", ), ( InvalidStatusCode(403), "server rejected WebSocket connection: HTTP 403", ), ( NegotiationError("unsupported subprotocol: spam"), "unsupported subprotocol: spam", ), ( InvalidParameterName("|"), "invalid parameter name: |", ), ( InvalidParameterValue("a", "|"), "invalid value for parameter a: |", ), ( DuplicateParameter("a"), "duplicate parameter: a", ), ( InvalidState("WebSocket connection isn't established yet"), "WebSocket connection isn't established yet", ), ( ConnectionClosed(1000, ""), "WebSocket connection is closed: code = 1000 " "(OK), no reason", ), ( ConnectionClosedOK(1001, "bye"), "WebSocket connection is closed: code = 1001 " "(going away), reason = bye", ), (ConnectionClosed(1006, None), "WebSocket connection is closed: code = 1006 " "(connection closed abnormally [internal]), no reason"), (ConnectionClosedError(1016, None), "WebSocket connection is closed: code = 1016 " "(unknown), no reason"), (ConnectionClosed(3000, None), "WebSocket connection is closed: code = 3000 " "(registered), no reason"), (ConnectionClosed(4000, None), "WebSocket connection is closed: code = 4000 " "(private use), no reason"), ( InvalidURI("|"), "| isn't a valid URI", ), ( PayloadTooBig("payload length exceeds limit: 2 > 1 bytes"), "payload length exceeds limit: 2 > 1 bytes", ), ( WebSocketProtocolError("invalid opcode: 7"), "invalid opcode: 7", ), # fmt: on ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str)