def test_wf_resolving(): x = my_wf(a=3, b="hello") assert x == (5, "helloworld") # Because the workflow is nested inside a test, calling location will fail as it tries to find the LHS that the # workflow was assigned to assert my_wf.location == "tests.flytekit.unit.core.test_resolver.my_wf" workflows_tasks = my_wf.get_all_tasks() assert len(workflows_tasks) == 2 # Two tasks were declared inside # The tasks should get the location the workflow was assigned to as the resolver. # The args are the index. srz_t0_spec = get_serializable(OrderedDict(), serialization_settings, workflows_tasks[0]) assert srz_t0_spec.template.container.args[-4:] == [ "--resolver", "tests.flytekit.unit.core.test_resolver.my_wf", "--", "0", ] srz_t1_spec = get_serializable(OrderedDict(), serialization_settings, workflows_tasks[1]) assert srz_t1_spec.template.container.args[-4:] == [ "--resolver", "tests.flytekit.unit.core.test_resolver.my_wf", "--", "1", ]
def test_basics(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): return a + 2, "world" @task def t2(a: str, b: str) -> str: return b + a @workflow def my_wf(a: int, b: str) -> (int, str): x, y = t1(a=a) d = t2(a=y, b=b) return x, d wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.interface.inputs) == 2 assert len(wf_spec.template.interface.outputs) == 2 assert len(wf_spec.template.nodes) == 2 assert wf_spec.template.id.resource_type == identifier_models.ResourceType.WORKFLOW # Gets cached the first time around so it's not actually fast. ssettings = ( serialization_settings.new_builder().with_fast_serialization_settings( FastSerializationSettings(enabled=True)).build()) task_spec = get_serializable(OrderedDict(), ssettings, t1) assert "pyflyte-execute" in task_spec.template.container.args lp = LaunchPlan.create( "testlp", my_wf, ) lp_model = get_serializable(OrderedDict(), serialization_settings, lp) assert lp_model.id.name == "testlp"
def test_serialization(): square = ContainerTask( name="square", input_data_dir="/var/inputs", output_data_dir="/var/outputs", inputs=kwtypes(val=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out" ], ) sum = ContainerTask( name="sum", input_data_dir="/var/flyte/inputs", output_data_dir="/var/flyte/outputs", inputs=kwtypes(x=int, y=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out" ], ) @workflow def raw_container_wf(val1: int, val2: int) -> int: return sum(x=square(val=val1), y=square(val=val2)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, raw_container_wf) assert wf_spec is not None assert wf_spec.template is not None assert len(wf_spec.template.nodes) == 3 sqn_spec = get_serializable(OrderedDict(), serialization_settings, square) assert sqn_spec.template.container.image == "alpine" sumn_spec = get_serializable(OrderedDict(), serialization_settings, sum) assert sumn_spec.template.container.image == "alpine"
def test_normal_task(mock_client): merge_sort_remotely = load_proto_from_file( task_pb2.Task, os.path.join(responses_dir, "admin.task_pb2.Task.pb"), ) admin_task = task_models.Task.from_flyte_idl(merge_sort_remotely) mock_client.get_task.return_value = admin_task ft = rr.fetch_task(name="merge_sort_remotely", version="tst") @workflow def my_wf(numbers: typing.List[int], run_local_at_count: int) -> typing.List[int]: t1_node = create_node(ft, numbers=numbers, run_local_at_count=run_local_at_count) return t1_node.o0 serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig.auto(img_name=DefaultImages.default_image()), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert wf_spec.template.nodes[ 0].task_node.reference_id.name == "merge_sort_remotely"
def test_nonfunction_task_and_df_input(): @reference_task( project="flytesnacks", domain="development", name="ref_t1", version="fast56d8ce2e373baf011f4d3532e45f0a9b", ) def ref_t1( dataframe: pd.DataFrame, imputation_method: str = "median", ) -> pd.DataFrame: ... @reference_task( project="flytesnacks", domain="development", name="ref_t2", version="aedbd6fe44051c171fd966c280c5c3036f658831", ) def ref_t2( dataframe: pd.DataFrame, split_mask: int, num_features: int, ) -> pd.DataFrame: ... wb = ImperativeWorkflow(name="core.feature_engineering.workflow.fe_wf") wb.add_workflow_input("sqlite_archive", FlyteFile[typing.TypeVar("sqlite")]) sql_task = SQLite3Task( name="dummy.sqlite.task", query_template="select * from data", inputs=kwtypes(), output_schema_type=FlyteSchema, task_config=SQLite3Config( uri="https://sample/data", compressed=True, ), ) node_sql = wb.add_task(sql_task) node_t1 = wb.add_task(ref_t1, dataframe=node_sql.outputs["results"], imputation_method="mean") node_t2 = wb.add_task( ref_t2, dataframe=node_t1.outputs["o0"], split_mask=24, num_features=15, ) wb.add_workflow_output("output_from_t3", node_t2.outputs["o0"], python_type=pd.DataFrame) wf_spec = get_serializable(OrderedDict(), serialization_settings, wb) assert len(wf_spec.template.nodes) == 3 assert len(wf_spec.template.interface.inputs) == 1 assert wf_spec.template.interface.inputs["sqlite_archive"].type.blob is not None assert len(wf_spec.template.interface.outputs) == 1 assert wf_spec.template.interface.outputs["output_from_t3"].type.structured_dataset_type is not None assert wf_spec.template.interface.outputs["output_from_t3"].type.structured_dataset_type == StructuredDatasetType( format="parquet" )
def test_condition_tuple_branches(): @task def sum_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, sub=int): return a + b, a - b @workflow def math_ops(a: int, b: int) -> typing.Tuple[int, int]: add, sub = ( conditional("noDivByZero") .if_(a > b) .then(sum_sub(a=a, b=b)) .else_() .fail("Only positive results are allowed") ) return add, sub x, y = math_ops(a=3, b=2) assert x == 5 assert y == 1 wf_spec = get_serializable(OrderedDict(), serialization_settings, math_ops) assert len(wf_spec.template.nodes) == 1 assert ( wf_spec.template.nodes[0].branch_node.if_else.case.then_node.task_node.reference_id.name == "tests.flytekit.unit.core.test_conditions.sum_sub" )
def test_pod_task_serialized(): pod = Pod( pod_spec=get_pod_spec(), primary_container_name="an undefined container", labels={"label": "foo"}, annotations={"anno": "bar"}, ) @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") ssettings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) serialized = get_serializable(OrderedDict(), ssettings, simple_pod_task) assert serialized.template.task_type_version == 2 assert serialized.template.config[ "primary_container_name"] == "an undefined container" assert serialized.template.k8s_pod.metadata.labels == {"label": "foo"} assert serialized.template.k8s_pod.metadata.annotations == {"anno": "bar"} assert serialized.template.k8s_pod.pod_spec is not None
def test_serialization_branch(): @task def mimic(a: int) -> typing.NamedTuple("OutputsBC", c=int): return (a, ) @task def t1(c: int) -> typing.NamedTuple("OutputsBC", c=str): return ("world", ) @task def t2() -> typing.NamedTuple("OutputsBC", c=str): return ("hello", ) @workflow def my_wf(a: int) -> str: c = mimic(a=a) return conditional("test1").if_(c.c == 4).then( t1(c=c.c).c).else_().then(t2().c) assert my_wf(a=4) == "world" assert my_wf(a=2) == "hello" default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert wf_spec is not None assert len(wf_spec.template.nodes) == 2 assert wf_spec.template.nodes[1].branch_node is not None
def test_launch_plan_with_fixed_input(): @task def greet(day_of_week: str, number: int, am: bool) -> str: greeting = "Have a great " + day_of_week + " " greeting += "morning" if am else "evening" return greeting + "!" * number @workflow def go_greet(day_of_week: str, number: int, am: bool = False) -> str: return greet(day_of_week=day_of_week, number=number, am=am) morning_greeting = LaunchPlan.create( "morning_greeting", go_greet, fixed_inputs={"am": True}, default_inputs={"number": 1}, ) @workflow def morning_greeter_caller(day_of_week: str) -> str: greeting = morning_greeting(day_of_week=day_of_week) return greeting settings = ( serialization_settings.new_builder().with_fast_serialization_settings( FastSerializationSettings(enabled=True)).build()) task_spec = get_serializable(OrderedDict(), settings, morning_greeter_caller) assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 assert len(task_spec.template.nodes) == 1 assert len(task_spec.template.nodes[0].inputs) == 2
def test_resolver_load_task(): # any task is fine, just copied one square = ContainerTask( name="square", input_data_dir="/var/inputs", output_data_dir="/var/outputs", inputs=kwtypes(val=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out" ], ) resolver = TaskTemplateResolver() ts = get_serializable(OrderedDict(), serialization_settings, square) file = tempfile.NamedTemporaryFile().name # load_task should create an instance of the path to the object given, doesn't need to be a real executor write_proto_to_file(ts.template.to_flyte_idl(), file) shim_task = resolver.load_task( [file, f"{Placeholder.__module__}.Placeholder"]) assert isinstance(shim_task.executor, Placeholder) assert shim_task.task_template.id.name == "square" assert shim_task.task_template.interface.inputs["val"] is not None assert shim_task.task_template.interface.outputs["out"] is not None
def test_wf_nested_comp(): @task def t1(a: int) -> int: a = a + 5 return a @workflow def outer() -> typing.Tuple[int, int]: # You should not do this. This is just here for testing. @workflow def wf2() -> int: return t1(a=5) return t1(a=3), wf2() assert (8, 10) == outer() entity_mapping = OrderedDict() model_wf = get_serializable(entity_mapping, serialization_settings, outer) assert len(model_wf.template.interface.outputs) == 2 assert len(model_wf.template.nodes) == 2 assert model_wf.template.nodes[1].workflow_node is not None sub_wf = model_wf.sub_workflows[0] assert len(sub_wf.nodes) == 1 assert sub_wf.nodes[0].id == "n0" assert sub_wf.nodes[0].task_node.reference_id.name == "tests.flytekit.unit.core.test_workflows.t1"
def test_resources_override(): @task def t1(a: str) -> str: return f"*~*~*~{a}*~*~*~" @workflow def my_wf(a: typing.List[str]) -> typing.List[str]: mappy = map_task(t1) map_node = mappy(a=a).with_overrides( requests=Resources(cpu="1", mem="100", ephemeral_storage="500Mi"), limits=Resources(cpu="2", mem="200", ephemeral_storage="1Gi"), ) return map_node serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].task_node.overrides is not None assert wf_spec.template.nodes[0].task_node.overrides.resources.requests == [ _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "1"), _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "100"), _resources_models.ResourceEntry(_resources_models.ResourceName.EPHEMERAL_STORAGE, "500Mi"), ] assert wf_spec.template.nodes[0].task_node.overrides.resources.limits == [ _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2"), _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "200"), _resources_models.ResourceEntry(_resources_models.ResourceName.EPHEMERAL_STORAGE, "1Gi"), ]
def test_serialization(serialization_settings): maptask = map_task(t1, metadata=TaskMetadata(retries=1)) task_spec = get_serializable(OrderedDict(), serialization_settings, maptask) # By default all map_task tasks will have their custom fields set. assert task_spec.template.custom["minSuccessRatio"] == 1.0 assert task_spec.template.type == "container_array" assert task_spec.template.task_type_version == 1 assert task_spec.template.container.args == [ "pyflyte-map-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--checkpoint-path", "{{.checkpointOutputPrefix}}", "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.flytekit.unit.core.test_map_task", "task-name", "t1", ]
def test_serialization_branch_complex_2(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): return a + 2, "world" @task def t2(a: str) -> str: return a @workflow def my_wf(a: int, b: str) -> (int, str): x, y = t1(a=a) d = (conditional("test1").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then( t2(a=y)).else_().fail("Unable to choose branch")) f = conditional("test2").if_(d == "hello ").then( t2(a="It is hello")).else_().then(t2(a="Not Hello!")) return x, f default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert wf_spec is not None assert wf_spec.template.nodes[1].inputs[0].var == "n0.t1_int_output"
def test_serialization_of_custom_fields(custom_fields_dict, expected_custom_fields, serialization_settings): maptask = map_task(t1, **custom_fields_dict) task_spec = get_serializable(OrderedDict(), serialization_settings, maptask) assert task_spec.template.custom == expected_custom_fields
def test_example_module(): @task def t1() -> torch.nn.Module: return torch.nn.BatchNorm1d(3, track_running_stats=True) task_spec = get_serializable(OrderedDict(), serialization_settings, t1) assert task_spec.template.interface.outputs[ "o0"].type.blob.format is PyTorchModuleTransformer.PYTORCH_FORMAT
def test_example_tensor(): @task def t1(array: torch.Tensor) -> torch.Tensor: return torch.flatten(array) task_spec = get_serializable(OrderedDict(), serialization_settings, t1) assert task_spec.template.interface.outputs[ "o0"].type.blob.format is PyTorchTensorTransformer.PYTORCH_FORMAT
def test_wf_docstring(): model_wf = get_serializable(OrderedDict(), serialization_settings, my_wf_example) assert len(model_wf.template.interface.outputs) == 2 assert model_wf.template.interface.outputs["o0"].description == "outputs" assert model_wf.template.interface.outputs["o1"].description == "outputs" assert len(model_wf.template.interface.inputs) == 1 assert model_wf.template.interface.inputs["a"].description == "input a"
def test_example(): @task def t1(array: np.ndarray) -> np.ndarray: return array.flatten() task_spec = get_serializable(OrderedDict(), serialization_settings, t1) assert task_spec.template.interface.outputs[ "o0"].type.blob.format is NumpyArrayTransformer.NUMPY_ARRAY_FORMAT
def test_serialization_images(): @task(container_image="{{.image.xyz.fqn}}:{{.image.xyz.version}}") def t1(a: int) -> int: return a @task(container_image="{{.image.abc.fqn}}:{{.image.xyz.version}}") def t2(): pass @task(container_image="docker.io/org/myimage:latest") def t4(): pass @task(container_image="docker.io/org/myimage:{{.image.xyz.version}}") def t5(a: int) -> int: return a @task(container_image="{{.image.xyz_123.fqn}}:{{.image.xyz_123.version}}") def t6(a: int) -> int: return a os.environ["FLYTE_INTERNAL_IMAGE"] = "docker.io/default:version" imgs = ImageConfig.auto(config_file=os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs/images.config")) rs = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=imgs, ) t1_spec = get_serializable(OrderedDict(), rs, t1) assert t1_spec.template.container.image == "docker.io/xyz:latest" t1_spec.to_flyte_idl() t2_spec = get_serializable(OrderedDict(), rs, t2) assert t2_spec.template.container.image == "docker.io/abc:latest" t4_spec = get_serializable(OrderedDict(), rs, t4) assert t4_spec.template.container.image == "docker.io/org/myimage:latest" t5_spec = get_serializable(OrderedDict(), rs, t5) assert t5_spec.template.container.image == "docker.io/org/myimage:latest" t5_spec = get_serializable(OrderedDict(), rs, t6) assert t5_spec.template.container.image == "docker.io/xyz_123:v1"
def get_registrable_entities( ctx: flyte_context.FlyteContext, options: typing.Optional[Options] = None ) -> typing.List[RegistrableEntity]: """ Returns all entities that can be serialized and should be sent over to Flyte backend. This will filter any entities that are not known to Admin """ new_api_serializable_entities = OrderedDict() # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan # object, which gets added to the FlyteEntities.entities list, which we're iterating over. for entity in flyte_context.FlyteEntities.entities.copy(): if isinstance(entity, PythonTask) or isinstance( entity, WorkflowBase) or isinstance(entity, LaunchPlan): get_serializable(new_api_serializable_entities, ctx.serialization_settings, entity, options=options) if isinstance(entity, WorkflowBase): lp = LaunchPlan.get_default_launch_plan(ctx, entity) get_serializable(new_api_serializable_entities, ctx.serialization_settings, lp, options) new_api_model_values = list(new_api_serializable_entities.values()) entities_to_be_serialized = list( filter(_should_register_with_admin, new_api_model_values)) serializable_tasks: typing.List[task_models.TaskSpec] = [ entity for entity in entities_to_be_serialized if isinstance(entity, task_models.TaskSpec) ] # Detect if any of the tasks is duplicated. Duplicate tasks are defined as having the same # metadata identifiers (see :py:class:`flytekit.common.core.identifier.Identifier`). Duplicate # tasks are considered invalid at registration # time and usually indicate user error, so we catch this common mistake at serialization time. duplicate_tasks = _find_duplicate_tasks(serializable_tasks) if len(duplicate_tasks) > 0: duplicate_task_names = [ task.template.id.name for task in duplicate_tasks ] raise FlyteValidationException( f"Multiple definitions of the following tasks were found: {duplicate_task_names}" ) return [v.to_flyte_idl() for v in entities_to_be_serialized]
def test_calling_wf(): # No way to fetch from Admin in unit tests so we serialize and then promote back serialized = OrderedDict() wf_spec = get_serializable(serialized, serialization_settings, sub_wf) task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) fwf = FlyteWorkflow.promote_from_model(wf_spec.template, tasks=task_templates) @workflow def parent_1(a: int, b: str) -> typing.Tuple[int, str]: y = t1(a=a) return fwf(a=y, b=b) # No way to fetch from Admin in unit tests so we serialize and then promote back serialized = OrderedDict() wf_spec = get_serializable(serialized, serialization_settings, parent_1) # Get task_specs from the second one, merge with the first one. Admin normally would be the one to do this. task_templates_p1, wf_specs, lp_specs = gather_dependent_entities( serialized) for k, v in task_templates.items(): task_templates_p1[k] = v # Pick out the subworkflow templates from the ordereddict. We can't use the output of the gather_dependent_entities # function because that only looks for WorkflowSpecs subwf_templates = { x.id: x for x in list( filter(lambda x: isinstance(x, WorkflowTemplate), serialized.values())) } fwf_p1 = FlyteWorkflow.promote_from_model(wf_spec.template, sub_workflows=subwf_templates, tasks=task_templates_p1) @workflow def parent_2(a: int, b: str) -> typing.Tuple[int, str]: x, y = fwf_p1(a=a, b=b) z = t1(a=x) return z, y serialized = OrderedDict() wf_spec = get_serializable(serialized, serialization_settings, parent_2) # Make sure both were picked up. assert len(wf_spec.sub_workflows) == 2
def test_nested_condition_2(): @workflow def multiplier_2(my_input: float) -> float: return ( conditional("fractions") .if_((my_input > 0.1) & (my_input < 1.0)) .then( conditional("inner_fractions") .if_(my_input < 0.5) .then(double(n=my_input)) .elif_((my_input > 0.5) & (my_input < 0.7)) .then(square(n=my_input)) .else_() .fail("Only <0.7 allowed") ) .elif_((my_input > 1.0) & (my_input < 10.0)) .then(square(n=my_input)) .else_() .then(double(n=my_input)) ) srz_wf = get_serializable(OrderedDict(), serialization_settings, multiplier_2) assert len(srz_wf.template.nodes) == 1 fractions_branch = srz_wf.template.nodes[0] assert isinstance(fractions_branch, Node) assert fractions_branch.id == "n0" assert fractions_branch.branch_node is not None if_else_b = fractions_branch.branch_node.if_else assert if_else_b is not None assert if_else_b.case is not None assert if_else_b.case.then_node is not None inner_fractions_node = if_else_b.case.then_node assert inner_fractions_node.id == "n0" assert inner_fractions_node.branch_node.if_else.case.then_node.task_node is not None assert inner_fractions_node.branch_node.if_else.case.then_node.id == "n0" assert len(inner_fractions_node.branch_node.if_else.other) == 1 assert inner_fractions_node.branch_node.if_else.other[0].then_node.id == "n1" # Ensure other cases exist assert len(if_else_b.other) == 1 assert if_else_b.other[0].then_node.task_node is not None assert if_else_b.other[0].then_node.id == "n1" with pytest.raises(ValueError): multiplier_2(my_input=0.7) res = multiplier_2(my_input=0.3) assert res == 0.6 res = multiplier_2(my_input=5.0) assert res == 25 res = multiplier_2(my_input=10.0) assert res == 20
def test_ref_sub_wf(): ref_entity = get_reference_entity( _identifier_model.ResourceType.WORKFLOW, "proj", "dom", "app.other.sub_wf", "123", inputs=kwtypes(a=str, b=int), outputs={}, ) ctx = context_manager.FlyteContext.current_context() with pytest.raises(Exception) as e: ref_entity() assert "You must mock this out" in f"{e}" with context_manager.FlyteContextManager.with_context( ctx.with_new_compilation_state()) as ctx: with pytest.raises(Exception) as e: ref_entity() assert "Input was not specified" in f"{e}" output = ref_entity(a="hello", b=3) assert isinstance(output, VoidPromise) @workflow def wf1(a: str, b: int): ref_entity(a=a, b=b) serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) with pytest.raises(Exception, match="currently unsupported"): # Subworkflow as references don't work (probably ever). The reason is because we'd need to make a network call # to admin to get the structure of the subworkflow and the whole point of reference entities is that there # is no network call. get_serializable(OrderedDict(), serialization_settings, wf1)
def test_serialization_named_return(): @task def t1() -> str: return "Hello" @workflow def wf() -> typing.NamedTuple("OP", a=str, b=str): return t1(), t1() wf_spec = get_serializable(OrderedDict(), serialization_settings, wf) assert len(wf_spec.template.interface.outputs) == 2 assert list(wf_spec.template.interface.outputs.keys()) == ["a", "b"]
def test_lps(resource_type): ref_entity = get_reference_entity( resource_type, "proj", "dom", "app.other.flyte_entity", "123", inputs=kwtypes(a=str, b=int), outputs={}, ) ctx = context_manager.FlyteContext.current_context() with pytest.raises(Exception) as e: ref_entity() assert "You must mock this out" in f"{e}" with context_manager.FlyteContextManager.with_context( ctx.with_new_compilation_state()) as ctx: with pytest.raises(Exception) as e: ref_entity() assert "Input was not specified" in f"{e}" output = ref_entity(a="hello", b=3) assert isinstance(output, VoidPromise) @workflow def wf1(a: str, b: int): ref_entity(a=a, b=b) serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) wf_spec = get_serializable(OrderedDict(), serialization_settings, wf1) assert len(wf_spec.template.interface.inputs) == 2 assert len(wf_spec.template.interface.outputs) == 0 assert len(wf_spec.template.nodes) == 1 if resource_type == _identifier_model.ResourceType.LAUNCH_PLAN: assert wf_spec.template.nodes[ 0].workflow_node.launchplan_ref.project == "proj" assert wf_spec.template.nodes[ 0].workflow_node.launchplan_ref.name == "app.other.flyte_entity" else: assert wf_spec.template.nodes[ 0].task_node.reference_id.project == "proj" assert wf_spec.template.nodes[ 0].task_node.reference_id.name == "app.other.flyte_entity"
def test_calling_lp(): sub_wf_lp = LaunchPlan.get_or_create(sub_wf) serialized = OrderedDict() lp_model = get_serializable(serialized, serialization_settings, sub_wf_lp) task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) for wf_id, spec in wf_specs.items(): break remote_lp = FlyteLaunchPlan.promote_from_model(lp_model.id, lp_model.spec) # To pretend that we've fetched this launch plan from Admin, also fill in the Flyte interface, which isn't # part of the IDL object but is something FlyteRemote does remote_lp._interface = TypedInterface.promote_from_model( spec.template.interface) serialized = OrderedDict() @workflow def wf2(a: int) -> typing.Tuple[int, str]: return remote_lp(a=a, b="hello") wf_spec = get_serializable(serialized, serialization_settings, wf2) print(wf_spec.template.nodes[0].workflow_node.launchplan_ref) assert wf_spec.template.nodes[ 0].workflow_node.launchplan_ref == lp_model.id
def test_serialization_named_outputs_single(): @task def t1() -> typing.NamedTuple("OP", a=str): return "Hello" @workflow def wf() -> typing.NamedTuple("OP", a=str): return t1().a wf_spec = get_serializable(OrderedDict(), serialization_settings, wf) assert len(wf_spec.template.interface.outputs) == 1 assert list(wf_spec.template.interface.outputs.keys()) == ["a"] a = wf() assert a.a == "Hello"
def test_serialization_types(): @task(cache=True, cache_version="1.0.0") def squared(value: int) -> typing.List[typing.Dict[str, int]]: return [ { "squared_value": value**2 }, ] @workflow def compute_square_wf( input_integer: int) -> typing.List[typing.Dict[str, int]]: compute_square_result = squared(value=input_integer) return compute_square_result wf_spec = get_serializable(OrderedDict(), serialization_settings, compute_square_wf) assert wf_spec.template.interface.outputs[ "o0"].type.collection_type.map_value_type.simple == SimpleType.INTEGER task_spec = get_serializable(OrderedDict(), serialization_settings, squared) assert task_spec.template.interface.outputs[ "o0"].type.collection_type.map_value_type.simple == SimpleType.INTEGER
def test_all_node_types(): assert my_wf_example(a=1) == (6, 16) entity_mapping = OrderedDict() model_wf = get_serializable(entity_mapping, serialization_settings, my_wf_example) assert len(model_wf.template.interface.outputs) == 2 assert len(model_wf.template.nodes) == 4 assert model_wf.template.nodes[2].workflow_node is not None sub_wf = model_wf.sub_workflows[0] assert len(sub_wf.nodes) == 1 assert sub_wf.nodes[0].id == "n0" assert sub_wf.nodes[0].task_node.reference_id.name == "tests.flytekit.unit.core.test_workflows.add_5"