def send_get_response(result: Any) -> None: """Pushes GetResponses to the main DataPath loop to send to the client. This is called when the object is ready on the server side.""" try: serialized = dumps_from_server(result, client_id, self) total_size = len(serialized) assert total_size > 0, "Serialized object cannot be zero bytes" total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE) for chunk_id in range(request.start_chunk_id, total_chunks): start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE) get_resp = ray_client_pb2.GetResponse( valid=True, data=serialized[start:end], chunk_id=chunk_id, total_chunks=total_chunks, total_size=total_size, ) chunk_resp = ray_client_pb2.DataResponse( get=get_resp, req_id=req_id) result_queue.put(chunk_resp) except Exception as exc: get_resp = ray_client_pb2.GetResponse( valid=False, error=cloudpickle.dumps(exc)) resp = ray_client_pb2.DataResponse(get=get_resp, req_id=req_id) result_queue.put(resp)
def Datapath(self, request_iterator, context): metadata = {k: v for k, v in context.invocation_metadata()} client_id = metadata["client_id"] if client_id == "": logger.error("Client connecting with no client_id") return logger.info(f"New data connection from client {client_id}") try: with self._clients_lock: with disable_client_hook(): if self._num_clients == 0 and not ray.is_initialized(): self.ray_connect_handler() self._num_clients += 1 for req in request_iterator: resp = None req_type = req.WhichOneof("type") if req_type == "get": get_resp = self.basic_service._get_object( req.get, client_id) resp = ray_client_pb2.DataResponse(get=get_resp) elif req_type == "put": put_resp = self.basic_service._put_object( req.put, client_id) resp = ray_client_pb2.DataResponse(put=put_resp) elif req_type == "release": released = [] for rel_id in req.release.ids: rel = self.basic_service.release(client_id, rel_id) released.append(rel) resp = ray_client_pb2.DataResponse( release=ray_client_pb2.ReleaseResponse(ok=released)) elif req_type == "connection_info": resp = ray_client_pb2.DataResponse( connection_info=self._build_connection_response()) else: raise Exception(f"Unreachable code: Request type " f"{req_type} not handled in Datapath") resp.req_id = req.req_id yield resp except grpc.RpcError as e: logger.debug(f"Closing data channel: {e}") finally: logger.info(f"Lost data connection from client {client_id}") self.basic_service.release_all(client_id) with self._clients_lock: self._num_clients -= 1 with disable_client_hook(): if self._num_clients == 0: ray.shutdown()
def Datapath(self, request_iterator, context): metadata = {k: v for k, v in context.invocation_metadata()} client_id = metadata["client_id"] if client_id == "": logger.error("Client connecting with no client_id") return logger.info(f"New data connection from client {client_id}") try: with self._clients_lock: self._num_clients += 1 for req in request_iterator: resp = None req_type = req.WhichOneof("type") if req_type == "get": get_resp = self.basic_service._get_object( req.get, client_id) resp = ray_client_pb2.DataResponse(get=get_resp) elif req_type == "put": put_resp = self.basic_service._put_object( req.put, client_id) resp = ray_client_pb2.DataResponse(put=put_resp) elif req_type == "release": released = [] for rel_id in req.release.ids: rel = self.basic_service.release(client_id, rel_id) released.append(rel) resp = ray_client_pb2.DataResponse( release=ray_client_pb2.ReleaseResponse(ok=released)) elif req_type == "connection_info": with self._clients_lock: cur_num_clients = self._num_clients info = ray_client_pb2.ConnectionInfoResponse( num_clients=cur_num_clients, python_version="{}.{}.{}".format( sys.version_info[0], sys.version_info[1], sys.version_info[2]), ray_version=ray.__version__, ray_commit=ray.__commit__) resp = ray_client_pb2.DataResponse(connection_info=info) else: raise Exception(f"Unreachable code: Request type " f"{req_type} not handled in Datapath") resp.req_id = req.req_id yield resp except grpc.RpcError as e: logger.debug(f"Closing data channel: {e}") finally: logger.info(f"Lost data connection from client {client_id}") self.basic_service.release_all(client_id) with self._clients_lock: self._num_clients -= 1
def Datapath(self, request_iterator, context): client_id = _get_client_id_from_context(context) if client_id == "": return # Create Placeholder *before* reading the first request. server = self.proxy_manager.create_specific_server(client_id) try: with self.clients_lock: self.num_clients += 1 logger.info(f"New data connection from client {client_id}: ") init_req = next(request_iterator) try: modified_init_req, job_config = prepare_runtime_init_req( init_req) if not self.proxy_manager.start_specific_server( client_id, job_config): logger.error( f"Server startup failed for client: {client_id}, " f"using JobConfig: {job_config}!") raise RuntimeError( "Starting Ray client server failed. This is most " "likely because the runtime_env failed to be " "installed. See ray_client_server_[port].err on the " "head node of the cluster for the relevant logs.") channel = self.proxy_manager.get_channel(client_id) if channel is None: logger.error(f"Channel not found for {client_id}") raise RuntimeError( "Proxy failed to Connect to backend! Check " "`ray_client_server.err` on the cluster.") stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel) except Exception: init_resp = ray_client_pb2.DataResponse( init=ray_client_pb2.InitResponse( ok=False, msg=traceback.format_exc())) init_resp.req_id = init_req.req_id yield init_resp return None new_iter = chain([modified_init_req], request_iterator) resp_stream = stub.Datapath( new_iter, metadata=[("client_id", client_id)]) for resp in resp_stream: yield self.modify_connection_info_resp(resp) except Exception: logger.exception("Proxying Datapath failed!") finally: server.set_result(None) with self.clients_lock: logger.debug(f"Client detached: {client_id}") self.num_clients -= 1
def send_get_response(result: Any) -> None: """Pushes a GetResponse to the main DataPath loop to send to the client. This is called when the object is ready on the server side.""" try: serialized = dumps_from_server(result, client_id, self) get_resp = ray_client_pb2.GetResponse(valid=True, data=serialized) except Exception as exc: get_resp = ray_client_pb2.GetResponse( valid=False, error=cloudpickle.dumps(exc)) resp = ray_client_pb2.DataResponse(get=get_resp, req_id=req_id) result_queue.put(resp)
def modify_connection_info_resp(self, init_resp: ray_client_pb2.DataResponse ) -> ray_client_pb2.DataResponse: """ Modify the `num_clients` returned the ConnectionInfoResponse because individual SpecificServers only have **one** client. """ init_type = init_resp.WhichOneof("type") if init_type != "connection_info": return init_resp modified_resp = ray_client_pb2.DataResponse() modified_resp.CopyFrom(init_resp) with self.clients_lock: modified_resp.connection_info.num_clients = self.num_clients return modified_resp
def _init(self, req_init, client_id): with self.clients_lock: threshold = int(CLIENT_SERVER_MAX_THREADS / 2) if self.num_clients >= threshold: logger.warning( f"[Data Servicer]: Num clients {self.num_clients} " f"has reached the threshold {threshold}. " f"Rejecting client: {client_id}. ") if log_once("client_threshold"): logger.warning( "You can configure the client connection " "threshold by setting the " "RAY_CLIENT_SERVER_MAX_THREADS env var " f"(currently set to {CLIENT_SERVER_MAX_THREADS}).") return None resp_init = self.basic_service.Init(req_init) self.num_clients += 1 return ray_client_pb2.DataResponse(init=resp_init, )
def Datapath(self, request_iterator, context): start_time = time.time() # set to True if client shuts down gracefully cleanup_requested = False metadata = {k: v for k, v in context.invocation_metadata()} client_id = metadata.get("client_id") if client_id is None: logger.error("Client connecting with no client_id") return logger.debug(f"New data connection from client {client_id}: ") accepted_connection = self._init(client_id, context, start_time) response_cache = self.response_caches[client_id] # Set to False if client requests a reconnect grace period of 0 reconnect_enabled = True if not accepted_connection: return try: request_queue = Queue() queue_filler_thread = Thread(target=fill_queue, daemon=True, args=(request_iterator, request_queue)) queue_filler_thread.start() """For non `async get` requests, this loop yields immediately For `async get` requests, this loop: 1) does not yield, it just continues 2) When the result is ready, it yields """ for req in iter(request_queue.get, None): if isinstance(req, ray_client_pb2.DataResponse): # Early shortcut if this is the result of an async get. yield req continue assert isinstance(req, ray_client_pb2.DataRequest) if _should_cache(req) and reconnect_enabled: cached_resp = response_cache.check_cache(req.req_id) if isinstance(cached_resp, Exception): # Cache state is invalid, raise exception raise cached_resp if cached_resp is not None: yield cached_resp continue resp = None req_type = req.WhichOneof("type") if req_type == "init": resp_init = self.basic_service.Init(req.init) resp = ray_client_pb2.DataResponse(init=resp_init, ) with self.clients_lock: self.reconnect_grace_periods[ client_id] = req.init.reconnect_grace_period if req.init.reconnect_grace_period == 0: reconnect_enabled = False elif req_type == "get": if req.get.asynchronous: get_resp = self.basic_service._async_get_object( req.get, client_id, req.req_id, request_queue) if get_resp is None: # Skip sending a response for this request and # continue to the next requst. The response for # this request will be sent when the object is # ready. continue else: get_resp = self.basic_service._get_object( req.get, client_id) resp = ray_client_pb2.DataResponse(get=get_resp) elif req_type == "put": if not self.put_request_chunk_collector.add_chunk( req, req.put): # Put request still in progress continue put_resp = self.basic_service._put_object( self.put_request_chunk_collector.data, req.put.client_ref_id, client_id, ) self.put_request_chunk_collector.reset() resp = ray_client_pb2.DataResponse(put=put_resp) elif req_type == "release": released = [] for rel_id in req.release.ids: rel = self.basic_service.release(client_id, rel_id) released.append(rel) resp = ray_client_pb2.DataResponse( release=ray_client_pb2.ReleaseResponse(ok=released)) elif req_type == "connection_info": resp = ray_client_pb2.DataResponse( connection_info=self._build_connection_response()) elif req_type == "prep_runtime_env": with self.clients_lock: resp_prep = self.basic_service.PrepRuntimeEnv( req.prep_runtime_env) resp = ray_client_pb2.DataResponse( prep_runtime_env=resp_prep) elif req_type == "connection_cleanup": cleanup_requested = True cleanup_resp = ray_client_pb2.ConnectionCleanupResponse() resp = ray_client_pb2.DataResponse( connection_cleanup=cleanup_resp) elif req_type == "acknowledge": # Clean up acknowledged cache entries response_cache.cleanup(req.acknowledge.req_id) continue elif req_type == "task": with self.clients_lock: task = req.task if not self.client_task_chunk_collector.add_chunk( req, task): # Not all serialized arguments have arrived continue arglist, kwargs = loads_from_client( self.client_task_chunk_collector.data, self.basic_service) self.client_task_chunk_collector.reset() resp_ticket = self.basic_service.Schedule( req.task, arglist, kwargs, context) resp = ray_client_pb2.DataResponse( task_ticket=resp_ticket) elif req_type == "terminate": with self.clients_lock: response = self.basic_service.Terminate( req.terminate, context) resp = ray_client_pb2.DataResponse(terminate=response) elif req_type == "list_named_actors": with self.clients_lock: response = self.basic_service.ListNamedActors( req.list_named_actors) resp = ray_client_pb2.DataResponse( list_named_actors=response) else: raise Exception(f"Unreachable code: Request type " f"{req_type} not handled in Datapath") resp.req_id = req.req_id if _should_cache(req) and reconnect_enabled: response_cache.update_cache(req.req_id, resp) yield resp except Exception as e: logger.exception("Error in data channel:") recoverable = _propagate_error_in_context(e, context) invalid_cache = response_cache.invalidate(e) if not recoverable or invalid_cache: context.set_code(grpc.StatusCode.FAILED_PRECONDITION) # Connection isn't recoverable, skip cleanup cleanup_requested = True finally: logger.debug(f"Stream is broken with client {client_id}") queue_filler_thread.join(QUEUE_JOIN_SECONDS) if queue_filler_thread.is_alive(): logger.error( "Queue filler thread failed to join before timeout: {}". format(QUEUE_JOIN_SECONDS)) cleanup_delay = self.reconnect_grace_periods.get(client_id) if not cleanup_requested and cleanup_delay is not None: logger.debug("Cleanup wasn't requested, delaying cleanup by" f"{cleanup_delay} seconds.") # Delay cleanup, since client may attempt a reconnect # Wait on the "stopped" event in case the grpc server is # stopped and we can clean up earlier. self.stopped.wait(timeout=cleanup_delay) else: logger.debug("Cleanup was requested, cleaning up immediately.") with self.clients_lock: if client_id not in self.client_last_seen: logger.debug("Connection already cleaned up.") # Some other connection has already cleaned up this # this client's session. This can happen if the client # reconnects and then gracefully shut's down immediately. return last_seen = self.client_last_seen[client_id] if last_seen > start_time: # The client successfully reconnected and updated # last seen some time during the grace period logger.debug("Client reconnected, skipping cleanup") return # Either the client shut down gracefully, or the client # failed to reconnect within the grace period. Clean up # the connection. self.basic_service.release_all(client_id) del self.client_last_seen[client_id] if client_id in self.reconnect_grace_periods: del self.reconnect_grace_periods[client_id] if client_id in self.response_caches: del self.response_caches[client_id] self.num_clients -= 1 logger.debug(f"Removed client {client_id}, " f"remaining={self.num_clients}") # It's important to keep the Ray shutdown # within this locked context or else Ray could hang. # NOTE: it is strange to start ray in server.py but shut it # down here. Consider consolidating ray lifetime management. with disable_client_hook(): if self.num_clients == 0: logger.debug("Shutting down ray.") ray.shutdown()
def Datapath(self, request_iterator, context): cleanup_requested = False start_time = time.time() client_id = _get_client_id_from_context(context) if client_id == "": return reconnecting = _get_reconnecting_from_context(context) if reconnecting: with self.clients_lock: if client_id not in self.clients_last_seen: # Client took too long to reconnect, session has already # been cleaned up context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details( "Attempted to reconnect a session that has already " "been cleaned up") return self.clients_last_seen[client_id] = start_time server = self.proxy_manager._get_server_for_client(client_id) channel = self.proxy_manager.get_channel(client_id) # iterator doesn't need modification on reconnect new_iter = request_iterator else: # Create Placeholder *before* reading the first request. server = self.proxy_manager.create_specific_server(client_id) with self.clients_lock: self.clients_last_seen[client_id] = start_time self.num_clients += 1 try: if not reconnecting: logger.info(f"New data connection from client {client_id}: ") init_req = next(request_iterator) with self.clients_lock: self.reconnect_grace_periods[client_id] = \ init_req.init.reconnect_grace_period try: modified_init_req, job_config = prepare_runtime_init_req( init_req) if not self.proxy_manager.start_specific_server( client_id, job_config): logger.error( f"Server startup failed for client: {client_id}, " f"using JobConfig: {job_config}!") raise RuntimeError( "Starting Ray client server failed. See " f"ray_client_server_{server.port}.err for " "detailed logs.") channel = self.proxy_manager.get_channel(client_id) if channel is None: logger.error(f"Channel not found for {client_id}") raise RuntimeError( "Proxy failed to Connect to backend! Check " "`ray_client_server.err` and " f"`ray_client_server_{server.port}.err` on the " "head node of the cluster for the relevant logs. " "By default these are located at " "/tmp/ray/session_latest/logs.") except Exception: init_resp = ray_client_pb2.DataResponse( init=ray_client_pb2.InitResponse( ok=False, msg=traceback.format_exc())) init_resp.req_id = init_req.req_id yield init_resp return None new_iter = chain([modified_init_req], request_iterator) stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel) metadata = [("client_id", client_id), ("reconnecting", str(reconnecting))] resp_stream = stub.Datapath(new_iter, metadata=metadata) for resp in resp_stream: resp_type = resp.WhichOneof("type") if resp_type == "connection_cleanup": # Specific server is skipping cleanup, proxier should too cleanup_requested = True yield self.modify_connection_info_resp(resp) except Exception as e: logger.exception("Proxying Datapath failed!") # Propogate error through context recoverable = _propagate_error_in_context(e, context) if not recoverable: # Client shouldn't attempt to recover, clean up connection cleanup_requested = True finally: cleanup_delay = self.reconnect_grace_periods.get(client_id) if not cleanup_requested and cleanup_delay is not None: # Delay cleanup, since client may attempt a reconnect # Wait on stopped event in case the server closes and we # can clean up earlier self.stopped.wait(timeout=cleanup_delay) with self.clients_lock: if client_id not in self.clients_last_seen: logger.info(f"{client_id} not found. Skipping clean up.") # Connection has already been cleaned up return last_seen = self.clients_last_seen[client_id] logger.info( f"{client_id} last started stream at {last_seen}. Current " f"stream started at {start_time}.") if last_seen > start_time: logger.info("Client reconnected. Skipping cleanup.") # Client has reconnected, don't clean up return logger.debug(f"Client detached: {client_id}") self.num_clients -= 1 del self.clients_last_seen[client_id] if client_id in self.reconnect_grace_periods: del self.reconnect_grace_periods[client_id] server.set_result(None)
def Datapath(self, request_iterator, context): metadata = {k: v for k, v in context.invocation_metadata()} client_id = metadata["client_id"] if client_id == "": logger.error("Client connecting with no client_id") return logger.debug(f"New data connection from client {client_id}: ") accepted_connection = self._init(client_id, context) if not accepted_connection: return try: request_queue = Queue() queue_filler_thread = Thread(target=fill_queue, daemon=True, args=(request_iterator, request_queue)) queue_filler_thread.start() """For non `async get` requests, this loop yields immediately For `async get` requests, this loop: 1) does not yield, it just continues 2) When the result is ready, it yields """ for req in iter(request_queue.get, None): if isinstance(req, ray_client_pb2.DataResponse): # Early shortcut if this is the result of an async get. yield req continue assert isinstance(req, ray_client_pb2.DataRequest) resp = None req_type = req.WhichOneof("type") if req_type == "init": resp_init = self.basic_service.Init(req.init) resp = ray_client_pb2.DataResponse(init=resp_init, ) elif req_type == "get": get_resp = None if req.get.asynchronous: get_resp = self.basic_service._async_get_object( req.get, client_id, req.req_id, request_queue) if get_resp is None: # Skip sending a response for this request and # continue to the next requst. The response for # this request will be sent when the object is # ready. continue else: get_resp = self.basic_service._get_object( req.get, client_id) resp = ray_client_pb2.DataResponse(get=get_resp) elif req_type == "put": put_resp = self.basic_service._put_object( req.put, client_id) resp = ray_client_pb2.DataResponse(put=put_resp) elif req_type == "release": released = [] for rel_id in req.release.ids: rel = self.basic_service.release(client_id, rel_id) released.append(rel) resp = ray_client_pb2.DataResponse( release=ray_client_pb2.ReleaseResponse(ok=released)) elif req_type == "connection_info": resp = ray_client_pb2.DataResponse( connection_info=self._build_connection_response()) elif req_type == "prep_runtime_env": with self.clients_lock: resp_prep = self.basic_service.PrepRuntimeEnv( req.prep_runtime_env) resp = ray_client_pb2.DataResponse( prep_runtime_env=resp_prep) else: raise Exception(f"Unreachable code: Request type " f"{req_type} not handled in Datapath") resp.req_id = req.req_id yield resp except grpc.RpcError as e: logger.debug(f"Closing data channel: {e}") finally: logger.debug(f"Lost data connection from client {client_id}") self.basic_service.release_all(client_id) queue_filler_thread.join(QUEUE_JOIN_SECONDS) if queue_filler_thread.is_alive(): logger.error( "Queue filler thread failed to join before timeout: {}". format(QUEUE_JOIN_SECONDS)) with self.clients_lock: # Could fail before client accounting happens self.num_clients -= 1 logger.debug(f"Removed clients. {self.num_clients}") # It's important to keep the Ray shutdown # within this locked context or else Ray could hang. with disable_client_hook(): if self.num_clients == 0: logger.debug("Shutting down ray.") ray.shutdown()
def Datapath(self, request_iterator, context): metadata = {k: v for k, v in context.invocation_metadata()} client_id = metadata["client_id"] accepted_connection = False if client_id == "": logger.error("Client connecting with no client_id") return logger.debug(f"New data connection from client {client_id}: ") try: for req in request_iterator: resp = None req_type = req.WhichOneof("type") if req_type == "init": resp = self._init(req.init, client_id) if resp is None: context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED) return logger.debug(f"Accepted data connection from {client_id}. " f"Total clients: {self.num_clients}") accepted_connection = True else: assert accepted_connection if req_type == "get": get_resp = self.basic_service._get_object( req.get, client_id) resp = ray_client_pb2.DataResponse(get=get_resp) elif req_type == "put": put_resp = self.basic_service._put_object( req.put, client_id) resp = ray_client_pb2.DataResponse(put=put_resp) elif req_type == "release": released = [] for rel_id in req.release.ids: rel = self.basic_service.release(client_id, rel_id) released.append(rel) resp = ray_client_pb2.DataResponse( release=ray_client_pb2.ReleaseResponse( ok=released)) elif req_type == "connection_info": resp = ray_client_pb2.DataResponse( connection_info=self._build_connection_response()) elif req_type == "prep_runtime_env": with self.clients_lock: resp_prep = self.basic_service.PrepRuntimeEnv( req.prep_runtime_env) resp = ray_client_pb2.DataResponse( prep_runtime_env=resp_prep) else: raise Exception(f"Unreachable code: Request type " f"{req_type} not handled in Datapath") resp.req_id = req.req_id yield resp except grpc.RpcError as e: logger.debug(f"Closing data channel: {e}") finally: logger.debug(f"Lost data connection from client {client_id}") self.basic_service.release_all(client_id) with self.clients_lock: if accepted_connection: # Could fail before client accounting happens self.num_clients -= 1 logger.debug(f"Removed clients. {self.num_clients}") # It's important to keep the Ray shutdown # within this locked context or else Ray could hang. with disable_client_hook(): if self.num_clients == 0: logger.debug("Shutting down ray.") ray.shutdown()
def Datapath(self, request_iterator, context): metadata = {k: v for k, v in context.invocation_metadata()} client_id = metadata["client_id"] accepted_connection = False if client_id == "": logger.error("Client connecting with no client_id") return logger.debug(f"New data connection from client {client_id}: ") try: with self.clients_lock: with disable_client_hook(): # It's important to keep the ray initialization call # within this locked context or else Ray could hang. if self.num_clients == 0 and not ray.is_initialized(): self.ray_connect_handler() threshold = int(CLIENT_SERVER_MAX_THREADS / 2) if self.num_clients >= threshold: context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED) logger.warning( f"[Data Servicer]: Num clients {self.num_clients} " f"has reached the threshold {threshold}. " f"Rejecting client: {metadata['client_id']}. ") if log_once("client_threshold"): logger.warning( "You can configure the client connection " "threshold by setting the " "RAY_CLIENT_SERVER_MAX_THREADS env var " f"(currently set to {CLIENT_SERVER_MAX_THREADS}).") return self.num_clients += 1 logger.debug(f"Accepted data connection from {client_id}. " f"Total clients: {self.num_clients}") accepted_connection = True for req in request_iterator: resp = None req_type = req.WhichOneof("type") if req_type == "get": get_resp = self.basic_service._get_object( req.get, client_id) resp = ray_client_pb2.DataResponse(get=get_resp) elif req_type == "put": put_resp = self.basic_service._put_object( req.put, client_id) resp = ray_client_pb2.DataResponse(put=put_resp) elif req_type == "release": released = [] for rel_id in req.release.ids: rel = self.basic_service.release(client_id, rel_id) released.append(rel) resp = ray_client_pb2.DataResponse( release=ray_client_pb2.ReleaseResponse(ok=released)) elif req_type == "connection_info": resp = ray_client_pb2.DataResponse( connection_info=self._build_connection_response()) else: raise Exception(f"Unreachable code: Request type " f"{req_type} not handled in Datapath") resp.req_id = req.req_id yield resp except grpc.RpcError as e: logger.debug(f"Closing data channel: {e}") finally: logger.debug(f"Lost data connection from client {client_id}") self.basic_service.release_all(client_id) with self.clients_lock: if accepted_connection: # Could fail before client accounting happens self.num_clients -= 1 logger.debug(f"Removed clients. {self.num_clients}") # It's important to keep the Ray shutdown # within this locked context or else Ray could hang. with disable_client_hook(): if self.num_clients == 0: logger.debug("Shutting down ray.") ray.shutdown()