Beispiel #1
0
def test_condition_tuple_branches():
    @task
    def sum_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, sub=int):
        return a + b, a - b

    @workflow
    def math_ops(a: int, b: int) -> (int, int):
        # Flyte will only make `sum` and `sub` available as outputs because they are common between all branches
        sum, sub = (
            conditional("noDivByZero")
            .if_(a > b)
            .then(sum_sub(a=a, b=b))
            .else_()
            .fail("Only positive results are allowed")
        )

        return sum, sub

    x, y = math_ops(a=3, b=2)
    assert x == 5
    assert y == 1

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img, images=[default_img]),
    )

    sdk_wf = get_serializable(serialization_settings, math_ops)
    assert sdk_wf.nodes[0].branch_node.if_else.case.then_node.task_node.reference_id.name == "test_conditions.sum_sub"
Beispiel #2
0
def test_workflow_values():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        a = a + 2
        return a, "world-" + str(a)

    @workflow(interruptible=True,
              failure_policy=WorkflowFailurePolicy.
              FAIL_AFTER_EXECUTABLE_NODES_COMPLETE)
    def wf(a: int) -> (str, str):
        x, y = t1(a=a)
        u, v = t1(a=x)
        return y, v

    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa",
                                       tag="123")),
        env={},
    )
    sdk_wf = get_serializable(serialization_settings, wf)
    assert sdk_wf.metadata_defaults.interruptible
    assert sdk_wf.metadata.on_failure == 1
def test_serialization_branch():
    @task
    def mimic(a: int) -> typing.NamedTuple("OutputsBC", c=int):
        return (a, )

    @task
    def t1(c: int) -> typing.NamedTuple("OutputsBC", c=str):
        return ("world", )

    @task
    def t2() -> typing.NamedTuple("OutputsBC", c=str):
        return ("hello", )

    @workflow
    def my_wf(a: int) -> str:
        c = mimic(a=a)
        return conditional("test1").if_(c.c == 4).then(
            t1(c=c.c).c).else_().then(t2().c)

    assert my_wf(a=4) == "world"
    assert my_wf(a=2) == "hello"

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf = get_serializable(serialization_settings, my_wf)
    assert wf is not None
    assert len(wf.nodes) == 2
    assert wf.nodes[1].branch_node is not None
Beispiel #4
0
def test_pod_task():
    pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container")

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")

    custom = simple_pod_task.get_custom(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))
    assert custom["podSpec"]["restartPolicy"] == "OnFailure"
    assert len(custom["podSpec"]["containers"]) == 2
    primary_container = custom["podSpec"]["containers"][0]
    assert primary_container["name"] == "a container"
    assert primary_container["args"] == [
        "pyflyte-execute",
        "--task-module",
        "pod.test_pod",
        "--task-name",
        "simple_pod_task",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
    ]
    assert primary_container["volumeMounts"] == [{
        "mountPath": "some/where",
        "name": "volume mount"
    }]
    assert primary_container["resources"] == {
        "requests": {
            "cpu": {
                "string": "10"
            }
        },
        "limits": {
            "gpu": {
                "string": "2"
            }
        },
    }
    assert primary_container["env"] == [{"name": "FOO", "value": "bar"}]
    assert custom["podSpec"]["containers"][1]["name"] == "another container"
    assert custom["primaryContainerName"] == "a container"
def test_serialization_branch_complex():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        return a + 2, "world"

    @task
    def t2(a: str) -> str:
        return a

    @workflow
    def my_wf(a: int, b: str) -> (int, str):
        x, y = t1(a=a)
        d = (conditional("test1").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then(
            t2(a=y)).else_().fail("Unable to choose branch"))
        f = conditional("test2").if_(d == "hello ").then(
            t2(a="It is hello")).else_().then(t2(a="Not Hello!"))
        return x, f

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf = get_serializable(serialization_settings, my_wf)
    assert wf is not None
    assert len(wf.nodes) == 3
    assert wf.nodes[1].branch_node is not None
    assert wf.nodes[2].branch_node is not None
def test_serialization_branch_sub_wf():
    @task
    def t1(a: int) -> int:
        return a + 2

    @workflow
    def my_sub_wf(a: int) -> int:
        return t1(a=a)

    @workflow
    def my_wf(a: int) -> int:
        d = conditional("test1").if_(a > 3).then(t1(a=a)).else_().then(
            my_sub_wf(a=a))
        return d

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf = get_serializable(serialization_settings, my_wf)
    assert wf is not None
    assert len(wf.nodes[0].inputs) == 1
    assert wf.nodes[0].inputs[0].var == ".a"
    assert wf.nodes[0] is not None
def test_query_no_inputs_or_outputs():
    hive_task = HiveTask(
        name="flytekit.demo.hive_task.hivequery1",
        inputs={},
        task_config=HiveConfig(cluster_label="flyte"),
        query_template="""
            insert into extant_table (1, 'two')
        """,
        output_schema_type=None,
    )

    @workflow
    def my_wf():
        hive_task()

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )
    sdk_task = get_serializable(serialization_settings, hive_task)
    assert len(sdk_task.interface.inputs) == 0
    assert len(sdk_task.interface.outputs) == 0

    get_serializable(serialization_settings, my_wf)
Beispiel #8
0
def test_pytorch_task():
    @task(task_config=PyTorch(num_workers=10,
                              per_replica_requests=Resources(cpu="1")),
          cache=True,
          cache_version="1")
    def my_pytorch_task(x: int, y: str) -> int:
        return x

    assert my_pytorch_task(x=10, y="hello") == 10

    assert my_pytorch_task.task_config is not None

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    assert my_pytorch_task.get_custom(settings) == {"workers": 10}
    assert my_pytorch_task.resources.limits == Resources()
    assert my_pytorch_task.resources.requests == Resources(cpu="1")
    assert my_pytorch_task.task_type == "pytorch"
Beispiel #9
0
def test_dynamic_pod_task():
    dynamic_pod = Pod(pod_spec=get_pod_spec(),
                      primary_container_name="a container")

    @task
    def t1(a: int) -> int:
        return a + 10

    @dynamic(task_config=dynamic_pod)
    def dynamic_pod_task(a: int) -> List[int]:
        s = []
        for i in range(a):
            s.append(t1(a=i))
        return s

    assert isinstance(dynamic_pod_task, PodFunctionTask)
    default_img = Image(name="default", fqn="test", tag="tag")

    custom = dynamic_pod_task.get_custom(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))
    assert len(custom["podSpec"]["containers"]) == 2

    with context_manager.FlyteContext.current_context(
    ).new_serialization_settings(
            serialization_settings=context_manager.SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(
                    Image(name="name", fqn="image", tag="name")),
                env={},
            )) as ctx:
        with ctx.new_execution_context(
                mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
            dynamic_job_spec = dynamic_pod_task.compile_into_workflow(
                ctx, dynamic_pod_task._task_function, a=5)
            assert len(dynamic_job_spec._nodes) == 5
def _get_reg_settings():
    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    return settings
def test_ref():
    @reference_task(
        project="flytesnacks",
        domain="development",
        name="recipes.aaa.simple.join_strings",
        version="553018f39e519bdb2597b652639c30ce16b99c79",
    )
    def ref_t1(a: typing.List[str]) -> str:
        ...

    assert ref_t1.id.project == "flytesnacks"
    assert ref_t1.id.domain == "development"
    assert ref_t1.id.name == "recipes.aaa.simple.join_strings"
    assert ref_t1.id.version == "553018f39e519bdb2597b652639c30ce16b99c79"

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    ss = get_serializable(serialization_settings, ref_t1)
    assert ss.id == ref_t1.id
    assert ss.interface.inputs["a"] is not None
    assert ss.interface.outputs["o0"] is not None

    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")),
        env={},
    )
    sdk_task = get_serializable(serialization_settings, ref_t1)
    assert sdk_task.has_registered
    assert sdk_task.id.project == "flytesnacks"
    assert sdk_task.id.domain == "development"
    assert sdk_task.id.name == "recipes.aaa.simple.join_strings"
    assert sdk_task.id.version == "553018f39e519bdb2597b652639c30ce16b99c79"
def test_serialization():
    square = ContainerTask(
        name="square",
        input_data_dir="/var/inputs",
        output_data_dir="/var/outputs",
        inputs=kwtypes(val=int),
        outputs=kwtypes(out=int),
        image="alpine",
        command=[
            "sh", "-c",
            "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out"
        ],
    )

    sum = ContainerTask(
        name="sum",
        input_data_dir="/var/flyte/inputs",
        output_data_dir="/var/flyte/outputs",
        inputs=kwtypes(x=int, y=int),
        outputs=kwtypes(out=int),
        image="alpine",
        command=[
            "sh", "-c",
            "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out"
        ],
    )

    @workflow
    def raw_container_wf(val1: int, val2: int) -> int:
        return sum(x=square(val=val1), y=square(val=val2))

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf = get_serializable(serialization_settings, raw_container_wf)
    assert wf is not None
    assert len(wf.nodes) == 3
    sqn = get_serializable(serialization_settings, square)
    assert sqn.container.image == "alpine"
    sumn = get_serializable(serialization_settings, sum)
    assert sumn.container.image == "alpine"
def test_lp_with_output():
    ref_lp = get_reference_entity(
        _identifier_model.ResourceType.LAUNCH_PLAN,
        "proj",
        "dom",
        "app.other.flyte_entity",
        "123",
        inputs=kwtypes(a=str, b=int),
        outputs=kwtypes(x=bool, y=int),
    )

    @task
    def t1() -> (str, int):
        return "hello", 88

    @task
    def t2(q: bool, r: int) -> str:
        return f"q: {q} r: {r}"

    @workflow
    def wf1() -> str:
        t1_str, t1_int = t1()
        x_out, y_out = ref_lp(a=t1_str, b=t1_int)
        return t2(q=x_out, r=y_out)

    @patch(ref_lp)
    def inner_test(ref_mock):
        ref_mock.return_value = (False, 30)
        x = wf1()
        assert x == "q: False r: 30"
        ref_mock.assert_called_with(a="hello", b=88)

    inner_test()

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    sdk_wf = get_serializable(serialization_settings, wf1)
    assert sdk_wf.nodes[1].workflow_node.launchplan_ref.project == "proj"
    assert sdk_wf.nodes[1].workflow_node.launchplan_ref.name == "app.other.flyte_entity"
def test_lps(resource_type):
    ref_entity = get_reference_entity(
        resource_type, "proj", "dom", "app.other.flyte_entity", "123", inputs=kwtypes(a=str, b=int), outputs={},
    )

    ctx = context_manager.FlyteContext.current_context()
    with pytest.raises(Exception) as e:
        ref_entity()
    assert "You must mock this out" in f"{e}"

    with ctx.new_compilation_context() as ctx:
        with pytest.raises(Exception) as e:
            ref_entity()
        assert "Input was not specified" in f"{e}"

        output = ref_entity(a="hello", b=3)
        assert isinstance(output, VoidPromise)

    @workflow
    def wf1(a: str, b: int):
        ref_entity(a=a, b=b)

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    sdk_wf = get_serializable(serialization_settings, wf1)
    assert len(sdk_wf.interface.inputs) == 2
    assert len(sdk_wf.interface.outputs) == 0
    assert len(sdk_wf.nodes) == 1
    if resource_type == _identifier_model.ResourceType.LAUNCH_PLAN:
        assert sdk_wf.nodes[0].workflow_node.launchplan_ref.project == "proj"
        assert sdk_wf.nodes[0].workflow_node.launchplan_ref.name == "app.other.flyte_entity"
    elif resource_type == _identifier_model.ResourceType.WORKFLOW:
        assert sdk_wf.nodes[0].workflow_node.sub_workflow_ref.project == "proj"
        assert sdk_wf.nodes[0].workflow_node.sub_workflow_ref.name == "app.other.flyte_entity"
    else:
        assert sdk_wf.nodes[0].task_node.reference_id.project == "proj"
        assert sdk_wf.nodes[0].task_node.reference_id.name == "app.other.flyte_entity"
def test_environment():
    @task(environment={"FOO": "foofoo", "BAZ": "baz"})
    def t1(a: int) -> str:
        a = a + 2
        return "now it's " + str(a)

    @workflow
    def my_wf(a: int) -> str:
        x = t1(a=a)
        return x

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={"FOO": "foo", "BAR": "bar"},
    )
    with context_manager.FlyteContext.current_context().new_compilation_context():
        sdk_task = get_serializable(serialization_settings, t1)
        assert sdk_task.container.env == {"FOO": "foofoo", "BAR": "bar", "BAZ": "baz"}
def test_serialization():
    hive_task = HiveTask(
        name="flytekit.demo.hive_task.hivequery1",
        inputs=kwtypes(my_schema=FlyteSchema, ds=str),
        config=HiveConfig(cluster_label="flyte"),
        query_template="""
            set engine=tez;
            insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet  -- will be unique per retry
            select *
            from blah
            where ds = '{{ .Inputs.ds }}' and uri = '{{ .inputs.my_schema }}'
        """,
        # the schema literal's backend uri will be equal to the value of .raw_output_data
        output_schema_type=FlyteSchema,
    )

    @workflow
    def my_wf(in_schema: FlyteSchema, ds: str) -> FlyteSchema:
        return hive_task(my_schema=in_schema, ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )
    sdk_task = get_serializable(serialization_settings, hive_task)
    assert "{{ .rawOutputDataPrefix" in sdk_task.custom["query"]["query"]
    assert "insert overwrite directory" in sdk_task.custom["query"]["query"]
    assert len(sdk_task.interface.inputs) == 2
    assert len(sdk_task.interface.outputs) == 1

    sdk_wf = get_serializable(serialization_settings, my_wf)
    assert sdk_wf.interface.outputs["o0"].type.schema is not None
    assert sdk_wf.outputs[0].var == "o0"
    assert sdk_wf.outputs[0].binding.promise.node_id == "n0"
    assert sdk_wf.outputs[0].binding.promise.var == "results"
def test_lp_serialize():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        a = a + 2
        return a, "world-" + str(a)

    @task
    def t2(a: str, b: str) -> str:
        return b + a

    @workflow
    def my_subwf(a: int) -> (str, str):
        x, y = t1(a=a)
        u, v = t1(a=x)
        return y, v

    lp = launch_plan.LaunchPlan.create("serialize_test1", my_subwf)
    lp_with_defaults = launch_plan.LaunchPlan.create("serialize_test2", my_subwf, default_inputs={"a": 3})

    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")),
        env={},
    )
    sdk_lp = get_serializable(serialization_settings, lp)
    assert len(sdk_lp.default_inputs.parameters) == 0
    assert len(sdk_lp.fixed_inputs.literals) == 0

    sdk_lp = get_serializable(serialization_settings, lp_with_defaults)
    assert len(sdk_lp.default_inputs.parameters) == 1
    assert len(sdk_lp.fixed_inputs.literals) == 0

    # Adding a check to make sure oneof is respected. Tricky with booleans... if a default is specified, the
    # required field needs to be None, not False.
    parameter_a = sdk_lp.default_inputs.parameters["a"]
    parameter_a = Parameter.from_flyte_idl(parameter_a.to_flyte_idl())
    assert parameter_a.default is not None
Beispiel #18
0
def test_spark_task():
    @task(task_config=Spark(spark_conf={"spark": "1"}))
    def my_spark(a: str) -> int:
        session = flytekit.current_context().spark_session
        assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
        return 10

    assert my_spark.task_config is not None
    assert my_spark.task_config.spark_conf == {"spark": "1"}

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    retrieved_settings = my_spark.get_custom(settings)
    assert retrieved_settings["sparkConf"] == {"spark": "1"}
def test_wf1_with_dynamic():
    @task
    def t1(a: int) -> str:
        a = a + 2
        return "world-" + str(a)

    @task
    def t2(a: str, b: str) -> str:
        return b + a

    @dynamic
    def my_subwf(a: int) -> typing.List[str]:
        s = []
        for i in range(a):
            s.append(t1(a=i))
        return s

    @workflow
    def my_wf(a: int, b: str) -> (str, typing.List[str]):
        x = t2(a=b, b=b)
        v = my_subwf(a=a)
        return x, v

    v = 5
    x = my_wf(a=v, b="hello ")
    assert x == ("hello hello ", ["world-" + str(i) for i in range(2, v + 2)])

    with context_manager.FlyteContext.current_context().new_serialization_settings(
        serialization_settings=context_manager.SerializationSettings(
            project="test_proj",
            domain="test_domain",
            version="abc",
            image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
            env={},
        )
    ) as ctx:
        with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
            dynamic_job_spec = my_subwf.compile_into_workflow(ctx, my_subwf._task_function, a=5)
            assert len(dynamic_job_spec._nodes) == 5
def test_resources():
    @task(requests=Resources(cpu="1"), limits=Resources(cpu="2", mem="400M"))
    def t1(a: int) -> str:
        a = a + 2
        return "now it's " + str(a)

    @task(requests=Resources(cpu="3"))
    def t2(a: int) -> str:
        a = a + 200
        return "now it's " + str(a)

    @workflow
    def my_wf(a: int) -> str:
        x = t1(a=a)
        return x

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    with context_manager.FlyteContext.current_context().new_compilation_context():
        sdk_task = get_serializable(serialization_settings, t1)
        assert sdk_task.container.resources.requests == [
            _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "1")
        ]
        assert sdk_task.container.resources.limits == [
            _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "2"),
            _resource_models.ResourceEntry(_resource_models.ResourceName.MEMORY, "400M"),
        ]

        sdk_task2 = get_serializable(serialization_settings, t2)
        assert sdk_task2.container.resources.requests == [
            _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "3")
        ]
        assert sdk_task2.container.resources.limits == []
def test_dynamic_conditional():
    @task
    def split(
            in1: typing.List[int]
    ) -> (typing.List[int], typing.List[int], int):
        return in1[0:int(len(in1) / 2)], in1[int(len(in1) / 2) +
                                             1:], len(in1) / 2

    # One sample implementation for merging. In a more real world example, this might merge file streams and only load
    # chunks into the memory.
    @task
    def merge(x: typing.List[int], y: typing.List[int]) -> typing.List[int]:
        n1 = len(x)
        n2 = len(y)
        result = list[int]()
        i = 0
        j = 0

        # Traverse both array
        while i < n1 and j < n2:
            # Check if current element of first array is smaller than current element of second array. If yes,
            # store first array element and increment first array index. Otherwise do same with second array
            if x[i] < y[j]:
                result.append(x[i])
                i = i + 1
            else:
                result.append(y[j])
                j = j + 1

        # Store remaining elements of first array
        while i < n1:
            result.append(x[i])
            i = i + 1

        # Store remaining elements of second array
        while j < n2:
            result.append(y[j])
            j = j + 1

        return result

    # This runs the sorting completely locally. It's faster and more efficient to do so if the entire list fits in memory.
    @task
    def merge_sort_locally(in1: typing.List[int]) -> typing.List[int]:
        return sorted(in1)

    @task
    def also_merge_sort_locally(in1: typing.List[int]) -> typing.List[int]:
        return sorted(in1)

    @dynamic
    def merge_sort_remotely(in1: typing.List[int]) -> typing.List[int]:
        x, y, new_count = split(in1=in1)
        sorted_x = merge_sort(in1=x, count=new_count)
        sorted_y = merge_sort(in1=y, count=new_count)
        return merge(x=sorted_x, y=sorted_y)

    @workflow
    def merge_sort(in1: typing.List[int], count: int) -> typing.List[int]:
        return (conditional("terminal_case").if_(count < 500).then(
            merge_sort_locally(in1=in1)).elif_(count < 1000).then(
                also_merge_sort_locally(in1=in1)).else_().then(
                    merge_sort_remotely(in1=in1)))

    with context_manager.FlyteContext.current_context(
    ).new_serialization_settings(
            serialization_settings=context_manager.SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(
                    Image(name="name", fqn="image", tag="name")),
                env={},
            )) as ctx:
        with ctx.new_execution_context(
                mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
            dynamic_job_spec = merge_sort_remotely.compile_into_workflow(
                ctx, merge_sort_remotely._task_function, in1=[2, 3, 4, 5])
            assert len(dynamic_job_spec.tasks) == 5
Beispiel #22
0
import typing

from flytekit import ContainerTask
from flytekit.annotated import context_manager
from flytekit.annotated.base_task import kwtypes
from flytekit.annotated.context_manager import Image, ImageConfig
from flytekit.annotated.launch_plan import LaunchPlan, ReferenceLaunchPlan
from flytekit.annotated.task import ReferenceTask, task
from flytekit.annotated.workflow import ReferenceWorkflow, workflow
from flytekit.common.translator import get_serializable

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = context_manager.SerializationSettings(
    project="project",
    domain="domain",
    version="version",
    env=None,
    image_config=ImageConfig(default_image=default_img, images=[default_img]),
)


def test_references():
    rlp = ReferenceLaunchPlan("media",
                              "stg",
                              "some.name",
                              "cafe",
                              inputs=kwtypes(in1=str),
                              outputs=kwtypes())
    sdk_lp = get_serializable(serialization_settings, rlp)
    assert sdk_lp.has_registered
Beispiel #23
0
def test_normal_task():
    @task
    def t1(a: str) -> str:
        return a + " world"

    @workflow
    def my_wf(a: str) -> str:
        t1_node = create_node(t1, a=a)
        return t1_node.o0

    r = my_wf(a="hello")
    assert r == "hello world"

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    sdk_wf = get_serializable(serialization_settings, my_wf)
    assert len(sdk_wf.nodes) == 1
    assert len(sdk_wf.outputs) == 1

    @task
    def t2():
        ...

    @task
    def t3():
        ...

    @workflow
    def empty_wf():
        t2_node = create_node(t2)
        t3_node = create_node(t3)
        t3_node.runs_before(t2_node)

    # Test that VoidPromises can handle runs_before
    empty_wf()

    @workflow
    def empty_wf2():
        t2_node = create_node(t2)
        t3_node = create_node(t3)
        t3_node >> t2_node

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    sdk_wf = get_serializable(serialization_settings, empty_wf)
    assert sdk_wf.nodes[0].upstream_node_ids[0] == "n1"
    assert sdk_wf.nodes[0].id == "n0"

    sdk_wf = get_serializable(serialization_settings, empty_wf2)
    assert sdk_wf.nodes[0].upstream_node_ids[0] == "n1"
    assert sdk_wf.nodes[0].id == "n0"