def _handle_failure(self, error): logger.exception("Error in monitor loop") if self.autoscaler is not None and \ os.environ.get("RAY_AUTOSCALER_FATESHARE_WORKERS", "") == "1": self.autoscaler.kill_workers() # Take down autoscaler workers if necessary. self.destroy_autoscaler_workers() # Something went wrong, so push an error to all current and future # drivers. message = f"The autoscaler failed with the following error:\n{error}" if _internal_kv_initialized(): _internal_kv_put(DEBUG_AUTOSCALING_ERROR, message, overwrite=True) redis_client = ray._private.services.create_redis_client( self.redis_address, password=self.redis_password) gcs_publisher = None if args.gcs_address: gcs_publisher = GcsPublisher(address=args.gcs_address) elif gcs_pubsub_enabled(): gcs_publisher = GcsPublisher( address=get_gcs_address_from_redis(redis_client)) from ray._private.utils import publish_error_to_driver publish_error_to_driver( ray_constants.MONITOR_DIED_ERROR, message, redis_client=redis_client, gcs_publisher=gcs_publisher)
def use_gcs_for_bootstrap(): from ray._private.gcs_pubsub import gcs_pubsub_enabled from ray._raylet import Config ret = Config.bootstrap_with_gcs() if ret: assert gcs_pubsub_enabled() return ret
async def run(self, server): reporter_pb2_grpc.add_ReporterServiceServicer_to_server(self, server) if gcs_pubsub_enabled(): gcs_addr = self._dashboard_agent.gcs_address if gcs_addr is None: aioredis_client = await aioredis.create_redis_pool( address=self._dashboard_agent.redis_address, password=self._dashboard_agent.redis_password, ) gcs_addr = await aioredis_client.get("GcsServerAddress") gcs_addr = gcs_addr.decode() publisher = GcsAioPublisher(address=gcs_addr) async def publish(key: str, data: str): await publisher.publish_resource_usage(key, data) else: aioredis_client = await aioredis.create_redis_pool( address=self._dashboard_agent.redis_address, password=self._dashboard_agent.redis_password, ) async def publish(key: str, data: str): await aioredis_client.publish(key, data) await self._perform_iteration(publish)
def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard): assert wait_until_server_available( ray_start_with_dashboard["webui_url"]) is True all_processes = ray.worker._global_node.all_processes dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0] dashboard_proc = psutil.Process(dashboard_info.process.pid) gcs_server_info = all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER][0] gcs_server_proc = psutil.Process(gcs_server_info.process.pid) assert dashboard_proc.status() in [ psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING, psutil.STATUS_DISK_SLEEP, ] gcs_server_proc.kill() gcs_server_proc.wait() if gcs_pubsub_enabled(): # When pubsub enabled, the exits comes from pubsub errored. # TODO: Fix this exits logic for pubsub assert dashboard_proc.wait(10) != 0 else: # The dashboard exits by os._exit(-1) assert dashboard_proc.wait(10) == 255
def use_gcs_for_bootstrap(): import os from ray._private.gcs_pubsub import gcs_pubsub_enabled ret = os.environ.get("RAY_bootstrap_with_gcs") not in (None, "0", "false") if ret: assert gcs_pubsub_enabled() return ret
def init_log_pubsub(): """Initialize redis error info pub/sub""" if gcs_pubsub_enabled(): s = GcsLogSubscriber(address=ray.worker.global_worker.gcs_client.address) s.subscribe() else: s = ray.worker.global_worker.redis_client.pubsub(ignore_subscribe_messages=True) s.psubscribe(gcs_utils.LOG_FILE_CHANNEL) return s
def init_error_pubsub(): """Initialize redis error info pub/sub""" if gcs_pubsub_enabled(): s = GcsErrorSubscriber(address=ray.worker.global_worker.gcs_client.address) s.subscribe() else: s = ray.worker.global_worker.redis_client.pubsub(ignore_subscribe_messages=True) s.psubscribe(gcs_utils.RAY_ERROR_PUBSUB_PATTERN) return s
def __init__(self, logs_dir, redis_address, redis_password=None): """Initialize the log monitor object.""" self.ip = services.get_node_ip_address() self.logs_dir = logs_dir self.redis_client = ray._private.services.create_redis_client( redis_address, password=redis_password) self.publisher = None if gcs_pubsub.gcs_pubsub_enabled(): gcs_addr = gcs_utils.get_gcs_address_from_redis(self.redis_client) self.publisher = gcs_pubsub.GcsPublisher(address=gcs_addr) self.log_filenames = set() self.open_file_infos = [] self.closed_file_infos = [] self.can_open_more_files = True
def handle_pub_messages(msgs, timeout, expect_num): start_time = time.time() while time.time() - start_time < timeout and len(msgs) < expect_num: if gcs_pubsub.gcs_pubsub_enabled(): _, actor_data = sub.poll(timeout=timeout) else: msg = sub.get_message() if msg is None: time.sleep(0.01) continue pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"]) actor_data = gcs_utils.ActorTableData.FromString( pubsub_msg.data) if actor_data is None: continue msgs.append(actor_data)
def test_publish_error_to_driver(ray_start_regular, error_pubsub): address_info = ray_start_regular address = address_info["redis_address"] redis_client = ray._private.services.create_redis_client( address, password=ray.ray_constants.REDIS_DEFAULT_PASSWORD) gcs_publisher = None if gcs_pubsub_enabled(): gcs_publisher = GcsPublisher( address=gcs_utils.get_gcs_address_from_redis(redis_client)) error_message = "Test error message" ray._private.utils.publish_error_to_driver( ray_constants.DASHBOARD_AGENT_DIED_ERROR, error_message, redis_client=redis_client, gcs_publisher=gcs_publisher) errors = get_error_message(error_pubsub, 1, ray_constants.DASHBOARD_AGENT_DIED_ERROR) assert errors[0].type == ray_constants.DASHBOARD_AGENT_DIED_ERROR assert errors[0].error_message == error_message
async def run(self, server): # Need daemon True to avoid dashboard hangs at exit. self.service_discovery.daemon = True self.service_discovery.start() if gcs_pubsub_enabled(): gcs_addr = self._dashboard_head.gcs_address subscriber = GcsAioResourceUsageSubscriber(gcs_addr) await subscriber.subscribe() while True: try: # The key is b'RAY_REPORTER:{node id hex}', # e.g. b'RAY_REPORTER:2b4fbd...' key, data = await subscriber.poll() if key is None: continue data = json.loads(data) node_id = key.split(":")[-1] DataSource.node_physical_stats[node_id] = data except Exception: logger.exception("Error receiving node physical stats " "from reporter agent.") else: from aioredis.pubsub import Receiver receiver = Receiver() aioredis_client = self._dashboard_head.aioredis_client reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX) await aioredis_client.psubscribe(receiver.pattern(reporter_key)) logger.info(f"Subscribed to {reporter_key}") async for sender, msg in receiver.iter(): try: key, data = msg data = json.loads(ray._private.utils.decode(data)) key = key.decode("utf-8") node_id = key.split(":")[-1] DataSource.node_physical_stats[node_id] = data except Exception: logger.exception("Error receiving node physical stats " "from reporter agent.")
def test_function_unique_export(ray_start_regular): @ray.remote def f(): pass @ray.remote def g(): ray.get(f.remote()) if gcs_pubsub_enabled(): subscriber = GcsFunctionKeySubscriber( channel=ray.worker.global_worker.gcs_channel.channel()) subscriber.subscribe() ray.get(g.remote()) # Poll pubsub channel for messages generated from running task g(). num_exports = 0 while True: key = subscriber.poll(timeout=1) if key is None: break else: num_exports += 1 print(f"num_exports after running g(): {num_exports}") ray.get([g.remote() for _ in range(5)]) key = subscriber.poll(timeout=1) assert key is None, f"Unexpected function key export: {key}" else: ray.get(g.remote()) num_exports = ray.worker.global_worker.redis_client.llen("Exports") ray.get([g.remote() for _ in range(5)]) assert ray.worker.global_worker.redis_client.llen("Exports") == \ num_exports
loop = asyncio.get_event_loop() loop.run_until_complete(agent.run()) except Exception as e: # All these env vars should be available because # they are provided by the parent raylet. restart_count = os.environ["RESTART_COUNT"] max_restart_count = os.environ["MAX_RESTART_COUNT"] raylet_pid = os.environ["RAY_RAYLET_PID"] node_ip = args.node_ip_address if restart_count >= max_restart_count: # Agent is failed to be started many times. # Push an error to all drivers, so that users can know the # impact of the issue. redis_client = None gcs_publisher = None if gcs_pubsub_enabled(): if use_gcs_for_bootstrap(): gcs_publisher = GcsPublisher(args.gcs_address) else: redis_client = ray._private.services.create_redis_client( args.redis_address, password=args.redis_password) gcs_publisher = GcsPublisher( address=get_gcs_address_from_redis(redis_client)) else: redis_client = ray._private.services.create_redis_client( args.redis_address, password=args.redis_password) traceback_str = ray._private.utils.format_error_message( traceback.format_exc()) message = ( f"(ip={node_ip}) "
async def run(self): gcs_address = await self.get_gcs_address() # Dashboard will handle connection failure automatically self.gcs_client = GcsClient(address=gcs_address, nums_reconnect_retry=0) internal_kv._initialize_internal_kv(self.gcs_client) self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel( gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True) if gcs_pubsub_enabled(): self.gcs_error_subscriber = GcsAioErrorSubscriber( address=gcs_address) self.gcs_log_subscriber = GcsAioLogSubscriber(address=gcs_address) await self.gcs_error_subscriber.subscribe() await self.gcs_log_subscriber.subscribe() self.health_check_thread = GCSHealthCheckThread(gcs_address) self.health_check_thread.start() # Start a grpc asyncio server. await self.server.start() async def _async_notify(): """Notify signals from queue.""" while True: co = await dashboard_utils.NotifyQueue.get() try: await co except Exception: logger.exception(f"Error notifying coroutine {co}") modules = self._load_modules() http_host, http_port = self.http_host, self.http_port if not self.minimal: self.http_server = await self._configure_http_server(modules) http_host, http_port = self.http_server.get_address() internal_kv._internal_kv_put( ray_constants.DASHBOARD_ADDRESS, f"{http_host}:{http_port}", namespace=ray_constants.KV_NAMESPACE_DASHBOARD, ) # TODO: Use async version if performance is an issue # Write the dashboard head port to gcs kv. internal_kv._internal_kv_put( dashboard_consts.DASHBOARD_RPC_ADDRESS, f"{self.ip}:{self.grpc_port}", namespace=ray_constants.KV_NAMESPACE_DASHBOARD, ) # Freeze signal after all modules loaded. dashboard_utils.SignalManager.freeze() concurrent_tasks = [ self._gcs_check_alive(), _async_notify(), DataOrganizer.purge(), DataOrganizer.organize(), ] await asyncio.gather(*concurrent_tasks, *(m.run(self.server) for m in modules)) await self.server.wait_for_termination() if self.http_server: await self.http_server.cleanup()
assert result == 2 # Check whether actor1 is alive or not. # NOTE: We can't execute it immediately after gcs restarts # because it takes time for the worker to exit. result = ray.get(actor1.method.remote(7)) assert result == 9 @pytest.mark.parametrize("ray_start_regular_with_external_redis", [ generate_system_config_map(num_heartbeats_timeout=20, gcs_rpc_server_reconnect_timeout_s=60) ], indirect=True) @pytest.mark.skipif( gcs_pubsub.gcs_pubsub_enabled(), reason="GCS pubsub may lose messages after GCS restarts. Need to " "implement re-fetching state in GCS client.") def test_gcs_server_restart_during_actor_creation( ray_start_regular_with_external_redis): ids = [] # We reduce the number of actors because there are too many actors created # and `Too many open files` error will be thrown. for i in range(0, 20): actor = Increase.remote() ids.append(actor.method.remote(1)) ray.worker._global_node.kill_gcs_server() ray.worker._global_node.start_gcs_server() # The timeout seems too long.
def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard): timeout = 5 assert wait_until_server_available( ray_start_with_dashboard["webui_url"]) is True address_info = ray_start_with_dashboard if gcs_pubsub.gcs_pubsub_enabled(): sub = gcs_pubsub.GcsActorSubscriber( address=address_info["gcs_address"]) sub.subscribe() else: address = address_info["redis_address"] address = address.split(":") assert len(address) == 2 client = redis.StrictRedis( host=address[0], port=int(address[1]), password=ray_constants.REDIS_DEFAULT_PASSWORD, ) sub = client.pubsub(ignore_subscribe_messages=True) sub.psubscribe(gcs_utils.RAY_ACTOR_PUBSUB_PATTERN) @ray.remote class DummyActor: def __init__(self): pass # Create a dummy actor. a = DummyActor.remote() def handle_pub_messages(msgs, timeout, expect_num): start_time = time.time() while time.time() - start_time < timeout and len(msgs) < expect_num: if gcs_pubsub.gcs_pubsub_enabled(): _, actor_data = sub.poll(timeout=timeout) else: msg = sub.get_message() if msg is None: time.sleep(0.01) continue pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"]) actor_data = gcs_utils.ActorTableData.FromString( pubsub_msg.data) if actor_data is None: continue msgs.append(actor_data) msgs = [] handle_pub_messages(msgs, timeout, 3) # Assert we received published actor messages with state # DEPENDENCIES_UNREADY, PENDING_CREATION and ALIVE. assert len(msgs) == 3, msgs # Kill actor. ray.kill(a) handle_pub_messages(msgs, timeout, 4) # Assert we received published actor messages with state DEAD. assert len(msgs) == 4 def actor_table_data_to_dict(message): return dashboard_utils.message_to_dict( message, { "actorId", "parentId", "jobId", "workerId", "rayletId", "actorCreationDummyObjectId", "callerId", "taskId", "parentTaskId", "sourceActorId", "placementGroupId", }, including_default_value_fields=False, ) non_state_keys = ("actorId", "jobId", "taskSpec") for msg in msgs: actor_data_dict = actor_table_data_to_dict(msg) # DEPENDENCIES_UNREADY is 0, which would not be kept in dict. We # need check its original value. if msg.state == 0: assert len(actor_data_dict) > 5 for k in non_state_keys: assert k in actor_data_dict # For status that is not DEPENDENCIES_UNREADY, only states fields will # be published. elif actor_data_dict["state"] in ("ALIVE", "DEAD"): assert actor_data_dict.keys() >= { "state", "address", "timestamp", "pid", "rayNamespace", } elif actor_data_dict["state"] == "PENDING_CREATION": assert actor_data_dict.keys() == { "state", "address", "actorId", "actorCreationDummyObjectId", "jobId", "ownerAddress", "taskSpec", "className", "serializedRuntimeEnv", "rayNamespace", } else: raise Exception("Unknown state: {}".format( actor_data_dict["state"]))
async def run(self): # Create a http session for all modules. # aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"): self.http_session = aiohttp.ClientSession( loop=asyncio.get_event_loop()) else: self.http_session = aiohttp.ClientSession() gcs_address = await self.get_gcs_address() # Dashboard will handle connection failure automatically self.gcs_client = GcsClient(address=gcs_address, nums_reconnect_retry=0) internal_kv._initialize_internal_kv(self.gcs_client) self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel( gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True) if gcs_pubsub_enabled(): self.gcs_error_subscriber = GcsAioErrorSubscriber( address=gcs_address) self.gcs_log_subscriber = GcsAioLogSubscriber(address=gcs_address) await self.gcs_error_subscriber.subscribe() await self.gcs_log_subscriber.subscribe() self.health_check_thread = GCSHealthCheckThread(gcs_address) self.health_check_thread.start() # Start a grpc asyncio server. await self.server.start() async def _async_notify(): """Notify signals from queue.""" while True: co = await dashboard_utils.NotifyQueue.get() try: await co except Exception: logger.exception(f"Error notifying coroutine {co}") modules = self._load_modules() # Http server should be initialized after all modules loaded. # working_dir uploads for job submission can be up to 100MiB. app = aiohttp.web.Application(client_max_size=100 * 1024**2) app.add_routes(routes=routes.bound_routes()) runner = aiohttp.web.AppRunner(app) await runner.setup() last_ex = None for i in range(1 + self.http_port_retries): try: site = aiohttp.web.TCPSite(runner, self.http_host, self.http_port) await site.start() break except OSError as e: last_ex = e self.http_port += 1 logger.warning("Try to use port %s: %s", self.http_port, e) else: raise Exception(f"Failed to find a valid port for dashboard after " f"{self.http_port_retries} retries: {last_ex}") http_host, http_port, *_ = site._server.sockets[0].getsockname() http_host = self.ip if ipaddress.ip_address( http_host).is_unspecified else http_host logger.info("Dashboard head http address: %s:%s", http_host, http_port) # TODO: Use async version if performance is an issue # Write the dashboard head port to gcs kv. internal_kv._internal_kv_put( ray_constants.DASHBOARD_ADDRESS, f"{http_host}:{http_port}", namespace=ray_constants.KV_NAMESPACE_DASHBOARD) internal_kv._internal_kv_put( dashboard_consts.DASHBOARD_RPC_ADDRESS, f"{self.ip}:{self.grpc_port}", namespace=ray_constants.KV_NAMESPACE_DASHBOARD) # Dump registered http routes. dump_routes = [ r for r in app.router.routes() if r.method != hdrs.METH_HEAD ] for r in dump_routes: logger.info(r) logger.info("Registered %s routes.", len(dump_routes)) # Freeze signal after all modules loaded. dashboard_utils.SignalManager.freeze() concurrent_tasks = [ self._gcs_check_alive(), _async_notify(), DataOrganizer.purge(), DataOrganizer.organize(), ] await asyncio.gather(*concurrent_tasks, *(m.run(self.server) for m in modules)) await self.server.wait_for_termination()
async def _update_actors(self): # Get all actor info. while True: try: logger.info("Getting all actor info from GCS.") request = gcs_service_pb2.GetAllActorInfoRequest() reply = await self._gcs_actor_info_stub.GetAllActorInfo( request, timeout=5) if reply.status.code == 0: actors = {} for message in reply.actor_table_data: actor_table_data = actor_table_data_to_dict(message) actors[actor_table_data["actorId"]] = actor_table_data # Update actors. DataSource.actors.reset(actors) # Update node actors and job actors. job_actors = {} node_actors = {} for actor_id, actor_table_data in actors.items(): job_id = actor_table_data["jobId"] node_id = actor_table_data["address"]["rayletId"] job_actors.setdefault(job_id, {})[actor_id] = actor_table_data # Update only when node_id is not Nil. if node_id != actor_consts.NIL_NODE_ID: node_actors.setdefault( node_id, {})[actor_id] = actor_table_data DataSource.job_actors.reset(job_actors) DataSource.node_actors.reset(node_actors) logger.info("Received %d actor info from GCS.", len(actors)) break else: raise Exception( f"Failed to GetAllActorInfo: {reply.status.message}") except Exception: logger.exception("Error Getting all actor info from GCS.") await asyncio.sleep( actor_consts.RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS) state_keys = ("state", "address", "numRestarts", "timestamp", "pid") def process_actor_data_from_pubsub(actor_id, actor_table_data): actor_table_data = actor_table_data_to_dict(actor_table_data) # If actor is not new registered but updated, we only update # states related fields. if actor_table_data["state"] != "DEPENDENCIES_UNREADY": actor_table_data_copy = dict(DataSource.actors[actor_id]) for k in state_keys: actor_table_data_copy[k] = actor_table_data[k] actor_table_data = actor_table_data_copy actor_id = actor_table_data["actorId"] job_id = actor_table_data["jobId"] node_id = actor_table_data["address"]["rayletId"] # Update actors. DataSource.actors[actor_id] = actor_table_data # Update node actors (only when node_id is not Nil). if node_id != actor_consts.NIL_NODE_ID: node_actors = dict(DataSource.node_actors.get(node_id, {})) node_actors[actor_id] = actor_table_data DataSource.node_actors[node_id] = node_actors # Update job actors. job_actors = dict(DataSource.job_actors.get(job_id, {})) job_actors[actor_id] = actor_table_data DataSource.job_actors[job_id] = job_actors # Receive actors from channel. if gcs_pubsub_enabled(): gcs_addr = await self._dashboard_head.get_gcs_address() subscriber = GcsAioActorSubscriber(address=gcs_addr) await subscriber.subscribe() while True: try: actor_id, actor_table_data = await subscriber.poll() if actor_id is not None: # Convert to lower case hex ID. actor_id = actor_id.hex() process_actor_data_from_pubsub(actor_id, actor_table_data) except Exception: logger.exception("Error processing actor info from GCS.") else: aioredis_client = self._dashboard_head.aioredis_client receiver = Receiver() key = "{}:*".format(actor_consts.ACTOR_CHANNEL) pattern = receiver.pattern(key) await aioredis_client.psubscribe(pattern) logger.info("Subscribed to %s", key) async for sender, msg in receiver.iter(): try: actor_id, actor_table_data = msg actor_id = actor_id.decode( "UTF-8")[len(gcs_utils.TablePrefix_ACTOR_string + ":"):] pubsub_message = gcs_utils.PubSubMessage.FromString( actor_table_data) actor_table_data = gcs_utils.ActorTableData.FromString( pubsub_message.data) process_actor_data_from_pubsub(actor_id, actor_table_data) except Exception: logger.exception("Error processing actor info from Redis.")