def test_imperative_with_list_io(): @task def t1(a: int) -> typing.List[int]: return [1, a, 3] @task def t2(a: typing.List[int]) -> int: return sum(a) wb = ImperativeWorkflow(name="my.workflow.a") t1_node = wb.add_entity(t1, a=2) t2_node = wb.add_entity(t2, a=t1_node.outputs["o0"]) wb.add_workflow_output("from_n0t2", t2_node.outputs["o0"]) assert wb() == 6
def test_imperative_list_bound_output(): @task def t1() -> int: return 3 @task def t2(a: typing.List[int]) -> int: return sum(a) wb = ImperativeWorkflow(name="my.workflow.a") t1_node = wb.add_entity(t1) t2_node = wb.add_entity(t2, a=[1, 2, 3]) wb.add_workflow_output("wf0", [t1_node.outputs["o0"], t2_node.outputs["o0"]], python_type=typing.List[int]) assert wb() == [3, 6]
def test_imperative_wf_list_input(): @task def t1(a: int) -> typing.List[int]: return [1, a, 3] @task def t2(a: typing.List[int], b: typing.List[int]) -> int: return sum(a) + sum(b) wb = ImperativeWorkflow(name="my.workflow.a") wf_in1 = wb.add_workflow_input("in1", typing.List[int]) t1_node = wb.add_entity(t1, a=2) t2_node = wb.add_entity(t2, a=t1_node.outputs["o0"], b=wf_in1) wb.add_workflow_output("from_n0t2", t2_node.outputs["o0"]) assert wb(in1=[5, 6, 7]) == 24 srz_wf = get_serializable(OrderedDict(), serialization_settings, wb) assert len(srz_wf.nodes) == 2 assert srz_wf.nodes[0].task_node is not None
def test_codecov(): with pytest.raises(FlyteValidationException): get_promise(literal_models.BindingData(), {}) with pytest.raises(FlyteValidationException): get_promise(literal_models.BindingData(promise=3), {}) @task def t1(a: str) -> str: return a + " world" wb = ImperativeWorkflow(name="my.workflow") wb.add_workflow_input("in1", str) node = wb.add_entity(t1, a=wb.inputs["in1"]) wb.add_workflow_output("from_n0t1", node.outputs["o0"]) assert wb(in1="hello") == "hello world" with pytest.raises(AssertionError): wb(3) with pytest.raises(ValueError): wb(in2="hello")
default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) @task def t1(a: str) -> str: return a + " world" wb = ImperativeWorkflow(name="my.workflow") wb.add_workflow_input("in1", str) node = wb.add_entity(t1, a=wb.inputs["in1"]) wb.add_workflow_output("from_n0t1", node.outputs["o0"]) def test_base_case(): assert wb(in1="hello") == "hello world" # Please see https://github.com/flyteorg/flyte/issues/854 for more information. # This mock_patch_wf object is a duplicate of the wb object above. Because of the issue 854, we can't # use the same object. # TODO: Remove this duplicate object pending resolution of #854 mock_patch_wf = ImperativeWorkflow(name="my.workflow") mock_patch_wf.add_workflow_input("in1", str)
def test_imperative(): # Re import with alias from flytekit.core.workflow import ImperativeWorkflow as Workflow # noqa # docs_tasks_start @task def t1(a: str) -> str: return a + " world" @task def t2(): print("side effect") # docs_tasks_end # docs_start # Create the workflow with a name. This needs to be unique within the project and takes the place of the function # name that's used for regular decorated function-based workflows. wb = Workflow(name="my_workflow") # Adds a top level input to the workflow. This is like an input to a workflow function. wb.add_workflow_input("in1", str) # Call your tasks. node = wb.add_entity(t1, a=wb.inputs["in1"]) wb.add_entity(t2) # This is analogous to a return statement wb.add_workflow_output("from_n0t1", node.outputs["o0"]) # docs_end assert wb(in1="hello") == "hello world" wf_spec = get_serializable(OrderedDict(), serialization_settings, wb) assert len(wf_spec.template.nodes) == 2 assert wf_spec.template.nodes[0].task_node is not None assert len(wf_spec.template.outputs) == 1 assert wf_spec.template.outputs[0].var == "from_n0t1" assert len(wf_spec.template.interface.inputs) == 1 assert len(wf_spec.template.interface.outputs) == 1 # docs_equivalent_start nt = typing.NamedTuple("wf_output", from_n0t1=str) @workflow def my_workflow(in1: str) -> nt: x = t1(a=in1) t2() return nt( x, ) # docs_equivalent_end # Create launch plan from wf, that can also be serialized. lp = LaunchPlan.create("test_wb", wb) lp_model = get_serializable(OrderedDict(), serialization_settings, lp) assert lp_model.spec.workflow_id.name == "my_workflow" wb2 = ImperativeWorkflow(name="parent.imperative") p_in1 = wb2.add_workflow_input("p_in1", str) p_node0 = wb2.add_subwf(wb, in1=p_in1) wb2.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str) wb2_spec = get_serializable(OrderedDict(), serialization_settings, wb2) assert len(wb2_spec.template.nodes) == 1 assert len(wb2_spec.template.interface.inputs) == 1 assert wb2_spec.template.interface.inputs["p_in1"].type.simple is not None assert len(wb2_spec.template.interface.outputs) == 1 assert wb2_spec.template.interface.outputs["parent_wf_output"].type.simple is not None assert wb2_spec.template.nodes[0].workflow_node.sub_workflow_ref.name == "my_workflow" assert len(wb2_spec.sub_workflows) == 1 wb3 = ImperativeWorkflow(name="parent.imperative") p_in1 = wb3.add_workflow_input("p_in1", str) p_node0 = wb3.add_launch_plan(lp, in1=p_in1) wb3.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str) wb3_spec = get_serializable(OrderedDict(), serialization_settings, wb3) assert len(wb3_spec.template.nodes) == 1 assert len(wb3_spec.template.interface.inputs) == 1 assert wb3_spec.template.interface.inputs["p_in1"].type.simple is not None assert len(wb3_spec.template.interface.outputs) == 1 assert wb3_spec.template.interface.outputs["parent_wf_output"].type.simple is not None assert wb3_spec.template.nodes[0].workflow_node.launchplan_ref.name == "test_wb"
def test_imperative_tuples(): @task def t1() -> (int, str): return 3, "three" @task def t3(a: int, b: str) -> typing.Tuple[int, str]: return a + 2, "world" + b wb = ImperativeWorkflow(name="my.workflow.a") t1_node = wb.add_entity(t1) t3_node = wb.add_entity(t3, a=t1_node.outputs["o0"], b=t1_node.outputs["o1"]) wb.add_workflow_output("wf0", t3_node.outputs["o0"], python_type=int) wb.add_workflow_output("wf1", t3_node.outputs["o1"], python_type=str) res = wb() assert res == (5, "worldthree") with pytest.raises(KeyError): wb = ImperativeWorkflow(name="my.workflow.b") t1_node = wb.add_entity(t1) wb.add_entity(t3, a=t1_node.outputs["bad"], b=t1_node.outputs["o2"])
def test_imperative(): @task def t1(a: str) -> str: return a + " world" @task def t2(): print("side effect") wb = ImperativeWorkflow(name="my.workflow") wb.add_workflow_input("in1", str) node = wb.add_entity(t1, a=wb.inputs["in1"]) wb.add_entity(t2) wb.add_workflow_output("from_n0t1", node.outputs["o0"]) assert wb(in1="hello") == "hello world" wf_spec = get_serializable(OrderedDict(), serialization_settings, wb) assert len(wf_spec.template.nodes) == 2 assert wf_spec.template.nodes[0].task_node is not None assert len(wf_spec.template.outputs) == 1 assert wf_spec.template.outputs[0].var == "from_n0t1" assert len(wf_spec.template.interface.inputs) == 1 assert len(wf_spec.template.interface.outputs) == 1 # Create launch plan from wf, that can also be serialized. lp = LaunchPlan.create("test_wb", wb) lp_model = get_serializable(OrderedDict(), serialization_settings, lp) assert lp_model.spec.workflow_id.name == "my.workflow" wb2 = ImperativeWorkflow(name="parent.imperative") p_in1 = wb2.add_workflow_input("p_in1", str) p_node0 = wb2.add_subwf(wb, in1=p_in1) wb2.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str) wb2_spec = get_serializable(OrderedDict(), serialization_settings, wb2) assert len(wb2_spec.template.nodes) == 1 assert len(wb2_spec.template.interface.inputs) == 1 assert wb2_spec.template.interface.inputs["p_in1"].type.simple is not None assert len(wb2_spec.template.interface.outputs) == 1 assert wb2_spec.template.interface.outputs[ "parent_wf_output"].type.simple is not None assert wb2_spec.template.nodes[ 0].workflow_node.sub_workflow_ref.name == "my.workflow" assert len(wb2_spec.sub_workflows) == 1 wb3 = ImperativeWorkflow(name="parent.imperative") p_in1 = wb3.add_workflow_input("p_in1", str) p_node0 = wb3.add_launch_plan(lp, in1=p_in1) wb3.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str) wb3_spec = get_serializable(OrderedDict(), serialization_settings, wb3) assert len(wb3_spec.template.nodes) == 1 assert len(wb3_spec.template.interface.inputs) == 1 assert wb3_spec.template.interface.inputs["p_in1"].type.simple is not None assert len(wb3_spec.template.interface.outputs) == 1 assert wb3_spec.template.interface.outputs[ "parent_wf_output"].type.simple is not None assert wb3_spec.template.nodes[ 0].workflow_node.launchplan_ref.name == "test_wb"