Example #1
0
def test_prepare_runtime_init_req_modified_job():
    """
    Check that `prepare_runtime_init_req` properly extracts the JobConfig and
    modifies it according to `ray_client_server_env_prep`.
    """
    job_config = JobConfig(
        runtime_env={"env_vars": {"KEY": "VALUE"}}, ray_namespace="abc"
    )
    init_req = ray_client_pb2.DataRequest(
        init=ray_client_pb2.InitRequest(
            job_config=pickle.dumps(job_config),
            ray_init_kwargs=json.dumps({"log_to_driver": False}),
        )
    )

    def modify_namespace(job_config: JobConfig):
        job_config.set_ray_namespace("test_value")
        return job_config

    with patch.object(proxier, "ray_client_server_env_prep", modify_namespace):
        req, new_config = proxier.prepare_runtime_init_req(init_req)

    assert new_config.ray_namespace == "test_value"
    assert pickle.loads(req.init.job_config).serialize() == new_config.serialize()
    assert json.loads(req.init.ray_init_kwargs) == {"log_to_driver": False}
Example #2
0
def test_prepare_runtime_init_req_no_modification():
    """
    Check that `prepare_runtime_init_req` properly extracts the JobConfig.
    """
    job_config = JobConfig(worker_env={"KEY": "VALUE"}, ray_namespace="abc")
    init_req = ray_client_pb2.DataRequest(init=ray_client_pb2.InitRequest(
        job_config=pickle.dumps(job_config),
        ray_init_kwargs=json.dumps({"log_to_driver": False})), )
    req, new_config = proxier.prepare_runtime_init_req(init_req)
    assert new_config.serialize() == job_config.serialize()
    assert isinstance(req, ray_client_pb2.DataRequest)
    assert pickle.loads(
        req.init.job_config).serialize() == new_config.serialize()
    assert json.loads(req.init.ray_init_kwargs) == {"log_to_driver": False}
Example #3
0
def prepare_runtime_init_req(iterator: Iterator[ray_client_pb2.DataRequest]
                             ) -> Tuple[ray_client_pb2.DataRequest, JobConfig]:
    """
    Extract JobConfig and possibly mutate InitRequest before it is passed to
    the specific RayClient Server.
    """
    init_req = next(iterator)
    init_type = init_req.WhichOneof("type")
    assert init_type == "init", ("Received initial message of type "
                                 f"{init_type}, not 'init'.")
    req = init_req.init
    job_config = JobConfig()
    if req.job_config:
        job_config = pickle.loads(req.job_config)
    new_job_config = ray_client_server_env_prep(job_config)
    modified_init_req = ray_client_pb2.InitRequest(
        job_config=pickle.dumps(new_job_config))

    init_req.init.CopyFrom(modified_init_req)
    return (init_req, new_job_config)