Beispiel #1
0
def test_worker_restart(serve_instance):
    serve.init()
    serve.create_endpoint("worker_failure", "/worker_failure", methods=["GET"])

    class Worker1:
        def __call__(self):
            return os.getpid()

    serve.create_backend(Worker1, "worker_failure:v1")
    serve.link("worker_failure", "worker_failure:v1")

    # Get the PID of the worker.
    old_pid = request_with_retries("/worker_failure", timeout=0.1).text

    # Kill the worker.
    handles = _get_worker_handles("worker_failure:v1")
    assert len(handles) == 1
    ray.kill(list(handles.values())[0])

    # Wait until the worker is killed and a one is started.
    start = time.time()
    while time.time() - start < 30:
        response = request_with_retries("/worker_failure", timeout=30)
        if response.text != old_pid:
            break
    else:
        assert False, "Timed out waiting for worker to die."
Beispiel #2
0
def test_batching(serve_instance):
    class BatchingExample:
        def __init__(self):
            self.count = 0

        @serve.accept_batch
        def __call__(self, flask_request, temp=None):
            self.count += 1
            batch_size = serve.context.batch_size
            return [self.count] * batch_size

    serve.create_endpoint("counter1", "/increment")

    # Keep checking the routing table until /increment is populated
    while "/increment" not in requests.get(
            "http://127.0.0.1:8000/-/routes").json():
        time.sleep(0.2)

    # set the max batch size
    b_config = BackendConfig(max_batch_size=5)
    serve.create_backend(
        BatchingExample, "counter:v11", backend_config=b_config)
    serve.link("counter1", "counter:v11")

    future_list = []
    handle = serve.get_handle("counter1")
    for _ in range(20):
        f = handle.remote(temp=1)
        future_list.append(f)

    counter_result = ray.get(future_list)
    # since count is only updated per batch of queries
    # If there atleast one __call__ fn call with batch size greater than 1
    # counter result will always be less than 20
    assert max(counter_result) < 20
Beispiel #3
0
def test_http_proxy_failure(serve_instance):
    serve.init()
    serve.create_endpoint("proxy_failure", "/proxy_failure", methods=["GET"])

    def function():
        return "hello1"

    serve.create_backend(function, "proxy_failure:v1")
    serve.link("proxy_failure", "proxy_failure:v1")

    assert request_with_retries("/proxy_failure", timeout=0.1).text == "hello1"

    for _ in range(10):
        response = request_with_retries("/proxy_failure", timeout=30)
        assert response.text == "hello1"

    _kill_http_proxy()

    def function():
        return "hello2"

    serve.create_backend(function, "proxy_failure:v2")
    serve.link("proxy_failure", "proxy_failure:v2")

    for _ in range(10):
        response = request_with_retries("/proxy_failure", timeout=30)
        assert response.text == "hello2"
Beispiel #4
0
def test_http_proxy_failure(serve_instance):
    serve.init()
    serve.create_endpoint("failure_endpoint",
                          "/failure_endpoint",
                          methods=["GET"])

    def function(flask_request):
        return "hello1"

    serve.create_backend(function, "failure:v1")
    serve.link("failure_endpoint", "failure:v1")

    def verify_response(response):
        assert response.text == "hello1"

    request_with_retries("/failure_endpoint", verify_response, timeout=0)

    _kill_http_proxy()

    request_with_retries("/failure_endpoint", verify_response, timeout=30)

    _kill_http_proxy()

    def function(flask_request):
        return "hello2"

    serve.create_backend(function, "failure:v2")
    serve.link("failure_endpoint", "failure:v2")

    def verify_response(response):
        assert response.text == "hello2"

    request_with_retries("/failure_endpoint", verify_response, timeout=30)
Beispiel #5
0
def test_e2e(serve_instance):
    serve.init()
    serve.create_endpoint("endpoint", "/api", methods=["GET", "POST"])

    retry_count = 5
    timeout_sleep = 0.5
    while True:
        try:
            resp = requests.get(
                "http://127.0.0.1:8000/-/routes", timeout=0.5).json()
            assert resp == {"/api": ["endpoint", ["GET", "POST"]]}
            break
        except Exception as e:
            time.sleep(timeout_sleep)
            timeout_sleep *= 2
            retry_count -= 1
            if retry_count == 0:
                assert False, ("Route table hasn't been updated after 3 tries."
                               "The latest error was {}").format(e)

    def function(flask_request):
        return {"method": flask_request.method}

    serve.create_backend(function, "echo:v1")
    serve.link("endpoint", "echo:v1")

    resp = requests.get("http://127.0.0.1:8000/api").json()["method"]
    assert resp == "GET"

    resp = requests.post("http://127.0.0.1:8000/api").json()["method"]
    assert resp == "POST"
Beispiel #6
0
def test_e2e(serve_instance):
    serve.init()  # so we have access to global state
    serve.create_endpoint("endpoint", "/api", blocking=True)
    result = serve.api._get_global_state().route_table.list_service()
    assert result["/api"] == "endpoint"

    retry_count = 5
    timeout_sleep = 0.5
    while True:
        try:
            resp = requests.get("http://127.0.0.1:8000/", timeout=0.5).json()
            assert resp == result
            break
        except Exception:
            time.sleep(timeout_sleep)
            timeout_sleep *= 2
            retry_count -= 1
            if retry_count == 0:
                assert False, "Route table hasn't been updated after 3 tries."

    def function(flask_request):
        return "OK"

    serve.create_backend(function, "echo:v1")
    serve.link("endpoint", "echo:v1")

    resp = requests.get("http://127.0.0.1:8000/api").json()["result"]
    assert resp == "OK"
Beispiel #7
0
def test_no_route(serve_instance):
    serve.create_endpoint("noroute-endpoint")

    def func(_, i=1):
        return 1

    serve.create_backend(func, "backend:1")
    serve.link("noroute-endpoint", "backend:1")
    service_handle = serve.get_handle("noroute-endpoint")
    result = ray.get(service_handle.remote(i=1))
    assert result == 1
Beispiel #8
0
def test_no_route(serve_instance):
    serve.create_endpoint("noroute-endpoint", blocking=True)
    global_state = serve.api._get_global_state()

    result = global_state.route_table.list_service(include_headless=True)
    assert result[NO_ROUTE_KEY] == ["noroute-endpoint"]

    without_headless_result = global_state.route_table.list_service()
    assert NO_ROUTE_KEY not in without_headless_result

    def func(_, i=1):
        return 1

    serve.create_backend(func, "backend:1")
    serve.link("noroute-endpoint", "backend:1")
    service_handle = serve.get_handle("noroute-endpoint")
    result = ray.get(service_handle.remote(i=1))
    assert result == 1
Beispiel #9
0
def test_batching_exception(serve_instance):
    class NoListReturned:
        def __init__(self):
            self.count = 0

        @serve.accept_batch
        def __call__(self, flask_request, temp=None):
            batch_size = serve.context.batch_size
            return batch_size

    serve.create_endpoint("exception-test", "/noListReturned")
    # set the max batch size
    b_config = BackendConfig(max_batch_size=5)
    serve.create_backend(
        NoListReturned, "exception:v1", backend_config=b_config)
    serve.link("exception-test", "exception:v1")

    handle = serve.get_handle("exception-test")
    with pytest.raises(ray.exceptions.RayTaskError):
        assert ray.get(handle.remote(temp=1))
Beispiel #10
0
def test_scaling_replicas(serve_instance):
    class Counter:
        def __init__(self):
            self.count = 0

        def __call__(self, _):
            self.count += 1
            return self.count

    serve.create_endpoint("counter", "/increment")

    # Keep checking the routing table until /increment is populated
    while "/increment" not in requests.get(
            "http://127.0.0.1:8000/-/routes").json():
        time.sleep(0.2)

    b_config = BackendConfig(num_replicas=2)
    serve.create_backend(Counter, "counter:v1", backend_config=b_config)
    serve.link("counter", "counter:v1")

    counter_result = []
    for _ in range(10):
        resp = requests.get("http://127.0.0.1:8000/increment").json()
        counter_result.append(resp)

    # If the load is shared among two replicas. The max result cannot be 10.
    assert max(counter_result) < 10

    b_config = serve.get_backend_config("counter:v1")
    b_config.num_replicas = 1
    serve.set_backend_config("counter:v1", b_config)

    counter_result = []
    for _ in range(10):
        resp = requests.get("http://127.0.0.1:8000/increment").json()
        counter_result.append(resp)
    # Give some time for a replica to spin down. But majority of the request
    # should be served by the only remaining replica.
    assert max(counter_result) - min(counter_result) > 6
Beispiel #11
0
def test_handle_in_endpoint(serve_instance):
    serve.init()

    class Endpoint1:
        def __call__(self, flask_request):
            return "hello"

    class Endpoint2:
        def __init__(self):
            self.handle = serve.get_handle("endpoint1", missing_ok=True)

        def __call__(self):
            return ray.get(self.handle.remote())

    serve.create_endpoint("endpoint1", "/endpoint1", methods=["GET", "POST"])
    serve.create_backend(Endpoint1, "endpoint1:v0")
    serve.link("endpoint1", "endpoint1:v0")

    serve.create_endpoint("endpoint2", "/endpoint2", methods=["GET", "POST"])
    serve.create_backend(Endpoint2, "endpoint2:v0")
    serve.link("endpoint2", "endpoint2:v0")

    assert requests.get("http://127.0.0.1:8000/endpoint2").text == "hello"
Beispiel #12
0
def test_worker_replica_failure(serve_instance):
    serve.http_proxy.MAX_ACTOR_DEAD_RETRIES = 0
    serve.init()
    serve.create_endpoint("replica_failure",
                          "/replica_failure",
                          methods=["GET"])

    class Worker:
        # Assumes that two replicas are started. Will hang forever in the
        # constructor for any workers that are restarted.
        def __init__(self, path):
            self.should_hang = False
            if not os.path.exists(path):
                with open(path, "w") as f:
                    f.write("1")
            else:
                with open(path, "r") as f:
                    num = int(f.read())

                with open(path, "w") as f:
                    if num == 2:
                        self.should_hang = True
                    else:
                        f.write(str(num + 1))

            if self.should_hang:
                while True:
                    pass

        def __call__(self):
            pass

    temp_path = tempfile.gettempdir() + "/" + serve.utils.get_random_letters()
    serve.create_backend(Worker, "replica_failure", temp_path)
    backend_config = serve.get_backend_config("replica_failure")
    backend_config.num_replicas = 2
    serve.set_backend_config("replica_failure", backend_config)
    serve.link("replica_failure", "replica_failure")

    # Wait until both replicas have been started.
    responses = set()
    while len(responses) == 1:
        responses.add(
            request_with_retries("/replica_failure", timeout=0.1).text)
        time.sleep(0.1)

    # Kill one of the replicas.
    handles = _get_worker_handles("replica_failure")
    assert len(handles) == 2
    ray.kill(list(handles.values())[0])

    # Check that the other replica still serves requests.
    for _ in range(10):
        while True:
            try:
                # The timeout needs to be small here because the request to
                # the restarting worker will hang.
                request_with_retries("/replica_failure", timeout=0.1)
                break
            except TimeoutError:
                time.sleep(0.1)
"""
Example service that prints out http context.
"""

import time

import requests

from ray import serve
from ray.serve.utils import pformat_color_json


def echo(flask_request):
    return "hello " + flask_request.args.get("name", "serve!")


serve.init(blocking=True)

serve.create_endpoint("my_endpoint", "/echo")
serve.create_backend(echo, "echo:v1")
serve.link("my_endpoint", "echo:v1")

while True:
    resp = requests.get("http://127.0.0.1:8000/echo").json()
    print(pformat_color_json(resp))

    print("...Sleeping for 2 seconds...")
    time.sleep(2)
Beispiel #14
0
                ret_str = "Number: {} Batch size: {}".format(
                    base_num, batch_size)
                result.append(ret_str)
            return result
        return ""


serve.init(blocking=True)
serve.create_endpoint("magic_counter", "/counter")
# specify max_batch_size in BackendConfig
b_config = BackendConfig(max_batch_size=5)
serve.create_backend(
    MagicCounter, "counter:v1", 42, backend_config=b_config)  # increment=42
print("Backend Config for backend: 'counter:v1'")
print(b_config)
serve.link("magic_counter", "counter:v1")

handle = serve.get_handle("magic_counter")
future_list = []

# fire 30 requests
for r in range(30):
    print("> [REMOTE] Pinging handle.remote(base_number={})".format(r))
    f = handle.remote(base_number=r)
    future_list.append(f)

# get results of queries as they complete
left_futures = future_list
while left_futures:
    completed_futures, remaining_futures = ray.wait(left_futures, timeout=0.05)
    if len(completed_futures) > 0:
                                            "--batch_size")
else:
    args["--batch_size"] = 1

ray.init(address=args["--ray_address"], redis_password=args["--ray_password"])
serve.init(start_server=False)

input_p = Path(args["--input_directory"])
output_p = Path(args["--output_directory"])

all_wavs = list(input_p.rglob("**/*.WAV"))

# model = RunSplitter()
# predictions = model(None, audio_paths=all_wavs[0:10])
# print(predictions)

serve.create_endpoint("splitter")
serve.create_backend(
    RunSplitter,
    "splitter:v0",
    backend_config=serve.BackendConfig(num_replicas=args["--num_nodes"],
                                       max_batch_size=args["--batch_size"]),
)
serve.link("splitter", "splitter:v0")

handle = serve.get_handle("splitter")

ids = [handle.remote(audio_paths=audio_path) for audio_path in all_wavs]
results = ray.get(ids)
print(results)