Exemple #1
0
    def check_started(self):
        if self._state == ReplicaState.RUNNING:
            return True
        assert self._state == ReplicaState.STARTING, (
            f"State must be {ReplicaState.STARTING}, *not* {self._state}")
        ready, _ = ray.wait([self._startup_obj_ref], timeout=0)
        if len(ready) == 1:
            self._state = ReplicaState.RUNNING
            return True

        time_since_start = time.time() - self._start_time
        if (time_since_start > SLOW_STARTUP_WARNING_S
                and time.time() - self._prev_slow_startup_warning_time >
                SLOW_STARTUP_WARNING_PERIOD_S):
            # Filter to relevant resources.
            required = {
                k: v
                for k, v in self._actor_resources.items() if v > 0
            }
            available = {
                k: v
                for k, v in ray.available_resources().items() if k in required
            }
            logger.warning(
                f"Replica '{self._replica_tag}' for backend "
                f"'{self._backend_tag}' has taken more than "
                f"{time_since_start:.0f}s to start up. This may be "
                "caused by waiting for the cluster to auto-scale or "
                "because the backend constructor is slow. Resources required: "
                f"{required}, resources available: {available}.")
            self._prev_slow_startup_warning_time = time.time()

        return False
Exemple #2
0
    async def _start_pending_backend_replicas(
            self, current_state: SystemState) -> None:
        """Starts the pending backend replicas in self.backend_replicas_to_start.

        Waits for replicas to start up, then removes them from
        self.backend_replicas_to_start.
        """
        fut_to_replica_info = {}
        for backend_tag, replicas_to_create in self.backend_replicas_to_start.\
                items():
            for replica_tag in replicas_to_create:
                replica_handle = await self._start_backend_replica(
                    current_state, backend_tag, replica_tag)
                ready_future = replica_handle.ready.remote().as_future()
                fut_to_replica_info[ready_future] = (backend_tag, replica_tag,
                                                     replica_handle)

        start = time.time()
        prev_warning = start
        while fut_to_replica_info:
            if time.time() - prev_warning > REPLICA_STARTUP_TIME_WARNING_S:
                prev_warning = time.time()
                logger.warning("Waited {:.2f}s for replicas to start up. Make "
                               "sure there are enough resources to create the "
                               "replicas.".format(time.time() - start))

            done, pending = await asyncio.wait(
                list(fut_to_replica_info.keys()), timeout=1)
            for fut in done:
                (backend_tag, replica_tag,
                 replica_handle) = fut_to_replica_info.pop(fut)
                self.backend_replicas[backend_tag][
                    replica_tag] = replica_handle

        self.backend_replicas_to_start.clear()
Exemple #3
0
    def check_started(self) -> bool:
        """Check if the replica has started. If so, transition to RUNNING.

        Should handle the case where the replica has already stopped.
        """
        if self._state == ReplicaState.RUNNING:
            return True
        assert self._state == ReplicaState.STARTING, (
            f"State must be {ReplicaState.STARTING}, *not* {self._state}")

        if self._actor.check_ready():
            self._state = ReplicaState.RUNNING
            return True

        time_since_start = time.time() - self._start_time
        if (time_since_start > SLOW_STARTUP_WARNING_S
                and time.time() - self._prev_slow_startup_warning_time >
                SLOW_STARTUP_WARNING_PERIOD_S):
            required, available = self._actor.resource_requirements()
            logger.warning(
                f"Replica '{self._replica_tag}' for backend "
                f"'{self._backend_tag}' has taken more than "
                f"{time_since_start:.0f}s to start up. This may be "
                "caused by waiting for the cluster to auto-scale or "
                "because the backend constructor is slow. Resources required: "
                f"{required}, resources available: {available}.")
            self._prev_slow_startup_warning_time = time.time()

        return False
Exemple #4
0
    def get(self, key: str) -> Optional[bytes]:
        """Get the value associated with the given key from the store.

        Args:
            key (str)

        Returns:
            The bytes value. If the key wasn't found, returns None.
        """
        if not isinstance(key, str):
            raise TypeError("key must be a string, got: {}.".format(type(key)))

        try:
            response = self._s3.get_object(Bucket=self._bucket,
                                           Key=self.get_storage_key(key))
            return response["Body"].read()
        except ClientError as e:
            if e.response["Error"]["Code"] == "NoSuchKey":
                logger.warning(f"No such key in s3 for key = {key}")
                return None
            else:
                message = e.response["Error"]["Message"]
                logger.error(f"Encountered ClientError while calling get() "
                             f"in RayExternalKVStore: {message}")
                raise e
Exemple #5
0
    def get_handle(self,
                   endpoint_name: str,
                   missing_ok: Optional[bool] = False) -> RayServeHandle:
        """Retrieve RayServeHandle for service endpoint to invoke it from Python.

        Args:
            endpoint_name (str): A registered service endpoint.
            missing_ok (bool): If true, then Serve won't check the endpoint is
                registered. False by default.

        Returns:
            RayServeHandle
        """
        if not missing_ok and endpoint_name not in ray.get(
                self._controller.get_all_endpoints.remote()):
            raise KeyError(f"Endpoint '{endpoint_name}' does not exist.")

        routers = list(ray.get(self._controller.get_routers.remote()).values())
        current_node_id = ray.get_runtime_context().node_id.hex()

        try:
            router_chosen = next(
                filter(lambda r: get_node_id_for_actor(r) == current_node_id,
                       routers))
        except StopIteration:
            logger.warning(
                f"When getting a handle for {endpoint_name}, Serve can't find "
                "a router on the same node. Serve will use a random router.")
            router_chosen = random.choice(routers)

        return RayServeHandle(
            router_chosen,
            endpoint_name,
        )
Exemple #6
0
    def _get_target_nodes(self) -> List[Tuple[str, str]]:
        """Return the list of (id, resource_key) to deploy HTTP servers on."""
        location = self._config.location
        target_nodes = get_all_node_ids()

        if location == DeploymentMode.NoServer:
            return []

        if location == DeploymentMode.HeadOnly:
            head_node_resource_key = get_current_node_resource_key()
            return [(node_id, node_resource)
                    for node_id, node_resource in target_nodes
                    if node_resource == head_node_resource_key][:1]

        if location == DeploymentMode.FixedNumber:
            num_replicas = self._config.fixed_number_replicas
            if num_replicas > len(target_nodes):
                logger.warning(
                    "You specified fixed_number_replicas="
                    f"{num_replicas} but there are only "
                    f"{len(target_nodes)} total nodes. Serve will start one "
                    "HTTP proxy per node.")
                num_replicas = len(target_nodes)

            # Seed the random state so sample is deterministic.
            # i.e. it will always return the same set of nodes.
            random.seed(self._config.fixed_number_selection_seed)
            return random.sample(sorted(target_nodes), k=num_replicas)

        return target_nodes
Exemple #7
0
    def shutdown(self) -> None:
        """Completely shut down the connected Serve instance.

        Shuts down all processes and deletes all state associated with the
        instance.
        """
        if (not self._shutdown) and ray.is_initialized():
            for goal_id in ray.get(self._controller.shutdown.remote()):
                self._wait_for_goal(goal_id)

            ray.kill(self._controller, no_restart=True)

            # Wait for the named actor entry gets removed as well.
            started = time.time()
            while True:
                try:
                    controller_namespace = _get_controller_namespace(
                        self._detached)
                    ray.get_actor(
                        self._controller_name, namespace=controller_namespace)
                    if time.time() - started > 5:
                        logger.warning(
                            "Waited 5s for Serve to shutdown gracefully but "
                            "the controller is still not cleaned up. "
                            "You can ignore this warning if you are shutting "
                            "down the Ray cluster.")
                        break
                except ValueError:  # actor name is removed
                    break

            self._shutdown = True
Exemple #8
0
    def get_handle(self,
                   endpoint_name: str,
                   missing_ok: Optional[bool] = False,
                   sync: bool = True) -> RayServeHandle:
        """Retrieve RayServeHandle for service endpoint to invoke it from Python.

        Args:
            endpoint_name (str): A registered service endpoint.
            missing_ok (bool): If true, then Serve won't check the endpoint is
                registered. False by default.
            sync (bool): If true, then Serve will return a ServeHandle that
                works everywhere. Otherwise, Serve will return a ServeHandle
                that's only usable in asyncio loop.

        Returns:
            RayServeHandle
        """
        if not missing_ok and endpoint_name not in ray.get(
                self._controller.get_all_endpoints.remote()):
            raise KeyError(f"Endpoint '{endpoint_name}' does not exist.")

        if asyncio.get_event_loop().is_running() and sync:
            logger.warning(
                "You are retrieving a ServeHandle inside an asyncio loop. "
                "Try getting client.get_handle(.., sync=False) to get better "
                "performance.")

        if endpoint_name not in self._handle_cache:
            handle = RayServeHandle(self._controller, endpoint_name, sync=sync)
            self._handle_cache[endpoint_name] = handle
        return self._handle_cache[endpoint_name]
Exemple #9
0
def _check_http_and_checkpoint_options(
        client: Client,
        http_options: Union[dict, HTTPOptions],
        checkpoint_path: str,
) -> None:
    if checkpoint_path and checkpoint_path != client.checkpoint_path:
        logger.warning(
            f"The new client checkpoint path '{checkpoint_path}' "
            f"is different from the existing one '{client.checkpoint_path}'. "
            "The new checkpoint path is ignored.")

    if http_options:
        client_http_options = client.http_config
        new_http_options = http_options if isinstance(
            http_options, HTTPOptions) else HTTPOptions.parse_obj(http_options)
        different_fields = []
        all_http_option_fields = new_http_options.__dict__
        for field in all_http_option_fields:
            if getattr(new_http_options, field) != getattr(
                    client_http_options, field):
                different_fields.append(field)

        if len(different_fields):
            logger.warning(
                "The new client HTTP config differs from the existing one "
                f"in the following fields: {different_fields}. "
                "The new HTTP config is ignored.")
Exemple #10
0
    def get_handle(
            self,
            endpoint_name: str,
            missing_ok: Optional[bool] = False,
            sync: bool = True) -> Union[RayServeHandle, RayServeSyncHandle]:
        """Retrieve RayServeHandle for service endpoint to invoke it from Python.

        Args:
            endpoint_name (str): A registered service endpoint.
            missing_ok (bool): If true, then Serve won't check the endpoint is
                registered. False by default.
            sync (bool): If true, then Serve will return a ServeHandle that
                works everywhere. Otherwise, Serve will return a ServeHandle
                that's only usable in asyncio loop.

        Returns:
            RayServeHandle
        """
        all_endpoints = ray.get(self._controller.get_all_endpoints.remote())
        if not missing_ok and endpoint_name not in all_endpoints:
            raise KeyError(f"Endpoint '{endpoint_name}' does not exist.")

        if asyncio.get_event_loop().is_running() and sync:
            logger.warning(
                "You are retrieving a sync handle inside an asyncio loop. "
                "Try getting client.get_handle(.., sync=False) to get better "
                "performance. Learn more at https://docs.ray.io/en/master/"
                "serve/http-servehandle.html#sync-and-async-handles")

        if not asyncio.get_event_loop().is_running() and not sync:
            logger.warning(
                "You are retrieving an async handle outside an asyncio loop. "
                "You should make sure client.get_handle is called inside a "
                "running event loop. Or call client.get_handle(.., sync=True) "
                "to create sync handle. Learn more at https://docs.ray.io/en/"
                "master/serve/http-servehandle.html#sync-and-async-handles")

        if endpoint_name in all_endpoints:
            this_endpoint = all_endpoints[endpoint_name]
            python_methods: List[str] = this_endpoint["python_methods"]
        else:
            # This can happen in the missing_ok=True case.
            # handle.method_name.remote won't work and user must
            # use the legacy handle.options(method).remote().
            python_methods: List[str] = []

        # NOTE(simon): this extra layer of router seems unnecessary
        # BUT it's needed still because of the shared asyncio thread.
        router = self._get_proxied_router(sync=sync, endpoint=endpoint_name)
        if sync:
            handle = RayServeSyncHandle(router,
                                        endpoint_name,
                                        known_python_methods=python_methods)
        else:
            handle = RayServeHandle(router,
                                    endpoint_name,
                                    known_python_methods=python_methods)
        return handle
Exemple #11
0
 def check(self, *args, _internal=False, **kwargs):
     if self._shutdown:
         raise RayServeException("Client has already been shut down.")
     if not _internal:
         logger.warning(
             "The client-based API is being deprecated in favor of global "
             "API calls (e.g., `serve.create_backend()`). Please replace "
             "all instances of `client.api_call()` with "
             "`serve.api_call()`.")
     return f(self, *args, **kwargs)
Exemple #12
0
 async def retry_method(*args, **kwargs):
     while True:
         result = await f(*args, **kwargs)
         if isinstance(result, ray.exceptions.RayActorError):
             logger.warning(
                 "Actor method '{}' failed, retrying after 100ms.".
                 format(name))
             await asyncio.sleep(0.1)
         else:
             return result
Exemple #13
0
    async def __call__(self, scope, receive, send):
        http_body_bytes = await self.receive_http_body(scope, receive, send)

        headers = {k.decode(): v.decode() for k, v in scope["headers"]}

        # scope["router"] and scope["endpoint"] contain references to a router
        # and endpoint object, respectively, which each in turn contain a
        # reference to the Serve client, which cannot be serialized.
        # The solution is to delete these from scope, as they will not be used.
        del scope["router"]
        del scope["endpoint"]

        # Modify the path and root path so that reverse lookups and redirection
        # work as expected. We do this here instead of in replicas so it can be
        # changed without restarting the replicas.
        scope["path"] = scope["path"].replace(self.path_prefix, "", 1)
        scope["root_path"] = self.path_prefix
        handle = self.handle.options(
            method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
                                    DEFAULT.VALUE),
            shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE),
            http_method=scope["method"].upper(),
            http_headers=headers)

        # NOTE(edoakes): it's important that we defer building the starlette
        # request until it reaches the backend replica to avoid unnecessary
        # serialization cost, so we use a simple dataclass here.
        request = HTTPRequestWrapper(scope, http_body_bytes)

        retries = 0
        backoff_time_s = 0.05
        while retries < MAX_REPLICA_FAILURE_RETRIES:
            object_ref = await handle.remote(request)
            try:
                result = await object_ref
                break
            except RayActorError:
                logger.warning(
                    "Request failed due to replica failure. There are "
                    f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries "
                    "remaining.")
                await asyncio.sleep(backoff_time_s)
                backoff_time_s *= 2
                retries += 1

        if isinstance(result, RayTaskError):
            error_message = "Task Error. Traceback: {}.".format(result)
            await Response(error_message,
                           status_code=500).send(scope, receive, send)
        elif isinstance(result, starlette.responses.Response):
            await result(scope, receive, send)
        else:
            await Response(result).send(scope, receive, send)
Exemple #14
0
    def _process_update(self, updates: Dict[str, UpdatedObject]):
        if isinstance(updates, (ray.exceptions.RayActorError)):
            # This can happen during shutdown where the controller is
            # intentionally killed, the client should just gracefully
            # exit.
            logger.debug("LongPollClient failed to connect to host. "
                         "Shutting down.")
            self.is_running = False
            return

        if isinstance(updates, ConnectionError):
            logger.warning("LongPollClient connection failed, shutting down.")
            self.is_running = False
            return

        if isinstance(updates, (ray.exceptions.RayTaskError)):
            if isinstance(updates.as_instanceof_cause(),
                          (asyncio.TimeoutError)):
                logger.debug("LongPollClient polling timed out. Retrying.")
            else:
                # Some error happened in the controller. It could be a bug or
                # some undesired state.
                logger.error("LongPollHost errored\n" + updates.traceback_str)
            self._poll_next()
            return

        logger.debug(f"LongPollClient {self} received updates for keys: "
                     f"{list(updates.keys())}.")
        for key, update in updates.items():
            self.object_snapshots[key] = update.object_snapshot
            self.snapshot_ids[key] = update.snapshot_id
            callback = self.key_listeners[key]

            # Bind the parameters because closures are late-binding.
            # https://docs.python-guide.org/writing/gotchas/#late-binding-closures # noqa: E501
            def chained(callback=callback, arg=update.object_snapshot):
                callback(arg)
                self._on_callback_completed(trigger_at=len(updates))

            if self.event_loop is None:
                chained()
            else:
                # Schedule the next iteration only if the loop is running.
                # The event loop might not be running if users used a cached
                # version across loops.
                if self.event_loop.is_running():
                    self.event_loop.call_soon_threadsafe(chained)
                else:
                    logger.error(
                        "The event loop is closed, shutting down long poll "
                        "client.")
                    self.is_running = False
Exemple #15
0
async def _send_request_to_handle(handle, scope, receive, send):
    http_body_bytes = await receive_http_body(scope, receive, send)

    headers = {k.decode(): v.decode() for k, v in scope["headers"]}
    handle = handle.options(
        method_name=headers.get("X-SERVE-CALL-METHOD".lower(), DEFAULT.VALUE),
        shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE),
        http_method=scope["method"].upper(),
        http_headers=headers,
    )

    # scope["router"] and scope["endpoint"] contain references to a router
    # and endpoint object, respectively, which each in turn contain a
    # reference to the Serve client, which cannot be serialized.
    # The solution is to delete these from scope, as they will not be used.
    # TODO(edoakes): this can be removed once we deprecate the old API.
    if "router" in scope:
        del scope["router"]
    if "endpoint" in scope:
        del scope["endpoint"]

    # NOTE(edoakes): it's important that we defer building the starlette
    # request until it reaches the backend replica to avoid unnecessary
    # serialization cost, so we use a simple dataclass here.
    request = HTTPRequestWrapper(scope, http_body_bytes)
    # Perform a pickle here to improve latency. Stdlib pickle for simple
    # dataclasses are 10-100x faster than cloudpickle.
    request = pickle.dumps(request)

    retries = 0
    backoff_time_s = 0.05
    while retries < MAX_REPLICA_FAILURE_RETRIES:
        object_ref = await handle.remote(request)
        try:
            result = await object_ref
            break
        except RayActorError:
            logger.warning("Request failed due to replica failure. There are "
                           f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries "
                           "remaining.")
            await asyncio.sleep(backoff_time_s)
            backoff_time_s *= 2
            retries += 1

    if isinstance(result, RayTaskError):
        error_message = "Task Error. Traceback: {}.".format(result)
        await Response(error_message,
                       status_code=500).send(scope, receive, send)
    elif isinstance(result, starlette.responses.Response):
        await result(scope, receive, send)
    else:
        await Response(result).send(scope, receive, send)
    def _checkpoint(self):
        """Checkpoint internal state and write it to the KV store."""
        logger.debug("Writing checkpoint")
        start = time.time()
        checkpoint = pickle.dumps(
            (self.routes, self.backends, self.traffic_policies, self.replicas,
             self.replicas_to_start, self.replicas_to_stop))

        self.kv_store_client.put("checkpoint", checkpoint)
        logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start))

        if random.random() < _CRASH_AFTER_CHECKPOINT_PROBABILITY:
            logger.warning("Intentionally crashing after checkpoint")
            os._exit(0)
Exemple #17
0
    def update(self) -> bool:
        """Updates the state of all running replicas to match the goal state.
        """
        self._scale_all_backends()

        for goal_id in self._completed_goals():
            self._goal_manager.complete_goal(goal_id)

        transitioned_backend_tags = set()
        for backend_tag, replicas in self._replicas.items():
            for replica in replicas.pop(states=[ReplicaState.RUNNING]):
                if replica.check_health():
                    replicas.add(ReplicaState.RUNNING, replica)
                else:
                    logger.warning(
                        f"Replica {replica.replica_tag} of backend "
                        f"{backend_tag} failed health check, stopping it.")
                    replica.set_should_stop(0)
                    replicas.add(ReplicaState.SHOULD_STOP, replica)

            for replica in replicas.pop(states=[ReplicaState.SHOULD_START]):
                replica.start(self._backend_metadata[backend_tag])
                replicas.add(ReplicaState.STARTING, replica)

            for replica in replicas.pop(states=[ReplicaState.SHOULD_STOP]):
                # This replica should be taken off handle's replica set.
                transitioned_backend_tags.add(backend_tag)
                replica.stop()
                replicas.add(ReplicaState.STOPPING, replica)

            for replica in replicas.pop(states=[ReplicaState.STARTING]):
                if replica.check_started():
                    # This replica should be now be added to handle's replica
                    # set.
                    replicas.add(ReplicaState.RUNNING, replica)
                    transitioned_backend_tags.add(backend_tag)
                else:
                    replicas.add(ReplicaState.STARTING, replica)

            for replica in replicas.pop(states=[ReplicaState.STOPPING]):
                if not replica.check_stopped():
                    replicas.add(ReplicaState.STOPPING, replica)

        if len(transitioned_backend_tags) > 0:
            self._checkpoint()
            [
                self._notify_replica_handles_changed(tag)
                for tag in transitioned_backend_tags
            ]
Exemple #18
0
    def _checkpoint(self) -> None:
        """Checkpoint internal state and write it to the KV store."""
        assert self.write_lock.locked()
        logger.debug("Writing checkpoint")
        start = time.time()

        checkpoint = pickle.dumps(
            Checkpoint(self.configuration_store, self.actor_reconciler))

        self.kv_store.put(CHECKPOINT_KEY, checkpoint)
        logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start))

        if random.random(
        ) < _CRASH_AFTER_CHECKPOINT_PROBABILITY and self.detached:
            logger.warning("Intentionally crashing after checkpoint")
            os._exit(0)
Exemple #19
0
    async def backend_control_loop(self):
        start = time.time()
        prev_warning = start
        need_to_continue = True
        while need_to_continue:
            if time.time() - prev_warning > REPLICA_STARTUP_TIME_WARNING_S:
                prev_warning = time.time()
                logger.warning("Waited {:.2f}s for replicas to start up. Make "
                               "sure there are enough resources to create the "
                               "replicas.".format(time.time() - start))

            need_to_continue = (
                await self._check_currently_starting_replicas()
                or await self._check_currently_stopping_replicas())

            asyncio.sleep(1)
Exemple #20
0
    def _checkpoint(self) -> None:
        """Checkpoint internal state and write it to the KV store."""
        assert self.write_lock.locked()
        logger.debug("Writing checkpoint")
        start = time.time()

        checkpoint = pickle.dumps(
            Checkpoint(self.backend_state.checkpoint(),
                       self._serializable_inflight_results))

        self.kv_store.put(CHECKPOINT_KEY, checkpoint)
        logger.debug("Wrote checkpoint in {:.3f}s".format(time.time() - start))

        if random.random(
        ) < _CRASH_AFTER_CHECKPOINT_PROBABILITY and self.detached:
            logger.warning("Intentionally crashing after checkpoint")
            os._exit(0)
Exemple #21
0
    def _checkpoint(self):
        """Checkpoint internal state and write it to the KV store."""
        assert self.write_lock.locked()
        logger.debug("Writing checkpoint")
        start = time.time()
        checkpoint = pickle.dumps(
            (self.routes, list(
                self.routers.keys()), self.backends, self.traffic_policies,
             self.replicas, self.replicas_to_start, self.replicas_to_stop,
             self.backends_to_remove, self.endpoints_to_remove))

        self.kv_store.put(CHECKPOINT_KEY, checkpoint)
        logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start))

        if random.random() < _CRASH_AFTER_CHECKPOINT_PROBABILITY:
            logger.warning("Intentionally crashing after checkpoint")
            os._exit(0)
Exemple #22
0
    def get(self, key: str) -> Optional[bytes]:
        """Get the value associated with the given key from the store.

        Args:
            key (str)

        Returns:
            The bytes value. If the key wasn't found, returns None.
        """
        if not isinstance(key, str):
            raise TypeError("key must be a string, got: {}.".format(type(key)))

        try:
            blob = self._bucket.blob(blob_name=self.get_storage_key(key))
            return blob.download_as_bytes()
        except NotFound:
            logger.warning(f"No such key in GCS for key = {key}")
            return None
Exemple #23
0
    def _validate_batch_size(self):
        if (self.max_batch_size is not None
                and not self.internal_metadata.accepts_batches
                and self.max_batch_size > 1):
            raise ValueError(
                "max_batch_size is set in config but the function or "
                "method does not accept batching. Please use "
                "@serve.accept_batch to explicitly mark that the function or "
                "method accepts a list of requests as an argument.")

        if self.max_batch_size is not None:
            logger.warning(
                "Setting max_batch_size and batch_wait_timeout in the "
                "BackendConfig are deprecated in favor of using the "
                "@serve.batch decorator in the application level. Please see "
                "the documentation for details: "
                "https://docs.ray.io/en/master/serve/ml-models.html#request-batching."  # noqa:E501
            )
Exemple #24
0
    def get_handle(
            self,
            endpoint_name: str,
            missing_ok: Optional[bool] = False,
            sync: bool = True) -> Union[RayServeHandle, RayServeSyncHandle]:
        """Retrieve RayServeHandle for service endpoint to invoke it from Python.

        Args:
            endpoint_name (str): A registered service endpoint.
            missing_ok (bool): If true, then Serve won't check the endpoint is
                registered. False by default.
            sync (bool): If true, then Serve will return a ServeHandle that
                works everywhere. Otherwise, Serve will return a ServeHandle
                that's only usable in asyncio loop.

        Returns:
            RayServeHandle
        """
        if not missing_ok and endpoint_name not in ray.get(
                self._controller.get_all_endpoints.remote()):
            raise KeyError(f"Endpoint '{endpoint_name}' does not exist.")

        if asyncio.get_event_loop().is_running() and sync:
            logger.warning(
                "You are retrieving a sync handle inside an asyncio loop. "
                "Try getting client.get_handle(.., sync=False) to get better "
                "performance. Learn more at https://docs.ray.io/en/master/"
                "serve/advanced.html#sync-and-async-handles")

        if not asyncio.get_event_loop().is_running() and not sync:
            logger.warning(
                "You are retrieving an async handle outside an asyncio loop. "
                "You should make sure client.get_handle is called inside a "
                "running event loop. Or call client.get_handle(.., sync=True) "
                "to create sync handle. Learn more at https://docs.ray.io/en/"
                "master/serve/advanced.html#sync-and-async-handles")

        if sync:
            handle = RayServeSyncHandle(self._get_proxied_router(sync=sync),
                                        endpoint_name)
        else:
            handle = RayServeHandle(self._get_proxied_router(sync=sync),
                                    endpoint_name)
        return handle
Exemple #25
0
    async def update_actor_state(self, start_time: float) -> bool:
        """Returns whether the number of backends has changed."""
        num_starting = len(self.currently_starting_replicas)
        num_stopping = len(self.currently_stopping_replicas)

        num_pending_starts = await self._check_currently_starting_replicas()
        num_pending_stops = await self._check_currently_stopping_replicas()
        time_running = int(time.time() - start_time)
        if (time_running > 0
                and time_running % REPLICA_STARTUP_TIME_WARNING_S == 0):
            delta = time.time() - start_time
            logger.warning(
                f"Waited {delta:.2f}s for {num_pending_starts} replicas "
                f"to start up or {num_pending_stops} replicas to shutdown."
                " Make sure there are enough resources to create the "
                "replicas.")

        return (len(self.currently_starting_replicas) != num_starting) or \
            (len(self.currently_stopping_replicas) != num_stopping)
Exemple #26
0
    async def __call__(self, scope, receive, send):
        http_body_bytes = await self.receive_http_body(scope, receive, send)

        headers = {k.decode(): v.decode() for k, v in scope["headers"]}

        # Modify the path and root path so that reverse lookups and redirection
        # work as expected. We do this here instead of in replicas so it can be
        # changed without restarting the replicas.
        scope["path"] = scope["path"].replace(self.path_prefix, "", 1)
        scope["root_path"] = self.path_prefix
        starlette_request = build_starlette_request(scope, http_body_bytes)
        handle = self.handle.options(
            method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
                                    DEFAULT.VALUE),
            shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE),
            http_method=scope["method"].upper(),
            http_headers=headers)

        retries = 0
        backoff_time_s = 0.05
        while retries < MAX_ACTOR_FAILURE_RETRIES:
            object_ref = await handle.remote(starlette_request)
            try:
                result = await object_ref
                break
            except RayActorError:
                logger.warning(
                    "Request failed due to actor failure. There are "
                    f"{MAX_ACTOR_FAILURE_RETRIES - retries} retries "
                    "remaining.")
                await asyncio.sleep(backoff_time_s)
                backoff_time_s *= 2
                retries += 1

        if isinstance(result, RayTaskError):
            error_message = "Task Error. Traceback: {}.".format(result)
            await Response(error_message,
                           status_code=500).send(scope, receive, send)
        elif isinstance(result, starlette.responses.Response):
            await result(scope, receive, send)
        else:
            await Response(result).send(scope, receive, send)
Exemple #27
0
    def _process_update(self, updates: Dict[str, UpdatedObject]):
        if isinstance(updates, (ray.exceptions.RayActorError)):
            # This can happen during shutdown where the controller is
            # intentionally killed, the client should just gracefully
            # exit.
            logger.debug("LongPollClient failed to connect to host. Shutting down.")
            self.is_running = False
            return

        if isinstance(updates, ConnectionError):
            logger.warning("LongPollClient connection failed, shutting down.")
            self.is_running = False
            return

        if isinstance(updates, (ray.exceptions.RayTaskError)):
            if isinstance(updates.as_instanceof_cause(), (asyncio.TimeoutError)):
                logger.debug("LongPollClient polling timed out. Retrying.")
            else:
                # Some error happened in the controller. It could be a bug or
                # some undesired state.
                logger.error("LongPollHost errored\n" + updates.traceback_str)
            # We must call this in event loop so it works in Ray Client.
            # See https://github.com/ray-project/ray/issues/20971
            self._schedule_to_event_loop(self._poll_next)
            return

        logger.debug(
            f"LongPollClient {self} received updates for keys: "
            f"{list(updates.keys())}."
        )
        for key, update in updates.items():
            self.object_snapshots[key] = update.object_snapshot
            self.snapshot_ids[key] = update.snapshot_id
            callback = self.key_listeners[key]

            # Bind the parameters because closures are late-binding.
            # https://docs.python-guide.org/writing/gotchas/#late-binding-closures # noqa: E501
            def chained(callback=callback, arg=update.object_snapshot):
                callback(arg)
                self._on_callback_completed(trigger_at=len(updates))

            self._schedule_to_event_loop(chained)
Exemple #28
0
    async def backend_control_loop(self):
        start = time.time()
        prev_warning = start
        need_to_continue = True
        num_pending_starts, num_pending_stops = 0, 0
        while need_to_continue:
            if time.time() - prev_warning > REPLICA_STARTUP_TIME_WARNING_S:
                prev_warning = time.time()
                delta = time.time() - start
                logger.warning(
                    f"Waited {delta:.2f}s for {num_pending_starts} replicas "
                    f"to start up or {num_pending_stops} replicas to shutdown."
                    " Make sure there are enough resources to create the "
                    "replicas.")

            num_pending_starts = await self._check_currently_starting_replicas(
            )
            num_pending_stops = await self._check_currently_stopping_replicas()
            need_to_continue = num_pending_starts or num_pending_stops

            asyncio.sleep(1)
Exemple #29
0
    def check_started(self):
        if self._state == ReplicaState.RUNNING:
            return True
        assert self._state == ReplicaState.STARTING, (
            f"State must be {ReplicaState.STARTING}, *not* {self._state}")
        ready, _ = ray.wait([self._startup_obj_ref], timeout=0)
        if len(ready) == 1:
            self._state = ReplicaState.RUNNING
            return True

        time_since_start = time.time() - self._start_time
        if (time_since_start > SLOW_STARTUP_WARNING_S
                and time.time() - self._prev_slow_startup_warning_time >
                SLOW_STARTUP_WARNING_PERIOD_S):
            logger.warning(
                f"Replica '{self._replica_tag}' for backend "
                f"'{self._backend_tag}' has taken more than "
                f"{time_since_start:.0f}s to start up. This may be "
                "caused by waiting for the cluster to auto-scale or "
                "because the backend constructor is slow.")
            self._prev_slow_startup_warning_time = time.time()

        return False
Exemple #30
0
    async def __call__(self, scope, receive, send):
        http_body_bytes = await self.receive_http_body(scope, receive, send)

        headers = {k.decode(): v.decode() for k, v in scope["headers"]}

        retries = 0
        backoff_time_s = 0.05
        while retries < MAX_ACTOR_FAILURE_RETRIES:
            object_ref = await self.handle.options(
                method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
                                        DEFAULT.VALUE),
                shard_key=headers.get("X-SERVE-SHARD-KEY".lower(),
                                      DEFAULT.VALUE),
                http_method=scope["method"].upper(),
                http_headers=headers).remote(
                    build_starlette_request(scope, http_body_bytes))

            try:
                result = await object_ref
                break
            except RayActorError:
                logger.warning(
                    "Request failed due to actor failure. There are "
                    f"{MAX_ACTOR_FAILURE_RETRIES - retries} retries "
                    "remaining.")
                await asyncio.sleep(backoff_time_s)
                backoff_time_s *= 2
                retries += 1

        if isinstance(result, RayTaskError):
            error_message = "Task Error. Traceback: {}.".format(result)
            await Response(
                error_message, status_code=500).send(scope, receive, send)
        elif isinstance(result, starlette.responses.Response):
            await result(scope, receive, send)
        else:
            await Response(result).send(scope, receive, send)