def trial_info(self, trial): if trial.last_result: result = trial.last_result.copy() else: result = None info_dict = { "id": trial.trial_id, "trainable_name": trial.trainable_name, "config": binary_to_hex(cloudpickle.dumps(trial.config)), "status": trial.status, "result": binary_to_hex(cloudpickle.dumps(result)) } return info_dict
def _task_table(self, task_id): """Fetch and parse the task table information for a single task ID. Args: task_id_binary: A string of bytes with the task ID to get information about. Returns: A dictionary with information about the task ID in question. """ message = self._execute_command(task_id, "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.RAYLET_TASK, "", task_id.id()) gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( message, 0) assert gcs_entries.EntriesLength() == 1 task_table_message = ray.gcs_utils.Task.GetRootAsTask( gcs_entries.Entries(0), 0) execution_spec = task_table_message.TaskExecutionSpec() task_spec = task_table_message.TaskSpecification() task_spec = ray.raylet.task_from_string(task_spec) function_descriptor_list = task_spec.function_descriptor_list() function_descriptor = FunctionDescriptor.from_bytes_list( function_descriptor_list) task_spec_info = { "DriverID": binary_to_hex(task_spec.driver_id().id()), "TaskID": binary_to_hex(task_spec.task_id().id()), "ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()), "ParentCounter": task_spec.parent_counter(), "ActorID": binary_to_hex(task_spec.actor_id().id()), "ActorCreationID": binary_to_hex( task_spec.actor_creation_id().id()), "ActorCreationDummyObjectID": binary_to_hex( task_spec.actor_creation_dummy_object_id().id()), "ActorCounter": task_spec.actor_counter(), "Args": task_spec.arguments(), "ReturnObjectIDs": task_spec.returns(), "RequiredResources": task_spec.required_resources(), "FunctionID": binary_to_hex(function_descriptor.function_id.id()), "FunctionHash": binary_to_hex(function_descriptor.function_hash), "ModuleName": function_descriptor.module_name, "ClassName": function_descriptor.class_name, "FunctionName": function_descriptor.function_name, } return { "ExecutionSpec": { "Dependencies": [ execution_spec.Dependencies(i) for i in range(execution_spec.DependenciesLength()) ], "LastTimestamp": execution_spec.LastTimestamp(), "NumForwards": execution_spec.NumForwards() }, "TaskSpec": task_spec_info }
def actors(self): actor_keys = self.redis_client.keys("Actor:*") actor_info = dict() for key in actor_keys: info = self.redis_client.hgetall(key) actor_id = key[len("Actor:"):] assert len(actor_id) == 20 actor_info[binary_to_hex(actor_id)] = { "class_id": binary_to_hex(info[b"class_id"]), "driver_id": binary_to_hex(info[b"driver_id"]), "local_scheduler_id": binary_to_hex(info[b"local_scheduler_id"]), "num_gpus": int(info[b"num_gpus"]), "removed": decode(info[b"removed"]) == "True"} return actor_info
def db_client_notification_handler(self, unused_channel, data): """Handle a notification from the db_client table from Redis. This handler processes notifications from the db_client table. Notifications should be parsed using the SubscribeToDBClientTableReply flatbuffer. Deletions are processed, insertions are ignored. Cleanup of the associated state in the state tables should be handled by the caller. """ notification_object = (SubscribeToDBClientTableReply. GetRootAsSubscribeToDBClientTableReply(data, 0)) db_client_id = binary_to_hex(notification_object.DbClientId()) client_type = notification_object.ClientType() is_insertion = notification_object.IsInsertion() # If the update was an insertion, we ignore it. if is_insertion: return # If the update was a deletion, add them to our accounting for dead # local schedulers and plasma managers. log.warn("Removed {}, client ID {}".format(client_type, db_client_id)) if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: if db_client_id not in self.dead_local_schedulers: self.dead_local_schedulers.add(db_client_id) elif client_type == PLASMA_MANAGER_CLIENT_TYPE: if db_client_id not in self.dead_plasma_managers: self.dead_plasma_managers.add(db_client_id) # Stop tracking this plasma manager's heartbeats, since it's # already dead. del self.live_plasma_managers[db_client_id]
def __getstate__(self): """Memento generator for Trial. Sets RUNNING trials to PENDING, and flushes the result logger. Note this can only occur if the trial holds a DISK checkpoint. """ assert self._checkpoint.storage == Checkpoint.DISK, ( "Checkpoint must not be in-memory.") state = self.__dict__.copy() state["resources"] = resources_to_json(self.resources) pickle_data = { "_checkpoint": self._checkpoint, "config": self.config, "custom_loggers": self.custom_loggers, "sync_function": self.sync_function, "last_result": self.last_result } for key, value in pickle_data.items(): state[key] = binary_to_hex(cloudpickle.dumps(value)) state["runner"] = None state["result_logger"] = None if self.status == Trial.RUNNING: state["status"] = Trial.PENDING if self.result_logger: self.result_logger.flush() state["__logger_started__"] = True else: state["__logger_started__"] = False return copy.deepcopy(state)
def task_table(self, task_id=None): """Fetch and parse the task table information for one or more task IDs. Args: task_id: A hex string of the task ID to fetch information about. If this is None, then the task object table is fetched. Returns: Information from the task table. """ self._check_connected() if task_id is not None: task_id = ray.TaskID(hex_to_binary(task_id)) return self._task_table(task_id) else: task_table_keys = self._keys( ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*") task_ids_binary = [ key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):] for key in task_table_keys ] results = {} for task_id_binary in task_ids_binary: results[binary_to_hex(task_id_binary)] = self._task_table( ray.TaskID(task_id_binary)) return results
def error_messages(self, job_id=None): """Get the error messages for all jobs or a specific job. Args: job_id: The specific job to get the errors for. If this is None, then this method retrieves the errors for all jobs. Returns: A dictionary mapping job ID to a list of the error messages for that job. """ if job_id is not None: assert isinstance(job_id, ray.DriverID) return self._error_messages(job_id) error_table_keys = self.redis_client.keys( ray.gcs_utils.TablePrefix_ERROR_INFO_string + "*") job_ids = [ key[len(ray.gcs_utils.TablePrefix_ERROR_INFO_string):] for key in error_table_keys ] return { binary_to_hex(job_id): self._error_messages(ray.DriverID(job_id)) for job_id in job_ids }
def _xray_clean_up_entries_for_driver(self, driver_id): """Remove this driver's object/task entries from redis. Removes control-state entries of all tasks and task return objects belonging to the driver. Args: driver_id: The driver id. """ xray_task_table_prefix = ( ray.gcs_utils.TablePrefix_RAYLET_TASK_string.encode("ascii")) xray_object_table_prefix = ( ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii")) task_table_objects = self.state.task_table() driver_id_hex = binary_to_hex(driver_id) driver_task_id_bins = set() for task_id_hex, task_info in task_table_objects.items(): task_table_object = task_info["TaskSpec"] task_driver_id_hex = task_table_object["DriverID"] if driver_id_hex != task_driver_id_hex: # Ignore tasks that aren't from this driver. continue driver_task_id_bins.add(hex_to_binary(task_id_hex)) # Get objects associated with the driver. object_table_objects = self.state.object_table() driver_object_id_bins = set() for object_id, _ in object_table_objects.items(): task_id_bin = ray.raylet.compute_task_id(object_id).id() if task_id_bin in driver_task_id_bins: driver_object_id_bins.add(object_id.id()) def to_shard_index(id_bin): return binary_to_object_id(id_bin).redis_shard_hash() % len( self.state.redis_clients) # Form the redis keys to delete. sharded_keys = [[] for _ in range(len(self.state.redis_clients))] for task_id_bin in driver_task_id_bins: sharded_keys[to_shard_index(task_id_bin)].append( xray_task_table_prefix + task_id_bin) for object_id_bin in driver_object_id_bins: sharded_keys[to_shard_index(object_id_bin)].append( xray_object_table_prefix + object_id_bin) # Remove with best effort. for shard_index in range(len(sharded_keys)): keys = sharded_keys[shard_index] if len(keys) == 0: continue redis = self.state.redis_clients[shard_index] num_deleted = redis.delete(*keys) logger.info("Removed {} dead redis entries of the driver from" " redis shard {}.".format(num_deleted, shard_index)) if num_deleted != len(keys): logger.warning("Failed to remove {} relevant redis entries" " from redis shard {}.".format( len(keys) - num_deleted, shard_index))
def workers(self): """Get a dictionary mapping worker ID to worker information.""" worker_keys = self.redis_client.keys("Worker*") workers_data = dict() for worker_key in worker_keys: worker_info = self.redis_client.hgetall(worker_key) worker_id = binary_to_hex(worker_key[len("Workers:"):]) workers_data[worker_id] = { "local_scheduler_socket": (worker_info[b"local_scheduler_socket"] .decode("ascii")), "node_ip_address": (worker_info[b"node_ip_address"] .decode("ascii")), "plasma_manager_socket": (worker_info[b"plasma_manager_socket"] .decode("ascii")), "plasma_store_socket": (worker_info[b"plasma_store_socket"] .decode("ascii")) } if b"stderr_file" in worker_info: workers_data[worker_id]["stderr_file"] = ( worker_info[b"stderr_file"].decode("ascii")) if b"stdout_file" in worker_info: workers_data[worker_id]["stdout_file"] = ( worker_info[b"stdout_file"].decode("ascii")) return workers_data
def function_table(self, function_id=None): """Fetch and parse the function table. Returns: A dictionary that maps function IDs to information about the function. """ self._check_connected() function_table_keys = self.redis_client.keys(FUNCTION_PREFIX + "*") results = {} for key in function_table_keys: info = self.redis_client.hgetall(key) function_info_parsed = { "DriverID": binary_to_hex(info[b"driver_id"]), "Module": decode(info[b"module"]), "Name": decode(info[b"name"])} results[binary_to_hex(info[b"function_id"])] = function_info_parsed return results
def _task_table_shard(shard_index): redis_client = ray.global_state.redis_clients[shard_index] task_table_keys = redis_client.keys(TASK_PREFIX + b"*") results = {} for key in task_table_keys: task_id_binary = key[len(TASK_PREFIX):] results[binary_to_hex(task_id_binary)] = ray.global_state._task_table( ray.ObjectID(task_id_binary)) return results
def _object_table_shard(shard_index): redis_client = ray.global_state.redis_clients[shard_index] object_table_keys = redis_client.keys(OBJECT_LOCATION_PREFIX + b"*") results = {} for key in object_table_keys: object_id_binary = key[len(OBJECT_LOCATION_PREFIX):] results[binary_to_hex(object_id_binary)] = ( ray.global_state._object_table(ray.ObjectID(object_id_binary))) return results
def _profile_table(self, batch_id): """Get the profile events for a given batch of profile events. Args: batch_id: An identifier for a batch of profile events. Returns: A list of the profile events for the specified batch. """ # TODO(rkn): This method should support limiting the number of log # events and should also support returning a window of events. message = self._execute_command(batch_id, "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.PROFILE, "", batch_id.binary()) if message is None: return [] gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( message, 0) profile_events = [] for i in range(gcs_entries.EntriesLength()): profile_table_message = ( ray.gcs_utils.ProfileTableData.GetRootAsProfileTableData( gcs_entries.Entries(i), 0)) component_type = decode(profile_table_message.ComponentType()) component_id = binary_to_hex(profile_table_message.ComponentId()) node_ip_address = decode( profile_table_message.NodeIpAddress(), allow_none=True) for j in range(profile_table_message.ProfileEventsLength()): profile_event_message = profile_table_message.ProfileEvents(j) profile_event = { "event_type": decode(profile_event_message.EventType()), "component_id": component_id, "node_ip_address": node_ip_address, "component_type": component_type, "start_time": profile_event_message.StartTime(), "end_time": profile_event_message.EndTime(), "extra_data": json.loads( decode(profile_event_message.ExtraData())), } profile_events.append(profile_event) return profile_events
def cleanup_actors(self): """Recreate any live actors whose corresponding local scheduler died. For any live actor whose local scheduler just died, we choose a new local scheduler and broadcast a notification to create that actor. """ actor_info = self.state.actors() for actor_id, info in actor_info.items(): if (not info["removed"] and info["local_scheduler_id"] in self.dead_local_schedulers): # Choose a new local scheduler to run the actor. local_scheduler_id = ray.utils.select_local_scheduler( info["driver_id"], self.state.local_schedulers(), info["num_gpus"], self.redis) import sys sys.stdout.flush() # The new local scheduler should not be the same as the old # local scheduler. TODO(rkn): This should not be an assert, it # should be something more benign. assert (binary_to_hex(local_scheduler_id) != info["local_scheduler_id"]) # Announce to all of the local schedulers that the actor should # be recreated on this new local scheduler. ray.utils.publish_actor_creation( hex_to_binary(actor_id), hex_to_binary(info["driver_id"]), local_scheduler_id, True, self.redis) log.info("Actor {} for driver {} was on dead local scheduler " "{}. It is being recreated on local scheduler {}" .format(actor_id, info["driver_id"], info["local_scheduler_id"], binary_to_hex(local_scheduler_id))) # Update the actor info in Redis. self.redis.hset(b"Actor:" + hex_to_binary(actor_id), "local_scheduler_id", local_scheduler_id)
def _object_table(self, object_id): """Fetch and parse the object table information for a single object ID. Args: object_id_binary: A string of bytes with the object ID to get information about. Returns: A dictionary with information about the object ID in question. """ # Allow the argument to be either an ObjectID or a hex string. if not isinstance(object_id, ray.local_scheduler.ObjectID): object_id = ray.local_scheduler.ObjectID(hex_to_binary(object_id)) # Return information about a single object ID. object_locations = self._execute_command(object_id, "RAY.OBJECT_TABLE_LOOKUP", object_id.id()) if object_locations is not None: manager_ids = [binary_to_hex(manager_id) for manager_id in object_locations] else: manager_ids = None result_table_response = self._execute_command( object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id()) result_table_message = ResultTableReply.GetRootAsResultTableReply( result_table_response, 0) result = {"ManagerIDs": manager_ids, "TaskID": binary_to_hex(result_table_message.TaskId()), "IsPut": bool(result_table_message.IsPut()), "DataSize": result_table_message.DataSize(), "Hash": binary_to_hex(result_table_message.Hash())} return result
def xray_driver_removed_handler(self, unused_channel, data): """Handle a notification that a driver has been removed. Args: unused_channel: The message channel. data: The message data. """ gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( data, 0) driver_data = gcs_entries.Entries(0) message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData( driver_data, 0) driver_id = message.DriverId() logger.info("XRay Driver {} has been removed.".format( binary_to_hex(driver_id))) self._xray_clean_up_entries_for_driver(driver_id)
def __init__( self, trainable_name, config=None, local_dir=DEFAULT_RESULTS_DIR, experiment_tag=None, resources=Resources(cpu=1, gpu=0), stopping_criterion=None, checkpoint_freq=0, restore_path=None, upload_dir=None, max_failures=0): """Initialize a new trial. The args here take the same meaning as the command line flags defined in ray.tune.config_parser. """ if not _default_registry.contains( TRAINABLE_CLASS, trainable_name): raise TuneError("Unknown trainable: " + trainable_name) if stopping_criterion: for k in stopping_criterion: if k not in TrainingResult._fields: raise TuneError( "Stopping condition key `{}` must be one of {}".format( k, TrainingResult._fields)) # Trial config self.trainable_name = trainable_name self.config = config or {} self.local_dir = local_dir self.experiment_tag = experiment_tag self.resources = resources self.stopping_criterion = stopping_criterion or {} self.checkpoint_freq = checkpoint_freq self.upload_dir = upload_dir self.verbose = True self.max_failures = max_failures # Local trial state that is updated during the run self.last_result = None self._checkpoint_path = restore_path self._checkpoint_obj = None self.runner = None self.status = Trial.PENDING self.location = None self.logdir = None self.result_logger = None self.last_debug = 0 self.trial_id = binary_to_hex(random_string())[:8] self.error_file = None self.num_failures = 0
def client_table(self): """Fetch and parse the Redis DB client table. Returns: Information about the Ray clients in the cluster. """ self._check_connected() db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*") node_info = dict() for key in db_client_keys: client_info = self.redis_client.hgetall(key) node_ip_address = decode(client_info[b"node_ip_address"]) if node_ip_address not in node_info: node_info[node_ip_address] = [] client_info_parsed = {} assert b"client_type" in client_info assert b"deleted" in client_info assert b"ray_client_id" in client_info for field, value in client_info.items(): if field == b"node_ip_address": pass elif field == b"client_type": client_info_parsed["ClientType"] = decode(value) elif field == b"deleted": client_info_parsed["Deleted"] = bool(int(decode(value))) elif field == b"ray_client_id": client_info_parsed["DBClientID"] = binary_to_hex(value) elif field == b"manager_address": client_info_parsed["AuxAddress"] = decode(value) elif field == b"local_scheduler_socket_name": client_info_parsed["LocalSchedulerSocketName"] = ( decode(value)) elif client_info[b"client_type"] == b"local_scheduler": # The remaining fields are resource types. client_info_parsed[field.decode("ascii")] = float( decode(value)) else: client_info_parsed[field.decode("ascii")] = decode(value) node_info[node_ip_address].append(client_info_parsed) return node_info
def workers(self): """Get a dictionary mapping worker ID to worker information.""" worker_keys = self.redis_client.keys("Worker*") workers_data = {} for worker_key in worker_keys: worker_info = self.redis_client.hgetall(worker_key) worker_id = binary_to_hex(worker_key[len("Workers:"):]) workers_data[worker_id] = { "node_ip_address": decode(worker_info[b"node_ip_address"]), "plasma_store_socket": decode( worker_info[b"plasma_store_socket"]) } if b"stderr_file" in worker_info: workers_data[worker_id]["stderr_file"] = decode( worker_info[b"stderr_file"]) if b"stdout_file" in worker_info: workers_data[worker_id]["stdout_file"] = decode( worker_info[b"stdout_file"]) return workers_data
def _task_table(self, task_id): """Fetch and parse the task table information for a single task ID. Args: task_id_binary: A string of bytes with the task ID to get information about. Returns: A dictionary with information about the task ID in question. TASK_STATUS_MAPPING should be used to parse the "State" field into a human-readable string. """ task_table_response = self._execute_command(task_id, "RAY.TASK_TABLE_GET", task_id.id()) if task_table_response is None: raise Exception("There is no entry for task ID {} in the task " "table.".format(binary_to_hex(task_id.id()))) task_table_message = TaskReply.GetRootAsTaskReply(task_table_response, 0) task_spec = task_table_message.TaskSpec() task_spec = ray.local_scheduler.task_from_string(task_spec) task_spec_info = { "DriverID": binary_to_hex(task_spec.driver_id().id()), "TaskID": binary_to_hex(task_spec.task_id().id()), "ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()), "ParentCounter": task_spec.parent_counter(), "ActorID": binary_to_hex(task_spec.actor_id().id()), "ActorCounter": task_spec.actor_counter(), "FunctionID": binary_to_hex(task_spec.function_id().id()), "Args": task_spec.arguments(), "ReturnObjectIDs": task_spec.returns(), "RequiredResources": task_spec.required_resources()} return {"State": task_table_message.State(), "LocalSchedulerID": binary_to_hex( task_table_message.LocalSchedulerId()), "ExecutionDependenciesString": task_table_message.ExecutionDependencies(), "SpillbackCount": task_table_message.SpillbackCount(), "TaskSpec": task_spec_info}
def __getstate__(self): """Memento generator for Trial. Sets RUNNING trials to PENDING, and flushes the result logger. Note this can only occur if the trial holds a DISK checkpoint. """ assert self._checkpoint.storage == Checkpoint.DISK, ( "Checkpoint must not be in-memory.") state = self.__dict__.copy() state["resources"] = resources_to_json(self.resources) for key in self._nonjson_fields: state[key] = binary_to_hex(cloudpickle.dumps(state.get(key))) state["runner"] = None state["result_logger"] = None if self.result_logger: self.result_logger.flush() state["__logger_started__"] = True else: state["__logger_started__"] = False return copy.deepcopy(state)
def task_table(self, task_id=None): """Fetch and parse the task table information for one or more task IDs. Args: task_id: A hex string of the task ID to fetch information about. If this is None, then the task object table is fetched. Returns: Information from the task table. """ self._check_connected() if task_id is not None: task_id = ray.local_scheduler.ObjectID(hex_to_binary(task_id)) return self._task_table(task_id) else: task_table_keys = self._keys(TASK_PREFIX + "*") results = {} for key in task_table_keys: task_id_binary = key[len(TASK_PREFIX):] results[binary_to_hex(task_id_binary)] = self._task_table( ray.local_scheduler.ObjectID(task_id_binary)) return results
def driver_removed_handler(self, unused_channel, data): """Handle a notification that a driver has been removed. This releases any GPU resources that were reserved for that driver in Redis. """ message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0) driver_id = message.DriverId() log.info( "Driver {} has been removed.".format(binary_to_hex(driver_id))) # Get a list of the local schedulers that have not been deleted. local_schedulers = ray.global_state.local_schedulers() self._clean_up_entries_for_driver(driver_id) # Release any GPU resources that have been reserved for this driver in # Redis. for local_scheduler in local_schedulers: if local_scheduler.get("GPU", 0) > 0: local_scheduler_id = local_scheduler["DBClientID"] num_gpus_returned = 0 # Perform a transaction to return the GPUs. with self.redis.pipeline() as pipe: while True: try: # If this key is changed before the transaction # below (the multi/exec block), then the # transaction will not take place. pipe.watch(local_scheduler_id) result = pipe.hget(local_scheduler_id, "gpus_in_use") gpus_in_use = (dict() if result is None else json.loads(result.decode("ascii"))) driver_id_hex = binary_to_hex(driver_id) if driver_id_hex in gpus_in_use: num_gpus_returned = gpus_in_use.pop( driver_id_hex) pipe.multi() pipe.hset(local_scheduler_id, "gpus_in_use", json.dumps(gpus_in_use)) pipe.execute() # If a WatchError is not raise, then the operations # should have gone through atomically. break except redis.WatchError: # Another client must have changed the watched key # between the time we started WATCHing it and the # pipeline's execution. We should just retry. continue log.info("Driver {} is returning GPU IDs {} to local " "scheduler {}.".format( binary_to_hex(driver_id), num_gpus_returned, local_scheduler_id))
def _object_table(self, object_id): """Fetch and parse the object table information for a single object ID. Args: object_id_binary: A string of bytes with the object ID to get information about. Returns: A dictionary with information about the object ID in question. """ # Allow the argument to be either an ObjectID or a hex string. if not isinstance(object_id, ray.ObjectID): object_id = ray.ObjectID(hex_to_binary(object_id)) # Return information about a single object ID. if not self.use_raylet: # Use the non-raylet code path. object_locations = self._execute_command( object_id, "RAY.OBJECT_TABLE_LOOKUP", object_id.id()) if object_locations is not None: manager_ids = [ binary_to_hex(manager_id) for manager_id in object_locations ] else: manager_ids = None result_table_response = self._execute_command( object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id()) result_table_message = ( ray.gcs_utils.ResultTableReply.GetRootAsResultTableReply( result_table_response, 0)) result = { "ManagerIDs": manager_ids, "TaskID": binary_to_hex(result_table_message.TaskId()), "IsPut": bool(result_table_message.IsPut()), "DataSize": result_table_message.DataSize(), "Hash": binary_to_hex(result_table_message.Hash()) } else: # Use the raylet code path. message = self.redis_client.execute_command( "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.OBJECT, "", object_id.id()) result = [] gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( message, 0) for i in range(gcs_entry.EntriesLength()): entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData( gcs_entry.Entries(i), 0) object_info = { "DataSize": entry.ObjectSize(), "Manager": entry.Manager(), "IsEviction": entry.IsEviction(), "NumEvictions": entry.NumEvictions() } result.append(object_info) return result
def _task_table(self, task_id): """Fetch and parse the task table information for a single task ID. Args: task_id_binary: A string of bytes with the task ID to get information about. Returns: A dictionary with information about the task ID in question. TASK_STATUS_MAPPING should be used to parse the "State" field into a human-readable string. """ if not self.use_raylet: # Use the non-raylet code path. task_table_response = self._execute_command( task_id, "RAY.TASK_TABLE_GET", task_id.id()) if task_table_response is None: raise Exception("There is no entry for task ID {} in the task " "table.".format(binary_to_hex(task_id.id()))) task_table_message = ray.gcs_utils.TaskReply.GetRootAsTaskReply( task_table_response, 0) task_spec = task_table_message.TaskSpec() task_spec = ray.local_scheduler.task_from_string(task_spec) task_spec_info = { "DriverID": binary_to_hex(task_spec.driver_id().id()), "TaskID": binary_to_hex(task_spec.task_id().id()), "ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()), "ParentCounter": task_spec.parent_counter(), "ActorID": binary_to_hex(task_spec.actor_id().id()), "ActorCreationID": binary_to_hex( task_spec.actor_creation_id().id()), "ActorCreationDummyObjectID": binary_to_hex( task_spec.actor_creation_dummy_object_id().id()), "ActorCounter": task_spec.actor_counter(), "FunctionID": binary_to_hex(task_spec.function_id().id()), "Args": task_spec.arguments(), "ReturnObjectIDs": task_spec.returns(), "RequiredResources": task_spec.required_resources() } execution_dependencies_message = ( ray.gcs_utils.TaskExecutionDependencies. GetRootAsTaskExecutionDependencies( task_table_message.ExecutionDependencies(), 0)) execution_dependencies = [ ray.ObjectID( execution_dependencies_message.ExecutionDependencies(i)) for i in range(execution_dependencies_message. ExecutionDependenciesLength()) ] # TODO(rkn): The return fields ExecutionDependenciesString and # ExecutionDependencies are redundant, so we should remove # ExecutionDependencies. However, it is currently used in # monitor.py. return { "State": task_table_message.State(), "LocalSchedulerID": binary_to_hex( task_table_message.LocalSchedulerId()), "ExecutionDependenciesString": task_table_message. ExecutionDependencies(), "ExecutionDependencies": execution_dependencies, "SpillbackCount": task_table_message.SpillbackCount(), "TaskSpec": task_spec_info } else: # Use the raylet code path. message = self.redis_client.execute_command( "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.RAYLET_TASK, "", task_id.id()) gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( message, 0) info = [] for i in range(gcs_entries.EntriesLength()): task_table_message = ray.gcs_utils.Task.GetRootAsTask( gcs_entries.Entries(i), 0) task_table_message = ray.gcs_utils.Task.GetRootAsTask( gcs_entries.Entries(0), 0) execution_spec = task_table_message.TaskExecutionSpec() task_spec = task_table_message.TaskSpecification() task_spec = ray.local_scheduler.task_from_string(task_spec) task_spec_info = { "DriverID": binary_to_hex(task_spec.driver_id().id()), "TaskID": binary_to_hex(task_spec.task_id().id()), "ParentTaskID": binary_to_hex( task_spec.parent_task_id().id()), "ParentCounter": task_spec.parent_counter(), "ActorID": binary_to_hex(task_spec.actor_id().id()), "ActorCreationID": binary_to_hex( task_spec.actor_creation_id().id()), "ActorCreationDummyObjectID": binary_to_hex( task_spec.actor_creation_dummy_object_id().id()), "ActorCounter": task_spec.actor_counter(), "FunctionID": binary_to_hex(task_spec.function_id().id()), "Args": task_spec.arguments(), "ReturnObjectIDs": task_spec.returns(), "RequiredResources": task_spec.required_resources() } info.append({ "ExecutionSpec": { "Dependencies": [ execution_spec.Dependencies(i) for i in range(execution_spec.DependenciesLength()) ], "LastTimestamp": execution_spec.LastTimestamp(), "NumForwards": execution_spec.NumForwards() }, "TaskSpec": task_spec_info }) return info
def _to_cloudpickle(self, obj): return { "_type": "CLOUDPICKLE_FALLBACK", "value": binary_to_hex(cloudpickle.dumps(obj)) }
def fetch_and_register_actor(actor_class_key, worker): """Import an actor. This will be called by the worker's import thread when the worker receives the actor_class export, assuming that the worker is an actor for that class. """ actor_id_str = worker.actor_id (driver_id, class_id, class_name, module, pickled_class, checkpoint_interval, actor_method_names) = worker.redis_client.hmget(actor_class_key, [ "driver_id", "class_id", "class_name", "module", "class", "checkpoint_interval", "actor_method_names" ]) actor_name = class_name.decode("ascii") module = module.decode("ascii") checkpoint_interval = int(checkpoint_interval) actor_method_names = json.loads(actor_method_names.decode("ascii")) # Create a temporary actor with some temporary methods so that if the actor # fails to be unpickled, the temporary actor can be used (just to produce # error messages and to prevent the driver from hanging). class TemporaryActor(object): pass worker.actors[actor_id_str] = TemporaryActor() worker.actor_checkpoint_interval = checkpoint_interval def temporary_actor_method(*xs): raise Exception("The actor with name {} failed to be imported, and so " "cannot execute this method".format(actor_name)) # Register the actor method signatures. register_actor_signatures(worker, driver_id, class_name, actor_method_names) # Register the actor method executors. for actor_method_name in actor_method_names: function_id = compute_actor_method_function_id(class_name, actor_method_name).id() temporary_executor = make_actor_method_executor( worker, actor_method_name, temporary_actor_method) worker.functions[driver_id][function_id] = (actor_method_name, temporary_executor) worker.num_task_executions[driver_id][function_id] = 0 try: unpickled_class = pickle.loads(pickled_class) worker.actor_class = unpickled_class except Exception: # If an exception was thrown when the actor was imported, we record the # traceback and notify the scheduler of the failure. traceback_str = ray.worker.format_error_message(traceback.format_exc()) # Log the error message. push_error_to_driver(worker.redis_client, "register_actor_signatures", traceback_str, driver_id, data={"actor_id": actor_id_str}) # TODO(rkn): In the future, it might make sense to have the worker exit # here. However, currently that would lead to hanging if someone calls # ray.get on a method invoked on the actor. else: # TODO(pcm): Why is the below line necessary? unpickled_class.__module__ = module worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class) actor_methods = inspect.getmembers( unpickled_class, predicate=(lambda x: (inspect.isfunction(x) or inspect.ismethod(x) or is_cython(x)))) for actor_method_name, actor_method in actor_methods: function_id = compute_actor_method_function_id( class_name, actor_method_name).id() executor = make_actor_method_executor(worker, actor_method_name, actor_method) worker.functions[driver_id][function_id] = (actor_method_name, executor) # We do not set worker.function_properties[driver_id][function_id] # because we currently do need the actor worker to submit new tasks # for the actor. # Store some extra information that will be used when the actor exits # to release GPU resources. worker.driver_id = binary_to_hex(driver_id) local_scheduler_id = worker.redis_client.hget(b"Actor:" + actor_id_str, "local_scheduler_id") worker.local_scheduler_id = binary_to_hex(local_scheduler_id)
def generate_id(cls): return binary_to_hex(_random_string())[:8]
def _task_table(self, task_id): """Fetch and parse the task table information for a single task ID. Args: task_id_binary: A string of bytes with the task ID to get information about. Returns: A dictionary with information about the task ID in question. TASK_STATUS_MAPPING should be used to parse the "State" field into a human-readable string. """ task_table_response = self._execute_command(task_id, "RAY.TASK_TABLE_GET", task_id.id()) if task_table_response is None: raise Exception("There is no entry for task ID {} in the task " "table.".format(binary_to_hex(task_id.id()))) task_table_message = TaskReply.GetRootAsTaskReply( task_table_response, 0) task_spec = task_table_message.TaskSpec() task_spec = ray.local_scheduler.task_from_string(task_spec) task_spec_info = { "DriverID": binary_to_hex(task_spec.driver_id().id()), "TaskID": binary_to_hex(task_spec.task_id().id()), "ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()), "ParentCounter": task_spec.parent_counter(), "ActorID": binary_to_hex(task_spec.actor_id().id()), "ActorCreationID": binary_to_hex(task_spec.actor_creation_id().id()), "ActorCreationDummyObjectID": binary_to_hex(task_spec.actor_creation_dummy_object_id().id()), "ActorCounter": task_spec.actor_counter(), "FunctionID": binary_to_hex(task_spec.function_id().id()), "Args": task_spec.arguments(), "ReturnObjectIDs": task_spec.returns(), "RequiredResources": task_spec.required_resources() } execution_dependencies_message = ( TaskExecutionDependencies.GetRootAsTaskExecutionDependencies( task_table_message.ExecutionDependencies(), 0)) execution_dependencies = [ ray.ObjectID( execution_dependencies_message.ExecutionDependencies(i)) for i in range( execution_dependencies_message.ExecutionDependenciesLength()) ] # TODO(rkn): The return fields ExecutionDependenciesString and # ExecutionDependencies are redundant, so we should remove # ExecutionDependencies. However, it is currently used in monitor.py. return { "State": task_table_message.State(), "LocalSchedulerID": binary_to_hex(task_table_message.LocalSchedulerId()), "ExecutionDependenciesString": task_table_message.ExecutionDependencies(), "ExecutionDependencies": execution_dependencies, "SpillbackCount": task_table_message.SpillbackCount(), "TaskSpec": task_spec_info }
def _task_table(self, task_id): """Fetch and parse the task table information for a single task ID. Args: task_id_binary: A string of bytes with the task ID to get information about. Returns: A dictionary with information about the task ID in question. TASK_STATUS_MAPPING should be used to parse the "State" field into a human-readable string. """ task_table_response = self._execute_command(task_id, "RAY.TASK_TABLE_GET", task_id.id()) if task_table_response is None: raise Exception("There is no entry for task ID {} in the task " "table.".format(binary_to_hex(task_id.id()))) task_table_message = TaskReply.GetRootAsTaskReply( task_table_response, 0) task_spec = task_table_message.TaskSpec() task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0) args = [] for i in range(task_spec_message.ArgsLength()): arg = task_spec_message.Args(i) if len(arg.ObjectId()) != 0: args.append(binary_to_object_id(arg.ObjectId())) else: args.append(pickle.loads(arg.Data())) # TODO(atumanov): Instead of hard coding these indices, we should use # the flatbuffer constants. assert task_spec_message.RequiredResourcesLength() == 3 required_resources = { "CPUs": task_spec_message.RequiredResources(0), "GPUs": task_spec_message.RequiredResources(1), "CustomResource": task_spec_message.RequiredResources(2) } task_spec_info = { "DriverID": binary_to_hex(task_spec_message.DriverId()), "TaskID": binary_to_hex(task_spec_message.TaskId()), "ParentTaskID": binary_to_hex(task_spec_message.ParentTaskId()), "ParentCounter": task_spec_message.ParentCounter(), "ActorID": binary_to_hex(task_spec_message.ActorId()), "ActorCounter": task_spec_message.ActorCounter(), "FunctionID": binary_to_hex(task_spec_message.FunctionId()), "Args": args, "ReturnObjectIDs": [ binary_to_object_id(task_spec_message.Returns(i)) for i in range(task_spec_message.ReturnsLength()) ], "RequiredResources": required_resources } return { "State": task_table_message.State(), "LocalSchedulerID": binary_to_hex(task_table_message.LocalSchedulerId()), "TaskSpec": task_spec_info }
def driver_removed_handler(self, channel, data): """Handle a notification that a driver has been removed. This releases any GPU resources that were reserved for that driver in Redis. """ message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0) driver_id = message.DriverId() log.info("Driver {} has been removed." .format(binary_to_hex(driver_id))) # Get a list of the local schedulers. client_table = ray.global_state.client_table() local_schedulers = [] for ip_address, clients in client_table.items(): for client in clients: if client["ClientType"] == "local_scheduler": local_schedulers.append(client) # Release any GPU resources that have been reserved for this driver in # Redis. for local_scheduler in local_schedulers: if int(local_scheduler["NumGPUs"]) > 0: local_scheduler_id = local_scheduler["DBClientID"] num_gpus_returned = 0 # Perform a transaction to return the GPUs. with self.redis.pipeline() as pipe: while True: try: # If this key is changed before the transaction # below (the multi/exec block), then the # transaction will not take place. pipe.watch(local_scheduler_id) result = pipe.hget(local_scheduler_id, "gpus_in_use") gpus_in_use = (dict() if result is None else json.loads(result)) driver_id_hex = binary_to_hex(driver_id) if driver_id_hex in gpus_in_use: num_gpus_returned = gpus_in_use.pop( driver_id_hex) pipe.multi() pipe.hset(local_scheduler_id, "gpus_in_use", json.dumps(gpus_in_use)) pipe.execute() # If a WatchError is not raise, then the operations # should have gone through atomically. break except redis.WatchError: # Another client must have changed the watched key # between the time we started WATCHing it and the # pipeline's execution. We should just retry. continue log.info("Driver {} is returning GPU IDs {} to local " "scheduler {}.".format(binary_to_hex(driver_id), num_gpus_returned, local_scheduler_id))
def generate_id(cls): return binary_to_hex(random_string())[:8]
def _task_table(self, task_id): """Fetch and parse the task table information for a single task ID. Args: task_id: A task ID to get information about. Returns: A dictionary with information about the task ID in question. """ assert isinstance(task_id, ray.TaskID) message = self._execute_command(task_id, "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.RAYLET_TASK, "", task_id.binary()) if message is None: return {} gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( message, 0) assert gcs_entries.EntriesLength() == 1 task_table_message = ray.gcs_utils.Task.GetRootAsTask( gcs_entries.Entries(0), 0) execution_spec = task_table_message.TaskExecutionSpec() task_spec = task_table_message.TaskSpecification() task = ray._raylet.Task.from_string(task_spec) function_descriptor_list = task.function_descriptor_list() function_descriptor = FunctionDescriptor.from_bytes_list( function_descriptor_list) task_spec_info = { "DriverID": task.driver_id().hex(), "TaskID": task.task_id().hex(), "ParentTaskID": task.parent_task_id().hex(), "ParentCounter": task.parent_counter(), "ActorID": (task.actor_id().hex()), "ActorCreationID": task.actor_creation_id().hex(), "ActorCreationDummyObjectID": (task.actor_creation_dummy_object_id().hex()), "ActorCounter": task.actor_counter(), "Args": task.arguments(), "ReturnObjectIDs": task.returns(), "RequiredResources": task.required_resources(), "FunctionID": function_descriptor.function_id.hex(), "FunctionHash": binary_to_hex(function_descriptor.function_hash), "ModuleName": function_descriptor.module_name, "ClassName": function_descriptor.class_name, "FunctionName": function_descriptor.function_name, } return { "ExecutionSpec": { "Dependencies": [ execution_spec.Dependencies(i) for i in range(execution_spec.DependenciesLength()) ], "LastTimestamp": execution_spec.LastTimestamp(), "NumForwards": execution_spec.NumForwards() }, "TaskSpec": task_spec_info }
def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler, worker): """Attempt to acquire GPUs on a particular local scheduler for an actor. Args: num_gpus: The number of GPUs to acquire. driver_id: The ID of the driver responsible for creating the actor. local_scheduler: Information about the local scheduler. Returns: True if the GPUs were successfully reserved and false otherwise. """ assert num_gpus != 0 local_scheduler_id = local_scheduler["DBClientID"] local_scheduler_total_gpus = int(local_scheduler["NumGPUs"]) success = False # Attempt to acquire GPU IDs atomically. with worker.redis_client.pipeline() as pipe: while True: try: # If this key is changed before the transaction below (the multi/exec # block), then the transaction will not take place. pipe.watch(local_scheduler_id) # Figure out which GPUs are currently in use. result = worker.redis_client.hget(local_scheduler_id, "gpus_in_use") gpus_in_use = dict() if result is None else json.loads( result.decode("ascii")) num_gpus_in_use = 0 for key in gpus_in_use: num_gpus_in_use += gpus_in_use[key] assert num_gpus_in_use <= local_scheduler_total_gpus pipe.multi() if local_scheduler_total_gpus - num_gpus_in_use >= num_gpus: # There are enough available GPUs, so try to reserve some. We use the # hex driver ID in hex as a dictionary key so that the dictionary is # JSON serializable. driver_id_hex = binary_to_hex(driver_id) if driver_id_hex not in gpus_in_use: gpus_in_use[driver_id_hex] = 0 gpus_in_use[driver_id_hex] += num_gpus # Stick the updated GPU IDs back in Redis pipe.hset(local_scheduler_id, "gpus_in_use", json.dumps(gpus_in_use)) success = True pipe.execute() # If a WatchError is not raised, then the operations should have gone # through atomically. break except redis.WatchError: # Another client must have changed the watched key between the time we # started WATCHing it and the pipeline's execution. We should just # retry. success = False continue return success
def client_table(self): """Fetch and parse the Redis DB client table. Returns: Information about the Ray clients in the cluster. """ self._check_connected() if not self.use_raylet: db_client_keys = self.redis_client.keys( ray.gcs_utils.DB_CLIENT_PREFIX + "*") node_info = {} for key in db_client_keys: client_info = self.redis_client.hgetall(key) node_ip_address = decode(client_info[b"node_ip_address"]) if node_ip_address not in node_info: node_info[node_ip_address] = [] client_info_parsed = {} assert b"client_type" in client_info assert b"deleted" in client_info assert b"ray_client_id" in client_info for field, value in client_info.items(): if field == b"node_ip_address": pass elif field == b"client_type": client_info_parsed["ClientType"] = decode(value) elif field == b"deleted": client_info_parsed["Deleted"] = bool( int(decode(value))) elif field == b"ray_client_id": client_info_parsed["DBClientID"] = binary_to_hex(value) elif field == b"manager_address": client_info_parsed["AuxAddress"] = decode(value) elif field == b"local_scheduler_socket_name": client_info_parsed["LocalSchedulerSocketName"] = ( decode(value)) elif client_info[b"client_type"] == b"local_scheduler": # The remaining fields are resource types. client_info_parsed[field.decode("ascii")] = float( decode(value)) else: client_info_parsed[field.decode("ascii")] = decode( value) node_info[node_ip_address].append(client_info_parsed) return node_info else: # This is the raylet code path. NIL_CLIENT_ID = 20 * b"\xff" message = self.redis_client.execute_command( "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "", NIL_CLIENT_ID) node_info = [] gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( message, 0) for i in range(gcs_entry.EntriesLength()): client = ( ray.gcs_utils.ClientTableData.GetRootAsClientTableData( gcs_entry.Entries(i), 0)) resources = { client.ResourcesTotalLabel(i).decode("ascii"): client.ResourcesTotalCapacity(i) for i in range(client.ResourcesTotalLabelLength()) } node_info.append({ "ClientID": ray.utils.binary_to_hex(client.ClientId()), "IsInsertion": client.IsInsertion(), "NodeManagerAddress": client.NodeManagerAddress().decode( "ascii"), "NodeManagerPort": client.NodeManagerPort(), "ObjectManagerPort": client.ObjectManagerPort(), "ObjectStoreSocketName": client.ObjectStoreSocketName() .decode("ascii"), "RayletSocketName": client.RayletSocketName().decode( "ascii"), "Resources": resources }) return node_info
def __repr__(self): return ("FunctionDescriptor:" + self._module_name + "." + self._class_name + "." + self._function_name + "." + binary_to_hex(self._function_source_hash))
def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler, worker): """Attempt to acquire GPUs on a particular local scheduler for an actor. Args: num_gpus: The number of GPUs to acquire. driver_id: The ID of the driver responsible for creating the actor. local_scheduler: Information about the local scheduler. Returns: A list of the GPU IDs that were successfully acquired. This should have length either equal to num_gpus or equal to 0. """ local_scheduler_id = local_scheduler["DBClientID"] local_scheduler_total_gpus = int(local_scheduler["NumGPUs"]) gpus_to_acquire = [] # Attempt to acquire GPU IDs atomically. with worker.redis_client.pipeline() as pipe: while True: try: # If this key is changed before the transaction below (the multi/exec # block), then the transaction will not take place. pipe.watch(local_scheduler_id) # Figure out which GPUs are currently in use. result = worker.redis_client.hget(local_scheduler_id, "gpus_in_use") gpus_in_use = dict() if result is None else json.loads(result) all_gpu_ids_in_use = [] for key in gpus_in_use: all_gpu_ids_in_use += gpus_in_use[key] assert len(all_gpu_ids_in_use) <= local_scheduler_total_gpus assert len(set(all_gpu_ids_in_use)) == len(all_gpu_ids_in_use) pipe.multi() if local_scheduler_total_gpus - len(all_gpu_ids_in_use) >= num_gpus: # There are enough available GPUs, so try to reserve some. all_gpu_ids = set(range(local_scheduler_total_gpus)) for gpu_id in all_gpu_ids_in_use: all_gpu_ids.remove(gpu_id) gpus_to_acquire = list(all_gpu_ids)[:num_gpus] # Use the hex driver ID so that the dictionary is JSON serializable. driver_id_hex = binary_to_hex(driver_id) if driver_id_hex not in gpus_in_use: gpus_in_use[driver_id_hex] = [] gpus_in_use[driver_id_hex] += gpus_to_acquire # Stick the updated GPU IDs back in Redis pipe.hset(local_scheduler_id, "gpus_in_use", json.dumps(gpus_in_use)) pipe.execute() # If a WatchError is not raised, then the operations should have gone # through atomically. break except redis.WatchError: # Another client must have changed the watched key between the time we # started WATCHing it and the pipeline's execution. We should just # retry. gpus_to_acquire = [] continue return gpus_to_acquire
def _xray_clean_up_entries_for_driver(self, driver_id): """Remove this driver's object/task entries from redis. Removes control-state entries of all tasks and task return objects belonging to the driver. Args: driver_id: The driver id. """ xray_task_table_prefix = ( ray.gcs_utils.TablePrefix_RAYLET_TASK_string.encode("ascii")) xray_object_table_prefix = ( ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii")) task_table_objects = self.state.task_table() driver_id_hex = binary_to_hex(driver_id) driver_task_id_bins = set() for task_id_hex in task_table_objects: if len(task_table_objects[task_id_hex]) == 0: continue task_table_object = task_table_objects[task_id_hex][0]["TaskSpec"] task_driver_id_hex = task_table_object["DriverID"] if driver_id_hex != task_driver_id_hex: # Ignore tasks that aren't from this driver. continue driver_task_id_bins.add(hex_to_binary(task_id_hex)) # Get objects associated with the driver. object_table_objects = self.state.object_table() driver_object_id_bins = set() for object_id, object_table_object in object_table_objects.items(): assert len(object_table_object) > 0 task_id_bin = ray.local_scheduler.compute_task_id(object_id).id() if task_id_bin in driver_task_id_bins: driver_object_id_bins.add(object_id.id()) def to_shard_index(id_bin): return binary_to_object_id(id_bin).redis_shard_hash() % len( self.state.redis_clients) # Form the redis keys to delete. sharded_keys = [[] for _ in range(len(self.state.redis_clients))] for task_id_bin in driver_task_id_bins: sharded_keys[to_shard_index(task_id_bin)].append( xray_task_table_prefix + task_id_bin) for object_id_bin in driver_object_id_bins: sharded_keys[to_shard_index(object_id_bin)].append( xray_object_table_prefix + object_id_bin) # Remove with best effort. for shard_index in range(len(sharded_keys)): keys = sharded_keys[shard_index] if len(keys) == 0: continue redis = self.state.redis_clients[shard_index] num_deleted = redis.delete(*keys) log.info("Removed {} dead redis entries of the driver" " from redis shard {}.".format(num_deleted, shard_index)) if num_deleted != len(keys): log.warning("Failed to remove {} relevant redis entries" " from redis shard {}.".format( len(keys) - num_deleted, shard_index))
def _xray_clean_up_entries_for_job(self, job_id): """Remove this job's object/task entries from redis. Removes control-state entries of all tasks and task return objects belonging to the driver. Args: job_id: The job id. """ xray_task_table_prefix = ( ray.gcs_utils.TablePrefix_RAYLET_TASK_string.encode("ascii")) xray_object_table_prefix = ( ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii")) task_table_objects = ray.tasks() job_id_hex = binary_to_hex(job_id) job_task_id_bins = set() for task_id_hex, task_info in task_table_objects.items(): task_table_object = task_info["TaskSpec"] task_job_id_hex = task_table_object["JobID"] if job_id_hex != task_job_id_hex: # Ignore tasks that aren't from this driver. continue job_task_id_bins.add(hex_to_binary(task_id_hex)) # Get objects associated with the driver. object_table_objects = ray.objects() job_object_id_bins = set() for object_id, _ in object_table_objects.items(): task_id_bin = ray._raylet.compute_task_id(object_id).binary() if task_id_bin in job_task_id_bins: job_object_id_bins.add(object_id.binary()) def to_shard_index(id_bin): if len(id_bin) == ray.TaskID.size(): return binary_to_task_id(id_bin).redis_shard_hash() % len( ray.state.state.redis_clients) else: return binary_to_object_id(id_bin).redis_shard_hash() % len( ray.state.state.redis_clients) # Form the redis keys to delete. sharded_keys = [[] for _ in range(len(ray.state.state.redis_clients))] for task_id_bin in job_task_id_bins: sharded_keys[to_shard_index(task_id_bin)].append( xray_task_table_prefix + task_id_bin) for object_id_bin in job_object_id_bins: sharded_keys[to_shard_index(object_id_bin)].append( xray_object_table_prefix + object_id_bin) # Remove with best effort. for shard_index in range(len(sharded_keys)): keys = sharded_keys[shard_index] if len(keys) == 0: continue redis = ray.state.state.redis_clients[shard_index] num_deleted = redis.delete(*keys) logger.info("Monitor: " "Removed {} dead redis entries of the " "driver from redis shard {}.".format( num_deleted, shard_index)) if num_deleted != len(keys): logger.warning("Monitor: " "Failed to remove {} relevant redis " "entries from redis shard {}.".format( len(keys) - num_deleted, shard_index))
def _task_table(self, task_id): """Fetch and parse the task table information for a single task ID. Args: task_id: A task ID to get information about. Returns: A dictionary with information about the task ID in question. """ assert isinstance(task_id, ray.TaskID) message = self._execute_command( task_id, "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("RAYLET_TASK"), "", task_id.binary()) if message is None: return {} gcs_entries = gcs_utils.GcsEntry.FromString(message) assert len(gcs_entries.entries) == 1 task_table_data = gcs_utils.TaskTableData.FromString( gcs_entries.entries[0]) task = ray._raylet.TaskSpec.from_string( task_table_data.task.task_spec.SerializeToString()) function_descriptor_list = task.function_descriptor_list() function_descriptor = FunctionDescriptor.from_bytes_list( function_descriptor_list) task_spec_info = { "JobID": task.job_id().hex(), "TaskID": task.task_id().hex(), "ParentTaskID": task.parent_task_id().hex(), "ParentCounter": task.parent_counter(), "ActorID": (task.actor_id().hex()), "ActorCreationID": task.actor_creation_id().hex(), "ActorCreationDummyObjectID": (task.actor_creation_dummy_object_id().hex()), "PreviousActorTaskDummyObjectID": (task.previous_actor_task_dummy_object_id().hex()), "ActorCounter": task.actor_counter(), "Args": task.arguments(), "ReturnObjectIDs": task.returns(), "RequiredResources": task.required_resources(), "FunctionID": function_descriptor.function_id.hex(), "FunctionHash": binary_to_hex(function_descriptor.function_hash), "ModuleName": function_descriptor.module_name, "ClassName": function_descriptor.class_name, "FunctionName": function_descriptor.function_name, } execution_spec = ray._raylet.TaskExecutionSpec.from_string( task_table_data.task.task_execution_spec.SerializeToString()) return { "ExecutionSpec": { "NumForwards": execution_spec.num_forwards(), }, "TaskSpec": task_spec_info }
def default(self, obj): if isinstance(obj, bytes): return binary_to_hex(obj) # Let the base class default method raise the TypeError return json.JSONEncoder.default(self, obj)