Exemple #1
0
 def testTaskTableSubscribe(self):
   scheduling_state = 1
   local_scheduler_id = "local_scheduler_id"
   # Subscribe to the task table.
   p = self.redis.pubsub()
   p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
   p.psubscribe("{prefix}*:{state}".format(
       prefix=TASK_PREFIX, state=scheduling_state))
   p.psubscribe("{prefix}{local_scheduler_id}:*".format(
       prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
   task_args = [b"task_id", scheduling_state,
                local_scheduler_id.encode("ascii"), b"task_spec"]
   self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
   # Receive the acknowledgement message.
   self.assertEqual(get_next_message(p)["data"], 1)
   self.assertEqual(get_next_message(p)["data"], 2)
   self.assertEqual(get_next_message(p)["data"], 3)
   # Receive the actual data.
   for i in range(3):
     message = get_next_message(p)["data"]
     # Check that the notification object is correct.
     notification_object = TaskReply.GetRootAsTaskReply(message, 0)
     self.assertEqual(notification_object.TaskId(), b"task_id")
     self.assertEqual(notification_object.State(), scheduling_state)
     self.assertEqual(notification_object.LocalSchedulerId(),
                      local_scheduler_id.encode("ascii"))
     self.assertEqual(notification_object.TaskSpec(), b"task_spec")
Exemple #2
0
    def cleanup_task_table(self):
        """Clean up global state for a failed local schedulers.

    This marks any tasks that were scheduled on dead local schedulers as
    TASK_STATUS_LOST. A local scheduler is deemed dead if it is not in
    self.local_schedulers.
    """
        task_ids = self.redis.scan_iter(match="{prefix}*".format(
            prefix=TASK_PREFIX))
        num_tasks_updated = 0
        for task_id in task_ids:
            task_id = task_id[len(TASK_PREFIX):]
            response = self.redis.execute_command("RAY.TASK_TABLE_GET",
                                                  task_id)
            # Parse the serialized task object.
            task_object = TaskReply.GetRootAsTaskReply(response, 0)
            local_scheduler_id = task_object.LocalSchedulerId()
            # See if the corresponding local scheduler is alive.
            if local_scheduler_id not in self.local_schedulers:
                num_tasks_updated += 1
                ok = self.redis.execute_command("RAY.TASK_TABLE_UPDATE",
                                                task_id, TASK_STATUS_LOST,
                                                NIL_ID)
                if ok != b"OK":
                    log.warn("Failed to update lost task for dead scheduler.")
        if num_tasks_updated > 0:
            log.warn("Marked {} tasks as lost.".format(num_tasks_updated))
Exemple #3
0
    def tearDown(self):
        self.assertTrue(ray.services.all_processes_alive())

        # Determine the IDs of all local schedulers that had a task scheduled or
        # submitted.
        r = redis.StrictRedis(port=self.redis_port)
        task_ids = r.keys("TT:*")
        task_ids = [task_id[3:] for task_id in task_ids]
        local_scheduler_ids = []
        for task_id in task_ids:
            message = r.execute_command("ray.task_table_get", task_id)
            task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
            local_scheduler_ids.append(task_reply_object.LocalSchedulerId())

        # Make sure that all nodes in the cluster were used by checking that the
        # set of local scheduler IDs that had a task scheduled or submitted is
        # equal to the total number of local schedulers started. We add one to the
        # total number of local schedulers to account for NIL_LOCAL_SCHEDULER_ID.
        # This is the local scheduler ID associated with the driver task, since it
        # is not scheduled by a particular local scheduler.
        self.assertEqual(len(set(local_scheduler_ids)),
                         self.num_local_schedulers + 1)

        # Clean up the Ray cluster.
        ray.worker.cleanup()
Exemple #4
0
 def check_task_reply(message, task_args):
     task_status, local_scheduler_id, task_spec = task_args
     task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
     self.assertEqual(task_reply_object.State(), task_status)
     self.assertEqual(task_reply_object.LocalSchedulerId(),
                      local_scheduler_id)
     self.assertEqual(task_reply_object.TaskSpec(), task_spec)
 def check_task_reply(message, task_args, updated=False):
     (task_status, local_scheduler_id, execution_dependencies_string,
      task_spec) = task_args
     task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
     self.assertEqual(task_reply_object.State(), task_status)
     self.assertEqual(task_reply_object.LocalSchedulerId(),
                      local_scheduler_id)
     self.assertEqual(task_reply_object.TaskSpec(), task_spec)
     self.assertEqual(task_reply_object.Updated(), updated)
Exemple #6
0
    def _task_table(self, task_id):
        """Fetch and parse the task table information for a single task ID.

        Args:
            task_id_binary: A string of bytes with the task ID to get
                information about.

        Returns:
            A dictionary with information about the task ID in question.
                TASK_STATUS_MAPPING should be used to parse the "State" field
                into a human-readable string.
        """
        task_table_response = self._execute_command(task_id,
                                                    "RAY.TASK_TABLE_GET",
                                                    task_id.id())
        if task_table_response is None:
            raise Exception("There is no entry for task ID {} in the task "
                            "table.".format(binary_to_hex(task_id.id())))
        task_table_message = TaskReply.GetRootAsTaskReply(task_table_response,
                                                          0)
        task_spec = task_table_message.TaskSpec()
        task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0)
        args = []
        for i in range(task_spec_message.ArgsLength()):
            arg = task_spec_message.Args(i)
            if len(arg.ObjectId()) != 0:
                args.append(binary_to_object_id(arg.ObjectId()))
            else:
                args.append(pickle.loads(arg.Data()))
        # TODO(atumanov): Instead of hard coding these indices, we should use
        # the flatbuffer constants.
        assert task_spec_message.RequiredResourcesLength() == 3
        required_resources = {
            "CPUs": task_spec_message.RequiredResources(0),
            "GPUs": task_spec_message.RequiredResources(1),
            "CustomResource": task_spec_message.RequiredResources(2)}
        task_spec_info = {
            "DriverID": binary_to_hex(task_spec_message.DriverId()),
            "TaskID": binary_to_hex(task_spec_message.TaskId()),
            "ParentTaskID": binary_to_hex(task_spec_message.ParentTaskId()),
            "ParentCounter": task_spec_message.ParentCounter(),
            "ActorID": binary_to_hex(task_spec_message.ActorId()),
            "ActorCounter": task_spec_message.ActorCounter(),
            "FunctionID": binary_to_hex(task_spec_message.FunctionId()),
            "Args": args,
            "ReturnObjectIDs": [binary_to_object_id(
                                    task_spec_message.Returns(i))
                                for i in range(
                                    task_spec_message.ReturnsLength())],
            "RequiredResources": required_resources}

        return {"State": task_table_message.State(),
                "LocalSchedulerID": binary_to_hex(
                    task_table_message.LocalSchedulerId()),
                "TaskSpec": task_spec_info}
Exemple #7
0
 def check_task_subscription(self, p, scheduling_state, local_scheduler_id):
   task_args = [b"task_id", scheduling_state,
                local_scheduler_id.encode("ascii"), b"task_spec"]
   self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
   # Receive the data.
   message = get_next_message(p)["data"]
   # Check that the notification object is correct.
   notification_object = TaskReply.GetRootAsTaskReply(message, 0)
   self.assertEqual(notification_object.TaskId(), b"task_id")
   self.assertEqual(notification_object.State(), scheduling_state)
   self.assertEqual(notification_object.LocalSchedulerId(),
                    local_scheduler_id.encode("ascii"))
   self.assertEqual(notification_object.TaskSpec(), b"task_spec")
Exemple #8
0
    def _task_table(self, task_id):
        """Fetch and parse the task table information for a single task ID.

        Args:
            task_id_binary: A string of bytes with the task ID to get
                information about.

        Returns:
            A dictionary with information about the task ID in question.
                TASK_STATUS_MAPPING should be used to parse the "State" field
                into a human-readable string.
        """
        task_table_response = self._execute_command(task_id,
                                                    "RAY.TASK_TABLE_GET",
                                                    task_id.id())
        if task_table_response is None:
            raise Exception("There is no entry for task ID {} in the task "
                            "table.".format(binary_to_hex(task_id.id())))
        task_table_message = TaskReply.GetRootAsTaskReply(
            task_table_response, 0)
        task_spec = task_table_message.TaskSpec()
        task_spec = ray.local_scheduler.task_from_string(task_spec)

        task_spec_info = {
            "DriverID": binary_to_hex(task_spec.driver_id().id()),
            "TaskID": binary_to_hex(task_spec.task_id().id()),
            "ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()),
            "ParentCounter": task_spec.parent_counter(),
            "ActorID": binary_to_hex(task_spec.actor_id().id()),
            "ActorCounter": task_spec.actor_counter(),
            "FunctionID": binary_to_hex(task_spec.function_id().id()),
            "Args": task_spec.arguments(),
            "ReturnObjectIDs": task_spec.returns(),
            "RequiredResources": task_spec.required_resources()
        }

        return {
            "State":
            task_table_message.State(),
            "LocalSchedulerID":
            binary_to_hex(task_table_message.LocalSchedulerId()),
            "ExecutionDependenciesString":
            task_table_message.ExecutionDependencies(),
            "SpillbackCount":
            task_table_message.SpillbackCount(),
            "TaskSpec":
            task_spec_info
        }
Exemple #9
0
    def tearDown(self):
        self.assertTrue(ray.services.all_processes_alive())

        # Make sure that all nodes in the cluster were used by checking where tasks
        # were scheduled and/or submitted from.
        r = redis.StrictRedis(port=self.redis_port)
        task_ids = r.keys("TT:*")
        task_ids = [task_id[3:] for task_id in task_ids]
        local_scheduler_ids = []
        for task_id in task_ids:
            message = r.execute_command("ray.task_table_get", task_id)
            task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
            local_scheduler_ids.append(task_reply_object.LocalSchedulerId())

        self.assertEqual(len(set(local_scheduler_ids)),
                         self.num_local_schedulers)

        # Clean up the Ray cluster.
        ray.worker.cleanup()
Exemple #10
0
    def _task_table(self, task_id):
        """Fetch and parse the task table information for a single task ID.

        Args:
            task_id_binary: A string of bytes with the task ID to get
                information about.

        Returns:
            A dictionary with information about the task ID in question.
                TASK_STATUS_MAPPING should be used to parse the "State" field
                into a human-readable string.
        """
        task_table_response = self._execute_command(task_id,
                                                    "RAY.TASK_TABLE_GET",
                                                    task_id.id())
        if task_table_response is None:
            raise Exception("There is no entry for task ID {} in the task "
                            "table.".format(binary_to_hex(task_id.id())))
        task_table_message = TaskReply.GetRootAsTaskReply(
            task_table_response, 0)
        task_spec = task_table_message.TaskSpec()
        task_spec = ray.local_scheduler.task_from_string(task_spec)

        task_spec_info = {
            "DriverID":
            binary_to_hex(task_spec.driver_id().id()),
            "TaskID":
            binary_to_hex(task_spec.task_id().id()),
            "ParentTaskID":
            binary_to_hex(task_spec.parent_task_id().id()),
            "ParentCounter":
            task_spec.parent_counter(),
            "ActorID":
            binary_to_hex(task_spec.actor_id().id()),
            "ActorCreationID":
            binary_to_hex(task_spec.actor_creation_id().id()),
            "ActorCreationDummyObjectID":
            binary_to_hex(task_spec.actor_creation_dummy_object_id().id()),
            "ActorCounter":
            task_spec.actor_counter(),
            "FunctionID":
            binary_to_hex(task_spec.function_id().id()),
            "Args":
            task_spec.arguments(),
            "ReturnObjectIDs":
            task_spec.returns(),
            "RequiredResources":
            task_spec.required_resources()
        }

        execution_dependencies_message = (
            TaskExecutionDependencies.GetRootAsTaskExecutionDependencies(
                task_table_message.ExecutionDependencies(), 0))
        execution_dependencies = [
            ray.local_scheduler.ObjectID(
                execution_dependencies_message.ExecutionDependencies(i))
            for i in range(
                execution_dependencies_message.ExecutionDependenciesLength())
        ]

        # TODO(rkn): The return fields ExecutionDependenciesString and
        # ExecutionDependencies are redundant, so we should remove
        # ExecutionDependencies. However, it is currently used in monitor.py.

        return {
            "State":
            task_table_message.State(),
            "LocalSchedulerID":
            binary_to_hex(task_table_message.LocalSchedulerId()),
            "ExecutionDependenciesString":
            task_table_message.ExecutionDependencies(),
            "ExecutionDependencies":
            execution_dependencies,
            "SpillbackCount":
            task_table_message.SpillbackCount(),
            "TaskSpec":
            task_spec_info
        }
Exemple #11
0
    def _task_table(self, task_id_binary):
        """Fetch and parse the task table information for a single object task ID.

    Args:
      task_id_binary: A string of bytes with the task ID to get information
        about.

    Returns:
      A dictionary with information about the task ID in question.
    """
        task_table_response = self.redis_client.execute_command(
            "RAY.TASK_TABLE_GET", task_id_binary)
        if task_table_response is None:
            raise Exception(
                "There is no entry for task ID {} in the task table.".format(
                    binary_to_hex(task_id_binary)))
        task_table_message = TaskReply.GetRootAsTaskReply(
            task_table_response, 0)
        task_spec = task_table_message.TaskSpec()
        task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0)
        args = []
        for i in range(task_spec_message.ArgsLength()):
            arg = task_spec_message.Args(i)
            if len(arg.ObjectId()) != 0:
                args.append(binary_to_object_id(arg.ObjectId()))
            else:
                args.append(pickle.loads(arg.Data()))
        assert task_spec_message.RequiredResourcesLength() == 2
        required_resources = {
            "CPUs": task_spec_message.RequiredResources(0),
            "GPUs": task_spec_message.RequiredResources(1)
        }
        task_spec_info = {
            "DriverID":
            binary_to_hex(task_spec_message.DriverId()),
            "TaskID":
            binary_to_hex(task_spec_message.TaskId()),
            "ParentTaskID":
            binary_to_hex(task_spec_message.ParentTaskId()),
            "ParentCounter":
            task_spec_message.ParentCounter(),
            "ActorID":
            binary_to_hex(task_spec_message.ActorId()),
            "ActorCounter":
            task_spec_message.ActorCounter(),
            "FunctionID":
            binary_to_hex(task_spec_message.FunctionId()),
            "Args":
            args,
            "ReturnObjectIDs": [
                binary_to_object_id(task_spec_message.Returns(i))
                for i in range(task_spec_message.ReturnsLength())
            ],
            "RequiredResources":
            required_resources
        }

        return {
            "State":
            task_state_mapping[task_table_message.State()],
            "LocalSchedulerID":
            binary_to_hex(task_table_message.LocalSchedulerId()),
            "TaskSpec":
            task_spec_info
        }