def test_ref(): @reference_task( project="flytesnacks", domain="development", name="recipes.aaa.simple.join_strings", version="553018f39e519bdb2597b652639c30ce16b99c79", ) def ref_t1(a: typing.List[str]) -> str: ... assert ref_t1.id.project == "flytesnacks" assert ref_t1.id.domain == "development" assert ref_t1.id.name == "recipes.aaa.simple.join_strings" assert ref_t1.id.version == "553018f39e519bdb2597b652639c30ce16b99c79" serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ss = get_serializable(OrderedDict(), serialization_settings, ref_t1) assert ss is None serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, ref_t1) assert task_spec is None
def test_condition_tuple_branches(): @task def sum_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, sub=int): return a + b, a - b @workflow def math_ops(a: int, b: int) -> (int, int): # Flyte will only make `sum` and `sub` available as outputs because they are common between all branches sum, sub = ( conditional("noDivByZero") .if_(a > b) .then(sum_sub(a=a, b=b)) .else_() .fail("Only positive results are allowed") ) return sum, sub x, y = math_ops(a=3, b=2) assert x == 5 assert y == 1 default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) sdk_wf = get_serializable(OrderedDict(), serialization_settings, math_ops) assert sdk_wf.nodes[0].branch_node.if_else.case.then_node.task_node.reference_id.name == "test_conditions.sum_sub"
def test_serialization_branch_sub_wf(): @task def t1(a: int) -> int: return a + 2 @workflow def my_sub_wf(a: int) -> int: return t1(a=a) @workflow def my_wf(a: int) -> int: d = conditional("test1").if_(a > 3).then(t1(a=a)).else_().then( my_sub_wf(a=a)) return d default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf = get_serializable(OrderedDict(), serialization_settings, my_wf) assert wf is not None assert len(wf.nodes[0].inputs) == 1 assert wf.nodes[0].inputs[0].var == ".a" assert wf.nodes[0] is not None
def validate_image(ctx: typing.Any, param: str, values: tuple) -> ImageConfig: """ Validates the image to match the standard format. Also validates that only one default image is provided. a default image, is one that is specified as default=img or just img. All other images should be provided with a name, in the format name=img """ default_image = None images = [] for v in values: if "=" in v: splits = v.split("=", maxsplit=1) img = look_up_image_info(name=splits[0], tag=splits[1], optional_tag=False) else: img = look_up_image_info(_DEFAULT_IMAGE_NAME, v, False) if default_image and img.name == _DEFAULT_IMAGE_NAME: raise click.BadParameter( f"Only one default image can be specified. Received multiple {default_image} & {img} for {param}" ) if img.name == _DEFAULT_IMAGE_NAME: default_image = img else: images.append(img) return ImageConfig(default_image, images)
def test_serialization(): maptask = map_task(t1, metadata=TaskMetadata(retries=1)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) task_spec = get_serializable(OrderedDict(), serialization_settings, maptask) assert task_spec.template.type == "container_array" assert task_spec.template.task_type_version == 1 assert task_spec.template.container.args == [ "pyflyte-map-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "test_map_task", "task-name", "t1", ]
def test_condition_tuple_branches(): @task def sum_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, sub=int): return a + b, a - b @workflow def math_ops(a: int, b: int) -> (int, int): add, sub = (conditional("noDivByZero").if_(a > b).then( sum_sub(a=a, b=b)).else_().fail("Only positive results are allowed")) return add, sub x, y = math_ops(a=3, b=2) assert x == 5 assert y == 1 default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, math_ops) assert len(wf_spec.template.nodes) == 1 assert (wf_spec.template.nodes[0].branch_node.if_else.case.then_node. task_node.reference_id.name == "test_conditions.sum_sub")
def get_registerable_container_image(img: Optional[str], cfg: ImageConfig) -> str: """ :param img: Configured image :param cfg: Registration configuration :return: """ if img is not None and img != "": matches = _IMAGE_REPLACE_REGEX.findall(img) if matches is None or len(matches) == 0: return img for m in matches: if len(m) < 3: raise AssertionError( "Image specification should be of the form <fqn>:<tag> OR <fqn>:{{.image.default.version}} OR " f"{{.image.xyz.fqn}}:{{.image.xyz.version}} OR {{.image.xyz}} - Received {m}" ) replace_group, name, attr = m if name is None or name == "": raise AssertionError(f"Image format is incorrect {m}") img_cfg = cfg.find_image(name) if img_cfg is None: raise AssertionError(f"Image Config with name {name} not found in the configuration") if attr == "version": if img_cfg.tag is not None: img = img.replace(replace_group, img_cfg.tag) else: img = img.replace(replace_group, cfg.default_image.tag) elif attr == "fqn": img = img.replace(replace_group, img_cfg.fqn) elif attr == "": img = img.replace(replace_group, img_cfg.full) else: raise AssertionError(f"Only fqn and version are supported replacements, {attr} is not supported") return img return f"{cfg.default_image.fqn}:{cfg.default_image.tag}"
def test_serialization_branch_complex_2(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): return a + 2, "world" @task def t2(a: str) -> str: return a @workflow def my_wf(a: int, b: str) -> (int, str): x, y = t1(a=a) d = (conditional("test1").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then( t2(a=y)).else_().fail("Unable to choose branch")) f = conditional("test2").if_(d == "hello ").then( t2(a="It is hello")).else_().then(t2(a="Not Hello!")) return x, f default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert wf_spec is not None assert wf_spec.template.nodes[1].inputs[0].var == "n0.t1_int_output"
def test_workflow_values(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @workflow(interruptible=True, failure_policy=WorkflowFailurePolicy. FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) def wf(a: int) -> (str, str): x, y = t1(a=a) u, v = t1(a=x) return y, v serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) sdk_wf = get_serializable(OrderedDict(), serialization_settings, wf) assert sdk_wf.metadata_defaults.interruptible assert sdk_wf.metadata.on_failure == 1
def test_dont_convert_remotes(): @task def t1(in1: FlyteFile): print(in1) @dynamic def dyn(in1: FlyteFile): t1(in1=in1) fd = FlyteFile("s3://anything") with context_manager.FlyteContext.current_context( ).new_serialization_settings( serialization_settings=context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, )) as ctx: with ctx.new_execution_context( mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: lit = TypeEngine.to_literal( ctx, fd, FlyteFile, BlobType("", dimensionality=BlobType.BlobDimensionality.SINGLE)) lm = LiteralMap(literals={"in1": lit}) wf = dyn.dispatch_execute(ctx, lm) assert wf.nodes[0].inputs[ 0].binding.scalar.blob.uri == "s3://anything"
def test_container_image_conversion(): default_img = Image(name="default", fqn="xyz.com/abc", tag="tag1") other_img = Image(name="other", fqn="xyz.com/other", tag="tag-other") cfg = ImageConfig(default_image=default_img, images=[default_img, other_img]) assert get_registerable_container_image(None, cfg) == "xyz.com/abc:tag1" assert get_registerable_container_image("", cfg) == "xyz.com/abc:tag1" assert get_registerable_container_image("abc", cfg) == "abc" assert get_registerable_container_image("abc:latest", cfg) == "abc:latest" assert get_registerable_container_image("abc:{{.image.default.version}}", cfg) == "abc:tag1" assert ( get_registerable_container_image("{{.image.default.fqn}}:{{.image.default.version}}", cfg) == "xyz.com/abc:tag1" ) assert ( get_registerable_container_image("{{.image.other.fqn}}:{{.image.other.version}}", cfg) == "xyz.com/other:tag-other" ) assert ( get_registerable_container_image("{{.image.other.fqn}}:{{.image.default.version}}", cfg) == "xyz.com/other:tag1" ) assert get_registerable_container_image("{{.image.other.fqn}}", cfg) == "xyz.com/other" # Works with images instead of just image assert get_registerable_container_image("{{.images.other.fqn}}", cfg) == "xyz.com/other" with pytest.raises(AssertionError): get_registerable_container_image("{{.image.blah.fqn}}:{{.image.other.version}}", cfg) with pytest.raises(AssertionError): get_registerable_container_image("{{.image.fqn}}:{{.image.other.version}}", cfg) with pytest.raises(AssertionError): get_registerable_container_image("{{.image.blah}}", cfg) assert get_registerable_container_image("{{.image.default}}", cfg) == "xyz.com/abc:tag1"
def test_environment(): @task(environment={"FOO": "foofoo", "BAZ": "baz"}) def t1(a: int) -> str: a = a + 2 return "now it's " + str(a) @workflow def my_wf(a: int) -> str: x = t1(a=a) return x serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={ "FOO": "foo", "BAR": "bar" }, ) with context_manager.FlyteContext.current_context( ).new_compilation_context(): sdk_task = get_serializable(OrderedDict(), serialization_settings, t1) assert sdk_task.container.env == { "FOO": "foofoo", "BAR": "bar", "BAZ": "baz" }
def test_serialization_branch(): @task def mimic(a: int) -> typing.NamedTuple("OutputsBC", c=int): return (a, ) @task def t1(c: int) -> typing.NamedTuple("OutputsBC", c=str): return ("world", ) @task def t2() -> typing.NamedTuple("OutputsBC", c=str): return ("hello", ) @workflow def my_wf(a: int) -> str: c = mimic(a=a) return conditional("test1").if_(c.c == 4).then( t1(c=c.c).c).else_().then(t2().c) assert my_wf(a=4) == "world" assert my_wf(a=2) == "hello" default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert wf_spec is not None assert len(wf_spec.template.nodes) == 2 assert wf_spec.template.nodes[1].branch_node is not None
def test_ref(): @reference_task( project="flytesnacks", domain="development", name="recipes.aaa.simple.join_strings", version="553018f39e519bdb2597b652639c30ce16b99c79", ) def ref_t1(a: typing.List[str]) -> str: ... assert ref_t1.id.project == "flytesnacks" assert ref_t1.id.domain == "development" assert ref_t1.id.name == "recipes.aaa.simple.join_strings" assert ref_t1.id.version == "553018f39e519bdb2597b652639c30ce16b99c79" serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ss = get_serializable(serialization_settings, ref_t1) assert ss.id == ref_t1.id assert ss.interface.inputs["a"] is not None assert ss.interface.outputs["o0"] is not None serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) sdk_task = get_serializable(serialization_settings, ref_t1) assert sdk_task.has_registered assert sdk_task.id.project == "flytesnacks" assert sdk_task.id.domain == "development" assert sdk_task.id.name == "recipes.aaa.simple.join_strings" assert sdk_task.id.version == "553018f39e519bdb2597b652639c30ce16b99c79"
def test_wf1_with_fast_dynamic(): @task def t1(a: int) -> str: a = a + 2 return "fast-" + str(a) @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(t1(a=i)) return s @workflow def my_wf(a: int) -> typing.List[str]: v = my_subwf(a=a) return v with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings( context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, fast_serialization_settings=FastSerializationSettings( enabled=True), ))) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, additional_context={ "dynamic_addl_distro": "s3://my-s3-bucket/fast/123", "dynamic_dest_dir": "/User/flyte/workflows", }, ))) as ctx: dynamic_job_spec = my_subwf.compile_into_workflow( ctx, my_subwf._task_function, a=5) assert len(dynamic_job_spec._nodes) == 5 assert len(dynamic_job_spec.tasks) == 1 args = " ".join(dynamic_job_spec.tasks[0].container.args) assert args.startswith( "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 " "--dest-dir /User/flyte/workflows") assert context_manager.FlyteContextManager.size() == 1
def test_lps(resource_type): ref_entity = get_reference_entity( resource_type, "proj", "dom", "app.other.flyte_entity", "123", inputs=kwtypes(a=str, b=int), outputs={}, ) ctx = context_manager.FlyteContext.current_context() with pytest.raises(Exception) as e: ref_entity() assert "You must mock this out" in f"{e}" with ctx.new_compilation_context() as ctx: with pytest.raises(Exception) as e: ref_entity() assert "Input was not specified" in f"{e}" output = ref_entity(a="hello", b=3) assert isinstance(output, VoidPromise) @workflow def wf1(a: str, b: int): ref_entity(a=a, b=b) serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) sdk_wf = get_serializable(serialization_settings, wf1) assert len(sdk_wf.interface.inputs) == 2 assert len(sdk_wf.interface.outputs) == 0 assert len(sdk_wf.nodes) == 1 if resource_type == _identifier_model.ResourceType.LAUNCH_PLAN: assert sdk_wf.nodes[0].workflow_node.launchplan_ref.project == "proj" assert sdk_wf.nodes[ 0].workflow_node.launchplan_ref.name == "app.other.flyte_entity" elif resource_type == _identifier_model.ResourceType.WORKFLOW: assert sdk_wf.nodes[0].workflow_node.sub_workflow_ref.project == "proj" assert sdk_wf.nodes[ 0].workflow_node.sub_workflow_ref.name == "app.other.flyte_entity" else: assert sdk_wf.nodes[0].task_node.reference_id.project == "proj" assert sdk_wf.nodes[ 0].task_node.reference_id.name == "app.other.flyte_entity"
def test_py_func_task_get_container(): def foo(i: int): pass default_img = Image(name="default", fqn="xyz.com/abc", tag="tag1") other_img = Image(name="other", fqn="xyz.com/other", tag="tag-other") cfg = ImageConfig(default_image=default_img, images=[default_img, other_img]) settings = SerializationSettings(project="p", domain="d", version="v", image_config=cfg, env={"FOO": "bar"}) pytask = PythonFunctionTask(None, foo, None, environment={"BAZ": "baz"}) c = pytask.get_container(settings) assert c.image == "xyz.com/abc:tag1" assert c.env == {"FOO": "bar", "BAZ": "baz"}
def test_serialization(): square = ContainerTask( name="square", input_data_dir="/var/inputs", output_data_dir="/var/outputs", inputs=kwtypes(val=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out" ], ) sum = ContainerTask( name="sum", input_data_dir="/var/flyte/inputs", output_data_dir="/var/flyte/outputs", inputs=kwtypes(x=int, y=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out" ], ) @workflow def raw_container_wf(val1: int, val2: int) -> int: return sum(x=square(val=val1), y=square(val=val2)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, raw_container_wf) assert wf_spec is not None assert wf_spec.template is not None assert len(wf_spec.template.nodes) == 3 sqn_spec = get_serializable(OrderedDict(), serialization_settings, square) assert sqn_spec.template.container.image == "alpine" sumn_spec = get_serializable(OrderedDict(), serialization_settings, sum) assert sumn_spec.template.container.image == "alpine"
def test_lp_serialize(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @task def t2(a: str, b: str) -> str: return b + a @workflow def my_subwf(a: int) -> (str, str): x, y = t1(a=a) u, v = t1(a=x) return y, v lp = launch_plan.LaunchPlan.create("serialize_test1", my_subwf) lp_with_defaults = launch_plan.LaunchPlan.create("serialize_test2", my_subwf, default_inputs={"a": 3}) serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp) assert len(sdk_lp.default_inputs.parameters) == 1 assert sdk_lp.default_inputs.parameters["a"].required assert len(sdk_lp.fixed_inputs.literals) == 0 sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp_with_defaults) assert len(sdk_lp.default_inputs.parameters) == 1 assert not sdk_lp.default_inputs.parameters["a"].required assert sdk_lp.default_inputs.parameters[ "a"].default == _literal_models.Literal(scalar=_literal_models.Scalar( primitive=_literal_models.Primitive(integer=3))) assert len(sdk_lp.fixed_inputs.literals) == 0 # Adding a check to make sure oneof is respected. Tricky with booleans... if a default is specified, the # required field needs to be None, not False. parameter_a = sdk_lp.default_inputs.parameters["a"] parameter_a = Parameter.from_flyte_idl(parameter_a.to_flyte_idl()) assert parameter_a.default is not None
def test_wf1_with_dynamic(): @task def t1(a: int) -> str: a = a + 2 return "world-" + str(a) @task def t2(a: str, b: str) -> str: return b + a @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(t1(a=i)) return s @workflow def my_wf(a: int, b: str) -> (str, typing.List[str]): x = t2(a=b, b=b) v = my_subwf(a=a) return x, v v = 5 x = my_wf(a=v, b="hello ") assert x == ("hello hello ", ["world-" + str(i) for i in range(2, v + 2)]) with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings( context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, ))) as ctx: new_exc_state = ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION) with context_manager.FlyteContextManager.with_context( ctx.with_execution_state(new_exc_state)) as ctx: dynamic_job_spec = my_subwf.compile_into_workflow( ctx, my_subwf._task_function, a=5) assert len(dynamic_job_spec._nodes) == 5 assert len(dynamic_job_spec.tasks) == 1 assert context_manager.FlyteContextManager.size() == 1
def test_lp_with_output(): ref_lp = get_reference_entity( _identifier_model.ResourceType.LAUNCH_PLAN, "proj", "dom", "app.other.flyte_entity", "123", inputs=kwtypes(a=str, b=int), outputs=kwtypes(x=bool, y=int), ) @task def t1() -> (str, int): return "hello", 88 @task def t2(q: bool, r: int) -> str: return f"q: {q} r: {r}" @workflow def wf1() -> str: t1_str, t1_int = t1() x_out, y_out = ref_lp(a=t1_str, b=t1_int) return t2(q=x_out, r=y_out) @patch(ref_lp) def inner_test(ref_mock): ref_mock.return_value = (False, 30) x = wf1() assert x == "q: False r: 30" ref_mock.assert_called_with(a="hello", b=88) inner_test() serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) wf_spec = get_serializable(OrderedDict(), serialization_settings, wf1) assert wf_spec.template.nodes[ 1].workflow_node.launchplan_ref.project == "proj" assert wf_spec.template.nodes[ 1].workflow_node.launchplan_ref.name == "app.other.flyte_entity"
def test_resources(): @task(requests=Resources(cpu="1"), limits=Resources(cpu="2", mem="400M")) def t1(a: int) -> str: a = a + 2 return "now it's " + str(a) @task(requests=Resources(cpu="3")) def t2(a: int) -> str: a = a + 200 return "now it's " + str(a) @workflow def my_wf(a: int) -> str: x = t1(a=a) return x serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_new_compilation_state()): task_spec = get_serializable(OrderedDict(), serialization_settings, t1) assert task_spec.template.container.resources.requests == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "1") ] assert task_spec.template.container.resources.limits == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "2"), _resource_models.ResourceEntry( _resource_models.ResourceName.MEMORY, "400M"), ] task_spec2 = get_serializable(OrderedDict(), serialization_settings, t2) assert task_spec2.template.container.resources.requests == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "3") ] assert task_spec2.template.container.resources.limits == []
def test_serialization_workflow_def(): @task def complex_task(a: int) -> str: b = a + 2 return str(b) maptask = map_task(complex_task, metadata=TaskMetadata(retries=1)) @workflow def w1(a: typing.List[int]) -> typing.List[str]: return maptask(a=a) @workflow def w2(a: typing.List[int]) -> typing.List[str]: return map_task(complex_task, metadata=TaskMetadata(retries=2))(a=a) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) serialized_control_plane_entities = OrderedDict() wf1_spec = get_serializable(serialized_control_plane_entities, serialization_settings, w1) assert wf1_spec.template is not None assert len(wf1_spec.template.nodes) == 1 wf2_spec = get_serializable(serialized_control_plane_entities, serialization_settings, w2) assert wf2_spec.template is not None assert len(wf2_spec.template.nodes) == 1 flyte_entities = list(serialized_control_plane_entities.keys()) tasks_seen = [] for entity in flyte_entities: if isinstance(entity, MapPythonTask) and "complex" in entity.name: tasks_seen.append(entity) assert len(tasks_seen) == 2 print(tasks_seen[0])
def test_sql_command(): default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) srz_t = get_serializable(OrderedDict(), serialization_settings, not_tk) assert srz_t.template.container.args[-5:] == [ "--resolver", "flytekit.core.python_customized_container_task.default_task_template_resolver", "--", "{{.taskTemplatePath}}", "flytekit.extras.sqlite3.task.SQLite3TaskExecutor", ]
def test_nested_dynamic(): @task def t1(a: int) -> str: a = a + 2 return "world-" + str(a) @task def t2(a: str, b: str) -> str: return b + a @workflow def my_wf(a: int, b: str) -> (str, typing.List[str]): @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(t1(a=i)) return s x = t2(a=b, b=b) v = my_subwf(a=a) return x, v v = 5 x = my_wf(a=v, b="hello ") assert x == ("hello hello ", ["world-" + str(i) for i in range(2, v + 2)]) settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) nested_my_subwf = my_wf.get_all_tasks()[0] with context_manager.FlyteContext.current_context( ).new_serialization_settings(serialization_settings=settings) as ctx: with ctx.new_execution_context( mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: dynamic_job_spec = nested_my_subwf.compile_into_workflow( ctx, False, nested_my_subwf._task_function, a=5) assert len(dynamic_job_spec._nodes) == 5
def test_sql_command(): default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) srz_t = get_serializable(OrderedDict(), serialization_settings, not_tk) assert srz_t.template.container.args[-7:] == [ "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "plugins.tests.sqlalchemy.test_task", "task-name", "tk", ]
def test_ref_dynamic(): @reference_task( project="flytesnacks", domain="development", name="sample.reference.task", version="553018f39e519bdb2597b652639c30ce16b99c79", ) def ref_t1(a: int) -> str: ... @task def t2(a: str, b: str) -> str: return b + a @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(ref_t1(a=i)) return s with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings( context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, ))) as ctx: new_exc_state = ctx.execution_state.with_params( mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) with context_manager.FlyteContextManager.with_context( ctx.with_execution_state(new_exc_state)) as ctx: with pytest.raises(Exception): my_subwf.compile_into_workflow(ctx, False, my_subwf._task_function, a=5)
def test_ref_sub_wf(): ref_entity = get_reference_entity( _identifier_model.ResourceType.WORKFLOW, "proj", "dom", "app.other.sub_wf", "123", inputs=kwtypes(a=str, b=int), outputs={}, ) ctx = context_manager.FlyteContext.current_context() with pytest.raises(Exception) as e: ref_entity() assert "You must mock this out" in f"{e}" with context_manager.FlyteContextManager.with_context( ctx.with_new_compilation_state()) as ctx: with pytest.raises(Exception) as e: ref_entity() assert "Input was not specified" in f"{e}" output = ref_entity(a="hello", b=3) assert isinstance(output, VoidPromise) @workflow def wf1(a: str, b: int): ref_entity(a=a, b=b) serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) with pytest.raises(Exception): # Subworkflow as references don't work (probably ever). The reason is because we'd need to make a network call # to admin to get the structure of the subworkflow and the whole point of reference entities is that there # is no network call. get_serializable(OrderedDict(), serialization_settings, wf1)
def test_resource_overrides(): @task def t1(a: str) -> str: return f"*~*~*~{a}*~*~*~" @workflow def my_wf(a: typing.List[str]) -> typing.List[str]: mappy = map_task(t1) map_node = create_node(mappy, a=a).with_overrides( requests=Resources(cpu="1", mem="100"), limits=Resources(cpu="2", mem="200")) return map_node.o0 serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].task_node.overrides is not None assert wf_spec.template.nodes[ 0].task_node.overrides.resources.requests == [ _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "1"), _resources_models.ResourceEntry( _resources_models.ResourceName.MEMORY, "100"), ] assert wf_spec.template.nodes[0].task_node.overrides.resources.limits == [ _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2"), _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "200"), ]
def test_serialization_branch_compound_conditions(): @task def t1(a: int) -> int: return a + 2 @workflow def my_wf(a: int) -> int: d = (conditional("test1").if_((a == 4) | (a == 3)).then(t1(a=a)).elif_( a < 6).then(t1(a=a)).else_().fail("Unable to choose branch")) return d default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert wf_spec is not None assert len(wf_spec.template.nodes[0].inputs) == 1 assert wf_spec.template.nodes[0].inputs[0].var == ".a"