def test_prepare_runtime_init_req_fails(): """ Check that a connection that is initiated with a non-Init request raises an error. """ put_req = ray_client_pb2.DataRequest(put=ray_client_pb2.PutRequest()) with pytest.raises(AssertionError): proxier.prepare_runtime_init_req(put_req)
def test_prepare_runtime_init_req_modified_job(): """ Check that `prepare_runtime_init_req` properly extracts the JobConfig and modifies it according to `ray_client_server_env_prep`. """ job_config = JobConfig( runtime_env={"env_vars": {"KEY": "VALUE"}}, ray_namespace="abc" ) init_req = ray_client_pb2.DataRequest( init=ray_client_pb2.InitRequest( job_config=pickle.dumps(job_config), ray_init_kwargs=json.dumps({"log_to_driver": False}), ) ) def modify_namespace(job_config: JobConfig): job_config.set_ray_namespace("test_value") return job_config with patch.object(proxier, "ray_client_server_env_prep", modify_namespace): req, new_config = proxier.prepare_runtime_init_req(init_req) assert new_config.ray_namespace == "test_value" assert pickle.loads(req.init.job_config).serialize() == new_config.serialize() assert json.loads(req.init.ray_init_kwargs) == {"log_to_driver": False}
def test_prepare_runtime_init_req_no_modification(): """ Check that `prepare_runtime_init_req` properly extracts the JobConfig. """ job_config = JobConfig(worker_env={"KEY": "VALUE"}, ray_namespace="abc") init_req = ray_client_pb2.DataRequest( init=ray_client_pb2.InitRequest(job_config=pickle.dumps(job_config))) req, new_config = proxier.prepare_runtime_init_req(init_req) assert new_config.serialize() == job_config.serialize() assert isinstance(req, ray_client_pb2.DataRequest) assert pickle.loads( req.init.job_config).serialize() == new_config.serialize()
def test_prepare_runtime_init_req_no_modification(): """ Check that `prepare_runtime_init_req` properly extracts the JobConfig. """ job_config = JobConfig(runtime_env={"env_vars": { "KEY": "VALUE" }}, ray_namespace="abc") init_req = ray_client_pb2.DataRequest(init=ray_client_pb2.InitRequest( job_config=pickle.dumps(job_config), ray_init_kwargs=json.dumps({"log_to_driver": False})), ) req, new_config = proxier.prepare_runtime_init_req(init_req) assert new_config.serialize() == job_config.serialize() assert isinstance(req, ray_client_pb2.DataRequest) assert pickle.loads( req.init.job_config).serialize() == new_config.serialize() assert json.loads(req.init.ray_init_kwargs) == {"log_to_driver": False}
def test_prepare_runtime_init_req_modified_job(): """ Check that `prepare_runtime_init_req` properly extracts the JobConfig and modifies it according to `ray_client_server_env_prep`. """ job_config = JobConfig(worker_env={"KEY": "VALUE"}, ray_namespace="abc") init_req = ray_client_pb2.DataRequest(init=ray_client_pb2.InitRequest( job_config=pickle.dumps(job_config))) def modify_namespace(job_config: JobConfig): job_config.set_ray_namespace("test_value") return job_config proxier.ray_client_server_env_prep = modify_namespace req, new_config = proxier.prepare_runtime_init_req(iter([init_req])) assert new_config.ray_namespace == "test_value" assert pickle.loads( req.init.job_config).serialize() == new_config.serialize()