def log_files(self): """Fetch and return a dictionary of log file names to outputs. Returns: IP address to log file name to log file contents mappings. """ relevant_files = self.redis_client.keys("LOGFILE*") ip_filename_file = {} for filename in relevant_files: filename = decode(filename) filename_components = filename.split(":") ip_addr = filename_components[1] file = self.redis_client.lrange(filename, 0, -1) file_str = [] for x in file: y = decode(x) file_str.append(y) if ip_addr not in ip_filename_file: ip_filename_file[ip_addr] = {} ip_filename_file[ip_addr][filename] = file_str return ip_filename_file
def _error_messages(self, job_id): """Get the error messages for a specific job. Args: job_id: The ID of the job to get the errors for. Returns: A list of the error messages for this job. """ assert isinstance(job_id, ray.DriverID) message = self.redis_client.execute_command( "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.ERROR_INFO, "", job_id.binary()) # If there are no errors, return early. if message is None: return [] gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( message, 0) error_messages = [] for i in range(gcs_entries.EntriesLength()): error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( gcs_entries.Entries(i), 0) assert job_id.binary() == error_data.JobId() error_message = { "type": decode(error_data.Type()), "message": decode(error_data.ErrorMessage()), "timestamp": error_data.Timestamp(), } error_messages.append(error_message) return error_messages
def parse_client_table(redis_client): """Read the client table. Args: redis_client: A client to the primary Redis shard. Returns: A list of information about the nodes in the cluster. """ NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff" message = redis_client.execute_command("RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "", NIL_CLIENT_ID) # Handle the case where no clients are returned. This should only # occur potentially immediately after the cluster is started. if message is None: return [] node_info = {} gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(message, 0) # Since GCS entries are append-only, we override so that # only the latest entries are kept. for i in range(gcs_entry.EntriesLength()): client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData( gcs_entry.Entries(i), 0)) resources = { decode(client.ResourcesTotalLabel(i)): client.ResourcesTotalCapacity(i) for i in range(client.ResourcesTotalLabelLength()) } client_id = ray.utils.binary_to_hex(client.ClientId()) # If this client is being removed, then it must # have previously been inserted, and # it cannot have previously been removed. if not client.IsInsertion(): assert client_id in node_info, "Client removed not found!" assert node_info[client_id]["IsInsertion"], ( "Unexpected duplicate removal of client.") node_info[client_id] = { "ClientID": client_id, "IsInsertion": client.IsInsertion(), "NodeManagerAddress": decode( client.NodeManagerAddress(), allow_none=True), "NodeManagerPort": client.NodeManagerPort(), "ObjectManagerPort": client.ObjectManagerPort(), "ObjectStoreSocketName": decode( client.ObjectStoreSocketName(), allow_none=True), "RayletSocketName": decode( client.RayletSocketName(), allow_none=True), "Resources": resources } return list(node_info.values())
def fetch_and_register_remote_function(self, key): """Import a remote function.""" (driver_id_str, function_id_str, function_name, serialized_function, num_return_vals, module, resources, max_calls) = self._worker.redis_client.hmget(key, [ "driver_id", "function_id", "name", "function", "num_return_vals", "module", "resources", "max_calls" ]) function_id = ray.FunctionID(function_id_str) driver_id = ray.DriverID(driver_id_str) function_name = decode(function_name) max_calls = int(max_calls) module = decode(module) # This is a placeholder in case the function can't be unpickled. This # will be overwritten if the function is successfully registered. def f(): raise Exception("This function was not imported properly.") self._function_execution_info[driver_id][function_id] = ( FunctionExecutionInfo( function=f, function_name=function_name, max_calls=max_calls)) self._num_task_executions[driver_id][function_id] = 0 try: function = pickle.loads(serialized_function) except Exception: # If an exception was thrown when the remote function was imported, # we record the traceback and notify the scheduler of the failure. traceback_str = format_error_message(traceback.format_exc()) # Log the error message. push_error_to_driver( self._worker, ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR, "Failed to unpickle the remote function '{}' with function ID " "{}. Traceback:\n{}".format(function_name, function_id.hex(), traceback_str), driver_id=driver_id) else: # The below line is necessary. Because in the driver process, # if the function is defined in the file where the python script # was started from, its module is `__main__`. # However in the worker process, the `__main__` module is a # different module, which is `default_worker.py` function.__module__ = module self._function_execution_info[driver_id][function_id] = ( FunctionExecutionInfo( function=function, function_name=function_name, max_calls=max_calls)) # Add the function to the function table. self._worker.redis_client.rpush( b"FunctionTable:" + function_id.binary(), self._worker.worker_id)
def _profile_table(self, batch_id): """Get the profile events for a given batch of profile events. Args: batch_id: An identifier for a batch of profile events. Returns: A list of the profile events for the specified batch. """ # TODO(rkn): This method should support limiting the number of log # events and should also support returning a window of events. message = self._execute_command(batch_id, "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.PROFILE, "", batch_id.binary()) if message is None: return [] gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( message, 0) profile_events = [] for i in range(gcs_entries.EntriesLength()): profile_table_message = ( ray.gcs_utils.ProfileTableData.GetRootAsProfileTableData( gcs_entries.Entries(i), 0)) component_type = decode(profile_table_message.ComponentType()) component_id = binary_to_hex(profile_table_message.ComponentId()) node_ip_address = decode( profile_table_message.NodeIpAddress(), allow_none=True) for j in range(profile_table_message.ProfileEventsLength()): profile_event_message = profile_table_message.ProfileEvents(j) profile_event = { "event_type": decode(profile_event_message.EventType()), "component_id": component_id, "node_ip_address": node_ip_address, "component_type": component_type, "start_time": profile_event_message.StartTime(), "end_time": profile_event_message.EndTime(), "extra_data": json.loads( decode(profile_event_message.ExtraData())), } profile_events.append(profile_event) return profile_events
def fetch_and_execute_function_to_run(self, key): """Run on arbitrary function on the worker.""" (driver_id, serialized_function, run_on_other_drivers) = self.redis_client.hmget( key, ["driver_id", "function", "run_on_other_drivers"]) if (utils.decode(run_on_other_drivers) == "False" and self.worker.mode == ray.SCRIPT_MODE and driver_id != self.worker.task_driver_id.binary()): return try: # Deserialize the function. function = pickle.loads(serialized_function) # Run the function. function({"worker": self.worker}) except Exception: # If an exception was thrown when the function was run, we record # the traceback and notify the scheduler of the failure. traceback_str = traceback.format_exc() # Log the error message. utils.push_error_to_driver( self.worker, ray_constants.FUNCTION_TO_RUN_PUSH_ERROR, traceback_str, driver_id=ray.DriverID(driver_id))
def function_table(self, function_id=None): """Fetch and parse the function table. Returns: A dictionary that maps function IDs to information about the function. """ self._check_connected() function_table_keys = self.redis_client.keys(FUNCTION_PREFIX + "*") results = {} for key in function_table_keys: info = self.redis_client.hgetall(key) function_info_parsed = { "DriverID": binary_to_hex(info[b"driver_id"]), "Module": decode(info[b"module"]), "Name": decode(info[b"name"])} results[binary_to_hex(info[b"function_id"])] = function_info_parsed return results
def workers(self): """Get a dictionary mapping worker ID to worker information.""" worker_keys = self.redis_client.keys("Worker*") workers_data = {} for worker_key in worker_keys: worker_info = self.redis_client.hgetall(worker_key) worker_id = binary_to_hex(worker_key[len("Workers:"):]) workers_data[worker_id] = { "node_ip_address": decode(worker_info[b"node_ip_address"]), "plasma_store_socket": decode( worker_info[b"plasma_store_socket"]) } if b"stderr_file" in worker_info: workers_data[worker_id]["stderr_file"] = decode( worker_info[b"stderr_file"]) if b"stdout_file" in worker_info: workers_data[worker_id]["stdout_file"] = decode( worker_info[b"stdout_file"]) return workers_data
def workers(self): """Get a dictionary mapping worker ID to worker information.""" worker_keys = self.redis_client.keys("Worker*") workers_data = {} for worker_key in worker_keys: worker_info = self.redis_client.hgetall(worker_key) worker_id = binary_to_hex(worker_key[len("Workers:"):]) workers_data[worker_id] = { "node_ip_address": decode(worker_info[b"node_ip_address"]), "plasma_store_socket": decode(worker_info[b"plasma_store_socket"]) } if b"stderr_file" in worker_info: workers_data[worker_id]["stderr_file"] = decode( worker_info[b"stderr_file"]) if b"stdout_file" in worker_info: workers_data[worker_id]["stdout_file"] = decode( worker_info[b"stdout_file"]) return workers_data
def actors(self): actor_keys = self.redis_client.keys("Actor:*") actor_info = dict() for key in actor_keys: info = self.redis_client.hgetall(key) actor_id = key[len("Actor:"):] assert len(actor_id) == 20 actor_info[binary_to_hex(actor_id)] = { "class_id": binary_to_hex(info[b"class_id"]), "driver_id": binary_to_hex(info[b"driver_id"]), "local_scheduler_id": binary_to_hex(info[b"local_scheduler_id"]), "num_gpus": int(info[b"num_gpus"]), "removed": decode(info[b"removed"]) == "True"} return actor_info
def client_table(self): """Fetch and parse the Redis DB client table. Returns: Information about the Ray clients in the cluster. """ self._check_connected() db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*") node_info = dict() for key in db_client_keys: client_info = self.redis_client.hgetall(key) node_ip_address = decode(client_info[b"node_ip_address"]) if node_ip_address not in node_info: node_info[node_ip_address] = [] client_info_parsed = {} assert b"client_type" in client_info assert b"deleted" in client_info assert b"ray_client_id" in client_info for field, value in client_info.items(): if field == b"node_ip_address": pass elif field == b"client_type": client_info_parsed["ClientType"] = decode(value) elif field == b"deleted": client_info_parsed["Deleted"] = bool(int(decode(value))) elif field == b"ray_client_id": client_info_parsed["DBClientID"] = binary_to_hex(value) elif field == b"manager_address": client_info_parsed["AuxAddress"] = decode(value) elif field == b"local_scheduler_socket_name": client_info_parsed["LocalSchedulerSocketName"] = ( decode(value)) elif client_info[b"client_type"] == b"local_scheduler": # The remaining fields are resource types. client_info_parsed[field.decode("ascii")] = float( decode(value)) else: client_info_parsed[field.decode("ascii")] = decode(value) node_info[node_ip_address].append(client_info_parsed) return node_info
def client_table(self): """Fetch and parse the Redis DB client table. Returns: Information about the Ray clients in the cluster. """ self._check_connected() db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*") node_info = dict() for key in db_client_keys: client_info = self.redis_client.hgetall(key) node_ip_address = decode(client_info[b"node_ip_address"]) if node_ip_address not in node_info: node_info[node_ip_address] = [] client_info_parsed = { "ClientType": decode(client_info[b"client_type"]), "Deleted": bool(int(decode(client_info[b"deleted"]))), "DBClientID": binary_to_hex(client_info[b"ray_client_id"]) } if b"manager_address" in client_info: client_info_parsed["AuxAddress"] = decode( client_info[b"manager_address"]) if b"num_cpus" in client_info: client_info_parsed["NumCPUs"] = float( decode(client_info[b"num_cpus"])) if b"num_gpus" in client_info: client_info_parsed["NumGPUs"] = float( decode(client_info[b"num_gpus"])) if b"num_custom_resource" in client_info: client_info_parsed["NumCustomResource"] = float( decode(client_info[b"num_custom_resource"])) if b"local_scheduler_socket_name" in client_info: client_info_parsed["LocalSchedulerSocketName"] = decode( client_info[b"local_scheduler_socket_name"]) node_info[node_ip_address].append(client_info_parsed) return node_info
def fetch_and_register_remote_function(self, key): """Import a remote function.""" (job_id_str, function_id_str, function_name, serialized_function, module, max_calls) = self._worker.redis_client.hmget(key, [ "job_id", "function_id", "function_name", "function", "module", "max_calls" ]) function_id = ray.FunctionID(function_id_str) job_id = ray.JobID(job_id_str) function_name = decode(function_name) max_calls = int(max_calls) module = decode(module) # This function is called by ImportThread. This operation needs to be # atomic. Otherwise, there is race condition. Another thread may use # the temporary function above before the real function is ready. with self.lock: self._num_task_executions[job_id][function_id] = 0 try: function = pickle.loads(serialized_function) except Exception: def f(*args, **kwargs): raise RuntimeError( "This function was not imported properly.") # Use a placeholder method when function pickled failed self._function_execution_info[job_id][function_id] = ( FunctionExecutionInfo(function=f, function_name=function_name, max_calls=max_calls)) # If an exception was thrown when the remote function was # imported, we record the traceback and notify the scheduler # of the failure. traceback_str = format_error_message(traceback.format_exc()) # Log the error message. push_error_to_driver( self._worker, ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR, "Failed to unpickle the remote function " f"'{function_name}' with " f"function ID {function_id.hex()}. " f"Traceback:\n{traceback_str}", job_id=job_id) else: # The below line is necessary. Because in the driver process, # if the function is defined in the file where the python # script was started from, its module is `__main__`. # However in the worker process, the `__main__` module is a # different module, which is `default_worker.py` function.__module__ = module self._function_execution_info[job_id][function_id] = ( FunctionExecutionInfo(function=function, function_name=function_name, max_calls=max_calls)) # Add the function to the function table. self._worker.redis_client.rpush( b"FunctionTable:" + function_id.binary(), self._worker.worker_id)
def available_resources(self): """Get the current available cluster resources. This is different from `cluster_resources` in that this will return idle (available) resources rather than total resources. Note that this information can grow stale as tasks start and finish. Returns: A dictionary mapping resource name to the total quantity of that resource in the cluster. """ available_resources_by_id = {} subscribe_clients = [ redis_client.pubsub(ignore_subscribe_messages=True) for redis_client in self.redis_clients ] for subscribe_client in subscribe_clients: subscribe_client.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL) client_ids = self._live_client_ids() while set(available_resources_by_id.keys()) != client_ids: for subscribe_client in subscribe_clients: # Parse client message raw_message = subscribe_client.get_message() if (raw_message is None or raw_message["channel"] != ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL): continue data = raw_message["data"] gcs_entries = ( ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( data, 0)) heartbeat_data = gcs_entries.Entries(0) message = (ray.gcs_utils.HeartbeatTableData. GetRootAsHeartbeatTableData(heartbeat_data, 0)) # Calculate available resources for this client num_resources = message.ResourcesAvailableLabelLength() dynamic_resources = {} for i in range(num_resources): resource_id = decode(message.ResourcesAvailableLabel(i)) dynamic_resources[resource_id] = ( message.ResourcesAvailableCapacity(i)) # Update available resources for this client client_id = ray.utils.binary_to_hex(message.ClientId()) available_resources_by_id[client_id] = dynamic_resources # Update clients in cluster client_ids = self._live_client_ids() # Remove disconnected clients for client_id in available_resources_by_id.keys(): if client_id not in client_ids: del available_resources_by_id[client_id] # Calculate total available resources total_available_resources = defaultdict(int) for available_resources in available_resources_by_id.values(): for resource_id, num_available in available_resources.items(): total_available_resources[resource_id] += num_available # Close the pubsub clients to avoid leaking file descriptors. for subscribe_client in subscribe_clients: subscribe_client.close() return dict(total_available_resources)
def client_table(self): """Fetch and parse the Redis DB client table. Returns: Information about the Ray clients in the cluster. """ self._check_connected() if not self.use_raylet: db_client_keys = self.redis_client.keys( ray.gcs_utils.DB_CLIENT_PREFIX + "*") node_info = {} for key in db_client_keys: client_info = self.redis_client.hgetall(key) node_ip_address = decode(client_info[b"node_ip_address"]) if node_ip_address not in node_info: node_info[node_ip_address] = [] client_info_parsed = {} assert b"client_type" in client_info assert b"deleted" in client_info assert b"ray_client_id" in client_info for field, value in client_info.items(): if field == b"node_ip_address": pass elif field == b"client_type": client_info_parsed["ClientType"] = decode(value) elif field == b"deleted": client_info_parsed["Deleted"] = bool( int(decode(value))) elif field == b"ray_client_id": client_info_parsed["DBClientID"] = binary_to_hex(value) elif field == b"manager_address": client_info_parsed["AuxAddress"] = decode(value) elif field == b"local_scheduler_socket_name": client_info_parsed["LocalSchedulerSocketName"] = ( decode(value)) elif client_info[b"client_type"] == b"local_scheduler": # The remaining fields are resource types. client_info_parsed[field.decode("ascii")] = float( decode(value)) else: client_info_parsed[field.decode("ascii")] = decode( value) node_info[node_ip_address].append(client_info_parsed) return node_info else: # This is the raylet code path. NIL_CLIENT_ID = 20 * b"\xff" message = self.redis_client.execute_command( "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "", NIL_CLIENT_ID) node_info = [] gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( message, 0) for i in range(gcs_entry.EntriesLength()): client = ( ray.gcs_utils.ClientTableData.GetRootAsClientTableData( gcs_entry.Entries(i), 0)) resources = { client.ResourcesTotalLabel(i).decode("ascii"): client.ResourcesTotalCapacity(i) for i in range(client.ResourcesTotalLabelLength()) } node_info.append({ "ClientID": ray.utils.binary_to_hex(client.ClientId()), "IsInsertion": client.IsInsertion(), "NodeManagerAddress": client.NodeManagerAddress().decode( "ascii"), "NodeManagerPort": client.NodeManagerPort(), "ObjectManagerPort": client.ObjectManagerPort(), "ObjectStoreSocketName": client.ObjectStoreSocketName() .decode("ascii"), "RayletSocketName": client.RayletSocketName().decode( "ascii"), "Resources": resources }) return node_info
def parse_client_table(redis_client): """Read the client table. Args: redis_client: A client to the primary Redis shard. Returns: A list of information about the nodes in the cluster. """ NIL_CLIENT_ID = ray.ObjectID.nil().binary() message = redis_client.execute_command("RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "", NIL_CLIENT_ID) # Handle the case where no clients are returned. This should only # occur potentially immediately after the cluster is started. if message is None: return [] node_info = {} gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(message, 0) ordered_client_ids = [] # Since GCS entries are append-only, we override so that # only the latest entries are kept. for i in range(gcs_entry.EntriesLength()): client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData( gcs_entry.Entries(i), 0)) resources = { decode(client.ResourcesTotalLabel(i)): client.ResourcesTotalCapacity(i) for i in range(client.ResourcesTotalLabelLength()) } client_id = ray.utils.binary_to_hex(client.ClientId()) # If this client is being removed, then it must # have previously been inserted, and # it cannot have previously been removed. if not client.IsInsertion(): assert client_id in node_info, "Client removed not found!" assert node_info[client_id]["IsInsertion"], ( "Unexpected duplicate removal of client.") else: ordered_client_ids.append(client_id) node_info[client_id] = { "ClientID": client_id, "IsInsertion": client.IsInsertion(), "NodeManagerAddress": decode(client.NodeManagerAddress(), allow_none=True), "NodeManagerPort": client.NodeManagerPort(), "ObjectManagerPort": client.ObjectManagerPort(), "ObjectStoreSocketName": decode(client.ObjectStoreSocketName(), allow_none=True), "RayletSocketName": decode(client.RayletSocketName(), allow_none=True), "Resources": resources } # NOTE: We return the list comprehension below instead of simply doing # 'list(node_info.values())' in order to have the nodes appear in the order # that they joined the cluster. Python dictionaries do not preserve # insertion order. We could use an OrderedDict, but then we'd have to be # sure to only insert a given node a single time (clients that die appear # twice in the GCS log). return [node_info[client_id] for client_id in ordered_client_ids]
def fetch_and_register_actor(self, actor_class_key): """Import an actor. This will be called by the worker's import thread when the worker receives the actor_class export, assuming that the worker is an actor for that class. Args: actor_class_key: The key in Redis to use to fetch the actor. """ actor_id = self._worker.actor_id (driver_id_str, class_name, module, pickled_class, actor_method_names) = self._worker.redis_client.hmget( actor_class_key, [ "driver_id", "class_name", "module", "class", "actor_method_names" ]) class_name = decode(class_name) module = decode(module) driver_id = ray.DriverID(driver_id_str) actor_method_names = json.loads(decode(actor_method_names)) # In Python 2, json loads strings as unicode, so convert them back to # strings. if sys.version_info < (3, 0): actor_method_names = [ method_name.encode("ascii") for method_name in actor_method_names ] # Create a temporary actor with some temporary methods so that if # the actor fails to be unpickled, the temporary actor can be used # (just to produce error messages and to prevent the driver from # hanging). class TemporaryActor(object): pass self._worker.actors[actor_id] = TemporaryActor() def temporary_actor_method(*xs): raise Exception( "The actor with name {} failed to be imported, " "and so cannot execute this method".format(class_name)) # Register the actor method executors. for actor_method_name in actor_method_names: function_descriptor = FunctionDescriptor(module, actor_method_name, class_name) function_id = function_descriptor.function_id temporary_executor = self._make_actor_method_executor( actor_method_name, temporary_actor_method, actor_imported=False) self._function_execution_info[driver_id][function_id] = ( FunctionExecutionInfo(function=temporary_executor, function_name=actor_method_name, max_calls=0)) self._num_task_executions[driver_id][function_id] = 0 try: unpickled_class = pickle.loads(pickled_class) self._worker.actor_class = unpickled_class except Exception: # If an exception was thrown when the actor was imported, we record # the traceback and notify the scheduler of the failure. traceback_str = ray.utils.format_error_message( traceback.format_exc()) # Log the error message. push_error_to_driver( self._worker, ray_constants.REGISTER_ACTOR_PUSH_ERROR, "Failed to unpickle actor class '{}' for actor ID {}. " "Traceback:\n{}".format(class_name, actor_id.hex(), traceback_str), driver_id) # TODO(rkn): In the future, it might make sense to have the worker # exit here. However, currently that would lead to hanging if # someone calls ray.get on a method invoked on the actor. else: # TODO(pcm): Why is the below line necessary? unpickled_class.__module__ = module self._worker.actors[actor_id] = unpickled_class.__new__( unpickled_class) actor_methods = inspect.getmembers(unpickled_class, predicate=is_function_or_method) for actor_method_name, actor_method in actor_methods: function_descriptor = FunctionDescriptor( module, actor_method_name, class_name) function_id = function_descriptor.function_id executor = self._make_actor_method_executor( actor_method_name, actor_method, actor_imported=True) self._function_execution_info[driver_id][function_id] = ( FunctionExecutionInfo(function=executor, function_name=actor_method_name, max_calls=0))
def client_table(self): """Fetch and parse the Redis DB client table. Returns: Information about the Ray clients in the cluster. """ self._check_connected() NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff" message = self.redis_client.execute_command( "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "", NIL_CLIENT_ID) # Handle the case where no clients are returned. This should only # occur potentially immediately after the cluster is started. if message is None: return [] node_info = {} gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( message, 0) # Since GCS entries are append-only, we override so that # only the latest entries are kept. for i in range(gcs_entry.EntriesLength()): client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData( gcs_entry.Entries(i), 0)) resources = { decode(client.ResourcesTotalLabel(i)): client.ResourcesTotalCapacity(i) for i in range(client.ResourcesTotalLabelLength()) } client_id = ray.utils.binary_to_hex(client.ClientId()) # If this client is being removed, then it must # have previously been inserted, and # it cannot have previously been removed. if not client.IsInsertion(): assert client_id in node_info, "Client removed not found!" assert node_info[client_id]["IsInsertion"], ( "Unexpected duplicate removal of client.") node_info[client_id] = { "ClientID": client_id, "IsInsertion": client.IsInsertion(), "NodeManagerAddress": decode(client.NodeManagerAddress(), allow_none=True), "NodeManagerPort": client.NodeManagerPort(), "ObjectManagerPort": client.ObjectManagerPort(), "ObjectStoreSocketName": decode(client.ObjectStoreSocketName(), allow_none=True), "RayletSocketName": decode(client.RayletSocketName(), allow_none=True), "Resources": resources } return list(node_info.values())
def fetch_and_register_actor(actor_class_key, worker): """Import an actor. This will be called by the worker's import thread when the worker receives the actor_class export, assuming that the worker is an actor for that class. Args: actor_class_key: The key in Redis to use to fetch the actor. worker: The worker to use. """ actor_id_str = worker.actor_id (driver_id, class_id, class_name, module, pickled_class, checkpoint_interval, actor_method_names) = worker.redis_client.hmget(actor_class_key, [ "driver_id", "class_id", "class_name", "module", "class", "checkpoint_interval", "actor_method_names" ]) class_name = decode(class_name) module = decode(module) checkpoint_interval = int(checkpoint_interval) actor_method_names = json.loads(decode(actor_method_names)) # Create a temporary actor with some temporary methods so that if the actor # fails to be unpickled, the temporary actor can be used (just to produce # error messages and to prevent the driver from hanging). class TemporaryActor(object): pass worker.actors[actor_id_str] = TemporaryActor() worker.actor_checkpoint_interval = checkpoint_interval def temporary_actor_method(*xs): raise Exception("The actor with name {} failed to be imported, and so " "cannot execute this method".format(class_name)) # Register the actor method executors. for actor_method_name in actor_method_names: function_id = compute_actor_method_function_id(class_name, actor_method_name).id() temporary_executor = make_actor_method_executor(worker, actor_method_name, temporary_actor_method, actor_imported=False) worker.function_execution_info[driver_id][function_id] = ( ray.worker.FunctionExecutionInfo(function=temporary_executor, function_name=actor_method_name, max_calls=0)) worker.num_task_executions[driver_id][function_id] = 0 try: unpickled_class = pickle.loads(pickled_class) worker.actor_class = unpickled_class except Exception: # If an exception was thrown when the actor was imported, we record the # traceback and notify the scheduler of the failure. traceback_str = ray.utils.format_error_message(traceback.format_exc()) # Log the error message. push_error_to_driver(worker, ray_constants.REGISTER_ACTOR_PUSH_ERROR, traceback_str, driver_id, data={"actor_id": actor_id_str}) # TODO(rkn): In the future, it might make sense to have the worker exit # here. However, currently that would lead to hanging if someone calls # ray.get on a method invoked on the actor. else: # TODO(pcm): Why is the below line necessary? unpickled_class.__module__ = module worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class) def pred(x): return (inspect.isfunction(x) or inspect.ismethod(x) or is_cython(x)) actor_methods = inspect.getmembers(unpickled_class, predicate=pred) for actor_method_name, actor_method in actor_methods: function_id = compute_actor_method_function_id( class_name, actor_method_name).id() executor = make_actor_method_executor(worker, actor_method_name, actor_method, actor_imported=True) worker.function_execution_info[driver_id][function_id] = ( ray.worker.FunctionExecutionInfo( function=executor, function_name=actor_method_name, max_calls=0))
def parse_client_table(redis_client): """Read the client table. Args: redis_client: A client to the primary Redis shard. Returns: A list of information about the nodes in the cluster. """ NIL_CLIENT_ID = ray.ObjectID.nil().binary() message = redis_client.execute_command("RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "", NIL_CLIENT_ID) # Handle the case where no clients are returned. This should only # occur potentially immediately after the cluster is started. if message is None: return [] node_info = {} gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(message, 0) ordered_client_ids = [] # Since GCS entries are append-only, we override so that # only the latest entries are kept. for i in range(gcs_entry.EntriesLength()): client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData( gcs_entry.Entries(i), 0)) resources = { decode(client.ResourcesTotalLabel(i)): client.ResourcesTotalCapacity(i) for i in range(client.ResourcesTotalLabelLength()) } client_id = ray.utils.binary_to_hex(client.ClientId()) # If this client is being removed, then it must # have previously been inserted, and # it cannot have previously been removed. if not client.IsInsertion(): assert client_id in node_info, "Client removed not found!" assert node_info[client_id]["IsInsertion"], ( "Unexpected duplicate removal of client.") else: ordered_client_ids.append(client_id) node_info[client_id] = { "ClientID": client_id, "IsInsertion": client.IsInsertion(), "NodeManagerAddress": decode( client.NodeManagerAddress(), allow_none=True), "NodeManagerPort": client.NodeManagerPort(), "ObjectManagerPort": client.ObjectManagerPort(), "ObjectStoreSocketName": decode( client.ObjectStoreSocketName(), allow_none=True), "RayletSocketName": decode( client.RayletSocketName(), allow_none=True), "Resources": resources } # NOTE: We return the list comprehension below instead of simply doing # 'list(node_info.values())' in order to have the nodes appear in the order # that they joined the cluster. Python dictionaries do not preserve # insertion order. We could use an OrderedDict, but then we'd have to be # sure to only insert a given node a single time (clients that die appear # twice in the GCS log). return [node_info[client_id] for client_id in ordered_client_ids]
def fetch_and_register_remote_function(self, key): """Import a remote function.""" (driver_id_str, function_id_str, function_name, serialized_function, num_return_vals, module, resources, max_calls) = self._worker.redis_client.hmget(key, [ "driver_id", "function_id", "name", "function", "num_return_vals", "module", "resources", "max_calls" ]) function_id = ray.FunctionID(function_id_str) driver_id = ray.DriverID(driver_id_str) function_name = decode(function_name) max_calls = int(max_calls) module = decode(module) # This is a placeholder in case the function can't be unpickled. This # will be overwritten if the function is successfully registered. def f(): raise Exception("This function was not imported properly.") # This function is called by ImportThread. This operation needs to be # atomic. Otherwise, there is race condition. Another thread may use # the temporary function above before the real function is ready. with self.lock: self._function_execution_info[driver_id][function_id] = ( FunctionExecutionInfo( function=f, function_name=function_name, max_calls=max_calls)) self._num_task_executions[driver_id][function_id] = 0 try: function = pickle.loads(serialized_function) except Exception: # If an exception was thrown when the remote function was # imported, we record the traceback and notify the scheduler # of the failure. traceback_str = format_error_message(traceback.format_exc()) # Log the error message. push_error_to_driver( self._worker, ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR, "Failed to unpickle the remote function '{}' with " "function ID {}. Traceback:\n{}".format( function_name, function_id.hex(), traceback_str), driver_id=driver_id) else: # The below line is necessary. Because in the driver process, # if the function is defined in the file where the python # script was started from, its module is `__main__`. # However in the worker process, the `__main__` module is a # different module, which is `default_worker.py` function.__module__ = module self._function_execution_info[driver_id][function_id] = ( FunctionExecutionInfo( function=function, function_name=function_name, max_calls=max_calls)) # Add the function to the function table. self._worker.redis_client.rpush( b"FunctionTable:" + function_id.binary(), self._worker.worker_id)