async def test_task_runner_custom_method_single(serve_instance): q = RoundRobinPolicyQueueActor.remote() class NonBatcher: def a(self, _): return "a" def b(self, _): return "b" CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" worker = setup_worker(CONSUMER_NAME, NonBatcher) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) query_param = RequestMetadata( PRODUCER_NAME, context.TaskContext.Python, call_method="a") a_result = await q.enqueue_request.remote(query_param) assert a_result == "a" query_param = RequestMetadata( PRODUCER_NAME, context.TaskContext.Python, call_method="b") b_result = await q.enqueue_request.remote(query_param) assert b_result == "b" query_param = RequestMetadata( PRODUCER_NAME, context.TaskContext.Python, call_method="non_exist") with pytest.raises(ray.exceptions.RayTaskError): await q.enqueue_request.remote(query_param)
async def test_task_runner_custom_method_batch(serve_instance): q = RoundRobinPolicyQueueActor.remote() @serve.accept_batch class Batcher: def a(self, _): return ["a-{}".format(i) for i in range(serve.context.batch_size)] def b(self, _): return ["b-{}".format(i) for i in range(serve.context.batch_size)] CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" worker = setup_worker(CONSUMER_NAME, Batcher) await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) await q.set_backend_config.remote( CONSUMER_NAME, BackendConfig({ "max_batch_size": 10 }, accepts_batches=True)) a_query_param = RequestMetadata( PRODUCER_NAME, context.TaskContext.Python, call_method="a") b_query_param = RequestMetadata( PRODUCER_NAME, context.TaskContext.Python, call_method="b") futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)] futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)] await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) gathered = await asyncio.gather(*futures) assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
async def test_ray_serve_mixin(serve_instance): q = RoundRobinPolicyQueueActor.remote() CONSUMER_NAME = "runner-cls" PRODUCER_NAME = "prod-cls" class MyAdder: def __init__(self, inc): self.increment = inc def __call__(self, flask_request, i=None): return i + self.increment @ray.remote class CustomActor(MyAdder, RayServeMixin): pass runner = CustomActor.remote(3) runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner) runner._ray_serve_fetch.remote() q.link.remote(PRODUCER_NAME, CONSUMER_NAME) for query in [333, 444, 555]: query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) result = await q.enqueue_request.remote(query_param, i=query) assert result == query + 3
async def test_task_runner_custom_method_batch(serve_instance): q = RoundRobinPolicyQueueActor.remote() @serve.accept_batch class Batcher: def a(self, _): return ["a-{}".format(i) for i in range(serve.context.batch_size)] def b(self, _): return ["b-{}".format(i) for i in range(serve.context.batch_size)] def error_different_size(self, _): return [""] * (serve.context.batch_size * 2) def error_non_iterable(self, _): return 42 def return_np_array(self, _): return np.array([1] * serve.context.batch_size).astype(np.int32) CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" worker = setup_worker(CONSUMER_NAME, Batcher) await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) await q.set_backend_config.remote( CONSUMER_NAME, BackendConfig({ "max_batch_size": 10 }, accepts_batches=True)) def make_request_param(call_method): return RequestMetadata( PRODUCER_NAME, context.TaskContext.Python, call_method=call_method) a_query_param = make_request_param("a") b_query_param = make_request_param("b") futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)] futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)] await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) gathered = await asyncio.gather(*futures) assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"} with pytest.raises(RayServeException, match="doesn't preserve batch size"): different_size = make_request_param("error_different_size") await q.enqueue_request.remote(different_size) with pytest.raises(RayServeException, match="iterable"): non_iterable = make_request_param("error_non_iterable") await q.enqueue_request.remote(non_iterable) np_array = make_request_param("return_np_array") result_np_value = await q.enqueue_request.remote(np_array) assert isinstance(result_np_value, np.int32)
async def test_task_runner_check_context(serve_instance): q = RoundRobinPolicyQueueActor.remote() def echo(flask_request, i=None): # Accessing the flask_request without web context should throw. return flask_request.args["i"] CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" worker = setup_worker(CONSUMER_NAME, echo) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) result_oid = q.enqueue_request.remote(query_param, i=42) with pytest.raises(ray.exceptions.RayTaskError): await result_oid
async def test_runner_actor(serve_instance): q = RoundRobinPolicyQueueActor.remote() def echo(flask_request, i=None): return i CONSUMER_NAME = "runner" PRODUCER_NAME = "prod" worker = setup_worker(CONSUMER_NAME, echo) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) for query in [333, 444, 555]: query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) result = await q.enqueue_request.remote(query_param, i=query) assert result == query
async def test_round_robin(serve_instance, task_runner_mock_actor): q = RoundRobinPolicyQueueActor.remote() await q.set_traffic.remote("svc", {"backend-rr": 0.5, "backend-rr-2": 0.5}) runner_1, runner_2 = [make_task_runner_mock() for _ in range(2)] # NOTE: this is the only difference between the # test_split_traffic_random and test_round_robin await q.add_new_worker.remote("backend-rr", "replica-1", runner_1) await q.add_new_worker.remote("backend-rr-2", "replica-1", runner_2) for _ in range(20): await q.enqueue_request.remote(RequestMetadata("svc", None), 1) got_work = [ await runner.get_recent_call.remote() for runner in (runner_1, runner_2) ] assert [g.request_args[0] for g in got_work] == [1, 1]
async def test_runner_actor(serve_instance): q = RoundRobinPolicyQueueActor.remote() def echo(flask_request, i=None): return i CONSUMER_NAME = "runner" PRODUCER_NAME = "prod" runner = TaskRunnerActor.remote(echo) runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner) runner._ray_serve_fetch.remote() q.link.remote(PRODUCER_NAME, CONSUMER_NAME) for query in [333, 444, 555]: query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) result = await q.enqueue_request.remote(query_param, i=query) assert result == query
async def test_task_runner_custom_method_batch(serve_instance): q = RoundRobinPolicyQueueActor.remote() @serve.accept_batch class Batcher: def a(self, _): return ["a-{}".format(i) for i in range(serve.context.batch_size)] def b(self, _): return ["b-{}".format(i) for i in range(serve.context.batch_size)] @ray.remote class CustomActor(Batcher, RayServeMixin): pass CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" runner = CustomActor.remote() ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)) await q.link.remote(PRODUCER_NAME, CONSUMER_NAME) await q.set_backend_config.remote( CONSUMER_NAME, BackendConfig(max_batch_size=10).__dict__) a_query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python, call_method="a") b_query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python, call_method="b") futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)] futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)] await runner._ray_serve_fetch.remote() gathered = await asyncio.gather(*futures) assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
async def test_task_runner_custom_method_single(serve_instance): q = RoundRobinPolicyQueueActor.remote() class NonBatcher: def a(self, _): return "a" def b(self, _): return "b" @ray.remote class CustomActor(NonBatcher, RayServeMixin): pass CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" runner = CustomActor.remote() ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)) runner._ray_serve_fetch.remote() q.link.remote(PRODUCER_NAME, CONSUMER_NAME) query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python, call_method="a") a_result = await q.enqueue_request.remote(query_param) assert a_result == "a" query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python, call_method="b") b_result = await q.enqueue_request.remote(query_param) assert b_result == "b" query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python, call_method="non_exist") with pytest.raises(ray.exceptions.RayTaskError): await q.enqueue_request.remote(query_param)
async def test_task_runner_check_context(serve_instance): q = RoundRobinPolicyQueueActor.remote() def echo(flask_request, i=None): # Accessing the flask_request without web context should throw. return flask_request.args["i"] CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" runner = TaskRunnerActor.remote(echo) runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner) runner._ray_serve_fetch.remote() q.link.remote(PRODUCER_NAME, CONSUMER_NAME) query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) result_oid = q.enqueue_request.remote(query_param, i=42) with pytest.raises(ray.exceptions.RayTaskError): await result_oid
async def test_ray_serve_mixin(serve_instance): q = RoundRobinPolicyQueueActor.remote() CONSUMER_NAME = "runner-cls" PRODUCER_NAME = "prod-cls" class MyAdder: def __init__(self, inc): self.increment = inc def __call__(self, flask_request, i=None): return i + self.increment worker = setup_worker(CONSUMER_NAME, MyAdder, init_args=(3, )) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) for query in [333, 444, 555]: query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) result = await q.enqueue_request.remote(query_param, i=query) assert result == query + 3