예제 #1
0
def test_two_node_local_file(two_node_cluster, working_dir, client_mode):
    with open(os.path.join(working_dir, "test_file"), "w") as f:
        f.write("1")
    cluster, _ = two_node_cluster
    (address, env, PKG_DIR) = start_client_server(cluster, client_mode)
    # test runtime_env iwth working_dir
    runtime_env = f"""{{  "working_dir": "{working_dir}" }}"""
    # Execute the following cmd in driver with runtime_env
    execute_statement = """
vals = ray.get([check_file.remote('test_file')] * 1000)
print(sum([int(v) for v in vals]))
"""
    script = driver_script.format(**locals())
    out = run_string_as_driver(script, env)
    assert out.strip().split()[-1] == "1000"
    assert len(list(Path(PKG_DIR).iterdir())) == 1
    assert len(kv._internal_kv_list("gcs://")) == 0
예제 #2
0
    async def get_serve_info(self) -> Dict[str, Any]:
        # Conditionally import serve to prevent ModuleNotFoundError from serve
        # dependencies when only ray[default] is installed (#17712)
        try:
            from ray.serve.controller import SNAPSHOT_KEY as SERVE_SNAPSHOT_KEY
            from ray.serve.constants import SERVE_CONTROLLER_NAME
        except Exception:
            return {}

        # Serve wraps Ray's internal KV store and specially formats the keys.
        # These are the keys we are interested in:
        # SERVE_CONTROLLER_NAME(+ optional random letters):SERVE_SNAPSHOT_KEY

        serve_keys = _internal_kv_list(
            SERVE_CONTROLLER_NAME, namespace=ray_constants.KV_NAMESPACE_SERVE)
        serve_snapshot_keys = filter(lambda k: SERVE_SNAPSHOT_KEY in str(k),
                                     serve_keys)

        deployments_per_controller: List[Dict[str, Any]] = []
        for key in serve_snapshot_keys:
            val_bytes = _internal_kv_get(
                key, namespace=ray_constants.KV_NAMESPACE_SERVE
            ) or "{}".encode("utf-8")
            deployments_per_controller.append(
                json.loads(val_bytes.decode("utf-8")))
        # Merge the deployments dicts of all controllers.
        deployments: Dict[str, Any] = {
            k: v
            for d in deployments_per_controller for k, v in d.items()
        }
        # Replace the keys (deployment names) with their hashes to prevent
        # collisions caused by the automatic conversion to camelcase by the
        # dashboard agent.
        deployments = {
            hashlib.sha1(name.encode()).hexdigest(): info
            for name, info in deployments.items()
        }
        return deployments
def check_internal_kv_gced():
    return len(kv._internal_kv_list("gcs://")) == 0
예제 #4
0
        return "DONE"

    def get_step(self):
        return self.current_step

    def stop(self):
        self.stopped = True


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--interval-s",
                        required=False,
                        type=int,
                        default=1,
                        help="Address to use to connect to Ray")
    parser.add_argument("--total-steps",
                        required=False,
                        type=int,
                        default=3,
                        help="Password for connecting to Redis")
    args, _ = parser.parse_known_args()

    ray.init()
    step_actor = StepActor.remote(interval_s=args.interval_s,
                                  total_steps=args.total_steps)
    ref = step_actor.run.remote()
    print(ray.get([ref]))
    job_key = ray_kv._internal_kv_list("JOB:")[0]
    print(f"{job_key}, {ray_kv._internal_kv_get(job_key)}")
예제 #5
0
 def get_all_jobs(self) -> Dict[str, JobStatusInfo]:
     raw_job_ids = _internal_kv_list(self.JOB_STATUS_KEY_PREFIX)
     job_ids = [job_id.decode() for job_id in raw_job_ids]
     return {job_id: self.get_status(job_id) for job_id in job_ids}