def chain_func(*args, **kw_argv): # Get the first function as a start wf_step = workflow.step(fs[0]).step(*args, **kw_argv) for i in range(1, len(fs)): # Convert each function inside steps into workflow step # function and then use the previous output as the input # for them. wf_step = workflow.step(fs[i]).step(wf_step) return wf_step
def test_partial(workflow_start_regular_shared): ys = [1, 2, 3] def add(x, y): return x + y from functools import partial f1 = workflow.step(partial(add, 10)).step(10) assert "__anonymous_func__" in f1._name assert f1.run() == 20 fs = [partial(add, y=y) for y in ys] @ray.remote def chain_func(*args, **kw_argv): # Get the first function as a start wf_step = workflow.step(fs[0]).step(*args, **kw_argv) for i in range(1, len(fs)): # Convert each function inside steps into workflow step # function and then use the previous output as the input # for them. wf_step = workflow.step(fs[i]).step(wf_step) return wf_step assert workflow.run(chain_func.bind(1)) == 7
def _node_visitor(node: Any) -> Any: if isinstance(node, FunctionNode): # "_resolve_like_object_ref_in_args" indicates we should resolve the # workflow like an ObjectRef, when included in the arguments of # another workflow. workflow_step = workflow.step( node._body).options(**node._bound_options, _resolve_like_object_ref_in_args=True) wf = workflow_step.step(*node._bound_args, **node._bound_kwargs) return wf if isinstance(node, InputAtrributeNode): return node._execute_impl() # get data from input node if isinstance(node, InputNode): return input_context # replace input node with input data if not isinstance(node, DAGNode): return node # return normal objects raise TypeError(f"Unsupported DAG node: {node}")