Ejemplo n.º 1
0
def test_container_image_conversion():
    default_img = Image(name="default", fqn="xyz.com/abc", tag="tag1")
    other_img = Image(name="other", fqn="xyz.com/other", tag="tag-other")
    cfg = ImageConfig(default_image=default_img, images=[default_img, other_img])
    assert get_registerable_container_image(None, cfg) == "xyz.com/abc:tag1"
    assert get_registerable_container_image("", cfg) == "xyz.com/abc:tag1"
    assert get_registerable_container_image("abc", cfg) == "abc"
    assert get_registerable_container_image("abc:latest", cfg) == "abc:latest"
    assert get_registerable_container_image("abc:{{.image.default.version}}", cfg) == "abc:tag1"
    assert (
        get_registerable_container_image("{{.image.default.fqn}}:{{.image.default.version}}", cfg) == "xyz.com/abc:tag1"
    )
    assert (
        get_registerable_container_image("{{.image.other.fqn}}:{{.image.other.version}}", cfg)
        == "xyz.com/other:tag-other"
    )
    assert (
        get_registerable_container_image("{{.image.other.fqn}}:{{.image.default.version}}", cfg) == "xyz.com/other:tag1"
    )
    assert get_registerable_container_image("{{.image.other.fqn}}", cfg) == "xyz.com/other"
    # Works with images instead of just image
    assert get_registerable_container_image("{{.images.other.fqn}}", cfg) == "xyz.com/other"

    with pytest.raises(AssertionError):
        get_registerable_container_image("{{.image.blah.fqn}}:{{.image.other.version}}", cfg)

    with pytest.raises(AssertionError):
        get_registerable_container_image("{{.image.fqn}}:{{.image.other.version}}", cfg)

    with pytest.raises(AssertionError):
        get_registerable_container_image("{{.image.blah}}", cfg)

    assert get_registerable_container_image("{{.image.default}}", cfg) == "xyz.com/abc:tag1"
Ejemplo n.º 2
0
def test_py_func_task_get_container():
    def foo(i: int):
        pass

    default_img = Image(name="default", fqn="xyz.com/abc", tag="tag1")
    other_img = Image(name="other", fqn="xyz.com/other", tag="tag-other")
    cfg = ImageConfig(default_image=default_img, images=[default_img, other_img])

    settings = SerializationSettings(project="p", domain="d", version="v", image_config=cfg, env={"FOO": "bar"})

    pytask = PythonFunctionTask(None, foo, None, environment={"BAZ": "baz"})
    c = pytask.get_container(settings)
    assert c.image == "xyz.com/abc:tag1"
    assert c.env == {"FOO": "bar", "BAZ": "baz"}
Ejemplo n.º 3
0
def test_resources_override():
    @task
    def t1(a: str) -> str:
        return f"*~*~*~{a}*~*~*~"

    @workflow
    def my_wf(a: typing.List[str]) -> typing.List[str]:
        mappy = map_task(t1)
        map_node = mappy(a=a).with_overrides(
            requests=Resources(cpu="1", mem="100", ephemeral_storage="500Mi"),
            limits=Resources(cpu="2", mem="200", ephemeral_storage="1Gi"),
        )
        return map_node

    serialization_settings = flytekit.configuration.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert len(wf_spec.template.nodes) == 1
    assert wf_spec.template.nodes[0].task_node.overrides is not None
    assert wf_spec.template.nodes[0].task_node.overrides.resources.requests == [
        _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "1"),
        _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "100"),
        _resources_models.ResourceEntry(_resources_models.ResourceName.EPHEMERAL_STORAGE, "500Mi"),
    ]

    assert wf_spec.template.nodes[0].task_node.overrides.resources.limits == [
        _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2"),
        _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "200"),
        _resources_models.ResourceEntry(_resources_models.ResourceName.EPHEMERAL_STORAGE, "1Gi"),
    ]
Ejemplo n.º 4
0
def test_pod_task_serialized():
    pod = Pod(
        pod_spec=get_pod_spec(),
        primary_container_name="an undefined container",
        labels={"label": "foo"},
        annotations={"anno": "bar"},
    )

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")
    ssettings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    serialized = get_serializable(OrderedDict(), ssettings, simple_pod_task)
    assert serialized.template.task_type_version == 2
    assert serialized.template.config[
        "primary_container_name"] == "an undefined container"
    assert serialized.template.k8s_pod.metadata.labels == {"label": "foo"}
    assert serialized.template.k8s_pod.metadata.annotations == {"anno": "bar"}
    assert serialized.template.k8s_pod.pod_spec is not None
Ejemplo n.º 5
0
def test_pytorch_task():
    @task(
        task_config=PyTorch(num_workers=10),
        cache=True,
        cache_version="1",
        requests=Resources(cpu="1"),
    )
    def my_pytorch_task(x: int, y: str) -> int:
        return x

    assert my_pytorch_task(x=10, y="hello") == 10

    assert my_pytorch_task.task_config is not None

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    assert my_pytorch_task.get_custom(settings) == {"workers": 10}
    assert my_pytorch_task.resources.limits == Resources()
    assert my_pytorch_task.resources.requests == Resources(cpu="1")
    assert my_pytorch_task.task_type == "pytorch"
Ejemplo n.º 6
0
def test_tensorflow_task():
    @task(
        task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1),
        cache=True,
        requests=Resources(cpu="1"),
        cache_version="1",
    )
    def my_tensorflow_task(x: int, y: str) -> int:
        return x

    assert my_tensorflow_task(x=10, y="hello") == 10

    assert my_tensorflow_task.task_config is not None

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img, images=[default_img]),
    )

    assert my_tensorflow_task.get_custom(settings) == {"workers": 10, "psReplicas": 1, "chiefReplicas": 1}
    assert my_tensorflow_task.resources.limits == Resources()
    assert my_tensorflow_task.resources.requests == Resources(cpu="1")
    assert my_tensorflow_task.task_type == "tensorflow"
Ejemplo n.º 7
0
def test_two(two_sample_inputs):
    my_input = two_sample_inputs[0]
    my_input_2 = two_sample_inputs[1]

    @dynamic
    def dt1(a: List[MyInput]) -> List[FlyteFile]:
        x = []
        for aa in a:
            x.append(aa.main_product)
        return x

    with FlyteContextManager.with_context(
        FlyteContextManager.current_context().with_serialization_settings(
            SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
                env={},
            )
        )
    ) as ctx:
        with FlyteContextManager.with_context(
            ctx.with_execution_state(
                ctx.execution_state.with_params(
                    mode=ExecutionState.Mode.TASK_EXECUTION,
                )
            )
        ) as ctx:
            input_literal_map = TypeEngine.dict_to_literal_map(
                ctx, d={"a": [my_input, my_input_2]}, type_hints={"a": List[MyInput]}
            )
            dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map)
            assert len(dynamic_job_spec.literals["o0"].collection.literals) == 2
Ejemplo n.º 8
0
def test_mpi_task():
    @task(
        task_config=MPIJob(num_workers=10, num_launcher_replicas=10, slots=1),
        requests=Resources(cpu="1"),
        cache=True,
        cache_version="1",
    )
    def my_mpi_task(x: int, y: str) -> int:
        return x

    assert my_mpi_task(x=10, y="hello") == 10

    assert my_mpi_task.task_config is not None

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    assert my_mpi_task.get_custom(settings) == {
        "numLauncherReplicas": 10,
        "numWorkers": 10,
        "slots": 1
    }
    assert my_mpi_task.task_type == "mpi"
Ejemplo n.º 9
0
def test_dont_convert_remotes():
    @task
    def t1(in1: FlyteDirectory):
        print(in1)

    @dynamic
    def dyn(in1: FlyteDirectory):
        t1(in1=in1)

    fd = FlyteDirectory("s3://anything")

    ctx = context_manager.FlyteContext.current_context()
    with context_manager.FlyteContextManager.with_context(
        ctx.with_serialization_settings(
            flytekit.configuration.SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
                env={},
            )
        )
    ) as ctx:
        with context_manager.FlyteContextManager.with_context(
            ctx.with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION))
        ) as ctx:
            lit = TypeEngine.to_literal(
                ctx, fd, FlyteDirectory, BlobType("", dimensionality=BlobType.BlobDimensionality.MULTIPART)
            )
            lm = LiteralMap(literals={"in1": lit})
            wf = dyn.dispatch_execute(ctx, lm)
            assert wf.nodes[0].inputs[0].binding.scalar.blob.uri == "s3://anything"
Ejemplo n.º 10
0
def test_serialization_branch():
    @task
    def mimic(a: int) -> typing.NamedTuple("OutputsBC", c=int):
        return (a, )

    @task
    def t1(c: int) -> typing.NamedTuple("OutputsBC", c=str):
        return ("world", )

    @task
    def t2() -> typing.NamedTuple("OutputsBC", c=str):
        return ("hello", )

    @workflow
    def my_wf(a: int) -> str:
        c = mimic(a=a)
        return conditional("test1").if_(c.c == 4).then(
            t1(c=c.c).c).else_().then(t2().c)

    assert my_wf(a=4) == "world"
    assert my_wf(a=2) == "hello"

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = flytekit.configuration.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert wf_spec is not None
    assert len(wf_spec.template.nodes) == 2
    assert wf_spec.template.nodes[1].branch_node is not None
Ejemplo n.º 11
0
def test_ref_dynamic_lp():
    @dynamic
    def my_subwf(a: int) -> typing.List[int]:
        @reference_launch_plan(project="project",
                               domain="domain",
                               name="name",
                               version="version")
        def ref_lp1(p1: str, p2: str) -> int:
            ...

        s = []
        for i in range(a):
            s.append(ref_lp1(p1="hello", p2=str(a)))
        return s

    with context_manager.FlyteContextManager.with_context(
            context_manager.FlyteContextManager.current_context(
            ).with_serialization_settings(
                flytekit.configuration.SerializationSettings(
                    project="test_proj",
                    domain="test_domain",
                    version="abc",
                    image_config=ImageConfig(
                        Image(name="name", fqn="image", tag="name")),
                    env={},
                ))) as ctx:
        new_exc_state = ctx.execution_state.with_params(
            mode=context_manager.ExecutionState.Mode.TASK_EXECUTION)
        with context_manager.FlyteContextManager.with_context(
                ctx.with_execution_state(new_exc_state)) as ctx:
            djspec = my_subwf.compile_into_workflow(ctx,
                                                    my_subwf._task_function,
                                                    a=5)
            assert len(djspec.nodes) == 5
Ejemplo n.º 12
0
def test_query_no_inputs_or_outputs():
    hive_task = HiveTask(
        name="flytekit.demo.hive_task.hivequery1",
        inputs={},
        task_config=HiveConfig(cluster_label="flyte"),
        query_template="""
            insert into extant_table (1, 'two')
        """,
        output_schema_type=None,
    )

    @workflow
    def my_wf():
        hive_task()

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 hive_task)
    assert len(task_spec.template.interface.inputs) == 0
    assert len(task_spec.template.interface.outputs) == 0

    get_serializable(OrderedDict(), serialization_settings, my_wf)
Ejemplo n.º 13
0
def test_serialization_settings_transport():
    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"hello": "blah"},
        image_config=ImageConfig(
            default_image=default_img,
            images=[default_img],
        ),
        flytekit_virtualenv_root="/opt/venv/blah",
        python_interpreter="/opt/venv/bin/python3",
        fast_serialization_settings=FastSerializationSettings(
            enabled=True,
            destination_dir="/opt/blah/blah/blah",
            distribution_location="s3://my-special-bucket/blah/bha/asdasdasd/cbvsdsdf/asdddasdasdasdasdasdasd.tar.gz",
        ),
    )

    tp = serialization_settings.serialized_context
    with_serialized = serialization_settings.with_serialized_context()
    assert serialization_settings.env == {"hello": "blah"}
    assert with_serialized.env
    assert with_serialized.env[SERIALIZED_CONTEXT_ENV_VAR] == tp
    ss = SerializationSettings.from_transport(tp)
    assert ss is not None
    assert ss == serialization_settings
    assert len(tp) == 376
Ejemplo n.º 14
0
def test_serialization_branch_complex_2():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        return a + 2, "world"

    @task
    def t2(a: str) -> str:
        return a

    @workflow
    def my_wf(a: int, b: str) -> (int, str):
        x, y = t1(a=a)
        d = (conditional("test1").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then(
            t2(a=y)).else_().fail("Unable to choose branch"))
        f = conditional("test2").if_(d == "hello ").then(
            t2(a="It is hello")).else_().then(t2(a="Not Hello!"))
        return x, f

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = flytekit.configuration.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert wf_spec is not None
    assert wf_spec.template.nodes[1].inputs[0].var == "n0.t1_int_output"
Ejemplo n.º 15
0
def serialization_settings():
    default_img = Image(name="default", fqn="test", tag="tag")
    return flytekit.configuration.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
Ejemplo n.º 16
0
def _get_reg_settings():
    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    return settings
Ejemplo n.º 17
0
def test_dc_dyn_directory(folders_and_files_setup):
    proxy_c = MyProxyConfiguration(splat_data_dir="/tmp/proxy_splat", apriori_file="/opt/config/a_file")
    proxy_p = MyProxyParameters(id="pp_id", job_i_step=1)

    my_input_gcs = MyInput(
        main_product=FlyteFile(folders_and_files_setup[0]),
        apriori_config=MyAprioriConfiguration(
            static_data_dir=FlyteDirectory("gs://my-bucket/one"),
            external_data_dir=FlyteDirectory("gs://my-bucket/two"),
        ),
        proxy_config=proxy_c,
        proxy_params=proxy_p,
    )

    my_input_gcs_2 = MyInput(
        main_product=FlyteFile(folders_and_files_setup[0]),
        apriori_config=MyAprioriConfiguration(
            static_data_dir=FlyteDirectory("gs://my-bucket/three"),
            external_data_dir=FlyteDirectory("gs://my-bucket/four"),
        ),
        proxy_config=proxy_c,
        proxy_params=proxy_p,
    )

    @dynamic
    def dt1(a: List[MyInput]) -> List[FlyteDirectory]:
        x = []
        for aa in a:
            x.append(aa.apriori_config.external_data_dir)

        return x

    ctx = FlyteContextManager.current_context()
    cb = (
        ctx.new_builder()
        .with_serialization_settings(
            SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
                env={},
            )
        )
        .with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION))
    )
    with FlyteContextManager.with_context(cb) as ctx:
        input_literal_map = TypeEngine.dict_to_literal_map(
            ctx, d={"a": [my_input_gcs, my_input_gcs_2]}, type_hints={"a": List[MyInput]}
        )
        dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map)
        assert dynamic_job_spec.literals["o0"].collection.literals[0].scalar.blob.uri == "gs://my-bucket/two"
        assert dynamic_job_spec.literals["o0"].collection.literals[1].scalar.blob.uri == "gs://my-bucket/four"
Ejemplo n.º 18
0
def test_lps(resource_type):
    ref_entity = get_reference_entity(
        resource_type,
        "proj",
        "dom",
        "app.other.flyte_entity",
        "123",
        inputs=kwtypes(a=str, b=int),
        outputs={},
    )

    ctx = context_manager.FlyteContext.current_context()
    with pytest.raises(Exception) as e:
        ref_entity()
    assert "You must mock this out" in f"{e}"

    with context_manager.FlyteContextManager.with_context(
            ctx.with_new_compilation_state()) as ctx:
        with pytest.raises(Exception) as e:
            ref_entity()
        assert "Input was not specified" in f"{e}"

        output = ref_entity(a="hello", b=3)
        assert isinstance(output, VoidPromise)

    @workflow
    def wf1(a: str, b: int):
        ref_entity(a=a, b=b)

    serialization_settings = flytekit.configuration.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, wf1)
    assert len(wf_spec.template.interface.inputs) == 2
    assert len(wf_spec.template.interface.outputs) == 0
    assert len(wf_spec.template.nodes) == 1
    if resource_type == _identifier_model.ResourceType.LAUNCH_PLAN:
        assert wf_spec.template.nodes[
            0].workflow_node.launchplan_ref.project == "proj"
        assert wf_spec.template.nodes[
            0].workflow_node.launchplan_ref.name == "app.other.flyte_entity"
    else:
        assert wf_spec.template.nodes[
            0].task_node.reference_id.project == "proj"
        assert wf_spec.template.nodes[
            0].task_node.reference_id.name == "app.other.flyte_entity"
Ejemplo n.º 19
0
def test_fast_pod_task_serialization():
    pod = Pod(
        pod_spec=V1PodSpec(restart_policy="OnFailure",
                           containers=[V1Container(name="primary")]),
        primary_container_name="primary",
    )

    @task(task_config=pod, environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        fast_serialization_settings=FastSerializationSettings(enabled=True),
    )
    serialized = get_serializable(OrderedDict(), serialization_settings,
                                  simple_pod_task)

    assert serialized.template.k8s_pod.pod_spec["containers"][0]["args"] == [
        "pyflyte-fast-execute",
        "--additional-distribution",
        "{{ .remote_package_path }}",
        "--dest-dir",
        "{{ .dest_dir }}",
        "--",
        "pyflyte-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--checkpoint-path",
        "{{.checkpointOutputPrefix}}",
        "--prev-checkpoint",
        "{{.prevCheckpointPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "tests.test_pod",
        "task-name",
        "simple_pod_task",
    ]
Ejemplo n.º 20
0
def test_serialization():
    snowflake_task = SnowflakeTask(
        name="flytekit.demo.snowflake_task.query",
        inputs=kwtypes(ds=str),
        task_config=SnowflakeConfig(account="snowflake",
                                    warehouse="my_warehouse",
                                    schema="my_schema",
                                    database="my_database"),
        query_template=query_template,
        # the schema literal's backend uri will be equal to the value of .raw_output_data
        output_schema_type=FlyteSchema,
    )

    @workflow
    def my_wf(ds: str) -> FlyteSchema:
        return snowflake_task(ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )

    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 snowflake_task)

    assert "{{ .rawOutputDataPrefix" in task_spec.template.sql.statement
    assert "insert overwrite directory" in task_spec.template.sql.statement
    assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI
    assert "snowflake" == task_spec.template.config["account"]
    assert "my_warehouse" == task_spec.template.config["warehouse"]
    assert "my_schema" == task_spec.template.config["schema"]
    assert "my_database" == task_spec.template.config["database"]
    assert len(task_spec.template.interface.inputs) == 1
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(),
                                           serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs[
        "o0"].type.schema is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.var == "results"
Ejemplo n.º 21
0
def test_wf1_with_fast_dynamic():
    @task
    def t1(a: int) -> str:
        a = a + 2
        return "fast-" + str(a)

    @dynamic
    def my_subwf(a: int) -> typing.List[str]:
        s = []
        for i in range(a):
            s.append(t1(a=i))
        return s

    @workflow
    def my_wf(a: int) -> typing.List[str]:
        v = my_subwf(a=a)
        return v

    with context_manager.FlyteContextManager.with_context(
            context_manager.FlyteContextManager.current_context(
            ).with_serialization_settings(
                flytekit.configuration.SerializationSettings(
                    project="test_proj",
                    domain="test_domain",
                    version="abc",
                    image_config=ImageConfig(
                        Image(name="name", fqn="image", tag="name")),
                    env={},
                    fast_serialization_settings=FastSerializationSettings(
                        enabled=True,
                        destination_dir="/User/flyte/workflows",
                        distribution_location="s3://my-s3-bucket/fast/123",
                    ),
                ))) as ctx:
        with context_manager.FlyteContextManager.with_context(
                ctx.with_execution_state(
                    ctx.execution_state.with_params(
                        mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx:
            input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5})
            dynamic_job_spec = my_subwf.dispatch_execute(
                ctx, input_literal_map)
            assert len(dynamic_job_spec._nodes) == 5
            assert len(dynamic_job_spec.tasks) == 1
            args = " ".join(dynamic_job_spec.tasks[0].container.args)
            assert args.startswith(
                "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 "
                "--dest-dir /User/flyte/workflows")

    assert context_manager.FlyteContextManager.size() == 1
Ejemplo n.º 22
0
def test_serialization():
    square = ContainerTask(
        name="square",
        input_data_dir="/var/inputs",
        output_data_dir="/var/outputs",
        inputs=kwtypes(val=int),
        outputs=kwtypes(out=int),
        image="alpine",
        command=[
            "sh", "-c",
            "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out"
        ],
    )

    sum = ContainerTask(
        name="sum",
        input_data_dir="/var/flyte/inputs",
        output_data_dir="/var/flyte/outputs",
        inputs=kwtypes(x=int, y=int),
        outputs=kwtypes(out=int),
        image="alpine",
        command=[
            "sh", "-c",
            "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out"
        ],
    )

    @workflow
    def raw_container_wf(val1: int, val2: int) -> int:
        return sum(x=square(val=val1), y=square(val=val2))

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = flytekit.configuration.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings,
                               raw_container_wf)
    assert wf_spec is not None
    assert wf_spec.template is not None
    assert len(wf_spec.template.nodes) == 3
    sqn_spec = get_serializable(OrderedDict(), serialization_settings, square)
    assert sqn_spec.template.container.image == "alpine"
    sumn_spec = get_serializable(OrderedDict(), serialization_settings, sum)
    assert sumn_spec.template.container.image == "alpine"
Ejemplo n.º 23
0
def test_map_pod_task_serialization():
    pod = Pod(
        pod_spec=V1PodSpec(restart_policy="OnFailure",
                           containers=[V1Container(name="primary")]),
        primary_container_name="primary",
    )

    @task(task_config=pod, environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    mapped_task = map_task(simple_pod_task, metadata=TaskMetadata(retries=1))
    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    # Test that target is correctly serialized with an updated command
    pod_spec = mapped_task.get_k8s_pod(serialization_settings).pod_spec

    assert len(pod_spec["containers"]) == 1
    assert pod_spec["containers"][0]["args"] == [
        "pyflyte-map-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--checkpoint-path",
        "{{.checkpointOutputPrefix}}",
        "--prev-checkpoint",
        "{{.prevCheckpointPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "tests.test_pod",
        "task-name",
        "simple_pod_task",
    ]
    assert {
        "primary_container_name": "primary"
    } == mapped_task.get_config(serialization_settings)
Ejemplo n.º 24
0
def test_serialization():
    hive_task = HiveTask(
        name="flytekit.demo.hive_task.hivequery1",
        inputs=kwtypes(my_schema=FlyteSchema, ds=str),
        config=HiveConfig(cluster_label="flyte"),
        query_template="""
            set engine=tez;
            insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet  -- will be unique per retry
            select *
            from blah
            where ds = '{{ .Inputs.ds }}' and uri = '{{ .inputs.my_schema }}'
        """,
        # the schema literal's backend uri will be equal to the value of .raw_output_data
        output_schema_type=FlyteSchema,
    )

    @workflow
    def my_wf(in_schema: FlyteSchema, ds: str) -> FlyteSchema:
        return hive_task(my_schema=in_schema, ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 hive_task)
    assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["query"][
        "query"]
    assert "insert overwrite directory" in task_spec.template.custom["query"][
        "query"]
    assert len(task_spec.template.interface.inputs) == 2
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(),
                                           serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs[
        "o0"].type.schema is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.var == "results"
Ejemplo n.º 25
0
def test_sql_command():
    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = flytekit.configuration.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img, images=[default_img]),
    )
    srz_t = get_serializable(OrderedDict(), serialization_settings, not_tk)
    assert srz_t.template.container.args[-5:] == [
        "--resolver",
        "flytekit.core.python_customized_container_task.default_task_template_resolver",
        "--",
        "{{.taskTemplatePath}}",
        "flytekitplugins.sqlalchemy.task.SQLAlchemyTaskExecutor",
    ]
Ejemplo n.º 26
0
def test_lp_with_output():
    ref_lp = get_reference_entity(
        _identifier_model.ResourceType.LAUNCH_PLAN,
        "proj",
        "dom",
        "app.other.flyte_entity",
        "123",
        inputs=kwtypes(a=str, b=int),
        outputs=kwtypes(x=bool, y=int),
    )

    @task
    def t1() -> (str, int):
        return "hello", 88

    @task
    def t2(q: bool, r: int) -> str:
        return f"q: {q} r: {r}"

    @workflow
    def wf1() -> str:
        t1_str, t1_int = t1()
        x_out, y_out = ref_lp(a=t1_str, b=t1_int)
        return t2(q=x_out, r=y_out)

    @patch(ref_lp)
    def inner_test(ref_mock):
        ref_mock.return_value = (False, 30)
        x = wf1()
        assert x == "q: False r: 30"
        ref_mock.assert_called_with(a="hello", b=88)

    inner_test()

    serialization_settings = flytekit.configuration.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, wf1)
    assert wf_spec.template.nodes[
        1].workflow_node.launchplan_ref.project == "proj"
    assert wf_spec.template.nodes[
        1].workflow_node.launchplan_ref.name == "app.other.flyte_entity"
Ejemplo n.º 27
0
def test_ref():
    assert ref_t1.id.project == "flytesnacks"
    assert ref_t1.id.domain == "development"
    assert ref_t1.id.name == "recipes.aaa.simple.join_strings"
    assert ref_t1.id.version == "553018f39e519bdb2597b652639c30ce16b99c79"

    serialization_settings = flytekit.configuration.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    spec = get_serializable(OrderedDict(), serialization_settings, ref_t1)
    assert isinstance(spec, ReferenceSpec)
    assert isinstance(spec.template, ReferenceTemplate)
    assert spec.template.id == ref_t1.id
    assert spec.template.resource_type == _identifier_model.ResourceType.TASK
Ejemplo n.º 28
0
def test_interruptible_override(interruptible):
    @task
    def t1(a: str) -> str:
        return f"*~*~*~{a}*~*~*~"

    @workflow
    def my_wf(a: str) -> str:
        return t1(a=a).with_overrides(interruptible=interruptible)

    serialization_settings = flytekit.configuration.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert len(wf_spec.template.nodes) == 1
    assert wf_spec.template.nodes[0].metadata.interruptible == interruptible
Ejemplo n.º 29
0
def test_ref_sub_wf():
    ref_entity = get_reference_entity(
        _identifier_model.ResourceType.WORKFLOW,
        "proj",
        "dom",
        "app.other.sub_wf",
        "123",
        inputs=kwtypes(a=str, b=int),
        outputs={},
    )

    ctx = context_manager.FlyteContext.current_context()
    with pytest.raises(Exception) as e:
        ref_entity()
    assert "You must mock this out" in f"{e}"

    with context_manager.FlyteContextManager.with_context(
            ctx.with_new_compilation_state()) as ctx:
        with pytest.raises(Exception) as e:
            ref_entity()
        assert "Input was not specified" in f"{e}"

        output = ref_entity(a="hello", b=3)
        assert isinstance(output, VoidPromise)

    @workflow
    def wf1(a: str, b: int):
        ref_entity(a=a, b=b)

    serialization_settings = flytekit.configuration.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    with pytest.raises(Exception, match="currently unsupported"):
        # Subworkflow as references don't work (probably ever). The reason is because we'd need to make a network call
        # to admin to get the structure of the subworkflow and the whole point of reference entities is that there
        # is no network call.
        get_serializable(OrderedDict(), serialization_settings, wf1)
Ejemplo n.º 30
0
def test_serialization():
    athena_task = AthenaTask(
        name="flytekit.demo.athena_task.query",
        inputs=kwtypes(ds=str),
        task_config=AthenaConfig(database="mnist", catalog="my_catalog", workgroup="my_wg"),
        query_template="""
            insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet
            select *
            from blah
            where ds = '{{ .Inputs.ds }}'
        """,
        # the schema literal's backend uri will be equal to the value of .raw_output_data
        output_schema_type=FlyteSchema,
    )

    @workflow
    def my_wf(ds: str) -> FlyteSchema:
        return athena_task(ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img, images=[default_img]),
        env={},
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings, athena_task)
    assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["statement"]
    assert "insert overwrite directory" in task_spec.template.custom["statement"]
    assert "mnist" == task_spec.template.custom["schema"]
    assert "my_catalog" == task_spec.template.custom["catalog"]
    assert "my_wg" == task_spec.template.custom["routingGroup"]
    assert len(task_spec.template.interface.inputs) == 1
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs["o0"].type.schema is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results"