예제 #1
0
def test_sync_job_config(shutdown_only):
    num_java_workers_per_process = 8
    worker_env = {
        "key": "value",
    }

    ray.init(
        job_config=ray.job_config.JobConfig(
            num_java_workers_per_process=num_java_workers_per_process,
            worker_env=worker_env))

    # Check that the job config is synchronized at the driver side.
    job_config = ray.worker.global_worker.core_worker.get_job_config()
    assert (job_config.num_java_workers_per_process ==
            num_java_workers_per_process)
    assert (job_config.worker_env == worker_env)

    @ray.remote
    def get_job_config():
        job_config = ray.worker.global_worker.core_worker.get_job_config()
        return job_config.SerializeToString()

    # Check that the job config is synchronized at the worker side.
    job_config = gcs_utils.JobConfig()
    job_config.ParseFromString(ray.get(get_job_config.remote()))
    assert (job_config.num_java_workers_per_process ==
            num_java_workers_per_process)
    assert (job_config.worker_env == worker_env)
예제 #2
0
def test_sync_job_config(shutdown_only):
    num_java_workers_per_process = 8
    runtime_env = {"env_vars": {"key": "value"}}

    ray.init(job_config=ray.job_config.JobConfig(
        num_java_workers_per_process=num_java_workers_per_process,
        runtime_env=runtime_env,
    ))

    # Check that the job config is synchronized at the driver side.
    job_config = ray.worker.global_worker.core_worker.get_job_config()
    assert job_config.num_java_workers_per_process == num_java_workers_per_process
    job_runtime_env = RuntimeEnv.deserialize(
        job_config.runtime_env_info.serialized_runtime_env)
    assert job_runtime_env.env_vars() == runtime_env["env_vars"]

    @ray.remote
    def get_job_config():
        job_config = ray.worker.global_worker.core_worker.get_job_config()
        return job_config.SerializeToString()

    # Check that the job config is synchronized at the worker side.
    job_config = gcs_utils.JobConfig()
    job_config.ParseFromString(ray.get(get_job_config.remote()))
    assert job_config.num_java_workers_per_process == num_java_workers_per_process
    job_runtime_env = RuntimeEnv.deserialize(
        job_config.runtime_env_info.serialized_runtime_env)
    assert job_runtime_env.env_vars() == runtime_env["env_vars"]
예제 #3
0
    def get_proto_job_config(self):
        """Return the protobuf structure of JobConfig."""
        if self._cached_pb is None:
            pb = gcs_utils.JobConfig()
            if self.ray_namespace is None:
                pb.ray_namespace = str(uuid.uuid4())
            else:
                pb.ray_namespace = self.ray_namespace
            pb.num_java_workers_per_process = self.num_java_workers_per_process
            pb.jvm_options.extend(self.jvm_options)
            pb.code_search_path.extend(self.code_search_path)
            for k, v in self.metadata.items():
                pb.metadata[k] = v

            parsed_env, eager_install = self._validate_runtime_env()
            pb.runtime_env_info.uris[:] = parsed_env.get_uris()
            pb.runtime_env_info.serialized_runtime_env = \
                parsed_env.serialize()
            pb.runtime_env_info.runtime_env_eager_install = eager_install

            if self._default_actor_lifetime is not None:
                pb.default_actor_lifetime = self._default_actor_lifetime
            self._cached_pb = pb

        return self._cached_pb
예제 #4
0
    def get_proto_job_config(self):
        """Return the protobuf structure of JobConfig."""
        # TODO(edoakes): this is really unfortunate, but JobConfig is imported
        # all over the place so this causes circular imports. We should remove
        # this dependency and pass in a validated runtime_env instead.
        from ray.utils import get_runtime_env_info

        if self._cached_pb is None:
            pb = gcs_utils.JobConfig()
            if self.ray_namespace is None:
                pb.ray_namespace = str(uuid.uuid4())
            else:
                pb.ray_namespace = self.ray_namespace
            pb.num_java_workers_per_process = self.num_java_workers_per_process
            pb.jvm_options.extend(self.jvm_options)
            pb.code_search_path.extend(self.code_search_path)
            for k, v in self.metadata.items():
                pb.metadata[k] = v

            parsed_env = self._validate_runtime_env()
            pb.runtime_env_info.CopyFrom(
                get_runtime_env_info(
                    parsed_env,
                    is_job_runtime_env=True,
                    serialize=False,
                ))

            if self._default_actor_lifetime is not None:
                pb.default_actor_lifetime = self._default_actor_lifetime
            self._cached_pb = pb

        return self._cached_pb
예제 #5
0
파일: job_config.py 프로젝트: rlan/ray
 def get_proto_job_config(self):
     """Return the prototype structure of JobConfig"""
     if self._cached_pb is None:
         self._cached_pb = gcs_utils.JobConfig()
         if self.ray_namespace is None:
             self._cached_pb.ray_namespace = str(uuid.uuid4())
         else:
             self._cached_pb.ray_namespace = self.ray_namespace
         self._cached_pb.num_java_workers_per_process = (
             self.num_java_workers_per_process)
         self._cached_pb.jvm_options.extend(self.jvm_options)
         self._cached_pb.code_search_path.extend(self.code_search_path)
         self._cached_pb.runtime_env.uris[:] = self.get_runtime_env_uris()
         serialized_env = self.get_serialized_runtime_env()
         self._cached_pb.runtime_env.serialized_runtime_env = serialized_env
         for k, v in self.metadata.items():
             self._cached_pb.metadata[k] = v
     return self._cached_pb