def test_http_proxy_failure(serve_instance): serve.init() def function(): return "hello1" serve.create_backend("proxy_failure:v1", function) serve.create_endpoint( "proxy_failure", backend="proxy_failure:v1", route="/proxy_failure") assert request_with_retries("/proxy_failure", timeout=1.0).text == "hello1" for _ in range(10): response = request_with_retries("/proxy_failure", timeout=30) assert response.text == "hello1" _kill_http_proxy() def function(): return "hello2" serve.create_backend("proxy_failure:v2", function) serve.set_traffic("proxy_failure", {"proxy_failure:v2": 1.0}) for _ in range(10): response = request_with_retries("/proxy_failure", timeout=30) assert response.text == "hello2"
def test_shutdown(serve_instance): def f(): pass instance_name = "shutdown" serve.init(name=instance_name, http_port=8003) serve.create_backend("backend", f) serve.create_endpoint("endpoint", backend="backend") serve.shutdown() with pytest.raises(RayServeException, match="Please run serve.init"): serve.list_backends() def check_dead(): for actor_name in [ constants.SERVE_CONTROLLER_NAME, constants.SERVE_PROXY_NAME, constants.SERVE_METRIC_SINK_NAME ]: try: ray.get_actor(format_actor_name(actor_name, instance_name)) return False except ValueError: pass return True wait_for_condition(check_dead)
def test_worker_restart(serve_instance): serve.init() class Worker1: def __call__(self): return os.getpid() serve.create_backend("worker_failure:v1", Worker1) serve.create_endpoint( "worker_failure", backend="worker_failure:v1", route="/worker_failure") # Get the PID of the worker. old_pid = request_with_retries("/worker_failure", timeout=1).text # Kill the worker. handles = _get_worker_handles("worker_failure:v1") assert len(handles) == 1 ray.kill(handles[0], no_restart=False) # Wait until the worker is killed and a one is started. start = time.time() while time.time() - start < 30: response = request_with_retries("/worker_failure", timeout=30) if response.text != old_pid: break else: assert False, "Timed out waiting for worker to die."
def test_updating_config(serve_instance): class BatchSimple: def __init__(self): self.count = 0 def __call__(self, request): return 1 config = BackendConfig(max_concurrent_queries=2, num_replicas=3) serve.create_backend("bsimple:v1", BatchSimple, config=config) serve.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple") controller = serve.api._global_client._controller old_replica_tag_list = list( ray.get(controller._all_replica_handles.remote())["bsimple:v1"].keys()) update_config = BackendConfig(max_concurrent_queries=5) serve.update_backend_config("bsimple:v1", update_config) new_replica_tag_list = list( ray.get(controller._all_replica_handles.remote())["bsimple:v1"].keys()) new_all_tag_list = [] for worker_dict in ray.get( controller._all_replica_handles.remote()).values(): new_all_tag_list.extend(list(worker_dict.keys())) # the old and new replica tag list should be identical # and should be subset of all_tag_list assert set(old_replica_tag_list) <= set(new_all_tag_list) assert set(old_replica_tag_list) == set(new_replica_tag_list)
def test_scaling_replicas(serve_instance): class Counter: def __init__(self): self.count = 0 def __call__(self, _): self.count += 1 return self.count serve.create_backend("counter:v1", Counter, config={"num_replicas": 2}) serve.create_endpoint("counter", backend="counter:v1", route="/increment") # Keep checking the routing table until /increment is populated while "/increment" not in requests.get( "http://127.0.0.1:8000/-/routes").json(): time.sleep(0.2) counter_result = [] for _ in range(10): resp = requests.get("http://127.0.0.1:8000/increment").json() counter_result.append(resp) # If the load is shared among two replicas. The max result cannot be 10. assert max(counter_result) < 10 serve.update_backend_config("counter:v1", {"num_replicas": 1}) counter_result = [] for _ in range(10): resp = requests.get("http://127.0.0.1:8000/increment").json() counter_result.append(resp) # Give some time for a replica to spin down. But majority of the request # should be served by the only remaining replica. assert max(counter_result) - min(counter_result) > 6
def test_scaling_replicas(serve_instance): class Counter: def __init__(self): self.count = 0 def __call__(self, _): self.count += 1 return self.count config = BackendConfig(num_replicas=2) serve.create_backend("counter:v1", Counter, config=config) serve.create_endpoint("counter", backend="counter:v1", route="/increment") counter_result = [] for _ in range(10): resp = requests.get("http://127.0.0.1:8000/increment").json() counter_result.append(resp) # If the load is shared among two replicas. The max result cannot be 10. assert max(counter_result) < 10 update_config = BackendConfig(num_replicas=1) serve.update_backend_config("counter:v1", update_config) counter_result = [] for _ in range(10): resp = requests.get("http://127.0.0.1:8000/increment").json() counter_result.append(resp) # Give some time for a replica to spin down. But majority of the request # should be served by the only remaining replica. assert max(counter_result) - min(counter_result) > 6
def test_backend_user_config(serve_instance): class Counter: def __init__(self): self.count = 10 def __call__(self, starlette_request): return self.count, os.getpid() def reconfigure(self, config): self.count = config["count"] config = BackendConfig(num_replicas=2, user_config={"count": 123, "b": 2}) serve.create_backend("counter", Counter, config=config) serve.create_endpoint("counter", backend="counter") handle = serve.get_handle("counter") def check(val, num_replicas): pids_seen = set() for i in range(100): result = ray.get(handle.remote()) if str(result[0]) != val: return False pids_seen.add(result[1]) return len(pids_seen) == num_replicas wait_for_condition(lambda: check("123", 2)) serve.update_backend_config("counter", BackendConfig(num_replicas=3)) wait_for_condition(lambda: check("123", 3)) config = BackendConfig(user_config={"count": 456}) serve.update_backend_config("counter", config) wait_for_condition(lambda: check("456", 3))
def test_http_proxy_failure(serve_instance): def function(_): return "hello1" serve.create_backend("proxy_failure:v1", function) serve.create_endpoint("proxy_failure", backend="proxy_failure:v1", route="/proxy_failure") assert request_with_retries("/proxy_failure", timeout=1.0).text == "hello1" for _ in range(10): response = request_with_retries("/proxy_failure", timeout=30) assert response.text == "hello1" _kill_http_proxies() def function(_): return "hello2" serve.create_backend("proxy_failure:v2", function) serve.set_traffic("proxy_failure", {"proxy_failure:v2": 1.0}) def check_new(): for _ in range(10): response = request_with_retries("/proxy_failure", timeout=30) if response.text != "hello2": return False return True wait_for_condition(check_new)
def test_batching(serve_instance): class BatchingExample: def __init__(self): self.count = 0 @serve.accept_batch def __call__(self, flask_request, temp=None): self.count += 1 batch_size = serve.context.batch_size return [self.count] * batch_size serve.create_endpoint("counter1", "/increment2") # Keep checking the routing table until /increment is populated while "/increment2" not in requests.get( "http://127.0.0.1:8000/-/routes").json(): time.sleep(0.2) # set the max batch size serve.create_backend("counter:v11", BatchingExample, config={"max_batch_size": 5}) serve.set_traffic("counter1", {"counter:v11": 1.0}) future_list = [] handle = serve.get_handle("counter1") for _ in range(20): f = handle.remote(temp=1) future_list.append(f) counter_result = ray.get(future_list) # since count is only updated per batch of queries # If there atleast one __call__ fn call with batch size greater than 1 # counter result will always be less than 20 assert max(counter_result) < 20
def test_updating_config(serve_instance): class BatchSimple: def __init__(self): self.count = 0 @serve.accept_batch def __call__(self, flask_request, temp=None): batch_size = serve.context.batch_size return [1] * batch_size serve.create_endpoint("bsimple", "/bsimple") serve.create_backend("bsimple:v1", BatchSimple, config={ "max_batch_size": 2, "num_replicas": 3 }) master_actor = serve.api._get_master_actor() old_replica_tag_list = ray.get( master_actor._list_replicas.remote("bsimple:v1")) serve.update_backend_config("bsimple:v1", {"max_batch_size": 5}) new_replica_tag_list = ray.get( master_actor._list_replicas.remote("bsimple:v1")) new_all_tag_list = [] for worker_dict in ray.get( master_actor.get_all_worker_handles.remote()).values(): new_all_tag_list.extend(list(worker_dict.keys())) # the old and new replica tag list should be identical # and should be subset of all_tag_list assert set(old_replica_tag_list) <= set(new_all_tag_list) assert set(old_replica_tag_list) == set(new_replica_tag_list)
def test_router_failure(serve_instance): serve.init() serve.create_endpoint("router_failure", "/router_failure") def function(): return "hello1" serve.create_backend(function, "router_failure:v1") serve.set_traffic("router_failure", {"router_failure:v1": 1.0}) assert request_with_retries("/router_failure", timeout=5).text == "hello1" for _ in range(10): response = request_with_retries("/router_failure", timeout=30) assert response.text == "hello1" _kill_router() for _ in range(10): response = request_with_retries("/router_failure", timeout=30) assert response.text == "hello1" def function(): return "hello2" serve.create_backend(function, "router_failure:v2") serve.set_traffic("router_failure", {"router_failure:v2": 1.0}) for _ in range(10): response = request_with_retries("/router_failure", timeout=30) assert response.text == "hello2"
def test_app_level_batching(serve_instance): class BatchingExample: def __init__(self): self.count = 0 @serve.batch(max_batch_size=5, batch_wait_timeout_s=1) async def handle_batch(self, requests): self.count += 1 batch_size = len(requests) return [self.count] * batch_size async def __call__(self, request): return await self.handle_batch(request) # set the max batch size serve.create_backend("counter:v11", BatchingExample) serve.create_endpoint("counter1", backend="counter:v11", route="/increment2") future_list = [] handle = serve.get_handle("counter1") for _ in range(20): f = handle.remote(temp=1) future_list.append(f) counter_result = ray.get(future_list) # since count is only updated per batch of queries # If there atleast one __call__ fn call with batch size greater than 1 # counter result will always be less than 20 assert max(counter_result) < 20
def test_np_in_composed_model(serve_instance): # https://github.com/ray-project/ray/issues/9441 # AttributeError: 'bytes' object has no attribute 'readonly' # in cloudpickle _from_numpy_buffer def sum_model(_request, data=None): return np.sum(data) class ComposedModel: def __init__(self): self.model = serve.get_handle("sum_model") async def __call__(self, _request): data = np.ones((10, 10)) result = await self.model.remote(data=data) return result serve.create_backend("sum_model", sum_model) serve.create_endpoint("sum_model", backend="sum_model") serve.create_backend("model", ComposedModel) serve.create_endpoint("model", backend="model", route="/model", methods=['GET']) result = requests.get("http://127.0.0.1:8000/model") assert result.status_code == 200 assert result.json() == 100.0
def test_e2e(serve_instance): serve.init() serve.create_endpoint("endpoint", "/api", methods=["GET", "POST"]) retry_count = 5 timeout_sleep = 0.5 while True: try: resp = requests.get("http://127.0.0.1:8000/-/routes", timeout=0.5).json() assert resp == {"/api": ["endpoint", ["GET", "POST"]]} break except Exception as e: time.sleep(timeout_sleep) timeout_sleep *= 2 retry_count -= 1 if retry_count == 0: assert False, ("Route table hasn't been updated after 3 tries." "The latest error was {}").format(e) def function(flask_request): return {"method": flask_request.method} serve.create_backend("echo:v1", function) serve.set_traffic("endpoint", {"echo:v1": 1.0}) resp = requests.get("http://127.0.0.1:8000/api").json()["method"] assert resp == "GET" resp = requests.post("http://127.0.0.1:8000/api").json()["method"] assert resp == "POST"
def test_shadow_traffic(serve_instance): @ray.remote class RequestCounter: def __init__(self): self.requests = defaultdict(int) def record(self, backend): self.requests[backend] += 1 def get(self, backend): return self.requests[backend] counter = RequestCounter.remote() def f(): ray.get(counter.record.remote("backend1")) return "hello" def f_shadow_1(): ray.get(counter.record.remote("backend2")) return "oops" def f_shadow_2(): ray.get(counter.record.remote("backend3")) return "oops" def f_shadow_3(): ray.get(counter.record.remote("backend4")) return "oops" serve.create_backend("backend1", f) serve.create_backend("backend2", f_shadow_1) serve.create_backend("backend3", f_shadow_2) serve.create_backend("backend4", f_shadow_3) serve.create_endpoint("endpoint", backend="backend1", route="/api") serve.shadow_traffic("endpoint", "backend2", 1.0) serve.shadow_traffic("endpoint", "backend3", 0.5) serve.shadow_traffic("endpoint", "backend4", 0.1) start = time.time() num_requests = 100 for _ in range(num_requests): assert requests.get("http://127.0.0.1:8000/api").text == "hello" print("Finished 100 requests in {}s.".format(time.time() - start)) def requests_to_backend(backend): return ray.get(counter.get.remote(backend)) def check_requests(): return all([ requests_to_backend("backend1") == num_requests, requests_to_backend("backend2") == requests_to_backend("backend1"), requests_to_backend("backend3") < requests_to_backend("backend2"), requests_to_backend("backend4") < requests_to_backend("backend3"), requests_to_backend("backend4") > 0, ]) wait_for_condition(check_requests)
def test_worker_replica_failure(serve_instance): @ray.remote class Counter: def __init__(self): self.count = 0 def inc_and_get(self): self.count += 1 return self.count class Worker: # Assumes that two replicas are started. Will hang forever in the # constructor for any workers that are restarted. def __init__(self, counter): self.should_hang = False self.index = ray.get(counter.inc_and_get.remote()) if self.index > 2: while True: pass def __call__(self, *args): return self.index counter = Counter.remote() serve.create_backend("replica_failure", Worker, counter) serve.update_backend_config("replica_failure", BackendConfig(num_replicas=2)) serve.create_endpoint("replica_failure", backend="replica_failure", route="/replica_failure") # Wait until both replicas have been started. responses = set() start = time.time() while time.time() - start < 30: time.sleep(0.1) response = request_with_retries("/replica_failure", timeout=1).text assert response in ["1", "2"] responses.add(response) if len(responses) > 1: break else: raise TimeoutError("Timed out waiting for replicas after 30s.") # Kill one of the replicas. handles = _get_worker_handles("replica_failure") assert len(handles) == 2 ray.kill(handles[0], no_restart=False) # Check that the other replica still serves requests. for _ in range(10): while True: try: # The timeout needs to be small here because the request to # the restarting worker will hang. request_with_retries("/replica_failure", timeout=0.1) break except TimeoutError: time.sleep(0.1)
def test_worker_replica_failure(serve_instance): serve.http_proxy.MAX_ACTOR_DEAD_RETRIES = 0 serve.init() class Worker: # Assumes that two replicas are started. Will hang forever in the # constructor for any workers that are restarted. def __init__(self, path): self.should_hang = False if not os.path.exists(path): with open(path, "w") as f: f.write("1") else: with open(path, "r") as f: num = int(f.read()) with open(path, "w") as f: if num == 2: self.should_hang = True else: f.write(str(num + 1)) if self.should_hang: while True: pass def __call__(self): pass temp_path = os.path.join(tempfile.gettempdir(), serve.utils.get_random_letters()) serve.create_backend("replica_failure", Worker, temp_path) serve.update_backend_config("replica_failure", BackendConfig(num_replicas=2)) serve.create_endpoint("replica_failure", backend="replica_failure", route="/replica_failure") # Wait until both replicas have been started. responses = set() while len(responses) == 1: responses.add(request_with_retries("/replica_failure", timeout=1).text) time.sleep(0.1) # Kill one of the replicas. handles = _get_worker_handles("replica_failure") assert len(handles) == 2 ray.kill(handles[0], no_restart=False) # Check that the other replica still serves requests. for _ in range(10): while True: try: # The timeout needs to be small here because the request to # the restarting worker will hang. request_with_retries("/replica_failure", timeout=0.1) break except TimeoutError: time.sleep(0.1)
def test_set_traffic_missing_data(serve_instance): endpoint_name = "foobar" backend_name = "foo_backend" serve.create_backend(backend_name, lambda: 5) serve.create_endpoint(endpoint_name, backend=backend_name) with pytest.raises(ValueError): serve.set_traffic(endpoint_name, {"nonexistent_backend": 1.0}) with pytest.raises(ValueError): serve.set_traffic("nonexistent_endpoint_name", {backend_name: 1.0})
def test_repeated_get_handle_cached(serve_instance): def f(_): return "" serve.create_backend("m", f) serve.create_endpoint("m", backend="m") handle_sets = {serve.get_handle("m") for _ in range(100)} assert len(handle_sets) == 1
def test_no_route(serve_instance): def func(_, i=1): return 1 serve.create_backend("backend:1", func) serve.create_endpoint("noroute-endpoint", backend="backend:1") service_handle = serve.get_handle("noroute-endpoint") result = ray.get(service_handle.remote(i=1)) assert result == 1
def test_reject_duplicate_route(serve_instance): def f(): pass serve.create_backend("backend", f) route = "/foo" serve.create_endpoint("bar", backend="backend", route=route) with pytest.raises(ValueError): serve.create_endpoint("foo", backend="backend", route=route)
def test_serve_metrics(serve_instance): @serve.accept_batch def batcher(starlette_requests): return ["hello"] * len(starlette_requests) serve.create_backend("metrics", batcher) serve.create_endpoint("metrics", backend="metrics", route="/metrics") # send 10 concurrent requests url = "http://127.0.0.1:8000/metrics" ray.get([block_until_http_ready.remote(url) for _ in range(10)]) def verify_metrics(do_assert=False): try: resp = requests.get("http://127.0.0.1:9999").text # Requests will fail if we are crashing the controller except requests.ConnectionError: return False expected_metrics = [ # counter "num_router_requests_total", "num_http_requests_total", "backend_queued_queries_total", "backend_request_counter_requests_total", "backend_worker_starts_restarts_total", # histogram "backend_processing_latency_ms_bucket", "backend_processing_latency_ms_count", "backend_processing_latency_ms_sum", "backend_queuing_latency_ms_bucket", "backend_queuing_latency_ms_count", "backend_queuing_latency_ms_sum", # gauge "replica_processing_queries", "replica_queued_queries", # handle "serve_handle_request_counter", # ReplicaSet "backend_queued_queries" ] for metric in expected_metrics: # For the final error round if do_assert: assert metric in resp # For the wait_for_condition else: if metric not in resp: return False return True try: wait_for_condition(verify_metrics, retry_interval_ms=500) except RuntimeError: verify_metrics()
async def test_system_metric_endpoints(serve_instance): def test_error_counter(flask_request): 1 / 0 serve.create_backend("m:v1", test_error_counter) serve.create_endpoint("test_metrics", backend="m:v1", route="/measure") serve.set_traffic("test_metrics", {"m:v1": 1}) # Check metrics are exposed under http endpoint def test_metric_endpoint(): requests.get("http://127.0.0.1:8000/measure", timeout=5) in_memory_metric = requests.get( "http://127.0.0.1:8000/-/metrics", timeout=5).json() # We don't want to check the values since this check might be retried. in_memory_metric_without_values = [] for m in in_memory_metric: m.pop("value") in_memory_metric_without_values.append(m) target_metrics = [{ "info": { "name": "num_http_requests", "type": "MetricType.COUNTER", "route": "/measure" }, }, { "info": { "name": "num_router_requests", "type": "MetricType.COUNTER", "endpoint": "test_metrics" }, }, { "info": { "name": "backend_error_counter", "type": "MetricType.COUNTER", "backend": "m:v1" }, }] for target in target_metrics: assert target in in_memory_metric_without_values success = False for _ in range(3): try: test_metric_endpoint() success = True break except (AssertionError, requests.ReadTimeout): # Metrics may not have been propagated yet time.sleep(2) print("Metric not correct, retrying...") if not success: test_metric_endpoint()
def test_reject_duplicate_endpoint(serve_instance): def f(): pass serve.create_backend("backend", f) endpoint_name = "foo" serve.create_endpoint(endpoint_name, backend="backend", route="/ok") with pytest.raises(ValueError): serve.create_endpoint( endpoint_name, backend="backend", route="/different")
def test_controller_failure(serve_instance): def function(_): return "hello1" serve.create_backend("controller_failure:v1", function) serve.create_endpoint( "controller_failure", backend="controller_failure:v1", route="/controller_failure") assert request_with_retries( "/controller_failure", timeout=1).text == "hello1" for _ in range(10): response = request_with_retries("/controller_failure", timeout=30) assert response.text == "hello1" ray.kill(serve.api._global_client._controller, no_restart=False) for _ in range(10): response = request_with_retries("/controller_failure", timeout=30) assert response.text == "hello1" def function(_): return "hello2" ray.kill(serve.api._global_client._controller, no_restart=False) serve.create_backend("controller_failure:v2", function) serve.set_traffic("controller_failure", {"controller_failure:v2": 1.0}) def check_controller_failure(): response = request_with_retries("/controller_failure", timeout=30) return response.text == "hello2" wait_for_condition(check_controller_failure) def function(_): return "hello3" ray.kill(serve.api._global_client._controller, no_restart=False) serve.create_backend("controller_failure_2", function) ray.kill(serve.api._global_client._controller, no_restart=False) serve.create_endpoint( "controller_failure_2", backend="controller_failure_2", route="/controller_failure_2") ray.kill(serve.api._global_client._controller, no_restart=False) for _ in range(10): response = request_with_retries("/controller_failure", timeout=30) assert response.text == "hello2" response = request_with_retries("/controller_failure_2", timeout=30) assert response.text == "hello3"
def test_controller_failure(serve_instance): serve.init() def function(): return "hello1" serve.create_backend("controller_failure:v1", function) serve.create_endpoint( "controller_failure", backend="controller_failure:v1", route="/controller_failure") assert request_with_retries( "/controller_failure", timeout=1).text == "hello1" for _ in range(10): response = request_with_retries("/controller_failure", timeout=30) assert response.text == "hello1" ray.kill(serve.api._get_controller(), no_restart=False) for _ in range(10): response = request_with_retries("/controller_failure", timeout=30) assert response.text == "hello1" def function(): return "hello2" ray.kill(serve.api._get_controller(), no_restart=False) serve.create_backend("controller_failure:v2", function) serve.set_traffic("controller_failure", {"controller_failure:v2": 1.0}) for _ in range(10): response = request_with_retries("/controller_failure", timeout=30) assert response.text == "hello2" def function(): return "hello3" ray.kill(serve.api._get_controller(), no_restart=False) serve.create_backend("controller_failure_2", function) ray.kill(serve.api._get_controller(), no_restart=False) serve.create_endpoint( "controller_failure_2", backend="controller_failure_2", route="/controller_failure_2") ray.kill(serve.api._get_controller(), no_restart=False) for _ in range(10): response = request_with_retries("/controller_failure", timeout=30) assert response.text == "hello2" response = request_with_retries("/controller_failure_2", timeout=30) assert response.text == "hello3"
def test_no_route(serve_instance): serve.create_endpoint("noroute-endpoint") def func(_, i=1): return 1 serve.create_backend(func, "backend:1") serve.set_traffic("noroute-endpoint", {"backend:1": 1.0}) service_handle = serve.get_handle("noroute-endpoint") result = ray.get(service_handle.remote(i=1)) assert result == 1
def test_controller_inflight_requests_clear(serve_instance): controller = serve.api._global_client._controller initial_number_reqs = ray.get(controller._num_pending_goals.remote()) def function(_): return "hello" serve.create_backend("tst", function) serve.create_endpoint("end_pt", backend="tst") assert ray.get( controller._num_pending_goals.remote()) - initial_number_reqs == 0
def test_nonblocking(): serve.init() serve.create_endpoint("nonblocking", "/nonblocking") def function(flask_request): return {"method": flask_request.method} serve.create_backend(function, "nonblocking:v1") serve.set_traffic("nonblocking", {"nonblocking:v1": 1.0}) resp = requests.get("http://127.0.0.1:8000/nonblocking").json()["method"] assert resp == "GET"
def test_connect(detached, ray_shutdown): # Check that you can call serve.connect() from within a backend for both # detached and non-detached instances. ray.init(num_cpus=16) serve.start(detached=detached) def connect_in_backend(_): serve.create_backend("backend-ception", connect_in_backend) serve.create_backend("connect_in_backend", connect_in_backend) serve.create_endpoint("endpoint", backend="connect_in_backend") ray.get(serve.get_handle("endpoint").remote()) assert "backend-ception" in serve.list_backends().keys()