def test_split_message(self):
     """Test splitting messages."""
     num_blocks, gen = split_message("abcdefgh", 2)
     self.assertEqual(num_blocks, 4)
     self.assertEqual(list(gen), [b"ab", b"cd", b"ef", b"gh"])
     num_blocks, gen = split_message("abcdefgh", 3)
     self.assertEqual(num_blocks, 3)
     self.assertEqual(list(gen), [b"abc", b"def", b"gh"])
示例#2
0
    async def test_serve(self):
        """Test the serve method of the server."""
        response = []
        disconnects = []
        connection_ids = []

        def handle_request(command, data, temp_directory, connection_id):
            nonlocal connection_ids
            connection_ids.append(connection_id)
            return command + "-" + data + "!", []

        server = CommunicationEngineServer(
            host=None,
            port=None,
            handle_request=handle_request,
            handle_disconnect=lambda u: disconnects.append(u),
        )

        websocket = MockWebsocket(
            id=12,
            to_recv=[
                encode_header([1, 3, 0, "greet"], LEN_HEADER),
                *split_message("Hello", block_size=2)[1],
                encode_header([1, 2, 0, "say_goodbye"], LEN_HEADER),
                *split_message("Bye", block_size=2)[1],
            ],
            sent_data=response,
        )

        await server._serve(websocket, None)
        version, num_blocks, num_files = decode_header(response[0], LEN_HEADER)
        self.assertEqual(version, 1)
        self.assertEqual(num_blocks, 1)
        self.assertEqual(num_files, 0)
        self.assertEqual(response[1], b"greet-Hello!")
        version, num_blocks, num_files = decode_header(response[2], LEN_HEADER)
        self.assertEqual(version, 1)
        self.assertEqual(num_blocks, 1)
        self.assertEqual(num_files, 0)
        self.assertEqual(response[3], b"say_goodbye-Bye!")
        self.assertEqual(disconnects, [connection_ids[0]])

        websocket.reset()
        await server._serve(websocket, None)
        self.assertEqual(len(connection_ids), 4)
        self.assertEqual(connection_ids[0], connection_ids[1])
        self.assertNotEqual(connection_ids[1], connection_ids[2])
        self.assertEqual(connection_ids[2], connection_ids[3])
示例#3
0
    async def test_serve(self):
        """Test serve method of the server."""
        request = None
        response = []

        def handle_request(command, data, temp_directory, connection_id):
            nonlocal request
            request = (command, data, temp_directory, connection_id)
            return "response", FILE_PATHS

        s = CommunicationEngineServer(
            host=None,
            port=None,
            handle_request=handle_request,
            handle_disconnect=lambda u: None,
        )
        ws = MockWebsocket(
            id=0,
            to_recv=[
                encode_header([1, 4, 2, "test"], LEN_HEADER),
                *split_message("data", block_size=1)[1],
                encode_header([2, FILES[0]], LEN_FILES_HEADER),
                ("0" * BLOCK_SIZE).encode("utf-8"),
                "0".encode("utf-8"),
                encode_header([2, FILES[1]], LEN_FILES_HEADER),
                ("1" * BLOCK_SIZE).encode("utf-8"),
                ("1" * BLOCK_SIZE).encode("utf-8"),
            ],
            sent_data=response,
        )
        await s._serve(ws, None)
        self.assertEqual(list(s._file_hashes.keys()), [])
        self.assertEqual(list(s._file_hashes.values()), [])
        version, num_blocks, num_files = decode_header(response[0], LEN_HEADER)
        self.assertEqual(version, 1)
        self.assertEqual(num_blocks, 1)
        self.assertEqual(num_files, 1)
        self.assertEqual(response[1], b"response")
        num_blocks, filename = decode_header(response[2], LEN_FILES_HEADER)
        self.assertEqual(num_blocks, 0)
        self.assertEqual(filename, FILES[2])
        self.assertEqual(request[0], "test")
        self.assertEqual(request[1], "data")
        # TODO: Think of a different way to perform the check below. The
        #  condition does is not always true. For example if the IDE is
        #  installed as a Flatpak package, then the temporary directory has a
        #  different path format:
        #  /run/user/1000/app/com.jetbrains.PyCharm-Community/tmpgdjc38pd.
        # if os.name == "posix":
        #     self.assertTrue(request[2].startswith("/tmp/tmp"))
        self.assertTrue(isinstance(request[3], uuid.UUID))
示例#4
0
    async def test_request(self):
        """Test the request method of the client."""
        responses = []
        requests = []
        client = CommunicationEngineClient(
            uri=None,
            handle_response=lambda data, temp_directory: responses.append(data
                                                                          ),
        )
        client.websocket = MockWebsocket(
            id=7,
            to_recv=[
                encode_header([1, 3, 0], LEN_HEADER),
                *split_message("hello", block_size=2)[1],
                encode_header([1, 2, 0], LEN_HEADER),
                *split_message("bye", block_size=2)[1],
            ],
            sent_data=requests,
        )
        await client._request("greet", "Hello")
        version, num_blocks, num_files, command = decode_header(
            requests[0], LEN_HEADER)
        self.assertEqual(version, 1)
        self.assertEqual(num_blocks, 1)
        self.assertEqual(num_files, 0)
        self.assertEqual(command, "greet")
        self.assertEqual(requests[1], b"Hello")
        self.assertEqual(responses, ["hello"])

        await client._request("say_goodbye", "Bye")
        version, num_blocks, num_files, command = decode_header(
            requests[2], LEN_HEADER)
        self.assertEqual(version, 1)
        self.assertEqual(num_blocks, 1)
        self.assertEqual(num_files, 0)
        self.assertEqual(command, "say_goodbye")
        self.assertEqual(requests[3], b"Bye")
        self.assertEqual(responses, ["hello", "bye"])
示例#5
0
    def _encode(self, command, data, files):
        """Encode the data to send to the server to bytes.

        Args:
            command (str): The command to execute.
            data (str): The json data to send
            binary_data (bytes): The binary data to send

        Returns:
            bytes: The resulting data encoded
        """
        files = files or []
        num_blocks, data = split_message(data)
        version = 1
        yield encode_header(
            [version, num_blocks, len(files), command], LEN_HEADER)
        yield from data
        yield from encode_files(files)
示例#6
0
    async def _serve(self, websocket, _):
        """Wait for requests, compute responses and serve them to the user.

        Args:
            websocket (Websocket): The websockets object.
            _ (str): The path of the URI (will be ignored).
        """
        connection_id = self._connection_ids.get(websocket, uuid.uuid4())
        self._connection_ids[websocket] = connection_id
        file_hashes = self._file_hashes.get(connection_id, dict())
        self._file_hashes[connection_id] = file_hashes
        try:
            while True:
                with tempfile.TemporaryDirectory() as temp_dir:
                    command, data = await self._decode(websocket, temp_dir,
                                                       file_hashes,
                                                       connection_id)
                    # let session handle the request
                    response, files = self._handle_request(
                        command=command,
                        data=data,
                        temp_directory=temp_dir,
                        connection_id=connection_id,
                    )

                    # send the response
                    files = filter_files(files, file_hashes)
                    logger.debug("Response: %s with %s files" %
                                 (response[:DEBUG_MAX], len(files)))
                    num_blocks, response = split_message(response)
                    await websocket.send(
                        encode_header([VERSION, num_blocks,
                                       len(files)], LEN_HEADER))
                    for part in response:
                        await websocket.send(part)
                    for part in encode_files(files):
                        await websocket.send(part)
        except ws_exceptions.ConnectionClosedOK:
            pass
        finally:
            logger.debug("Connection %s closed!" % connection_id)
            del self._file_hashes[connection_id]
            del self._connection_ids[websocket]
            self._handle_disconnect(connection_id)
示例#7
0
    async def test_serve(self):
        """Test serve method of the server."""
        request = None
        response = []

        def handle_request(command, data, temp_directory, connection_id):
            nonlocal request
            request = (command, data, temp_directory, connection_id)
            return "response", FILE_PATHS

        s = CommunicationEngineServer(host=None,
                                      port=None,
                                      handle_request=handle_request,
                                      handle_disconnect=lambda u: None)
        ws = MockWebsocket(id=0,
                           to_recv=[
                               encode_header([1, 4, 2, "test"], LEN_HEADER),
                               *split_message("data", block_size=1)[1],
                               encode_header([2, FILES[0]], LEN_FILES_HEADER),
                               ("0" * BLOCK_SIZE).encode("utf-8"),
                               "0".encode("utf-8"),
                               encode_header([2, FILES[1]], LEN_FILES_HEADER),
                               ("1" * BLOCK_SIZE).encode("utf-8"),
                               ("1" * BLOCK_SIZE).encode("utf-8")
                           ],
                           sent_data=response)
        await s._serve(ws, None)
        self.assertEqual(list(s._file_hashes.keys()), [])
        self.assertEqual(list(s._file_hashes.values()), [])
        version, num_blocks, num_files = decode_header(response[0], LEN_HEADER)
        self.assertEqual(version, 1)
        self.assertEqual(num_blocks, 1)
        self.assertEqual(num_files, 1)
        self.assertEqual(response[1], b'response')
        num_blocks, filename = decode_header(response[2], LEN_FILES_HEADER)
        self.assertEqual(num_blocks, 0)
        self.assertEqual(filename, FILES[2])
        self.assertEqual(request[0], "test")
        self.assertEqual(request[1], "data")
        if os.name == "posix":
            self.assertTrue(request[2].startswith("/tmp/tmp"))
        self.assertTrue(isinstance(request[3], uuid.UUID))