Esempio n. 1
0
async def test_task_runner_custom_method_single(serve_instance):
    q = RoundRobinPolicyQueueActor.remote()

    class NonBatcher:
        def a(self, _):
            return "a"

        def b(self, _):
            return "b"

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "producer"

    worker = setup_worker(CONSUMER_NAME, NonBatcher)
    await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)

    q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})

    query_param = RequestMetadata(
        PRODUCER_NAME, context.TaskContext.Python, call_method="a")
    a_result = await q.enqueue_request.remote(query_param)
    assert a_result == "a"

    query_param = RequestMetadata(
        PRODUCER_NAME, context.TaskContext.Python, call_method="b")
    b_result = await q.enqueue_request.remote(query_param)
    assert b_result == "b"

    query_param = RequestMetadata(
        PRODUCER_NAME, context.TaskContext.Python, call_method="non_exist")
    with pytest.raises(ray.exceptions.RayTaskError):
        await q.enqueue_request.remote(query_param)
Esempio n. 2
0
async def test_task_runner_custom_method_batch(serve_instance):
    q = RoundRobinPolicyQueueActor.remote()

    @serve.accept_batch
    class Batcher:
        def a(self, _):
            return ["a-{}".format(i) for i in range(serve.context.batch_size)]

        def b(self, _):
            return ["b-{}".format(i) for i in range(serve.context.batch_size)]

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "producer"

    worker = setup_worker(CONSUMER_NAME, Batcher)

    await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
    await q.set_backend_config.remote(
        CONSUMER_NAME,
        BackendConfig({
            "max_batch_size": 10
        }, accepts_batches=True))

    a_query_param = RequestMetadata(
        PRODUCER_NAME, context.TaskContext.Python, call_method="a")
    b_query_param = RequestMetadata(
        PRODUCER_NAME, context.TaskContext.Python, call_method="b")

    futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
    futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]

    await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)

    gathered = await asyncio.gather(*futures)
    assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
Esempio n. 3
0
async def test_ray_serve_mixin(serve_instance):
    q = RoundRobinPolicyQueueActor.remote()

    CONSUMER_NAME = "runner-cls"
    PRODUCER_NAME = "prod-cls"

    class MyAdder:
        def __init__(self, inc):
            self.increment = inc

        def __call__(self, flask_request, i=None):
            return i + self.increment

    @ray.remote
    class CustomActor(MyAdder, RayServeMixin):
        pass

    runner = CustomActor.remote(3)

    runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
    runner._ray_serve_fetch.remote()

    q.link.remote(PRODUCER_NAME, CONSUMER_NAME)

    for query in [333, 444, 555]:
        query_param = RequestMetadata(PRODUCER_NAME,
                                      context.TaskContext.Python)
        result = await q.enqueue_request.remote(query_param, i=query)
        assert result == query + 3
Esempio n. 4
0
async def test_task_runner_custom_method_batch(serve_instance):
    q = RoundRobinPolicyQueueActor.remote()

    @serve.accept_batch
    class Batcher:
        def a(self, _):
            return ["a-{}".format(i) for i in range(serve.context.batch_size)]

        def b(self, _):
            return ["b-{}".format(i) for i in range(serve.context.batch_size)]

        def error_different_size(self, _):
            return [""] * (serve.context.batch_size * 2)

        def error_non_iterable(self, _):
            return 42

        def return_np_array(self, _):
            return np.array([1] * serve.context.batch_size).astype(np.int32)

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "producer"

    worker = setup_worker(CONSUMER_NAME, Batcher)

    await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
    await q.set_backend_config.remote(
        CONSUMER_NAME,
        BackendConfig({
            "max_batch_size": 10
        }, accepts_batches=True))

    def make_request_param(call_method):
        return RequestMetadata(
            PRODUCER_NAME, context.TaskContext.Python, call_method=call_method)

    a_query_param = make_request_param("a")
    b_query_param = make_request_param("b")

    futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
    futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]

    await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)

    gathered = await asyncio.gather(*futures)
    assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}

    with pytest.raises(RayServeException, match="doesn't preserve batch size"):
        different_size = make_request_param("error_different_size")
        await q.enqueue_request.remote(different_size)

    with pytest.raises(RayServeException, match="iterable"):
        non_iterable = make_request_param("error_non_iterable")
        await q.enqueue_request.remote(non_iterable)

    np_array = make_request_param("return_np_array")
    result_np_value = await q.enqueue_request.remote(np_array)
    assert isinstance(result_np_value, np.int32)
Esempio n. 5
0
async def test_task_runner_check_context(serve_instance):
    q = RoundRobinPolicyQueueActor.remote()

    def echo(flask_request, i=None):
        # Accessing the flask_request without web context should throw.
        return flask_request.args["i"]

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "producer"

    worker = setup_worker(CONSUMER_NAME, echo)
    await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)

    q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
    query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
    result_oid = q.enqueue_request.remote(query_param, i=42)

    with pytest.raises(ray.exceptions.RayTaskError):
        await result_oid
Esempio n. 6
0
async def test_runner_actor(serve_instance):
    q = RoundRobinPolicyQueueActor.remote()

    def echo(flask_request, i=None):
        return i

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "prod"

    worker = setup_worker(CONSUMER_NAME, echo)
    await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)

    q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})

    for query in [333, 444, 555]:
        query_param = RequestMetadata(PRODUCER_NAME,
                                      context.TaskContext.Python)
        result = await q.enqueue_request.remote(query_param, i=query)
        assert result == query
Esempio n. 7
0
async def test_round_robin(serve_instance, task_runner_mock_actor):
    q = RoundRobinPolicyQueueActor.remote()

    await q.set_traffic.remote("svc", {"backend-rr": 0.5, "backend-rr-2": 0.5})
    runner_1, runner_2 = [make_task_runner_mock() for _ in range(2)]

    # NOTE: this is the only difference between the
    # test_split_traffic_random and test_round_robin
    await q.add_new_worker.remote("backend-rr", "replica-1", runner_1)
    await q.add_new_worker.remote("backend-rr-2", "replica-1", runner_2)

    for _ in range(20):
        await q.enqueue_request.remote(RequestMetadata("svc", None), 1)

    got_work = [
        await runner.get_recent_call.remote()
        for runner in (runner_1, runner_2)
    ]
    assert [g.request_args[0] for g in got_work] == [1, 1]
Esempio n. 8
0
async def test_runner_actor(serve_instance):
    q = RoundRobinPolicyQueueActor.remote()

    def echo(flask_request, i=None):
        return i

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "prod"

    runner = TaskRunnerActor.remote(echo)
    runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
    runner._ray_serve_fetch.remote()

    q.link.remote(PRODUCER_NAME, CONSUMER_NAME)

    for query in [333, 444, 555]:
        query_param = RequestMetadata(PRODUCER_NAME,
                                      context.TaskContext.Python)
        result = await q.enqueue_request.remote(query_param, i=query)
        assert result == query
Esempio n. 9
0
async def test_task_runner_custom_method_batch(serve_instance):
    q = RoundRobinPolicyQueueActor.remote()

    @serve.accept_batch
    class Batcher:
        def a(self, _):
            return ["a-{}".format(i) for i in range(serve.context.batch_size)]

        def b(self, _):
            return ["b-{}".format(i) for i in range(serve.context.batch_size)]

    @ray.remote
    class CustomActor(Batcher, RayServeMixin):
        pass

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "producer"

    runner = CustomActor.remote()

    ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))

    await q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
    await q.set_backend_config.remote(
        CONSUMER_NAME,
        BackendConfig(max_batch_size=10).__dict__)

    a_query_param = RequestMetadata(PRODUCER_NAME,
                                    context.TaskContext.Python,
                                    call_method="a")
    b_query_param = RequestMetadata(PRODUCER_NAME,
                                    context.TaskContext.Python,
                                    call_method="b")

    futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
    futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]

    await runner._ray_serve_fetch.remote()

    gathered = await asyncio.gather(*futures)
    assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
Esempio n. 10
0
async def test_task_runner_custom_method_single(serve_instance):
    q = RoundRobinPolicyQueueActor.remote()

    class NonBatcher:
        def a(self, _):
            return "a"

        def b(self, _):
            return "b"

    @ray.remote
    class CustomActor(NonBatcher, RayServeMixin):
        pass

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "producer"

    runner = CustomActor.remote()

    ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))
    runner._ray_serve_fetch.remote()

    q.link.remote(PRODUCER_NAME, CONSUMER_NAME)

    query_param = RequestMetadata(PRODUCER_NAME,
                                  context.TaskContext.Python,
                                  call_method="a")
    a_result = await q.enqueue_request.remote(query_param)
    assert a_result == "a"

    query_param = RequestMetadata(PRODUCER_NAME,
                                  context.TaskContext.Python,
                                  call_method="b")
    b_result = await q.enqueue_request.remote(query_param)
    assert b_result == "b"

    query_param = RequestMetadata(PRODUCER_NAME,
                                  context.TaskContext.Python,
                                  call_method="non_exist")
    with pytest.raises(ray.exceptions.RayTaskError):
        await q.enqueue_request.remote(query_param)
Esempio n. 11
0
async def test_task_runner_check_context(serve_instance):
    q = RoundRobinPolicyQueueActor.remote()

    def echo(flask_request, i=None):
        # Accessing the flask_request without web context should throw.
        return flask_request.args["i"]

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "producer"

    runner = TaskRunnerActor.remote(echo)

    runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
    runner._ray_serve_fetch.remote()

    q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
    query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
    result_oid = q.enqueue_request.remote(query_param, i=42)

    with pytest.raises(ray.exceptions.RayTaskError):
        await result_oid
Esempio n. 12
0
async def test_ray_serve_mixin(serve_instance):
    q = RoundRobinPolicyQueueActor.remote()

    CONSUMER_NAME = "runner-cls"
    PRODUCER_NAME = "prod-cls"

    class MyAdder:
        def __init__(self, inc):
            self.increment = inc

        def __call__(self, flask_request, i=None):
            return i + self.increment

    worker = setup_worker(CONSUMER_NAME, MyAdder, init_args=(3, ))
    await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)

    q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})

    for query in [333, 444, 555]:
        query_param = RequestMetadata(PRODUCER_NAME,
                                      context.TaskContext.Python)
        result = await q.enqueue_request.remote(query_param, i=query)
        assert result == query + 3