コード例 #1
0
def connect_to_client_or_not(connect_to_client: bool):
    """Utility for running test logic with and without a Ray client connection.

    If client_connect is True, will connect to Ray client in context.
    If client_connect is False, does nothing.

    How to use:
    Given a test of the following form:

    def test_<name>(args):
        <initialize a ray cluster>
        <use the ray cluster>

    Modify the test to

    @pytest.mark.parametrize("connect_to_client", [False, True])
    def test_<name>(args, connect_to_client)
    <initialize a ray cluster>
    with connect_to_client_or_not(connect_to_client):
        <use the ray cluster>

    Parameterize the argument connect over True, False to run the test with and
    without a Ray client connection.
    """

    if connect_to_client:
        with ray_start_client_server(namespace=""), enable_client_mode():
            yield
    else:
        yield
コード例 #2
0
def test_pretask_posttask_shared_state_multi_client(ray_start_regular_shared):
    """
    Repeat the last test with Ray client.
    """

    class PretaskPosttaskCallback(RayDaskCallback):
        def __init__(self, suffix):
            self.suffix = suffix

        def _ray_pretask(self, key, object_refs):
            return key + self.suffix

        def _ray_posttask(self, key, result, pre_state):
            assert pre_state == key + self.suffix

    class PretaskOnlyCallback(RayDaskCallback):
        def _ray_pretask(self, key, object_refs):
            return "baz"

    class PosttaskOnlyCallback(RayDaskCallback):
        def _ray_posttask(self, key, result, pre_state):
            assert pre_state is None

    cb1 = PretaskPosttaskCallback("foo")
    cb2 = PretaskOnlyCallback()
    cb3 = PosttaskOnlyCallback()
    cb4 = PretaskPosttaskCallback("bar")
    with ray_start_client_server(), enable_client_mode():
        with cb1, cb2, cb3, cb4:
            z = add(2, 3)
            result = z.compute(scheduler=ray_dask_get)

    assert result == 5
コード例 #3
0
def test_rllib_integration(ray_start_regular_shared):
    with ray_start_client_server():
        # Confirming the behavior of this context manager.
        # (Client mode hook not yet enabled.)
        assert not client_mode_should_convert()
        # Need to enable this for client APIs to be used.
        with enable_client_mode():
            # Confirming mode hook is enabled.
            assert client_mode_should_convert()

            rock_paper_scissors_multiagent.main()
コード例 #4
0
def test_rllib_integration_tune(ray_start_regular_shared):
    with ray_start_client_server():
        # Confirming the behavior of this context manager.
        # (Client mode hook not yet enabled.)
        assert not client_mode_should_convert(auto_init=True)
        # Need to enable this for client APIs to be used.
        with enable_client_mode():
            # Confirming mode hook is enabled.
            assert client_mode_should_convert(auto_init=True)
            tune.run("DQN",
                     config={"env": "CartPole-v1"},
                     stop={"training_iteration": 2})
コード例 #5
0
def test_client_gpu_ids(call_ray_stop_only):
    import ray
    ray.init(num_cpus=2)

    with enable_client_mode():
        # No client connection.
        with pytest.raises(Exception) as e:
            ray.get_gpu_ids()
        assert str(e.value) == "Ray Client is not connected."\
            " Please connect by calling `ray.init`."

        with ray_start_client_server():
            # Now have a client connection.
            assert ray.get_gpu_ids() == []
コード例 #6
0
async def test_serve_handle(ray_start_regular_shared):
    with ray_start_client_server() as ray:
        from ray import serve
        with enable_client_mode():
            serve.start()

            @serve.deployment
            def hello():
                return "hello"

            hello.deploy()
            handle = hello.get_handle()
            assert ray.get(handle.remote()) == "hello"
            assert await handle.remote() == "hello"
コード例 #7
0
ファイル: test_client.py プロジェクト: felipeeeantunes/ray
def test_client_mode_hook_thread_safe(ray_start_regular_shared):
    with ray_start_client_server():
        with enable_client_mode():
            assert client_mode_should_convert(auto_init=True)
            lock = threading.Lock()
            lock.acquire()
            q = queue.Queue()

            def disable():
                with disable_client_hook():
                    q.put(client_mode_should_convert(auto_init=True))
                    lock.acquire()
                q.put(client_mode_should_convert(auto_init=True))

            t = threading.Thread(target=disable)
            t.start()
            assert client_mode_should_convert(auto_init=True)
            lock.release()
            t.join()
            assert q.get() is False, "Threaded disable_client_hook failed  to disable"
            assert q.get() is True, "Threaded disable_client_hook failed to re-enable"
コード例 #8
0
def test_rllib_integration(ray_start_regular_shared):
    with ray_start_client_server():
        import ray.rllib.agents.dqn as dqn
        # Confirming the behavior of this context manager.
        # (Client mode hook not yet enabled.)
        assert not client_mode_should_convert()
        # Need to enable this for client APIs to be used.
        with enable_client_mode():
            # Confirming mode hook is enabled.
            assert client_mode_should_convert()

            config = dqn.SIMPLE_Q_DEFAULT_CONFIG.copy()
            # Run locally.
            config["num_workers"] = 0
            # Test with compression.
            config["compress_observations"] = True
            num_iterations = 2
            trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v1")
            rw = trainer.workers.local_worker()
            for i in range(num_iterations):
                sb = rw.sample()
                assert sb.count == config["rollout_fragment_length"]
                trainer.train()