def test_ref(): @reference_task( project="flytesnacks", domain="development", name="recipes.aaa.simple.join_strings", version="553018f39e519bdb2597b652639c30ce16b99c79", ) def ref_t1(a: typing.List[str]) -> str: ... assert ref_t1.id.project == "flytesnacks" assert ref_t1.id.domain == "development" assert ref_t1.id.name == "recipes.aaa.simple.join_strings" assert ref_t1.id.version == "553018f39e519bdb2597b652639c30ce16b99c79" serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ss = get_serializable(OrderedDict(), serialization_settings, ref_t1) assert ss is None serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, ref_t1) assert task_spec is None
def test_imperative(): @task def t1(a: str) -> str: return a + " world" @task def t2(): print("side effect") wb = ImperativeWorkflow(name="my.workflow") wb.add_workflow_input("in1", str) node = wb.add_entity(t1, a=wb.inputs["in1"]) wb.add_entity(t2) wb.add_workflow_output("from_n0t1", node.outputs["o0"]) assert wb(in1="hello") == "hello world" srz_wf = get_serializable(OrderedDict(), serialization_settings, wb) assert len(srz_wf.nodes) == 2 assert srz_wf.nodes[0].task_node is not None assert len(srz_wf.outputs) == 1 assert srz_wf.outputs[0].var == "from_n0t1" assert len(srz_wf.interface.inputs) == 1 assert len(srz_wf.interface.outputs) == 1 # Create launch plan from wf, that can also be serialized. lp = LaunchPlan.create("test_wb", wb) srz_lp = get_serializable(OrderedDict(), serialization_settings, lp) assert srz_lp.workflow_id.name == "my.workflow"
def test_basics(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): return a + 2, "world" @task def t2(a: str, b: str) -> str: return b + a @workflow def my_wf(a: int, b: str) -> (int, str): x, y = t1(a=a) d = t2(a=y, b=b) return x, d wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf, False) assert len(wf_spec.template.interface.inputs) == 2 assert len(wf_spec.template.interface.outputs) == 2 assert len(wf_spec.template.nodes) == 2 assert wf_spec.template.id.resource_type == identifier_models.ResourceType.WORKFLOW # Gets cached the first time around so it's not actually fast. task_spec = get_serializable(OrderedDict(), serialization_settings, t1, True) assert "pyflyte-execute" in task_spec.template.container.args lp = LaunchPlan.create( "testlp", my_wf, ) lp_model = get_serializable(OrderedDict(), serialization_settings, lp) assert lp_model.id.name == "testlp"
def test_references(): rlp = ReferenceLaunchPlan("media", "stg", "some.name", "cafe", inputs=kwtypes(in1=str), outputs=kwtypes()) sdk_lp = get_serializable(serialization_settings, rlp) assert sdk_lp.has_registered rt = ReferenceTask("media", "stg", "some.name", "cafe", inputs=kwtypes(in1=str), outputs=kwtypes()) sdk_task = get_serializable(serialization_settings, rt) assert sdk_task.has_registered rw = ReferenceWorkflow("media", "stg", "some.name", "cafe", inputs=kwtypes(in1=str), outputs=kwtypes()) sdk_wf = get_serializable(serialization_settings, rw) assert sdk_wf.has_registered
def test_query_no_inputs_or_outputs(): hive_task = HiveTask( name="flytekit.demo.hive_task.hivequery1", inputs={}, task_config=HiveConfig(cluster_label="flyte"), query_template=""" insert into extant_table (1, 'two') """, output_schema_type=None, ) @workflow def my_wf(): hive_task() default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) sdk_task = get_serializable(serialization_settings, hive_task) assert len(sdk_task.interface.inputs) == 0 assert len(sdk_task.interface.outputs) == 0 get_serializable(serialization_settings, my_wf)
def test_basics(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): return a + 2, "world" @task def t2(a: str, b: str) -> str: return b + a @workflow def my_wf(a: int, b: str) -> (int, str): x, y = t1(a=a) d = t2(a=y, b=b) return x, d sdk_wf = get_serializable(serialization_settings, my_wf, False) assert len(sdk_wf.interface.inputs) == 2 assert len(sdk_wf.interface.outputs) == 2 assert len(sdk_wf.nodes) == 2 # Gets cached the first time around so it's not actually fast. sdk_task = get_serializable(serialization_settings, t1, True) assert "pyflyte-execute" in sdk_task.container.args lp = LaunchPlan.create( "testlp", my_wf, ) sdk_lp = get_serializable(serialization_settings, lp) assert sdk_lp.id.name == "testlp"
def test_references(): rlp = ReferenceLaunchPlan("media", "stg", "some.name", "cafe", inputs=kwtypes(in1=str), outputs=kwtypes()) lp_model = get_serializable(OrderedDict(), serialization_settings, rlp) assert lp_model is None rt = ReferenceTask("media", "stg", "some.name", "cafe", inputs=kwtypes(in1=str), outputs=kwtypes()) task_spec = get_serializable(OrderedDict(), serialization_settings, rt) assert task_spec is None rw = ReferenceWorkflow("media", "stg", "some.name", "cafe", inputs=kwtypes(in1=str), outputs=kwtypes()) wf_spec = get_serializable(OrderedDict(), serialization_settings, rw) assert wf_spec is None
def test_imperative(): @task def t1(a: str) -> str: return a + " world" @task def t2(): print("side effect") wb = ImperativeWorkflow(name="my.workflow") wb.add_workflow_input("in1", str) node = wb.add_entity(t1, a=wb.inputs["in1"]) wb.add_entity(t2) wb.add_workflow_output("from_n0t1", node.outputs["o0"]) assert wb(in1="hello") == "hello world" wf_spec = get_serializable(OrderedDict(), serialization_settings, wb) assert len(wf_spec.template.nodes) == 2 assert wf_spec.template.nodes[0].task_node is not None assert len(wf_spec.template.outputs) == 1 assert wf_spec.template.outputs[0].var == "from_n0t1" assert len(wf_spec.template.interface.inputs) == 1 assert len(wf_spec.template.interface.outputs) == 1 # Create launch plan from wf, that can also be serialized. lp = LaunchPlan.create("test_wb", wb) lp_model = get_serializable(OrderedDict(), serialization_settings, lp) assert lp_model.spec.workflow_id.name == "my.workflow" wb2 = ImperativeWorkflow(name="parent.imperative") p_in1 = wb2.add_workflow_input("p_in1", str) p_node0 = wb2.add_subwf(wb, in1=p_in1) wb2.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str) wb2_spec = get_serializable(OrderedDict(), serialization_settings, wb2) assert len(wb2_spec.template.nodes) == 1 assert len(wb2_spec.template.interface.inputs) == 1 assert wb2_spec.template.interface.inputs["p_in1"].type.simple is not None assert len(wb2_spec.template.interface.outputs) == 1 assert wb2_spec.template.interface.outputs[ "parent_wf_output"].type.simple is not None assert wb2_spec.template.nodes[ 0].workflow_node.sub_workflow_ref.name == "my.workflow" assert len(wb2_spec.sub_workflows) == 1 wb3 = ImperativeWorkflow(name="parent.imperative") p_in1 = wb3.add_workflow_input("p_in1", str) p_node0 = wb3.add_launch_plan(lp, in1=p_in1) wb3.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str) wb3_spec = get_serializable(OrderedDict(), serialization_settings, wb3) assert len(wb3_spec.template.nodes) == 1 assert len(wb3_spec.template.interface.inputs) == 1 assert wb3_spec.template.interface.inputs["p_in1"].type.simple is not None assert len(wb3_spec.template.interface.outputs) == 1 assert wb3_spec.template.interface.outputs[ "parent_wf_output"].type.simple is not None assert wb3_spec.template.nodes[ 0].workflow_node.launchplan_ref.name == "test_wb"
def test_wf_resolving(): @workflow def my_wf(a: int, b: str) -> (int, str): @task def t1(a: int) -> (int, str): return a + 2, "world" @task def t2(a: str, b: str) -> str: return b + a x, y = t1(a=a) d = t2(a=y, b=b) return x, d x = my_wf(a=3, b="hello") assert x == (5, "helloworld") # Because the workflow is nested inside a test, calling location will fail as it tries to find the LHS that the # workflow was assigned to with pytest.raises(Exception): _ = my_wf.location # Pretend my_wf was not actually nested, but somehow assigned to example_var_name at the module layer my_wf._instantiated_in = "example.module" my_wf._lhs = "example_var_name" workflows_tasks = my_wf.get_all_tasks() assert len(workflows_tasks) == 2 # Two tasks were declared inside # The tasks should get the location the workflow was assigned to as the resolver. # The args are the index. srz_t0 = get_serializable(OrderedDict(), serialization_settings, workflows_tasks[0]) assert srz_t0.container.args[-4:] == [ "--resolver", "example.module.example_var_name", "--", "0", ] srz_t1 = get_serializable(OrderedDict(), serialization_settings, workflows_tasks[1]) assert srz_t1.container.args[-4:] == [ "--resolver", "example.module.example_var_name", "--", "1", ]
def test_serialization(): square = ContainerTask( name="square", input_data_dir="/var/inputs", output_data_dir="/var/outputs", inputs=kwtypes(val=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out" ], ) sum = ContainerTask( name="sum", input_data_dir="/var/flyte/inputs", output_data_dir="/var/flyte/outputs", inputs=kwtypes(x=int, y=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out" ], ) @workflow def raw_container_wf(val1: int, val2: int) -> int: return sum(x=square(val=val1), y=square(val=val2)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, raw_container_wf) assert wf_spec is not None assert wf_spec.template is not None assert len(wf_spec.template.nodes) == 3 sqn_spec = get_serializable(OrderedDict(), serialization_settings, square) assert sqn_spec.template.container.image == "alpine" sumn_spec = get_serializable(OrderedDict(), serialization_settings, sum) assert sumn_spec.template.container.image == "alpine"
def test_lp_serialize(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @task def t2(a: str, b: str) -> str: return b + a @workflow def my_subwf(a: int) -> (str, str): x, y = t1(a=a) u, v = t1(a=x) return y, v lp = launch_plan.LaunchPlan.create("serialize_test1", my_subwf) lp_with_defaults = launch_plan.LaunchPlan.create("serialize_test2", my_subwf, default_inputs={"a": 3}) serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp) assert len(sdk_lp.default_inputs.parameters) == 1 assert sdk_lp.default_inputs.parameters["a"].required assert len(sdk_lp.fixed_inputs.literals) == 0 sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp_with_defaults) assert len(sdk_lp.default_inputs.parameters) == 1 assert not sdk_lp.default_inputs.parameters["a"].required assert sdk_lp.default_inputs.parameters[ "a"].default == _literal_models.Literal(scalar=_literal_models.Scalar( primitive=_literal_models.Primitive(integer=3))) assert len(sdk_lp.fixed_inputs.literals) == 0 # Adding a check to make sure oneof is respected. Tricky with booleans... if a default is specified, the # required field needs to be None, not False. parameter_a = sdk_lp.default_inputs.parameters["a"] parameter_a = Parameter.from_flyte_idl(parameter_a.to_flyte_idl()) assert parameter_a.default is not None
def test_pod_task_serialized(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="an undefined container") @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") ssettings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) serialized = get_serializable(OrderedDict(), ssettings, simple_pod_task) assert serialized.task_type_version == 1 assert serialized.config[ "primary_container_name"] == "an undefined container"
def test_serialization_branch_sub_wf(): @task def t1(a: int) -> int: return a + 2 @workflow def my_sub_wf(a: int) -> int: return t1(a=a) @workflow def my_wf(a: int) -> int: d = conditional("test1").if_(a > 3).then(t1(a=a)).else_().then( my_sub_wf(a=a)) return d default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf = get_serializable(serialization_settings, my_wf) assert wf is not None assert len(wf.nodes[0].inputs) == 1 assert wf.nodes[0].inputs[0].var == ".a" assert wf.nodes[0] is not None
def test_serialization_branch_complex(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): return a + 2, "world" @task def t2(a: str) -> str: return a @workflow def my_wf(a: int, b: str) -> (int, str): x, y = t1(a=a) d = (conditional("test1").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then( t2(a=y)).else_().fail("Unable to choose branch")) f = conditional("test2").if_(d == "hello ").then( t2(a="It is hello")).else_().then(t2(a="Not Hello!")) return x, f default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf = get_serializable(serialization_settings, my_wf) assert wf is not None assert len(wf.nodes) == 3 assert wf.nodes[1].branch_node is not None assert wf.nodes[2].branch_node is not None
def test_condition_tuple_branches(): @task def sum_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, sub=int): return a + b, a - b @workflow def math_ops(a: int, b: int) -> (int, int): # Flyte will only make `sum` and `sub` available as outputs because they are common between all branches sum, sub = ( conditional("noDivByZero") .if_(a > b) .then(sum_sub(a=a, b=b)) .else_() .fail("Only positive results are allowed") ) return sum, sub x, y = math_ops(a=3, b=2) assert x == 5 assert y == 1 default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) sdk_wf = get_serializable(OrderedDict(), serialization_settings, math_ops) assert sdk_wf.nodes[0].branch_node.if_else.case.then_node.task_node.reference_id.name == "test_conditions.sum_sub"
def test_serialization_branch(): @task def mimic(a: int) -> typing.NamedTuple("OutputsBC", c=int): return (a, ) @task def t1(c: int) -> typing.NamedTuple("OutputsBC", c=str): return ("world", ) @task def t2() -> typing.NamedTuple("OutputsBC", c=str): return ("hello", ) @workflow def my_wf(a: int) -> str: c = mimic(a=a) return conditional("test1").if_(c.c == 4).then( t1(c=c.c).c).else_().then(t2().c) assert my_wf(a=4) == "world" assert my_wf(a=2) == "hello" default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf = get_serializable(serialization_settings, my_wf) assert wf is not None assert len(wf.nodes) == 2 assert wf.nodes[1].branch_node is not None
def test_environment(): @task(environment={"FOO": "foofoo", "BAZ": "baz"}) def t1(a: int) -> str: a = a + 2 return "now it's " + str(a) @workflow def my_wf(a: int) -> str: x = t1(a=a) return x serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={ "FOO": "foo", "BAR": "bar" }, ) with context_manager.FlyteContext.current_context( ).new_compilation_context(): sdk_task = get_serializable(OrderedDict(), serialization_settings, t1) assert sdk_task.container.env == { "FOO": "foofoo", "BAR": "bar", "BAZ": "baz" }
def test_condition_tuple_branches(): @task def sum_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, sub=int): return a + b, a - b @workflow def math_ops(a: int, b: int) -> (int, int): add, sub = ( conditional("noDivByZero") .if_(a > b) .then(sum_sub(a=a, b=b)) .else_() .fail("Only positive results are allowed") ) return add, sub x, y = math_ops(a=3, b=2) assert x == 5 assert y == 1 wf_spec = get_serializable(OrderedDict(), serialization_settings, math_ops) assert len(wf_spec.template.nodes) == 1 assert ( wf_spec.template.nodes[0].branch_node.if_else.case.then_node.task_node.reference_id.name == "test_conditions.sum_sub" )
def test_serialization_nested_subwf(): @task def t1(a: int) -> int: return a + 2 @workflow def leaf_subwf(a: int = 42) -> (int, int): x = t1(a=a) u = t1(a=x) return x, u @workflow def middle_subwf() -> (int, int): s1, s2 = leaf_subwf(a=50) return s2, s2 @workflow def parent_wf() -> (int, int, int, int): m1, m2 = middle_subwf() l1, l2 = leaf_subwf() return m1, m2, l1, l2 wf_spec = get_serializable(OrderedDict(), serialization_settings, parent_wf) assert wf_spec is not None assert len(wf_spec.sub_workflows) == 2 subwf = {v.id.name: v for v in wf_spec.sub_workflows} assert subwf.keys() == { "test_serialization.leaf_subwf", "test_serialization.middle_subwf" } midwf = subwf["test_serialization.middle_subwf"] assert len(midwf.nodes) == 1 assert midwf.nodes[0].workflow_node is not None assert midwf.nodes[ 0].workflow_node.sub_workflow_ref.name == "test_serialization.leaf_subwf"
def test_workflow_values(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @workflow(interruptible=True, failure_policy=WorkflowFailurePolicy. FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) def wf(a: int) -> (str, str): x, y = t1(a=a) u, v = t1(a=x) return y, v serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) sdk_wf = get_serializable(OrderedDict(), serialization_settings, wf) assert sdk_wf.metadata_defaults.interruptible assert sdk_wf.metadata.on_failure == 1
def test_wf_nested_comp(): @task def t1(a: int) -> int: a = a + 5 return a @workflow def outer() -> (int, int): # You should not do this. This is just here for testing. @workflow def wf2() -> int: return t1(a=5) return t1(a=3), wf2() assert (8, 10) == outer() entity_mapping = OrderedDict() sdk_wf = get_serializable(entity_mapping, serialization_settings, outer) model_wf = sdk_wf.serialize() assert len(model_wf.template.interface.outputs.variables) == 2 assert len(model_wf.template.nodes) == 2 assert model_wf.template.nodes[1].workflow_node is not None sub_wf = model_wf.sub_workflows[0] assert len(sub_wf.nodes) == 1 assert sub_wf.nodes[0].id == "wf2-n0" assert sub_wf.nodes[0].task_node.reference_id.name == "test_workflows.t1"
def test_serialization(): maptask = map_task(t1, metadata=TaskMetadata(retries=1)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) task_spec = get_serializable(OrderedDict(), serialization_settings, maptask) assert task_spec.template.type == "container_array" assert task_spec.template.task_type_version == 1 assert task_spec.template.container.args == [ "pyflyte-map-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "test_map_task", "task-name", "t1", ]
def test_condition_tuple_branches(): @task def sum_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, sub=int): return a + b, a - b @workflow def math_ops(a: int, b: int) -> (int, int): add, sub = (conditional("noDivByZero").if_(a > b).then( sum_sub(a=a, b=b)).else_().fail("Only positive results are allowed")) return add, sub x, y = math_ops(a=3, b=2) assert x == 5 assert y == 1 default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, math_ops) assert len(wf_spec.template.nodes) == 1 assert (wf_spec.template.nodes[0].branch_node.if_else.case.then_node. task_node.reference_id.name == "test_conditions.sum_sub")
def test_serialization_images(): @task(container_image="{{.image.xyz.fqn}}:{{.image.default.version}}") def t1(a: int) -> int: return a @task(container_image="{{.image.default.fqn}}:{{.image.default.version}}") def t2(): pass @task def t3(): pass @task(container_image="docker.io/org/myimage:latest") def t4(): pass @task(container_image="docker.io/org/myimage:{{.image.default.version}}") def t5(a: int) -> int: return a os.environ["FLYTE_INTERNAL_IMAGE"] = "docker.io/default:version" set_flyte_config_file( os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config")) rs = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=get_image_config(), ) t1_ser = get_serializable(rs, t1) assert t1_ser.container.image == "docker.io/xyz:version" t1_ser.to_flyte_idl() t2_ser = get_serializable(rs, t2) assert t2_ser.container.image == "docker.io/default:version" t3_ser = get_serializable(rs, t3) assert t3_ser.container.image == "docker.io/default:version" t4_ser = get_serializable(rs, t4) assert t4_ser.container.image == "docker.io/org/myimage:latest" t5_ser = get_serializable(rs, t5) assert t5_ser.container.image == "docker.io/org/myimage:version"
def test_resources(): @task(requests=Resources(cpu="1"), limits=Resources(cpu="2", mem="400M")) def t1(a: int) -> str: a = a + 2 return "now it's " + str(a) @task(requests=Resources(cpu="3")) def t2(a: int) -> str: a = a + 200 return "now it's " + str(a) @workflow def my_wf(a: int) -> str: x = t1(a=a) return x serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_new_compilation_state()): task_spec = get_serializable(OrderedDict(), serialization_settings, t1) assert task_spec.template.container.resources.requests == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "1") ] assert task_spec.template.container.resources.limits == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "2"), _resource_models.ResourceEntry( _resource_models.ResourceName.MEMORY, "400M"), ] task_spec2 = get_serializable(OrderedDict(), serialization_settings, t2) assert task_spec2.template.container.resources.requests == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "3") ] assert task_spec2.template.container.resources.limits == []
def test_serialization_workflow_def(): @task def complex_task(a: int) -> str: b = a + 2 return str(b) maptask = map_task(complex_task, metadata=TaskMetadata(retries=1)) @workflow def w1(a: typing.List[int]) -> typing.List[str]: return maptask(a=a) @workflow def w2(a: typing.List[int]) -> typing.List[str]: return map_task(complex_task, metadata=TaskMetadata(retries=2))(a=a) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) serialized_control_plane_entities = OrderedDict() wf1_spec = get_serializable(serialized_control_plane_entities, serialization_settings, w1) assert wf1_spec.template is not None assert len(wf1_spec.template.nodes) == 1 wf2_spec = get_serializable(serialized_control_plane_entities, serialization_settings, w2) assert wf2_spec.template is not None assert len(wf2_spec.template.nodes) == 1 flyte_entities = list(serialized_control_plane_entities.keys()) tasks_seen = [] for entity in flyte_entities: if isinstance(entity, MapPythonTask) and "complex" in entity.name: tasks_seen.append(entity) assert len(tasks_seen) == 2 print(tasks_seen[0])
def test_fast(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): return a + 2, "world" @task def t2(a: str, b: str) -> str: return b + a sdk_task = get_serializable(serialization_settings, t1, True) assert "pyflyte-fast-execute" in sdk_task.container.args
def test_nested_condition_2(): @workflow def multiplier_2(my_input: float) -> float: return ( conditional("fractions") .if_((my_input > 0.1) & (my_input < 1.0)) .then( conditional("inner_fractions") .if_(my_input < 0.5) .then(double(n=my_input)) .elif_((my_input > 0.5) & (my_input < 0.7)) .then(square(n=my_input)) .else_() .fail("Only <0.7 allowed") ) .elif_((my_input > 1.0) & (my_input < 10.0)) .then(square(n=my_input)) .else_() .then(double(n=my_input)) ) srz_wf = get_serializable(OrderedDict(), serialization_settings, multiplier_2) assert len(srz_wf.template.nodes) == 1 fractions_branch = srz_wf.template.nodes[0] assert isinstance(fractions_branch, Node) assert fractions_branch.id == "n0" assert fractions_branch.branch_node is not None if_else_b = fractions_branch.branch_node.if_else assert if_else_b is not None assert if_else_b.case is not None assert if_else_b.case.then_node is not None inner_fractions_node = if_else_b.case.then_node assert inner_fractions_node.id == "n0" assert inner_fractions_node.branch_node.if_else.case.then_node.task_node is not None assert inner_fractions_node.branch_node.if_else.case.then_node.id == "n0" assert len(inner_fractions_node.branch_node.if_else.other) == 1 assert inner_fractions_node.branch_node.if_else.other[0].then_node.id == "n1" # Ensure other cases exist assert len(if_else_b.other) == 1 assert if_else_b.other[0].then_node.task_node is not None assert if_else_b.other[0].then_node.id == "n1" with pytest.raises(ValueError): multiplier_2(my_input=0.7) res = multiplier_2(my_input=0.3) assert res == 0.6 res = multiplier_2(my_input=5) assert res == 25 res = multiplier_2(my_input=10) assert res == 20
def test_ref_sub_wf(): ref_entity = get_reference_entity( _identifier_model.ResourceType.WORKFLOW, "proj", "dom", "app.other.sub_wf", "123", inputs=kwtypes(a=str, b=int), outputs={}, ) ctx = context_manager.FlyteContext.current_context() with pytest.raises(Exception) as e: ref_entity() assert "You must mock this out" in f"{e}" with context_manager.FlyteContextManager.with_context( ctx.with_new_compilation_state()) as ctx: with pytest.raises(Exception) as e: ref_entity() assert "Input was not specified" in f"{e}" output = ref_entity(a="hello", b=3) assert isinstance(output, VoidPromise) @workflow def wf1(a: str, b: int): ref_entity(a=a, b=b) serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) with pytest.raises(Exception): # Subworkflow as references don't work (probably ever). The reason is because we'd need to make a network call # to admin to get the structure of the subworkflow and the whole point of reference entities is that there # is no network call. get_serializable(OrderedDict(), serialization_settings, wf1)
def test_ref(): @reference_task( project="flytesnacks", domain="development", name="recipes.aaa.simple.join_strings", version="553018f39e519bdb2597b652639c30ce16b99c79", ) def ref_t1(a: typing.List[str]) -> str: ... assert ref_t1.id.project == "flytesnacks" assert ref_t1.id.domain == "development" assert ref_t1.id.name == "recipes.aaa.simple.join_strings" assert ref_t1.id.version == "553018f39e519bdb2597b652639c30ce16b99c79" serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ss = get_serializable(serialization_settings, ref_t1) assert ss.id == ref_t1.id assert ss.interface.inputs["a"] is not None assert ss.interface.outputs["o0"] is not None serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) sdk_task = get_serializable(serialization_settings, ref_t1) assert sdk_task.has_registered assert sdk_task.id.project == "flytesnacks" assert sdk_task.id.domain == "development" assert sdk_task.id.name == "recipes.aaa.simple.join_strings" assert sdk_task.id.version == "553018f39e519bdb2597b652639c30ce16b99c79"