Example #1
0
def test_imperative_with_list_io():
    @task
    def t1(a: int) -> typing.List[int]:
        return [1, a, 3]

    @task
    def t2(a: typing.List[int]) -> int:
        return sum(a)

    wb = ImperativeWorkflow(name="my.workflow.a")
    t1_node = wb.add_entity(t1, a=2)
    t2_node = wb.add_entity(t2, a=t1_node.outputs["o0"])
    wb.add_workflow_output("from_n0t2", t2_node.outputs["o0"])

    assert wb() == 6
Example #2
0
def test_imperative_list_bound_output():
    @task
    def t1() -> int:
        return 3

    @task
    def t2(a: typing.List[int]) -> int:
        return sum(a)

    wb = ImperativeWorkflow(name="my.workflow.a")
    t1_node = wb.add_entity(t1)
    t2_node = wb.add_entity(t2, a=[1, 2, 3])
    wb.add_workflow_output("wf0", [t1_node.outputs["o0"], t2_node.outputs["o0"]], python_type=typing.List[int])

    assert wb() == [3, 6]
Example #3
0
def test_imperative_wf_list_input():
    @task
    def t1(a: int) -> typing.List[int]:
        return [1, a, 3]

    @task
    def t2(a: typing.List[int], b: typing.List[int]) -> int:
        return sum(a) + sum(b)

    wb = ImperativeWorkflow(name="my.workflow.a")
    wf_in1 = wb.add_workflow_input("in1", typing.List[int])
    t1_node = wb.add_entity(t1, a=2)
    t2_node = wb.add_entity(t2, a=t1_node.outputs["o0"], b=wf_in1)
    wb.add_workflow_output("from_n0t2", t2_node.outputs["o0"])

    assert wb(in1=[5, 6, 7]) == 24

    srz_wf = get_serializable(OrderedDict(), serialization_settings, wb)
    assert len(srz_wf.nodes) == 2
    assert srz_wf.nodes[0].task_node is not None
Example #4
0
def test_codecov():
    with pytest.raises(FlyteValidationException):
        get_promise(literal_models.BindingData(), {})

    with pytest.raises(FlyteValidationException):
        get_promise(literal_models.BindingData(promise=3), {})

    @task
    def t1(a: str) -> str:
        return a + " world"

    wb = ImperativeWorkflow(name="my.workflow")
    wb.add_workflow_input("in1", str)
    node = wb.add_entity(t1, a=wb.inputs["in1"])
    wb.add_workflow_output("from_n0t1", node.outputs["o0"])

    assert wb(in1="hello") == "hello world"

    with pytest.raises(AssertionError):
        wb(3)

    with pytest.raises(ValueError):
        wb(in2="hello")
Example #5
0
default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = flytekit.configuration.SerializationSettings(
    project="project",
    domain="domain",
    version="version",
    env=None,
    image_config=ImageConfig(default_image=default_img, images=[default_img]),
)


@task
def t1(a: str) -> str:
    return a + " world"


wb = ImperativeWorkflow(name="my.workflow")
wb.add_workflow_input("in1", str)
node = wb.add_entity(t1, a=wb.inputs["in1"])
wb.add_workflow_output("from_n0t1", node.outputs["o0"])


def test_base_case():
    assert wb(in1="hello") == "hello world"


# Please see https://github.com/flyteorg/flyte/issues/854 for more information.
# This mock_patch_wf object is a duplicate of the wb object above. Because of the issue 854, we can't
# use the same object.
# TODO: Remove this duplicate object pending resolution of #854
mock_patch_wf = ImperativeWorkflow(name="my.workflow")
mock_patch_wf.add_workflow_input("in1", str)
Example #6
0
def test_imperative():
    # Re import with alias
    from flytekit.core.workflow import ImperativeWorkflow as Workflow  # noqa

    # docs_tasks_start
    @task
    def t1(a: str) -> str:
        return a + " world"

    @task
    def t2():
        print("side effect")

    # docs_tasks_end

    # docs_start
    # Create the workflow with a name. This needs to be unique within the project and takes the place of the function
    # name that's used for regular decorated function-based workflows.
    wb = Workflow(name="my_workflow")
    # Adds a top level input to the workflow. This is like an input to a workflow function.
    wb.add_workflow_input("in1", str)
    # Call your tasks.
    node = wb.add_entity(t1, a=wb.inputs["in1"])
    wb.add_entity(t2)
    # This is analogous to a return statement
    wb.add_workflow_output("from_n0t1", node.outputs["o0"])
    # docs_end

    assert wb(in1="hello") == "hello world"

    wf_spec = get_serializable(OrderedDict(), serialization_settings, wb)
    assert len(wf_spec.template.nodes) == 2
    assert wf_spec.template.nodes[0].task_node is not None
    assert len(wf_spec.template.outputs) == 1
    assert wf_spec.template.outputs[0].var == "from_n0t1"
    assert len(wf_spec.template.interface.inputs) == 1
    assert len(wf_spec.template.interface.outputs) == 1

    # docs_equivalent_start
    nt = typing.NamedTuple("wf_output", from_n0t1=str)

    @workflow
    def my_workflow(in1: str) -> nt:
        x = t1(a=in1)
        t2()
        return nt(
            x,
        )

    # docs_equivalent_end

    # Create launch plan from wf, that can also be serialized.
    lp = LaunchPlan.create("test_wb", wb)
    lp_model = get_serializable(OrderedDict(), serialization_settings, lp)
    assert lp_model.spec.workflow_id.name == "my_workflow"

    wb2 = ImperativeWorkflow(name="parent.imperative")
    p_in1 = wb2.add_workflow_input("p_in1", str)
    p_node0 = wb2.add_subwf(wb, in1=p_in1)
    wb2.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str)
    wb2_spec = get_serializable(OrderedDict(), serialization_settings, wb2)
    assert len(wb2_spec.template.nodes) == 1
    assert len(wb2_spec.template.interface.inputs) == 1
    assert wb2_spec.template.interface.inputs["p_in1"].type.simple is not None
    assert len(wb2_spec.template.interface.outputs) == 1
    assert wb2_spec.template.interface.outputs["parent_wf_output"].type.simple is not None
    assert wb2_spec.template.nodes[0].workflow_node.sub_workflow_ref.name == "my_workflow"
    assert len(wb2_spec.sub_workflows) == 1

    wb3 = ImperativeWorkflow(name="parent.imperative")
    p_in1 = wb3.add_workflow_input("p_in1", str)
    p_node0 = wb3.add_launch_plan(lp, in1=p_in1)
    wb3.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str)
    wb3_spec = get_serializable(OrderedDict(), serialization_settings, wb3)
    assert len(wb3_spec.template.nodes) == 1
    assert len(wb3_spec.template.interface.inputs) == 1
    assert wb3_spec.template.interface.inputs["p_in1"].type.simple is not None
    assert len(wb3_spec.template.interface.outputs) == 1
    assert wb3_spec.template.interface.outputs["parent_wf_output"].type.simple is not None
    assert wb3_spec.template.nodes[0].workflow_node.launchplan_ref.name == "test_wb"
Example #7
0
def test_imperative_tuples():
    @task
    def t1() -> (int, str):
        return 3, "three"

    @task
    def t3(a: int, b: str) -> typing.Tuple[int, str]:
        return a + 2, "world" + b

    wb = ImperativeWorkflow(name="my.workflow.a")
    t1_node = wb.add_entity(t1)
    t3_node = wb.add_entity(t3, a=t1_node.outputs["o0"], b=t1_node.outputs["o1"])
    wb.add_workflow_output("wf0", t3_node.outputs["o0"], python_type=int)
    wb.add_workflow_output("wf1", t3_node.outputs["o1"], python_type=str)
    res = wb()
    assert res == (5, "worldthree")

    with pytest.raises(KeyError):
        wb = ImperativeWorkflow(name="my.workflow.b")
        t1_node = wb.add_entity(t1)
        wb.add_entity(t3, a=t1_node.outputs["bad"], b=t1_node.outputs["o2"])
Example #8
0
def test_imperative():
    @task
    def t1(a: str) -> str:
        return a + " world"

    @task
    def t2():
        print("side effect")

    wb = ImperativeWorkflow(name="my.workflow")
    wb.add_workflow_input("in1", str)
    node = wb.add_entity(t1, a=wb.inputs["in1"])
    wb.add_entity(t2)
    wb.add_workflow_output("from_n0t1", node.outputs["o0"])

    assert wb(in1="hello") == "hello world"

    wf_spec = get_serializable(OrderedDict(), serialization_settings, wb)
    assert len(wf_spec.template.nodes) == 2
    assert wf_spec.template.nodes[0].task_node is not None
    assert len(wf_spec.template.outputs) == 1
    assert wf_spec.template.outputs[0].var == "from_n0t1"
    assert len(wf_spec.template.interface.inputs) == 1
    assert len(wf_spec.template.interface.outputs) == 1

    # Create launch plan from wf, that can also be serialized.
    lp = LaunchPlan.create("test_wb", wb)
    lp_model = get_serializable(OrderedDict(), serialization_settings, lp)
    assert lp_model.spec.workflow_id.name == "my.workflow"

    wb2 = ImperativeWorkflow(name="parent.imperative")
    p_in1 = wb2.add_workflow_input("p_in1", str)
    p_node0 = wb2.add_subwf(wb, in1=p_in1)
    wb2.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str)
    wb2_spec = get_serializable(OrderedDict(), serialization_settings, wb2)
    assert len(wb2_spec.template.nodes) == 1
    assert len(wb2_spec.template.interface.inputs) == 1
    assert wb2_spec.template.interface.inputs["p_in1"].type.simple is not None
    assert len(wb2_spec.template.interface.outputs) == 1
    assert wb2_spec.template.interface.outputs[
        "parent_wf_output"].type.simple is not None
    assert wb2_spec.template.nodes[
        0].workflow_node.sub_workflow_ref.name == "my.workflow"
    assert len(wb2_spec.sub_workflows) == 1

    wb3 = ImperativeWorkflow(name="parent.imperative")
    p_in1 = wb3.add_workflow_input("p_in1", str)
    p_node0 = wb3.add_launch_plan(lp, in1=p_in1)
    wb3.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str)
    wb3_spec = get_serializable(OrderedDict(), serialization_settings, wb3)
    assert len(wb3_spec.template.nodes) == 1
    assert len(wb3_spec.template.interface.inputs) == 1
    assert wb3_spec.template.interface.inputs["p_in1"].type.simple is not None
    assert len(wb3_spec.template.interface.outputs) == 1
    assert wb3_spec.template.interface.outputs[
        "parent_wf_output"].type.simple is not None
    assert wb3_spec.template.nodes[
        0].workflow_node.launchplan_ref.name == "test_wb"