def _update_traffic_policy(self, traffic_policy: TrafficPolicy): self.endpoint_policy = RandomEndpointPolicy(traffic_policy) backend_tags = traffic_policy.backend_tags added, removed, _ = compute_iterable_delta( self.backend_replicas.keys(), backend_tags, ) for tag in added: self._get_or_create_replica_set(tag) for tag in removed: del self.backend_replicas[tag] if not self._pending_endpoint_registered.is_set(): self._pending_endpoint_registered.set()
async def _update_traffic_policies(self, traffic_policies): for endpoint, traffic_policy in traffic_policies.items(): self.endpoint_policies[endpoint] = RandomEndpointPolicy( traffic_policy) if endpoint in self._pending_endpoints: event = self._pending_endpoints.pop(endpoint) event.set()
async def _update_traffic_policies(self, traffic_policies): added, removed, updated = compute_dict_delta(self.endpoint_policies, traffic_policies) for endpoint, traffic_policy in ChainMap(added, updated).items(): self.endpoint_policies[endpoint] = RandomEndpointPolicy( traffic_policy) if endpoint in self._pending_endpoints: future = self._pending_endpoints.pop(endpoint) future.set_result(_PendingEndpointFound.ADDED) for endpoint, traffic_policy in removed.items(): del self.endpoint_policies[endpoint] if endpoint in self._pending_endpoints: future = self._pending_endpoints.pop(endpoint) future.set_result(_PendingEndpointFound.REMOVED)
class Router: def __init__( self, controller_handle: ActorHandle, endpoint_tag: EndpointTag, loop: asyncio.BaseEventLoop = None, ): """Router process incoming queries: choose backend, and assign replica. Args: controller_handle(ActorHandle): The controller handle. """ self.controller = controller_handle self.endpoint_tag = endpoint_tag self.endpoint_policy: Optional[EndpointPolicy] = None self.backend_replicas: Dict[BackendTag, ReplicaSet] = dict() self._pending_endpoint_registered = asyncio.Event(loop=loop) self._loop = loop or asyncio.get_event_loop() # -- Metrics Registration -- # self.num_router_requests = metrics.Counter( "serve_num_router_requests", description="The number of requests processed by the router.", tag_keys=("endpoint", )) self.long_poll_client = LongPollClient( self.controller, { (LongPollNamespace.TRAFFIC_POLICIES, endpoint_tag): self._update_traffic_policy, }, call_in_event_loop=self._loop, ) def _update_traffic_policy(self, traffic_policy: TrafficPolicy): self.endpoint_policy = RandomEndpointPolicy(traffic_policy) backend_tags = traffic_policy.backend_tags added, removed, _ = compute_iterable_delta( self.backend_replicas.keys(), backend_tags, ) for tag in added: self._get_or_create_replica_set(tag) for tag in removed: del self.backend_replicas[tag] if not self._pending_endpoint_registered.is_set(): self._pending_endpoint_registered.set() def _get_or_create_replica_set(self, tag): if tag not in self.backend_replicas: self.backend_replicas[tag] = ReplicaSet(self.controller, tag, self._loop) return self.backend_replicas[tag] async def assign_request( self, request_meta: RequestMetadata, *request_args, **request_kwargs, ): """Assign a query and returns an object ref represent the result""" endpoint = request_meta.endpoint query = Query( args=list(request_args), kwargs=request_kwargs, metadata=request_meta, ) if not self._pending_endpoint_registered.is_set(): # This can happen when the router is created but the endpoint # information hasn't been retrieved via long-poll yet. try: await asyncio.wait_for( self._pending_endpoint_registered.wait(), timeout=5, ) except asyncio.TimeoutError: raise RayServeException( f"Endpoint {endpoint} doesn't exist after 5s timeout. " "Marking the query failed.") chosen_backend, *shadow_backends = self.endpoint_policy.assign(query) result_ref = await self._get_or_create_replica_set( chosen_backend).assign_replica(query) for backend in shadow_backends: (await self._get_or_create_replica_set(backend).assign_replica(query)) self.num_router_requests.inc(tags={"endpoint": endpoint}) return result_ref
async def set_traffic(self, endpoint, traffic_policy): logger.debug("Setting traffic for endpoint %s to %s", endpoint, traffic_policy) async with self.flush_lock: self.traffic[endpoint] = RandomEndpointPolicy(traffic_policy) self.flush_endpoint_queue(endpoint)