Пример #1
0
def test_pod_task_serialized():
    pod = Pod(pod_spec=get_pod_spec(),
              primary_container_name="an undefined container")

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")
    ssettings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    serialized = get_serializable(OrderedDict(), ssettings, simple_pod_task)
    assert serialized.task_type_version == 1
    assert serialized.config[
        "primary_container_name"] == "an undefined container"
Пример #2
0
def test_pod_task():
    pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container")

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")

    pod_spec = simple_pod_task.get_k8s_pod(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        )).pod_spec

    assert pod_spec["restartPolicy"] == "OnFailure"
    assert len(pod_spec["containers"]) == 2
    primary_container = pod_spec["containers"][0]
    assert primary_container["name"] == "a container"
    assert primary_container["args"] == [
        "pyflyte-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "plugins.tests.pod.test_pod",
        "task-name",
        "simple_pod_task",
    ]
    assert primary_container["volumeMounts"][0]["mountPath"] == "some/where"
    assert primary_container["volumeMounts"][0]["name"] == "volume mount"
    assert primary_container["resources"] == {
        "requests": {
            "cpu": "10"
        },
        "limits": {
            "gpu": "2"
        },
    }
    assert primary_container["env"] == [{"name": "FOO", "value": "bar"}]
    assert pod_spec["containers"][1]["name"] == "another container"
Пример #3
0
def test_fast_pod_task_serialization():
    pod = Pod(
        pod_spec=V1PodSpec(restart_policy="OnFailure",
                           containers=[V1Container(name="primary")]),
        primary_container_name="primary",
    )

    @task(task_config=pod, environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        fast_serialization_settings=FastSerializationSettings(enabled=True),
    )
    serialized = get_serializable(OrderedDict(), serialization_settings,
                                  simple_pod_task)

    assert serialized.template.k8s_pod.pod_spec["containers"][0]["args"] == [
        "pyflyte-fast-execute",
        "--additional-distribution",
        "{{ .remote_package_path }}",
        "--dest-dir",
        "{{ .dest_dir }}",
        "--",
        "pyflyte-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "plugins.tests.pod.test_pod",
        "task-name",
        "simple_pod_task",
    ]
Пример #4
0
def test_map_pod_task_serialization():
    pod = Pod(
        pod_spec=V1PodSpec(restart_policy="OnFailure",
                           containers=[V1Container(name="primary")]),
        primary_container_name="primary",
    )

    @task(task_config=pod, environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    mapped_task = map_task(simple_pod_task, metadata=TaskMetadata(retries=1))
    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    # Test that target is correctly serialized with an updated command
    pod_spec = mapped_task.get_k8s_pod(serialization_settings).pod_spec

    assert len(pod_spec["containers"]) == 1
    assert pod_spec["containers"][0]["args"] == [
        "pyflyte-map-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "plugins.tests.pod.test_pod",
        "task-name",
        "simple_pod_task",
    ]
    assert {
        "primary_container_name": "primary"
    } == mapped_task.get_config(serialization_settings)
Пример #5
0
def test_pod_task_undefined_primary():
    pod = Pod(pod_spec=get_pod_spec(),
              primary_container_name="an undefined container")

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")
    custom = simple_pod_task.get_custom(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))

    assert len(custom["containers"]) == 3

    primary_container = custom["containers"][2]
    assert primary_container["name"] == "an undefined container"

    config = simple_pod_task.get_config(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))
    assert config["primary_container_name"] == "an undefined container"
Пример #6
0
def test_pod_task_deserialization():
    pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container")

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")

    custom = simple_pod_task.get_custom(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))

    # Test that custom is correctly serialized by deserializing it with the python API client
    response = MagicMock()
    response.data = json.dumps(custom)
    deserialized_pod_spec = ApiClient().deserialize(response, V1PodSpec)

    assert deserialized_pod_spec.restart_policy == "OnFailure"
    assert len(deserialized_pod_spec.containers) == 2
    primary_container = deserialized_pod_spec.containers[0]
    assert primary_container.name == "a container"
    assert primary_container.args == [
        "pyflyte-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "plugins.tests.pod.test_pod",
        "task-name",
        "simple_pod_task",
    ]
    assert primary_container.volume_mounts[0].mount_path == "some/where"
    assert primary_container.volume_mounts[0].name == "volume mount"
    assert primary_container.resources == V1ResourceRequirements(
        limits={"gpu": "2"}, requests={"cpu": "10"})
    assert primary_container.env == [V1EnvVar(name="FOO", value="bar")]
    assert deserialized_pod_spec.containers[1].name == "another container"

    config = simple_pod_task.get_config(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))
    assert config["primary_container_name"] == "a container"
Пример #7
0
def test_dynamic_pod_task():
    dynamic_pod = Pod(pod_spec=get_pod_spec(),
                      primary_container_name="a container")

    @task
    def t1(a: int) -> int:
        return a + 10

    @dynamic(task_config=dynamic_pod,
             requests=Resources(cpu="10"),
             limits=Resources(gpu="2"),
             environment={"FOO": "bar"})
    def dynamic_pod_task(a: int) -> List[int]:
        s = []
        for i in range(a):
            s.append(t1(a=i))
        return s

    assert isinstance(dynamic_pod_task, PodFunctionTask)
    default_img = Image(name="default", fqn="test", tag="tag")

    custom = dynamic_pod_task.get_custom(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))
    assert len(custom["containers"]) == 2
    primary_container = custom["containers"][0]
    assert isinstance(dynamic_pod_task.task_config, Pod)
    assert primary_container["resources"] == {
        "requests": {
            "cpu": "10"
        },
        "limits": {
            "gpu": "2"
        },
    }

    config = dynamic_pod_task.get_config(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))
    assert config["primary_container_name"] == "a container"

    with context_manager.FlyteContext.current_context(
    ).new_serialization_settings(serialization_settings=SerializationSettings(
            project="test_proj",
            domain="test_domain",
            version="abc",
            image_config=ImageConfig(
                Image(name="name", fqn="image", tag="name")),
            env={},
    )) as ctx:
        with ctx.new_execution_context(
                mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
            dynamic_job_spec = dynamic_pod_task.compile_into_workflow(
                ctx, dynamic_pod_task._task_function, a=5)
            assert len(dynamic_job_spec._nodes) == 5