Example #1
0
class ServeController:
    """Responsible for managing the state of the serving system.

    The controller implements fault tolerance by persisting its state in
    a new checkpoint each time a state change is made. If the actor crashes,
    the latest checkpoint is loaded and the state is recovered. Checkpoints
    are written/read using a provided KV-store interface.

    All hard state in the system is maintained by this actor and persisted via
    these checkpoints. Soft state required by other components is fetched by
    those actors from this actor on startup and updates are pushed out from
    this actor.

    All other actors started by the controller are named, detached actors
    so they will not fate share with the controller if it crashes.

    The following guarantees are provided for state-changing calls to the
    controller:
        - If the call succeeds, the change was made and will be reflected in
          the system even if the controller or other actors die unexpectedly.
        - If the call fails, the change may have been made but isn't guaranteed
          to have been. The client should retry in this case. Note that this
          requires all implementations here to be idempotent.
    """
    async def __init__(self,
                       controller_name: str,
                       http_host: str,
                       http_port: str,
                       http_middlewares: List[Any],
                       detached: bool = False):
        # Used to read/write checkpoints.
        self.kv_store = RayInternalKVStore(namespace=controller_name)
        # ConfigurationStore
        self.configuration_store = ConfigurationStore()
        # ActorStateReconciler
        self.actor_reconciler = ActorStateReconciler(controller_name, detached)

        # backend -> AutoscalingPolicy
        self.autoscaling_policies = dict()

        # Dictionary of backend_tag -> router_name -> most recent queue length.
        self.backend_stats = defaultdict(lambda: defaultdict(dict))

        # Used to ensure that only a single state-changing operation happens
        # at any given time.
        self.write_lock = asyncio.Lock()

        self.http_host = http_host
        self.http_port = http_port
        self.http_middlewares = http_middlewares

        # If starting the actor for the first time, starts up the other system
        # components. If recovering, fetches their actor handles.
        self.actor_reconciler._start_routers_if_needed(self.http_host,
                                                       self.http_port,
                                                       self.http_middlewares)

        # NOTE(edoakes): unfortunately, we can't completely recover from a
        # checkpoint in the constructor because we block while waiting for
        # other actors to start up, and those actors fetch soft state from
        # this actor. Because no other tasks will start executing until after
        # the constructor finishes, if we were to run this logic in the
        # constructor it could lead to deadlock between this actor and a child.
        # However we do need to guarantee that we have fully recovered from a
        # checkpoint before any other state-changing calls run. We address this
        # by acquiring the write_lock and then posting the task to recover from
        # a checkpoint to the event loop. Other state-changing calls acquire
        # this lock and will be blocked until recovering from the checkpoint
        # finishes.
        checkpoint = self.kv_store.get(CHECKPOINT_KEY)
        if checkpoint is None:
            logger.debug("No checkpoint found")
        else:
            await self.write_lock.acquire()
            asyncio.get_event_loop().create_task(
                self._recover_from_checkpoint(checkpoint))

        # NOTE(simon): Currently we do all-to-all broadcast. This means
        # any listeners will receive notification for all changes. This
        # can be problem at scale, e.g. updating a single backend config
        # will send over the entire configs. In the future, we should
        # optimize the logic to support subscription by key.
        self.long_poll_host = LongPollerHost()
        self.notify_backend_configs_changed()
        self.notify_replica_handles_changed()
        self.notify_traffic_policies_changed()

        asyncio.get_event_loop().create_task(self.run_control_loop())

    def notify_replica_handles_changed(self):
        self.long_poll_host.notify_changed(
            "worker_handles", {
                backend_tag: list(replica_dict.values())
                for backend_tag, replica_dict in
                self.actor_reconciler.backend_replicas.items()
            })

    def notify_traffic_policies_changed(self):
        self.long_poll_host.notify_changed(
            "traffic_policies", self.configuration_store.traffic_policies)

    def notify_backend_configs_changed(self):
        self.long_poll_host.notify_changed(
            "backend_configs", self.configuration_store.get_backend_configs())

    async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
        """Proxy long pull client's listen request.

        Args:
            keys_to_snapshot_ids (Dict[str, int]): Snapshot IDs are used to
              determine whether or not the host should immediately return the
              data or wait for the value to be changed.
        """
        return await (
            self.long_poll_host.listen_for_change(keys_to_snapshot_ids))

    def get_routers(self) -> Dict[str, ActorHandle]:
        """Returns a dictionary of node ID to router actor handles."""
        return self.actor_reconciler.routers_cache

    def get_router_config(self) -> Dict[str, Dict[str, Tuple[str, List[str]]]]:
        """Called by the router on startup to fetch required state."""
        return self.configuration_store.routes

    def _checkpoint(self) -> None:
        """Checkpoint internal state and write it to the KV store."""
        assert self.write_lock.locked()
        logger.debug("Writing checkpoint")
        start = time.time()

        checkpoint = pickle.dumps(
            Checkpoint(self.configuration_store, self.actor_reconciler))

        self.kv_store.put(CHECKPOINT_KEY, checkpoint)
        logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start))

        if random.random(
        ) < _CRASH_AFTER_CHECKPOINT_PROBABILITY and self.detached:
            logger.warning("Intentionally crashing after checkpoint")
            os._exit(0)

    async def _recover_from_checkpoint(self, checkpoint_bytes: bytes) -> None:
        """Recover the instance state from the provided checkpoint.

        Performs the following operations:
            1) Deserializes the internal state from the checkpoint.
            2) Pushes the latest configuration to the routers
               in case we crashed before updating them.
            3) Starts/stops any replicas that are pending creation or
               deletion.

        NOTE: this requires that self.write_lock is already acquired and will
        release it before returning.
        """
        assert self.write_lock.locked()

        start = time.time()
        logger.info("Recovering from checkpoint")

        restored_checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
        # Restore ConfigurationStore
        self.configuration_store = restored_checkpoint.config

        # Restore ActorStateReconciler
        self.actor_reconciler = restored_checkpoint.reconciler

        self.autoscaling_policies = await self.actor_reconciler.\
            _recover_from_checkpoint(self.configuration_store, self)

        logger.info("Recovered from checkpoint in {:.3f}s".format(time.time() -
                                                                  start))

        self.write_lock.release()

    async def do_autoscale(self) -> None:
        for backend, info in self.configuration_store.backends.items():
            if backend not in self.autoscaling_policies:
                continue

            new_num_replicas = self.autoscaling_policies[backend].scale(
                self.backend_stats[backend], info.backend_config.num_replicas)
            if new_num_replicas > 0:
                await self.update_backend_config(
                    backend, BackendConfig(num_replicas=new_num_replicas))

    async def run_control_loop(self) -> None:
        while True:
            await self.do_autoscale()
            async with self.write_lock:
                self.actor_reconciler._start_routers_if_needed(
                    self.http_host, self.http_port, self.http_middlewares)
                checkpoint_required = self.actor_reconciler.\
                    _stop_routers_if_needed()
                if checkpoint_required:
                    self._checkpoint()

            await asyncio.sleep(CONTROL_LOOP_PERIOD_S)

    def get_backend_configs(self) -> Dict[str, BackendConfig]:
        """Fetched by the router on startup."""
        return self.configuration_store.get_backend_configs()

    def get_traffic_policies(self) -> Dict[str, TrafficPolicy]:
        """Fetched by the router on startup."""
        return self.configuration_store.traffic_policies

    def _list_replicas(self, backend_tag: BackendTag) -> List[ReplicaTag]:
        """Used only for testing."""
        return list(self.actor_reconciler.backend_replicas[backend_tag].keys())

    def get_traffic_policy(self, endpoint: str) -> TrafficPolicy:
        """Fetched by serve handles."""
        return self.configuration_store.traffic_policies[endpoint]

    def get_all_replica_handles(self) -> Dict[str, Dict[str, ActorHandle]]:
        """Fetched by the router on startup."""
        return self.actor_reconciler.backend_replicas

    def get_all_backends(self) -> Dict[str, BackendConfig]:
        """Returns a dictionary of backend tag to backend config."""
        return self.configuration_store.get_backend_configs()

    def get_all_endpoints(self) -> Dict[str, Dict[str, Any]]:
        """Returns a dictionary of endpoint to endpoint config."""
        endpoints = {}
        for route, (endpoint,
                    methods) in self.configuration_store.routes.items():
            if endpoint in self.configuration_store.traffic_policies:
                traffic_policy = self.configuration_store.traffic_policies[
                    endpoint]
                traffic_dict = traffic_policy.traffic_dict
                shadow_dict = traffic_policy.shadow_dict
            else:
                traffic_dict = {}
                shadow_dict = {}

            endpoints[endpoint] = {
                "route": route if route.startswith("/") else None,
                "methods": methods,
                "traffic": traffic_dict,
                "shadows": shadow_dict,
            }
        return endpoints

    async def _set_traffic(self, endpoint_name: str,
                           traffic_dict: Dict[str, float]) -> None:
        if endpoint_name not in self.get_all_endpoints():
            raise ValueError("Attempted to assign traffic for an endpoint '{}'"
                             " that is not registered.".format(endpoint_name))

        assert isinstance(traffic_dict,
                          dict), "Traffic policy must be a dictionary."

        for backend in traffic_dict:
            if self.configuration_store.get_backend(backend) is None:
                raise ValueError(
                    "Attempted to assign traffic to a backend '{}' that "
                    "is not registered.".format(backend))

        traffic_policy = TrafficPolicy(traffic_dict)
        self.configuration_store.traffic_policies[
            endpoint_name] = traffic_policy

        # NOTE(edoakes): we must write a checkpoint before pushing the
        # update to avoid inconsistent state if we crash after pushing the
        # update.
        self._checkpoint()

        self.notify_traffic_policies_changed()

    async def set_traffic(self, endpoint_name: str,
                          traffic_dict: Dict[str, float]) -> None:
        """Sets the traffic policy for the specified endpoint."""
        async with self.write_lock:
            await self._set_traffic(endpoint_name, traffic_dict)

    async def shadow_traffic(self, endpoint_name: str, backend_tag: BackendTag,
                             proportion: float) -> None:
        """Shadow traffic from the endpoint to the backend."""
        async with self.write_lock:
            if endpoint_name not in self.get_all_endpoints():
                raise ValueError(
                    "Attempted to shadow traffic from an "
                    "endpoint '{}' that is not registered.".format(
                        endpoint_name))

            if self.configuration_store.get_backend(backend_tag) is None:
                raise ValueError(
                    "Attempted to shadow traffic to a backend '{}' that "
                    "is not registered.".format(backend_tag))

            self.configuration_store.traffic_policies[
                endpoint_name].set_shadow(backend_tag, proportion)

            # NOTE(edoakes): we must write a checkpoint before pushing the
            # update to avoid inconsistent state if we crash after pushing the
            # update.
            self._checkpoint()
            self.notify_traffic_policies_changed()

    # TODO(architkulkarni): add Optional for route after cloudpickle upgrade
    async def create_endpoint(self, endpoint: str, traffic_dict: Dict[str,
                                                                      float],
                              route, methods) -> None:
        """Create a new endpoint with the specified route and methods.

        If the route is None, this is a "headless" endpoint that will not
        be exposed over HTTP and can only be accessed via a handle.
        """
        async with self.write_lock:
            # If this is a headless endpoint with no route, key the endpoint
            # based on its name.
            # TODO(edoakes): we should probably just store routes and endpoints
            # separately.
            if route is None:
                route = endpoint

            # TODO(edoakes): move this to client side.
            err_prefix = "Cannot create endpoint."
            if route in self.configuration_store.routes:

                # Ensures this method is idempotent
                if self.configuration_store.routes[route] == (endpoint,
                                                              methods):
                    return

                else:
                    raise ValueError(
                        "{} Route '{}' is already registered.".format(
                            err_prefix, route))

            if endpoint in self.get_all_endpoints():
                raise ValueError(
                    "{} Endpoint '{}' is already registered.".format(
                        err_prefix, endpoint))

            logger.info(
                "Registering route '{}' to endpoint '{}' with methods '{}'.".
                format(route, endpoint, methods))

            self.configuration_store.routes[route] = (endpoint, methods)

            # NOTE(edoakes): checkpoint is written in self._set_traffic.
            await self._set_traffic(endpoint, traffic_dict)
            await asyncio.gather(*[
                router.set_route_table.remote(self.configuration_store.routes)
                for router in self.actor_reconciler.router_handles()
            ])

    async def delete_endpoint(self, endpoint: str) -> None:
        """Delete the specified endpoint.

        Does not modify any corresponding backends.
        """
        logger.info("Deleting endpoint '{}'".format(endpoint))
        async with self.write_lock:
            # This method must be idempotent. We should validate that the
            # specified endpoint exists on the client.
            for route, (route_endpoint,
                        _) in self.configuration_store.routes.items():
                if route_endpoint == endpoint:
                    route_to_delete = route
                    break
            else:
                logger.info("Endpoint '{}' doesn't exist".format(endpoint))
                return

            # Remove the routing entry.
            del self.configuration_store.routes[route_to_delete]

            # Remove the traffic policy entry if it exists.
            if endpoint in self.configuration_store.traffic_policies:
                del self.configuration_store.traffic_policies[endpoint]

            self.actor_reconciler.endpoints_to_remove.append(endpoint)

            # NOTE(edoakes): we must write a checkpoint before pushing the
            # updates to the routers to avoid inconsistent state if we crash
            # after pushing the update.
            self._checkpoint()

            await asyncio.gather(*[
                router.set_route_table.remote(self.configuration_store.routes)
                for router in self.actor_reconciler.router_handles()
            ])

    async def create_backend(self, backend_tag: BackendTag,
                             backend_config: BackendConfig,
                             replica_config: ReplicaConfig) -> None:
        """Register a new backend under the specified tag."""
        async with self.write_lock:
            # Ensures this method is idempotent.
            backend_info = self.configuration_store.get_backend(backend_tag)
            if backend_info is not None:
                if (backend_info.backend_config == backend_config
                        and backend_info.replica_config == replica_config):
                    return

            backend_replica = create_backend_replica(
                replica_config.func_or_class)

            # Save creator that starts replicas, the arguments to be passed in,
            # and the configuration for the backends.
            self.configuration_store.add_backend(
                backend_tag,
                BackendInfo(worker_class=backend_replica,
                            backend_config=backend_config,
                            replica_config=replica_config))
            metadata = backend_config.internal_metadata
            if metadata.autoscaling_config is not None:
                self.autoscaling_policies[
                    backend_tag] = BasicAutoscalingPolicy(
                        backend_tag, metadata.autoscaling_config)

            try:
                self.actor_reconciler._scale_backend_replicas(
                    self.configuration_store.backends, backend_tag,
                    backend_config.num_replicas)
            except RayServeException as e:
                del self.configuration_store.backends[backend_tag]
                raise e

            # NOTE(edoakes): we must write a checkpoint before starting new
            # or pushing the updated config to avoid inconsistent state if we
            # crash while making the change.
            self._checkpoint()
            await self.actor_reconciler._start_pending_backend_replicas(
                self.configuration_store)

            self.notify_replica_handles_changed()

            # Set the backend config inside the router
            # (particularly for max_concurrent_queries).
            self.notify_backend_configs_changed()
            await self.broadcast_backend_config(backend_tag)

    async def delete_backend(self, backend_tag: BackendTag) -> None:
        async with self.write_lock:
            # This method must be idempotent. We should validate that the
            # specified backend exists on the client.
            if self.configuration_store.get_backend(backend_tag) is None:
                return

            # Check that the specified backend isn't used by any endpoints.
            for endpoint, traffic_policy in self.configuration_store.\
                    traffic_policies.items():
                if (backend_tag in traffic_policy.traffic_dict
                        or backend_tag in traffic_policy.shadow_dict):
                    raise ValueError("Backend '{}' is used by endpoint '{}' "
                                     "and cannot be deleted. Please remove "
                                     "the backend from all endpoints and try "
                                     "again.".format(backend_tag, endpoint))

            # Scale its replicas down to 0. This will also remove the backend
            # from self.configuration_store.backends and
            # self.actor_reconciler.backend_replicas.
            self.actor_reconciler._scale_backend_replicas(
                self.configuration_store.backends, backend_tag, 0)

            # Remove the backend's metadata.
            del self.configuration_store.backends[backend_tag]
            if backend_tag in self.autoscaling_policies:
                del self.autoscaling_policies[backend_tag]

            # Add the intention to remove the backend from the router.
            self.actor_reconciler.backends_to_remove.append(backend_tag)

            # NOTE(edoakes): we must write a checkpoint before removing the
            # backend from the router to avoid inconsistent state if we crash
            # after pushing the update.
            self._checkpoint()
            await self.actor_reconciler._stop_pending_backend_replicas()

            self.notify_replica_handles_changed()

    async def update_backend_config(self, backend_tag: BackendTag,
                                    config_options: BackendConfig) -> None:
        """Set the config for the specified backend."""
        async with self.write_lock:
            assert (self.configuration_store.get_backend(backend_tag)
                    ), "Backend {} is not registered.".format(backend_tag)
            assert isinstance(config_options, BackendConfig)

            stored_backend_config = self.configuration_store.get_backend(
                backend_tag).backend_config
            backend_config = stored_backend_config.copy(
                update=config_options.dict(exclude_unset=True))
            backend_config._validate_complete()
            self.configuration_store.get_backend(
                backend_tag).backend_config = backend_config

            # Scale the replicas with the new configuration.
            self.actor_reconciler._scale_backend_replicas(
                self.configuration_store.backends, backend_tag,
                backend_config.num_replicas)

            # NOTE(edoakes): we must write a checkpoint before pushing the
            # update to avoid inconsistent state if we crash after pushing the
            # update.
            self._checkpoint()

            # Inform the router about change in configuration
            # (particularly for setting max_batch_size).

            await self.actor_reconciler._start_pending_backend_replicas(
                self.configuration_store)
            await self.actor_reconciler._stop_pending_backend_replicas()

            self.notify_replica_handles_changed()
            self.notify_backend_configs_changed()

            await self.broadcast_backend_config(backend_tag)

    async def broadcast_backend_config(self, backend_tag: BackendTag) -> None:
        backend_config = self.configuration_store.get_backend(
            backend_tag).backend_config
        broadcast_futures = [
            replica.update_config.remote(backend_config).as_future()
            for replica in
            self.actor_reconciler.get_replica_handles_for_backend(backend_tag)
        ]
        await asyncio.gather(*broadcast_futures)

    def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig:
        """Get the current config for the specified backend."""
        assert (self.configuration_store.get_backend(backend_tag)
                ), "Backend {} is not registered.".format(backend_tag)
        return self.configuration_store.get_backend(backend_tag).backend_config

    async def shutdown(self) -> None:
        """Shuts down the serve instance completely."""
        async with self.write_lock:
            for router in self.actor_reconciler.router_handles():
                ray.kill(router, no_restart=True)
            for replica in self.actor_reconciler.get_replica_handles():
                ray.kill(replica, no_restart=True)
            self.kv_store.delete(CHECKPOINT_KEY)
Example #2
0
class ServeController:
    """Responsible for managing the state of the serving system.

    The controller implements fault tolerance by persisting its state in
    a new checkpoint each time a state change is made. If the actor crashes,
    the latest checkpoint is loaded and the state is recovered. Checkpoints
    are written/read using a provided KV-store interface.

    All hard state in the system is maintained by this actor and persisted via
    these checkpoints. Soft state required by other components is fetched by
    those actors from this actor on startup and updates are pushed out from
    this actor.

    All other actors started by the controller are named, detached actors
    so they will not fate share with the controller if it crashes.

    The following guarantees are provided for state-changing calls to the
    controller:
        - If the call succeeds, the change was made and will be reflected in
          the system even if the controller or other actors die unexpectedly.
        - If the call fails, the change may have been made but isn't guaranteed
          to have been. The client should retry in this case. Note that this
          requires all implementations here to be idempotent.
    """

    async def __init__(self, instance_name, http_host, http_port,
                       metric_exporter_class, _http_middlewares):
        # Unique name of the serve instance managed by this actor. Used to
        # namespace child actors and checkpoints.
        self.instance_name = instance_name
        # Used to read/write checkpoints.
        self.kv_store = RayInternalKVStore(namespace=instance_name)
        # path -> (endpoint, methods).
        self.routes = dict()
        # backend -> BackendInfo.
        self.backends = dict()
        # backend -> AutoscalingPolicy
        self.autoscaling_policies = dict()
        # backend -> replica_tags.
        self.replicas = defaultdict(list)
        # replicas that should be started if recovering from a checkpoint.
        self.replicas_to_start = defaultdict(list)
        # replicas that should be stopped if recovering from a checkpoint.
        self.replicas_to_stop = defaultdict(list)
        # backends that should be removed from the router if recovering from a
        # checkpoint.
        self.backends_to_remove = list()
        # endpoints that should be removed from the router if recovering from a
        # checkpoint.
        self.endpoints_to_remove = list()
        # endpoint -> TrafficPolicy
        self.traffic_policies = dict()
        # Dictionary of backend tag to dictionaries of replica tag to worker.
        # TODO(edoakes): consider removing this and just using the names.
        self.workers = defaultdict(dict)
        # Dictionary of backend_tag -> router_name -> most recent queue length.
        self.backend_stats = defaultdict(lambda: defaultdict(dict))

        # Used to ensure that only a single state-changing operation happens
        # at any given time.
        self.write_lock = asyncio.Lock()

        # Cached handles to actors in the system.
        # node_id -> actor_handle
        self.routers = dict()
        self.metric_exporter = None

        self.http_host = http_host
        self.http_port = http_port
        self._http_middlewares = _http_middlewares

        # If starting the actor for the first time, starts up the other system
        # components. If recovering, fetches their actor handles.
        self._start_metric_exporter(metric_exporter_class)
        self._start_routers_if_needed()

        # NOTE(edoakes): unfortunately, we can't completely recover from a
        # checkpoint in the constructor because we block while waiting for
        # other actors to start up, and those actors fetch soft state from
        # this actor. Because no other tasks will start executing until after
        # the constructor finishes, if we were to run this logic in the
        # constructor it could lead to deadlock between this actor and a child.
        # However we do need to guarantee that we have fully recovered from a
        # checkpoint before any other state-changing calls run. We address this
        # by acquiring the write_lock and then posting the task to recover from
        # a checkpoint to the event loop. Other state-changing calls acquire
        # this lock and will be blocked until recovering from the checkpoint
        # finishes.
        checkpoint = self.kv_store.get(CHECKPOINT_KEY)
        if checkpoint is None:
            logger.debug("No checkpoint found")
        else:
            await self.write_lock.acquire()
            asyncio.get_event_loop().create_task(
                self._recover_from_checkpoint(checkpoint))

        asyncio.get_event_loop().create_task(self.run_control_loop())

    def _start_routers_if_needed(self):
        """Start a router on every node if it doesn't already exist."""
        for node_id, node_resource in get_all_node_ids():
            if node_id in self.routers:
                continue

            router_name = format_actor_name(SERVE_PROXY_NAME,
                                            self.instance_name, node_id)
            try:
                router = ray.get_actor(router_name)
            except ValueError:
                logger.info("Starting router with name '{}' on node '{}' "
                            "listening on '{}:{}'".format(
                                router_name, node_id, self.http_host,
                                self.http_port))
                router = HTTPProxyActor.options(
                    name=router_name,
                    max_concurrency=ASYNC_CONCURRENCY,
                    max_restarts=-1,
                    max_task_retries=-1,
                    resources={
                        node_resource: 0.01
                    },
                ).remote(
                    node_id,
                    self.http_host,
                    self.http_port,
                    instance_name=self.instance_name,
                    _http_middlewares=self._http_middlewares)

            self.routers[node_id] = router

    def _stop_routers_if_needed(self):
        """Removes router actors from any nodes that no longer exist.

        Returns whether or not any actors were removed (a checkpoint should
        be taken).
        """
        checkpoint_required = False
        all_node_ids = {node_id for node_id, _ in get_all_node_ids()}
        to_stop = []
        for node_id in self.routers:
            if node_id not in all_node_ids:
                logger.info(
                    "Removing router on removed node '{}'.".format(node_id))
                to_stop.append(node_id)

        for node_id in to_stop:
            router_handle = self.routers.pop(node_id)
            ray.kill(router_handle, no_restart=True)
            checkpoint_required = True

        return checkpoint_required

    def get_routers(self):
        """Returns a dictionary of node ID to router actor handles."""
        return self.routers

    def get_router_config(self):
        """Called by the router on startup to fetch required state."""
        return self.routes

    def _start_metric_exporter(self, metric_exporter_class):
        """Get the metric exporter belonging to this serve instance.

        If the metric exporter does not already exist, it will be started.
        """
        metric_sink_name = format_actor_name(SERVE_METRIC_SINK_NAME,
                                             self.instance_name)
        try:
            self.metric_exporter = ray.get_actor(metric_sink_name)
        except ValueError:
            logger.info("Starting metric exporter with name '{}'".format(
                metric_sink_name))
            self.metric_exporter = MetricExporterActor.options(
                name=metric_sink_name).remote(metric_exporter_class)

    def get_metric_exporter(self):
        """Returns a handle to the metric exporter managed by this actor."""
        return [self.metric_exporter]

    def _checkpoint(self):
        """Checkpoint internal state and write it to the KV store."""
        assert self.write_lock.locked()
        logger.debug("Writing checkpoint")
        start = time.time()
        checkpoint = pickle.dumps(
            (self.routes, list(
                self.routers.keys()), self.backends, self.traffic_policies,
             self.replicas, self.replicas_to_start, self.replicas_to_stop,
             self.backends_to_remove, self.endpoints_to_remove))

        self.kv_store.put(CHECKPOINT_KEY, checkpoint)
        logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start))

        if random.random() < _CRASH_AFTER_CHECKPOINT_PROBABILITY:
            logger.warning("Intentionally crashing after checkpoint")
            os._exit(0)

    async def _recover_from_checkpoint(self, checkpoint_bytes):
        """Recover the instance state from the provided checkpoint.

        Performs the following operations:
            1) Deserializes the internal state from the checkpoint.
            2) Pushes the latest configuration to the routers
               in case we crashed before updating them.
            3) Starts/stops any worker replicas that are pending creation or
               deletion.

        NOTE: this requires that self.write_lock is already acquired and will
        release it before returning.
        """
        assert self.write_lock.locked()

        start = time.time()
        logger.info("Recovering from checkpoint")

        # Load internal state from the checkpoint data.
        (
            self.routes,
            router_node_ids,
            self.backends,
            self.traffic_policies,
            self.replicas,
            self.replicas_to_start,
            self.replicas_to_stop,
            self.backends_to_remove,
            self.endpoints_to_remove,
        ) = pickle.loads(checkpoint_bytes)

        for node_id in router_node_ids:
            router_name = format_actor_name(SERVE_PROXY_NAME,
                                            self.instance_name, node_id)
            self.routers[node_id] = ray.get_actor(router_name)

        # Fetch actor handles for all of the backend replicas in the system.
        # All of these workers are guaranteed to already exist because they
        # would not be written to a checkpoint in self.workers until they
        # were created.
        for backend_tag, replica_tags in self.replicas.items():
            for replica_tag in replica_tags:
                replica_name = format_actor_name(replica_tag,
                                                 self.instance_name)
                self.workers[backend_tag][replica_tag] = ray.get_actor(
                    replica_name)

        # Push configuration state to the router.
        # TODO(edoakes): should we make this a pull-only model for simplicity?
        for endpoint, traffic_policy in self.traffic_policies.items():
            await asyncio.gather(*[
                router.set_traffic.remote(endpoint, traffic_policy)
                for router in self.routers.values()
            ])

        for backend_tag, replica_dict in self.workers.items():
            for replica_tag, worker in replica_dict.items():
                await asyncio.gather(*[
                    router.add_new_worker.remote(backend_tag, replica_tag,
                                                 worker)
                    for router in self.routers.values()
                ])

        for backend, info in self.backends.items():
            await asyncio.gather(*[
                router.set_backend_config.remote(backend, info.backend_config)
                for router in self.routers.values()
            ])
            await self.broadcast_backend_config(backend)
            if info.backend_config.autoscaling_config is not None:
                self.autoscaling_policies[backend] = BasicAutoscalingPolicy(
                    backend, info.backend_config.autoscaling_config)

        # Push configuration state to the routers.
        await asyncio.gather(*[
            router.set_route_table.remote(self.routes)
            for router in self.routers.values()
        ])

        # Start/stop any pending backend replicas.
        await self._start_pending_replicas()
        await self._stop_pending_replicas()

        # Remove any pending backends and endpoints.
        await self._remove_pending_backends()
        await self._remove_pending_endpoints()

        logger.info(
            "Recovered from checkpoint in {:.3f}s".format(time.time() - start))

        self.write_lock.release()

    async def do_autoscale(self):
        for backend in self.backends:
            if backend not in self.autoscaling_policies:
                continue

            new_num_replicas = self.autoscaling_policies[backend].scale(
                self.backend_stats[backend],
                self.backends[backend].backend_config.num_replicas)
            if new_num_replicas > 0:
                await self.update_backend_config(
                    backend, {"num_replicas": new_num_replicas})

    async def run_control_loop(self):
        while True:
            await self.do_autoscale()
            async with self.write_lock:
                self._start_routers_if_needed()
                checkpoint_required = self._stop_routers_if_needed()
                if checkpoint_required:
                    self._checkpoint()

            await asyncio.sleep(CONTROL_LOOP_PERIOD_S)

    def get_backend_configs(self):
        """Fetched by the router on startup."""
        backend_configs = {}
        for backend, info in self.backends.items():
            backend_configs[backend] = info.backend_config
        return backend_configs

    def get_traffic_policies(self):
        """Fetched by the router on startup."""
        return self.traffic_policies

    def _list_replicas(self, backend_tag):
        """Used only for testing."""
        return self.replicas[backend_tag]

    def get_traffic_policy(self, endpoint):
        """Fetched by serve handles."""
        return self.traffic_policies[endpoint]

    async def _start_backend_worker(self, backend_tag, replica_tag):
        """Creates a backend worker and waits for it to start up.

        Assumes that the backend configuration has already been registered
        in self.backends.
        """
        logger.debug("Starting worker '{}' for backend '{}'.".format(
            replica_tag, backend_tag))
        backend_info = self.backends[backend_tag]

        replica_name = format_actor_name(replica_tag, self.instance_name)
        worker_handle = ray.remote(backend_info.worker_class).options(
            name=replica_name,
            max_restarts=-1,
            max_task_retries=-1,
            **backend_info.replica_config.ray_actor_options).remote(
                backend_tag,
                replica_tag,
                backend_info.replica_config.actor_init_args,
                backend_info.backend_config,
                instance_name=self.instance_name)
        # TODO(edoakes): we should probably have a timeout here.
        await worker_handle.ready.remote()
        return worker_handle

    async def _start_replica(self, backend_tag, replica_tag):
        # NOTE(edoakes): the replicas may already be created if we
        # failed after creating them but before writing a
        # checkpoint.
        try:
            worker_handle = ray.get_actor(replica_tag)
        except ValueError:
            worker_handle = await self._start_backend_worker(
                backend_tag, replica_tag)

        self.replicas[backend_tag].append(replica_tag)
        self.workers[backend_tag][replica_tag] = worker_handle

        # Register the worker with the router.
        await asyncio.gather(*[
            router.add_new_worker.remote(backend_tag, replica_tag,
                                         worker_handle)
            for router in self.routers.values()
        ])

    async def _start_pending_replicas(self):
        """Starts the pending backend replicas in self.replicas_to_start.

        Starts the worker, then pushes an update to the router to add it to
        the proper backend. If the worker has already been started, only
        updates the router.

        Clears self.replicas_to_start.
        """
        replica_started_futures = []
        for backend_tag, replicas_to_create in self.replicas_to_start.items():
            for replica_tag in replicas_to_create:
                replica_started_futures.append(
                    self._start_replica(backend_tag, replica_tag))

        # Wait on all creation task futures together.
        await asyncio.gather(*replica_started_futures)

        self.replicas_to_start.clear()

    async def _stop_pending_replicas(self):
        """Stops the pending backend replicas in self.replicas_to_stop.

        Removes workers from the router, kills them, and clears
        self.replicas_to_stop.
        """
        for backend_tag, replicas_to_stop in self.replicas_to_stop.items():
            for replica_tag in replicas_to_stop:
                # NOTE(edoakes): the replicas may already be stopped if we
                # failed after stopping them but before writing a checkpoint.
                try:
                    replica = ray.get_actor(replica_tag)
                except ValueError:
                    continue

                # Remove the replica from router. This call is idempotent.
                await asyncio.gather(*[
                    router.remove_worker.remote(backend_tag, replica_tag)
                    for router in self.routers.values()
                ])

                # TODO(edoakes): this logic isn't ideal because there may be
                # pending tasks still executing on the replica. However, if we
                # use replica.__ray_terminate__, we may send it while the
                # replica is being restarted and there's no way to tell if it
                # successfully killed the worker or not.
                ray.kill(replica, no_restart=True)

        self.replicas_to_stop.clear()

    async def _remove_pending_backends(self):
        """Removes the pending backends in self.backends_to_remove.

        Clears self.backends_to_remove.
        """
        for backend_tag in self.backends_to_remove:
            await asyncio.gather(*[
                router.remove_backend.remote(backend_tag)
                for router in self.routers.values()
            ])
        self.backends_to_remove.clear()

    async def _remove_pending_endpoints(self):
        """Removes the pending endpoints in self.endpoints_to_remove.

        Clears self.endpoints_to_remove.
        """
        for endpoint_tag in self.endpoints_to_remove:
            await asyncio.gather(*[
                router.remove_endpoint.remote(endpoint_tag)
                for router in self.routers.values()
            ])
        self.endpoints_to_remove.clear()

    def _scale_replicas(self, backend_tag, num_replicas):
        """Scale the given backend to the number of replicas.

        NOTE: this does not actually start or stop the replicas, but instead
        adds the intention to start/stop them to self.workers_to_start and
        self.workers_to_stop. The caller is responsible for then first writing
        a checkpoint and then actually starting/stopping the intended replicas.
        This avoids inconsistencies with starting/stopping a worker and then
        crashing before writing a checkpoint.
        """
        logger.debug("Scaling backend '{}' to {} replicas".format(
            backend_tag, num_replicas))
        assert (backend_tag in self.backends
                ), "Backend {} is not registered.".format(backend_tag)
        assert num_replicas >= 0, ("Number of replicas must be"
                                   " greater than or equal to 0.")

        current_num_replicas = len(self.replicas[backend_tag])
        delta_num_replicas = num_replicas - current_num_replicas

        backend_info = self.backends[backend_tag]
        if delta_num_replicas > 0:
            can_schedule = try_schedule_resources_on_nodes(
                requirements=[
                    backend_info.replica_config.resource_dict
                    for _ in range(delta_num_replicas)
                ],
                ray_nodes=ray.nodes())
            if _RESOURCE_CHECK_ENABLED and not all(can_schedule):
                num_possible = sum(can_schedule)
                raise RayServeException(
                    "Cannot scale backend {} to {} replicas. Ray Serve tried "
                    "to add {} replicas but the resources only allows {} "
                    "to be added. To fix this, consider scaling to replica to "
                    "{} or add more resources to the cluster. You can check "
                    "avaiable resources with ray.nodes().".format(
                        backend_tag, num_replicas, delta_num_replicas,
                        num_possible, current_num_replicas + num_possible))

            logger.debug("Adding {} replicas to backend {}".format(
                delta_num_replicas, backend_tag))
            for _ in range(delta_num_replicas):
                replica_tag = "{}#{}".format(backend_tag, get_random_letters())
                self.replicas_to_start[backend_tag].append(replica_tag)

        elif delta_num_replicas < 0:
            logger.debug("Removing {} replicas from backend {}".format(
                -delta_num_replicas, backend_tag))
            assert len(self.replicas[backend_tag]) >= delta_num_replicas
            for _ in range(-delta_num_replicas):
                replica_tag = self.replicas[backend_tag].pop()
                if len(self.replicas[backend_tag]) == 0:
                    del self.replicas[backend_tag]
                del self.workers[backend_tag][replica_tag]
                if len(self.workers[backend_tag]) == 0:
                    del self.workers[backend_tag]

                self.replicas_to_stop[backend_tag].append(replica_tag)

    def get_all_worker_handles(self):
        """Fetched by the router on startup."""
        return self.workers

    def get_all_backends(self):
        """Returns a dictionary of backend tag to backend config dict."""
        backends = {}
        for backend_tag, backend_info in self.backends.items():
            backends[backend_tag] = backend_info.backend_config.__dict__
        return backends

    def get_all_endpoints(self):
        """Returns a dictionary of endpoint to endpoint config."""
        endpoints = {}
        for route, (endpoint, methods) in self.routes.items():
            if endpoint in self.traffic_policies:
                traffic_policy = self.traffic_policies[endpoint]
                traffic_dict = traffic_policy.traffic_dict
                shadow_dict = traffic_policy.shadow_dict
            else:
                traffic_dict = {}
                shadow_dict = {}

            endpoints[endpoint] = {
                "route": route if route.startswith("/") else None,
                "methods": methods,
                "traffic": traffic_dict,
                "shadows": shadow_dict,
            }
        return endpoints

    async def _set_traffic(self, endpoint_name, traffic_dict):
        if endpoint_name not in self.get_all_endpoints():
            raise ValueError("Attempted to assign traffic for an endpoint '{}'"
                             " that is not registered.".format(endpoint_name))

        assert isinstance(traffic_dict,
                          dict), "Traffic policy must be a dictionary."

        for backend in traffic_dict:
            if backend not in self.backends:
                raise ValueError(
                    "Attempted to assign traffic to a backend '{}' that "
                    "is not registered.".format(backend))

        traffic_policy = TrafficPolicy(traffic_dict)
        self.traffic_policies[endpoint_name] = traffic_policy

        # NOTE(edoakes): we must write a checkpoint before pushing the
        # update to avoid inconsistent state if we crash after pushing the
        # update.
        self._checkpoint()
        await asyncio.gather(*[
            router.set_traffic.remote(endpoint_name, traffic_policy)
            for router in self.routers.values()
        ])

    async def set_traffic(self, endpoint_name, traffic_dict):
        """Sets the traffic policy for the specified endpoint."""
        async with self.write_lock:
            await self._set_traffic(endpoint_name, traffic_dict)

    async def shadow_traffic(self, endpoint_name, backend_tag, proportion):
        """Shadow traffic from the endpoint to the backend."""
        async with self.write_lock:
            if endpoint_name not in self.get_all_endpoints():
                raise ValueError("Attempted to shadow traffic from an "
                                 "endpoint '{}' that is not registered."
                                 .format(endpoint_name))

            if backend_tag not in self.backends:
                raise ValueError(
                    "Attempted to shadow traffic to a backend '{}' that "
                    "is not registered.".format(backend_tag))

            self.traffic_policies[endpoint_name].set_shadow(
                backend_tag, proportion)

            # NOTE(edoakes): we must write a checkpoint before pushing the
            # update to avoid inconsistent state if we crash after pushing the
            # update.
            self._checkpoint()
            await asyncio.gather(*[
                router.set_traffic.remote(
                    endpoint_name,
                    self.traffic_policies[endpoint_name],
                ) for router in self.routers.values()
            ])

    async def create_endpoint(self, endpoint, traffic_dict, route, methods):
        """Create a new endpoint with the specified route and methods.

        If the route is None, this is a "headless" endpoint that will not
        be exposed over HTTP and can only be accessed via a handle.
        """
        async with self.write_lock:
            # If this is a headless endpoint with no route, key the endpoint
            # based on its name.
            # TODO(edoakes): we should probably just store routes and endpoints
            # separately.
            if route is None:
                route = endpoint

            # TODO(edoakes): move this to client side.
            err_prefix = "Cannot create endpoint."
            if route in self.routes:

                # Ensures this method is idempotent
                if self.routes[route] == (endpoint, methods):
                    return

                else:
                    raise ValueError(
                        "{} Route '{}' is already registered.".format(
                            err_prefix, route))

            if endpoint in self.get_all_endpoints():
                raise ValueError(
                    "{} Endpoint '{}' is already registered.".format(
                        err_prefix, endpoint))

            logger.info(
                "Registering route '{}' to endpoint '{}' with methods '{}'.".
                format(route, endpoint, methods))

            self.routes[route] = (endpoint, methods)

            # NOTE(edoakes): checkpoint is written in self._set_traffic.
            await self._set_traffic(endpoint, traffic_dict)
            await asyncio.gather(*[
                router.set_route_table.remote(self.routes)
                for router in self.routers.values()
            ])

    async def delete_endpoint(self, endpoint):
        """Delete the specified endpoint.

        Does not modify any corresponding backends.
        """
        logger.info("Deleting endpoint '{}'".format(endpoint))
        async with self.write_lock:
            # This method must be idempotent. We should validate that the
            # specified endpoint exists on the client.
            for route, (route_endpoint, _) in self.routes.items():
                if route_endpoint == endpoint:
                    route_to_delete = route
                    break
            else:
                logger.info("Endpoint '{}' doesn't exist".format(endpoint))
                return

            # Remove the routing entry.
            del self.routes[route_to_delete]

            # Remove the traffic policy entry if it exists.
            if endpoint in self.traffic_policies:
                del self.traffic_policies[endpoint]

            self.endpoints_to_remove.append(endpoint)

            # NOTE(edoakes): we must write a checkpoint before pushing the
            # updates to the routers to avoid inconsistent state if we crash
            # after pushing the update.
            self._checkpoint()

            await asyncio.gather(*[
                router.set_route_table.remote(self.routes)
                for router in self.routers.values()
            ])
            await self._remove_pending_endpoints()

    async def create_backend(self, backend_tag, backend_config,
                             replica_config):
        """Register a new backend under the specified tag."""
        async with self.write_lock:
            # Ensures this method is idempotent.
            if backend_tag in self.backends:
                backend_info = self.backends[backend_tag]
                if (backend_info.backend_config == backend_config
                        and backend_info.replica_config == replica_config):
                    return

            backend_worker = create_backend_worker(
                replica_config.func_or_class)

            # Save creator that starts replicas, the arguments to be passed in,
            # and the configuration for the backends.
            self.backends[backend_tag] = BackendInfo(
                backend_worker, backend_config, replica_config)
            if backend_config.autoscaling_config is not None:
                self.autoscaling_policies[
                    backend_tag] = BasicAutoscalingPolicy(
                        backend_tag, backend_config.autoscaling_config)

            try:
                self._scale_replicas(backend_tag, backend_config.num_replicas)
            except RayServeException as e:
                del self.backends[backend_tag]
                raise e

            # NOTE(edoakes): we must write a checkpoint before starting new
            # or pushing the updated config to avoid inconsistent state if we
            # crash while making the change.
            self._checkpoint()
            await self._start_pending_replicas()

            # Set the backend config inside the router
            # (particularly for max-batch-size).
            await asyncio.gather(*[
                router.set_backend_config.remote(backend_tag, backend_config)
                for router in self.routers.values()
            ])
            await self.broadcast_backend_config(backend_tag)

    async def delete_backend(self, backend_tag):
        async with self.write_lock:
            # This method must be idempotent. We should validate that the
            # specified backend exists on the client.
            if backend_tag not in self.backends:
                return

            # Check that the specified backend isn't used by any endpoints.
            for endpoint, traffic_policy in self.traffic_policies.items():
                if (backend_tag in traffic_policy.traffic_dict
                        or backend_tag in traffic_policy.shadow_dict):
                    raise ValueError("Backend '{}' is used by endpoint '{}' "
                                     "and cannot be deleted. Please remove "
                                     "the backend from all endpoints and try "
                                     "again.".format(backend_tag, endpoint))

            # Scale its replicas down to 0. This will also remove the backend
            # from self.backends and self.replicas.
            self._scale_replicas(backend_tag, 0)

            # Remove the backend's metadata.
            del self.backends[backend_tag]
            if backend_tag in self.autoscaling_policies:
                del self.autoscaling_policies[backend_tag]

            # Add the intention to remove the backend from the router.
            self.backends_to_remove.append(backend_tag)

            # NOTE(edoakes): we must write a checkpoint before removing the
            # backend from the router to avoid inconsistent state if we crash
            # after pushing the update.
            self._checkpoint()
            await self._stop_pending_replicas()
            await self._remove_pending_backends()

    async def update_backend_config(self, backend_tag, config_options):
        """Set the config for the specified backend."""
        async with self.write_lock:
            assert (backend_tag in self.backends
                    ), "Backend {} is not registered.".format(backend_tag)
            assert isinstance(config_options, dict)

            self.backends[backend_tag].backend_config.update(config_options)
            backend_config = self.backends[backend_tag].backend_config

            # Scale the replicas with the new configuration.
            self._scale_replicas(backend_tag, backend_config.num_replicas)

            # NOTE(edoakes): we must write a checkpoint before pushing the
            # update to avoid inconsistent state if we crash after pushing the
            # update.
            self._checkpoint()

            # Inform the router about change in configuration
            # (particularly for setting max_batch_size).
            await asyncio.gather(*[
                router.set_backend_config.remote(backend_tag, backend_config)
                for router in self.routers.values()
            ])

            await self._start_pending_replicas()
            await self._stop_pending_replicas()

            await self.broadcast_backend_config(backend_tag)

    async def broadcast_backend_config(self, backend_tag):
        backend_config = self.backends[backend_tag].backend_config
        broadcast_futures = []
        for replica_tag in self.replicas[backend_tag]:
            try:
                replica = ray.get_actor(replica_tag)
            except ValueError:
                continue

            future = replica.update_config.remote(backend_config).as_future()
            broadcast_futures.append(future)
        if len(broadcast_futures) > 0:
            await asyncio.gather(*broadcast_futures)

    def get_backend_config(self, backend_tag):
        """Get the current config for the specified backend."""
        assert (backend_tag in self.backends
                ), "Backend {} is not registered.".format(backend_tag)
        return self.backends[backend_tag].backend_config

    async def shutdown(self):
        """Shuts down the serve instance completely."""
        async with self.write_lock:
            for router in self.routers.values():
                ray.kill(router, no_restart=True)
            ray.kill(self.metric_exporter, no_restart=True)
            for replica_dict in self.workers.values():
                for replica in replica_dict.values():
                    ray.kill(replica, no_restart=True)
            self.kv_store.delete(CHECKPOINT_KEY)

    async def report_queue_lengths(self, router_name, queue_lengths):
        # TODO: remove old router stats when removing them.
        for backend, queue_length in queue_lengths.items():
            self.backend_stats[backend][router_name] = queue_length
Example #3
0
class ServeController:
    """Responsible for managing the state of the serving system.

    The controller implements fault tolerance by persisting its state in
    a new checkpoint each time a state change is made. If the actor crashes,
    the latest checkpoint is loaded and the state is recovered. Checkpoints
    are written/read using a provided KV-store interface.

    All hard state in the system is maintained by this actor and persisted via
    these checkpoints. Soft state required by other components is fetched by
    those actors from this actor on startup and updates are pushed out from
    this actor.

    All other actors started by the controller are named, detached actors
    so they will not fate share with the controller if it crashes.

    The following guarantees are provided for state-changing calls to the
    controller:
        - If the call succeeds, the change was made and will be reflected in
          the system even if the controller or other actors die unexpectedly.
        - If the call fails, the change may have been made but isn't guaranteed
          to have been. The client should retry in this case. Note that this
          requires all implementations here to be idempotent.
    """

    async def __init__(self,
                       controller_name: str,
                       http_config: HTTPOptions,
                       detached: bool = False):
        # Used to read/write checkpoints.
        self.kv_store = RayInternalKVStore(namespace=controller_name)

        # Dictionary of backend_tag -> proxy_name -> most recent queue length.
        self.backend_stats = defaultdict(lambda: defaultdict(dict))

        # Used to ensure that only a single state-changing operation happens
        # at any given time.
        self.write_lock = asyncio.Lock()

        # NOTE(simon): Currently we do all-to-all broadcast. This means
        # any listeners will receive notification for all changes. This
        # can be problem at scale, e.g. updating a single backend config
        # will send over the entire configs. In the future, we should
        # optimize the logic to support subscription by key.
        self.long_poll_host = LongPollHost()

        self.goal_manager = AsyncGoalManager()
        self.http_state = HTTPState(controller_name, detached, http_config)
        self.endpoint_state = EndpointState(self.kv_store, self.long_poll_host)
        self.backend_state = BackendState(controller_name, detached,
                                          self.kv_store, self.long_poll_host,
                                          self.goal_manager)

        asyncio.get_event_loop().create_task(self.run_control_loop())

    async def wait_for_goal(self, goal_id: GoalId) -> None:
        await self.goal_manager.wait_for_goal(goal_id)

    async def _num_pending_goals(self) -> int:
        return self.goal_manager.num_pending_goals()

    async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
        """Proxy long pull client's listen request.

        Args:
            keys_to_snapshot_ids (Dict[str, int]): Snapshot IDs are used to
              determine whether or not the host should immediately return the
              data or wait for the value to be changed.
        """
        return await (
            self.long_poll_host.listen_for_change(keys_to_snapshot_ids))

    def get_http_proxies(self) -> Dict[NodeId, ActorHandle]:
        """Returns a dictionary of node ID to http_proxy actor handles."""
        return self.http_state.get_http_proxy_handles()

    async def run_control_loop(self) -> None:
        while True:
            async with self.write_lock:
                self.http_state.update()
                self.backend_state.update()

            await asyncio.sleep(CONTROL_LOOP_PERIOD_S)

    def _all_replica_handles(
            self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]:
        """Used for testing."""
        return self.backend_state.get_running_replica_handles()

    def get_all_backends(self) -> Dict[BackendTag, BackendConfig]:
        """Returns a dictionary of backend tag to backend config."""
        return self.backend_state.get_backend_configs()

    def get_all_endpoints(self) -> Dict[EndpointTag, Dict[BackendTag, Any]]:
        """Returns a dictionary of backend tag to backend config."""
        return self.endpoint_state.get_endpoints()

    def _validate_traffic_dict(self, traffic_dict: Dict[str, float]):
        for backend in traffic_dict:
            if self.backend_state.get_backend(backend) is None:
                raise ValueError(
                    "Attempted to assign traffic to a backend '{}' that "
                    "is not registered.".format(backend))

    async def set_traffic(self, endpoint: str,
                          traffic_dict: Dict[str, float]) -> None:
        """Sets the traffic policy for the specified endpoint."""
        async with self.write_lock:
            self._validate_traffic_dict(traffic_dict)

            logger.info("Setting traffic for endpoint "
                        f"'{endpoint}' to '{traffic_dict}'.")

            self.endpoint_state.set_traffic_policy(endpoint,
                                                   TrafficPolicy(traffic_dict))

    async def shadow_traffic(self, endpoint_name: str, backend_tag: BackendTag,
                             proportion: float) -> None:
        """Shadow traffic from the endpoint to the backend."""
        async with self.write_lock:
            if self.backend_state.get_backend(backend_tag) is None:
                raise ValueError(
                    "Attempted to shadow traffic to a backend '{}' that "
                    "is not registered.".format(backend_tag))

            logger.info(
                "Shadowing '{}' of traffic to endpoint '{}' to backend '{}'.".
                format(proportion, endpoint_name, backend_tag))

            self.endpoint_state.shadow_traffic(endpoint_name, backend_tag,
                                               proportion)

    async def create_endpoint(
            self,
            endpoint: str,
            traffic_dict: Dict[str, float],
            route: Optional[str],
            methods: List[str],
    ) -> None:
        """Create a new endpoint with the specified route and methods.

        If the route is None, this is a "headless" endpoint that will not
        be exposed over HTTP and can only be accessed via a handle.
        """
        async with self.write_lock:
            self._validate_traffic_dict(traffic_dict)

            logger.info(
                "Registering route '{}' to endpoint '{}' with methods '{}'.".
                format(route, endpoint, methods))

            self.endpoint_state.create_endpoint(endpoint, route, methods,
                                                TrafficPolicy(traffic_dict))

    async def delete_endpoint(self, endpoint: str) -> None:
        """Delete the specified endpoint.

        Does not modify any corresponding backends.
        """
        logger.info("Deleting endpoint '{}'".format(endpoint))
        async with self.write_lock:
            self.endpoint_state.delete_endpoint(endpoint)

    async def create_backend(
            self, backend_tag: BackendTag, backend_config: BackendConfig,
            replica_config: ReplicaConfig) -> Optional[GoalId]:
        """Register a new backend under the specified tag."""
        async with self.write_lock:
            return self.backend_state.create_backend(
                backend_tag, backend_config, replica_config)

    async def delete_backend(self,
                             backend_tag: BackendTag,
                             force_kill: bool = False) -> Optional[GoalId]:
        async with self.write_lock:
            # Check that the specified backend isn't used by any endpoints.
            for endpoint, info in self.endpoint_state.get_endpoints().items():
                if (backend_tag in info["traffic"]
                        or backend_tag in info["shadows"]):
                    raise ValueError("Backend '{}' is used by endpoint '{}' "
                                     "and cannot be deleted. Please remove "
                                     "the backend from all endpoints and try "
                                     "again.".format(backend_tag, endpoint))
            return self.backend_state.delete_backend(backend_tag, force_kill)

    async def update_backend_config(self, backend_tag: BackendTag,
                                    config_options: BackendConfig) -> GoalId:
        """Set the config for the specified backend."""
        async with self.write_lock:
            return self.backend_state.update_backend_config(
                backend_tag, config_options)

    def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig:
        """Get the current config for the specified backend."""
        if self.backend_state.get_backend(backend_tag) is None:
            raise ValueError(f"Backend {backend_tag} is not registered.")
        return self.backend_state.get_backend(backend_tag).backend_config

    def get_http_config(self):
        """Return the HTTP proxy configuration."""
        return self.http_state.get_config()

    async def shutdown(self) -> None:
        """Shuts down the serve instance completely."""
        async with self.write_lock:
            for proxy in self.http_state.get_http_proxy_handles().values():
                ray.kill(proxy, no_restart=True)
            for replica_dict in self.backend_state.get_running_replica_handles(
            ).values():
                for replica in replica_dict.values():
                    ray.kill(replica, no_restart=True)
            self.kv_store.delete(CHECKPOINT_KEY)

    async def deploy(self, name: str, backend_config: BackendConfig,
                     replica_config: ReplicaConfig,
                     version: Optional[str]) -> Optional[GoalId]:
        """TODO."""
        async with self.write_lock:
            if version is None:
                version = RESERVED_VERSION_TAG
            else:
                if version == RESERVED_VERSION_TAG:
                    # TODO(edoakes): this is unlikely to ever be hit, but it's
                    # still ugly and should be removed once the old codepath
                    # can be deleted.
                    raise ValueError(
                        f"Version {RESERVED_VERSION_TAG} is reserved and "
                        "cannot be used by applications.")
            goal_id = self.backend_state.create_backend(
                name, backend_config, replica_config, version)

            self.endpoint_state.create_endpoint(name, f"/{name}",
                                                ["GET", "POST"],
                                                TrafficPolicy({
                                                    name: 1.0
                                                }))
            return goal_id
Example #4
0
class ServeController:
    """Responsible for managing the state of the serving system.

    The controller implements fault tolerance by persisting its state in
    a new checkpoint each time a state change is made. If the actor crashes,
    the latest checkpoint is loaded and the state is recovered. Checkpoints
    are written/read using a provided KV-store interface.

    All hard state in the system is maintained by this actor and persisted via
    these checkpoints. Soft state required by other components is fetched by
    those actors from this actor on startup and updates are pushed out from
    this actor.

    All other actors started by the controller are named, detached actors
    so they will not fate share with the controller if it crashes.

    The following guarantees are provided for state-changing calls to the
    controller:
        - If the call succeeds, the change was made and will be reflected in
          the system even if the controller or other actors die unexpectedly.
        - If the call fails, the change may have been made but isn't guaranteed
          to have been. The client should retry in this case. Note that this
          requires all implementations here to be idempotent.
    """
    async def __init__(self,
                       controller_name: str,
                       http_config: HTTPOptions,
                       detached: bool = False):
        # Used to read/write checkpoints.
        self.kv_store = RayInternalKVStore(namespace=controller_name)

        # Dictionary of backend_tag -> proxy_name -> most recent queue length.
        self.backend_stats = defaultdict(lambda: defaultdict(dict))

        # Used to ensure that only a single state-changing operation happens
        # at any given time.
        self.write_lock = asyncio.Lock()

        self.long_poll_host = LongPollHost()

        self.goal_manager = AsyncGoalManager()
        self.http_state = HTTPState(controller_name, detached, http_config)
        self.endpoint_state = EndpointState(self.kv_store, self.long_poll_host)
        self.backend_state = BackendState(controller_name, detached,
                                          self.kv_store, self.long_poll_host,
                                          self.goal_manager)

        asyncio.get_event_loop().create_task(self.run_control_loop())

    async def wait_for_goal(self, goal_id: GoalId) -> None:
        await self.goal_manager.wait_for_goal(goal_id)

    async def _num_pending_goals(self) -> int:
        return self.goal_manager.num_pending_goals()

    async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
        """Proxy long pull client's listen request.

        Args:
            keys_to_snapshot_ids (Dict[str, int]): Snapshot IDs are used to
              determine whether or not the host should immediately return the
              data or wait for the value to be changed.
        """
        return await (
            self.long_poll_host.listen_for_change(keys_to_snapshot_ids))

    def get_http_proxies(self) -> Dict[NodeId, ActorHandle]:
        """Returns a dictionary of node ID to http_proxy actor handles."""
        return self.http_state.get_http_proxy_handles()

    async def run_control_loop(self) -> None:
        while True:
            async with self.write_lock:
                try:
                    self.http_state.update()
                except Exception as e:
                    logger.error(f"Exception updating HTTP state: {e}")
                try:
                    self.backend_state.update()
                except Exception as e:
                    logger.error(f"Exception updating backend state: {e}")

            await asyncio.sleep(CONTROL_LOOP_PERIOD_S)

    def _all_replica_handles(
            self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]:
        """Used for testing."""
        return self.backend_state.get_running_replica_handles()

    def get_all_backends(self) -> Dict[BackendTag, BackendConfig]:
        """Returns a dictionary of backend tag to backend config."""
        return self.backend_state.get_backend_configs()

    def get_all_endpoints(self) -> Dict[EndpointTag, Dict[BackendTag, Any]]:
        """Returns a dictionary of backend tag to backend config."""
        return self.endpoint_state.get_endpoints()

    def _validate_traffic_dict(self, traffic_dict: Dict[str, float]):
        for backend in traffic_dict:
            if self.backend_state.get_backend(backend) is None:
                raise ValueError(
                    "Attempted to assign traffic to a backend '{}' that "
                    "is not registered.".format(backend))

    async def set_traffic(self, endpoint: str,
                          traffic_dict: Dict[str, float]) -> None:
        """Sets the traffic policy for the specified endpoint."""
        async with self.write_lock:
            self._validate_traffic_dict(traffic_dict)

            logger.info("Setting traffic for endpoint "
                        f"'{endpoint}' to '{traffic_dict}'.")

            self.endpoint_state.set_traffic_policy(endpoint,
                                                   TrafficPolicy(traffic_dict))

    async def shadow_traffic(self, endpoint_name: str, backend_tag: BackendTag,
                             proportion: float) -> None:
        """Shadow traffic from the endpoint to the backend."""
        async with self.write_lock:
            if self.backend_state.get_backend(backend_tag) is None:
                raise ValueError(
                    "Attempted to shadow traffic to a backend '{}' that "
                    "is not registered.".format(backend_tag))

            logger.info(
                "Shadowing '{}' of traffic to endpoint '{}' to backend '{}'.".
                format(proportion, endpoint_name, backend_tag))

            self.endpoint_state.shadow_traffic(endpoint_name, backend_tag,
                                               proportion)

    async def create_endpoint(
        self,
        endpoint: str,
        traffic_dict: Dict[str, float],
        route: Optional[str],
        methods: List[str],
    ) -> None:
        """Create a new endpoint with the specified route and methods.

        If the route is None, this is a "headless" endpoint that will not
        be exposed over HTTP and can only be accessed via a handle.
        """
        async with self.write_lock:
            self._validate_traffic_dict(traffic_dict)

            logger.info(
                "Registering route '{}' to endpoint '{}' with methods '{}'.".
                format(route, endpoint, methods))

            self.endpoint_state.create_endpoint(endpoint, route, methods,
                                                TrafficPolicy(traffic_dict))

    async def delete_endpoint(self, endpoint: str) -> None:
        """Delete the specified endpoint.

        Does not modify any corresponding backends.
        """
        logger.info("Deleting endpoint '{}'".format(endpoint))
        async with self.write_lock:
            self.endpoint_state.delete_endpoint(endpoint)

    async def create_backend(
            self, backend_tag: BackendTag, backend_config: BackendConfig,
            replica_config: ReplicaConfig) -> Optional[GoalId]:
        """Register a new backend under the specified tag."""
        async with self.write_lock:
            backend_info = BackendInfo(worker_class=create_backend_replica(
                replica_config.backend_def),
                                       version=RESERVED_VERSION_TAG,
                                       backend_config=backend_config,
                                       replica_config=replica_config)
            return self.backend_state.deploy_backend(backend_tag, backend_info)

    async def delete_backend(self,
                             backend_tag: BackendTag,
                             force_kill: bool = False) -> Optional[GoalId]:
        async with self.write_lock:
            # Check that the specified backend isn't used by any endpoints.
            for endpoint, info in self.endpoint_state.get_endpoints().items():
                if (backend_tag in info["traffic"]
                        or backend_tag in info["shadows"]):
                    raise ValueError("Backend '{}' is used by endpoint '{}' "
                                     "and cannot be deleted. Please remove "
                                     "the backend from all endpoints and try "
                                     "again.".format(backend_tag, endpoint))
            return self.backend_state.delete_backend(backend_tag, force_kill)

    async def update_backend_config(self, backend_tag: BackendTag,
                                    config_options: BackendConfig) -> GoalId:
        """Set the config for the specified backend."""
        async with self.write_lock:
            existing_info = self.backend_state.get_backend(backend_tag)
            if existing_info is None:
                raise ValueError(f"Backend {backend_tag} is not registered.")

            backend_info = BackendInfo(
                worker_class=existing_info.worker_class,
                version=existing_info.version,
                backend_config=existing_info.backend_config.copy(
                    update=config_options.dict(exclude_unset=True)),
                replica_config=existing_info.replica_config)
            return self.backend_state.deploy_backend(backend_tag, backend_info)

    def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig:
        """Get the current config for the specified backend."""
        if self.backend_state.get_backend(backend_tag) is None:
            raise ValueError(f"Backend {backend_tag} is not registered.")
        return self.backend_state.get_backend(backend_tag).backend_config

    def get_http_config(self):
        """Return the HTTP proxy configuration."""
        return self.http_state.get_config()

    async def shutdown(self) -> None:
        """Shuts down the serve instance completely."""
        async with self.write_lock:
            for proxy in self.http_state.get_http_proxy_handles().values():
                ray.kill(proxy, no_restart=True)
            for replica_dict in self.backend_state.get_running_replica_handles(
            ).values():
                for replica in replica_dict.values():
                    ray.kill(replica, no_restart=True)
            self.kv_store.delete(CHECKPOINT_KEY)

    async def deploy(self, name: str, backend_config: BackendConfig,
                     replica_config: ReplicaConfig, version: Optional[str],
                     route_prefix: Optional[str]) -> Optional[GoalId]:
        if route_prefix is None:
            route_prefix = f"/{name}"

        if replica_config.is_asgi_app:
            # When the backend is asgi application, we want to proxy it
            # with a prefixed path as well as proxy all HTTP methods.
            # {wildcard:path} is used so HTTPProxy's Starlette router can match
            # arbitrary path.
            if route_prefix.endswith("/"):
                route_prefix = route_prefix[:-1]
            http_route = route_prefix + WILDCARD_PATH_SUFFIX
            http_methods = ALL_HTTP_METHODS
        else:
            http_route = route_prefix
            # Generic endpoint should support a limited subset of HTTP methods.
            http_methods = ["GET", "POST"]

        python_methods = []
        if inspect.isclass(replica_config.backend_def):
            for method_name, _ in inspect.getmembers(
                    replica_config.backend_def, inspect.isfunction):
                python_methods.append(method_name)

        async with self.write_lock:
            backend_info = BackendInfo(worker_class=create_backend_replica(
                replica_config.backend_def),
                                       version=version,
                                       backend_config=backend_config,
                                       replica_config=replica_config)

            goal_id = self.backend_state.deploy_backend(name, backend_info)
            self.endpoint_state.update_endpoint(name,
                                                http_route,
                                                http_methods,
                                                TrafficPolicy({name: 1.0}),
                                                python_methods=python_methods)
            return goal_id

    def delete_deployment(self, name: str) -> Optional[GoalId]:
        self.endpoint_state.delete_endpoint(name)
        return self.backend_state.delete_backend(name, force_kill=False)

    def get_deployment_info(self, name: str) -> Tuple[BackendInfo, str]:
        """Get the current information about a deployment.

        Args:
            name(str): the name of the deployment.

        Returns:
            (BackendInfo, route)

        Raises:
            KeyError if the deployment doesn't exist.
        """
        backend_info: BackendInfo = self.backend_state.get_backend(name)
        if backend_info is None:
            raise KeyError(f"Deployment {name} does not exist.")

        route = self.endpoint_state.get_endpoint_route(name)

        return backend_info, route
Example #5
0
class ServeController:
    """Responsible for managing the state of the serving system.

    The controller implements fault tolerance by persisting its state in
    a new checkpoint each time a state change is made. If the actor crashes,
    the latest checkpoint is loaded and the state is recovered. Checkpoints
    are written/read using a provided KV-store interface.

    All hard state in the system is maintained by this actor and persisted via
    these checkpoints. Soft state required by other components is fetched by
    those actors from this actor on startup and updates are pushed out from
    this actor.

    All other actors started by the controller are named, detached actors
    so they will not fate share with the controller if it crashes.

    The following guarantees are provided for state-changing calls to the
    controller:
        - If the call succeeds, the change was made and will be reflected in
          the system even if the controller or other actors die unexpectedly.
        - If the call fails, the change may have been made but isn't guaranteed
          to have been. The client should retry in this case. Note that this
          requires all implementations here to be idempotent.
    """

    async def __init__(self,
                       controller_name: str,
                       http_config: HTTPOptions,
                       detached: bool = False):
        # Used to read/write checkpoints.
        self.kv_store = RayInternalKVStore(namespace=controller_name)

        # Dictionary of backend_tag -> proxy_name -> most recent queue length.
        self.backend_stats = defaultdict(lambda: defaultdict(dict))

        # Used to ensure that only a single state-changing operation happens
        # at any given time.
        self.write_lock = asyncio.Lock()

        # Map of awaiting results
        # TODO(ilr): Checkpoint this once this becomes asynchronous
        self.inflight_results: Dict[UUID, asyncio.Event] = dict()
        self._serializable_inflight_results: Dict[UUID, FutureResult] = dict()

        # NOTE(simon): Currently we do all-to-all broadcast. This means
        # any listeners will receive notification for all changes. This
        # can be problem at scale, e.g. updating a single backend config
        # will send over the entire configs. In the future, we should
        # optimize the logic to support subscription by key.
        self.long_poll_host = LongPollHost()

        self.http_state = HTTPState(controller_name, detached, http_config)
        self.endpoint_state = EndpointState(self.kv_store, self.long_poll_host)

        checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY)
        if checkpoint_bytes is None:
            logger.debug("No checkpoint found")
            self.backend_state = BackendState(controller_name, detached)
        else:
            checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
            self.backend_state = BackendState(
                controller_name,
                detached,
                checkpoint=checkpoint.backend_state_checkpoint)

            self._serializable_inflight_results = checkpoint.inflight_reqs
            for uuid, fut_result in self._serializable_inflight_results.items(
            ):
                self._create_event_with_result(fut_result.requested_goal, uuid)

        self.notify_backend_configs_changed()
        self.notify_replica_handles_changed()

        asyncio.get_event_loop().create_task(self.run_control_loop())

    async def wait_for_event(self, uuid: UUID) -> bool:
        start = time.time()
        if uuid not in self.inflight_results:
            logger.debug(f"UUID ({uuid}) not found!!!")
            return True
        event = self.inflight_results[uuid]
        await event.wait()
        self.inflight_results.pop(uuid)
        self._serializable_inflight_results.pop(uuid)
        async with self.write_lock:
            self._checkpoint()
        logger.debug(f"Waiting for {uuid} took {time.time() - start} seconds")

        return True

    def _create_event_with_result(
            self,
            goal_state: Dict[str, any],
            recreation_uuid: Optional[UUID] = None) -> UUID:
        # NOTE(ilr) Must be called before checkpointing!
        event = asyncio.Event()
        event.result = FutureResult(goal_state)
        uuid_val = recreation_uuid or uuid4()
        self.inflight_results[uuid_val] = event
        self._serializable_inflight_results[uuid_val] = event.result
        return uuid_val

    async def _num_inflight_results(self) -> int:
        return len(self.inflight_results)

    def notify_replica_handles_changed(self):
        self.long_poll_host.notify_changed(
            LongPollKey.REPLICA_HANDLES, {
                backend_tag: list(replica_dict.values())
                for backend_tag, replica_dict in
                self.backend_state.backend_replicas.items()
            })

    def notify_backend_configs_changed(self):
        self.long_poll_host.notify_changed(
            LongPollKey.BACKEND_CONFIGS,
            self.backend_state.get_backend_configs())

    async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
        """Proxy long pull client's listen request.

        Args:
            keys_to_snapshot_ids (Dict[str, int]): Snapshot IDs are used to
              determine whether or not the host should immediately return the
              data or wait for the value to be changed.
        """
        return await (
            self.long_poll_host.listen_for_change(keys_to_snapshot_ids))

    def get_http_proxies(self) -> Dict[NodeId, ActorHandle]:
        """Returns a dictionary of node ID to http_proxy actor handles."""
        return self.http_state.get_http_proxy_handles()

    def _checkpoint(self) -> None:
        """Checkpoint internal state and write it to the KV store."""
        assert self.write_lock.locked()
        logger.debug("Writing checkpoint")
        start = time.time()

        checkpoint = pickle.dumps(
            Checkpoint(self.backend_state.checkpoint(),
                       self._serializable_inflight_results))

        self.kv_store.put(CHECKPOINT_KEY, checkpoint)
        logger.debug("Wrote checkpoint in {:.3f}s".format(time.time() - start))

        if random.random(
        ) < _CRASH_AFTER_CHECKPOINT_PROBABILITY and self.detached:
            logger.warning("Intentionally crashing after checkpoint")
            os._exit(0)

    async def reconcile_current_and_goal_backends(self):
        pass

    def set_goal_id(self, goal_id: UUID) -> None:
        event = self.inflight_results.get(goal_id)
        logger.debug(f"Setting goal id {goal_id}")
        if event:
            event.set()

    async def run_control_loop(self) -> None:
        while True:
            async with self.write_lock:
                self.http_state.update()

                completed_ids = self.backend_state.completed_goals()
                for done_id in completed_ids:
                    self.set_goal_id(done_id)
                delta_workers = await self.backend_state.update()
                if delta_workers:
                    self.notify_replica_handles_changed()
                    self._checkpoint()

            await asyncio.sleep(CONTROL_LOOP_PERIOD_S)

    def _all_replica_handles(
            self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]:
        """Used for testing."""
        return self.backend_state.get_replica_handles()

    def get_all_backends(self) -> Dict[BackendTag, BackendConfig]:
        """Returns a dictionary of backend tag to backend config."""
        return self.backend_state.get_backend_configs()

    def get_all_endpoints(self) -> Dict[EndpointTag, Dict[BackendTag, Any]]:
        """Returns a dictionary of backend tag to backend config."""
        return self.endpoint_state.get_endpoints()

    def _set_traffic(self, endpoint_name: str,
                     traffic_dict: Dict[str, float]) -> UUID:
        for backend in traffic_dict:
            if self.backend_state.get_backend(backend) is None:
                raise ValueError(
                    "Attempted to assign traffic to a backend '{}' that "
                    "is not registered.".format(backend))

        self.endpoint_state.set_traffic_policy(endpoint_name,
                                               TrafficPolicy(traffic_dict))

    def _validate_traffic_dict(self, traffic_dict: Dict[str, float]):
        for backend in traffic_dict:
            if self.backend_state.get_backend(backend) is None:
                raise ValueError(
                    "Attempted to assign traffic to a backend '{}' that "
                    "is not registered.".format(backend))

    async def set_traffic(self, endpoint_name: str,
                          traffic_dict: Dict[str, float]) -> None:
        """Sets the traffic policy for the specified endpoint."""
        async with self.write_lock:
            self._validate_traffic_dict(traffic_dict)
            self._set_traffic(endpoint_name, traffic_dict)

    async def shadow_traffic(self, endpoint_name: str, backend_tag: BackendTag,
                             proportion: float) -> UUID:
        """Shadow traffic from the endpoint to the backend."""
        async with self.write_lock:
            if self.backend_state.get_backend(backend_tag) is None:
                raise ValueError(
                    "Attempted to shadow traffic to a backend '{}' that "
                    "is not registered.".format(backend_tag))

            logger.info(
                "Shadowing '{}' of traffic to endpoint '{}' to backend '{}'.".
                format(proportion, endpoint_name, backend_tag))

            self.endpoint_state.shadow_traffic(endpoint_name, backend_tag,
                                               proportion)

    # TODO(architkulkarni): add Optional for route after cloudpickle upgrade
    async def create_endpoint(self, endpoint: str,
                              traffic_dict: Dict[str, float], route,
                              methods: List[str]) -> UUID:
        """Create a new endpoint with the specified route and methods.

        If the route is None, this is a "headless" endpoint that will not
        be exposed over HTTP and can only be accessed via a handle.
        """
        async with self.write_lock:
            self._validate_traffic_dict(traffic_dict)

            logger.info(
                "Registering route '{}' to endpoint '{}' with methods '{}'.".
                format(route, endpoint, methods))

            self.endpoint_state.create_endpoint(endpoint, route, methods,
                                                TrafficPolicy(traffic_dict))

    async def delete_endpoint(self, endpoint: str) -> None:
        """Delete the specified endpoint.

        Does not modify any corresponding backends.
        """
        logger.info("Deleting endpoint '{}'".format(endpoint))
        async with self.write_lock:
            self.endpoint_state.delete_endpoint(endpoint)

    async def set_backend_goal(self, backend_tag: BackendTag,
                               backend_info: BackendInfo,
                               new_id: GoalId) -> None:
        # NOTE(ilr) Must checkpoint after doing this!
        existing_id_to_set = self.backend_state._set_backend_goal(
            backend_tag, backend_info, new_id)
        if existing_id_to_set:
            self.set_goal_id(existing_id_to_set)

    async def create_backend(self, backend_tag: BackendTag,
                             backend_config: BackendConfig,
                             replica_config: ReplicaConfig) -> UUID:
        """Register a new backend under the specified tag."""
        async with self.write_lock:
            # Ensures this method is idempotent.
            backend_info = self.backend_state.get_backend(backend_tag)
            if backend_info is not None:
                if (backend_info.backend_config == backend_config
                        and backend_info.replica_config == replica_config):
                    return

            backend_replica = create_backend_replica(
                replica_config.func_or_class)

            # Save creator that starts replicas, the arguments to be passed in,
            # and the configuration for the backends.
            backend_info = BackendInfo(
                worker_class=backend_replica,
                backend_config=backend_config,
                replica_config=replica_config)

            return_uuid = self._create_event_with_result({
                backend_tag: backend_info
            })

            await self.set_backend_goal(backend_tag, backend_info, return_uuid)

            try:
                # This call should be to run control loop
                self.backend_state.scale_backend_replicas(
                    backend_tag, backend_config.num_replicas)
            except RayServeException as e:
                del self.backend_state.backends[backend_tag]
                raise e

            # NOTE(edoakes): we must write a checkpoint before starting new
            # or pushing the updated config to avoid inconsistent state if we
            # crash while making the change.
            self._checkpoint()
            self.notify_backend_configs_changed()
            return return_uuid

    async def delete_backend(self,
                             backend_tag: BackendTag,
                             force_kill: bool = False) -> UUID:
        async with self.write_lock:
            # This method must be idempotent. We should validate that the
            # specified backend exists on the client.
            if self.backend_state.get_backend(backend_tag) is None:
                return

            # Check that the specified backend isn't used by any endpoints.
            for endpoint, info in self.endpoint_state.get_endpoints().items():
                if (backend_tag in info["traffic"]
                        or backend_tag in info["shadows"]):
                    raise ValueError("Backend '{}' is used by endpoint '{}' "
                                     "and cannot be deleted. Please remove "
                                     "the backend from all endpoints and try "
                                     "again.".format(backend_tag, endpoint))

            # Scale its replicas down to 0.
            self.backend_state.scale_backend_replicas(backend_tag, 0,
                                                      force_kill)

            # Remove the backend's metadata.
            del self.backend_state.backends[backend_tag]

            # Add the intention to remove the backend from the routers.
            self.backend_state.backends_to_remove.append(backend_tag)

            return_uuid = self._create_event_with_result({backend_tag: None})
            # Remove the backend's metadata.
            await self.set_backend_goal(backend_tag, None, return_uuid)
            # NOTE(edoakes): we must write a checkpoint before removing the
            # backend from the routers to avoid inconsistent state if we crash
            # after pushing the update.
            self._checkpoint()
            return return_uuid

    async def update_backend_config(self, backend_tag: BackendTag,
                                    config_options: BackendConfig) -> UUID:
        """Set the config for the specified backend."""
        async with self.write_lock:
            assert (self.backend_state.get_backend(backend_tag)
                    ), "Backend {} is not registered.".format(backend_tag)
            assert isinstance(config_options, BackendConfig)

            stored_backend_config = self.backend_state.get_backend(
                backend_tag).backend_config
            backend_config = stored_backend_config.copy(
                update=config_options.dict(exclude_unset=True))
            backend_config._validate_complete()
            self.backend_state.get_backend(
                backend_tag).backend_config = backend_config
            backend_info = self.backend_state.get_backend(backend_tag)

            return_uuid = self._create_event_with_result({
                backend_tag: backend_info
            })
            await self.set_backend_goal(backend_tag, backend_info, return_uuid)

            # Scale the replicas with the new configuration.

            # This should be to run the control loop
            self.backend_state.scale_backend_replicas(
                backend_tag, backend_config.num_replicas)

            # NOTE(edoakes): we must write a checkpoint before pushing the
            # update to avoid inconsistent state if we crash after pushing the
            # update.
            self._checkpoint()

            # Inform the routers and backend replicas about config changes.
            self.notify_backend_configs_changed()

            return return_uuid

    def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig:
        """Get the current config for the specified backend."""
        assert (self.backend_state.get_backend(backend_tag)
                ), "Backend {} is not registered.".format(backend_tag)
        return self.backend_state.get_backend(backend_tag).backend_config

    def get_http_config(self):
        """Return the HTTP proxy configuration."""
        return self.http_state.get_config()

    async def shutdown(self) -> None:
        """Shuts down the serve instance completely."""
        async with self.write_lock:
            for proxy in self.http_state.get_http_proxy_handles().values():
                ray.kill(proxy, no_restart=True)
            for replica_dict in self.backend_state.get_replica_handles(
            ).values():
                for replica in replica_dict.values():
                    ray.kill(replica, no_restart=True)
            self.kv_store.delete(CHECKPOINT_KEY)
Example #6
0
class ServeController:
    """Responsible for managing the state of the serving system.

    The controller implements fault tolerance by persisting its state in
    a new checkpoint each time a state change is made. If the actor crashes,
    the latest checkpoint is loaded and the state is recovered. Checkpoints
    are written/read using a provided KV-store interface.

    All hard state in the system is maintained by this actor and persisted via
    these checkpoints. Soft state required by other components is fetched by
    those actors from this actor on startup and updates are pushed out from
    this actor.

    All other actors started by the controller are named, detached actors
    so they will not fate share with the controller if it crashes.

    The following guarantees are provided for state-changing calls to the
    controller:
        - If the call succeeds, the change was made and will be reflected in
          the system even if the controller or other actors die unexpectedly.
        - If the call fails, the change may have been made but isn't guaranteed
          to have been. The client should retry in this case. Note that this
          requires all implementations here to be idempotent.
    """
    async def __init__(self,
                       controller_name: str,
                       http_config: HTTPConfig,
                       detached: bool = False):
        # Used to read/write checkpoints.
        self.kv_store = RayInternalKVStore(namespace=controller_name)
        self.actor_reconciler = ActorStateReconciler(controller_name, detached)

        # backend -> AutoscalingPolicy
        self.autoscaling_policies = dict()

        # Dictionary of backend_tag -> proxy_name -> most recent queue length.
        self.backend_stats = defaultdict(lambda: defaultdict(dict))

        # Used to ensure that only a single state-changing operation happens
        # at any given time.
        self.write_lock = asyncio.Lock()

        # Map of awaiting results
        # TODO(ilr): Checkpoint this once this becomes asynchronous
        self.inflight_results: Dict[UUID, asyncio.Event] = dict()
        self._serializable_inflight_results: Dict[UUID, FutureResult] = dict()

        # HTTP state doesn't currently require a checkpoint.
        self.http_state = HTTPState(controller_name, detached, http_config)

        checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY)
        if checkpoint_bytes is None:
            logger.debug("No checkpoint found")
            self.backend_state = BackendState()
            self.endpoint_state = EndpointState()
        else:
            checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
            self.backend_state = BackendState(
                checkpoint=checkpoint.backend_state_checkpoint)
            self.endpoint_state = EndpointState(
                checkpoint=checkpoint.endpoint_state_checkpoint)
            await self._recover_from_checkpoint(checkpoint)

        # NOTE(simon): Currently we do all-to-all broadcast. This means
        # any listeners will receive notification for all changes. This
        # can be problem at scale, e.g. updating a single backend config
        # will send over the entire configs. In the future, we should
        # optimize the logic to support subscription by key.
        self.long_poll_host = LongPollHost()

        # The configs pushed out here get updated by
        # self._recover_from_checkpoint in the failure scenario, so that must
        # be run before we notify the changes.
        self.notify_backend_configs_changed()
        self.notify_replica_handles_changed()
        self.notify_traffic_policies_changed()
        self.notify_route_table_changed()

        asyncio.get_event_loop().create_task(self.run_control_loop())

    async def wait_for_event(self, uuid: UUID) -> bool:
        if uuid not in self.inflight_results:
            return True
        event = self.inflight_results[uuid]
        await event.wait()
        self.inflight_results.pop(uuid)
        self._serializable_inflight_results.pop(uuid)
        async with self.write_lock:
            self._checkpoint()

        return True

    def _create_event_with_result(
            self,
            goal_state: Dict[str, any],
            recreation_uuid: Optional[UUID] = None) -> UUID:
        # NOTE(ilr) Must be called before checkpointing!
        event = asyncio.Event()
        event.result = FutureResult(goal_state)
        event.set()
        uuid_val = recreation_uuid or uuid4()
        self.inflight_results[uuid_val] = event
        self._serializable_inflight_results[uuid_val] = event.result
        return uuid_val

    async def _num_inflight_results(self) -> int:
        return len(self.inflight_results)

    def notify_replica_handles_changed(self):
        self.long_poll_host.notify_changed(
            LongPollKey.REPLICA_HANDLES, {
                backend_tag: list(replica_dict.values())
                for backend_tag, replica_dict in
                self.actor_reconciler.backend_replicas.items()
            })

    def notify_traffic_policies_changed(self):
        self.long_poll_host.notify_changed(
            LongPollKey.TRAFFIC_POLICIES,
            self.endpoint_state.traffic_policies,
        )

    def notify_backend_configs_changed(self):
        self.long_poll_host.notify_changed(
            LongPollKey.BACKEND_CONFIGS,
            self.backend_state.get_backend_configs())

    def notify_route_table_changed(self):
        self.long_poll_host.notify_changed(LongPollKey.ROUTE_TABLE,
                                           self.endpoint_state.routes)

    async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
        """Proxy long pull client's listen request.

        Args:
            keys_to_snapshot_ids (Dict[str, int]): Snapshot IDs are used to
              determine whether or not the host should immediately return the
              data or wait for the value to be changed.
        """
        return await (
            self.long_poll_host.listen_for_change(keys_to_snapshot_ids))

    def get_http_proxies(self) -> Dict[NodeId, ActorHandle]:
        """Returns a dictionary of node ID to http_proxy actor handles."""
        return self.http_state.get_http_proxy_handles()

    def _checkpoint(self) -> None:
        """Checkpoint internal state and write it to the KV store."""
        assert self.write_lock.locked()
        logger.debug("Writing checkpoint")
        start = time.time()

        checkpoint = pickle.dumps(
            Checkpoint(self.endpoint_state.checkpoint(),
                       self.backend_state.checkpoint(), self.actor_reconciler,
                       self._serializable_inflight_results))

        self.kv_store.put(CHECKPOINT_KEY, checkpoint)
        logger.debug("Wrote checkpoint in {:.3f}s".format(time.time() - start))

        if random.random(
        ) < _CRASH_AFTER_CHECKPOINT_PROBABILITY and self.detached:
            logger.warning("Intentionally crashing after checkpoint")
            os._exit(0)

    async def _recover_from_checkpoint(self, checkpoint: Checkpoint) -> None:
        """Recover the instance state from the provided checkpoint.

        This should be called in the constructor to ensure that the internal
        state is updated before any other operations run. After running this,
        internal state will be updated and long-poll clients may be notified.

        Performs the following operations:
            1) Deserializes the internal state from the checkpoint.
            2) Starts/stops any replicas that are pending creation or
               deletion.
        """
        start = time.time()
        logger.info("Recovering from checkpoint")

        self.actor_reconciler = checkpoint.reconciler

        self._serializable_inflight_results = checkpoint.inflight_reqs
        for uuid, fut_result in self._serializable_inflight_results.items():
            self._create_event_with_result(fut_result.requested_goal, uuid)

        # NOTE(edoakes): unfortunately, we can't completely recover from a
        # checkpoint in the constructor because we block while waiting for
        # other actors to start up, and those actors fetch soft state from
        # this actor. Because no other tasks will start executing until after
        # the constructor finishes, if we were to run this logic in the
        # constructor it could lead to deadlock between this actor and a child.
        # However, we do need to guarantee that we have fully recovered from a
        # checkpoint before any other state-changing calls run. We address this
        # by acquiring the write_lock and then posting the task to recover from
        # a checkpoint to the event loop. Other state-changing calls acquire
        # this lock and will be blocked until recovering from the checkpoint
        # finishes. This can be removed once we move to the async control loop.

        async def finish_recover_from_checkpoint():
            assert self.write_lock.locked()
            self.autoscaling_policies = await self.actor_reconciler.\
                _recover_from_checkpoint(self.backend_state, self)
            self.write_lock.release()
            logger.info(
                "Recovered from checkpoint in {:.3f}s".format(time.time() -
                                                              start))

        await self.write_lock.acquire()
        asyncio.get_event_loop().create_task(finish_recover_from_checkpoint())

    async def do_autoscale(self) -> None:
        for backend, info in self.backend_state.backends.items():
            if backend not in self.autoscaling_policies:
                continue

            new_num_replicas = self.autoscaling_policies[backend].scale(
                self.backend_stats[backend], info.backend_config.num_replicas)
            if new_num_replicas > 0:
                await self.update_backend_config(
                    backend, BackendConfig(num_replicas=new_num_replicas))

    async def reconcile_current_and_goal_backends(self):
        pass

    async def run_control_loop(self) -> None:
        while True:
            await self.do_autoscale()
            async with self.write_lock:
                self.http_state.update()

            await asyncio.sleep(CONTROL_LOOP_PERIOD_S)

    def _all_replica_handles(
            self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]:
        """Used for testing."""
        return self.actor_reconciler.backend_replicas

    def get_all_backends(self) -> Dict[BackendTag, BackendConfig]:
        """Returns a dictionary of backend tag to backend config."""
        return self.backend_state.get_backend_configs()

    def get_all_endpoints(self) -> Dict[EndpointTag, Dict[BackendTag, Any]]:
        """Returns a dictionary of backend tag to backend config."""
        return self.endpoint_state.get_endpoints()

    async def _set_traffic(self, endpoint_name: str,
                           traffic_dict: Dict[str, float]) -> UUID:
        if endpoint_name not in self.endpoint_state.get_endpoints():
            raise ValueError("Attempted to assign traffic for an endpoint '{}'"
                             " that is not registered.".format(endpoint_name))

        assert isinstance(traffic_dict,
                          dict), "Traffic policy must be a dictionary."

        for backend in traffic_dict:
            if self.backend_state.get_backend(backend) is None:
                raise ValueError(
                    "Attempted to assign traffic to a backend '{}' that "
                    "is not registered.".format(backend))

        traffic_policy = TrafficPolicy(traffic_dict)
        self.endpoint_state.traffic_policies[endpoint_name] = traffic_policy

        return_uuid = self._create_event_with_result(
            {endpoint_name: traffic_policy})
        # NOTE(edoakes): we must write a checkpoint before pushing the
        # update to avoid inconsistent state if we crash after pushing the
        # update.
        self._checkpoint()
        self.notify_traffic_policies_changed()
        return return_uuid

    async def set_traffic(self, endpoint_name: str,
                          traffic_dict: Dict[str, float]) -> UUID:
        """Sets the traffic policy for the specified endpoint."""
        async with self.write_lock:
            return_uuid = await self._set_traffic(endpoint_name, traffic_dict)
        return return_uuid

    async def shadow_traffic(self, endpoint_name: str, backend_tag: BackendTag,
                             proportion: float) -> UUID:
        """Shadow traffic from the endpoint to the backend."""
        async with self.write_lock:
            if endpoint_name not in self.endpoint_state.get_endpoints():
                raise ValueError(
                    "Attempted to shadow traffic from an "
                    "endpoint '{}' that is not registered.".format(
                        endpoint_name))

            if self.backend_state.get_backend(backend_tag) is None:
                raise ValueError(
                    "Attempted to shadow traffic to a backend '{}' that "
                    "is not registered.".format(backend_tag))

            self.endpoint_state.traffic_policies[endpoint_name].set_shadow(
                backend_tag, proportion)

            traffic_policy = self.endpoint_state.traffic_policies[
                endpoint_name]

            return_uuid = self._create_event_with_result(
                {endpoint_name: traffic_policy})
            # NOTE(edoakes): we must write a checkpoint before pushing the
            # update to avoid inconsistent state if we crash after pushing the
            # update.
            self._checkpoint()
            self.notify_traffic_policies_changed()
            return return_uuid

    # TODO(architkulkarni): add Optional for route after cloudpickle upgrade
    async def create_endpoint(self, endpoint: str, traffic_dict: Dict[str,
                                                                      float],
                              route, methods) -> UUID:
        """Create a new endpoint with the specified route and methods.

        If the route is None, this is a "headless" endpoint that will not
        be exposed over HTTP and can only be accessed via a handle.
        """
        async with self.write_lock:
            # If this is a headless endpoint with no route, key the endpoint
            # based on its name.
            # TODO(edoakes): we should probably just store routes and endpoints
            # separately.
            if route is None:
                route = endpoint

            # TODO(edoakes): move this to client side.
            err_prefix = "Cannot create endpoint."
            if route in self.endpoint_state.routes:

                # Ensures this method is idempotent
                if self.endpoint_state.routes[route] == (endpoint, methods):
                    return

                else:
                    raise ValueError(
                        "{} Route '{}' is already registered.".format(
                            err_prefix, route))

            if endpoint in self.endpoint_state.get_endpoints():
                raise ValueError(
                    "{} Endpoint '{}' is already registered.".format(
                        err_prefix, endpoint))

            logger.info(
                "Registering route '{}' to endpoint '{}' with methods '{}'.".
                format(route, endpoint, methods))

            self.endpoint_state.routes[route] = (endpoint, methods)

            # NOTE(edoakes): checkpoint is written in self._set_traffic.
            return_uuid = await self._set_traffic(endpoint, traffic_dict)
            self.notify_route_table_changed()
            return return_uuid

    async def delete_endpoint(self, endpoint: str) -> UUID:
        """Delete the specified endpoint.

        Does not modify any corresponding backends.
        """
        logger.info("Deleting endpoint '{}'".format(endpoint))
        async with self.write_lock:
            # This method must be idempotent. We should validate that the
            # specified endpoint exists on the client.
            for route, (route_endpoint,
                        _) in self.endpoint_state.routes.items():
                if route_endpoint == endpoint:
                    route_to_delete = route
                    break
            else:
                logger.info("Endpoint '{}' doesn't exist".format(endpoint))
                return

            # Remove the routing entry.
            del self.endpoint_state.routes[route_to_delete]

            # Remove the traffic policy entry if it exists.
            if endpoint in self.endpoint_state.traffic_policies:
                del self.endpoint_state.traffic_policies[endpoint]

            return_uuid = self._create_event_with_result({
                route_to_delete: None,
                endpoint: None
            })
            # NOTE(edoakes): we must write a checkpoint before pushing the
            # updates to the proxies to avoid inconsistent state if we crash
            # after pushing the update.
            self._checkpoint()
            self.notify_route_table_changed()
            return return_uuid

    async def create_backend(self, backend_tag: BackendTag,
                             backend_config: BackendConfig,
                             replica_config: ReplicaConfig) -> UUID:
        """Register a new backend under the specified tag."""
        async with self.write_lock:
            # Ensures this method is idempotent.
            backend_info = self.backend_state.get_backend(backend_tag)
            if backend_info is not None:
                if (backend_info.backend_config == backend_config
                        and backend_info.replica_config == replica_config):
                    return

            backend_replica = create_backend_replica(
                replica_config.func_or_class)

            # Save creator that starts replicas, the arguments to be passed in,
            # and the configuration for the backends.
            backend_info = BackendInfo(worker_class=backend_replica,
                                       backend_config=backend_config,
                                       replica_config=replica_config)
            self.backend_state.add_backend(backend_tag, backend_info)
            metadata = backend_config.internal_metadata
            if metadata.autoscaling_config is not None:
                self.autoscaling_policies[
                    backend_tag] = BasicAutoscalingPolicy(
                        backend_tag, metadata.autoscaling_config)

            try:
                # This call should be to run control loop
                self.actor_reconciler._scale_backend_replicas(
                    self.backend_state.backends, backend_tag,
                    backend_config.num_replicas)
            except RayServeException as e:
                del self.backend_state.backends[backend_tag]
                raise e

            return_uuid = self._create_event_with_result(
                {backend_tag: backend_info})
            # NOTE(edoakes): we must write a checkpoint before starting new
            # or pushing the updated config to avoid inconsistent state if we
            # crash while making the change.
            self._checkpoint()
            await self.actor_reconciler._enqueue_pending_scale_changes_loop(
                self.backend_state)
            await self.actor_reconciler.backend_control_loop()

            self.notify_replica_handles_changed()

            # Set the backend config inside routers
            # (particularly for max_concurrent_queries).
            self.notify_backend_configs_changed()
            return return_uuid

    async def delete_backend(self, backend_tag: BackendTag) -> UUID:
        async with self.write_lock:
            # This method must be idempotent. We should validate that the
            # specified backend exists on the client.
            if self.backend_state.get_backend(backend_tag) is None:
                return

            # Check that the specified backend isn't used by any endpoints.
            for endpoint, traffic_policy in self.endpoint_state.\
                    traffic_policies.items():
                if (backend_tag in traffic_policy.traffic_dict
                        or backend_tag in traffic_policy.shadow_dict):
                    raise ValueError("Backend '{}' is used by endpoint '{}' "
                                     "and cannot be deleted. Please remove "
                                     "the backend from all endpoints and try "
                                     "again.".format(backend_tag, endpoint))

            # Scale its replicas down to 0. This will also remove the backend
            # from self.backend_state.backends and
            # self.actor_reconciler.backend_replicas.

            # This should be a call to the control loop
            self.actor_reconciler._scale_backend_replicas(
                self.backend_state.backends, backend_tag, 0)

            # Remove the backend's metadata.
            del self.backend_state.backends[backend_tag]
            if backend_tag in self.autoscaling_policies:
                del self.autoscaling_policies[backend_tag]

            # Add the intention to remove the backend from the routers.
            self.actor_reconciler.backends_to_remove.append(backend_tag)

            return_uuid = self._create_event_with_result({backend_tag: None})
            # NOTE(edoakes): we must write a checkpoint before removing the
            # backend from the routers to avoid inconsistent state if we crash
            # after pushing the update.
            self._checkpoint()
            await self.actor_reconciler._enqueue_pending_scale_changes_loop(
                self.backend_state)
            await self.actor_reconciler.backend_control_loop()

            self.notify_replica_handles_changed()
            return return_uuid

    async def update_backend_config(self, backend_tag: BackendTag,
                                    config_options: BackendConfig) -> UUID:
        """Set the config for the specified backend."""
        async with self.write_lock:
            assert (self.backend_state.get_backend(backend_tag)
                    ), "Backend {} is not registered.".format(backend_tag)
            assert isinstance(config_options, BackendConfig)

            stored_backend_config = self.backend_state.get_backend(
                backend_tag).backend_config
            backend_config = stored_backend_config.copy(
                update=config_options.dict(exclude_unset=True))
            backend_config._validate_complete()
            self.backend_state.get_backend(
                backend_tag).backend_config = backend_config
            backend_info = self.backend_state.get_backend(backend_tag)

            # Scale the replicas with the new configuration.

            # This should be to run the control loop
            self.actor_reconciler._scale_backend_replicas(
                self.backend_state.backends, backend_tag,
                backend_config.num_replicas)

            return_uuid = self._create_event_with_result(
                {backend_tag: backend_info})
            # NOTE(edoakes): we must write a checkpoint before pushing the
            # update to avoid inconsistent state if we crash after pushing the
            # update.
            self._checkpoint()

            # Inform the routers about change in configuration
            # (particularly for setting max_batch_size).

            await self.actor_reconciler._enqueue_pending_scale_changes_loop(
                self.backend_state)
            await self.actor_reconciler.backend_control_loop()

            self.notify_replica_handles_changed()
            self.notify_backend_configs_changed()
            return return_uuid

    def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig:
        """Get the current config for the specified backend."""
        assert (self.backend_state.get_backend(backend_tag)
                ), "Backend {} is not registered.".format(backend_tag)
        return self.backend_state.get_backend(backend_tag).backend_config

    def get_http_config(self):
        """Return the HTTP proxy configuration."""
        return self.http_state.get_config()

    async def shutdown(self) -> None:
        """Shuts down the serve instance completely."""
        async with self.write_lock:
            for proxy in self.http_state.get_http_proxy_handles().values():
                ray.kill(proxy, no_restart=True)
            for replica in self.actor_reconciler.get_replica_handles():
                ray.kill(replica, no_restart=True)
            self.kv_store.delete(CHECKPOINT_KEY)