Esempio n. 1
0
def test_ref():
    @reference_task(
        project="flytesnacks",
        domain="development",
        name="recipes.aaa.simple.join_strings",
        version="553018f39e519bdb2597b652639c30ce16b99c79",
    )
    def ref_t1(a: typing.List[str]) -> str:
        ...

    assert ref_t1.id.project == "flytesnacks"
    assert ref_t1.id.domain == "development"
    assert ref_t1.id.name == "recipes.aaa.simple.join_strings"
    assert ref_t1.id.version == "553018f39e519bdb2597b652639c30ce16b99c79"

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    ss = get_serializable(OrderedDict(), serialization_settings, ref_t1)
    assert ss is None

    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa",
                                       tag="123")),
        env={},
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings, ref_t1)
    assert task_spec is None
Esempio n. 2
0
def test_imperative():
    @task
    def t1(a: str) -> str:
        return a + " world"

    @task
    def t2():
        print("side effect")

    wb = ImperativeWorkflow(name="my.workflow")
    wb.add_workflow_input("in1", str)
    node = wb.add_entity(t1, a=wb.inputs["in1"])
    wb.add_entity(t2)
    wb.add_workflow_output("from_n0t1", node.outputs["o0"])

    assert wb(in1="hello") == "hello world"

    srz_wf = get_serializable(OrderedDict(), serialization_settings, wb)
    assert len(srz_wf.nodes) == 2
    assert srz_wf.nodes[0].task_node is not None
    assert len(srz_wf.outputs) == 1
    assert srz_wf.outputs[0].var == "from_n0t1"
    assert len(srz_wf.interface.inputs) == 1
    assert len(srz_wf.interface.outputs) == 1

    # Create launch plan from wf, that can also be serialized.
    lp = LaunchPlan.create("test_wb", wb)
    srz_lp = get_serializable(OrderedDict(), serialization_settings, lp)
    assert srz_lp.workflow_id.name == "my.workflow"
Esempio n. 3
0
def test_basics():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        return a + 2, "world"

    @task
    def t2(a: str, b: str) -> str:
        return b + a

    @workflow
    def my_wf(a: int, b: str) -> (int, str):
        x, y = t1(a=a)
        d = t2(a=y, b=b)
        return x, d

    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf,
                               False)
    assert len(wf_spec.template.interface.inputs) == 2
    assert len(wf_spec.template.interface.outputs) == 2
    assert len(wf_spec.template.nodes) == 2
    assert wf_spec.template.id.resource_type == identifier_models.ResourceType.WORKFLOW

    # Gets cached the first time around so it's not actually fast.
    task_spec = get_serializable(OrderedDict(), serialization_settings, t1,
                                 True)
    assert "pyflyte-execute" in task_spec.template.container.args

    lp = LaunchPlan.create(
        "testlp",
        my_wf,
    )
    lp_model = get_serializable(OrderedDict(), serialization_settings, lp)
    assert lp_model.id.name == "testlp"
Esempio n. 4
0
def test_references():
    rlp = ReferenceLaunchPlan("media",
                              "stg",
                              "some.name",
                              "cafe",
                              inputs=kwtypes(in1=str),
                              outputs=kwtypes())
    sdk_lp = get_serializable(serialization_settings, rlp)
    assert sdk_lp.has_registered

    rt = ReferenceTask("media",
                       "stg",
                       "some.name",
                       "cafe",
                       inputs=kwtypes(in1=str),
                       outputs=kwtypes())
    sdk_task = get_serializable(serialization_settings, rt)
    assert sdk_task.has_registered

    rw = ReferenceWorkflow("media",
                           "stg",
                           "some.name",
                           "cafe",
                           inputs=kwtypes(in1=str),
                           outputs=kwtypes())
    sdk_wf = get_serializable(serialization_settings, rw)
    assert sdk_wf.has_registered
def test_query_no_inputs_or_outputs():
    hive_task = HiveTask(
        name="flytekit.demo.hive_task.hivequery1",
        inputs={},
        task_config=HiveConfig(cluster_label="flyte"),
        query_template="""
            insert into extant_table (1, 'two')
        """,
        output_schema_type=None,
    )

    @workflow
    def my_wf():
        hive_task()

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )
    sdk_task = get_serializable(serialization_settings, hive_task)
    assert len(sdk_task.interface.inputs) == 0
    assert len(sdk_task.interface.outputs) == 0

    get_serializable(serialization_settings, my_wf)
Esempio n. 6
0
def test_basics():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        return a + 2, "world"

    @task
    def t2(a: str, b: str) -> str:
        return b + a

    @workflow
    def my_wf(a: int, b: str) -> (int, str):
        x, y = t1(a=a)
        d = t2(a=y, b=b)
        return x, d

    sdk_wf = get_serializable(serialization_settings, my_wf, False)
    assert len(sdk_wf.interface.inputs) == 2
    assert len(sdk_wf.interface.outputs) == 2
    assert len(sdk_wf.nodes) == 2

    # Gets cached the first time around so it's not actually fast.
    sdk_task = get_serializable(serialization_settings, t1, True)
    assert "pyflyte-execute" in sdk_task.container.args

    lp = LaunchPlan.create(
        "testlp",
        my_wf,
    )
    sdk_lp = get_serializable(serialization_settings, lp)
    assert sdk_lp.id.name == "testlp"
Esempio n. 7
0
def test_references():
    rlp = ReferenceLaunchPlan("media",
                              "stg",
                              "some.name",
                              "cafe",
                              inputs=kwtypes(in1=str),
                              outputs=kwtypes())
    lp_model = get_serializable(OrderedDict(), serialization_settings, rlp)
    assert lp_model is None

    rt = ReferenceTask("media",
                       "stg",
                       "some.name",
                       "cafe",
                       inputs=kwtypes(in1=str),
                       outputs=kwtypes())
    task_spec = get_serializable(OrderedDict(), serialization_settings, rt)
    assert task_spec is None

    rw = ReferenceWorkflow("media",
                           "stg",
                           "some.name",
                           "cafe",
                           inputs=kwtypes(in1=str),
                           outputs=kwtypes())
    wf_spec = get_serializable(OrderedDict(), serialization_settings, rw)
    assert wf_spec is None
Esempio n. 8
0
def test_imperative():
    @task
    def t1(a: str) -> str:
        return a + " world"

    @task
    def t2():
        print("side effect")

    wb = ImperativeWorkflow(name="my.workflow")
    wb.add_workflow_input("in1", str)
    node = wb.add_entity(t1, a=wb.inputs["in1"])
    wb.add_entity(t2)
    wb.add_workflow_output("from_n0t1", node.outputs["o0"])

    assert wb(in1="hello") == "hello world"

    wf_spec = get_serializable(OrderedDict(), serialization_settings, wb)
    assert len(wf_spec.template.nodes) == 2
    assert wf_spec.template.nodes[0].task_node is not None
    assert len(wf_spec.template.outputs) == 1
    assert wf_spec.template.outputs[0].var == "from_n0t1"
    assert len(wf_spec.template.interface.inputs) == 1
    assert len(wf_spec.template.interface.outputs) == 1

    # Create launch plan from wf, that can also be serialized.
    lp = LaunchPlan.create("test_wb", wb)
    lp_model = get_serializable(OrderedDict(), serialization_settings, lp)
    assert lp_model.spec.workflow_id.name == "my.workflow"

    wb2 = ImperativeWorkflow(name="parent.imperative")
    p_in1 = wb2.add_workflow_input("p_in1", str)
    p_node0 = wb2.add_subwf(wb, in1=p_in1)
    wb2.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str)
    wb2_spec = get_serializable(OrderedDict(), serialization_settings, wb2)
    assert len(wb2_spec.template.nodes) == 1
    assert len(wb2_spec.template.interface.inputs) == 1
    assert wb2_spec.template.interface.inputs["p_in1"].type.simple is not None
    assert len(wb2_spec.template.interface.outputs) == 1
    assert wb2_spec.template.interface.outputs[
        "parent_wf_output"].type.simple is not None
    assert wb2_spec.template.nodes[
        0].workflow_node.sub_workflow_ref.name == "my.workflow"
    assert len(wb2_spec.sub_workflows) == 1

    wb3 = ImperativeWorkflow(name="parent.imperative")
    p_in1 = wb3.add_workflow_input("p_in1", str)
    p_node0 = wb3.add_launch_plan(lp, in1=p_in1)
    wb3.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str)
    wb3_spec = get_serializable(OrderedDict(), serialization_settings, wb3)
    assert len(wb3_spec.template.nodes) == 1
    assert len(wb3_spec.template.interface.inputs) == 1
    assert wb3_spec.template.interface.inputs["p_in1"].type.simple is not None
    assert len(wb3_spec.template.interface.outputs) == 1
    assert wb3_spec.template.interface.outputs[
        "parent_wf_output"].type.simple is not None
    assert wb3_spec.template.nodes[
        0].workflow_node.launchplan_ref.name == "test_wb"
Esempio n. 9
0
def test_wf_resolving():
    @workflow
    def my_wf(a: int, b: str) -> (int, str):
        @task
        def t1(a: int) -> (int, str):
            return a + 2, "world"

        @task
        def t2(a: str, b: str) -> str:
            return b + a

        x, y = t1(a=a)
        d = t2(a=y, b=b)
        return x, d

    x = my_wf(a=3, b="hello")
    assert x == (5, "helloworld")

    # Because the workflow is nested inside a test, calling location will fail as it tries to find the LHS that the
    # workflow was assigned to
    with pytest.raises(Exception):
        _ = my_wf.location

    # Pretend my_wf was not actually nested, but somehow assigned to example_var_name at the module layer
    my_wf._instantiated_in = "example.module"
    my_wf._lhs = "example_var_name"

    workflows_tasks = my_wf.get_all_tasks()
    assert len(workflows_tasks) == 2  # Two tasks were declared inside

    # The tasks should get the location the workflow was assigned to as the resolver.
    # The args are the index.
    srz_t0 = get_serializable(OrderedDict(), serialization_settings,
                              workflows_tasks[0])
    assert srz_t0.container.args[-4:] == [
        "--resolver",
        "example.module.example_var_name",
        "--",
        "0",
    ]

    srz_t1 = get_serializable(OrderedDict(), serialization_settings,
                              workflows_tasks[1])
    assert srz_t1.container.args[-4:] == [
        "--resolver",
        "example.module.example_var_name",
        "--",
        "1",
    ]
Esempio n. 10
0
def test_serialization():
    square = ContainerTask(
        name="square",
        input_data_dir="/var/inputs",
        output_data_dir="/var/outputs",
        inputs=kwtypes(val=int),
        outputs=kwtypes(out=int),
        image="alpine",
        command=[
            "sh", "-c",
            "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out"
        ],
    )

    sum = ContainerTask(
        name="sum",
        input_data_dir="/var/flyte/inputs",
        output_data_dir="/var/flyte/outputs",
        inputs=kwtypes(x=int, y=int),
        outputs=kwtypes(out=int),
        image="alpine",
        command=[
            "sh", "-c",
            "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out"
        ],
    )

    @workflow
    def raw_container_wf(val1: int, val2: int) -> int:
        return sum(x=square(val=val1), y=square(val=val2))

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings,
                               raw_container_wf)
    assert wf_spec is not None
    assert wf_spec.template is not None
    assert len(wf_spec.template.nodes) == 3
    sqn_spec = get_serializable(OrderedDict(), serialization_settings, square)
    assert sqn_spec.template.container.image == "alpine"
    sumn_spec = get_serializable(OrderedDict(), serialization_settings, sum)
    assert sumn_spec.template.container.image == "alpine"
Esempio n. 11
0
def test_lp_serialize():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        a = a + 2
        return a, "world-" + str(a)

    @task
    def t2(a: str, b: str) -> str:
        return b + a

    @workflow
    def my_subwf(a: int) -> (str, str):
        x, y = t1(a=a)
        u, v = t1(a=x)
        return y, v

    lp = launch_plan.LaunchPlan.create("serialize_test1", my_subwf)
    lp_with_defaults = launch_plan.LaunchPlan.create("serialize_test2",
                                                     my_subwf,
                                                     default_inputs={"a": 3})

    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa",
                                       tag="123")),
        env={},
    )
    sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp)
    assert len(sdk_lp.default_inputs.parameters) == 1
    assert sdk_lp.default_inputs.parameters["a"].required
    assert len(sdk_lp.fixed_inputs.literals) == 0

    sdk_lp = get_serializable(OrderedDict(), serialization_settings,
                              lp_with_defaults)
    assert len(sdk_lp.default_inputs.parameters) == 1
    assert not sdk_lp.default_inputs.parameters["a"].required
    assert sdk_lp.default_inputs.parameters[
        "a"].default == _literal_models.Literal(scalar=_literal_models.Scalar(
            primitive=_literal_models.Primitive(integer=3)))
    assert len(sdk_lp.fixed_inputs.literals) == 0

    # Adding a check to make sure oneof is respected. Tricky with booleans... if a default is specified, the
    # required field needs to be None, not False.
    parameter_a = sdk_lp.default_inputs.parameters["a"]
    parameter_a = Parameter.from_flyte_idl(parameter_a.to_flyte_idl())
    assert parameter_a.default is not None
Esempio n. 12
0
def test_pod_task_serialized():
    pod = Pod(pod_spec=get_pod_spec(),
              primary_container_name="an undefined container")

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")
    ssettings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    serialized = get_serializable(OrderedDict(), ssettings, simple_pod_task)
    assert serialized.task_type_version == 1
    assert serialized.config[
        "primary_container_name"] == "an undefined container"
def test_serialization_branch_sub_wf():
    @task
    def t1(a: int) -> int:
        return a + 2

    @workflow
    def my_sub_wf(a: int) -> int:
        return t1(a=a)

    @workflow
    def my_wf(a: int) -> int:
        d = conditional("test1").if_(a > 3).then(t1(a=a)).else_().then(
            my_sub_wf(a=a))
        return d

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf = get_serializable(serialization_settings, my_wf)
    assert wf is not None
    assert len(wf.nodes[0].inputs) == 1
    assert wf.nodes[0].inputs[0].var == ".a"
    assert wf.nodes[0] is not None
def test_serialization_branch_complex():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        return a + 2, "world"

    @task
    def t2(a: str) -> str:
        return a

    @workflow
    def my_wf(a: int, b: str) -> (int, str):
        x, y = t1(a=a)
        d = (conditional("test1").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then(
            t2(a=y)).else_().fail("Unable to choose branch"))
        f = conditional("test2").if_(d == "hello ").then(
            t2(a="It is hello")).else_().then(t2(a="Not Hello!"))
        return x, f

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf = get_serializable(serialization_settings, my_wf)
    assert wf is not None
    assert len(wf.nodes) == 3
    assert wf.nodes[1].branch_node is not None
    assert wf.nodes[2].branch_node is not None
Esempio n. 15
0
def test_condition_tuple_branches():
    @task
    def sum_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, sub=int):
        return a + b, a - b

    @workflow
    def math_ops(a: int, b: int) -> (int, int):
        # Flyte will only make `sum` and `sub` available as outputs because they are common between all branches
        sum, sub = (
            conditional("noDivByZero")
            .if_(a > b)
            .then(sum_sub(a=a, b=b))
            .else_()
            .fail("Only positive results are allowed")
        )

        return sum, sub

    x, y = math_ops(a=3, b=2)
    assert x == 5
    assert y == 1

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img, images=[default_img]),
    )

    sdk_wf = get_serializable(OrderedDict(), serialization_settings, math_ops)
    assert sdk_wf.nodes[0].branch_node.if_else.case.then_node.task_node.reference_id.name == "test_conditions.sum_sub"
def test_serialization_branch():
    @task
    def mimic(a: int) -> typing.NamedTuple("OutputsBC", c=int):
        return (a, )

    @task
    def t1(c: int) -> typing.NamedTuple("OutputsBC", c=str):
        return ("world", )

    @task
    def t2() -> typing.NamedTuple("OutputsBC", c=str):
        return ("hello", )

    @workflow
    def my_wf(a: int) -> str:
        c = mimic(a=a)
        return conditional("test1").if_(c.c == 4).then(
            t1(c=c.c).c).else_().then(t2().c)

    assert my_wf(a=4) == "world"
    assert my_wf(a=2) == "hello"

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf = get_serializable(serialization_settings, my_wf)
    assert wf is not None
    assert len(wf.nodes) == 2
    assert wf.nodes[1].branch_node is not None
Esempio n. 17
0
def test_environment():
    @task(environment={"FOO": "foofoo", "BAZ": "baz"})
    def t1(a: int) -> str:
        a = a + 2
        return "now it's " + str(a)

    @workflow
    def my_wf(a: int) -> str:
        x = t1(a=a)
        return x

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={
            "FOO": "foo",
            "BAR": "bar"
        },
    )
    with context_manager.FlyteContext.current_context(
    ).new_compilation_context():
        sdk_task = get_serializable(OrderedDict(), serialization_settings, t1)
        assert sdk_task.container.env == {
            "FOO": "foofoo",
            "BAR": "bar",
            "BAZ": "baz"
        }
Esempio n. 18
0
def test_condition_tuple_branches():
    @task
    def sum_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, sub=int):
        return a + b, a - b

    @workflow
    def math_ops(a: int, b: int) -> (int, int):
        add, sub = (
            conditional("noDivByZero")
            .if_(a > b)
            .then(sum_sub(a=a, b=b))
            .else_()
            .fail("Only positive results are allowed")
        )

        return add, sub

    x, y = math_ops(a=3, b=2)
    assert x == 5
    assert y == 1

    wf_spec = get_serializable(OrderedDict(), serialization_settings, math_ops)
    assert len(wf_spec.template.nodes) == 1
    assert (
        wf_spec.template.nodes[0].branch_node.if_else.case.then_node.task_node.reference_id.name
        == "test_conditions.sum_sub"
    )
Esempio n. 19
0
def test_serialization_nested_subwf():
    @task
    def t1(a: int) -> int:
        return a + 2

    @workflow
    def leaf_subwf(a: int = 42) -> (int, int):
        x = t1(a=a)
        u = t1(a=x)
        return x, u

    @workflow
    def middle_subwf() -> (int, int):
        s1, s2 = leaf_subwf(a=50)
        return s2, s2

    @workflow
    def parent_wf() -> (int, int, int, int):
        m1, m2 = middle_subwf()
        l1, l2 = leaf_subwf()
        return m1, m2, l1, l2

    wf_spec = get_serializable(OrderedDict(), serialization_settings,
                               parent_wf)
    assert wf_spec is not None
    assert len(wf_spec.sub_workflows) == 2
    subwf = {v.id.name: v for v in wf_spec.sub_workflows}
    assert subwf.keys() == {
        "test_serialization.leaf_subwf", "test_serialization.middle_subwf"
    }
    midwf = subwf["test_serialization.middle_subwf"]
    assert len(midwf.nodes) == 1
    assert midwf.nodes[0].workflow_node is not None
    assert midwf.nodes[
        0].workflow_node.sub_workflow_ref.name == "test_serialization.leaf_subwf"
Esempio n. 20
0
def test_workflow_values():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        a = a + 2
        return a, "world-" + str(a)

    @workflow(interruptible=True,
              failure_policy=WorkflowFailurePolicy.
              FAIL_AFTER_EXECUTABLE_NODES_COMPLETE)
    def wf(a: int) -> (str, str):
        x, y = t1(a=a)
        u, v = t1(a=x)
        return y, v

    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa",
                                       tag="123")),
        env={},
    )
    sdk_wf = get_serializable(OrderedDict(), serialization_settings, wf)
    assert sdk_wf.metadata_defaults.interruptible
    assert sdk_wf.metadata.on_failure == 1
Esempio n. 21
0
def test_wf_nested_comp():
    @task
    def t1(a: int) -> int:
        a = a + 5
        return a

    @workflow
    def outer() -> (int, int):
        # You should not do this. This is just here for testing.
        @workflow
        def wf2() -> int:
            return t1(a=5)

        return t1(a=3), wf2()

    assert (8, 10) == outer()
    entity_mapping = OrderedDict()

    sdk_wf = get_serializable(entity_mapping, serialization_settings, outer)
    model_wf = sdk_wf.serialize()
    assert len(model_wf.template.interface.outputs.variables) == 2
    assert len(model_wf.template.nodes) == 2
    assert model_wf.template.nodes[1].workflow_node is not None

    sub_wf = model_wf.sub_workflows[0]
    assert len(sub_wf.nodes) == 1
    assert sub_wf.nodes[0].id == "wf2-n0"
    assert sub_wf.nodes[0].task_node.reference_id.name == "test_workflows.t1"
Esempio n. 22
0
def test_serialization():
    maptask = map_task(t1, metadata=TaskMetadata(retries=1))
    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 maptask)

    assert task_spec.template.type == "container_array"
    assert task_spec.template.task_type_version == 1
    assert task_spec.template.container.args == [
        "pyflyte-map-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "test_map_task",
        "task-name",
        "t1",
    ]
Esempio n. 23
0
def test_condition_tuple_branches():
    @task
    def sum_sub(a: int,
                b: int) -> typing.NamedTuple("Outputs", sum=int, sub=int):
        return a + b, a - b

    @workflow
    def math_ops(a: int, b: int) -> (int, int):
        add, sub = (conditional("noDivByZero").if_(a > b).then(
            sum_sub(a=a,
                    b=b)).else_().fail("Only positive results are allowed"))

        return add, sub

    x, y = math_ops(a=3, b=2)
    assert x == 5
    assert y == 1

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    wf_spec = get_serializable(OrderedDict(), serialization_settings, math_ops)
    assert len(wf_spec.template.nodes) == 1
    assert (wf_spec.template.nodes[0].branch_node.if_else.case.then_node.
            task_node.reference_id.name == "test_conditions.sum_sub")
def test_serialization_images():
    @task(container_image="{{.image.xyz.fqn}}:{{.image.default.version}}")
    def t1(a: int) -> int:
        return a

    @task(container_image="{{.image.default.fqn}}:{{.image.default.version}}")
    def t2():
        pass

    @task
    def t3():
        pass

    @task(container_image="docker.io/org/myimage:latest")
    def t4():
        pass

    @task(container_image="docker.io/org/myimage:{{.image.default.version}}")
    def t5(a: int) -> int:
        return a

    os.environ["FLYTE_INTERNAL_IMAGE"] = "docker.io/default:version"
    set_flyte_config_file(
        os.path.join(os.path.dirname(os.path.realpath(__file__)),
                     "configs/images.config"))
    rs = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=get_image_config(),
    )
    t1_ser = get_serializable(rs, t1)
    assert t1_ser.container.image == "docker.io/xyz:version"
    t1_ser.to_flyte_idl()

    t2_ser = get_serializable(rs, t2)
    assert t2_ser.container.image == "docker.io/default:version"

    t3_ser = get_serializable(rs, t3)
    assert t3_ser.container.image == "docker.io/default:version"

    t4_ser = get_serializable(rs, t4)
    assert t4_ser.container.image == "docker.io/org/myimage:latest"

    t5_ser = get_serializable(rs, t5)
    assert t5_ser.container.image == "docker.io/org/myimage:version"
Esempio n. 25
0
def test_resources():
    @task(requests=Resources(cpu="1"), limits=Resources(cpu="2", mem="400M"))
    def t1(a: int) -> str:
        a = a + 2
        return "now it's " + str(a)

    @task(requests=Resources(cpu="3"))
    def t2(a: int) -> str:
        a = a + 200
        return "now it's " + str(a)

    @workflow
    def my_wf(a: int) -> str:
        x = t1(a=a)
        return x

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    with context_manager.FlyteContextManager.with_context(
            context_manager.FlyteContextManager.current_context(
            ).with_new_compilation_state()):
        task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
        assert task_spec.template.container.resources.requests == [
            _resource_models.ResourceEntry(_resource_models.ResourceName.CPU,
                                           "1")
        ]
        assert task_spec.template.container.resources.limits == [
            _resource_models.ResourceEntry(_resource_models.ResourceName.CPU,
                                           "2"),
            _resource_models.ResourceEntry(
                _resource_models.ResourceName.MEMORY, "400M"),
        ]

        task_spec2 = get_serializable(OrderedDict(), serialization_settings,
                                      t2)
        assert task_spec2.template.container.resources.requests == [
            _resource_models.ResourceEntry(_resource_models.ResourceName.CPU,
                                           "3")
        ]
        assert task_spec2.template.container.resources.limits == []
Esempio n. 26
0
def test_serialization_workflow_def():
    @task
    def complex_task(a: int) -> str:
        b = a + 2
        return str(b)

    maptask = map_task(complex_task, metadata=TaskMetadata(retries=1))

    @workflow
    def w1(a: typing.List[int]) -> typing.List[str]:
        return maptask(a=a)

    @workflow
    def w2(a: typing.List[int]) -> typing.List[str]:
        return map_task(complex_task, metadata=TaskMetadata(retries=2))(a=a)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    serialized_control_plane_entities = OrderedDict()
    wf1_spec = get_serializable(serialized_control_plane_entities,
                                serialization_settings, w1)
    assert wf1_spec.template is not None
    assert len(wf1_spec.template.nodes) == 1

    wf2_spec = get_serializable(serialized_control_plane_entities,
                                serialization_settings, w2)
    assert wf2_spec.template is not None
    assert len(wf2_spec.template.nodes) == 1

    flyte_entities = list(serialized_control_plane_entities.keys())

    tasks_seen = []
    for entity in flyte_entities:
        if isinstance(entity, MapPythonTask) and "complex" in entity.name:
            tasks_seen.append(entity)

    assert len(tasks_seen) == 2
    print(tasks_seen[0])
Esempio n. 27
0
def test_fast():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        return a + 2, "world"

    @task
    def t2(a: str, b: str) -> str:
        return b + a

    sdk_task = get_serializable(serialization_settings, t1, True)
    assert "pyflyte-fast-execute" in sdk_task.container.args
Esempio n. 28
0
def test_nested_condition_2():
    @workflow
    def multiplier_2(my_input: float) -> float:
        return (
            conditional("fractions")
            .if_((my_input > 0.1) & (my_input < 1.0))
            .then(
                conditional("inner_fractions")
                .if_(my_input < 0.5)
                .then(double(n=my_input))
                .elif_((my_input > 0.5) & (my_input < 0.7))
                .then(square(n=my_input))
                .else_()
                .fail("Only <0.7 allowed")
            )
            .elif_((my_input > 1.0) & (my_input < 10.0))
            .then(square(n=my_input))
            .else_()
            .then(double(n=my_input))
        )

    srz_wf = get_serializable(OrderedDict(), serialization_settings, multiplier_2)
    assert len(srz_wf.template.nodes) == 1
    fractions_branch = srz_wf.template.nodes[0]
    assert isinstance(fractions_branch, Node)
    assert fractions_branch.id == "n0"
    assert fractions_branch.branch_node is not None
    if_else_b = fractions_branch.branch_node.if_else
    assert if_else_b is not None
    assert if_else_b.case is not None
    assert if_else_b.case.then_node is not None
    inner_fractions_node = if_else_b.case.then_node
    assert inner_fractions_node.id == "n0"
    assert inner_fractions_node.branch_node.if_else.case.then_node.task_node is not None
    assert inner_fractions_node.branch_node.if_else.case.then_node.id == "n0"
    assert len(inner_fractions_node.branch_node.if_else.other) == 1
    assert inner_fractions_node.branch_node.if_else.other[0].then_node.id == "n1"

    # Ensure other cases exist
    assert len(if_else_b.other) == 1
    assert if_else_b.other[0].then_node.task_node is not None
    assert if_else_b.other[0].then_node.id == "n1"

    with pytest.raises(ValueError):
        multiplier_2(my_input=0.7)

    res = multiplier_2(my_input=0.3)
    assert res == 0.6

    res = multiplier_2(my_input=5)
    assert res == 25

    res = multiplier_2(my_input=10)
    assert res == 20
Esempio n. 29
0
def test_ref_sub_wf():
    ref_entity = get_reference_entity(
        _identifier_model.ResourceType.WORKFLOW,
        "proj",
        "dom",
        "app.other.sub_wf",
        "123",
        inputs=kwtypes(a=str, b=int),
        outputs={},
    )

    ctx = context_manager.FlyteContext.current_context()
    with pytest.raises(Exception) as e:
        ref_entity()
    assert "You must mock this out" in f"{e}"

    with context_manager.FlyteContextManager.with_context(
            ctx.with_new_compilation_state()) as ctx:
        with pytest.raises(Exception) as e:
            ref_entity()
        assert "Input was not specified" in f"{e}"

        output = ref_entity(a="hello", b=3)
        assert isinstance(output, VoidPromise)

    @workflow
    def wf1(a: str, b: int):
        ref_entity(a=a, b=b)

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    with pytest.raises(Exception):
        # Subworkflow as references don't work (probably ever). The reason is because we'd need to make a network call
        # to admin to get the structure of the subworkflow and the whole point of reference entities is that there
        # is no network call.
        get_serializable(OrderedDict(), serialization_settings, wf1)
Esempio n. 30
0
def test_ref():
    @reference_task(
        project="flytesnacks",
        domain="development",
        name="recipes.aaa.simple.join_strings",
        version="553018f39e519bdb2597b652639c30ce16b99c79",
    )
    def ref_t1(a: typing.List[str]) -> str:
        ...

    assert ref_t1.id.project == "flytesnacks"
    assert ref_t1.id.domain == "development"
    assert ref_t1.id.name == "recipes.aaa.simple.join_strings"
    assert ref_t1.id.version == "553018f39e519bdb2597b652639c30ce16b99c79"

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    ss = get_serializable(serialization_settings, ref_t1)
    assert ss.id == ref_t1.id
    assert ss.interface.inputs["a"] is not None
    assert ss.interface.outputs["o0"] is not None

    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa",
                                       tag="123")),
        env={},
    )
    sdk_task = get_serializable(serialization_settings, ref_t1)
    assert sdk_task.has_registered
    assert sdk_task.id.project == "flytesnacks"
    assert sdk_task.id.domain == "development"
    assert sdk_task.id.name == "recipes.aaa.simple.join_strings"
    assert sdk_task.id.version == "553018f39e519bdb2597b652639c30ce16b99c79"