def check_started(self): if self._state == ReplicaState.RUNNING: return True assert self._state == ReplicaState.STARTING, ( f"State must be {ReplicaState.STARTING}, *not* {self._state}") ready, _ = ray.wait([self._startup_obj_ref], timeout=0) if len(ready) == 1: self._state = ReplicaState.RUNNING return True time_since_start = time.time() - self._start_time if (time_since_start > SLOW_STARTUP_WARNING_S and time.time() - self._prev_slow_startup_warning_time > SLOW_STARTUP_WARNING_PERIOD_S): # Filter to relevant resources. required = { k: v for k, v in self._actor_resources.items() if v > 0 } available = { k: v for k, v in ray.available_resources().items() if k in required } logger.warning( f"Replica '{self._replica_tag}' for backend " f"'{self._backend_tag}' has taken more than " f"{time_since_start:.0f}s to start up. This may be " "caused by waiting for the cluster to auto-scale or " "because the backend constructor is slow. Resources required: " f"{required}, resources available: {available}.") self._prev_slow_startup_warning_time = time.time() return False
async def _start_pending_backend_replicas( self, current_state: SystemState) -> None: """Starts the pending backend replicas in self.backend_replicas_to_start. Waits for replicas to start up, then removes them from self.backend_replicas_to_start. """ fut_to_replica_info = {} for backend_tag, replicas_to_create in self.backend_replicas_to_start.\ items(): for replica_tag in replicas_to_create: replica_handle = await self._start_backend_replica( current_state, backend_tag, replica_tag) ready_future = replica_handle.ready.remote().as_future() fut_to_replica_info[ready_future] = (backend_tag, replica_tag, replica_handle) start = time.time() prev_warning = start while fut_to_replica_info: if time.time() - prev_warning > REPLICA_STARTUP_TIME_WARNING_S: prev_warning = time.time() logger.warning("Waited {:.2f}s for replicas to start up. Make " "sure there are enough resources to create the " "replicas.".format(time.time() - start)) done, pending = await asyncio.wait( list(fut_to_replica_info.keys()), timeout=1) for fut in done: (backend_tag, replica_tag, replica_handle) = fut_to_replica_info.pop(fut) self.backend_replicas[backend_tag][ replica_tag] = replica_handle self.backend_replicas_to_start.clear()
def check_started(self) -> bool: """Check if the replica has started. If so, transition to RUNNING. Should handle the case where the replica has already stopped. """ if self._state == ReplicaState.RUNNING: return True assert self._state == ReplicaState.STARTING, ( f"State must be {ReplicaState.STARTING}, *not* {self._state}") if self._actor.check_ready(): self._state = ReplicaState.RUNNING return True time_since_start = time.time() - self._start_time if (time_since_start > SLOW_STARTUP_WARNING_S and time.time() - self._prev_slow_startup_warning_time > SLOW_STARTUP_WARNING_PERIOD_S): required, available = self._actor.resource_requirements() logger.warning( f"Replica '{self._replica_tag}' for backend " f"'{self._backend_tag}' has taken more than " f"{time_since_start:.0f}s to start up. This may be " "caused by waiting for the cluster to auto-scale or " "because the backend constructor is slow. Resources required: " f"{required}, resources available: {available}.") self._prev_slow_startup_warning_time = time.time() return False
def get(self, key: str) -> Optional[bytes]: """Get the value associated with the given key from the store. Args: key (str) Returns: The bytes value. If the key wasn't found, returns None. """ if not isinstance(key, str): raise TypeError("key must be a string, got: {}.".format(type(key))) try: response = self._s3.get_object(Bucket=self._bucket, Key=self.get_storage_key(key)) return response["Body"].read() except ClientError as e: if e.response["Error"]["Code"] == "NoSuchKey": logger.warning(f"No such key in s3 for key = {key}") return None else: message = e.response["Error"]["Message"] logger.error(f"Encountered ClientError while calling get() " f"in RayExternalKVStore: {message}") raise e
def get_handle(self, endpoint_name: str, missing_ok: Optional[bool] = False) -> RayServeHandle: """Retrieve RayServeHandle for service endpoint to invoke it from Python. Args: endpoint_name (str): A registered service endpoint. missing_ok (bool): If true, then Serve won't check the endpoint is registered. False by default. Returns: RayServeHandle """ if not missing_ok and endpoint_name not in ray.get( self._controller.get_all_endpoints.remote()): raise KeyError(f"Endpoint '{endpoint_name}' does not exist.") routers = list(ray.get(self._controller.get_routers.remote()).values()) current_node_id = ray.get_runtime_context().node_id.hex() try: router_chosen = next( filter(lambda r: get_node_id_for_actor(r) == current_node_id, routers)) except StopIteration: logger.warning( f"When getting a handle for {endpoint_name}, Serve can't find " "a router on the same node. Serve will use a random router.") router_chosen = random.choice(routers) return RayServeHandle( router_chosen, endpoint_name, )
def _get_target_nodes(self) -> List[Tuple[str, str]]: """Return the list of (id, resource_key) to deploy HTTP servers on.""" location = self._config.location target_nodes = get_all_node_ids() if location == DeploymentMode.NoServer: return [] if location == DeploymentMode.HeadOnly: head_node_resource_key = get_current_node_resource_key() return [(node_id, node_resource) for node_id, node_resource in target_nodes if node_resource == head_node_resource_key][:1] if location == DeploymentMode.FixedNumber: num_replicas = self._config.fixed_number_replicas if num_replicas > len(target_nodes): logger.warning( "You specified fixed_number_replicas=" f"{num_replicas} but there are only " f"{len(target_nodes)} total nodes. Serve will start one " "HTTP proxy per node.") num_replicas = len(target_nodes) # Seed the random state so sample is deterministic. # i.e. it will always return the same set of nodes. random.seed(self._config.fixed_number_selection_seed) return random.sample(sorted(target_nodes), k=num_replicas) return target_nodes
def shutdown(self) -> None: """Completely shut down the connected Serve instance. Shuts down all processes and deletes all state associated with the instance. """ if (not self._shutdown) and ray.is_initialized(): for goal_id in ray.get(self._controller.shutdown.remote()): self._wait_for_goal(goal_id) ray.kill(self._controller, no_restart=True) # Wait for the named actor entry gets removed as well. started = time.time() while True: try: controller_namespace = _get_controller_namespace( self._detached) ray.get_actor( self._controller_name, namespace=controller_namespace) if time.time() - started > 5: logger.warning( "Waited 5s for Serve to shutdown gracefully but " "the controller is still not cleaned up. " "You can ignore this warning if you are shutting " "down the Ray cluster.") break except ValueError: # actor name is removed break self._shutdown = True
def get_handle(self, endpoint_name: str, missing_ok: Optional[bool] = False, sync: bool = True) -> RayServeHandle: """Retrieve RayServeHandle for service endpoint to invoke it from Python. Args: endpoint_name (str): A registered service endpoint. missing_ok (bool): If true, then Serve won't check the endpoint is registered. False by default. sync (bool): If true, then Serve will return a ServeHandle that works everywhere. Otherwise, Serve will return a ServeHandle that's only usable in asyncio loop. Returns: RayServeHandle """ if not missing_ok and endpoint_name not in ray.get( self._controller.get_all_endpoints.remote()): raise KeyError(f"Endpoint '{endpoint_name}' does not exist.") if asyncio.get_event_loop().is_running() and sync: logger.warning( "You are retrieving a ServeHandle inside an asyncio loop. " "Try getting client.get_handle(.., sync=False) to get better " "performance.") if endpoint_name not in self._handle_cache: handle = RayServeHandle(self._controller, endpoint_name, sync=sync) self._handle_cache[endpoint_name] = handle return self._handle_cache[endpoint_name]
def _check_http_and_checkpoint_options( client: Client, http_options: Union[dict, HTTPOptions], checkpoint_path: str, ) -> None: if checkpoint_path and checkpoint_path != client.checkpoint_path: logger.warning( f"The new client checkpoint path '{checkpoint_path}' " f"is different from the existing one '{client.checkpoint_path}'. " "The new checkpoint path is ignored.") if http_options: client_http_options = client.http_config new_http_options = http_options if isinstance( http_options, HTTPOptions) else HTTPOptions.parse_obj(http_options) different_fields = [] all_http_option_fields = new_http_options.__dict__ for field in all_http_option_fields: if getattr(new_http_options, field) != getattr( client_http_options, field): different_fields.append(field) if len(different_fields): logger.warning( "The new client HTTP config differs from the existing one " f"in the following fields: {different_fields}. " "The new HTTP config is ignored.")
def get_handle( self, endpoint_name: str, missing_ok: Optional[bool] = False, sync: bool = True) -> Union[RayServeHandle, RayServeSyncHandle]: """Retrieve RayServeHandle for service endpoint to invoke it from Python. Args: endpoint_name (str): A registered service endpoint. missing_ok (bool): If true, then Serve won't check the endpoint is registered. False by default. sync (bool): If true, then Serve will return a ServeHandle that works everywhere. Otherwise, Serve will return a ServeHandle that's only usable in asyncio loop. Returns: RayServeHandle """ all_endpoints = ray.get(self._controller.get_all_endpoints.remote()) if not missing_ok and endpoint_name not in all_endpoints: raise KeyError(f"Endpoint '{endpoint_name}' does not exist.") if asyncio.get_event_loop().is_running() and sync: logger.warning( "You are retrieving a sync handle inside an asyncio loop. " "Try getting client.get_handle(.., sync=False) to get better " "performance. Learn more at https://docs.ray.io/en/master/" "serve/http-servehandle.html#sync-and-async-handles") if not asyncio.get_event_loop().is_running() and not sync: logger.warning( "You are retrieving an async handle outside an asyncio loop. " "You should make sure client.get_handle is called inside a " "running event loop. Or call client.get_handle(.., sync=True) " "to create sync handle. Learn more at https://docs.ray.io/en/" "master/serve/http-servehandle.html#sync-and-async-handles") if endpoint_name in all_endpoints: this_endpoint = all_endpoints[endpoint_name] python_methods: List[str] = this_endpoint["python_methods"] else: # This can happen in the missing_ok=True case. # handle.method_name.remote won't work and user must # use the legacy handle.options(method).remote(). python_methods: List[str] = [] # NOTE(simon): this extra layer of router seems unnecessary # BUT it's needed still because of the shared asyncio thread. router = self._get_proxied_router(sync=sync, endpoint=endpoint_name) if sync: handle = RayServeSyncHandle(router, endpoint_name, known_python_methods=python_methods) else: handle = RayServeHandle(router, endpoint_name, known_python_methods=python_methods) return handle
def check(self, *args, _internal=False, **kwargs): if self._shutdown: raise RayServeException("Client has already been shut down.") if not _internal: logger.warning( "The client-based API is being deprecated in favor of global " "API calls (e.g., `serve.create_backend()`). Please replace " "all instances of `client.api_call()` with " "`serve.api_call()`.") return f(self, *args, **kwargs)
async def retry_method(*args, **kwargs): while True: result = await f(*args, **kwargs) if isinstance(result, ray.exceptions.RayActorError): logger.warning( "Actor method '{}' failed, retrying after 100ms.". format(name)) await asyncio.sleep(0.1) else: return result
async def __call__(self, scope, receive, send): http_body_bytes = await self.receive_http_body(scope, receive, send) headers = {k.decode(): v.decode() for k, v in scope["headers"]} # scope["router"] and scope["endpoint"] contain references to a router # and endpoint object, respectively, which each in turn contain a # reference to the Serve client, which cannot be serialized. # The solution is to delete these from scope, as they will not be used. del scope["router"] del scope["endpoint"] # Modify the path and root path so that reverse lookups and redirection # work as expected. We do this here instead of in replicas so it can be # changed without restarting the replicas. scope["path"] = scope["path"].replace(self.path_prefix, "", 1) scope["root_path"] = self.path_prefix handle = self.handle.options( method_name=headers.get("X-SERVE-CALL-METHOD".lower(), DEFAULT.VALUE), shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE), http_method=scope["method"].upper(), http_headers=headers) # NOTE(edoakes): it's important that we defer building the starlette # request until it reaches the backend replica to avoid unnecessary # serialization cost, so we use a simple dataclass here. request = HTTPRequestWrapper(scope, http_body_bytes) retries = 0 backoff_time_s = 0.05 while retries < MAX_REPLICA_FAILURE_RETRIES: object_ref = await handle.remote(request) try: result = await object_ref break except RayActorError: logger.warning( "Request failed due to replica failure. There are " f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries " "remaining.") await asyncio.sleep(backoff_time_s) backoff_time_s *= 2 retries += 1 if isinstance(result, RayTaskError): error_message = "Task Error. Traceback: {}.".format(result) await Response(error_message, status_code=500).send(scope, receive, send) elif isinstance(result, starlette.responses.Response): await result(scope, receive, send) else: await Response(result).send(scope, receive, send)
def _process_update(self, updates: Dict[str, UpdatedObject]): if isinstance(updates, (ray.exceptions.RayActorError)): # This can happen during shutdown where the controller is # intentionally killed, the client should just gracefully # exit. logger.debug("LongPollClient failed to connect to host. " "Shutting down.") self.is_running = False return if isinstance(updates, ConnectionError): logger.warning("LongPollClient connection failed, shutting down.") self.is_running = False return if isinstance(updates, (ray.exceptions.RayTaskError)): if isinstance(updates.as_instanceof_cause(), (asyncio.TimeoutError)): logger.debug("LongPollClient polling timed out. Retrying.") else: # Some error happened in the controller. It could be a bug or # some undesired state. logger.error("LongPollHost errored\n" + updates.traceback_str) self._poll_next() return logger.debug(f"LongPollClient {self} received updates for keys: " f"{list(updates.keys())}.") for key, update in updates.items(): self.object_snapshots[key] = update.object_snapshot self.snapshot_ids[key] = update.snapshot_id callback = self.key_listeners[key] # Bind the parameters because closures are late-binding. # https://docs.python-guide.org/writing/gotchas/#late-binding-closures # noqa: E501 def chained(callback=callback, arg=update.object_snapshot): callback(arg) self._on_callback_completed(trigger_at=len(updates)) if self.event_loop is None: chained() else: # Schedule the next iteration only if the loop is running. # The event loop might not be running if users used a cached # version across loops. if self.event_loop.is_running(): self.event_loop.call_soon_threadsafe(chained) else: logger.error( "The event loop is closed, shutting down long poll " "client.") self.is_running = False
async def _send_request_to_handle(handle, scope, receive, send): http_body_bytes = await receive_http_body(scope, receive, send) headers = {k.decode(): v.decode() for k, v in scope["headers"]} handle = handle.options( method_name=headers.get("X-SERVE-CALL-METHOD".lower(), DEFAULT.VALUE), shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE), http_method=scope["method"].upper(), http_headers=headers, ) # scope["router"] and scope["endpoint"] contain references to a router # and endpoint object, respectively, which each in turn contain a # reference to the Serve client, which cannot be serialized. # The solution is to delete these from scope, as they will not be used. # TODO(edoakes): this can be removed once we deprecate the old API. if "router" in scope: del scope["router"] if "endpoint" in scope: del scope["endpoint"] # NOTE(edoakes): it's important that we defer building the starlette # request until it reaches the backend replica to avoid unnecessary # serialization cost, so we use a simple dataclass here. request = HTTPRequestWrapper(scope, http_body_bytes) # Perform a pickle here to improve latency. Stdlib pickle for simple # dataclasses are 10-100x faster than cloudpickle. request = pickle.dumps(request) retries = 0 backoff_time_s = 0.05 while retries < MAX_REPLICA_FAILURE_RETRIES: object_ref = await handle.remote(request) try: result = await object_ref break except RayActorError: logger.warning("Request failed due to replica failure. There are " f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries " "remaining.") await asyncio.sleep(backoff_time_s) backoff_time_s *= 2 retries += 1 if isinstance(result, RayTaskError): error_message = "Task Error. Traceback: {}.".format(result) await Response(error_message, status_code=500).send(scope, receive, send) elif isinstance(result, starlette.responses.Response): await result(scope, receive, send) else: await Response(result).send(scope, receive, send)
def _checkpoint(self): """Checkpoint internal state and write it to the KV store.""" logger.debug("Writing checkpoint") start = time.time() checkpoint = pickle.dumps( (self.routes, self.backends, self.traffic_policies, self.replicas, self.replicas_to_start, self.replicas_to_stop)) self.kv_store_client.put("checkpoint", checkpoint) logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start)) if random.random() < _CRASH_AFTER_CHECKPOINT_PROBABILITY: logger.warning("Intentionally crashing after checkpoint") os._exit(0)
def update(self) -> bool: """Updates the state of all running replicas to match the goal state. """ self._scale_all_backends() for goal_id in self._completed_goals(): self._goal_manager.complete_goal(goal_id) transitioned_backend_tags = set() for backend_tag, replicas in self._replicas.items(): for replica in replicas.pop(states=[ReplicaState.RUNNING]): if replica.check_health(): replicas.add(ReplicaState.RUNNING, replica) else: logger.warning( f"Replica {replica.replica_tag} of backend " f"{backend_tag} failed health check, stopping it.") replica.set_should_stop(0) replicas.add(ReplicaState.SHOULD_STOP, replica) for replica in replicas.pop(states=[ReplicaState.SHOULD_START]): replica.start(self._backend_metadata[backend_tag]) replicas.add(ReplicaState.STARTING, replica) for replica in replicas.pop(states=[ReplicaState.SHOULD_STOP]): # This replica should be taken off handle's replica set. transitioned_backend_tags.add(backend_tag) replica.stop() replicas.add(ReplicaState.STOPPING, replica) for replica in replicas.pop(states=[ReplicaState.STARTING]): if replica.check_started(): # This replica should be now be added to handle's replica # set. replicas.add(ReplicaState.RUNNING, replica) transitioned_backend_tags.add(backend_tag) else: replicas.add(ReplicaState.STARTING, replica) for replica in replicas.pop(states=[ReplicaState.STOPPING]): if not replica.check_stopped(): replicas.add(ReplicaState.STOPPING, replica) if len(transitioned_backend_tags) > 0: self._checkpoint() [ self._notify_replica_handles_changed(tag) for tag in transitioned_backend_tags ]
def _checkpoint(self) -> None: """Checkpoint internal state and write it to the KV store.""" assert self.write_lock.locked() logger.debug("Writing checkpoint") start = time.time() checkpoint = pickle.dumps( Checkpoint(self.configuration_store, self.actor_reconciler)) self.kv_store.put(CHECKPOINT_KEY, checkpoint) logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start)) if random.random( ) < _CRASH_AFTER_CHECKPOINT_PROBABILITY and self.detached: logger.warning("Intentionally crashing after checkpoint") os._exit(0)
async def backend_control_loop(self): start = time.time() prev_warning = start need_to_continue = True while need_to_continue: if time.time() - prev_warning > REPLICA_STARTUP_TIME_WARNING_S: prev_warning = time.time() logger.warning("Waited {:.2f}s for replicas to start up. Make " "sure there are enough resources to create the " "replicas.".format(time.time() - start)) need_to_continue = ( await self._check_currently_starting_replicas() or await self._check_currently_stopping_replicas()) asyncio.sleep(1)
def _checkpoint(self) -> None: """Checkpoint internal state and write it to the KV store.""" assert self.write_lock.locked() logger.debug("Writing checkpoint") start = time.time() checkpoint = pickle.dumps( Checkpoint(self.backend_state.checkpoint(), self._serializable_inflight_results)) self.kv_store.put(CHECKPOINT_KEY, checkpoint) logger.debug("Wrote checkpoint in {:.3f}s".format(time.time() - start)) if random.random( ) < _CRASH_AFTER_CHECKPOINT_PROBABILITY and self.detached: logger.warning("Intentionally crashing after checkpoint") os._exit(0)
def _checkpoint(self): """Checkpoint internal state and write it to the KV store.""" assert self.write_lock.locked() logger.debug("Writing checkpoint") start = time.time() checkpoint = pickle.dumps( (self.routes, list( self.routers.keys()), self.backends, self.traffic_policies, self.replicas, self.replicas_to_start, self.replicas_to_stop, self.backends_to_remove, self.endpoints_to_remove)) self.kv_store.put(CHECKPOINT_KEY, checkpoint) logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start)) if random.random() < _CRASH_AFTER_CHECKPOINT_PROBABILITY: logger.warning("Intentionally crashing after checkpoint") os._exit(0)
def get(self, key: str) -> Optional[bytes]: """Get the value associated with the given key from the store. Args: key (str) Returns: The bytes value. If the key wasn't found, returns None. """ if not isinstance(key, str): raise TypeError("key must be a string, got: {}.".format(type(key))) try: blob = self._bucket.blob(blob_name=self.get_storage_key(key)) return blob.download_as_bytes() except NotFound: logger.warning(f"No such key in GCS for key = {key}") return None
def _validate_batch_size(self): if (self.max_batch_size is not None and not self.internal_metadata.accepts_batches and self.max_batch_size > 1): raise ValueError( "max_batch_size is set in config but the function or " "method does not accept batching. Please use " "@serve.accept_batch to explicitly mark that the function or " "method accepts a list of requests as an argument.") if self.max_batch_size is not None: logger.warning( "Setting max_batch_size and batch_wait_timeout in the " "BackendConfig are deprecated in favor of using the " "@serve.batch decorator in the application level. Please see " "the documentation for details: " "https://docs.ray.io/en/master/serve/ml-models.html#request-batching." # noqa:E501 )
def get_handle( self, endpoint_name: str, missing_ok: Optional[bool] = False, sync: bool = True) -> Union[RayServeHandle, RayServeSyncHandle]: """Retrieve RayServeHandle for service endpoint to invoke it from Python. Args: endpoint_name (str): A registered service endpoint. missing_ok (bool): If true, then Serve won't check the endpoint is registered. False by default. sync (bool): If true, then Serve will return a ServeHandle that works everywhere. Otherwise, Serve will return a ServeHandle that's only usable in asyncio loop. Returns: RayServeHandle """ if not missing_ok and endpoint_name not in ray.get( self._controller.get_all_endpoints.remote()): raise KeyError(f"Endpoint '{endpoint_name}' does not exist.") if asyncio.get_event_loop().is_running() and sync: logger.warning( "You are retrieving a sync handle inside an asyncio loop. " "Try getting client.get_handle(.., sync=False) to get better " "performance. Learn more at https://docs.ray.io/en/master/" "serve/advanced.html#sync-and-async-handles") if not asyncio.get_event_loop().is_running() and not sync: logger.warning( "You are retrieving an async handle outside an asyncio loop. " "You should make sure client.get_handle is called inside a " "running event loop. Or call client.get_handle(.., sync=True) " "to create sync handle. Learn more at https://docs.ray.io/en/" "master/serve/advanced.html#sync-and-async-handles") if sync: handle = RayServeSyncHandle(self._get_proxied_router(sync=sync), endpoint_name) else: handle = RayServeHandle(self._get_proxied_router(sync=sync), endpoint_name) return handle
async def update_actor_state(self, start_time: float) -> bool: """Returns whether the number of backends has changed.""" num_starting = len(self.currently_starting_replicas) num_stopping = len(self.currently_stopping_replicas) num_pending_starts = await self._check_currently_starting_replicas() num_pending_stops = await self._check_currently_stopping_replicas() time_running = int(time.time() - start_time) if (time_running > 0 and time_running % REPLICA_STARTUP_TIME_WARNING_S == 0): delta = time.time() - start_time logger.warning( f"Waited {delta:.2f}s for {num_pending_starts} replicas " f"to start up or {num_pending_stops} replicas to shutdown." " Make sure there are enough resources to create the " "replicas.") return (len(self.currently_starting_replicas) != num_starting) or \ (len(self.currently_stopping_replicas) != num_stopping)
async def __call__(self, scope, receive, send): http_body_bytes = await self.receive_http_body(scope, receive, send) headers = {k.decode(): v.decode() for k, v in scope["headers"]} # Modify the path and root path so that reverse lookups and redirection # work as expected. We do this here instead of in replicas so it can be # changed without restarting the replicas. scope["path"] = scope["path"].replace(self.path_prefix, "", 1) scope["root_path"] = self.path_prefix starlette_request = build_starlette_request(scope, http_body_bytes) handle = self.handle.options( method_name=headers.get("X-SERVE-CALL-METHOD".lower(), DEFAULT.VALUE), shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE), http_method=scope["method"].upper(), http_headers=headers) retries = 0 backoff_time_s = 0.05 while retries < MAX_ACTOR_FAILURE_RETRIES: object_ref = await handle.remote(starlette_request) try: result = await object_ref break except RayActorError: logger.warning( "Request failed due to actor failure. There are " f"{MAX_ACTOR_FAILURE_RETRIES - retries} retries " "remaining.") await asyncio.sleep(backoff_time_s) backoff_time_s *= 2 retries += 1 if isinstance(result, RayTaskError): error_message = "Task Error. Traceback: {}.".format(result) await Response(error_message, status_code=500).send(scope, receive, send) elif isinstance(result, starlette.responses.Response): await result(scope, receive, send) else: await Response(result).send(scope, receive, send)
def _process_update(self, updates: Dict[str, UpdatedObject]): if isinstance(updates, (ray.exceptions.RayActorError)): # This can happen during shutdown where the controller is # intentionally killed, the client should just gracefully # exit. logger.debug("LongPollClient failed to connect to host. Shutting down.") self.is_running = False return if isinstance(updates, ConnectionError): logger.warning("LongPollClient connection failed, shutting down.") self.is_running = False return if isinstance(updates, (ray.exceptions.RayTaskError)): if isinstance(updates.as_instanceof_cause(), (asyncio.TimeoutError)): logger.debug("LongPollClient polling timed out. Retrying.") else: # Some error happened in the controller. It could be a bug or # some undesired state. logger.error("LongPollHost errored\n" + updates.traceback_str) # We must call this in event loop so it works in Ray Client. # See https://github.com/ray-project/ray/issues/20971 self._schedule_to_event_loop(self._poll_next) return logger.debug( f"LongPollClient {self} received updates for keys: " f"{list(updates.keys())}." ) for key, update in updates.items(): self.object_snapshots[key] = update.object_snapshot self.snapshot_ids[key] = update.snapshot_id callback = self.key_listeners[key] # Bind the parameters because closures are late-binding. # https://docs.python-guide.org/writing/gotchas/#late-binding-closures # noqa: E501 def chained(callback=callback, arg=update.object_snapshot): callback(arg) self._on_callback_completed(trigger_at=len(updates)) self._schedule_to_event_loop(chained)
async def backend_control_loop(self): start = time.time() prev_warning = start need_to_continue = True num_pending_starts, num_pending_stops = 0, 0 while need_to_continue: if time.time() - prev_warning > REPLICA_STARTUP_TIME_WARNING_S: prev_warning = time.time() delta = time.time() - start logger.warning( f"Waited {delta:.2f}s for {num_pending_starts} replicas " f"to start up or {num_pending_stops} replicas to shutdown." " Make sure there are enough resources to create the " "replicas.") num_pending_starts = await self._check_currently_starting_replicas( ) num_pending_stops = await self._check_currently_stopping_replicas() need_to_continue = num_pending_starts or num_pending_stops asyncio.sleep(1)
def check_started(self): if self._state == ReplicaState.RUNNING: return True assert self._state == ReplicaState.STARTING, ( f"State must be {ReplicaState.STARTING}, *not* {self._state}") ready, _ = ray.wait([self._startup_obj_ref], timeout=0) if len(ready) == 1: self._state = ReplicaState.RUNNING return True time_since_start = time.time() - self._start_time if (time_since_start > SLOW_STARTUP_WARNING_S and time.time() - self._prev_slow_startup_warning_time > SLOW_STARTUP_WARNING_PERIOD_S): logger.warning( f"Replica '{self._replica_tag}' for backend " f"'{self._backend_tag}' has taken more than " f"{time_since_start:.0f}s to start up. This may be " "caused by waiting for the cluster to auto-scale or " "because the backend constructor is slow.") self._prev_slow_startup_warning_time = time.time() return False
async def __call__(self, scope, receive, send): http_body_bytes = await self.receive_http_body(scope, receive, send) headers = {k.decode(): v.decode() for k, v in scope["headers"]} retries = 0 backoff_time_s = 0.05 while retries < MAX_ACTOR_FAILURE_RETRIES: object_ref = await self.handle.options( method_name=headers.get("X-SERVE-CALL-METHOD".lower(), DEFAULT.VALUE), shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE), http_method=scope["method"].upper(), http_headers=headers).remote( build_starlette_request(scope, http_body_bytes)) try: result = await object_ref break except RayActorError: logger.warning( "Request failed due to actor failure. There are " f"{MAX_ACTOR_FAILURE_RETRIES - retries} retries " "remaining.") await asyncio.sleep(backoff_time_s) backoff_time_s *= 2 retries += 1 if isinstance(result, RayTaskError): error_message = "Task Error. Traceback: {}.".format(result) await Response( error_message, status_code=500).send(scope, receive, send) elif isinstance(result, starlette.responses.Response): await result(scope, receive, send) else: await Response(result).send(scope, receive, send)