def testTaskTableSubscribe(self): scheduling_state = 1 local_scheduler_id = "local_scheduler_id" # Subscribe to the task table. p = self.redis.pubsub() p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) p.psubscribe("{prefix}*:{state}".format( prefix=TASK_PREFIX, state=scheduling_state)) p.psubscribe("{prefix}{local_scheduler_id}:*".format( prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id)) task_args = [b"task_id", scheduling_state, local_scheduler_id.encode("ascii"), b"task_spec"] self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args) # Receive the acknowledgement message. self.assertEqual(get_next_message(p)["data"], 1) self.assertEqual(get_next_message(p)["data"], 2) self.assertEqual(get_next_message(p)["data"], 3) # Receive the actual data. for i in range(3): message = get_next_message(p)["data"] # Check that the notification object is correct. notification_object = TaskReply.GetRootAsTaskReply(message, 0) self.assertEqual(notification_object.TaskId(), b"task_id") self.assertEqual(notification_object.State(), scheduling_state) self.assertEqual(notification_object.LocalSchedulerId(), local_scheduler_id.encode("ascii")) self.assertEqual(notification_object.TaskSpec(), b"task_spec")
def cleanup_task_table(self): """Clean up global state for a failed local schedulers. This marks any tasks that were scheduled on dead local schedulers as TASK_STATUS_LOST. A local scheduler is deemed dead if it is not in self.local_schedulers. """ task_ids = self.redis.scan_iter(match="{prefix}*".format( prefix=TASK_PREFIX)) num_tasks_updated = 0 for task_id in task_ids: task_id = task_id[len(TASK_PREFIX):] response = self.redis.execute_command("RAY.TASK_TABLE_GET", task_id) # Parse the serialized task object. task_object = TaskReply.GetRootAsTaskReply(response, 0) local_scheduler_id = task_object.LocalSchedulerId() # See if the corresponding local scheduler is alive. if local_scheduler_id not in self.local_schedulers: num_tasks_updated += 1 ok = self.redis.execute_command("RAY.TASK_TABLE_UPDATE", task_id, TASK_STATUS_LOST, NIL_ID) if ok != b"OK": log.warn("Failed to update lost task for dead scheduler.") if num_tasks_updated > 0: log.warn("Marked {} tasks as lost.".format(num_tasks_updated))
def tearDown(self): self.assertTrue(ray.services.all_processes_alive()) # Determine the IDs of all local schedulers that had a task scheduled or # submitted. r = redis.StrictRedis(port=self.redis_port) task_ids = r.keys("TT:*") task_ids = [task_id[3:] for task_id in task_ids] local_scheduler_ids = [] for task_id in task_ids: message = r.execute_command("ray.task_table_get", task_id) task_reply_object = TaskReply.GetRootAsTaskReply(message, 0) local_scheduler_ids.append(task_reply_object.LocalSchedulerId()) # Make sure that all nodes in the cluster were used by checking that the # set of local scheduler IDs that had a task scheduled or submitted is # equal to the total number of local schedulers started. We add one to the # total number of local schedulers to account for NIL_LOCAL_SCHEDULER_ID. # This is the local scheduler ID associated with the driver task, since it # is not scheduled by a particular local scheduler. self.assertEqual(len(set(local_scheduler_ids)), self.num_local_schedulers + 1) # Clean up the Ray cluster. ray.worker.cleanup()
def check_task_reply(message, task_args): task_status, local_scheduler_id, task_spec = task_args task_reply_object = TaskReply.GetRootAsTaskReply(message, 0) self.assertEqual(task_reply_object.State(), task_status) self.assertEqual(task_reply_object.LocalSchedulerId(), local_scheduler_id) self.assertEqual(task_reply_object.TaskSpec(), task_spec)
def check_task_reply(message, task_args, updated=False): (task_status, local_scheduler_id, execution_dependencies_string, task_spec) = task_args task_reply_object = TaskReply.GetRootAsTaskReply(message, 0) self.assertEqual(task_reply_object.State(), task_status) self.assertEqual(task_reply_object.LocalSchedulerId(), local_scheduler_id) self.assertEqual(task_reply_object.TaskSpec(), task_spec) self.assertEqual(task_reply_object.Updated(), updated)
def _task_table(self, task_id): """Fetch and parse the task table information for a single task ID. Args: task_id_binary: A string of bytes with the task ID to get information about. Returns: A dictionary with information about the task ID in question. TASK_STATUS_MAPPING should be used to parse the "State" field into a human-readable string. """ task_table_response = self._execute_command(task_id, "RAY.TASK_TABLE_GET", task_id.id()) if task_table_response is None: raise Exception("There is no entry for task ID {} in the task " "table.".format(binary_to_hex(task_id.id()))) task_table_message = TaskReply.GetRootAsTaskReply(task_table_response, 0) task_spec = task_table_message.TaskSpec() task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0) args = [] for i in range(task_spec_message.ArgsLength()): arg = task_spec_message.Args(i) if len(arg.ObjectId()) != 0: args.append(binary_to_object_id(arg.ObjectId())) else: args.append(pickle.loads(arg.Data())) # TODO(atumanov): Instead of hard coding these indices, we should use # the flatbuffer constants. assert task_spec_message.RequiredResourcesLength() == 3 required_resources = { "CPUs": task_spec_message.RequiredResources(0), "GPUs": task_spec_message.RequiredResources(1), "CustomResource": task_spec_message.RequiredResources(2)} task_spec_info = { "DriverID": binary_to_hex(task_spec_message.DriverId()), "TaskID": binary_to_hex(task_spec_message.TaskId()), "ParentTaskID": binary_to_hex(task_spec_message.ParentTaskId()), "ParentCounter": task_spec_message.ParentCounter(), "ActorID": binary_to_hex(task_spec_message.ActorId()), "ActorCounter": task_spec_message.ActorCounter(), "FunctionID": binary_to_hex(task_spec_message.FunctionId()), "Args": args, "ReturnObjectIDs": [binary_to_object_id( task_spec_message.Returns(i)) for i in range( task_spec_message.ReturnsLength())], "RequiredResources": required_resources} return {"State": task_table_message.State(), "LocalSchedulerID": binary_to_hex( task_table_message.LocalSchedulerId()), "TaskSpec": task_spec_info}
def check_task_subscription(self, p, scheduling_state, local_scheduler_id): task_args = [b"task_id", scheduling_state, local_scheduler_id.encode("ascii"), b"task_spec"] self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args) # Receive the data. message = get_next_message(p)["data"] # Check that the notification object is correct. notification_object = TaskReply.GetRootAsTaskReply(message, 0) self.assertEqual(notification_object.TaskId(), b"task_id") self.assertEqual(notification_object.State(), scheduling_state) self.assertEqual(notification_object.LocalSchedulerId(), local_scheduler_id.encode("ascii")) self.assertEqual(notification_object.TaskSpec(), b"task_spec")
def _task_table(self, task_id): """Fetch and parse the task table information for a single task ID. Args: task_id_binary: A string of bytes with the task ID to get information about. Returns: A dictionary with information about the task ID in question. TASK_STATUS_MAPPING should be used to parse the "State" field into a human-readable string. """ task_table_response = self._execute_command(task_id, "RAY.TASK_TABLE_GET", task_id.id()) if task_table_response is None: raise Exception("There is no entry for task ID {} in the task " "table.".format(binary_to_hex(task_id.id()))) task_table_message = TaskReply.GetRootAsTaskReply( task_table_response, 0) task_spec = task_table_message.TaskSpec() task_spec = ray.local_scheduler.task_from_string(task_spec) task_spec_info = { "DriverID": binary_to_hex(task_spec.driver_id().id()), "TaskID": binary_to_hex(task_spec.task_id().id()), "ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()), "ParentCounter": task_spec.parent_counter(), "ActorID": binary_to_hex(task_spec.actor_id().id()), "ActorCounter": task_spec.actor_counter(), "FunctionID": binary_to_hex(task_spec.function_id().id()), "Args": task_spec.arguments(), "ReturnObjectIDs": task_spec.returns(), "RequiredResources": task_spec.required_resources() } return { "State": task_table_message.State(), "LocalSchedulerID": binary_to_hex(task_table_message.LocalSchedulerId()), "ExecutionDependenciesString": task_table_message.ExecutionDependencies(), "SpillbackCount": task_table_message.SpillbackCount(), "TaskSpec": task_spec_info }
def tearDown(self): self.assertTrue(ray.services.all_processes_alive()) # Make sure that all nodes in the cluster were used by checking where tasks # were scheduled and/or submitted from. r = redis.StrictRedis(port=self.redis_port) task_ids = r.keys("TT:*") task_ids = [task_id[3:] for task_id in task_ids] local_scheduler_ids = [] for task_id in task_ids: message = r.execute_command("ray.task_table_get", task_id) task_reply_object = TaskReply.GetRootAsTaskReply(message, 0) local_scheduler_ids.append(task_reply_object.LocalSchedulerId()) self.assertEqual(len(set(local_scheduler_ids)), self.num_local_schedulers) # Clean up the Ray cluster. ray.worker.cleanup()
def _task_table(self, task_id): """Fetch and parse the task table information for a single task ID. Args: task_id_binary: A string of bytes with the task ID to get information about. Returns: A dictionary with information about the task ID in question. TASK_STATUS_MAPPING should be used to parse the "State" field into a human-readable string. """ task_table_response = self._execute_command(task_id, "RAY.TASK_TABLE_GET", task_id.id()) if task_table_response is None: raise Exception("There is no entry for task ID {} in the task " "table.".format(binary_to_hex(task_id.id()))) task_table_message = TaskReply.GetRootAsTaskReply( task_table_response, 0) task_spec = task_table_message.TaskSpec() task_spec = ray.local_scheduler.task_from_string(task_spec) task_spec_info = { "DriverID": binary_to_hex(task_spec.driver_id().id()), "TaskID": binary_to_hex(task_spec.task_id().id()), "ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()), "ParentCounter": task_spec.parent_counter(), "ActorID": binary_to_hex(task_spec.actor_id().id()), "ActorCreationID": binary_to_hex(task_spec.actor_creation_id().id()), "ActorCreationDummyObjectID": binary_to_hex(task_spec.actor_creation_dummy_object_id().id()), "ActorCounter": task_spec.actor_counter(), "FunctionID": binary_to_hex(task_spec.function_id().id()), "Args": task_spec.arguments(), "ReturnObjectIDs": task_spec.returns(), "RequiredResources": task_spec.required_resources() } execution_dependencies_message = ( TaskExecutionDependencies.GetRootAsTaskExecutionDependencies( task_table_message.ExecutionDependencies(), 0)) execution_dependencies = [ ray.local_scheduler.ObjectID( execution_dependencies_message.ExecutionDependencies(i)) for i in range( execution_dependencies_message.ExecutionDependenciesLength()) ] # TODO(rkn): The return fields ExecutionDependenciesString and # ExecutionDependencies are redundant, so we should remove # ExecutionDependencies. However, it is currently used in monitor.py. return { "State": task_table_message.State(), "LocalSchedulerID": binary_to_hex(task_table_message.LocalSchedulerId()), "ExecutionDependenciesString": task_table_message.ExecutionDependencies(), "ExecutionDependencies": execution_dependencies, "SpillbackCount": task_table_message.SpillbackCount(), "TaskSpec": task_spec_info }
def _task_table(self, task_id_binary): """Fetch and parse the task table information for a single object task ID. Args: task_id_binary: A string of bytes with the task ID to get information about. Returns: A dictionary with information about the task ID in question. """ task_table_response = self.redis_client.execute_command( "RAY.TASK_TABLE_GET", task_id_binary) if task_table_response is None: raise Exception( "There is no entry for task ID {} in the task table.".format( binary_to_hex(task_id_binary))) task_table_message = TaskReply.GetRootAsTaskReply( task_table_response, 0) task_spec = task_table_message.TaskSpec() task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0) args = [] for i in range(task_spec_message.ArgsLength()): arg = task_spec_message.Args(i) if len(arg.ObjectId()) != 0: args.append(binary_to_object_id(arg.ObjectId())) else: args.append(pickle.loads(arg.Data())) assert task_spec_message.RequiredResourcesLength() == 2 required_resources = { "CPUs": task_spec_message.RequiredResources(0), "GPUs": task_spec_message.RequiredResources(1) } task_spec_info = { "DriverID": binary_to_hex(task_spec_message.DriverId()), "TaskID": binary_to_hex(task_spec_message.TaskId()), "ParentTaskID": binary_to_hex(task_spec_message.ParentTaskId()), "ParentCounter": task_spec_message.ParentCounter(), "ActorID": binary_to_hex(task_spec_message.ActorId()), "ActorCounter": task_spec_message.ActorCounter(), "FunctionID": binary_to_hex(task_spec_message.FunctionId()), "Args": args, "ReturnObjectIDs": [ binary_to_object_id(task_spec_message.Returns(i)) for i in range(task_spec_message.ReturnsLength()) ], "RequiredResources": required_resources } return { "State": task_state_mapping[task_table_message.State()], "LocalSchedulerID": binary_to_hex(task_table_message.LocalSchedulerId()), "TaskSpec": task_spec_info }