Ejemplo n.º 1
0
 def trial_info(self, trial):
     if trial.last_result:
         result = trial.last_result.copy()
     else:
         result = None
     info_dict = {
         "id": trial.trial_id,
         "trainable_name": trial.trainable_name,
         "config": binary_to_hex(cloudpickle.dumps(trial.config)),
         "status": trial.status,
         "result": binary_to_hex(cloudpickle.dumps(result))
     }
     return info_dict
Ejemplo n.º 2
0
    def _task_table(self, task_id):
        """Fetch and parse the task table information for a single task ID.

        Args:
            task_id_binary: A string of bytes with the task ID to get
                information about.

        Returns:
            A dictionary with information about the task ID in question.
        """
        message = self._execute_command(task_id, "RAY.TABLE_LOOKUP",
                                        ray.gcs_utils.TablePrefix.RAYLET_TASK,
                                        "", task_id.id())
        gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
            message, 0)

        assert gcs_entries.EntriesLength() == 1

        task_table_message = ray.gcs_utils.Task.GetRootAsTask(
            gcs_entries.Entries(0), 0)

        execution_spec = task_table_message.TaskExecutionSpec()
        task_spec = task_table_message.TaskSpecification()
        task_spec = ray.raylet.task_from_string(task_spec)
        function_descriptor_list = task_spec.function_descriptor_list()
        function_descriptor = FunctionDescriptor.from_bytes_list(
            function_descriptor_list)
        task_spec_info = {
            "DriverID": binary_to_hex(task_spec.driver_id().id()),
            "TaskID": binary_to_hex(task_spec.task_id().id()),
            "ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()),
            "ParentCounter": task_spec.parent_counter(),
            "ActorID": binary_to_hex(task_spec.actor_id().id()),
            "ActorCreationID": binary_to_hex(
                task_spec.actor_creation_id().id()),
            "ActorCreationDummyObjectID": binary_to_hex(
                task_spec.actor_creation_dummy_object_id().id()),
            "ActorCounter": task_spec.actor_counter(),
            "Args": task_spec.arguments(),
            "ReturnObjectIDs": task_spec.returns(),
            "RequiredResources": task_spec.required_resources(),
            "FunctionID": binary_to_hex(function_descriptor.function_id.id()),
            "FunctionHash": binary_to_hex(function_descriptor.function_hash),
            "ModuleName": function_descriptor.module_name,
            "ClassName": function_descriptor.class_name,
            "FunctionName": function_descriptor.function_name,
        }

        return {
            "ExecutionSpec": {
                "Dependencies": [
                    execution_spec.Dependencies(i)
                    for i in range(execution_spec.DependenciesLength())
                ],
                "LastTimestamp": execution_spec.LastTimestamp(),
                "NumForwards": execution_spec.NumForwards()
            },
            "TaskSpec": task_spec_info
        }
Ejemplo n.º 3
0
 def actors(self):
     actor_keys = self.redis_client.keys("Actor:*")
     actor_info = dict()
     for key in actor_keys:
         info = self.redis_client.hgetall(key)
         actor_id = key[len("Actor:"):]
         assert len(actor_id) == 20
         actor_info[binary_to_hex(actor_id)] = {
             "class_id": binary_to_hex(info[b"class_id"]),
             "driver_id": binary_to_hex(info[b"driver_id"]),
             "local_scheduler_id":
                 binary_to_hex(info[b"local_scheduler_id"]),
             "num_gpus": int(info[b"num_gpus"]),
             "removed": decode(info[b"removed"]) == "True"}
     return actor_info
Ejemplo n.º 4
0
    def db_client_notification_handler(self, unused_channel, data):
        """Handle a notification from the db_client table from Redis.

        This handler processes notifications from the db_client table.
        Notifications should be parsed using the SubscribeToDBClientTableReply
        flatbuffer. Deletions are processed, insertions are ignored. Cleanup of
        the associated state in the state tables should be handled by the
        caller.
        """
        notification_object = (SubscribeToDBClientTableReply.
                               GetRootAsSubscribeToDBClientTableReply(data, 0))
        db_client_id = binary_to_hex(notification_object.DbClientId())
        client_type = notification_object.ClientType()
        is_insertion = notification_object.IsInsertion()

        # If the update was an insertion, we ignore it.
        if is_insertion:
            return

        # If the update was a deletion, add them to our accounting for dead
        # local schedulers and plasma managers.
        log.warn("Removed {}, client ID {}".format(client_type, db_client_id))
        if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
            if db_client_id not in self.dead_local_schedulers:
                self.dead_local_schedulers.add(db_client_id)
        elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
            if db_client_id not in self.dead_plasma_managers:
                self.dead_plasma_managers.add(db_client_id)
            # Stop tracking this plasma manager's heartbeats, since it's
            # already dead.
            del self.live_plasma_managers[db_client_id]
Ejemplo n.º 5
0
    def __getstate__(self):
        """Memento generator for Trial.

        Sets RUNNING trials to PENDING, and flushes the result logger.
        Note this can only occur if the trial holds a DISK checkpoint.
        """
        assert self._checkpoint.storage == Checkpoint.DISK, (
            "Checkpoint must not be in-memory.")
        state = self.__dict__.copy()
        state["resources"] = resources_to_json(self.resources)

        pickle_data = {
            "_checkpoint": self._checkpoint,
            "config": self.config,
            "custom_loggers": self.custom_loggers,
            "sync_function": self.sync_function,
            "last_result": self.last_result
        }

        for key, value in pickle_data.items():
            state[key] = binary_to_hex(cloudpickle.dumps(value))

        state["runner"] = None
        state["result_logger"] = None
        if self.status == Trial.RUNNING:
            state["status"] = Trial.PENDING
        if self.result_logger:
            self.result_logger.flush()
            state["__logger_started__"] = True
        else:
            state["__logger_started__"] = False
        return copy.deepcopy(state)
Ejemplo n.º 6
0
    def task_table(self, task_id=None):
        """Fetch and parse the task table information for one or more task IDs.

        Args:
            task_id: A hex string of the task ID to fetch information about. If
                this is None, then the task object table is fetched.

        Returns:
            Information from the task table.
        """
        self._check_connected()
        if task_id is not None:
            task_id = ray.TaskID(hex_to_binary(task_id))
            return self._task_table(task_id)
        else:
            task_table_keys = self._keys(
                ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*")
            task_ids_binary = [
                key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):]
                for key in task_table_keys
            ]

            results = {}
            for task_id_binary in task_ids_binary:
                results[binary_to_hex(task_id_binary)] = self._task_table(
                    ray.TaskID(task_id_binary))
            return results
Ejemplo n.º 7
0
    def error_messages(self, job_id=None):
        """Get the error messages for all jobs or a specific job.

        Args:
            job_id: The specific job to get the errors for. If this is None,
                then this method retrieves the errors for all jobs.

        Returns:
            A dictionary mapping job ID to a list of the error messages for
                that job.
        """
        if job_id is not None:
            assert isinstance(job_id, ray.DriverID)
            return self._error_messages(job_id)

        error_table_keys = self.redis_client.keys(
            ray.gcs_utils.TablePrefix_ERROR_INFO_string + "*")
        job_ids = [
            key[len(ray.gcs_utils.TablePrefix_ERROR_INFO_string):]
            for key in error_table_keys
        ]

        return {
            binary_to_hex(job_id): self._error_messages(ray.DriverID(job_id))
            for job_id in job_ids
        }
Ejemplo n.º 8
0
    def _xray_clean_up_entries_for_driver(self, driver_id):
        """Remove this driver's object/task entries from redis.

        Removes control-state entries of all tasks and task return
        objects belonging to the driver.

        Args:
            driver_id: The driver id.
        """

        xray_task_table_prefix = (
            ray.gcs_utils.TablePrefix_RAYLET_TASK_string.encode("ascii"))
        xray_object_table_prefix = (
            ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii"))

        task_table_objects = self.state.task_table()
        driver_id_hex = binary_to_hex(driver_id)
        driver_task_id_bins = set()
        for task_id_hex, task_info in task_table_objects.items():
            task_table_object = task_info["TaskSpec"]
            task_driver_id_hex = task_table_object["DriverID"]
            if driver_id_hex != task_driver_id_hex:
                # Ignore tasks that aren't from this driver.
                continue
            driver_task_id_bins.add(hex_to_binary(task_id_hex))

        # Get objects associated with the driver.
        object_table_objects = self.state.object_table()
        driver_object_id_bins = set()
        for object_id, _ in object_table_objects.items():
            task_id_bin = ray.raylet.compute_task_id(object_id).id()
            if task_id_bin in driver_task_id_bins:
                driver_object_id_bins.add(object_id.id())

        def to_shard_index(id_bin):
            return binary_to_object_id(id_bin).redis_shard_hash() % len(
                self.state.redis_clients)

        # Form the redis keys to delete.
        sharded_keys = [[] for _ in range(len(self.state.redis_clients))]
        for task_id_bin in driver_task_id_bins:
            sharded_keys[to_shard_index(task_id_bin)].append(
                xray_task_table_prefix + task_id_bin)
        for object_id_bin in driver_object_id_bins:
            sharded_keys[to_shard_index(object_id_bin)].append(
                xray_object_table_prefix + object_id_bin)

        # Remove with best effort.
        for shard_index in range(len(sharded_keys)):
            keys = sharded_keys[shard_index]
            if len(keys) == 0:
                continue
            redis = self.state.redis_clients[shard_index]
            num_deleted = redis.delete(*keys)
            logger.info("Removed {} dead redis entries of the driver from"
                        " redis shard {}.".format(num_deleted, shard_index))
            if num_deleted != len(keys):
                logger.warning("Failed to remove {} relevant redis entries"
                               " from redis shard {}.".format(
                                   len(keys) - num_deleted, shard_index))
Ejemplo n.º 9
0
    def workers(self):
        """Get a dictionary mapping worker ID to worker information."""
        worker_keys = self.redis_client.keys("Worker*")
        workers_data = dict()

        for worker_key in worker_keys:
            worker_info = self.redis_client.hgetall(worker_key)
            worker_id = binary_to_hex(worker_key[len("Workers:"):])

            workers_data[worker_id] = {
                "local_scheduler_socket":
                    (worker_info[b"local_scheduler_socket"]
                     .decode("ascii")),
                "node_ip_address": (worker_info[b"node_ip_address"]
                                    .decode("ascii")),
                "plasma_manager_socket": (worker_info[b"plasma_manager_socket"]
                                          .decode("ascii")),
                "plasma_store_socket": (worker_info[b"plasma_store_socket"]
                                        .decode("ascii"))
            }
            if b"stderr_file" in worker_info:
                workers_data[worker_id]["stderr_file"] = (
                    worker_info[b"stderr_file"].decode("ascii"))
            if b"stdout_file" in worker_info:
                workers_data[worker_id]["stdout_file"] = (
                    worker_info[b"stdout_file"].decode("ascii"))
        return workers_data
Ejemplo n.º 10
0
    def function_table(self, function_id=None):
        """Fetch and parse the function table.

        Returns:
            A dictionary that maps function IDs to information about the
                function.
        """
        self._check_connected()
        function_table_keys = self.redis_client.keys(FUNCTION_PREFIX + "*")
        results = {}
        for key in function_table_keys:
            info = self.redis_client.hgetall(key)
            function_info_parsed = {
                "DriverID": binary_to_hex(info[b"driver_id"]),
                "Module": decode(info[b"module"]),
                "Name": decode(info[b"name"])}
            results[binary_to_hex(info[b"function_id"])] = function_info_parsed
        return results
Ejemplo n.º 11
0
def _task_table_shard(shard_index):
    redis_client = ray.global_state.redis_clients[shard_index]
    task_table_keys = redis_client.keys(TASK_PREFIX + b"*")
    results = {}
    for key in task_table_keys:
        task_id_binary = key[len(TASK_PREFIX):]
        results[binary_to_hex(task_id_binary)] = ray.global_state._task_table(
            ray.ObjectID(task_id_binary))

    return results
Ejemplo n.º 12
0
def _object_table_shard(shard_index):
    redis_client = ray.global_state.redis_clients[shard_index]
    object_table_keys = redis_client.keys(OBJECT_LOCATION_PREFIX + b"*")
    results = {}
    for key in object_table_keys:
        object_id_binary = key[len(OBJECT_LOCATION_PREFIX):]
        results[binary_to_hex(object_id_binary)] = (
            ray.global_state._object_table(ray.ObjectID(object_id_binary)))

    return results
Ejemplo n.º 13
0
    def _profile_table(self, batch_id):
        """Get the profile events for a given batch of profile events.

        Args:
            batch_id: An identifier for a batch of profile events.

        Returns:
            A list of the profile events for the specified batch.
        """
        # TODO(rkn): This method should support limiting the number of log
        # events and should also support returning a window of events.
        message = self._execute_command(batch_id, "RAY.TABLE_LOOKUP",
                                        ray.gcs_utils.TablePrefix.PROFILE, "",
                                        batch_id.binary())

        if message is None:
            return []

        gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
            message, 0)

        profile_events = []
        for i in range(gcs_entries.EntriesLength()):
            profile_table_message = (
                ray.gcs_utils.ProfileTableData.GetRootAsProfileTableData(
                    gcs_entries.Entries(i), 0))

            component_type = decode(profile_table_message.ComponentType())
            component_id = binary_to_hex(profile_table_message.ComponentId())
            node_ip_address = decode(
                profile_table_message.NodeIpAddress(), allow_none=True)

            for j in range(profile_table_message.ProfileEventsLength()):
                profile_event_message = profile_table_message.ProfileEvents(j)

                profile_event = {
                    "event_type": decode(profile_event_message.EventType()),
                    "component_id": component_id,
                    "node_ip_address": node_ip_address,
                    "component_type": component_type,
                    "start_time": profile_event_message.StartTime(),
                    "end_time": profile_event_message.EndTime(),
                    "extra_data": json.loads(
                        decode(profile_event_message.ExtraData())),
                }

                profile_events.append(profile_event)

        return profile_events
Ejemplo n.º 14
0
    def cleanup_actors(self):
        """Recreate any live actors whose corresponding local scheduler died.

        For any live actor whose local scheduler just died, we choose a new
        local scheduler and broadcast a notification to create that actor.
        """
        actor_info = self.state.actors()
        for actor_id, info in actor_info.items():
            if (not info["removed"] and
                    info["local_scheduler_id"] in self.dead_local_schedulers):
                # Choose a new local scheduler to run the actor.
                local_scheduler_id = ray.utils.select_local_scheduler(
                    info["driver_id"],
                    self.state.local_schedulers(), info["num_gpus"],
                    self.redis)
                import sys
                sys.stdout.flush()
                # The new local scheduler should not be the same as the old
                # local scheduler. TODO(rkn): This should not be an assert, it
                # should be something more benign.
                assert (binary_to_hex(local_scheduler_id) !=
                        info["local_scheduler_id"])
                # Announce to all of the local schedulers that the actor should
                # be recreated on this new local scheduler.
                ray.utils.publish_actor_creation(
                    hex_to_binary(actor_id),
                    hex_to_binary(info["driver_id"]), local_scheduler_id, True,
                    self.redis)
                log.info("Actor {} for driver {} was on dead local scheduler "
                         "{}. It is being recreated on local scheduler {}"
                         .format(actor_id, info["driver_id"],
                                 info["local_scheduler_id"],
                                 binary_to_hex(local_scheduler_id)))
                # Update the actor info in Redis.
                self.redis.hset(b"Actor:" + hex_to_binary(actor_id),
                                "local_scheduler_id", local_scheduler_id)
Ejemplo n.º 15
0
    def _object_table(self, object_id):
        """Fetch and parse the object table information for a single object ID.

        Args:
            object_id_binary: A string of bytes with the object ID to get
                information about.

        Returns:
            A dictionary with information about the object ID in question.
        """
        # Allow the argument to be either an ObjectID or a hex string.
        if not isinstance(object_id, ray.local_scheduler.ObjectID):
            object_id = ray.local_scheduler.ObjectID(hex_to_binary(object_id))

        # Return information about a single object ID.
        object_locations = self._execute_command(object_id,
                                                 "RAY.OBJECT_TABLE_LOOKUP",
                                                 object_id.id())
        if object_locations is not None:
            manager_ids = [binary_to_hex(manager_id)
                           for manager_id in object_locations]
        else:
            manager_ids = None

        result_table_response = self._execute_command(
            object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id())
        result_table_message = ResultTableReply.GetRootAsResultTableReply(
            result_table_response, 0)

        result = {"ManagerIDs": manager_ids,
                  "TaskID": binary_to_hex(result_table_message.TaskId()),
                  "IsPut": bool(result_table_message.IsPut()),
                  "DataSize": result_table_message.DataSize(),
                  "Hash": binary_to_hex(result_table_message.Hash())}

        return result
Ejemplo n.º 16
0
    def xray_driver_removed_handler(self, unused_channel, data):
        """Handle a notification that a driver has been removed.

        Args:
            unused_channel: The message channel.
            data: The message data.
        """
        gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
            data, 0)
        driver_data = gcs_entries.Entries(0)
        message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData(
            driver_data, 0)
        driver_id = message.DriverId()
        logger.info("XRay Driver {} has been removed.".format(
            binary_to_hex(driver_id)))
        self._xray_clean_up_entries_for_driver(driver_id)
Ejemplo n.º 17
0
    def __init__(
            self, trainable_name, config=None, local_dir=DEFAULT_RESULTS_DIR,
            experiment_tag=None, resources=Resources(cpu=1, gpu=0),
            stopping_criterion=None, checkpoint_freq=0,
            restore_path=None, upload_dir=None, max_failures=0):
        """Initialize a new trial.

        The args here take the same meaning as the command line flags defined
        in ray.tune.config_parser.
        """

        if not _default_registry.contains(
                TRAINABLE_CLASS, trainable_name):
            raise TuneError("Unknown trainable: " + trainable_name)

        if stopping_criterion:
            for k in stopping_criterion:
                if k not in TrainingResult._fields:
                    raise TuneError(
                        "Stopping condition key `{}` must be one of {}".format(
                            k, TrainingResult._fields))

        # Trial config
        self.trainable_name = trainable_name
        self.config = config or {}
        self.local_dir = local_dir
        self.experiment_tag = experiment_tag
        self.resources = resources
        self.stopping_criterion = stopping_criterion or {}
        self.checkpoint_freq = checkpoint_freq
        self.upload_dir = upload_dir
        self.verbose = True
        self.max_failures = max_failures

        # Local trial state that is updated during the run
        self.last_result = None
        self._checkpoint_path = restore_path
        self._checkpoint_obj = None
        self.runner = None
        self.status = Trial.PENDING
        self.location = None
        self.logdir = None
        self.result_logger = None
        self.last_debug = 0
        self.trial_id = binary_to_hex(random_string())[:8]
        self.error_file = None
        self.num_failures = 0
Ejemplo n.º 18
0
    def client_table(self):
        """Fetch and parse the Redis DB client table.

        Returns:
            Information about the Ray clients in the cluster.
        """
        self._check_connected()
        db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*")
        node_info = dict()
        for key in db_client_keys:
            client_info = self.redis_client.hgetall(key)
            node_ip_address = decode(client_info[b"node_ip_address"])
            if node_ip_address not in node_info:
                node_info[node_ip_address] = []
            client_info_parsed = {}
            assert b"client_type" in client_info
            assert b"deleted" in client_info
            assert b"ray_client_id" in client_info
            for field, value in client_info.items():
                if field == b"node_ip_address":
                    pass
                elif field == b"client_type":
                    client_info_parsed["ClientType"] = decode(value)
                elif field == b"deleted":
                    client_info_parsed["Deleted"] = bool(int(decode(value)))
                elif field == b"ray_client_id":
                    client_info_parsed["DBClientID"] = binary_to_hex(value)
                elif field == b"manager_address":
                    client_info_parsed["AuxAddress"] = decode(value)
                elif field == b"local_scheduler_socket_name":
                    client_info_parsed["LocalSchedulerSocketName"] = (
                        decode(value))
                elif client_info[b"client_type"] == b"local_scheduler":
                    # The remaining fields are resource types.
                    client_info_parsed[field.decode("ascii")] = float(
                        decode(value))
                else:
                    client_info_parsed[field.decode("ascii")] = decode(value)

            node_info[node_ip_address].append(client_info_parsed)

        return node_info
Ejemplo n.º 19
0
    def workers(self):
        """Get a dictionary mapping worker ID to worker information."""
        worker_keys = self.redis_client.keys("Worker*")
        workers_data = {}

        for worker_key in worker_keys:
            worker_info = self.redis_client.hgetall(worker_key)
            worker_id = binary_to_hex(worker_key[len("Workers:"):])

            workers_data[worker_id] = {
                "node_ip_address": decode(worker_info[b"node_ip_address"]),
                "plasma_store_socket": decode(
                    worker_info[b"plasma_store_socket"])
            }
            if b"stderr_file" in worker_info:
                workers_data[worker_id]["stderr_file"] = decode(
                    worker_info[b"stderr_file"])
            if b"stdout_file" in worker_info:
                workers_data[worker_id]["stdout_file"] = decode(
                    worker_info[b"stdout_file"])
        return workers_data
Ejemplo n.º 20
0
    def _task_table(self, task_id):
        """Fetch and parse the task table information for a single task ID.

        Args:
            task_id_binary: A string of bytes with the task ID to get
                information about.

        Returns:
            A dictionary with information about the task ID in question.
                TASK_STATUS_MAPPING should be used to parse the "State" field
                into a human-readable string.
        """
        task_table_response = self._execute_command(task_id,
                                                    "RAY.TASK_TABLE_GET",
                                                    task_id.id())
        if task_table_response is None:
            raise Exception("There is no entry for task ID {} in the task "
                            "table.".format(binary_to_hex(task_id.id())))
        task_table_message = TaskReply.GetRootAsTaskReply(task_table_response,
                                                          0)
        task_spec = task_table_message.TaskSpec()
        task_spec = ray.local_scheduler.task_from_string(task_spec)

        task_spec_info = {
            "DriverID": binary_to_hex(task_spec.driver_id().id()),
            "TaskID": binary_to_hex(task_spec.task_id().id()),
            "ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()),
            "ParentCounter": task_spec.parent_counter(),
            "ActorID": binary_to_hex(task_spec.actor_id().id()),
            "ActorCounter": task_spec.actor_counter(),
            "FunctionID": binary_to_hex(task_spec.function_id().id()),
            "Args": task_spec.arguments(),
            "ReturnObjectIDs": task_spec.returns(),
            "RequiredResources": task_spec.required_resources()}

        return {"State": task_table_message.State(),
                "LocalSchedulerID": binary_to_hex(
                    task_table_message.LocalSchedulerId()),
                "ExecutionDependenciesString":
                    task_table_message.ExecutionDependencies(),
                "SpillbackCount":
                    task_table_message.SpillbackCount(),
                "TaskSpec": task_spec_info}
Ejemplo n.º 21
0
    def __getstate__(self):
        """Memento generator for Trial.

        Sets RUNNING trials to PENDING, and flushes the result logger.
        Note this can only occur if the trial holds a DISK checkpoint.
        """
        assert self._checkpoint.storage == Checkpoint.DISK, (
            "Checkpoint must not be in-memory.")
        state = self.__dict__.copy()
        state["resources"] = resources_to_json(self.resources)

        for key in self._nonjson_fields:
            state[key] = binary_to_hex(cloudpickle.dumps(state.get(key)))

        state["runner"] = None
        state["result_logger"] = None
        if self.result_logger:
            self.result_logger.flush()
            state["__logger_started__"] = True
        else:
            state["__logger_started__"] = False
        return copy.deepcopy(state)
Ejemplo n.º 22
0
    def task_table(self, task_id=None):
        """Fetch and parse the task table information for one or more task IDs.

        Args:
            task_id: A hex string of the task ID to fetch information about. If
                this is None, then the task object table is fetched.


        Returns:
            Information from the task table.
        """
        self._check_connected()
        if task_id is not None:
            task_id = ray.local_scheduler.ObjectID(hex_to_binary(task_id))
            return self._task_table(task_id)
        else:
            task_table_keys = self._keys(TASK_PREFIX + "*")
            results = {}
            for key in task_table_keys:
                task_id_binary = key[len(TASK_PREFIX):]
                results[binary_to_hex(task_id_binary)] = self._task_table(
                    ray.local_scheduler.ObjectID(task_id_binary))
            return results
Ejemplo n.º 23
0
    def driver_removed_handler(self, unused_channel, data):
        """Handle a notification that a driver has been removed.

        This releases any GPU resources that were reserved for that driver in
        Redis.
        """
        message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
        driver_id = message.DriverId()
        log.info(
            "Driver {} has been removed.".format(binary_to_hex(driver_id)))

        # Get a list of the local schedulers that have not been deleted.
        local_schedulers = ray.global_state.local_schedulers()

        self._clean_up_entries_for_driver(driver_id)

        # Release any GPU resources that have been reserved for this driver in
        # Redis.
        for local_scheduler in local_schedulers:
            if local_scheduler.get("GPU", 0) > 0:
                local_scheduler_id = local_scheduler["DBClientID"]

                num_gpus_returned = 0

                # Perform a transaction to return the GPUs.
                with self.redis.pipeline() as pipe:
                    while True:
                        try:
                            # If this key is changed before the transaction
                            # below (the multi/exec block), then the
                            # transaction will not take place.
                            pipe.watch(local_scheduler_id)

                            result = pipe.hget(local_scheduler_id,
                                               "gpus_in_use")
                            gpus_in_use = (dict() if result is None else
                                           json.loads(result.decode("ascii")))

                            driver_id_hex = binary_to_hex(driver_id)
                            if driver_id_hex in gpus_in_use:
                                num_gpus_returned = gpus_in_use.pop(
                                    driver_id_hex)

                            pipe.multi()

                            pipe.hset(local_scheduler_id, "gpus_in_use",
                                      json.dumps(gpus_in_use))

                            pipe.execute()
                            # If a WatchError is not raise, then the operations
                            # should have gone through atomically.
                            break
                        except redis.WatchError:
                            # Another client must have changed the watched key
                            # between the time we started WATCHing it and the
                            # pipeline's execution. We should just retry.
                            continue

                log.info("Driver {} is returning GPU IDs {} to local "
                         "scheduler {}.".format(
                             binary_to_hex(driver_id), num_gpus_returned,
                             local_scheduler_id))
Ejemplo n.º 24
0
    def _object_table(self, object_id):
        """Fetch and parse the object table information for a single object ID.

        Args:
            object_id_binary: A string of bytes with the object ID to get
                information about.

        Returns:
            A dictionary with information about the object ID in question.
        """
        # Allow the argument to be either an ObjectID or a hex string.
        if not isinstance(object_id, ray.ObjectID):
            object_id = ray.ObjectID(hex_to_binary(object_id))

        # Return information about a single object ID.
        if not self.use_raylet:
            # Use the non-raylet code path.
            object_locations = self._execute_command(
                object_id, "RAY.OBJECT_TABLE_LOOKUP", object_id.id())
            if object_locations is not None:
                manager_ids = [
                    binary_to_hex(manager_id)
                    for manager_id in object_locations
                ]
            else:
                manager_ids = None

            result_table_response = self._execute_command(
                object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id())
            result_table_message = (
                ray.gcs_utils.ResultTableReply.GetRootAsResultTableReply(
                    result_table_response, 0))

            result = {
                "ManagerIDs": manager_ids,
                "TaskID": binary_to_hex(result_table_message.TaskId()),
                "IsPut": bool(result_table_message.IsPut()),
                "DataSize": result_table_message.DataSize(),
                "Hash": binary_to_hex(result_table_message.Hash())
            }

        else:
            # Use the raylet code path.
            message = self.redis_client.execute_command(
                "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.OBJECT, "",
                object_id.id())
            result = []
            gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
                message, 0)

            for i in range(gcs_entry.EntriesLength()):
                entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData(
                    gcs_entry.Entries(i), 0)
                object_info = {
                    "DataSize": entry.ObjectSize(),
                    "Manager": entry.Manager(),
                    "IsEviction": entry.IsEviction(),
                    "NumEvictions": entry.NumEvictions()
                }
                result.append(object_info)

        return result
Ejemplo n.º 25
0
    def _task_table(self, task_id):
        """Fetch and parse the task table information for a single task ID.

        Args:
            task_id_binary: A string of bytes with the task ID to get
                information about.

        Returns:
            A dictionary with information about the task ID in question.
                TASK_STATUS_MAPPING should be used to parse the "State" field
                into a human-readable string.
        """
        if not self.use_raylet:
            # Use the non-raylet code path.
            task_table_response = self._execute_command(
                task_id, "RAY.TASK_TABLE_GET", task_id.id())
            if task_table_response is None:
                raise Exception("There is no entry for task ID {} in the task "
                                "table.".format(binary_to_hex(task_id.id())))
            task_table_message = ray.gcs_utils.TaskReply.GetRootAsTaskReply(
                task_table_response, 0)
            task_spec = task_table_message.TaskSpec()
            task_spec = ray.local_scheduler.task_from_string(task_spec)

            task_spec_info = {
                "DriverID": binary_to_hex(task_spec.driver_id().id()),
                "TaskID": binary_to_hex(task_spec.task_id().id()),
                "ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()),
                "ParentCounter": task_spec.parent_counter(),
                "ActorID": binary_to_hex(task_spec.actor_id().id()),
                "ActorCreationID": binary_to_hex(
                    task_spec.actor_creation_id().id()),
                "ActorCreationDummyObjectID": binary_to_hex(
                    task_spec.actor_creation_dummy_object_id().id()),
                "ActorCounter": task_spec.actor_counter(),
                "FunctionID": binary_to_hex(task_spec.function_id().id()),
                "Args": task_spec.arguments(),
                "ReturnObjectIDs": task_spec.returns(),
                "RequiredResources": task_spec.required_resources()
            }

            execution_dependencies_message = (
                ray.gcs_utils.TaskExecutionDependencies.
                GetRootAsTaskExecutionDependencies(
                    task_table_message.ExecutionDependencies(), 0))
            execution_dependencies = [
                ray.ObjectID(
                    execution_dependencies_message.ExecutionDependencies(i))
                for i in range(execution_dependencies_message.
                               ExecutionDependenciesLength())
            ]

            # TODO(rkn): The return fields ExecutionDependenciesString and
            # ExecutionDependencies are redundant, so we should remove
            # ExecutionDependencies. However, it is currently used in
            # monitor.py.

            return {
                "State": task_table_message.State(),
                "LocalSchedulerID": binary_to_hex(
                    task_table_message.LocalSchedulerId()),
                "ExecutionDependenciesString": task_table_message.
                ExecutionDependencies(),
                "ExecutionDependencies": execution_dependencies,
                "SpillbackCount": task_table_message.SpillbackCount(),
                "TaskSpec": task_spec_info
            }

        else:
            # Use the raylet code path.
            message = self.redis_client.execute_command(
                "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.RAYLET_TASK, "",
                task_id.id())
            gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
                message, 0)

            info = []
            for i in range(gcs_entries.EntriesLength()):
                task_table_message = ray.gcs_utils.Task.GetRootAsTask(
                    gcs_entries.Entries(i), 0)

                task_table_message = ray.gcs_utils.Task.GetRootAsTask(
                    gcs_entries.Entries(0), 0)
                execution_spec = task_table_message.TaskExecutionSpec()
                task_spec = task_table_message.TaskSpecification()
                task_spec = ray.local_scheduler.task_from_string(task_spec)
                task_spec_info = {
                    "DriverID": binary_to_hex(task_spec.driver_id().id()),
                    "TaskID": binary_to_hex(task_spec.task_id().id()),
                    "ParentTaskID": binary_to_hex(
                        task_spec.parent_task_id().id()),
                    "ParentCounter": task_spec.parent_counter(),
                    "ActorID": binary_to_hex(task_spec.actor_id().id()),
                    "ActorCreationID": binary_to_hex(
                        task_spec.actor_creation_id().id()),
                    "ActorCreationDummyObjectID": binary_to_hex(
                        task_spec.actor_creation_dummy_object_id().id()),
                    "ActorCounter": task_spec.actor_counter(),
                    "FunctionID": binary_to_hex(task_spec.function_id().id()),
                    "Args": task_spec.arguments(),
                    "ReturnObjectIDs": task_spec.returns(),
                    "RequiredResources": task_spec.required_resources()
                }

                info.append({
                    "ExecutionSpec": {
                        "Dependencies": [
                            execution_spec.Dependencies(i)
                            for i in range(execution_spec.DependenciesLength())
                        ],
                        "LastTimestamp": execution_spec.LastTimestamp(),
                        "NumForwards": execution_spec.NumForwards()
                    },
                    "TaskSpec": task_spec_info
                })

            return info
Ejemplo n.º 26
0
 def _to_cloudpickle(self, obj):
     return {
         "_type": "CLOUDPICKLE_FALLBACK",
         "value": binary_to_hex(cloudpickle.dumps(obj))
     }
Ejemplo n.º 27
0
Archivo: actor.py Proyecto: dmuestc/ray
def fetch_and_register_actor(actor_class_key, worker):
    """Import an actor.

    This will be called by the worker's import thread when the worker receives
    the actor_class export, assuming that the worker is an actor for that
    class.
    """
    actor_id_str = worker.actor_id
    (driver_id, class_id, class_name, module, pickled_class,
     checkpoint_interval,
     actor_method_names) = worker.redis_client.hmget(actor_class_key, [
         "driver_id", "class_id", "class_name", "module", "class",
         "checkpoint_interval", "actor_method_names"
     ])

    actor_name = class_name.decode("ascii")
    module = module.decode("ascii")
    checkpoint_interval = int(checkpoint_interval)
    actor_method_names = json.loads(actor_method_names.decode("ascii"))

    # Create a temporary actor with some temporary methods so that if the actor
    # fails to be unpickled, the temporary actor can be used (just to produce
    # error messages and to prevent the driver from hanging).
    class TemporaryActor(object):
        pass

    worker.actors[actor_id_str] = TemporaryActor()
    worker.actor_checkpoint_interval = checkpoint_interval

    def temporary_actor_method(*xs):
        raise Exception("The actor with name {} failed to be imported, and so "
                        "cannot execute this method".format(actor_name))

    # Register the actor method signatures.
    register_actor_signatures(worker, driver_id, class_name,
                              actor_method_names)
    # Register the actor method executors.
    for actor_method_name in actor_method_names:
        function_id = compute_actor_method_function_id(class_name,
                                                       actor_method_name).id()
        temporary_executor = make_actor_method_executor(
            worker, actor_method_name, temporary_actor_method)
        worker.functions[driver_id][function_id] = (actor_method_name,
                                                    temporary_executor)
        worker.num_task_executions[driver_id][function_id] = 0

    try:
        unpickled_class = pickle.loads(pickled_class)
        worker.actor_class = unpickled_class
    except Exception:
        # If an exception was thrown when the actor was imported, we record the
        # traceback and notify the scheduler of the failure.
        traceback_str = ray.worker.format_error_message(traceback.format_exc())
        # Log the error message.
        push_error_to_driver(worker.redis_client,
                             "register_actor_signatures",
                             traceback_str,
                             driver_id,
                             data={"actor_id": actor_id_str})
        # TODO(rkn): In the future, it might make sense to have the worker exit
        # here. However, currently that would lead to hanging if someone calls
        # ray.get on a method invoked on the actor.
    else:
        # TODO(pcm): Why is the below line necessary?
        unpickled_class.__module__ = module
        worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
        actor_methods = inspect.getmembers(
            unpickled_class,
            predicate=(lambda x: (inspect.isfunction(x) or inspect.ismethod(x)
                                  or is_cython(x))))
        for actor_method_name, actor_method in actor_methods:
            function_id = compute_actor_method_function_id(
                class_name, actor_method_name).id()
            executor = make_actor_method_executor(worker, actor_method_name,
                                                  actor_method)
            worker.functions[driver_id][function_id] = (actor_method_name,
                                                        executor)
            # We do not set worker.function_properties[driver_id][function_id]
            # because we currently do need the actor worker to submit new tasks
            # for the actor.

        # Store some extra information that will be used when the actor exits
        # to release GPU resources.
        worker.driver_id = binary_to_hex(driver_id)
        local_scheduler_id = worker.redis_client.hget(b"Actor:" + actor_id_str,
                                                      "local_scheduler_id")
        worker.local_scheduler_id = binary_to_hex(local_scheduler_id)
Ejemplo n.º 28
0
 def generate_id(cls):
     return binary_to_hex(_random_string())[:8]
Ejemplo n.º 29
0
    def _task_table(self, task_id):
        """Fetch and parse the task table information for a single task ID.

        Args:
            task_id_binary: A string of bytes with the task ID to get
                information about.

        Returns:
            A dictionary with information about the task ID in question.
                TASK_STATUS_MAPPING should be used to parse the "State" field
                into a human-readable string.
        """
        task_table_response = self._execute_command(task_id,
                                                    "RAY.TASK_TABLE_GET",
                                                    task_id.id())
        if task_table_response is None:
            raise Exception("There is no entry for task ID {} in the task "
                            "table.".format(binary_to_hex(task_id.id())))
        task_table_message = TaskReply.GetRootAsTaskReply(
            task_table_response, 0)
        task_spec = task_table_message.TaskSpec()
        task_spec = ray.local_scheduler.task_from_string(task_spec)

        task_spec_info = {
            "DriverID":
            binary_to_hex(task_spec.driver_id().id()),
            "TaskID":
            binary_to_hex(task_spec.task_id().id()),
            "ParentTaskID":
            binary_to_hex(task_spec.parent_task_id().id()),
            "ParentCounter":
            task_spec.parent_counter(),
            "ActorID":
            binary_to_hex(task_spec.actor_id().id()),
            "ActorCreationID":
            binary_to_hex(task_spec.actor_creation_id().id()),
            "ActorCreationDummyObjectID":
            binary_to_hex(task_spec.actor_creation_dummy_object_id().id()),
            "ActorCounter":
            task_spec.actor_counter(),
            "FunctionID":
            binary_to_hex(task_spec.function_id().id()),
            "Args":
            task_spec.arguments(),
            "ReturnObjectIDs":
            task_spec.returns(),
            "RequiredResources":
            task_spec.required_resources()
        }

        execution_dependencies_message = (
            TaskExecutionDependencies.GetRootAsTaskExecutionDependencies(
                task_table_message.ExecutionDependencies(), 0))
        execution_dependencies = [
            ray.ObjectID(
                execution_dependencies_message.ExecutionDependencies(i))
            for i in range(
                execution_dependencies_message.ExecutionDependenciesLength())
        ]

        # TODO(rkn): The return fields ExecutionDependenciesString and
        # ExecutionDependencies are redundant, so we should remove
        # ExecutionDependencies. However, it is currently used in monitor.py.

        return {
            "State":
            task_table_message.State(),
            "LocalSchedulerID":
            binary_to_hex(task_table_message.LocalSchedulerId()),
            "ExecutionDependenciesString":
            task_table_message.ExecutionDependencies(),
            "ExecutionDependencies":
            execution_dependencies,
            "SpillbackCount":
            task_table_message.SpillbackCount(),
            "TaskSpec":
            task_spec_info
        }
Ejemplo n.º 30
0
Archivo: state.py Proyecto: the-sea/ray
    def _task_table(self, task_id):
        """Fetch and parse the task table information for a single task ID.

        Args:
            task_id_binary: A string of bytes with the task ID to get
                information about.

        Returns:
            A dictionary with information about the task ID in question.
                TASK_STATUS_MAPPING should be used to parse the "State" field
                into a human-readable string.
        """
        task_table_response = self._execute_command(task_id,
                                                    "RAY.TASK_TABLE_GET",
                                                    task_id.id())
        if task_table_response is None:
            raise Exception("There is no entry for task ID {} in the task "
                            "table.".format(binary_to_hex(task_id.id())))
        task_table_message = TaskReply.GetRootAsTaskReply(
            task_table_response, 0)
        task_spec = task_table_message.TaskSpec()
        task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0)
        args = []
        for i in range(task_spec_message.ArgsLength()):
            arg = task_spec_message.Args(i)
            if len(arg.ObjectId()) != 0:
                args.append(binary_to_object_id(arg.ObjectId()))
            else:
                args.append(pickle.loads(arg.Data()))
        # TODO(atumanov): Instead of hard coding these indices, we should use
        # the flatbuffer constants.
        assert task_spec_message.RequiredResourcesLength() == 3
        required_resources = {
            "CPUs": task_spec_message.RequiredResources(0),
            "GPUs": task_spec_message.RequiredResources(1),
            "CustomResource": task_spec_message.RequiredResources(2)
        }
        task_spec_info = {
            "DriverID":
            binary_to_hex(task_spec_message.DriverId()),
            "TaskID":
            binary_to_hex(task_spec_message.TaskId()),
            "ParentTaskID":
            binary_to_hex(task_spec_message.ParentTaskId()),
            "ParentCounter":
            task_spec_message.ParentCounter(),
            "ActorID":
            binary_to_hex(task_spec_message.ActorId()),
            "ActorCounter":
            task_spec_message.ActorCounter(),
            "FunctionID":
            binary_to_hex(task_spec_message.FunctionId()),
            "Args":
            args,
            "ReturnObjectIDs": [
                binary_to_object_id(task_spec_message.Returns(i))
                for i in range(task_spec_message.ReturnsLength())
            ],
            "RequiredResources":
            required_resources
        }

        return {
            "State":
            task_table_message.State(),
            "LocalSchedulerID":
            binary_to_hex(task_table_message.LocalSchedulerId()),
            "TaskSpec":
            task_spec_info
        }
Ejemplo n.º 31
0
    def driver_removed_handler(self, channel, data):
        """Handle a notification that a driver has been removed.

        This releases any GPU resources that were reserved for that driver in
        Redis.
        """
        message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
        driver_id = message.DriverId()
        log.info("Driver {} has been removed."
                 .format(binary_to_hex(driver_id)))

        # Get a list of the local schedulers.
        client_table = ray.global_state.client_table()
        local_schedulers = []
        for ip_address, clients in client_table.items():
            for client in clients:
                if client["ClientType"] == "local_scheduler":
                    local_schedulers.append(client)

        # Release any GPU resources that have been reserved for this driver in
        # Redis.
        for local_scheduler in local_schedulers:
            if int(local_scheduler["NumGPUs"]) > 0:
                local_scheduler_id = local_scheduler["DBClientID"]

                num_gpus_returned = 0

                # Perform a transaction to return the GPUs.
                with self.redis.pipeline() as pipe:
                    while True:
                        try:
                            # If this key is changed before the transaction
                            # below (the multi/exec block), then the
                            # transaction will not take place.
                            pipe.watch(local_scheduler_id)

                            result = pipe.hget(local_scheduler_id,
                                               "gpus_in_use")
                            gpus_in_use = (dict() if result is None
                                           else json.loads(result))

                            driver_id_hex = binary_to_hex(driver_id)
                            if driver_id_hex in gpus_in_use:
                                num_gpus_returned = gpus_in_use.pop(
                                    driver_id_hex)

                            pipe.multi()

                            pipe.hset(local_scheduler_id, "gpus_in_use",
                                      json.dumps(gpus_in_use))

                            pipe.execute()
                            # If a WatchError is not raise, then the operations
                            # should have gone through atomically.
                            break
                        except redis.WatchError:
                            # Another client must have changed the watched key
                            # between the time we started WATCHing it and the
                            # pipeline's execution. We should just retry.
                            continue

                log.info("Driver {} is returning GPU IDs {} to local "
                         "scheduler {}.".format(binary_to_hex(driver_id),
                                                num_gpus_returned,
                                                local_scheduler_id))
Ejemplo n.º 32
0
 def generate_id(cls):
     return binary_to_hex(random_string())[:8]
Ejemplo n.º 33
0
    def _task_table(self, task_id):
        """Fetch and parse the task table information for a single task ID.

        Args:
            task_id: A task ID to get information about.

        Returns:
            A dictionary with information about the task ID in question.
        """
        assert isinstance(task_id, ray.TaskID)
        message = self._execute_command(task_id, "RAY.TABLE_LOOKUP",
                                        ray.gcs_utils.TablePrefix.RAYLET_TASK,
                                        "", task_id.binary())
        if message is None:
            return {}
        gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
            message, 0)

        assert gcs_entries.EntriesLength() == 1

        task_table_message = ray.gcs_utils.Task.GetRootAsTask(
            gcs_entries.Entries(0), 0)

        execution_spec = task_table_message.TaskExecutionSpec()
        task_spec = task_table_message.TaskSpecification()
        task = ray._raylet.Task.from_string(task_spec)
        function_descriptor_list = task.function_descriptor_list()
        function_descriptor = FunctionDescriptor.from_bytes_list(
            function_descriptor_list)

        task_spec_info = {
            "DriverID":
            task.driver_id().hex(),
            "TaskID":
            task.task_id().hex(),
            "ParentTaskID":
            task.parent_task_id().hex(),
            "ParentCounter":
            task.parent_counter(),
            "ActorID": (task.actor_id().hex()),
            "ActorCreationID":
            task.actor_creation_id().hex(),
            "ActorCreationDummyObjectID":
            (task.actor_creation_dummy_object_id().hex()),
            "ActorCounter":
            task.actor_counter(),
            "Args":
            task.arguments(),
            "ReturnObjectIDs":
            task.returns(),
            "RequiredResources":
            task.required_resources(),
            "FunctionID":
            function_descriptor.function_id.hex(),
            "FunctionHash":
            binary_to_hex(function_descriptor.function_hash),
            "ModuleName":
            function_descriptor.module_name,
            "ClassName":
            function_descriptor.class_name,
            "FunctionName":
            function_descriptor.function_name,
        }

        return {
            "ExecutionSpec": {
                "Dependencies": [
                    execution_spec.Dependencies(i)
                    for i in range(execution_spec.DependenciesLength())
                ],
                "LastTimestamp":
                execution_spec.LastTimestamp(),
                "NumForwards":
                execution_spec.NumForwards()
            },
            "TaskSpec": task_spec_info
        }
Ejemplo n.º 34
0
def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler, worker):
    """Attempt to acquire GPUs on a particular local scheduler for an actor.

  Args:
    num_gpus: The number of GPUs to acquire.
    driver_id: The ID of the driver responsible for creating the actor.
    local_scheduler: Information about the local scheduler.

  Returns:
    True if the GPUs were successfully reserved and false otherwise.
  """
    assert num_gpus != 0
    local_scheduler_id = local_scheduler["DBClientID"]
    local_scheduler_total_gpus = int(local_scheduler["NumGPUs"])

    success = False

    # Attempt to acquire GPU IDs atomically.
    with worker.redis_client.pipeline() as pipe:
        while True:
            try:
                # If this key is changed before the transaction below (the multi/exec
                # block), then the transaction will not take place.
                pipe.watch(local_scheduler_id)

                # Figure out which GPUs are currently in use.
                result = worker.redis_client.hget(local_scheduler_id,
                                                  "gpus_in_use")
                gpus_in_use = dict() if result is None else json.loads(
                    result.decode("ascii"))
                num_gpus_in_use = 0
                for key in gpus_in_use:
                    num_gpus_in_use += gpus_in_use[key]
                assert num_gpus_in_use <= local_scheduler_total_gpus

                pipe.multi()

                if local_scheduler_total_gpus - num_gpus_in_use >= num_gpus:
                    # There are enough available GPUs, so try to reserve some. We use the
                    # hex driver ID in hex as a dictionary key so that the dictionary is
                    # JSON serializable.
                    driver_id_hex = binary_to_hex(driver_id)
                    if driver_id_hex not in gpus_in_use:
                        gpus_in_use[driver_id_hex] = 0
                    gpus_in_use[driver_id_hex] += num_gpus

                    # Stick the updated GPU IDs back in Redis
                    pipe.hset(local_scheduler_id, "gpus_in_use",
                              json.dumps(gpus_in_use))
                    success = True

                pipe.execute()
                # If a WatchError is not raised, then the operations should have gone
                # through atomically.
                break
            except redis.WatchError:
                # Another client must have changed the watched key between the time we
                # started WATCHing it and the pipeline's execution. We should just
                # retry.
                success = False
                continue

    return success
Ejemplo n.º 35
0
    def client_table(self):
        """Fetch and parse the Redis DB client table.

        Returns:
            Information about the Ray clients in the cluster.
        """
        self._check_connected()
        if not self.use_raylet:
            db_client_keys = self.redis_client.keys(
                ray.gcs_utils.DB_CLIENT_PREFIX + "*")
            node_info = {}
            for key in db_client_keys:
                client_info = self.redis_client.hgetall(key)
                node_ip_address = decode(client_info[b"node_ip_address"])
                if node_ip_address not in node_info:
                    node_info[node_ip_address] = []
                client_info_parsed = {}
                assert b"client_type" in client_info
                assert b"deleted" in client_info
                assert b"ray_client_id" in client_info
                for field, value in client_info.items():
                    if field == b"node_ip_address":
                        pass
                    elif field == b"client_type":
                        client_info_parsed["ClientType"] = decode(value)
                    elif field == b"deleted":
                        client_info_parsed["Deleted"] = bool(
                            int(decode(value)))
                    elif field == b"ray_client_id":
                        client_info_parsed["DBClientID"] = binary_to_hex(value)
                    elif field == b"manager_address":
                        client_info_parsed["AuxAddress"] = decode(value)
                    elif field == b"local_scheduler_socket_name":
                        client_info_parsed["LocalSchedulerSocketName"] = (
                            decode(value))
                    elif client_info[b"client_type"] == b"local_scheduler":
                        # The remaining fields are resource types.
                        client_info_parsed[field.decode("ascii")] = float(
                            decode(value))
                    else:
                        client_info_parsed[field.decode("ascii")] = decode(
                            value)

                node_info[node_ip_address].append(client_info_parsed)

            return node_info

        else:
            # This is the raylet code path.
            NIL_CLIENT_ID = 20 * b"\xff"
            message = self.redis_client.execute_command(
                "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "",
                NIL_CLIENT_ID)
            node_info = []
            gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
                message, 0)

            for i in range(gcs_entry.EntriesLength()):
                client = (
                    ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
                        gcs_entry.Entries(i), 0))

                resources = {
                    client.ResourcesTotalLabel(i).decode("ascii"):
                    client.ResourcesTotalCapacity(i)
                    for i in range(client.ResourcesTotalLabelLength())
                }
                node_info.append({
                    "ClientID": ray.utils.binary_to_hex(client.ClientId()),
                    "IsInsertion": client.IsInsertion(),
                    "NodeManagerAddress": client.NodeManagerAddress().decode(
                        "ascii"),
                    "NodeManagerPort": client.NodeManagerPort(),
                    "ObjectManagerPort": client.ObjectManagerPort(),
                    "ObjectStoreSocketName": client.ObjectStoreSocketName()
                    .decode("ascii"),
                    "RayletSocketName": client.RayletSocketName().decode(
                        "ascii"),
                    "Resources": resources
                })
            return node_info
Ejemplo n.º 36
0
 def __repr__(self):
     return ("FunctionDescriptor:" + self._module_name + "." +
             self._class_name + "." + self._function_name + "." +
             binary_to_hex(self._function_source_hash))
Ejemplo n.º 37
0
def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler, worker):
  """Attempt to acquire GPUs on a particular local scheduler for an actor.

  Args:
    num_gpus: The number of GPUs to acquire.
    driver_id: The ID of the driver responsible for creating the actor.
    local_scheduler: Information about the local scheduler.

  Returns:
    A list of the GPU IDs that were successfully acquired. This should have
      length either equal to num_gpus or equal to 0.
  """
  local_scheduler_id = local_scheduler["DBClientID"]
  local_scheduler_total_gpus = int(local_scheduler["NumGPUs"])

  gpus_to_acquire = []

  # Attempt to acquire GPU IDs atomically.
  with worker.redis_client.pipeline() as pipe:
    while True:
      try:
        # If this key is changed before the transaction below (the multi/exec
        # block), then the transaction will not take place.
        pipe.watch(local_scheduler_id)

        # Figure out which GPUs are currently in use.
        result = worker.redis_client.hget(local_scheduler_id, "gpus_in_use")
        gpus_in_use = dict() if result is None else json.loads(result)
        all_gpu_ids_in_use = []
        for key in gpus_in_use:
          all_gpu_ids_in_use += gpus_in_use[key]
        assert len(all_gpu_ids_in_use) <= local_scheduler_total_gpus
        assert len(set(all_gpu_ids_in_use)) == len(all_gpu_ids_in_use)

        pipe.multi()

        if local_scheduler_total_gpus - len(all_gpu_ids_in_use) >= num_gpus:
          # There are enough available GPUs, so try to reserve some.
          all_gpu_ids = set(range(local_scheduler_total_gpus))
          for gpu_id in all_gpu_ids_in_use:
            all_gpu_ids.remove(gpu_id)
          gpus_to_acquire = list(all_gpu_ids)[:num_gpus]

          # Use the hex driver ID so that the dictionary is JSON serializable.
          driver_id_hex = binary_to_hex(driver_id)
          if driver_id_hex not in gpus_in_use:
            gpus_in_use[driver_id_hex] = []
          gpus_in_use[driver_id_hex] += gpus_to_acquire

          # Stick the updated GPU IDs back in Redis
          pipe.hset(local_scheduler_id, "gpus_in_use", json.dumps(gpus_in_use))

        pipe.execute()
        # If a WatchError is not raised, then the operations should have gone
        # through atomically.
        break
      except redis.WatchError:
        # Another client must have changed the watched key between the time we
        # started WATCHing it and the pipeline's execution. We should just
        # retry.
        gpus_to_acquire = []
        continue

  return gpus_to_acquire
Ejemplo n.º 38
0
    def _xray_clean_up_entries_for_driver(self, driver_id):
        """Remove this driver's object/task entries from redis.

        Removes control-state entries of all tasks and task return
        objects belonging to the driver.

        Args:
            driver_id: The driver id.
        """

        xray_task_table_prefix = (
            ray.gcs_utils.TablePrefix_RAYLET_TASK_string.encode("ascii"))
        xray_object_table_prefix = (
            ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii"))

        task_table_objects = self.state.task_table()
        driver_id_hex = binary_to_hex(driver_id)
        driver_task_id_bins = set()
        for task_id_hex in task_table_objects:
            if len(task_table_objects[task_id_hex]) == 0:
                continue
            task_table_object = task_table_objects[task_id_hex][0]["TaskSpec"]
            task_driver_id_hex = task_table_object["DriverID"]
            if driver_id_hex != task_driver_id_hex:
                # Ignore tasks that aren't from this driver.
                continue
            driver_task_id_bins.add(hex_to_binary(task_id_hex))

        # Get objects associated with the driver.
        object_table_objects = self.state.object_table()
        driver_object_id_bins = set()
        for object_id, object_table_object in object_table_objects.items():
            assert len(object_table_object) > 0
            task_id_bin = ray.local_scheduler.compute_task_id(object_id).id()
            if task_id_bin in driver_task_id_bins:
                driver_object_id_bins.add(object_id.id())

        def to_shard_index(id_bin):
            return binary_to_object_id(id_bin).redis_shard_hash() % len(
                self.state.redis_clients)

        # Form the redis keys to delete.
        sharded_keys = [[] for _ in range(len(self.state.redis_clients))]
        for task_id_bin in driver_task_id_bins:
            sharded_keys[to_shard_index(task_id_bin)].append(
                xray_task_table_prefix + task_id_bin)
        for object_id_bin in driver_object_id_bins:
            sharded_keys[to_shard_index(object_id_bin)].append(
                xray_object_table_prefix + object_id_bin)

        # Remove with best effort.
        for shard_index in range(len(sharded_keys)):
            keys = sharded_keys[shard_index]
            if len(keys) == 0:
                continue
            redis = self.state.redis_clients[shard_index]
            num_deleted = redis.delete(*keys)
            log.info("Removed {} dead redis entries of the driver"
                     " from redis shard {}.".format(num_deleted, shard_index))
            if num_deleted != len(keys):
                log.warning("Failed to remove {} relevant redis entries"
                            " from redis shard {}.".format(
                                len(keys) - num_deleted, shard_index))
Ejemplo n.º 39
0
 def __repr__(self):
     return ("FunctionDescriptor:" + self._module_name + "." +
             self._class_name + "." + self._function_name + "." +
             binary_to_hex(self._function_source_hash))
Ejemplo n.º 40
0
    def _xray_clean_up_entries_for_job(self, job_id):
        """Remove this job's object/task entries from redis.

        Removes control-state entries of all tasks and task return
        objects belonging to the driver.

        Args:
            job_id: The job id.
        """

        xray_task_table_prefix = (
            ray.gcs_utils.TablePrefix_RAYLET_TASK_string.encode("ascii"))
        xray_object_table_prefix = (
            ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii"))

        task_table_objects = ray.tasks()
        job_id_hex = binary_to_hex(job_id)
        job_task_id_bins = set()
        for task_id_hex, task_info in task_table_objects.items():
            task_table_object = task_info["TaskSpec"]
            task_job_id_hex = task_table_object["JobID"]
            if job_id_hex != task_job_id_hex:
                # Ignore tasks that aren't from this driver.
                continue
            job_task_id_bins.add(hex_to_binary(task_id_hex))

        # Get objects associated with the driver.
        object_table_objects = ray.objects()
        job_object_id_bins = set()
        for object_id, _ in object_table_objects.items():
            task_id_bin = ray._raylet.compute_task_id(object_id).binary()
            if task_id_bin in job_task_id_bins:
                job_object_id_bins.add(object_id.binary())

        def to_shard_index(id_bin):
            if len(id_bin) == ray.TaskID.size():
                return binary_to_task_id(id_bin).redis_shard_hash() % len(
                    ray.state.state.redis_clients)
            else:
                return binary_to_object_id(id_bin).redis_shard_hash() % len(
                    ray.state.state.redis_clients)

        # Form the redis keys to delete.
        sharded_keys = [[] for _ in range(len(ray.state.state.redis_clients))]
        for task_id_bin in job_task_id_bins:
            sharded_keys[to_shard_index(task_id_bin)].append(
                xray_task_table_prefix + task_id_bin)
        for object_id_bin in job_object_id_bins:
            sharded_keys[to_shard_index(object_id_bin)].append(
                xray_object_table_prefix + object_id_bin)

        # Remove with best effort.
        for shard_index in range(len(sharded_keys)):
            keys = sharded_keys[shard_index]
            if len(keys) == 0:
                continue
            redis = ray.state.state.redis_clients[shard_index]
            num_deleted = redis.delete(*keys)
            logger.info("Monitor: "
                        "Removed {} dead redis entries of the "
                        "driver from redis shard {}.".format(
                            num_deleted, shard_index))
            if num_deleted != len(keys):
                logger.warning("Monitor: "
                               "Failed to remove {} relevant redis "
                               "entries from redis shard {}.".format(
                                   len(keys) - num_deleted, shard_index))
Ejemplo n.º 41
0
    def _task_table(self, task_id):
        """Fetch and parse the task table information for a single task ID.

        Args:
            task_id: A task ID to get information about.

        Returns:
            A dictionary with information about the task ID in question.
        """
        assert isinstance(task_id, ray.TaskID)
        message = self._execute_command(
            task_id, "RAY.TABLE_LOOKUP",
            gcs_utils.TablePrefix.Value("RAYLET_TASK"), "", task_id.binary())
        if message is None:
            return {}
        gcs_entries = gcs_utils.GcsEntry.FromString(message)

        assert len(gcs_entries.entries) == 1
        task_table_data = gcs_utils.TaskTableData.FromString(
            gcs_entries.entries[0])

        task = ray._raylet.TaskSpec.from_string(
            task_table_data.task.task_spec.SerializeToString())
        function_descriptor_list = task.function_descriptor_list()
        function_descriptor = FunctionDescriptor.from_bytes_list(
            function_descriptor_list)

        task_spec_info = {
            "JobID":
            task.job_id().hex(),
            "TaskID":
            task.task_id().hex(),
            "ParentTaskID":
            task.parent_task_id().hex(),
            "ParentCounter":
            task.parent_counter(),
            "ActorID": (task.actor_id().hex()),
            "ActorCreationID":
            task.actor_creation_id().hex(),
            "ActorCreationDummyObjectID":
            (task.actor_creation_dummy_object_id().hex()),
            "PreviousActorTaskDummyObjectID":
            (task.previous_actor_task_dummy_object_id().hex()),
            "ActorCounter":
            task.actor_counter(),
            "Args":
            task.arguments(),
            "ReturnObjectIDs":
            task.returns(),
            "RequiredResources":
            task.required_resources(),
            "FunctionID":
            function_descriptor.function_id.hex(),
            "FunctionHash":
            binary_to_hex(function_descriptor.function_hash),
            "ModuleName":
            function_descriptor.module_name,
            "ClassName":
            function_descriptor.class_name,
            "FunctionName":
            function_descriptor.function_name,
        }

        execution_spec = ray._raylet.TaskExecutionSpec.from_string(
            task_table_data.task.task_execution_spec.SerializeToString())
        return {
            "ExecutionSpec": {
                "NumForwards": execution_spec.num_forwards(),
            },
            "TaskSpec": task_spec_info
        }
Ejemplo n.º 42
0
Archivo: utils.py Proyecto: zzmcdc/ray
 def default(self, obj):
     if isinstance(obj, bytes):
         return binary_to_hex(obj)
     # Let the base class default method raise the TypeError
     return json.JSONEncoder.default(self, obj)