class Monitor: """A monitor for Ray processes. The monitor is in charge of cleaning up the tables in the global state after processes have died. The monitor is currently not responsible for detecting component failures. Attributes: redis: A connection to the Redis server. primary_subscribe_client: A pubsub client for the Redis server. This is used to receive notifications about failed components. """ def __init__(self, redis_address, autoscaling_config, redis_password=None): # Initialize the Redis clients. ray.state.state._initialize_global_state( redis_address, redis_password=redis_password) self.redis = ray.services.create_redis_client( redis_address, password=redis_password) # Setup subscriptions to the primary Redis server and the Redis shards. self.primary_subscribe_client = self.redis.pubsub( ignore_subscribe_messages=True) # Keep a mapping from raylet client ID to IP address to use # for updating the load metrics. self.raylet_id_to_ip_map = {} self.load_metrics = LoadMetrics() if autoscaling_config: self.autoscaler = StandardAutoscaler(autoscaling_config, self.load_metrics) self.autoscaling_config = autoscaling_config else: self.autoscaler = None self.autoscaling_config = None def __del__(self): """Destruct the monitor object.""" # We close the pubsub client to avoid leaking file descriptors. self.primary_subscribe_client.close() def subscribe(self, channel): """Subscribe to the given channel on the primary Redis shard. Args: channel (str): The channel to subscribe to. Raises: Exception: An exception is raised if the subscription fails. """ self.primary_subscribe_client.subscribe(channel) def xray_heartbeat_batch_handler(self, unused_channel, data): """Handle an xray heartbeat batch message from Redis.""" gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) heartbeat_data = gcs_entries.entries[0] message = ray.gcs_utils.HeartbeatBatchTableData.FromString( heartbeat_data) for heartbeat_message in message.batch: resource_load = dict( zip(heartbeat_message.resource_load_label, heartbeat_message.resource_load_capacity)) total_resources = dict( zip(heartbeat_message.resources_total_label, heartbeat_message.resources_total_capacity)) available_resources = dict( zip(heartbeat_message.resources_available_label, heartbeat_message.resources_available_capacity)) for resource in total_resources: available_resources.setdefault(resource, 0.0) # Update the load metrics for this raylet. client_id = ray.utils.binary_to_hex(heartbeat_message.client_id) ip = self.raylet_id_to_ip_map.get(client_id) if ip: self.load_metrics.update(ip, total_resources, available_resources, resource_load) else: logger.warning( "Monitor: " "could not find ip for client {}".format(client_id)) def _xray_clean_up_entries_for_job(self, job_id): """Remove this job's object/task entries from redis. Removes control-state entries of all tasks and task return objects belonging to the driver. Args: job_id: The job id. """ xray_task_table_prefix = ( ray.gcs_utils.TablePrefix_RAYLET_TASK_string.encode("ascii")) xray_object_table_prefix = ( ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii")) task_table_objects = ray.tasks() job_id_hex = binary_to_hex(job_id) job_task_id_bins = set() for task_id_hex, task_info in task_table_objects.items(): task_table_object = task_info["TaskSpec"] task_job_id_hex = task_table_object["JobID"] if job_id_hex != task_job_id_hex: # Ignore tasks that aren't from this driver. continue job_task_id_bins.add(hex_to_binary(task_id_hex)) # Get objects associated with the driver. object_table_objects = ray.objects() job_object_id_bins = set() for object_id, _ in object_table_objects.items(): task_id_bin = ray._raylet.compute_task_id(object_id).binary() if task_id_bin in job_task_id_bins: job_object_id_bins.add(object_id.binary()) def to_shard_index(id_bin): if len(id_bin) == ray.TaskID.size(): return binary_to_task_id(id_bin).redis_shard_hash() % len( ray.state.state.redis_clients) else: return binary_to_object_id(id_bin).redis_shard_hash() % len( ray.state.state.redis_clients) # Form the redis keys to delete. sharded_keys = [[] for _ in range(len(ray.state.state.redis_clients))] for task_id_bin in job_task_id_bins: sharded_keys[to_shard_index(task_id_bin)].append( xray_task_table_prefix + task_id_bin) for object_id_bin in job_object_id_bins: sharded_keys[to_shard_index(object_id_bin)].append( xray_object_table_prefix + object_id_bin) # Remove with best effort. for shard_index in range(len(sharded_keys)): keys = sharded_keys[shard_index] if len(keys) == 0: continue redis = ray.state.state.redis_clients[shard_index] num_deleted = redis.delete(*keys) logger.info("Monitor: " "Removed {} dead redis entries of the " "driver from redis shard {}.".format( num_deleted, shard_index)) if num_deleted != len(keys): logger.warning("Monitor: " "Failed to remove {} relevant redis " "entries from redis shard {}.".format( len(keys) - num_deleted, shard_index)) def xray_job_notification_handler(self, unused_channel, data): """Handle a notification that a job has been added or removed. Args: unused_channel: The message channel. data: The message data. """ gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) job_data = gcs_entries.entries[0] message = ray.gcs_utils.JobTableData.FromString(job_data) job_id = message.job_id if message.is_dead: logger.info("Monitor: " "XRay Driver {} has been removed.".format( binary_to_hex(job_id))) self._xray_clean_up_entries_for_job(job_id) def autoscaler_resource_request_handler(self, _, data): """Handle a notification of a resource request for the autoscaler. Args: channel: unused data: a resource request as JSON, e.g. {"CPU": 1} """ if not self.autoscaler: return try: self.autoscaler.request_resources(json.loads(data)) except Exception: # We don't want this to kill the monitor. traceback.print_exc() def process_messages(self, max_messages=10000): """Process all messages ready in the subscription channels. This reads messages from the subscription channels and calls the appropriate handlers until there are no messages left. Args: max_messages: The maximum number of messages to process before returning. """ subscribe_clients = [self.primary_subscribe_client] for subscribe_client in subscribe_clients: for _ in range(max_messages): message = subscribe_client.get_message() if message is None: # Continue on to the next subscribe client. break # Parse the message. channel = message["channel"] data = message["data"] # Determine the appropriate message handler. if channel == ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL: # Similar functionality as raylet info channel message_handler = self.xray_heartbeat_batch_handler elif channel == ray.gcs_utils.XRAY_JOB_CHANNEL: # Handles driver death. message_handler = self.xray_job_notification_handler elif (channel == ray.ray_constants.AUTOSCALER_RESOURCE_REQUEST_CHANNEL): message_handler = self.autoscaler_resource_request_handler else: assert False, "This code should be unreachable." # Call the handler. message_handler(channel, data) def update_raylet_map(self, _append_port=False): """Updates internal raylet map. Args: _append_port (bool): Defaults to False. Appending the port is useful in testing, as mock clusters have many nodes with the same IP and cannot be uniquely identified. """ all_raylet_nodes = ray.nodes() self.raylet_id_to_ip_map = {} for raylet_info in all_raylet_nodes: node_id = (raylet_info.get("DBClientID") or raylet_info["NodeID"]) ip_address = (raylet_info.get("AuxAddress") or raylet_info["NodeManagerAddress"]).split(":")[0] if _append_port: ip_address += ":" + str(raylet_info["NodeManagerPort"]) self.raylet_id_to_ip_map[node_id] = ip_address def _run(self): """Run the monitor. This function loops forever, checking for messages about dead database clients and cleaning up state accordingly. """ # Initialize the subscription channel. self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL) self.subscribe(ray.gcs_utils.XRAY_JOB_CHANNEL) if self.autoscaler: self.subscribe( ray.ray_constants.AUTOSCALER_RESOURCE_REQUEST_CHANNEL) # TODO(rkn): If there were any dead clients at startup, we should clean # up the associated state in the state tables. # Handle messages from the subscription channels. while True: # Update the mapping from raylet client ID to IP address. # This is only used to update the load metrics for the autoscaler. self.update_raylet_map() # Process autoscaling actions if self.autoscaler: self.autoscaler.update() # Process a round of messages. self.process_messages() # Wait for a heartbeat interval before processing the next round of # messages. time.sleep( ray._config.raylet_heartbeat_timeout_milliseconds() * 1e-3) def destroy_autoscaler_workers(self): """Cleanup the autoscaler, in case of an exception in the run() method. We kill the worker nodes, but retain the head node in order to keep logs around, keeping costs minimal. This monitor process runs on the head node anyway, so this is more reliable.""" if self.autoscaler is None: return # Nothing to clean up. if self.autoscaling_config is None: # This is a logic error in the program. Can't do anything. logger.error( "Monitor: Cleanup failed due to lack of autoscaler config.") return logger.info("Monitor: Exception caught. Taking down workers...") clean = False while not clean: try: teardown_cluster( config_file=self.autoscaling_config, yes=True, # Non-interactive. workers_only=True, # Retain head node for logs. override_cluster_name=None, keep_min_workers=True, # Retain minimal amount of workers. ) clean = True logger.info("Monitor: Workers taken down.") except Exception: logger.error("Monitor: Cleanup exception. Trying again...") time.sleep(2) def run(self): try: self._run() except Exception: logger.exception("Error in monitor loop") if self.autoscaler: self.autoscaler.kill_workers() raise
class Monitor(object): """A monitor for Ray processes. The monitor is in charge of cleaning up the tables in the global state after processes have died. The monitor is currently not responsible for detecting component failures. Attributes: redis: A connection to the Redis server. primary_subscribe_client: A pubsub client for the Redis server. This is used to receive notifications about failed components. """ def __init__(self, redis_address, autoscaling_config, redis_password=None): # Initialize the Redis clients. ray.state.state._initialize_global_state( args.redis_address, redis_password=redis_password) self.redis = ray.services.create_redis_client( redis_address, password=redis_password) # Setup subscriptions to the primary Redis server and the Redis shards. self.primary_subscribe_client = self.redis.pubsub( ignore_subscribe_messages=True) # Keep a mapping from raylet client ID to IP address to use # for updating the load metrics. self.raylet_id_to_ip_map = {} self.load_metrics = LoadMetrics() if autoscaling_config: self.autoscaler = StandardAutoscaler(autoscaling_config, self.load_metrics) else: self.autoscaler = None # Experimental feature: GCS flushing. self.issue_gcs_flushes = "RAY_USE_NEW_GCS" in os.environ self.gcs_flush_policy = None if self.issue_gcs_flushes: # Data is stored under the first data shard, so we issue flushes to # that redis server. addr_port = self.redis.lrange("RedisShards", 0, -1) if len(addr_port) > 1: logger.warning( "Monitor: " "TODO: if launching > 1 redis shard, flushing needs to " "touch shards in parallel.") self.issue_gcs_flushes = False else: addr_port = addr_port[0].split(b":") self.redis_shard = redis.StrictRedis( host=addr_port[0], port=addr_port[1], password=redis_password) try: self.redis_shard.execute_command("HEAD.FLUSH 0") except redis.exceptions.ResponseError as e: logger.info( "Monitor: " "Turning off flushing due to exception: {}".format( str(e))) self.issue_gcs_flushes = False def __del__(self): """Destruct the monitor object.""" # We close the pubsub client to avoid leaking file descriptors. self.primary_subscribe_client.close() def subscribe(self, channel): """Subscribe to the given channel on the primary Redis shard. Args: channel (str): The channel to subscribe to. Raises: Exception: An exception is raised if the subscription fails. """ self.primary_subscribe_client.subscribe(channel) def xray_heartbeat_batch_handler(self, unused_channel, data): """Handle an xray heartbeat batch message from Redis.""" gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) heartbeat_data = gcs_entries.entries[0] message = ray.gcs_utils.HeartbeatBatchTableData.FromString( heartbeat_data) for heartbeat_message in message.batch: num_resources = len(heartbeat_message.resources_available_label) static_resources = {} dynamic_resources = {} for i in range(num_resources): dyn = heartbeat_message.resources_available_label[i] static = heartbeat_message.resources_total_label[i] dynamic_resources[dyn] = ( heartbeat_message.resources_available_capacity[i]) static_resources[static] = ( heartbeat_message.resources_total_capacity[i]) # Update the load metrics for this raylet. client_id = ray.utils.binary_to_hex(heartbeat_message.client_id) ip = self.raylet_id_to_ip_map.get(client_id) if ip: self.load_metrics.update(ip, static_resources, dynamic_resources) else: logger.warning( "Monitor: " "could not find ip for client {}".format(client_id)) def _xray_clean_up_entries_for_job(self, job_id): """Remove this job's object/task entries from redis. Removes control-state entries of all tasks and task return objects belonging to the driver. Args: job_id: The job id. """ xray_task_table_prefix = ( ray.gcs_utils.TablePrefix_RAYLET_TASK_string.encode("ascii")) xray_object_table_prefix = ( ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii")) task_table_objects = ray.tasks() job_id_hex = binary_to_hex(job_id) job_task_id_bins = set() for task_id_hex, task_info in task_table_objects.items(): task_table_object = task_info["TaskSpec"] task_job_id_hex = task_table_object["JobID"] if job_id_hex != task_job_id_hex: # Ignore tasks that aren't from this driver. continue job_task_id_bins.add(hex_to_binary(task_id_hex)) # Get objects associated with the driver. object_table_objects = ray.objects() job_object_id_bins = set() for object_id, _ in object_table_objects.items(): task_id_bin = ray._raylet.compute_task_id(object_id).binary() if task_id_bin in job_task_id_bins: job_object_id_bins.add(object_id.binary()) def to_shard_index(id_bin): if len(id_bin) == ray.TaskID.size(): return binary_to_task_id(id_bin).redis_shard_hash() % len( ray.state.state.redis_clients) else: return binary_to_object_id(id_bin).redis_shard_hash() % len( ray.state.state.redis_clients) # Form the redis keys to delete. sharded_keys = [[] for _ in range(len(ray.state.state.redis_clients))] for task_id_bin in job_task_id_bins: sharded_keys[to_shard_index(task_id_bin)].append( xray_task_table_prefix + task_id_bin) for object_id_bin in job_object_id_bins: sharded_keys[to_shard_index(object_id_bin)].append( xray_object_table_prefix + object_id_bin) # Remove with best effort. for shard_index in range(len(sharded_keys)): keys = sharded_keys[shard_index] if len(keys) == 0: continue redis = ray.state.state.redis_clients[shard_index] num_deleted = redis.delete(*keys) logger.info("Monitor: " "Removed {} dead redis entries of the " "driver from redis shard {}.".format( num_deleted, shard_index)) if num_deleted != len(keys): logger.warning("Monitor: " "Failed to remove {} relevant redis " "entries from redis shard {}.".format( len(keys) - num_deleted, shard_index)) def xray_job_notification_handler(self, unused_channel, data): """Handle a notification that a job has been added or removed. Args: unused_channel: The message channel. data: The message data. """ gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) job_data = gcs_entries.entries[0] message = ray.gcs_utils.JobTableData.FromString(job_data) job_id = message.job_id if message.is_dead: logger.info("Monitor: " "XRay Driver {} has been removed.".format( binary_to_hex(job_id))) self._xray_clean_up_entries_for_job(job_id) def process_messages(self, max_messages=10000): """Process all messages ready in the subscription channels. This reads messages from the subscription channels and calls the appropriate handlers until there are no messages left. Args: max_messages: The maximum number of messages to process before returning. """ subscribe_clients = [self.primary_subscribe_client] for subscribe_client in subscribe_clients: for _ in range(max_messages): message = subscribe_client.get_message() if message is None: # Continue on to the next subscribe client. break # Parse the message. channel = message["channel"] data = message["data"] # Determine the appropriate message handler. if channel == ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL: # Similar functionality as raylet info channel message_handler = self.xray_heartbeat_batch_handler elif channel == ray.gcs_utils.XRAY_JOB_CHANNEL: # Handles driver death. message_handler = self.xray_job_notification_handler else: raise Exception("This code should be unreachable.") # Call the handler. message_handler(channel, data) def update_raylet_map(self): all_raylet_nodes = ray.nodes() self.raylet_id_to_ip_map = {} for raylet_info in all_raylet_nodes: client_id = (raylet_info.get("DBClientID") or raylet_info["ClientID"]) ip_address = (raylet_info.get("AuxAddress") or raylet_info["NodeManagerAddress"]).split(":")[0] self.raylet_id_to_ip_map[client_id] = ip_address def _maybe_flush_gcs(self): """Experimental: issue a flush request to the GCS. The purpose of this feature is to control GCS memory usage. To activate this feature, Ray must be compiled with the flag RAY_USE_NEW_GCS set, and Ray must be started at run time with the flag as well. """ if not self.issue_gcs_flushes: return if self.gcs_flush_policy is None: serialized = self.redis.get("gcs_flushing_policy") if serialized is None: # Client has not set any policy; by default flushing is off. return self.gcs_flush_policy = pickle.loads(serialized) if not self.gcs_flush_policy.should_flush(self.redis_shard): return max_entries_to_flush = self.gcs_flush_policy.num_entries_to_flush() num_flushed = self.redis_shard.execute_command( "HEAD.FLUSH {}".format(max_entries_to_flush)) logger.info("Monitor: num_flushed {}".format(num_flushed)) # This flushes event log and log files. ray.experimental.flush_redis_unsafe(self.redis) self.gcs_flush_policy.record_flush() def _run(self): """Run the monitor. This function loops forever, checking for messages about dead database clients and cleaning up state accordingly. """ # Initialize the subscription channel. self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL) self.subscribe(ray.gcs_utils.XRAY_JOB_CHANNEL) # TODO(rkn): If there were any dead clients at startup, we should clean # up the associated state in the state tables. # Handle messages from the subscription channels. while True: # Update the mapping from raylet client ID to IP address. # This is only used to update the load metrics for the autoscaler. self.update_raylet_map() # Process autoscaling actions if self.autoscaler: self.autoscaler.update() self._maybe_flush_gcs() # Process a round of messages. self.process_messages() # Wait for a heartbeat interval before processing the next round of # messages. time.sleep(ray._config.heartbeat_timeout_milliseconds() * 1e-3) def run(self): try: self._run() except Exception: if self.autoscaler: self.autoscaler.kill_workers() raise
class Monitor: """A monitor for Ray processes. The monitor is in charge of cleaning up the tables in the global state after processes have died. The monitor is currently not responsible for detecting component failures. Attributes: redis: A connection to the Redis server. primary_subscribe_client: A pubsub client for the Redis server. This is used to receive notifications about failed components. """ def __init__(self, redis_address, autoscaling_config, redis_password=None): # Initialize the Redis clients. ray.state.state._initialize_global_state(redis_address, redis_password=redis_password) self.redis = ray.services.create_redis_client(redis_address, password=redis_password) # Set the redis client and mode so _internal_kv works for autoscaler. worker = ray.worker.global_worker worker.redis_client = self.redis worker.mode = 0 # Setup subscriptions to the primary Redis server and the Redis shards. self.primary_subscribe_client = self.redis.pubsub( ignore_subscribe_messages=True) # Keep a mapping from raylet client ID to IP address to use # for updating the load metrics. self.raylet_id_to_ip_map = {} self.load_metrics = LoadMetrics() if autoscaling_config: self.autoscaler = StandardAutoscaler(autoscaling_config, self.load_metrics) self.autoscaling_config = autoscaling_config else: self.autoscaler = None self.autoscaling_config = None def __del__(self): """Destruct the monitor object.""" # We close the pubsub client to avoid leaking file descriptors. try: primary_subscribe_client = self.primary_subscribe_client except AttributeError: primary_subscribe_client = None if primary_subscribe_client is not None: primary_subscribe_client.close() def subscribe(self, channel): """Subscribe to the given channel on the primary Redis shard. Args: channel (str): The channel to subscribe to. Raises: Exception: An exception is raised if the subscription fails. """ self.primary_subscribe_client.subscribe(channel) def psubscribe(self, pattern): """Subscribe to the given pattern on the primary Redis shard. Args: pattern (str): The pattern to subscribe to. Raises: Exception: An exception is raised if the subscription fails. """ self.primary_subscribe_client.psubscribe(pattern) def handle_resource_demands(self, resource_load_by_shape): """Handle the message.resource_load_by_shape protobuf for the demand based autoscaling. Catch and log all exceptions so this doesn't interfere with the utilization based autoscaler until we're confident this is stable. Args: resource_load_by_shape (pb2.gcs.ResourceLoad): The resource demands in protobuf form or None. """ try: if not self.autoscaler: return bundles = [] for resource_demand_pb in list( resource_load_by_shape.resource_demands): request_shape = dict(resource_demand_pb.shape) bundles.append(request_shape) self.autoscaler.request_resources(bundles) except Exception as e: logger.exception(e) def xray_heartbeat_batch_handler(self, unused_channel, data): """Handle an xray heartbeat batch message from Redis.""" pub_message = ray.gcs_utils.PubSubMessage.FromString(data) heartbeat_data = pub_message.data message = ray.gcs_utils.HeartbeatBatchTableData.FromString( heartbeat_data) for heartbeat_message in message.batch: resource_load = dict(heartbeat_message.resource_load) total_resources = dict(heartbeat_message.resources_total) available_resources = dict(heartbeat_message.resources_available) for resource in total_resources: available_resources.setdefault(resource, 0.0) # Update the load metrics for this raylet. client_id = ray.utils.binary_to_hex(heartbeat_message.client_id) ip = self.raylet_id_to_ip_map.get(client_id) if ip: self.load_metrics.update(ip, total_resources, available_resources, resource_load) else: logger.warning( f"Monitor: could not find ip for client {client_id}") self.handle_resource_demands(message.resource_load_by_shape) def xray_job_notification_handler(self, unused_channel, data): """Handle a notification that a job has been added or removed. Args: unused_channel: The message channel. data: The message data. """ pub_message = ray.gcs_utils.PubSubMessage.FromString(data) job_data = pub_message.data message = ray.gcs_utils.JobTableData.FromString(job_data) job_id = message.job_id if message.is_dead: logger.info("Monitor: " "XRay Driver {} has been removed.".format( binary_to_hex(job_id))) def autoscaler_resource_request_handler(self, _, data): """Handle a notification of a resource request for the autoscaler. Args: channel: unused data: a resource request as JSON, e.g. {"CPU": 1} """ if not self.autoscaler: return try: self.autoscaler.request_resources(json.loads(data)) except Exception: # We don't want this to kill the monitor. traceback.print_exc() def process_messages(self, max_messages=10000): """Process all messages ready in the subscription channels. This reads messages from the subscription channels and calls the appropriate handlers until there are no messages left. Args: max_messages: The maximum number of messages to process before returning. """ subscribe_clients = [self.primary_subscribe_client] for subscribe_client in subscribe_clients: for _ in range(max_messages): message = None try: message = subscribe_client.get_message() except redis.exceptions.ConnectionError: pass if message is None: # Continue on to the next subscribe client. break # Parse the message. pattern = message["pattern"] channel = message["channel"] data = message["data"] # Determine the appropriate message handler. if pattern == ray.gcs_utils.XRAY_HEARTBEAT_BATCH_PATTERN: # Similar functionality as raylet info channel message_handler = self.xray_heartbeat_batch_handler elif pattern == ray.gcs_utils.XRAY_JOB_PATTERN: # Handles driver death. message_handler = self.xray_job_notification_handler elif (channel == ray.ray_constants.AUTOSCALER_RESOURCE_REQUEST_CHANNEL): message_handler = self.autoscaler_resource_request_handler else: assert False, "This code should be unreachable." # Call the handler. message_handler(channel, data) def update_raylet_map(self, _append_port=False): """Updates internal raylet map. Args: _append_port (bool): Defaults to False. Appending the port is useful in testing, as mock clusters have many nodes with the same IP and cannot be uniquely identified. """ all_raylet_nodes = ray.nodes() self.raylet_id_to_ip_map = {} for raylet_info in all_raylet_nodes: node_id = (raylet_info.get("DBClientID") or raylet_info["NodeID"]) ip_address = (raylet_info.get("AuxAddress") or raylet_info["NodeManagerAddress"]).split(":")[0] if _append_port: ip_address += ":" + str(raylet_info["NodeManagerPort"]) self.raylet_id_to_ip_map[node_id] = ip_address def _run(self): """Run the monitor. This function loops forever, checking for messages about dead database clients and cleaning up state accordingly. """ # Initialize the mapping from raylet client ID to IP address. self.update_raylet_map() # Initialize the subscription channel. self.psubscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_PATTERN) self.psubscribe(ray.gcs_utils.XRAY_JOB_PATTERN) if self.autoscaler: self.subscribe( ray.ray_constants.AUTOSCALER_RESOURCE_REQUEST_CHANNEL) # TODO(rkn): If there were any dead clients at startup, we should clean # up the associated state in the state tables. # Handle messages from the subscription channels. while True: # Process autoscaling actions if self.autoscaler: # Only used to update the load metrics for the autoscaler. self.update_raylet_map() self.autoscaler.update() # Process a round of messages. self.process_messages() # Wait for a heartbeat interval before processing the next round of # messages. time.sleep(ray._config.raylet_heartbeat_timeout_milliseconds() * 1e-3) def destroy_autoscaler_workers(self): """Cleanup the autoscaler, in case of an exception in the run() method. We kill the worker nodes, but retain the head node in order to keep logs around, keeping costs minimal. This monitor process runs on the head node anyway, so this is more reliable.""" if self.autoscaler is None: return # Nothing to clean up. if self.autoscaling_config is None: # This is a logic error in the program. Can't do anything. logger.error( "Monitor: Cleanup failed due to lack of autoscaler config.") return logger.info("Monitor: Exception caught. Taking down workers...") clean = False while not clean: try: teardown_cluster( config_file=self.autoscaling_config, yes=True, # Non-interactive. workers_only=True, # Retain head node for logs. override_cluster_name=None, keep_min_workers=True, # Retain minimal amount of workers. ) clean = True logger.info("Monitor: Workers taken down.") except Exception: logger.error("Monitor: Cleanup exception. Trying again...") time.sleep(2) def run(self): try: self._run() except Exception: logger.exception("Error in monitor loop") if self.autoscaler: self.autoscaler.kill_workers() raise