예제 #1
0
def test_http_proxy_failure(serve_instance):
    serve.init()

    def function():
        return "hello1"

    serve.create_backend("proxy_failure:v1", function)
    serve.create_endpoint(
        "proxy_failure", backend="proxy_failure:v1", route="/proxy_failure")

    assert request_with_retries("/proxy_failure", timeout=1.0).text == "hello1"

    for _ in range(10):
        response = request_with_retries("/proxy_failure", timeout=30)
        assert response.text == "hello1"

    _kill_http_proxy()

    def function():
        return "hello2"

    serve.create_backend("proxy_failure:v2", function)
    serve.set_traffic("proxy_failure", {"proxy_failure:v2": 1.0})

    for _ in range(10):
        response = request_with_retries("/proxy_failure", timeout=30)
        assert response.text == "hello2"
예제 #2
0
def test_shutdown(serve_instance):
    def f():
        pass

    instance_name = "shutdown"
    serve.init(name=instance_name, http_port=8003)
    serve.create_backend("backend", f)
    serve.create_endpoint("endpoint", backend="backend")

    serve.shutdown()
    with pytest.raises(RayServeException, match="Please run serve.init"):
        serve.list_backends()

    def check_dead():
        for actor_name in [
                constants.SERVE_CONTROLLER_NAME, constants.SERVE_PROXY_NAME,
                constants.SERVE_METRIC_SINK_NAME
        ]:
            try:
                ray.get_actor(format_actor_name(actor_name, instance_name))
                return False
            except ValueError:
                pass
        return True

    wait_for_condition(check_dead)
예제 #3
0
def test_worker_restart(serve_instance):
    serve.init()

    class Worker1:
        def __call__(self):
            return os.getpid()

    serve.create_backend("worker_failure:v1", Worker1)
    serve.create_endpoint(
        "worker_failure", backend="worker_failure:v1", route="/worker_failure")

    # Get the PID of the worker.
    old_pid = request_with_retries("/worker_failure", timeout=1).text

    # Kill the worker.
    handles = _get_worker_handles("worker_failure:v1")
    assert len(handles) == 1
    ray.kill(handles[0], no_restart=False)

    # Wait until the worker is killed and a one is started.
    start = time.time()
    while time.time() - start < 30:
        response = request_with_retries("/worker_failure", timeout=30)
        if response.text != old_pid:
            break
    else:
        assert False, "Timed out waiting for worker to die."
예제 #4
0
def test_updating_config(serve_instance):
    class BatchSimple:
        def __init__(self):
            self.count = 0

        def __call__(self, request):
            return 1

    config = BackendConfig(max_concurrent_queries=2, num_replicas=3)
    serve.create_backend("bsimple:v1", BatchSimple, config=config)
    serve.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple")

    controller = serve.api._global_client._controller
    old_replica_tag_list = list(
        ray.get(controller._all_replica_handles.remote())["bsimple:v1"].keys())

    update_config = BackendConfig(max_concurrent_queries=5)
    serve.update_backend_config("bsimple:v1", update_config)
    new_replica_tag_list = list(
        ray.get(controller._all_replica_handles.remote())["bsimple:v1"].keys())
    new_all_tag_list = []
    for worker_dict in ray.get(
            controller._all_replica_handles.remote()).values():
        new_all_tag_list.extend(list(worker_dict.keys()))

    # the old and new replica tag list should be identical
    # and should be subset of all_tag_list
    assert set(old_replica_tag_list) <= set(new_all_tag_list)
    assert set(old_replica_tag_list) == set(new_replica_tag_list)
예제 #5
0
def test_scaling_replicas(serve_instance):
    class Counter:
        def __init__(self):
            self.count = 0

        def __call__(self, _):
            self.count += 1
            return self.count

    serve.create_backend("counter:v1", Counter, config={"num_replicas": 2})
    serve.create_endpoint("counter", backend="counter:v1", route="/increment")

    # Keep checking the routing table until /increment is populated
    while "/increment" not in requests.get(
            "http://127.0.0.1:8000/-/routes").json():
        time.sleep(0.2)

    counter_result = []
    for _ in range(10):
        resp = requests.get("http://127.0.0.1:8000/increment").json()
        counter_result.append(resp)

    # If the load is shared among two replicas. The max result cannot be 10.
    assert max(counter_result) < 10

    serve.update_backend_config("counter:v1", {"num_replicas": 1})

    counter_result = []
    for _ in range(10):
        resp = requests.get("http://127.0.0.1:8000/increment").json()
        counter_result.append(resp)
    # Give some time for a replica to spin down. But majority of the request
    # should be served by the only remaining replica.
    assert max(counter_result) - min(counter_result) > 6
예제 #6
0
def test_scaling_replicas(serve_instance):
    class Counter:
        def __init__(self):
            self.count = 0

        def __call__(self, _):
            self.count += 1
            return self.count

    config = BackendConfig(num_replicas=2)
    serve.create_backend("counter:v1", Counter, config=config)

    serve.create_endpoint("counter", backend="counter:v1", route="/increment")

    counter_result = []
    for _ in range(10):
        resp = requests.get("http://127.0.0.1:8000/increment").json()
        counter_result.append(resp)

    # If the load is shared among two replicas. The max result cannot be 10.
    assert max(counter_result) < 10

    update_config = BackendConfig(num_replicas=1)
    serve.update_backend_config("counter:v1", update_config)

    counter_result = []
    for _ in range(10):
        resp = requests.get("http://127.0.0.1:8000/increment").json()
        counter_result.append(resp)
    # Give some time for a replica to spin down. But majority of the request
    # should be served by the only remaining replica.
    assert max(counter_result) - min(counter_result) > 6
예제 #7
0
def test_backend_user_config(serve_instance):
    class Counter:
        def __init__(self):
            self.count = 10

        def __call__(self, starlette_request):
            return self.count, os.getpid()

        def reconfigure(self, config):
            self.count = config["count"]

    config = BackendConfig(num_replicas=2, user_config={"count": 123, "b": 2})
    serve.create_backend("counter", Counter, config=config)
    serve.create_endpoint("counter", backend="counter")
    handle = serve.get_handle("counter")

    def check(val, num_replicas):
        pids_seen = set()
        for i in range(100):
            result = ray.get(handle.remote())
            if str(result[0]) != val:
                return False
            pids_seen.add(result[1])
        return len(pids_seen) == num_replicas

    wait_for_condition(lambda: check("123", 2))

    serve.update_backend_config("counter", BackendConfig(num_replicas=3))
    wait_for_condition(lambda: check("123", 3))

    config = BackendConfig(user_config={"count": 456})
    serve.update_backend_config("counter", config)
    wait_for_condition(lambda: check("456", 3))
예제 #8
0
def test_http_proxy_failure(serve_instance):
    def function(_):
        return "hello1"

    serve.create_backend("proxy_failure:v1", function)
    serve.create_endpoint("proxy_failure",
                          backend="proxy_failure:v1",
                          route="/proxy_failure")

    assert request_with_retries("/proxy_failure", timeout=1.0).text == "hello1"

    for _ in range(10):
        response = request_with_retries("/proxy_failure", timeout=30)
        assert response.text == "hello1"

    _kill_http_proxies()

    def function(_):
        return "hello2"

    serve.create_backend("proxy_failure:v2", function)
    serve.set_traffic("proxy_failure", {"proxy_failure:v2": 1.0})

    def check_new():
        for _ in range(10):
            response = request_with_retries("/proxy_failure", timeout=30)
            if response.text != "hello2":
                return False
        return True

    wait_for_condition(check_new)
예제 #9
0
def test_batching(serve_instance):
    class BatchingExample:
        def __init__(self):
            self.count = 0

        @serve.accept_batch
        def __call__(self, flask_request, temp=None):
            self.count += 1
            batch_size = serve.context.batch_size
            return [self.count] * batch_size

    serve.create_endpoint("counter1", "/increment2")

    # Keep checking the routing table until /increment is populated
    while "/increment2" not in requests.get(
            "http://127.0.0.1:8000/-/routes").json():
        time.sleep(0.2)

    # set the max batch size
    serve.create_backend("counter:v11",
                         BatchingExample,
                         config={"max_batch_size": 5})
    serve.set_traffic("counter1", {"counter:v11": 1.0})

    future_list = []
    handle = serve.get_handle("counter1")
    for _ in range(20):
        f = handle.remote(temp=1)
        future_list.append(f)

    counter_result = ray.get(future_list)
    # since count is only updated per batch of queries
    # If there atleast one __call__ fn call with batch size greater than 1
    # counter result will always be less than 20
    assert max(counter_result) < 20
예제 #10
0
def test_updating_config(serve_instance):
    class BatchSimple:
        def __init__(self):
            self.count = 0

        @serve.accept_batch
        def __call__(self, flask_request, temp=None):
            batch_size = serve.context.batch_size
            return [1] * batch_size

    serve.create_endpoint("bsimple", "/bsimple")
    serve.create_backend("bsimple:v1",
                         BatchSimple,
                         config={
                             "max_batch_size": 2,
                             "num_replicas": 3
                         })
    master_actor = serve.api._get_master_actor()
    old_replica_tag_list = ray.get(
        master_actor._list_replicas.remote("bsimple:v1"))

    serve.update_backend_config("bsimple:v1", {"max_batch_size": 5})
    new_replica_tag_list = ray.get(
        master_actor._list_replicas.remote("bsimple:v1"))
    new_all_tag_list = []
    for worker_dict in ray.get(
            master_actor.get_all_worker_handles.remote()).values():
        new_all_tag_list.extend(list(worker_dict.keys()))

    # the old and new replica tag list should be identical
    # and should be subset of all_tag_list
    assert set(old_replica_tag_list) <= set(new_all_tag_list)
    assert set(old_replica_tag_list) == set(new_replica_tag_list)
예제 #11
0
def test_router_failure(serve_instance):
    serve.init()
    serve.create_endpoint("router_failure", "/router_failure")

    def function():
        return "hello1"

    serve.create_backend(function, "router_failure:v1")
    serve.set_traffic("router_failure", {"router_failure:v1": 1.0})

    assert request_with_retries("/router_failure", timeout=5).text == "hello1"

    for _ in range(10):
        response = request_with_retries("/router_failure", timeout=30)
        assert response.text == "hello1"

    _kill_router()

    for _ in range(10):
        response = request_with_retries("/router_failure", timeout=30)
        assert response.text == "hello1"

    def function():
        return "hello2"

    serve.create_backend(function, "router_failure:v2")
    serve.set_traffic("router_failure", {"router_failure:v2": 1.0})

    for _ in range(10):
        response = request_with_retries("/router_failure", timeout=30)
        assert response.text == "hello2"
예제 #12
0
def test_app_level_batching(serve_instance):
    class BatchingExample:
        def __init__(self):
            self.count = 0

        @serve.batch(max_batch_size=5, batch_wait_timeout_s=1)
        async def handle_batch(self, requests):
            self.count += 1
            batch_size = len(requests)
            return [self.count] * batch_size

        async def __call__(self, request):
            return await self.handle_batch(request)

    # set the max batch size
    serve.create_backend("counter:v11", BatchingExample)
    serve.create_endpoint("counter1",
                          backend="counter:v11",
                          route="/increment2")

    future_list = []
    handle = serve.get_handle("counter1")
    for _ in range(20):
        f = handle.remote(temp=1)
        future_list.append(f)

    counter_result = ray.get(future_list)
    # since count is only updated per batch of queries
    # If there atleast one __call__ fn call with batch size greater than 1
    # counter result will always be less than 20
    assert max(counter_result) < 20
예제 #13
0
def test_np_in_composed_model(serve_instance):
    # https://github.com/ray-project/ray/issues/9441
    # AttributeError: 'bytes' object has no attribute 'readonly'
    # in cloudpickle _from_numpy_buffer

    def sum_model(_request, data=None):
        return np.sum(data)

    class ComposedModel:
        def __init__(self):
            self.model = serve.get_handle("sum_model")

        async def __call__(self, _request):
            data = np.ones((10, 10))
            result = await self.model.remote(data=data)
            return result

    serve.create_backend("sum_model", sum_model)
    serve.create_endpoint("sum_model", backend="sum_model")
    serve.create_backend("model", ComposedModel)
    serve.create_endpoint("model",
                          backend="model",
                          route="/model",
                          methods=['GET'])

    result = requests.get("http://127.0.0.1:8000/model")
    assert result.status_code == 200
    assert result.json() == 100.0
예제 #14
0
def test_e2e(serve_instance):
    serve.init()
    serve.create_endpoint("endpoint", "/api", methods=["GET", "POST"])

    retry_count = 5
    timeout_sleep = 0.5
    while True:
        try:
            resp = requests.get("http://127.0.0.1:8000/-/routes",
                                timeout=0.5).json()
            assert resp == {"/api": ["endpoint", ["GET", "POST"]]}
            break
        except Exception as e:
            time.sleep(timeout_sleep)
            timeout_sleep *= 2
            retry_count -= 1
            if retry_count == 0:
                assert False, ("Route table hasn't been updated after 3 tries."
                               "The latest error was {}").format(e)

    def function(flask_request):
        return {"method": flask_request.method}

    serve.create_backend("echo:v1", function)
    serve.set_traffic("endpoint", {"echo:v1": 1.0})

    resp = requests.get("http://127.0.0.1:8000/api").json()["method"]
    assert resp == "GET"

    resp = requests.post("http://127.0.0.1:8000/api").json()["method"]
    assert resp == "POST"
예제 #15
0
파일: test_api.py 프로젝트: yonkshi/ray
def test_shadow_traffic(serve_instance):
    @ray.remote
    class RequestCounter:
        def __init__(self):
            self.requests = defaultdict(int)

        def record(self, backend):
            self.requests[backend] += 1

        def get(self, backend):
            return self.requests[backend]

    counter = RequestCounter.remote()

    def f():
        ray.get(counter.record.remote("backend1"))
        return "hello"

    def f_shadow_1():
        ray.get(counter.record.remote("backend2"))
        return "oops"

    def f_shadow_2():
        ray.get(counter.record.remote("backend3"))
        return "oops"

    def f_shadow_3():
        ray.get(counter.record.remote("backend4"))
        return "oops"

    serve.create_backend("backend1", f)
    serve.create_backend("backend2", f_shadow_1)
    serve.create_backend("backend3", f_shadow_2)
    serve.create_backend("backend4", f_shadow_3)

    serve.create_endpoint("endpoint", backend="backend1", route="/api")
    serve.shadow_traffic("endpoint", "backend2", 1.0)
    serve.shadow_traffic("endpoint", "backend3", 0.5)
    serve.shadow_traffic("endpoint", "backend4", 0.1)

    start = time.time()
    num_requests = 100
    for _ in range(num_requests):
        assert requests.get("http://127.0.0.1:8000/api").text == "hello"
    print("Finished 100 requests in {}s.".format(time.time() - start))

    def requests_to_backend(backend):
        return ray.get(counter.get.remote(backend))

    def check_requests():
        return all([
            requests_to_backend("backend1") == num_requests,
            requests_to_backend("backend2") == requests_to_backend("backend1"),
            requests_to_backend("backend3") < requests_to_backend("backend2"),
            requests_to_backend("backend4") < requests_to_backend("backend3"),
            requests_to_backend("backend4") > 0,
        ])

    wait_for_condition(check_requests)
예제 #16
0
def test_worker_replica_failure(serve_instance):
    @ray.remote
    class Counter:
        def __init__(self):
            self.count = 0

        def inc_and_get(self):
            self.count += 1
            return self.count

    class Worker:
        # Assumes that two replicas are started. Will hang forever in the
        # constructor for any workers that are restarted.
        def __init__(self, counter):
            self.should_hang = False
            self.index = ray.get(counter.inc_and_get.remote())
            if self.index > 2:
                while True:
                    pass

        def __call__(self, *args):
            return self.index

    counter = Counter.remote()
    serve.create_backend("replica_failure", Worker, counter)
    serve.update_backend_config("replica_failure",
                                BackendConfig(num_replicas=2))
    serve.create_endpoint("replica_failure",
                          backend="replica_failure",
                          route="/replica_failure")

    # Wait until both replicas have been started.
    responses = set()
    start = time.time()
    while time.time() - start < 30:
        time.sleep(0.1)
        response = request_with_retries("/replica_failure", timeout=1).text
        assert response in ["1", "2"]
        responses.add(response)
        if len(responses) > 1:
            break
    else:
        raise TimeoutError("Timed out waiting for replicas after 30s.")

    # Kill one of the replicas.
    handles = _get_worker_handles("replica_failure")
    assert len(handles) == 2
    ray.kill(handles[0], no_restart=False)

    # Check that the other replica still serves requests.
    for _ in range(10):
        while True:
            try:
                # The timeout needs to be small here because the request to
                # the restarting worker will hang.
                request_with_retries("/replica_failure", timeout=0.1)
                break
            except TimeoutError:
                time.sleep(0.1)
예제 #17
0
def test_worker_replica_failure(serve_instance):
    serve.http_proxy.MAX_ACTOR_DEAD_RETRIES = 0
    serve.init()

    class Worker:
        # Assumes that two replicas are started. Will hang forever in the
        # constructor for any workers that are restarted.
        def __init__(self, path):
            self.should_hang = False
            if not os.path.exists(path):
                with open(path, "w") as f:
                    f.write("1")
            else:
                with open(path, "r") as f:
                    num = int(f.read())

                with open(path, "w") as f:
                    if num == 2:
                        self.should_hang = True
                    else:
                        f.write(str(num + 1))

            if self.should_hang:
                while True:
                    pass

        def __call__(self):
            pass

    temp_path = os.path.join(tempfile.gettempdir(),
                             serve.utils.get_random_letters())
    serve.create_backend("replica_failure", Worker, temp_path)
    serve.update_backend_config("replica_failure",
                                BackendConfig(num_replicas=2))
    serve.create_endpoint("replica_failure",
                          backend="replica_failure",
                          route="/replica_failure")

    # Wait until both replicas have been started.
    responses = set()
    while len(responses) == 1:
        responses.add(request_with_retries("/replica_failure", timeout=1).text)
        time.sleep(0.1)

    # Kill one of the replicas.
    handles = _get_worker_handles("replica_failure")
    assert len(handles) == 2
    ray.kill(handles[0], no_restart=False)

    # Check that the other replica still serves requests.
    for _ in range(10):
        while True:
            try:
                # The timeout needs to be small here because the request to
                # the restarting worker will hang.
                request_with_retries("/replica_failure", timeout=0.1)
                break
            except TimeoutError:
                time.sleep(0.1)
예제 #18
0
def test_set_traffic_missing_data(serve_instance):
    endpoint_name = "foobar"
    backend_name = "foo_backend"
    serve.create_backend(backend_name, lambda: 5)
    serve.create_endpoint(endpoint_name, backend=backend_name)
    with pytest.raises(ValueError):
        serve.set_traffic(endpoint_name, {"nonexistent_backend": 1.0})
    with pytest.raises(ValueError):
        serve.set_traffic("nonexistent_endpoint_name", {backend_name: 1.0})
예제 #19
0
def test_repeated_get_handle_cached(serve_instance):
    def f(_):
        return ""

    serve.create_backend("m", f)
    serve.create_endpoint("m", backend="m")

    handle_sets = {serve.get_handle("m") for _ in range(100)}
    assert len(handle_sets) == 1
예제 #20
0
def test_no_route(serve_instance):
    def func(_, i=1):
        return 1

    serve.create_backend("backend:1", func)
    serve.create_endpoint("noroute-endpoint", backend="backend:1")
    service_handle = serve.get_handle("noroute-endpoint")
    result = ray.get(service_handle.remote(i=1))
    assert result == 1
예제 #21
0
def test_reject_duplicate_route(serve_instance):
    def f():
        pass

    serve.create_backend("backend", f)

    route = "/foo"
    serve.create_endpoint("bar", backend="backend", route=route)
    with pytest.raises(ValueError):
        serve.create_endpoint("foo", backend="backend", route=route)
예제 #22
0
def test_serve_metrics(serve_instance):
    @serve.accept_batch
    def batcher(starlette_requests):
        return ["hello"] * len(starlette_requests)

    serve.create_backend("metrics", batcher)
    serve.create_endpoint("metrics", backend="metrics", route="/metrics")

    # send 10 concurrent requests
    url = "http://127.0.0.1:8000/metrics"
    ray.get([block_until_http_ready.remote(url) for _ in range(10)])

    def verify_metrics(do_assert=False):
        try:
            resp = requests.get("http://127.0.0.1:9999").text
        # Requests will fail if we are crashing the controller
        except requests.ConnectionError:
            return False

        expected_metrics = [
            # counter
            "num_router_requests_total",
            "num_http_requests_total",
            "backend_queued_queries_total",
            "backend_request_counter_requests_total",
            "backend_worker_starts_restarts_total",
            # histogram
            "backend_processing_latency_ms_bucket",
            "backend_processing_latency_ms_count",
            "backend_processing_latency_ms_sum",
            "backend_queuing_latency_ms_bucket",
            "backend_queuing_latency_ms_count",
            "backend_queuing_latency_ms_sum",
            # gauge
            "replica_processing_queries",
            "replica_queued_queries",
            # handle
            "serve_handle_request_counter",
            # ReplicaSet
            "backend_queued_queries"
        ]
        for metric in expected_metrics:
            # For the final error round
            if do_assert:
                assert metric in resp
            # For the wait_for_condition
            else:
                if metric not in resp:
                    return False
        return True

    try:
        wait_for_condition(verify_metrics, retry_interval_ms=500)
    except RuntimeError:
        verify_metrics()
예제 #23
0
async def test_system_metric_endpoints(serve_instance):
    def test_error_counter(flask_request):
        1 / 0

    serve.create_backend("m:v1", test_error_counter)
    serve.create_endpoint("test_metrics", backend="m:v1", route="/measure")
    serve.set_traffic("test_metrics", {"m:v1": 1})

    # Check metrics are exposed under http endpoint
    def test_metric_endpoint():
        requests.get("http://127.0.0.1:8000/measure", timeout=5)
        in_memory_metric = requests.get(
            "http://127.0.0.1:8000/-/metrics", timeout=5).json()

        # We don't want to check the values since this check might be retried.
        in_memory_metric_without_values = []
        for m in in_memory_metric:
            m.pop("value")
            in_memory_metric_without_values.append(m)

        target_metrics = [{
            "info": {
                "name": "num_http_requests",
                "type": "MetricType.COUNTER",
                "route": "/measure"
            },
        }, {
            "info": {
                "name": "num_router_requests",
                "type": "MetricType.COUNTER",
                "endpoint": "test_metrics"
            },
        }, {
            "info": {
                "name": "backend_error_counter",
                "type": "MetricType.COUNTER",
                "backend": "m:v1"
            },
        }]

        for target in target_metrics:
            assert target in in_memory_metric_without_values

    success = False
    for _ in range(3):
        try:
            test_metric_endpoint()
            success = True
            break
        except (AssertionError, requests.ReadTimeout):
            # Metrics may not have been propagated yet
            time.sleep(2)
            print("Metric not correct, retrying...")
    if not success:
        test_metric_endpoint()
예제 #24
0
def test_reject_duplicate_endpoint(serve_instance):
    def f():
        pass

    serve.create_backend("backend", f)

    endpoint_name = "foo"
    serve.create_endpoint(endpoint_name, backend="backend", route="/ok")
    with pytest.raises(ValueError):
        serve.create_endpoint(
            endpoint_name, backend="backend", route="/different")
예제 #25
0
def test_controller_failure(serve_instance):
    def function(_):
        return "hello1"

    serve.create_backend("controller_failure:v1", function)
    serve.create_endpoint(
        "controller_failure",
        backend="controller_failure:v1",
        route="/controller_failure")

    assert request_with_retries(
        "/controller_failure", timeout=1).text == "hello1"

    for _ in range(10):
        response = request_with_retries("/controller_failure", timeout=30)
        assert response.text == "hello1"

    ray.kill(serve.api._global_client._controller, no_restart=False)

    for _ in range(10):
        response = request_with_retries("/controller_failure", timeout=30)
        assert response.text == "hello1"

    def function(_):
        return "hello2"

    ray.kill(serve.api._global_client._controller, no_restart=False)

    serve.create_backend("controller_failure:v2", function)
    serve.set_traffic("controller_failure", {"controller_failure:v2": 1.0})

    def check_controller_failure():
        response = request_with_retries("/controller_failure", timeout=30)
        return response.text == "hello2"

    wait_for_condition(check_controller_failure)

    def function(_):
        return "hello3"

    ray.kill(serve.api._global_client._controller, no_restart=False)
    serve.create_backend("controller_failure_2", function)
    ray.kill(serve.api._global_client._controller, no_restart=False)
    serve.create_endpoint(
        "controller_failure_2",
        backend="controller_failure_2",
        route="/controller_failure_2")
    ray.kill(serve.api._global_client._controller, no_restart=False)

    for _ in range(10):
        response = request_with_retries("/controller_failure", timeout=30)
        assert response.text == "hello2"
        response = request_with_retries("/controller_failure_2", timeout=30)
        assert response.text == "hello3"
예제 #26
0
def test_controller_failure(serve_instance):
    serve.init()

    def function():
        return "hello1"

    serve.create_backend("controller_failure:v1", function)
    serve.create_endpoint(
        "controller_failure",
        backend="controller_failure:v1",
        route="/controller_failure")

    assert request_with_retries(
        "/controller_failure", timeout=1).text == "hello1"

    for _ in range(10):
        response = request_with_retries("/controller_failure", timeout=30)
        assert response.text == "hello1"

    ray.kill(serve.api._get_controller(), no_restart=False)

    for _ in range(10):
        response = request_with_retries("/controller_failure", timeout=30)
        assert response.text == "hello1"

    def function():
        return "hello2"

    ray.kill(serve.api._get_controller(), no_restart=False)

    serve.create_backend("controller_failure:v2", function)
    serve.set_traffic("controller_failure", {"controller_failure:v2": 1.0})

    for _ in range(10):
        response = request_with_retries("/controller_failure", timeout=30)
        assert response.text == "hello2"

    def function():
        return "hello3"

    ray.kill(serve.api._get_controller(), no_restart=False)
    serve.create_backend("controller_failure_2", function)
    ray.kill(serve.api._get_controller(), no_restart=False)
    serve.create_endpoint(
        "controller_failure_2",
        backend="controller_failure_2",
        route="/controller_failure_2")
    ray.kill(serve.api._get_controller(), no_restart=False)

    for _ in range(10):
        response = request_with_retries("/controller_failure", timeout=30)
        assert response.text == "hello2"
        response = request_with_retries("/controller_failure_2", timeout=30)
        assert response.text == "hello3"
예제 #27
0
def test_no_route(serve_instance):
    serve.create_endpoint("noroute-endpoint")

    def func(_, i=1):
        return 1

    serve.create_backend(func, "backend:1")
    serve.set_traffic("noroute-endpoint", {"backend:1": 1.0})
    service_handle = serve.get_handle("noroute-endpoint")
    result = ray.get(service_handle.remote(i=1))
    assert result == 1
예제 #28
0
def test_controller_inflight_requests_clear(serve_instance):
    controller = serve.api._global_client._controller
    initial_number_reqs = ray.get(controller._num_pending_goals.remote())

    def function(_):
        return "hello"

    serve.create_backend("tst", function)
    serve.create_endpoint("end_pt", backend="tst")

    assert ray.get(
        controller._num_pending_goals.remote()) - initial_number_reqs == 0
예제 #29
0
def test_nonblocking():
    serve.init()
    serve.create_endpoint("nonblocking", "/nonblocking")

    def function(flask_request):
        return {"method": flask_request.method}

    serve.create_backend(function, "nonblocking:v1")
    serve.set_traffic("nonblocking", {"nonblocking:v1": 1.0})

    resp = requests.get("http://127.0.0.1:8000/nonblocking").json()["method"]
    assert resp == "GET"
예제 #30
0
def test_connect(detached, ray_shutdown):
    # Check that you can call serve.connect() from within a backend for both
    # detached and non-detached instances.
    ray.init(num_cpus=16)
    serve.start(detached=detached)

    def connect_in_backend(_):
        serve.create_backend("backend-ception", connect_in_backend)

    serve.create_backend("connect_in_backend", connect_in_backend)
    serve.create_endpoint("endpoint", backend="connect_in_backend")
    ray.get(serve.get_handle("endpoint").remote())
    assert "backend-ception" in serve.list_backends().keys()