def _task_table(self, task_id): """Fetch and parse the task table information for a single task ID. Args: task_id_binary: A string of bytes with the task ID to get information about. Returns: A dictionary with information about the task ID in question. TASK_STATUS_MAPPING should be used to parse the "State" field into a human-readable string. """ task_table_response = self._execute_command(task_id, "RAY.TASK_TABLE_GET", task_id.id()) if task_table_response is None: raise Exception("There is no entry for task ID {} in the task " "table.".format(binary_to_hex(task_id.id()))) task_table_message = TaskReply.GetRootAsTaskReply(task_table_response, 0) task_spec = task_table_message.TaskSpec() task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0) args = [] for i in range(task_spec_message.ArgsLength()): arg = task_spec_message.Args(i) if len(arg.ObjectId()) != 0: args.append(binary_to_object_id(arg.ObjectId())) else: args.append(pickle.loads(arg.Data())) # TODO(atumanov): Instead of hard coding these indices, we should use # the flatbuffer constants. assert task_spec_message.RequiredResourcesLength() == 3 required_resources = { "CPUs": task_spec_message.RequiredResources(0), "GPUs": task_spec_message.RequiredResources(1), "CustomResource": task_spec_message.RequiredResources(2)} task_spec_info = { "DriverID": binary_to_hex(task_spec_message.DriverId()), "TaskID": binary_to_hex(task_spec_message.TaskId()), "ParentTaskID": binary_to_hex(task_spec_message.ParentTaskId()), "ParentCounter": task_spec_message.ParentCounter(), "ActorID": binary_to_hex(task_spec_message.ActorId()), "ActorCounter": task_spec_message.ActorCounter(), "FunctionID": binary_to_hex(task_spec_message.FunctionId()), "Args": args, "ReturnObjectIDs": [binary_to_object_id( task_spec_message.Returns(i)) for i in range( task_spec_message.ReturnsLength())], "RequiredResources": required_resources} return {"State": task_table_message.State(), "LocalSchedulerID": binary_to_hex( task_table_message.LocalSchedulerId()), "TaskSpec": task_spec_info}
def _entries_for_driver_in_shard(self, driver_id, redis_shard_index): """Collect IDs of control-state entries for a driver from a shard. Args: driver_id: The ID of the driver. redis_shard_index: The index of the Redis shard to query. Returns: Lists of IDs: (returned_object_ids, task_ids, put_objects). The first two are relevant to the driver and are safe to delete. The last contains all "put" objects in this redis shard; each element is an (object_id, corresponding task_id) pair. """ # TODO(zongheng): consider adding save & restore functionalities. redis = self.state.redis_clients[redis_shard_index] task_table_infos = {} # task id -> TaskInfo messages # Scan the task table & filter to get the list of tasks belong to this # driver. Use a cursor in order not to block the redis shards. for key in redis.scan_iter(match=TASK_TABLE_PREFIX + b"*"): entry = redis.hgetall(key) task_info = TaskInfo.GetRootAsTaskInfo(entry[b"TaskSpec"], 0) if driver_id != task_info.DriverId(): # Ignore tasks that aren't from this driver. continue task_table_infos[task_info.TaskId()] = task_info # Get the list of objects returned by these tasks. Note these might # not belong to this redis shard. returned_object_ids = [] for task_info in task_table_infos.values(): returned_object_ids.extend([ task_info.Returns(i) for i in range(task_info.ReturnsLength()) ]) # Also record all the ray.put()'d objects. put_objects = [] for key in redis.scan_iter(match=OBJECT_INFO_PREFIX + b"*"): entry = redis.hgetall(key) if entry[b"is_put"] == "0": continue object_id = key.split(OBJECT_INFO_PREFIX)[1] task_id = entry[b"task"] put_objects.append((object_id, task_id)) return returned_object_ids, task_table_infos.keys(), put_objects
def _task_table(self, task_id_binary): """Fetch and parse the task table information for a single object task ID. Args: task_id_binary: A string of bytes with the task ID to get information about. Returns: A dictionary with information about the task ID in question. """ task_table_response = self.redis_client.execute_command( "RAY.TASK_TABLE_GET", task_id_binary) if task_table_response is None: raise Exception( "There is no entry for task ID {} in the task table.".format( binary_to_hex(task_id_binary))) task_table_message = TaskReply.GetRootAsTaskReply( task_table_response, 0) task_spec = task_table_message.TaskSpec() task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0) args = [] for i in range(task_spec_message.ArgsLength()): arg = task_spec_message.Args(i) if len(arg.ObjectId()) != 0: args.append(binary_to_object_id(arg.ObjectId())) else: args.append(pickle.loads(arg.Data())) assert task_spec_message.RequiredResourcesLength() == 2 required_resources = { "CPUs": task_spec_message.RequiredResources(0), "GPUs": task_spec_message.RequiredResources(1) } task_spec_info = { "DriverID": binary_to_hex(task_spec_message.DriverId()), "TaskID": binary_to_hex(task_spec_message.TaskId()), "ParentTaskID": binary_to_hex(task_spec_message.ParentTaskId()), "ParentCounter": task_spec_message.ParentCounter(), "ActorID": binary_to_hex(task_spec_message.ActorId()), "ActorCounter": task_spec_message.ActorCounter(), "FunctionID": binary_to_hex(task_spec_message.FunctionId()), "Args": args, "ReturnObjectIDs": [ binary_to_object_id(task_spec_message.Returns(i)) for i in range(task_spec_message.ReturnsLength()) ], "RequiredResources": required_resources } return { "State": task_state_mapping[task_table_message.State()], "LocalSchedulerID": binary_to_hex(task_table_message.LocalSchedulerId()), "TaskSpec": task_spec_info }