def test_serialization_branch_complex(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): return a + 2, "world" @task def t2(a: str) -> str: return a @workflow def my_wf(a: int, b: str) -> (int, str): x, y = t1(a=a) d = (conditional("test1").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then( t2(a=y)).else_().fail("Unable to choose branch")) f = conditional("test2").if_(d == "hello ").then( t2(a="It is hello")).else_().then(t2(a="Not Hello!")) return x, f default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf = get_serializable(serialization_settings, my_wf) assert wf is not None assert len(wf.nodes) == 3 assert wf.nodes[1].branch_node is not None assert wf.nodes[2].branch_node is not None
def test_serialization_branch(): @task def mimic(a: int) -> typing.NamedTuple("OutputsBC", c=int): return (a, ) @task def t1(c: int) -> typing.NamedTuple("OutputsBC", c=str): return ("world", ) @task def t2() -> typing.NamedTuple("OutputsBC", c=str): return ("hello", ) @workflow def my_wf(a: int) -> str: c = mimic(a=a) return conditional("test1").if_(c.c == 4).then( t1(c=c.c).c).else_().then(t2().c) assert my_wf(a=4) == "world" assert my_wf(a=2) == "hello" default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf = get_serializable(serialization_settings, my_wf) assert wf is not None assert len(wf.nodes) == 2 assert wf.nodes[1].branch_node is not None
def test_serialization_branch_sub_wf(): @task def t1(a: int) -> int: return a + 2 @workflow def my_sub_wf(a: int) -> int: return t1(a=a) @workflow def my_wf(a: int) -> int: d = conditional("test1").if_(a > 3).then(t1(a=a)).else_().then( my_sub_wf(a=a)) return d default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf = get_serializable(serialization_settings, my_wf) assert wf is not None assert len(wf.nodes[0].inputs) == 1 assert wf.nodes[0].inputs[0].var == ".a" assert wf.nodes[0] is not None
def test_workflow_values(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @workflow(interruptible=True, failure_policy=WorkflowFailurePolicy. FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) def wf(a: int) -> (str, str): x, y = t1(a=a) u, v = t1(a=x) return y, v serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) sdk_wf = get_serializable(serialization_settings, wf) assert sdk_wf.metadata_defaults.interruptible assert sdk_wf.metadata.on_failure == 1
def test_query_no_inputs_or_outputs(): hive_task = HiveTask( name="flytekit.demo.hive_task.hivequery1", inputs={}, task_config=HiveConfig(cluster_label="flyte"), query_template=""" insert into extant_table (1, 'two') """, output_schema_type=None, ) @workflow def my_wf(): hive_task() default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) sdk_task = get_serializable(serialization_settings, hive_task) assert len(sdk_task.interface.inputs) == 0 assert len(sdk_task.interface.outputs) == 0 get_serializable(serialization_settings, my_wf)
def test_ref(): @reference_task( project="flytesnacks", domain="development", name="recipes.aaa.simple.join_strings", version="553018f39e519bdb2597b652639c30ce16b99c79", ) def ref_t1(a: typing.List[str]) -> str: ... assert ref_t1.id.project == "flytesnacks" assert ref_t1.id.domain == "development" assert ref_t1.id.name == "recipes.aaa.simple.join_strings" assert ref_t1.id.version == "553018f39e519bdb2597b652639c30ce16b99c79" serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ss = get_serializable(serialization_settings, ref_t1) assert ss.id == ref_t1.id assert ss.interface.inputs["a"] is not None assert ss.interface.outputs["o0"] is not None serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) sdk_task = get_serializable(serialization_settings, ref_t1) assert sdk_task.has_registered assert sdk_task.id.project == "flytesnacks" assert sdk_task.id.domain == "development" assert sdk_task.id.name == "recipes.aaa.simple.join_strings" assert sdk_task.id.version == "553018f39e519bdb2597b652639c30ce16b99c79"
def test_serialization_images(): @task(container_image="{{.image.xyz.fqn}}:{{.image.default.version}}") def t1(a: int) -> int: return a @task(container_image="{{.image.default.fqn}}:{{.image.default.version}}") def t2(): pass @task def t3(): pass @task(container_image="docker.io/org/myimage:latest") def t4(): pass @task(container_image="docker.io/org/myimage:{{.image.default.version}}") def t5(a: int) -> int: return a os.environ["FLYTE_INTERNAL_IMAGE"] = "docker.io/default:version" set_flyte_config_file( os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config")) rs = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=get_image_config(), ) t1_ser = get_serializable(rs, t1) assert t1_ser.container.image == "docker.io/xyz:version" t1_ser.to_flyte_idl() t2_ser = get_serializable(rs, t2) assert t2_ser.container.image == "docker.io/default:version" t3_ser = get_serializable(rs, t3) assert t3_ser.container.image == "docker.io/default:version" t4_ser = get_serializable(rs, t4) assert t4_ser.container.image == "docker.io/org/myimage:latest" t5_ser = get_serializable(rs, t5) assert t5_ser.container.image == "docker.io/org/myimage:version"
def test_serialization(): square = ContainerTask( name="square", input_data_dir="/var/inputs", output_data_dir="/var/outputs", inputs=kwtypes(val=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out" ], ) sum = ContainerTask( name="sum", input_data_dir="/var/flyte/inputs", output_data_dir="/var/flyte/outputs", inputs=kwtypes(x=int, y=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out" ], ) @workflow def raw_container_wf(val1: int, val2: int) -> int: return sum(x=square(val=val1), y=square(val=val2)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf = get_serializable(serialization_settings, raw_container_wf) assert wf is not None assert len(wf.nodes) == 3 sqn = get_serializable(serialization_settings, square) assert sqn.container.image == "alpine" sumn = get_serializable(serialization_settings, sum) assert sumn.container.image == "alpine"
def test_dynamic_pod_task(): dynamic_pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") @task def t1(a: int) -> int: return a + 10 @dynamic(task_config=dynamic_pod) def dynamic_pod_task(a: int) -> List[int]: s = [] for i in range(a): s.append(t1(a=i)) return s assert isinstance(dynamic_pod_task, PodFunctionTask) default_img = Image(name="default", fqn="test", tag="tag") custom = dynamic_pod_task.get_custom( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert len(custom["podSpec"]["containers"]) == 2 with context_manager.FlyteContext.current_context( ).new_serialization_settings( serialization_settings=context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, )) as ctx: with ctx.new_execution_context( mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: dynamic_job_spec = dynamic_pod_task.compile_into_workflow( ctx, dynamic_pod_task._task_function, a=5) assert len(dynamic_job_spec._nodes) == 5
def test_lp_with_output(): ref_lp = get_reference_entity( _identifier_model.ResourceType.LAUNCH_PLAN, "proj", "dom", "app.other.flyte_entity", "123", inputs=kwtypes(a=str, b=int), outputs=kwtypes(x=bool, y=int), ) @task def t1() -> (str, int): return "hello", 88 @task def t2(q: bool, r: int) -> str: return f"q: {q} r: {r}" @workflow def wf1() -> str: t1_str, t1_int = t1() x_out, y_out = ref_lp(a=t1_str, b=t1_int) return t2(q=x_out, r=y_out) @patch(ref_lp) def inner_test(ref_mock): ref_mock.return_value = (False, 30) x = wf1() assert x == "q: False r: 30" ref_mock.assert_called_with(a="hello", b=88) inner_test() serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) sdk_wf = get_serializable(serialization_settings, wf1) assert sdk_wf.nodes[1].workflow_node.launchplan_ref.project == "proj" assert sdk_wf.nodes[1].workflow_node.launchplan_ref.name == "app.other.flyte_entity"
def test_lps(resource_type): ref_entity = get_reference_entity( resource_type, "proj", "dom", "app.other.flyte_entity", "123", inputs=kwtypes(a=str, b=int), outputs={}, ) ctx = context_manager.FlyteContext.current_context() with pytest.raises(Exception) as e: ref_entity() assert "You must mock this out" in f"{e}" with ctx.new_compilation_context() as ctx: with pytest.raises(Exception) as e: ref_entity() assert "Input was not specified" in f"{e}" output = ref_entity(a="hello", b=3) assert isinstance(output, VoidPromise) @workflow def wf1(a: str, b: int): ref_entity(a=a, b=b) serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) sdk_wf = get_serializable(serialization_settings, wf1) assert len(sdk_wf.interface.inputs) == 2 assert len(sdk_wf.interface.outputs) == 0 assert len(sdk_wf.nodes) == 1 if resource_type == _identifier_model.ResourceType.LAUNCH_PLAN: assert sdk_wf.nodes[0].workflow_node.launchplan_ref.project == "proj" assert sdk_wf.nodes[0].workflow_node.launchplan_ref.name == "app.other.flyte_entity" elif resource_type == _identifier_model.ResourceType.WORKFLOW: assert sdk_wf.nodes[0].workflow_node.sub_workflow_ref.project == "proj" assert sdk_wf.nodes[0].workflow_node.sub_workflow_ref.name == "app.other.flyte_entity" else: assert sdk_wf.nodes[0].task_node.reference_id.project == "proj" assert sdk_wf.nodes[0].task_node.reference_id.name == "app.other.flyte_entity"
def test_environment(): @task(environment={"FOO": "foofoo", "BAZ": "baz"}) def t1(a: int) -> str: a = a + 2 return "now it's " + str(a) @workflow def my_wf(a: int) -> str: x = t1(a=a) return x serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={"FOO": "foo", "BAR": "bar"}, ) with context_manager.FlyteContext.current_context().new_compilation_context(): sdk_task = get_serializable(serialization_settings, t1) assert sdk_task.container.env == {"FOO": "foofoo", "BAR": "bar", "BAZ": "baz"}
def test_serialization(): hive_task = HiveTask( name="flytekit.demo.hive_task.hivequery1", inputs=kwtypes(my_schema=FlyteSchema, ds=str), config=HiveConfig(cluster_label="flyte"), query_template=""" set engine=tez; insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet -- will be unique per retry select * from blah where ds = '{{ .Inputs.ds }}' and uri = '{{ .inputs.my_schema }}' """, # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @workflow def my_wf(in_schema: FlyteSchema, ds: str) -> FlyteSchema: return hive_task(my_schema=in_schema, ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) sdk_task = get_serializable(serialization_settings, hive_task) assert "{{ .rawOutputDataPrefix" in sdk_task.custom["query"]["query"] assert "insert overwrite directory" in sdk_task.custom["query"]["query"] assert len(sdk_task.interface.inputs) == 2 assert len(sdk_task.interface.outputs) == 1 sdk_wf = get_serializable(serialization_settings, my_wf) assert sdk_wf.interface.outputs["o0"].type.schema is not None assert sdk_wf.outputs[0].var == "o0" assert sdk_wf.outputs[0].binding.promise.node_id == "n0" assert sdk_wf.outputs[0].binding.promise.var == "results"
def test_lp_serialize(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @task def t2(a: str, b: str) -> str: return b + a @workflow def my_subwf(a: int) -> (str, str): x, y = t1(a=a) u, v = t1(a=x) return y, v lp = launch_plan.LaunchPlan.create("serialize_test1", my_subwf) lp_with_defaults = launch_plan.LaunchPlan.create("serialize_test2", my_subwf, default_inputs={"a": 3}) serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) sdk_lp = get_serializable(serialization_settings, lp) assert len(sdk_lp.default_inputs.parameters) == 0 assert len(sdk_lp.fixed_inputs.literals) == 0 sdk_lp = get_serializable(serialization_settings, lp_with_defaults) assert len(sdk_lp.default_inputs.parameters) == 1 assert len(sdk_lp.fixed_inputs.literals) == 0 # Adding a check to make sure oneof is respected. Tricky with booleans... if a default is specified, the # required field needs to be None, not False. parameter_a = sdk_lp.default_inputs.parameters["a"] parameter_a = Parameter.from_flyte_idl(parameter_a.to_flyte_idl()) assert parameter_a.default is not None
def test_wf1_with_dynamic(): @task def t1(a: int) -> str: a = a + 2 return "world-" + str(a) @task def t2(a: str, b: str) -> str: return b + a @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(t1(a=i)) return s @workflow def my_wf(a: int, b: str) -> (str, typing.List[str]): x = t2(a=b, b=b) v = my_subwf(a=a) return x, v v = 5 x = my_wf(a=v, b="hello ") assert x == ("hello hello ", ["world-" + str(i) for i in range(2, v + 2)]) with context_manager.FlyteContext.current_context().new_serialization_settings( serialization_settings=context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ) as ctx: with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: dynamic_job_spec = my_subwf.compile_into_workflow(ctx, my_subwf._task_function, a=5) assert len(dynamic_job_spec._nodes) == 5
def test_resources(): @task(requests=Resources(cpu="1"), limits=Resources(cpu="2", mem="400M")) def t1(a: int) -> str: a = a + 2 return "now it's " + str(a) @task(requests=Resources(cpu="3")) def t2(a: int) -> str: a = a + 200 return "now it's " + str(a) @workflow def my_wf(a: int) -> str: x = t1(a=a) return x serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) with context_manager.FlyteContext.current_context().new_compilation_context(): sdk_task = get_serializable(serialization_settings, t1) assert sdk_task.container.resources.requests == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "1") ] assert sdk_task.container.resources.limits == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "2"), _resource_models.ResourceEntry(_resource_models.ResourceName.MEMORY, "400M"), ] sdk_task2 = get_serializable(serialization_settings, t2) assert sdk_task2.container.resources.requests == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "3") ] assert sdk_task2.container.resources.limits == []
def test_dynamic_conditional(): @task def split( in1: typing.List[int] ) -> (typing.List[int], typing.List[int], int): return in1[0:int(len(in1) / 2)], in1[int(len(in1) / 2) + 1:], len(in1) / 2 # One sample implementation for merging. In a more real world example, this might merge file streams and only load # chunks into the memory. @task def merge(x: typing.List[int], y: typing.List[int]) -> typing.List[int]: n1 = len(x) n2 = len(y) result = list[int]() i = 0 j = 0 # Traverse both array while i < n1 and j < n2: # Check if current element of first array is smaller than current element of second array. If yes, # store first array element and increment first array index. Otherwise do same with second array if x[i] < y[j]: result.append(x[i]) i = i + 1 else: result.append(y[j]) j = j + 1 # Store remaining elements of first array while i < n1: result.append(x[i]) i = i + 1 # Store remaining elements of second array while j < n2: result.append(y[j]) j = j + 1 return result # This runs the sorting completely locally. It's faster and more efficient to do so if the entire list fits in memory. @task def merge_sort_locally(in1: typing.List[int]) -> typing.List[int]: return sorted(in1) @task def also_merge_sort_locally(in1: typing.List[int]) -> typing.List[int]: return sorted(in1) @dynamic def merge_sort_remotely(in1: typing.List[int]) -> typing.List[int]: x, y, new_count = split(in1=in1) sorted_x = merge_sort(in1=x, count=new_count) sorted_y = merge_sort(in1=y, count=new_count) return merge(x=sorted_x, y=sorted_y) @workflow def merge_sort(in1: typing.List[int], count: int) -> typing.List[int]: return (conditional("terminal_case").if_(count < 500).then( merge_sort_locally(in1=in1)).elif_(count < 1000).then( also_merge_sort_locally(in1=in1)).else_().then( merge_sort_remotely(in1=in1))) with context_manager.FlyteContext.current_context( ).new_serialization_settings( serialization_settings=context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, )) as ctx: with ctx.new_execution_context( mode=ExecutionState.Mode.TASK_EXECUTION) as ctx: dynamic_job_spec = merge_sort_remotely.compile_into_workflow( ctx, merge_sort_remotely._task_function, in1=[2, 3, 4, 5]) assert len(dynamic_job_spec.tasks) == 5
def serialize_all( pkgs: List[str] = None, local_source_root: str = None, folder: str = None, mode: SerializationMode = None, image: str = None, config_path: str = None, flytekit_virtualenv_root: str = None, ): """ In order to register, we have to comply with Admin's endpoints. Those endpoints take the following objects. These flyteidl.admin.launch_plan_pb2.LaunchPlanSpec flyteidl.admin.workflow_pb2.WorkflowSpec flyteidl.admin.task_pb2.TaskSpec However, if we were to merely call .to_flyte_idl() on all the discovered entities, what we would get are: flyteidl.admin.launch_plan_pb2.LaunchPlanSpec flyteidl.core.workflow_pb2.WorkflowTemplate flyteidl.core.tasks_pb2.TaskTemplate For Workflows and Tasks therefore, there is special logic in the serialize function that translates these objects. :param list[Text] pkgs: :param Text folder: :return: """ # m = module (i.e. python file) # k = value of dir(m), type str # o = object (e.g. SdkWorkflow) env = { _internal_config.CONFIGURATION_PATH.env_var: config_path if config_path else _internal_config.CONFIGURATION_PATH.get(), _internal_config.IMAGE.env_var: image, } serialization_settings = flyte_context.SerializationSettings( project=_PROJECT_PLACEHOLDER, domain=_DOMAIN_PLACEHOLDER, version=_VERSION_PLACEHOLDER, image_config=flyte_context.get_image_config(img_name=image), env=env, flytekit_virtualenv_root=flytekit_virtualenv_root, entrypoint_settings=flyte_context.EntrypointSettings( path=_os.path.join(flytekit_virtualenv_root, _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC) ), ) with flyte_context.FlyteContext.current_context().new_serialization_settings( serialization_settings=serialization_settings ) as ctx: loaded_entities = [] for m, k, o in iterate_registerable_entities_in_order(pkgs, local_source_root=local_source_root): name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type) _logging.debug("Found module {}\n K: {} Instantiated in {}".format(m, k, o._instantiated_in)) o._id = _identifier.Identifier( o.resource_type, _PROJECT_PLACEHOLDER, _DOMAIN_PLACEHOLDER, name, _VERSION_PLACEHOLDER ) loaded_entities.append(o) ctx.serialization_settings.add_instance_var(InstanceVar(module=m, name=k, o=o)) click.echo(f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows") mode = mode if mode else SerializationMode.DEFAULT # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan # object, which gets added to the FlyteEntities.entities list, which we're iterating over. for entity in flyte_context.FlyteEntities.entities.copy(): # TODO: Add a reachable check. Since these entities are always added by the constructor, weird things can # happen. If someone creates a workflow inside a workflow, we don't actually want the inner workflow to be # registered. Or do we? Certainly, we don't want inner tasks to be registered because we don't know how # to reach them, but perhaps workflows should be okay to take into account generated workflows. # Also a user may import dir_b.workflows from dir_a.workflows but workflow packages might only # specify dir_a if isinstance(entity, PythonTask) or isinstance(entity, Workflow) or isinstance(entity, LaunchPlan): if isinstance(entity, PythonTask): if mode == SerializationMode.DEFAULT: serializable = get_serializable(ctx.serialization_settings, entity) elif mode == SerializationMode.FAST: serializable = get_serializable(ctx.serialization_settings, entity, fast=True) else: raise AssertionError(f"Unrecognized serialization mode: {mode}") else: serializable = get_serializable(ctx.serialization_settings, entity) loaded_entities.append(serializable) if isinstance(entity, Workflow): lp = LaunchPlan.get_default_launch_plan(ctx, entity) launch_plan = get_serializable(ctx.serialization_settings, lp) loaded_entities.append(launch_plan) zero_padded_length = _determine_text_chars(len(loaded_entities)) for i, entity in enumerate(loaded_entities): if entity.has_registered: _logging.info(f"Skipping entity {entity.id} because already registered") continue serialized = entity.serialize() fname_index = str(i).zfill(zero_padded_length) fname = "{}_{}.pb".format(fname_index, entity.id.name) click.echo(f" Writing type: {entity.id.resource_type_name()}, {entity.id.name} to\n {fname}") if folder: fname = _os.path.join(folder, fname) _write_proto_to_file(serialized, fname) # Not everything serialized will necessarily have an identifier field in it, even though some do (like the # TaskTemplate). To be more rigorous, we write an explicit identifier file that reflects the choices (like # project/domain, etc.) made for this serialize call. We should not allow users to specify a different project # for instance come registration time, to avoid mismatches between potential internal ids like the TaskTemplate # and the registered entity. identifier_fname = "{}_{}.identifier.pb".format(fname_index, entity._id.name) if folder: identifier_fname = _os.path.join(folder, identifier_fname) _write_proto_to_file(entity._id.to_flyte_idl(), identifier_fname)
def test_normal_task(): @task def t1(a: str) -> str: return a + " world" @workflow def my_wf(a: str) -> str: t1_node = create_node(t1, a=a) return t1_node.o0 r = my_wf(a="hello") assert r == "hello world" serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) sdk_wf = get_serializable(serialization_settings, my_wf) assert len(sdk_wf.nodes) == 1 assert len(sdk_wf.outputs) == 1 @task def t2(): ... @task def t3(): ... @workflow def empty_wf(): t2_node = create_node(t2) t3_node = create_node(t3) t3_node.runs_before(t2_node) # Test that VoidPromises can handle runs_before empty_wf() @workflow def empty_wf2(): t2_node = create_node(t2) t3_node = create_node(t3) t3_node >> t2_node serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) sdk_wf = get_serializable(serialization_settings, empty_wf) assert sdk_wf.nodes[0].upstream_node_ids[0] == "n1" assert sdk_wf.nodes[0].id == "n0" sdk_wf = get_serializable(serialization_settings, empty_wf2) assert sdk_wf.nodes[0].upstream_node_ids[0] == "n1" assert sdk_wf.nodes[0].id == "n0"
import typing from flytekit import ContainerTask from flytekit.annotated import context_manager from flytekit.annotated.base_task import kwtypes from flytekit.annotated.context_manager import Image, ImageConfig from flytekit.annotated.launch_plan import LaunchPlan, ReferenceLaunchPlan from flytekit.annotated.task import ReferenceTask, task from flytekit.annotated.workflow import ReferenceWorkflow, workflow from flytekit.common.translator import get_serializable default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) def test_references(): rlp = ReferenceLaunchPlan("media", "stg", "some.name", "cafe", inputs=kwtypes(in1=str), outputs=kwtypes()) sdk_lp = get_serializable(serialization_settings, rlp) assert sdk_lp.has_registered