Ejemplo n.º 1
0
def test_backend_config_update():
    b = BackendConfig({"num_replicas": 1, "max_batch_size": 1})

    # Test updating a key works.
    b.update({"num_replicas": 2})
    assert b.num_replicas == 2
    # Check that not specifying a key doesn't update it.
    assert b.max_batch_size == 1

    # Check that passing an invalid key fails.
    with pytest.raises(ValueError):
        b.update({"unknown": 1})

    # Check that input is validated.
    with pytest.raises(TypeError):
        b.update({"num_replicas": "hello"})
    with pytest.raises(ValueError):
        b.update({"num_replicas": -1})

    # Test batch validation.
    b = BackendConfig({}, accepts_batches=False)
    b.update({"max_batch_size": 1})
    with pytest.raises(ValueError):
        b.update({"max_batch_size": 2})

    b = BackendConfig({}, accepts_batches=True)
    b.update({"max_batch_size": 2})
Ejemplo n.º 2
0
def test_list_backends(serve_instance):
    serve.init()

    @serve.accept_batch
    def f():
        pass

    serve.create_backend("backend", f, config=BackendConfig(max_batch_size=10))
    backends = serve.list_backends()
    assert len(backends) == 1
    assert "backend" in backends
    assert backends["backend"]["max_batch_size"] == 10

    serve.create_backend("backend2", f, config=BackendConfig(num_replicas=10))
    backends = serve.list_backends()
    assert len(backends) == 2
    assert backends["backend2"]["num_replicas"] == 10

    serve.delete_backend("backend")
    backends = serve.list_backends()
    assert len(backends) == 1
    assert "backend2" in backends

    serve.delete_backend("backend2")
    assert len(serve.list_backends()) == 0
Ejemplo n.º 3
0
def test_backend_user_config(serve_instance):
    client = serve_instance

    class Counter:
        def __init__(self):
            self.count = 10

        def __call__(self, starlette_request):
            return self.count, os.getpid()

        def reconfigure(self, config):
            self.count = config["count"]

    config = BackendConfig(num_replicas=2, user_config={"count": 123, "b": 2})
    client.create_backend("counter", Counter, config=config)
    client.create_endpoint("counter", backend="counter")
    handle = client.get_handle("counter")

    def check(val, num_replicas):
        pids_seen = set()
        for i in range(100):
            result = ray.get(handle.remote())
            if str(result[0]) != val:
                return False
            pids_seen.add(result[1])
        return len(pids_seen) == num_replicas

    wait_for_condition(lambda: check("123", 2))

    client.update_backend_config("counter", BackendConfig(num_replicas=3))
    wait_for_condition(lambda: check("123", 3))

    config = BackendConfig(user_config={"count": 456})
    client.update_backend_config("counter", config)
    wait_for_condition(lambda: check("456", 3))
Ejemplo n.º 4
0
def test_updating_config(serve_instance):
    class BatchSimple:
        def __init__(self):
            self.count = 0

        @serve.accept_batch
        def __call__(self, flask_request, temp=None):
            batch_size = serve.context.batch_size
            return [1] * batch_size

    serve.create_backend(
        "bsimple:v1",
        BatchSimple,
        config=BackendConfig(max_batch_size=2, num_replicas=3))
    serve.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple")

    controller = serve.api._get_controller()
    old_replica_tag_list = ray.get(
        controller._list_replicas.remote("bsimple:v1"))

    serve.update_backend_config("bsimple:v1", BackendConfig(max_batch_size=5))
    new_replica_tag_list = ray.get(
        controller._list_replicas.remote("bsimple:v1"))
    new_all_tag_list = []
    for worker_dict in ray.get(
            controller.get_all_worker_handles.remote()).values():
        new_all_tag_list.extend(list(worker_dict.keys()))

    # the old and new replica tag list should be identical
    # and should be subset of all_tag_list
    assert set(old_replica_tag_list) <= set(new_all_tag_list)
    assert set(old_replica_tag_list) == set(new_replica_tag_list)
Ejemplo n.º 5
0
def test_updating_config(serve_instance):
    client = serve_instance

    class BatchSimple:
        def __init__(self):
            self.count = 0

        @serve.accept_batch
        def __call__(self, request):
            return [1] * len(request)

    config = BackendConfig(max_batch_size=2, num_replicas=3)
    client.create_backend("bsimple:v1", BatchSimple, config=config)
    client.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple")

    controller = client._controller
    old_replica_tag_list = list(
        ray.get(controller._all_replica_handles.remote())["bsimple:v1"].keys())

    update_config = BackendConfig(max_batch_size=5)
    client.update_backend_config("bsimple:v1", update_config)
    new_replica_tag_list = list(
        ray.get(controller._all_replica_handles.remote())["bsimple:v1"].keys())
    new_all_tag_list = []
    for worker_dict in ray.get(
            controller._all_replica_handles.remote()).values():
        new_all_tag_list.extend(list(worker_dict.keys()))

    # the old and new replica tag list should be identical
    # and should be subset of all_tag_list
    assert set(old_replica_tag_list) <= set(new_all_tag_list)
    assert set(old_replica_tag_list) == set(new_replica_tag_list)
Ejemplo n.º 6
0
def test_imported_backend(serve_instance):
    config = BackendConfig(user_config="config")
    serve.create_backend("imported",
                         "ray.serve.utils.MockImportedBackend",
                         "input_arg",
                         config=config)
    serve.create_endpoint("imported", backend="imported")

    # Basic sanity check.
    handle = serve.get_handle("imported")
    assert ray.get(handle.remote()) == {"arg": "input_arg", "config": "config"}

    # Check that updating backend config works.
    serve.update_backend_config("imported",
                                BackendConfig(user_config="new_config"))
    assert ray.get(handle.remote()) == {
        "arg": "input_arg",
        "config": "new_config"
    }

    # Check that other call methods work.
    handle = handle.options(method_name="other_method")
    assert ray.get(handle.remote("hello")) == "hello"

    # Check that functions work as well.
    serve.create_backend("imported_func",
                         "ray.serve.utils.mock_imported_function")
    serve.create_endpoint("imported_func", backend="imported_func")
    handle = serve.get_handle("imported_func")
    assert ray.get(handle.remote("hello")) == "hello"
Ejemplo n.º 7
0
def test_list_backends(serve_instance):
    client = serve_instance

    @serve.accept_batch
    def f():
        pass

    config1 = BackendConfig(max_batch_size=10)
    client.create_backend("backend", f, config=config1)
    backends = client.list_backends()
    assert len(backends) == 1
    assert "backend" in backends
    assert backends["backend"].max_batch_size == 10

    config2 = BackendConfig(num_replicas=10)
    client.create_backend("backend2", f, config=config2)
    backends = client.list_backends()
    assert len(backends) == 2
    assert backends["backend2"].num_replicas == 10

    client.delete_backend("backend")
    backends = client.list_backends()
    assert len(backends) == 1
    assert "backend2" in backends

    client.delete_backend("backend2")
    assert len(client.list_backends()) == 0
Ejemplo n.º 8
0
def test_list_backends(serve_instance, use_legacy_config):
    client = serve_instance

    @serve.accept_batch
    def f():
        pass

    config1 = {
        "max_batch_size": 10
    } if use_legacy_config else BackendConfig(max_batch_size=10)
    client.create_backend("backend", f, config=config1)
    backends = client.list_backends()
    assert len(backends) == 1
    assert "backend" in backends
    assert backends["backend"]["max_batch_size"] == 10

    config2 = {
        "num_replicas": 10
    } if use_legacy_config else BackendConfig(num_replicas=10)
    client.create_backend("backend2", f, config=config2)
    backends = client.list_backends()
    assert len(backends) == 2
    assert backends["backend2"]["num_replicas"] == 10

    client.delete_backend("backend")
    backends = client.list_backends()
    assert len(backends) == 1
    assert "backend2" in backends

    client.delete_backend("backend2")
    assert len(client.list_backends()) == 0
Ejemplo n.º 9
0
def test_updating_config(serve_instance):
    @serve.deployment(
        "bsimple", config=BackendConfig(max_batch_size=2, num_replicas=2))
    class BatchSimple:
        def __init__(self):
            self.count = 0

        @serve.accept_batch
        def __call__(self, request):
            return [1] * len(request)

    BatchSimple.deploy()

    controller = serve.api._global_client._controller
    old_replica_tag_list = list(
        ray.get(controller._all_replica_handles.remote())["bsimple"].keys())

    BatchSimple.options(
        config=BackendConfig(max_batch_size=5, num_replicas=2)).deploy()
    new_replica_tag_list = list(
        ray.get(controller._all_replica_handles.remote())["bsimple"].keys())
    new_all_tag_list = []
    for worker_dict in ray.get(
            controller._all_replica_handles.remote()).values():
        new_all_tag_list.extend(list(worker_dict.keys()))

    # the old and new replica tag list should be identical
    # and should be subset of all_tag_list
    assert set(old_replica_tag_list) <= set(new_all_tag_list)
    assert set(old_replica_tag_list) == set(new_replica_tag_list)
Ejemplo n.º 10
0
def test_scaling_replicas(serve_instance):
    @serve.deployment("counter", config=BackendConfig(num_replicas=2))
    class Counter:
        def __init__(self):
            self.count = 0

        def __call__(self, _):
            self.count += 1
            return self.count

    Counter.deploy()

    counter_result = []
    for _ in range(10):
        resp = requests.get("http://127.0.0.1:8000/counter").json()
        counter_result.append(resp)

    # If the load is shared among two replicas. The max result cannot be 10.
    assert max(counter_result) < 10

    Counter.options(config=BackendConfig(num_replicas=1)).deploy()

    counter_result = []
    for _ in range(10):
        resp = requests.get("http://127.0.0.1:8000/counter").json()
        counter_result.append(resp)
    # Give some time for a replica to spin down. But majority of the request
    # should be served by the only remaining replica.
    assert max(counter_result) - min(counter_result) > 6
Ejemplo n.º 11
0
def test_backend_user_config(serve_instance):
    client = serve_instance

    class Counter:
        def __init__(self):
            self.count = 10

        def __call__(self, flask_request):
            return self.count, os.getpid()

        def reconfigure(self, config):
            self.count = config["count"]

    config = BackendConfig(num_replicas=2, user_config={"count": 123, "b": 2})
    client.create_backend("counter", Counter, config=config)
    client.create_endpoint("counter", backend="counter", route="/counter")
    handle = client.get_handle("counter")

    def check(val, num_replicas):
        pids_seen = set()
        for i in range(100):
            result = ray.get(handle.remote())
            assert (str(result[0]) == val), result[0]
            pids_seen.add(result[1])
        assert (len(pids_seen) == num_replicas)

    check("123", 2)

    client.update_backend_config("counter", BackendConfig(num_replicas=3))
    check("123", 3)

    config = BackendConfig(user_config={"count": 456})
    client.update_backend_config("counter", config)
    check("456", 3)
Ejemplo n.º 12
0
def test_scaling_replicas(serve_instance):
    client = serve_instance

    class Counter:
        def __init__(self):
            self.count = 0

        def __call__(self, _):
            self.count += 1
            return self.count

    config = BackendConfig(num_replicas=2)
    client.create_backend("counter:v1", Counter, config=config)

    client.create_endpoint("counter", backend="counter:v1", route="/increment")

    counter_result = []
    for _ in range(10):
        resp = requests.get("http://127.0.0.1:8000/increment").json()
        counter_result.append(resp)

    # If the load is shared among two replicas. The max result cannot be 10.
    assert max(counter_result) < 10

    update_config = BackendConfig(num_replicas=1)
    client.update_backend_config("counter:v1", update_config)

    counter_result = []
    for _ in range(10):
        resp = requests.get("http://127.0.0.1:8000/increment").json()
        counter_result.append(resp)
    # Give some time for a replica to spin down. But majority of the request
    # should be served by the only remaining replica.
    assert max(counter_result) - min(counter_result) > 6
Ejemplo n.º 13
0
def test_with_proto():
    # Test roundtrip
    config = BackendConfig(num_replicas=100, max_concurrent_queries=16)
    assert config == BackendConfig.from_proto_bytes(config.to_proto_bytes())

    # Test user_config object
    config = BackendConfig(user_config={"python": ("native", ["objects"])})
    assert config == BackendConfig.from_proto_bytes(config.to_proto_bytes())
Ejemplo n.º 14
0
def test_backend_config_validation():
    # Test unknown key.
    with pytest.raises(ValidationError):
        BackendConfig(unknown_key=-1)

    # Test num_replicas validation.
    BackendConfig(num_replicas=1)
    with pytest.raises(ValidationError, match="type_error"):
        BackendConfig(num_replicas="hello")
    with pytest.raises(ValidationError, match="value_error"):
        BackendConfig(num_replicas=-1)

    # Test dynamic default for max_concurrent_queries.
    assert BackendConfig().max_concurrent_queries == 100
Ejemplo n.º 15
0
def setup_worker(name,
                 backend_def,
                 init_args=None,
                 backend_config=BackendConfig(),
                 controller_name=""):
    if init_args is None:
        init_args = ()

    @ray.remote
    class WorkerActor:
        async def __init__(self):
            self.worker = object.__new__(create_backend_replica(backend_def))
            await self.worker.__init__(name, name + ":tag", init_args,
                                       backend_config, controller_name)

        def ready(self):
            pass

        @ray.method(num_returns=2)
        async def handle_request(self, *args, **kwargs):
            return await self.worker.handle_request(*args, **kwargs)

        def update_config(self, new_config):
            return self.worker.update_config(new_config)

        async def drain_pending_queries(self):
            return await self.worker.drain_pending_queries()

    worker = WorkerActor.remote()
    ray.get(worker.ready.remote())
    return worker
Ejemplo n.º 16
0
def test_scaling_replicas(serve_instance):
    class Counter:
        def __init__(self):
            self.count = 0

        def __call__(self, _):
            self.count += 1
            return self.count

    serve.create_backend(
        "counter:v1", Counter, config=BackendConfig(num_replicas=2))
    serve.create_endpoint("counter", backend="counter:v1", route="/increment")

    # Keep checking the routing table until /increment is populated
    while "/increment" not in requests.get(
            "http://127.0.0.1:8000/-/routes").json():
        time.sleep(0.2)

    counter_result = []
    for _ in range(10):
        resp = requests.get("http://127.0.0.1:8000/increment").json()
        counter_result.append(resp)

    # If the load is shared among two replicas. The max result cannot be 10.
    assert max(counter_result) < 10

    serve.update_backend_config("counter:v1", {"num_replicas": 1})

    counter_result = []
    for _ in range(10):
        resp = requests.get("http://127.0.0.1:8000/increment").json()
        counter_result.append(resp)
    # Give some time for a replica to spin down. But majority of the request
    # should be served by the only remaining replica.
    assert max(counter_result) - min(counter_result) > 6
Ejemplo n.º 17
0
def test_batching(serve_instance):
    client = serve_instance

    class BatchingExample:
        def __init__(self):
            self.count = 0

        @serve.accept_batch
        def __call__(self, requests):
            self.count += 1
            batch_size = len(requests)
            return [self.count] * batch_size

    # set the max batch size
    config = BackendConfig(max_batch_size=5, batch_wait_timeout=1)
    client.create_backend("counter:v11", BatchingExample, config=config)
    client.create_endpoint("counter1",
                           backend="counter:v11",
                           route="/increment2")

    future_list = []
    handle = client.get_handle("counter1")
    for _ in range(20):
        f = handle.remote(temp=1)
        future_list.append(f)

    counter_result = ray.get(future_list)
    # since count is only updated per batch of queries
    # If there atleast one __call__ fn call with batch size greater than 1
    # counter result will always be less than 20
    assert max(counter_result) < 20
Ejemplo n.º 18
0
async def test_servable_batch_error(serve_instance, router,
                                    mock_controller_with_name):
    @serve.accept_batch
    class ErrorBatcher:
        def error_different_size(self, requests):
            return [""] * (len(requests) + 10)

        def error_non_iterable(self, _):
            return 42

        def return_np_array(self, requests):
            return np.array([1] * len(requests)).astype(np.int32)

    backend_config = BackendConfig(
        max_batch_size=4,
        internal_metadata=BackendMetadata(accepts_batches=True))
    await add_servable_to_router(ErrorBatcher,
                                 router,
                                 mock_controller_with_name[0],
                                 backend_config=backend_config)

    with pytest.raises(RayServeException, match="doesn't preserve batch size"):
        different_size = make_request_param("error_different_size")
        await (await router.assign_request.remote(different_size))

    with pytest.raises(RayServeException, match="iterable"):
        non_iterable = make_request_param("error_non_iterable")
        await (await router.assign_request.remote(non_iterable))

    np_array = make_request_param("return_np_array")
    result_np_value = await (await router.assign_request.remote(np_array))
    assert isinstance(result_np_value, np.int32)
Ejemplo n.º 19
0
async def test_task_runner_custom_method_batch(serve_instance, router,
                                               mock_controller_with_name):
    @serve.accept_batch
    class Batcher:
        def a(self, requests):
            return ["a-{}".format(i) for i in range(len(requests))]

        def b(self, requests):
            return ["b-{}".format(i) for i in range(len(requests))]

    backend_config = BackendConfig(
        max_batch_size=4,
        batch_wait_timeout=10,
        internal_metadata=BackendMetadata(accepts_batches=True))
    await add_servable_to_router(Batcher,
                                 router,
                                 mock_controller_with_name[0],
                                 backend_config=backend_config)

    a_query_param = make_request_param("a")
    b_query_param = make_request_param("b")

    futures = [
        await router.assign_request.remote(a_query_param) for _ in range(2)
    ]
    futures += [
        await router.assign_request.remote(b_query_param) for _ in range(2)
    ]

    gathered = await asyncio.gather(*futures)
    assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
Ejemplo n.º 20
0
async def test_task_runner_custom_method_batch(serve_instance):
    q = RoundRobinPolicyQueueActor.remote()

    @serve.accept_batch
    class Batcher:
        def a(self, _):
            return ["a-{}".format(i) for i in range(serve.context.batch_size)]

        def b(self, _):
            return ["b-{}".format(i) for i in range(serve.context.batch_size)]

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "producer"

    worker = setup_worker(CONSUMER_NAME, Batcher)

    await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
    await q.set_backend_config.remote(
        CONSUMER_NAME,
        BackendConfig({
            "max_batch_size": 10
        }, accepts_batches=True))

    a_query_param = RequestMetadata(
        PRODUCER_NAME, context.TaskContext.Python, call_method="a")
    b_query_param = RequestMetadata(
        PRODUCER_NAME, context.TaskContext.Python, call_method="b")

    futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
    futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]

    await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)

    gathered = await asyncio.gather(*futures)
    assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
Ejemplo n.º 21
0
def test_create_infeasible_error(serve_instance):
    client = serve_instance

    def f():
        pass

    # Non existent resource should be infeasible.
    with pytest.raises(RayServeException, match="Cannot scale backend"):
        client.create_backend(
            "f:1",
            f,
            ray_actor_options={"resources": {
                "MagicMLResource": 100
            }})

    # Even though each replica might be feasible, the total might not be.
    current_cpus = int(ray.nodes()[0]["Resources"]["CPU"])
    num_replicas = current_cpus + 20
    config = BackendConfig(num_replicas=num_replicas)
    with pytest.raises(RayServeException, match="Cannot scale backend"):
        client.create_backend("f:1",
                              f,
                              ray_actor_options={"resources": {
                                  "CPU": 1,
                              }},
                              config=config)
Ejemplo n.º 22
0
def test_batching(serve_instance):
    class BatchingExample:
        def __init__(self):
            self.count = 0

        @serve.accept_batch
        def __call__(self, flask_request, temp=None):
            self.count += 1
            batch_size = serve.context.batch_size
            return [self.count] * batch_size

    # set the max batch size
    serve.create_backend(
        "counter:v11",
        BatchingExample,
        config=BackendConfig(max_batch_size=5, batch_wait_timeout=1))
    serve.create_endpoint(
        "counter1", backend="counter:v11", route="/increment2")

    # Keep checking the routing table until /increment is populated
    while "/increment2" not in requests.get(
            "http://127.0.0.1:8000/-/routes").json():
        time.sleep(0.2)

    future_list = []
    handle = serve.get_handle("counter1")
    for _ in range(20):
        f = handle.remote(temp=1)
        future_list.append(f)

    counter_result = ray.get(future_list)
    # since count is only updated per batch of queries
    # If there atleast one __call__ fn call with batch size greater than 1
    # counter result will always be less than 20
    assert max(counter_result) < 20
Ejemplo n.º 23
0
def test_parallel_start(serve_instance):
    # Test the ability to start multiple replicas in parallel.
    # In the past, when Serve scale up a backend, it does so one by one and
    # wait for each replica to initialize. This test avoid this by preventing
    # the first replica to finish initialization unless the second replica is
    # also started.
    @ray.remote
    class Barrier:
        def __init__(self, release_on):
            self.release_on = release_on
            self.current_waiters = 0
            self.event = asyncio.Event()

        async def wait(self):
            self.current_waiters += 1
            if self.current_waiters == self.release_on:
                self.event.set()
            else:
                await self.event.wait()

    barrier = Barrier.remote(release_on=2)

    class LongStartingServable:
        def __init__(self):
            ray.get(barrier.wait.remote(), timeout=10)

        def __call__(self, _):
            return "Ready"

    serve.create_backend(
        "p:v0", LongStartingServable, config=BackendConfig(num_replicas=2))
    serve.create_endpoint("test-parallel", backend="p:v0")
    handle = serve.get_handle("test-parallel")

    ray.get(handle.remote(), timeout=10)
Ejemplo n.º 24
0
async def test_task_runner_perform_async(serve_instance,
                                         mock_controller_with_name):
    @ray.remote
    class Barrier:
        def __init__(self, release_on):
            self.release_on = release_on
            self.current_waiters = 0
            self.event = asyncio.Event()

        async def wait(self):
            self.current_waiters += 1
            if self.current_waiters == self.release_on:
                self.event.set()
            else:
                await self.event.wait()

    barrier = Barrier.remote(release_on=10)

    async def wait_and_go(*args, **kwargs):
        await barrier.wait.remote()
        return "done!"

    config = BackendConfig(max_concurrent_queries=10)

    worker, router = await add_servable_to_router(
        wait_and_go, *mock_controller_with_name, backend_config=config)

    query_param = make_request_param()

    done, not_done = await asyncio.wait(
        [(await router.assign_request(query_param)) for _ in range(10)],
        timeout=10)
    assert len(done) == 10
    for item in done:
        assert await item == "done!"
Ejemplo n.º 25
0
def test_create_infeasible_error(serve_instance):
    client = serve_instance

    def f():
        pass

    # Non existent resource should be infeasible.
    with pytest.raises(RayServeException, match="Cannot scale backend"):
        client.create_backend(
            "f:1",
            f,
            ray_actor_options={"resources": {
                "MagicMLResource": 100
            }})

    # Even each replica might be feasible, the total might not be.
    current_cpus = int(ray.nodes()[0]["Resources"]["CPU"])
    with pytest.raises(RayServeException, match="Cannot scale backend"):
        client.create_backend("f:1",
                              f,
                              ray_actor_options={"resources": {
                                  "CPU": 1,
                              }},
                              config=BackendConfig(num_replicas=(current_cpus +
                                                                 20)))

    # No replica should be created!
    replicas = ray.get(client._controller._list_replicas.remote("f1"))
    assert len(replicas) == 0
Ejemplo n.º 26
0
async def test_task_runner_perform_batch(serve_instance):
    q = ray.remote(Router).remote()
    await q.setup.remote()

    def batcher(*args, **kwargs):
        return [serve.context.batch_size] * serve.context.batch_size

    CONSUMER_NAME = "runner"
    PRODUCER_NAME = "producer"

    config = BackendConfig({
        "max_batch_size": 2,
        "batch_wait_timeout": 10
    },
                           accepts_batches=True)

    worker = setup_worker(CONSUMER_NAME, batcher, backend_config=config)
    await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
    await q.set_backend_config.remote(CONSUMER_NAME, config)
    await q.set_traffic.remote(PRODUCER_NAME,
                               TrafficPolicy({CONSUMER_NAME: 1.0}))

    query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)

    my_batch_sizes = await asyncio.gather(
        *[q.enqueue_request.remote(query_param) for _ in range(3)])
    assert my_batch_sizes == [2, 2, 1]
Ejemplo n.º 27
0
def create_backend(backend_tag,
                   func_or_class,
                   *actor_init_args,
                   ray_actor_options=None,
                   config=None):
    """Create a backend with the provided tag.

    The backend will serve requests with func_or_class.

    Args:
        backend_tag (str): a unique tag assign to identify this backend.
        func_or_class (callable, class): a function or a class implementing
            __call__.
        actor_init_args (optional): the arguments to pass to the class.
            initialization method.
        ray_actor_options (optional): options to be passed into the
            @ray.remote decorator for the backend actor.
        config: (optional) configuration options for this backend.
    """
    if config is None:
        config = {}
    if not isinstance(config, dict):
        raise TypeError("config must be a dictionary.")

    replica_config = ReplicaConfig(func_or_class,
                                   *actor_init_args,
                                   ray_actor_options=ray_actor_options)
    backend_config = BackendConfig(config, replica_config.accepts_batches)

    retry_actor_failures(master_actor.create_backend, backend_tag,
                         backend_config, replica_config)
Ejemplo n.º 28
0
def setup_worker(name,
                 func_or_class,
                 init_args=None,
                 backend_config=BackendConfig({})):
    if init_args is None:
        init_args = ()

    @ray.remote
    class WorkerActor:
        def __init__(self):
            self.worker = create_backend_worker(func_or_class)(name,
                                                               name + ":tag",
                                                               init_args,
                                                               backend_config)

        def ready(self):
            pass

        async def handle_request(self, *args, **kwargs):
            return await self.worker.handle_request(*args, **kwargs)

        def update_config(self, new_config):
            return self.worker.update_config(new_config)

    worker = WorkerActor.remote()
    ray.get(worker.ready.remote())
    return worker
Ejemplo n.º 29
0
def test_worker_replica_failure(serve_instance):
    @ray.remote
    class Counter:
        def __init__(self):
            self.count = 0

        def inc_and_get(self):
            self.count += 1
            return self.count

    class Worker:
        # Assumes that two replicas are started. Will hang forever in the
        # constructor for any workers that are restarted.
        def __init__(self, counter):
            self.should_hang = False
            self.index = ray.get(counter.inc_and_get.remote())
            if self.index > 2:
                while True:
                    pass

        def __call__(self, *args):
            return self.index

    counter = Counter.remote()
    serve.create_backend("replica_failure", Worker, counter)
    serve.update_backend_config("replica_failure",
                                BackendConfig(num_replicas=2))
    serve.create_endpoint("replica_failure",
                          backend="replica_failure",
                          route="/replica_failure")

    # Wait until both replicas have been started.
    responses = set()
    start = time.time()
    while time.time() - start < 30:
        time.sleep(0.1)
        response = request_with_retries("/replica_failure", timeout=1).text
        assert response in ["1", "2"]
        responses.add(response)
        if len(responses) > 1:
            break
    else:
        raise TimeoutError("Timed out waiting for replicas after 30s.")

    # Kill one of the replicas.
    handles = _get_worker_handles("replica_failure")
    assert len(handles) == 2
    ray.kill(handles[0], no_restart=False)

    # Check that the other replica still serves requests.
    for _ in range(10):
        while True:
            try:
                # The timeout needs to be small here because the request to
                # the restarting worker will hang.
                request_with_retries("/replica_failure", timeout=0.1)
                break
            except TimeoutError:
                time.sleep(0.1)
Ejemplo n.º 30
0
def test_worker_replica_failure(serve_instance):
    serve.http_proxy.MAX_ACTOR_DEAD_RETRIES = 0
    serve.init()

    class Worker:
        # Assumes that two replicas are started. Will hang forever in the
        # constructor for any workers that are restarted.
        def __init__(self, path):
            self.should_hang = False
            if not os.path.exists(path):
                with open(path, "w") as f:
                    f.write("1")
            else:
                with open(path, "r") as f:
                    num = int(f.read())

                with open(path, "w") as f:
                    if num == 2:
                        self.should_hang = True
                    else:
                        f.write(str(num + 1))

            if self.should_hang:
                while True:
                    pass

        def __call__(self):
            pass

    temp_path = os.path.join(tempfile.gettempdir(),
                             serve.utils.get_random_letters())
    serve.create_backend("replica_failure", Worker, temp_path)
    serve.update_backend_config("replica_failure",
                                BackendConfig(num_replicas=2))
    serve.create_endpoint("replica_failure",
                          backend="replica_failure",
                          route="/replica_failure")

    # Wait until both replicas have been started.
    responses = set()
    while len(responses) == 1:
        responses.add(request_with_retries("/replica_failure", timeout=1).text)
        time.sleep(0.1)

    # Kill one of the replicas.
    handles = _get_worker_handles("replica_failure")
    assert len(handles) == 2
    ray.kill(handles[0], no_restart=False)

    # Check that the other replica still serves requests.
    for _ in range(10):
        while True:
            try:
                # The timeout needs to be small here because the request to
                # the restarting worker will hang.
                request_with_retries("/replica_failure", timeout=0.1)
                break
            except TimeoutError:
                time.sleep(0.1)