def test_dynamic_pod_task(): dynamic_pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") @task def t1(a: int) -> int: return a + 10 @dynamic( task_config=dynamic_pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"} ) def dynamic_pod_task(a: int) -> List[int]: s = [] for i in range(a): s.append(t1(a=i)) return s assert isinstance(dynamic_pod_task, PodFunctionTask) default_img = Image(name="default", fqn="test", tag="tag") custom = dynamic_pod_task.get_custom( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) ) assert len(custom["podSpec"]["containers"]) == 2 primary_container = custom["podSpec"]["containers"][0] assert isinstance(dynamic_pod_task.task_config, Pod) assert primary_container["resources"] == { "requests": {"cpu": {"string": "10"}}, "limits": {"gpu": {"string": "2"}}, } with context_manager.FlyteContext.current_context().new_serialization_settings( serialization_settings=SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ) as ctx: with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: dynamic_job_spec = dynamic_pod_task.compile_into_workflow(ctx, dynamic_pod_task._task_function, a=5) assert len(dynamic_job_spec._nodes) == 5
def test_pod_task(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") custom = simple_pod_task.get_custom( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert custom["podSpec"]["restartPolicy"] == "OnFailure" assert len(custom["podSpec"]["containers"]) == 2 primary_container = custom["podSpec"]["containers"][0] assert primary_container["name"] == "a container" assert primary_container["args"] == [ "pyflyte-execute", "--task-module", "pod.test_pod", "--task-name", "simple_pod_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", ] assert primary_container["volumeMounts"] == [{ "mountPath": "some/where", "name": "volume mount" }] assert primary_container["resources"] == { "requests": { "cpu": { "string": "10" } }, "limits": { "gpu": { "string": "2" } }, } assert primary_container["env"] == [{"name": "FOO", "value": "bar"}] assert custom["podSpec"]["containers"][1]["name"] == "another container" assert custom["primaryContainerName"] == "a container"
def test_pod_task_serialized(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="an undefined container") @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") ssettings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) serialized = get_serializable(OrderedDict(), ssettings, simple_pod_task) assert serialized.task_type_version == 1 assert serialized.config[ "primary_container_name"] == "an undefined container"
def test_query_no_inputs_or_outputs(): hive_task = HiveTask( name="flytekit.demo.hive_task.hivequery1", inputs={}, task_config=HiveConfig(cluster_label="flyte"), query_template=""" insert into extant_table (1, 'two') """, output_schema_type=None, ) @workflow def my_wf(): hive_task() default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) sdk_task = get_serializable(serialization_settings, hive_task) assert len(sdk_task.interface.inputs) == 0 assert len(sdk_task.interface.outputs) == 0 get_serializable(serialization_settings, my_wf)
def test_tensorflow_task(): @task( task_config=TfJob(num_workers=10, per_replica_requests=Resources(cpu="1"), num_ps_replicas=1, num_chief_replicas=1), cache=True, cache_version="1", ) def my_tensorflow_task(x: int, y: str) -> int: return x assert my_tensorflow_task(x=10, y="hello") == 10 assert my_tensorflow_task.task_config is not None default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) assert my_tensorflow_task.get_custom(settings) == { "workers": 10, "psReplicas": 1, "chiefReplicas": 1 } assert my_tensorflow_task.resources.limits == Resources() assert my_tensorflow_task.resources.requests == Resources(cpu="1") assert my_tensorflow_task.task_type == "tensorflow"
def test_spark_task(): @task(task_config=Spark(spark_conf={"spark": "1"})) def my_spark(a: str) -> int: session = flytekit.current_context().spark_session assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" return 10 assert my_spark.task_config is not None assert my_spark.task_config.spark_conf == {"spark": "1"} default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) retrieved_settings = my_spark.get_custom(settings) assert retrieved_settings["sparkConf"] == {"spark": "1"} pb = ExecutionParameters.new_builder() pb.working_dir = "/tmp" pb.execution_id = "ex:local:local:local" p = pb.build() new_p = my_spark.pre_execute(p) assert new_p is not None assert new_p.has_attr("SPARK_SESSION") assert my_spark.sess is not None configs = my_spark.sess.sparkContext.getConf().getAll() assert ("spark", "1") in configs assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs
def test_pod_task_undefined_primary(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="an undefined container") @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") custom = simple_pod_task.get_custom( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert len(custom["containers"]) == 3 primary_container = custom["containers"][2] assert primary_container["name"] == "an undefined container" config = simple_pod_task.get_config( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert config["primary_container_name"] == "an undefined container"
def test_serialization(): athena_task = AthenaTask( name="flytekit.demo.athena_task.query", inputs=kwtypes(ds=str), task_config=AthenaConfig(database="mnist", catalog="my_catalog", workgroup="my_wg"), query_template=""" insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet select * from blah where ds = '{{ .Inputs.ds }}' """, # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @workflow def my_wf(ds: str) -> FlyteSchema: return athena_task(ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, athena_task) assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["statement"] assert "insert overwrite directory" in task_spec.template.custom[ "statement"] assert "mnist" == task_spec.template.custom["schema"] assert "my_catalog" == task_spec.template.custom["catalog"] assert "my_wg" == task_spec.template.custom["routingGroup"] assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs[ "o0"].type.schema is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.var == "results"
def test_serialization(): hive_task = HiveTask( name="flytekit.demo.hive_task.hivequery1", inputs=kwtypes(my_schema=FlyteSchema, ds=str), config=HiveConfig(cluster_label="flyte"), query_template=""" set engine=tez; insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet -- will be unique per retry select * from blah where ds = '{{ .Inputs.ds }}' and uri = '{{ .inputs.my_schema }}' """, # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @workflow def my_wf(in_schema: FlyteSchema, ds: str) -> FlyteSchema: return hive_task(my_schema=in_schema, ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, hive_task) assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["query"][ "query"] assert "insert overwrite directory" in task_spec.template.custom["query"][ "query"] assert len(task_spec.template.interface.inputs) == 2 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs[ "o0"].type.schema is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.var == "results"
def test_fast_pod_task_serialization(): pod = Pod( pod_spec=V1PodSpec(restart_policy="OnFailure", containers=[V1Container(name="primary")]), primary_container_name="primary", ) @task(task_config=pod, environment={"FOO": "bar"}) def simple_pod_task(i: int): pass default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), fast_serialization_settings=FastSerializationSettings(enabled=True), ) serialized = get_serializable(OrderedDict(), serialization_settings, simple_pod_task) assert serialized.template.k8s_pod.pod_spec["containers"][0]["args"] == [ "pyflyte-fast-execute", "--additional-distribution", "{{ .remote_package_path }}", "--dest-dir", "{{ .dest_dir }}", "--", "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "plugins.tests.pod.test_pod", "task-name", "simple_pod_task", ]
def test_map_pod_task_serialization(): pod = Pod( pod_spec=V1PodSpec(restart_policy="OnFailure", containers=[V1Container(name="primary")]), primary_container_name="primary", ) @task(task_config=pod, environment={"FOO": "bar"}) def simple_pod_task(i: int): pass mapped_task = map_task(simple_pod_task, metadata=TaskMetadata(retries=1)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) # Test that target is correctly serialized with an updated command pod_spec = mapped_task.get_k8s_pod(serialization_settings).pod_spec assert len(pod_spec["containers"]) == 1 assert pod_spec["containers"][0]["args"] == [ "pyflyte-map-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "plugins.tests.pod.test_pod", "task-name", "simple_pod_task", ] assert { "primary_container_name": "primary" } == mapped_task.get_config(serialization_settings)
def test_spark_task(): @task(task_config=Spark(spark_conf={"spark": "1"})) def my_spark(a: str) -> int: session = flytekit.current_context().spark_session assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" return 10 assert my_spark.task_config is not None assert my_spark.task_config.spark_conf == {"spark": "1"} default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) retrieved_settings = my_spark.get_custom(settings) assert retrieved_settings["sparkConf"] == {"spark": "1"}
def test_pod_task_deserialization(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") custom = simple_pod_task.get_custom( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) # Test that custom is correctly serialized by deserializing it with the python API client response = MagicMock() response.data = json.dumps(custom) deserialized_pod_spec = ApiClient().deserialize(response, V1PodSpec) assert deserialized_pod_spec.restart_policy == "OnFailure" assert len(deserialized_pod_spec.containers) == 2 primary_container = deserialized_pod_spec.containers[0] assert primary_container.name == "a container" assert primary_container.args == [ "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "plugins.tests.pod.test_pod", "task-name", "simple_pod_task", ] assert primary_container.volume_mounts[0].mount_path == "some/where" assert primary_container.volume_mounts[0].name == "volume mount" assert primary_container.resources == V1ResourceRequirements( limits={"gpu": "2"}, requests={"cpu": "10"}) assert primary_container.env == [V1EnvVar(name="FOO", value="bar")] assert deserialized_pod_spec.containers[1].name == "another container" config = simple_pod_task.get_config( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert config["primary_container_name"] == "a container"