示例#1
0
def test_failed_and_resumed_workflow(workflow_start_regular, tmp_path):

    workflow_id = "simple"
    error_flag = tmp_path / "error"
    error_flag.touch()

    @ray.remote
    def simple():
        if error_flag.exists():
            raise ValueError()
        return 0

    with pytest.raises(workflow.WorkflowExecutionError):
        workflow.run(simple.bind(), workflow_id=workflow_id)

    workflow_metadata_failed = workflow.get_metadata(workflow_id)
    assert workflow_metadata_failed["status"] == "FAILED"

    error_flag.unlink()
    assert workflow.resume(workflow_id) == 0

    workflow_metadata_resumed = workflow.get_metadata(workflow_id)
    assert workflow_metadata_resumed["status"] == "SUCCESSFUL"

    # make sure resume updated running metrics
    assert (workflow_metadata_resumed["stats"]["start_time"] >
            workflow_metadata_failed["stats"]["start_time"])
    assert (workflow_metadata_resumed["stats"]["end_time"] >
            workflow_metadata_failed["stats"]["end_time"])
示例#2
0
def test_nested_workflow(workflow_start_regular):
    @workflow.options(name="inner", metadata={"inner_k": "inner_v"})
    @ray.remote
    def inner():
        time.sleep(2)
        return 10

    @workflow.options(name="outer", metadata={"outer_k": "outer_v"})
    @ray.remote
    def outer():
        time.sleep(2)
        return workflow.continuation(inner.bind())

    workflow.run(outer.bind(),
                 workflow_id="nested",
                 metadata={"workflow_k": "workflow_v"})

    workflow_metadata = workflow.get_metadata("nested")
    outer_step_metadata = workflow.get_metadata("nested", "outer")
    inner_step_metadata = workflow.get_metadata("nested", "inner")

    assert workflow_metadata["user_metadata"] == {"workflow_k": "workflow_v"}
    assert outer_step_metadata["user_metadata"] == {"outer_k": "outer_v"}
    assert inner_step_metadata["user_metadata"] == {"inner_k": "inner_v"}

    assert (workflow_metadata["stats"]["end_time"] >=
            workflow_metadata["stats"]["start_time"] + 4)
    assert (outer_step_metadata["stats"]["end_time"] >=
            outer_step_metadata["stats"]["start_time"] + 2)
    assert (inner_step_metadata["stats"]["end_time"] >=
            inner_step_metadata["stats"]["start_time"] + 2)
    assert (inner_step_metadata["stats"]["start_time"] >=
            outer_step_metadata["stats"]["end_time"])
示例#3
0
def test_recovery_simple_3(workflow_start_regular):
    @ray.remote
    def append1(x):
        return x + "[append1]"

    @ray.remote
    def append2(x):
        return x + "[append2]"

    @ray.remote
    def simple(x):
        x = append1.bind(x)
        y = the_failed_step.bind(x)
        z = append2.bind(y)
        return workflow.continuation(z)

    utils.unset_global_mark()
    workflow_id = "test_recovery_simple_3"
    with pytest.raises(workflow.WorkflowExecutionError):
        # internally we get WorkerCrashedError
        workflow.run(simple.bind("x"), workflow_id=workflow_id)

    assert workflow.get_status(workflow_id) == workflow.WorkflowStatus.FAILED

    utils.set_global_mark()
    assert workflow.resume(workflow_id) == "foo(x[append1])[append2]"
    utils.unset_global_mark()
    # resume from workflow output checkpoint
    assert workflow.resume(workflow_id) == "foo(x[append1])[append2]"
示例#4
0
def test_successful_workflow(workflow_start_regular):

    user_step_metadata = {"k1": "v1"}
    user_run_metadata = {"k2": "v2"}
    step_name = "simple_step"
    workflow_id = "simple"

    @workflow.options(name=step_name, metadata=user_step_metadata)
    @ray.remote
    def simple():
        time.sleep(2)
        return 0

    workflow.run(simple.bind(),
                 workflow_id=workflow_id,
                 metadata=user_run_metadata)

    workflow_metadata = workflow.get_metadata("simple")
    assert workflow_metadata["status"] == "SUCCESSFUL"
    assert workflow_metadata["user_metadata"] == user_run_metadata
    assert "start_time" in workflow_metadata["stats"]
    assert "end_time" in workflow_metadata["stats"]
    assert (workflow_metadata["stats"]["end_time"] >=
            workflow_metadata["stats"]["start_time"] + 2)

    step_metadata = workflow.get_metadata("simple", "simple_step")
    assert step_metadata["user_metadata"] == user_step_metadata
    assert "start_time" in step_metadata["stats"]
    assert "end_time" in step_metadata["stats"]
    assert (step_metadata["stats"]["end_time"] >=
            step_metadata["stats"]["start_time"] + 2)
示例#5
0
def test_dedupe_serialization(workflow_start_regular_shared):
    @ray.remote(num_cpus=0)
    class Counter:
        def __init__(self):
            self.count = 0

        def incr(self):
            self.count += 1

        def get_count(self):
            return self.count

    counter = Counter.remote()

    class CustomClass:
        def __getstate__(self):
            # Count the number of times this class is serialized.
            ray.get(counter.incr.remote())
            return {}

    ref = ray.put(CustomClass())
    list_of_refs = [ref for _ in range(2)]

    # One for the ray.put
    assert ray.get(counter.get_count.remote()) == 1

    single = identity.bind((ref, ))
    double = identity.bind(list_of_refs)

    workflow.run(gather.bind(single, double))

    # One more for hashing the ref, and for uploading.
    assert ray.get(counter.get_count.remote()) == 3
示例#6
0
def test_nested_catch_exception_3(workflow_start_regular_shared, tmp_path):
    """Test the case where the exception is not raised by the output task of
    a nested DAG."""

    @ray.remote
    def f3():
        return 10

    @ray.remote
    def f3_exc():
        raise ValueError()

    @ray.remote
    def f2(x):
        return x

    @ray.remote
    def f1(exc):
        if exc:
            return workflow.continuation(f2.bind(f3_exc.bind()))
        else:
            return workflow.continuation(f2.bind(f3.bind()))

    ret, err = workflow.run(
        f1.options(**workflow.options(catch_exceptions=True)).bind(True)
    )
    assert ret is None
    assert isinstance(err, ValueError)

    assert (10, None) == workflow.run(
        f1.options(**workflow.options(catch_exceptions=True)).bind(False)
    )
示例#7
0
def test_dedupe_indirect(workflow_start_regular_shared, tmp_path):
    counter = Path(tmp_path) / "counter.txt"
    lock = Path(tmp_path) / "lock.txt"
    counter.write_text("0")

    @ray.remote
    def incr():
        with FileLock(str(lock)):
            c = int(counter.read_text())
            c += 1
            counter.write_text(f"{c}")

    @ray.remote
    def identity(a):
        return a

    @ray.remote
    def join(*a):
        return counter.read_text()

    # Here a is passed to two steps and we need to ensure
    # it's only executed once
    a = incr.bind()
    i1 = identity.bind(a)
    i2 = identity.bind(a)
    assert "1" == workflow.run(join.bind(i1, i2))
    assert "2" == workflow.run(join.bind(i1, i2))
    # pass a multiple times
    assert "3" == workflow.run(join.bind(a, a, a, a))
    assert "4" == workflow.run(join.bind(a, a, a, a))
示例#8
0
def test_get_output_3(workflow_start_regular, tmp_path):
    cnt_file = tmp_path / "counter"
    cnt_file.write_text("0")
    error_flag = tmp_path / "error"
    error_flag.touch()

    @ray.remote
    def incr():
        v = int(cnt_file.read_text())
        cnt_file.write_text(str(v + 1))
        if error_flag.exists():
            raise ValueError()
        return 10

    with pytest.raises(workflow.WorkflowExecutionError):
        workflow.run(incr.options(max_retries=0).bind(), workflow_id="incr")

    assert cnt_file.read_text() == "1"

    from ray.exceptions import RaySystemError

    # TODO(suquark): We should prevent Ray from raising "RaySystemError",
    #   in workflow, because "RaySystemError" does not inherit the underlying
    #   error, so users and developers cannot catch the expected error.
    #   I feel this issue is a very annoying.
    with pytest.raises((RaySystemError, ValueError)):
        workflow.get_output("incr")

    assert cnt_file.read_text() == "1"
    error_flag.unlink()
    with pytest.raises((RaySystemError, ValueError)):
        workflow.get_output("incr")
    assert workflow.resume("incr") == 10
示例#9
0
def test_dynamic_output(workflow_start_regular_shared):
    @ray.remote
    def exponential_fail(k, n):
        if n > 0:
            if n < 3:
                raise Exception("Failed intentionally")
            return workflow.continuation(
                exponential_fail.options(**workflow.options(
                    name=f"step_{n}")).bind(k * 2, n - 1))
        return k

    # When workflow fails, the dynamic output should points to the
    # latest successful step.
    try:
        workflow.run(
            exponential_fail.options(**workflow.options(name="step_0")).bind(
                3, 10),
            workflow_id="dynamic_output",
        )
    except Exception:
        pass
    from ray.workflow.workflow_storage import get_workflow_storage

    wf_storage = get_workflow_storage(workflow_id="dynamic_output")
    result = wf_storage.inspect_step("step_0")
    assert result.output_step_id == "step_3"
示例#10
0
def test_resume_different_storage(shutdown_only, tmp_path):
    @ray.remote
    def constant():
        return 31416

    ray.init(storage=str(tmp_path))
    workflow.init()
    workflow.run(constant.bind(), workflow_id="const")
    assert workflow.resume(workflow_id="const") == 31416
示例#11
0
def test_sleep_checkpointing(workflow_start_regular_shared):
    """Test that the workflow sleep only starts after `run` not when the step is
    defined."""
    sleep_step = workflow.sleep(2)
    time.sleep(2)
    start_time = time.time()
    workflow.run(sleep_step)
    end_time = time.time()
    duration = end_time - start_time
    assert 1 < duration
示例#12
0
def test_user_metadata_not_dict(workflow_start_regular):
    @ray.remote
    def simple():
        return 0

    with pytest.raises(ValueError):
        workflow.run_async(
            simple.options(**workflow.options(metadata="x")).bind())

    with pytest.raises(ValueError):
        workflow.run(simple.bind(), metadata="x")
示例#13
0
def test_user_metadata_not_json_serializable(workflow_start_regular):
    @ray.remote
    def simple():
        return 0

    class X:
        pass

    with pytest.raises(ValueError):
        workflow.run_async(
            simple.options(**workflow.options(metadata={"x": X()})).bind())

    with pytest.raises(ValueError):
        workflow.run(simple.bind(), metadata={"x": X()})
示例#14
0
def test_recovery_simple_1(workflow_start_regular):
    utils.unset_global_mark()
    workflow_id = "test_recovery_simple_1"
    with pytest.raises(workflow.WorkflowExecutionError):
        # internally we get WorkerCrashedError
        workflow.run(the_failed_step.bind("x"), workflow_id=workflow_id)

    assert workflow.get_status(workflow_id) == workflow.WorkflowStatus.FAILED

    utils.set_global_mark()
    assert workflow.resume(workflow_id) == "foo(x)"
    utils.unset_global_mark()
    # resume from workflow output checkpoint
    assert workflow.resume(workflow_id) == "foo(x)"
示例#15
0
def test_dynamic_workflow_ref(workflow_start_regular_shared):
    @ray.remote
    def incr(x):
        return x + 1

    # This test also shows different "style" of running workflows.
    assert workflow.run(incr.bind(0), workflow_id="test_dynamic_workflow_ref") == 1
    # Without rerun, it'll just return the previous result
    assert (
        workflow.run(
            incr.bind(WorkflowRef("incr")), workflow_id="test_dynamic_workflow_ref"
        )
        == 1
    )
示例#16
0
def test_wf_run(workflow_start_regular_shared, tmp_path):
    counter = tmp_path / "counter"
    counter.write_text("0")

    @ray.remote
    def f():
        v = int(counter.read_text()) + 1
        counter.write_text(str(v))

    workflow.run(f.bind(), workflow_id="abc")
    assert counter.read_text() == "1"
    # This will not rerun the job from beginning
    workflow.run(f.bind(), workflow_id="abc")
    assert counter.read_text() == "1"
示例#17
0
def test_wf_no_run(shutdown_only):
    # workflow should be able to run without explicit init
    ray.shutdown()

    @ray.remote
    def f1():
        pass

    f1.bind()

    @ray.remote
    def f2(*w):
        pass

    workflow.run(f2.bind(*[f1.bind() for _ in range(10)]))
示例#18
0
def test_user_metadata_empty(workflow_start_regular):

    step_name = "simple_step"
    workflow_id = "simple"

    @workflow.options(name=step_name)
    @ray.remote
    def simple():
        return 0

    workflow.run(simple.bind(), workflow_id=workflow_id)

    assert workflow.get_metadata("simple")["user_metadata"] == {}
    assert workflow.get_metadata("simple",
                                 "simple_step")["user_metadata"] == {}
示例#19
0
def test_dedupe_download_raw_ref(workflow_start_regular):
    with tempfile.TemporaryDirectory() as temp_dir:
        debug_store = DebugStorage(temp_dir)
        utils._alter_storage(debug_store)

        ref = ray.put("hello")
        workflows = [identity.bind(ref) for _ in range(100)]

        workflow.run(gather.bind(*workflows))

        ops = debug_store._logged_storage.get_op_counter()
        get_objects_count = 0
        for key in ops["get"]:
            if "objects" in key:
                get_objects_count += 1
        assert get_objects_count == 1
示例#20
0
def test_partial(workflow_start_regular_shared):
    ys = [1, 2, 3]

    def add(x, y):
        return x + y

    from functools import partial

    f1 = workflow.step(partial(add, 10)).step(10)

    assert "__anonymous_func__" in f1._name
    assert f1.run() == 20

    fs = [partial(add, y=y) for y in ys]

    @ray.remote
    def chain_func(*args, **kw_argv):
        # Get the first function as a start
        wf_step = workflow.step(fs[0]).step(*args, **kw_argv)
        for i in range(1, len(fs)):
            # Convert each function inside steps into workflow step
            # function and then use the previous output as the input
            # for them.
            wf_step = workflow.step(fs[i]).step(wf_step)
        return wf_step

    assert workflow.run(chain_func.bind(1)) == 7
示例#21
0
def test_nested_workflow_no_download(workflow_start_regular):
    """Test that we _only_ load from storage on recovery. For a nested workflow
    step, we should checkpoint the input/output, but continue to reuse the
    in-memory value.
    """

    @ray.remote
    def recursive(ref, count):
        if count == 0:
            return ref
        return workflow.continuation(recursive.bind(ref, count - 1))

    with tempfile.TemporaryDirectory() as temp_dir:
        debug_store = DebugStorage(temp_dir)
        utils._alter_storage(debug_store)

        ref = ray.put("hello")
        result = workflow.run(recursive.bind([ref], 10))

        ops = debug_store._logged_storage.get_op_counter()
        get_objects_count = 0
        for key in ops["get"]:
            if "objects" in key:
                get_objects_count += 1
        assert get_objects_count == 1, "We should only get once when resuming."
        put_objects_count = 0
        for key in ops["put"]:
            if "objects" in key:
                print(key)
                put_objects_count += 1
        assert (
            put_objects_count == 1
        ), "We should detect the object exists before uploading"
        assert ray.get(result) == ["hello"]
示例#22
0
def test_get_output_1(workflow_start_regular, tmp_path):
    @ray.remote
    def simple(v):
        return v

    assert 0 == workflow.run(simple.bind(0), workflow_id="simple")
    assert 0 == workflow.get_output("simple")
示例#23
0
def test_event_during_arg_resolution(workflow_start_regular_shared):
    """If a workflow's arguments are being executed when the event occurs, the
    workflow should run immediately with no issues.
    """

    class MyEventListener(workflow.EventListener):
        async def poll_for_event(self):
            while not utils.check_global_mark():
                await asyncio.sleep(0.1)
            utils.set_global_mark("event_returning")

    @ray.remote
    def triggers_event():
        utils.set_global_mark()
        while not utils.check_global_mark("event_returning"):
            time.sleep(0.1)

    @ray.remote
    def gather(*args):
        return args

    event_promise = workflow.wait_for_event(MyEventListener)
    assert workflow.run(gather.bind(event_promise, triggers_event.bind())) == (
        None,
        None,
    )
示例#24
0
def test_event_after_arg_resolution(workflow_start_regular_shared):
    """Ensure that a workflow resolves all of its non-event arguments while it
    waiting the the event to occur.
    """

    class MyEventListener(workflow.EventListener):
        async def poll_for_event(self):
            while not utils.check_global_mark():
                await asyncio.sleep(0.1)
            # Give the other step time to finish.
            await asyncio.sleep(1)

    @ray.remote
    def triggers_event():
        utils.set_global_mark()

    @ray.remote
    def gather(*args):
        return args

    event_promise = workflow.wait_for_event(MyEventListener)

    assert workflow.run(gather.bind(event_promise, triggers_event.bind())) == (
        None,
        None,
    )
示例#25
0
def test_object_deref(workflow_start_regular_shared):
    @ray.remote
    def empty_list():
        return [1]

    @ray.remote
    def receive_workflow(workflow):
        pass

    @ray.remote
    def return_workflow():
        return empty_list.bind()

    @ray.remote
    def return_data() -> ray.ObjectRef:
        return ray.put(np.ones(4096))

    @ray.remote
    def receive_data(data: "ray.ObjectRef[np.ndarray]"):
        return ray.get(data)

    # test we are forbidden from directly passing workflow to Ray.
    x = empty_list.bind()
    with pytest.raises(ValueError):
        ray.put(x)
    with pytest.raises(ValueError):
        ray.get(receive_workflow.remote(x))
    with pytest.raises(ValueError):
        ray.get(return_workflow.remote())

    # test return object ref
    obj = return_data.bind()
    arr: np.ndarray = workflow.run(receive_data.bind(obj))
    assert np.array_equal(arr, np.ones(4096))
示例#26
0
def test_objectref_inputs(workflow_start_regular_shared):
    @ray.remote
    def nested_workflow(n: int):
        if n <= 0:
            return "nested"
        else:
            return workflow.continuation(nested_workflow.bind(n - 1))

    @ray.remote
    def deref_check(u: int, x: str, y: List[str], z: List[Dict[str, str]]):
        try:
            return (u == 42 and x == "nested"
                    and isinstance(y[0], ray.ObjectRef)
                    and ray.get(y) == ["nested"]
                    and isinstance(z[0]["output"], ray.ObjectRef) and ray.get(
                        z[0]["output"]) == "nested"), f"{u}, {x}, {y}, {z}"
        except Exception as e:
            return False, str(e)

    output, s = workflow.run(
        deref_check.bind(
            ray.put(42),
            nested_workflow.bind(10),
            [nested_workflow.bind(9)],
            [{
                "output": nested_workflow.bind(7)
            }],
        ))
    assert output is True, s
示例#27
0
def test_dataset_2(workflow_start_regular_shared):
    ds_ref = gen_dataset_2.bind()
    transformed_ref = transform_dataset_1.bind(ds_ref)
    output_ref = sum_dataset.bind(transformed_ref)

    result = workflow.run(output_ref)
    assert result == 2 * sum(range(1000))
示例#28
0
def test_checkpoint_dag_recovery_partial(workflow_start_regular_shared):
    utils.unset_global_mark()

    start = time.time()
    with pytest.raises(workflow.WorkflowExecutionError):
        workflow.run(checkpoint_dag.bind(False),
                     workflow_id="checkpoint_partial_recovery")
    run_duration_partial = time.time() - start

    utils.set_global_mark()

    start = time.time()
    recovered = workflow.resume("checkpoint_partial_recovery")
    recover_duration_partial = time.time() - start
    assert np.isclose(recovered, np.arange(SIZE).mean())
    print(f"[partial] run_duration = {run_duration_partial}, "
          f"recover_duration = {recover_duration_partial}")
示例#29
0
def test_same_object_many_workflows(workflow_start_regular_shared):
    """Ensure that when we dedupe uploads, we upload the object once per workflow,
    since different workflows shouldn't look in each others object directories.
    """
    @ray.remote
    def f(a):
        return [a[0]]

    x = {0: ray.put(10)}

    result1 = workflow.run(f.bind(x))
    result2 = workflow.run(f.bind(x))
    print(result1)
    print(result2)

    assert ray.get(*result1) == 10
    assert ray.get(*result2) == 10
示例#30
0
def test_workflow_queuing_1(shutdown_only, tmp_path):
    ray.init(storage=str(tmp_path))
    workflow.init(max_running_workflows=2, max_pending_workflows=2)

    import queue
    import filelock

    lock_path = str(tmp_path / ".lock")

    @ray.remote
    def long_running(x):
        with filelock.FileLock(lock_path):
            return x

    wfs = [long_running.bind(i) for i in range(5)]

    with filelock.FileLock(lock_path):
        refs = [
            workflow.run_async(wfs[i], workflow_id=f"workflow_{i}")
            for i in range(4)
        ]

        assert sorted(x[0] for x in workflow.list_all({workflow.RUNNING})) == [
            "workflow_0",
            "workflow_1",
        ]
        assert sorted(x[0] for x in workflow.list_all({workflow.PENDING})) == [
            "workflow_2",
            "workflow_3",
        ]

        with pytest.raises(queue.Full, match="Workflow queue has been full"):
            workflow.run(wfs[4], workflow_id="workflow_4")

    assert ray.get(refs) == [0, 1, 2, 3]
    assert workflow.run(wfs[4], workflow_id="workflow_4") == 4
    assert sorted(x[0] for x in workflow.list_all({workflow.SUCCESSFUL})) == [
        "workflow_0",
        "workflow_1",
        "workflow_2",
        "workflow_3",
        "workflow_4",
    ]
    for i in range(5):
        assert workflow.get_output(f"workflow_{i}") == i