Beispiel #1
0
    def _task_table(self, task_id):
        """Fetch and parse the task table information for a single task ID.

        Args:
            task_id_binary: A string of bytes with the task ID to get
                information about.

        Returns:
            A dictionary with information about the task ID in question.
                TASK_STATUS_MAPPING should be used to parse the "State" field
                into a human-readable string.
        """
        task_table_response = self._execute_command(task_id,
                                                    "RAY.TASK_TABLE_GET",
                                                    task_id.id())
        if task_table_response is None:
            raise Exception("There is no entry for task ID {} in the task "
                            "table.".format(binary_to_hex(task_id.id())))
        task_table_message = TaskReply.GetRootAsTaskReply(task_table_response,
                                                          0)
        task_spec = task_table_message.TaskSpec()
        task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0)
        args = []
        for i in range(task_spec_message.ArgsLength()):
            arg = task_spec_message.Args(i)
            if len(arg.ObjectId()) != 0:
                args.append(binary_to_object_id(arg.ObjectId()))
            else:
                args.append(pickle.loads(arg.Data()))
        # TODO(atumanov): Instead of hard coding these indices, we should use
        # the flatbuffer constants.
        assert task_spec_message.RequiredResourcesLength() == 3
        required_resources = {
            "CPUs": task_spec_message.RequiredResources(0),
            "GPUs": task_spec_message.RequiredResources(1),
            "CustomResource": task_spec_message.RequiredResources(2)}
        task_spec_info = {
            "DriverID": binary_to_hex(task_spec_message.DriverId()),
            "TaskID": binary_to_hex(task_spec_message.TaskId()),
            "ParentTaskID": binary_to_hex(task_spec_message.ParentTaskId()),
            "ParentCounter": task_spec_message.ParentCounter(),
            "ActorID": binary_to_hex(task_spec_message.ActorId()),
            "ActorCounter": task_spec_message.ActorCounter(),
            "FunctionID": binary_to_hex(task_spec_message.FunctionId()),
            "Args": args,
            "ReturnObjectIDs": [binary_to_object_id(
                                    task_spec_message.Returns(i))
                                for i in range(
                                    task_spec_message.ReturnsLength())],
            "RequiredResources": required_resources}

        return {"State": task_table_message.State(),
                "LocalSchedulerID": binary_to_hex(
                    task_table_message.LocalSchedulerId()),
                "TaskSpec": task_spec_info}
Beispiel #2
0
    def _entries_for_driver_in_shard(self, driver_id, redis_shard_index):
        """Collect IDs of control-state entries for a driver from a shard.

        Args:
            driver_id: The ID of the driver.
            redis_shard_index: The index of the Redis shard to query.

        Returns:
            Lists of IDs: (returned_object_ids, task_ids, put_objects). The
                first two are relevant to the driver and are safe to delete.
                The last contains all "put" objects in this redis shard; each
                element is an (object_id, corresponding task_id) pair.
        """
        # TODO(zongheng): consider adding save & restore functionalities.
        redis = self.state.redis_clients[redis_shard_index]
        task_table_infos = {}  # task id -> TaskInfo messages

        # Scan the task table & filter to get the list of tasks belong to this
        # driver.  Use a cursor in order not to block the redis shards.
        for key in redis.scan_iter(match=TASK_TABLE_PREFIX + b"*"):
            entry = redis.hgetall(key)
            task_info = TaskInfo.GetRootAsTaskInfo(entry[b"TaskSpec"], 0)
            if driver_id != task_info.DriverId():
                # Ignore tasks that aren't from this driver.
                continue
            task_table_infos[task_info.TaskId()] = task_info

        # Get the list of objects returned by these tasks.  Note these might
        # not belong to this redis shard.
        returned_object_ids = []
        for task_info in task_table_infos.values():
            returned_object_ids.extend([
                task_info.Returns(i) for i in range(task_info.ReturnsLength())
            ])

        # Also record all the ray.put()'d objects.
        put_objects = []
        for key in redis.scan_iter(match=OBJECT_INFO_PREFIX + b"*"):
            entry = redis.hgetall(key)
            if entry[b"is_put"] == "0":
                continue
            object_id = key.split(OBJECT_INFO_PREFIX)[1]
            task_id = entry[b"task"]
            put_objects.append((object_id, task_id))

        return returned_object_ids, task_table_infos.keys(), put_objects
Beispiel #3
0
    def _task_table(self, task_id_binary):
        """Fetch and parse the task table information for a single object task ID.

    Args:
      task_id_binary: A string of bytes with the task ID to get information
        about.

    Returns:
      A dictionary with information about the task ID in question.
    """
        task_table_response = self.redis_client.execute_command(
            "RAY.TASK_TABLE_GET", task_id_binary)
        if task_table_response is None:
            raise Exception(
                "There is no entry for task ID {} in the task table.".format(
                    binary_to_hex(task_id_binary)))
        task_table_message = TaskReply.GetRootAsTaskReply(
            task_table_response, 0)
        task_spec = task_table_message.TaskSpec()
        task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0)
        args = []
        for i in range(task_spec_message.ArgsLength()):
            arg = task_spec_message.Args(i)
            if len(arg.ObjectId()) != 0:
                args.append(binary_to_object_id(arg.ObjectId()))
            else:
                args.append(pickle.loads(arg.Data()))
        assert task_spec_message.RequiredResourcesLength() == 2
        required_resources = {
            "CPUs": task_spec_message.RequiredResources(0),
            "GPUs": task_spec_message.RequiredResources(1)
        }
        task_spec_info = {
            "DriverID":
            binary_to_hex(task_spec_message.DriverId()),
            "TaskID":
            binary_to_hex(task_spec_message.TaskId()),
            "ParentTaskID":
            binary_to_hex(task_spec_message.ParentTaskId()),
            "ParentCounter":
            task_spec_message.ParentCounter(),
            "ActorID":
            binary_to_hex(task_spec_message.ActorId()),
            "ActorCounter":
            task_spec_message.ActorCounter(),
            "FunctionID":
            binary_to_hex(task_spec_message.FunctionId()),
            "Args":
            args,
            "ReturnObjectIDs": [
                binary_to_object_id(task_spec_message.Returns(i))
                for i in range(task_spec_message.ReturnsLength())
            ],
            "RequiredResources":
            required_resources
        }

        return {
            "State":
            task_state_mapping[task_table_message.State()],
            "LocalSchedulerID":
            binary_to_hex(task_table_message.LocalSchedulerId()),
            "TaskSpec":
            task_spec_info
        }