def _initialize_global_state(self, redis_address, redis_password=None, timeout=20): """Initialize the GlobalState object by connecting to Redis. It's possible that certain keys in Redis may not have been fully populated yet. In this case, we will retry this method until they have been populated or we exceed a timeout. Args: redis_address: The Redis address to connect. redis_password: The password of the redis server. """ self.redis_client = services.create_redis_client( redis_address, redis_password) self.global_state_accessor = GlobalStateAccessor( redis_address, redis_password, False) self.global_state_accessor.connect() start_time = time.time() num_redis_shards = None redis_shard_addresses = [] while time.time() - start_time < timeout: # Attempt to get the number of Redis shards. num_redis_shards = self.redis_client.get("NumRedisShards") if num_redis_shards is None: print("Waiting longer for NumRedisShards to be populated.") time.sleep(1) continue num_redis_shards = int(num_redis_shards) assert num_redis_shards >= 1, ( "Expected at least one Redis " "shard, found {}.".format(num_redis_shards)) # Attempt to get all of the Redis shards. redis_shard_addresses = self.redis_client.lrange( "RedisShards", start=0, end=-1) if len(redis_shard_addresses) != num_redis_shards: print("Waiting longer for RedisShards to be populated.") time.sleep(1) continue # If we got here then we successfully got all of the information. break # Check to see if we timed out. if time.time() - start_time >= timeout: raise TimeoutError("Timed out while attempting to initialize the " "global state. num_redis_shards = {}, " "redis_shard_addresses = {}".format( num_redis_shards, redis_shard_addresses)) # Get the rest of the information. self.redis_clients = [] for shard_address in redis_shard_addresses: self.redis_clients.append( services.create_redis_client(shard_address.decode(), redis_password))
def __init__(self, redis_address, autoscaling_config, redis_password=None, prefix_cluster_info=False): # Initialize the Redis clients. ray.state.state._initialize_global_state(redis_address, redis_password=redis_password) self.redis = ray._private.services.create_redis_client( redis_address, password=redis_password) self.global_state_accessor = GlobalStateAccessor( redis_address, redis_password, False) self.global_state_accessor.connect() # Set the redis client and mode so _internal_kv works for autoscaler. worker = ray.worker.global_worker worker.redis_client = self.redis worker.mode = 0 # Keep a mapping from raylet client ID to IP address to use # for updating the load metrics. self.raylet_id_to_ip_map = {} head_node_ip = redis_address.split(":")[0] self.load_metrics = LoadMetrics(local_ip=head_node_ip) if autoscaling_config: self.autoscaler = StandardAutoscaler( autoscaling_config, self.load_metrics, prefix_cluster_info=prefix_cluster_info) self.autoscaling_config = autoscaling_config else: self.autoscaler = None self.autoscaling_config = None
def __init__(self, redis_address, autoscaling_config, redis_password=None): # Initialize the Redis clients. ray.state.state._initialize_global_state(redis_address, redis_password=redis_password) self.redis = ray._private.services.create_redis_client( redis_address, password=redis_password) self.global_state_accessor = GlobalStateAccessor( redis_address, redis_password, False) self.global_state_accessor.connect() # Set the redis client and mode so _internal_kv works for autoscaler. worker = ray.worker.global_worker worker.redis_client = self.redis worker.mode = 0 # Setup subscriptions to the primary Redis server and the Redis shards. self.primary_subscribe_client = self.redis.pubsub( ignore_subscribe_messages=True) # Keep a mapping from raylet client ID to IP address to use # for updating the load metrics. self.raylet_id_to_ip_map = {} head_node_ip = redis_address.split(":")[0] self.load_metrics = LoadMetrics(local_ip=head_node_ip) if autoscaling_config: self.autoscaler = StandardAutoscaler(autoscaling_config, self.load_metrics) self.autoscaling_config = autoscaling_config else: self.autoscaler = None self.autoscaling_config = None
def make_global_state_accessor(address_info): if not gcs_utils.use_gcs_for_bootstrap(): gcs_options = GcsClientOptions.from_redis_address( address_info["redis_address"], ray.ray_constants.REDIS_DEFAULT_PASSWORD ) else: gcs_options = GcsClientOptions.from_gcs_address(address_info["gcs_address"]) global_state_accessor = GlobalStateAccessor(gcs_options) global_state_accessor.connect() return global_state_accessor
def _really_init_global_state(self, timeout=20): self.redis_client = services.create_redis_client( self.redis_address, self.redis_password) self.global_state_accessor = GlobalStateAccessor( self.redis_address, self.redis_password, False) self.global_state_accessor.connect() start_time = time.time() num_redis_shards = None redis_shard_addresses = [] while time.time() - start_time < timeout: # Attempt to get the number of Redis shards. num_redis_shards = self.redis_client.get("NumRedisShards") if num_redis_shards is None: print("Waiting longer for NumRedisShards to be populated.") time.sleep(1) continue num_redis_shards = int(num_redis_shards) assert num_redis_shards >= 1, ( f"Expected at least one Redis shard, found {num_redis_shards}." ) # Attempt to get all of the Redis shards. redis_shard_addresses = self.redis_client.lrange("RedisShards", start=0, end=-1) if len(redis_shard_addresses) != num_redis_shards: print("Waiting longer for RedisShards to be populated.") time.sleep(1) continue # If we got here then we successfully got all of the information. break # Check to see if we timed out. if time.time() - start_time >= timeout: raise TimeoutError("Timed out while attempting to initialize the " "global state. " f"num_redis_shards = {num_redis_shards}, " "redis_shard_addresses = " f"{redis_shard_addresses}") # Get the rest of the information. self.redis_clients = [] for shard_address in redis_shard_addresses: self.redis_clients.append( services.create_redis_client(shard_address.decode(), self.redis_password))
def test_actor_resource_demand(shutdown_only): ray.shutdown() cluster = ray.init(num_cpus=3) global_state_accessor = GlobalStateAccessor( cluster["redis_address"], ray.ray_constants.REDIS_DEFAULT_PASSWORD) global_state_accessor.connect() @ray.remote(num_cpus=2) class Actor: def foo(self): return "ok" a = Actor.remote() ray.get(a.foo.remote()) time.sleep(1) message = global_state_accessor.get_all_resource_usage() resource_usages = gcs_utils.ResourceUsageBatchData.FromString(message) # The actor is scheduled so there should be no more demands left. assert len(resource_usages.resource_load_by_shape.resource_demands) == 0 @ray.remote(num_cpus=80) class Actor2: pass actors = [] actors.append(Actor2.remote()) time.sleep(1) # This actor cannot be scheduled. message = global_state_accessor.get_all_resource_usage() resource_usages = gcs_utils.ResourceUsageBatchData.FromString(message) assert len(resource_usages.resource_load_by_shape.resource_demands) == 1 assert ( resource_usages.resource_load_by_shape.resource_demands[0].shape == { "CPU": 80.0 }) assert (resource_usages.resource_load_by_shape.resource_demands[0]. num_infeasible_requests_queued == 1) actors.append(Actor2.remote()) time.sleep(1) # Two actors cannot be scheduled. message = global_state_accessor.get_all_resource_usage() resource_usages = gcs_utils.ResourceUsageBatchData.FromString(message) assert len(resource_usages.resource_load_by_shape.resource_demands) == 1 assert (resource_usages.resource_load_by_shape.resource_demands[0]. num_infeasible_requests_queued == 2) global_state_accessor.disconnect()
def test_kill_pending_actor_with_no_restart_true(): cluster = ray.init() global_state_accessor = GlobalStateAccessor( cluster["redis_address"], ray.ray_constants.REDIS_DEFAULT_PASSWORD) global_state_accessor.connect() @ray.remote(resources={"WORKER": 1.0}) class PendingActor: pass # Kill actor with `no_restart=True`. actor = PendingActor.remote() # TODO(ffbin): The raylet doesn't guarantee the order when dealing with # RequestWorkerLease and CancelWorkerLease. If we kill the actor # immediately after creating the actor, we may not be able to clean up # the request cached by the raylet. # See https://github.com/ray-project/ray/issues/13545 for details. time.sleep(1) ray.kill(actor, no_restart=True) def condition1(): message = global_state_accessor.get_all_resource_usage() resource_usages = gcs_utils.ResourceUsageBatchData.FromString(message) if len(resource_usages.resource_load_by_shape.resource_demands) == 0: return True return False # Actor is dead, so the infeasible task queue length is 0. wait_for_condition(condition1, timeout=10) global_state_accessor.disconnect() ray.shutdown()
def test_backlog_report(shutdown_only): cluster = ray.init(num_cpus=1, _system_config={ "report_worker_backlog": True, }) global_state_accessor = GlobalStateAccessor( cluster["redis_address"], ray.ray_constants.REDIS_DEFAULT_PASSWORD) global_state_accessor.connect() @ray.remote(num_cpus=1) def foo(x): print(".") time.sleep(x) return None def backlog_size_set(): message = global_state_accessor.get_all_resource_usage() if message is None: return False resource_usage = ray.gcs_utils.ResourceUsageBatchData.FromString( message) aggregate_resource_load = \ resource_usage.resource_load_by_shape.resource_demands if len(aggregate_resource_load) == 1: backlog_size = aggregate_resource_load[0].backlog_size print(backlog_size) # Ideally we'd want to assert backlog_size == 8, but guaranteeing # the order the order that submissions will occur is too # hard/flaky. return backlog_size > 0 return False # We want this first task to finish refs = [foo.remote(0.5)] # These tasks should all start _before_ the first one finishes. refs.extend([foo.remote(1000) for _ in range(9)]) # Now there's 1 request running, 1 queued in the raylet, and 8 queued in # the worker backlog. ray.get(refs[0]) # First request finishes, second request is now running, third lease # request is sent to the raylet with backlog=7 ray.test_utils.wait_for_condition(backlog_size_set, timeout=2) global_state_accessor.disconnect()
def test_heartbeat_ip(shutdown_only): cluster = ray.init(num_cpus=1) global_state_accessor = GlobalStateAccessor( cluster["redis_address"], ray.ray_constants.REDIS_DEFAULT_PASSWORD) global_state_accessor.connect() self_ip = ray.util.get_node_ip_address() def self_ip_is_set(): message = global_state_accessor.get_all_resource_usage() if message is None: return False resource_usage = gcs_utils.ResourceUsageBatchData.FromString(message) resources_data = resource_usage.batch[0] return resources_data.node_manager_address == self_ip wait_for_condition(self_ip_is_set, timeout=2) global_state_accessor.disconnect()
class GlobalState: """A class used to interface with the Ray control state. # TODO(zongheng): In the future move this to use Ray's redis module in the # backend to cut down on # of request RPCs. Attributes: redis_client: The Redis client used to query the primary redis server. redis_clients: Redis clients for each of the Redis shards. global_state_accessor: The client used to query gcs table from gcs server. """ def __init__(self): """Create a GlobalState object.""" # The redis server storing metadata, such as function table, client # table, log files, event logs, workers/actions info. self.redis_client = None # Clients for the redis shards, storing the object table & task table. self.redis_clients = None self.global_state_accessor = None def _check_connected(self): """Check that the object has been initialized before it is used. Raises: RuntimeError: An exception is raised if ray.init() has not been called yet. """ if self.redis_client is None: raise RuntimeError("The ray global state API cannot be used " "before ray.init has been called.") if self.redis_clients is None: raise RuntimeError("The ray global state API cannot be used " "before ray.init has been called.") if self.global_state_accessor is None: raise RuntimeError("The ray global state API cannot be used " "before ray.init has been called.") def disconnect(self): """Disconnect global state from GCS.""" self.redis_client = None self.redis_clients = None if self.global_state_accessor is not None: self.global_state_accessor.disconnect() self.global_state_accessor = None def _initialize_global_state(self, redis_address, redis_password=None, timeout=20): """Initialize the GlobalState object by connecting to Redis. It's possible that certain keys in Redis may not have been fully populated yet. In this case, we will retry this method until they have been populated or we exceed a timeout. Args: redis_address: The Redis address to connect. redis_password: The password of the redis server. """ self.redis_client = services.create_redis_client( redis_address, redis_password) self.global_state_accessor = GlobalStateAccessor( redis_address, redis_password, False) self.global_state_accessor.connect() start_time = time.time() num_redis_shards = None redis_shard_addresses = [] while time.time() - start_time < timeout: # Attempt to get the number of Redis shards. num_redis_shards = self.redis_client.get("NumRedisShards") if num_redis_shards is None: print("Waiting longer for NumRedisShards to be populated.") time.sleep(1) continue num_redis_shards = int(num_redis_shards) assert num_redis_shards >= 1, ( "Expected at least one Redis " "shard, found {}.".format(num_redis_shards)) # Attempt to get all of the Redis shards. redis_shard_addresses = self.redis_client.lrange( "RedisShards", start=0, end=-1) if len(redis_shard_addresses) != num_redis_shards: print("Waiting longer for RedisShards to be populated.") time.sleep(1) continue # If we got here then we successfully got all of the information. break # Check to see if we timed out. if time.time() - start_time >= timeout: raise TimeoutError("Timed out while attempting to initialize the " "global state. num_redis_shards = {}, " "redis_shard_addresses = {}".format( num_redis_shards, redis_shard_addresses)) # Get the rest of the information. self.redis_clients = [] for shard_address in redis_shard_addresses: self.redis_clients.append( services.create_redis_client(shard_address.decode(), redis_password)) def _execute_command(self, key, *args): """Execute a Redis command on the appropriate Redis shard based on key. Args: key: The object ID or the task ID that the query is about. args: The command to run. Returns: The value returned by the Redis command. """ client = self.redis_clients[key.redis_shard_hash() % len( self.redis_clients)] return client.execute_command(*args) def _keys(self, pattern): """Execute the KEYS command on all Redis shards. Args: pattern: The KEYS pattern to query. Returns: The concatenated list of results from all shards. """ result = [] for client in self.redis_clients: result.extend(list(client.scan_iter(match=pattern))) return result def object_table(self, object_id=None): """Fetch and parse the object table info for one or more object IDs. Args: object_id: An object ID to fetch information about. If this is None, then the entire object table is fetched. Returns: Information from the object table. """ self._check_connected() if object_id is not None: object_id = ray.ObjectID(hex_to_binary(object_id)) object_info = self.global_state_accessor.get_object_info(object_id) if object_info is None: return {} else: object_location_info = gcs_utils.ObjectLocationInfo.FromString( object_info) return self._gen_object_info(object_location_info) else: object_table = self.global_state_accessor.get_object_table() results = {} for i in range(len(object_table)): object_location_info = gcs_utils.ObjectLocationInfo.FromString( object_table[i]) results[binary_to_hex(object_location_info.object_id)] = \ self._gen_object_info(object_location_info) return results def _gen_object_info(self, object_location_info): """Parse object location info. Returns: Information from object. """ locations = [] for location in object_location_info.locations: locations.append(ray.utils.binary_to_hex(location.manager)) object_info = { "ObjectID": ray.utils.binary_to_hex( object_location_info.object_id), "Locations": locations, } return object_info def _actor_table(self, actor_id): """Fetch and parse the actor table information for a single actor ID. Args: actor_id: A actor ID to get information about. Returns: A dictionary with information about the actor ID in question. """ assert isinstance(actor_id, ray.ActorID) message = self.redis_client.execute_command( "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ACTOR"), "", actor_id.binary()) if message is None: return {} gcs_entries = gcs_utils.GcsEntry.FromString(message) assert len(gcs_entries.entries) > 0 actor_table_data = gcs_utils.ActorTableData.FromString( gcs_entries.entries[-1]) actor_info = { "ActorID": binary_to_hex(actor_table_data.actor_id), "JobID": binary_to_hex(actor_table_data.job_id), "Address": { "IPAddress": actor_table_data.address.ip_address, "Port": actor_table_data.address.port }, "OwnerAddress": { "IPAddress": actor_table_data.owner_address.ip_address, "Port": actor_table_data.owner_address.port }, "State": actor_table_data.state, "Timestamp": actor_table_data.timestamp, } return actor_info def actor_table(self, actor_id=None): """Fetch and parse the actor table information for one or more actor IDs. Args: actor_id: A hex string of the actor ID to fetch information about. If this is None, then the actor table is fetched. Returns: Information from the actor table. """ self._check_connected() if actor_id is not None: actor_id = ray.ActorID(hex_to_binary(actor_id)) return self._actor_table(actor_id) else: actor_table_keys = list( self.redis_client.scan_iter( match=gcs_utils.TablePrefix_ACTOR_string + "*")) actor_ids_binary = [ key[len(gcs_utils.TablePrefix_ACTOR_string):] for key in actor_table_keys ] results = {} for actor_id_binary in actor_ids_binary: results[binary_to_hex(actor_id_binary)] = self._actor_table( ray.ActorID(actor_id_binary)) return results def client_table(self): """Fetch and parse the Redis DB client table. Returns: Information about the Ray clients in the cluster. """ self._check_connected() client_table = _parse_client_table(self.redis_client) for client in client_table: # These are equivalent and is better for application developers. client["alive"] = client["Alive"] return client_table def job_table(self): """Fetch and parse the Redis job table. Returns: Information about the Ray jobs in the cluster, namely a list of dicts with keys: - "JobID" (identifier for the job), - "DriverIPAddress" (IP address of the driver for this job), - "DriverPid" (process ID of the driver for this job), - "StartTime" (UNIX timestamp of the start time of this job), - "StopTime" (UNIX timestamp of the stop time of this job, if any) """ self._check_connected() job_table = self.global_state_accessor.get_job_table() results = [] for i in range(len(job_table)): entry = gcs_utils.JobTableData.FromString(job_table[i]) job_info = {} job_info["JobID"] = entry.job_id.hex() job_info["DriverIPAddress"] = entry.driver_ip_address job_info["DriverPid"] = entry.driver_pid if entry.is_dead: job_info["StopTime"] = entry.timestamp else: job_info["StartTime"] = entry.timestamp results.append(job_info) return results def profile_table(self): self._check_connected() result = defaultdict(list) profile_table = self.global_state_accessor.get_profile_table() for i in range(len(profile_table)): profile = gcs_utils.ProfileTableData.FromString(profile_table[i]) component_type = profile.component_type component_id = binary_to_hex(profile.component_id) node_ip_address = profile.node_ip_address for event in profile.profile_events: try: extra_data = json.loads(event.extra_data) except ValueError: extra_data = {} profile_event = { "event_type": event.event_type, "component_id": component_id, "node_ip_address": node_ip_address, "component_type": component_type, "start_time": event.start_time, "end_time": event.end_time, "extra_data": extra_data } result[component_id].append(profile_event) return dict(result) def _seconds_to_microseconds(self, time_in_seconds): """A helper function for converting seconds to microseconds.""" time_in_microseconds = 10**6 * time_in_seconds return time_in_microseconds # Colors are specified at # https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html. # noqa: E501 _default_color_mapping = defaultdict( lambda: "generic_work", { "worker_idle": "cq_build_abandoned", "task": "rail_response", "task:deserialize_arguments": "rail_load", "task:execute": "rail_animation", "task:store_outputs": "rail_idle", "wait_for_function": "detailed_memory_dump", "ray.get": "good", "ray.put": "terrible", "ray.wait": "vsync_highlight_color", "submit_task": "background_memory_dump", "fetch_and_run_function": "detailed_memory_dump", "register_remote_function": "detailed_memory_dump", }) # These colors are for use in Chrome tracing. _chrome_tracing_colors = [ "thread_state_uninterruptible", "thread_state_iowait", "thread_state_running", "thread_state_runnable", "thread_state_sleeping", "thread_state_unknown", "background_memory_dump", "light_memory_dump", "detailed_memory_dump", "vsync_highlight_color", "generic_work", "good", "bad", "terrible", # "black", # "grey", # "white", "yellow", "olive", "rail_response", "rail_animation", "rail_idle", "rail_load", "startup", "heap_dump_stack_frame", "heap_dump_object_type", "heap_dump_child_node_arrow", "cq_build_running", "cq_build_passed", "cq_build_failed", "cq_build_abandoned", "cq_build_attempt_runnig", "cq_build_attempt_passed", "cq_build_attempt_failed", ] def chrome_tracing_dump(self, filename=None): """Return a list of profiling events that can viewed as a timeline. To view this information as a timeline, simply dump it as a json file by passing in "filename" or using using json.dump, and then load go to chrome://tracing in the Chrome web browser and load the dumped file. Make sure to enable "Flow events" in the "View Options" menu. Args: filename: If a filename is provided, the timeline is dumped to that file. Returns: If filename is not provided, this returns a list of profiling events. Each profile event is a dictionary. """ # TODO(rkn): Support including the task specification data in the # timeline. # TODO(rkn): This should support viewing just a window of time or a # limited number of events. self._check_connected() profile_table = self.profile_table() all_events = [] for component_id_hex, component_events in profile_table.items(): # Only consider workers and drivers. component_type = component_events[0]["component_type"] if component_type not in ["worker", "driver"]: continue for event in component_events: new_event = { # The category of the event. "cat": event["event_type"], # The string displayed on the event. "name": event["event_type"], # The identifier for the group of rows that the event # appears in. "pid": event["node_ip_address"], # The identifier for the row that the event appears in. "tid": event["component_type"] + ":" + event["component_id"], # The start time in microseconds. "ts": self._seconds_to_microseconds(event["start_time"]), # The duration in microseconds. "dur": self._seconds_to_microseconds(event["end_time"] - event["start_time"]), # What is this? "ph": "X", # This is the name of the color to display the box in. "cname": self._default_color_mapping[event["event_type"]], # The extra user-defined data. "args": event["extra_data"], } # Modify the json with the additional user-defined extra data. # This can be used to add fields or override existing fields. if "cname" in event["extra_data"]: new_event["cname"] = event["extra_data"]["cname"] if "name" in event["extra_data"]: new_event["name"] = event["extra_data"]["name"] all_events.append(new_event) if filename is not None: with open(filename, "w") as outfile: json.dump(all_events, outfile) else: return all_events def chrome_tracing_object_transfer_dump(self, filename=None): """Return a list of transfer events that can viewed as a timeline. To view this information as a timeline, simply dump it as a json file by passing in "filename" or using using json.dump, and then load go to chrome://tracing in the Chrome web browser and load the dumped file. Make sure to enable "Flow events" in the "View Options" menu. Args: filename: If a filename is provided, the timeline is dumped to that file. Returns: If filename is not provided, this returns a list of profiling events. Each profile event is a dictionary. """ self._check_connected() node_id_to_address = {} for node_info in self.client_table(): node_id_to_address[node_info["NodeID"]] = "{}:{}".format( node_info["NodeManagerAddress"], node_info["ObjectManagerPort"]) all_events = [] for key, items in self.profile_table().items(): # Only consider object manager events. if items[0]["component_type"] != "object_manager": continue for event in items: if event["event_type"] == "transfer_send": object_id, remote_node_id, _, _ = event["extra_data"] elif event["event_type"] == "transfer_receive": object_id, remote_node_id, _, _ = event["extra_data"] elif event["event_type"] == "receive_pull_request": object_id, remote_node_id = event["extra_data"] else: assert False, "This should be unreachable." # Choose a color by reading the first couple of hex digits of # the object ID as an integer and turning that into a color. object_id_int = int(object_id[:2], 16) color = self._chrome_tracing_colors[object_id_int % len( self._chrome_tracing_colors)] new_event = { # The category of the event. "cat": event["event_type"], # The string displayed on the event. "name": event["event_type"], # The identifier for the group of rows that the event # appears in. "pid": node_id_to_address[key], # The identifier for the row that the event appears in. "tid": node_id_to_address[remote_node_id], # The start time in microseconds. "ts": self._seconds_to_microseconds(event["start_time"]), # The duration in microseconds. "dur": self._seconds_to_microseconds(event["end_time"] - event["start_time"]), # What is this? "ph": "X", # This is the name of the color to display the box in. "cname": color, # The extra user-defined data. "args": event["extra_data"], } all_events.append(new_event) # Add another box with a color indicating whether it was a send # or a receive event. if event["event_type"] == "transfer_send": additional_event = new_event.copy() additional_event["cname"] = "black" all_events.append(additional_event) elif event["event_type"] == "transfer_receive": additional_event = new_event.copy() additional_event["cname"] = "grey" all_events.append(additional_event) else: pass if filename is not None: with open(filename, "w") as outfile: json.dump(all_events, outfile) else: return all_events def workers(self): """Get a dictionary mapping worker ID to worker information.""" self._check_connected() worker_keys = self.redis_client.keys("Worker*") workers_data = {} for worker_key in worker_keys: worker_info = self.redis_client.hgetall(worker_key) worker_id = binary_to_hex(worker_key[len("Workers:"):]) workers_data[worker_id] = { "node_ip_address": decode(worker_info[b"node_ip_address"]), "plasma_store_socket": decode( worker_info[b"plasma_store_socket"]) } if b"stderr_file" in worker_info: workers_data[worker_id]["stderr_file"] = decode( worker_info[b"stderr_file"]) if b"stdout_file" in worker_info: workers_data[worker_id]["stdout_file"] = decode( worker_info[b"stdout_file"]) return workers_data def _job_length(self): event_log_sets = self.redis_client.keys("event_log*") overall_smallest = sys.maxsize overall_largest = 0 num_tasks = 0 for event_log_set in event_log_sets: fwd_range = self.redis_client.zrange( event_log_set, start=0, end=0, withscores=True) overall_smallest = min(overall_smallest, fwd_range[0][1]) rev_range = self.redis_client.zrevrange( event_log_set, start=0, end=0, withscores=True) overall_largest = max(overall_largest, rev_range[0][1]) num_tasks += self.redis_client.zcount( event_log_set, min=0, max=time.time()) if num_tasks == 0: return 0, 0, 0 return overall_smallest, overall_largest, num_tasks def cluster_resources(self): """Get the current total cluster resources. Note that this information can grow stale as nodes are added to or removed from the cluster. Returns: A dictionary mapping resource name to the total quantity of that resource in the cluster. """ self._check_connected() resources = defaultdict(int) clients = self.client_table() for client in clients: # Only count resources from latest entries of live clients. if client["Alive"]: for key, value in client["Resources"].items(): resources[key] += value return dict(resources) def _live_client_ids(self): """Returns a set of client IDs corresponding to clients still alive.""" return { client["NodeID"] for client in self.client_table() if (client["Alive"]) } def available_resources(self): """Get the current available cluster resources. This is different from `cluster_resources` in that this will return idle (available) resources rather than total resources. Note that this information can grow stale as tasks start and finish. Returns: A dictionary mapping resource name to the total quantity of that resource in the cluster. """ self._check_connected() available_resources_by_id = {} subscribe_clients = [ redis_client.pubsub(ignore_subscribe_messages=True) for redis_client in self.redis_clients ] for subscribe_client in subscribe_clients: subscribe_client.subscribe(gcs_utils.XRAY_HEARTBEAT_CHANNEL) client_ids = self._live_client_ids() while set(available_resources_by_id.keys()) != client_ids: for subscribe_client in subscribe_clients: # Parse client message raw_message = subscribe_client.get_message() if (raw_message is None or raw_message["channel"] != gcs_utils.XRAY_HEARTBEAT_CHANNEL): continue data = raw_message["data"] gcs_entries = gcs_utils.GcsEntry.FromString(data) heartbeat_data = gcs_entries.entries[0] message = gcs_utils.HeartbeatTableData.FromString( heartbeat_data) # Calculate available resources for this client num_resources = len(message.resources_available_label) dynamic_resources = {} for i in range(num_resources): resource_id = message.resources_available_label[i] dynamic_resources[resource_id] = ( message.resources_available_capacity[i]) # Update available resources for this client client_id = ray.utils.binary_to_hex(message.client_id) available_resources_by_id[client_id] = dynamic_resources # Update clients in cluster client_ids = self._live_client_ids() # Remove disconnected clients for client_id in list(available_resources_by_id.keys()): if client_id not in client_ids: del available_resources_by_id[client_id] # Calculate total available resources total_available_resources = defaultdict(int) for available_resources in available_resources_by_id.values(): for resource_id, num_available in available_resources.items(): total_available_resources[resource_id] += num_available # Close the pubsub clients to avoid leaking file descriptors. for subscribe_client in subscribe_clients: subscribe_client.close() return dict(total_available_resources) def _error_messages(self, job_id): """Get the error messages for a specific driver. Args: job_id: The ID of the job to get the errors for. Returns: A list of the error messages for this driver. """ assert isinstance(job_id, ray.JobID) message = self.redis_client.execute_command( "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ERROR_INFO"), "", job_id.binary()) # If there are no errors, return early. if message is None: return [] gcs_entries = gcs_utils.GcsEntry.FromString(message) error_messages = [] for entry in gcs_entries.entries: error_data = gcs_utils.ErrorTableData.FromString(entry) assert job_id.binary() == error_data.job_id error_message = { "type": error_data.type, "message": error_data.error_message, "timestamp": error_data.timestamp, } error_messages.append(error_message) return error_messages def error_messages(self, job_id=None): """Get the error messages for all drivers or a specific driver. Args: job_id: The specific job to get the errors for. If this is None, then this method retrieves the errors for all jobs. Returns: A list of the error messages for the specified driver if one was given, or a dictionary mapping from job ID to a list of error messages for that driver otherwise. """ self._check_connected() if job_id is not None: assert isinstance(job_id, ray.JobID) return self._error_messages(job_id) error_table_keys = self.redis_client.keys( gcs_utils.TablePrefix_ERROR_INFO_string + "*") job_ids = [ key[len(gcs_utils.TablePrefix_ERROR_INFO_string):] for key in error_table_keys ] return { binary_to_hex(job_id): self._error_messages(ray.JobID(job_id)) for job_id in job_ids } def actor_checkpoint_info(self, actor_id): """Get checkpoint info for the given actor id. Args: actor_id: Actor's ID. Returns: A dictionary with information about the actor's checkpoint IDs and their timestamps. """ self._check_connected() message = self._execute_command( actor_id, "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ACTOR_CHECKPOINT_ID"), "", actor_id.binary(), ) if message is None: return None gcs_entry = gcs_utils.GcsEntry.FromString(message) entry = gcs_utils.ActorCheckpointIdData.FromString( gcs_entry.entries[0]) checkpoint_ids = [ ray.ActorCheckpointID(checkpoint_id) for checkpoint_id in entry.checkpoint_ids ] return { "ActorID": ray.utils.binary_to_hex(entry.actor_id), "CheckpointIds": checkpoint_ids, "Timestamps": list(entry.timestamps), }
class Monitor: """A monitor for Ray processes. The monitor is in charge of cleaning up the tables in the global state after processes have died. The monitor is currently not responsible for detecting component failures. Attributes: redis: A connection to the Redis server. primary_subscribe_client: A pubsub client for the Redis server. This is used to receive notifications about failed components. """ def __init__(self, redis_address, autoscaling_config, redis_password=None): # Initialize the Redis clients. ray.state.state._initialize_global_state(redis_address, redis_password=redis_password) self.redis = ray._private.services.create_redis_client( redis_address, password=redis_password) self.global_state_accessor = GlobalStateAccessor( redis_address, redis_password, False) self.global_state_accessor.connect() # Set the redis client and mode so _internal_kv works for autoscaler. worker = ray.worker.global_worker worker.redis_client = self.redis worker.mode = 0 # Setup subscriptions to the primary Redis server and the Redis shards. self.primary_subscribe_client = self.redis.pubsub( ignore_subscribe_messages=True) # Keep a mapping from raylet client ID to IP address to use # for updating the load metrics. self.raylet_id_to_ip_map = {} head_node_ip = redis_address.split(":")[0] self.load_metrics = LoadMetrics(local_ip=head_node_ip) if autoscaling_config: self.autoscaler = StandardAutoscaler(autoscaling_config, self.load_metrics) self.autoscaling_config = autoscaling_config else: self.autoscaler = None self.autoscaling_config = None def __del__(self): """Destruct the monitor object.""" # We close the pubsub client to avoid leaking file descriptors. try: primary_subscribe_client = self.primary_subscribe_client except AttributeError: primary_subscribe_client = None if primary_subscribe_client is not None: primary_subscribe_client.close() if self.global_state_accessor is not None: self.global_state_accessor.disconnect() self.global_state_accessor = None def subscribe(self, channel): """Subscribe to the given channel on the primary Redis shard. Args: channel (str): The channel to subscribe to. Raises: Exception: An exception is raised if the subscription fails. """ self.primary_subscribe_client.subscribe(channel) def update_load_metrics(self): """Fetches heartbeat data from GCS and updates load metrics.""" all_heartbeat = self.global_state_accessor.get_all_heartbeat() heartbeat_batch_data = \ ray.gcs_utils.HeartbeatBatchTableData.FromString(all_heartbeat) for heartbeat_message in heartbeat_batch_data.batch: resource_load = dict(heartbeat_message.resource_load) total_resources = dict(heartbeat_message.resources_total) available_resources = dict(heartbeat_message.resources_available) waiting_bundles, infeasible_bundles = parse_resource_demands( heartbeat_batch_data.resource_load_by_shape) pending_placement_groups = list( heartbeat_batch_data.placement_group_load.placement_group_data) # Update the load metrics for this raylet. node_id = ray.utils.binary_to_hex(heartbeat_message.node_id) ip = self.raylet_id_to_ip_map.get(node_id) if ip: self.load_metrics.update(ip, total_resources, available_resources, resource_load, waiting_bundles, infeasible_bundles, pending_placement_groups) else: logger.warning( f"Monitor: could not find ip for node {node_id}") def autoscaler_resource_request_handler(self, _, data): """Handle a notification of a resource request for the autoscaler. This channel and method are only used by the manual `ray.autoscaler.sdk.request_resources` api. Args: channel: unused data: a resource request as JSON, e.g. {"CPU": 1} """ if not self.autoscaler: return try: self.autoscaler.request_resources(json.loads(data)) except Exception: # We don't want this to kill the monitor. traceback.print_exc() def process_messages(self, max_messages=10000): """Process all messages ready in the subscription channels. This reads messages from the subscription channels and calls the appropriate handlers until there are no messages left. Args: max_messages: The maximum number of messages to process before returning. """ subscribe_clients = [self.primary_subscribe_client] for subscribe_client in subscribe_clients: for _ in range(max_messages): message = None try: message = subscribe_client.get_message() except redis.exceptions.ConnectionError: pass if message is None: # Continue on to the next subscribe client. break # Parse the message. channel = message["channel"] data = message["data"] if (channel == ray.ray_constants.AUTOSCALER_RESOURCE_REQUEST_CHANNEL): message_handler = self.autoscaler_resource_request_handler else: assert False, "This code should be unreachable." # Call the handler. message_handler(channel, data) def update_raylet_map(self, _append_port=False): """Updates internal raylet map. Args: _append_port (bool): Defaults to False. Appending the port is useful in testing, as mock clusters have many nodes with the same IP and cannot be uniquely identified. """ all_raylet_nodes = ray.nodes() self.raylet_id_to_ip_map = {} for raylet_info in all_raylet_nodes: node_id = (raylet_info.get("DBClientID") or raylet_info["NodeID"]) ip_address = (raylet_info.get("AuxAddress") or raylet_info["NodeManagerAddress"]).split(":")[0] if _append_port: ip_address += ":" + str(raylet_info["NodeManagerPort"]) self.raylet_id_to_ip_map[node_id] = ip_address def _run(self): """Run the monitor. This function loops forever, checking for messages about dead database clients and cleaning up state accordingly. """ self.subscribe(ray.ray_constants.AUTOSCALER_RESOURCE_REQUEST_CHANNEL) # Handle messages from the subscription channels. while True: # Process autoscaling actions if self.autoscaler: # Only used to update the load metrics for the autoscaler. self.update_raylet_map() self.update_load_metrics() self.autoscaler.update() # Process a round of messages. self.process_messages() # Wait for a autoscaler update interval before processing the next # round of messages. time.sleep(AUTOSCALER_UPDATE_INTERVAL_S) def destroy_autoscaler_workers(self): """Cleanup the autoscaler, in case of an exception in the run() method. We kill the worker nodes, but retain the head node in order to keep logs around, keeping costs minimal. This monitor process runs on the head node anyway, so this is more reliable.""" if self.autoscaler is None: return # Nothing to clean up. if self.autoscaling_config is None: # This is a logic error in the program. Can't do anything. logger.error( "Monitor: Cleanup failed due to lack of autoscaler config.") return logger.info("Monitor: Exception caught. Taking down workers...") clean = False while not clean: try: teardown_cluster( config_file=self.autoscaling_config, yes=True, # Non-interactive. workers_only=True, # Retain head node for logs. override_cluster_name=None, keep_min_workers=True, # Retain minimal amount of workers. ) clean = True logger.info("Monitor: Workers taken down.") except Exception: logger.error("Monitor: Cleanup exception. Trying again...") time.sleep(2) def run(self): try: self._run() except Exception: logger.exception("Error in monitor loop") if self.autoscaler: self.autoscaler.kill_workers() raise
def test_load_report(shutdown_only, max_shapes): resource1 = "A" resource2 = "B" cluster = ray.init(num_cpus=1, resources={resource1: 1}, _system_config={ "max_resource_shapes_per_load_report": max_shapes, }) global_state_accessor = GlobalStateAccessor( cluster["redis_address"], ray.ray_constants.REDIS_DEFAULT_PASSWORD) global_state_accessor.connect() @ray.remote def sleep(): time.sleep(1000) sleep.remote() for _ in range(3): sleep.remote() sleep.options(resources={resource1: 1}).remote() sleep.options(resources={resource2: 1}).remote() class Checker: def __init__(self): self.report = None def check_load_report(self): message = global_state_accessor.get_all_resource_usage() if message is None: return False resource_usage = ray.gcs_utils.ResourceUsageBatchData.FromString( message) self.report = \ resource_usage.resource_load_by_shape.resource_demands if max_shapes == 0: return True elif max_shapes == 2: return len(self.report) >= 2 else: return len(self.report) >= 3 # Wait for load information to arrive. checker = Checker() ray.test_utils.wait_for_condition(checker.check_load_report) # Check that we respect the max shapes limit. if max_shapes != -1: assert len(checker.report) <= max_shapes print(checker.report) if max_shapes > 0: # Check that we always include the 1-CPU resource shape. one_cpu_shape = {"CPU": 1} one_cpu_found = False for demand in checker.report: if demand.shape == one_cpu_shape: one_cpu_found = True assert one_cpu_found # Check that we differentiate between infeasible and ready tasks. for demand in checker.report: if resource2 in demand.shape: assert demand.num_infeasible_requests_queued > 0 assert demand.num_ready_requests_queued == 0 else: assert demand.num_ready_requests_queued > 0 assert demand.num_infeasible_requests_queued == 0 global_state_accessor.disconnect()
class Monitor: """Autoscaling monitor. This process periodically collects stats from the GCS and triggers autoscaler updates. Attributes: redis: A connection to the Redis server. """ def __init__(self, redis_address, autoscaling_config, redis_password=None, prefix_cluster_info=False): # Initialize the Redis clients. ray.state.state._initialize_global_state(redis_address, redis_password=redis_password) self.redis = ray._private.services.create_redis_client( redis_address, password=redis_password) self.global_state_accessor = GlobalStateAccessor( redis_address, redis_password, False) self.global_state_accessor.connect() # Set the redis client and mode so _internal_kv works for autoscaler. worker = ray.worker.global_worker worker.redis_client = self.redis worker.mode = 0 # Keep a mapping from raylet client ID to IP address to use # for updating the load metrics. self.raylet_id_to_ip_map = {} head_node_ip = redis_address.split(":")[0] self.load_metrics = LoadMetrics(local_ip=head_node_ip) if autoscaling_config: self.autoscaler = StandardAutoscaler( autoscaling_config, self.load_metrics, prefix_cluster_info=prefix_cluster_info) self.autoscaling_config = autoscaling_config else: self.autoscaler = None self.autoscaling_config = None def __del__(self): """Destruct the monitor object.""" # We close the pubsub client to avoid leaking file descriptors. if self.global_state_accessor is not None: self.global_state_accessor.disconnect() self.global_state_accessor = None def update_load_metrics(self): """Fetches resource usage data from GCS and updates load metrics.""" all_resources = self.global_state_accessor.get_all_resource_usage() resources_batch_data = \ ray.gcs_utils.ResourceUsageBatchData.FromString(all_resources) for resource_message in resources_batch_data.batch: resource_load = dict(resource_message.resource_load) total_resources = dict(resource_message.resources_total) available_resources = dict(resource_message.resources_available) waiting_bundles, infeasible_bundles = parse_resource_demands( resources_batch_data.resource_load_by_shape) pending_placement_groups = list( resources_batch_data.placement_group_load.placement_group_data) # Update the load metrics for this raylet. node_id = ray.utils.binary_to_hex(resource_message.node_id) ip = self.raylet_id_to_ip_map.get(node_id) if ip: self.load_metrics.update(ip, total_resources, available_resources, resource_load, waiting_bundles, infeasible_bundles, pending_placement_groups) else: logger.warning( f"Monitor: could not find ip for node {node_id}") def update_resource_requests(self): """Fetches resource requests from the internal KV and updates load.""" if not _internal_kv_initialized(): return data = _internal_kv_get( ray.ray_constants.AUTOSCALER_RESOURCE_REQUEST_CHANNEL) if data: try: resource_request = json.loads(data) self.load_metrics.set_resource_requests(resource_request) except Exception: logger.exception("Error parsing resource requests") def autoscaler_resource_request_handler(self, _, data): """Handle a notification of a resource request for the autoscaler. This channel and method are only used by the manual `ray.autoscaler.sdk.request_resources` api. Args: channel: unused data: a resource request as JSON, e.g. {"CPU": 1} """ resource_request = json.loads(data) self.load_metrics.set_resource_requests(resource_request) def update_raylet_map(self, _append_port=False): """Updates internal raylet map. Args: _append_port (bool): Defaults to False. Appending the port is useful in testing, as mock clusters have many nodes with the same IP and cannot be uniquely identified. """ all_raylet_nodes = ray.nodes() self.raylet_id_to_ip_map = {} for raylet_info in all_raylet_nodes: node_id = (raylet_info.get("DBClientID") or raylet_info["NodeID"]) ip_address = (raylet_info.get("AuxAddress") or raylet_info["NodeManagerAddress"]).split(":")[0] if _append_port: ip_address += ":" + str(raylet_info["NodeManagerPort"]) self.raylet_id_to_ip_map[node_id] = ip_address def _run(self): """Run the monitor loop.""" while True: self.update_raylet_map() self.update_load_metrics() self.update_resource_requests() status = { "load_metrics_report": self.load_metrics.summary()._asdict() } # Process autoscaling actions if self.autoscaler: # Only used to update the load metrics for the autoscaler. self.autoscaler.update() status["autoscaler_report"] = self.autoscaler.summary( )._asdict() as_json = json.dumps(status) if _internal_kv_initialized(): _internal_kv_put(DEBUG_AUTOSCALING_STATUS, as_json, overwrite=True) # Wait for a autoscaler update interval before processing the next # round of messages. time.sleep(AUTOSCALER_UPDATE_INTERVAL_S) def destroy_autoscaler_workers(self): """Cleanup the autoscaler, in case of an exception in the run() method. We kill the worker nodes, but retain the head node in order to keep logs around, keeping costs minimal. This monitor process runs on the head node anyway, so this is more reliable.""" if self.autoscaler is None: return # Nothing to clean up. if self.autoscaling_config is None: # This is a logic error in the program. Can't do anything. logger.error( "Monitor: Cleanup failed due to lack of autoscaler config.") return logger.info("Monitor: Exception caught. Taking down workers...") clean = False while not clean: try: teardown_cluster( config_file=self.autoscaling_config, yes=True, # Non-interactive. workers_only=True, # Retain head node for logs. override_cluster_name=None, keep_min_workers=True, # Retain minimal amount of workers. ) clean = True logger.info("Monitor: Workers taken down.") except Exception: logger.error("Monitor: Cleanup exception. Trying again...") time.sleep(2) def run(self): try: self._run() except Exception: logger.exception("Error in monitor loop") if self.autoscaler: self.autoscaler.kill_workers() raise
def _really_init_global_state(self): self.global_state_accessor = GlobalStateAccessor(self.gcs_options) self.global_state_accessor.connect()
class GlobalState: """A class used to interface with the Ray control state. Attributes: global_state_accessor: The client used to query gcs table from gcs server. """ def __init__(self): """Create a GlobalState object.""" # Args used for lazy init of this object. self.gcs_options = None self.global_state_accessor = None def _check_connected(self): """Ensure that the object has been initialized before it is used. This lazily initializes clients needed for state accessors. Raises: RuntimeError: An exception is raised if ray.init() has not been called yet. """ if (self.gcs_options is not None and self.global_state_accessor is None): self._really_init_global_state() # _really_init_global_state should have set self.global_state_accessor if self.global_state_accessor is None: raise ray.exceptions.RaySystemError( "Ray has not been started yet. You can start Ray with " "'ray.init()'.") def disconnect(self): """Disconnect global state from GCS.""" self.gcs_options = None if self.global_state_accessor is not None: self.global_state_accessor.disconnect() self.global_state_accessor = None def _initialize_global_state(self, gcs_options): """Set args for lazily initialization of the GlobalState object. It's possible that certain keys in gcs kv may not have been fully populated yet. In this case, we will retry this method until they have been populated or we exceed a timeout. Args: gcs_options: The client options for gcs """ # Save args for lazy init of global state. This avoids opening extra # gcs connections from each worker until needed. self.gcs_options = gcs_options def _really_init_global_state(self): self.global_state_accessor = GlobalStateAccessor(self.gcs_options) self.global_state_accessor.connect() def actor_table(self, actor_id): """Fetch and parse the actor table information for a single actor ID. Args: actor_id: A hex string of the actor ID to fetch information about. If this is None, then the actor table is fetched. Returns: Information from the actor table. """ self._check_connected() if actor_id is not None: actor_id = ray.ActorID(hex_to_binary(actor_id)) actor_info = self.global_state_accessor.get_actor_info(actor_id) if actor_info is None: return {} else: actor_table_data = gcs_utils.ActorTableData.FromString( actor_info) return self._gen_actor_info(actor_table_data) else: actor_table = self.global_state_accessor.get_actor_table() results = {} for i in range(len(actor_table)): actor_table_data = gcs_utils.ActorTableData.FromString( actor_table[i]) results[binary_to_hex(actor_table_data.actor_id)] = \ self._gen_actor_info(actor_table_data) return results def _gen_actor_info(self, actor_table_data): """Parse actor table data. Returns: Information from actor table. """ actor_info = { "ActorID": binary_to_hex(actor_table_data.actor_id), "ActorClassName": actor_table_data.class_name, "IsDetached": actor_table_data.is_detached, "Name": actor_table_data.name, "JobID": binary_to_hex(actor_table_data.job_id), "Address": { "IPAddress": actor_table_data.address.ip_address, "Port": actor_table_data.address.port, "NodeID": binary_to_hex(actor_table_data.address.raylet_id), }, "OwnerAddress": { "IPAddress": actor_table_data.owner_address.ip_address, "Port": actor_table_data.owner_address.port, "NodeID": binary_to_hex(actor_table_data.owner_address.raylet_id), }, "State": gcs_pb2.ActorTableData.ActorState.DESCRIPTOR.values_by_number[ actor_table_data.state].name, "NumRestarts": actor_table_data.num_restarts, "Timestamp": actor_table_data.timestamp, "StartTime": actor_table_data.start_time, "EndTime": actor_table_data.end_time, "DeathCause": actor_table_data.death_cause } return actor_info def node_resource_table(self, node_id=None): """Fetch and parse the node resource table info for one. Args: node_id: An node ID to fetch information about. Returns: Information from the node resource table. """ self._check_connected() node_id = ray.NodeID(hex_to_binary(node_id)) node_resource_bytes = \ self.global_state_accessor.get_node_resource_info(node_id) if node_resource_bytes is None: return {} else: node_resource_info = gcs_utils.ResourceMap.FromString( node_resource_bytes) return { key: value.resource_capacity for key, value in node_resource_info.items.items() } def node_table(self): """Fetch and parse the Gcs node info table. Returns: Information about the node in the cluster. """ self._check_connected() node_table = self.global_state_accessor.get_node_table() results = [] for node_info_item in node_table: item = gcs_utils.GcsNodeInfo.FromString(node_info_item) node_info = { "NodeID": ray._private.utils.binary_to_hex(item.node_id), "Alive": item.state == gcs_utils.GcsNodeInfo.GcsNodeState.Value( "ALIVE"), "NodeManagerAddress": item.node_manager_address, "NodeManagerHostname": item.node_manager_hostname, "NodeManagerPort": item.node_manager_port, "ObjectManagerPort": item.object_manager_port, "ObjectStoreSocketName": item.object_store_socket_name, "RayletSocketName": item.raylet_socket_name, "MetricsExportPort": item.metrics_export_port, } node_info["alive"] = node_info["Alive"] node_info["Resources"] = self.node_resource_table( node_info["NodeID"]) if node_info["Alive"] else {} results.append(node_info) return results def job_table(self): """Fetch and parse the gcs job table. Returns: Information about the Ray jobs in the cluster, namely a list of dicts with keys: - "JobID" (identifier for the job), - "DriverIPAddress" (IP address of the driver for this job), - "DriverPid" (process ID of the driver for this job), - "StartTime" (UNIX timestamp of the start time of this job), - "StopTime" (UNIX timestamp of the stop time of this job, if any) """ self._check_connected() job_table = self.global_state_accessor.get_job_table() results = [] for i in range(len(job_table)): entry = gcs_utils.JobTableData.FromString(job_table[i]) job_info = {} job_info["JobID"] = entry.job_id.hex() job_info["DriverIPAddress"] = entry.driver_ip_address job_info["DriverPid"] = entry.driver_pid job_info["Timestamp"] = entry.timestamp job_info["StartTime"] = entry.start_time job_info["EndTime"] = entry.end_time job_info["IsDead"] = entry.is_dead results.append(job_info) return results def next_job_id(self): """Get next job id from GCS. Returns: Next job id in the cluster. """ self._check_connected() return ray.JobID.from_int(self.global_state_accessor.get_next_job_id()) def profile_table(self): self._check_connected() result = defaultdict(list) profile_table = self.global_state_accessor.get_profile_table() for i in range(len(profile_table)): profile = gcs_utils.ProfileTableData.FromString(profile_table[i]) component_type = profile.component_type component_id = binary_to_hex(profile.component_id) node_ip_address = profile.node_ip_address for event in profile.profile_events: try: extra_data = json.loads(event.extra_data) except ValueError: extra_data = {} profile_event = { "event_type": event.event_type, "component_id": component_id, "node_ip_address": node_ip_address, "component_type": component_type, "start_time": event.start_time, "end_time": event.end_time, "extra_data": extra_data } result[component_id].append(profile_event) return dict(result) def get_placement_group_by_name(self, placement_group_name, ray_namespace): self._check_connected() placement_group_info = ( self.global_state_accessor.get_placement_group_by_name( placement_group_name, ray_namespace)) if placement_group_info is None: return None else: placement_group_table_data = \ gcs_utils.PlacementGroupTableData.FromString( placement_group_info) return self._gen_placement_group_info(placement_group_table_data) def placement_group_table(self, placement_group_id=None): self._check_connected() if placement_group_id is not None: placement_group_id = ray.PlacementGroupID( hex_to_binary(placement_group_id.hex())) placement_group_info = ( self.global_state_accessor.get_placement_group_info( placement_group_id)) if placement_group_info is None: return {} else: placement_group_info = (gcs_utils.PlacementGroupTableData. FromString(placement_group_info)) return self._gen_placement_group_info(placement_group_info) else: placement_group_table = self.global_state_accessor.\ get_placement_group_table() results = {} for placement_group_info in placement_group_table: placement_group_table_data = gcs_utils.\ PlacementGroupTableData.FromString(placement_group_info) placement_group_id = binary_to_hex( placement_group_table_data.placement_group_id) results[placement_group_id] = \ self._gen_placement_group_info(placement_group_table_data) return results def _gen_placement_group_info(self, placement_group_info): # This should be imported here, otherwise, it will error doc build. from ray.core.generated.common_pb2 import PlacementStrategy def get_state(state): if state == gcs_utils.PlacementGroupTableData.PENDING: return "PENDING" elif state == gcs_utils.PlacementGroupTableData.CREATED: return "CREATED" else: return "REMOVED" def get_strategy(strategy): if strategy == PlacementStrategy.PACK: return "PACK" elif strategy == PlacementStrategy.STRICT_PACK: return "STRICT_PACK" elif strategy == PlacementStrategy.STRICT_SPREAD: return "STRICT_SPREAD" elif strategy == PlacementStrategy.SPREAD: return "SPREAD" else: raise ValueError( f"Invalid strategy returned: {PlacementStrategy}") stats = placement_group_info.stats assert placement_group_info is not None return { "placement_group_id": binary_to_hex(placement_group_info.placement_group_id), "name": placement_group_info.name, "bundles": { # The value here is needs to be dictionarified # otherwise, the payload becomes unserializable. bundle.bundle_id.bundle_index: MessageToDict(bundle)["unitResources"] for bundle in placement_group_info.bundles }, "strategy": get_strategy(placement_group_info.strategy), "state": get_state(placement_group_info.state), "stats": { "end_to_end_creation_latency_ms": (stats.end_to_end_creation_latency_us / 1000.0), "scheduling_latency_ms": (stats.scheduling_latency_us / 1000.0), "scheduling_attempt": stats.scheduling_attempt, "highest_retry_delay_ms": stats.highest_retry_delay_ms, "scheduling_state": gcs_pb2.PlacementGroupStats.SchedulingState.DESCRIPTOR. values_by_number[stats.scheduling_state].name } } def _seconds_to_microseconds(self, time_in_seconds): """A helper function for converting seconds to microseconds.""" time_in_microseconds = 10**6 * time_in_seconds return time_in_microseconds # Colors are specified at # https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html. # noqa: E501 _default_color_mapping = defaultdict( lambda: "generic_work", { "worker_idle": "cq_build_abandoned", "task": "rail_response", "task:deserialize_arguments": "rail_load", "task:execute": "rail_animation", "task:store_outputs": "rail_idle", "wait_for_function": "detailed_memory_dump", "ray.get": "good", "ray.put": "terrible", "ray.wait": "vsync_highlight_color", "submit_task": "background_memory_dump", "fetch_and_run_function": "detailed_memory_dump", "register_remote_function": "detailed_memory_dump", }) # These colors are for use in Chrome tracing. _chrome_tracing_colors = [ "thread_state_uninterruptible", "thread_state_iowait", "thread_state_running", "thread_state_runnable", "thread_state_sleeping", "thread_state_unknown", "background_memory_dump", "light_memory_dump", "detailed_memory_dump", "vsync_highlight_color", "generic_work", "good", "bad", "terrible", # "black", # "grey", # "white", "yellow", "olive", "rail_response", "rail_animation", "rail_idle", "rail_load", "startup", "heap_dump_stack_frame", "heap_dump_object_type", "heap_dump_child_node_arrow", "cq_build_running", "cq_build_passed", "cq_build_failed", "cq_build_abandoned", "cq_build_attempt_runnig", "cq_build_attempt_passed", "cq_build_attempt_failed", ] def chrome_tracing_dump(self, filename=None): """Return a list of profiling events that can viewed as a timeline. To view this information as a timeline, simply dump it as a json file by passing in "filename" or using using json.dump, and then load go to chrome://tracing in the Chrome web browser and load the dumped file. Make sure to enable "Flow events" in the "View Options" menu. Args: filename: If a filename is provided, the timeline is dumped to that file. Returns: If filename is not provided, this returns a list of profiling events. Each profile event is a dictionary. """ # TODO(rkn): Support including the task specification data in the # timeline. # TODO(rkn): This should support viewing just a window of time or a # limited number of events. self._check_connected() profile_table = self.profile_table() all_events = [] for component_id_hex, component_events in profile_table.items(): # Only consider workers and drivers. component_type = component_events[0]["component_type"] if component_type not in ["worker", "driver"]: continue for event in component_events: new_event = { # The category of the event. "cat": event["event_type"], # The string displayed on the event. "name": event["event_type"], # The identifier for the group of rows that the event # appears in. "pid": event["node_ip_address"], # The identifier for the row that the event appears in. "tid": event["component_type"] + ":" + event["component_id"], # The start time in microseconds. "ts": self._seconds_to_microseconds(event["start_time"]), # The duration in microseconds. "dur": self._seconds_to_microseconds(event["end_time"] - event["start_time"]), # What is this? "ph": "X", # This is the name of the color to display the box in. "cname": self._default_color_mapping[event["event_type"]], # The extra user-defined data. "args": event["extra_data"], } # Modify the json with the additional user-defined extra data. # This can be used to add fields or override existing fields. if "cname" in event["extra_data"]: new_event["cname"] = event["extra_data"]["cname"] if "name" in event["extra_data"]: new_event["name"] = event["extra_data"]["name"] all_events.append(new_event) if not all_events: logger.warning( "No profiling events found. Ray profiling must be enabled " "by setting RAY_PROFILING=1.") if filename is not None: with open(filename, "w") as outfile: json.dump(all_events, outfile) else: return all_events def chrome_tracing_object_transfer_dump(self, filename=None): """Return a list of transfer events that can viewed as a timeline. To view this information as a timeline, simply dump it as a json file by passing in "filename" or using using json.dump, and then load go to chrome://tracing in the Chrome web browser and load the dumped file. Make sure to enable "Flow events" in the "View Options" menu. Args: filename: If a filename is provided, the timeline is dumped to that file. Returns: If filename is not provided, this returns a list of profiling events. Each profile event is a dictionary. """ self._check_connected() node_id_to_address = {} for node_info in self.node_table(): node_id_to_address[node_info["NodeID"]] = "{}:{}".format( node_info["NodeManagerAddress"], node_info["ObjectManagerPort"]) all_events = [] for key, items in self.profile_table().items(): # Only consider object manager events. if items[0]["component_type"] != "object_manager": continue for event in items: if event["event_type"] == "transfer_send": object_ref, remote_node_id, _, _ = event["extra_data"] elif event["event_type"] == "transfer_receive": object_ref, remote_node_id, _ = event["extra_data"] elif event["event_type"] == "receive_pull_request": object_ref, remote_node_id = event["extra_data"] else: assert False, "This should be unreachable." # Choose a color by reading the first couple of hex digits of # the object ref as an integer and turning that into a color. object_ref_int = int(object_ref[:2], 16) color = self._chrome_tracing_colors[object_ref_int % len( self._chrome_tracing_colors)] new_event = { # The category of the event. "cat": event["event_type"], # The string displayed on the event. "name": event["event_type"], # The identifier for the group of rows that the event # appears in. "pid": node_id_to_address[key], # The identifier for the row that the event appears in. "tid": node_id_to_address[remote_node_id], # The start time in microseconds. "ts": self._seconds_to_microseconds(event["start_time"]), # The duration in microseconds. "dur": self._seconds_to_microseconds(event["end_time"] - event["start_time"]), # What is this? "ph": "X", # This is the name of the color to display the box in. "cname": color, # The extra user-defined data. "args": event["extra_data"], } all_events.append(new_event) # Add another box with a color indicating whether it was a send # or a receive event. if event["event_type"] == "transfer_send": additional_event = new_event.copy() additional_event["cname"] = "black" all_events.append(additional_event) elif event["event_type"] == "transfer_receive": additional_event = new_event.copy() additional_event["cname"] = "grey" all_events.append(additional_event) else: pass if filename is not None: with open(filename, "w") as outfile: json.dump(all_events, outfile) else: return all_events def workers(self): """Get a dictionary mapping worker ID to worker information.""" self._check_connected() # Get all data in worker table worker_table = self.global_state_accessor.get_worker_table() workers_data = {} for i in range(len(worker_table)): worker_table_data = gcs_utils.WorkerTableData.FromString( worker_table[i]) if worker_table_data.is_alive and \ worker_table_data.worker_type == gcs_utils.WORKER: worker_id = binary_to_hex( worker_table_data.worker_address.worker_id) worker_info = worker_table_data.worker_info workers_data[worker_id] = { "node_ip_address": decode(worker_info[b"node_ip_address"]), "plasma_store_socket": decode(worker_info[b"plasma_store_socket"]) } if b"stderr_file" in worker_info: workers_data[worker_id]["stderr_file"] = decode( worker_info[b"stderr_file"]) if b"stdout_file" in worker_info: workers_data[worker_id]["stdout_file"] = decode( worker_info[b"stdout_file"]) return workers_data def add_worker(self, worker_id, worker_type, worker_info): """ Add a worker to the cluster. Args: worker_id: ID of this worker. Type is bytes. worker_type: Type of this worker. Value is gcs_utils.DRIVER or gcs_utils.WORKER. worker_info: Info of this worker. Type is dict{str: str}. Returns: Is operation success """ worker_data = gcs_utils.WorkerTableData() worker_data.is_alive = True worker_data.worker_address.worker_id = worker_id worker_data.worker_type = worker_type for k, v in worker_info.items(): worker_data.worker_info[k] = bytes(v, encoding="utf-8") return self.global_state_accessor.add_worker_info( worker_data.SerializeToString()) def cluster_resources(self): """Get the current total cluster resources. Note that this information can grow stale as nodes are added to or removed from the cluster. Returns: A dictionary mapping resource name to the total quantity of that resource in the cluster. """ self._check_connected() resources = defaultdict(int) nodes = self.node_table() for node in nodes: # Only count resources from latest entries of live nodes. if node["Alive"]: for key, value in node["Resources"].items(): resources[key] += value return dict(resources) def _live_node_ids(self): """Returns a set of node IDs corresponding to nodes still alive.""" return { node["NodeID"] for node in self.node_table() if (node["Alive"]) } def _available_resources_per_node(self): """Returns a dictionary mapping node id to avaiable resources.""" self._check_connected() available_resources_by_id = {} all_available_resources = \ self.global_state_accessor.get_all_available_resources() for available_resource in all_available_resources: message = gcs_utils.AvailableResources.FromString( available_resource) # Calculate available resources for this node. dynamic_resources = {} for resource_id, capacity in \ message.resources_available.items(): dynamic_resources[resource_id] = capacity # Update available resources for this node. node_id = ray._private.utils.binary_to_hex(message.node_id) available_resources_by_id[node_id] = dynamic_resources # Update nodes in cluster. node_ids = self._live_node_ids() # Remove disconnected nodes. for node_id in list(available_resources_by_id.keys()): if node_id not in node_ids: del available_resources_by_id[node_id] return available_resources_by_id def available_resources(self): """Get the current available cluster resources. This is different from `cluster_resources` in that this will return idle (available) resources rather than total resources. Note that this information can grow stale as tasks start and finish. Returns: A dictionary mapping resource name to the total quantity of that resource in the cluster. """ self._check_connected() available_resources_by_id = self._available_resources_per_node() # Calculate total available resources. total_available_resources = defaultdict(int) for available_resources in available_resources_by_id.values(): for resource_id, num_available in available_resources.items(): total_available_resources[resource_id] += num_available return dict(total_available_resources) def get_system_config(self): """Get the system config of the cluster. """ self._check_connected() return json.loads(self.global_state_accessor.get_system_config()) def get_node_to_connect_for_driver(self, node_ip_address): """Get the node to connect for a Ray driver.""" self._check_connected() node_info_str = (self.global_state_accessor. get_node_to_connect_for_driver(node_ip_address)) return gcs_utils.GcsNodeInfo.FromString(node_info_str)
def make_global_state_accessor(ray_context): gcs_options = GcsClientOptions.from_gcs_address( ray_context.address_info["gcs_address"]) global_state_accessor = GlobalStateAccessor(gcs_options) global_state_accessor.connect() return global_state_accessor
def initialize(num_cpus: int, num_gpus: int, log_root_path: str, log_name: Optional[str] = None, logger_cls: type = TBXLogger, launch_tensorboard: bool = True, debug: bool = False, verbose: bool = True) -> Callable[[Dict[str, Any]], Logger]: """Initialize Ray and Tensorboard daemons. It will be used later for almost everything from dashboard, remote/client management, to multithreaded environment. .. note: The default Tensorboard port will be used, namely 6006 if available, using 0.0.0.0 (binding to all IPv4 addresses on local machine). Similarly, Ray dashboard port is 8265 if available. In both cases, the port will be increased interatively until to find one available. :param num_cpus: Maximum number of CPU threads that can be executed in parallel. Note that it does not actually reserve part of the CPU, so that several processes can reserve the number of threads available on the system at the same time. :param num_gpu: Maximum number of GPU unit that can be used, which can be fractional to only allocate part of the resource. Note that contrary to CPU resource, the memory is likely to actually be reserve and allocated by the process, in particular using Tensorflow backend. :param log_root_path: Fullpath of root log directory. :param log_name: Name of the subdirectory where to save data. `None` to use default name, empty string '' to set it interactively in command prompt. It must be a valid Python identifier. Optional: full date _ hostname by default. :param logger_cls: Custom logger class type deriving from `TBXLogger`. Optional: `TBXLogger` by default. :param launch_tensorboard: Whether or not to launch tensorboard automatically. Optional: Enabled by default. :param debug: Whether or not to display debugging trace. Optional: Disabled by default. :param verbose: Whether or not to print information about what is going on. Optional: True by default. :returns: lambda function to pass a `ray.Trainer` to monitor learning progress in Tensorboard. """ # Make sure provided logger class derives from ray.tune.logger.Logger assert issubclass(logger_cls, Logger), ( "Logger class must derive from `ray.tune.logger.Logger`") # Check if cluster servers are already running, and if requested resources # are available. is_cluster_running = False redis_addresses = services.find_redis_address() if redis_addresses: for redis_address in redis_addresses: # Connect to redis global state accessor global_state_accessor = GlobalStateAccessor( redis_address, ray_constants.REDIS_DEFAULT_PASSWORD) global_state_accessor.connect() # Get available resources resources: Dict[str, int] = defaultdict(int) for info in global_state_accessor.get_all_available_resources(): # pylint: disable=no-member message = ray.gcs_utils.AvailableResources.FromString(info) for field, capacity in message.resources_available.items(): resources[field] += capacity # Disconnect global state accessor time.sleep(0.1) global_state_accessor.disconnect() # Check if enough computation resources are available is_cluster_running = (resources["CPU"] >= num_cpus and resources["GPU"] >= num_gpus) # Stop looking as soon as a cluster with enough resources is found if is_cluster_running: break # Connect to Ray server if necessary, starting one if not already running if not ray.is_initialized(): if not is_cluster_running: # Start new Ray server, if not already running ray.init( # Address of Ray cluster to connect to, if any address=None, # Number of CPUs assigned to each raylet num_cpus=num_cpus, # Number of GPUs assigned to each raylet num_gpus=num_gpus, # Enable object eviction in LRU order under memory pressure _lru_evict=False, # Whether or not to execute the code serially (for debugging) local_mode=debug, # Logging level logging_level=logging.DEBUG if debug else logging.ERROR, # Whether to redirect outputs from every worker to the driver log_to_driver=debug, # Whether to start Ray dashboard, to monitor cluster's status include_dashboard=True, # The host to bind the dashboard server to dashboard_host="0.0.0.0") else: # Connect to existing Ray cluster ray.init( address="auto", _lru_evict=False, local_mode=debug, logging_level=logging.DEBUG if debug else logging.ERROR, log_to_driver=debug, include_dashboard=False) # Configure Tensorboard if launch_tensorboard: tb = TensorBoard() tb.configure(host="0.0.0.0", logdir=os.path.abspath(log_root_path)) url = tb.launch() if verbose: print(f"Started Tensorboard {url}.", f"Root directory: {log_root_path}") # Define log filename interactively if requested if log_name == "": while True: log_name = input( "Enter desired log subdirectory name (empty for default)...") if not log_name or re.match(r'^[A-Za-z0-9_]+$', log_name): break print("Unvalid name. Only Python identifiers are supported.") # Handling of default log name and sanity checks if not log_name: log_name = "_".join(( datetime.now().strftime("%Y_%m_%d_%H_%M_%S"), re.sub(r'[^A-Za-z0-9_]', "_", socket.gethostname()))) else: assert re.match(r'^[A-Za-z0-9_]+$', log_name), ( "Log name must be a valid Python identifier.") # Create log directory log_path = os.path.join(log_root_path, log_name) pathlib.Path(log_path).mkdir(parents=True, exist_ok=True) if verbose: print(f"Tensorboard logfiles directory: {log_path}") # Define Ray logger def logger_creator(config: Dict[str, Any]) -> Logger: return logger_cls(config, log_path) return logger_creator
def test_placement_group_load_report(ray_start_cluster): cluster = ray_start_cluster # Add a head node that doesn't have gpu resource. cluster.add_node(num_cpus=4) ray.init(address=cluster.address) global_state_accessor = GlobalStateAccessor( cluster.address, ray.ray_constants.REDIS_DEFAULT_PASSWORD) global_state_accessor.connect() class PgLoadChecker: def nothing_is_ready(self): resource_usage = self._read_resource_usage() if not resource_usage: return False if resource_usage.HasField("placement_group_load"): pg_load = resource_usage.placement_group_load return len(pg_load.placement_group_data) == 2 return False def only_first_one_ready(self): resource_usage = self._read_resource_usage() if not resource_usage: return False if resource_usage.HasField("placement_group_load"): pg_load = resource_usage.placement_group_load return len(pg_load.placement_group_data) == 1 return False def two_infeasible_pg(self): resource_usage = self._read_resource_usage() if not resource_usage: return False if resource_usage.HasField("placement_group_load"): pg_load = resource_usage.placement_group_load return len(pg_load.placement_group_data) == 2 return False def _read_resource_usage(self): message = global_state_accessor.get_all_resource_usage() if message is None: return False resource_usage = ray.gcs_utils.ResourceUsageBatchData.FromString( message) return resource_usage checker = PgLoadChecker() # Create 2 placement groups that are infeasible. pg_feasible = ray.util.placement_group([{"A": 1}]) pg_infeasible = ray.util.placement_group([{"B": 1}]) _, unready = ray.wait( [pg_feasible.ready(), pg_infeasible.ready()], timeout=0) assert len(unready) == 2 ray.test_utils.wait_for_condition(checker.nothing_is_ready) # Add a node that makes pg feasible. Make sure load include this change. cluster.add_node(resources={"A": 1}) ray.get(pg_feasible.ready()) ray.test_utils.wait_for_condition(checker.only_first_one_ready) # Create one more infeasible pg and make sure load is properly updated. pg_infeasible_second = ray.util.placement_group([{"C": 1}]) _, unready = ray.wait([pg_infeasible_second.ready()], timeout=0) assert len(unready) == 1 ray.test_utils.wait_for_condition(checker.two_infeasible_pg) global_state_accessor.disconnect()
class GlobalState: """A class used to interface with the Ray control state. # TODO(zongheng): In the future move this to use Ray's redis module in the # backend to cut down on # of request RPCs. Attributes: redis_client: The Redis client used to query the primary redis server. redis_clients: Redis clients for each of the Redis shards. global_state_accessor: The client used to query gcs table from gcs server. """ def __init__(self): """Create a GlobalState object.""" # The redis server storing metadata, such as function table, client # table, log files, event logs, workers/actions info. self.redis_client = None # Clients for the redis shards, storing the object table & task table. self.redis_clients = None self.global_state_accessor = None def _check_connected(self): """Check that the object has been initialized before it is used. Raises: RuntimeError: An exception is raised if ray.init() has not been called yet. """ if (self.redis_client is None or self.redis_clients is None or self.global_state_accessor is None): raise ray.exceptions.RaySystemError( "Ray has not been started yet. You can start Ray with " "'ray.init()'.") def disconnect(self): """Disconnect global state from GCS.""" self.redis_client = None self.redis_clients = None if self.global_state_accessor is not None: self.global_state_accessor.disconnect() self.global_state_accessor = None def _initialize_global_state(self, redis_address, redis_password=None, timeout=20): """Initialize the GlobalState object by connecting to Redis. It's possible that certain keys in Redis may not have been fully populated yet. In this case, we will retry this method until they have been populated or we exceed a timeout. Args: redis_address: The Redis address to connect. redis_password: The password of the redis server. """ self.redis_client = services.create_redis_client( redis_address, redis_password) self.global_state_accessor = GlobalStateAccessor( redis_address, redis_password, False) self.global_state_accessor.connect() start_time = time.time() num_redis_shards = None redis_shard_addresses = [] while time.time() - start_time < timeout: # Attempt to get the number of Redis shards. num_redis_shards = self.redis_client.get("NumRedisShards") if num_redis_shards is None: print("Waiting longer for NumRedisShards to be populated.") time.sleep(1) continue num_redis_shards = int(num_redis_shards) assert num_redis_shards >= 1, ( f"Expected at least one Redis shard, found {num_redis_shards}." ) # Attempt to get all of the Redis shards. redis_shard_addresses = self.redis_client.lrange("RedisShards", start=0, end=-1) if len(redis_shard_addresses) != num_redis_shards: print("Waiting longer for RedisShards to be populated.") time.sleep(1) continue # If we got here then we successfully got all of the information. break # Check to see if we timed out. if time.time() - start_time >= timeout: raise TimeoutError("Timed out while attempting to initialize the " "global state. " f"num_redis_shards = {num_redis_shards}, " "redis_shard_addresses = " f"{redis_shard_addresses}") # Get the rest of the information. self.redis_clients = [] for shard_address in redis_shard_addresses: self.redis_clients.append( services.create_redis_client(shard_address.decode(), redis_password)) def _execute_command(self, key, *args): """Execute a Redis command on the appropriate Redis shard based on key. Args: key: The object ref or the task ID that the query is about. args: The command to run. Returns: The value returned by the Redis command. """ client = self.redis_clients[key.redis_shard_hash() % len(self.redis_clients)] return client.execute_command(*args) def _keys(self, pattern): """Execute the KEYS command on all Redis shards. Args: pattern: The KEYS pattern to query. Returns: The concatenated list of results from all shards. """ result = [] for client in self.redis_clients: result.extend(list(client.scan_iter(match=pattern))) return result def object_table(self, object_ref=None): """Fetch and parse the object table info for one or more object refs. Args: object_ref: An object ref to fetch information about. If this is None, then the entire object table is fetched. Returns: Information from the object table. """ self._check_connected() if object_ref is not None: object_ref = ray.ObjectRef(hex_to_binary(object_ref)) object_info = self.global_state_accessor.get_object_info( object_ref) if object_info is None: return {} else: object_location_info = gcs_utils.ObjectLocationInfo.FromString( object_info) return self._gen_object_info(object_location_info) else: object_table = self.global_state_accessor.get_object_table() results = {} for i in range(len(object_table)): object_location_info = gcs_utils.ObjectLocationInfo.FromString( object_table[i]) results[binary_to_hex(object_location_info.object_id)] = \ self._gen_object_info(object_location_info) return results def _gen_object_info(self, object_location_info): """Parse object location info. Returns: Information from object. """ locations = [] for location in object_location_info.locations: locations.append(ray.utils.binary_to_hex(location.manager)) object_info = { "ObjectRef": ray.utils.binary_to_hex(object_location_info.object_id), "Locations": locations, } return object_info def actor_table(self, actor_id): """Fetch and parse the actor table information for a single actor ID. Args: actor_id: A hex string of the actor ID to fetch information about. If this is None, then the actor table is fetched. Returns: Information from the actor table. """ self._check_connected() if actor_id is not None: actor_id = ray.ActorID(hex_to_binary(actor_id)) actor_info = self.global_state_accessor.get_actor_info(actor_id) if actor_info is None: return {} else: actor_table_data = gcs_utils.ActorTableData.FromString( actor_info) return self._gen_actor_info(actor_table_data) else: actor_table = self.global_state_accessor.get_actor_table() results = {} for i in range(len(actor_table)): actor_table_data = gcs_utils.ActorTableData.FromString( actor_table[i]) results[binary_to_hex(actor_table_data.actor_id)] = \ self._gen_actor_info(actor_table_data) return results def _gen_actor_info(self, actor_table_data): """Parse actor table data. Returns: Information from actor table. """ actor_info = { "ActorID": binary_to_hex(actor_table_data.actor_id), "JobID": binary_to_hex(actor_table_data.job_id), "Address": { "IPAddress": actor_table_data.address.ip_address, "Port": actor_table_data.address.port, "NodeID": binary_to_hex(actor_table_data.address.raylet_id), }, "OwnerAddress": { "IPAddress": actor_table_data.owner_address.ip_address, "Port": actor_table_data.owner_address.port, "NodeID": binary_to_hex(actor_table_data.owner_address.raylet_id), }, "State": actor_table_data.state, "NumRestarts": actor_table_data.num_restarts, "Timestamp": actor_table_data.timestamp, } return actor_info def node_resource_table(self, node_id=None): """Fetch and parse the node resource table info for one. Args: node_id: An node ID to fetch information about. Returns: Information from the node resource table. """ self._check_connected() node_id = ray.NodeID(hex_to_binary(node_id)) node_resource_bytes = \ self.global_state_accessor.get_node_resource_info(node_id) if node_resource_bytes is None: return {} else: node_resource_info = gcs_utils.ResourceMap.FromString( node_resource_bytes) return { key: value.resource_capacity for key, value in node_resource_info.items.items() } def node_table(self): """Fetch and parse the Gcs node info table. Returns: Information about the node in the cluster. """ self._check_connected() node_table = self.global_state_accessor.get_node_table() results = [] for node_info_item in node_table: item = gcs_utils.GcsNodeInfo.FromString(node_info_item) node_info = { "NodeID": ray.utils.binary_to_hex(item.node_id), "Alive": item.state == gcs_utils.GcsNodeInfo.GcsNodeState.Value( "ALIVE"), "NodeManagerAddress": item.node_manager_address, "NodeManagerHostname": item.node_manager_hostname, "NodeManagerPort": item.node_manager_port, "ObjectManagerPort": item.object_manager_port, "ObjectStoreSocketName": item.object_store_socket_name, "RayletSocketName": item.raylet_socket_name, "MetricsExportPort": item.metrics_export_port, } node_info["alive"] = node_info["Alive"] node_info["Resources"] = self.node_resource_table( node_info["NodeID"]) if node_info["Alive"] else {} results.append(node_info) return results def job_table(self): """Fetch and parse the Redis job table. Returns: Information about the Ray jobs in the cluster, namely a list of dicts with keys: - "JobID" (identifier for the job), - "DriverIPAddress" (IP address of the driver for this job), - "DriverPid" (process ID of the driver for this job), - "StartTime" (UNIX timestamp of the start time of this job), - "StopTime" (UNIX timestamp of the stop time of this job, if any) """ self._check_connected() job_table = self.global_state_accessor.get_job_table() results = [] for i in range(len(job_table)): entry = gcs_utils.JobTableData.FromString(job_table[i]) job_info = {} job_info["JobID"] = entry.job_id.hex() job_info["DriverIPAddress"] = entry.driver_ip_address job_info["DriverPid"] = entry.driver_pid if entry.is_dead: job_info["StopTime"] = entry.timestamp else: job_info["StartTime"] = entry.timestamp results.append(job_info) return results def profile_table(self): self._check_connected() result = defaultdict(list) profile_table = self.global_state_accessor.get_profile_table() for i in range(len(profile_table)): profile = gcs_utils.ProfileTableData.FromString(profile_table[i]) component_type = profile.component_type component_id = binary_to_hex(profile.component_id) node_ip_address = profile.node_ip_address for event in profile.profile_events: try: extra_data = json.loads(event.extra_data) except ValueError: extra_data = {} profile_event = { "event_type": event.event_type, "component_id": component_id, "node_ip_address": node_ip_address, "component_type": component_type, "start_time": event.start_time, "end_time": event.end_time, "extra_data": extra_data } result[component_id].append(profile_event) return dict(result) def placement_group_table(self, placement_group_id=None): self._check_connected() if placement_group_id is not None: placement_group_id = ray.PlacementGroupID( hex_to_binary(placement_group_id.hex())) placement_group_info = ( self.global_state_accessor.get_placement_group_info( placement_group_id)) if placement_group_info is None: return {} else: placement_group_info = (gcs_utils.PlacementGroupTableData. FromString(placement_group_info)) return self._gen_placement_group_info(placement_group_info) else: placement_group_table = self.global_state_accessor.\ get_placement_group_table() results = {} for placement_group_info in placement_group_table: placement_group_table_data = gcs_utils.\ PlacementGroupTableData.FromString(placement_group_info) placement_group_id = binary_to_hex( placement_group_table_data.placement_group_id) results[placement_group_id] = \ self._gen_placement_group_info(placement_group_table_data) return results def _gen_placement_group_info(self, placement_group_info): # This should be imported here, otherwise, it will error doc build. from ray.core.generated.common_pb2 import PlacementStrategy def get_state(state): if state == ray.gcs_utils.PlacementGroupTableData.PENDING: return "PENDING" elif state == ray.gcs_utils.PlacementGroupTableData.CREATED: return "CREATED" else: return "REMOVED" def get_strategy(strategy): if strategy == PlacementStrategy.PACK: return "PACK" elif strategy == PlacementStrategy.STRICT_PACK: return "STRICT_PACK" elif strategy == PlacementStrategy.STRICT_SPREAD: return "STRICT_SPREAD" elif strategy == PlacementStrategy.SPREAD: return "SPREAD" else: raise ValueError( f"Invalid strategy returned: {PlacementStrategy}") assert placement_group_info is not None return { "placement_group_id": binary_to_hex(placement_group_info.placement_group_id), "name": placement_group_info.name, "bundles": { # The value here is needs to be dictionarified # otherwise, the payload becomes unserializable. bundle.bundle_id.bundle_index: MessageToDict(bundle)["unitResources"] for bundle in placement_group_info.bundles }, "strategy": get_strategy(placement_group_info.strategy), "state": get_state(placement_group_info.state), } def _seconds_to_microseconds(self, time_in_seconds): """A helper function for converting seconds to microseconds.""" time_in_microseconds = 10**6 * time_in_seconds return time_in_microseconds # Colors are specified at # https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html. # noqa: E501 _default_color_mapping = defaultdict( lambda: "generic_work", { "worker_idle": "cq_build_abandoned", "task": "rail_response", "task:deserialize_arguments": "rail_load", "task:execute": "rail_animation", "task:store_outputs": "rail_idle", "wait_for_function": "detailed_memory_dump", "ray.get": "good", "ray.put": "terrible", "ray.wait": "vsync_highlight_color", "submit_task": "background_memory_dump", "fetch_and_run_function": "detailed_memory_dump", "register_remote_function": "detailed_memory_dump", }) # These colors are for use in Chrome tracing. _chrome_tracing_colors = [ "thread_state_uninterruptible", "thread_state_iowait", "thread_state_running", "thread_state_runnable", "thread_state_sleeping", "thread_state_unknown", "background_memory_dump", "light_memory_dump", "detailed_memory_dump", "vsync_highlight_color", "generic_work", "good", "bad", "terrible", # "black", # "grey", # "white", "yellow", "olive", "rail_response", "rail_animation", "rail_idle", "rail_load", "startup", "heap_dump_stack_frame", "heap_dump_object_type", "heap_dump_child_node_arrow", "cq_build_running", "cq_build_passed", "cq_build_failed", "cq_build_abandoned", "cq_build_attempt_runnig", "cq_build_attempt_passed", "cq_build_attempt_failed", ] def chrome_tracing_dump(self, filename=None): """Return a list of profiling events that can viewed as a timeline. To view this information as a timeline, simply dump it as a json file by passing in "filename" or using using json.dump, and then load go to chrome://tracing in the Chrome web browser and load the dumped file. Make sure to enable "Flow events" in the "View Options" menu. Args: filename: If a filename is provided, the timeline is dumped to that file. Returns: If filename is not provided, this returns a list of profiling events. Each profile event is a dictionary. """ # TODO(rkn): Support including the task specification data in the # timeline. # TODO(rkn): This should support viewing just a window of time or a # limited number of events. self._check_connected() profile_table = self.profile_table() all_events = [] for component_id_hex, component_events in profile_table.items(): # Only consider workers and drivers. component_type = component_events[0]["component_type"] if component_type not in ["worker", "driver"]: continue for event in component_events: new_event = { # The category of the event. "cat": event["event_type"], # The string displayed on the event. "name": event["event_type"], # The identifier for the group of rows that the event # appears in. "pid": event["node_ip_address"], # The identifier for the row that the event appears in. "tid": event["component_type"] + ":" + event["component_id"], # The start time in microseconds. "ts": self._seconds_to_microseconds(event["start_time"]), # The duration in microseconds. "dur": self._seconds_to_microseconds(event["end_time"] - event["start_time"]), # What is this? "ph": "X", # This is the name of the color to display the box in. "cname": self._default_color_mapping[event["event_type"]], # The extra user-defined data. "args": event["extra_data"], } # Modify the json with the additional user-defined extra data. # This can be used to add fields or override existing fields. if "cname" in event["extra_data"]: new_event["cname"] = event["extra_data"]["cname"] if "name" in event["extra_data"]: new_event["name"] = event["extra_data"]["name"] all_events.append(new_event) if filename is not None: with open(filename, "w") as outfile: json.dump(all_events, outfile) else: return all_events def chrome_tracing_object_transfer_dump(self, filename=None): """Return a list of transfer events that can viewed as a timeline. To view this information as a timeline, simply dump it as a json file by passing in "filename" or using using json.dump, and then load go to chrome://tracing in the Chrome web browser and load the dumped file. Make sure to enable "Flow events" in the "View Options" menu. Args: filename: If a filename is provided, the timeline is dumped to that file. Returns: If filename is not provided, this returns a list of profiling events. Each profile event is a dictionary. """ self._check_connected() node_id_to_address = {} for node_info in self.node_table(): node_id_to_address[node_info["NodeID"]] = "{}:{}".format( node_info["NodeManagerAddress"], node_info["ObjectManagerPort"]) all_events = [] for key, items in self.profile_table().items(): # Only consider object manager events. if items[0]["component_type"] != "object_manager": continue for event in items: if event["event_type"] == "transfer_send": object_ref, remote_node_id, _, _ = event["extra_data"] elif event["event_type"] == "transfer_receive": object_ref, remote_node_id, _, _ = event["extra_data"] elif event["event_type"] == "receive_pull_request": object_ref, remote_node_id = event["extra_data"] else: assert False, "This should be unreachable." # Choose a color by reading the first couple of hex digits of # the object ref as an integer and turning that into a color. object_ref_int = int(object_ref[:2], 16) color = self._chrome_tracing_colors[object_ref_int % len( self._chrome_tracing_colors)] new_event = { # The category of the event. "cat": event["event_type"], # The string displayed on the event. "name": event["event_type"], # The identifier for the group of rows that the event # appears in. "pid": node_id_to_address[key], # The identifier for the row that the event appears in. "tid": node_id_to_address[remote_node_id], # The start time in microseconds. "ts": self._seconds_to_microseconds(event["start_time"]), # The duration in microseconds. "dur": self._seconds_to_microseconds(event["end_time"] - event["start_time"]), # What is this? "ph": "X", # This is the name of the color to display the box in. "cname": color, # The extra user-defined data. "args": event["extra_data"], } all_events.append(new_event) # Add another box with a color indicating whether it was a send # or a receive event. if event["event_type"] == "transfer_send": additional_event = new_event.copy() additional_event["cname"] = "black" all_events.append(additional_event) elif event["event_type"] == "transfer_receive": additional_event = new_event.copy() additional_event["cname"] = "grey" all_events.append(additional_event) else: pass if filename is not None: with open(filename, "w") as outfile: json.dump(all_events, outfile) else: return all_events def workers(self): """Get a dictionary mapping worker ID to worker information.""" self._check_connected() # Get all data in worker table worker_table = self.global_state_accessor.get_worker_table() workers_data = {} for i in range(len(worker_table)): worker_table_data = gcs_utils.WorkerTableData.FromString( worker_table[i]) if worker_table_data.is_alive and \ worker_table_data.worker_type == gcs_utils.WORKER: worker_id = binary_to_hex( worker_table_data.worker_address.worker_id) worker_info = worker_table_data.worker_info workers_data[worker_id] = { "node_ip_address": decode(worker_info[b"node_ip_address"]), "plasma_store_socket": decode(worker_info[b"plasma_store_socket"]) } if b"stderr_file" in worker_info: workers_data[worker_id]["stderr_file"] = decode( worker_info[b"stderr_file"]) if b"stdout_file" in worker_info: workers_data[worker_id]["stdout_file"] = decode( worker_info[b"stdout_file"]) return workers_data def add_worker(self, worker_id, worker_type, worker_info): """ Add a worker to the cluster. Args: worker_id: ID of this worker. Type is bytes. worker_type: Type of this worker. Value is ray.gcs_utils.DRIVER or ray.gcs_utils.WORKER. worker_info: Info of this worker. Type is dict{str: str}. Returns: Is operation success """ worker_data = ray.gcs_utils.WorkerTableData() worker_data.is_alive = True worker_data.worker_address.worker_id = worker_id worker_data.worker_type = worker_type for k, v in worker_info.items(): worker_data.worker_info[k] = bytes(v, encoding="utf-8") return self.global_state_accessor.add_worker_info( worker_data.SerializeToString()) def _job_length(self): event_log_sets = self.redis_client.keys("event_log*") overall_smallest = sys.maxsize overall_largest = 0 num_tasks = 0 for event_log_set in event_log_sets: fwd_range = self.redis_client.zrange(event_log_set, start=0, end=0, withscores=True) overall_smallest = min(overall_smallest, fwd_range[0][1]) rev_range = self.redis_client.zrevrange(event_log_set, start=0, end=0, withscores=True) overall_largest = max(overall_largest, rev_range[0][1]) num_tasks += self.redis_client.zcount(event_log_set, min=0, max=time.time()) if num_tasks == 0: return 0, 0, 0 return overall_smallest, overall_largest, num_tasks def cluster_resources(self): """Get the current total cluster resources. Note that this information can grow stale as nodes are added to or removed from the cluster. Returns: A dictionary mapping resource name to the total quantity of that resource in the cluster. """ self._check_connected() resources = defaultdict(int) clients = self.node_table() for client in clients: # Only count resources from latest entries of live clients. if client["Alive"]: for key, value in client["Resources"].items(): resources[key] += value return dict(resources) def _live_client_ids(self): """Returns a set of client IDs corresponding to clients still alive.""" return { client["NodeID"] for client in self.node_table() if (client["Alive"]) } def _available_resources_per_node(self): """Returns a dictionary mapping node id to avaiable resources.""" available_resources_by_id = {} all_available_resources = \ self.global_state_accessor.get_all_available_resources() for available_resource in all_available_resources: message = ray.gcs_utils.AvailableResources.FromString( available_resource) # Calculate available resources for this node. dynamic_resources = {} for resource_id, capacity in \ message.resources_available.items(): dynamic_resources[resource_id] = capacity # Update available resources for this node. node_id = ray.utils.binary_to_hex(message.node_id) available_resources_by_id[node_id] = dynamic_resources # Update nodes in cluster. node_ids = self._live_client_ids() # Remove disconnected nodes. for node_id in available_resources_by_id.keys(): if node_id not in node_ids: del available_resources_by_id[node_id] return available_resources_by_id def available_resources(self): """Get the current available cluster resources. This is different from `cluster_resources` in that this will return idle (available) resources rather than total resources. Note that this information can grow stale as tasks start and finish. Returns: A dictionary mapping resource name to the total quantity of that resource in the cluster. """ self._check_connected() available_resources_by_id = self._available_resources_per_node() # Calculate total available resources. total_available_resources = defaultdict(int) for available_resources in available_resources_by_id.values(): for resource_id, num_available in available_resources.items(): total_available_resources[resource_id] += num_available return dict(total_available_resources) def actor_checkpoint_info(self, actor_id): """Get checkpoint info for the given actor id. Args: actor_id: Actor's ID. Returns: A dictionary with information about the actor's checkpoint IDs and their timestamps. """ self._check_connected() message = self._execute_command( actor_id, "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ACTOR_CHECKPOINT_ID"), "", actor_id.binary(), ) if message is None: return None gcs_entry = gcs_utils.GcsEntry.FromString(message) entry = gcs_utils.ActorCheckpointIdData.FromString( gcs_entry.entries[0]) checkpoint_ids = [ ray.ActorCheckpointID(checkpoint_id) for checkpoint_id in entry.checkpoint_ids ] return { "ActorID": ray.utils.binary_to_hex(entry.actor_id), "CheckpointIds": checkpoint_ids, "Timestamps": list(entry.timestamps), }
def _really_init_global_state(self, timeout=20): self.global_state_accessor = GlobalStateAccessor( self.redis_address, self.redis_password) self.global_state_accessor.connect()