Exemple #1
0
def test_get_output_3(workflow_start_regular, tmp_path):
    cnt_file = tmp_path / "counter"
    cnt_file.write_text("0")
    error_flag = tmp_path / "error"
    error_flag.touch()

    @ray.remote
    def incr():
        v = int(cnt_file.read_text())
        cnt_file.write_text(str(v + 1))
        if error_flag.exists():
            raise ValueError()
        return 10

    with pytest.raises(workflow.WorkflowExecutionError):
        workflow.create(incr.options(**workflow.options(max_retries=0)).bind()).run(
            "incr"
        )

    assert cnt_file.read_text() == "1"

    from ray.exceptions import RaySystemError

    # TODO(suquark): We should prevent Ray from raising "RaySystemError",
    #   in workflow, because "RaySystemError" does not inherit the underlying
    #   error, so users and developers cannot catch the expected error.
    #   I feel this issue is a very annoying.
    with pytest.raises((RaySystemError, ValueError)):
        ray.get(workflow.get_output("incr"))

    assert cnt_file.read_text() == "1"
    error_flag.unlink()
    with pytest.raises((RaySystemError, ValueError)):
        ray.get(workflow.get_output("incr"))
    assert ray.get(workflow.resume("incr")) == 10
Exemple #2
0
def test_output_with_name(workflow_start_regular):
    @ray.remote
    def double(v):
        return 2 * v

    inner_task = double.options(**workflow.options(name="inner")).bind(1)
    outer_task = double.options(**workflow.options(name="outer")).bind(inner_task)
    result = workflow.create(outer_task).run_async("double")
    inner = workflow.get_output("double", name="inner")
    outer = workflow.get_output("double", name="outer")

    assert ray.get(inner) == 2
    assert ray.get(outer) == 4
    assert ray.get(result) == 4

    @workflow.options(name="double")
    @ray.remote
    def double_2(s):
        return s * 2

    inner_task = double_2.bind(1)
    outer_task = double_2.bind(inner_task)
    workflow_id = "double_2"
    result = workflow.create(outer_task).run_async(workflow_id)

    inner = workflow.get_output(workflow_id, name="double")
    outer = workflow.get_output(workflow_id, name="double_1")

    assert ray.get(inner) == 2
    assert ray.get(outer) == 4
    assert ray.get(result) == 4
Exemple #3
0
def test_dedupe_indirect(workflow_start_regular_shared, tmp_path):
    counter = Path(tmp_path) / "counter.txt"
    lock = Path(tmp_path) / "lock.txt"
    counter.write_text("0")

    @ray.remote
    def incr():
        with FileLock(str(lock)):
            c = int(counter.read_text())
            c += 1
            counter.write_text(f"{c}")

    @ray.remote
    def identity(a):
        return a

    @ray.remote
    def join(*a):
        return counter.read_text()

    # Here a is passed to two steps and we need to ensure
    # it's only executed once
    a = incr.bind()
    i1 = identity.bind(a)
    i2 = identity.bind(a)
    assert "1" == workflow.create(join.bind(i1, i2)).run()
    assert "2" == workflow.create(join.bind(i1, i2)).run()
    # pass a multiple times
    assert "3" == workflow.create(join.bind(a, a, a, a)).run()
    assert "4" == workflow.create(join.bind(a, a, a, a)).run()
def test_nested_catch_exception_3(workflow_start_regular_shared, tmp_path):
    """Test the case where the exception is not raised by the output task of
    a nested DAG."""
    @ray.remote
    def f3():
        return 10

    @ray.remote
    def f3_exc():
        raise ValueError()

    @ray.remote
    def f2(x):
        return x

    @ray.remote
    def f1(exc):
        if exc:
            return workflow.continuation(f2.bind(f3_exc.bind()))
        else:
            return workflow.continuation(f2.bind(f3.bind()))

    ret, err = workflow.create(
        f1.options(**workflow.options(
            catch_exceptions=True)).bind(True)).run()
    assert ret is None
    assert isinstance(err, ValueError)

    assert (10, None) == workflow.create(
        f1.options(**workflow.options(
            catch_exceptions=True)).bind(False)).run()
Exemple #5
0
def test_runtime_metadata(workflow_start_regular):

    step_name = "simple_step"
    workflow_id = "simple"

    @workflow.options(name=step_name)
    @ray.remote
    def simple():
        time.sleep(2)
        return 0

    workflow.create(simple.bind()).run(workflow_id)

    workflow_metadata = workflow.get_metadata("simple")
    assert "start_time" in workflow_metadata["stats"]
    assert "end_time" in workflow_metadata["stats"]
    assert (
        workflow_metadata["stats"]["end_time"]
        >= workflow_metadata["stats"]["start_time"] + 2
    )

    step_metadata = workflow.get_metadata("simple", "simple_step")
    assert "start_time" in step_metadata["stats"]
    assert "end_time" in step_metadata["stats"]
    assert (
        step_metadata["stats"]["end_time"] >= step_metadata["stats"]["start_time"] + 2
    )
Exemple #6
0
def test_nested_workflow(workflow_start_regular):
    @workflow.options(name="inner", metadata={"inner_k": "inner_v"})
    @ray.remote
    def inner():
        time.sleep(2)
        return 10

    @workflow.options(name="outer", metadata={"outer_k": "outer_v"})
    @ray.remote
    def outer():
        time.sleep(2)
        return workflow.continuation(inner.bind())

    workflow.create(outer.bind()).run("nested",
                                      metadata={"workflow_k": "workflow_v"})

    workflow_metadata = workflow.get_metadata("nested")
    outer_step_metadata = workflow.get_metadata("nested", "outer")
    inner_step_metadata = workflow.get_metadata("nested", "inner")

    assert workflow_metadata["user_metadata"] == {"workflow_k": "workflow_v"}
    assert outer_step_metadata["user_metadata"] == {"outer_k": "outer_v"}
    assert inner_step_metadata["user_metadata"] == {"inner_k": "inner_v"}

    assert (workflow_metadata["stats"]["end_time"] >=
            workflow_metadata["stats"]["start_time"] + 4)
    assert (outer_step_metadata["stats"]["end_time"] >=
            outer_step_metadata["stats"]["start_time"] + 2)
    assert (inner_step_metadata["stats"]["end_time"] >=
            inner_step_metadata["stats"]["start_time"] + 2)
    assert (inner_step_metadata["stats"]["start_time"] >=
            outer_step_metadata["stats"]["end_time"])
def test_get_output_3(workflow_start_regular, tmp_path):
    cnt_file = tmp_path / "counter"
    cnt_file.write_text("0")
    error_flag = tmp_path / "error"
    error_flag.touch()

    @ray.remote
    def incr():
        v = int(cnt_file.read_text())
        cnt_file.write_text(str(v + 1))
        if error_flag.exists():
            raise ValueError()
        return 10

    with pytest.raises(ray.exceptions.RaySystemError):
        workflow.create(incr.options(max_retries=0).bind()).run("incr")

    assert cnt_file.read_text() == "1"

    with pytest.raises(ray.exceptions.RaySystemError):
        ray.get(workflow.get_output("incr"))

    assert cnt_file.read_text() == "1"
    error_flag.unlink()
    with pytest.raises(ray.exceptions.RaySystemError):
        ray.get(workflow.get_output("incr"))
    assert ray.get(workflow.resume("incr")) == 10
Exemple #8
0
def test_running_and_canceled_workflow(workflow_start_regular, tmp_path):

    workflow_id = "simple"
    flag = tmp_path / "flag"

    @ray.remote
    def simple():
        flag.touch()
        time.sleep(1000)
        return 0

    workflow.create(simple.bind()).run_async(workflow_id)

    # Wait until step runs to make sure pre-run metadata is written
    while not flag.exists():
        time.sleep(1)

    workflow_metadata = workflow.get_metadata(workflow_id)
    assert workflow_metadata["status"] == "RUNNING"
    assert "start_time" in workflow_metadata["stats"]
    assert "end_time" not in workflow_metadata["stats"]

    workflow.cancel(workflow_id)

    workflow_metadata = workflow.get_metadata(workflow_id)
    assert workflow_metadata["status"] == "CANCELED"
    assert "start_time" in workflow_metadata["stats"]
    assert "end_time" not in workflow_metadata["stats"]
Exemple #9
0
def test_successful_workflow(workflow_start_regular):

    user_step_metadata = {"k1": "v1"}
    user_run_metadata = {"k2": "v2"}
    step_name = "simple_step"
    workflow_id = "simple"

    @workflow.options(name=step_name, metadata=user_step_metadata)
    @ray.remote
    def simple():
        time.sleep(2)
        return 0

    workflow.create(simple.bind()).run(workflow_id, metadata=user_run_metadata)

    workflow_metadata = workflow.get_metadata("simple")
    assert workflow_metadata["status"] == "SUCCESSFUL"
    assert workflow_metadata["user_metadata"] == user_run_metadata
    assert "start_time" in workflow_metadata["stats"]
    assert "end_time" in workflow_metadata["stats"]
    assert (
        workflow_metadata["stats"]["end_time"]
        >= workflow_metadata["stats"]["start_time"] + 2
    )

    step_metadata = workflow.get_metadata("simple", "simple_step")
    assert step_metadata["user_metadata"] == user_step_metadata
    assert "start_time" in step_metadata["stats"]
    assert "end_time" in step_metadata["stats"]
    assert (
        step_metadata["stats"]["end_time"] >= step_metadata["stats"]["start_time"] + 2
    )
def test_object_deref(workflow_start_regular_shared):
    @ray.remote
    def empty_list():
        return [1]

    @ray.remote
    def receive_workflow(workflow):
        pass

    @ray.remote
    def return_workflow():
        return workflow.create(empty_list.bind())

    @ray.remote
    def return_data() -> ray.ObjectRef:
        return ray.put(np.ones(4096))

    @ray.remote
    def receive_data(data: "ray.ObjectRef[np.ndarray]"):
        return ray.get(data)

    # test we are forbidden from directly passing workflow to Ray.
    x = workflow.create(empty_list.bind())
    with pytest.raises(ValueError):
        ray.put(x)
    with pytest.raises(ValueError):
        ray.get(receive_workflow.remote(x))
    with pytest.raises(ValueError):
        ray.get(return_workflow.remote())

    # test return object ref
    obj = return_data.bind()
    arr: np.ndarray = workflow.create(receive_data.bind(obj)).run()
    assert np.array_equal(arr, np.ones(4096))
Exemple #11
0
def test_dedupe_serialization(workflow_start_regular_shared):
    @ray.remote(num_cpus=0)
    class Counter:
        def __init__(self):
            self.count = 0

        def incr(self):
            self.count += 1

        def get_count(self):
            return self.count

    counter = Counter.remote()

    class CustomClass:
        def __getstate__(self):
            # Count the number of times this class is serialized.
            ray.get(counter.incr.remote())
            return {}

    ref = ray.put(CustomClass())
    list_of_refs = [ref for _ in range(2)]

    # One for the ray.put
    assert ray.get(counter.get_count.remote()) == 1

    single = identity.bind((ref, ))
    double = identity.bind(list_of_refs)

    workflow.create(gather.bind(single, double)).run()

    # One more for hashing the ref, and for uploading.
    assert ray.get(counter.get_count.remote()) == 3
Exemple #12
0
def test_failed_and_resumed_workflow(workflow_start_regular, tmp_path):

    workflow_id = "simple"
    error_flag = tmp_path / "error"
    error_flag.touch()

    @ray.remote
    def simple():
        if error_flag.exists():
            raise ValueError()
        return 0

    with pytest.raises(workflow.WorkflowExecutionError):
        workflow.create(simple.bind()).run(workflow_id)

    workflow_metadata_failed = workflow.get_metadata(workflow_id)
    assert workflow_metadata_failed["status"] == "FAILED"

    error_flag.unlink()
    ref = workflow.resume(workflow_id)
    assert ray.get(ref) == 0

    workflow_metadata_resumed = workflow.get_metadata(workflow_id)
    assert workflow_metadata_resumed["status"] == "SUCCESSFUL"

    # make sure resume updated running metrics
    assert (
        workflow_metadata_resumed["stats"]["start_time"]
        > workflow_metadata_failed["stats"]["start_time"]
    )
    assert (
        workflow_metadata_resumed["stats"]["end_time"]
        > workflow_metadata_failed["stats"]["end_time"]
    )
Exemple #13
0
def test_recovery_simple(workflow_start_regular):
    @ray.remote
    def append1(x):
        return x + "[append1]"

    @ray.remote
    def append2(x):
        return x + "[append2]"

    @ray.remote
    def simple(x):
        x = append1.bind(x)
        y = the_failed_step.bind(x)
        z = append2.bind(y)
        return workflow.continuation(z)

    utils.unset_global_mark()
    workflow_id = "test_recovery_simple"
    with pytest.raises(RaySystemError):
        # internally we get WorkerCrashedError
        workflow.create(simple.bind("x")).run(workflow_id=workflow_id)

    assert workflow.get_status(
        workflow_id) == workflow.WorkflowStatus.RESUMABLE

    utils.set_global_mark()
    output = workflow.resume(workflow_id)
    assert ray.get(output) == "foo(x[append1])[append2]"
    utils.unset_global_mark()
    # resume from workflow output checkpoint
    output = workflow.resume(workflow_id)
    assert ray.get(output) == "foo(x[append1])[append2]"
def test_run_or_resume_during_running(workflow_start_regular_shared):
    @ray.remote
    def source1():
        return "[source1]"

    @ray.remote
    def append1(x):
        return x + "[append1]"

    @ray.remote
    def append2(x):
        return x + "[append2]"

    @ray.remote
    def simple_sequential():
        x = source1.bind()
        y = append1.bind(x)
        return workflow.continuation(append2.bind(y))

    output = workflow.create(
        simple_sequential.bind()).run_async(workflow_id="running_workflow")
    with pytest.raises(RuntimeError):
        workflow.create(
            simple_sequential.bind()).run_async(workflow_id="running_workflow")
    with pytest.raises(RuntimeError):
        workflow.resume(workflow_id="running_workflow")
    assert ray.get(output) == "[source1][append1][append2]"
Exemple #15
0
def test_dag_to_workflow_execution(workflow_start_regular_shared):
    """This test constructs a DAG with complex dependencies
    and turns it into a workflow."""

    @ray.remote
    def begin(x, pos, a):
        return x * a + pos  # 23.14

    @ray.remote
    def left(x, c, a):
        return f"left({x}, {c}, {a})"

    @ray.remote
    def right(x, b, pos):
        return f"right({x}, {b}, {pos})"

    @ray.remote
    def end(lf, rt, b):
        return f"{lf},{rt};{b}"

    with pytest.raises(TypeError):
        workflow.create(begin.remote(1, 2, 3))

    with InputNode() as dag_input:
        f = begin.bind(2, dag_input[1], a=dag_input.a)
        lf = left.bind(f, "hello", dag_input.a)
        rt = right.bind(f, b=dag_input.b, pos=dag_input[0])
        b = end.bind(lf, rt, b=dag_input.b)

    wf = workflow.create(b, 2, 3.14, a=10, b="ok")
    assert len(list(wf._iter_workflows_in_dag())) == 4, "incorrect amount of steps"
    assert wf.run() == "left(23.14, hello, 10),right(23.14, ok, 2);ok"
Exemple #16
0
def test_resume_different_storage(shutdown_only, tmp_path):
    @ray.remote
    def constant():
        return 31416

    ray.init(storage=str(tmp_path))
    workflow.init()
    workflow.create(constant.bind()).run(workflow_id="const")
    assert ray.get(workflow.resume(workflow_id="const")) == 31416
Exemple #17
0
def test_resume_different_storage(ray_start_regular, tmp_path, reset_workflow):
    @ray.remote
    def constant():
        return 31416

    workflow.init(storage=str(tmp_path))
    workflow.create(constant.bind()).run(workflow_id="const")
    assert ray.get(workflow.resume(workflow_id="const")) == 31416
    workflow.storage.set_global_storage(None)
Exemple #18
0
def test_sleep_checkpointing(workflow_start_regular_shared):
    """Test that the workflow sleep only starts after `run` not when the step is
    defined."""
    sleep_step = workflow.sleep(2)
    time.sleep(2)
    start_time = time.time()
    workflow.create(sleep_step).run()
    end_time = time.time()
    duration = end_time - start_time
    assert 1 < duration
Exemple #19
0
def test_user_metadata_not_dict(workflow_start_regular):
    @ray.remote
    def simple():
        return 0

    with pytest.raises(ValueError):
        workflow.create(simple.options(**workflow.options(metadata="x")).bind())

    with pytest.raises(ValueError):
        workflow.create(simple.bind()).run(metadata="x")
Exemple #20
0
def test_dynamic_workflow_ref(workflow_start_regular_shared):
    @ray.remote
    def incr(x):
        return x + 1

    # This test also shows different "style" of running workflows.
    first_step = workflow.create(incr.bind(0))
    assert first_step.run("test_dynamic_workflow_ref") == 1
    second_step = workflow.create(incr.bind(WorkflowRef(first_step.step_id)))
    # Without rerun, it'll just return the previous result
    assert second_step.run("test_dynamic_workflow_ref") == 1
Exemple #21
0
def test_user_metadata_not_json_serializable(workflow_start_regular):
    @ray.remote
    def simple():
        return 0

    class X:
        pass

    with pytest.raises(ValueError):
        workflow.create(simple.options(**workflow.options(metadata={"x": X()})).bind())

    with pytest.raises(ValueError):
        workflow.create(simple.bind()).run(metadata={"x": X()})
def test_tail_recursion_optimization(workflow_start_regular_shared):
    @ray.remote
    def tail_recursion(n):
        import inspect

        # check if the stack is growing
        assert len(inspect.stack(0)) < 20
        if n <= 0:
            return "ok"
        return workflow.continuation(
            tail_recursion.options(**workflow.options(
                allow_inplace=True)).bind(n - 1))

    workflow.create(tail_recursion.bind(30)).run()
Exemple #23
0
def test_wf_run(workflow_start_regular_shared, tmp_path):
    counter = tmp_path / "counter"
    counter.write_text("0")

    @ray.remote
    def f():
        v = int(counter.read_text()) + 1
        counter.write_text(str(v))

    workflow.create(f.bind()).run("abc")
    assert counter.read_text() == "1"
    # This will not rerun the job from beginning
    workflow.create(f.bind()).run("abc")
    assert counter.read_text() == "1"
Exemple #24
0
def test_user_metadata_empty(workflow_start_regular):

    step_name = "simple_step"
    workflow_id = "simple"

    @workflow.options(name=step_name)
    @ray.remote
    def simple():
        return 0

    workflow.create(simple.bind()).run(workflow_id)

    assert workflow.get_metadata("simple")["user_metadata"] == {}
    assert workflow.get_metadata("simple", "simple_step")["user_metadata"] == {}
Exemple #25
0
def test_wait_for_multiple_events(workflow_start_regular_shared):
    """If a workflow has multiple event arguments, it should wait for them at the
    same time.
    """
    class EventListener1(workflow.EventListener):
        async def poll_for_event(self):
            utils.set_global_mark("listener1")
            while not utils.check_global_mark("trigger_event"):
                await asyncio.sleep(0.1)
            return "event1"

    class EventListener2(workflow.EventListener):
        async def poll_for_event(self):
            utils.set_global_mark("listener2")
            while not utils.check_global_mark("trigger_event"):
                await asyncio.sleep(0.1)
            return "event2"

    @ray.remote
    def trivial_step(arg1, arg2):
        return f"{arg1} {arg2}"

    event1_promise = workflow.wait_for_event(EventListener1)
    event2_promise = workflow.wait_for_event(EventListener2)

    promise = workflow.create(trivial_step.bind(event1_promise,
                                                event2_promise)).run_async()

    while not (utils.check_global_mark("listener1")
               and utils.check_global_mark("listener2")):
        time.sleep(0.1)

    utils.set_global_mark("trigger_event")
    assert ray.get(promise) == "event1 event2"
def test_step_resources(workflow_start_regular, tmp_path):
    lock_path = str(tmp_path / "lock")
    # We use signal actor here because we can't guarantee the order of tasks
    # sent from worker to raylet.
    signal_actor = SignalActor.remote()

    @ray.remote
    def step_run():
        ray.wait([signal_actor.send.remote()])
        with FileLock(lock_path):
            return None

    @ray.remote(num_cpus=1)
    def remote_run():
        return None

    lock = FileLock(lock_path)
    lock.acquire()
    ret = workflow.create(step_run.options(num_cpus=2).bind()).run_async()
    ray.wait([signal_actor.wait.remote()])
    obj = remote_run.remote()
    with pytest.raises(ray.exceptions.GetTimeoutError):
        ray.get(obj, timeout=2)
    lock.release()
    assert ray.get(ret) is None
    assert ray.get(obj) is None
Exemple #27
0
def test_nested_workflow_no_download(workflow_start_regular):
    """Test that we _only_ load from storage on recovery. For a nested workflow
    step, we should checkpoint the input/output, but continue to reuse the
    in-memory value.
    """
    @ray.remote
    def recursive(ref, count):
        if count == 0:
            return ref
        return workflow.continuation(recursive.bind(ref, count - 1))

    with tempfile.TemporaryDirectory() as temp_dir:
        debug_store = DebugStorage(temp_dir)
        utils._alter_storage(debug_store)

        ref = ray.put("hello")
        result = workflow.create(recursive.bind([ref], 10)).run()

        ops = debug_store._logged_storage.get_op_counter()
        get_objects_count = 0
        for key in ops["get"]:
            if "objects" in key:
                get_objects_count += 1
        assert get_objects_count == 1, "We should only get once when resuming."
        put_objects_count = 0
        for key in ops["put"]:
            if "objects" in key:
                print(key)
                put_objects_count += 1
        assert (put_objects_count == 1
                ), "We should detect the object exists before uploading"
        assert ray.get(result) == ["hello"]
def test_get_output_1(workflow_start_regular, tmp_path):
    @ray.remote
    def simple(v):
        return v

    assert 0 == workflow.create(simple.bind(0)).run("simple")
    assert 0 == ray.get(workflow.get_output("simple"))
Exemple #29
0
def test_event_during_arg_resolution(workflow_start_regular_shared):
    """If a workflow's arguments are being executed when the event occurs, the
    workflow should run immediately with no issues.
    """
    class MyEventListener(workflow.EventListener):
        async def poll_for_event(self):
            while not utils.check_global_mark():
                await asyncio.sleep(0.1)
            utils.set_global_mark("event_returning")

    @ray.remote
    def triggers_event():
        utils.set_global_mark()
        while not utils.check_global_mark("event_returning"):
            time.sleep(0.1)

    @ray.remote
    def gather(*args):
        return args

    event_promise = workflow.wait_for_event(MyEventListener)
    assert workflow.create(gather.bind(event_promise,
                                       triggers_event.bind())).run() == (
                                           None,
                                           None,
                                       )
Exemple #30
0
def test_event_after_arg_resolution(workflow_start_regular_shared):
    """Ensure that a workflow resolves all of its non-event arguments while it
    waiting the the event to occur.
    """
    class MyEventListener(workflow.EventListener):
        async def poll_for_event(self):
            while not utils.check_global_mark():
                await asyncio.sleep(0.1)
            # Give the other step time to finish.
            await asyncio.sleep(1)

    @ray.remote
    def triggers_event():
        utils.set_global_mark()

    @ray.remote
    def gather(*args):
        return args

    event_promise = workflow.wait_for_event(MyEventListener)

    assert workflow.create(gather.bind(event_promise,
                                       triggers_event.bind())).run() == (
                                           None,
                                           None,
                                       )