Пример #1
0
def test_ref():
    @reference_task(
        project="flytesnacks",
        domain="development",
        name="recipes.aaa.simple.join_strings",
        version="553018f39e519bdb2597b652639c30ce16b99c79",
    )
    def ref_t1(a: typing.List[str]) -> str:
        ...

    assert ref_t1.id.project == "flytesnacks"
    assert ref_t1.id.domain == "development"
    assert ref_t1.id.name == "recipes.aaa.simple.join_strings"
    assert ref_t1.id.version == "553018f39e519bdb2597b652639c30ce16b99c79"

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    ss = get_serializable(OrderedDict(), serialization_settings, ref_t1)
    assert ss is None

    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa",
                                       tag="123")),
        env={},
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings, ref_t1)
    assert task_spec is None
Пример #2
0
def test_condition_tuple_branches():
    @task
    def sum_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, sub=int):
        return a + b, a - b

    @workflow
    def math_ops(a: int, b: int) -> (int, int):
        # Flyte will only make `sum` and `sub` available as outputs because they are common between all branches
        sum, sub = (
            conditional("noDivByZero")
            .if_(a > b)
            .then(sum_sub(a=a, b=b))
            .else_()
            .fail("Only positive results are allowed")
        )

        return sum, sub

    x, y = math_ops(a=3, b=2)
    assert x == 5
    assert y == 1

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img, images=[default_img]),
    )

    sdk_wf = get_serializable(OrderedDict(), serialization_settings, math_ops)
    assert sdk_wf.nodes[0].branch_node.if_else.case.then_node.task_node.reference_id.name == "test_conditions.sum_sub"
Пример #3
0
def test_serialization_branch_sub_wf():
    @task
    def t1(a: int) -> int:
        return a + 2

    @workflow
    def my_sub_wf(a: int) -> int:
        return t1(a=a)

    @workflow
    def my_wf(a: int) -> int:
        d = conditional("test1").if_(a > 3).then(t1(a=a)).else_().then(
            my_sub_wf(a=a))
        return d

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert wf is not None
    assert len(wf.nodes[0].inputs) == 1
    assert wf.nodes[0].inputs[0].var == ".a"
    assert wf.nodes[0] is not None
Пример #4
0
def validate_image(ctx: typing.Any, param: str, values: tuple) -> ImageConfig:
    """
    Validates the image to match the standard format. Also validates that only one default image
    is provided. a default image, is one that is specified as
      default=img or just img. All other images should be provided with a name, in the format
      name=img
    """
    default_image = None
    images = []
    for v in values:
        if "=" in v:
            splits = v.split("=", maxsplit=1)
            img = look_up_image_info(name=splits[0],
                                     tag=splits[1],
                                     optional_tag=False)
        else:
            img = look_up_image_info(_DEFAULT_IMAGE_NAME, v, False)

        if default_image and img.name == _DEFAULT_IMAGE_NAME:
            raise click.BadParameter(
                f"Only one default image can be specified. Received multiple {default_image} & {img} for {param}"
            )
        if img.name == _DEFAULT_IMAGE_NAME:
            default_image = img
        else:
            images.append(img)

    return ImageConfig(default_image, images)
Пример #5
0
def test_serialization():
    maptask = map_task(t1, metadata=TaskMetadata(retries=1))
    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 maptask)

    assert task_spec.template.type == "container_array"
    assert task_spec.template.task_type_version == 1
    assert task_spec.template.container.args == [
        "pyflyte-map-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "test_map_task",
        "task-name",
        "t1",
    ]
Пример #6
0
def test_condition_tuple_branches():
    @task
    def sum_sub(a: int,
                b: int) -> typing.NamedTuple("Outputs", sum=int, sub=int):
        return a + b, a - b

    @workflow
    def math_ops(a: int, b: int) -> (int, int):
        add, sub = (conditional("noDivByZero").if_(a > b).then(
            sum_sub(a=a,
                    b=b)).else_().fail("Only positive results are allowed"))

        return add, sub

    x, y = math_ops(a=3, b=2)
    assert x == 5
    assert y == 1

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    wf_spec = get_serializable(OrderedDict(), serialization_settings, math_ops)
    assert len(wf_spec.template.nodes) == 1
    assert (wf_spec.template.nodes[0].branch_node.if_else.case.then_node.
            task_node.reference_id.name == "test_conditions.sum_sub")
Пример #7
0
def get_registerable_container_image(img: Optional[str], cfg: ImageConfig) -> str:
    """
    :param img: Configured image
    :param cfg: Registration configuration
    :return:
    """
    if img is not None and img != "":
        matches = _IMAGE_REPLACE_REGEX.findall(img)
        if matches is None or len(matches) == 0:
            return img
        for m in matches:
            if len(m) < 3:
                raise AssertionError(
                    "Image specification should be of the form <fqn>:<tag> OR <fqn>:{{.image.default.version}} OR "
                    f"{{.image.xyz.fqn}}:{{.image.xyz.version}} OR {{.image.xyz}} - Received {m}"
                )
            replace_group, name, attr = m
            if name is None or name == "":
                raise AssertionError(f"Image format is incorrect {m}")
            img_cfg = cfg.find_image(name)
            if img_cfg is None:
                raise AssertionError(f"Image Config with name {name} not found in the configuration")
            if attr == "version":
                if img_cfg.tag is not None:
                    img = img.replace(replace_group, img_cfg.tag)
                else:
                    img = img.replace(replace_group, cfg.default_image.tag)
            elif attr == "fqn":
                img = img.replace(replace_group, img_cfg.fqn)
            elif attr == "":
                img = img.replace(replace_group, img_cfg.full)
            else:
                raise AssertionError(f"Only fqn and version are supported replacements, {attr} is not supported")
        return img
    return f"{cfg.default_image.fqn}:{cfg.default_image.tag}"
Пример #8
0
def test_serialization_branch_complex_2():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        return a + 2, "world"

    @task
    def t2(a: str) -> str:
        return a

    @workflow
    def my_wf(a: int, b: str) -> (int, str):
        x, y = t1(a=a)
        d = (conditional("test1").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then(
            t2(a=y)).else_().fail("Unable to choose branch"))
        f = conditional("test2").if_(d == "hello ").then(
            t2(a="It is hello")).else_().then(t2(a="Not Hello!"))
        return x, f

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert wf_spec is not None
    assert wf_spec.template.nodes[1].inputs[0].var == "n0.t1_int_output"
Пример #9
0
def test_workflow_values():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        a = a + 2
        return a, "world-" + str(a)

    @workflow(interruptible=True,
              failure_policy=WorkflowFailurePolicy.
              FAIL_AFTER_EXECUTABLE_NODES_COMPLETE)
    def wf(a: int) -> (str, str):
        x, y = t1(a=a)
        u, v = t1(a=x)
        return y, v

    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa",
                                       tag="123")),
        env={},
    )
    sdk_wf = get_serializable(OrderedDict(), serialization_settings, wf)
    assert sdk_wf.metadata_defaults.interruptible
    assert sdk_wf.metadata.on_failure == 1
Пример #10
0
def test_dont_convert_remotes():
    @task
    def t1(in1: FlyteFile):
        print(in1)

    @dynamic
    def dyn(in1: FlyteFile):
        t1(in1=in1)

    fd = FlyteFile("s3://anything")

    with context_manager.FlyteContext.current_context(
    ).new_serialization_settings(
            serialization_settings=context_manager.SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(
                    Image(name="name", fqn="image", tag="name")),
                env={},
            )) as ctx:
        with ctx.new_execution_context(
                mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
            lit = TypeEngine.to_literal(
                ctx, fd, FlyteFile,
                BlobType("",
                         dimensionality=BlobType.BlobDimensionality.SINGLE))
            lm = LiteralMap(literals={"in1": lit})
            wf = dyn.dispatch_execute(ctx, lm)
            assert wf.nodes[0].inputs[
                0].binding.scalar.blob.uri == "s3://anything"
Пример #11
0
def test_container_image_conversion():
    default_img = Image(name="default", fqn="xyz.com/abc", tag="tag1")
    other_img = Image(name="other", fqn="xyz.com/other", tag="tag-other")
    cfg = ImageConfig(default_image=default_img, images=[default_img, other_img])
    assert get_registerable_container_image(None, cfg) == "xyz.com/abc:tag1"
    assert get_registerable_container_image("", cfg) == "xyz.com/abc:tag1"
    assert get_registerable_container_image("abc", cfg) == "abc"
    assert get_registerable_container_image("abc:latest", cfg) == "abc:latest"
    assert get_registerable_container_image("abc:{{.image.default.version}}", cfg) == "abc:tag1"
    assert (
        get_registerable_container_image("{{.image.default.fqn}}:{{.image.default.version}}", cfg) == "xyz.com/abc:tag1"
    )
    assert (
        get_registerable_container_image("{{.image.other.fqn}}:{{.image.other.version}}", cfg)
        == "xyz.com/other:tag-other"
    )
    assert (
        get_registerable_container_image("{{.image.other.fqn}}:{{.image.default.version}}", cfg) == "xyz.com/other:tag1"
    )
    assert get_registerable_container_image("{{.image.other.fqn}}", cfg) == "xyz.com/other"
    # Works with images instead of just image
    assert get_registerable_container_image("{{.images.other.fqn}}", cfg) == "xyz.com/other"

    with pytest.raises(AssertionError):
        get_registerable_container_image("{{.image.blah.fqn}}:{{.image.other.version}}", cfg)

    with pytest.raises(AssertionError):
        get_registerable_container_image("{{.image.fqn}}:{{.image.other.version}}", cfg)

    with pytest.raises(AssertionError):
        get_registerable_container_image("{{.image.blah}}", cfg)

    assert get_registerable_container_image("{{.image.default}}", cfg) == "xyz.com/abc:tag1"
Пример #12
0
def test_environment():
    @task(environment={"FOO": "foofoo", "BAZ": "baz"})
    def t1(a: int) -> str:
        a = a + 2
        return "now it's " + str(a)

    @workflow
    def my_wf(a: int) -> str:
        x = t1(a=a)
        return x

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={
            "FOO": "foo",
            "BAR": "bar"
        },
    )
    with context_manager.FlyteContext.current_context(
    ).new_compilation_context():
        sdk_task = get_serializable(OrderedDict(), serialization_settings, t1)
        assert sdk_task.container.env == {
            "FOO": "foofoo",
            "BAR": "bar",
            "BAZ": "baz"
        }
Пример #13
0
def test_serialization_branch():
    @task
    def mimic(a: int) -> typing.NamedTuple("OutputsBC", c=int):
        return (a, )

    @task
    def t1(c: int) -> typing.NamedTuple("OutputsBC", c=str):
        return ("world", )

    @task
    def t2() -> typing.NamedTuple("OutputsBC", c=str):
        return ("hello", )

    @workflow
    def my_wf(a: int) -> str:
        c = mimic(a=a)
        return conditional("test1").if_(c.c == 4).then(
            t1(c=c.c).c).else_().then(t2().c)

    assert my_wf(a=4) == "world"
    assert my_wf(a=2) == "hello"

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert wf_spec is not None
    assert len(wf_spec.template.nodes) == 2
    assert wf_spec.template.nodes[1].branch_node is not None
Пример #14
0
def test_ref():
    @reference_task(
        project="flytesnacks",
        domain="development",
        name="recipes.aaa.simple.join_strings",
        version="553018f39e519bdb2597b652639c30ce16b99c79",
    )
    def ref_t1(a: typing.List[str]) -> str:
        ...

    assert ref_t1.id.project == "flytesnacks"
    assert ref_t1.id.domain == "development"
    assert ref_t1.id.name == "recipes.aaa.simple.join_strings"
    assert ref_t1.id.version == "553018f39e519bdb2597b652639c30ce16b99c79"

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    ss = get_serializable(serialization_settings, ref_t1)
    assert ss.id == ref_t1.id
    assert ss.interface.inputs["a"] is not None
    assert ss.interface.outputs["o0"] is not None

    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa",
                                       tag="123")),
        env={},
    )
    sdk_task = get_serializable(serialization_settings, ref_t1)
    assert sdk_task.has_registered
    assert sdk_task.id.project == "flytesnacks"
    assert sdk_task.id.domain == "development"
    assert sdk_task.id.name == "recipes.aaa.simple.join_strings"
    assert sdk_task.id.version == "553018f39e519bdb2597b652639c30ce16b99c79"
Пример #15
0
def test_wf1_with_fast_dynamic():
    @task
    def t1(a: int) -> str:
        a = a + 2
        return "fast-" + str(a)

    @dynamic
    def my_subwf(a: int) -> typing.List[str]:
        s = []
        for i in range(a):
            s.append(t1(a=i))
        return s

    @workflow
    def my_wf(a: int) -> typing.List[str]:
        v = my_subwf(a=a)
        return v

    with context_manager.FlyteContextManager.with_context(
            context_manager.FlyteContextManager.current_context(
            ).with_serialization_settings(
                context_manager.SerializationSettings(
                    project="test_proj",
                    domain="test_domain",
                    version="abc",
                    image_config=ImageConfig(
                        Image(name="name", fqn="image", tag="name")),
                    env={},
                    fast_serialization_settings=FastSerializationSettings(
                        enabled=True),
                ))) as ctx:
        with context_manager.FlyteContextManager.with_context(
                ctx.with_execution_state(
                    ctx.execution_state.with_params(
                        mode=ExecutionState.Mode.TASK_EXECUTION,
                        additional_context={
                            "dynamic_addl_distro":
                            "s3://my-s3-bucket/fast/123",
                            "dynamic_dest_dir": "/User/flyte/workflows",
                        },
                    ))) as ctx:
            dynamic_job_spec = my_subwf.compile_into_workflow(
                ctx, my_subwf._task_function, a=5)
            assert len(dynamic_job_spec._nodes) == 5
            assert len(dynamic_job_spec.tasks) == 1
            args = " ".join(dynamic_job_spec.tasks[0].container.args)
            assert args.startswith(
                "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 "
                "--dest-dir /User/flyte/workflows")

    assert context_manager.FlyteContextManager.size() == 1
Пример #16
0
def test_lps(resource_type):
    ref_entity = get_reference_entity(
        resource_type,
        "proj",
        "dom",
        "app.other.flyte_entity",
        "123",
        inputs=kwtypes(a=str, b=int),
        outputs={},
    )

    ctx = context_manager.FlyteContext.current_context()
    with pytest.raises(Exception) as e:
        ref_entity()
    assert "You must mock this out" in f"{e}"

    with ctx.new_compilation_context() as ctx:
        with pytest.raises(Exception) as e:
            ref_entity()
        assert "Input was not specified" in f"{e}"

        output = ref_entity(a="hello", b=3)
        assert isinstance(output, VoidPromise)

    @workflow
    def wf1(a: str, b: int):
        ref_entity(a=a, b=b)

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    sdk_wf = get_serializable(serialization_settings, wf1)
    assert len(sdk_wf.interface.inputs) == 2
    assert len(sdk_wf.interface.outputs) == 0
    assert len(sdk_wf.nodes) == 1
    if resource_type == _identifier_model.ResourceType.LAUNCH_PLAN:
        assert sdk_wf.nodes[0].workflow_node.launchplan_ref.project == "proj"
        assert sdk_wf.nodes[
            0].workflow_node.launchplan_ref.name == "app.other.flyte_entity"
    elif resource_type == _identifier_model.ResourceType.WORKFLOW:
        assert sdk_wf.nodes[0].workflow_node.sub_workflow_ref.project == "proj"
        assert sdk_wf.nodes[
            0].workflow_node.sub_workflow_ref.name == "app.other.flyte_entity"
    else:
        assert sdk_wf.nodes[0].task_node.reference_id.project == "proj"
        assert sdk_wf.nodes[
            0].task_node.reference_id.name == "app.other.flyte_entity"
Пример #17
0
def test_py_func_task_get_container():
    def foo(i: int):
        pass

    default_img = Image(name="default", fqn="xyz.com/abc", tag="tag1")
    other_img = Image(name="other", fqn="xyz.com/other", tag="tag-other")
    cfg = ImageConfig(default_image=default_img, images=[default_img, other_img])

    settings = SerializationSettings(project="p", domain="d", version="v", image_config=cfg, env={"FOO": "bar"})

    pytask = PythonFunctionTask(None, foo, None, environment={"BAZ": "baz"})
    c = pytask.get_container(settings)
    assert c.image == "xyz.com/abc:tag1"
    assert c.env == {"FOO": "bar", "BAZ": "baz"}
Пример #18
0
def test_serialization():
    square = ContainerTask(
        name="square",
        input_data_dir="/var/inputs",
        output_data_dir="/var/outputs",
        inputs=kwtypes(val=int),
        outputs=kwtypes(out=int),
        image="alpine",
        command=[
            "sh", "-c",
            "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out"
        ],
    )

    sum = ContainerTask(
        name="sum",
        input_data_dir="/var/flyte/inputs",
        output_data_dir="/var/flyte/outputs",
        inputs=kwtypes(x=int, y=int),
        outputs=kwtypes(out=int),
        image="alpine",
        command=[
            "sh", "-c",
            "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out"
        ],
    )

    @workflow
    def raw_container_wf(val1: int, val2: int) -> int:
        return sum(x=square(val=val1), y=square(val=val2))

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings,
                               raw_container_wf)
    assert wf_spec is not None
    assert wf_spec.template is not None
    assert len(wf_spec.template.nodes) == 3
    sqn_spec = get_serializable(OrderedDict(), serialization_settings, square)
    assert sqn_spec.template.container.image == "alpine"
    sumn_spec = get_serializable(OrderedDict(), serialization_settings, sum)
    assert sumn_spec.template.container.image == "alpine"
Пример #19
0
def test_lp_serialize():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        a = a + 2
        return a, "world-" + str(a)

    @task
    def t2(a: str, b: str) -> str:
        return b + a

    @workflow
    def my_subwf(a: int) -> (str, str):
        x, y = t1(a=a)
        u, v = t1(a=x)
        return y, v

    lp = launch_plan.LaunchPlan.create("serialize_test1", my_subwf)
    lp_with_defaults = launch_plan.LaunchPlan.create("serialize_test2",
                                                     my_subwf,
                                                     default_inputs={"a": 3})

    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa",
                                       tag="123")),
        env={},
    )
    sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp)
    assert len(sdk_lp.default_inputs.parameters) == 1
    assert sdk_lp.default_inputs.parameters["a"].required
    assert len(sdk_lp.fixed_inputs.literals) == 0

    sdk_lp = get_serializable(OrderedDict(), serialization_settings,
                              lp_with_defaults)
    assert len(sdk_lp.default_inputs.parameters) == 1
    assert not sdk_lp.default_inputs.parameters["a"].required
    assert sdk_lp.default_inputs.parameters[
        "a"].default == _literal_models.Literal(scalar=_literal_models.Scalar(
            primitive=_literal_models.Primitive(integer=3)))
    assert len(sdk_lp.fixed_inputs.literals) == 0

    # Adding a check to make sure oneof is respected. Tricky with booleans... if a default is specified, the
    # required field needs to be None, not False.
    parameter_a = sdk_lp.default_inputs.parameters["a"]
    parameter_a = Parameter.from_flyte_idl(parameter_a.to_flyte_idl())
    assert parameter_a.default is not None
Пример #20
0
def test_wf1_with_dynamic():
    @task
    def t1(a: int) -> str:
        a = a + 2
        return "world-" + str(a)

    @task
    def t2(a: str, b: str) -> str:
        return b + a

    @dynamic
    def my_subwf(a: int) -> typing.List[str]:
        s = []
        for i in range(a):
            s.append(t1(a=i))
        return s

    @workflow
    def my_wf(a: int, b: str) -> (str, typing.List[str]):
        x = t2(a=b, b=b)
        v = my_subwf(a=a)
        return x, v

    v = 5
    x = my_wf(a=v, b="hello ")
    assert x == ("hello hello ", ["world-" + str(i) for i in range(2, v + 2)])

    with context_manager.FlyteContextManager.with_context(
            context_manager.FlyteContextManager.current_context(
            ).with_serialization_settings(
                context_manager.SerializationSettings(
                    project="test_proj",
                    domain="test_domain",
                    version="abc",
                    image_config=ImageConfig(
                        Image(name="name", fqn="image", tag="name")),
                    env={},
                ))) as ctx:
        new_exc_state = ctx.execution_state.with_params(
            mode=ExecutionState.Mode.TASK_EXECUTION)
        with context_manager.FlyteContextManager.with_context(
                ctx.with_execution_state(new_exc_state)) as ctx:
            dynamic_job_spec = my_subwf.compile_into_workflow(
                ctx, my_subwf._task_function, a=5)
            assert len(dynamic_job_spec._nodes) == 5
            assert len(dynamic_job_spec.tasks) == 1

    assert context_manager.FlyteContextManager.size() == 1
Пример #21
0
def test_lp_with_output():
    ref_lp = get_reference_entity(
        _identifier_model.ResourceType.LAUNCH_PLAN,
        "proj",
        "dom",
        "app.other.flyte_entity",
        "123",
        inputs=kwtypes(a=str, b=int),
        outputs=kwtypes(x=bool, y=int),
    )

    @task
    def t1() -> (str, int):
        return "hello", 88

    @task
    def t2(q: bool, r: int) -> str:
        return f"q: {q} r: {r}"

    @workflow
    def wf1() -> str:
        t1_str, t1_int = t1()
        x_out, y_out = ref_lp(a=t1_str, b=t1_int)
        return t2(q=x_out, r=y_out)

    @patch(ref_lp)
    def inner_test(ref_mock):
        ref_mock.return_value = (False, 30)
        x = wf1()
        assert x == "q: False r: 30"
        ref_mock.assert_called_with(a="hello", b=88)

    inner_test()

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, wf1)
    assert wf_spec.template.nodes[
        1].workflow_node.launchplan_ref.project == "proj"
    assert wf_spec.template.nodes[
        1].workflow_node.launchplan_ref.name == "app.other.flyte_entity"
Пример #22
0
def test_resources():
    @task(requests=Resources(cpu="1"), limits=Resources(cpu="2", mem="400M"))
    def t1(a: int) -> str:
        a = a + 2
        return "now it's " + str(a)

    @task(requests=Resources(cpu="3"))
    def t2(a: int) -> str:
        a = a + 200
        return "now it's " + str(a)

    @workflow
    def my_wf(a: int) -> str:
        x = t1(a=a)
        return x

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    with context_manager.FlyteContextManager.with_context(
            context_manager.FlyteContextManager.current_context(
            ).with_new_compilation_state()):
        task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
        assert task_spec.template.container.resources.requests == [
            _resource_models.ResourceEntry(_resource_models.ResourceName.CPU,
                                           "1")
        ]
        assert task_spec.template.container.resources.limits == [
            _resource_models.ResourceEntry(_resource_models.ResourceName.CPU,
                                           "2"),
            _resource_models.ResourceEntry(
                _resource_models.ResourceName.MEMORY, "400M"),
        ]

        task_spec2 = get_serializable(OrderedDict(), serialization_settings,
                                      t2)
        assert task_spec2.template.container.resources.requests == [
            _resource_models.ResourceEntry(_resource_models.ResourceName.CPU,
                                           "3")
        ]
        assert task_spec2.template.container.resources.limits == []
Пример #23
0
def test_serialization_workflow_def():
    @task
    def complex_task(a: int) -> str:
        b = a + 2
        return str(b)

    maptask = map_task(complex_task, metadata=TaskMetadata(retries=1))

    @workflow
    def w1(a: typing.List[int]) -> typing.List[str]:
        return maptask(a=a)

    @workflow
    def w2(a: typing.List[int]) -> typing.List[str]:
        return map_task(complex_task, metadata=TaskMetadata(retries=2))(a=a)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    serialized_control_plane_entities = OrderedDict()
    wf1_spec = get_serializable(serialized_control_plane_entities,
                                serialization_settings, w1)
    assert wf1_spec.template is not None
    assert len(wf1_spec.template.nodes) == 1

    wf2_spec = get_serializable(serialized_control_plane_entities,
                                serialization_settings, w2)
    assert wf2_spec.template is not None
    assert len(wf2_spec.template.nodes) == 1

    flyte_entities = list(serialized_control_plane_entities.keys())

    tasks_seen = []
    for entity in flyte_entities:
        if isinstance(entity, MapPythonTask) and "complex" in entity.name:
            tasks_seen.append(entity)

    assert len(tasks_seen) == 2
    print(tasks_seen[0])
Пример #24
0
def test_sql_command():
    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    srz_t = get_serializable(OrderedDict(), serialization_settings, not_tk)
    assert srz_t.template.container.args[-5:] == [
        "--resolver",
        "flytekit.core.python_customized_container_task.default_task_template_resolver",
        "--",
        "{{.taskTemplatePath}}",
        "flytekit.extras.sqlite3.task.SQLite3TaskExecutor",
    ]
Пример #25
0
def test_nested_dynamic():
    @task
    def t1(a: int) -> str:
        a = a + 2
        return "world-" + str(a)

    @task
    def t2(a: str, b: str) -> str:
        return b + a

    @workflow
    def my_wf(a: int, b: str) -> (str, typing.List[str]):
        @dynamic
        def my_subwf(a: int) -> typing.List[str]:
            s = []
            for i in range(a):
                s.append(t1(a=i))
            return s

        x = t2(a=b, b=b)
        v = my_subwf(a=a)
        return x, v

    v = 5
    x = my_wf(a=v, b="hello ")
    assert x == ("hello hello ", ["world-" + str(i) for i in range(2, v + 2)])

    settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )

    nested_my_subwf = my_wf.get_all_tasks()[0]

    with context_manager.FlyteContext.current_context(
    ).new_serialization_settings(serialization_settings=settings) as ctx:
        with ctx.new_execution_context(
                mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
            dynamic_job_spec = nested_my_subwf.compile_into_workflow(
                ctx, False, nested_my_subwf._task_function, a=5)
            assert len(dynamic_job_spec._nodes) == 5
Пример #26
0
def test_sql_command():
    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img, images=[default_img]),
    )
    srz_t = get_serializable(OrderedDict(), serialization_settings, not_tk)
    assert srz_t.template.container.args[-7:] == [
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "plugins.tests.sqlalchemy.test_task",
        "task-name",
        "tk",
    ]
Пример #27
0
def test_ref_dynamic():
    @reference_task(
        project="flytesnacks",
        domain="development",
        name="sample.reference.task",
        version="553018f39e519bdb2597b652639c30ce16b99c79",
    )
    def ref_t1(a: int) -> str:
        ...

    @task
    def t2(a: str, b: str) -> str:
        return b + a

    @dynamic
    def my_subwf(a: int) -> typing.List[str]:
        s = []
        for i in range(a):
            s.append(ref_t1(a=i))
        return s

    with context_manager.FlyteContextManager.with_context(
            context_manager.FlyteContextManager.current_context(
            ).with_serialization_settings(
                context_manager.SerializationSettings(
                    project="test_proj",
                    domain="test_domain",
                    version="abc",
                    image_config=ImageConfig(
                        Image(name="name", fqn="image", tag="name")),
                    env={},
                ))) as ctx:
        new_exc_state = ctx.execution_state.with_params(
            mode=context_manager.ExecutionState.Mode.TASK_EXECUTION)
        with context_manager.FlyteContextManager.with_context(
                ctx.with_execution_state(new_exc_state)) as ctx:
            with pytest.raises(Exception):
                my_subwf.compile_into_workflow(ctx,
                                               False,
                                               my_subwf._task_function,
                                               a=5)
Пример #28
0
def test_ref_sub_wf():
    ref_entity = get_reference_entity(
        _identifier_model.ResourceType.WORKFLOW,
        "proj",
        "dom",
        "app.other.sub_wf",
        "123",
        inputs=kwtypes(a=str, b=int),
        outputs={},
    )

    ctx = context_manager.FlyteContext.current_context()
    with pytest.raises(Exception) as e:
        ref_entity()
    assert "You must mock this out" in f"{e}"

    with context_manager.FlyteContextManager.with_context(
            ctx.with_new_compilation_state()) as ctx:
        with pytest.raises(Exception) as e:
            ref_entity()
        assert "Input was not specified" in f"{e}"

        output = ref_entity(a="hello", b=3)
        assert isinstance(output, VoidPromise)

    @workflow
    def wf1(a: str, b: int):
        ref_entity(a=a, b=b)

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    with pytest.raises(Exception):
        # Subworkflow as references don't work (probably ever). The reason is because we'd need to make a network call
        # to admin to get the structure of the subworkflow and the whole point of reference entities is that there
        # is no network call.
        get_serializable(OrderedDict(), serialization_settings, wf1)
Пример #29
0
def test_resource_overrides():
    @task
    def t1(a: str) -> str:
        return f"*~*~*~{a}*~*~*~"

    @workflow
    def my_wf(a: typing.List[str]) -> typing.List[str]:
        mappy = map_task(t1)
        map_node = create_node(mappy, a=a).with_overrides(
            requests=Resources(cpu="1", mem="100"),
            limits=Resources(cpu="2", mem="200"))
        return map_node.o0

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert len(wf_spec.template.nodes) == 1
    assert wf_spec.template.nodes[0].task_node.overrides is not None
    assert wf_spec.template.nodes[
        0].task_node.overrides.resources.requests == [
            _resources_models.ResourceEntry(_resources_models.ResourceName.CPU,
                                            "1"),
            _resources_models.ResourceEntry(
                _resources_models.ResourceName.MEMORY, "100"),
        ]

    assert wf_spec.template.nodes[0].task_node.overrides.resources.limits == [
        _resources_models.ResourceEntry(_resources_models.ResourceName.CPU,
                                        "2"),
        _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY,
                                        "200"),
    ]
Пример #30
0
def test_serialization_branch_compound_conditions():
    @task
    def t1(a: int) -> int:
        return a + 2

    @workflow
    def my_wf(a: int) -> int:
        d = (conditional("test1").if_((a == 4) | (a == 3)).then(t1(a=a)).elif_(
            a < 6).then(t1(a=a)).else_().fail("Unable to choose branch"))
        return d

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert wf_spec is not None
    assert len(wf_spec.template.nodes[0].inputs) == 1
    assert wf_spec.template.nodes[0].inputs[0].var == ".a"