def test_grpc_client_credentials_are_passed_to_channel(monkeypatch): class Stop(Exception): def __init__(self, credentials): self.credentials = credentials class MockChannel: def __init__(self, conn_str, credentials, options, compression): self.credentials = credentials def subscribe(self, f): raise Stop(self.credentials) def mock_secure_channel(conn_str, credentials, options=None, compression=None): return MockChannel(conn_str, credentials, options, compression) monkeypatch.setattr(grpc, "secure_channel", mock_secure_channel) # Credentials should be respected whether secure is set or not. with pytest.raises(Stop) as stop: Worker(secure=False, _credentials=Credentials("test")) assert stop.value.credentials.name == "test" with pytest.raises(Stop) as stop: Worker(secure=True, _credentials=Credentials("test")) assert stop.value.credentials.name == "test"
def connect(self, conn_str: str, secure: bool = False, metadata: List[Tuple[str, str]] = None) -> None: """Connect the Ray Client to a server. Args: conn_str: Connection string, in the form "[host]:port" secure: Whether to use a TLS secured gRPC channel metadata: gRPC metadata to send on connect """ # Delay imports until connect to avoid circular imports. from ray.util.client.worker import Worker import ray._private.client_mode_hook if self.client_worker is not None: if self._connected_with_init: return raise Exception( "ray.connect() called, but ray client is already connected") if not self._inside_client_test: # If we're calling a client connect specifically and we're not # currently in client mode, ensure we are. ray._private.client_mode_hook._explicitly_enable_client_mode() self.client_worker = Worker(conn_str, secure=secure, metadata=metadata) self.api.worker = self.client_worker
def connect(self, conn_str: str, job_config: JobConfig = None, secure: bool = False, metadata: List[Tuple[str, str]] = None, connection_retries: int = 3, namespace: str = None, *, ignore_version: bool = False) -> Dict[str, Any]: """Connect the Ray Client to a server. Args: conn_str: Connection string, in the form "[host]:port" job_config: The job config of the server. secure: Whether to use a TLS secured gRPC channel metadata: gRPC metadata to send on connect connection_retries: number of connection attempts to make ignore_version: whether to ignore Python or Ray version mismatches. This should only be used for debugging purposes. Returns: Dictionary of connection info, e.g., {"num_clients": 1}. """ # Delay imports until connect to avoid circular imports. from ray.util.client.worker import Worker import ray._private.client_mode_hook if self.client_worker is not None: if self._connected_with_init: return raise Exception( "ray.connect() called, but ray client is already connected") if not self._inside_client_test: # If we're calling a client connect specifically and we're not # currently in client mode, ensure we are. ray._private.client_mode_hook._explicitly_enable_client_mode() if namespace is not None: job_config = job_config or JobConfig() job_config.set_ray_namespace(namespace) if job_config is not None: runtime_env = json.loads(job_config.get_serialized_runtime_env()) if runtime_env.get("pip") or runtime_env.get("conda"): logger.warning("The 'pip' or 'conda' field was specified in " "the runtime env, so it may take some time to " "install the environment before ray.connect() " "returns.") try: self.client_worker = Worker(conn_str, secure=secure, metadata=metadata, connection_retries=connection_retries) self.api.worker = self.client_worker self.client_worker._server_init(job_config) conn_info = self.client_worker.connection_info() self._check_versions(conn_info, ignore_version) self._register_serializers() return conn_info except Exception: self.disconnect() raise
def connect(self, conn_str: str, job_config: JobConfig = None, secure: bool = False, metadata: List[Tuple[str, str]] = None, connection_retries: int = 3, *, ignore_version: bool = False) -> Dict[str, Any]: """Connect the Ray Client to a server. Args: conn_str: Connection string, in the form "[host]:port" job_config: The job config of the server. secure: Whether to use a TLS secured gRPC channel metadata: gRPC metadata to send on connect connection_retries: number of connection attempts to make ignore_version: whether to ignore Python or Ray version mismatches. This should only be used for debugging purposes. Returns: Dictionary of connection info, e.g., {"num_clients": 1}. """ # Delay imports until connect to avoid circular imports. from ray.util.client.worker import Worker import ray._private.client_mode_hook if self.client_worker is not None: if self._connected_with_init: return raise Exception( "ray.connect() called, but ray client is already connected") if not self._inside_client_test: # If we're calling a client connect specifically and we're not # currently in client mode, ensure we are. ray._private.client_mode_hook._explicitly_enable_client_mode() try: self.client_worker = Worker(conn_str, secure=secure, metadata=metadata, connection_retries=connection_retries) self.api.worker = self.client_worker self.client_worker._server_init(job_config) conn_info = self.client_worker.connection_info() self._check_versions(conn_info, ignore_version) self._register_serializers() return conn_info except Exception: self.disconnect() raise
def test_grpc_client_credentials_are_generated(monkeypatch): # Test that credentials are generated when secure is True, but _credentials # isn't passed. class Stop(Exception): def __init__(self, result): self.result = result def mock_gen_credentials(): raise Stop("ssl_channel_credentials called") monkeypatch.setattr(grpc, "ssl_channel_credentials", mock_gen_credentials) with pytest.raises(Stop) as stop: Worker(secure=True) assert stop.value.result == "ssl_channel_credentials called"
def connect(self, conn_str: str, secure: bool = False, metadata: List[Tuple[str, str]] = None, connection_retries: int = 3, *, ignore_version: bool = False) -> Dict[str, Any]: """Connect the Ray Client to a server. Args: conn_str: Connection string, in the form "[host]:port" secure: Whether to use a TLS secured gRPC channel metadata: gRPC metadata to send on connect Returns: Dictionary of connection info, e.g., {"num_clients": 1}. """ # Delay imports until connect to avoid circular imports. from ray.util.client.worker import Worker import ray._private.client_mode_hook if self.client_worker is not None: if self._connected_with_init: return raise Exception( "ray.connect() called, but ray client is already connected") if not self._inside_client_test: # If we're calling a client connect specifically and we're not # currently in client mode, ensure we are. ray._private.client_mode_hook._explicitly_enable_client_mode() try: self.client_worker = Worker(conn_str, secure=secure, metadata=metadata, connection_retries=connection_retries) self.api.worker = self.client_worker conn_info = self.client_worker.connection_info() self._check_versions(conn_info, ignore_version) return conn_info except Exception: self.disconnect() raise
def connect( self, conn_str: str, job_config: JobConfig = None, secure: bool = False, metadata: List[Tuple[str, str]] = None, connection_retries: int = 3, namespace: str = None, *, ignore_version: bool = False, _credentials: Optional[grpc.ChannelCredentials] = None, ray_init_kwargs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Connect the Ray Client to a server. Args: conn_str: Connection string, in the form "[host]:port" job_config: The job config of the server. secure: Whether to use a TLS secured gRPC channel metadata: gRPC metadata to send on connect connection_retries: number of connection attempts to make ignore_version: whether to ignore Python or Ray version mismatches. This should only be used for debugging purposes. Returns: Dictionary of connection info, e.g., {"num_clients": 1}. """ # Delay imports until connect to avoid circular imports. from ray.util.client.worker import Worker if self.client_worker is not None: if self._connected_with_init: return raise Exception("ray.init() called, but ray client is already connected") if not self._inside_client_test: # If we're calling a client connect specifically and we're not # currently in client mode, ensure we are. _explicitly_enable_client_mode() if namespace is not None: job_config = job_config or JobConfig() job_config.set_ray_namespace(namespace) logging_level = ray_constants.LOGGER_LEVEL logging_format = ray_constants.LOGGER_FORMAT if ray_init_kwargs is not None: if ray_init_kwargs.get("logging_level") is not None: logging_level = ray_init_kwargs["logging_level"] if ray_init_kwargs.get("logging_format") is not None: logging_format = ray_init_kwargs["logging_format"] setup_logger(logging_level, logging_format) try: self.client_worker = Worker( conn_str, secure=secure, _credentials=_credentials, metadata=metadata, connection_retries=connection_retries, ) self.api.worker = self.client_worker self.client_worker._server_init(job_config, ray_init_kwargs) conn_info = self.client_worker.connection_info() self._check_versions(conn_info, ignore_version) self._register_serializers() return conn_info except Exception: self.disconnect() raise