def test_pod_task(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") custom = simple_pod_task.get_custom( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert custom["podSpec"]["restartPolicy"] == "OnFailure" assert len(custom["podSpec"]["containers"]) == 2 primary_container = custom["podSpec"]["containers"][0] assert primary_container["name"] == "a container" assert primary_container["args"] == [ "pyflyte-execute", "--task-module", "pod.test_pod", "--task-name", "simple_pod_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", ] assert primary_container["volumeMounts"] == [{ "mountPath": "some/where", "name": "volume mount" }] assert primary_container["resources"] == { "requests": { "cpu": { "string": "10" } }, "limits": { "gpu": { "string": "2" } }, } assert primary_container["env"] == [{"name": "FOO", "value": "bar"}] assert custom["podSpec"]["containers"][1]["name"] == "another container" assert custom["primaryContainerName"] == "a container"
def test_pod_task_serialized(): pod = Pod( pod_spec=get_pod_spec(), primary_container_name="an undefined container", labels={"label": "foo"}, annotations={"anno": "bar"}, ) @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") ssettings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) serialized = get_serializable(OrderedDict(), ssettings, simple_pod_task) assert serialized.template.task_type_version == 2 assert serialized.template.config[ "primary_container_name"] == "an undefined container" assert serialized.template.k8s_pod.metadata.labels == {"label": "foo"} assert serialized.template.k8s_pod.metadata.annotations == {"anno": "bar"} assert serialized.template.k8s_pod.pod_spec is not None
def test_fast_pod_task_serialization(): pod = Pod( pod_spec=V1PodSpec(restart_policy="OnFailure", containers=[V1Container(name="primary")]), primary_container_name="primary", ) @task(task_config=pod, environment={"FOO": "bar"}) def simple_pod_task(i: int): pass default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), fast_serialization_settings=FastSerializationSettings(enabled=True), ) serialized = get_serializable(OrderedDict(), serialization_settings, simple_pod_task) assert serialized.template.k8s_pod.pod_spec["containers"][0]["args"] == [ "pyflyte-fast-execute", "--additional-distribution", "{{ .remote_package_path }}", "--dest-dir", "{{ .dest_dir }}", "--", "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--checkpoint-path", "{{.checkpointOutputPrefix}}", "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.test_pod", "task-name", "simple_pod_task", ]
def test_map_pod_task_serialization(): pod = Pod( pod_spec=V1PodSpec(restart_policy="OnFailure", containers=[V1Container(name="primary")]), primary_container_name="primary", ) @task(task_config=pod, environment={"FOO": "bar"}) def simple_pod_task(i: int): pass mapped_task = map_task(simple_pod_task, metadata=TaskMetadata(retries=1)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) # Test that target is correctly serialized with an updated command pod_spec = mapped_task.get_k8s_pod(serialization_settings).pod_spec assert len(pod_spec["containers"]) == 1 assert pod_spec["containers"][0]["args"] == [ "pyflyte-map-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--checkpoint-path", "{{.checkpointOutputPrefix}}", "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.test_pod", "task-name", "simple_pod_task", ] assert { "primary_container_name": "primary" } == mapped_task.get_config(serialization_settings)
def test_dynamic_pod_task(): dynamic_pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") @task def t1(a: int) -> int: return a + 10 @dynamic( task_config=dynamic_pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"} ) def dynamic_pod_task(a: int) -> List[int]: s = [] for i in range(a): s.append(t1(a=i)) return s assert isinstance(dynamic_pod_task, PodFunctionTask) default_img = Image(name="default", fqn="test", tag="tag") custom = dynamic_pod_task.get_custom( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) ) assert len(custom["podSpec"]["containers"]) == 2 primary_container = custom["podSpec"]["containers"][0] assert isinstance(dynamic_pod_task.task_config, Pod) assert primary_container["resources"] == { "requests": {"cpu": {"string": "10"}}, "limits": {"gpu": {"string": "2"}}, } with context_manager.FlyteContext.current_context().new_serialization_settings( serialization_settings=SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ) as ctx: with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: dynamic_job_spec = dynamic_pod_task.compile_into_workflow(ctx, dynamic_pod_task._task_function, a=5) assert len(dynamic_job_spec._nodes) == 5
def test_pod_task_undefined_primary(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="an undefined container") @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") pod_spec = simple_pod_task.get_k8s_pod( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )).pod_spec assert len(pod_spec["containers"]) == 3 primary_container = pod_spec["containers"][2] assert primary_container["name"] == "an undefined container" config = simple_pod_task.get_config( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert config["primary_container_name"] == "an undefined container"
def get_minimal_pod_task_config() -> Pod: primary_container = V1Container(name="flytetask") pod_spec = V1PodSpec(containers=[primary_container]) return Pod(pod_spec=pod_spec, primary_container_name="flytetask")
def test_pod_task_deserialization(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") target = simple_pod_task.get_k8s_pod( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) # Test that custom is correctly serialized by deserializing it with the python API client response = MagicMock() response.data = json.dumps(target.pod_spec) deserialized_pod_spec = ApiClient().deserialize(response, V1PodSpec) assert deserialized_pod_spec.restart_policy == "OnFailure" assert len(deserialized_pod_spec.containers) == 2 primary_container = deserialized_pod_spec.containers[0] assert primary_container.name == "a container" assert primary_container.args == [ "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--checkpoint-path", "{{.checkpointOutputPrefix}}", "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.test_pod", "task-name", "simple_pod_task", ] assert primary_container.volume_mounts[0].mount_path == "some/where" assert primary_container.volume_mounts[0].name == "volume mount" assert primary_container.resources == V1ResourceRequirements( limits={"gpu": "2"}, requests={"cpu": "10"}) assert primary_container.env == [V1EnvVar(name="FOO", value="bar")] assert deserialized_pod_spec.containers[1].name == "another container" config = simple_pod_task.get_config( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert config["primary_container_name"] == "a container"
def test_pod_task(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") @task( task_config=pod, requests=Resources(cpu="10"), limits=Resources(ephemeral_storage="1Gi", gpu="2"), environment={"FOO": "bar"}, ) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") pod_spec = simple_pod_task.get_k8s_pod( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )).pod_spec assert pod_spec["restartPolicy"] == "OnFailure" assert len(pod_spec["containers"]) == 2 primary_container = pod_spec["containers"][0] assert primary_container["name"] == "a container" assert primary_container["args"] == [ "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--checkpoint-path", "{{.checkpointOutputPrefix}}", "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.test_pod", "task-name", "simple_pod_task", ] assert primary_container["volumeMounts"][0]["mountPath"] == "some/where" assert primary_container["volumeMounts"][0]["name"] == "volume mount" assert primary_container["resources"] == { "requests": { "cpu": "10" }, "limits": { "ephemeral-storage": "1Gi", "gpu": "2" }, } assert primary_container["env"] == [{"name": "FOO", "value": "bar"}] assert pod_spec["containers"][1]["name"] == "another container"