Esempio n. 1
0
def serialize_tasks_only(pkgs, folder=None):
    """
    :param list[Text] pkgs:
    :param Text folder:

    :return:
    """
    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    loaded_entities = []
    for m, k, o in iterate_registerable_entities_in_order(pkgs, include_entities={_sdk_task.SdkTask}):
        name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
        _logging.debug("Found module {}\n   K: {} Instantiated in {}".format(m, k, o._instantiated_in))
        o._id = _identifier.Identifier(
            o.resource_type, _PROJECT_PLACEHOLDER, _DOMAIN_PLACEHOLDER, name, _VERSION_PLACEHOLDER
        )
        loaded_entities.append(o)

    zero_padded_length = _determine_text_chars(len(loaded_entities))
    for i, entity in enumerate(loaded_entities):
        serialized = entity.serialize()
        fname_index = str(i).zfill(zero_padded_length)
        fname = "{}_{}.pb".format(fname_index, entity._id.name)
        click.echo("  Writing {} to\n    {}".format(entity._id, fname))
        if folder:
            fname = _os.path.join(folder, fname)
        _write_proto_to_file(serialized, fname)

        identifier_fname = "{}_{}.identifier.pb".format(fname_index, entity._id.name)
        if folder:
            identifier_fname = _os.path.join(folder, identifier_fname)
        _write_proto_to_file(entity._id.to_flyte_idl(), identifier_fname)
Esempio n. 2
0
    def get_command(self, ctx, lp_argument):
        # Get the launch plan object in one of two ways. If get_command is being called by the list function
        # then it should have been cached in the context.
        # If we are actually running the command, then it won't have been cached and we'll have to load everything again
        launch_plan = None
        pkgs = ctx.obj[_constants.CTX_PACKAGES]

        if 'lps' in ctx.obj:
            launch_plan = ctx.obj['lps'][lp_argument]
        else:
            for m, k, lp in iterate_registerable_entities_in_order(
                    pkgs,
                    include_entities={_executable_mixins.ExecutableEntity},
                    detect_unreferenced_entities=False):
                safe_name = _utils.fqdn(m.__name__,
                                        k,
                                        entity_type=lp.resource_type)
                if lp_argument == safe_name:
                    launch_plan = lp

        if launch_plan is None:
            raise Exception(
                'Could not load launch plan {}'.format(lp_argument))

        launch_plan._id = _identifier.Identifier(
            _identifier.ResourceType.LAUNCH_PLAN,
            ctx.obj[_constants.CTX_PROJECT], ctx.obj[_constants.CTX_DOMAIN],
            lp_argument, ctx.obj[_constants.CTX_VERSION])
        return self._get_command(ctx, launch_plan, lp_argument)
Esempio n. 3
0
def register_all(project, domain, pkgs, test, version):
    if test:
        click.echo('Test switch enabled, not doing anything...')
    click.echo('Running task, workflow, and launch plan registration for {}, {}, {} with version {}'.format(
        project, domain, pkgs, version))

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    loaded_entities = []
    for m, k, o in iterate_registerable_entities_in_order(pkgs):
        name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
        _logging.debug("Found module {}\n   K: {} Instantiated in {}".format(m, k, o._instantiated_in))
        o._id = _identifier.Identifier(
            o.resource_type,
            project,
            domain,
            name,
            version
        )
        loaded_entities.append(o)

    for o in loaded_entities:
        if test:
            click.echo("Would register {:20} {}".format("{}:".format(o.entity_type_text), o.id.name))
        else:
            click.echo("Registering {:20} {}".format("{}:".format(o.entity_type_text), o.id.name))
            o.register(project, domain, o.id.name, version)
Esempio n. 4
0
    def auto_assign_name(self):
        """
        This function is a bit of trickster Python code that goes hand in hand with the _InstanceTracker metaclass
        defined above. Thanks @matthewphsmith for this bit of ingenuity.

        For instance, if a user has code that looks like this:

            from some.other.module import wf
            my_launch_plan = wf.create_launch_plan()

            @dynamic_task
            def sample_task(wf_params):
                yield my_launch_plan()

        This code means that we should have a launch plan with a name ending in "my_launch_plan", since that is the
        name of the variable that the created launch plan gets assigned to. That is also the name that the launch plan
        would be registered with.

        However, when the create_launch_plan() function runs, the Python interpreter has no idea where the created
        object will be assigned to. It has no idea that the output of the create_launch_plan call is to be paired up
        with a variable named "my_launch_plan". This function basically does this after the fact. Leveraging the
        _instantiated_in field provided by the _InstanceTracker class above, this code will re-import the
        module (ie Python file) that the object is in. Since it's already loaded, it's just retrieved from memory.
        It then scans all objects in the module, and when an object match is found, it knows it's found the right
        variable name.

        Just to drive the point home, this function is mostly needed for Launch Plans. Assuming that user code has:

            @python_task
            def some_task()

        When Flytekit calls the module loader and loads the task, the name of the task is the name of the function
        itself.  It's known at time of creation. In contrast, when

            xyz = SomeWorflow.create_launch_plan()

        is called, the name of the launch plan isn't known until after creation, it's not "SomeWorkflow", it's "xyz"
        """
        _logging.debug("Running name auto assign")
        m = _importlib.import_module(self.instantiated_in)

        for k in dir(m):
            try:
                if getattr(m, k) == self:
                    self._platform_valid_name = _utils.fqdn(m.__name__, k, entity_type=self.resource_type)
                    _logging.debug("Auto-assigning name to {}".format(self._platform_valid_name))
                    return
            except ValueError as err:
                # Empty pandas dataframes behave weirdly here such that calling `m.df` raises:
                # ValueError: The truth value of a {type(self).__name__} is ambiguous. Use a.empty, a.bool(), a.item(),
                #   a.any() or a.all()
                # Since dataframes aren't registrable entities to begin with we swallow any errors they raise and
                # continue looping through m.
                _logging.warning("Caught ValueError {} while attempting to auto-assign name".format(err))
                pass

        _logging.error("Could not auto-assign name")
        raise _system_exceptions.FlyteSystemException("Error looking for object while auto-assigning name.")
Esempio n. 5
0
def serialize_tasks(pkgs):
    # Serialize all tasks
    for m, k, t in iterate_registerable_entities_in_order(
            pkgs, include_entities={_sdk_task.SdkTask}):
        fname = '{}.pb'.format(
            _utils.fqdn(m.__name__, k, entity_type=t.resource_type))
        click.echo('Writing task {} to {}'.format(t.id, fname))
        pb = t.to_flyte_idl()
        _write_proto_to_file(pb, fname)
Esempio n. 6
0
def register_tasks_only(project, domain, pkgs, test, version):
    if test:
        click.echo('Test switch enabled, not doing anything...')

    click.echo('Running task only registration for {}, {}, {} with version {}'.format(
        project, domain, pkgs, version))

    # Discover all tasks by loading the module
    for m, k, t in iterate_registerable_entities_in_order(pkgs, include_entities={_task.SdkTask}):
        t.register(project, domain, _utils.fqdn(m.__name__, k, entity_type=t.resource_type), version)
Esempio n. 7
0
def serialize_all(project, domain, pkgs, version, folder=None):
    """
    In order to register, we have to comply with Admin's endpoints. Those endpoints take the following objects. These
    flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
    flyteidl.admin.workflow_pb2.WorkflowSpec
    flyteidl.admin.task_pb2.TaskSpec

    However, if we were to merely call .to_flyte_idl() on all the discovered entities, what we would get are:
    flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
    flyteidl.core.workflow_pb2.WorkflowTemplate
    flyteidl.core.tasks_pb2.TaskTemplate

    For Workflows and Tasks therefore, there is special logic in the serialize function that translates these objects.

    :param Text project:
    :param Text domain:
    :param list[Text] pkgs:
    :param Text version:
    :param Text folder:

    :return:
    """

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    loaded_entities = []
    for m, k, o in iterate_registerable_entities_in_order(pkgs):
        name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
        _logging.debug("Found module {}\n   K: {} Instantiated in {}".format(
            m, k, o._instantiated_in))
        o._id = _identifier.Identifier(o.resource_type, project, domain, name,
                                       version)
        loaded_entities.append(o)

    zero_padded_length = _determine_text_chars(len(loaded_entities))
    for i, entity in enumerate(loaded_entities):
        serialized = entity.serialize()
        fname_index = str(i).zfill(zero_padded_length)
        fname = '{}_{}.pb'.format(fname_index, entity._id.name)
        click.echo('  Writing {} to\n    {}'.format(entity._id, fname))
        if folder:
            fname = _os.path.join(folder, fname)
        _write_proto_to_file(serialized, fname)

        # Not everything serialized will necessarily have an identifier field in it, even though some do (like the
        # TaskTemplate). To be more rigorous, we write an explicit identifier file that reflects the choices (like
        # project/domain, etc.) made for this serialize call. We should not allow users to specify a different project
        # for instance come registration time, to avoid mismatches between potential internal ids like the TaskTemplate
        # and the registered entity.
        identifier_fname = '{}_{}.identifier.pb'.format(
            fname_index, entity._id.name)
        if folder:
            identifier_fname = _os.path.join(folder, identifier_fname)
        _write_proto_to_file(entity._id.to_flyte_idl(), identifier_fname)
Esempio n. 8
0
def register_all(project, domain, pkgs, test, version):
    if test:
        click.echo('Test switch enabled, not doing anything...')

    click.echo('Running task, workflow, and launch plan registration for {}, {}, {} with version {}'.format(
        project, domain, pkgs, version))

    for m, k, o in iterate_registerable_entities_in_order(pkgs):
        name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
        click.echo("Registering {:20} {}".format("{}:".format(o.entity_type_text), name))
        o.register(project, domain, name, version)
Esempio n. 9
0
def activate_all_impl(project, domain, version, pkgs):
    # TODO: This should be a transaction to ensure all or none are updated
    # TODO: We should optionally allow deactivation of missing launch plans

    # Discover all launch plans by loading the modules
    for m, k, lp in iterate_registerable_entities_in_order(
            pkgs, include_entities={_SdkLaunchPlan}):
        lp._id = _identifier.Identifier(
            _identifier.ResourceType.LAUNCH_PLAN, project, domain,
            _utils.fqdn(m.__name__, k, entity_type=lp.resource_type), version)
        lp.update(_launch_plan_model.LaunchPlanState.ACTIVE)
Esempio n. 10
0
def serialize_workflows(pkgs):
    # Create map to look up tasks by their unique identifier.  This is so we can compile them into the workflow closure.
    tmap = {}
    for _, _, t in iterate_registerable_entities_in_order(
            pkgs, include_entities={_sdk_task.SdkTask}):
        tmap[t.id] = t

    for m, k, w in iterate_registerable_entities_in_order(
            pkgs, include_entities={_workflow.SdkWorkflow}):
        click.echo('Serializing {}'.format(
            _utils.fqdn(m.__name__, k, entity_type=w.resource_type)))
        task_templates = []
        for n in w.nodes:
            if n.task_node is not None:
                task_templates.append(tmap[n.task_node.reference_id])

        wc = _WorkflowClosure(workflow=w, tasks=task_templates)
        wc_pb = wc.to_flyte_idl()

        fname = '{}.pb'.format(
            _utils.fqdn(m.__name__, k, entity_type=w.resource_type))
        click.echo('  Writing workflow closure {}'.format(fname))
        _write_proto_to_file(wc_pb, fname)
Esempio n. 11
0
    def list_commands(self, ctx):
        commands = []
        lps = {}
        pkgs = ctx.obj[_constants.CTX_PACKAGES]
        # Discover all launch plans by loading the modules
        for m, k, lp in iterate_registerable_entities_in_order(
                pkgs, include_entities={_executable_mixins.ExecutableEntity}):
            safe_name = _utils.fqdn(m.__name__, k, entity_type=lp.resource_type)
            commands.append(safe_name)
            lps[safe_name] = lp

        ctx.obj['lps'] = lps
        commands.sort()

        return commands
Esempio n. 12
0
def register_all(project, domain, pkgs, test, version):
    if test:
        click.echo('Test switch enabled, not doing anything...')

    click.echo('Running task, workflow, and launch plan registration for {}, {}, {} with version {}'.format(
        project, domain, pkgs, version))

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    for m, k, o in iterate_registerable_entities_in_order(pkgs):
        name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)

        if test:
            click.echo("Would register {:20} {}".format("{}:".format(o.entity_type_text), name))
        else:
            click.echo("Registering {:20} {}".format("{}:".format(o.entity_type_text), name))
            o.register(project, domain, name, version)
Esempio n. 13
0
def fast_register_all(
    project: str,
    domain: str,
    pkgs: _List[str],
    test: bool,
    version: str,
    source_dir: _os.PathLike,
    dest_dir: _os.PathLike = None,
):
    if test:
        click.echo("Test switch enabled, not doing anything...")

    if not version:
        digest = _compute_digest(source_dir)
    else:
        digest = version
    remote_package_path = _upload_package(
        source_dir, digest, _sdk_config.FAST_REGISTRATION_DIR.get())

    click.echo(
        "Running task, workflow, and launch plan fast registration for {}, {}, {} with version {} and code dir {}"
        .format(project, domain, pkgs, digest, source_dir))

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    for m, k, o in iterate_registerable_entities_in_order(pkgs):
        name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
        o._id = _identifier.Identifier(o.resource_type, project, domain, name,
                                       digest)

        if test:
            click.echo("Would fast register {:20} {}".format(
                "{}:".format(o.entity_type_text), o.id.name))
        else:
            click.echo("Fast registering {:20} {}".format(
                "{}:".format(o.entity_type_text), o.id.name))
            _get_additional_distribution_loc(
                _sdk_config.FAST_REGISTRATION_DIR.get(), digest)
            if isinstance(o, _sdk_runnable_task.SdkRunnableTask):
                o.fast_register(project, domain, o.id.name, digest,
                                remote_package_path, dest_dir)
            else:
                o.register(project, domain, o.id.name, digest)
Esempio n. 14
0
    def list_commands(self, ctx):
        commands = []
        lps = {}
        pkgs = ctx.obj[_constants.CTX_PACKAGES]
        # Discover all launch plans by loading the modules
        for m, k, lp in iterate_registerable_entities_in_order(
                pkgs,
                include_entities={_SdkLaunchPlan},
                detect_unreferenced_entities=False):
            safe_name = _utils.fqdn(m.__name__,
                                    k,
                                    entity_type=lp.resource_type)
            commands.append(safe_name)
            lps[safe_name] = lp

        ctx.obj["lps"] = lps
        commands.sort()

        return commands
Esempio n. 15
0
def activate_all_impl(project, domain, version, pkgs, ignore_schedules=False):
    # TODO: This should be a transaction to ensure all or none are updated
    # TODO: We should optionally allow deactivation of missing launch plans

    # Discover all launch plans by loading the modules
    _logging.info(
        f"Setting this version's {version} launch plans active in {project} {domain}"
    )
    for m, k, lp in iterate_registerable_entities_in_order(
            pkgs,
            include_entities={_SdkLaunchPlan},
            detect_unreferenced_entities=False):
        lp._id = _identifier.Identifier(
            _identifier.ResourceType.LAUNCH_PLAN, project, domain,
            _utils.fqdn(m.__name__, k, entity_type=lp.resource_type), version)
        if not (lp.is_scheduled and ignore_schedules):
            _logging.info(
                f"Setting active {_utils.fqdn(m.__name__, k, entity_type=lp.resource_type)}"
            )
            lp.update(_launch_plan_model.LaunchPlanState.ACTIVE)
Esempio n. 16
0
def register_tasks_only(project, domain, pkgs, test, version):
    if test:
        click.echo("Test switch enabled, not doing anything...")

    click.echo(
        "Running task only registration for {}, {}, {} with version {}".format(
            project, domain, pkgs, version))

    # Discover all tasks by loading the module
    for m, k, t in iterate_registerable_entities_in_order(
            pkgs, include_entities={_task.SdkTask}):
        name = _utils.fqdn(m.__name__, k, entity_type=t.resource_type)

        if test:
            click.echo("Would register task {:20} {}".format(
                "{}:".format(t.entity_type_text), name))
        else:
            click.echo("Registering task {:20} {}".format(
                "{}:".format(t.entity_type_text), name))
            t.register(project, domain, name, version)
Esempio n. 17
0
def fast_register_tasks_only(
    project: str,
    domain: str,
    pkgs: _List[str],
    test: bool,
    version: str,
    source_dir: _os.PathLike,
    dest_dir: _os.PathLike = None,
):
    if test:
        click.echo("Test switch enabled, not doing anything...")

    if not version:
        digest = _compute_digest(source_dir)
    else:
        digest = version
    remote_package_path = _upload_package(
        source_dir, digest, _sdk_config.FAST_REGISTRATION_DIR.get())

    click.echo(
        "Running task only fast registration for {}, {}, {} with version {} and code dir {}"
        .format(project, domain, pkgs, digest, source_dir))

    # Discover all tasks by loading the module
    for m, k, t in iterate_registerable_entities_in_order(
            pkgs, include_entities={_task.SdkTask}):
        name = _utils.fqdn(m.__name__, k, entity_type=t.resource_type)

        if test:
            click.echo("Would fast register task {:20} {}".format(
                "{}:".format(t.entity_type_text), name))
        else:
            click.echo("Fast registering task {:20} {}".format(
                "{}:".format(t.entity_type_text), name))
            if isinstance(t, _sdk_runnable_task.SdkRunnableTask):
                t.fast_register(project, domain, name, digest,
                                remote_package_path, dest_dir)
            else:
                t.register(project, domain, name, digest)
Esempio n. 18
0
def serialize_all(
    pkgs: List[str] = None,
    local_source_root: str = None,
    folder: str = None,
    mode: SerializationMode = None,
    image: str = None,
    config_path: str = None,
    flytekit_virtualenv_root: str = None,
):
    """
    In order to register, we have to comply with Admin's endpoints. Those endpoints take the following objects. These
    flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
    flyteidl.admin.workflow_pb2.WorkflowSpec
    flyteidl.admin.task_pb2.TaskSpec

    However, if we were to merely call .to_flyte_idl() on all the discovered entities, what we would get are:
    flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
    flyteidl.core.workflow_pb2.WorkflowTemplate
    flyteidl.core.tasks_pb2.TaskTemplate

    For Workflows and Tasks therefore, there is special logic in the serialize function that translates these objects.

    :param list[Text] pkgs:
    :param Text folder:

    :return:
    """

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    env = {
        _internal_config.CONFIGURATION_PATH.env_var:
        config_path
        if config_path else _internal_config.CONFIGURATION_PATH.get(),
        _internal_config.IMAGE.env_var:
        image,
    }

    serialization_settings = flyte_context.SerializationSettings(
        project=_PROJECT_PLACEHOLDER,
        domain=_DOMAIN_PLACEHOLDER,
        version=_VERSION_PLACEHOLDER,
        image_config=flyte_context.get_image_config(img_name=image),
        env=env,
        flytekit_virtualenv_root=flytekit_virtualenv_root,
        entrypoint_settings=flyte_context.EntrypointSettings(
            path=_os.path.join(flytekit_virtualenv_root,
                               _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)),
    )
    with flyte_context.FlyteContext.current_context(
    ).new_serialization_settings(
            serialization_settings=serialization_settings) as ctx:
        loaded_entities = []
        for m, k, o in iterate_registerable_entities_in_order(
                pkgs, local_source_root=local_source_root):
            name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
            _logging.debug(
                "Found module {}\n   K: {} Instantiated in {}".format(
                    m, k, o._instantiated_in))
            o._id = _identifier.Identifier(o.resource_type,
                                           _PROJECT_PLACEHOLDER,
                                           _DOMAIN_PLACEHOLDER, name,
                                           _VERSION_PLACEHOLDER)
            loaded_entities.append(o)
            ctx.serialization_settings.add_instance_var(
                InstanceVar(module=m, name=k, o=o))

        click.echo(
            f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows"
        )

        mode = mode if mode else SerializationMode.DEFAULT
        # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan
        #  object, which gets added to the FlyteEntities.entities list, which we're iterating over.
        for entity in flyte_context.FlyteEntities.entities.copy():
            # TODO: Add a reachable check. Since these entities are always added by the constructor, weird things can
            #  happen. If someone creates a workflow inside a workflow, we don't actually want the inner workflow to be
            #  registered. Or do we? Certainly, we don't want inner tasks to be registered because we don't know how
            #  to reach them, but perhaps workflows should be okay to take into account generated workflows.
            #  Also a user may import dir_b.workflows from dir_a.workflows but workflow packages might only
            #  specify dir_a

            if isinstance(entity, PythonTask) or isinstance(
                    entity, Workflow) or isinstance(entity, LaunchPlan):
                if isinstance(entity, PythonTask):
                    if mode == SerializationMode.DEFAULT:
                        serializable = get_serializable(
                            ctx.serialization_settings, entity)
                    elif mode == SerializationMode.FAST:
                        serializable = get_serializable(
                            ctx.serialization_settings, entity, fast=True)
                    else:
                        raise AssertionError(
                            f"Unrecognized serialization mode: {mode}")
                else:
                    serializable = get_serializable(ctx.serialization_settings,
                                                    entity)
                loaded_entities.append(serializable)

                if isinstance(entity, Workflow):
                    lp = LaunchPlan.get_default_launch_plan(ctx, entity)
                    launch_plan = get_serializable(ctx.serialization_settings,
                                                   lp)
                    loaded_entities.append(launch_plan)

        zero_padded_length = _determine_text_chars(len(loaded_entities))
        for i, entity in enumerate(loaded_entities):
            if entity.has_registered:
                _logging.info(
                    f"Skipping entity {entity.id} because already registered")
                continue
            serialized = entity.serialize()
            fname_index = str(i).zfill(zero_padded_length)
            fname = "{}_{}_{}.pb".format(fname_index, entity.id.name,
                                         entity.id.resource_type)
            click.echo(
                f"  Writing type: {entity.id.resource_type_name()}, {entity.id.name} to\n    {fname}"
            )
            if folder:
                fname = _os.path.join(folder, fname)
            _write_proto_to_file(serialized, fname)

        click.secho(
            f"Successfully serialized {len(loaded_entities)} flyte objects",
            fg="green")
Esempio n. 19
0
def serialize_all(
    pkgs: List[str] = None,
    local_source_root: str = None,
    folder: str = None,
    mode: SerializationMode = None,
    image: str = None,
    config_path: str = None,
    flytekit_virtualenv_root: str = None,
    python_interpreter: str = None,
):
    """
    This function will write to the folder specified the following protobuf types ::
        flyteidl.admin.launch_plan_pb2.LaunchPlan
        flyteidl.admin.workflow_pb2.WorkflowSpec
        flyteidl.admin.task_pb2.TaskSpec

    These can be inspected by calling (in the launch plan case) ::
        flyte-cli parse-proto -f filename.pb -p flyteidl.admin.launch_plan_pb2.LaunchPlan

    See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the
    entity type.
    :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization.
    :param local_source_root: Where to start looking for the code.
    :param folder: Where to write the output protobuf files
    :param mode: Regular vs fast
    :param image: The fully qualified and versioned default image to use
    :param config_path: Path to the config file, if any, to be used during serialization
    :param flytekit_virtualenv_root: The full path of the virtual env in the container.
    """

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    env = {
        _internal_config.CONFIGURATION_PATH.env_var:
        config_path
        if config_path else _internal_config.CONFIGURATION_PATH.get(),
        _internal_config.IMAGE.env_var:
        image,
    }

    if not (mode == SerializationMode.DEFAULT
            or mode == SerializationMode.FAST):
        raise AssertionError(f"Unrecognized serialization mode: {mode}")
    fast_serialization_settings = flyte_context.FastSerializationSettings(
        enabled=mode == SerializationMode.FAST,
        # TODO: if we want to move the destination dir as a serialization argument, we should initialize it here
    )
    serialization_settings = flyte_context.SerializationSettings(
        project=_PROJECT_PLACEHOLDER,
        domain=_DOMAIN_PLACEHOLDER,
        version=_VERSION_PLACEHOLDER,
        image_config=flyte_context.get_image_config(img_name=image),
        env=env,
        flytekit_virtualenv_root=flytekit_virtualenv_root,
        python_interpreter=python_interpreter,
        entrypoint_settings=flyte_context.EntrypointSettings(
            path=_os.path.join(flytekit_virtualenv_root,
                               _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)),
        fast_serialization_settings=fast_serialization_settings,
    )
    ctx = flyte_context.FlyteContextManager.current_context(
    ).with_serialization_settings(serialization_settings)
    with flyte_context.FlyteContextManager.with_context(ctx) as ctx:
        old_style_entities = []
        # This first for loop is for legacy API entities - SdkTask, SdkWorkflow, etc. The _get_entity_to_module
        # function that this iterate calls only works on legacy objects
        for m, k, o in iterate_registerable_entities_in_order(
                pkgs, local_source_root=local_source_root):
            name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
            _logging.debug(
                "Found module {}\n   K: {} Instantiated in {}".format(
                    m, k, o._instantiated_in))
            o._id = _identifier.Identifier(o.resource_type,
                                           _PROJECT_PLACEHOLDER,
                                           _DOMAIN_PLACEHOLDER, name,
                                           _VERSION_PLACEHOLDER)
            old_style_entities.append(o)

        serialized_old_style_entities = []
        for entity in old_style_entities:
            if entity.has_registered:
                _logging.info(
                    f"Skipping entity {entity.id} because already registered")
                continue
            serialized_old_style_entities.append(entity.serialize())

        click.echo(
            f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows"
        )

        new_api_model_values = get_registrable_entities(ctx)

        loaded_entities = serialized_old_style_entities + new_api_model_values
        if folder is None:
            folder = "."
        persist_registrable_entities(loaded_entities, folder)

        click.secho(
            f"Successfully serialized {len(loaded_entities)} flyte objects",
            fg="green")
Esempio n. 20
0
def serialize_all(
    pkgs: List[str] = None,
    local_source_root: str = None,
    folder: str = None,
    mode: SerializationMode = None,
    image: str = None,
    config_path: str = None,
    flytekit_virtualenv_root: str = None,
    python_interpreter: str = None,
):
    """
    This function will write to the folder specified the following protobuf types ::
        flyteidl.admin.launch_plan_pb2.LaunchPlan
        flyteidl.admin.workflow_pb2.WorkflowSpec
        flyteidl.admin.task_pb2.TaskSpec

    These can be inspected by calling (in the launch plan case) ::
        flyte-cli parse-proto -f filename.pb -p flyteidl.admin.launch_plan_pb2.LaunchPlan

    See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the
    entity type.
    :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization.
    :param local_source_root: Where to start looking for the code.
    :param folder: Where to write the output protobuf files
    :param mode: Regular vs fast
    :param image: The fully qualified and versioned default image to use
    :param config_path: Path to the config file, if any, to be used during serialization
    :param flytekit_virtualenv_root: The full path of the virtual env in the container.
    """

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    env = {
        _internal_config.CONFIGURATION_PATH.env_var:
        config_path
        if config_path else _internal_config.CONFIGURATION_PATH.get(),
        _internal_config.IMAGE.env_var:
        image,
    }

    serialization_settings = flyte_context.SerializationSettings(
        project=_PROJECT_PLACEHOLDER,
        domain=_DOMAIN_PLACEHOLDER,
        version=_VERSION_PLACEHOLDER,
        image_config=flyte_context.get_image_config(img_name=image),
        env=env,
        flytekit_virtualenv_root=flytekit_virtualenv_root,
        python_interpreter=python_interpreter,
        entrypoint_settings=flyte_context.EntrypointSettings(
            path=_os.path.join(flytekit_virtualenv_root,
                               _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)),
    )
    ctx = flyte_context.FlyteContextManager.current_context(
    ).with_serialization_settings(serialization_settings)
    with flyte_context.FlyteContextManager.with_context(ctx) as ctx:
        old_style_entities = []
        # This first for loop is for legacy API entities - SdkTask, SdkWorkflow, etc. The _get_entity_to_module
        # function that this iterate calls only works on legacy objects
        for m, k, o in iterate_registerable_entities_in_order(
                pkgs, local_source_root=local_source_root):
            name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
            _logging.debug(
                "Found module {}\n   K: {} Instantiated in {}".format(
                    m, k, o._instantiated_in))
            o._id = _identifier.Identifier(o.resource_type,
                                           _PROJECT_PLACEHOLDER,
                                           _DOMAIN_PLACEHOLDER, name,
                                           _VERSION_PLACEHOLDER)
            old_style_entities.append(o)

        serialized_old_style_entities = []
        for entity in old_style_entities:
            if entity.has_registered:
                _logging.info(
                    f"Skipping entity {entity.id} because already registered")
                continue
            serialized_old_style_entities.append(entity.serialize())

        click.echo(
            f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows"
        )

        mode = mode if mode else SerializationMode.DEFAULT

        new_api_serializable_entities = OrderedDict()
        # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan
        #  object, which gets added to the FlyteEntities.entities list, which we're iterating over.
        for entity in flyte_context.FlyteEntities.entities.copy():
            # TODO: Add a reachable check. Since these entities are always added by the constructor, weird things can
            #  happen. If someone creates a workflow inside a workflow, we don't actually want the inner workflow to be
            #  registered. Or do we? Certainly, we don't want inner tasks to be registered because we don't know how
            #  to reach them, but perhaps workflows should be okay to take into account generated workflows.
            #  Also a user may import dir_b.workflows from dir_a.workflows but workflow packages might only
            #  specify dir_a
            if isinstance(entity, PythonTask) or isinstance(
                    entity, WorkflowBase) or isinstance(entity, LaunchPlan):
                if isinstance(entity, PythonTask):
                    if mode == SerializationMode.DEFAULT:
                        get_serializable(new_api_serializable_entities,
                                         ctx.serialization_settings, entity)
                    elif mode == SerializationMode.FAST:
                        get_serializable(new_api_serializable_entities,
                                         ctx.serialization_settings,
                                         entity,
                                         fast=True)
                    else:
                        raise AssertionError(
                            f"Unrecognized serialization mode: {mode}")
                else:
                    get_serializable(new_api_serializable_entities,
                                     ctx.serialization_settings, entity)

                if isinstance(entity, WorkflowBase):
                    lp = LaunchPlan.get_default_launch_plan(ctx, entity)
                    get_serializable(new_api_serializable_entities,
                                     ctx.serialization_settings, lp)

        new_api_model_values = list(new_api_serializable_entities.values())
        new_api_model_values = list(
            filter(_should_register_with_admin, new_api_model_values))
        new_api_model_values = [v.to_flyte_idl() for v in new_api_model_values]

        loaded_entities = serialized_old_style_entities + new_api_model_values
        zero_padded_length = _determine_text_chars(len(loaded_entities))
        for i, entity in enumerate(loaded_entities):
            fname_index = str(i).zfill(zero_padded_length)
            if isinstance(entity, _idl_admin_TaskSpec):
                fname = "{}_{}_1.pb".format(fname_index,
                                            entity.template.id.name)
            elif isinstance(entity, _idl_admin_WorkflowSpec):
                fname = "{}_{}_2.pb".format(fname_index,
                                            entity.template.id.name)
            elif isinstance(entity, _idl_admin_LaunchPlan):
                fname = "{}_{}_3.pb".format(fname_index, entity.id.name)
            else:
                raise Exception(f"Bad format {type(entity)}")
            click.echo(f"  Writing to file: {fname}")
            if folder:
                fname = _os.path.join(folder, fname)
            _write_proto_to_file(entity, fname)

        click.secho(
            f"Successfully serialized {len(loaded_entities)} flyte objects",
            fg="green")