def test_worker_replica_failure(serve_instance): serve.http_proxy.MAX_ACTOR_DEAD_RETRIES = 0 serve.init() serve.create_endpoint( "replica_failure", "/replica_failure", methods=["GET"]) class Worker: # Assumes that two replicas are started. Will hang forever in the # constructor for any workers that are restarted. def __init__(self, path): self.should_hang = False if not os.path.exists(path): with open(path, "w") as f: f.write("1") else: with open(path, "r") as f: num = int(f.read()) with open(path, "w") as f: if num == 2: self.should_hang = True else: f.write(str(num + 1)) if self.should_hang: while True: pass def __call__(self): pass temp_path = tempfile.gettempdir() + "/" + serve.utils.get_random_letters() serve.create_backend(Worker, "replica_failure", temp_path) backend_config = serve.get_backend_config("replica_failure") backend_config.num_replicas = 2 serve.set_backend_config("replica_failure", backend_config) serve.set_traffic("replica_failure", {"replica_failure": 1.0}) # Wait until both replicas have been started. responses = set() while len(responses) == 1: responses.add( request_with_retries("/replica_failure", timeout=0.1).text) time.sleep(0.1) # Kill one of the replicas. handles = _get_worker_handles("replica_failure") assert len(handles) == 2 ray.kill(handles[0]) # Check that the other replica still serves requests. for _ in range(10): while True: try: # The timeout needs to be small here because the request to # the restarting worker will hang. request_with_retries("/replica_failure", timeout=0.1) break except TimeoutError: time.sleep(0.1)
def test_e2e(serve_instance): serve.init() # so we have access to global state serve.create_endpoint("endpoint", "/api", blocking=True) result = serve.api._get_global_state().route_table.list_service() assert result["/api"] == "endpoint" retry_count = 5 timeout_sleep = 0.5 while True: try: resp = requests.get("http://127.0.0.1:8000/", timeout=0.5).json() assert resp == result break except Exception: time.sleep(timeout_sleep) timeout_sleep *= 2 retry_count -= 1 if retry_count == 0: assert False, "Route table hasn't been updated after 3 tries." def function(flask_request): return "OK" serve.create_backend(function, "echo:v1") serve.link("endpoint", "echo:v1") resp = requests.get("http://127.0.0.1:8000/api").json()["result"] assert resp == "OK"
def test_list_endpoints(serve_instance): serve.init() def f(): pass serve.create_endpoint("endpoint", "/api", methods=["GET", "POST"]) serve.create_endpoint("endpoint2", methods=["POST"]) serve.create_backend("backend", f) serve.set_traffic("endpoint2", {"backend": 1.0}) endpoints = serve.list_endpoints() assert "endpoint" in endpoints assert endpoints["endpoint"] == { "route": "/api", "methods": ["GET", "POST"], "traffic": {} } assert "endpoint2" in endpoints assert endpoints["endpoint2"] == { "route": None, "methods": ["POST"], "traffic": { "backend": 1.0 } } serve.delete_endpoint("endpoint") assert "endpoint2" in serve.list_endpoints() serve.delete_endpoint("endpoint2") assert len(serve.list_endpoints()) == 0
def __init__(self, func_to_run): serve.init() self.func = func_to_run # This parameter let argument inspection work with inner function. self.__wrapped__ = func_to_run
def serve_instance(): _, new_db_path = tempfile.mkstemp(suffix=".test.db") serve.init(kv_store_path=new_db_path, blocking=True, ray_init_kwargs={"num_cpus": 36}) yield os.remove(new_db_path)
def backend_setup(tag: str, worker_args: Tuple, replicas: int, max_batch_size: int) -> None: """ Setups the backend for the distributed explanation task. Parameters ---------- tag A tag for the backend component. The same tag must be passed to `endpoint_setup`. worker_args A tuple containing the arguments for initialising the explainer and fitting it. replicas The number of backend replicas that serve explanations. max_batch_size Maximum number of requests to batch and send to a worker process. """ serve.init() if max_batch_size == 1: config = {'num_replicas': max(replicas, 1)} serve.create_backend(tag, KernelShapModel, *worker_args) else: config = { 'num_replicas': max(replicas, 1), 'max_batch_size': max_batch_size } serve.create_backend(tag, BatchKernelShapModel, *worker_args) serve.update_backend_config(tag, config) logging.info(f"Backends: {serve.list_backends()}")
def test_http_proxy_failure(serve_instance): serve.init() serve.create_endpoint("failure_endpoint", "/failure_endpoint", methods=["GET"]) def function(flask_request): return "hello1" serve.create_backend(function, "failure:v1") serve.link("failure_endpoint", "failure:v1") def verify_response(response): assert response.text == "hello1" request_with_retries("/failure_endpoint", verify_response, timeout=0) _kill_http_proxy() request_with_retries("/failure_endpoint", verify_response, timeout=30) _kill_http_proxy() def function(flask_request): return "hello2" serve.create_backend(function, "failure:v2") serve.link("failure_endpoint", "failure:v2") def verify_response(response): assert response.text == "hello2" request_with_retries("/failure_endpoint", verify_response, timeout=30)
def test_worker_restart(serve_instance): serve.init() serve.create_endpoint("worker_failure", "/worker_failure", methods=["GET"]) class Worker1: def __call__(self): return os.getpid() serve.create_backend(Worker1, "worker_failure:v1") serve.link("worker_failure", "worker_failure:v1") # Get the PID of the worker. old_pid = request_with_retries("/worker_failure", timeout=0.1).text # Kill the worker. handles = _get_worker_handles("worker_failure:v1") assert len(handles) == 1 ray.kill(list(handles.values())[0]) # Wait until the worker is killed and a one is started. start = time.time() while time.time() - start < 30: response = request_with_retries("/worker_failure", timeout=30) if response.text != old_pid: break else: assert False, "Timed out waiting for worker to die."
def test_http_proxy_failure(serve_instance): serve.init() def function(): return "hello1" serve.create_backend("proxy_failure:v1", function) serve.create_endpoint("proxy_failure", backend="proxy_failure:v1", route="/proxy_failure") assert request_with_retries("/proxy_failure", timeout=1.0).text == "hello1" for _ in range(10): response = request_with_retries("/proxy_failure", timeout=30) assert response.text == "hello1" _kill_routers() def function(): return "hello2" serve.create_backend("proxy_failure:v2", function) serve.set_traffic("proxy_failure", {"proxy_failure:v2": 1.0}) for _ in range(10): response = request_with_retries("/proxy_failure", timeout=30) assert response.text == "hello2"
def test_e2e(serve_instance): serve.init() def function(flask_request): return {"method": flask_request.method} serve.create_backend("echo:v1", function) serve.create_endpoint( "endpoint", backend="echo:v1", route="/api", methods=["GET", "POST"]) retry_count = 5 timeout_sleep = 0.5 while True: try: resp = requests.get( "http://127.0.0.1:8000/-/routes", timeout=0.5).json() assert resp == {"/api": ["endpoint", ["GET", "POST"]]} break except Exception as e: time.sleep(timeout_sleep) timeout_sleep *= 2 retry_count -= 1 if retry_count == 0: assert False, ("Route table hasn't been updated after 3 tries." "The latest error was {}").format(e) resp = requests.get("http://127.0.0.1:8000/api").json()["method"] assert resp == "GET" resp = requests.post("http://127.0.0.1:8000/api").json()["method"] assert resp == "POST"
def test_create_infeasible_error(serve_instance): serve.init() def f(): pass # Non existent resource should be infeasible. with pytest.raises(RayServeException, match="Cannot scale backend"): serve.create_backend( "f:1", f, ray_actor_options={"resources": { "MagicMLResource": 100 }}) # Even each replica might be feasible, the total might not be. current_cpus = int(ray.nodes()[0]["Resources"]["CPU"]) with pytest.raises(RayServeException, match="Cannot scale backend"): serve.create_backend("f:1", f, ray_actor_options={"resources": { "CPU": 1, }}, config={"num_replicas": current_cpus + 20}) # No replica should be created! replicas = ray.get(serve.api.master_actor._list_replicas.remote("f1")) assert len(replicas) == 0
def serve_new_model(model_dir, checkpoint, config, metrics, day, gpu=False): print("Serving checkpoint: {}".format(checkpoint)) checkpoint_path = _move_checkpoint_to_model_dir(model_dir, checkpoint, config, metrics) serve.init() backend_name = "mnist:day_{}".format(day) serve.create_backend(backend_name, MNISTBackend, checkpoint_path, config, metrics, gpu) if "mnist" not in serve.list_endpoints(): # First time we serve a model - create endpoint serve.create_endpoint("mnist", backend=backend_name, route="/mnist", methods=["POST"]) else: # The endpoint already exists, route all traffic to the new model # Here you could also implement an incremental rollout, where only # a part of the traffic is sent to the new backend and the # rest is sent to the existing backends. serve.set_traffic("mnist", {backend_name: 1.0}) # Delete previous existing backends for existing_backend in serve.list_backends(): if existing_backend.startswith("mnist:day") and \ existing_backend != backend_name: serve.delete_backend(existing_backend) return True
def test_list_backends(serve_instance): serve.init() @serve.accept_batch def f(): pass serve.create_backend("backend", f, config={"max_batch_size": 10}) backends = serve.list_backends() assert len(backends) == 1 assert "backend" in backends assert backends["backend"]["max_batch_size"] == 10 serve.create_backend("backend2", f, config={"num_replicas": 10}) backends = serve.list_backends() assert len(backends) == 2 assert backends["backend2"]["num_replicas"] == 10 serve.delete_backend("backend") backends = serve.list_backends() assert len(backends) == 1 assert "backend2" in backends serve.delete_backend("backend2") assert len(serve.list_backends()) == 0
def test_middleware(): from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware port = new_port() serve.init(http_port=port, http_middlewares=[ Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"]) ]) ray.get(block_until_http_ready.remote(f"http://127.0.0.1:{port}/-/routes")) # Snatched several test cases from Starlette # https://github.com/encode/starlette/blob/master/tests/ # middleware/test_cors.py headers = { "Origin": "https://example.org", "Access-Control-Request-Method": "GET", } root = f"http://localhost:{port}" resp = requests.options(root, headers=headers) assert resp.headers["access-control-allow-origin"] == "*" resp = requests.get(f"{root}/-/routes", headers=headers) assert resp.headers["access-control-allow-origin"] == "*"
def test_shutdown(serve_instance): def f(): pass instance_name = "shutdown" serve.init(name=instance_name, http_port=8002) serve.create_backend("backend", f) serve.create_endpoint("endpoint", backend="backend") serve.shutdown() with pytest.raises(RayServeException, match="Please run serve.init"): serve.list_backends() def check_dead(): for actor_name in [ constants.SERVE_MASTER_NAME, constants.SERVE_PROXY_NAME, constants.SERVE_ROUTER_NAME, constants.SERVE_METRIC_SINK_NAME ]: try: ray.get_actor(format_actor_name(actor_name, instance_name)) return False except ValueError: pass return True assert wait_for_condition(check_dead)
def test_http_proxy_failure(serve_instance): serve.init() serve.create_endpoint("proxy_failure", "/proxy_failure", methods=["GET"]) def function(): return "hello1" serve.create_backend(function, "proxy_failure:v1") serve.link("proxy_failure", "proxy_failure:v1") assert request_with_retries("/proxy_failure", timeout=0.1).text == "hello1" for _ in range(10): response = request_with_retries("/proxy_failure", timeout=30) assert response.text == "hello1" _kill_http_proxy() def function(): return "hello2" serve.create_backend(function, "proxy_failure:v2") serve.link("proxy_failure", "proxy_failure:v2") for _ in range(10): response = request_with_retries("/proxy_failure", timeout=30) assert response.text == "hello2"
def test_router_failure(serve_instance): serve.init() serve.create_endpoint("router_failure", "/router_failure") def function(): return "hello1" serve.create_backend("router_failure:v1", function) serve.set_traffic("router_failure", {"router_failure:v1": 1.0}) assert request_with_retries("/router_failure", timeout=5).text == "hello1" for _ in range(10): response = request_with_retries("/router_failure", timeout=30) assert response.text == "hello1" _kill_router() for _ in range(10): response = request_with_retries("/router_failure", timeout=30) assert response.text == "hello1" def function(): return "hello2" serve.create_backend("router_failure:v2", function) serve.set_traffic("router_failure", {"router_failure:v2": 1.0}) for _ in range(10): response = request_with_retries("/router_failure", timeout=30) assert response.text == "hello2"
def test_handle_in_endpoint(serve_instance): serve.init() class Endpoint1: def __call__(self, flask_request): return "hello" class Endpoint2: def __init__(self): self.handle = serve.get_handle("endpoint1", missing_ok=True) def __call__(self): return ray.get(self.handle.remote()) serve.create_backend("endpoint1:v0", Endpoint1) serve.create_endpoint( "endpoint1", backend="endpoint1:v0", route="/endpoint1", methods=["GET", "POST"]) serve.create_backend("endpoint2:v0", Endpoint2) serve.create_endpoint( "endpoint2", backend="endpoint2:v0", route="/endpoint2", methods=["GET", "POST"]) assert requests.get("http://127.0.0.1:8000/endpoint2").text == "hello"
async def main(): ray.init(log_to_driver=False) serve.init() serve.create_backend("backend", backend) serve.create_endpoint("endpoint", backend="backend", route="/api") actors = [Client.remote() for _ in range(NUM_CLIENTS)] for num_replicas in [1, 8]: for backend_config in [ { "max_batch_size": 1, "max_concurrent_queries": 1 }, { "max_batch_size": 1, "max_concurrent_queries": 10000 }, { "max_batch_size": 10000, "max_concurrent_queries": 10000 }, ]: backend_config["num_replicas"] = num_replicas serve.update_backend_config("backend", backend_config) print(repr(backend_config) + ":") async with aiohttp.ClientSession() as session: # TODO(edoakes): large data causes broken pipe errors. for data_size in ["small"]: await trial(actors, session, data_size)
def test_worker_restart(serve_instance): serve.init() class Worker1: def __call__(self): return os.getpid() serve.create_backend("worker_failure:v1", Worker1) serve.create_endpoint("worker_failure", backend="worker_failure:v1", route="/worker_failure") # Get the PID of the worker. old_pid = request_with_retries("/worker_failure", timeout=1).text # Kill the worker. handles = _get_worker_handles("worker_failure:v1") assert len(handles) == 1 ray.kill(handles[0], no_restart=False) # Wait until the worker is killed and a one is started. start = time.time() while time.time() - start < 30: response = request_with_retries("/worker_failure", timeout=30) if response.text != old_pid: break else: assert False, "Timed out waiting for worker to die."
def main(): if not os.path.exists('results'): os.mkdir('results') data = load_data() X_explain = data['all']['X']['processed']['test'].toarray() max_batch_size = [int(elem) for elem in args.max_batch_size][0] batch_mode, replicas = args.batch_mode, args.replicas ray.init(address='auto') # connect to the cluster serve.init( http_host='0.0.0.0' ) # listen on 0.0.0.0 to make endpoint accessible from other machines host, route = os.environ.get("RAY_HEAD_SERVICE_HOST", args.host), "explain" url = f"http://{host}:{args.port}/{route}" backend_tag = "kernel_shap:b100" # b100 means 100 background samples endpoint_tag = f"{backend_tag}_endpoint" worker_args = prepare_explainer_args(data) if batch_mode == 'ray': backend_setup(backend_tag, worker_args, replicas, max_batch_size) logging.info(f"Batching with max_batch_size of {max_batch_size} ...") else: # minibatches are sent to the ray worker backend_setup(backend_tag, worker_args, replicas, 1) logging.info(f"Minibatches distributed of size {max_batch_size} ...") endpont_setup(endpoint_tag, backend_tag, route=f"/{route}") run_explainer(X_explain, args.n_runs, replicas, max_batch_size, batch_mode=batch_mode, url=url)
def __init__(self, backend_tag, replica_tag, init_args): serve.init() if is_function: _callable = func_or_class else: _callable = func_or_class(*init_args) self.backend = RayServeWorker(backend_tag, _callable, is_function)
def scale(backend_tag, num_replicas): if num_replicas <= 0: click.Abort( "Cannot set number of replicas to be smaller or equal to 0.") ray.init(address="auto") serve.init() serve.scale(backend_tag, num_replicas)
def serve_instance(_shared_serve_instance): serve.init() yield master = serve.api._get_master_actor() # Clear all state between tests to avoid naming collisions. for endpoint in retry_actor_failures(master.get_all_endpoints): serve.delete_endpoint(endpoint) for backend in retry_actor_failures(master.get_all_backends): serve.delete_backend(backend)
async def __init__(self, host, port, instance_name=None): serve.init(name=instance_name) self.app = HTTPProxy() await self.app.fetch_config_from_controller(instance_name) self.host = host self.port = port # Start running the HTTP server on the event loop. asyncio.get_event_loop().create_task(self.run())
def test_controller_failure(serve_instance): serve.init() def function(): return "hello1" serve.create_backend("controller_failure:v1", function) serve.create_endpoint( "controller_failure", backend="controller_failure:v1", route="/controller_failure") assert request_with_retries( "/controller_failure", timeout=1).text == "hello1" for _ in range(10): response = request_with_retries("/controller_failure", timeout=30) assert response.text == "hello1" ray.kill(serve.api._get_controller(), no_restart=False) for _ in range(10): response = request_with_retries("/controller_failure", timeout=30) assert response.text == "hello1" def function(): return "hello2" ray.kill(serve.api._get_controller(), no_restart=False) serve.create_backend("controller_failure:v2", function) serve.set_traffic("controller_failure", {"controller_failure:v2": 1.0}) for _ in range(10): response = request_with_retries("/controller_failure", timeout=30) assert response.text == "hello2" def function(): return "hello3" ray.kill(serve.api._get_controller(), no_restart=False) serve.create_backend("controller_failure_2", function) ray.kill(serve.api._get_controller(), no_restart=False) serve.create_endpoint( "controller_failure_2", backend="controller_failure_2", route="/controller_failure_2") ray.kill(serve.api._get_controller(), no_restart=False) for _ in range(10): response = request_with_retries("/controller_failure", timeout=30) assert response.text == "hello2" response = request_with_retries("/controller_failure_2", timeout=30) assert response.text == "hello3"
def serve_instance(_shared_serve_instance): serve.init() yield # Re-init if necessary. serve.init() controller = serve.api._get_controller() # Clear all state between tests to avoid naming collisions. for endpoint in ray.get(controller.get_all_endpoints.remote()): serve.delete_endpoint(endpoint) for backend in ray.get(controller.get_all_backends.remote()): serve.delete_backend(backend)
def test_nonblocking(): serve.init() serve.create_endpoint("nonblocking", "/nonblocking") def function(flask_request): return {"method": flask_request.method} serve.create_backend(function, "nonblocking:v1") serve.set_traffic("nonblocking", {"nonblocking:v1": 1.0}) resp = requests.get("http://127.0.0.1:8000/nonblocking").json()["method"] assert resp == "GET"
def test_nonblocking(): serve.init() def function(flask_request): return {"method": flask_request.method} serve.create_backend("nonblocking:v1", function) serve.create_endpoint("nonblocking", backend="nonblocking:v1", route="/nonblocking") resp = requests.get("http://127.0.0.1:8000/nonblocking").json()["method"] assert resp == "GET"
def test_endpoint_input_validation(serve_instance): serve.init() def f(): pass serve.create_backend("backend", f) with pytest.raises(TypeError): serve.create_endpoint("endpoint") with pytest.raises(TypeError): serve.create_endpoint("endpoint", route="/hello") with pytest.raises(TypeError): serve.create_endpoint("endpoint", backend=2) serve.create_endpoint("endpoint", backend="backend")