def test_get_output_3(workflow_start_regular, tmp_path): cnt_file = tmp_path / "counter" cnt_file.write_text("0") error_flag = tmp_path / "error" error_flag.touch() @ray.remote def incr(): v = int(cnt_file.read_text()) cnt_file.write_text(str(v + 1)) if error_flag.exists(): raise ValueError() return 10 with pytest.raises(workflow.WorkflowExecutionError): workflow.create(incr.options(**workflow.options(max_retries=0)).bind()).run( "incr" ) assert cnt_file.read_text() == "1" from ray.exceptions import RaySystemError # TODO(suquark): We should prevent Ray from raising "RaySystemError", # in workflow, because "RaySystemError" does not inherit the underlying # error, so users and developers cannot catch the expected error. # I feel this issue is a very annoying. with pytest.raises((RaySystemError, ValueError)): ray.get(workflow.get_output("incr")) assert cnt_file.read_text() == "1" error_flag.unlink() with pytest.raises((RaySystemError, ValueError)): ray.get(workflow.get_output("incr")) assert ray.get(workflow.resume("incr")) == 10
def test_output_with_name(workflow_start_regular): @ray.remote def double(v): return 2 * v inner_task = double.options(**workflow.options(name="inner")).bind(1) outer_task = double.options(**workflow.options(name="outer")).bind(inner_task) result = workflow.create(outer_task).run_async("double") inner = workflow.get_output("double", name="inner") outer = workflow.get_output("double", name="outer") assert ray.get(inner) == 2 assert ray.get(outer) == 4 assert ray.get(result) == 4 @workflow.options(name="double") @ray.remote def double_2(s): return s * 2 inner_task = double_2.bind(1) outer_task = double_2.bind(inner_task) workflow_id = "double_2" result = workflow.create(outer_task).run_async(workflow_id) inner = workflow.get_output(workflow_id, name="double") outer = workflow.get_output(workflow_id, name="double_1") assert ray.get(inner) == 2 assert ray.get(outer) == 4 assert ray.get(result) == 4
def test_dedupe_indirect(workflow_start_regular_shared, tmp_path): counter = Path(tmp_path) / "counter.txt" lock = Path(tmp_path) / "lock.txt" counter.write_text("0") @ray.remote def incr(): with FileLock(str(lock)): c = int(counter.read_text()) c += 1 counter.write_text(f"{c}") @ray.remote def identity(a): return a @ray.remote def join(*a): return counter.read_text() # Here a is passed to two steps and we need to ensure # it's only executed once a = incr.bind() i1 = identity.bind(a) i2 = identity.bind(a) assert "1" == workflow.create(join.bind(i1, i2)).run() assert "2" == workflow.create(join.bind(i1, i2)).run() # pass a multiple times assert "3" == workflow.create(join.bind(a, a, a, a)).run() assert "4" == workflow.create(join.bind(a, a, a, a)).run()
def test_nested_catch_exception_3(workflow_start_regular_shared, tmp_path): """Test the case where the exception is not raised by the output task of a nested DAG.""" @ray.remote def f3(): return 10 @ray.remote def f3_exc(): raise ValueError() @ray.remote def f2(x): return x @ray.remote def f1(exc): if exc: return workflow.continuation(f2.bind(f3_exc.bind())) else: return workflow.continuation(f2.bind(f3.bind())) ret, err = workflow.create( f1.options(**workflow.options( catch_exceptions=True)).bind(True)).run() assert ret is None assert isinstance(err, ValueError) assert (10, None) == workflow.create( f1.options(**workflow.options( catch_exceptions=True)).bind(False)).run()
def test_runtime_metadata(workflow_start_regular): step_name = "simple_step" workflow_id = "simple" @workflow.options(name=step_name) @ray.remote def simple(): time.sleep(2) return 0 workflow.create(simple.bind()).run(workflow_id) workflow_metadata = workflow.get_metadata("simple") assert "start_time" in workflow_metadata["stats"] assert "end_time" in workflow_metadata["stats"] assert ( workflow_metadata["stats"]["end_time"] >= workflow_metadata["stats"]["start_time"] + 2 ) step_metadata = workflow.get_metadata("simple", "simple_step") assert "start_time" in step_metadata["stats"] assert "end_time" in step_metadata["stats"] assert ( step_metadata["stats"]["end_time"] >= step_metadata["stats"]["start_time"] + 2 )
def test_nested_workflow(workflow_start_regular): @workflow.options(name="inner", metadata={"inner_k": "inner_v"}) @ray.remote def inner(): time.sleep(2) return 10 @workflow.options(name="outer", metadata={"outer_k": "outer_v"}) @ray.remote def outer(): time.sleep(2) return workflow.continuation(inner.bind()) workflow.create(outer.bind()).run("nested", metadata={"workflow_k": "workflow_v"}) workflow_metadata = workflow.get_metadata("nested") outer_step_metadata = workflow.get_metadata("nested", "outer") inner_step_metadata = workflow.get_metadata("nested", "inner") assert workflow_metadata["user_metadata"] == {"workflow_k": "workflow_v"} assert outer_step_metadata["user_metadata"] == {"outer_k": "outer_v"} assert inner_step_metadata["user_metadata"] == {"inner_k": "inner_v"} assert (workflow_metadata["stats"]["end_time"] >= workflow_metadata["stats"]["start_time"] + 4) assert (outer_step_metadata["stats"]["end_time"] >= outer_step_metadata["stats"]["start_time"] + 2) assert (inner_step_metadata["stats"]["end_time"] >= inner_step_metadata["stats"]["start_time"] + 2) assert (inner_step_metadata["stats"]["start_time"] >= outer_step_metadata["stats"]["end_time"])
def test_get_output_3(workflow_start_regular, tmp_path): cnt_file = tmp_path / "counter" cnt_file.write_text("0") error_flag = tmp_path / "error" error_flag.touch() @ray.remote def incr(): v = int(cnt_file.read_text()) cnt_file.write_text(str(v + 1)) if error_flag.exists(): raise ValueError() return 10 with pytest.raises(ray.exceptions.RaySystemError): workflow.create(incr.options(max_retries=0).bind()).run("incr") assert cnt_file.read_text() == "1" with pytest.raises(ray.exceptions.RaySystemError): ray.get(workflow.get_output("incr")) assert cnt_file.read_text() == "1" error_flag.unlink() with pytest.raises(ray.exceptions.RaySystemError): ray.get(workflow.get_output("incr")) assert ray.get(workflow.resume("incr")) == 10
def test_running_and_canceled_workflow(workflow_start_regular, tmp_path): workflow_id = "simple" flag = tmp_path / "flag" @ray.remote def simple(): flag.touch() time.sleep(1000) return 0 workflow.create(simple.bind()).run_async(workflow_id) # Wait until step runs to make sure pre-run metadata is written while not flag.exists(): time.sleep(1) workflow_metadata = workflow.get_metadata(workflow_id) assert workflow_metadata["status"] == "RUNNING" assert "start_time" in workflow_metadata["stats"] assert "end_time" not in workflow_metadata["stats"] workflow.cancel(workflow_id) workflow_metadata = workflow.get_metadata(workflow_id) assert workflow_metadata["status"] == "CANCELED" assert "start_time" in workflow_metadata["stats"] assert "end_time" not in workflow_metadata["stats"]
def test_successful_workflow(workflow_start_regular): user_step_metadata = {"k1": "v1"} user_run_metadata = {"k2": "v2"} step_name = "simple_step" workflow_id = "simple" @workflow.options(name=step_name, metadata=user_step_metadata) @ray.remote def simple(): time.sleep(2) return 0 workflow.create(simple.bind()).run(workflow_id, metadata=user_run_metadata) workflow_metadata = workflow.get_metadata("simple") assert workflow_metadata["status"] == "SUCCESSFUL" assert workflow_metadata["user_metadata"] == user_run_metadata assert "start_time" in workflow_metadata["stats"] assert "end_time" in workflow_metadata["stats"] assert ( workflow_metadata["stats"]["end_time"] >= workflow_metadata["stats"]["start_time"] + 2 ) step_metadata = workflow.get_metadata("simple", "simple_step") assert step_metadata["user_metadata"] == user_step_metadata assert "start_time" in step_metadata["stats"] assert "end_time" in step_metadata["stats"] assert ( step_metadata["stats"]["end_time"] >= step_metadata["stats"]["start_time"] + 2 )
def test_object_deref(workflow_start_regular_shared): @ray.remote def empty_list(): return [1] @ray.remote def receive_workflow(workflow): pass @ray.remote def return_workflow(): return workflow.create(empty_list.bind()) @ray.remote def return_data() -> ray.ObjectRef: return ray.put(np.ones(4096)) @ray.remote def receive_data(data: "ray.ObjectRef[np.ndarray]"): return ray.get(data) # test we are forbidden from directly passing workflow to Ray. x = workflow.create(empty_list.bind()) with pytest.raises(ValueError): ray.put(x) with pytest.raises(ValueError): ray.get(receive_workflow.remote(x)) with pytest.raises(ValueError): ray.get(return_workflow.remote()) # test return object ref obj = return_data.bind() arr: np.ndarray = workflow.create(receive_data.bind(obj)).run() assert np.array_equal(arr, np.ones(4096))
def test_dedupe_serialization(workflow_start_regular_shared): @ray.remote(num_cpus=0) class Counter: def __init__(self): self.count = 0 def incr(self): self.count += 1 def get_count(self): return self.count counter = Counter.remote() class CustomClass: def __getstate__(self): # Count the number of times this class is serialized. ray.get(counter.incr.remote()) return {} ref = ray.put(CustomClass()) list_of_refs = [ref for _ in range(2)] # One for the ray.put assert ray.get(counter.get_count.remote()) == 1 single = identity.bind((ref, )) double = identity.bind(list_of_refs) workflow.create(gather.bind(single, double)).run() # One more for hashing the ref, and for uploading. assert ray.get(counter.get_count.remote()) == 3
def test_failed_and_resumed_workflow(workflow_start_regular, tmp_path): workflow_id = "simple" error_flag = tmp_path / "error" error_flag.touch() @ray.remote def simple(): if error_flag.exists(): raise ValueError() return 0 with pytest.raises(workflow.WorkflowExecutionError): workflow.create(simple.bind()).run(workflow_id) workflow_metadata_failed = workflow.get_metadata(workflow_id) assert workflow_metadata_failed["status"] == "FAILED" error_flag.unlink() ref = workflow.resume(workflow_id) assert ray.get(ref) == 0 workflow_metadata_resumed = workflow.get_metadata(workflow_id) assert workflow_metadata_resumed["status"] == "SUCCESSFUL" # make sure resume updated running metrics assert ( workflow_metadata_resumed["stats"]["start_time"] > workflow_metadata_failed["stats"]["start_time"] ) assert ( workflow_metadata_resumed["stats"]["end_time"] > workflow_metadata_failed["stats"]["end_time"] )
def test_recovery_simple(workflow_start_regular): @ray.remote def append1(x): return x + "[append1]" @ray.remote def append2(x): return x + "[append2]" @ray.remote def simple(x): x = append1.bind(x) y = the_failed_step.bind(x) z = append2.bind(y) return workflow.continuation(z) utils.unset_global_mark() workflow_id = "test_recovery_simple" with pytest.raises(RaySystemError): # internally we get WorkerCrashedError workflow.create(simple.bind("x")).run(workflow_id=workflow_id) assert workflow.get_status( workflow_id) == workflow.WorkflowStatus.RESUMABLE utils.set_global_mark() output = workflow.resume(workflow_id) assert ray.get(output) == "foo(x[append1])[append2]" utils.unset_global_mark() # resume from workflow output checkpoint output = workflow.resume(workflow_id) assert ray.get(output) == "foo(x[append1])[append2]"
def test_run_or_resume_during_running(workflow_start_regular_shared): @ray.remote def source1(): return "[source1]" @ray.remote def append1(x): return x + "[append1]" @ray.remote def append2(x): return x + "[append2]" @ray.remote def simple_sequential(): x = source1.bind() y = append1.bind(x) return workflow.continuation(append2.bind(y)) output = workflow.create( simple_sequential.bind()).run_async(workflow_id="running_workflow") with pytest.raises(RuntimeError): workflow.create( simple_sequential.bind()).run_async(workflow_id="running_workflow") with pytest.raises(RuntimeError): workflow.resume(workflow_id="running_workflow") assert ray.get(output) == "[source1][append1][append2]"
def test_dag_to_workflow_execution(workflow_start_regular_shared): """This test constructs a DAG with complex dependencies and turns it into a workflow.""" @ray.remote def begin(x, pos, a): return x * a + pos # 23.14 @ray.remote def left(x, c, a): return f"left({x}, {c}, {a})" @ray.remote def right(x, b, pos): return f"right({x}, {b}, {pos})" @ray.remote def end(lf, rt, b): return f"{lf},{rt};{b}" with pytest.raises(TypeError): workflow.create(begin.remote(1, 2, 3)) with InputNode() as dag_input: f = begin.bind(2, dag_input[1], a=dag_input.a) lf = left.bind(f, "hello", dag_input.a) rt = right.bind(f, b=dag_input.b, pos=dag_input[0]) b = end.bind(lf, rt, b=dag_input.b) wf = workflow.create(b, 2, 3.14, a=10, b="ok") assert len(list(wf._iter_workflows_in_dag())) == 4, "incorrect amount of steps" assert wf.run() == "left(23.14, hello, 10),right(23.14, ok, 2);ok"
def test_resume_different_storage(shutdown_only, tmp_path): @ray.remote def constant(): return 31416 ray.init(storage=str(tmp_path)) workflow.init() workflow.create(constant.bind()).run(workflow_id="const") assert ray.get(workflow.resume(workflow_id="const")) == 31416
def test_resume_different_storage(ray_start_regular, tmp_path, reset_workflow): @ray.remote def constant(): return 31416 workflow.init(storage=str(tmp_path)) workflow.create(constant.bind()).run(workflow_id="const") assert ray.get(workflow.resume(workflow_id="const")) == 31416 workflow.storage.set_global_storage(None)
def test_sleep_checkpointing(workflow_start_regular_shared): """Test that the workflow sleep only starts after `run` not when the step is defined.""" sleep_step = workflow.sleep(2) time.sleep(2) start_time = time.time() workflow.create(sleep_step).run() end_time = time.time() duration = end_time - start_time assert 1 < duration
def test_user_metadata_not_dict(workflow_start_regular): @ray.remote def simple(): return 0 with pytest.raises(ValueError): workflow.create(simple.options(**workflow.options(metadata="x")).bind()) with pytest.raises(ValueError): workflow.create(simple.bind()).run(metadata="x")
def test_dynamic_workflow_ref(workflow_start_regular_shared): @ray.remote def incr(x): return x + 1 # This test also shows different "style" of running workflows. first_step = workflow.create(incr.bind(0)) assert first_step.run("test_dynamic_workflow_ref") == 1 second_step = workflow.create(incr.bind(WorkflowRef(first_step.step_id))) # Without rerun, it'll just return the previous result assert second_step.run("test_dynamic_workflow_ref") == 1
def test_user_metadata_not_json_serializable(workflow_start_regular): @ray.remote def simple(): return 0 class X: pass with pytest.raises(ValueError): workflow.create(simple.options(**workflow.options(metadata={"x": X()})).bind()) with pytest.raises(ValueError): workflow.create(simple.bind()).run(metadata={"x": X()})
def test_tail_recursion_optimization(workflow_start_regular_shared): @ray.remote def tail_recursion(n): import inspect # check if the stack is growing assert len(inspect.stack(0)) < 20 if n <= 0: return "ok" return workflow.continuation( tail_recursion.options(**workflow.options( allow_inplace=True)).bind(n - 1)) workflow.create(tail_recursion.bind(30)).run()
def test_wf_run(workflow_start_regular_shared, tmp_path): counter = tmp_path / "counter" counter.write_text("0") @ray.remote def f(): v = int(counter.read_text()) + 1 counter.write_text(str(v)) workflow.create(f.bind()).run("abc") assert counter.read_text() == "1" # This will not rerun the job from beginning workflow.create(f.bind()).run("abc") assert counter.read_text() == "1"
def test_user_metadata_empty(workflow_start_regular): step_name = "simple_step" workflow_id = "simple" @workflow.options(name=step_name) @ray.remote def simple(): return 0 workflow.create(simple.bind()).run(workflow_id) assert workflow.get_metadata("simple")["user_metadata"] == {} assert workflow.get_metadata("simple", "simple_step")["user_metadata"] == {}
def test_wait_for_multiple_events(workflow_start_regular_shared): """If a workflow has multiple event arguments, it should wait for them at the same time. """ class EventListener1(workflow.EventListener): async def poll_for_event(self): utils.set_global_mark("listener1") while not utils.check_global_mark("trigger_event"): await asyncio.sleep(0.1) return "event1" class EventListener2(workflow.EventListener): async def poll_for_event(self): utils.set_global_mark("listener2") while not utils.check_global_mark("trigger_event"): await asyncio.sleep(0.1) return "event2" @ray.remote def trivial_step(arg1, arg2): return f"{arg1} {arg2}" event1_promise = workflow.wait_for_event(EventListener1) event2_promise = workflow.wait_for_event(EventListener2) promise = workflow.create(trivial_step.bind(event1_promise, event2_promise)).run_async() while not (utils.check_global_mark("listener1") and utils.check_global_mark("listener2")): time.sleep(0.1) utils.set_global_mark("trigger_event") assert ray.get(promise) == "event1 event2"
def test_step_resources(workflow_start_regular, tmp_path): lock_path = str(tmp_path / "lock") # We use signal actor here because we can't guarantee the order of tasks # sent from worker to raylet. signal_actor = SignalActor.remote() @ray.remote def step_run(): ray.wait([signal_actor.send.remote()]) with FileLock(lock_path): return None @ray.remote(num_cpus=1) def remote_run(): return None lock = FileLock(lock_path) lock.acquire() ret = workflow.create(step_run.options(num_cpus=2).bind()).run_async() ray.wait([signal_actor.wait.remote()]) obj = remote_run.remote() with pytest.raises(ray.exceptions.GetTimeoutError): ray.get(obj, timeout=2) lock.release() assert ray.get(ret) is None assert ray.get(obj) is None
def test_nested_workflow_no_download(workflow_start_regular): """Test that we _only_ load from storage on recovery. For a nested workflow step, we should checkpoint the input/output, but continue to reuse the in-memory value. """ @ray.remote def recursive(ref, count): if count == 0: return ref return workflow.continuation(recursive.bind(ref, count - 1)) with tempfile.TemporaryDirectory() as temp_dir: debug_store = DebugStorage(temp_dir) utils._alter_storage(debug_store) ref = ray.put("hello") result = workflow.create(recursive.bind([ref], 10)).run() ops = debug_store._logged_storage.get_op_counter() get_objects_count = 0 for key in ops["get"]: if "objects" in key: get_objects_count += 1 assert get_objects_count == 1, "We should only get once when resuming." put_objects_count = 0 for key in ops["put"]: if "objects" in key: print(key) put_objects_count += 1 assert (put_objects_count == 1 ), "We should detect the object exists before uploading" assert ray.get(result) == ["hello"]
def test_get_output_1(workflow_start_regular, tmp_path): @ray.remote def simple(v): return v assert 0 == workflow.create(simple.bind(0)).run("simple") assert 0 == ray.get(workflow.get_output("simple"))
def test_event_during_arg_resolution(workflow_start_regular_shared): """If a workflow's arguments are being executed when the event occurs, the workflow should run immediately with no issues. """ class MyEventListener(workflow.EventListener): async def poll_for_event(self): while not utils.check_global_mark(): await asyncio.sleep(0.1) utils.set_global_mark("event_returning") @ray.remote def triggers_event(): utils.set_global_mark() while not utils.check_global_mark("event_returning"): time.sleep(0.1) @ray.remote def gather(*args): return args event_promise = workflow.wait_for_event(MyEventListener) assert workflow.create(gather.bind(event_promise, triggers_event.bind())).run() == ( None, None, )
def test_event_after_arg_resolution(workflow_start_regular_shared): """Ensure that a workflow resolves all of its non-event arguments while it waiting the the event to occur. """ class MyEventListener(workflow.EventListener): async def poll_for_event(self): while not utils.check_global_mark(): await asyncio.sleep(0.1) # Give the other step time to finish. await asyncio.sleep(1) @ray.remote def triggers_event(): utils.set_global_mark() @ray.remote def gather(*args): return args event_promise = workflow.wait_for_event(MyEventListener) assert workflow.create(gather.bind(event_promise, triggers_event.bind())).run() == ( None, None, )