async def test_task_runner_perform_batch(serve_instance): q = ray.remote(Router).remote() def batcher(*args, **kwargs): return [serve.context.batch_size] * serve.context.batch_size CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" config = BackendConfig({ "max_batch_size": 2, "batch_wait_timeout": 10 }, accepts_batches=True) worker = setup_worker(CONSUMER_NAME, batcher, backend_config=config) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) await q.set_backend_config.remote(CONSUMER_NAME, config) await q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) my_batch_sizes = await asyncio.gather( *[q.enqueue_request.remote(query_param) for _ in range(3)]) assert my_batch_sizes == [2, 2, 1]
async def test_alter_backend(serve_instance, task_runner_mock_actor): q = ray.remote(Router).remote() await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter": 1})) await q.add_new_worker.remote("backend-alter", "replica-1", task_runner_mock_actor) await q.enqueue_request.remote(RequestMetadata("svc", None), 1) got_work = await task_runner_mock_actor.get_recent_call.remote() assert got_work.request_args[0] == 1 await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter-2": 1})) await q.add_new_worker.remote("backend-alter-2", "replica-1", task_runner_mock_actor) await q.enqueue_request.remote(RequestMetadata("svc", None), 2) got_work = await task_runner_mock_actor.get_recent_call.remote() assert got_work.request_args[0] == 2
async def test_task_runner_custom_method_single(serve_instance): q = ray.remote(Router).remote() class NonBatcher: def a(self, _): return "a" def b(self, _): return "b" CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" worker = setup_worker(CONSUMER_NAME, NonBatcher) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python, call_method="a") a_result = await q.enqueue_request.remote(query_param) assert a_result == "a" query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python, call_method="b") b_result = await q.enqueue_request.remote(query_param) assert b_result == "b" query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python, call_method="non_exist") with pytest.raises(ray.exceptions.RayTaskError): await q.enqueue_request.remote(query_param)
async def test_router_use_max_concurrency(serve_instance): signal = SignalActor.remote() @ray.remote class MockWorker: async def handle_request(self, request): await signal.wait.remote() return "DONE" def ready(self): pass class VisibleRouter(Router): def get_queues(self): return self.queries_counter, self.backend_queues worker = MockWorker.remote() q = ray.remote(VisibleRouter).remote() await q.setup.remote() BACKEND_NAME = "max-concurrent-test" config = BackendConfig({"max_concurrent_queries": 1}) await q.set_traffic.remote("svc", TrafficPolicy({BACKEND_NAME: 1.0})) await q.add_new_worker.remote(BACKEND_NAME, "replica-tag", worker) await q.set_backend_config.remote(BACKEND_NAME, config) # We send over two queries first_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1) second_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1) # Neither queries should be available with pytest.raises(ray.exceptions.RayTimeoutError): ray.get([first_query, second_query], timeout=0.2) # Let's retrieve the router internal state queries_counter, backend_queues = await q.get_queues.remote() # There should be just one inflight request assert queries_counter["max-concurrent-test:replica-tag"] == 1 # The second query is buffered assert len(backend_queues["max-concurrent-test"]) == 1 # Let's unblock the first query await signal.send.remote(clear=True) assert await first_query == "DONE" # The internal state of router should have changed. queries_counter, backend_queues = await q.get_queues.remote() # There should still be one inflight request assert queries_counter["max-concurrent-test:replica-tag"] == 1 # But there shouldn't be any queries in the queue assert len(backend_queues["max-concurrent-test"]) == 0 # Unblocking the second query await signal.send.remote(clear=True) assert await second_query == "DONE" # Checking the internal state of the router one more time queries_counter, backend_queues = await q.get_queues.remote() assert queries_counter["max-concurrent-test:replica-tag"] == 0 assert len(backend_queues["max-concurrent-test"]) == 0
async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor): q = ray.remote(Router).remote() q.set_traffic.remote("svc", TrafficPolicy({"backend-single-prod": 1.0})) q.add_new_worker.remote("backend-single-prod", "replica-1", task_runner_mock_actor) # Make sure we get the request result back result = await q.enqueue_request.remote(RequestMetadata("svc", None), 1) assert result == "DONE" # Make sure it's the right request got_work = await task_runner_mock_actor.get_recent_call.remote() assert got_work.request_args[0] == 1 assert got_work.request_kwargs == {}
async def test_task_runner_check_context(serve_instance): q = ray.remote(Router).remote() def echo(flask_request, i=None): # Accessing the flask_request without web context should throw. return flask_request.args["i"] CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" worker = setup_worker(CONSUMER_NAME, echo) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) result_oid = q.enqueue_request.remote(query_param, i=42) with pytest.raises(ray.exceptions.RayTaskError): await result_oid
async def test_runner_actor(serve_instance): q = ray.remote(Router).remote() def echo(flask_request, i=None): return i CONSUMER_NAME = "runner" PRODUCER_NAME = "prod" worker = setup_worker(CONSUMER_NAME, echo) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) for query in [333, 444, 555]: query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) result = await q.enqueue_request.remote(query_param, i=query) assert result == query
async def test_task_runner_perform_async(serve_instance): q = ray.remote(Router).remote() await q.setup.remote() @ray.remote class Barrier: def __init__(self, release_on): self.release_on = release_on self.current_waiters = 0 self.event = asyncio.Event() async def wait(self): self.current_waiters += 1 if self.current_waiters == self.release_on: self.event.set() else: await self.event.wait() barrier = Barrier.remote(release_on=10) async def wait_and_go(*args, **kwargs): await barrier.wait.remote() return "done!" CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" config = BackendConfig({"max_concurrent_queries": 10}, is_blocking=False) worker = setup_worker(CONSUMER_NAME, wait_and_go, backend_config=config) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) await q.set_backend_config.remote(CONSUMER_NAME, config) q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) done, not_done = await asyncio.wait( [q.enqueue_request.remote(query_param) for _ in range(10)], timeout=10) assert len(done) == 10 for item in done: await item == "done!"
async def test_slo(serve_instance, task_runner_mock_actor): q = ray.remote(Router).remote() await q.set_traffic.remote("svc", TrafficPolicy({"backend-slo": 1.0})) all_request_sent = [] for i in range(10): slo_ms = 1000 - 100 * i all_request_sent.append( q.enqueue_request.remote( RequestMetadata("svc", None, relative_slo_ms=slo_ms), i)) await q.add_new_worker.remote("backend-slo", "replica-1", task_runner_mock_actor) await asyncio.gather(*all_request_sent) i_should_be = 9 all_calls = await task_runner_mock_actor.get_all_calls.remote() all_calls = all_calls[-10:] for call in all_calls: assert call.request_args[0] == i_should_be i_should_be -= 1
async def test_split_traffic_random(serve_instance, task_runner_mock_actor): q = ray.remote(Router).remote() await q.set_traffic.remote( "svc", TrafficPolicy({ "backend-split": 0.5, "backend-split-2": 0.5 })) runner_1, runner_2 = [mock_task_runner() for _ in range(2)] await q.add_new_worker.remote("backend-split", "replica-1", runner_1) await q.add_new_worker.remote("backend-split-2", "replica-1", runner_2) # assume 50% split, the probability of all 20 requests goes to a # single queue is 0.5^20 ~ 1-6 for _ in range(20): await q.enqueue_request.remote(RequestMetadata("svc", None), 1) got_work = [ await runner.get_recent_call.remote() for runner in (runner_1, runner_2) ] assert [g.request_args[0] for g in got_work] == [1, 1]
async def test_ray_serve_mixin(serve_instance): q = ray.remote(Router).remote() CONSUMER_NAME = "runner-cls" PRODUCER_NAME = "prod-cls" class MyAdder: def __init__(self, inc): self.increment = inc def __call__(self, flask_request, i=None): return i + self.increment worker = setup_worker(CONSUMER_NAME, MyAdder, init_args=(3, )) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) for query in [333, 444, 555]: query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) result = await q.enqueue_request.remote(query_param, i=query) assert result == query + 3
async def test_shard_key(serve_instance, task_runner_mock_actor): q = ray.remote(Router).remote() await q.setup.remote() num_backends = 5 traffic_dict = {} runners = [mock_task_runner() for _ in range(num_backends)] for i, runner in enumerate(runners): backend_name = "backend-split-" + str(i) traffic_dict[backend_name] = 1.0 / num_backends await q.add_new_worker.remote(backend_name, "replica-1", runner) await q.set_traffic.remote("svc", TrafficPolicy(traffic_dict)) # Generate random shard keys and send one request for each. shard_keys = [get_random_letters() for _ in range(100)] for shard_key in shard_keys: await q.enqueue_request.remote( RequestMetadata("svc", None, shard_key=shard_key), shard_key) # Log the shard keys that were assigned to each backend. runner_shard_keys = defaultdict(set) for i, runner in enumerate(runners): calls = await runner.get_all_calls.remote() for call in calls: runner_shard_keys[i].add(call.request_args[0]) await runner.clear_calls.remote() # Send queries with the same shard keys a second time. for shard_key in shard_keys: await q.enqueue_request.remote( RequestMetadata("svc", None, shard_key=shard_key), shard_key) # Check that the requests were all mapped to the same backends. for i, runner in enumerate(runners): calls = await runner.get_all_calls.remote() for call in calls: assert call.request_args[0] in runner_shard_keys[i]
async def test_task_runner_custom_method_batch(serve_instance): q = ray.remote(Router).remote() @serve.accept_batch class Batcher: def a(self, _): return ["a-{}".format(i) for i in range(serve.context.batch_size)] def b(self, _): return ["b-{}".format(i) for i in range(serve.context.batch_size)] def error_different_size(self, _): return [""] * (serve.context.batch_size * 2) def error_non_iterable(self, _): return 42 def return_np_array(self, _): return np.array([1] * serve.context.batch_size).astype(np.int32) CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" backend_config = BackendConfig( { "max_batch_size": 4, "batch_wait_timeout": 2 }, accepts_batches=True) worker = setup_worker(CONSUMER_NAME, Batcher, backend_config=backend_config) await q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) await q.set_backend_config.remote(CONSUMER_NAME, backend_config) def make_request_param(call_method): return RequestMetadata(PRODUCER_NAME, context.TaskContext.Python, call_method=call_method) a_query_param = make_request_param("a") b_query_param = make_request_param("b") futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)] futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)] await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) gathered = await asyncio.gather(*futures) assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"} with pytest.raises(RayServeException, match="doesn't preserve batch size"): different_size = make_request_param("error_different_size") await q.enqueue_request.remote(different_size) with pytest.raises(RayServeException, match="iterable"): non_iterable = make_request_param("error_non_iterable") await q.enqueue_request.remote(non_iterable) np_array = make_request_param("return_np_array") result_np_value = await q.enqueue_request.remote(np_array) assert isinstance(result_np_value, np.int32)