예제 #1
0
def test_imperative_call_from_normal():
    @task
    def t1(a: str) -> str:
        return a + " world"

    wb = ImperativeWorkflow(name="my.workflow")
    wb.add_workflow_input("in1", str)
    node = wb.add_entity(t1, a=wb.inputs["in1"])
    wb.add_workflow_output("from_n0t1", node.outputs["o0"])

    assert wb(in1="hello") == "hello world"

    @workflow
    def my_functional_wf(a: str) -> str:
        x = wb(in1=a)
        return x

    assert my_functional_wf(a="hello") == "hello world"

    # Create launch plan from wf
    lp = LaunchPlan.create("test_wb_2", wb, fixed_inputs={"in1": "hello"})

    @workflow
    def my_functional_wf_lp() -> str:
        x = lp()
        return x

    assert my_functional_wf_lp() == "hello world"
예제 #2
0
def test_imperative():
    @task
    def t1(a: str) -> str:
        return a + " world"

    @task
    def t2():
        print("side effect")

    wb = ImperativeWorkflow(name="my.workflow")
    wb.add_workflow_input("in1", str)
    node = wb.add_entity(t1, a=wb.inputs["in1"])
    wb.add_entity(t2)
    wb.add_workflow_output("from_n0t1", node.outputs["o0"])

    assert wb(in1="hello") == "hello world"

    srz_wf = get_serializable(OrderedDict(), serialization_settings, wb)
    assert len(srz_wf.nodes) == 2
    assert srz_wf.nodes[0].task_node is not None
    assert len(srz_wf.outputs) == 1
    assert srz_wf.outputs[0].var == "from_n0t1"
    assert len(srz_wf.interface.inputs) == 1
    assert len(srz_wf.interface.outputs) == 1

    # Create launch plan from wf, that can also be serialized.
    lp = LaunchPlan.create("test_wb", wb)
    srz_lp = get_serializable(OrderedDict(), serialization_settings, lp)
    assert srz_lp.workflow_id.name == "my.workflow"
예제 #3
0
def test_basics():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        return a + 2, "world"

    @task
    def t2(a: str, b: str) -> str:
        return b + a

    @workflow
    def my_wf(a: int, b: str) -> (int, str):
        x, y = t1(a=a)
        d = t2(a=y, b=b)
        return x, d

    sdk_wf = get_serializable(OrderedDict(), serialization_settings, my_wf, False)
    assert len(sdk_wf.interface.inputs) == 2
    assert len(sdk_wf.interface.outputs) == 2
    assert len(sdk_wf.nodes) == 2

    # Gets cached the first time around so it's not actually fast.
    sdk_task = get_serializable(OrderedDict(), serialization_settings, t1, True)
    assert "pyflyte-execute" in sdk_task.container.args

    lp = LaunchPlan.create(
        "testlp",
        my_wf,
    )
    sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp)
    assert sdk_lp.id.name == "testlp"
예제 #4
0
def test_execute_python_workflow_and_launch_plan(flyteclient,
                                                 flyte_workflows_register,
                                                 flyte_remote_env):
    """Test execution of a @workflow-decorated python function and launchplan that are already registered."""
    from mock_flyte_repo.workflows.basic.basic_workflow import my_wf

    # make sure the task name is the same as the name used during registration
    my_wf._name = my_wf.name.replace("mock_flyte_repo.", "")

    remote = FlyteRemote(Config.auto(), PROJECT, "development")
    execution = remote.execute(my_wf,
                               inputs={
                                   "a": 10,
                                   "b": "xyz"
                               },
                               version=f"v{VERSION}",
                               wait=True)
    assert execution.outputs["o0"] == 12
    assert execution.outputs["o1"] == "xyzworld"

    launch_plan = LaunchPlan.get_or_create(workflow=my_wf, name=my_wf.name)
    execution = remote.execute(launch_plan,
                               inputs={
                                   "a": 14,
                                   "b": "foobar"
                               },
                               version=f"v{VERSION}",
                               wait=True)
    assert execution.outputs["o0"] == 16
    assert execution.outputs["o1"] == "foobarworld"

    flyte_workflow_execution = remote.fetch_execution(name=execution.id.name)
    assert execution.inputs == flyte_workflow_execution.inputs
    assert execution.outputs == flyte_workflow_execution.outputs
예제 #5
0
def test_launch_plan_with_fixed_input():
    @task
    def greet(day_of_week: str, number: int, am: bool) -> str:
        greeting = "Have a great " + day_of_week + " "
        greeting += "morning" if am else "evening"
        return greeting + "!" * number

    @workflow
    def go_greet(day_of_week: str, number: int, am: bool = False) -> str:
        return greet(day_of_week=day_of_week, number=number, am=am)

    morning_greeting = LaunchPlan.create(
        "morning_greeting",
        go_greet,
        fixed_inputs={"am": True},
        default_inputs={"number": 1},
    )

    @workflow
    def morning_greeter_caller(day_of_week: str) -> str:
        greeting = morning_greeting(day_of_week=day_of_week)
        return greeting

    settings = (
        serialization_settings.new_builder().with_fast_serialization_settings(
            FastSerializationSettings(enabled=True)).build())
    task_spec = get_serializable(OrderedDict(), settings,
                                 morning_greeter_caller)
    assert len(task_spec.template.interface.inputs) == 1
    assert len(task_spec.template.interface.outputs) == 1
    assert len(task_spec.template.nodes) == 1
    assert len(task_spec.template.nodes[0].inputs) == 2
예제 #6
0
def test_execute_python_workflow_dict_of_string_to_string(
        flyteclient, flyte_workflows_register, flyte_remote_env):
    """Test execution of a @workflow-decorated python function and launchplan that are already registered."""
    from mock_flyte_repo.workflows.basic.dict_str_wf import my_wf

    # make sure the task name is the same as the name used during registration
    my_wf._name = my_wf.name.replace("mock_flyte_repo.", "")

    remote = FlyteRemote(Config.auto(), PROJECT, "development")
    d: typing.Dict[str, str] = {"k1": "v1", "k2": "v2"}
    execution = remote.execute(my_wf,
                               inputs={"d": d},
                               version=f"v{VERSION}",
                               wait=True)
    assert json.loads(execution.outputs["o0"]) == {"k1": "v1", "k2": "v2"}

    launch_plan = LaunchPlan.get_or_create(workflow=my_wf, name=my_wf.name)
    execution = remote.execute(launch_plan,
                               inputs={"d": {
                                   "k2": "vvvv",
                                   "abc": "def"
                               }},
                               version=f"v{VERSION}",
                               wait=True)
    assert json.loads(execution.outputs["o0"]) == {"k2": "vvvv", "abc": "def"}
예제 #7
0
def test_call_normal():
    @task
    def t1(a: int) -> (int, str):
        return a + 2, "world"

    @workflow
    def my_functional_wf(a: int) -> (int, str):
        return t1(a=a)

    my_functional_lp = LaunchPlan.create("my_functional_wf.lp0",
                                         my_functional_wf,
                                         default_inputs={"a": 3})

    wb = ImperativeWorkflow(name="imperio")
    node = wb.add_entity(my_functional_wf, a=3)
    wb.add_workflow_output("from_n0_1", node.outputs["o0"])
    wb.add_workflow_output("from_n0_2", node.outputs["o1"])

    assert wb() == (5, "world")

    wb_lp = ImperativeWorkflow(name="imperio")
    node = wb_lp.add_entity(my_functional_lp)
    wb_lp.add_workflow_output("from_n0_1", node.outputs["o0"])
    wb_lp.add_workflow_output("from_n0_2", node.outputs["o1"])

    assert wb_lp() == (5, "world")
예제 #8
0
def test_basics():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        return a + 2, "world"

    @task
    def t2(a: str, b: str) -> str:
        return b + a

    @workflow
    def my_wf(a: int, b: str) -> (int, str):
        x, y = t1(a=a)
        d = t2(a=y, b=b)
        return x, d

    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf,
                               False)
    assert len(wf_spec.template.interface.inputs) == 2
    assert len(wf_spec.template.interface.outputs) == 2
    assert len(wf_spec.template.nodes) == 2
    assert wf_spec.template.id.resource_type == identifier_models.ResourceType.WORKFLOW

    # Gets cached the first time around so it's not actually fast.
    task_spec = get_serializable(OrderedDict(), serialization_settings, t1,
                                 True)
    assert "pyflyte-execute" in task_spec.template.container.args

    lp = LaunchPlan.create(
        "testlp",
        my_wf,
    )
    lp_model = get_serializable(OrderedDict(), serialization_settings, lp)
    assert lp_model.id.name == "testlp"
예제 #9
0
def test_imperative():
    @task
    def t1(a: str) -> str:
        return a + " world"

    @task
    def t2():
        print("side effect")

    wb = ImperativeWorkflow(name="my.workflow")
    wb.add_workflow_input("in1", str)
    node = wb.add_entity(t1, a=wb.inputs["in1"])
    wb.add_entity(t2)
    wb.add_workflow_output("from_n0t1", node.outputs["o0"])

    assert wb(in1="hello") == "hello world"

    wf_spec = get_serializable(OrderedDict(), serialization_settings, wb)
    assert len(wf_spec.template.nodes) == 2
    assert wf_spec.template.nodes[0].task_node is not None
    assert len(wf_spec.template.outputs) == 1
    assert wf_spec.template.outputs[0].var == "from_n0t1"
    assert len(wf_spec.template.interface.inputs) == 1
    assert len(wf_spec.template.interface.outputs) == 1

    # Create launch plan from wf, that can also be serialized.
    lp = LaunchPlan.create("test_wb", wb)
    lp_model = get_serializable(OrderedDict(), serialization_settings, lp)
    assert lp_model.spec.workflow_id.name == "my.workflow"

    wb2 = ImperativeWorkflow(name="parent.imperative")
    p_in1 = wb2.add_workflow_input("p_in1", str)
    p_node0 = wb2.add_subwf(wb, in1=p_in1)
    wb2.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str)
    wb2_spec = get_serializable(OrderedDict(), serialization_settings, wb2)
    assert len(wb2_spec.template.nodes) == 1
    assert len(wb2_spec.template.interface.inputs) == 1
    assert wb2_spec.template.interface.inputs["p_in1"].type.simple is not None
    assert len(wb2_spec.template.interface.outputs) == 1
    assert wb2_spec.template.interface.outputs[
        "parent_wf_output"].type.simple is not None
    assert wb2_spec.template.nodes[
        0].workflow_node.sub_workflow_ref.name == "my.workflow"
    assert len(wb2_spec.sub_workflows) == 1

    wb3 = ImperativeWorkflow(name="parent.imperative")
    p_in1 = wb3.add_workflow_input("p_in1", str)
    p_node0 = wb3.add_launch_plan(lp, in1=p_in1)
    wb3.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str)
    wb3_spec = get_serializable(OrderedDict(), serialization_settings, wb3)
    assert len(wb3_spec.template.nodes) == 1
    assert len(wb3_spec.template.interface.inputs) == 1
    assert wb3_spec.template.interface.inputs["p_in1"].type.simple is not None
    assert len(wb3_spec.template.interface.outputs) == 1
    assert wb3_spec.template.interface.outputs[
        "parent_wf_output"].type.simple is not None
    assert wb3_spec.template.nodes[
        0].workflow_node.launchplan_ref.name == "test_wb"
예제 #10
0
def test_lp_from_ref_wf():
    @reference_workflow(project="project", domain="domain", name="name", version="version")
    def ref_wf1(p1: str, p2: str) -> None:
        ...

    lp = LaunchPlan.create("reference-wf-12345", ref_wf1, fixed_inputs={"p1": "p1-value", "p2": "p2-value"})
    assert lp.name == "reference-wf-12345"
    assert lp.workflow == ref_wf1
    assert lp.workflow.id.name == "name"
    assert lp.workflow.id.project == "project"
    assert lp.workflow.id.domain == "domain"
    assert lp.workflow.id.version == "version"
예제 #11
0
def test_file_handling_remote_default_wf_input():
    SAMPLE_DATA = "https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv"

    @task
    def t1(fname: os.PathLike) -> int:
        with open(fname, "r") as fh:
            x = len(fh.readlines())

        return x

    @workflow
    def my_wf(fname: os.PathLike = SAMPLE_DATA) -> int:
        length = t1(fname=fname)
        return length

    assert my_wf.python_interface.inputs_with_defaults["fname"][1] == SAMPLE_DATA
    sample_lp = LaunchPlan.create("test_launch_plan", my_wf)
    assert sample_lp.parameters.parameters["fname"].default.scalar.blob.uri == SAMPLE_DATA
예제 #12
0
def test_schedule_with_lp():
    @task
    def double(a: int) -> int:
        return a * 2

    @workflow
    def quadruple(a: int) -> int:
        b = double(a=a)
        c = double(a=b)
        return c

    lp = LaunchPlan.create(
        "schedule_test",
        quadruple,
        schedule=FixedRate(_datetime.timedelta(hours=12), "kickoff_input"),
    )
    assert lp.schedule == _schedule_models.Schedule(
        "kickoff_input",
        rate=_schedule_models.Schedule.FixedRate(
            12, _schedule_models.Schedule.FixedRateUnit.HOUR))
예제 #13
0
def test_execute_python_workflow_list_of_floats(flyteclient,
                                                flyte_workflows_register,
                                                flyte_remote_env):
    """Test execution of a @workflow-decorated python function and launchplan that are already registered."""
    from mock_flyte_repo.workflows.basic.list_float_wf import my_wf

    # make sure the task name is the same as the name used during registration
    my_wf._name = my_wf.name.replace("mock_flyte_repo.", "")
    remote = FlyteRemote(Config.auto(), PROJECT, "development")

    xs: typing.List[float] = [42.24, 999.1, 0.0001]
    execution = remote.execute(my_wf,
                               inputs={"xs": xs},
                               version=f"v{VERSION}",
                               wait=True)
    assert execution.outputs["o0"] == "[42.24, 999.1, 0.0001]"

    launch_plan = LaunchPlan.get_or_create(workflow=my_wf, name=my_wf.name)
    execution = remote.execute(launch_plan,
                               inputs={"xs": [-1.1, 0.12345]},
                               version=f"v{VERSION}",
                               wait=True)
    assert execution.outputs["o0"] == "[-1.1, 0.12345]"
예제 #14
0
def test_with_launch_plan():
    @task
    def double(a: int) -> int:
        return a * 2

    @workflow
    def quadruple(a: int) -> int:
        b = double(a=a)
        c = double(a=b)
        return c

    lp = LaunchPlan.create(
        "notif_test",
        quadruple,
        notifications=[
            notification.Email(phases=[_workflow_execution_succeeded],
                               recipients_email=["*****@*****.**"])
        ],
    )
    assert lp.notifications == [
        notification.Email(phases=[_workflow_execution_succeeded],
                           recipients_email=["*****@*****.**"])
    ]
예제 #15
0
def get_registrable_entities(ctx: flyte_context.FlyteContext) -> typing.List:
    """
    Returns all entities that can be serialized and should be sent over to Flyte backend. This will filter any entities
    that are not known to Admin
    """
    new_api_serializable_entities = OrderedDict()
    # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan
    #  object, which gets added to the FlyteEntities.entities list, which we're iterating over.
    for entity in flyte_context.FlyteEntities.entities.copy():
        if isinstance(entity, PythonTask) or isinstance(
                entity, WorkflowBase) or isinstance(entity, LaunchPlan):
            get_serializable(new_api_serializable_entities,
                             ctx.serialization_settings, entity)

            if isinstance(entity, WorkflowBase):
                lp = LaunchPlan.get_default_launch_plan(ctx, entity)
                get_serializable(new_api_serializable_entities,
                                 ctx.serialization_settings, lp)

    new_api_model_values = list(new_api_serializable_entities.values())
    entities_to_be_serialized = list(
        filter(_should_register_with_admin, new_api_model_values))
    return [v.to_flyte_idl() for v in entities_to_be_serialized]
예제 #16
0
def test_calling_lp():
    sub_wf_lp = LaunchPlan.get_or_create(sub_wf)
    serialized = OrderedDict()
    lp_model = get_serializable(serialized, serialization_settings, sub_wf_lp)
    task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized)
    for wf_id, spec in wf_specs.items():
        break

    remote_lp = FlyteLaunchPlan.promote_from_model(lp_model.id, lp_model.spec)
    # To pretend that we've fetched this launch plan from Admin, also fill in the Flyte interface, which isn't
    # part of the IDL object but is something FlyteRemote does
    remote_lp._interface = TypedInterface.promote_from_model(
        spec.template.interface)
    serialized = OrderedDict()

    @workflow
    def wf2(a: int) -> typing.Tuple[int, str]:
        return remote_lp(a=a, b="hello")

    wf_spec = get_serializable(serialized, serialization_settings, wf2)
    print(wf_spec.template.nodes[0].workflow_node.launchplan_ref)
    assert wf_spec.template.nodes[
        0].workflow_node.launchplan_ref == lp_model.id
예제 #17
0
def test_imperative():
    # Re import with alias
    from flytekit.core.workflow import ImperativeWorkflow as Workflow  # noqa

    # docs_tasks_start
    @task
    def t1(a: str) -> str:
        return a + " world"

    @task
    def t2():
        print("side effect")

    # docs_tasks_end

    # docs_start
    # Create the workflow with a name. This needs to be unique within the project and takes the place of the function
    # name that's used for regular decorated function-based workflows.
    wb = Workflow(name="my_workflow")
    # Adds a top level input to the workflow. This is like an input to a workflow function.
    wb.add_workflow_input("in1", str)
    # Call your tasks.
    node = wb.add_entity(t1, a=wb.inputs["in1"])
    wb.add_entity(t2)
    # This is analogous to a return statement
    wb.add_workflow_output("from_n0t1", node.outputs["o0"])
    # docs_end

    assert wb(in1="hello") == "hello world"

    wf_spec = get_serializable(OrderedDict(), serialization_settings, wb)
    assert len(wf_spec.template.nodes) == 2
    assert wf_spec.template.nodes[0].task_node is not None
    assert len(wf_spec.template.outputs) == 1
    assert wf_spec.template.outputs[0].var == "from_n0t1"
    assert len(wf_spec.template.interface.inputs) == 1
    assert len(wf_spec.template.interface.outputs) == 1

    # docs_equivalent_start
    nt = typing.NamedTuple("wf_output", from_n0t1=str)

    @workflow
    def my_workflow(in1: str) -> nt:
        x = t1(a=in1)
        t2()
        return nt(
            x,
        )

    # docs_equivalent_end

    # Create launch plan from wf, that can also be serialized.
    lp = LaunchPlan.create("test_wb", wb)
    lp_model = get_serializable(OrderedDict(), serialization_settings, lp)
    assert lp_model.spec.workflow_id.name == "my_workflow"

    wb2 = ImperativeWorkflow(name="parent.imperative")
    p_in1 = wb2.add_workflow_input("p_in1", str)
    p_node0 = wb2.add_subwf(wb, in1=p_in1)
    wb2.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str)
    wb2_spec = get_serializable(OrderedDict(), serialization_settings, wb2)
    assert len(wb2_spec.template.nodes) == 1
    assert len(wb2_spec.template.interface.inputs) == 1
    assert wb2_spec.template.interface.inputs["p_in1"].type.simple is not None
    assert len(wb2_spec.template.interface.outputs) == 1
    assert wb2_spec.template.interface.outputs["parent_wf_output"].type.simple is not None
    assert wb2_spec.template.nodes[0].workflow_node.sub_workflow_ref.name == "my_workflow"
    assert len(wb2_spec.sub_workflows) == 1

    wb3 = ImperativeWorkflow(name="parent.imperative")
    p_in1 = wb3.add_workflow_input("p_in1", str)
    p_node0 = wb3.add_launch_plan(lp, in1=p_in1)
    wb3.add_workflow_output("parent_wf_output", p_node0.from_n0t1, str)
    wb3_spec = get_serializable(OrderedDict(), serialization_settings, wb3)
    assert len(wb3_spec.template.nodes) == 1
    assert len(wb3_spec.template.interface.inputs) == 1
    assert wb3_spec.template.interface.inputs["p_in1"].type.simple is not None
    assert len(wb3_spec.template.interface.outputs) == 1
    assert wb3_spec.template.interface.outputs["parent_wf_output"].type.simple is not None
    assert wb3_spec.template.nodes[0].workflow_node.launchplan_ref.name == "test_wb"
예제 #18
0
def serialize_all(
    pkgs: List[str] = None,
    local_source_root: str = None,
    folder: str = None,
    mode: SerializationMode = None,
    image: str = None,
    config_path: str = None,
    flytekit_virtualenv_root: str = None,
):
    """
    In order to register, we have to comply with Admin's endpoints. Those endpoints take the following objects. These
    flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
    flyteidl.admin.workflow_pb2.WorkflowSpec
    flyteidl.admin.task_pb2.TaskSpec

    However, if we were to merely call .to_flyte_idl() on all the discovered entities, what we would get are:
    flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
    flyteidl.core.workflow_pb2.WorkflowTemplate
    flyteidl.core.tasks_pb2.TaskTemplate

    For Workflows and Tasks therefore, there is special logic in the serialize function that translates these objects.

    :param list[Text] pkgs:
    :param Text folder:

    :return:
    """

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    env = {
        _internal_config.CONFIGURATION_PATH.env_var:
        config_path
        if config_path else _internal_config.CONFIGURATION_PATH.get(),
        _internal_config.IMAGE.env_var:
        image,
    }

    serialization_settings = flyte_context.SerializationSettings(
        project=_PROJECT_PLACEHOLDER,
        domain=_DOMAIN_PLACEHOLDER,
        version=_VERSION_PLACEHOLDER,
        image_config=flyte_context.get_image_config(img_name=image),
        env=env,
        flytekit_virtualenv_root=flytekit_virtualenv_root,
        entrypoint_settings=flyte_context.EntrypointSettings(
            path=_os.path.join(flytekit_virtualenv_root,
                               _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)),
    )
    with flyte_context.FlyteContext.current_context(
    ).new_serialization_settings(
            serialization_settings=serialization_settings) as ctx:
        loaded_entities = []
        for m, k, o in iterate_registerable_entities_in_order(
                pkgs, local_source_root=local_source_root):
            name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
            _logging.debug(
                "Found module {}\n   K: {} Instantiated in {}".format(
                    m, k, o._instantiated_in))
            o._id = _identifier.Identifier(o.resource_type,
                                           _PROJECT_PLACEHOLDER,
                                           _DOMAIN_PLACEHOLDER, name,
                                           _VERSION_PLACEHOLDER)
            loaded_entities.append(o)
            ctx.serialization_settings.add_instance_var(
                InstanceVar(module=m, name=k, o=o))

        click.echo(
            f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows"
        )

        mode = mode if mode else SerializationMode.DEFAULT
        # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan
        #  object, which gets added to the FlyteEntities.entities list, which we're iterating over.
        for entity in flyte_context.FlyteEntities.entities.copy():
            # TODO: Add a reachable check. Since these entities are always added by the constructor, weird things can
            #  happen. If someone creates a workflow inside a workflow, we don't actually want the inner workflow to be
            #  registered. Or do we? Certainly, we don't want inner tasks to be registered because we don't know how
            #  to reach them, but perhaps workflows should be okay to take into account generated workflows.
            #  Also a user may import dir_b.workflows from dir_a.workflows but workflow packages might only
            #  specify dir_a

            if isinstance(entity, PythonTask) or isinstance(
                    entity, Workflow) or isinstance(entity, LaunchPlan):
                if isinstance(entity, PythonTask):
                    if mode == SerializationMode.DEFAULT:
                        serializable = get_serializable(
                            ctx.serialization_settings, entity)
                    elif mode == SerializationMode.FAST:
                        serializable = get_serializable(
                            ctx.serialization_settings, entity, fast=True)
                    else:
                        raise AssertionError(
                            f"Unrecognized serialization mode: {mode}")
                else:
                    serializable = get_serializable(ctx.serialization_settings,
                                                    entity)
                loaded_entities.append(serializable)

                if isinstance(entity, Workflow):
                    lp = LaunchPlan.get_default_launch_plan(ctx, entity)
                    launch_plan = get_serializable(ctx.serialization_settings,
                                                   lp)
                    loaded_entities.append(launch_plan)

        zero_padded_length = _determine_text_chars(len(loaded_entities))
        for i, entity in enumerate(loaded_entities):
            if entity.has_registered:
                _logging.info(
                    f"Skipping entity {entity.id} because already registered")
                continue
            serialized = entity.serialize()
            fname_index = str(i).zfill(zero_padded_length)
            fname = "{}_{}_{}.pb".format(fname_index, entity.id.name,
                                         entity.id.resource_type)
            click.echo(
                f"  Writing type: {entity.id.resource_type_name()}, {entity.id.name} to\n    {fname}"
            )
            if folder:
                fname = _os.path.join(folder, fname)
            _write_proto_to_file(serialized, fname)

        click.secho(
            f"Successfully serialized {len(loaded_entities)} flyte objects",
            fg="green")
예제 #19
0
def serialize_all(
    pkgs: List[str] = None,
    local_source_root: str = None,
    folder: str = None,
    mode: SerializationMode = None,
    image: str = None,
    config_path: str = None,
    flytekit_virtualenv_root: str = None,
    python_interpreter: str = None,
):
    """
    This function will write to the folder specified the following protobuf types ::
        flyteidl.admin.launch_plan_pb2.LaunchPlan
        flyteidl.admin.workflow_pb2.WorkflowSpec
        flyteidl.admin.task_pb2.TaskSpec

    These can be inspected by calling (in the launch plan case) ::
        flyte-cli parse-proto -f filename.pb -p flyteidl.admin.launch_plan_pb2.LaunchPlan

    See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the
    entity type.
    :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization.
    :param local_source_root: Where to start looking for the code.
    :param folder: Where to write the output protobuf files
    :param mode: Regular vs fast
    :param image: The fully qualified and versioned default image to use
    :param config_path: Path to the config file, if any, to be used during serialization
    :param flytekit_virtualenv_root: The full path of the virtual env in the container.
    """

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    env = {
        _internal_config.CONFIGURATION_PATH.env_var:
        config_path
        if config_path else _internal_config.CONFIGURATION_PATH.get(),
        _internal_config.IMAGE.env_var:
        image,
    }

    serialization_settings = flyte_context.SerializationSettings(
        project=_PROJECT_PLACEHOLDER,
        domain=_DOMAIN_PLACEHOLDER,
        version=_VERSION_PLACEHOLDER,
        image_config=flyte_context.get_image_config(img_name=image),
        env=env,
        flytekit_virtualenv_root=flytekit_virtualenv_root,
        python_interpreter=python_interpreter,
        entrypoint_settings=flyte_context.EntrypointSettings(
            path=_os.path.join(flytekit_virtualenv_root,
                               _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)),
    )
    ctx = flyte_context.FlyteContextManager.current_context(
    ).with_serialization_settings(serialization_settings)
    with flyte_context.FlyteContextManager.with_context(ctx) as ctx:
        old_style_entities = []
        # This first for loop is for legacy API entities - SdkTask, SdkWorkflow, etc. The _get_entity_to_module
        # function that this iterate calls only works on legacy objects
        for m, k, o in iterate_registerable_entities_in_order(
                pkgs, local_source_root=local_source_root):
            name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
            _logging.debug(
                "Found module {}\n   K: {} Instantiated in {}".format(
                    m, k, o._instantiated_in))
            o._id = _identifier.Identifier(o.resource_type,
                                           _PROJECT_PLACEHOLDER,
                                           _DOMAIN_PLACEHOLDER, name,
                                           _VERSION_PLACEHOLDER)
            old_style_entities.append(o)

        serialized_old_style_entities = []
        for entity in old_style_entities:
            if entity.has_registered:
                _logging.info(
                    f"Skipping entity {entity.id} because already registered")
                continue
            serialized_old_style_entities.append(entity.serialize())

        click.echo(
            f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows"
        )

        mode = mode if mode else SerializationMode.DEFAULT

        new_api_serializable_entities = OrderedDict()
        # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan
        #  object, which gets added to the FlyteEntities.entities list, which we're iterating over.
        for entity in flyte_context.FlyteEntities.entities.copy():
            # TODO: Add a reachable check. Since these entities are always added by the constructor, weird things can
            #  happen. If someone creates a workflow inside a workflow, we don't actually want the inner workflow to be
            #  registered. Or do we? Certainly, we don't want inner tasks to be registered because we don't know how
            #  to reach them, but perhaps workflows should be okay to take into account generated workflows.
            #  Also a user may import dir_b.workflows from dir_a.workflows but workflow packages might only
            #  specify dir_a
            if isinstance(entity, PythonTask) or isinstance(
                    entity, WorkflowBase) or isinstance(entity, LaunchPlan):
                if isinstance(entity, PythonTask):
                    if mode == SerializationMode.DEFAULT:
                        get_serializable(new_api_serializable_entities,
                                         ctx.serialization_settings, entity)
                    elif mode == SerializationMode.FAST:
                        get_serializable(new_api_serializable_entities,
                                         ctx.serialization_settings,
                                         entity,
                                         fast=True)
                    else:
                        raise AssertionError(
                            f"Unrecognized serialization mode: {mode}")
                else:
                    get_serializable(new_api_serializable_entities,
                                     ctx.serialization_settings, entity)

                if isinstance(entity, WorkflowBase):
                    lp = LaunchPlan.get_default_launch_plan(ctx, entity)
                    get_serializable(new_api_serializable_entities,
                                     ctx.serialization_settings, lp)

        new_api_model_values = list(new_api_serializable_entities.values())
        new_api_model_values = list(
            filter(_should_register_with_admin, new_api_model_values))
        new_api_model_values = [v.to_flyte_idl() for v in new_api_model_values]

        loaded_entities = serialized_old_style_entities + new_api_model_values
        zero_padded_length = _determine_text_chars(len(loaded_entities))
        for i, entity in enumerate(loaded_entities):
            fname_index = str(i).zfill(zero_padded_length)
            if isinstance(entity, _idl_admin_TaskSpec):
                fname = "{}_{}_1.pb".format(fname_index,
                                            entity.template.id.name)
            elif isinstance(entity, _idl_admin_WorkflowSpec):
                fname = "{}_{}_2.pb".format(fname_index,
                                            entity.template.id.name)
            elif isinstance(entity, _idl_admin_LaunchPlan):
                fname = "{}_{}_3.pb".format(fname_index, entity.id.name)
            else:
                raise Exception(f"Bad format {type(entity)}")
            click.echo(f"  Writing to file: {fname}")
            if folder:
                fname = _os.path.join(folder, fname)
            _write_proto_to_file(entity, fname)

        click.secho(
            f"Successfully serialized {len(loaded_entities)} flyte objects",
            fg="green")