def test_condition_tuple_branches(): @task def sum_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, sub=int): return a + b, a - b @workflow def math_ops(a: int, b: int) -> (int, int): # Flyte will only make `sum` and `sub` available as outputs because they are common between all branches sum, sub = ( conditional("noDivByZero") .if_(a > b) .then(sum_sub(a=a, b=b)) .else_() .fail("Only positive results are allowed") ) return sum, sub x, y = math_ops(a=3, b=2) assert x == 5 assert y == 1 default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) sdk_wf = get_serializable(serialization_settings, math_ops) assert sdk_wf.nodes[0].branch_node.if_else.case.then_node.task_node.reference_id.name == "test_conditions.sum_sub"
def test_workflow_values(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @workflow(interruptible=True, failure_policy=WorkflowFailurePolicy. FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) def wf(a: int) -> (str, str): x, y = t1(a=a) u, v = t1(a=x) return y, v serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) sdk_wf = get_serializable(serialization_settings, wf) assert sdk_wf.metadata_defaults.interruptible assert sdk_wf.metadata.on_failure == 1
def test_serialization_branch(): @task def mimic(a: int) -> typing.NamedTuple("OutputsBC", c=int): return (a, ) @task def t1(c: int) -> typing.NamedTuple("OutputsBC", c=str): return ("world", ) @task def t2() -> typing.NamedTuple("OutputsBC", c=str): return ("hello", ) @workflow def my_wf(a: int) -> str: c = mimic(a=a) return conditional("test1").if_(c.c == 4).then( t1(c=c.c).c).else_().then(t2().c) assert my_wf(a=4) == "world" assert my_wf(a=2) == "hello" default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf = get_serializable(serialization_settings, my_wf) assert wf is not None assert len(wf.nodes) == 2 assert wf.nodes[1].branch_node is not None
def test_pod_task(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") custom = simple_pod_task.get_custom( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert custom["podSpec"]["restartPolicy"] == "OnFailure" assert len(custom["podSpec"]["containers"]) == 2 primary_container = custom["podSpec"]["containers"][0] assert primary_container["name"] == "a container" assert primary_container["args"] == [ "pyflyte-execute", "--task-module", "pod.test_pod", "--task-name", "simple_pod_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", ] assert primary_container["volumeMounts"] == [{ "mountPath": "some/where", "name": "volume mount" }] assert primary_container["resources"] == { "requests": { "cpu": { "string": "10" } }, "limits": { "gpu": { "string": "2" } }, } assert primary_container["env"] == [{"name": "FOO", "value": "bar"}] assert custom["podSpec"]["containers"][1]["name"] == "another container" assert custom["primaryContainerName"] == "a container"
def test_serialization_branch_complex(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): return a + 2, "world" @task def t2(a: str) -> str: return a @workflow def my_wf(a: int, b: str) -> (int, str): x, y = t1(a=a) d = (conditional("test1").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then( t2(a=y)).else_().fail("Unable to choose branch")) f = conditional("test2").if_(d == "hello ").then( t2(a="It is hello")).else_().then(t2(a="Not Hello!")) return x, f default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf = get_serializable(serialization_settings, my_wf) assert wf is not None assert len(wf.nodes) == 3 assert wf.nodes[1].branch_node is not None assert wf.nodes[2].branch_node is not None
def test_serialization_branch_sub_wf(): @task def t1(a: int) -> int: return a + 2 @workflow def my_sub_wf(a: int) -> int: return t1(a=a) @workflow def my_wf(a: int) -> int: d = conditional("test1").if_(a > 3).then(t1(a=a)).else_().then( my_sub_wf(a=a)) return d default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf = get_serializable(serialization_settings, my_wf) assert wf is not None assert len(wf.nodes[0].inputs) == 1 assert wf.nodes[0].inputs[0].var == ".a" assert wf.nodes[0] is not None
def test_query_no_inputs_or_outputs(): hive_task = HiveTask( name="flytekit.demo.hive_task.hivequery1", inputs={}, task_config=HiveConfig(cluster_label="flyte"), query_template=""" insert into extant_table (1, 'two') """, output_schema_type=None, ) @workflow def my_wf(): hive_task() default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) sdk_task = get_serializable(serialization_settings, hive_task) assert len(sdk_task.interface.inputs) == 0 assert len(sdk_task.interface.outputs) == 0 get_serializable(serialization_settings, my_wf)
def test_pytorch_task(): @task(task_config=PyTorch(num_workers=10, per_replica_requests=Resources(cpu="1")), cache=True, cache_version="1") def my_pytorch_task(x: int, y: str) -> int: return x assert my_pytorch_task(x=10, y="hello") == 10 assert my_pytorch_task.task_config is not None default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) assert my_pytorch_task.get_custom(settings) == {"workers": 10} assert my_pytorch_task.resources.limits == Resources() assert my_pytorch_task.resources.requests == Resources(cpu="1") assert my_pytorch_task.task_type == "pytorch"
def test_dynamic_pod_task(): dynamic_pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") @task def t1(a: int) -> int: return a + 10 @dynamic(task_config=dynamic_pod) def dynamic_pod_task(a: int) -> List[int]: s = [] for i in range(a): s.append(t1(a=i)) return s assert isinstance(dynamic_pod_task, PodFunctionTask) default_img = Image(name="default", fqn="test", tag="tag") custom = dynamic_pod_task.get_custom( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert len(custom["podSpec"]["containers"]) == 2 with context_manager.FlyteContext.current_context( ).new_serialization_settings( serialization_settings=context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, )) as ctx: with ctx.new_execution_context( mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: dynamic_job_spec = dynamic_pod_task.compile_into_workflow( ctx, dynamic_pod_task._task_function, a=5) assert len(dynamic_job_spec._nodes) == 5
def _get_reg_settings(): default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) return settings
def test_ref(): @reference_task( project="flytesnacks", domain="development", name="recipes.aaa.simple.join_strings", version="553018f39e519bdb2597b652639c30ce16b99c79", ) def ref_t1(a: typing.List[str]) -> str: ... assert ref_t1.id.project == "flytesnacks" assert ref_t1.id.domain == "development" assert ref_t1.id.name == "recipes.aaa.simple.join_strings" assert ref_t1.id.version == "553018f39e519bdb2597b652639c30ce16b99c79" serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ss = get_serializable(serialization_settings, ref_t1) assert ss.id == ref_t1.id assert ss.interface.inputs["a"] is not None assert ss.interface.outputs["o0"] is not None serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) sdk_task = get_serializable(serialization_settings, ref_t1) assert sdk_task.has_registered assert sdk_task.id.project == "flytesnacks" assert sdk_task.id.domain == "development" assert sdk_task.id.name == "recipes.aaa.simple.join_strings" assert sdk_task.id.version == "553018f39e519bdb2597b652639c30ce16b99c79"
def test_serialization(): square = ContainerTask( name="square", input_data_dir="/var/inputs", output_data_dir="/var/outputs", inputs=kwtypes(val=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out" ], ) sum = ContainerTask( name="sum", input_data_dir="/var/flyte/inputs", output_data_dir="/var/flyte/outputs", inputs=kwtypes(x=int, y=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out" ], ) @workflow def raw_container_wf(val1: int, val2: int) -> int: return sum(x=square(val=val1), y=square(val=val2)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf = get_serializable(serialization_settings, raw_container_wf) assert wf is not None assert len(wf.nodes) == 3 sqn = get_serializable(serialization_settings, square) assert sqn.container.image == "alpine" sumn = get_serializable(serialization_settings, sum) assert sumn.container.image == "alpine"
def test_lp_with_output(): ref_lp = get_reference_entity( _identifier_model.ResourceType.LAUNCH_PLAN, "proj", "dom", "app.other.flyte_entity", "123", inputs=kwtypes(a=str, b=int), outputs=kwtypes(x=bool, y=int), ) @task def t1() -> (str, int): return "hello", 88 @task def t2(q: bool, r: int) -> str: return f"q: {q} r: {r}" @workflow def wf1() -> str: t1_str, t1_int = t1() x_out, y_out = ref_lp(a=t1_str, b=t1_int) return t2(q=x_out, r=y_out) @patch(ref_lp) def inner_test(ref_mock): ref_mock.return_value = (False, 30) x = wf1() assert x == "q: False r: 30" ref_mock.assert_called_with(a="hello", b=88) inner_test() serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) sdk_wf = get_serializable(serialization_settings, wf1) assert sdk_wf.nodes[1].workflow_node.launchplan_ref.project == "proj" assert sdk_wf.nodes[1].workflow_node.launchplan_ref.name == "app.other.flyte_entity"
def test_lps(resource_type): ref_entity = get_reference_entity( resource_type, "proj", "dom", "app.other.flyte_entity", "123", inputs=kwtypes(a=str, b=int), outputs={}, ) ctx = context_manager.FlyteContext.current_context() with pytest.raises(Exception) as e: ref_entity() assert "You must mock this out" in f"{e}" with ctx.new_compilation_context() as ctx: with pytest.raises(Exception) as e: ref_entity() assert "Input was not specified" in f"{e}" output = ref_entity(a="hello", b=3) assert isinstance(output, VoidPromise) @workflow def wf1(a: str, b: int): ref_entity(a=a, b=b) serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) sdk_wf = get_serializable(serialization_settings, wf1) assert len(sdk_wf.interface.inputs) == 2 assert len(sdk_wf.interface.outputs) == 0 assert len(sdk_wf.nodes) == 1 if resource_type == _identifier_model.ResourceType.LAUNCH_PLAN: assert sdk_wf.nodes[0].workflow_node.launchplan_ref.project == "proj" assert sdk_wf.nodes[0].workflow_node.launchplan_ref.name == "app.other.flyte_entity" elif resource_type == _identifier_model.ResourceType.WORKFLOW: assert sdk_wf.nodes[0].workflow_node.sub_workflow_ref.project == "proj" assert sdk_wf.nodes[0].workflow_node.sub_workflow_ref.name == "app.other.flyte_entity" else: assert sdk_wf.nodes[0].task_node.reference_id.project == "proj" assert sdk_wf.nodes[0].task_node.reference_id.name == "app.other.flyte_entity"
def test_environment(): @task(environment={"FOO": "foofoo", "BAZ": "baz"}) def t1(a: int) -> str: a = a + 2 return "now it's " + str(a) @workflow def my_wf(a: int) -> str: x = t1(a=a) return x serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={"FOO": "foo", "BAR": "bar"}, ) with context_manager.FlyteContext.current_context().new_compilation_context(): sdk_task = get_serializable(serialization_settings, t1) assert sdk_task.container.env == {"FOO": "foofoo", "BAR": "bar", "BAZ": "baz"}
def test_serialization(): hive_task = HiveTask( name="flytekit.demo.hive_task.hivequery1", inputs=kwtypes(my_schema=FlyteSchema, ds=str), config=HiveConfig(cluster_label="flyte"), query_template=""" set engine=tez; insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet -- will be unique per retry select * from blah where ds = '{{ .Inputs.ds }}' and uri = '{{ .inputs.my_schema }}' """, # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @workflow def my_wf(in_schema: FlyteSchema, ds: str) -> FlyteSchema: return hive_task(my_schema=in_schema, ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) sdk_task = get_serializable(serialization_settings, hive_task) assert "{{ .rawOutputDataPrefix" in sdk_task.custom["query"]["query"] assert "insert overwrite directory" in sdk_task.custom["query"]["query"] assert len(sdk_task.interface.inputs) == 2 assert len(sdk_task.interface.outputs) == 1 sdk_wf = get_serializable(serialization_settings, my_wf) assert sdk_wf.interface.outputs["o0"].type.schema is not None assert sdk_wf.outputs[0].var == "o0" assert sdk_wf.outputs[0].binding.promise.node_id == "n0" assert sdk_wf.outputs[0].binding.promise.var == "results"
def test_lp_serialize(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @task def t2(a: str, b: str) -> str: return b + a @workflow def my_subwf(a: int) -> (str, str): x, y = t1(a=a) u, v = t1(a=x) return y, v lp = launch_plan.LaunchPlan.create("serialize_test1", my_subwf) lp_with_defaults = launch_plan.LaunchPlan.create("serialize_test2", my_subwf, default_inputs={"a": 3}) serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) sdk_lp = get_serializable(serialization_settings, lp) assert len(sdk_lp.default_inputs.parameters) == 0 assert len(sdk_lp.fixed_inputs.literals) == 0 sdk_lp = get_serializable(serialization_settings, lp_with_defaults) assert len(sdk_lp.default_inputs.parameters) == 1 assert len(sdk_lp.fixed_inputs.literals) == 0 # Adding a check to make sure oneof is respected. Tricky with booleans... if a default is specified, the # required field needs to be None, not False. parameter_a = sdk_lp.default_inputs.parameters["a"] parameter_a = Parameter.from_flyte_idl(parameter_a.to_flyte_idl()) assert parameter_a.default is not None
def test_spark_task(): @task(task_config=Spark(spark_conf={"spark": "1"})) def my_spark(a: str) -> int: session = flytekit.current_context().spark_session assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" return 10 assert my_spark.task_config is not None assert my_spark.task_config.spark_conf == {"spark": "1"} default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) retrieved_settings = my_spark.get_custom(settings) assert retrieved_settings["sparkConf"] == {"spark": "1"}
def test_wf1_with_dynamic(): @task def t1(a: int) -> str: a = a + 2 return "world-" + str(a) @task def t2(a: str, b: str) -> str: return b + a @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(t1(a=i)) return s @workflow def my_wf(a: int, b: str) -> (str, typing.List[str]): x = t2(a=b, b=b) v = my_subwf(a=a) return x, v v = 5 x = my_wf(a=v, b="hello ") assert x == ("hello hello ", ["world-" + str(i) for i in range(2, v + 2)]) with context_manager.FlyteContext.current_context().new_serialization_settings( serialization_settings=context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ) as ctx: with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: dynamic_job_spec = my_subwf.compile_into_workflow(ctx, my_subwf._task_function, a=5) assert len(dynamic_job_spec._nodes) == 5
def test_resources(): @task(requests=Resources(cpu="1"), limits=Resources(cpu="2", mem="400M")) def t1(a: int) -> str: a = a + 2 return "now it's " + str(a) @task(requests=Resources(cpu="3")) def t2(a: int) -> str: a = a + 200 return "now it's " + str(a) @workflow def my_wf(a: int) -> str: x = t1(a=a) return x serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) with context_manager.FlyteContext.current_context().new_compilation_context(): sdk_task = get_serializable(serialization_settings, t1) assert sdk_task.container.resources.requests == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "1") ] assert sdk_task.container.resources.limits == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "2"), _resource_models.ResourceEntry(_resource_models.ResourceName.MEMORY, "400M"), ] sdk_task2 = get_serializable(serialization_settings, t2) assert sdk_task2.container.resources.requests == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "3") ] assert sdk_task2.container.resources.limits == []
def test_dynamic_conditional(): @task def split( in1: typing.List[int] ) -> (typing.List[int], typing.List[int], int): return in1[0:int(len(in1) / 2)], in1[int(len(in1) / 2) + 1:], len(in1) / 2 # One sample implementation for merging. In a more real world example, this might merge file streams and only load # chunks into the memory. @task def merge(x: typing.List[int], y: typing.List[int]) -> typing.List[int]: n1 = len(x) n2 = len(y) result = list[int]() i = 0 j = 0 # Traverse both array while i < n1 and j < n2: # Check if current element of first array is smaller than current element of second array. If yes, # store first array element and increment first array index. Otherwise do same with second array if x[i] < y[j]: result.append(x[i]) i = i + 1 else: result.append(y[j]) j = j + 1 # Store remaining elements of first array while i < n1: result.append(x[i]) i = i + 1 # Store remaining elements of second array while j < n2: result.append(y[j]) j = j + 1 return result # This runs the sorting completely locally. It's faster and more efficient to do so if the entire list fits in memory. @task def merge_sort_locally(in1: typing.List[int]) -> typing.List[int]: return sorted(in1) @task def also_merge_sort_locally(in1: typing.List[int]) -> typing.List[int]: return sorted(in1) @dynamic def merge_sort_remotely(in1: typing.List[int]) -> typing.List[int]: x, y, new_count = split(in1=in1) sorted_x = merge_sort(in1=x, count=new_count) sorted_y = merge_sort(in1=y, count=new_count) return merge(x=sorted_x, y=sorted_y) @workflow def merge_sort(in1: typing.List[int], count: int) -> typing.List[int]: return (conditional("terminal_case").if_(count < 500).then( merge_sort_locally(in1=in1)).elif_(count < 1000).then( also_merge_sort_locally(in1=in1)).else_().then( merge_sort_remotely(in1=in1))) with context_manager.FlyteContext.current_context( ).new_serialization_settings( serialization_settings=context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, )) as ctx: with ctx.new_execution_context( mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: dynamic_job_spec = merge_sort_remotely.compile_into_workflow( ctx, merge_sort_remotely._task_function, in1=[2, 3, 4, 5]) assert len(dynamic_job_spec.tasks) == 5
import typing from flytekit import ContainerTask from flytekit.annotated import context_manager from flytekit.annotated.base_task import kwtypes from flytekit.annotated.context_manager import Image, ImageConfig from flytekit.annotated.launch_plan import LaunchPlan, ReferenceLaunchPlan from flytekit.annotated.task import ReferenceTask, task from flytekit.annotated.workflow import ReferenceWorkflow, workflow from flytekit.common.translator import get_serializable default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) def test_references(): rlp = ReferenceLaunchPlan("media", "stg", "some.name", "cafe", inputs=kwtypes(in1=str), outputs=kwtypes()) sdk_lp = get_serializable(serialization_settings, rlp) assert sdk_lp.has_registered
def test_normal_task(): @task def t1(a: str) -> str: return a + " world" @workflow def my_wf(a: str) -> str: t1_node = create_node(t1, a=a) return t1_node.o0 r = my_wf(a="hello") assert r == "hello world" serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) sdk_wf = get_serializable(serialization_settings, my_wf) assert len(sdk_wf.nodes) == 1 assert len(sdk_wf.outputs) == 1 @task def t2(): ... @task def t3(): ... @workflow def empty_wf(): t2_node = create_node(t2) t3_node = create_node(t3) t3_node.runs_before(t2_node) # Test that VoidPromises can handle runs_before empty_wf() @workflow def empty_wf2(): t2_node = create_node(t2) t3_node = create_node(t3) t3_node >> t2_node serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) sdk_wf = get_serializable(serialization_settings, empty_wf) assert sdk_wf.nodes[0].upstream_node_ids[0] == "n1" assert sdk_wf.nodes[0].id == "n0" sdk_wf = get_serializable(serialization_settings, empty_wf2) assert sdk_wf.nodes[0].upstream_node_ids[0] == "n1" assert sdk_wf.nodes[0].id == "n0"