def test_worker_restart(serve_instance): serve.init() serve.create_endpoint("worker_failure", "/worker_failure", methods=["GET"]) class Worker1: def __call__(self): return os.getpid() serve.create_backend(Worker1, "worker_failure:v1") serve.link("worker_failure", "worker_failure:v1") # Get the PID of the worker. old_pid = request_with_retries("/worker_failure", timeout=0.1).text # Kill the worker. handles = _get_worker_handles("worker_failure:v1") assert len(handles) == 1 ray.kill(list(handles.values())[0]) # Wait until the worker is killed and a one is started. start = time.time() while time.time() - start < 30: response = request_with_retries("/worker_failure", timeout=30) if response.text != old_pid: break else: assert False, "Timed out waiting for worker to die."
def test_batching(serve_instance): class BatchingExample: def __init__(self): self.count = 0 @serve.accept_batch def __call__(self, flask_request, temp=None): self.count += 1 batch_size = serve.context.batch_size return [self.count] * batch_size serve.create_endpoint("counter1", "/increment") # Keep checking the routing table until /increment is populated while "/increment" not in requests.get( "http://127.0.0.1:8000/-/routes").json(): time.sleep(0.2) # set the max batch size b_config = BackendConfig(max_batch_size=5) serve.create_backend( BatchingExample, "counter:v11", backend_config=b_config) serve.link("counter1", "counter:v11") future_list = [] handle = serve.get_handle("counter1") for _ in range(20): f = handle.remote(temp=1) future_list.append(f) counter_result = ray.get(future_list) # since count is only updated per batch of queries # If there atleast one __call__ fn call with batch size greater than 1 # counter result will always be less than 20 assert max(counter_result) < 20
def test_http_proxy_failure(serve_instance): serve.init() serve.create_endpoint("proxy_failure", "/proxy_failure", methods=["GET"]) def function(): return "hello1" serve.create_backend(function, "proxy_failure:v1") serve.link("proxy_failure", "proxy_failure:v1") assert request_with_retries("/proxy_failure", timeout=0.1).text == "hello1" for _ in range(10): response = request_with_retries("/proxy_failure", timeout=30) assert response.text == "hello1" _kill_http_proxy() def function(): return "hello2" serve.create_backend(function, "proxy_failure:v2") serve.link("proxy_failure", "proxy_failure:v2") for _ in range(10): response = request_with_retries("/proxy_failure", timeout=30) assert response.text == "hello2"
def test_http_proxy_failure(serve_instance): serve.init() serve.create_endpoint("failure_endpoint", "/failure_endpoint", methods=["GET"]) def function(flask_request): return "hello1" serve.create_backend(function, "failure:v1") serve.link("failure_endpoint", "failure:v1") def verify_response(response): assert response.text == "hello1" request_with_retries("/failure_endpoint", verify_response, timeout=0) _kill_http_proxy() request_with_retries("/failure_endpoint", verify_response, timeout=30) _kill_http_proxy() def function(flask_request): return "hello2" serve.create_backend(function, "failure:v2") serve.link("failure_endpoint", "failure:v2") def verify_response(response): assert response.text == "hello2" request_with_retries("/failure_endpoint", verify_response, timeout=30)
def test_e2e(serve_instance): serve.init() serve.create_endpoint("endpoint", "/api", methods=["GET", "POST"]) retry_count = 5 timeout_sleep = 0.5 while True: try: resp = requests.get( "http://127.0.0.1:8000/-/routes", timeout=0.5).json() assert resp == {"/api": ["endpoint", ["GET", "POST"]]} break except Exception as e: time.sleep(timeout_sleep) timeout_sleep *= 2 retry_count -= 1 if retry_count == 0: assert False, ("Route table hasn't been updated after 3 tries." "The latest error was {}").format(e) def function(flask_request): return {"method": flask_request.method} serve.create_backend(function, "echo:v1") serve.link("endpoint", "echo:v1") resp = requests.get("http://127.0.0.1:8000/api").json()["method"] assert resp == "GET" resp = requests.post("http://127.0.0.1:8000/api").json()["method"] assert resp == "POST"
def test_e2e(serve_instance): serve.init() # so we have access to global state serve.create_endpoint("endpoint", "/api", blocking=True) result = serve.api._get_global_state().route_table.list_service() assert result["/api"] == "endpoint" retry_count = 5 timeout_sleep = 0.5 while True: try: resp = requests.get("http://127.0.0.1:8000/", timeout=0.5).json() assert resp == result break except Exception: time.sleep(timeout_sleep) timeout_sleep *= 2 retry_count -= 1 if retry_count == 0: assert False, "Route table hasn't been updated after 3 tries." def function(flask_request): return "OK" serve.create_backend(function, "echo:v1") serve.link("endpoint", "echo:v1") resp = requests.get("http://127.0.0.1:8000/api").json()["result"] assert resp == "OK"
def test_no_route(serve_instance): serve.create_endpoint("noroute-endpoint") def func(_, i=1): return 1 serve.create_backend(func, "backend:1") serve.link("noroute-endpoint", "backend:1") service_handle = serve.get_handle("noroute-endpoint") result = ray.get(service_handle.remote(i=1)) assert result == 1
def test_no_route(serve_instance): serve.create_endpoint("noroute-endpoint", blocking=True) global_state = serve.api._get_global_state() result = global_state.route_table.list_service(include_headless=True) assert result[NO_ROUTE_KEY] == ["noroute-endpoint"] without_headless_result = global_state.route_table.list_service() assert NO_ROUTE_KEY not in without_headless_result def func(_, i=1): return 1 serve.create_backend(func, "backend:1") serve.link("noroute-endpoint", "backend:1") service_handle = serve.get_handle("noroute-endpoint") result = ray.get(service_handle.remote(i=1)) assert result == 1
def test_batching_exception(serve_instance): class NoListReturned: def __init__(self): self.count = 0 @serve.accept_batch def __call__(self, flask_request, temp=None): batch_size = serve.context.batch_size return batch_size serve.create_endpoint("exception-test", "/noListReturned") # set the max batch size b_config = BackendConfig(max_batch_size=5) serve.create_backend( NoListReturned, "exception:v1", backend_config=b_config) serve.link("exception-test", "exception:v1") handle = serve.get_handle("exception-test") with pytest.raises(ray.exceptions.RayTaskError): assert ray.get(handle.remote(temp=1))
def test_scaling_replicas(serve_instance): class Counter: def __init__(self): self.count = 0 def __call__(self, _): self.count += 1 return self.count serve.create_endpoint("counter", "/increment") # Keep checking the routing table until /increment is populated while "/increment" not in requests.get( "http://127.0.0.1:8000/-/routes").json(): time.sleep(0.2) b_config = BackendConfig(num_replicas=2) serve.create_backend(Counter, "counter:v1", backend_config=b_config) serve.link("counter", "counter:v1") counter_result = [] for _ in range(10): resp = requests.get("http://127.0.0.1:8000/increment").json() counter_result.append(resp) # If the load is shared among two replicas. The max result cannot be 10. assert max(counter_result) < 10 b_config = serve.get_backend_config("counter:v1") b_config.num_replicas = 1 serve.set_backend_config("counter:v1", b_config) counter_result = [] for _ in range(10): resp = requests.get("http://127.0.0.1:8000/increment").json() counter_result.append(resp) # Give some time for a replica to spin down. But majority of the request # should be served by the only remaining replica. assert max(counter_result) - min(counter_result) > 6
def test_handle_in_endpoint(serve_instance): serve.init() class Endpoint1: def __call__(self, flask_request): return "hello" class Endpoint2: def __init__(self): self.handle = serve.get_handle("endpoint1", missing_ok=True) def __call__(self): return ray.get(self.handle.remote()) serve.create_endpoint("endpoint1", "/endpoint1", methods=["GET", "POST"]) serve.create_backend(Endpoint1, "endpoint1:v0") serve.link("endpoint1", "endpoint1:v0") serve.create_endpoint("endpoint2", "/endpoint2", methods=["GET", "POST"]) serve.create_backend(Endpoint2, "endpoint2:v0") serve.link("endpoint2", "endpoint2:v0") assert requests.get("http://127.0.0.1:8000/endpoint2").text == "hello"
def test_worker_replica_failure(serve_instance): serve.http_proxy.MAX_ACTOR_DEAD_RETRIES = 0 serve.init() serve.create_endpoint("replica_failure", "/replica_failure", methods=["GET"]) class Worker: # Assumes that two replicas are started. Will hang forever in the # constructor for any workers that are restarted. def __init__(self, path): self.should_hang = False if not os.path.exists(path): with open(path, "w") as f: f.write("1") else: with open(path, "r") as f: num = int(f.read()) with open(path, "w") as f: if num == 2: self.should_hang = True else: f.write(str(num + 1)) if self.should_hang: while True: pass def __call__(self): pass temp_path = tempfile.gettempdir() + "/" + serve.utils.get_random_letters() serve.create_backend(Worker, "replica_failure", temp_path) backend_config = serve.get_backend_config("replica_failure") backend_config.num_replicas = 2 serve.set_backend_config("replica_failure", backend_config) serve.link("replica_failure", "replica_failure") # Wait until both replicas have been started. responses = set() while len(responses) == 1: responses.add( request_with_retries("/replica_failure", timeout=0.1).text) time.sleep(0.1) # Kill one of the replicas. handles = _get_worker_handles("replica_failure") assert len(handles) == 2 ray.kill(list(handles.values())[0]) # Check that the other replica still serves requests. for _ in range(10): while True: try: # The timeout needs to be small here because the request to # the restarting worker will hang. request_with_retries("/replica_failure", timeout=0.1) break except TimeoutError: time.sleep(0.1)
""" Example service that prints out http context. """ import time import requests from ray import serve from ray.serve.utils import pformat_color_json def echo(flask_request): return "hello " + flask_request.args.get("name", "serve!") serve.init(blocking=True) serve.create_endpoint("my_endpoint", "/echo") serve.create_backend(echo, "echo:v1") serve.link("my_endpoint", "echo:v1") while True: resp = requests.get("http://127.0.0.1:8000/echo").json() print(pformat_color_json(resp)) print("...Sleeping for 2 seconds...") time.sleep(2)
ret_str = "Number: {} Batch size: {}".format( base_num, batch_size) result.append(ret_str) return result return "" serve.init(blocking=True) serve.create_endpoint("magic_counter", "/counter") # specify max_batch_size in BackendConfig b_config = BackendConfig(max_batch_size=5) serve.create_backend( MagicCounter, "counter:v1", 42, backend_config=b_config) # increment=42 print("Backend Config for backend: 'counter:v1'") print(b_config) serve.link("magic_counter", "counter:v1") handle = serve.get_handle("magic_counter") future_list = [] # fire 30 requests for r in range(30): print("> [REMOTE] Pinging handle.remote(base_number={})".format(r)) f = handle.remote(base_number=r) future_list.append(f) # get results of queries as they complete left_futures = future_list while left_futures: completed_futures, remaining_futures = ray.wait(left_futures, timeout=0.05) if len(completed_futures) > 0:
"--batch_size") else: args["--batch_size"] = 1 ray.init(address=args["--ray_address"], redis_password=args["--ray_password"]) serve.init(start_server=False) input_p = Path(args["--input_directory"]) output_p = Path(args["--output_directory"]) all_wavs = list(input_p.rglob("**/*.WAV")) # model = RunSplitter() # predictions = model(None, audio_paths=all_wavs[0:10]) # print(predictions) serve.create_endpoint("splitter") serve.create_backend( RunSplitter, "splitter:v0", backend_config=serve.BackendConfig(num_replicas=args["--num_nodes"], max_batch_size=args["--batch_size"]), ) serve.link("splitter", "splitter:v0") handle = serve.get_handle("splitter") ids = [handle.remote(audio_paths=audio_path) for audio_path in all_wavs] results = ray.get(ids) print(results)