Exemple #1
0
def test_proxy_manager_lifecycle(shutdown_only):
    """
    Creates a ProxyManager and tests basic handling of the lifetime of a
    specific RayClient Server. It checks the following properties:
    1. The SpecificServer is created using the first port.
    2. The SpecificServer comes alive and has a log associated with it.
    3. The SpecificServer destructs itself when no client connects.
    4. The ProxyManager returns the port of the destructed SpecificServer.
    """
    proxier.CHECK_PROCESS_INTERVAL_S = 1
    os.environ["TIMEOUT_FOR_SPECIFIC_SERVER_S"] = "5"
    pm, free_ports = start_ray_and_proxy_manager(n_ports=2)
    client = "client1"

    pm.create_specific_server(client)
    assert pm.start_specific_server(client, JobConfig())
    # Channel should be ready and corresponding to an existing server
    grpc.channel_ready_future(pm.get_channel(client)).result(timeout=5)

    proc = pm._get_server_for_client(client)
    assert proc.port == free_ports[0], f"Free Ports are: {free_ports}"

    log_files_path = os.path.join(pm.node.get_session_dir_path(), "logs",
                                  "ray_client_server*")
    files = glob(log_files_path)
    assert any(str(free_ports[0]) in f for f in files)

    proc.process_handle_future.result().process.wait(10)
    # Wait for reconcile loop
    time.sleep(2)

    assert len(pm._free_ports) == 2
    assert pm._get_unused_port() == free_ports[1]
Exemple #2
0
def test_pip_job_config(shutdown_only, pip_as_str, tmp_path):
    """Tests dynamic installation of pip packages in a task's runtime env."""

    if pip_as_str:
        d = tmp_path / "pip_requirements"
        d.mkdir()
        p = d / "requirements.txt"
        requirements_txt = """
        pip-install-test==0.5
        """
        p.write_text(requirements_txt)
        runtime_env = {"pip": str(p)}
    else:
        runtime_env = {"pip": ["pip-install-test==0.5"]}

    ray.init(job_config=JobConfig(runtime_env=runtime_env))

    @ray.remote
    def f():
        import pip_install_test  # noqa
        return True

    with pytest.raises(ModuleNotFoundError):
        # Ensure pip-install-test is not installed on the test machine
        import pip_install_test  # noqa
    assert ray.get(f.remote())
Exemple #3
0
def upload_runtime_env_package_if_needed(job_config: JobConfig) -> None:
    """Upload runtime env if it's not there.

    It'll check whether the runtime environment exists in the cluster or not.
    If it doesn't exist, a package will be created based on the working
    directory and modules defined in job config. The package will be
    uploaded to the cluster after this.

    Args:
        job_config (JobConfig): The job config of driver.
    """
    assert _internal_kv_initialized()
    pkg_uris = job_config.get_runtime_env_uris()
    for pkg_uri in pkg_uris:
        if not package_exists(pkg_uri):
            file_path = _get_local_path(pkg_uri)
            pkg_file = Path(file_path)
            working_dir = job_config.runtime_env.get("working_dir")
            py_modules = job_config.runtime_env.get("py_modules")
            excludes = job_config.runtime_env.get("excludes") or []
            logger.info(f"{pkg_uri} doesn't exist. Create new package with"
                        f" {working_dir} and {py_modules}")
            if not pkg_file.exists():
                create_project_package(working_dir, py_modules, excludes,
                                       file_path)
            # Push the data to remote storage
            pkg_size = push_package(pkg_uri, pkg_file)
            logger.info(f"{pkg_uri} has been pushed with {pkg_size} bytes")
Exemple #4
0
def test_pip_job_config(shutdown_only, pip_as_str):
    """Tests dynamic installation of pip packages in a task's runtime env."""

    ray_wheel_path = os.path.join("/ray/.whl", get_wheel_filename())
    if pip_as_str:
        requirements_txt = f"""
        {ray_wheel_path}
        pip-install-test==0.5
        opentelemetry-api==1.0.0rc1
        opentelemetry-sdk==1.0.0rc1
        """
        runtime_env = {"pip": requirements_txt}
    else:
        runtime_env = {
            "pip": [
                ray_wheel_path, "pip-install-test==0.5",
                "opentelemetry-api==1.0.0rc1", "opentelemetry-sdk==1.0.0rc1"
            ]
        }

    ray.init(job_config=JobConfig(runtime_env=runtime_env))

    @ray.remote
    def f():
        import pip_install_test  # noqa
        return True

    with pytest.raises(ModuleNotFoundError):
        # Ensure pip-install-test is not installed on the test machine
        import pip_install_test  # noqa
    assert ray.get(f.remote())
Exemple #5
0
def test_prepare_runtime_init_req_modified_job():
    """
    Check that `prepare_runtime_init_req` properly extracts the JobConfig and
    modifies it according to `ray_client_server_env_prep`.
    """
    job_config = JobConfig(
        runtime_env={"env_vars": {"KEY": "VALUE"}}, ray_namespace="abc"
    )
    init_req = ray_client_pb2.DataRequest(
        init=ray_client_pb2.InitRequest(
            job_config=pickle.dumps(job_config),
            ray_init_kwargs=json.dumps({"log_to_driver": False}),
        )
    )

    def modify_namespace(job_config: JobConfig):
        job_config.set_ray_namespace("test_value")
        return job_config

    with patch.object(proxier, "ray_client_server_env_prep", modify_namespace):
        req, new_config = proxier.prepare_runtime_init_req(init_req)

    assert new_config.ray_namespace == "test_value"
    assert pickle.loads(req.init.job_config).serialize() == new_config.serialize()
    assert json.loads(req.init.ray_init_kwargs) == {"log_to_driver": False}
Exemple #6
0
def test_conda_create_job_config(shutdown_only):
    """Tests dynamic conda env creation in a runtime env in the JobConfig."""

    ray_wheel_filename = get_wheel_filename()
    # E.g. 3.6.13
    python_micro_version_dots = ".".join(map(str, sys.version_info[:3]))
    ray_wheel_path = os.path.join("/ray/.whl", ray_wheel_filename)

    runtime_env = {
        "conda": {
            "dependencies": [
                f"python={python_micro_version_dots}", "pip", {
                    "pip": [
                        ray_wheel_path, "pip-install-test==0.5",
                        "opentelemetry-api==1.0.0rc1",
                        "opentelemetry-sdk==1.0.0rc1"
                    ]
                }
            ]
        }
    }
    ray.init(job_config=JobConfig(runtime_env=runtime_env))

    @ray.remote
    def f():
        import pip_install_test  # noqa
        return True

    with pytest.raises(ModuleNotFoundError):
        # Ensure pip-install-test is not installed on the test machine
        import pip_install_test  # noqa
    assert ray.get(f.remote())
Exemple #7
0
def test_proxy_manager_lifecycle(shutdown_only):
    """
    Creates a ProxyManager and tests basic handling of the lifetime of a
    specific RayClient Server. It checks the following properties:
    1. The SpecificServer is created using the first port.
    2. The SpecificServer comes alive.
    3. The SpecificServer destructs itself when no client connects.
    4. The ProxyManager returns the port of the destructed SpecificServer.
    """
    ray_instance = ray.init()
    proxier.CHECK_PROCESS_INTERVAL_S = 1
    os.environ["TIMEOUT_FOR_SPECIFIC_SERVER_S"] = "5"
    pm = proxier.ProxyManager(ray_instance["redis_address"],
                              ray_instance["session_dir"])
    pm._free_ports = [45000, 45001]
    client = "client1"

    assert pm.start_specific_server(client, JobConfig())
    # Channel should be ready and corresponding to an existing server
    grpc.channel_ready_future(pm.get_channel(client)).result(timeout=5)

    proc = pm._get_server_for_client(client)
    assert proc.port == 45000

    proc.process_handle().process.wait(10)
    # Wait for reconcile loop
    time.sleep(2)

    assert len(pm._free_ports) == 2
    assert pm._get_unused_port() == 45001
def test_conda_create_job_config(shutdown_only):
    """Tests dynamic conda env creation in a runtime env in the JobConfig."""

    runtime_env = {
        "conda": {
            "dependencies": [
                "pip", {
                    "pip": [
                        "pip-install-test==0.5", "opentelemetry-api==1.0.0rc1",
                        "opentelemetry-sdk==1.0.0rc1"
                    ]
                }
            ]
        }
    }
    ray.init(job_config=JobConfig(runtime_env=runtime_env))

    @ray.remote
    def f():
        import pip_install_test  # noqa
        return True

    with pytest.raises(ModuleNotFoundError):
        # Ensure pip-install-test is not installed on the test machine
        import pip_install_test  # noqa
    assert ray.get(f.remote())
def test_default_actor_lifetime(default_actor_lifetime, child_actor_lifetime):
    @ray.remote
    class OwnerActor:
        def create_child_actor(self, child_actor_lifetime):
            if child_actor_lifetime is None:
                self._child_actor = ChildActor.remote()
            else:
                self._child_actor = ChildActor.options(
                    lifetime=child_actor_lifetime).remote()
            assert "ok" == ray.get(self._child_actor.ready.remote())
            return self._child_actor

        def get_pid(self):
            return os.getpid()

        def ready(self):
            return "ok"

    @ray.remote
    class ChildActor:
        def ready(self):
            return "ok"

    if default_actor_lifetime is not None:
        ray.init(job_config=JobConfig(
            default_actor_lifetime=default_actor_lifetime))
    else:
        ray.init()

    # 1. create owner and invoke create_child_actor.
    owner = OwnerActor.remote()
    child = ray.get(owner.create_child_actor.remote(child_actor_lifetime))
    assert "ok" == ray.get(child.ready.remote())
    # 2. Kill owner and make sure it's dead.
    owner_pid = ray.get(owner.get_pid.remote())
    os.kill(owner_pid, SIGKILL)
    wait_for_pid_to_exit(owner_pid)

    # 3. Assert child state.

    def is_child_actor_dead():
        try:
            ray.get(child.ready.remote())
            return False
        except RayActorError:
            return True

    actual_lifetime = default_actor_lifetime
    if child_actor_lifetime is not None:
        actual_lifetime = child_actor_lifetime

    assert actual_lifetime is not None
    if actual_lifetime == "detached":
        time.sleep(5)
        assert not is_child_actor_dead()
    else:
        wait_for_condition(is_child_actor_dead, timeout=5)

    ray.shutdown()
def test_prepare_runtime_init_req_no_modification():
    """
    Check that `prepare_runtime_init_req` properly extracts the JobConfig.
    """
    job_config = JobConfig(runtime_env={"env_vars": {
        "KEY": "VALUE"
    }},
                           ray_namespace="abc")
    init_req = ray_client_pb2.DataRequest(init=ray_client_pb2.InitRequest(
        job_config=pickle.dumps(job_config),
        ray_init_kwargs=json.dumps({"log_to_driver": False})), )
    req, new_config = proxier.prepare_runtime_init_req(init_req)
    assert new_config.serialize() == job_config.serialize()
    assert isinstance(req, ray_client_pb2.DataRequest)
    assert pickle.loads(
        req.init.job_config).serialize() == new_config.serialize()
    assert json.loads(req.init.ray_init_kwargs) == {"log_to_driver": False}
Exemple #11
0
    def start_specific_server(self, client_id: str, job_config: JobConfig) -> bool:
        """
        Start up a RayClient Server for an incoming client to
        communicate with. Returns whether creation was successful.
        """
        specific_server = self._get_server_for_client(client_id)
        assert specific_server, f"Server has not been created for: {client_id}"

        output, error = self.node.get_log_file_handles(
            f"ray_client_server_{specific_server.port}", unique=True
        )

        serialized_runtime_env = job_config.get_serialized_runtime_env()
        if not serialized_runtime_env or serialized_runtime_env == "{}":
            # TODO(edoakes): can we just remove this case and always send it
            # to the agent?
            serialized_runtime_env_context = RuntimeEnvContext().serialize()
        else:
            serialized_runtime_env_context = self._create_runtime_env(
                serialized_runtime_env=serialized_runtime_env,
                specific_server=specific_server,
            )

        proc = start_ray_client_server(
            self.address,
            self.node.node_ip_address,
            specific_server.port,
            stdout_file=output,
            stderr_file=error,
            fate_share=self.fate_share,
            server_type="specific-server",
            serialized_runtime_env_context=serialized_runtime_env_context,
            redis_password=self._redis_password,
        )

        # Wait for the process being run transitions from the shim process
        # to the actual RayClient Server.
        pid = proc.process.pid
        if sys.platform != "win32":
            psutil_proc = psutil.Process(pid)
        else:
            psutil_proc = None
        # Don't use `psutil` on Win32
        while psutil_proc is not None:
            if proc.process.poll() is not None:
                logger.error(f"SpecificServer startup failed for client: {client_id}")
                break
            cmd = psutil_proc.cmdline()
            if _match_running_client_server(cmd):
                break
            logger.debug("Waiting for Process to reach the actual client server.")
            time.sleep(0.5)
        specific_server.set_result(proc)
        logger.info(
            f"SpecificServer started on port: {specific_server.port} "
            f"with PID: {pid} for client: {client_id}"
        )
        return proc.process.poll() is None
Exemple #12
0
    def start_specific_server(self, client_id: str,
                              job_config: JobConfig) -> bool:
        """
        Start up a RayClient Server for an incoming client to
        communicate with. Returns whether creation was successful.
        """
        specific_server = self._get_server_for_client(client_id)
        assert specific_server, f"Server has not been created for: {client_id}"

        output, error = self.node.get_log_file_handles(
            f"ray_client_server_{specific_server.port}", unique=True)

        serialized_runtime_env = job_config.get_serialized_runtime_env()
        runtime_env = json.loads(serialized_runtime_env)

        # Set up the working_dir for the server.
        # TODO(edoakes): this should go be unified with the worker setup code
        # by going through the runtime_env agent.
        context = RuntimeEnvContext(
            env_vars=runtime_env.get("env_vars"),
            resources_dir=self.node.get_runtime_env_dir_path())
        working_dir_pkg.setup_working_dir(runtime_env, context)

        proc = start_ray_client_server(
            self.redis_address,
            specific_server.port,
            stdout_file=output,
            stderr_file=error,
            fate_share=self.fate_share,
            server_type="specific-server",
            serialized_runtime_env=serialized_runtime_env,
            serialized_runtime_env_context=context.serialize(),
            redis_password=self._redis_password)

        # Wait for the process being run transitions from the shim process
        # to the actual RayClient Server.
        pid = proc.process.pid
        if sys.platform != "win32":
            psutil_proc = psutil.Process(pid)
        else:
            psutil_proc = None
        # Don't use `psutil` on Win32
        while psutil_proc is not None:
            if proc.process.poll() is not None:
                logger.error(
                    f"SpecificServer startup failed for client: {client_id}")
                break
            cmd = psutil_proc.cmdline()
            if _match_running_client_server(cmd):
                break
            logger.debug(
                "Waiting for Process to reach the actual client server.")
            time.sleep(0.5)
        specific_server.set_result(proc)
        logger.info(f"SpecificServer started on port: {specific_server.port} "
                    f"with PID: {pid} for client: {client_id}")
        return proc.process.poll() is None
Exemple #13
0
def test_controller_starts_java_replica(shutdown_only):  # noqa: F811
    ray.init(
        num_cpus=8,
        namespace="default_test_namespace",
        # A dummy code search path to enable cross language.
        job_config=JobConfig(code_search_path=["."]),
    )
    client = serve.start(detached=True)

    controller = client._controller

    config = DeploymentConfig()
    config.deployment_language = JAVA
    config.is_cross_language = True

    replica_config = ReplicaConfig.create(
        "io.ray.serve.util.ExampleEchoDeployment",
        init_args=["my_prefix "],
    )

    # Deploy it
    deployment_name = "my_java"
    updating = ray.get(
        controller.deploy.remote(
            name=deployment_name,
            deployment_config_proto_bytes=config.to_proto_bytes(),
            replica_config_proto_bytes=replica_config.to_proto_bytes(),
            route_prefix=None,
            deployer_job_id=ray.get_runtime_context().job_id,
        )
    )
    assert updating
    client._wait_for_deployment_healthy(deployment_name)

    # Let's try to call it!
    all_handles = ray.get(controller._all_running_replicas.remote())
    backend_handle = all_handles["my_java"][0].actor_handle
    out = backend_handle.handleRequest.remote(
        RequestMetadata(
            request_id="id-1",
            endpoint="endpoint",
            call_method="call",
        ).SerializeToString(),
        RequestWrapper(body=msgpack_serialize("hello")).SerializeToString(),
    )
    assert ray.get(out) == "my_prefix hello"

    handle = serve.get_deployment("my_java").get_handle()
    handle_out = handle.remote("hello handle")
    assert ray.get(handle_out) == "my_prefix hello handle"

    ray.get(controller.delete_deployment.remote(deployment_name))
    client._wait_for_deployment_deleted(deployment_name)
Exemple #14
0
def prepare_runtime_init_req(
    req: ray_client_pb2.InitRequest
) -> Tuple[ray_client_pb2.InitRequest, JobConfig]:
    """
    Extract JobConfig and possibly mutate InitRequest before it is passed to
    the specific RayClient Server.
    """
    job_config = JobConfig()
    if req.job_config:
        import pickle
        job_config = pickle.loads(req.job_config)

    return (req, job_config)
Exemple #15
0
    def _server_init(
        self, job_config: JobConfig, ray_init_kwargs: Optional[Dict[str, Any]] = None
    ):
        """Initialize the server"""
        if ray_init_kwargs is None:
            ray_init_kwargs = {}
        try:
            if job_config is None:
                serialized_job_config = None
            else:
                with tempfile.TemporaryDirectory() as tmp_dir:
                    runtime_env = job_config.runtime_env or {}
                    runtime_env = upload_py_modules_if_needed(
                        runtime_env, tmp_dir, logger=logger
                    )
                    runtime_env = upload_working_dir_if_needed(
                        runtime_env, tmp_dir, logger=logger
                    )
                    # Remove excludes, it isn't relevant after the upload step.
                    runtime_env.pop("excludes", None)
                    job_config.set_runtime_env(runtime_env, validate=True)

                serialized_job_config = pickle.dumps(job_config)

            response = self.data_client.Init(
                ray_client_pb2.InitRequest(
                    job_config=serialized_job_config,
                    ray_init_kwargs=json.dumps(ray_init_kwargs),
                    reconnect_grace_period=self._reconnect_grace_period,
                )
            )
            if not response.ok:
                raise ConnectionAbortedError(
                    f"Initialization failure from server:\n{response.msg}"
                )

        except grpc.RpcError as e:
            raise decode_exception(e)
Exemple #16
0
def test_job_config_conda_env(conda_envs):
    import tensorflow as tf

    tf_version = "2.2.0"

    @ray.remote
    def get_conda_env():
        return tf.__version__

    for tf_version in ["2.2.0", "2.3.0"]:
        runtime_env = {"conda": f"tf-{tf_version}"}
        ray.init(job_config=JobConfig(runtime_env=runtime_env))
        assert ray.get(get_conda_env.remote()) == tf_version
        ray.shutdown()
Exemple #17
0
def rewrite_runtime_env_uris(job_config: JobConfig) -> None:
    """Rewrite the uris field in job_config.

    This function is used to update the runtime field in job_config. The
    runtime field will be generated based on the hash of required files and
    modules.

    Args:
        job_config (JobConfig): The job config.
    """
    # For now, we only support local directory and packages
    uris = job_config.runtime_env.get("uris")
    if uris is not None:
        return
    working_dir = job_config.runtime_env.get("working_dir")
    py_modules = job_config.runtime_env.get("py_modules")
    excludes = job_config.runtime_env.get("excludes")
    if working_dir or py_modules:
        if excludes is None:
            excludes = []
        pkg_name = get_project_package_name(working_dir, py_modules, excludes)
        job_config.set_runtime_env_uris(
            [Protocol.GCS.value + "://" + pkg_name])
Exemple #18
0
    def start_specific_server(self, client_id: str,
                              job_config: JobConfig) -> bool:
        """
        Start up a RayClient Server for an incoming client to
        communicate with. Returns whether creation was successful.
        """
        with self.server_lock:
            port = self._get_unused_port()
            handle_ready = futures.Future()
            specific_server = SpecificServer(
                port=port,
                process_handle_future=handle_ready,
                channel=grpc.insecure_channel(f"localhost:{port}",
                                              options=GRPC_OPTIONS))
            self.servers[client_id] = specific_server

        serialized_runtime_env = job_config.get_serialized_runtime_env()

        proc = start_ray_client_server(
            self.redis_address,
            port,
            fate_share=self.fate_share,
            server_type="specific-server",
            serialized_runtime_env=serialized_runtime_env,
            session_dir=self._get_session_dir())

        # Wait for the process being run transitions from the shim process
        # to the actual RayClient Server.
        pid = proc.process.pid
        if sys.platform != "win32":
            psutil_proc = psutil.Process(pid)
        else:
            psutil_proc = None
        # Don't use `psutil` on Win32
        while psutil_proc is not None:
            if proc.process.poll() is not None:
                logger.error(
                    f"SpecificServer startup failed for client: {client_id}")
                break
            cmd = psutil_proc.cmdline()
            if len(cmd) > 3 and cmd[2] == "ray.util.client.server":
                break
            logger.debug(
                "Waiting for Process to reach the actual client server.")
            time.sleep(0.5)
        handle_ready.set_result(proc)
        logger.info(f"SpecificServer started on port: {port} with PID: {pid} "
                    f"for client: {client_id}")
        return proc.process.poll() is None
Exemple #19
0
def test_prepare_runtime_init_req_modified_job():
    """
    Check that `prepare_runtime_init_req` properly extracts the JobConfig and
    modifies it according to `ray_client_server_env_prep`.
    """
    job_config = JobConfig(worker_env={"KEY": "VALUE"}, ray_namespace="abc")
    init_req = ray_client_pb2.DataRequest(init=ray_client_pb2.InitRequest(
        job_config=pickle.dumps(job_config)))

    def modify_namespace(job_config: JobConfig):
        job_config.set_ray_namespace("test_value")
        return job_config

    proxier.ray_client_server_env_prep = modify_namespace
    req, new_config = proxier.prepare_runtime_init_req(iter([init_req]))

    assert new_config.ray_namespace == "test_value"
    assert pickle.loads(
        req.init.job_config).serialize() == new_config.serialize()
Exemple #20
0
def rewrite_working_dir_uri(job_config: JobConfig) -> None:
    """Rewrite the working dir uri field in job_config.

    This function is used to update the runtime field in job_config. The
    runtime field will be generated based on the hash of required files and
    modules.

    Args:
        job_config (JobConfig): The job config.
    """
    # For now, we only support local directory and packages
    working_dir = job_config.runtime_env.get("working_dir")
    required_modules = job_config.runtime_env.get("local_modules")

    if (not job_config.runtime_env.get("working_dir_uri")) and (
            working_dir or required_modules):
        pkg_name = get_project_package_name(working_dir, required_modules)
        job_config.runtime_env[
            "working_dir_uri"] = Protocol.GCS.value + "://" + pkg_name
Exemple #21
0
def test_proxy_manager_bad_startup(shutdown_only):
    """
    Test that when a SpecificServer fails to start (because of a bad JobConfig)
    that it is properly GC'd.
    """
    proxier.CHECK_PROCESS_INTERVAL_S = 1
    proxier.CHECK_CHANNEL_TIMEOUT_S = 1
    pm, free_ports = start_ray_and_proxy_manager(n_ports=2)
    client = "client1"

    pm.create_specific_server(client)
    assert not pm.start_specific_server(
        client, JobConfig(runtime_env={"conda": "conda-env-that-sadly-does-not-exist"})
    )
    # Wait for reconcile loop
    time.sleep(2)
    assert pm.get_channel(client) is None

    assert len(pm._free_ports) == 2
Exemple #22
0
def prepare_runtime_init_req(iterator: Iterator[ray_client_pb2.DataRequest]
                             ) -> Tuple[ray_client_pb2.DataRequest, JobConfig]:
    """
    Extract JobConfig and possibly mutate InitRequest before it is passed to
    the specific RayClient Server.
    """
    init_req = next(iterator)
    init_type = init_req.WhichOneof("type")
    assert init_type == "init", ("Received initial message of type "
                                 f"{init_type}, not 'init'.")
    req = init_req.init
    job_config = JobConfig()
    if req.job_config:
        job_config = pickle.loads(req.job_config)
    new_job_config = ray_client_server_env_prep(job_config)
    modified_init_req = ray_client_pb2.InitRequest(
        job_config=pickle.dumps(new_job_config))

    init_req.init.CopyFrom(modified_init_req)
    return (init_req, new_job_config)
Exemple #23
0
def test_proxy_manager_lifecycle(shutdown_only):
    """
    Creates a ProxyManager and tests basic handling of the lifetime of a
    specific RayClient Server. It checks the following properties:
    1. The SpecificServer is created using the first port.
    2. The SpecificServer comes alive and has a log associated with it.
    3. The SpecificServer destructs itself when no client connects.
    4. The ProxyManager returns the port of the destructed SpecificServer.
    """
    ray_instance = ray.init()
    proxier.CHECK_PROCESS_INTERVAL_S = 1
    os.environ["TIMEOUT_FOR_SPECIFIC_SERVER_S"] = "5"
    pm = proxier.ProxyManager(ray_instance["redis_address"],
                              session_dir=ray_instance["session_dir"])
    # NOTE: We use different ports between runs because sometimes the port is
    # not released, introducing flakiness.
    port_one, port_two = random.choices(range(45000, 45100), k=2)
    pm._free_ports = [port_one, port_two]
    client = "client1"

    pm.create_specific_server(client)
    assert pm.start_specific_server(client, JobConfig())
    # Channel should be ready and corresponding to an existing server
    grpc.channel_ready_future(pm.get_channel(client)).result(timeout=5)

    proc = pm._get_server_for_client(client)
    assert proc.port == port_one, f"Free Ports are: [{port_one}, {port_two}]"

    log_files_path = os.path.join(pm.node.get_session_dir_path(), "logs",
                                  "ray_client_server*")
    files = glob(log_files_path)
    assert any(str(port_one) in f for f in files)

    proc.process_handle_future.result().process.wait(10)
    # Wait for reconcile loop
    time.sleep(2)

    assert len(pm._free_ports) == 2
    assert pm._get_unused_port() == port_two
Exemple #24
0
def test_proxy_manager_bad_startup(shutdown_only):
    """
    Test that when a SpecificServer fails to start (because of a bad JobConfig)
    that it is properly GC'd
    """
    ray_instance = ray.init()
    proxier.CHECK_PROCESS_INTERVAL_S = 1
    proxier.CHECK_CHANNEL_TIMEOUT_S = 1
    pm = proxier.ProxyManager(ray_instance["redis_address"],
                              ray_instance["session_dir"])
    pm._free_ports = [46000, 46001]
    client = "client1"

    assert not pm.start_specific_server(
        client,
        JobConfig(
            runtime_env={"conda": "conda-env-that-sadly-does-not-exist"}))
    # Wait for reconcile loop
    time.sleep(2)
    assert pm.get_channel(client) is None

    assert len(pm._free_ports) == 2
Exemple #25
0
    def upload_runtime_env_package_if_needed(
            self,
            job_config: JobConfig,
            logger: Optional[logging.Logger] = default_logger):
        """Upload runtime env if it's not there.

        It'll check whether the runtime environment exists in the cluster or
        not. If it doesn't, a package will be created based on the working
        directory and modules defined in job config. The package will be
        uploaded to the cluster after this.

        Args:
            job_config (JobConfig): The job config of driver.
        """
        if logger is None:
            logger = default_logger

        pkg_uris = job_config.get_runtime_env_uris()
        if len(pkg_uris) == 0:
            return  # Return early to avoid internal kv check in this case.
        for pkg_uri in pkg_uris:
            if not package_exists(pkg_uri):
                file_path = self._get_local_path(pkg_uri)
                pkg_file = Path(file_path)
                working_dir = job_config.runtime_env.get("working_dir")
                py_modules = job_config.runtime_env.get("py_modules")
                excludes = job_config.runtime_env.get("excludes") or []
                logger.info(f"{pkg_uri} doesn't exist. Create new package with"
                            f" {working_dir} and {py_modules}")
                if not pkg_file.exists():
                    create_project_package(working_dir,
                                           py_modules,
                                           excludes,
                                           file_path,
                                           logger=logger)
                # Push the data to remote storage
                pkg_size = push_package(pkg_uri, pkg_file)
                logger.info(f"{pkg_uri} has been pushed with {pkg_size} bytes")
Exemple #26
0
    def connect(
        self,
        conn_str: str,
        job_config: JobConfig = None,
        secure: bool = False,
        metadata: List[Tuple[str, str]] = None,
        connection_retries: int = 3,
        namespace: str = None,
        *,
        ignore_version: bool = False,
        _credentials: Optional[grpc.ChannelCredentials] = None,
        ray_init_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        """Connect the Ray Client to a server.

        Args:
            conn_str: Connection string, in the form "[host]:port"
            job_config: The job config of the server.
            secure: Whether to use a TLS secured gRPC channel
            metadata: gRPC metadata to send on connect
            connection_retries: number of connection attempts to make
            ignore_version: whether to ignore Python or Ray version mismatches.
                This should only be used for debugging purposes.

        Returns:
            Dictionary of connection info, e.g., {"num_clients": 1}.
        """
        # Delay imports until connect to avoid circular imports.
        from ray.util.client.worker import Worker

        if self.client_worker is not None:
            if self._connected_with_init:
                return
            raise Exception("ray.init() called, but ray client is already connected")
        if not self._inside_client_test:
            # If we're calling a client connect specifically and we're not
            # currently in client mode, ensure we are.
            _explicitly_enable_client_mode()
        if namespace is not None:
            job_config = job_config or JobConfig()
            job_config.set_ray_namespace(namespace)

        logging_level = ray_constants.LOGGER_LEVEL
        logging_format = ray_constants.LOGGER_FORMAT

        if ray_init_kwargs is not None:
            if ray_init_kwargs.get("logging_level") is not None:
                logging_level = ray_init_kwargs["logging_level"]
            if ray_init_kwargs.get("logging_format") is not None:
                logging_format = ray_init_kwargs["logging_format"]

        setup_logger(logging_level, logging_format)

        try:
            self.client_worker = Worker(
                conn_str,
                secure=secure,
                _credentials=_credentials,
                metadata=metadata,
                connection_retries=connection_retries,
            )
            self.api.worker = self.client_worker
            self.client_worker._server_init(job_config, ray_init_kwargs)
            conn_info = self.client_worker.connection_info()
            self._check_versions(conn_info, ignore_version)
            self._register_serializers()
            return conn_info
        except Exception:
            self.disconnect()
            raise
Exemple #27
0
 def __init__(self, address: Optional[str]) -> None:
     self.address = address
     self._job_config = JobConfig()
Exemple #28
0
def test_proxy_manager_internal_kv(shutdown_only, with_specific_server):
    """
    Test that proxy manager can use internal kv with and without a
    SpecificServer and that once a SpecificServer is started up, it
    goes through it.
    """

    proxier.CHECK_PROCESS_INTERVAL_S = 1
    # The timeout has likely been set to 1 in an earlier test. Increase timeout
    # to wait for the channel to become ready.
    proxier.CHECK_CHANNEL_TIMEOUT_S = 5
    os.environ["TIMEOUT_FOR_SPECIFIC_SERVER_S"] = "5"
    pm, free_ports = start_ray_and_proxy_manager(n_ports=2)
    client = "client1"

    task_servicer = proxier.RayletServicerProxy(None, pm)

    def make_internal_kv_calls():
        response = task_servicer.KVPut(
            ray_client_pb2.KVPutRequest(key=b"key", value=b"val")
        )
        assert isinstance(response, ray_client_pb2.KVPutResponse)
        assert not response.already_exists

        response = task_servicer.KVPut(
            ray_client_pb2.KVPutRequest(key=b"key", value=b"val2")
        )
        assert isinstance(response, ray_client_pb2.KVPutResponse)
        assert response.already_exists

        response = task_servicer.KVGet(ray_client_pb2.KVGetRequest(key=b"key"))
        assert isinstance(response, ray_client_pb2.KVGetResponse)
        assert response.value == b"val"

        response = task_servicer.KVPut(
            ray_client_pb2.KVPutRequest(key=b"key", value=b"val2", overwrite=True)
        )
        assert isinstance(response, ray_client_pb2.KVPutResponse)
        assert response.already_exists

        response = task_servicer.KVGet(ray_client_pb2.KVGetRequest(key=b"key"))
        assert isinstance(response, ray_client_pb2.KVGetResponse)
        assert response.value == b"val2"

    with patch(
        "ray.util.client.server.proxier._get_client_id_from_context"
    ) as mock_get_client_id:
        mock_get_client_id.return_value = client

        if with_specific_server:
            pm.create_specific_server(client)
            assert pm.start_specific_server(client, JobConfig())
            channel = pm.get_channel(client)
            assert channel is not None
            task_servicer.Init(
                ray_client_pb2.InitRequest(job_config=pickle.dumps(JobConfig()))
            )

            # Mock out the internal kv calls in this process to raise an
            # exception if they're called. This verifies that we are not
            # making any calls in the proxier if there is a SpecificServer
            # started up.
            with patch(
                "ray.experimental.internal_kv._internal_kv_put"
            ) as mock_put, patch(
                "ray.experimental.internal_kv._internal_kv_get"
            ) as mock_get, patch(
                "ray.experimental.internal_kv._internal_kv_initialized"
            ) as mock_initialized:
                mock_put.side_effect = Exception("This shouldn't be called!")
                mock_get.side_effect = Exception("This shouldn't be called!")
                mock_initialized.side_effect = Exception("This shouldn't be called!")
                make_internal_kv_calls()
        else:
            make_internal_kv_calls()
Exemple #29
0
 def modify_namespace(job_config: JobConfig):
     job_config.set_ray_namespace("test_value")
     return job_config
Exemple #30
0
class ClientBuilder:
    """
    Builder for a Ray Client connection. This class can be subclassed by
    custom builder classes to modify connection behavior to include additional
    features or altered semantics. One example is the ``_LocalClientBuilder``.
    """

    def __init__(self, address: Optional[str]) -> None:
        self.address = address
        self._job_config = JobConfig()
        self._remote_init_kwargs = {}
        # Whether to allow connections to multiple clusters"
        # " (allow_multiple=True).
        self._allow_multiple_connections = False
        self._credentials = None
        # Set to False if ClientBuilder is being constructed by internal
        # methods
        self._deprecation_warn_enabled = True

    def env(self, env: Dict[str, Any]) -> "ClientBuilder":
        """
        Set an environment for the session.
        Args:
            env (Dict[st, Any]): A runtime environment to use for this
            connection. See :ref:`runtime-environments` for what values are
            accepted in this dict.
        """
        self._job_config.set_runtime_env(env)
        return self

    def namespace(self, namespace: str) -> "ClientBuilder":
        """
        Sets the namespace for the session.
        Args:
            namespace (str): Namespace to use.
        """
        self._job_config.set_ray_namespace(namespace)
        return self

    def connect(self) -> ClientContext:
        """
        Begin a connection to the address passed in via ray.client(...).

        Returns:
            ClientInfo: Dataclass with information about the setting. This
                includes the server's version of Python & Ray as well as the
                dashboard_url.
        """
        if self._deprecation_warn_enabled:
            self._client_deprecation_warn()
        # Fill runtime env/namespace from environment if not already set.
        # Should be done *after* the deprecation warning, since warning will
        # check if those values are already set.
        self._fill_defaults_from_env()

        # If it has already connected to the cluster with allow_multiple=True,
        # connect to the default one is not allowed.
        # But if it has connected to the default one, connect to other clients
        # with allow_multiple=True is allowed
        default_cli_connected = ray.util.client.ray.is_connected()
        has_cli_connected = ray.util.client.num_connected_contexts() > 0
        if (
            not self._allow_multiple_connections
            and not default_cli_connected
            and has_cli_connected
        ):
            raise ValueError(
                "The client has already connected to the cluster "
                "with allow_multiple=True. Please set allow_multiple=True"
                " to proceed"
            )

        old_ray_cxt = None
        if self._allow_multiple_connections:
            old_ray_cxt = ray.util.client.ray.set_context(None)

        client_info_dict = ray.util.client_connect.connect(
            self.address,
            job_config=self._job_config,
            _credentials=self._credentials,
            ray_init_kwargs=self._remote_init_kwargs,
        )
        get_dashboard_url = ray.remote(ray.worker.get_dashboard_url)
        dashboard_url = ray.get(get_dashboard_url.options(num_cpus=0).remote())
        cxt = ClientContext(
            dashboard_url=dashboard_url,
            python_version=client_info_dict["python_version"],
            ray_version=client_info_dict["ray_version"],
            ray_commit=client_info_dict["ray_commit"],
            protocol_version=client_info_dict["protocol_version"],
            _num_clients=client_info_dict["num_clients"],
            _context_to_restore=ray.util.client.ray.get_context(),
        )
        if self._allow_multiple_connections:
            ray.util.client.ray.set_context(old_ray_cxt)
        return cxt

    def _fill_defaults_from_env(self):
        # Check environment variables for default values
        namespace_env_var = os.environ.get(RAY_NAMESPACE_ENVIRONMENT_VARIABLE)
        if namespace_env_var and self._job_config.ray_namespace is None:
            self.namespace(namespace_env_var)

        runtime_env_var = os.environ.get(RAY_RUNTIME_ENV_ENVIRONMENT_VARIABLE)
        if runtime_env_var and self._job_config.runtime_env is None:
            self.env(json.loads(runtime_env_var))

    def _init_args(self, **kwargs) -> "ClientBuilder":
        """
        When a client builder is constructed through ray.init, for example
        `ray.init(ray://..., namespace=...)`, all of the
        arguments passed into ray.init with non-default values are passed
        again into this method. Custom client builders can override this method
        to do their own handling/validation of arguments.
        """
        # Use namespace and runtime_env from ray.init call
        if kwargs.get("namespace") is not None:
            self.namespace(kwargs["namespace"])
            del kwargs["namespace"]
        if kwargs.get("runtime_env") is not None:
            self.env(kwargs["runtime_env"])
            del kwargs["runtime_env"]

        if kwargs.get("allow_multiple") is True:
            self._allow_multiple_connections = True
            del kwargs["allow_multiple"]

        if "_credentials" in kwargs.keys():
            self._credentials = kwargs["_credentials"]
            del kwargs["_credentials"]

        if kwargs:
            expected_sig = inspect.signature(ray_driver_init)
            extra_args = set(kwargs.keys()).difference(expected_sig.parameters.keys())
            if len(extra_args) > 0:
                raise RuntimeError(
                    "Got unexpected kwargs: {}".format(", ".join(extra_args))
                )
            self._remote_init_kwargs = kwargs
            unknown = ", ".join(kwargs)
            logger.info(
                "Passing the following kwargs to ray.init() "
                f"on the server: {unknown}"
            )
        return self

    def _client_deprecation_warn(self) -> None:
        """
        Generates a warning for user's if this ClientBuilder instance was
        created directly or through ray.client, instead of relying on
        internal methods (ray.init, or auto init)
        """
        namespace = self._job_config.ray_namespace
        runtime_env = self._job_config.runtime_env
        replacement_args = []
        if self.address:
            if isinstance(self, _LocalClientBuilder):
                # Address might be set for LocalClientBuilder if ray.client()
                # is called while ray_current_cluster is set
                # (see _get_builder_from_address). In this case,
                # leave off the ray:// so the user attaches the driver directly
                replacement_args.append(f'"{self.address}"')
            else:
                replacement_args.append(f'"ray://{self.address}"')
        if namespace:
            replacement_args.append(f'namespace="{namespace}"')
        if runtime_env:
            # Use a placeholder here, since the real runtime_env would be
            # difficult to read if formatted in directly
            replacement_args.append("runtime_env=<your_runtime_env>")
        args_str = ", ".join(replacement_args)
        replacement_call = f"ray.init({args_str})"

        # Note: stack level is set to 3 since we want the warning to reach the
        # call to ray.client(...).connect(). The intervening frames are
        # connect() -> client_deprecation_warn() -> warnings.warn()
        # https://docs.python.org/3/library/warnings.html#available-functions
        warnings.warn(
            "Starting a connection through `ray.client` will be deprecated "
            "in future ray versions in favor of `ray.init`. See the docs for "
            f"more details: {CLIENT_DOCS_URL}. You can replace your call to "
            "`ray.client().connect()` with the following:\n"
            f"      {replacement_call}\n",
            DeprecationWarning,
            stacklevel=3,
        )