def test_proxy_manager_lifecycle(shutdown_only): """ Creates a ProxyManager and tests basic handling of the lifetime of a specific RayClient Server. It checks the following properties: 1. The SpecificServer is created using the first port. 2. The SpecificServer comes alive and has a log associated with it. 3. The SpecificServer destructs itself when no client connects. 4. The ProxyManager returns the port of the destructed SpecificServer. """ proxier.CHECK_PROCESS_INTERVAL_S = 1 os.environ["TIMEOUT_FOR_SPECIFIC_SERVER_S"] = "5" pm, free_ports = start_ray_and_proxy_manager(n_ports=2) client = "client1" pm.create_specific_server(client) assert pm.start_specific_server(client, JobConfig()) # Channel should be ready and corresponding to an existing server grpc.channel_ready_future(pm.get_channel(client)).result(timeout=5) proc = pm._get_server_for_client(client) assert proc.port == free_ports[0], f"Free Ports are: {free_ports}" log_files_path = os.path.join(pm.node.get_session_dir_path(), "logs", "ray_client_server*") files = glob(log_files_path) assert any(str(free_ports[0]) in f for f in files) proc.process_handle_future.result().process.wait(10) # Wait for reconcile loop time.sleep(2) assert len(pm._free_ports) == 2 assert pm._get_unused_port() == free_ports[1]
def test_pip_job_config(shutdown_only, pip_as_str, tmp_path): """Tests dynamic installation of pip packages in a task's runtime env.""" if pip_as_str: d = tmp_path / "pip_requirements" d.mkdir() p = d / "requirements.txt" requirements_txt = """ pip-install-test==0.5 """ p.write_text(requirements_txt) runtime_env = {"pip": str(p)} else: runtime_env = {"pip": ["pip-install-test==0.5"]} ray.init(job_config=JobConfig(runtime_env=runtime_env)) @ray.remote def f(): import pip_install_test # noqa return True with pytest.raises(ModuleNotFoundError): # Ensure pip-install-test is not installed on the test machine import pip_install_test # noqa assert ray.get(f.remote())
def upload_runtime_env_package_if_needed(job_config: JobConfig) -> None: """Upload runtime env if it's not there. It'll check whether the runtime environment exists in the cluster or not. If it doesn't exist, a package will be created based on the working directory and modules defined in job config. The package will be uploaded to the cluster after this. Args: job_config (JobConfig): The job config of driver. """ assert _internal_kv_initialized() pkg_uris = job_config.get_runtime_env_uris() for pkg_uri in pkg_uris: if not package_exists(pkg_uri): file_path = _get_local_path(pkg_uri) pkg_file = Path(file_path) working_dir = job_config.runtime_env.get("working_dir") py_modules = job_config.runtime_env.get("py_modules") excludes = job_config.runtime_env.get("excludes") or [] logger.info(f"{pkg_uri} doesn't exist. Create new package with" f" {working_dir} and {py_modules}") if not pkg_file.exists(): create_project_package(working_dir, py_modules, excludes, file_path) # Push the data to remote storage pkg_size = push_package(pkg_uri, pkg_file) logger.info(f"{pkg_uri} has been pushed with {pkg_size} bytes")
def test_pip_job_config(shutdown_only, pip_as_str): """Tests dynamic installation of pip packages in a task's runtime env.""" ray_wheel_path = os.path.join("/ray/.whl", get_wheel_filename()) if pip_as_str: requirements_txt = f""" {ray_wheel_path} pip-install-test==0.5 opentelemetry-api==1.0.0rc1 opentelemetry-sdk==1.0.0rc1 """ runtime_env = {"pip": requirements_txt} else: runtime_env = { "pip": [ ray_wheel_path, "pip-install-test==0.5", "opentelemetry-api==1.0.0rc1", "opentelemetry-sdk==1.0.0rc1" ] } ray.init(job_config=JobConfig(runtime_env=runtime_env)) @ray.remote def f(): import pip_install_test # noqa return True with pytest.raises(ModuleNotFoundError): # Ensure pip-install-test is not installed on the test machine import pip_install_test # noqa assert ray.get(f.remote())
def test_prepare_runtime_init_req_modified_job(): """ Check that `prepare_runtime_init_req` properly extracts the JobConfig and modifies it according to `ray_client_server_env_prep`. """ job_config = JobConfig( runtime_env={"env_vars": {"KEY": "VALUE"}}, ray_namespace="abc" ) init_req = ray_client_pb2.DataRequest( init=ray_client_pb2.InitRequest( job_config=pickle.dumps(job_config), ray_init_kwargs=json.dumps({"log_to_driver": False}), ) ) def modify_namespace(job_config: JobConfig): job_config.set_ray_namespace("test_value") return job_config with patch.object(proxier, "ray_client_server_env_prep", modify_namespace): req, new_config = proxier.prepare_runtime_init_req(init_req) assert new_config.ray_namespace == "test_value" assert pickle.loads(req.init.job_config).serialize() == new_config.serialize() assert json.loads(req.init.ray_init_kwargs) == {"log_to_driver": False}
def test_conda_create_job_config(shutdown_only): """Tests dynamic conda env creation in a runtime env in the JobConfig.""" ray_wheel_filename = get_wheel_filename() # E.g. 3.6.13 python_micro_version_dots = ".".join(map(str, sys.version_info[:3])) ray_wheel_path = os.path.join("/ray/.whl", ray_wheel_filename) runtime_env = { "conda": { "dependencies": [ f"python={python_micro_version_dots}", "pip", { "pip": [ ray_wheel_path, "pip-install-test==0.5", "opentelemetry-api==1.0.0rc1", "opentelemetry-sdk==1.0.0rc1" ] } ] } } ray.init(job_config=JobConfig(runtime_env=runtime_env)) @ray.remote def f(): import pip_install_test # noqa return True with pytest.raises(ModuleNotFoundError): # Ensure pip-install-test is not installed on the test machine import pip_install_test # noqa assert ray.get(f.remote())
def test_proxy_manager_lifecycle(shutdown_only): """ Creates a ProxyManager and tests basic handling of the lifetime of a specific RayClient Server. It checks the following properties: 1. The SpecificServer is created using the first port. 2. The SpecificServer comes alive. 3. The SpecificServer destructs itself when no client connects. 4. The ProxyManager returns the port of the destructed SpecificServer. """ ray_instance = ray.init() proxier.CHECK_PROCESS_INTERVAL_S = 1 os.environ["TIMEOUT_FOR_SPECIFIC_SERVER_S"] = "5" pm = proxier.ProxyManager(ray_instance["redis_address"], ray_instance["session_dir"]) pm._free_ports = [45000, 45001] client = "client1" assert pm.start_specific_server(client, JobConfig()) # Channel should be ready and corresponding to an existing server grpc.channel_ready_future(pm.get_channel(client)).result(timeout=5) proc = pm._get_server_for_client(client) assert proc.port == 45000 proc.process_handle().process.wait(10) # Wait for reconcile loop time.sleep(2) assert len(pm._free_ports) == 2 assert pm._get_unused_port() == 45001
def test_conda_create_job_config(shutdown_only): """Tests dynamic conda env creation in a runtime env in the JobConfig.""" runtime_env = { "conda": { "dependencies": [ "pip", { "pip": [ "pip-install-test==0.5", "opentelemetry-api==1.0.0rc1", "opentelemetry-sdk==1.0.0rc1" ] } ] } } ray.init(job_config=JobConfig(runtime_env=runtime_env)) @ray.remote def f(): import pip_install_test # noqa return True with pytest.raises(ModuleNotFoundError): # Ensure pip-install-test is not installed on the test machine import pip_install_test # noqa assert ray.get(f.remote())
def test_default_actor_lifetime(default_actor_lifetime, child_actor_lifetime): @ray.remote class OwnerActor: def create_child_actor(self, child_actor_lifetime): if child_actor_lifetime is None: self._child_actor = ChildActor.remote() else: self._child_actor = ChildActor.options( lifetime=child_actor_lifetime).remote() assert "ok" == ray.get(self._child_actor.ready.remote()) return self._child_actor def get_pid(self): return os.getpid() def ready(self): return "ok" @ray.remote class ChildActor: def ready(self): return "ok" if default_actor_lifetime is not None: ray.init(job_config=JobConfig( default_actor_lifetime=default_actor_lifetime)) else: ray.init() # 1. create owner and invoke create_child_actor. owner = OwnerActor.remote() child = ray.get(owner.create_child_actor.remote(child_actor_lifetime)) assert "ok" == ray.get(child.ready.remote()) # 2. Kill owner and make sure it's dead. owner_pid = ray.get(owner.get_pid.remote()) os.kill(owner_pid, SIGKILL) wait_for_pid_to_exit(owner_pid) # 3. Assert child state. def is_child_actor_dead(): try: ray.get(child.ready.remote()) return False except RayActorError: return True actual_lifetime = default_actor_lifetime if child_actor_lifetime is not None: actual_lifetime = child_actor_lifetime assert actual_lifetime is not None if actual_lifetime == "detached": time.sleep(5) assert not is_child_actor_dead() else: wait_for_condition(is_child_actor_dead, timeout=5) ray.shutdown()
def test_prepare_runtime_init_req_no_modification(): """ Check that `prepare_runtime_init_req` properly extracts the JobConfig. """ job_config = JobConfig(runtime_env={"env_vars": { "KEY": "VALUE" }}, ray_namespace="abc") init_req = ray_client_pb2.DataRequest(init=ray_client_pb2.InitRequest( job_config=pickle.dumps(job_config), ray_init_kwargs=json.dumps({"log_to_driver": False})), ) req, new_config = proxier.prepare_runtime_init_req(init_req) assert new_config.serialize() == job_config.serialize() assert isinstance(req, ray_client_pb2.DataRequest) assert pickle.loads( req.init.job_config).serialize() == new_config.serialize() assert json.loads(req.init.ray_init_kwargs) == {"log_to_driver": False}
def start_specific_server(self, client_id: str, job_config: JobConfig) -> bool: """ Start up a RayClient Server for an incoming client to communicate with. Returns whether creation was successful. """ specific_server = self._get_server_for_client(client_id) assert specific_server, f"Server has not been created for: {client_id}" output, error = self.node.get_log_file_handles( f"ray_client_server_{specific_server.port}", unique=True ) serialized_runtime_env = job_config.get_serialized_runtime_env() if not serialized_runtime_env or serialized_runtime_env == "{}": # TODO(edoakes): can we just remove this case and always send it # to the agent? serialized_runtime_env_context = RuntimeEnvContext().serialize() else: serialized_runtime_env_context = self._create_runtime_env( serialized_runtime_env=serialized_runtime_env, specific_server=specific_server, ) proc = start_ray_client_server( self.address, self.node.node_ip_address, specific_server.port, stdout_file=output, stderr_file=error, fate_share=self.fate_share, server_type="specific-server", serialized_runtime_env_context=serialized_runtime_env_context, redis_password=self._redis_password, ) # Wait for the process being run transitions from the shim process # to the actual RayClient Server. pid = proc.process.pid if sys.platform != "win32": psutil_proc = psutil.Process(pid) else: psutil_proc = None # Don't use `psutil` on Win32 while psutil_proc is not None: if proc.process.poll() is not None: logger.error(f"SpecificServer startup failed for client: {client_id}") break cmd = psutil_proc.cmdline() if _match_running_client_server(cmd): break logger.debug("Waiting for Process to reach the actual client server.") time.sleep(0.5) specific_server.set_result(proc) logger.info( f"SpecificServer started on port: {specific_server.port} " f"with PID: {pid} for client: {client_id}" ) return proc.process.poll() is None
def start_specific_server(self, client_id: str, job_config: JobConfig) -> bool: """ Start up a RayClient Server for an incoming client to communicate with. Returns whether creation was successful. """ specific_server = self._get_server_for_client(client_id) assert specific_server, f"Server has not been created for: {client_id}" output, error = self.node.get_log_file_handles( f"ray_client_server_{specific_server.port}", unique=True) serialized_runtime_env = job_config.get_serialized_runtime_env() runtime_env = json.loads(serialized_runtime_env) # Set up the working_dir for the server. # TODO(edoakes): this should go be unified with the worker setup code # by going through the runtime_env agent. context = RuntimeEnvContext( env_vars=runtime_env.get("env_vars"), resources_dir=self.node.get_runtime_env_dir_path()) working_dir_pkg.setup_working_dir(runtime_env, context) proc = start_ray_client_server( self.redis_address, specific_server.port, stdout_file=output, stderr_file=error, fate_share=self.fate_share, server_type="specific-server", serialized_runtime_env=serialized_runtime_env, serialized_runtime_env_context=context.serialize(), redis_password=self._redis_password) # Wait for the process being run transitions from the shim process # to the actual RayClient Server. pid = proc.process.pid if sys.platform != "win32": psutil_proc = psutil.Process(pid) else: psutil_proc = None # Don't use `psutil` on Win32 while psutil_proc is not None: if proc.process.poll() is not None: logger.error( f"SpecificServer startup failed for client: {client_id}") break cmd = psutil_proc.cmdline() if _match_running_client_server(cmd): break logger.debug( "Waiting for Process to reach the actual client server.") time.sleep(0.5) specific_server.set_result(proc) logger.info(f"SpecificServer started on port: {specific_server.port} " f"with PID: {pid} for client: {client_id}") return proc.process.poll() is None
def test_controller_starts_java_replica(shutdown_only): # noqa: F811 ray.init( num_cpus=8, namespace="default_test_namespace", # A dummy code search path to enable cross language. job_config=JobConfig(code_search_path=["."]), ) client = serve.start(detached=True) controller = client._controller config = DeploymentConfig() config.deployment_language = JAVA config.is_cross_language = True replica_config = ReplicaConfig.create( "io.ray.serve.util.ExampleEchoDeployment", init_args=["my_prefix "], ) # Deploy it deployment_name = "my_java" updating = ray.get( controller.deploy.remote( name=deployment_name, deployment_config_proto_bytes=config.to_proto_bytes(), replica_config_proto_bytes=replica_config.to_proto_bytes(), route_prefix=None, deployer_job_id=ray.get_runtime_context().job_id, ) ) assert updating client._wait_for_deployment_healthy(deployment_name) # Let's try to call it! all_handles = ray.get(controller._all_running_replicas.remote()) backend_handle = all_handles["my_java"][0].actor_handle out = backend_handle.handleRequest.remote( RequestMetadata( request_id="id-1", endpoint="endpoint", call_method="call", ).SerializeToString(), RequestWrapper(body=msgpack_serialize("hello")).SerializeToString(), ) assert ray.get(out) == "my_prefix hello" handle = serve.get_deployment("my_java").get_handle() handle_out = handle.remote("hello handle") assert ray.get(handle_out) == "my_prefix hello handle" ray.get(controller.delete_deployment.remote(deployment_name)) client._wait_for_deployment_deleted(deployment_name)
def prepare_runtime_init_req( req: ray_client_pb2.InitRequest ) -> Tuple[ray_client_pb2.InitRequest, JobConfig]: """ Extract JobConfig and possibly mutate InitRequest before it is passed to the specific RayClient Server. """ job_config = JobConfig() if req.job_config: import pickle job_config = pickle.loads(req.job_config) return (req, job_config)
def _server_init( self, job_config: JobConfig, ray_init_kwargs: Optional[Dict[str, Any]] = None ): """Initialize the server""" if ray_init_kwargs is None: ray_init_kwargs = {} try: if job_config is None: serialized_job_config = None else: with tempfile.TemporaryDirectory() as tmp_dir: runtime_env = job_config.runtime_env or {} runtime_env = upload_py_modules_if_needed( runtime_env, tmp_dir, logger=logger ) runtime_env = upload_working_dir_if_needed( runtime_env, tmp_dir, logger=logger ) # Remove excludes, it isn't relevant after the upload step. runtime_env.pop("excludes", None) job_config.set_runtime_env(runtime_env, validate=True) serialized_job_config = pickle.dumps(job_config) response = self.data_client.Init( ray_client_pb2.InitRequest( job_config=serialized_job_config, ray_init_kwargs=json.dumps(ray_init_kwargs), reconnect_grace_period=self._reconnect_grace_period, ) ) if not response.ok: raise ConnectionAbortedError( f"Initialization failure from server:\n{response.msg}" ) except grpc.RpcError as e: raise decode_exception(e)
def test_job_config_conda_env(conda_envs): import tensorflow as tf tf_version = "2.2.0" @ray.remote def get_conda_env(): return tf.__version__ for tf_version in ["2.2.0", "2.3.0"]: runtime_env = {"conda": f"tf-{tf_version}"} ray.init(job_config=JobConfig(runtime_env=runtime_env)) assert ray.get(get_conda_env.remote()) == tf_version ray.shutdown()
def rewrite_runtime_env_uris(job_config: JobConfig) -> None: """Rewrite the uris field in job_config. This function is used to update the runtime field in job_config. The runtime field will be generated based on the hash of required files and modules. Args: job_config (JobConfig): The job config. """ # For now, we only support local directory and packages uris = job_config.runtime_env.get("uris") if uris is not None: return working_dir = job_config.runtime_env.get("working_dir") py_modules = job_config.runtime_env.get("py_modules") excludes = job_config.runtime_env.get("excludes") if working_dir or py_modules: if excludes is None: excludes = [] pkg_name = get_project_package_name(working_dir, py_modules, excludes) job_config.set_runtime_env_uris( [Protocol.GCS.value + "://" + pkg_name])
def start_specific_server(self, client_id: str, job_config: JobConfig) -> bool: """ Start up a RayClient Server for an incoming client to communicate with. Returns whether creation was successful. """ with self.server_lock: port = self._get_unused_port() handle_ready = futures.Future() specific_server = SpecificServer( port=port, process_handle_future=handle_ready, channel=grpc.insecure_channel(f"localhost:{port}", options=GRPC_OPTIONS)) self.servers[client_id] = specific_server serialized_runtime_env = job_config.get_serialized_runtime_env() proc = start_ray_client_server( self.redis_address, port, fate_share=self.fate_share, server_type="specific-server", serialized_runtime_env=serialized_runtime_env, session_dir=self._get_session_dir()) # Wait for the process being run transitions from the shim process # to the actual RayClient Server. pid = proc.process.pid if sys.platform != "win32": psutil_proc = psutil.Process(pid) else: psutil_proc = None # Don't use `psutil` on Win32 while psutil_proc is not None: if proc.process.poll() is not None: logger.error( f"SpecificServer startup failed for client: {client_id}") break cmd = psutil_proc.cmdline() if len(cmd) > 3 and cmd[2] == "ray.util.client.server": break logger.debug( "Waiting for Process to reach the actual client server.") time.sleep(0.5) handle_ready.set_result(proc) logger.info(f"SpecificServer started on port: {port} with PID: {pid} " f"for client: {client_id}") return proc.process.poll() is None
def test_prepare_runtime_init_req_modified_job(): """ Check that `prepare_runtime_init_req` properly extracts the JobConfig and modifies it according to `ray_client_server_env_prep`. """ job_config = JobConfig(worker_env={"KEY": "VALUE"}, ray_namespace="abc") init_req = ray_client_pb2.DataRequest(init=ray_client_pb2.InitRequest( job_config=pickle.dumps(job_config))) def modify_namespace(job_config: JobConfig): job_config.set_ray_namespace("test_value") return job_config proxier.ray_client_server_env_prep = modify_namespace req, new_config = proxier.prepare_runtime_init_req(iter([init_req])) assert new_config.ray_namespace == "test_value" assert pickle.loads( req.init.job_config).serialize() == new_config.serialize()
def rewrite_working_dir_uri(job_config: JobConfig) -> None: """Rewrite the working dir uri field in job_config. This function is used to update the runtime field in job_config. The runtime field will be generated based on the hash of required files and modules. Args: job_config (JobConfig): The job config. """ # For now, we only support local directory and packages working_dir = job_config.runtime_env.get("working_dir") required_modules = job_config.runtime_env.get("local_modules") if (not job_config.runtime_env.get("working_dir_uri")) and ( working_dir or required_modules): pkg_name = get_project_package_name(working_dir, required_modules) job_config.runtime_env[ "working_dir_uri"] = Protocol.GCS.value + "://" + pkg_name
def test_proxy_manager_bad_startup(shutdown_only): """ Test that when a SpecificServer fails to start (because of a bad JobConfig) that it is properly GC'd. """ proxier.CHECK_PROCESS_INTERVAL_S = 1 proxier.CHECK_CHANNEL_TIMEOUT_S = 1 pm, free_ports = start_ray_and_proxy_manager(n_ports=2) client = "client1" pm.create_specific_server(client) assert not pm.start_specific_server( client, JobConfig(runtime_env={"conda": "conda-env-that-sadly-does-not-exist"}) ) # Wait for reconcile loop time.sleep(2) assert pm.get_channel(client) is None assert len(pm._free_ports) == 2
def prepare_runtime_init_req(iterator: Iterator[ray_client_pb2.DataRequest] ) -> Tuple[ray_client_pb2.DataRequest, JobConfig]: """ Extract JobConfig and possibly mutate InitRequest before it is passed to the specific RayClient Server. """ init_req = next(iterator) init_type = init_req.WhichOneof("type") assert init_type == "init", ("Received initial message of type " f"{init_type}, not 'init'.") req = init_req.init job_config = JobConfig() if req.job_config: job_config = pickle.loads(req.job_config) new_job_config = ray_client_server_env_prep(job_config) modified_init_req = ray_client_pb2.InitRequest( job_config=pickle.dumps(new_job_config)) init_req.init.CopyFrom(modified_init_req) return (init_req, new_job_config)
def test_proxy_manager_lifecycle(shutdown_only): """ Creates a ProxyManager and tests basic handling of the lifetime of a specific RayClient Server. It checks the following properties: 1. The SpecificServer is created using the first port. 2. The SpecificServer comes alive and has a log associated with it. 3. The SpecificServer destructs itself when no client connects. 4. The ProxyManager returns the port of the destructed SpecificServer. """ ray_instance = ray.init() proxier.CHECK_PROCESS_INTERVAL_S = 1 os.environ["TIMEOUT_FOR_SPECIFIC_SERVER_S"] = "5" pm = proxier.ProxyManager(ray_instance["redis_address"], session_dir=ray_instance["session_dir"]) # NOTE: We use different ports between runs because sometimes the port is # not released, introducing flakiness. port_one, port_two = random.choices(range(45000, 45100), k=2) pm._free_ports = [port_one, port_two] client = "client1" pm.create_specific_server(client) assert pm.start_specific_server(client, JobConfig()) # Channel should be ready and corresponding to an existing server grpc.channel_ready_future(pm.get_channel(client)).result(timeout=5) proc = pm._get_server_for_client(client) assert proc.port == port_one, f"Free Ports are: [{port_one}, {port_two}]" log_files_path = os.path.join(pm.node.get_session_dir_path(), "logs", "ray_client_server*") files = glob(log_files_path) assert any(str(port_one) in f for f in files) proc.process_handle_future.result().process.wait(10) # Wait for reconcile loop time.sleep(2) assert len(pm._free_ports) == 2 assert pm._get_unused_port() == port_two
def test_proxy_manager_bad_startup(shutdown_only): """ Test that when a SpecificServer fails to start (because of a bad JobConfig) that it is properly GC'd """ ray_instance = ray.init() proxier.CHECK_PROCESS_INTERVAL_S = 1 proxier.CHECK_CHANNEL_TIMEOUT_S = 1 pm = proxier.ProxyManager(ray_instance["redis_address"], ray_instance["session_dir"]) pm._free_ports = [46000, 46001] client = "client1" assert not pm.start_specific_server( client, JobConfig( runtime_env={"conda": "conda-env-that-sadly-does-not-exist"})) # Wait for reconcile loop time.sleep(2) assert pm.get_channel(client) is None assert len(pm._free_ports) == 2
def upload_runtime_env_package_if_needed( self, job_config: JobConfig, logger: Optional[logging.Logger] = default_logger): """Upload runtime env if it's not there. It'll check whether the runtime environment exists in the cluster or not. If it doesn't, a package will be created based on the working directory and modules defined in job config. The package will be uploaded to the cluster after this. Args: job_config (JobConfig): The job config of driver. """ if logger is None: logger = default_logger pkg_uris = job_config.get_runtime_env_uris() if len(pkg_uris) == 0: return # Return early to avoid internal kv check in this case. for pkg_uri in pkg_uris: if not package_exists(pkg_uri): file_path = self._get_local_path(pkg_uri) pkg_file = Path(file_path) working_dir = job_config.runtime_env.get("working_dir") py_modules = job_config.runtime_env.get("py_modules") excludes = job_config.runtime_env.get("excludes") or [] logger.info(f"{pkg_uri} doesn't exist. Create new package with" f" {working_dir} and {py_modules}") if not pkg_file.exists(): create_project_package(working_dir, py_modules, excludes, file_path, logger=logger) # Push the data to remote storage pkg_size = push_package(pkg_uri, pkg_file) logger.info(f"{pkg_uri} has been pushed with {pkg_size} bytes")
def connect( self, conn_str: str, job_config: JobConfig = None, secure: bool = False, metadata: List[Tuple[str, str]] = None, connection_retries: int = 3, namespace: str = None, *, ignore_version: bool = False, _credentials: Optional[grpc.ChannelCredentials] = None, ray_init_kwargs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Connect the Ray Client to a server. Args: conn_str: Connection string, in the form "[host]:port" job_config: The job config of the server. secure: Whether to use a TLS secured gRPC channel metadata: gRPC metadata to send on connect connection_retries: number of connection attempts to make ignore_version: whether to ignore Python or Ray version mismatches. This should only be used for debugging purposes. Returns: Dictionary of connection info, e.g., {"num_clients": 1}. """ # Delay imports until connect to avoid circular imports. from ray.util.client.worker import Worker if self.client_worker is not None: if self._connected_with_init: return raise Exception("ray.init() called, but ray client is already connected") if not self._inside_client_test: # If we're calling a client connect specifically and we're not # currently in client mode, ensure we are. _explicitly_enable_client_mode() if namespace is not None: job_config = job_config or JobConfig() job_config.set_ray_namespace(namespace) logging_level = ray_constants.LOGGER_LEVEL logging_format = ray_constants.LOGGER_FORMAT if ray_init_kwargs is not None: if ray_init_kwargs.get("logging_level") is not None: logging_level = ray_init_kwargs["logging_level"] if ray_init_kwargs.get("logging_format") is not None: logging_format = ray_init_kwargs["logging_format"] setup_logger(logging_level, logging_format) try: self.client_worker = Worker( conn_str, secure=secure, _credentials=_credentials, metadata=metadata, connection_retries=connection_retries, ) self.api.worker = self.client_worker self.client_worker._server_init(job_config, ray_init_kwargs) conn_info = self.client_worker.connection_info() self._check_versions(conn_info, ignore_version) self._register_serializers() return conn_info except Exception: self.disconnect() raise
def __init__(self, address: Optional[str]) -> None: self.address = address self._job_config = JobConfig()
def test_proxy_manager_internal_kv(shutdown_only, with_specific_server): """ Test that proxy manager can use internal kv with and without a SpecificServer and that once a SpecificServer is started up, it goes through it. """ proxier.CHECK_PROCESS_INTERVAL_S = 1 # The timeout has likely been set to 1 in an earlier test. Increase timeout # to wait for the channel to become ready. proxier.CHECK_CHANNEL_TIMEOUT_S = 5 os.environ["TIMEOUT_FOR_SPECIFIC_SERVER_S"] = "5" pm, free_ports = start_ray_and_proxy_manager(n_ports=2) client = "client1" task_servicer = proxier.RayletServicerProxy(None, pm) def make_internal_kv_calls(): response = task_servicer.KVPut( ray_client_pb2.KVPutRequest(key=b"key", value=b"val") ) assert isinstance(response, ray_client_pb2.KVPutResponse) assert not response.already_exists response = task_servicer.KVPut( ray_client_pb2.KVPutRequest(key=b"key", value=b"val2") ) assert isinstance(response, ray_client_pb2.KVPutResponse) assert response.already_exists response = task_servicer.KVGet(ray_client_pb2.KVGetRequest(key=b"key")) assert isinstance(response, ray_client_pb2.KVGetResponse) assert response.value == b"val" response = task_servicer.KVPut( ray_client_pb2.KVPutRequest(key=b"key", value=b"val2", overwrite=True) ) assert isinstance(response, ray_client_pb2.KVPutResponse) assert response.already_exists response = task_servicer.KVGet(ray_client_pb2.KVGetRequest(key=b"key")) assert isinstance(response, ray_client_pb2.KVGetResponse) assert response.value == b"val2" with patch( "ray.util.client.server.proxier._get_client_id_from_context" ) as mock_get_client_id: mock_get_client_id.return_value = client if with_specific_server: pm.create_specific_server(client) assert pm.start_specific_server(client, JobConfig()) channel = pm.get_channel(client) assert channel is not None task_servicer.Init( ray_client_pb2.InitRequest(job_config=pickle.dumps(JobConfig())) ) # Mock out the internal kv calls in this process to raise an # exception if they're called. This verifies that we are not # making any calls in the proxier if there is a SpecificServer # started up. with patch( "ray.experimental.internal_kv._internal_kv_put" ) as mock_put, patch( "ray.experimental.internal_kv._internal_kv_get" ) as mock_get, patch( "ray.experimental.internal_kv._internal_kv_initialized" ) as mock_initialized: mock_put.side_effect = Exception("This shouldn't be called!") mock_get.side_effect = Exception("This shouldn't be called!") mock_initialized.side_effect = Exception("This shouldn't be called!") make_internal_kv_calls() else: make_internal_kv_calls()
def modify_namespace(job_config: JobConfig): job_config.set_ray_namespace("test_value") return job_config
class ClientBuilder: """ Builder for a Ray Client connection. This class can be subclassed by custom builder classes to modify connection behavior to include additional features or altered semantics. One example is the ``_LocalClientBuilder``. """ def __init__(self, address: Optional[str]) -> None: self.address = address self._job_config = JobConfig() self._remote_init_kwargs = {} # Whether to allow connections to multiple clusters" # " (allow_multiple=True). self._allow_multiple_connections = False self._credentials = None # Set to False if ClientBuilder is being constructed by internal # methods self._deprecation_warn_enabled = True def env(self, env: Dict[str, Any]) -> "ClientBuilder": """ Set an environment for the session. Args: env (Dict[st, Any]): A runtime environment to use for this connection. See :ref:`runtime-environments` for what values are accepted in this dict. """ self._job_config.set_runtime_env(env) return self def namespace(self, namespace: str) -> "ClientBuilder": """ Sets the namespace for the session. Args: namespace (str): Namespace to use. """ self._job_config.set_ray_namespace(namespace) return self def connect(self) -> ClientContext: """ Begin a connection to the address passed in via ray.client(...). Returns: ClientInfo: Dataclass with information about the setting. This includes the server's version of Python & Ray as well as the dashboard_url. """ if self._deprecation_warn_enabled: self._client_deprecation_warn() # Fill runtime env/namespace from environment if not already set. # Should be done *after* the deprecation warning, since warning will # check if those values are already set. self._fill_defaults_from_env() # If it has already connected to the cluster with allow_multiple=True, # connect to the default one is not allowed. # But if it has connected to the default one, connect to other clients # with allow_multiple=True is allowed default_cli_connected = ray.util.client.ray.is_connected() has_cli_connected = ray.util.client.num_connected_contexts() > 0 if ( not self._allow_multiple_connections and not default_cli_connected and has_cli_connected ): raise ValueError( "The client has already connected to the cluster " "with allow_multiple=True. Please set allow_multiple=True" " to proceed" ) old_ray_cxt = None if self._allow_multiple_connections: old_ray_cxt = ray.util.client.ray.set_context(None) client_info_dict = ray.util.client_connect.connect( self.address, job_config=self._job_config, _credentials=self._credentials, ray_init_kwargs=self._remote_init_kwargs, ) get_dashboard_url = ray.remote(ray.worker.get_dashboard_url) dashboard_url = ray.get(get_dashboard_url.options(num_cpus=0).remote()) cxt = ClientContext( dashboard_url=dashboard_url, python_version=client_info_dict["python_version"], ray_version=client_info_dict["ray_version"], ray_commit=client_info_dict["ray_commit"], protocol_version=client_info_dict["protocol_version"], _num_clients=client_info_dict["num_clients"], _context_to_restore=ray.util.client.ray.get_context(), ) if self._allow_multiple_connections: ray.util.client.ray.set_context(old_ray_cxt) return cxt def _fill_defaults_from_env(self): # Check environment variables for default values namespace_env_var = os.environ.get(RAY_NAMESPACE_ENVIRONMENT_VARIABLE) if namespace_env_var and self._job_config.ray_namespace is None: self.namespace(namespace_env_var) runtime_env_var = os.environ.get(RAY_RUNTIME_ENV_ENVIRONMENT_VARIABLE) if runtime_env_var and self._job_config.runtime_env is None: self.env(json.loads(runtime_env_var)) def _init_args(self, **kwargs) -> "ClientBuilder": """ When a client builder is constructed through ray.init, for example `ray.init(ray://..., namespace=...)`, all of the arguments passed into ray.init with non-default values are passed again into this method. Custom client builders can override this method to do their own handling/validation of arguments. """ # Use namespace and runtime_env from ray.init call if kwargs.get("namespace") is not None: self.namespace(kwargs["namespace"]) del kwargs["namespace"] if kwargs.get("runtime_env") is not None: self.env(kwargs["runtime_env"]) del kwargs["runtime_env"] if kwargs.get("allow_multiple") is True: self._allow_multiple_connections = True del kwargs["allow_multiple"] if "_credentials" in kwargs.keys(): self._credentials = kwargs["_credentials"] del kwargs["_credentials"] if kwargs: expected_sig = inspect.signature(ray_driver_init) extra_args = set(kwargs.keys()).difference(expected_sig.parameters.keys()) if len(extra_args) > 0: raise RuntimeError( "Got unexpected kwargs: {}".format(", ".join(extra_args)) ) self._remote_init_kwargs = kwargs unknown = ", ".join(kwargs) logger.info( "Passing the following kwargs to ray.init() " f"on the server: {unknown}" ) return self def _client_deprecation_warn(self) -> None: """ Generates a warning for user's if this ClientBuilder instance was created directly or through ray.client, instead of relying on internal methods (ray.init, or auto init) """ namespace = self._job_config.ray_namespace runtime_env = self._job_config.runtime_env replacement_args = [] if self.address: if isinstance(self, _LocalClientBuilder): # Address might be set for LocalClientBuilder if ray.client() # is called while ray_current_cluster is set # (see _get_builder_from_address). In this case, # leave off the ray:// so the user attaches the driver directly replacement_args.append(f'"{self.address}"') else: replacement_args.append(f'"ray://{self.address}"') if namespace: replacement_args.append(f'namespace="{namespace}"') if runtime_env: # Use a placeholder here, since the real runtime_env would be # difficult to read if formatted in directly replacement_args.append("runtime_env=<your_runtime_env>") args_str = ", ".join(replacement_args) replacement_call = f"ray.init({args_str})" # Note: stack level is set to 3 since we want the warning to reach the # call to ray.client(...).connect(). The intervening frames are # connect() -> client_deprecation_warn() -> warnings.warn() # https://docs.python.org/3/library/warnings.html#available-functions warnings.warn( "Starting a connection through `ray.client` will be deprecated " "in future ray versions in favor of `ray.init`. See the docs for " f"more details: {CLIENT_DOCS_URL}. You can replace your call to " "`ray.client().connect()` with the following:\n" f" {replacement_call}\n", DeprecationWarning, stacklevel=3, )