Пример #1
0
def test_dynamic_pod_task():
    dynamic_pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container")

    @task
    def t1(a: int) -> int:
        return a + 10

    @dynamic(
        task_config=dynamic_pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}
    )
    def dynamic_pod_task(a: int) -> List[int]:
        s = []
        for i in range(a):
            s.append(t1(a=i))
        return s

    assert isinstance(dynamic_pod_task, PodFunctionTask)
    default_img = Image(name="default", fqn="test", tag="tag")

    custom = dynamic_pod_task.get_custom(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img, images=[default_img]),
        )
    )
    assert len(custom["podSpec"]["containers"]) == 2
    primary_container = custom["podSpec"]["containers"][0]
    assert isinstance(dynamic_pod_task.task_config, Pod)
    assert primary_container["resources"] == {
        "requests": {"cpu": {"string": "10"}},
        "limits": {"gpu": {"string": "2"}},
    }

    with context_manager.FlyteContext.current_context().new_serialization_settings(
        serialization_settings=SerializationSettings(
            project="test_proj",
            domain="test_domain",
            version="abc",
            image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
            env={},
        )
    ) as ctx:
        with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
            dynamic_job_spec = dynamic_pod_task.compile_into_workflow(ctx, dynamic_pod_task._task_function, a=5)
            assert len(dynamic_job_spec._nodes) == 5
Пример #2
0
def test_pod_task():
    pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container")

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")

    custom = simple_pod_task.get_custom(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))
    assert custom["podSpec"]["restartPolicy"] == "OnFailure"
    assert len(custom["podSpec"]["containers"]) == 2
    primary_container = custom["podSpec"]["containers"][0]
    assert primary_container["name"] == "a container"
    assert primary_container["args"] == [
        "pyflyte-execute",
        "--task-module",
        "pod.test_pod",
        "--task-name",
        "simple_pod_task",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
    ]
    assert primary_container["volumeMounts"] == [{
        "mountPath": "some/where",
        "name": "volume mount"
    }]
    assert primary_container["resources"] == {
        "requests": {
            "cpu": {
                "string": "10"
            }
        },
        "limits": {
            "gpu": {
                "string": "2"
            }
        },
    }
    assert primary_container["env"] == [{"name": "FOO", "value": "bar"}]
    assert custom["podSpec"]["containers"][1]["name"] == "another container"
    assert custom["primaryContainerName"] == "a container"
Пример #3
0
def test_pod_task_serialized():
    pod = Pod(pod_spec=get_pod_spec(),
              primary_container_name="an undefined container")

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")
    ssettings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    serialized = get_serializable(OrderedDict(), ssettings, simple_pod_task)
    assert serialized.task_type_version == 1
    assert serialized.config[
        "primary_container_name"] == "an undefined container"
Пример #4
0
def test_query_no_inputs_or_outputs():
    hive_task = HiveTask(
        name="flytekit.demo.hive_task.hivequery1",
        inputs={},
        task_config=HiveConfig(cluster_label="flyte"),
        query_template="""
            insert into extant_table (1, 'two')
        """,
        output_schema_type=None,
    )

    @workflow
    def my_wf():
        hive_task()

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )
    sdk_task = get_serializable(serialization_settings, hive_task)
    assert len(sdk_task.interface.inputs) == 0
    assert len(sdk_task.interface.outputs) == 0

    get_serializable(serialization_settings, my_wf)
Пример #5
0
def test_tensorflow_task():
    @task(
        task_config=TfJob(num_workers=10,
                          per_replica_requests=Resources(cpu="1"),
                          num_ps_replicas=1,
                          num_chief_replicas=1),
        cache=True,
        cache_version="1",
    )
    def my_tensorflow_task(x: int, y: str) -> int:
        return x

    assert my_tensorflow_task(x=10, y="hello") == 10

    assert my_tensorflow_task.task_config is not None

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    assert my_tensorflow_task.get_custom(settings) == {
        "workers": 10,
        "psReplicas": 1,
        "chiefReplicas": 1
    }
    assert my_tensorflow_task.resources.limits == Resources()
    assert my_tensorflow_task.resources.requests == Resources(cpu="1")
    assert my_tensorflow_task.task_type == "tensorflow"
Пример #6
0
def test_spark_task():
    @task(task_config=Spark(spark_conf={"spark": "1"}))
    def my_spark(a: str) -> int:
        session = flytekit.current_context().spark_session
        assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
        return 10

    assert my_spark.task_config is not None
    assert my_spark.task_config.spark_conf == {"spark": "1"}

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    retrieved_settings = my_spark.get_custom(settings)
    assert retrieved_settings["sparkConf"] == {"spark": "1"}

    pb = ExecutionParameters.new_builder()
    pb.working_dir = "/tmp"
    pb.execution_id = "ex:local:local:local"
    p = pb.build()
    new_p = my_spark.pre_execute(p)
    assert new_p is not None
    assert new_p.has_attr("SPARK_SESSION")

    assert my_spark.sess is not None
    configs = my_spark.sess.sparkContext.getConf().getAll()
    assert ("spark", "1") in configs
    assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs
Пример #7
0
def test_serialization():
    athena_task = AthenaTask(
        name="flytekit.demo.athena_task.query",
        inputs=kwtypes(ds=str),
        task_config=AthenaConfig(database="mnist",
                                 catalog="my_catalog",
                                 workgroup="my_wg"),
        query_template="""
            insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet
            select *
            from blah
            where ds = '{{ .Inputs.ds }}'
        """,
        # the schema literal's backend uri will be equal to the value of .raw_output_data
        output_schema_type=FlyteSchema,
    )

    @workflow
    def my_wf(ds: str) -> FlyteSchema:
        return athena_task(ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 athena_task)
    assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["statement"]
    assert "insert overwrite directory" in task_spec.template.custom[
        "statement"]
    assert "mnist" == task_spec.template.custom["schema"]
    assert "my_catalog" == task_spec.template.custom["catalog"]
    assert "my_wg" == task_spec.template.custom["routingGroup"]
    assert len(task_spec.template.interface.inputs) == 1
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(),
                                           serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs[
        "o0"].type.schema is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.var == "results"
Пример #8
0
def test_serialization():
    hive_task = HiveTask(
        name="flytekit.demo.hive_task.hivequery1",
        inputs=kwtypes(my_schema=FlyteSchema, ds=str),
        config=HiveConfig(cluster_label="flyte"),
        query_template="""
            set engine=tez;
            insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet  -- will be unique per retry
            select *
            from blah
            where ds = '{{ .Inputs.ds }}' and uri = '{{ .inputs.my_schema }}'
        """,
        # the schema literal's backend uri will be equal to the value of .raw_output_data
        output_schema_type=FlyteSchema,
    )

    @workflow
    def my_wf(in_schema: FlyteSchema, ds: str) -> FlyteSchema:
        return hive_task(my_schema=in_schema, ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 hive_task)
    assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["query"][
        "query"]
    assert "insert overwrite directory" in task_spec.template.custom["query"][
        "query"]
    assert len(task_spec.template.interface.inputs) == 2
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(),
                                           serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs[
        "o0"].type.schema is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.var == "results"
Пример #9
0
def test_fast_pod_task_serialization():
    pod = Pod(
        pod_spec=V1PodSpec(restart_policy="OnFailure",
                           containers=[V1Container(name="primary")]),
        primary_container_name="primary",
    )

    @task(task_config=pod, environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        fast_serialization_settings=FastSerializationSettings(enabled=True),
    )
    serialized = get_serializable(OrderedDict(), serialization_settings,
                                  simple_pod_task)

    assert serialized.template.k8s_pod.pod_spec["containers"][0]["args"] == [
        "pyflyte-fast-execute",
        "--additional-distribution",
        "{{ .remote_package_path }}",
        "--dest-dir",
        "{{ .dest_dir }}",
        "--",
        "pyflyte-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "plugins.tests.pod.test_pod",
        "task-name",
        "simple_pod_task",
    ]
Пример #10
0
def test_map_pod_task_serialization():
    pod = Pod(
        pod_spec=V1PodSpec(restart_policy="OnFailure",
                           containers=[V1Container(name="primary")]),
        primary_container_name="primary",
    )

    @task(task_config=pod, environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    mapped_task = map_task(simple_pod_task, metadata=TaskMetadata(retries=1))
    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    # Test that target is correctly serialized with an updated command
    pod_spec = mapped_task.get_k8s_pod(serialization_settings).pod_spec

    assert len(pod_spec["containers"]) == 1
    assert pod_spec["containers"][0]["args"] == [
        "pyflyte-map-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "plugins.tests.pod.test_pod",
        "task-name",
        "simple_pod_task",
    ]
    assert {
        "primary_container_name": "primary"
    } == mapped_task.get_config(serialization_settings)
Пример #11
0
def test_pod_task_undefined_primary():
    pod = Pod(pod_spec=get_pod_spec(),
              primary_container_name="an undefined container")

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")
    custom = simple_pod_task.get_custom(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))

    assert len(custom["containers"]) == 3

    primary_container = custom["containers"][2]
    assert primary_container["name"] == "an undefined container"

    config = simple_pod_task.get_config(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))
    assert config["primary_container_name"] == "an undefined container"
Пример #12
0
def test_spark_task():
    @task(task_config=Spark(spark_conf={"spark": "1"}))
    def my_spark(a: str) -> int:
        session = flytekit.current_context().spark_session
        assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
        return 10

    assert my_spark.task_config is not None
    assert my_spark.task_config.spark_conf == {"spark": "1"}

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    retrieved_settings = my_spark.get_custom(settings)
    assert retrieved_settings["sparkConf"] == {"spark": "1"}
Пример #13
0
def test_pod_task_deserialization():
    pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container")

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")

    custom = simple_pod_task.get_custom(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))

    # Test that custom is correctly serialized by deserializing it with the python API client
    response = MagicMock()
    response.data = json.dumps(custom)
    deserialized_pod_spec = ApiClient().deserialize(response, V1PodSpec)

    assert deserialized_pod_spec.restart_policy == "OnFailure"
    assert len(deserialized_pod_spec.containers) == 2
    primary_container = deserialized_pod_spec.containers[0]
    assert primary_container.name == "a container"
    assert primary_container.args == [
        "pyflyte-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "plugins.tests.pod.test_pod",
        "task-name",
        "simple_pod_task",
    ]
    assert primary_container.volume_mounts[0].mount_path == "some/where"
    assert primary_container.volume_mounts[0].name == "volume mount"
    assert primary_container.resources == V1ResourceRequirements(
        limits={"gpu": "2"}, requests={"cpu": "10"})
    assert primary_container.env == [V1EnvVar(name="FOO", value="bar")]
    assert deserialized_pod_spec.containers[1].name == "another container"

    config = simple_pod_task.get_config(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))
    assert config["primary_container_name"] == "a container"