Beispiel #1
0
    def get_http_proxy_names(self) -> bytes:
        """Returns the http_proxy actor name list serialized by protobuf."""
        from ray.serve.generated.serve_pb2 import ActorNameList

        actor_name_list = ActorNameList(
            names=self.http_state.get_http_proxy_names().values())
        return actor_name_list.SerializeToString()
Beispiel #2
0
def test_listen_for_change_java(serve_instance):
    host = ray.remote(LongPollHost).remote()
    ray.get(host.notify_changed.remote("key_1", 999))
    request_1 = {"keys_to_snapshot_ids": {"key_1": -1}}
    object_ref = host.listen_for_change_java.remote(
        LongPollRequest(**request_1).SerializeToString())
    result_1: bytes = ray.get(object_ref)
    poll_result_1 = LongPollResult.FromString(result_1)
    assert set(poll_result_1.updated_objects.keys()) == {"key_1"}
    assert poll_result_1.updated_objects["key_1"].object_snapshot.decode(
    ) == "999"
    request_2 = {"keys_to_snapshot_ids": {"ROUTE_TABLE": -1}}
    endpoints: Dict[EndpointTag, EndpointInfo] = dict()
    endpoints["deployment_name"] = EndpointInfo(route="/test/xlang/poll")
    endpoints["deployment_name1"] = EndpointInfo(route="/test/xlang/poll1")
    ray.get(
        host.notify_changed.remote(LongPollNamespace.ROUTE_TABLE, endpoints))
    object_ref_2 = host.listen_for_change_java.remote(
        LongPollRequest(**request_2).SerializeToString())
    result_2: bytes = ray.get(object_ref_2)
    poll_result_2 = LongPollResult.FromString(result_2)
    assert set(poll_result_2.updated_objects.keys()) == {"ROUTE_TABLE"}
    endpoint_set = EndpointSet.FromString(
        poll_result_2.updated_objects["ROUTE_TABLE"].object_snapshot)
    assert set(endpoint_set.endpoints.keys()) == {
        "deployment_name", "deployment_name1"
    }
    assert endpoint_set.endpoints[
        "deployment_name"].route == "/test/xlang/poll"

    request_3 = {
        "keys_to_snapshot_ids": {
            "(RUNNING_REPLICAS,deployment_name)": -1
        }
    }
    replicas = [
        RunningReplicaInfo(
            deployment_name="deployment_name",
            replica_tag=str(i),
            actor_handle=host,
            max_concurrent_queries=1,
        ) for i in range(2)
    ]
    ray.get(
        host.notify_changed.remote(
            (LongPollNamespace.RUNNING_REPLICAS, "deployment_name"), replicas))
    object_ref_3 = host.listen_for_change_java.remote(
        LongPollRequest(**request_3).SerializeToString())
    result_3: bytes = ray.get(object_ref_3)
    poll_result_3 = LongPollResult.FromString(result_3)
    replica_name_list = ActorNameList.FromString(
        poll_result_3.updated_objects["(RUNNING_REPLICAS,deployment_name)"].
        object_snapshot)
    assert replica_name_list.names == ["SERVE_REPLICA::0", "SERVE_REPLICA::1"]
Beispiel #3
0
 def _object_snapshot_to_proto_bytes(
     self, key: KeyType, object_snapshot: Any
 ) -> bytes:
     if key == LongPollNamespace.ROUTE_TABLE:
         # object_snapshot is Dict[EndpointTag, EndpointInfo]
         xlang_endpoints = {
             endpoint_tag: EndpointInfoProto(route=endpoint_info.route)
             for endpoint_tag, endpoint_info in object_snapshot.items()
         }
         return EndpointSet(endpoints=xlang_endpoints).SerializeToString()
     elif isinstance(key, tuple) and key[0] == LongPollNamespace.RUNNING_REPLICAS:
         # object_snapshot is List[RunningReplicaInfo]
         actor_name_list = [
             f"{ReplicaName.prefix}{format_actor_name(replica_info.replica_tag)}"
             for replica_info in object_snapshot
         ]
         return ActorNameList(names=actor_name_list).SerializeToString()
     else:
         return str.encode(str(object_snapshot))
Beispiel #4
0
def test_fixed_number_proxies(ray_cluster):
    cluster = ray_cluster
    head_node = cluster.add_node(num_cpus=4)
    cluster.add_node(num_cpus=4)
    cluster.add_node(num_cpus=4)

    ray.init(head_node.address)
    node_ids = ray._private.state.node_ids()
    assert len(node_ids) == 3

    with pytest.raises(
        pydantic.ValidationError,
        match="you must specify the `fixed_number_replicas` parameter.",
    ):
        serve.start(
            http_options={
                "location": "FixedNumber",
            }
        )

    serve.start(
        http_options={
            "port": new_port(),
            "location": "FixedNumber",
            "fixed_number_replicas": 2,
        }
    )

    # Only the controller and two http proxy should be started.
    controller_handle = get_global_client()._controller
    node_to_http_actors = ray.get(controller_handle.get_http_proxies.remote())
    assert len(node_to_http_actors) == 2

    proxy_names_bytes = ray.get(controller_handle.get_http_proxy_names.remote())
    proxy_names = ActorNameList.FromString(proxy_names_bytes)
    assert len(proxy_names.names) == 2

    serve.shutdown()
    ray.shutdown()
    cluster.shutdown()