Exemple #1
0
async def test_task_runner_perform_batch(serve_instance):
    q = ray.remote(Router).remote()

    def batcher(*args, **kwargs):
        return [serve.context.batch_size] * serve.context.batch_size

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "producer"

    config = BackendConfig({
        "max_batch_size": 2,
        "batch_wait_timeout": 10
    },
                           accepts_batches=True)

    worker = setup_worker(CONSUMER_NAME, batcher, backend_config=config)
    await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
    await q.set_backend_config.remote(CONSUMER_NAME, config)
    await q.set_traffic.remote(PRODUCER_NAME,
                               TrafficPolicy({CONSUMER_NAME: 1.0}))

    query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)

    my_batch_sizes = await asyncio.gather(
        *[q.enqueue_request.remote(query_param) for _ in range(3)])
    assert my_batch_sizes == [2, 2, 1]
Exemple #2
0
async def test_alter_backend(serve_instance, task_runner_mock_actor):
    q = ray.remote(Router).remote()

    await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter": 1}))
    await q.add_new_worker.remote("backend-alter", "replica-1",
                                  task_runner_mock_actor)
    await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
    got_work = await task_runner_mock_actor.get_recent_call.remote()
    assert got_work.request_args[0] == 1

    await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter-2": 1}))
    await q.add_new_worker.remote("backend-alter-2", "replica-1",
                                  task_runner_mock_actor)
    await q.enqueue_request.remote(RequestMetadata("svc", None), 2)
    got_work = await task_runner_mock_actor.get_recent_call.remote()
    assert got_work.request_args[0] == 2
Exemple #3
0
async def test_task_runner_custom_method_single(serve_instance):
    q = ray.remote(Router).remote()

    class NonBatcher:
        def a(self, _):
            return "a"

        def b(self, _):
            return "b"

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "producer"

    worker = setup_worker(CONSUMER_NAME, NonBatcher)
    await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)

    q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))

    query_param = RequestMetadata(PRODUCER_NAME,
                                  context.TaskContext.Python,
                                  call_method="a")
    a_result = await q.enqueue_request.remote(query_param)
    assert a_result == "a"

    query_param = RequestMetadata(PRODUCER_NAME,
                                  context.TaskContext.Python,
                                  call_method="b")
    b_result = await q.enqueue_request.remote(query_param)
    assert b_result == "b"

    query_param = RequestMetadata(PRODUCER_NAME,
                                  context.TaskContext.Python,
                                  call_method="non_exist")
    with pytest.raises(ray.exceptions.RayTaskError):
        await q.enqueue_request.remote(query_param)
Exemple #4
0
async def test_router_use_max_concurrency(serve_instance):
    signal = SignalActor.remote()

    @ray.remote
    class MockWorker:
        async def handle_request(self, request):
            await signal.wait.remote()
            return "DONE"

        def ready(self):
            pass

    class VisibleRouter(Router):
        def get_queues(self):
            return self.queries_counter, self.backend_queues

    worker = MockWorker.remote()
    q = ray.remote(VisibleRouter).remote()
    await q.setup.remote()
    BACKEND_NAME = "max-concurrent-test"
    config = BackendConfig({"max_concurrent_queries": 1})
    await q.set_traffic.remote("svc", TrafficPolicy({BACKEND_NAME: 1.0}))
    await q.add_new_worker.remote(BACKEND_NAME, "replica-tag", worker)
    await q.set_backend_config.remote(BACKEND_NAME, config)

    # We send over two queries
    first_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1)
    second_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1)

    # Neither queries should be available
    with pytest.raises(ray.exceptions.RayTimeoutError):
        ray.get([first_query, second_query], timeout=0.2)

    # Let's retrieve the router internal state
    queries_counter, backend_queues = await q.get_queues.remote()
    # There should be just one inflight request
    assert queries_counter["max-concurrent-test:replica-tag"] == 1
    # The second query is buffered
    assert len(backend_queues["max-concurrent-test"]) == 1

    # Let's unblock the first query
    await signal.send.remote(clear=True)
    assert await first_query == "DONE"

    # The internal state of router should have changed.
    queries_counter, backend_queues = await q.get_queues.remote()
    # There should still be one inflight request
    assert queries_counter["max-concurrent-test:replica-tag"] == 1
    # But there shouldn't be any queries in the queue
    assert len(backend_queues["max-concurrent-test"]) == 0

    # Unblocking the second query
    await signal.send.remote(clear=True)
    assert await second_query == "DONE"

    # Checking the internal state of the router one more time
    queries_counter, backend_queues = await q.get_queues.remote()
    assert queries_counter["max-concurrent-test:replica-tag"] == 0
    assert len(backend_queues["max-concurrent-test"]) == 0
Exemple #5
0
async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
    q = ray.remote(Router).remote()
    q.set_traffic.remote("svc", TrafficPolicy({"backend-single-prod": 1.0}))
    q.add_new_worker.remote("backend-single-prod", "replica-1",
                            task_runner_mock_actor)

    # Make sure we get the request result back
    result = await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
    assert result == "DONE"

    # Make sure it's the right request
    got_work = await task_runner_mock_actor.get_recent_call.remote()
    assert got_work.request_args[0] == 1
    assert got_work.request_kwargs == {}
Exemple #6
0
async def test_task_runner_check_context(serve_instance):
    q = ray.remote(Router).remote()

    def echo(flask_request, i=None):
        # Accessing the flask_request without web context should throw.
        return flask_request.args["i"]

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "producer"

    worker = setup_worker(CONSUMER_NAME, echo)
    await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)

    q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))
    query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
    result_oid = q.enqueue_request.remote(query_param, i=42)

    with pytest.raises(ray.exceptions.RayTaskError):
        await result_oid
Exemple #7
0
async def test_runner_actor(serve_instance):
    q = ray.remote(Router).remote()

    def echo(flask_request, i=None):
        return i

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "prod"

    worker = setup_worker(CONSUMER_NAME, echo)
    await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)

    q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))

    for query in [333, 444, 555]:
        query_param = RequestMetadata(PRODUCER_NAME,
                                      context.TaskContext.Python)
        result = await q.enqueue_request.remote(query_param, i=query)
        assert result == query
Exemple #8
0
async def test_task_runner_perform_async(serve_instance):
    q = ray.remote(Router).remote()
    await q.setup.remote()

    @ray.remote
    class Barrier:
        def __init__(self, release_on):
            self.release_on = release_on
            self.current_waiters = 0
            self.event = asyncio.Event()

        async def wait(self):
            self.current_waiters += 1
            if self.current_waiters == self.release_on:
                self.event.set()
            else:
                await self.event.wait()

    barrier = Barrier.remote(release_on=10)

    async def wait_and_go(*args, **kwargs):
        await barrier.wait.remote()
        return "done!"

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "producer"

    config = BackendConfig({"max_concurrent_queries": 10}, is_blocking=False)

    worker = setup_worker(CONSUMER_NAME, wait_and_go, backend_config=config)
    await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
    await q.set_backend_config.remote(CONSUMER_NAME, config)
    q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))

    query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)

    done, not_done = await asyncio.wait(
        [q.enqueue_request.remote(query_param) for _ in range(10)], timeout=10)
    assert len(done) == 10
    for item in done:
        await item == "done!"
Exemple #9
0
async def test_slo(serve_instance, task_runner_mock_actor):
    q = ray.remote(Router).remote()
    await q.set_traffic.remote("svc", TrafficPolicy({"backend-slo": 1.0}))

    all_request_sent = []
    for i in range(10):
        slo_ms = 1000 - 100 * i
        all_request_sent.append(
            q.enqueue_request.remote(
                RequestMetadata("svc", None, relative_slo_ms=slo_ms), i))

    await q.add_new_worker.remote("backend-slo", "replica-1",
                                  task_runner_mock_actor)

    await asyncio.gather(*all_request_sent)

    i_should_be = 9
    all_calls = await task_runner_mock_actor.get_all_calls.remote()
    all_calls = all_calls[-10:]
    for call in all_calls:
        assert call.request_args[0] == i_should_be
        i_should_be -= 1
Exemple #10
0
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
    q = ray.remote(Router).remote()

    await q.set_traffic.remote(
        "svc", TrafficPolicy({
            "backend-split": 0.5,
            "backend-split-2": 0.5
        }))
    runner_1, runner_2 = [mock_task_runner() for _ in range(2)]
    await q.add_new_worker.remote("backend-split", "replica-1", runner_1)
    await q.add_new_worker.remote("backend-split-2", "replica-1", runner_2)

    # assume 50% split, the probability of all 20 requests goes to a
    # single queue is 0.5^20 ~ 1-6
    for _ in range(20):
        await q.enqueue_request.remote(RequestMetadata("svc", None), 1)

    got_work = [
        await runner.get_recent_call.remote()
        for runner in (runner_1, runner_2)
    ]
    assert [g.request_args[0] for g in got_work] == [1, 1]
Exemple #11
0
async def test_ray_serve_mixin(serve_instance):
    q = ray.remote(Router).remote()

    CONSUMER_NAME = "runner-cls"
    PRODUCER_NAME = "prod-cls"

    class MyAdder:
        def __init__(self, inc):
            self.increment = inc

        def __call__(self, flask_request, i=None):
            return i + self.increment

    worker = setup_worker(CONSUMER_NAME, MyAdder, init_args=(3, ))
    await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)

    q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))

    for query in [333, 444, 555]:
        query_param = RequestMetadata(PRODUCER_NAME,
                                      context.TaskContext.Python)
        result = await q.enqueue_request.remote(query_param, i=query)
        assert result == query + 3
Exemple #12
0
async def test_shard_key(serve_instance, task_runner_mock_actor):
    q = ray.remote(Router).remote()
    await q.setup.remote()

    num_backends = 5
    traffic_dict = {}
    runners = [mock_task_runner() for _ in range(num_backends)]
    for i, runner in enumerate(runners):
        backend_name = "backend-split-" + str(i)
        traffic_dict[backend_name] = 1.0 / num_backends
        await q.add_new_worker.remote(backend_name, "replica-1", runner)
    await q.set_traffic.remote("svc", TrafficPolicy(traffic_dict))

    # Generate random shard keys and send one request for each.
    shard_keys = [get_random_letters() for _ in range(100)]
    for shard_key in shard_keys:
        await q.enqueue_request.remote(
            RequestMetadata("svc", None, shard_key=shard_key), shard_key)

    # Log the shard keys that were assigned to each backend.
    runner_shard_keys = defaultdict(set)
    for i, runner in enumerate(runners):
        calls = await runner.get_all_calls.remote()
        for call in calls:
            runner_shard_keys[i].add(call.request_args[0])
        await runner.clear_calls.remote()

    # Send queries with the same shard keys a second time.
    for shard_key in shard_keys:
        await q.enqueue_request.remote(
            RequestMetadata("svc", None, shard_key=shard_key), shard_key)

    # Check that the requests were all mapped to the same backends.
    for i, runner in enumerate(runners):
        calls = await runner.get_all_calls.remote()
        for call in calls:
            assert call.request_args[0] in runner_shard_keys[i]
Exemple #13
0
async def test_task_runner_custom_method_batch(serve_instance):
    q = ray.remote(Router).remote()

    @serve.accept_batch
    class Batcher:
        def a(self, _):
            return ["a-{}".format(i) for i in range(serve.context.batch_size)]

        def b(self, _):
            return ["b-{}".format(i) for i in range(serve.context.batch_size)]

        def error_different_size(self, _):
            return [""] * (serve.context.batch_size * 2)

        def error_non_iterable(self, _):
            return 42

        def return_np_array(self, _):
            return np.array([1] * serve.context.batch_size).astype(np.int32)

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "producer"

    backend_config = BackendConfig(
        {
            "max_batch_size": 4,
            "batch_wait_timeout": 2
        }, accepts_batches=True)
    worker = setup_worker(CONSUMER_NAME,
                          Batcher,
                          backend_config=backend_config)

    await q.set_traffic.remote(PRODUCER_NAME,
                               TrafficPolicy({CONSUMER_NAME: 1.0}))
    await q.set_backend_config.remote(CONSUMER_NAME, backend_config)

    def make_request_param(call_method):
        return RequestMetadata(PRODUCER_NAME,
                               context.TaskContext.Python,
                               call_method=call_method)

    a_query_param = make_request_param("a")
    b_query_param = make_request_param("b")

    futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
    futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]

    await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)

    gathered = await asyncio.gather(*futures)
    assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}

    with pytest.raises(RayServeException, match="doesn't preserve batch size"):
        different_size = make_request_param("error_different_size")
        await q.enqueue_request.remote(different_size)

    with pytest.raises(RayServeException, match="iterable"):
        non_iterable = make_request_param("error_non_iterable")
        await q.enqueue_request.remote(non_iterable)

    np_array = make_request_param("return_np_array")
    result_np_value = await q.enqueue_request.remote(np_array)
    assert isinstance(result_np_value, np.int32)