예제 #1
0
import typing
from collections import OrderedDict

import mock
import pytest

import flytekit.configuration
from flytekit import task, workflow
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core.condition import conditional
from flytekit.models.core.workflow import Node
from flytekit.tools.translator import get_serializable

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = SerializationSettings(
    project="project",
    domain="domain",
    version="version",
    env=None,
    image_config=ImageConfig(default_image=default_img, images=[default_img]),
)


@task
def five() -> int:
    return 5


@task
def square(n: float) -> float:
    """
def default_image_config():
    default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash")
    return ImageConfig(default_image=default_image)
예제 #3
0
    protocol_prefix,
)

my_cols = kwtypes(w=typing.Dict[str, typing.Dict[str, int]],
                  x=typing.List[typing.List[int]],
                  y=int,
                  z=str)

fields = [("some_int", pa.int32()), ("some_string", pa.string())]
arrow_schema = pa.schema(fields)

serialization_settings = flytekit.configuration.SerializationSettings(
    project="proj",
    domain="dom",
    version="123",
    image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")),
    env={},
)


def test_protocol():
    assert protocol_prefix("s3://my-s3-bucket/file") == "s3"
    assert protocol_prefix("/file") == "/"


def generate_pandas() -> pd.DataFrame:
    return pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]})


def test_types_pandas():
    pt = pd.DataFrame
예제 #4
0
def test_normal_task():
    @task
    def t1(a: str) -> str:
        return a + " world"

    @dynamic
    def my_subwf(a: int) -> typing.List[str]:
        s = []
        for i in range(a):
            s.append(t1(a=str(i)))
        return s

    @workflow
    def my_wf(a: str) -> (str, typing.List[str]):
        t1_node = create_node(t1, a=a)
        dyn_node = create_node(my_subwf, a=3)
        return t1_node.o0, dyn_node.o0

    r, x = my_wf(a="hello")
    assert r == "hello world"
    assert x == ["0 world", "1 world", "2 world"]

    serialization_settings = flytekit.configuration.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert len(wf_spec.template.nodes) == 2
    assert len(wf_spec.template.outputs) == 2

    @task
    def t2():
        ...

    @task
    def t3():
        ...

    @workflow
    def empty_wf():
        t2_node = create_node(t2)
        t3_node = create_node(t3)
        t3_node.runs_before(t2_node)

    # Test that VoidPromises can handle runs_before
    empty_wf()

    @workflow
    def empty_wf2():
        t2_node = create_node(t2)
        t3_node = create_node(t3)
        t3_node >> t2_node

    serialization_settings = flytekit.configuration.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, empty_wf)
    assert wf_spec.template.nodes[0].upstream_node_ids[0] == "n1"
    assert wf_spec.template.nodes[0].id == "n0"

    wf_spec = get_serializable(OrderedDict(), serialization_settings, empty_wf2)
    assert wf_spec.template.nodes[0].upstream_node_ids[0] == "n1"
    assert wf_spec.template.nodes[0].id == "n0"

    with pytest.raises(FlyteAssertion):

        @workflow
        def empty_wf2():
            create_node(t2, "foo")
예제 #5
0
def test_fast():
    REQUESTS_GPU = Resources(cpu="123m",
                             mem="234Mi",
                             ephemeral_storage="123M",
                             gpu="1")
    LIMITS_GPU = Resources(cpu="124M",
                           mem="235Mi",
                           ephemeral_storage="124M",
                           gpu="1")

    def get_minimal_pod_task_config() -> Pod:
        primary_container = V1Container(name="flytetask")
        pod_spec = V1PodSpec(containers=[primary_container])
        return Pod(pod_spec=pod_spec, primary_container_name="flytetask")

    @task(
        task_config=get_minimal_pod_task_config(),
        requests=REQUESTS_GPU,
        limits=LIMITS_GPU,
    )
    def pod_task_with_resources(dummy_input: str) -> str:
        return dummy_input

    @dynamic(requests=REQUESTS_GPU, limits=LIMITS_GPU)
    def dynamic_task_with_pod_subtask(dummy_input: str) -> str:
        pod_task_with_resources(dummy_input=dummy_input)
        return dummy_input

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        fast_serialization_settings=FastSerializationSettings(
            enabled=True,
            destination_dir="/User/flyte/workflows",
            distribution_location="s3://my-s3-bucket/fast/123",
        ),
    )

    with context_manager.FlyteContextManager.with_context(
            context_manager.FlyteContextManager.current_context(
            ).with_serialization_settings(serialization_settings)) as ctx:
        with context_manager.FlyteContextManager.with_context(
                ctx.with_execution_state(
                    ctx.execution_state.with_params(
                        mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx:
            input_literal_map = TypeEngine.dict_to_literal_map(
                ctx, {"dummy_input": "hi"})
            dynamic_job_spec = dynamic_task_with_pod_subtask.dispatch_execute(
                ctx, input_literal_map)
            # print(dynamic_job_spec)
            assert len(dynamic_job_spec._nodes) == 1
            assert len(dynamic_job_spec.tasks) == 1
            args = " ".join(
                dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0]
                ["args"])
            assert args.startswith(
                "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 "
                "--dest-dir /User/flyte/workflows")
            assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][
                "resources"]["limits"]["cpu"] == "124M"
            assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][
                "resources"]["requests"]["gpu"] == "1"

    assert context_manager.FlyteContextManager.size() == 1
예제 #6
0
def test_pod_task_deserialization():
    pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container")

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")

    target = simple_pod_task.get_k8s_pod(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))

    # Test that custom is correctly serialized by deserializing it with the python API client
    response = MagicMock()
    response.data = json.dumps(target.pod_spec)
    deserialized_pod_spec = ApiClient().deserialize(response, V1PodSpec)

    assert deserialized_pod_spec.restart_policy == "OnFailure"
    assert len(deserialized_pod_spec.containers) == 2
    primary_container = deserialized_pod_spec.containers[0]
    assert primary_container.name == "a container"
    assert primary_container.args == [
        "pyflyte-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--checkpoint-path",
        "{{.checkpointOutputPrefix}}",
        "--prev-checkpoint",
        "{{.prevCheckpointPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "tests.test_pod",
        "task-name",
        "simple_pod_task",
    ]
    assert primary_container.volume_mounts[0].mount_path == "some/where"
    assert primary_container.volume_mounts[0].name == "volume mount"
    assert primary_container.resources == V1ResourceRequirements(
        limits={"gpu": "2"}, requests={"cpu": "10"})
    assert primary_container.env == [V1EnvVar(name="FOO", value="bar")]
    assert deserialized_pod_spec.containers[1].name == "another container"

    config = simple_pod_task.get_config(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))
    assert config["primary_container_name"] == "a container"
예제 #7
0
def test_dynamic_pod_task():
    dynamic_pod = Pod(pod_spec=get_pod_spec(),
                      primary_container_name="a container")

    @task
    def t1(a: int) -> int:
        return a + 10

    @dynamic(
        task_config=dynamic_pod,
        requests=Resources(cpu="10"),
        limits=Resources(ephemeral_storage="1Gi", gpu="2"),
        environment={"FOO": "bar"},
    )
    def dynamic_pod_task(a: int) -> List[int]:
        s = []
        for i in range(a):
            s.append(t1(a=i))
        return s

    assert isinstance(dynamic_pod_task, PodFunctionTask)
    default_img = Image(name="default", fqn="test", tag="tag")

    pod_spec = dynamic_pod_task.get_k8s_pod(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        )).pod_spec
    assert len(pod_spec["containers"]) == 2
    primary_container = pod_spec["containers"][0]
    assert isinstance(dynamic_pod_task.task_config, Pod)
    assert primary_container["resources"] == {
        "requests": {
            "cpu": "10"
        },
        "limits": {
            "ephemeral-storage": "1Gi",
            "gpu": "2"
        },
    }

    config = dynamic_pod_task.get_config(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))
    assert config["primary_container_name"] == "a container"

    with context_manager.FlyteContextManager.with_context(
            context_manager.FlyteContext.current_context(
            ).with_serialization_settings(
                SerializationSettings(
                    project="test_proj",
                    domain="test_domain",
                    version="abc",
                    image_config=ImageConfig(
                        Image(name="name", fqn="image", tag="name")),
                    env={},
                ))) as ctx:
        with context_manager.FlyteContextManager.with_context(
                ctx.with_execution_state(
                    ctx.execution_state.with_params(
                        mode=ExecutionState.Mode.TASK_EXECUTION))) as ctx:
            dynamic_job_spec = dynamic_pod_task.compile_into_workflow(
                ctx, dynamic_pod_task._task_function, a=5)
            assert len(dynamic_job_spec._nodes) == 5
예제 #8
0
def test_pod_task():
    pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container")

    @task(
        task_config=pod,
        requests=Resources(cpu="10"),
        limits=Resources(ephemeral_storage="1Gi", gpu="2"),
        environment={"FOO": "bar"},
    )
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")

    pod_spec = simple_pod_task.get_k8s_pod(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        )).pod_spec

    assert pod_spec["restartPolicy"] == "OnFailure"
    assert len(pod_spec["containers"]) == 2
    primary_container = pod_spec["containers"][0]
    assert primary_container["name"] == "a container"
    assert primary_container["args"] == [
        "pyflyte-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--checkpoint-path",
        "{{.checkpointOutputPrefix}}",
        "--prev-checkpoint",
        "{{.prevCheckpointPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "tests.test_pod",
        "task-name",
        "simple_pod_task",
    ]
    assert primary_container["volumeMounts"][0]["mountPath"] == "some/where"
    assert primary_container["volumeMounts"][0]["name"] == "volume mount"
    assert primary_container["resources"] == {
        "requests": {
            "cpu": "10"
        },
        "limits": {
            "ephemeral-storage": "1Gi",
            "gpu": "2"
        },
    }
    assert primary_container["env"] == [{"name": "FOO", "value": "bar"}]
    assert pod_spec["containers"][1]["name"] == "another container"