def test_ref(): @reference_task( project="flytesnacks", domain="development", name="recipes.aaa.simple.join_strings", version="553018f39e519bdb2597b652639c30ce16b99c79", ) def ref_t1(a: typing.List[str]) -> str: ... assert ref_t1.id.project == "flytesnacks" assert ref_t1.id.domain == "development" assert ref_t1.id.name == "recipes.aaa.simple.join_strings" assert ref_t1.id.version == "553018f39e519bdb2597b652639c30ce16b99c79" serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ss = get_serializable(OrderedDict(), serialization_settings, ref_t1) assert ss is None serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, ref_t1) assert task_spec is None
def test_environment(): @task(environment={"FOO": "foofoo", "BAZ": "baz"}) def t1(a: int) -> str: a = a + 2 return "now it's " + str(a) @workflow def my_wf(a: int) -> str: x = t1(a=a) return x serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={ "FOO": "foo", "BAR": "bar" }, ) with context_manager.FlyteContext.current_context( ).new_compilation_context(): sdk_task = get_serializable(OrderedDict(), serialization_settings, t1) assert sdk_task.container.env == { "FOO": "foofoo", "BAR": "bar", "BAZ": "baz" }
def test_dont_convert_remotes(): @task def t1(in1: FlyteFile): print(in1) @dynamic def dyn(in1: FlyteFile): t1(in1=in1) fd = FlyteFile("s3://anything") with context_manager.FlyteContext.current_context( ).new_serialization_settings( serialization_settings=context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, )) as ctx: with ctx.new_execution_context( mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: lit = TypeEngine.to_literal( ctx, fd, FlyteFile, BlobType("", dimensionality=BlobType.BlobDimensionality.SINGLE)) lm = LiteralMap(literals={"in1": lit}) wf = dyn.dispatch_execute(ctx, lm) assert wf.nodes[0].inputs[ 0].binding.scalar.blob.uri == "s3://anything"
def test_workflow_values(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @workflow(interruptible=True, failure_policy=WorkflowFailurePolicy. FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) def wf(a: int) -> (str, str): x, y = t1(a=a) u, v = t1(a=x) return y, v serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) sdk_wf = get_serializable(OrderedDict(), serialization_settings, wf) assert sdk_wf.metadata_defaults.interruptible assert sdk_wf.metadata.on_failure == 1
def test_serialization_branch_sub_wf(): @task def t1(a: int) -> int: return a + 2 @workflow def my_sub_wf(a: int) -> int: return t1(a=a) @workflow def my_wf(a: int) -> int: d = conditional("test1").if_(a > 3).then(t1(a=a)).else_().then( my_sub_wf(a=a)) return d default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf = get_serializable(OrderedDict(), serialization_settings, my_wf) assert wf is not None assert len(wf.nodes[0].inputs) == 1 assert wf.nodes[0].inputs[0].var == ".a" assert wf.nodes[0] is not None
def test_serialization_branch_complex_2(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): return a + 2, "world" @task def t2(a: str) -> str: return a @workflow def my_wf(a: int, b: str) -> (int, str): x, y = t1(a=a) d = (conditional("test1").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then( t2(a=y)).else_().fail("Unable to choose branch")) f = conditional("test2").if_(d == "hello ").then( t2(a="It is hello")).else_().then(t2(a="Not Hello!")) return x, f default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert wf_spec is not None assert wf_spec.template.nodes[1].inputs[0].var == "n0.t1_int_output"
def test_serialization(): maptask = map_task(t1, metadata=TaskMetadata(retries=1)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) task_spec = get_serializable(OrderedDict(), serialization_settings, maptask) assert task_spec.template.type == "container_array" assert task_spec.template.task_type_version == 1 assert task_spec.template.container.args == [ "pyflyte-map-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "test_map_task", "task-name", "t1", ]
def test_serialization_branch(): @task def mimic(a: int) -> typing.NamedTuple("OutputsBC", c=int): return (a, ) @task def t1(c: int) -> typing.NamedTuple("OutputsBC", c=str): return ("world", ) @task def t2() -> typing.NamedTuple("OutputsBC", c=str): return ("hello", ) @workflow def my_wf(a: int) -> str: c = mimic(a=a) return conditional("test1").if_(c.c == 4).then( t1(c=c.c).c).else_().then(t2().c) assert my_wf(a=4) == "world" assert my_wf(a=2) == "hello" default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert wf_spec is not None assert len(wf_spec.template.nodes) == 2 assert wf_spec.template.nodes[1].branch_node is not None
def test_ref(): @reference_task( project="flytesnacks", domain="development", name="recipes.aaa.simple.join_strings", version="553018f39e519bdb2597b652639c30ce16b99c79", ) def ref_t1(a: typing.List[str]) -> str: ... assert ref_t1.id.project == "flytesnacks" assert ref_t1.id.domain == "development" assert ref_t1.id.name == "recipes.aaa.simple.join_strings" assert ref_t1.id.version == "553018f39e519bdb2597b652639c30ce16b99c79" serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ss = get_serializable(serialization_settings, ref_t1) assert ss.id == ref_t1.id assert ss.interface.inputs["a"] is not None assert ss.interface.outputs["o0"] is not None serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) sdk_task = get_serializable(serialization_settings, ref_t1) assert sdk_task.has_registered assert sdk_task.id.project == "flytesnacks" assert sdk_task.id.domain == "development" assert sdk_task.id.name == "recipes.aaa.simple.join_strings" assert sdk_task.id.version == "553018f39e519bdb2597b652639c30ce16b99c79"
def test_wf1_with_fast_dynamic(): @task def t1(a: int) -> str: a = a + 2 return "fast-" + str(a) @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(t1(a=i)) return s @workflow def my_wf(a: int) -> typing.List[str]: v = my_subwf(a=a) return v with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings( context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, fast_serialization_settings=FastSerializationSettings( enabled=True), ))) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, additional_context={ "dynamic_addl_distro": "s3://my-s3-bucket/fast/123", "dynamic_dest_dir": "/User/flyte/workflows", }, ))) as ctx: dynamic_job_spec = my_subwf.compile_into_workflow( ctx, my_subwf._task_function, a=5) assert len(dynamic_job_spec._nodes) == 5 assert len(dynamic_job_spec.tasks) == 1 args = " ".join(dynamic_job_spec.tasks[0].container.args) assert args.startswith( "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 " "--dest-dir /User/flyte/workflows") assert context_manager.FlyteContextManager.size() == 1
def test_lps(resource_type): ref_entity = get_reference_entity( resource_type, "proj", "dom", "app.other.flyte_entity", "123", inputs=kwtypes(a=str, b=int), outputs={}, ) ctx = context_manager.FlyteContext.current_context() with pytest.raises(Exception) as e: ref_entity() assert "You must mock this out" in f"{e}" with ctx.new_compilation_context() as ctx: with pytest.raises(Exception) as e: ref_entity() assert "Input was not specified" in f"{e}" output = ref_entity(a="hello", b=3) assert isinstance(output, VoidPromise) @workflow def wf1(a: str, b: int): ref_entity(a=a, b=b) serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) sdk_wf = get_serializable(serialization_settings, wf1) assert len(sdk_wf.interface.inputs) == 2 assert len(sdk_wf.interface.outputs) == 0 assert len(sdk_wf.nodes) == 1 if resource_type == _identifier_model.ResourceType.LAUNCH_PLAN: assert sdk_wf.nodes[0].workflow_node.launchplan_ref.project == "proj" assert sdk_wf.nodes[ 0].workflow_node.launchplan_ref.name == "app.other.flyte_entity" elif resource_type == _identifier_model.ResourceType.WORKFLOW: assert sdk_wf.nodes[0].workflow_node.sub_workflow_ref.project == "proj" assert sdk_wf.nodes[ 0].workflow_node.sub_workflow_ref.name == "app.other.flyte_entity" else: assert sdk_wf.nodes[0].task_node.reference_id.project == "proj" assert sdk_wf.nodes[ 0].task_node.reference_id.name == "app.other.flyte_entity"
def test_serialization(): square = ContainerTask( name="square", input_data_dir="/var/inputs", output_data_dir="/var/outputs", inputs=kwtypes(val=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out" ], ) sum = ContainerTask( name="sum", input_data_dir="/var/flyte/inputs", output_data_dir="/var/flyte/outputs", inputs=kwtypes(x=int, y=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out" ], ) @workflow def raw_container_wf(val1: int, val2: int) -> int: return sum(x=square(val=val1), y=square(val=val2)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, raw_container_wf) assert wf_spec is not None assert wf_spec.template is not None assert len(wf_spec.template.nodes) == 3 sqn_spec = get_serializable(OrderedDict(), serialization_settings, square) assert sqn_spec.template.container.image == "alpine" sumn_spec = get_serializable(OrderedDict(), serialization_settings, sum) assert sumn_spec.template.container.image == "alpine"
def test_lp_serialize(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @task def t2(a: str, b: str) -> str: return b + a @workflow def my_subwf(a: int) -> (str, str): x, y = t1(a=a) u, v = t1(a=x) return y, v lp = launch_plan.LaunchPlan.create("serialize_test1", my_subwf) lp_with_defaults = launch_plan.LaunchPlan.create("serialize_test2", my_subwf, default_inputs={"a": 3}) serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp) assert len(sdk_lp.default_inputs.parameters) == 1 assert sdk_lp.default_inputs.parameters["a"].required assert len(sdk_lp.fixed_inputs.literals) == 0 sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp_with_defaults) assert len(sdk_lp.default_inputs.parameters) == 1 assert not sdk_lp.default_inputs.parameters["a"].required assert sdk_lp.default_inputs.parameters[ "a"].default == _literal_models.Literal(scalar=_literal_models.Scalar( primitive=_literal_models.Primitive(integer=3))) assert len(sdk_lp.fixed_inputs.literals) == 0 # Adding a check to make sure oneof is respected. Tricky with booleans... if a default is specified, the # required field needs to be None, not False. parameter_a = sdk_lp.default_inputs.parameters["a"] parameter_a = Parameter.from_flyte_idl(parameter_a.to_flyte_idl()) assert parameter_a.default is not None
def test_wf1_with_dynamic(): @task def t1(a: int) -> str: a = a + 2 return "world-" + str(a) @task def t2(a: str, b: str) -> str: return b + a @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(t1(a=i)) return s @workflow def my_wf(a: int, b: str) -> (str, typing.List[str]): x = t2(a=b, b=b) v = my_subwf(a=a) return x, v v = 5 x = my_wf(a=v, b="hello ") assert x == ("hello hello ", ["world-" + str(i) for i in range(2, v + 2)]) with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings( context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, ))) as ctx: new_exc_state = ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION) with context_manager.FlyteContextManager.with_context( ctx.with_execution_state(new_exc_state)) as ctx: dynamic_job_spec = my_subwf.compile_into_workflow( ctx, my_subwf._task_function, a=5) assert len(dynamic_job_spec._nodes) == 5 assert len(dynamic_job_spec.tasks) == 1 assert context_manager.FlyteContextManager.size() == 1
def test_serialization_images(): @task(container_image="{{.image.xyz.fqn}}:{{.image.default.version}}") def t1(a: int) -> int: return a @task(container_image="{{.image.default.fqn}}:{{.image.default.version}}") def t2(): pass @task def t3(): pass @task(container_image="docker.io/org/myimage:latest") def t4(): pass @task(container_image="docker.io/org/myimage:{{.image.default.version}}") def t5(a: int) -> int: return a os.environ["FLYTE_INTERNAL_IMAGE"] = "docker.io/default:version" set_flyte_config_file( os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config")) rs = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=get_image_config(), ) t1_spec = get_serializable(OrderedDict(), rs, t1) assert t1_spec.template.container.image == "docker.io/xyz:version" t1_spec.to_flyte_idl() t2_spec = get_serializable(OrderedDict(), rs, t2) assert t2_spec.template.container.image == "docker.io/default:version" t3_spec = get_serializable(OrderedDict(), rs, t3) assert t3_spec.template.container.image == "docker.io/default:version" t4_spec = get_serializable(OrderedDict(), rs, t4) assert t4_spec.template.container.image == "docker.io/org/myimage:latest" t5_spec = get_serializable(OrderedDict(), rs, t5) assert t5_spec.template.container.image == "docker.io/org/myimage:version"
def test_lp_with_output(): ref_lp = get_reference_entity( _identifier_model.ResourceType.LAUNCH_PLAN, "proj", "dom", "app.other.flyte_entity", "123", inputs=kwtypes(a=str, b=int), outputs=kwtypes(x=bool, y=int), ) @task def t1() -> (str, int): return "hello", 88 @task def t2(q: bool, r: int) -> str: return f"q: {q} r: {r}" @workflow def wf1() -> str: t1_str, t1_int = t1() x_out, y_out = ref_lp(a=t1_str, b=t1_int) return t2(q=x_out, r=y_out) @patch(ref_lp) def inner_test(ref_mock): ref_mock.return_value = (False, 30) x = wf1() assert x == "q: False r: 30" ref_mock.assert_called_with(a="hello", b=88) inner_test() serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) wf_spec = get_serializable(OrderedDict(), serialization_settings, wf1) assert wf_spec.template.nodes[ 1].workflow_node.launchplan_ref.project == "proj" assert wf_spec.template.nodes[ 1].workflow_node.launchplan_ref.name == "app.other.flyte_entity"
def test_serialization_workflow_def(): @task def complex_task(a: int) -> str: b = a + 2 return str(b) maptask = map_task(complex_task, metadata=TaskMetadata(retries=1)) @workflow def w1(a: typing.List[int]) -> typing.List[str]: return maptask(a=a) @workflow def w2(a: typing.List[int]) -> typing.List[str]: return map_task(complex_task, metadata=TaskMetadata(retries=2))(a=a) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) serialized_control_plane_entities = OrderedDict() wf1_spec = get_serializable(serialized_control_plane_entities, serialization_settings, w1) assert wf1_spec.template is not None assert len(wf1_spec.template.nodes) == 1 wf2_spec = get_serializable(serialized_control_plane_entities, serialization_settings, w2) assert wf2_spec.template is not None assert len(wf2_spec.template.nodes) == 1 flyte_entities = list(serialized_control_plane_entities.keys()) tasks_seen = [] for entity in flyte_entities: if isinstance(entity, MapPythonTask) and "complex" in entity.name: tasks_seen.append(entity) assert len(tasks_seen) == 2 print(tasks_seen[0])
def test_resources(): @task(requests=Resources(cpu="1"), limits=Resources(cpu="2", mem="400M")) def t1(a: int) -> str: a = a + 2 return "now it's " + str(a) @task(requests=Resources(cpu="3")) def t2(a: int) -> str: a = a + 200 return "now it's " + str(a) @workflow def my_wf(a: int) -> str: x = t1(a=a) return x serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_new_compilation_state()): task_spec = get_serializable(OrderedDict(), serialization_settings, t1) assert task_spec.template.container.resources.requests == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "1") ] assert task_spec.template.container.resources.limits == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "2"), _resource_models.ResourceEntry( _resource_models.ResourceName.MEMORY, "400M"), ] task_spec2 = get_serializable(OrderedDict(), serialization_settings, t2) assert task_spec2.template.container.resources.requests == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "3") ] assert task_spec2.template.container.resources.limits == []
def test_sql_command(): default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) srz_t = get_serializable(OrderedDict(), serialization_settings, not_tk) assert srz_t.template.container.args[-5:] == [ "--resolver", "flytekit.core.python_customized_container_task.default_task_template_resolver", "--", "{{.taskTemplatePath}}", "flytekit.extras.sqlite3.task.SQLite3TaskExecutor", ]
def test_nested_dynamic(): @task def t1(a: int) -> str: a = a + 2 return "world-" + str(a) @task def t2(a: str, b: str) -> str: return b + a @workflow def my_wf(a: int, b: str) -> (str, typing.List[str]): @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(t1(a=i)) return s x = t2(a=b, b=b) v = my_subwf(a=a) return x, v v = 5 x = my_wf(a=v, b="hello ") assert x == ("hello hello ", ["world-" + str(i) for i in range(2, v + 2)]) settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) nested_my_subwf = my_wf.get_all_tasks()[0] with context_manager.FlyteContext.current_context( ).new_serialization_settings(serialization_settings=settings) as ctx: with ctx.new_execution_context( mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: dynamic_job_spec = nested_my_subwf.compile_into_workflow( ctx, False, nested_my_subwf._task_function, a=5) assert len(dynamic_job_spec._nodes) == 5
def test_sql_command(): default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) srz_t = get_serializable(OrderedDict(), serialization_settings, not_tk) assert srz_t.template.container.args[-7:] == [ "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "plugins.tests.sqlalchemy.test_task", "task-name", "tk", ]
def test_get_registrable_entities(): ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings( context_manager.SerializationSettings( project="p", domain="d", version="v", image_config=context_manager.ImageConfig( default_image=context_manager.Image("def", "docker.io/def", "latest") ), ) ) context_manager.FlyteEntities.entities = [foo, wf, "str"] entities = serialize.get_registrable_entities(ctx) assert entities assert len(entities) == 3 for e in entities: if isinstance(e, WorkflowSpec) or isinstance(e, TaskSpec) or isinstance(e, LaunchPlan): continue assert False, f"found unknown entity {type(e)}"
def test_ref_dynamic(): @reference_task( project="flytesnacks", domain="development", name="sample.reference.task", version="553018f39e519bdb2597b652639c30ce16b99c79", ) def ref_t1(a: int) -> str: ... @task def t2(a: str, b: str) -> str: return b + a @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(ref_t1(a=i)) return s with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings( context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, ))) as ctx: new_exc_state = ctx.execution_state.with_params( mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) with context_manager.FlyteContextManager.with_context( ctx.with_execution_state(new_exc_state)) as ctx: with pytest.raises(Exception): my_subwf.compile_into_workflow(ctx, False, my_subwf._task_function, a=5)
def test_ref_sub_wf(): ref_entity = get_reference_entity( _identifier_model.ResourceType.WORKFLOW, "proj", "dom", "app.other.sub_wf", "123", inputs=kwtypes(a=str, b=int), outputs={}, ) ctx = context_manager.FlyteContext.current_context() with pytest.raises(Exception) as e: ref_entity() assert "You must mock this out" in f"{e}" with context_manager.FlyteContextManager.with_context( ctx.with_new_compilation_state()) as ctx: with pytest.raises(Exception) as e: ref_entity() assert "Input was not specified" in f"{e}" output = ref_entity(a="hello", b=3) assert isinstance(output, VoidPromise) @workflow def wf1(a: str, b: int): ref_entity(a=a, b=b) serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) with pytest.raises(Exception): # Subworkflow as references don't work (probably ever). The reason is because we'd need to make a network call # to admin to get the structure of the subworkflow and the whole point of reference entities is that there # is no network call. get_serializable(OrderedDict(), serialization_settings, wf1)
def test_dynamic(): @dynamic def my_subwf(a: int) -> typing.List[int]: s = [] for i in range(a): s.append(ft(a=i)) return s with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings( context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, fast_serialization_settings=FastSerializationSettings( enabled=True), ))) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx: input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 2}) # Test that it works dynamic_job_spec = my_subwf.dispatch_execute( ctx, input_literal_map) assert len(dynamic_job_spec._nodes) == 2 assert len(dynamic_job_spec.tasks) == 1 assert dynamic_job_spec.tasks[0].id == ft.id # Test that the fast execute stuff does not get applied because the commands of tasks fetched from # Admin should never change. args = " ".join(dynamic_job_spec.tasks[0].container.args) assert not args.startswith("pyflyte-fast-execute")
def test_resource_overrides(): @task def t1(a: str) -> str: return f"*~*~*~{a}*~*~*~" @workflow def my_wf(a: typing.List[str]) -> typing.List[str]: mappy = map_task(t1) map_node = create_node(mappy, a=a).with_overrides( requests=Resources(cpu="1", mem="100"), limits=Resources(cpu="2", mem="200")) return map_node.o0 serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].task_node.overrides is not None assert wf_spec.template.nodes[ 0].task_node.overrides.resources.requests == [ _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "1"), _resources_models.ResourceEntry( _resources_models.ResourceName.MEMORY, "100"), ] assert wf_spec.template.nodes[0].task_node.overrides.resources.limits == [ _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2"), _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "200"), ]
def test_serialization_branch_compound_conditions(): @task def t1(a: int) -> int: return a + 2 @workflow def my_wf(a: int) -> int: d = (conditional("test1").if_((a == 4) | (a == 3)).then(t1(a=a)).elif_( a < 6).then(t1(a=a)).else_().fail("Unable to choose branch")) return d default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert wf_spec is not None assert len(wf_spec.template.nodes[0].inputs) == 1 assert wf_spec.template.nodes[0].inputs[0].var == ".a"
import pytest from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.base_task import TaskResolverMixin from flytekit.core.class_based_resolver import ClassStorageTaskResolver from flytekit.core.context_manager import Image, ImageConfig from flytekit.core.python_auto_container import default_task_resolver from flytekit.core.task import task from flytekit.core.workflow import workflow default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) def test_wf_resolving(): @workflow def my_wf(a: int, b: str) -> (int, str): @task def t1(a: int) -> (int, str): return a + 2, "world" @task def t2(a: str, b: str) -> str: return b + a
def test_subworkflow_condition_serialization(): """Test that subworkflows are correctly extracted from serialized workflows with condiationals.""" @task def t() -> int: return 5 @workflow def wf1() -> int: return t() @workflow def wf2() -> int: return t() @workflow def wf3() -> int: return t() @workflow def wf4() -> int: return t() @workflow def ifelse_branching(x: int) -> int: return conditional("simple branching test").if_(x == 2).then( wf1()).else_().then(wf2()) @workflow def ifelse_branching_fail(x: int) -> int: return conditional("simple branching test").if_(x == 2).then( wf1()).else_().fail("failed") @workflow def if_elif_else_branching(x: int) -> int: return ( # noqa conditional("test").if_(x == 2).then(wf1()).elif_(x == 3).then( wf2()).elif_(x == 4).then(wf3()).else_().then(wf4())) @workflow def wf5() -> int: return t() @workflow def nested_branching(x: int) -> int: return conditional("nested test").if_(x == 2).then( ifelse_branching(x=x)).else_().then(wf5()) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) for wf, expected_subworkflows in [ (ifelse_branching, ["test_conditions.{}".format(x) for x in ("wf1", "wf2")]), (ifelse_branching_fail, ["test_conditions.{}".format(x) for x in ("wf1", )]), (if_elif_else_branching, [ "test_conditions.{}".format(x) for x in ("wf1", "wf2", "wf3", "wf4") ]), (nested_branching, [ "test_conditions.{}".format(x) for x in ("ifelse_branching", "wf1", "wf2", "wf5") ]), ]: wf_spec = get_serializable(OrderedDict(), serialization_settings, wf) subworkflows = wf_spec.sub_workflows for sub_wf in subworkflows: assert sub_wf.id.name in expected_subworkflows assert len(subworkflows) == len(expected_subworkflows)
def serialize_all( pkgs: List[str] = None, local_source_root: str = None, folder: str = None, mode: SerializationMode = None, image: str = None, config_path: str = None, flytekit_virtualenv_root: str = None, ): """ In order to register, we have to comply with Admin's endpoints. Those endpoints take the following objects. These flyteidl.admin.launch_plan_pb2.LaunchPlanSpec flyteidl.admin.workflow_pb2.WorkflowSpec flyteidl.admin.task_pb2.TaskSpec However, if we were to merely call .to_flyte_idl() on all the discovered entities, what we would get are: flyteidl.admin.launch_plan_pb2.LaunchPlanSpec flyteidl.core.workflow_pb2.WorkflowTemplate flyteidl.core.tasks_pb2.TaskTemplate For Workflows and Tasks therefore, there is special logic in the serialize function that translates these objects. :param list[Text] pkgs: :param Text folder: :return: """ # m = module (i.e. python file) # k = value of dir(m), type str # o = object (e.g. SdkWorkflow) env = { _internal_config.CONFIGURATION_PATH.env_var: config_path if config_path else _internal_config.CONFIGURATION_PATH.get(), _internal_config.IMAGE.env_var: image, } serialization_settings = flyte_context.SerializationSettings( project=_PROJECT_PLACEHOLDER, domain=_DOMAIN_PLACEHOLDER, version=_VERSION_PLACEHOLDER, image_config=flyte_context.get_image_config(img_name=image), env=env, flytekit_virtualenv_root=flytekit_virtualenv_root, entrypoint_settings=flyte_context.EntrypointSettings( path=_os.path.join(flytekit_virtualenv_root, _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)), ) with flyte_context.FlyteContext.current_context( ).new_serialization_settings( serialization_settings=serialization_settings) as ctx: loaded_entities = [] for m, k, o in iterate_registerable_entities_in_order( pkgs, local_source_root=local_source_root): name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type) _logging.debug( "Found module {}\n K: {} Instantiated in {}".format( m, k, o._instantiated_in)) o._id = _identifier.Identifier(o.resource_type, _PROJECT_PLACEHOLDER, _DOMAIN_PLACEHOLDER, name, _VERSION_PLACEHOLDER) loaded_entities.append(o) ctx.serialization_settings.add_instance_var( InstanceVar(module=m, name=k, o=o)) click.echo( f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows" ) mode = mode if mode else SerializationMode.DEFAULT # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan # object, which gets added to the FlyteEntities.entities list, which we're iterating over. for entity in flyte_context.FlyteEntities.entities.copy(): # TODO: Add a reachable check. Since these entities are always added by the constructor, weird things can # happen. If someone creates a workflow inside a workflow, we don't actually want the inner workflow to be # registered. Or do we? Certainly, we don't want inner tasks to be registered because we don't know how # to reach them, but perhaps workflows should be okay to take into account generated workflows. # Also a user may import dir_b.workflows from dir_a.workflows but workflow packages might only # specify dir_a if isinstance(entity, PythonTask) or isinstance( entity, Workflow) or isinstance(entity, LaunchPlan): if isinstance(entity, PythonTask): if mode == SerializationMode.DEFAULT: serializable = get_serializable( ctx.serialization_settings, entity) elif mode == SerializationMode.FAST: serializable = get_serializable( ctx.serialization_settings, entity, fast=True) else: raise AssertionError( f"Unrecognized serialization mode: {mode}") else: serializable = get_serializable(ctx.serialization_settings, entity) loaded_entities.append(serializable) if isinstance(entity, Workflow): lp = LaunchPlan.get_default_launch_plan(ctx, entity) launch_plan = get_serializable(ctx.serialization_settings, lp) loaded_entities.append(launch_plan) zero_padded_length = _determine_text_chars(len(loaded_entities)) for i, entity in enumerate(loaded_entities): if entity.has_registered: _logging.info( f"Skipping entity {entity.id} because already registered") continue serialized = entity.serialize() fname_index = str(i).zfill(zero_padded_length) fname = "{}_{}_{}.pb".format(fname_index, entity.id.name, entity.id.resource_type) click.echo( f" Writing type: {entity.id.resource_type_name()}, {entity.id.name} to\n {fname}" ) if folder: fname = _os.path.join(folder, fname) _write_proto_to_file(serialized, fname) click.secho( f"Successfully serialized {len(loaded_entities)} flyte objects", fg="green")