class Router: """A router that routes request to available workers.""" async def setup(self, instance_name=None): # Note: Several queues are used in the router # - When a request come in, it's placed inside its corresponding # endpoint_queue. # - The endpoint_queue is dequeued during flush operation, which moves # the queries to backend buffer_queue. Here we match a request # for an endpoint to a backend given some policy. # - The worker_queue is used to collect idle actor handle. These # handles are dequed during the second stage of flush operation, # which assign queries in buffer_queue to actor handle. # -- Queues -- # # endpoint_name -> request queue # We use FIFO (left to right) ordering. The new items should be added # using appendleft. Old items should be removed via pop(). self.endpoint_queues: DefaultDict[deque[Query]] = defaultdict(deque) # backend_name -> worker replica tag queue self.worker_queues: DefaultDict[deque[str]] = defaultdict(deque) # backend_name -> worker payload queue self.backend_queues = defaultdict(blist.sortedlist) # -- Metadata -- # # endpoint_name -> traffic_policy self.traffic = dict() # backend_name -> backend_config self.backend_info = dict() # replica tag -> worker_handle self.replicas = dict() # replica_tag -> concurrent queries counter self.queries_counter = defaultdict(lambda: 0) # -- Synchronization -- # # This lock guarantee that only one flush operation can happen at a # time. Without the lock, multiple flush operation can pop from the # same buffer_queue and worker_queue and create deadlock. For example, # an operation holding the only query and the other flush operation # holding the only idle replica. Additionally, allowing only one flush # operation at a time simplifies design overhead for custom queuing and # batching policies. self.flush_lock = asyncio.Lock() # -- State Restoration -- # # Fetch the worker handles, traffic policies, and backend configs from # the controller. We use a "pull-based" approach instead of pushing # them from the controller so that the router can transparently recover # from failure. serve.init(name=instance_name) controller = serve.api._get_controller() traffic_policies = ray.get(controller.get_traffic_policies.remote()) for endpoint, traffic_policy in traffic_policies.items(): await self.set_traffic(endpoint, traffic_policy) backend_dict = ray.get(controller.get_all_worker_handles.remote()) for backend_tag, replica_dict in backend_dict.items(): for replica_tag, worker in replica_dict.items(): await self.add_new_worker(backend_tag, replica_tag, worker) backend_configs = ray.get(controller.get_backend_configs.remote()) for backend, backend_config in backend_configs.items(): await self.set_backend_config(backend, backend_config) # -- Metric Registration -- # [metric_exporter] = ray.get(controller.get_metric_exporter.remote()) self.metric_client = MetricClient(metric_exporter) self.num_router_requests = self.metric_client.new_counter( "num_router_requests", description="Number of requests processed by the router.", label_names=("endpoint", )) self.num_error_endpoint_request = self.metric_client.new_counter( "num_error_endpoint_requests", description=("Number of requests errored when getting result " "for endpoint."), label_names=("endpoint", )) self.num_error_backend_request = self.metric_client.new_counter( "num_error_backend_requests", description=("Number of requests errored when getting result " "from backend."), label_names=("backend", )) def is_ready(self): return True async def enqueue_request(self, request_meta, *request_args, **request_kwargs): endpoint = request_meta.endpoint logger.debug("Received a request for endpoint {}".format(endpoint)) self.num_router_requests.labels(endpoint=endpoint).add() # check if the slo specified is directly the # wall clock time if request_meta.absolute_slo_ms is not None: request_slo_ms = request_meta.absolute_slo_ms else: request_slo_ms = request_meta.adjust_relative_slo_ms() request_context = request_meta.request_context query = Query(request_args, request_kwargs, request_context, request_slo_ms, call_method=request_meta.call_method, shard_key=request_meta.shard_key, async_future=asyncio.get_event_loop().create_future()) async with self.flush_lock: self.endpoint_queues[endpoint].appendleft(query) self.flush_endpoint_queue(endpoint) try: result = await query.async_future except RayTaskError as e: self.num_error_endpoint_request.labels(endpoint=endpoint).add() result = e return result async def add_new_worker(self, backend_tag, replica_tag, worker_handle): backend_replica_tag = backend_tag + ":" + replica_tag if backend_replica_tag in self.replicas: return self.replicas[backend_replica_tag] = worker_handle logger.debug("New worker added for backend '{}'".format(backend_tag)) await self.mark_worker_idle(backend_tag, backend_replica_tag) async def mark_worker_idle(self, backend_tag, backend_replica_tag): logger.debug( "Marking backend with tag {} as idle.".format(backend_replica_tag)) if backend_replica_tag not in self.replicas: return async with self.flush_lock: # NOTE(simon): This is a O(n) operation where n=len(worker_queue) if backend_replica_tag not in self.worker_queues[backend_tag]: self.worker_queues[backend_tag].appendleft(backend_replica_tag) self.flush_backend_queues([backend_tag]) async def remove_worker(self, backend_tag, replica_tag): backend_replica_tag = backend_tag + ":" + replica_tag if backend_replica_tag not in self.replicas: return # We need this lock because we modify worker_queue here. async with self.flush_lock: del self.replicas[backend_replica_tag] try: self.worker_queues[backend_tag].remove(backend_replica_tag) except ValueError: # Replica doesn't exist in the idle worker queues. # It's ok because the worker might not have returned the # result. pass async def set_traffic(self, endpoint, traffic_policy): logger.debug("Setting traffic for endpoint %s to %s", endpoint, traffic_policy) async with self.flush_lock: self.traffic[endpoint] = RandomEndpointPolicy(traffic_policy) self.flush_endpoint_queue(endpoint) async def remove_endpoint(self, endpoint): logger.debug("Removing endpoint {}".format(endpoint)) async with self.flush_lock: self.flush_endpoint_queue(endpoint) if endpoint in self.endpoint_queues: del self.endpoint_queues[endpoint] if endpoint in self.traffic: del self.traffic[endpoint] async def set_backend_config(self, backend, config): logger.debug("Setting backend config for " "backend {} to {}.".format(backend, config)) async with self.flush_lock: self.backend_info[backend] = config async def remove_backend(self, backend): logger.debug("Removing backend {}".format(backend)) async with self.flush_lock: self.flush_backend_queues([backend]) if backend in self.backend_info: del self.backend_info[backend] if backend in self.worker_queues: del self.worker_queues[backend] if backend in self.backend_queues: del self.backend_queues[backend] def flush_endpoint_queue(self, endpoint): """Attempt to schedule any pending requests to available backends.""" assert self.flush_lock.locked() if endpoint not in self.traffic: return backends_to_flush = self.traffic[endpoint].flush( self.endpoint_queues[endpoint], self.backend_queues) self.flush_backend_queues(backends_to_flush) # Flushes the specified backend queues and assigns work to workers. def flush_backend_queues(self, backends_to_flush): assert self.flush_lock.locked() for backend in backends_to_flush: # No workers available. if len(self.worker_queues[backend]) == 0: continue # No work to do. if len(self.backend_queues[backend]) == 0: continue buffer_queue = self.backend_queues[backend] worker_queue = self.worker_queues[backend] logger.debug("Assigning queries for backend {} with buffer " "queue size {} and worker queue size {}".format( backend, len(buffer_queue), len(worker_queue))) self._assign_query_to_worker( backend, buffer_queue, worker_queue, ) async def _do_query(self, backend, backend_replica_tag, req): # If the worker died, this will be a RayActorError. Just return it and # let the HTTP proxy handle the retry logic. logger.debug("Sending query to replica:" + backend_replica_tag) start = time.time() worker = self.replicas[backend_replica_tag] try: object_ref = worker.handle_request.remote(req.ray_serialize()) if req.is_shadow_query: # No need to actually get the result, but we do need to wait # until the call completes to mark the worker idle. await asyncio.wait([object_ref]) result = "" else: result = await object_ref except RayTaskError as error: self.num_error_backend_request.labels(backend=backend).add() result = error self.queries_counter[backend_replica_tag] -= 1 await self.mark_worker_idle(backend, backend_replica_tag) logger.debug("Got result in {:.2f}s".format(time.time() - start)) return result def _assign_query_to_worker(self, backend, buffer_queue, worker_queue): overloaded_replicas = set() while len(buffer_queue) and len(worker_queue): backend_replica_tag = worker_queue.pop() # The replica might have been deleted already. if backend_replica_tag not in self.replicas: continue # We have reached the end of the worker queue where all replicas # are overloaded. if backend_replica_tag in overloaded_replicas: break # This replica has too many in flight and processing queries. max_queries = 1 if backend in self.backend_info: max_queries = self.backend_info[backend].max_concurrent_queries curr_queries = self.queries_counter[backend_replica_tag] if curr_queries >= max_queries: # Put the worker back to the queue. worker_queue.appendleft(backend_replica_tag) overloaded_replicas.add(backend_replica_tag) logger.debug( "Skipping backend {} because it has {} in flight " "requests which exceeded the concurrency limit.".format( backend, curr_queries)) continue request = buffer_queue.pop(0) self.queries_counter[backend_replica_tag] += 1 future = asyncio.get_event_loop().create_task( self._do_query(backend, backend_replica_tag, request)) # For shadow queries, just ignore the result. if not request.is_shadow_query: chain_future(future, request.async_future) worker_queue.appendleft(backend_replica_tag)
class HTTPProxy: """ This class should be instantiated and ran by ASGI server. >>> import uvicorn >>> uvicorn.run(HTTPProxy(kv_store_actor_handle, router_handle)) # blocks forever """ async def fetch_config_from_controller(self, instance_name=None): assert ray.is_initialized() controller = serve.api._get_controller() self.route_table = await controller.get_router_config.remote() # The exporter is required to return results for /-/metrics endpoint. [self.metric_exporter] = await controller.get_metric_exporter.remote() self.metric_client = MetricClient(self.metric_exporter) self.request_counter = self.metric_client.new_counter( "num_http_requests", description="The number of requests processed", label_names=("route", )) self.router = Router() await self.router.setup(instance_name) def set_route_table(self, route_table): self.route_table = route_table async def receive_http_body(self, scope, receive, send): body_buffer = [] more_body = True while more_body: message = await receive() assert message["type"] == "http.request" more_body = message["more_body"] body_buffer.append(message["body"]) return b"".join(body_buffer) def _parse_latency_slo(self, scope): query_string = scope["query_string"].decode("ascii") query_kwargs = parse_qs(query_string) relative_slo_ms = query_kwargs.pop("relative_slo_ms", None) absolute_slo_ms = query_kwargs.pop("absolute_slo_ms", None) relative_slo_ms = self._validate_slo_ms(relative_slo_ms) absolute_slo_ms = self._validate_slo_ms(absolute_slo_ms) if relative_slo_ms is not None and absolute_slo_ms is not None: raise ValueError("Both relative and absolute slo's" "cannot be specified.") return relative_slo_ms, absolute_slo_ms def _validate_slo_ms(self, request_slo_ms): if request_slo_ms is None: return None if len(request_slo_ms) != 1: raise ValueError( "Multiple SLO specified, please specific only one.") request_slo_ms = request_slo_ms[0] request_slo_ms = float(request_slo_ms) if request_slo_ms < 0: raise ValueError("Request SLO must be positive, it is {}".format( request_slo_ms)) return request_slo_ms def _make_error_sender(self, scope, receive, send): async def sender(error_message, status_code): response = Response(error_message, status_code=status_code) await response.send(scope, receive, send) return sender async def _handle_system_request(self, scope, receive, send): current_path = scope["path"] if current_path == "/-/routes": await Response(self.route_table).send(scope, receive, send) elif current_path == "/-/metrics": metric_info = await self.metric_exporter.inspect_metrics.remote() await Response(metric_info).send(scope, receive, send) else: await Response("System path {} not found".format(current_path), status_code=404).send(scope, receive, send) async def __call__(self, scope, receive, send): # NOTE: This implements ASGI protocol specified in # https://asgi.readthedocs.io/en/latest/specs/index.html error_sender = self._make_error_sender(scope, receive, send) assert self.route_table is not None, ( "Route table must be set via set_route_table.") assert scope["type"] == "http" current_path = scope["path"] self.request_counter.labels(route=current_path).add() if current_path.startswith("/-/"): await self._handle_system_request(scope, receive, send) return try: endpoint_name, methods_allowed = self.route_table[current_path] except KeyError: error_message = ( "Path {} not found. " "Please ping http://.../-/routes for routing table" ).format(current_path) await error_sender(error_message, 404) return if scope["method"] not in methods_allowed: error_message = ("Methods {} not allowed. " "Available HTTP methods are {}.").format( scope["method"], methods_allowed) await error_sender(error_message, 405) return http_body_bytes = await self.receive_http_body(scope, receive, send) # get slo_ms before enqueuing the query try: relative_slo_ms, absolute_slo_ms = self._parse_latency_slo(scope) except ValueError as e: await error_sender(str(e), 400) return headers = {k.decode(): v.decode() for k, v in scope["headers"]} request_metadata = RequestMetadata( endpoint_name, TaskContext.Web, relative_slo_ms=relative_slo_ms, absolute_slo_ms=absolute_slo_ms, call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"), shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), None), ) result = await self.router.enqueue_request(request_metadata, scope, http_body_bytes) if isinstance(result, RayTaskError): error_message = "Task Error. Traceback: {}.".format(result) await error_sender(error_message, 500) else: await Response(result).send(scope, receive, send)
class Router: """A router that routes request to available workers. The traffic policy is used to assign requests to workers. Traffic policy splits the traffic among different replicas probabilistically: 1. When all backends are ready to receive traffic, we will randomly choose a backend based on the weights assigned by the traffic policy dictionary. 2. When more than 1 but not all backends are ready, we will normalize the weights of the ready backends to 1 and choose a backend via sampling. 3. When there is only 1 backend ready, we will only use that backend. """ async def __init__(self, cluster_name=None): # Note: Several queues are used in the router # - When a request come in, it's placed inside its corresponding # endpoint_queue. # - The endpoint_queue is dequeued during flush operation, which moves # the queries to backend buffer_queue. Here we match a request # for an endpoint to a backend given some policy. # - The worker_queue is used to collect idle actor handle. These # handles are dequed during the second stage of flush operation, # which assign queries in buffer_queue to actor handle. # -- Queues -- # # endpoint_name -> request queue self.endpoint_queues: DefaultDict[asyncio.Queue[Query]] = defaultdict( asyncio.Queue) # backend_name -> worker request queue self.worker_queues: DefaultDict[asyncio.Queue[ ray.actor.ActorHandle]] = defaultdict(asyncio.Queue) # backend_name -> worker payload queue self.backend_queues = defaultdict(blist.sortedlist) # -- Metadata -- # # endpoint_name -> traffic_policy self.traffic = dict() # backend_name -> backend_config self.backend_info = dict() # replica tag -> worker_handle self.replicas = dict() # -- Synchronization -- # # This lock guarantee that only one flush operation can happen at a # time. Without the lock, multiple flush operation can pop from the # same buffer_queue and worker_queue and create deadlock. For example, # an operation holding the only query and the other flush operation # holding the only idle replica. Additionally, allowing only one flush # operation at a time simplifies design overhead for custom queuing and # batching polcies. self.flush_lock = asyncio.Lock() # Fetch the worker handles, traffic policies, and backend configs from # the master actor. We use a "pull-based" approach instead of pushing # them from the master so that the router can transparently recover # from failure. serve.init(cluster_name=cluster_name) master_actor = serve.api._get_master_actor() traffic_policies = retry_actor_failures( master_actor.get_traffic_policies) for endpoint, traffic_policy in traffic_policies.items(): await self.set_traffic(endpoint, traffic_policy) backend_dict = retry_actor_failures( master_actor.get_all_worker_handles) for backend_tag, replica_dict in backend_dict.items(): for replica_tag, worker in replica_dict.items(): await self.add_new_worker(backend_tag, replica_tag, worker) backend_configs = retry_actor_failures( master_actor.get_backend_configs) for backend, backend_config in backend_configs.items(): await self.set_backend_config(backend, backend_config) [metric_exporter ] = retry_actor_failures(master_actor.get_metric_exporter) self.metric_client = MetricClient(metric_exporter) self.num_router_requests = self.metric_client.new_counter( "num_router_requests", description="Number of requests processed by the router.", label_names=("endpoint", )) self.num_error_endpoint_request = self.metric_client.new_counter( "num_error_endpoint_requests", description=("Number of requests errored when getting result " "for endpoint."), label_names=("endpoint", )) self.num_error_backend_request = self.metric_client.new_counter( "num_error_backend_requests", description=("Number of requests errored when getting result " "from backend."), label_names=("backend", )) def is_ready(self): return True async def enqueue_request(self, request_meta, *request_args, **request_kwargs): endpoint = request_meta.endpoint logger.debug("Received a request for endpoint {}".format(endpoint)) self.num_router_requests.labels(endpoint=endpoint).add() # check if the slo specified is directly the # wall clock time if request_meta.absolute_slo_ms is not None: request_slo_ms = request_meta.absolute_slo_ms else: request_slo_ms = request_meta.adjust_relative_slo_ms() request_context = request_meta.request_context query = Query(request_args, request_kwargs, request_context, request_slo_ms, call_method=request_meta.call_method, shard_key=request_meta.shard_key, async_future=asyncio.get_event_loop().create_future()) await self.endpoint_queues[endpoint].put(query) async with self.flush_lock: await self.flush_endpoint_queue(endpoint) # Note: a future change can be to directly return the ObjectID from # replica task submission try: result = await query.async_future except RayTaskError as e: self.num_error_endpoint_request.labels(endpoint=endpoint).add() result = e return result async def add_new_worker(self, backend_tag, replica_tag, worker_handle): backend_replica_tag = backend_tag + ":" + replica_tag if backend_replica_tag in self.replicas: return self.replicas[backend_replica_tag] = worker_handle logger.debug("New worker added for backend '{}'".format(backend_tag)) await self.mark_worker_idle(backend_tag, backend_replica_tag) async def mark_worker_idle(self, backend_tag, backend_replica_tag): if backend_replica_tag not in self.replicas: return await self.worker_queues[backend_tag].put(backend_replica_tag) async with self.flush_lock: await self.flush_backend_queues([backend_tag]) async def remove_worker(self, backend_tag, replica_tag): backend_replica_tag = backend_tag + ":" + replica_tag if backend_replica_tag not in self.replicas: return del self.replicas[backend_replica_tag] # We need this lock because we modify worker_queue here. async with self.flush_lock: old_queue = self.worker_queues[backend_tag] new_queue = asyncio.Queue() while not old_queue.empty(): curr_tag = await old_queue.get() if curr_tag != backend_replica_tag: await new_queue.put(curr_tag) self.worker_queues[backend_tag] = new_queue async def set_traffic(self, endpoint, traffic_dict): logger.debug("Setting traffic for endpoint %s to %s", endpoint, traffic_dict) async with self.flush_lock: self.traffic[endpoint] = RandomEndpointPolicy(traffic_dict) await self.flush_endpoint_queue(endpoint) async def remove_endpoint(self, endpoint): logger.debug("Removing endpoint {}".format(endpoint)) async with self.flush_lock: await self.flush_endpoint_queue(endpoint) if endpoint in self.endpoint_queues: del self.endpoint_queues[endpoint] if endpoint in self.traffic: del self.traffic[endpoint] async def set_backend_config(self, backend, config): logger.debug("Setting backend config for " "backend {} to {}.".format(backend, config)) self.backend_info[backend] = config async def remove_backend(self, backend): logger.debug("Removing backend {}".format(backend)) async with self.flush_lock: await self.flush_backend_queues([backend]) if backend in self.backend_info: del self.backend_info[backend] if backend in self.worker_queues: del self.worker_queues[backend] if backend in self.backend_queues: del self.backend_queues[backend] async def flush_endpoint_queue(self, endpoint): """Attempt to schedule any pending requests to available backends.""" assert self.flush_lock.locked() if endpoint not in self.traffic: return backends_to_flush = await self.traffic[endpoint].flush( self.endpoint_queues[endpoint], self.backend_queues) await self.flush_backend_queues(backends_to_flush) def _get_available_backends(self, endpoint): backends_in_policy = set(self.traffic[endpoint].keys()) available_workers = { backend for backend, queues in self.worker_queues.items() if queues.qsize() > 0 } return list(backends_in_policy.intersection(available_workers)) # Flushes the specified backend queues and assigns work to workers. async def flush_backend_queues(self, backends_to_flush): assert self.flush_lock.locked() for backend in backends_to_flush: # No workers available. if self.worker_queues[backend].qsize() == 0: continue # No work to do. if len(self.backend_queues[backend]) == 0: continue buffer_queue = self.backend_queues[backend] worker_queue = self.worker_queues[backend] logger.debug("Assigning queries for backend {} with buffer " "queue size {} and worker queue size {}".format( backend, len(buffer_queue), worker_queue.qsize())) max_batch_size = None if backend in self.backend_info: max_batch_size = self.backend_info[backend].max_batch_size await self._assign_query_to_worker(backend, buffer_queue, worker_queue, max_batch_size) async def _do_query(self, backend, backend_replica_tag, req): # If the worker died, this will be a RayActorError. Just return it and # let the HTTP proxy handle the retry logic. logger.debug("Sending query to replica:" + backend_replica_tag) start = time.time() worker = self.replicas[backend_replica_tag] try: result = await worker.handle_request.remote(req) except RayTaskError as error: self.num_error_backend_request.labels(backend=backend).add() result = error await self.mark_worker_idle(backend, backend_replica_tag) logger.debug("Got result in {:.2f}s".format(time.time() - start)) return result async def _assign_query_to_worker(self, backend, buffer_queue, worker_queue, max_batch_size=None): while len(buffer_queue) and worker_queue.qsize(): backend_replica_tag = await worker_queue.get() if max_batch_size is None: # No batching request = buffer_queue.pop(0) future = asyncio.get_event_loop().create_task( self._do_query(backend, backend_replica_tag, request)) # chaining satisfies request.async_future with future result. asyncio.futures._chain_future(future, request.async_future) else: real_batch_size = min(len(buffer_queue), max_batch_size) requests = [ buffer_queue.pop(0) for _ in range(real_batch_size) ] # split requests by method type requests_group = defaultdict(list) for request in requests: requests_group[request.call_method].append(request) for group in requests_group.values(): future = asyncio.get_event_loop().create_task( self._do_query(backend, backend_replica_tag, group)) future.add_done_callback( _make_future_unwrapper( client_futures=[req.async_future for req in group], host_future=future))