示例#1
0
def test_ray_dask_basic(ray_start_1_cpu):
    from ray.util.dask import ray_dask_get, enable_dask_on_ray, \
        disable_dask_on_ray

    @ray.remote
    def stringify(x):
        return "The answer is {}".format(x)

    zero_id = ray.put(0)

    def add(x, y):
        # Can retrieve ray objects from inside Dask.
        zero = ray.get(zero_id)
        # Can call Ray methods from inside Dask.
        return ray.get(stringify.remote(x + y + zero))

    add = dask.delayed(add)

    expected = "The answer is 6"
    # Test with explicit scheduler argument.
    assert add(2, 4).compute(scheduler=ray_dask_get) == expected

    # Test with config setter.
    enable_dask_on_ray()
    assert add(2, 4).compute() == expected
    disable_dask_on_ray()

    # Test with config setter as context manager.
    with enable_dask_on_ray():
        assert add(2, 4).compute() == expected

    # Test within Ray task.

    @ray.remote
    def call_add():
        z = add(2, 4)
        with ProgressBarCallback():
            r = z.compute(scheduler=ray_dask_get)
        return r

    ans = ray.get(call_add.remote())
    assert ans == "The answer is 6", ans
import dask.dataframe as dd
import numpy as np
import pandas as pd

# Start Ray.
# Tip: If connecting to an existing cluster, use ray.init(address="auto").
ray.init()

d_arr = da.from_array(np.random.randint(0, 1000, size=(256, 256)))

# The Dask scheduler submits the underlying task graph to Ray.
d_arr.mean().compute(scheduler=ray_dask_get)

# Use our Dask config helper to set the scheduler to ray_dask_get globally,
# without having to specify it on each compute call.
enable_dask_on_ray()

df = dd.from_pandas(pd.DataFrame(np.random.randint(0, 100, size=(1024, 2)),
                                 columns=["age", "grade"]),
                    npartitions=2)
df.groupby(["age"]).mean().compute()

disable_dask_on_ray()

# The Dask config helper can be used as a context manager, limiting the scope
# of the Dask-on-Ray scheduler to the context.
with enable_dask_on_ray():
    d_arr.mean().compute()

ray.shutdown()