def connect_to_client_or_not(connect_to_client: bool): """Utility for running test logic with and without a Ray client connection. If client_connect is True, will connect to Ray client in context. If client_connect is False, does nothing. How to use: Given a test of the following form: def test_<name>(args): <initialize a ray cluster> <use the ray cluster> Modify the test to @pytest.mark.parametrize("connect_to_client", [False, True]) def test_<name>(args, connect_to_client) <initialize a ray cluster> with connect_to_client_or_not(connect_to_client): <use the ray cluster> Parameterize the argument connect over True, False to run the test with and without a Ray client connection. """ if connect_to_client: with ray_start_client_server(namespace=""), enable_client_mode(): yield else: yield
def test_pretask_posttask_shared_state_multi_client(ray_start_regular_shared): """ Repeat the last test with Ray client. """ class PretaskPosttaskCallback(RayDaskCallback): def __init__(self, suffix): self.suffix = suffix def _ray_pretask(self, key, object_refs): return key + self.suffix def _ray_posttask(self, key, result, pre_state): assert pre_state == key + self.suffix class PretaskOnlyCallback(RayDaskCallback): def _ray_pretask(self, key, object_refs): return "baz" class PosttaskOnlyCallback(RayDaskCallback): def _ray_posttask(self, key, result, pre_state): assert pre_state is None cb1 = PretaskPosttaskCallback("foo") cb2 = PretaskOnlyCallback() cb3 = PosttaskOnlyCallback() cb4 = PretaskPosttaskCallback("bar") with ray_start_client_server(), enable_client_mode(): with cb1, cb2, cb3, cb4: z = add(2, 3) result = z.compute(scheduler=ray_dask_get) assert result == 5
def test_rllib_integration(ray_start_regular_shared): with ray_start_client_server(): # Confirming the behavior of this context manager. # (Client mode hook not yet enabled.) assert not client_mode_should_convert() # Need to enable this for client APIs to be used. with enable_client_mode(): # Confirming mode hook is enabled. assert client_mode_should_convert() rock_paper_scissors_multiagent.main()
def test_rllib_integration_tune(ray_start_regular_shared): with ray_start_client_server(): # Confirming the behavior of this context manager. # (Client mode hook not yet enabled.) assert not client_mode_should_convert(auto_init=True) # Need to enable this for client APIs to be used. with enable_client_mode(): # Confirming mode hook is enabled. assert client_mode_should_convert(auto_init=True) tune.run("DQN", config={"env": "CartPole-v1"}, stop={"training_iteration": 2})
def test_client_gpu_ids(call_ray_stop_only): import ray ray.init(num_cpus=2) with enable_client_mode(): # No client connection. with pytest.raises(Exception) as e: ray.get_gpu_ids() assert str(e.value) == "Ray Client is not connected."\ " Please connect by calling `ray.init`." with ray_start_client_server(): # Now have a client connection. assert ray.get_gpu_ids() == []
async def test_serve_handle(ray_start_regular_shared): with ray_start_client_server() as ray: from ray import serve with enable_client_mode(): serve.start() @serve.deployment def hello(): return "hello" hello.deploy() handle = hello.get_handle() assert ray.get(handle.remote()) == "hello" assert await handle.remote() == "hello"
def test_client_mode_hook_thread_safe(ray_start_regular_shared): with ray_start_client_server(): with enable_client_mode(): assert client_mode_should_convert(auto_init=True) lock = threading.Lock() lock.acquire() q = queue.Queue() def disable(): with disable_client_hook(): q.put(client_mode_should_convert(auto_init=True)) lock.acquire() q.put(client_mode_should_convert(auto_init=True)) t = threading.Thread(target=disable) t.start() assert client_mode_should_convert(auto_init=True) lock.release() t.join() assert q.get() is False, "Threaded disable_client_hook failed to disable" assert q.get() is True, "Threaded disable_client_hook failed to re-enable"
def test_rllib_integration(ray_start_regular_shared): with ray_start_client_server(): import ray.rllib.agents.dqn as dqn # Confirming the behavior of this context manager. # (Client mode hook not yet enabled.) assert not client_mode_should_convert() # Need to enable this for client APIs to be used. with enable_client_mode(): # Confirming mode hook is enabled. assert client_mode_should_convert() config = dqn.SIMPLE_Q_DEFAULT_CONFIG.copy() # Run locally. config["num_workers"] = 0 # Test with compression. config["compress_observations"] = True num_iterations = 2 trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v1") rw = trainer.workers.local_worker() for i in range(num_iterations): sb = rw.sample() assert sb.count == config["rollout_fragment_length"] trainer.train()