def test_split_message(self): """Test splitting messages.""" num_blocks, gen = split_message("abcdefgh", 2) self.assertEqual(num_blocks, 4) self.assertEqual(list(gen), [b"ab", b"cd", b"ef", b"gh"]) num_blocks, gen = split_message("abcdefgh", 3) self.assertEqual(num_blocks, 3) self.assertEqual(list(gen), [b"abc", b"def", b"gh"])
async def test_serve(self): """Test the serve method of the server.""" response = [] disconnects = [] connection_ids = [] def handle_request(command, data, temp_directory, connection_id): nonlocal connection_ids connection_ids.append(connection_id) return command + "-" + data + "!", [] server = CommunicationEngineServer( host=None, port=None, handle_request=handle_request, handle_disconnect=lambda u: disconnects.append(u), ) websocket = MockWebsocket( id=12, to_recv=[ encode_header([1, 3, 0, "greet"], LEN_HEADER), *split_message("Hello", block_size=2)[1], encode_header([1, 2, 0, "say_goodbye"], LEN_HEADER), *split_message("Bye", block_size=2)[1], ], sent_data=response, ) await server._serve(websocket, None) version, num_blocks, num_files = decode_header(response[0], LEN_HEADER) self.assertEqual(version, 1) self.assertEqual(num_blocks, 1) self.assertEqual(num_files, 0) self.assertEqual(response[1], b"greet-Hello!") version, num_blocks, num_files = decode_header(response[2], LEN_HEADER) self.assertEqual(version, 1) self.assertEqual(num_blocks, 1) self.assertEqual(num_files, 0) self.assertEqual(response[3], b"say_goodbye-Bye!") self.assertEqual(disconnects, [connection_ids[0]]) websocket.reset() await server._serve(websocket, None) self.assertEqual(len(connection_ids), 4) self.assertEqual(connection_ids[0], connection_ids[1]) self.assertNotEqual(connection_ids[1], connection_ids[2]) self.assertEqual(connection_ids[2], connection_ids[3])
async def test_serve(self): """Test serve method of the server.""" request = None response = [] def handle_request(command, data, temp_directory, connection_id): nonlocal request request = (command, data, temp_directory, connection_id) return "response", FILE_PATHS s = CommunicationEngineServer( host=None, port=None, handle_request=handle_request, handle_disconnect=lambda u: None, ) ws = MockWebsocket( id=0, to_recv=[ encode_header([1, 4, 2, "test"], LEN_HEADER), *split_message("data", block_size=1)[1], encode_header([2, FILES[0]], LEN_FILES_HEADER), ("0" * BLOCK_SIZE).encode("utf-8"), "0".encode("utf-8"), encode_header([2, FILES[1]], LEN_FILES_HEADER), ("1" * BLOCK_SIZE).encode("utf-8"), ("1" * BLOCK_SIZE).encode("utf-8"), ], sent_data=response, ) await s._serve(ws, None) self.assertEqual(list(s._file_hashes.keys()), []) self.assertEqual(list(s._file_hashes.values()), []) version, num_blocks, num_files = decode_header(response[0], LEN_HEADER) self.assertEqual(version, 1) self.assertEqual(num_blocks, 1) self.assertEqual(num_files, 1) self.assertEqual(response[1], b"response") num_blocks, filename = decode_header(response[2], LEN_FILES_HEADER) self.assertEqual(num_blocks, 0) self.assertEqual(filename, FILES[2]) self.assertEqual(request[0], "test") self.assertEqual(request[1], "data") # TODO: Think of a different way to perform the check below. The # condition does is not always true. For example if the IDE is # installed as a Flatpak package, then the temporary directory has a # different path format: # /run/user/1000/app/com.jetbrains.PyCharm-Community/tmpgdjc38pd. # if os.name == "posix": # self.assertTrue(request[2].startswith("/tmp/tmp")) self.assertTrue(isinstance(request[3], uuid.UUID))
async def test_request(self): """Test the request method of the client.""" responses = [] requests = [] client = CommunicationEngineClient( uri=None, handle_response=lambda data, temp_directory: responses.append(data ), ) client.websocket = MockWebsocket( id=7, to_recv=[ encode_header([1, 3, 0], LEN_HEADER), *split_message("hello", block_size=2)[1], encode_header([1, 2, 0], LEN_HEADER), *split_message("bye", block_size=2)[1], ], sent_data=requests, ) await client._request("greet", "Hello") version, num_blocks, num_files, command = decode_header( requests[0], LEN_HEADER) self.assertEqual(version, 1) self.assertEqual(num_blocks, 1) self.assertEqual(num_files, 0) self.assertEqual(command, "greet") self.assertEqual(requests[1], b"Hello") self.assertEqual(responses, ["hello"]) await client._request("say_goodbye", "Bye") version, num_blocks, num_files, command = decode_header( requests[2], LEN_HEADER) self.assertEqual(version, 1) self.assertEqual(num_blocks, 1) self.assertEqual(num_files, 0) self.assertEqual(command, "say_goodbye") self.assertEqual(requests[3], b"Bye") self.assertEqual(responses, ["hello", "bye"])
def _encode(self, command, data, files): """Encode the data to send to the server to bytes. Args: command (str): The command to execute. data (str): The json data to send binary_data (bytes): The binary data to send Returns: bytes: The resulting data encoded """ files = files or [] num_blocks, data = split_message(data) version = 1 yield encode_header( [version, num_blocks, len(files), command], LEN_HEADER) yield from data yield from encode_files(files)
async def _serve(self, websocket, _): """Wait for requests, compute responses and serve them to the user. Args: websocket (Websocket): The websockets object. _ (str): The path of the URI (will be ignored). """ connection_id = self._connection_ids.get(websocket, uuid.uuid4()) self._connection_ids[websocket] = connection_id file_hashes = self._file_hashes.get(connection_id, dict()) self._file_hashes[connection_id] = file_hashes try: while True: with tempfile.TemporaryDirectory() as temp_dir: command, data = await self._decode(websocket, temp_dir, file_hashes, connection_id) # let session handle the request response, files = self._handle_request( command=command, data=data, temp_directory=temp_dir, connection_id=connection_id, ) # send the response files = filter_files(files, file_hashes) logger.debug("Response: %s with %s files" % (response[:DEBUG_MAX], len(files))) num_blocks, response = split_message(response) await websocket.send( encode_header([VERSION, num_blocks, len(files)], LEN_HEADER)) for part in response: await websocket.send(part) for part in encode_files(files): await websocket.send(part) except ws_exceptions.ConnectionClosedOK: pass finally: logger.debug("Connection %s closed!" % connection_id) del self._file_hashes[connection_id] del self._connection_ids[websocket] self._handle_disconnect(connection_id)
async def test_serve(self): """Test serve method of the server.""" request = None response = [] def handle_request(command, data, temp_directory, connection_id): nonlocal request request = (command, data, temp_directory, connection_id) return "response", FILE_PATHS s = CommunicationEngineServer(host=None, port=None, handle_request=handle_request, handle_disconnect=lambda u: None) ws = MockWebsocket(id=0, to_recv=[ encode_header([1, 4, 2, "test"], LEN_HEADER), *split_message("data", block_size=1)[1], encode_header([2, FILES[0]], LEN_FILES_HEADER), ("0" * BLOCK_SIZE).encode("utf-8"), "0".encode("utf-8"), encode_header([2, FILES[1]], LEN_FILES_HEADER), ("1" * BLOCK_SIZE).encode("utf-8"), ("1" * BLOCK_SIZE).encode("utf-8") ], sent_data=response) await s._serve(ws, None) self.assertEqual(list(s._file_hashes.keys()), []) self.assertEqual(list(s._file_hashes.values()), []) version, num_blocks, num_files = decode_header(response[0], LEN_HEADER) self.assertEqual(version, 1) self.assertEqual(num_blocks, 1) self.assertEqual(num_files, 1) self.assertEqual(response[1], b'response') num_blocks, filename = decode_header(response[2], LEN_FILES_HEADER) self.assertEqual(num_blocks, 0) self.assertEqual(filename, FILES[2]) self.assertEqual(request[0], "test") self.assertEqual(request[1], "data") if os.name == "posix": self.assertTrue(request[2].startswith("/tmp/tmp")) self.assertTrue(isinstance(request[3], uuid.UUID))