Exemple #1
0
def test_wf_resolving():

    x = my_wf(a=3, b="hello")
    assert x == (5, "helloworld")

    # Because the workflow is nested inside a test, calling location will fail as it tries to find the LHS that the
    # workflow was assigned to
    assert my_wf.location == "tests.flytekit.unit.core.test_resolver.my_wf"

    workflows_tasks = my_wf.get_all_tasks()
    assert len(workflows_tasks) == 2  # Two tasks were declared inside

    # The tasks should get the location the workflow was assigned to as the resolver.
    # The args are the index.
    srz_t0_spec = get_serializable(OrderedDict(), serialization_settings,
                                   workflows_tasks[0])
    assert srz_t0_spec.template.container.args[-4:] == [
        "--resolver",
        "tests.flytekit.unit.core.test_resolver.my_wf",
        "--",
        "0",
    ]

    srz_t1_spec = get_serializable(OrderedDict(), serialization_settings,
                                   workflows_tasks[1])
    assert srz_t1_spec.template.container.args[-4:] == [
        "--resolver",
        "tests.flytekit.unit.core.test_resolver.my_wf",
        "--",
        "1",
    ]
Exemple #2
0
def test_basics():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        return a + 2, "world"

    @task
    def t2(a: str, b: str) -> str:
        return b + a

    @workflow
    def my_wf(a: int, b: str) -> (int, str):
        x, y = t1(a=a)
        d = t2(a=y, b=b)
        return x, d

    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert len(wf_spec.template.interface.inputs) == 2
    assert len(wf_spec.template.interface.outputs) == 2
    assert len(wf_spec.template.nodes) == 2
    assert wf_spec.template.id.resource_type == identifier_models.ResourceType.WORKFLOW

    # Gets cached the first time around so it's not actually fast.
    ssettings = (
        serialization_settings.new_builder().with_fast_serialization_settings(
            FastSerializationSettings(enabled=True)).build())
    task_spec = get_serializable(OrderedDict(), ssettings, t1)
    assert "pyflyte-execute" in task_spec.template.container.args

    lp = LaunchPlan.create(
        "testlp",
        my_wf,
    )
    lp_model = get_serializable(OrderedDict(), serialization_settings, lp)
    assert lp_model.id.name == "testlp"
def test_serialization():
    square = ContainerTask(
        name="square",
        input_data_dir="/var/inputs",
        output_data_dir="/var/outputs",
        inputs=kwtypes(val=int),
        outputs=kwtypes(out=int),
        image="alpine",
        command=[
            "sh", "-c",
            "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out"
        ],
    )

    sum = ContainerTask(
        name="sum",
        input_data_dir="/var/flyte/inputs",
        output_data_dir="/var/flyte/outputs",
        inputs=kwtypes(x=int, y=int),
        outputs=kwtypes(out=int),
        image="alpine",
        command=[
            "sh", "-c",
            "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out"
        ],
    )

    @workflow
    def raw_container_wf(val1: int, val2: int) -> int:
        return sum(x=square(val=val1), y=square(val=val2))

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = flytekit.configuration.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings,
                               raw_container_wf)
    assert wf_spec is not None
    assert wf_spec.template is not None
    assert len(wf_spec.template.nodes) == 3
    sqn_spec = get_serializable(OrderedDict(), serialization_settings, square)
    assert sqn_spec.template.container.image == "alpine"
    sumn_spec = get_serializable(OrderedDict(), serialization_settings, sum)
    assert sumn_spec.template.container.image == "alpine"
def test_normal_task(mock_client):
    merge_sort_remotely = load_proto_from_file(
        task_pb2.Task,
        os.path.join(responses_dir, "admin.task_pb2.Task.pb"),
    )
    admin_task = task_models.Task.from_flyte_idl(merge_sort_remotely)
    mock_client.get_task.return_value = admin_task
    ft = rr.fetch_task(name="merge_sort_remotely", version="tst")

    @workflow
    def my_wf(numbers: typing.List[int],
              run_local_at_count: int) -> typing.List[int]:
        t1_node = create_node(ft,
                              numbers=numbers,
                              run_local_at_count=run_local_at_count)
        return t1_node.o0

    serialization_settings = flytekit.configuration.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig.auto(img_name=DefaultImages.default_image()),
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert wf_spec.template.nodes[
        0].task_node.reference_id.name == "merge_sort_remotely"
Exemple #5
0
def test_nonfunction_task_and_df_input():
    @reference_task(
        project="flytesnacks",
        domain="development",
        name="ref_t1",
        version="fast56d8ce2e373baf011f4d3532e45f0a9b",
    )
    def ref_t1(
        dataframe: pd.DataFrame,
        imputation_method: str = "median",
    ) -> pd.DataFrame:
        ...

    @reference_task(
        project="flytesnacks",
        domain="development",
        name="ref_t2",
        version="aedbd6fe44051c171fd966c280c5c3036f658831",
    )
    def ref_t2(
        dataframe: pd.DataFrame,
        split_mask: int,
        num_features: int,
    ) -> pd.DataFrame:
        ...

    wb = ImperativeWorkflow(name="core.feature_engineering.workflow.fe_wf")
    wb.add_workflow_input("sqlite_archive", FlyteFile[typing.TypeVar("sqlite")])
    sql_task = SQLite3Task(
        name="dummy.sqlite.task",
        query_template="select * from data",
        inputs=kwtypes(),
        output_schema_type=FlyteSchema,
        task_config=SQLite3Config(
            uri="https://sample/data",
            compressed=True,
        ),
    )
    node_sql = wb.add_task(sql_task)
    node_t1 = wb.add_task(ref_t1, dataframe=node_sql.outputs["results"], imputation_method="mean")

    node_t2 = wb.add_task(
        ref_t2,
        dataframe=node_t1.outputs["o0"],
        split_mask=24,
        num_features=15,
    )
    wb.add_workflow_output("output_from_t3", node_t2.outputs["o0"], python_type=pd.DataFrame)

    wf_spec = get_serializable(OrderedDict(), serialization_settings, wb)
    assert len(wf_spec.template.nodes) == 3

    assert len(wf_spec.template.interface.inputs) == 1
    assert wf_spec.template.interface.inputs["sqlite_archive"].type.blob is not None

    assert len(wf_spec.template.interface.outputs) == 1
    assert wf_spec.template.interface.outputs["output_from_t3"].type.structured_dataset_type is not None
    assert wf_spec.template.interface.outputs["output_from_t3"].type.structured_dataset_type == StructuredDatasetType(
        format="parquet"
    )
Exemple #6
0
def test_condition_tuple_branches():
    @task
    def sum_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, sub=int):
        return a + b, a - b

    @workflow
    def math_ops(a: int, b: int) -> typing.Tuple[int, int]:
        add, sub = (
            conditional("noDivByZero")
            .if_(a > b)
            .then(sum_sub(a=a, b=b))
            .else_()
            .fail("Only positive results are allowed")
        )

        return add, sub

    x, y = math_ops(a=3, b=2)
    assert x == 5
    assert y == 1

    wf_spec = get_serializable(OrderedDict(), serialization_settings, math_ops)
    assert len(wf_spec.template.nodes) == 1
    assert (
        wf_spec.template.nodes[0].branch_node.if_else.case.then_node.task_node.reference_id.name
        == "tests.flytekit.unit.core.test_conditions.sum_sub"
    )
Exemple #7
0
def test_pod_task_serialized():
    pod = Pod(
        pod_spec=get_pod_spec(),
        primary_container_name="an undefined container",
        labels={"label": "foo"},
        annotations={"anno": "bar"},
    )

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")
    ssettings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    serialized = get_serializable(OrderedDict(), ssettings, simple_pod_task)
    assert serialized.template.task_type_version == 2
    assert serialized.template.config[
        "primary_container_name"] == "an undefined container"
    assert serialized.template.k8s_pod.metadata.labels == {"label": "foo"}
    assert serialized.template.k8s_pod.metadata.annotations == {"anno": "bar"}
    assert serialized.template.k8s_pod.pod_spec is not None
def test_serialization_branch():
    @task
    def mimic(a: int) -> typing.NamedTuple("OutputsBC", c=int):
        return (a, )

    @task
    def t1(c: int) -> typing.NamedTuple("OutputsBC", c=str):
        return ("world", )

    @task
    def t2() -> typing.NamedTuple("OutputsBC", c=str):
        return ("hello", )

    @workflow
    def my_wf(a: int) -> str:
        c = mimic(a=a)
        return conditional("test1").if_(c.c == 4).then(
            t1(c=c.c).c).else_().then(t2().c)

    assert my_wf(a=4) == "world"
    assert my_wf(a=2) == "hello"

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = flytekit.configuration.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert wf_spec is not None
    assert len(wf_spec.template.nodes) == 2
    assert wf_spec.template.nodes[1].branch_node is not None
Exemple #9
0
def test_launch_plan_with_fixed_input():
    @task
    def greet(day_of_week: str, number: int, am: bool) -> str:
        greeting = "Have a great " + day_of_week + " "
        greeting += "morning" if am else "evening"
        return greeting + "!" * number

    @workflow
    def go_greet(day_of_week: str, number: int, am: bool = False) -> str:
        return greet(day_of_week=day_of_week, number=number, am=am)

    morning_greeting = LaunchPlan.create(
        "morning_greeting",
        go_greet,
        fixed_inputs={"am": True},
        default_inputs={"number": 1},
    )

    @workflow
    def morning_greeter_caller(day_of_week: str) -> str:
        greeting = morning_greeting(day_of_week=day_of_week)
        return greeting

    settings = (
        serialization_settings.new_builder().with_fast_serialization_settings(
            FastSerializationSettings(enabled=True)).build())
    task_spec = get_serializable(OrderedDict(), settings,
                                 morning_greeter_caller)
    assert len(task_spec.template.interface.inputs) == 1
    assert len(task_spec.template.interface.outputs) == 1
    assert len(task_spec.template.nodes) == 1
    assert len(task_spec.template.nodes[0].inputs) == 2
Exemple #10
0
def test_resolver_load_task():
    # any task is fine, just copied one
    square = ContainerTask(
        name="square",
        input_data_dir="/var/inputs",
        output_data_dir="/var/outputs",
        inputs=kwtypes(val=int),
        outputs=kwtypes(out=int),
        image="alpine",
        command=[
            "sh", "-c",
            "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out"
        ],
    )

    resolver = TaskTemplateResolver()
    ts = get_serializable(OrderedDict(), serialization_settings, square)
    file = tempfile.NamedTemporaryFile().name
    # load_task should create an instance of the path to the object given, doesn't need to be a real executor
    write_proto_to_file(ts.template.to_flyte_idl(), file)
    shim_task = resolver.load_task(
        [file, f"{Placeholder.__module__}.Placeholder"])
    assert isinstance(shim_task.executor, Placeholder)
    assert shim_task.task_template.id.name == "square"
    assert shim_task.task_template.interface.inputs["val"] is not None
    assert shim_task.task_template.interface.outputs["out"] is not None
Exemple #11
0
def test_wf_nested_comp():
    @task
    def t1(a: int) -> int:
        a = a + 5
        return a

    @workflow
    def outer() -> typing.Tuple[int, int]:
        # You should not do this. This is just here for testing.
        @workflow
        def wf2() -> int:
            return t1(a=5)

        return t1(a=3), wf2()

    assert (8, 10) == outer()
    entity_mapping = OrderedDict()

    model_wf = get_serializable(entity_mapping, serialization_settings, outer)

    assert len(model_wf.template.interface.outputs) == 2
    assert len(model_wf.template.nodes) == 2
    assert model_wf.template.nodes[1].workflow_node is not None

    sub_wf = model_wf.sub_workflows[0]
    assert len(sub_wf.nodes) == 1
    assert sub_wf.nodes[0].id == "n0"
    assert sub_wf.nodes[0].task_node.reference_id.name == "tests.flytekit.unit.core.test_workflows.t1"
Exemple #12
0
def test_resources_override():
    @task
    def t1(a: str) -> str:
        return f"*~*~*~{a}*~*~*~"

    @workflow
    def my_wf(a: typing.List[str]) -> typing.List[str]:
        mappy = map_task(t1)
        map_node = mappy(a=a).with_overrides(
            requests=Resources(cpu="1", mem="100", ephemeral_storage="500Mi"),
            limits=Resources(cpu="2", mem="200", ephemeral_storage="1Gi"),
        )
        return map_node

    serialization_settings = flytekit.configuration.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert len(wf_spec.template.nodes) == 1
    assert wf_spec.template.nodes[0].task_node.overrides is not None
    assert wf_spec.template.nodes[0].task_node.overrides.resources.requests == [
        _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "1"),
        _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "100"),
        _resources_models.ResourceEntry(_resources_models.ResourceName.EPHEMERAL_STORAGE, "500Mi"),
    ]

    assert wf_spec.template.nodes[0].task_node.overrides.resources.limits == [
        _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2"),
        _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "200"),
        _resources_models.ResourceEntry(_resources_models.ResourceName.EPHEMERAL_STORAGE, "1Gi"),
    ]
Exemple #13
0
def test_serialization(serialization_settings):
    maptask = map_task(t1, metadata=TaskMetadata(retries=1))
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 maptask)

    # By default all map_task tasks will have their custom fields set.
    assert task_spec.template.custom["minSuccessRatio"] == 1.0
    assert task_spec.template.type == "container_array"
    assert task_spec.template.task_type_version == 1
    assert task_spec.template.container.args == [
        "pyflyte-map-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--checkpoint-path",
        "{{.checkpointOutputPrefix}}",
        "--prev-checkpoint",
        "{{.prevCheckpointPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "tests.flytekit.unit.core.test_map_task",
        "task-name",
        "t1",
    ]
def test_serialization_branch_complex_2():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        return a + 2, "world"

    @task
    def t2(a: str) -> str:
        return a

    @workflow
    def my_wf(a: int, b: str) -> (int, str):
        x, y = t1(a=a)
        d = (conditional("test1").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then(
            t2(a=y)).else_().fail("Unable to choose branch"))
        f = conditional("test2").if_(d == "hello ").then(
            t2(a="It is hello")).else_().then(t2(a="Not Hello!"))
        return x, f

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = flytekit.configuration.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert wf_spec is not None
    assert wf_spec.template.nodes[1].inputs[0].var == "n0.t1_int_output"
Exemple #15
0
def test_serialization_of_custom_fields(custom_fields_dict,
                                        expected_custom_fields,
                                        serialization_settings):
    maptask = map_task(t1, **custom_fields_dict)
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 maptask)

    assert task_spec.template.custom == expected_custom_fields
Exemple #16
0
def test_example_module():
    @task
    def t1() -> torch.nn.Module:
        return torch.nn.BatchNorm1d(3, track_running_stats=True)

    task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
    assert task_spec.template.interface.outputs[
        "o0"].type.blob.format is PyTorchModuleTransformer.PYTORCH_FORMAT
Exemple #17
0
def test_example_tensor():
    @task
    def t1(array: torch.Tensor) -> torch.Tensor:
        return torch.flatten(array)

    task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
    assert task_spec.template.interface.outputs[
        "o0"].type.blob.format is PyTorchTensorTransformer.PYTORCH_FORMAT
Exemple #18
0
def test_wf_docstring():
    model_wf = get_serializable(OrderedDict(), serialization_settings, my_wf_example)

    assert len(model_wf.template.interface.outputs) == 2
    assert model_wf.template.interface.outputs["o0"].description == "outputs"
    assert model_wf.template.interface.outputs["o1"].description == "outputs"
    assert len(model_wf.template.interface.inputs) == 1
    assert model_wf.template.interface.inputs["a"].description == "input a"
Exemple #19
0
def test_example():
    @task
    def t1(array: np.ndarray) -> np.ndarray:
        return array.flatten()

    task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
    assert task_spec.template.interface.outputs[
        "o0"].type.blob.format is NumpyArrayTransformer.NUMPY_ARRAY_FORMAT
def test_serialization_images():
    @task(container_image="{{.image.xyz.fqn}}:{{.image.xyz.version}}")
    def t1(a: int) -> int:
        return a

    @task(container_image="{{.image.abc.fqn}}:{{.image.xyz.version}}")
    def t2():
        pass

    @task(container_image="docker.io/org/myimage:latest")
    def t4():
        pass

    @task(container_image="docker.io/org/myimage:{{.image.xyz.version}}")
    def t5(a: int) -> int:
        return a

    @task(container_image="{{.image.xyz_123.fqn}}:{{.image.xyz_123.version}}")
    def t6(a: int) -> int:
        return a

    os.environ["FLYTE_INTERNAL_IMAGE"] = "docker.io/default:version"
    imgs = ImageConfig.auto(config_file=os.path.join(
        os.path.dirname(os.path.realpath(__file__)), "configs/images.config"))
    rs = flytekit.configuration.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=imgs,
    )
    t1_spec = get_serializable(OrderedDict(), rs, t1)
    assert t1_spec.template.container.image == "docker.io/xyz:latest"
    t1_spec.to_flyte_idl()

    t2_spec = get_serializable(OrderedDict(), rs, t2)
    assert t2_spec.template.container.image == "docker.io/abc:latest"

    t4_spec = get_serializable(OrderedDict(), rs, t4)
    assert t4_spec.template.container.image == "docker.io/org/myimage:latest"

    t5_spec = get_serializable(OrderedDict(), rs, t5)
    assert t5_spec.template.container.image == "docker.io/org/myimage:latest"

    t5_spec = get_serializable(OrderedDict(), rs, t6)
    assert t5_spec.template.container.image == "docker.io/xyz_123:v1"
def get_registrable_entities(
    ctx: flyte_context.FlyteContext,
    options: typing.Optional[Options] = None
) -> typing.List[RegistrableEntity]:
    """
    Returns all entities that can be serialized and should be sent over to Flyte backend. This will filter any entities
    that are not known to Admin
    """
    new_api_serializable_entities = OrderedDict()
    # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan
    #  object, which gets added to the FlyteEntities.entities list, which we're iterating over.
    for entity in flyte_context.FlyteEntities.entities.copy():
        if isinstance(entity, PythonTask) or isinstance(
                entity, WorkflowBase) or isinstance(entity, LaunchPlan):
            get_serializable(new_api_serializable_entities,
                             ctx.serialization_settings,
                             entity,
                             options=options)

            if isinstance(entity, WorkflowBase):
                lp = LaunchPlan.get_default_launch_plan(ctx, entity)
                get_serializable(new_api_serializable_entities,
                                 ctx.serialization_settings, lp, options)

    new_api_model_values = list(new_api_serializable_entities.values())
    entities_to_be_serialized = list(
        filter(_should_register_with_admin, new_api_model_values))
    serializable_tasks: typing.List[task_models.TaskSpec] = [
        entity for entity in entities_to_be_serialized
        if isinstance(entity, task_models.TaskSpec)
    ]
    # Detect if any of the tasks is duplicated. Duplicate tasks are defined as having the same
    # metadata identifiers (see :py:class:`flytekit.common.core.identifier.Identifier`). Duplicate
    # tasks are considered invalid at registration
    # time and usually indicate user error, so we catch this common mistake at serialization time.
    duplicate_tasks = _find_duplicate_tasks(serializable_tasks)
    if len(duplicate_tasks) > 0:
        duplicate_task_names = [
            task.template.id.name for task in duplicate_tasks
        ]
        raise FlyteValidationException(
            f"Multiple definitions of the following tasks were found: {duplicate_task_names}"
        )

    return [v.to_flyte_idl() for v in entities_to_be_serialized]
Exemple #22
0
def test_calling_wf():
    # No way to fetch from Admin in unit tests so we serialize and then promote back
    serialized = OrderedDict()
    wf_spec = get_serializable(serialized, serialization_settings, sub_wf)
    task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized)
    fwf = FlyteWorkflow.promote_from_model(wf_spec.template,
                                           tasks=task_templates)

    @workflow
    def parent_1(a: int, b: str) -> typing.Tuple[int, str]:
        y = t1(a=a)
        return fwf(a=y, b=b)

    # No way to fetch from Admin in unit tests so we serialize and then promote back
    serialized = OrderedDict()
    wf_spec = get_serializable(serialized, serialization_settings, parent_1)
    # Get task_specs from the second one, merge with the first one. Admin normally would be the one to do this.
    task_templates_p1, wf_specs, lp_specs = gather_dependent_entities(
        serialized)
    for k, v in task_templates.items():
        task_templates_p1[k] = v

    # Pick out the subworkflow templates from the ordereddict. We can't use the output of the gather_dependent_entities
    # function because that only looks for WorkflowSpecs
    subwf_templates = {
        x.id: x
        for x in list(
            filter(lambda x: isinstance(x, WorkflowTemplate),
                   serialized.values()))
    }
    fwf_p1 = FlyteWorkflow.promote_from_model(wf_spec.template,
                                              sub_workflows=subwf_templates,
                                              tasks=task_templates_p1)

    @workflow
    def parent_2(a: int, b: str) -> typing.Tuple[int, str]:
        x, y = fwf_p1(a=a, b=b)
        z = t1(a=x)
        return z, y

    serialized = OrderedDict()
    wf_spec = get_serializable(serialized, serialization_settings, parent_2)
    # Make sure both were picked up.
    assert len(wf_spec.sub_workflows) == 2
Exemple #23
0
def test_nested_condition_2():
    @workflow
    def multiplier_2(my_input: float) -> float:
        return (
            conditional("fractions")
            .if_((my_input > 0.1) & (my_input < 1.0))
            .then(
                conditional("inner_fractions")
                .if_(my_input < 0.5)
                .then(double(n=my_input))
                .elif_((my_input > 0.5) & (my_input < 0.7))
                .then(square(n=my_input))
                .else_()
                .fail("Only <0.7 allowed")
            )
            .elif_((my_input > 1.0) & (my_input < 10.0))
            .then(square(n=my_input))
            .else_()
            .then(double(n=my_input))
        )

    srz_wf = get_serializable(OrderedDict(), serialization_settings, multiplier_2)
    assert len(srz_wf.template.nodes) == 1
    fractions_branch = srz_wf.template.nodes[0]
    assert isinstance(fractions_branch, Node)
    assert fractions_branch.id == "n0"
    assert fractions_branch.branch_node is not None
    if_else_b = fractions_branch.branch_node.if_else
    assert if_else_b is not None
    assert if_else_b.case is not None
    assert if_else_b.case.then_node is not None
    inner_fractions_node = if_else_b.case.then_node
    assert inner_fractions_node.id == "n0"
    assert inner_fractions_node.branch_node.if_else.case.then_node.task_node is not None
    assert inner_fractions_node.branch_node.if_else.case.then_node.id == "n0"
    assert len(inner_fractions_node.branch_node.if_else.other) == 1
    assert inner_fractions_node.branch_node.if_else.other[0].then_node.id == "n1"

    # Ensure other cases exist
    assert len(if_else_b.other) == 1
    assert if_else_b.other[0].then_node.task_node is not None
    assert if_else_b.other[0].then_node.id == "n1"

    with pytest.raises(ValueError):
        multiplier_2(my_input=0.7)

    res = multiplier_2(my_input=0.3)
    assert res == 0.6

    res = multiplier_2(my_input=5.0)
    assert res == 25

    res = multiplier_2(my_input=10.0)
    assert res == 20
Exemple #24
0
def test_ref_sub_wf():
    ref_entity = get_reference_entity(
        _identifier_model.ResourceType.WORKFLOW,
        "proj",
        "dom",
        "app.other.sub_wf",
        "123",
        inputs=kwtypes(a=str, b=int),
        outputs={},
    )

    ctx = context_manager.FlyteContext.current_context()
    with pytest.raises(Exception) as e:
        ref_entity()
    assert "You must mock this out" in f"{e}"

    with context_manager.FlyteContextManager.with_context(
            ctx.with_new_compilation_state()) as ctx:
        with pytest.raises(Exception) as e:
            ref_entity()
        assert "Input was not specified" in f"{e}"

        output = ref_entity(a="hello", b=3)
        assert isinstance(output, VoidPromise)

    @workflow
    def wf1(a: str, b: int):
        ref_entity(a=a, b=b)

    serialization_settings = flytekit.configuration.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    with pytest.raises(Exception, match="currently unsupported"):
        # Subworkflow as references don't work (probably ever). The reason is because we'd need to make a network call
        # to admin to get the structure of the subworkflow and the whole point of reference entities is that there
        # is no network call.
        get_serializable(OrderedDict(), serialization_settings, wf1)
def test_serialization_named_return():
    @task
    def t1() -> str:
        return "Hello"

    @workflow
    def wf() -> typing.NamedTuple("OP", a=str, b=str):
        return t1(), t1()

    wf_spec = get_serializable(OrderedDict(), serialization_settings, wf)
    assert len(wf_spec.template.interface.outputs) == 2
    assert list(wf_spec.template.interface.outputs.keys()) == ["a", "b"]
Exemple #26
0
def test_lps(resource_type):
    ref_entity = get_reference_entity(
        resource_type,
        "proj",
        "dom",
        "app.other.flyte_entity",
        "123",
        inputs=kwtypes(a=str, b=int),
        outputs={},
    )

    ctx = context_manager.FlyteContext.current_context()
    with pytest.raises(Exception) as e:
        ref_entity()
    assert "You must mock this out" in f"{e}"

    with context_manager.FlyteContextManager.with_context(
            ctx.with_new_compilation_state()) as ctx:
        with pytest.raises(Exception) as e:
            ref_entity()
        assert "Input was not specified" in f"{e}"

        output = ref_entity(a="hello", b=3)
        assert isinstance(output, VoidPromise)

    @workflow
    def wf1(a: str, b: int):
        ref_entity(a=a, b=b)

    serialization_settings = flytekit.configuration.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, wf1)
    assert len(wf_spec.template.interface.inputs) == 2
    assert len(wf_spec.template.interface.outputs) == 0
    assert len(wf_spec.template.nodes) == 1
    if resource_type == _identifier_model.ResourceType.LAUNCH_PLAN:
        assert wf_spec.template.nodes[
            0].workflow_node.launchplan_ref.project == "proj"
        assert wf_spec.template.nodes[
            0].workflow_node.launchplan_ref.name == "app.other.flyte_entity"
    else:
        assert wf_spec.template.nodes[
            0].task_node.reference_id.project == "proj"
        assert wf_spec.template.nodes[
            0].task_node.reference_id.name == "app.other.flyte_entity"
Exemple #27
0
def test_calling_lp():
    sub_wf_lp = LaunchPlan.get_or_create(sub_wf)
    serialized = OrderedDict()
    lp_model = get_serializable(serialized, serialization_settings, sub_wf_lp)
    task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized)
    for wf_id, spec in wf_specs.items():
        break

    remote_lp = FlyteLaunchPlan.promote_from_model(lp_model.id, lp_model.spec)
    # To pretend that we've fetched this launch plan from Admin, also fill in the Flyte interface, which isn't
    # part of the IDL object but is something FlyteRemote does
    remote_lp._interface = TypedInterface.promote_from_model(
        spec.template.interface)
    serialized = OrderedDict()

    @workflow
    def wf2(a: int) -> typing.Tuple[int, str]:
        return remote_lp(a=a, b="hello")

    wf_spec = get_serializable(serialized, serialization_settings, wf2)
    print(wf_spec.template.nodes[0].workflow_node.launchplan_ref)
    assert wf_spec.template.nodes[
        0].workflow_node.launchplan_ref == lp_model.id
def test_serialization_named_outputs_single():
    @task
    def t1() -> typing.NamedTuple("OP", a=str):
        return "Hello"

    @workflow
    def wf() -> typing.NamedTuple("OP", a=str):
        return t1().a

    wf_spec = get_serializable(OrderedDict(), serialization_settings, wf)
    assert len(wf_spec.template.interface.outputs) == 1
    assert list(wf_spec.template.interface.outputs.keys()) == ["a"]
    a = wf()
    assert a.a == "Hello"
def test_serialization_types():
    @task(cache=True, cache_version="1.0.0")
    def squared(value: int) -> typing.List[typing.Dict[str, int]]:
        return [
            {
                "squared_value": value**2
            },
        ]

    @workflow
    def compute_square_wf(
            input_integer: int) -> typing.List[typing.Dict[str, int]]:
        compute_square_result = squared(value=input_integer)
        return compute_square_result

    wf_spec = get_serializable(OrderedDict(), serialization_settings,
                               compute_square_wf)
    assert wf_spec.template.interface.outputs[
        "o0"].type.collection_type.map_value_type.simple == SimpleType.INTEGER
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 squared)
    assert task_spec.template.interface.outputs[
        "o0"].type.collection_type.map_value_type.simple == SimpleType.INTEGER
Exemple #30
0
def test_all_node_types():
    assert my_wf_example(a=1) == (6, 16)
    entity_mapping = OrderedDict()

    model_wf = get_serializable(entity_mapping, serialization_settings, my_wf_example)

    assert len(model_wf.template.interface.outputs) == 2
    assert len(model_wf.template.nodes) == 4
    assert model_wf.template.nodes[2].workflow_node is not None

    sub_wf = model_wf.sub_workflows[0]
    assert len(sub_wf.nodes) == 1
    assert sub_wf.nodes[0].id == "n0"
    assert sub_wf.nodes[0].task_node.reference_id.name == "tests.flytekit.unit.core.test_workflows.add_5"