Example #1
0
def test_sync_job_config(shutdown_only):
    num_java_workers_per_process = 8
    runtime_env = {"env_vars": {"key": "value"}}

    ray.init(job_config=ray.job_config.JobConfig(
        num_java_workers_per_process=num_java_workers_per_process,
        runtime_env=runtime_env,
    ))

    # Check that the job config is synchronized at the driver side.
    job_config = ray.worker.global_worker.core_worker.get_job_config()
    assert job_config.num_java_workers_per_process == num_java_workers_per_process
    job_runtime_env = RuntimeEnv.deserialize(
        job_config.runtime_env_info.serialized_runtime_env)
    assert job_runtime_env.env_vars() == runtime_env["env_vars"]

    @ray.remote
    def get_job_config():
        job_config = ray.worker.global_worker.core_worker.get_job_config()
        return job_config.SerializeToString()

    # Check that the job config is synchronized at the worker side.
    job_config = gcs_utils.JobConfig()
    job_config.ParseFromString(ray.get(get_job_config.remote()))
    assert job_config.num_java_workers_per_process == num_java_workers_per_process
    job_runtime_env = RuntimeEnv.deserialize(
        job_config.runtime_env_info.serialized_runtime_env)
    assert job_runtime_env.env_vars() == runtime_env["env_vars"]
Example #2
0
def test_to_make_ensure_runtime_env_api(start_cluster):
    # make sure RuntimeEnv can be used in an be used interchangeably with
    # an unstructured dictionary in the relevant API calls.
    ENV_KEY = "TEST_RUNTIME_ENV"

    @ray.remote(runtime_env=RuntimeEnv(env_vars={ENV_KEY: "f1"}))
    def f1():
        assert os.environ.get(ENV_KEY) == "f1"

    ray.get(f1.remote())

    @ray.remote
    def f2():
        assert os.environ.get(ENV_KEY) == "f2"

    ray.get(
        f2.options(runtime_env=RuntimeEnv(env_vars={ENV_KEY: "f2"})).remote())

    @ray.remote(runtime_env=RuntimeEnv(env_vars={ENV_KEY: "a1"}))
    class A1:
        def f(self):
            assert os.environ.get(ENV_KEY) == "a1"

    a1 = A1.remote()
    ray.get(a1.f.remote())

    @ray.remote
    class A2:
        def f(self):
            assert os.environ.get(ENV_KEY) == "a2"

    a2 = A2.options(runtime_env=RuntimeEnv(env_vars={ENV_KEY: "a2"})).remote()
    ray.get(a2.f.remote())
Example #3
0
def test_serialize_deserialize(option):
    runtime_env = dict()
    if option == "pip_list":
        runtime_env["pip"] = ["pkg1", "pkg2"]
    elif option == "conda_name":
        runtime_env["conda"] = "env_name"
    elif option == "conda_dict":
        runtime_env["conda"] = {"dependencies": ["dep1", "dep2"]}
    elif option == "container":
        runtime_env["container"] = {
            "image": "anyscale/ray-ml:nightly-py38-cpu",
            "worker_path": "/root/python/ray/workers/default_worker.py",
            "run_options": ["--cap-drop SYS_ADMIN", "--log-level=debug"],
        }
    elif option == "plugins":
        runtime_env["plugins"] = {
            "class_path1": {
                "config1": "val1"
            },
            "class_path2": "string_config",
        }
    else:
        raise ValueError("unexpected option " + str(option))

    proto_runtime_env = RuntimeEnv(**runtime_env,
                                   _validate=False)._proto_runtime_env
    cls_runtime_env = RuntimeEnv.from_proto(proto_runtime_env)
    assert cls_runtime_env.to_dict() == runtime_env
    def test_serialization(self):
        env1 = RuntimeEnv(pip=["requests"],
                          env_vars={
                              "hi1": "hi1",
                              "hi2": "hi2"
                          })

        env2 = RuntimeEnv(env_vars={
            "hi2": "hi2",
            "hi1": "hi1"
        },
                          pip=["requests"])

        assert env1 == env2

        serialized_env1 = env1.serialize()
        serialized_env2 = env2.serialize()

        # Key ordering shouldn't matter.
        assert serialized_env1 == serialized_env2

        deserialized_env1 = RuntimeEnv.deserialize(serialized_env1)
        deserialized_env2 = RuntimeEnv.deserialize(serialized_env2)

        assert env1 == deserialized_env1 == env2 == deserialized_env2
Example #5
0
def test_runtime_env_config(start_cluster):
    _, address = start_cluster
    bad_configs = []
    bad_configs.append({"setup_timeout_seconds": 10.0})
    bad_configs.append({"setup_timeout_seconds": 0})
    bad_configs.append({"setup_timeout_seconds": "10"})

    good_configs = []
    good_configs.append({"setup_timeout_seconds": 10})
    good_configs.append({"setup_timeout_seconds": -1})

    @ray.remote
    def f():
        return True

    def raise_exception_run(fun, *args, **kwargs):
        try:
            fun(*args, **kwargs)
        except Exception:
            pass
        else:
            assert False

    for bad_config in bad_configs:

        def run(runtime_env):
            raise_exception_run(ray.init, address, runtime_env=runtime_env)
            raise_exception_run(f.options, runtime_env=runtime_env)

        runtime_env = {"config": bad_config}
        run(runtime_env)

        raise_exception_run(RuntimeEnvConfig, **bad_config)
        raise_exception_run(RuntimeEnv, config=bad_config)

    for good_config in good_configs:

        def run(runtime_env):
            ray.shutdown()
            ray.init(address, runtime_env=runtime_env)
            assert ray.get(f.options(runtime_env=runtime_env).remote())

        runtime_env = {"config": good_config}
        run(runtime_env)
        runtime_env = {"config": RuntimeEnvConfig(**good_config)}
        run(runtime_env)
        runtime_env = RuntimeEnv(config=good_config)
        run(runtime_env)
        runtime_env = RuntimeEnv(config=RuntimeEnvConfig(**good_config))
        run(runtime_env)
    def test_ray_commit_injection(self):
        # Should not be injected if no pip and conda.
        result = RuntimeEnv(env_vars={"hi": "hi"})
        assert "_ray_commit" not in result

        # Should be injected if pip or conda present.
        result = RuntimeEnv(pip=["requests"])
        assert "_ray_commit" in result

        result = RuntimeEnv(conda="env_name")
        assert "_ray_commit" in result

        # Should not override if passed.
        result = RuntimeEnv(conda="env_name", _ray_commit="Blah")
        assert result["_ray_commit"] == "Blah"
Example #7
0
def get_runtime_env_info(
    runtime_env: RuntimeEnv,
    *,
    is_job_runtime_env: bool = False,
    serialize: bool = False,
):
    """Create runtime env info from runtime env.

    In the user interface, the argument `runtime_env` contains some fields
    which not contained in `ProtoRuntimeEnv` but in `ProtoRuntimeEnvInfo`,
    such as `eager_install`. This function will extract those fields from
    `RuntimeEnv` and create a new `ProtoRuntimeEnvInfo`, and serialize it.
    """
    proto_runtime_env_info = ProtoRuntimeEnvInfo()

    proto_runtime_env_info.uris[:] = runtime_env.get_uris()

    # Normally, `RuntimeEnv` should guarantee the accuracy of field eager_install,
    # but so far, the internal code has not completely prohibited direct
    # modification of fields in RuntimeEnv, so we should check it for insurance.
    # TODO(Catch-Bull): overload `__setitem__` for `RuntimeEnv`, change the
    # runtime_env of all internal code from dict to RuntimeEnv.

    eager_install = runtime_env.get("eager_install")
    if is_job_runtime_env or eager_install is not None:
        if eager_install is None:
            eager_install = True
        elif not isinstance(eager_install, bool):
            raise TypeError(
                f"eager_install must be a boolean. got {type(eager_install)}")
        proto_runtime_env_info.runtime_env_eager_install = eager_install

    runtime_env_config = runtime_env.get("config")
    if runtime_env_config is None:
        runtime_env_config = RuntimeEnvConfig.default_config()
    else:
        runtime_env_config = RuntimeEnvConfig.parse_and_validate_runtime_env_config(
            runtime_env_config)

    proto_runtime_env_info.runtime_env_config.CopyFrom(
        runtime_env_config.build_proto_runtime_env_config())

    proto_runtime_env_info.serialized_runtime_env = runtime_env.serialize()

    if not serialize:
        return proto_runtime_env_info

    return json_format.MessageToJson(proto_runtime_env_info)
Example #8
0
    async def get_job_info(self):
        """Return info for each job.  Here a job is a Ray driver."""
        request = gcs_service_pb2.GetAllJobInfoRequest()
        reply = await self._gcs_job_info_stub.GetAllJobInfo(request, timeout=5)

        jobs = {}
        for job_table_entry in reply.job_info_list:
            job_id = job_table_entry.job_id.hex()
            metadata = dict(job_table_entry.config.metadata)
            config = {
                "namespace":
                job_table_entry.config.ray_namespace,
                "metadata":
                metadata,
                "runtime_env":
                RuntimeEnv.deserialize(job_table_entry.config.runtime_env_info.
                                       serialized_runtime_env),
            }
            info = self._get_job_info(metadata)
            entry = {
                "status": None if info is None else info.status,
                "status_message": None if info is None else info.message,
                "is_dead": job_table_entry.is_dead,
                "start_time": job_table_entry.start_time,
                "end_time": job_table_entry.end_time,
                "config": config,
            }
            jobs[job_id] = entry

        return jobs
def test_get_conda_dict_with_ray_inserted_m1_wheel(monkeypatch):
    # Disable dev mode to prevent Ray dependencies being automatically inserted
    # into the conda dict.
    if os.environ.get("RAY_RUNTIME_ENV_LOCAL_DEV_MODE") is not None:
        monkeypatch.delenv("RAY_RUNTIME_ENV_LOCAL_DEV_MODE")
    if os.environ.get("RAY_CI_POST_WHEEL_TESTS") is not None:
        monkeypatch.delenv("RAY_CI_POST_WHEEL_TESTS")
    monkeypatch.setattr(ray, "__version__", "1.9.0")
    monkeypatch.setattr(ray, "__commit__", "92599d9127e228fe8d0a2d94ca75754ec21c4ae4")
    monkeypatch.setattr(sys, "version_info", (3, 9, 7, "final", 0))
    # Simulate running on an M1 Mac.
    monkeypatch.setattr(sys, "platform", "darwin")
    monkeypatch.setattr(platform, "machine", lambda: "arm64")

    input_conda = {"dependencies": ["blah", "pip", {"pip": ["pip_pkg"]}]}
    runtime_env = RuntimeEnv(conda=input_conda)
    output_conda = _get_conda_dict_with_ray_inserted(runtime_env)
    # M1 wheels are not uploaded to AWS S3.  So rather than have an S3 URL
    # inserted as a dependency, we should just have the string "ray==1.9.0".
    assert output_conda == {
        "dependencies": [
            "blah",
            "pip",
            {"pip": ["ray==1.9.0", "ray[default]", "pip_pkg"]},
            "python=3.9.7",
        ]
    }
    def test_inject_current_ray(self):
        # Should not be injected if not provided by env var.
        result = RuntimeEnv(env_vars={"hi": "hi"})
        assert "_inject_current_ray" not in result

        os.environ["RAY_RUNTIME_ENV_LOCAL_DEV_MODE"] = "1"

        # Should be injected if provided by env var.
        result = RuntimeEnv()
        assert result["_inject_current_ray"]

        # Should be preserved if passed.
        result = RuntimeEnv(_inject_current_ray=False)
        assert not result["_inject_current_ray"]

        del os.environ["RAY_RUNTIME_ENV_LOCAL_DEV_MODE"]
Example #11
0
    def submit_job(
        self,
        *,
        entrypoint: str,
        job_id: Optional[str] = None,
        runtime_env: Optional[Dict[str, Any]] = None,
        metadata: Optional[Dict[str, str]] = None,
    ) -> str:
        runtime_env = runtime_env or {}
        metadata = metadata or {}
        metadata.update(self._default_metadata)

        self._upload_working_dir_if_needed(runtime_env)
        self._upload_py_modules_if_needed(runtime_env)

        # Run the RuntimeEnv constructor to parse local pip/conda requirements files.
        runtime_env = RuntimeEnv(**runtime_env).to_dict()

        req = JobSubmitRequest(
            entrypoint=entrypoint,
            job_id=job_id,
            runtime_env=runtime_env,
            metadata=metadata,
        )

        logger.debug(f"Submitting job with job_id={job_id}.")
        r = self._do_request("POST",
                             "/api/jobs/",
                             json_data=dataclasses.asdict(req))

        if r.status_code == 200:
            return JobSubmitResponse(**r.json()).job_id
        else:
            self._raise_error(r)
Example #12
0
def test_pip(start_cluster):
    cluster, address = start_cluster
    ray.init(address)

    runtime_env = RuntimeEnv()
    pip = Pip(packages=["pip-install-test==0.5"])
    runtime_env.set("pip", pip)

    @ray.remote
    class Actor:
        def foo(self):
            import pip_install_test  # noqa

            return "hello"

    a = Actor.options(runtime_env=runtime_env).remote()
    assert ray.get(a.foo.remote()) == "hello"
Example #13
0
def test_serialize_deserialize(option):
    runtime_env = dict()
    if option == "pip_list":
        runtime_env["pip"] = ["pkg1", "pkg2"]
    elif option == "pip_dict":
        runtime_env["pip"] = {
            "packages": ["pkg1", "pkg2"],
            "pip_check": False,
            "pip_version": "<22,>20",
        }
    elif option == "conda_name":
        runtime_env["conda"] = "env_name"
    elif option == "conda_dict":
        runtime_env["conda"] = {"dependencies": ["dep1", "dep2"]}
    elif option == "container":
        runtime_env["container"] = {
            "image": "anyscale/ray-ml:nightly-py38-cpu",
            "worker_path":
            "/root/python/ray/_private/workers/default_worker.py",
            "run_options": ["--cap-drop SYS_ADMIN", "--log-level=debug"],
        }
    elif option == "plugins":
        runtime_env["plugins"] = {
            "class_path1": {
                "config1": "val1"
            },
            "class_path2": "string_config",
        }
    else:
        raise ValueError("unexpected option " + str(option))

    proto_runtime_env = RuntimeEnv(**runtime_env,
                                   _validate=False).build_proto_runtime_env()
    cls_runtime_env = RuntimeEnv.from_proto(proto_runtime_env)
    cls_runtime_env_dict = cls_runtime_env.to_dict()

    if "pip" in runtime_env and isinstance(runtime_env["pip"], list):
        pip_config_in_cls_runtime_env = cls_runtime_env_dict.pop("pip")
        pip_config_in_runtime_env = runtime_env.pop("pip")
        assert {
            "packages": pip_config_in_runtime_env,
            "pip_check": False,
        } == pip_config_in_cls_runtime_env

    assert cls_runtime_env_dict == runtime_env
Example #14
0
    def _validate_runtime_env(self):
        # TODO(edoakes): this is really unfortunate, but JobConfig is imported
        # all over the place so this causes circular imports. We should remove
        # this dependency and pass in a validated runtime_env instead.
        from ray.runtime_env import RuntimeEnv

        if isinstance(self.runtime_env, RuntimeEnv):
            return self.runtime_env
        return RuntimeEnv(**self.runtime_env)
Example #15
0
    def runtime_env(self):
        """Get the runtime env used for the current driver or worker.

        Returns:
            The runtime env currently using by this worker. The type of
                return value is ray.runtime_env.RuntimeEnv.
        """

        return RuntimeEnv.deserialize(self.get_runtime_env_string())
Example #16
0
 def test_validate_working_dir(self, set_runtime_env_plugin_schemas):
     runtime_env = RuntimeEnv()
     runtime_env.set("working_dir", "https://abc/file.zip")
     with pytest.raises(jsonschema.exceptions.ValidationError,
                        match="working_dir"):
         runtime_env.set("working_dir", ["https://abc/file.zip"])
     runtime_env["working_dir"] = "https://abc/file.zip"
     with pytest.raises(jsonschema.exceptions.ValidationError,
                        match="working_dir"):
         runtime_env["working_dir"] = ["https://abc/file.zip"]
Example #17
0
def parse_runtime_env(runtime_env: Optional[Union[Dict, RuntimeEnv]]):
    # Parse local pip/conda config files here. If we instead did it in
    # .remote(), it would get run in the Ray Client server, which runs on
    # a remote node where the files aren't available.
    if runtime_env:
        if isinstance(runtime_env, dict):
            return RuntimeEnv(**(runtime_env or {}))
        raise TypeError(
            "runtime_env must be dict or RuntimeEnv, ",
            f"but got: {type(runtime_env)}",
        )
    else:
        # Keep the new_runtime_env as None.  In .remote(), we need to know
        # if runtime_env is None to know whether or not to fall back to the
        # runtime_env specified in the @ray.remote decorator.
        return None
Example #18
0
 def test_validate_pip(self, set_runtime_env_plugin_schemas):
     runtime_env = RuntimeEnv()
     runtime_env.set("pip", {"packages": ["requests"], "pip_check": True})
     with pytest.raises(jsonschema.exceptions.ValidationError,
                        match="pip_check"):
         runtime_env.set("pip", {
             "packages": ["requests"],
             "pip_check": "1"
         })
     runtime_env["pip"] = {"packages": ["requests"], "pip_check": True}
     with pytest.raises(jsonschema.exceptions.ValidationError,
                        match="pip_check"):
         runtime_env["pip"] = {"packages": ["requests"], "pip_check": "1"}
Example #19
0
    async def list_runtime_envs(self, *, option: ListApiOptions) -> List[dict]:
        """List all runtime env information from the cluster.

        Returns:
            A list of runtime env information in the cluster.
            The schema of returned "dict" is equivalent to the
            `RuntimeEnvState` protobuf message.
            We don't have id -> data mapping like other API because runtime env
            doesn't have unique ids.
        """
        replies = await asyncio.gather(*[
            self._client.get_runtime_envs_info(node_id, timeout=option.timeout)
            for node_id in self._client.get_all_registered_agent_ids()
        ])
        result = []
        for node_id, reply in zip(self._client.get_all_registered_agent_ids(),
                                  replies):
            states = reply.runtime_env_states
            for state in states:
                data = self._message_to_dict(message=state,
                                             fields_to_decode=[])
                # Need to deseiralize this field.
                data["runtime_env"] = RuntimeEnv.deserialize(
                    data["runtime_env"]).to_dict()
                data["node_id"] = node_id
                data = filter_fields(data, RuntimeEnvState)
                result.append(data)

        # Sort to make the output deterministic.
        def sort_func(entry):
            # If creation time is not there yet (runtime env is failed
            # to be created or not created yet, they are the highest priority.
            # Otherwise, "bigger" creation time is coming first.
            if "creation_time_ms" not in entry:
                return float("inf")
            elif entry["creation_time_ms"] is None:
                return float("inf")
            else:
                return float(entry["creation_time_ms"])

        result.sort(key=sort_func, reverse=True)
        return list(islice(result, option.limit))
Example #20
0
    def runtime_env(self):
        """Get the runtime env of the current job/worker.

        If this API is called in driver or ray client, returns the job level runtime
        env.
        If this API is called in workers/actors, returns the worker level runtime env.

        Returns:
            A new ray.runtime_env.RuntimeEnv instance.

        To merge from the current runtime env in some specific cases, you can get the
        current runtime env by this API and modify it by yourself.

        Example:

            >>> # Inherit current runtime env, except `env_vars`
            >>> Actor.options( # doctest: +SKIP
            ...     runtime_env=ray.get_runtime_context().runtime_env.update(
            ...     {"env_vars": {"A": "a", "B": "b"}})
            ... )

        """

        return RuntimeEnv.deserialize(self._get_runtime_env_string())
Example #21
0
    async def DeleteRuntimeEnvIfPossible(self, request, context):
        self._logger.info(
            f"Got request from {request.source_process} to decrease "
            "reference for runtime env: "
            f"{request.serialized_runtime_env}.")

        try:
            runtime_env = RuntimeEnv.deserialize(
                request.serialized_runtime_env)
        except Exception as e:
            self._logger.exception("[Decrease] Failed to parse runtime env: "
                                   f"{request.serialized_runtime_env}")
            return runtime_env_agent_pb2.GetOrCreateRuntimeEnvReply(
                status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
                error_message="".join(
                    traceback.format_exception(type(e), e, e.__traceback__)),
            )

        self._reference_table.decrease_reference(
            runtime_env, request.serialized_runtime_env,
            request.source_process)

        return runtime_env_agent_pb2.DeleteRuntimeEnvIfPossibleReply(
            status=agent_manager_pb2.AGENT_RPC_STATUS_OK)
Example #22
0
async def test_api_manager_list_runtime_envs(state_api_manager):
    data_source_client = state_api_manager.data_source_client
    data_source_client.get_all_registered_agent_ids = MagicMock()
    data_source_client.get_all_registered_agent_ids.return_value = [
        "1", "2", "3"
    ]

    data_source_client.get_runtime_envs_info.side_effect = [
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["requests"]})),
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["tensorflow"]}),
                                  creation_time=15),
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]}),
                                  creation_time=10),
    ]
    result = await state_api_manager.list_runtime_envs(
        option=list_api_options())
    data_source_client.get_runtime_envs_info.assert_any_call(
        "1", timeout=DEFAULT_RPC_TIMEOUT)
    data_source_client.get_runtime_envs_info.assert_any_call(
        "2", timeout=DEFAULT_RPC_TIMEOUT)
    data_source_client.get_runtime_envs_info.assert_any_call(
        "3", timeout=DEFAULT_RPC_TIMEOUT)
    assert len(result) == 3
    verify_schema(RuntimeEnvState, result[0])
    verify_schema(RuntimeEnvState, result[1])
    verify_schema(RuntimeEnvState, result[2])

    # Make sure the higher creation time is sorted first.
    assert "creation_time_ms" not in result[0]
    result[1]["creation_time_ms"] > result[2]["creation_time_ms"]
    """
    Test limit
    """
    data_source_client.get_runtime_envs_info.side_effect = [
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["requests"]})),
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["tensorflow"]}),
                                  creation_time=15),
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})),
    ]
    result = await state_api_manager.list_runtime_envs(option=list_api_options(
        limit=1))
    assert len(result) == 1
Example #23
0
def test_convert_from_and_to_dataclass():
    runtime_env = RuntimeEnv()
    test_plugin = TestPlugin(
        field1=[
            ValueType(nfield1=["a", "b", "c"], nfield2=False),
            ValueType(nfield1=["d", "e"], nfield2=True),
        ],
        field2="abc",
    )
    runtime_env.set("test_plugin", test_plugin)
    serialized_runtime_env = runtime_env.serialize()
    assert "test_plugin" in serialized_runtime_env
    runtime_env_2 = RuntimeEnv.deserialize(serialized_runtime_env)
    test_plugin_2 = runtime_env_2.get("test_plugin", data_class=TestPlugin)
    assert len(test_plugin_2.field1) == 2
    assert test_plugin_2.field1[0].nfield1 == ["a", "b", "c"]
    assert test_plugin_2.field1[0].nfield2 is False
    assert test_plugin_2.field1[1].nfield1 == ["d", "e"]
    assert test_plugin_2.field1[1].nfield2 is True
    assert test_plugin_2.field2 == "abc"
def test_key_with_value_none():
    parsed_runtime_env = RuntimeEnv(pip=None)
    assert parsed_runtime_env == {}
Example #25
0
        async def _setup_runtime_env(
            serialized_runtime_env, serialized_allocated_resource_instances
        ):
            runtime_env = RuntimeEnv.deserialize(serialized_runtime_env)
            allocated_resource: dict = json.loads(
                serialized_allocated_resource_instances or "{}"
            )

            # Use a separate logger for each job.
            per_job_logger = self.get_or_create_logger(request.job_id)
            # TODO(chenk008): Add log about allocated_resource to
            # avoid lint error. That will be moved to cgroup plugin.
            per_job_logger.debug(f"Worker has resource :" f"{allocated_resource}")
            context = RuntimeEnvContext(env_vars=runtime_env.env_vars())
            await self._container_manager.setup(
                runtime_env, context, logger=per_job_logger
            )

            for (manager, uri_cache) in [
                (self._working_dir_manager, self._working_dir_uri_cache),
                (self._conda_manager, self._conda_uri_cache),
                (self._pip_manager, self._pip_uri_cache),
            ]:
                uri = manager.get_uri(runtime_env)
                if uri is not None:
                    if uri not in uri_cache:
                        per_job_logger.debug(f"Cache miss for URI {uri}.")
                        size_bytes = await manager.create(
                            uri, runtime_env, context, logger=per_job_logger
                        )
                        uri_cache.add(uri, size_bytes, logger=per_job_logger)
                    else:
                        per_job_logger.debug(f"Cache hit for URI {uri}.")
                        uri_cache.mark_used(uri, logger=per_job_logger)
                manager.modify_context(uri, runtime_env, context)

            # Set up py_modules. For now, py_modules uses multiple URIs so
            # the logic is slightly different from working_dir, conda, and
            # pip above.
            py_modules_uris = self._py_modules_manager.get_uris(runtime_env)
            if py_modules_uris is not None:
                for uri in py_modules_uris:
                    if uri not in self._py_modules_uri_cache:
                        per_job_logger.debug(f"Cache miss for URI {uri}.")
                        size_bytes = await self._py_modules_manager.create(
                            uri, runtime_env, context, logger=per_job_logger
                        )
                        self._py_modules_uri_cache.add(
                            uri, size_bytes, logger=per_job_logger
                        )
                    else:
                        per_job_logger.debug(f"Cache hit for URI {uri}.")
                        self._py_modules_uri_cache.mark_used(uri, logger=per_job_logger)
            self._py_modules_manager.modify_context(
                py_modules_uris, runtime_env, context
            )

            # Add the mapping of URIs -> the serialized environment to be
            # used for cache invalidation.
            if runtime_env.working_dir_uri():
                uri = runtime_env.working_dir_uri()
                self._uris_to_envs[uri].add(serialized_runtime_env)
            if runtime_env.py_modules_uris():
                for uri in runtime_env.py_modules_uris():
                    self._uris_to_envs[uri].add(serialized_runtime_env)
            if runtime_env.conda_uri():
                uri = runtime_env.conda_uri()
                self._uris_to_envs[uri].add(serialized_runtime_env)
            if runtime_env.pip_uri():
                uri = runtime_env.pip_uri()
                self._uris_to_envs[uri].add(serialized_runtime_env)
            if runtime_env.plugin_uris():
                for uri in runtime_env.plugin_uris():
                    self._uris_to_envs[uri].add(serialized_runtime_env)

            def setup_plugins():
                # Run setup function from all the plugins
                for plugin_class_path, config in runtime_env.plugins():
                    per_job_logger.debug(
                        f"Setting up runtime env plugin {plugin_class_path}"
                    )
                    plugin_class = import_attr(plugin_class_path)
                    # TODO(simon): implement uri support
                    plugin_class.create(
                        "uri not implemented", json.loads(config), context
                    )
                    plugin_class.modify_context(
                        "uri not implemented", json.loads(config), context
                    )

            loop = asyncio.get_event_loop()
            # Plugins setup method is sync process, running in other threads
            # is to avoid  blocks asyncio loop
            await loop.run_in_executor(None, setup_plugins)

            return context
Example #26
0
    async def GetOrCreateRuntimeEnv(self, request, context):
        self._logger.debug(
            f"Got request from {request.source_process} to increase "
            "reference for runtime env: "
            f"{request.serialized_runtime_env}.")

        async def _setup_runtime_env(runtime_env, serialized_runtime_env,
                                     serialized_allocated_resource_instances):
            allocated_resource: dict = json.loads(
                serialized_allocated_resource_instances or "{}")
            # Use a separate logger for each job.
            per_job_logger = self.get_or_create_logger(request.job_id)
            # TODO(chenk008): Add log about allocated_resource to
            # avoid lint error. That will be moved to cgroup plugin.
            per_job_logger.debug(f"Worker has resource :"
                                 f"{allocated_resource}")
            context = RuntimeEnvContext(env_vars=runtime_env.env_vars())
            await self._container_manager.setup(runtime_env,
                                                context,
                                                logger=per_job_logger)

            for manager in self._base_plugin_cache_managers.values():
                await manager.create_if_needed(runtime_env,
                                               context,
                                               logger=per_job_logger)

            def setup_plugins():
                # Run setup function from all the plugins
                for name, config in runtime_env.plugins():
                    per_job_logger.debug(
                        f"Setting up runtime env plugin {name}")
                    plugin = self._runtime_env_plugin_manager.get_plugin(name)
                    if plugin is None:
                        raise RuntimeError(
                            f"runtime env plugin {name} not found.")
                    # TODO(architkulkarni): implement uri support
                    plugin.validate(runtime_env)
                    plugin.create("uri not implemented", json.loads(config),
                                  context)
                    plugin.modify_context(
                        "uri not implemented",
                        json.loads(config),
                        context,
                        per_job_logger,
                    )

            loop = asyncio.get_event_loop()
            # Plugins setup method is sync process, running in other threads
            # is to avoid blocking asyncio loop
            await loop.run_in_executor(None, setup_plugins)

            return context

        async def _create_runtime_env_with_retry(
            runtime_env,
            serialized_runtime_env,
            serialized_allocated_resource_instances,
            setup_timeout_seconds,
        ) -> Tuple[bool, str, str]:
            """
            Create runtime env with retry times. This function won't raise exceptions.

            Args:
                runtime_env(RuntimeEnv): The instance of RuntimeEnv class.
                serialized_runtime_env(str): The serialized runtime env.
                serialized_allocated_resource_instances(str): The serialized allocated
                resource instances.
                setup_timeout_seconds(int): The timeout of runtime environment creation.

            Returns:
                a tuple which contains result(bool), runtime env context(str), error
                message(str).

            """
            self._logger.info(
                f"Creating runtime env: {serialized_env} with timeout "
                f"{setup_timeout_seconds} seconds.")
            serialized_context = None
            error_message = None
            for _ in range(runtime_env_consts.RUNTIME_ENV_RETRY_TIMES):
                try:
                    # python 3.6 requires the type of input is `Future`,
                    # python 3.7+ only requires the type of input is `Awaitable`
                    # TODO(Catch-Bull): remove create_task when ray drop python 3.6
                    runtime_env_setup_task = create_task(
                        _setup_runtime_env(
                            runtime_env,
                            serialized_env,
                            request.serialized_allocated_resource_instances,
                        ))
                    runtime_env_context = await asyncio.wait_for(
                        runtime_env_setup_task, timeout=setup_timeout_seconds)
                    serialized_context = runtime_env_context.serialize()
                    error_message = None
                    break
                except Exception as e:
                    err_msg = f"Failed to create runtime env {serialized_env}."
                    self._logger.exception(err_msg)
                    error_message = "".join(
                        traceback.format_exception(type(e), e,
                                                   e.__traceback__))
                    await asyncio.sleep(
                        runtime_env_consts.RUNTIME_ENV_RETRY_INTERVAL_MS / 1000
                    )
            if error_message:
                self._logger.error(
                    "Runtime env creation failed for %d times, "
                    "don't retry any more.",
                    runtime_env_consts.RUNTIME_ENV_RETRY_TIMES,
                )
                return False, None, error_message
            else:
                self._logger.info(
                    "Successfully created runtime env: %s, the context: %s",
                    serialized_env,
                    serialized_context,
                )
                return True, serialized_context, None

        try:
            serialized_env = request.serialized_runtime_env
            runtime_env = RuntimeEnv.deserialize(serialized_env)
        except Exception as e:
            self._logger.exception("[Increase] Failed to parse runtime env: "
                                   f"{serialized_env}")
            return runtime_env_agent_pb2.GetOrCreateRuntimeEnvReply(
                status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
                error_message="".join(
                    traceback.format_exception(type(e), e, e.__traceback__)),
            )

        # Increase reference
        self._reference_table.increase_reference(runtime_env, serialized_env,
                                                 request.source_process)

        if serialized_env not in self._env_locks:
            # async lock to prevent the same env being concurrently installed
            self._env_locks[serialized_env] = asyncio.Lock()

        async with self._env_locks[serialized_env]:
            if serialized_env in self._env_cache:
                serialized_context = self._env_cache[serialized_env]
                result = self._env_cache[serialized_env]
                if result.success:
                    context = result.result
                    self._logger.info("Runtime env already created "
                                      f"successfully. Env: {serialized_env}, "
                                      f"context: {context}")
                    return runtime_env_agent_pb2.GetOrCreateRuntimeEnvReply(
                        status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
                        serialized_runtime_env_context=context,
                    )
                else:
                    error_message = result.result
                    self._logger.info("Runtime env already failed. "
                                      f"Env: {serialized_env}, "
                                      f"err: {error_message}")
                    # Recover the reference.
                    self._reference_table.decrease_reference(
                        runtime_env, serialized_env, request.source_process)
                    return runtime_env_agent_pb2.GetOrCreateRuntimeEnvReply(
                        status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
                        error_message=error_message,
                    )

            if SLEEP_FOR_TESTING_S:
                self._logger.info(f"Sleeping for {SLEEP_FOR_TESTING_S}s.")
                time.sleep(int(SLEEP_FOR_TESTING_S))

            runtime_env_config = RuntimeEnvConfig.from_proto(
                request.runtime_env_config)
            # accroding to the document of `asyncio.wait_for`,
            # None means disable timeout logic
            setup_timeout_seconds = (
                None if runtime_env_config["setup_timeout_seconds"] == -1 else
                runtime_env_config["setup_timeout_seconds"])

            start = time.perf_counter()
            (
                successful,
                serialized_context,
                error_message,
            ) = await _create_runtime_env_with_retry(
                runtime_env,
                serialized_env,
                request.serialized_allocated_resource_instances,
                setup_timeout_seconds,
            )
            creation_time_ms = int(
                round((time.perf_counter() - start) * 1000, 0))
            if not successful:
                # Recover the reference.
                self._reference_table.decrease_reference(
                    runtime_env, serialized_env, request.source_process)
            # Add the result to env cache.
            self._env_cache[serialized_env] = CreatedEnvResult(
                successful,
                serialized_context if successful else error_message,
                creation_time_ms,
            )
            # Reply the RPC
            return runtime_env_agent_pb2.GetOrCreateRuntimeEnvReply(
                status=agent_manager_pb2.AGENT_RPC_STATUS_OK
                if successful else agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
                serialized_runtime_env_context=serialized_context,
                error_message=error_message,
            )
Example #27
0
File: sdk.py Project: krfricke/ray
    def submit_job(
        self,
        *,
        entrypoint: str,
        job_id: Optional[str] = None,
        runtime_env: Optional[Dict[str, Any]] = None,
        metadata: Optional[Dict[str, str]] = None,
    ) -> str:
        """Submit and execute a job asynchronously.

        When a job is submitted, it runs once to completion or failure. Retries or
        different runs with different parameters should be handled by the
        submitter. Jobs are bound to the lifetime of a Ray cluster, so if the
        cluster goes down, all running jobs on that cluster will be terminated.

        Example:
            >>> from ray.job_submission import JobSubmissionClient
            >>> client = JobSubmissionClient("http://127.0.0.1:8265") # doctest: +SKIP
            >>> client.submit_job( # doctest: +SKIP
            ...     entrypoint="python script.py",
            ...     runtime_env={
            ...         "working_dir": "./",
            ...         "pip": ["requests==2.26.0"]
            ...     }
            ... )  # doctest: +SKIP
            'raysubmit_4LamXRuQpYdSMg7J'

        Args:
            entrypoint: The shell command to run for this job.
            job_id: A unique ID for this job.
            runtime_env: The runtime environment to install and run this job in.
            metadata: Arbitrary data to store along with this job.

        Returns:
            The job ID of the submitted job.  If not specified, this is a randomly
            generated unique ID.

        Raises:
            RuntimeError: If the request to the job server fails, or if the specified
            job_id has already been used by a job on this cluster.
        """
        runtime_env = runtime_env or {}
        metadata = metadata or {}
        metadata.update(self._default_metadata)

        self._upload_working_dir_if_needed(runtime_env)
        self._upload_py_modules_if_needed(runtime_env)

        # Run the RuntimeEnv constructor to parse local pip/conda requirements files.
        runtime_env = RuntimeEnv(**runtime_env).to_dict()

        req = JobSubmitRequest(
            entrypoint=entrypoint,
            job_id=job_id,
            runtime_env=runtime_env,
            metadata=metadata,
        )

        logger.debug(f"Submitting job with job_id={job_id}.")
        r = self._do_request("POST", "/api/jobs/", json_data=dataclasses.asdict(req))

        if r.status_code == 200:
            return JobSubmitResponse(**r.json()).job_id
        else:
            self._raise_error(r)
def test_reference_table():
    expected_unused_uris = []
    expected_unused_runtime_env = str()

    def uris_parser(runtime_env) -> Tuple[str, UriType]:
        result = list()
        result.append((runtime_env.working_dir(), UriType.WORKING_DIR))
        py_module_uris = runtime_env.py_modules()
        for uri in py_module_uris:
            result.append((uri, UriType.PY_MODULES))
        return result

    def unused_uris_processor(unused_uris: List[Tuple[str, UriType]]) -> None:
        nonlocal expected_unused_uris
        assert expected_unused_uris
        for unused in unused_uris:
            assert unused in expected_unused_uris
            expected_unused_uris.remove(unused)
        assert not expected_unused_uris

    def unused_runtime_env_processor(unused_runtime_env: str) -> None:
        nonlocal expected_unused_runtime_env
        assert expected_unused_runtime_env
        assert expected_unused_runtime_env == unused_runtime_env
        expected_unused_runtime_env = None

    reference_table = ReferenceTable(
        uris_parser, unused_uris_processor, unused_runtime_env_processor
    )
    runtime_env_1 = RuntimeEnv(
        working_dir="s3://working_dir_1.zip",
        py_modules=["s3://py_module_A.zip", "s3://py_module_B.zip"],
    )
    runtime_env_2 = RuntimeEnv(
        working_dir="s3://working_dir_2.zip",
        py_modules=["s3://py_module_A.zip", "s3://py_module_C.zip"],
    )
    # Add runtime env 1
    reference_table.increase_reference(
        runtime_env_1, runtime_env_1.serialize(), "raylet"
    )
    # Add runtime env 2
    reference_table.increase_reference(
        runtime_env_2, runtime_env_2.serialize(), "raylet"
    )
    # Add runtime env 1 by `client_server`, this will be skipped by reference table.
    reference_table.increase_reference(
        runtime_env_1, runtime_env_1.serialize(), "client_server"
    )

    # Remove runtime env 1
    expected_unused_uris.append(("s3://working_dir_1.zip", UriType.WORKING_DIR))
    expected_unused_uris.append(("s3://py_module_B.zip", UriType.PY_MODULES))
    expected_unused_runtime_env = runtime_env_1.serialize()
    reference_table.decrease_reference(
        runtime_env_1, runtime_env_1.serialize(), "raylet"
    )
    assert not expected_unused_uris
    assert not expected_unused_runtime_env

    # Remove runtime env 2
    expected_unused_uris.append(("s3://working_dir_2.zip", UriType.WORKING_DIR))
    expected_unused_uris.append(("s3://py_module_A.zip", UriType.PY_MODULES))
    expected_unused_uris.append(("s3://py_module_C.zip", UriType.PY_MODULES))
    expected_unused_runtime_env = runtime_env_2.serialize()
    reference_table.decrease_reference(
        runtime_env_2, runtime_env_2.serialize(), "raylet"
    )
    assert not expected_unused_uris
    assert not expected_unused_runtime_env
Example #29
0
    async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse:
        """List all runtime env information from the cluster.

        Returns:
            A list of runtime env information in the cluster.
            The schema of returned "dict" is equivalent to the
            `RuntimeEnvState` protobuf message.
            We don't have id -> data mapping like other API because runtime env
            doesn't have unique ids.
        """
        agent_ids = self._client.get_all_registered_agent_ids()
        replies = await asyncio.gather(
            *[
                self._client.get_runtime_envs_info(node_id, timeout=option.timeout)
                for node_id in agent_ids
            ],
            return_exceptions=True,
        )

        result = []
        unresponsive_nodes = 0
        for node_id, reply in zip(self._client.get_all_registered_agent_ids(), replies):
            if isinstance(reply, DataSourceUnavailable):
                unresponsive_nodes += 1
                continue
            elif isinstance(reply, Exception):
                raise reply

            states = reply.runtime_env_states
            for state in states:
                data = self._message_to_dict(message=state, fields_to_decode=[])
                # Need to deseiralize this field.
                data["runtime_env"] = RuntimeEnv.deserialize(
                    data["runtime_env"]
                ).to_dict()
                data["node_id"] = node_id
                result.append(data)

        partial_failure_warning = None
        if len(agent_ids) > 0 and unresponsive_nodes > 0:
            warning_msg = NODE_QUERY_FAILURE_WARNING.format(
                type="agent",
                total=len(agent_ids),
                network_failures=unresponsive_nodes,
                log_command="dashboard_agent.log",
            )
            if unresponsive_nodes == len(agent_ids):
                raise DataSourceUnavailable(warning_msg)
            partial_failure_warning = (
                f"The returned data may contain incomplete result. {warning_msg}"
            )

        result = self._filter(result, option.filters, RuntimeEnvState)

        # Sort to make the output deterministic.
        def sort_func(entry):
            # If creation time is not there yet (runtime env is failed
            # to be created or not created yet, they are the highest priority.
            # Otherwise, "bigger" creation time is coming first.
            if "creation_time_ms" not in entry:
                return float("inf")
            elif entry["creation_time_ms"] is None:
                return float("inf")
            else:
                return float(entry["creation_time_ms"])

        result.sort(key=sort_func, reverse=True)
        return ListApiResponse(
            result=list(islice(result, option.limit)),
            partial_failure_warning=partial_failure_warning,
        )
Example #30
0
def test_runtime_env_interface():

    # Test the interface related to working_dir
    default_working_dir = "s3://bucket/key.zip"
    modify_working_dir = "s3://bucket/key_A.zip"
    runtime_env = RuntimeEnv(working_dir=default_working_dir)
    runtime_env_dict = runtime_env.to_dict()
    assert runtime_env.working_dir_uri() == default_working_dir
    runtime_env["working_dir"] = modify_working_dir
    runtime_env_dict["working_dir"] = modify_working_dir
    assert runtime_env.working_dir_uri() == modify_working_dir
    assert runtime_env.to_dict() == runtime_env_dict
    # Test that the modification of working_dir also works on
    # proto serialization
    assert runtime_env_dict == RuntimeEnv.from_proto(
        runtime_env.build_proto_runtime_env())
    runtime_env.pop("working_dir")
    assert runtime_env.to_dict() == {}

    # Test the interface related to py_modules
    init_py_modules = ["s3://bucket/key_1.zip", "s3://bucket/key_2.zip"]
    addition_py_modules = ["s3://bucket/key_3.zip", "s3://bucket/key_4.zip"]
    runtime_env = RuntimeEnv(py_modules=init_py_modules)
    runtime_env_dict = runtime_env.to_dict()
    assert set(runtime_env.py_modules_uris()) == set(init_py_modules)
    runtime_env["py_modules"].extend(addition_py_modules)
    runtime_env_dict["py_modules"].extend(addition_py_modules)
    assert set(runtime_env.py_modules_uris()) == set(init_py_modules +
                                                     addition_py_modules)
    assert runtime_env.to_dict() == runtime_env_dict
    # Test that the modification of py_modules also works on
    # proto serialization
    assert runtime_env_dict == RuntimeEnv.from_proto(
        runtime_env.build_proto_runtime_env())
    runtime_env.pop("py_modules")
    assert runtime_env.to_dict() == {}

    # Test the interface related to env_vars
    init_env_vars = {"A": "a", "B": "b"}
    update_env_vars = {"C": "c"}
    runtime_env = RuntimeEnv(env_vars=init_env_vars)
    runtime_env_dict = runtime_env.to_dict()
    runtime_env["env_vars"].update(update_env_vars)
    runtime_env_dict["env_vars"].update(update_env_vars)
    init_env_vars_copy = init_env_vars.copy()
    init_env_vars_copy.update(update_env_vars)
    assert runtime_env["env_vars"] == init_env_vars_copy
    assert runtime_env_dict == runtime_env.to_dict()
    # Test that the modification of env_vars also works on
    # proto serialization
    assert runtime_env_dict == RuntimeEnv.from_proto(
        runtime_env.build_proto_runtime_env())
    runtime_env.pop("env_vars")
    assert runtime_env.to_dict() == {}

    # Test the interface related to conda
    conda_name = "conda"
    modify_conda_name = "conda_A"
    conda_config = {"dependencies": ["dep1", "dep2"]}
    runtime_env = RuntimeEnv(conda=conda_name)
    runtime_env_dict = runtime_env.to_dict()
    assert runtime_env.has_conda()
    assert runtime_env.conda_env_name() == conda_name
    assert runtime_env.conda_config() is None
    runtime_env["conda"] = modify_conda_name
    runtime_env_dict["conda"] = modify_conda_name
    assert runtime_env_dict == runtime_env.to_dict()
    assert runtime_env.has_conda()
    assert runtime_env.conda_env_name() == modify_conda_name
    assert runtime_env.conda_config() is None
    runtime_env["conda"] = conda_config
    runtime_env_dict["conda"] = conda_config
    assert runtime_env_dict == runtime_env.to_dict()
    assert runtime_env.has_conda()
    assert runtime_env.conda_env_name() is None
    assert runtime_env.conda_config() == json.dumps(conda_config,
                                                    sort_keys=True)
    # Test that the modification of conda also works on
    # proto serialization
    assert runtime_env_dict == RuntimeEnv.from_proto(
        runtime_env.build_proto_runtime_env())
    runtime_env.pop("conda")
    assert runtime_env.to_dict() == {"_ray_commit": "{{RAY_COMMIT_SHA}}"}

    # Test the interface related to pip
    with tempfile.TemporaryDirectory() as tmpdir, chdir(tmpdir):
        requirement_file = os.path.join(tmpdir, "requirements.txt")
        requirement_packages = ["dep5", "dep6"]
        with open(requirement_file, "wt") as f:
            for package in requirement_packages:
                f.write(package)
                f.write("\n")

        pip_packages = ["dep1", "dep2"]
        addition_pip_packages = ["dep3", "dep4"]
        runtime_env = RuntimeEnv(pip=pip_packages)
        runtime_env_dict = runtime_env.to_dict()
        assert runtime_env.has_pip()
        assert set(runtime_env.pip_config()["packages"]) == set(pip_packages)
        assert runtime_env.virtualenv_name() is None
        runtime_env["pip"]["packages"].extend(addition_pip_packages)
        runtime_env_dict["pip"]["packages"].extend(addition_pip_packages)
        # The default value of pip_check is False
        runtime_env_dict["pip"]["pip_check"] = False
        assert runtime_env_dict == runtime_env.to_dict()
        assert runtime_env.has_pip()
        assert set(
            runtime_env.pip_config()["packages"]) == set(pip_packages +
                                                         addition_pip_packages)
        assert runtime_env.virtualenv_name() is None
        runtime_env["pip"] = requirement_file
        runtime_env_dict["pip"] = requirement_packages
        assert runtime_env.has_pip()
        assert set(
            runtime_env.pip_config()["packages"]) == set(requirement_packages)
        assert runtime_env.virtualenv_name() is None
        # The default value of pip_check is False
        runtime_env_dict["pip"] = dict(packages=runtime_env_dict["pip"],
                                       pip_check=False)
        assert runtime_env_dict == runtime_env.to_dict()
        # Test that the modification of pip also works on
        # proto serialization
        assert runtime_env_dict == RuntimeEnv.from_proto(
            runtime_env.build_proto_runtime_env())
        runtime_env.pop("pip")
        assert runtime_env.to_dict() == {"_ray_commit": "{{RAY_COMMIT_SHA}}"}

    # Test conflict
    with pytest.raises(ValueError):
        RuntimeEnv(pip=pip_packages, conda=conda_name)

    runtime_env = RuntimeEnv(pip=pip_packages)
    runtime_env["conda"] = conda_name
    with pytest.raises(ValueError):
        runtime_env.serialize()

    # Test the interface related to container
    container_init = {
        "image": "anyscale/ray-ml:nightly-py38-cpu",
        "worker_path": "/root/python/ray/workers/default_worker.py",
        "run_options": ["--cap-drop SYS_ADMIN", "--log-level=debug"],
    }
    update_container = {"image": "test_modify"}
    runtime_env = RuntimeEnv(container=container_init)
    runtime_env_dict = runtime_env.to_dict()
    assert runtime_env.has_py_container()
    assert runtime_env.py_container_image() == container_init["image"]
    assert runtime_env.py_container_worker_path(
    ) == container_init["worker_path"]
    assert runtime_env.py_container_run_options(
    ) == container_init["run_options"]
    runtime_env["container"].update(update_container)
    runtime_env_dict["container"].update(update_container)
    container_copy = container_init
    container_copy.update(update_container)
    assert runtime_env_dict == runtime_env.to_dict()
    assert runtime_env.has_py_container()
    assert runtime_env.py_container_image() == container_copy["image"]
    assert runtime_env.py_container_worker_path(
    ) == container_copy["worker_path"]
    assert runtime_env.py_container_run_options(
    ) == container_copy["run_options"]
    # Test that the modification of container also works on
    # proto serialization
    assert runtime_env_dict == RuntimeEnv.from_proto(
        runtime_env.build_proto_runtime_env())
    runtime_env.pop("container")
    assert runtime_env.to_dict() == {}