import typing from collections import OrderedDict import mock import pytest import flytekit.configuration from flytekit import task, workflow from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.condition import conditional from flytekit.models.core.workflow import Node from flytekit.tools.translator import get_serializable default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) @task def five() -> int: return 5 @task def square(n: float) -> float: """
def default_image_config(): default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash") return ImageConfig(default_image=default_image)
protocol_prefix, ) my_cols = kwtypes(w=typing.Dict[str, typing.Dict[str, int]], x=typing.List[typing.List[int]], y=int, z=str) fields = [("some_int", pa.int32()), ("some_string", pa.string())] arrow_schema = pa.schema(fields) serialization_settings = flytekit.configuration.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) def test_protocol(): assert protocol_prefix("s3://my-s3-bucket/file") == "s3" assert protocol_prefix("/file") == "/" def generate_pandas() -> pd.DataFrame: return pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]}) def test_types_pandas(): pt = pd.DataFrame
def test_normal_task(): @task def t1(a: str) -> str: return a + " world" @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(t1(a=str(i))) return s @workflow def my_wf(a: str) -> (str, typing.List[str]): t1_node = create_node(t1, a=a) dyn_node = create_node(my_subwf, a=3) return t1_node.o0, dyn_node.o0 r, x = my_wf(a="hello") assert r == "hello world" assert x == ["0 world", "1 world", "2 world"] serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 2 assert len(wf_spec.template.outputs) == 2 @task def t2(): ... @task def t3(): ... @workflow def empty_wf(): t2_node = create_node(t2) t3_node = create_node(t3) t3_node.runs_before(t2_node) # Test that VoidPromises can handle runs_before empty_wf() @workflow def empty_wf2(): t2_node = create_node(t2) t3_node = create_node(t3) t3_node >> t2_node serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) wf_spec = get_serializable(OrderedDict(), serialization_settings, empty_wf) assert wf_spec.template.nodes[0].upstream_node_ids[0] == "n1" assert wf_spec.template.nodes[0].id == "n0" wf_spec = get_serializable(OrderedDict(), serialization_settings, empty_wf2) assert wf_spec.template.nodes[0].upstream_node_ids[0] == "n1" assert wf_spec.template.nodes[0].id == "n0" with pytest.raises(FlyteAssertion): @workflow def empty_wf2(): create_node(t2, "foo")
def test_fast(): REQUESTS_GPU = Resources(cpu="123m", mem="234Mi", ephemeral_storage="123M", gpu="1") LIMITS_GPU = Resources(cpu="124M", mem="235Mi", ephemeral_storage="124M", gpu="1") def get_minimal_pod_task_config() -> Pod: primary_container = V1Container(name="flytetask") pod_spec = V1PodSpec(containers=[primary_container]) return Pod(pod_spec=pod_spec, primary_container_name="flytetask") @task( task_config=get_minimal_pod_task_config(), requests=REQUESTS_GPU, limits=LIMITS_GPU, ) def pod_task_with_resources(dummy_input: str) -> str: return dummy_input @dynamic(requests=REQUESTS_GPU, limits=LIMITS_GPU) def dynamic_task_with_pod_subtask(dummy_input: str) -> str: pod_task_with_resources(dummy_input=dummy_input) return dummy_input default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir="/User/flyte/workflows", distribution_location="s3://my-s3-bucket/fast/123", ), ) with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings(serialization_settings)) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx: input_literal_map = TypeEngine.dict_to_literal_map( ctx, {"dummy_input": "hi"}) dynamic_job_spec = dynamic_task_with_pod_subtask.dispatch_execute( ctx, input_literal_map) # print(dynamic_job_spec) assert len(dynamic_job_spec._nodes) == 1 assert len(dynamic_job_spec.tasks) == 1 args = " ".join( dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0] ["args"]) assert args.startswith( "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 " "--dest-dir /User/flyte/workflows") assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][ "resources"]["limits"]["cpu"] == "124M" assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][ "resources"]["requests"]["gpu"] == "1" assert context_manager.FlyteContextManager.size() == 1
def test_pod_task_deserialization(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") target = simple_pod_task.get_k8s_pod( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) # Test that custom is correctly serialized by deserializing it with the python API client response = MagicMock() response.data = json.dumps(target.pod_spec) deserialized_pod_spec = ApiClient().deserialize(response, V1PodSpec) assert deserialized_pod_spec.restart_policy == "OnFailure" assert len(deserialized_pod_spec.containers) == 2 primary_container = deserialized_pod_spec.containers[0] assert primary_container.name == "a container" assert primary_container.args == [ "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--checkpoint-path", "{{.checkpointOutputPrefix}}", "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.test_pod", "task-name", "simple_pod_task", ] assert primary_container.volume_mounts[0].mount_path == "some/where" assert primary_container.volume_mounts[0].name == "volume mount" assert primary_container.resources == V1ResourceRequirements( limits={"gpu": "2"}, requests={"cpu": "10"}) assert primary_container.env == [V1EnvVar(name="FOO", value="bar")] assert deserialized_pod_spec.containers[1].name == "another container" config = simple_pod_task.get_config( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert config["primary_container_name"] == "a container"
def test_dynamic_pod_task(): dynamic_pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") @task def t1(a: int) -> int: return a + 10 @dynamic( task_config=dynamic_pod, requests=Resources(cpu="10"), limits=Resources(ephemeral_storage="1Gi", gpu="2"), environment={"FOO": "bar"}, ) def dynamic_pod_task(a: int) -> List[int]: s = [] for i in range(a): s.append(t1(a=i)) return s assert isinstance(dynamic_pod_task, PodFunctionTask) default_img = Image(name="default", fqn="test", tag="tag") pod_spec = dynamic_pod_task.get_k8s_pod( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )).pod_spec assert len(pod_spec["containers"]) == 2 primary_container = pod_spec["containers"][0] assert isinstance(dynamic_pod_task.task_config, Pod) assert primary_container["resources"] == { "requests": { "cpu": "10" }, "limits": { "ephemeral-storage": "1Gi", "gpu": "2" }, } config = dynamic_pod_task.get_config( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert config["primary_container_name"] == "a container" with context_manager.FlyteContextManager.with_context( context_manager.FlyteContext.current_context( ).with_serialization_settings( SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, ))) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION))) as ctx: dynamic_job_spec = dynamic_pod_task.compile_into_workflow( ctx, dynamic_pod_task._task_function, a=5) assert len(dynamic_job_spec._nodes) == 5
def test_pod_task(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") @task( task_config=pod, requests=Resources(cpu="10"), limits=Resources(ephemeral_storage="1Gi", gpu="2"), environment={"FOO": "bar"}, ) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") pod_spec = simple_pod_task.get_k8s_pod( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )).pod_spec assert pod_spec["restartPolicy"] == "OnFailure" assert len(pod_spec["containers"]) == 2 primary_container = pod_spec["containers"][0] assert primary_container["name"] == "a container" assert primary_container["args"] == [ "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--checkpoint-path", "{{.checkpointOutputPrefix}}", "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.test_pod", "task-name", "simple_pod_task", ] assert primary_container["volumeMounts"][0]["mountPath"] == "some/where" assert primary_container["volumeMounts"][0]["name"] == "volume mount" assert primary_container["resources"] == { "requests": { "cpu": "10" }, "limits": { "ephemeral-storage": "1Gi", "gpu": "2" }, } assert primary_container["env"] == [{"name": "FOO", "value": "bar"}] assert pod_spec["containers"][1]["name"] == "another container"