def test_imperative_call_from_normal(): @task def t1(a: str) -> str: return a + " world" wb = ImperativeWorkflow(name="my.workflow") wb.add_workflow_input("in1", str) node = wb.add_entity(t1, a=wb.inputs["in1"]) wb.add_workflow_output("from_n0t1", node.outputs["o0"]) assert wb(in1="hello") == "hello world" @workflow def my_functional_wf(a: str) -> str: x = wb(in1=a) return x assert my_functional_wf(a="hello") == "hello world" # Create launch plan from wf lp = LaunchPlan.create("test_wb_2", wb, fixed_inputs={"in1": "hello"}) @workflow def my_functional_wf_lp() -> str: x = lp() return x assert my_functional_wf_lp() == "hello world"
def test_imperative(): @task def t1(a: str) -> str: return a + " world" @task def t2(): print("side effect") wb = ImperativeWorkflow(name="my.workflow") wb.add_workflow_input("in1", str) node = wb.add_entity(t1, a=wb.inputs["in1"]) wb.add_entity(t2) wb.add_workflow_output("from_n0t1", node.outputs["o0"]) assert wb(in1="hello") == "hello world" srz_wf = get_serializable(OrderedDict(), serialization_settings, wb) assert len(srz_wf.nodes) == 2 assert srz_wf.nodes[0].task_node is not None assert len(srz_wf.outputs) == 1 assert srz_wf.outputs[0].var == "from_n0t1" assert len(srz_wf.interface.inputs) == 1 assert len(srz_wf.interface.outputs) == 1 # Create launch plan from wf, that can also be serialized. lp = LaunchPlan.create("test_wb", wb) srz_lp = get_serializable(OrderedDict(), serialization_settings, lp) assert srz_lp.workflow_id.name == "my.workflow"
def test_basics(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): return a + 2, "world" @task def t2(a: str, b: str) -> str: return b + a @workflow def my_wf(a: int, b: str) -> (int, str): x, y = t1(a=a) d = t2(a=y, b=b) return x, d sdk_wf = get_serializable(OrderedDict(), serialization_settings, my_wf, False) assert len(sdk_wf.interface.inputs) == 2 assert len(sdk_wf.interface.outputs) == 2 assert len(sdk_wf.nodes) == 2 # Gets cached the first time around so it's not actually fast. sdk_task = get_serializable(OrderedDict(), serialization_settings, t1, True) assert "pyflyte-execute" in sdk_task.container.args lp = LaunchPlan.create( "testlp", my_wf, ) sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp) assert sdk_lp.id.name == "testlp"
def test_execute_python_workflow_and_launch_plan(flyteclient, flyte_workflows_register, flyte_remote_env): """Test execution of a @workflow-decorated python function and launchplan that are already registered.""" from mock_flyte_repo.workflows.basic.basic_workflow import my_wf # make sure the task name is the same as the name used during registration my_wf._name = my_wf.name.replace("mock_flyte_repo.", "") remote = FlyteRemote(Config.auto(), PROJECT, "development") execution = remote.execute(my_wf, inputs={ "a": 10, "b": "xyz" }, version=f"v{VERSION}", wait=True) assert execution.outputs["o0"] == 12 assert execution.outputs["o1"] == "xyzworld" launch_plan = LaunchPlan.get_or_create(workflow=my_wf, name=my_wf.name) execution = remote.execute(launch_plan, inputs={ "a": 14, "b": "foobar" }, version=f"v{VERSION}", wait=True) assert execution.outputs["o0"] == 16 assert execution.outputs["o1"] == "foobarworld" flyte_workflow_execution = remote.fetch_execution(name=execution.id.name) assert execution.inputs == flyte_workflow_execution.inputs assert execution.outputs == flyte_workflow_execution.outputs
def test_launch_plan_with_fixed_input(): @task def greet(day_of_week: str, number: int, am: bool) -> str: greeting = "Have a great " + day_of_week + " " greeting += "morning" if am else "evening" return greeting + "!" * number @workflow def go_greet(day_of_week: str, number: int, am: bool = False) -> str: return greet(day_of_week=day_of_week, number=number, am=am) morning_greeting = LaunchPlan.create( "morning_greeting", go_greet, fixed_inputs={"am": True}, default_inputs={"number": 1}, ) @workflow def morning_greeter_caller(day_of_week: str) -> str: greeting = morning_greeting(day_of_week=day_of_week) return greeting settings = ( serialization_settings.new_builder().with_fast_serialization_settings( FastSerializationSettings(enabled=True)).build()) task_spec = get_serializable(OrderedDict(), settings, morning_greeter_caller) assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 assert len(task_spec.template.nodes) == 1 assert len(task_spec.template.nodes[0].inputs) == 2
def test_execute_python_workflow_dict_of_string_to_string( flyteclient, flyte_workflows_register, flyte_remote_env): """Test execution of a @workflow-decorated python function and launchplan that are already registered.""" from mock_flyte_repo.workflows.basic.dict_str_wf import my_wf # make sure the task name is the same as the name used during registration my_wf._name = my_wf.name.replace("mock_flyte_repo.", "") remote = FlyteRemote(Config.auto(), PROJECT, "development") d: typing.Dict[str, str] = {"k1": "v1", "k2": "v2"} execution = remote.execute(my_wf, inputs={"d": d}, version=f"v{VERSION}", wait=True) assert json.loads(execution.outputs["o0"]) == {"k1": "v1", "k2": "v2"} launch_plan = LaunchPlan.get_or_create(workflow=my_wf, name=my_wf.name) execution = remote.execute(launch_plan, inputs={"d": { "k2": "vvvv", "abc": "def" }}, version=f"v{VERSION}", wait=True) assert json.loads(execution.outputs["o0"]) == {"k2": "vvvv", "abc": "def"}
def test_call_normal(): @task def t1(a: int) -> (int, str): return a + 2, "world" @workflow def my_functional_wf(a: int) -> (int, str): return t1(a=a) my_functional_lp = LaunchPlan.create("my_functional_wf.lp0", my_functional_wf, default_inputs={"a": 3}) wb = ImperativeWorkflow(name="imperio") node = wb.add_entity(my_functional_wf, a=3) wb.add_workflow_output("from_n0_1", node.outputs["o0"]) wb.add_workflow_output("from_n0_2", node.outputs["o1"]) assert wb() == (5, "world") wb_lp = ImperativeWorkflow(name="imperio") node = wb_lp.add_entity(my_functional_lp) wb_lp.add_workflow_output("from_n0_1", node.outputs["o0"]) wb_lp.add_workflow_output("from_n0_2", node.outputs["o1"]) assert wb_lp() == (5, "world")
def test_basics(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): return a + 2, "world" @task def t2(a: str, b: str) -> str: return b + a @workflow def my_wf(a: int, b: str) -> (int, str): x, y = t1(a=a) d = t2(a=y, b=b) return x, d wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf, False) assert len(wf_spec.template.interface.inputs) == 2 assert len(wf_spec.template.interface.outputs) == 2 assert len(wf_spec.template.nodes) == 2 assert wf_spec.template.id.resource_type == identifier_models.ResourceType.WORKFLOW # Gets cached the first time around so it's not actually fast. task_spec = get_serializable(OrderedDict(), serialization_settings, t1, True) assert "pyflyte-execute" in task_spec.template.container.args lp = LaunchPlan.create( "testlp", my_wf, ) lp_model = get_serializable(OrderedDict(), serialization_settings, lp) assert lp_model.id.name == "testlp"
def test_imperative(): @task def t1(a: str) -> str: return a + " world" @task def t2(): print("side effect") wb = ImperativeWorkflow(name="my.workflow") wb.add_workflow_input("in1", str) node = wb.add_entity(t1, a=wb.inputs["in1"]) wb.add_entity(t2) wb.add_workflow_output("from_n0t1", node.outputs["o0"]) assert wb(in1="hello") == "hello world" wf_spec = get_serializable(OrderedDict(), serialization_settings, wb) assert len(wf_spec.template.nodes) == 2 assert wf_spec.template.nodes[0].task_node is not None assert len(wf_spec.template.outputs) == 1 assert wf_spec.template.outputs[0].var == "from_n0t1" assert len(wf_spec.template.interface.inputs) == 1 assert len(wf_spec.template.interface.outputs) == 1 # Create launch plan from wf, that can also be serialized. lp = LaunchPlan.create("test_wb", wb) lp_model = get_serializable(OrderedDict(), serialization_settings, lp) assert lp_model.spec.workflow_id.name == "my.workflow" wb2 = ImperativeWorkflow(name="parent.imperative") p_in1 = wb2.add_workflow_input("p_in1", str) p_node0 = wb2.add_subwf(wb, in1=p_in1) wb2.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str) wb2_spec = get_serializable(OrderedDict(), serialization_settings, wb2) assert len(wb2_spec.template.nodes) == 1 assert len(wb2_spec.template.interface.inputs) == 1 assert wb2_spec.template.interface.inputs["p_in1"].type.simple is not None assert len(wb2_spec.template.interface.outputs) == 1 assert wb2_spec.template.interface.outputs[ "parent_wf_output"].type.simple is not None assert wb2_spec.template.nodes[ 0].workflow_node.sub_workflow_ref.name == "my.workflow" assert len(wb2_spec.sub_workflows) == 1 wb3 = ImperativeWorkflow(name="parent.imperative") p_in1 = wb3.add_workflow_input("p_in1", str) p_node0 = wb3.add_launch_plan(lp, in1=p_in1) wb3.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str) wb3_spec = get_serializable(OrderedDict(), serialization_settings, wb3) assert len(wb3_spec.template.nodes) == 1 assert len(wb3_spec.template.interface.inputs) == 1 assert wb3_spec.template.interface.inputs["p_in1"].type.simple is not None assert len(wb3_spec.template.interface.outputs) == 1 assert wb3_spec.template.interface.outputs[ "parent_wf_output"].type.simple is not None assert wb3_spec.template.nodes[ 0].workflow_node.launchplan_ref.name == "test_wb"
def test_lp_from_ref_wf(): @reference_workflow(project="project", domain="domain", name="name", version="version") def ref_wf1(p1: str, p2: str) -> None: ... lp = LaunchPlan.create("reference-wf-12345", ref_wf1, fixed_inputs={"p1": "p1-value", "p2": "p2-value"}) assert lp.name == "reference-wf-12345" assert lp.workflow == ref_wf1 assert lp.workflow.id.name == "name" assert lp.workflow.id.project == "project" assert lp.workflow.id.domain == "domain" assert lp.workflow.id.version == "version"
def test_file_handling_remote_default_wf_input(): SAMPLE_DATA = "https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv" @task def t1(fname: os.PathLike) -> int: with open(fname, "r") as fh: x = len(fh.readlines()) return x @workflow def my_wf(fname: os.PathLike = SAMPLE_DATA) -> int: length = t1(fname=fname) return length assert my_wf.python_interface.inputs_with_defaults["fname"][1] == SAMPLE_DATA sample_lp = LaunchPlan.create("test_launch_plan", my_wf) assert sample_lp.parameters.parameters["fname"].default.scalar.blob.uri == SAMPLE_DATA
def test_schedule_with_lp(): @task def double(a: int) -> int: return a * 2 @workflow def quadruple(a: int) -> int: b = double(a=a) c = double(a=b) return c lp = LaunchPlan.create( "schedule_test", quadruple, schedule=FixedRate(_datetime.timedelta(hours=12), "kickoff_input"), ) assert lp.schedule == _schedule_models.Schedule( "kickoff_input", rate=_schedule_models.Schedule.FixedRate( 12, _schedule_models.Schedule.FixedRateUnit.HOUR))
def test_execute_python_workflow_list_of_floats(flyteclient, flyte_workflows_register, flyte_remote_env): """Test execution of a @workflow-decorated python function and launchplan that are already registered.""" from mock_flyte_repo.workflows.basic.list_float_wf import my_wf # make sure the task name is the same as the name used during registration my_wf._name = my_wf.name.replace("mock_flyte_repo.", "") remote = FlyteRemote(Config.auto(), PROJECT, "development") xs: typing.List[float] = [42.24, 999.1, 0.0001] execution = remote.execute(my_wf, inputs={"xs": xs}, version=f"v{VERSION}", wait=True) assert execution.outputs["o0"] == "[42.24, 999.1, 0.0001]" launch_plan = LaunchPlan.get_or_create(workflow=my_wf, name=my_wf.name) execution = remote.execute(launch_plan, inputs={"xs": [-1.1, 0.12345]}, version=f"v{VERSION}", wait=True) assert execution.outputs["o0"] == "[-1.1, 0.12345]"
def test_with_launch_plan(): @task def double(a: int) -> int: return a * 2 @workflow def quadruple(a: int) -> int: b = double(a=a) c = double(a=b) return c lp = LaunchPlan.create( "notif_test", quadruple, notifications=[ notification.Email(phases=[_workflow_execution_succeeded], recipients_email=["*****@*****.**"]) ], ) assert lp.notifications == [ notification.Email(phases=[_workflow_execution_succeeded], recipients_email=["*****@*****.**"]) ]
def get_registrable_entities(ctx: flyte_context.FlyteContext) -> typing.List: """ Returns all entities that can be serialized and should be sent over to Flyte backend. This will filter any entities that are not known to Admin """ new_api_serializable_entities = OrderedDict() # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan # object, which gets added to the FlyteEntities.entities list, which we're iterating over. for entity in flyte_context.FlyteEntities.entities.copy(): if isinstance(entity, PythonTask) or isinstance( entity, WorkflowBase) or isinstance(entity, LaunchPlan): get_serializable(new_api_serializable_entities, ctx.serialization_settings, entity) if isinstance(entity, WorkflowBase): lp = LaunchPlan.get_default_launch_plan(ctx, entity) get_serializable(new_api_serializable_entities, ctx.serialization_settings, lp) new_api_model_values = list(new_api_serializable_entities.values()) entities_to_be_serialized = list( filter(_should_register_with_admin, new_api_model_values)) return [v.to_flyte_idl() for v in entities_to_be_serialized]
def test_calling_lp(): sub_wf_lp = LaunchPlan.get_or_create(sub_wf) serialized = OrderedDict() lp_model = get_serializable(serialized, serialization_settings, sub_wf_lp) task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) for wf_id, spec in wf_specs.items(): break remote_lp = FlyteLaunchPlan.promote_from_model(lp_model.id, lp_model.spec) # To pretend that we've fetched this launch plan from Admin, also fill in the Flyte interface, which isn't # part of the IDL object but is something FlyteRemote does remote_lp._interface = TypedInterface.promote_from_model( spec.template.interface) serialized = OrderedDict() @workflow def wf2(a: int) -> typing.Tuple[int, str]: return remote_lp(a=a, b="hello") wf_spec = get_serializable(serialized, serialization_settings, wf2) print(wf_spec.template.nodes[0].workflow_node.launchplan_ref) assert wf_spec.template.nodes[ 0].workflow_node.launchplan_ref == lp_model.id
def test_imperative(): # Re import with alias from flytekit.core.workflow import ImperativeWorkflow as Workflow # noqa # docs_tasks_start @task def t1(a: str) -> str: return a + " world" @task def t2(): print("side effect") # docs_tasks_end # docs_start # Create the workflow with a name. This needs to be unique within the project and takes the place of the function # name that's used for regular decorated function-based workflows. wb = Workflow(name="my_workflow") # Adds a top level input to the workflow. This is like an input to a workflow function. wb.add_workflow_input("in1", str) # Call your tasks. node = wb.add_entity(t1, a=wb.inputs["in1"]) wb.add_entity(t2) # This is analogous to a return statement wb.add_workflow_output("from_n0t1", node.outputs["o0"]) # docs_end assert wb(in1="hello") == "hello world" wf_spec = get_serializable(OrderedDict(), serialization_settings, wb) assert len(wf_spec.template.nodes) == 2 assert wf_spec.template.nodes[0].task_node is not None assert len(wf_spec.template.outputs) == 1 assert wf_spec.template.outputs[0].var == "from_n0t1" assert len(wf_spec.template.interface.inputs) == 1 assert len(wf_spec.template.interface.outputs) == 1 # docs_equivalent_start nt = typing.NamedTuple("wf_output", from_n0t1=str) @workflow def my_workflow(in1: str) -> nt: x = t1(a=in1) t2() return nt( x, ) # docs_equivalent_end # Create launch plan from wf, that can also be serialized. lp = LaunchPlan.create("test_wb", wb) lp_model = get_serializable(OrderedDict(), serialization_settings, lp) assert lp_model.spec.workflow_id.name == "my_workflow" wb2 = ImperativeWorkflow(name="parent.imperative") p_in1 = wb2.add_workflow_input("p_in1", str) p_node0 = wb2.add_subwf(wb, in1=p_in1) wb2.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str) wb2_spec = get_serializable(OrderedDict(), serialization_settings, wb2) assert len(wb2_spec.template.nodes) == 1 assert len(wb2_spec.template.interface.inputs) == 1 assert wb2_spec.template.interface.inputs["p_in1"].type.simple is not None assert len(wb2_spec.template.interface.outputs) == 1 assert wb2_spec.template.interface.outputs["parent_wf_output"].type.simple is not None assert wb2_spec.template.nodes[0].workflow_node.sub_workflow_ref.name == "my_workflow" assert len(wb2_spec.sub_workflows) == 1 wb3 = ImperativeWorkflow(name="parent.imperative") p_in1 = wb3.add_workflow_input("p_in1", str) p_node0 = wb3.add_launch_plan(lp, in1=p_in1) wb3.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str) wb3_spec = get_serializable(OrderedDict(), serialization_settings, wb3) assert len(wb3_spec.template.nodes) == 1 assert len(wb3_spec.template.interface.inputs) == 1 assert wb3_spec.template.interface.inputs["p_in1"].type.simple is not None assert len(wb3_spec.template.interface.outputs) == 1 assert wb3_spec.template.interface.outputs["parent_wf_output"].type.simple is not None assert wb3_spec.template.nodes[0].workflow_node.launchplan_ref.name == "test_wb"
def serialize_all( pkgs: List[str] = None, local_source_root: str = None, folder: str = None, mode: SerializationMode = None, image: str = None, config_path: str = None, flytekit_virtualenv_root: str = None, ): """ In order to register, we have to comply with Admin's endpoints. Those endpoints take the following objects. These flyteidl.admin.launch_plan_pb2.LaunchPlanSpec flyteidl.admin.workflow_pb2.WorkflowSpec flyteidl.admin.task_pb2.TaskSpec However, if we were to merely call .to_flyte_idl() on all the discovered entities, what we would get are: flyteidl.admin.launch_plan_pb2.LaunchPlanSpec flyteidl.core.workflow_pb2.WorkflowTemplate flyteidl.core.tasks_pb2.TaskTemplate For Workflows and Tasks therefore, there is special logic in the serialize function that translates these objects. :param list[Text] pkgs: :param Text folder: :return: """ # m = module (i.e. python file) # k = value of dir(m), type str # o = object (e.g. SdkWorkflow) env = { _internal_config.CONFIGURATION_PATH.env_var: config_path if config_path else _internal_config.CONFIGURATION_PATH.get(), _internal_config.IMAGE.env_var: image, } serialization_settings = flyte_context.SerializationSettings( project=_PROJECT_PLACEHOLDER, domain=_DOMAIN_PLACEHOLDER, version=_VERSION_PLACEHOLDER, image_config=flyte_context.get_image_config(img_name=image), env=env, flytekit_virtualenv_root=flytekit_virtualenv_root, entrypoint_settings=flyte_context.EntrypointSettings( path=_os.path.join(flytekit_virtualenv_root, _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)), ) with flyte_context.FlyteContext.current_context( ).new_serialization_settings( serialization_settings=serialization_settings) as ctx: loaded_entities = [] for m, k, o in iterate_registerable_entities_in_order( pkgs, local_source_root=local_source_root): name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type) _logging.debug( "Found module {}\n K: {} Instantiated in {}".format( m, k, o._instantiated_in)) o._id = _identifier.Identifier(o.resource_type, _PROJECT_PLACEHOLDER, _DOMAIN_PLACEHOLDER, name, _VERSION_PLACEHOLDER) loaded_entities.append(o) ctx.serialization_settings.add_instance_var( InstanceVar(module=m, name=k, o=o)) click.echo( f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows" ) mode = mode if mode else SerializationMode.DEFAULT # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan # object, which gets added to the FlyteEntities.entities list, which we're iterating over. for entity in flyte_context.FlyteEntities.entities.copy(): # TODO: Add a reachable check. Since these entities are always added by the constructor, weird things can # happen. If someone creates a workflow inside a workflow, we don't actually want the inner workflow to be # registered. Or do we? Certainly, we don't want inner tasks to be registered because we don't know how # to reach them, but perhaps workflows should be okay to take into account generated workflows. # Also a user may import dir_b.workflows from dir_a.workflows but workflow packages might only # specify dir_a if isinstance(entity, PythonTask) or isinstance( entity, Workflow) or isinstance(entity, LaunchPlan): if isinstance(entity, PythonTask): if mode == SerializationMode.DEFAULT: serializable = get_serializable( ctx.serialization_settings, entity) elif mode == SerializationMode.FAST: serializable = get_serializable( ctx.serialization_settings, entity, fast=True) else: raise AssertionError( f"Unrecognized serialization mode: {mode}") else: serializable = get_serializable(ctx.serialization_settings, entity) loaded_entities.append(serializable) if isinstance(entity, Workflow): lp = LaunchPlan.get_default_launch_plan(ctx, entity) launch_plan = get_serializable(ctx.serialization_settings, lp) loaded_entities.append(launch_plan) zero_padded_length = _determine_text_chars(len(loaded_entities)) for i, entity in enumerate(loaded_entities): if entity.has_registered: _logging.info( f"Skipping entity {entity.id} because already registered") continue serialized = entity.serialize() fname_index = str(i).zfill(zero_padded_length) fname = "{}_{}_{}.pb".format(fname_index, entity.id.name, entity.id.resource_type) click.echo( f" Writing type: {entity.id.resource_type_name()}, {entity.id.name} to\n {fname}" ) if folder: fname = _os.path.join(folder, fname) _write_proto_to_file(serialized, fname) click.secho( f"Successfully serialized {len(loaded_entities)} flyte objects", fg="green")
def serialize_all( pkgs: List[str] = None, local_source_root: str = None, folder: str = None, mode: SerializationMode = None, image: str = None, config_path: str = None, flytekit_virtualenv_root: str = None, python_interpreter: str = None, ): """ This function will write to the folder specified the following protobuf types :: flyteidl.admin.launch_plan_pb2.LaunchPlan flyteidl.admin.workflow_pb2.WorkflowSpec flyteidl.admin.task_pb2.TaskSpec These can be inspected by calling (in the launch plan case) :: flyte-cli parse-proto -f filename.pb -p flyteidl.admin.launch_plan_pb2.LaunchPlan See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the entity type. :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization. :param local_source_root: Where to start looking for the code. :param folder: Where to write the output protobuf files :param mode: Regular vs fast :param image: The fully qualified and versioned default image to use :param config_path: Path to the config file, if any, to be used during serialization :param flytekit_virtualenv_root: The full path of the virtual env in the container. """ # m = module (i.e. python file) # k = value of dir(m), type str # o = object (e.g. SdkWorkflow) env = { _internal_config.CONFIGURATION_PATH.env_var: config_path if config_path else _internal_config.CONFIGURATION_PATH.get(), _internal_config.IMAGE.env_var: image, } serialization_settings = flyte_context.SerializationSettings( project=_PROJECT_PLACEHOLDER, domain=_DOMAIN_PLACEHOLDER, version=_VERSION_PLACEHOLDER, image_config=flyte_context.get_image_config(img_name=image), env=env, flytekit_virtualenv_root=flytekit_virtualenv_root, python_interpreter=python_interpreter, entrypoint_settings=flyte_context.EntrypointSettings( path=_os.path.join(flytekit_virtualenv_root, _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)), ) ctx = flyte_context.FlyteContextManager.current_context( ).with_serialization_settings(serialization_settings) with flyte_context.FlyteContextManager.with_context(ctx) as ctx: old_style_entities = [] # This first for loop is for legacy API entities - SdkTask, SdkWorkflow, etc. The _get_entity_to_module # function that this iterate calls only works on legacy objects for m, k, o in iterate_registerable_entities_in_order( pkgs, local_source_root=local_source_root): name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type) _logging.debug( "Found module {}\n K: {} Instantiated in {}".format( m, k, o._instantiated_in)) o._id = _identifier.Identifier(o.resource_type, _PROJECT_PLACEHOLDER, _DOMAIN_PLACEHOLDER, name, _VERSION_PLACEHOLDER) old_style_entities.append(o) serialized_old_style_entities = [] for entity in old_style_entities: if entity.has_registered: _logging.info( f"Skipping entity {entity.id} because already registered") continue serialized_old_style_entities.append(entity.serialize()) click.echo( f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows" ) mode = mode if mode else SerializationMode.DEFAULT new_api_serializable_entities = OrderedDict() # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan # object, which gets added to the FlyteEntities.entities list, which we're iterating over. for entity in flyte_context.FlyteEntities.entities.copy(): # TODO: Add a reachable check. Since these entities are always added by the constructor, weird things can # happen. If someone creates a workflow inside a workflow, we don't actually want the inner workflow to be # registered. Or do we? Certainly, we don't want inner tasks to be registered because we don't know how # to reach them, but perhaps workflows should be okay to take into account generated workflows. # Also a user may import dir_b.workflows from dir_a.workflows but workflow packages might only # specify dir_a if isinstance(entity, PythonTask) or isinstance( entity, WorkflowBase) or isinstance(entity, LaunchPlan): if isinstance(entity, PythonTask): if mode == SerializationMode.DEFAULT: get_serializable(new_api_serializable_entities, ctx.serialization_settings, entity) elif mode == SerializationMode.FAST: get_serializable(new_api_serializable_entities, ctx.serialization_settings, entity, fast=True) else: raise AssertionError( f"Unrecognized serialization mode: {mode}") else: get_serializable(new_api_serializable_entities, ctx.serialization_settings, entity) if isinstance(entity, WorkflowBase): lp = LaunchPlan.get_default_launch_plan(ctx, entity) get_serializable(new_api_serializable_entities, ctx.serialization_settings, lp) new_api_model_values = list(new_api_serializable_entities.values()) new_api_model_values = list( filter(_should_register_with_admin, new_api_model_values)) new_api_model_values = [v.to_flyte_idl() for v in new_api_model_values] loaded_entities = serialized_old_style_entities + new_api_model_values zero_padded_length = _determine_text_chars(len(loaded_entities)) for i, entity in enumerate(loaded_entities): fname_index = str(i).zfill(zero_padded_length) if isinstance(entity, _idl_admin_TaskSpec): fname = "{}_{}_1.pb".format(fname_index, entity.template.id.name) elif isinstance(entity, _idl_admin_WorkflowSpec): fname = "{}_{}_2.pb".format(fname_index, entity.template.id.name) elif isinstance(entity, _idl_admin_LaunchPlan): fname = "{}_{}_3.pb".format(fname_index, entity.id.name) else: raise Exception(f"Bad format {type(entity)}") click.echo(f" Writing to file: {fname}") if folder: fname = _os.path.join(folder, fname) _write_proto_to_file(entity, fname) click.secho( f"Successfully serialized {len(loaded_entities)} flyte objects", fg="green")