def test_failed_and_resumed_workflow(workflow_start_regular, tmp_path): workflow_id = "simple" error_flag = tmp_path / "error" error_flag.touch() @ray.remote def simple(): if error_flag.exists(): raise ValueError() return 0 with pytest.raises(workflow.WorkflowExecutionError): workflow.run(simple.bind(), workflow_id=workflow_id) workflow_metadata_failed = workflow.get_metadata(workflow_id) assert workflow_metadata_failed["status"] == "FAILED" error_flag.unlink() assert workflow.resume(workflow_id) == 0 workflow_metadata_resumed = workflow.get_metadata(workflow_id) assert workflow_metadata_resumed["status"] == "SUCCESSFUL" # make sure resume updated running metrics assert (workflow_metadata_resumed["stats"]["start_time"] > workflow_metadata_failed["stats"]["start_time"]) assert (workflow_metadata_resumed["stats"]["end_time"] > workflow_metadata_failed["stats"]["end_time"])
def test_nested_workflow(workflow_start_regular): @workflow.options(name="inner", metadata={"inner_k": "inner_v"}) @ray.remote def inner(): time.sleep(2) return 10 @workflow.options(name="outer", metadata={"outer_k": "outer_v"}) @ray.remote def outer(): time.sleep(2) return workflow.continuation(inner.bind()) workflow.run(outer.bind(), workflow_id="nested", metadata={"workflow_k": "workflow_v"}) workflow_metadata = workflow.get_metadata("nested") outer_step_metadata = workflow.get_metadata("nested", "outer") inner_step_metadata = workflow.get_metadata("nested", "inner") assert workflow_metadata["user_metadata"] == {"workflow_k": "workflow_v"} assert outer_step_metadata["user_metadata"] == {"outer_k": "outer_v"} assert inner_step_metadata["user_metadata"] == {"inner_k": "inner_v"} assert (workflow_metadata["stats"]["end_time"] >= workflow_metadata["stats"]["start_time"] + 4) assert (outer_step_metadata["stats"]["end_time"] >= outer_step_metadata["stats"]["start_time"] + 2) assert (inner_step_metadata["stats"]["end_time"] >= inner_step_metadata["stats"]["start_time"] + 2) assert (inner_step_metadata["stats"]["start_time"] >= outer_step_metadata["stats"]["end_time"])
def test_recovery_simple_3(workflow_start_regular): @ray.remote def append1(x): return x + "[append1]" @ray.remote def append2(x): return x + "[append2]" @ray.remote def simple(x): x = append1.bind(x) y = the_failed_step.bind(x) z = append2.bind(y) return workflow.continuation(z) utils.unset_global_mark() workflow_id = "test_recovery_simple_3" with pytest.raises(workflow.WorkflowExecutionError): # internally we get WorkerCrashedError workflow.run(simple.bind("x"), workflow_id=workflow_id) assert workflow.get_status(workflow_id) == workflow.WorkflowStatus.FAILED utils.set_global_mark() assert workflow.resume(workflow_id) == "foo(x[append1])[append2]" utils.unset_global_mark() # resume from workflow output checkpoint assert workflow.resume(workflow_id) == "foo(x[append1])[append2]"
def test_successful_workflow(workflow_start_regular): user_step_metadata = {"k1": "v1"} user_run_metadata = {"k2": "v2"} step_name = "simple_step" workflow_id = "simple" @workflow.options(name=step_name, metadata=user_step_metadata) @ray.remote def simple(): time.sleep(2) return 0 workflow.run(simple.bind(), workflow_id=workflow_id, metadata=user_run_metadata) workflow_metadata = workflow.get_metadata("simple") assert workflow_metadata["status"] == "SUCCESSFUL" assert workflow_metadata["user_metadata"] == user_run_metadata assert "start_time" in workflow_metadata["stats"] assert "end_time" in workflow_metadata["stats"] assert (workflow_metadata["stats"]["end_time"] >= workflow_metadata["stats"]["start_time"] + 2) step_metadata = workflow.get_metadata("simple", "simple_step") assert step_metadata["user_metadata"] == user_step_metadata assert "start_time" in step_metadata["stats"] assert "end_time" in step_metadata["stats"] assert (step_metadata["stats"]["end_time"] >= step_metadata["stats"]["start_time"] + 2)
def test_dedupe_serialization(workflow_start_regular_shared): @ray.remote(num_cpus=0) class Counter: def __init__(self): self.count = 0 def incr(self): self.count += 1 def get_count(self): return self.count counter = Counter.remote() class CustomClass: def __getstate__(self): # Count the number of times this class is serialized. ray.get(counter.incr.remote()) return {} ref = ray.put(CustomClass()) list_of_refs = [ref for _ in range(2)] # One for the ray.put assert ray.get(counter.get_count.remote()) == 1 single = identity.bind((ref, )) double = identity.bind(list_of_refs) workflow.run(gather.bind(single, double)) # One more for hashing the ref, and for uploading. assert ray.get(counter.get_count.remote()) == 3
def test_nested_catch_exception_3(workflow_start_regular_shared, tmp_path): """Test the case where the exception is not raised by the output task of a nested DAG.""" @ray.remote def f3(): return 10 @ray.remote def f3_exc(): raise ValueError() @ray.remote def f2(x): return x @ray.remote def f1(exc): if exc: return workflow.continuation(f2.bind(f3_exc.bind())) else: return workflow.continuation(f2.bind(f3.bind())) ret, err = workflow.run( f1.options(**workflow.options(catch_exceptions=True)).bind(True) ) assert ret is None assert isinstance(err, ValueError) assert (10, None) == workflow.run( f1.options(**workflow.options(catch_exceptions=True)).bind(False) )
def test_dedupe_indirect(workflow_start_regular_shared, tmp_path): counter = Path(tmp_path) / "counter.txt" lock = Path(tmp_path) / "lock.txt" counter.write_text("0") @ray.remote def incr(): with FileLock(str(lock)): c = int(counter.read_text()) c += 1 counter.write_text(f"{c}") @ray.remote def identity(a): return a @ray.remote def join(*a): return counter.read_text() # Here a is passed to two steps and we need to ensure # it's only executed once a = incr.bind() i1 = identity.bind(a) i2 = identity.bind(a) assert "1" == workflow.run(join.bind(i1, i2)) assert "2" == workflow.run(join.bind(i1, i2)) # pass a multiple times assert "3" == workflow.run(join.bind(a, a, a, a)) assert "4" == workflow.run(join.bind(a, a, a, a))
def test_get_output_3(workflow_start_regular, tmp_path): cnt_file = tmp_path / "counter" cnt_file.write_text("0") error_flag = tmp_path / "error" error_flag.touch() @ray.remote def incr(): v = int(cnt_file.read_text()) cnt_file.write_text(str(v + 1)) if error_flag.exists(): raise ValueError() return 10 with pytest.raises(workflow.WorkflowExecutionError): workflow.run(incr.options(max_retries=0).bind(), workflow_id="incr") assert cnt_file.read_text() == "1" from ray.exceptions import RaySystemError # TODO(suquark): We should prevent Ray from raising "RaySystemError", # in workflow, because "RaySystemError" does not inherit the underlying # error, so users and developers cannot catch the expected error. # I feel this issue is a very annoying. with pytest.raises((RaySystemError, ValueError)): workflow.get_output("incr") assert cnt_file.read_text() == "1" error_flag.unlink() with pytest.raises((RaySystemError, ValueError)): workflow.get_output("incr") assert workflow.resume("incr") == 10
def test_dynamic_output(workflow_start_regular_shared): @ray.remote def exponential_fail(k, n): if n > 0: if n < 3: raise Exception("Failed intentionally") return workflow.continuation( exponential_fail.options(**workflow.options( name=f"step_{n}")).bind(k * 2, n - 1)) return k # When workflow fails, the dynamic output should points to the # latest successful step. try: workflow.run( exponential_fail.options(**workflow.options(name="step_0")).bind( 3, 10), workflow_id="dynamic_output", ) except Exception: pass from ray.workflow.workflow_storage import get_workflow_storage wf_storage = get_workflow_storage(workflow_id="dynamic_output") result = wf_storage.inspect_step("step_0") assert result.output_step_id == "step_3"
def test_resume_different_storage(shutdown_only, tmp_path): @ray.remote def constant(): return 31416 ray.init(storage=str(tmp_path)) workflow.init() workflow.run(constant.bind(), workflow_id="const") assert workflow.resume(workflow_id="const") == 31416
def test_sleep_checkpointing(workflow_start_regular_shared): """Test that the workflow sleep only starts after `run` not when the step is defined.""" sleep_step = workflow.sleep(2) time.sleep(2) start_time = time.time() workflow.run(sleep_step) end_time = time.time() duration = end_time - start_time assert 1 < duration
def test_user_metadata_not_dict(workflow_start_regular): @ray.remote def simple(): return 0 with pytest.raises(ValueError): workflow.run_async( simple.options(**workflow.options(metadata="x")).bind()) with pytest.raises(ValueError): workflow.run(simple.bind(), metadata="x")
def test_user_metadata_not_json_serializable(workflow_start_regular): @ray.remote def simple(): return 0 class X: pass with pytest.raises(ValueError): workflow.run_async( simple.options(**workflow.options(metadata={"x": X()})).bind()) with pytest.raises(ValueError): workflow.run(simple.bind(), metadata={"x": X()})
def test_recovery_simple_1(workflow_start_regular): utils.unset_global_mark() workflow_id = "test_recovery_simple_1" with pytest.raises(workflow.WorkflowExecutionError): # internally we get WorkerCrashedError workflow.run(the_failed_step.bind("x"), workflow_id=workflow_id) assert workflow.get_status(workflow_id) == workflow.WorkflowStatus.FAILED utils.set_global_mark() assert workflow.resume(workflow_id) == "foo(x)" utils.unset_global_mark() # resume from workflow output checkpoint assert workflow.resume(workflow_id) == "foo(x)"
def test_dynamic_workflow_ref(workflow_start_regular_shared): @ray.remote def incr(x): return x + 1 # This test also shows different "style" of running workflows. assert workflow.run(incr.bind(0), workflow_id="test_dynamic_workflow_ref") == 1 # Without rerun, it'll just return the previous result assert ( workflow.run( incr.bind(WorkflowRef("incr")), workflow_id="test_dynamic_workflow_ref" ) == 1 )
def test_wf_run(workflow_start_regular_shared, tmp_path): counter = tmp_path / "counter" counter.write_text("0") @ray.remote def f(): v = int(counter.read_text()) + 1 counter.write_text(str(v)) workflow.run(f.bind(), workflow_id="abc") assert counter.read_text() == "1" # This will not rerun the job from beginning workflow.run(f.bind(), workflow_id="abc") assert counter.read_text() == "1"
def test_wf_no_run(shutdown_only): # workflow should be able to run without explicit init ray.shutdown() @ray.remote def f1(): pass f1.bind() @ray.remote def f2(*w): pass workflow.run(f2.bind(*[f1.bind() for _ in range(10)]))
def test_user_metadata_empty(workflow_start_regular): step_name = "simple_step" workflow_id = "simple" @workflow.options(name=step_name) @ray.remote def simple(): return 0 workflow.run(simple.bind(), workflow_id=workflow_id) assert workflow.get_metadata("simple")["user_metadata"] == {} assert workflow.get_metadata("simple", "simple_step")["user_metadata"] == {}
def test_dedupe_download_raw_ref(workflow_start_regular): with tempfile.TemporaryDirectory() as temp_dir: debug_store = DebugStorage(temp_dir) utils._alter_storage(debug_store) ref = ray.put("hello") workflows = [identity.bind(ref) for _ in range(100)] workflow.run(gather.bind(*workflows)) ops = debug_store._logged_storage.get_op_counter() get_objects_count = 0 for key in ops["get"]: if "objects" in key: get_objects_count += 1 assert get_objects_count == 1
def test_partial(workflow_start_regular_shared): ys = [1, 2, 3] def add(x, y): return x + y from functools import partial f1 = workflow.step(partial(add, 10)).step(10) assert "__anonymous_func__" in f1._name assert f1.run() == 20 fs = [partial(add, y=y) for y in ys] @ray.remote def chain_func(*args, **kw_argv): # Get the first function as a start wf_step = workflow.step(fs[0]).step(*args, **kw_argv) for i in range(1, len(fs)): # Convert each function inside steps into workflow step # function and then use the previous output as the input # for them. wf_step = workflow.step(fs[i]).step(wf_step) return wf_step assert workflow.run(chain_func.bind(1)) == 7
def test_nested_workflow_no_download(workflow_start_regular): """Test that we _only_ load from storage on recovery. For a nested workflow step, we should checkpoint the input/output, but continue to reuse the in-memory value. """ @ray.remote def recursive(ref, count): if count == 0: return ref return workflow.continuation(recursive.bind(ref, count - 1)) with tempfile.TemporaryDirectory() as temp_dir: debug_store = DebugStorage(temp_dir) utils._alter_storage(debug_store) ref = ray.put("hello") result = workflow.run(recursive.bind([ref], 10)) ops = debug_store._logged_storage.get_op_counter() get_objects_count = 0 for key in ops["get"]: if "objects" in key: get_objects_count += 1 assert get_objects_count == 1, "We should only get once when resuming." put_objects_count = 0 for key in ops["put"]: if "objects" in key: print(key) put_objects_count += 1 assert ( put_objects_count == 1 ), "We should detect the object exists before uploading" assert ray.get(result) == ["hello"]
def test_get_output_1(workflow_start_regular, tmp_path): @ray.remote def simple(v): return v assert 0 == workflow.run(simple.bind(0), workflow_id="simple") assert 0 == workflow.get_output("simple")
def test_event_during_arg_resolution(workflow_start_regular_shared): """If a workflow's arguments are being executed when the event occurs, the workflow should run immediately with no issues. """ class MyEventListener(workflow.EventListener): async def poll_for_event(self): while not utils.check_global_mark(): await asyncio.sleep(0.1) utils.set_global_mark("event_returning") @ray.remote def triggers_event(): utils.set_global_mark() while not utils.check_global_mark("event_returning"): time.sleep(0.1) @ray.remote def gather(*args): return args event_promise = workflow.wait_for_event(MyEventListener) assert workflow.run(gather.bind(event_promise, triggers_event.bind())) == ( None, None, )
def test_event_after_arg_resolution(workflow_start_regular_shared): """Ensure that a workflow resolves all of its non-event arguments while it waiting the the event to occur. """ class MyEventListener(workflow.EventListener): async def poll_for_event(self): while not utils.check_global_mark(): await asyncio.sleep(0.1) # Give the other step time to finish. await asyncio.sleep(1) @ray.remote def triggers_event(): utils.set_global_mark() @ray.remote def gather(*args): return args event_promise = workflow.wait_for_event(MyEventListener) assert workflow.run(gather.bind(event_promise, triggers_event.bind())) == ( None, None, )
def test_object_deref(workflow_start_regular_shared): @ray.remote def empty_list(): return [1] @ray.remote def receive_workflow(workflow): pass @ray.remote def return_workflow(): return empty_list.bind() @ray.remote def return_data() -> ray.ObjectRef: return ray.put(np.ones(4096)) @ray.remote def receive_data(data: "ray.ObjectRef[np.ndarray]"): return ray.get(data) # test we are forbidden from directly passing workflow to Ray. x = empty_list.bind() with pytest.raises(ValueError): ray.put(x) with pytest.raises(ValueError): ray.get(receive_workflow.remote(x)) with pytest.raises(ValueError): ray.get(return_workflow.remote()) # test return object ref obj = return_data.bind() arr: np.ndarray = workflow.run(receive_data.bind(obj)) assert np.array_equal(arr, np.ones(4096))
def test_objectref_inputs(workflow_start_regular_shared): @ray.remote def nested_workflow(n: int): if n <= 0: return "nested" else: return workflow.continuation(nested_workflow.bind(n - 1)) @ray.remote def deref_check(u: int, x: str, y: List[str], z: List[Dict[str, str]]): try: return (u == 42 and x == "nested" and isinstance(y[0], ray.ObjectRef) and ray.get(y) == ["nested"] and isinstance(z[0]["output"], ray.ObjectRef) and ray.get( z[0]["output"]) == "nested"), f"{u}, {x}, {y}, {z}" except Exception as e: return False, str(e) output, s = workflow.run( deref_check.bind( ray.put(42), nested_workflow.bind(10), [nested_workflow.bind(9)], [{ "output": nested_workflow.bind(7) }], )) assert output is True, s
def test_dataset_2(workflow_start_regular_shared): ds_ref = gen_dataset_2.bind() transformed_ref = transform_dataset_1.bind(ds_ref) output_ref = sum_dataset.bind(transformed_ref) result = workflow.run(output_ref) assert result == 2 * sum(range(1000))
def test_checkpoint_dag_recovery_partial(workflow_start_regular_shared): utils.unset_global_mark() start = time.time() with pytest.raises(workflow.WorkflowExecutionError): workflow.run(checkpoint_dag.bind(False), workflow_id="checkpoint_partial_recovery") run_duration_partial = time.time() - start utils.set_global_mark() start = time.time() recovered = workflow.resume("checkpoint_partial_recovery") recover_duration_partial = time.time() - start assert np.isclose(recovered, np.arange(SIZE).mean()) print(f"[partial] run_duration = {run_duration_partial}, " f"recover_duration = {recover_duration_partial}")
def test_same_object_many_workflows(workflow_start_regular_shared): """Ensure that when we dedupe uploads, we upload the object once per workflow, since different workflows shouldn't look in each others object directories. """ @ray.remote def f(a): return [a[0]] x = {0: ray.put(10)} result1 = workflow.run(f.bind(x)) result2 = workflow.run(f.bind(x)) print(result1) print(result2) assert ray.get(*result1) == 10 assert ray.get(*result2) == 10
def test_workflow_queuing_1(shutdown_only, tmp_path): ray.init(storage=str(tmp_path)) workflow.init(max_running_workflows=2, max_pending_workflows=2) import queue import filelock lock_path = str(tmp_path / ".lock") @ray.remote def long_running(x): with filelock.FileLock(lock_path): return x wfs = [long_running.bind(i) for i in range(5)] with filelock.FileLock(lock_path): refs = [ workflow.run_async(wfs[i], workflow_id=f"workflow_{i}") for i in range(4) ] assert sorted(x[0] for x in workflow.list_all({workflow.RUNNING})) == [ "workflow_0", "workflow_1", ] assert sorted(x[0] for x in workflow.list_all({workflow.PENDING})) == [ "workflow_2", "workflow_3", ] with pytest.raises(queue.Full, match="Workflow queue has been full"): workflow.run(wfs[4], workflow_id="workflow_4") assert ray.get(refs) == [0, 1, 2, 3] assert workflow.run(wfs[4], workflow_id="workflow_4") == 4 assert sorted(x[0] for x in workflow.list_all({workflow.SUCCESSFUL})) == [ "workflow_0", "workflow_1", "workflow_2", "workflow_3", "workflow_4", ] for i in range(5): assert workflow.get_output(f"workflow_{i}") == i