Example #1
0
async def test_async_client(serve_instance):
    host = ray.remote(LongPollHost).remote()

    # Write two values
    ray.get(host.notify_changed.remote("key_1", 100))
    ray.get(host.notify_changed.remote("key_2", 999))

    # Check that construction fails with a sync callback.
    def callback(result, key):
        pass

    with pytest.raises(ValueError):
        client = LongPollAsyncClient(host, {"key": callback})

    callback_results = dict()

    async def key_1_callback(result):
        callback_results["key_1"] = result

    async def key_2_callback(result):
        callback_results["key_2"] = result

    client = LongPollAsyncClient(host, {
        "key_1": key_1_callback,
        "key_2": key_2_callback,
    })

    while len(client.object_snapshots) == 0:
        # Yield the loop for client to get the result
        await asyncio.sleep(0.2)

    assert client.object_snapshots["key_1"] == 100
    assert client.object_snapshots["key_2"] == 999

    ray.get(host.notify_changed.remote("key_2", 1999))

    values = set()
    for _ in range(3):
        values.add(client.object_snapshots["key_2"])
        if 1999 in values:
            break
        await asyncio.sleep(1)
    assert 1999 in values

    assert callback_results == {"key_1": 100, "key_2": 1999}
Example #2
0
 async def setup_in_async_loop(self):
     # NOTE(simon): Instead of performing initialization in __init__,
     # We separated the init of LongPollAsyncClient to this method because
     # __init__ might be called in sync context. LongPollAsyncClient
     # requires async context.
     self.long_poll_client = LongPollAsyncClient(
         self.controller, {
             LongPollKey.TRAFFIC_POLICIES: self._update_traffic_policies,
             LongPollKey.REPLICA_HANDLES: self._update_replica_handles,
             LongPollKey.BACKEND_CONFIGS: self._update_backend_configs,
         })
Example #3
0
    def __init__(self, controller_name):
        controller = ray.get_actor(controller_name)
        self.router = Router(controller)
        self.long_poll_client = LongPollAsyncClient(
            controller, {
                LongPollKey.ROUTE_TABLE: self._update_route_table,
            })

        self.request_counter = metrics.Count(
            "num_http_requests",
            description="The number of HTTP requests processed",
            tag_keys=("route", ))
Example #4
0
    def __init__(self, controller_name):
        # Set the controller name so that serve.connect() will connect to the
        # controller instance this proxy is running in.
        ray.serve.api._set_internal_controller_name(controller_name)
        self.client = ray.serve.connect()

        controller = ray.get_actor(controller_name)
        self.route_table = {}  # Should be updated via long polling.
        self.router = Router(controller)
        self.long_poll_client = LongPollAsyncClient(controller, {
            LongPollKey.ROUTE_TABLE: self._update_route_table,
        })

        self.request_counter = metrics.Count(
            "num_http_requests",
            description="The number of HTTP requests processed",
            tag_keys=("route", ))
Example #5
0
    def __init__(self, controller_name):
        # Set the controller name so that serve.connect() will connect to the
        # controller instance this proxy is running in.
        ray.serve.api._set_internal_replica_context(None, None,
                                                    controller_name)
        self.client = ray.serve.connect()

        controller = ray.get_actor(controller_name)

        self.router = starlette.routing.Router(default=self._not_found)

        # route -> (endpoint_tag, methods).  Updated via long polling.
        self.route_table: Dict[str, Tuple[EndpointTag, List[str]]] = {}

        self.long_poll_client = LongPollAsyncClient(
            controller, {
                LongPollKey.ROUTE_TABLE: self._update_route_table,
            })

        self.request_counter = metrics.Counter(
            "serve_num_http_requests",
            description="The number of HTTP requests processed.",
            tag_keys=("route", ))
Example #6
0
    def __init__(self, _callable: Callable, backend_config: BackendConfig,
                 is_function: bool, controller_handle: ActorHandle) -> None:
        self.backend_tag = ray.serve.api.get_replica_context().backend_tag
        self.replica_tag = ray.serve.api.get_replica_context().replica_tag
        self.callable = _callable
        self.is_function = is_function

        self.config = backend_config
        self.batch_queue = _BatchQueue(self.config.max_batch_size or 1,
                                       self.config.batch_wait_timeout)
        self.reconfigure(self.config.user_config)

        self.num_ongoing_requests = 0

        self.request_counter = metrics.Counter(
            "serve_backend_request_counter",
            description=("The number of queries that have been "
                         "processed in this replica."),
            tag_keys=("backend", ))
        self.request_counter.set_default_tags({"backend": self.backend_tag})

        self.long_poll_client = LongPollAsyncClient(controller_handle, {
            LongPollKey.BACKEND_CONFIGS: self._update_backend_configs,
        })

        self.error_counter = metrics.Counter(
            "serve_backend_error_counter",
            description=("The number of exceptions that have "
                         "occurred in the backend."),
            tag_keys=("backend", ))
        self.error_counter.set_default_tags({"backend": self.backend_tag})

        self.restart_counter = metrics.Counter(
            "serve_backend_replica_starts",
            description=("The number of times this replica "
                         "has been restarted due to failure."),
            tag_keys=("backend", "replica"))
        self.restart_counter.set_default_tags({
            "backend": self.backend_tag,
            "replica": self.replica_tag
        })

        self.queuing_latency_tracker = metrics.Histogram(
            "serve_backend_queuing_latency_ms",
            description=("The latency for queries in the replica's queue "
                         "waiting to be processed or batched."),
            boundaries=DEFAULT_LATENCY_BUCKET_MS,
            tag_keys=("backend", "replica"))
        self.queuing_latency_tracker.set_default_tags({
            "backend": self.backend_tag,
            "replica": self.replica_tag
        })

        self.processing_latency_tracker = metrics.Histogram(
            "serve_backend_processing_latency_ms",
            description="The latency for queries to be processed.",
            boundaries=DEFAULT_LATENCY_BUCKET_MS,
            tag_keys=("backend", "replica", "batch_size"))
        self.processing_latency_tracker.set_default_tags({
            "backend": self.backend_tag,
            "replica": self.replica_tag
        })

        self.num_queued_items = metrics.Gauge(
            "serve_replica_queued_queries",
            description=("The current number of queries queued in "
                         "the backend replicas."),
            tag_keys=("backend", "replica"))
        self.num_queued_items.set_default_tags({
            "backend": self.backend_tag,
            "replica": self.replica_tag
        })

        self.num_processing_items = metrics.Gauge(
            "serve_replica_processing_queries",
            description="The current number of queries being processed.",
            tag_keys=("backend", "replica"))
        self.num_processing_items.set_default_tags({
            "backend": self.backend_tag,
            "replica": self.replica_tag
        })

        self.restart_counter.inc()

        ray_logger = logging.getLogger("ray")
        for handler in ray_logger.handlers:
            handler.setFormatter(
                logging.Formatter(
                    handler.formatter._fmt +
                    f" component=serve backend={self.backend_tag} "
                    f"replica={self.replica_tag}"))

        asyncio.get_event_loop().create_task(self.main_loop())