Esempio n. 1
0
    def __init__(self, vertex_config, model_name):
        super().__init__()
        handle_list = list()
        for node_id in vertex_config.keys():
            backend_config = vertex_config[node_id]
            with serve_reference.using_router(node_id):
                serve_reference.create_endpoint(node_id)
                config = serve_reference.BackendConfig(**backend_config)
                if node_id == "prepoc":
                    min_img_size = 224
                    transform = transforms.Compose([
                        transforms.Resize(min_img_size),
                        transforms.ToTensor(),
                        transforms.Normalize(
                            mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225],
                        ),
                    ])
                    serve_reference.create_backend(Transform,
                                                   node_id,
                                                   transform,
                                                   backend_config=config)
                elif node_id == "model":
                    serve_reference.create_backend(
                        PredictModelPytorch,
                        node_id,
                        model_name,
                        True,
                        backend_config=config,
                    )
                serve_reference.link(node_id, node_id)
                handle_list.append(serve_reference.get_handle(node_id))

        self.chain_handle = ChainHandle(handle_list)
Esempio n. 2
0
def test_killing_replicas(serve_instance):
    class Simple:
        def __init__(self):
            self.count = 0

        def __call__(self, flask_request, temp=None):
            return temp

    serve_reference.create_endpoint("simple", "/simple")
    b_config = BackendConfig(num_replicas=3, num_cpus=2)
    serve_reference.create_backend(Simple,
                                   "simple:v1",
                                   backend_config=b_config)
    global_state = serve_reference.api._get_global_state()
    old_replica_tag_list = global_state.backend_table.list_replicas(
        "simple:v1")

    bnew_config = serve_reference.get_backend_config("simple:v1")
    # change the config
    bnew_config.num_cpus = 1
    # set the config
    serve_reference.set_backend_config("simple:v1", bnew_config)
    new_replica_tag_list = global_state.backend_table.list_replicas(
        "simple:v1")
    global_state.refresh_actor_handle_cache()
    new_all_tag_list = list(global_state.actor_handle_cache.keys())

    # the new_replica_tag_list must be subset of all_tag_list
    assert set(new_replica_tag_list) <= set(new_all_tag_list)

    # the old_replica_tag_list must not be subset of all_tag_list
    assert not set(old_replica_tag_list) <= set(new_all_tag_list)
Esempio n. 3
0
def test_not_killing_replicas(serve_instance):
    class BatchSimple:
        def __init__(self):
            self.count = 0

        @serve_reference.accept_batch
        def __call__(self, flask_request, temp=None):
            batch_size = serve_reference.context.batch_size
            return [1] * batch_size

    serve_reference.create_endpoint("bsimple", "/bsimple")
    b_config = BackendConfig(num_replicas=3, max_batch_size=2)
    serve_reference.create_backend(BatchSimple,
                                   "bsimple:v1",
                                   backend_config=b_config)
    global_state = serve_reference.api._get_global_state()
    old_replica_tag_list = global_state.backend_table.list_replicas(
        "bsimple:v1")

    bnew_config = serve_reference.get_backend_config("bsimple:v1")
    # change the config
    bnew_config.max_batch_size = 5
    # set the config
    serve_reference.set_backend_config("bsimple:v1", bnew_config)
    new_replica_tag_list = global_state.backend_table.list_replicas(
        "bsimple:v1")
    global_state.refresh_actor_handle_cache()
    new_all_tag_list = list(global_state.actor_handle_cache.keys())

    # the old and new replica tag list should be identical
    # and should be subset of all_tag_list
    assert set(old_replica_tag_list) <= set(new_all_tag_list)
    assert set(old_replica_tag_list) == set(new_replica_tag_list)
Esempio n. 4
0
def test_e2e(serve_instance):
    serve_reference.init()  # so we have access to global state
    serve_reference.create_endpoint("endpoint",
                                    "/api",
                                    methods=["GET", "POST"])
    result = serve_reference.api._get_global_state().route_table.list_service()
    assert result["/api"] == "endpoint"

    retry_count = 5
    timeout_sleep = 0.5
    while True:
        try:
            resp = requests.get("http://127.0.0.1:8000/-/routes",
                                timeout=0.5).json()
            assert resp == {"/api": ["endpoint", ["GET", "POST"]]}
            break
        except Exception as e:
            time.sleep(timeout_sleep)
            timeout_sleep *= 2
            retry_count -= 1
            if retry_count == 0:
                assert False, ("Route table hasn't been updated after 3 tries."
                               "The latest error was {}").format(e)

    def function(flask_request):
        return {"method": flask_request.method}

    serve_reference.create_backend(function, "echo:v1")
    serve_reference.link("endpoint", "echo:v1")

    resp = requests.get("http://127.0.0.1:8000/api").json()["method"]
    assert resp == "GET"

    resp = requests.post("http://127.0.0.1:8000/api").json()["method"]
    assert resp == "POST"
Esempio n. 5
0
def main(batch_size, num_warmups, num_queries, return_type):
    serve_reference.init()

    def noop(_, *args, **kwargs):
        bs = serve_reference.context.batch_size
        assert (bs == batch_size
                ), f"worker received {bs} which is not what expected"
        result = ""

        if return_type == "torch":
            result = torch.zeros((3, 224, 224))

        if bs is None:  # No batching
            return result
        else:
            return [result] * bs

    if batch_size:
        noop = serve_reference.accept_batch(noop)

    with serve_reference.using_router("noop"):
        serve_reference.create_endpoint("noop", "/noop")
        config = serve_reference.BackendConfig(max_batch_size=batch_size)
        serve_reference.create_backend(noop, "noop", backend_config=config)
        serve_reference.link("noop", "noop")
        handle = serve_reference.get_handle("noop")

    latency = []
    for i in tqdm(range(num_warmups + num_queries)):
        if i == num_warmups:
            serve_reference.clear_trace()

        start = time.perf_counter()

        if not batch_size:
            ray.get(
                # This is how to pass a higher level metadata to the tracing
                # context
                handle.options(tracing_metadata={
                    "demo": "pipeline-id"
                }).remote())
        else:
            ray.get(handle.enqueue_batch(val=[1] * batch_size))
            # ray.get([handle.remote() for _ in range(batch_size)])

        end = time.perf_counter()
        latency.append(end - start)

    # Remove initial samples
    latency = latency[num_warmups:]

    series = pd.Series(latency) * 1000
    print("Latency for single noop backend (ms)")
    print(series.describe(percentiles=[0.5, 0.9, 0.95, 0.99]))

    _, trace_file = tempfile.mkstemp(suffix=".json")
    with open(trace_file, "w") as f:
        json.dump(serve_reference.get_trace(), f)
    print(f"trace file written to {trace_file}")
Esempio n. 6
0
    def __init__(self, max_batch_size, pipeline_length):
        self.plength = pipeline_length
        self.handles = list()

        for index in range(self.plength):
            node_id = f"service-{index}"
            with serve_reference.using_router(node_id):
                serve_reference.create_endpoint(node_id)
                config = serve_reference.BackendConfig(
                    max_batch_size=max_batch_size, num_replicas=1)
                serve_reference.create_backend(noop,
                                               node_id,
                                               backend_config=config)
                serve_reference.link(node_id, node_id)
                self.handles.append(serve_reference.get_handle(node_id))
Esempio n. 7
0
def main():

    TAG = "Resnet18"
    min_img_size = 224
    transform = transforms.Compose([
        transforms.Resize(min_img_size),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
    ])

    for num_replica in range(1, 9):
        # initialize serve
        serve_reference.init(start_server=False)

        serve_handle = None
        with serve_reference.using_router(TAG):
            serve_reference.create_endpoint(TAG)
            config = serve_reference.BackendConfig(max_batch_size=8,
                                                   num_replicas=num_replica,
                                                   num_gpus=1)
            serve_reference.create_backend(
                PredictModelPytorch,
                TAG,
                transform,
                "resnet18",
                True,
                backend_config=config,
            )
            serve_reference.link(TAG, TAG)
            serve_handle = serve_reference.get_handle(TAG)

        img = base64.b64encode(open("elephant.jpg", "rb").read())

        # warmup
        ready_refs, _ = ray.wait(
            [serve_handle.remote(data=img) for _ in range(200)], 200)
        complete_oids, _ = ray.wait(ray.get(ready_refs), num_returns=200)
        del ready_refs
        del complete_oids

        qps = throughput_calculation(serve_handle, {"data": img}, 2000)
        print(f"[Resnet18] Batch Size: 8 Replica: {num_replica} "
              f"Throughput: {qps} QPS")

        serve_reference.shutdown()
Esempio n. 8
0
def test_no_route(serve_instance):
    serve_reference.create_endpoint("noroute-endpoint")
    global_state = serve_reference.api._get_global_state()

    result = global_state.route_table.list_service(include_headless=True)
    assert result[NO_ROUTE_KEY] == ["noroute-endpoint"]

    without_headless_result = global_state.route_table.list_service()
    assert NO_ROUTE_KEY not in without_headless_result

    def func(_, i=1):
        return 1

    serve_reference.create_backend(func, "backend:1")
    serve_reference.link("noroute-endpoint", "backend:1")
    service_handle = serve_reference.get_handle("noroute-endpoint")
    result = ray.get(service_handle.remote(i=1))
    assert result == 1
Esempio n. 9
0
def test_scaling_replicas(serve_instance):
    class Counter:
        def __init__(self):
            self.count = 0

        def __call__(self, _):
            self.count += 1
            return self.count

    serve_reference.create_endpoint("counter", "/increment")

    # Keep checking the routing table until /increment is populated
    while ("/increment"
           not in requests.get("http://127.0.0.1:8000/-/routes").json()):
        time.sleep(0.2)

    b_config = BackendConfig(num_replicas=2)
    serve_reference.create_backend(Counter,
                                   "counter:v1",
                                   backend_config=b_config)
    serve_reference.link("counter", "counter:v1")

    counter_result = []
    for _ in range(10):
        resp = requests.get("http://127.0.0.1:8000/increment").json()
        counter_result.append(resp)

    # If the load is shared among two replicas. The max result cannot be 10.
    assert max(counter_result) < 10

    b_config = serve_reference.get_backend_config("counter:v1")
    b_config.num_replicas = 1
    serve_reference.set_backend_config("counter:v1", b_config)

    counter_result = []
    for _ in range(10):
        resp = requests.get("http://127.0.0.1:8000/increment").json()
        counter_result.append(resp)
    # Give some time for a replica to spin down. But majority of the request
    # should be served by the only remaining replica.
    assert max(counter_result) - min(counter_result) > 6
Esempio n. 10
0
def test_batching_exception(serve_instance):
    class NoListReturned:
        def __init__(self):
            self.count = 0

        @serve_reference.accept_batch
        def __call__(self, flask_request, temp=None):
            batch_size = serve_reference.context.batch_size
            return batch_size

    serve_reference.create_endpoint("exception-test", "/noListReturned")
    # set the max batch size
    b_config = BackendConfig(max_batch_size=5)
    serve_reference.create_backend(NoListReturned,
                                   "exception:v1",
                                   backend_config=b_config)
    serve_reference.link("exception-test", "exception:v1")

    handle = serve_reference.get_handle("exception-test")
    with pytest.raises(ray.exceptions.RayTaskError):
        assert ray.get(handle.remote(temp=1))
Esempio n. 11
0
def test_batching(serve_instance):
    class BatchingExample:
        def __init__(self):
            self.count = 0

        @serve_reference.accept_batch
        def __call__(self, flask_request, temp=None):
            self.count += 1
            batch_size = serve_reference.context.batch_size
            return [self.count] * batch_size

    serve_reference.create_endpoint("counter1", "/increment")

    # Keep checking the routing table until /increment is populated
    while ("/increment"
           not in requests.get("http://127.0.0.1:8000/-/routes").json()):
        time.sleep(0.2)

    # set the max batch size
    b_config = BackendConfig(max_batch_size=5)
    serve_reference.create_backend(BatchingExample,
                                   "counter:v11",
                                   backend_config=b_config)
    serve_reference.link("counter1", "counter:v11")

    future_list = []
    handle = serve_reference.get_handle("counter1")
    for _ in range(20):
        f = handle.remote(temp=1)
        future_list.append(f)

    counter_result = ray.get(future_list)
    # since count is only updated per batch of queries
    # If there atleast one __call__ fn call with batch size greater than 1
    # counter result will always be less than 20
    assert max(counter_result) < 20
Esempio n. 12
0
def main(num_replicas, method):
    for node_id in ["up", "down"]:
        with serve_reference.using_router(node_id):
            serve_reference.create_endpoint(node_id)
            config = serve_reference.BackendConfig(max_batch_size=1,
                                                   num_replicas=num_replicas)
            serve_reference.create_backend(noop,
                                           node_id,
                                           backend_config=config)
            serve_reference.link(node_id, node_id)

    with serve_reference.using_router("up"):
        up_handle = serve_reference.get_handle("up")
    with serve_reference.using_router("down"):
        down_handle = serve_reference.get_handle("down")

    start = time.perf_counter()
    oids = []

    if method == "chain":
        for i in range(num_queries):
            r = up_handle.options(tracing_metadata={
                "pipeline-id": i
            }).remote(sleep_time=0.01, data=image_data)
            r = down_handle.options(tracing_metadata={
                "pipeline-id": i
            }).remote(
                sleep_time=0.02,
                data=r  # torch tensor
            )
            oids.append(r)
    elif method == "group":
        res = [
            up_handle.options(tracing_metadata={
                "pipeline-id": i
            }).remote(sleep_time=0.01, data=image_data)
            for i in range(num_queries)
        ]
        oids = [
            down_handle.options(tracing_metadata={
                "pipeline-id": i
            }).remote(
                sleep_time=0.02,
                data=d  # torch tensor
            ) for i, d in enumerate(res)
        ]
    else:
        raise RuntimeError("Unreachable")
    print(f"Submission time {time.perf_counter() - start}")

    ray.wait(oids, len(oids))
    end = time.perf_counter()

    duration = end - start
    qps = num_queries / duration

    print(f"Throughput {qps}")

    _, trace_file = tempfile.mkstemp(suffix=".json")
    with open(trace_file, "w") as f:
        json.dump(serve_reference.get_trace(), f)
    print(f"trace file written to {trace_file}")