Пример #1
0
def serialize_tasks_only(pkgs, folder=None):
    """
    :param list[Text] pkgs:
    :param Text folder:

    :return:
    """
    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    loaded_entities = []
    for m, k, o in iterate_registerable_entities_in_order(pkgs, include_entities={_sdk_task.SdkTask}):
        name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
        _logging.debug("Found module {}\n   K: {} Instantiated in {}".format(m, k, o._instantiated_in))
        o._id = _identifier.Identifier(
            o.resource_type, _PROJECT_PLACEHOLDER, _DOMAIN_PLACEHOLDER, name, _VERSION_PLACEHOLDER
        )
        loaded_entities.append(o)

    zero_padded_length = _determine_text_chars(len(loaded_entities))
    for i, entity in enumerate(loaded_entities):
        serialized = entity.serialize()
        fname_index = str(i).zfill(zero_padded_length)
        fname = "{}_{}.pb".format(fname_index, entity._id.name)
        click.echo("  Writing {} to\n    {}".format(entity._id, fname))
        if folder:
            fname = _os.path.join(folder, fname)
        _write_proto_to_file(serialized, fname)

        identifier_fname = "{}_{}.identifier.pb".format(fname_index, entity._id.name)
        if folder:
            identifier_fname = _os.path.join(folder, identifier_fname)
        _write_proto_to_file(entity._id.to_flyte_idl(), identifier_fname)
Пример #2
0
    def get_command(self, ctx, lp_argument):
        # Get the launch plan object in one of two ways. If get_command is being called by the list function
        # then it should have been cached in the context.
        # If we are actually running the command, then it won't have been cached and we'll have to load everything again
        launch_plan = None
        pkgs = ctx.obj[_constants.CTX_PACKAGES]

        if 'lps' in ctx.obj:
            launch_plan = ctx.obj['lps'][lp_argument]
        else:
            for m, k, lp in iterate_registerable_entities_in_order(
                    pkgs,
                    include_entities={_executable_mixins.ExecutableEntity},
                    detect_unreferenced_entities=False):
                safe_name = _utils.fqdn(m.__name__,
                                        k,
                                        entity_type=lp.resource_type)
                if lp_argument == safe_name:
                    launch_plan = lp

        if launch_plan is None:
            raise Exception(
                'Could not load launch plan {}'.format(lp_argument))

        launch_plan._id = _identifier.Identifier(
            _identifier.ResourceType.LAUNCH_PLAN,
            ctx.obj[_constants.CTX_PROJECT], ctx.obj[_constants.CTX_DOMAIN],
            lp_argument, ctx.obj[_constants.CTX_VERSION])
        return self._get_command(ctx, launch_plan, lp_argument)
Пример #3
0
def register_all(project, domain, pkgs, test, version):
    if test:
        click.echo('Test switch enabled, not doing anything...')
    click.echo('Running task, workflow, and launch plan registration for {}, {}, {} with version {}'.format(
        project, domain, pkgs, version))

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    loaded_entities = []
    for m, k, o in iterate_registerable_entities_in_order(pkgs):
        name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
        _logging.debug("Found module {}\n   K: {} Instantiated in {}".format(m, k, o._instantiated_in))
        o._id = _identifier.Identifier(
            o.resource_type,
            project,
            domain,
            name,
            version
        )
        loaded_entities.append(o)

    for o in loaded_entities:
        if test:
            click.echo("Would register {:20} {}".format("{}:".format(o.entity_type_text), o.id.name))
        else:
            click.echo("Registering {:20} {}".format("{}:".format(o.entity_type_text), o.id.name))
            o.register(project, domain, o.id.name, version)
Пример #4
0
def serialize_tasks(pkgs):
    # Serialize all tasks
    for m, k, t in iterate_registerable_entities_in_order(
            pkgs, include_entities={_sdk_task.SdkTask}):
        fname = '{}.pb'.format(
            _utils.fqdn(m.__name__, k, entity_type=t.resource_type))
        click.echo('Writing task {} to {}'.format(t.id, fname))
        pb = t.to_flyte_idl()
        _write_proto_to_file(pb, fname)
Пример #5
0
def register_tasks_only(project, domain, pkgs, test, version):
    if test:
        click.echo('Test switch enabled, not doing anything...')

    click.echo('Running task only registration for {}, {}, {} with version {}'.format(
        project, domain, pkgs, version))

    # Discover all tasks by loading the module
    for m, k, t in iterate_registerable_entities_in_order(pkgs, include_entities={_task.SdkTask}):
        t.register(project, domain, _utils.fqdn(m.__name__, k, entity_type=t.resource_type), version)
Пример #6
0
def serialize_all(project, domain, pkgs, version, folder=None):
    """
    In order to register, we have to comply with Admin's endpoints. Those endpoints take the following objects. These
    flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
    flyteidl.admin.workflow_pb2.WorkflowSpec
    flyteidl.admin.task_pb2.TaskSpec

    However, if we were to merely call .to_flyte_idl() on all the discovered entities, what we would get are:
    flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
    flyteidl.core.workflow_pb2.WorkflowTemplate
    flyteidl.core.tasks_pb2.TaskTemplate

    For Workflows and Tasks therefore, there is special logic in the serialize function that translates these objects.

    :param Text project:
    :param Text domain:
    :param list[Text] pkgs:
    :param Text version:
    :param Text folder:

    :return:
    """

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    loaded_entities = []
    for m, k, o in iterate_registerable_entities_in_order(pkgs):
        name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
        _logging.debug("Found module {}\n   K: {} Instantiated in {}".format(
            m, k, o._instantiated_in))
        o._id = _identifier.Identifier(o.resource_type, project, domain, name,
                                       version)
        loaded_entities.append(o)

    zero_padded_length = _determine_text_chars(len(loaded_entities))
    for i, entity in enumerate(loaded_entities):
        serialized = entity.serialize()
        fname_index = str(i).zfill(zero_padded_length)
        fname = '{}_{}.pb'.format(fname_index, entity._id.name)
        click.echo('  Writing {} to\n    {}'.format(entity._id, fname))
        if folder:
            fname = _os.path.join(folder, fname)
        _write_proto_to_file(serialized, fname)

        # Not everything serialized will necessarily have an identifier field in it, even though some do (like the
        # TaskTemplate). To be more rigorous, we write an explicit identifier file that reflects the choices (like
        # project/domain, etc.) made for this serialize call. We should not allow users to specify a different project
        # for instance come registration time, to avoid mismatches between potential internal ids like the TaskTemplate
        # and the registered entity.
        identifier_fname = '{}_{}.identifier.pb'.format(
            fname_index, entity._id.name)
        if folder:
            identifier_fname = _os.path.join(folder, identifier_fname)
        _write_proto_to_file(entity._id.to_flyte_idl(), identifier_fname)
Пример #7
0
def register_all(project, domain, pkgs, test, version):
    if test:
        click.echo('Test switch enabled, not doing anything...')

    click.echo('Running task, workflow, and launch plan registration for {}, {}, {} with version {}'.format(
        project, domain, pkgs, version))

    for m, k, o in iterate_registerable_entities_in_order(pkgs):
        name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
        click.echo("Registering {:20} {}".format("{}:".format(o.entity_type_text), name))
        o.register(project, domain, name, version)
Пример #8
0
def activate_all_impl(project, domain, version, pkgs):
    # TODO: This should be a transaction to ensure all or none are updated
    # TODO: We should optionally allow deactivation of missing launch plans

    # Discover all launch plans by loading the modules
    for m, k, lp in iterate_registerable_entities_in_order(
            pkgs, include_entities={_SdkLaunchPlan}):
        lp._id = _identifier.Identifier(
            _identifier.ResourceType.LAUNCH_PLAN, project, domain,
            _utils.fqdn(m.__name__, k, entity_type=lp.resource_type), version)
        lp.update(_launch_plan_model.LaunchPlanState.ACTIVE)
Пример #9
0
def serialize_workflows(pkgs):
    # Create map to look up tasks by their unique identifier.  This is so we can compile them into the workflow closure.
    tmap = {}
    for _, _, t in iterate_registerable_entities_in_order(
            pkgs, include_entities={_sdk_task.SdkTask}):
        tmap[t.id] = t

    for m, k, w in iterate_registerable_entities_in_order(
            pkgs, include_entities={_workflow.SdkWorkflow}):
        click.echo('Serializing {}'.format(
            _utils.fqdn(m.__name__, k, entity_type=w.resource_type)))
        task_templates = []
        for n in w.nodes:
            if n.task_node is not None:
                task_templates.append(tmap[n.task_node.reference_id])

        wc = _WorkflowClosure(workflow=w, tasks=task_templates)
        wc_pb = wc.to_flyte_idl()

        fname = '{}.pb'.format(
            _utils.fqdn(m.__name__, k, entity_type=w.resource_type))
        click.echo('  Writing workflow closure {}'.format(fname))
        _write_proto_to_file(wc_pb, fname)
Пример #10
0
    def list_commands(self, ctx):
        commands = []
        lps = {}
        pkgs = ctx.obj[_constants.CTX_PACKAGES]
        # Discover all launch plans by loading the modules
        for m, k, lp in iterate_registerable_entities_in_order(
                pkgs, include_entities={_executable_mixins.ExecutableEntity}):
            safe_name = _utils.fqdn(m.__name__, k, entity_type=lp.resource_type)
            commands.append(safe_name)
            lps[safe_name] = lp

        ctx.obj['lps'] = lps
        commands.sort()

        return commands
Пример #11
0
def register_all(project, domain, pkgs, test, version):
    if test:
        click.echo('Test switch enabled, not doing anything...')

    click.echo('Running task, workflow, and launch plan registration for {}, {}, {} with version {}'.format(
        project, domain, pkgs, version))

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    for m, k, o in iterate_registerable_entities_in_order(pkgs):
        name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)

        if test:
            click.echo("Would register {:20} {}".format("{}:".format(o.entity_type_text), name))
        else:
            click.echo("Registering {:20} {}".format("{}:".format(o.entity_type_text), name))
            o.register(project, domain, name, version)
Пример #12
0
def fast_register_all(
    project: str,
    domain: str,
    pkgs: _List[str],
    test: bool,
    version: str,
    source_dir: _os.PathLike,
    dest_dir: _os.PathLike = None,
):
    if test:
        click.echo("Test switch enabled, not doing anything...")

    if not version:
        digest = _compute_digest(source_dir)
    else:
        digest = version
    remote_package_path = _upload_package(
        source_dir, digest, _sdk_config.FAST_REGISTRATION_DIR.get())

    click.echo(
        "Running task, workflow, and launch plan fast registration for {}, {}, {} with version {} and code dir {}"
        .format(project, domain, pkgs, digest, source_dir))

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    for m, k, o in iterate_registerable_entities_in_order(pkgs):
        name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
        o._id = _identifier.Identifier(o.resource_type, project, domain, name,
                                       digest)

        if test:
            click.echo("Would fast register {:20} {}".format(
                "{}:".format(o.entity_type_text), o.id.name))
        else:
            click.echo("Fast registering {:20} {}".format(
                "{}:".format(o.entity_type_text), o.id.name))
            _get_additional_distribution_loc(
                _sdk_config.FAST_REGISTRATION_DIR.get(), digest)
            if isinstance(o, _sdk_runnable_task.SdkRunnableTask):
                o.fast_register(project, domain, o.id.name, digest,
                                remote_package_path, dest_dir)
            else:
                o.register(project, domain, o.id.name, digest)
Пример #13
0
    def list_commands(self, ctx):
        commands = []
        lps = {}
        pkgs = ctx.obj[_constants.CTX_PACKAGES]
        # Discover all launch plans by loading the modules
        for m, k, lp in iterate_registerable_entities_in_order(
                pkgs,
                include_entities={_SdkLaunchPlan},
                detect_unreferenced_entities=False):
            safe_name = _utils.fqdn(m.__name__,
                                    k,
                                    entity_type=lp.resource_type)
            commands.append(safe_name)
            lps[safe_name] = lp

        ctx.obj["lps"] = lps
        commands.sort()

        return commands
Пример #14
0
def activate_all_impl(project, domain, version, pkgs, ignore_schedules=False):
    # TODO: This should be a transaction to ensure all or none are updated
    # TODO: We should optionally allow deactivation of missing launch plans

    # Discover all launch plans by loading the modules
    _logging.info(
        f"Setting this version's {version} launch plans active in {project} {domain}"
    )
    for m, k, lp in iterate_registerable_entities_in_order(
            pkgs,
            include_entities={_SdkLaunchPlan},
            detect_unreferenced_entities=False):
        lp._id = _identifier.Identifier(
            _identifier.ResourceType.LAUNCH_PLAN, project, domain,
            _utils.fqdn(m.__name__, k, entity_type=lp.resource_type), version)
        if not (lp.is_scheduled and ignore_schedules):
            _logging.info(
                f"Setting active {_utils.fqdn(m.__name__, k, entity_type=lp.resource_type)}"
            )
            lp.update(_launch_plan_model.LaunchPlanState.ACTIVE)
Пример #15
0
def register_tasks_only(project, domain, pkgs, test, version):
    if test:
        click.echo("Test switch enabled, not doing anything...")

    click.echo(
        "Running task only registration for {}, {}, {} with version {}".format(
            project, domain, pkgs, version))

    # Discover all tasks by loading the module
    for m, k, t in iterate_registerable_entities_in_order(
            pkgs, include_entities={_task.SdkTask}):
        name = _utils.fqdn(m.__name__, k, entity_type=t.resource_type)

        if test:
            click.echo("Would register task {:20} {}".format(
                "{}:".format(t.entity_type_text), name))
        else:
            click.echo("Registering task {:20} {}".format(
                "{}:".format(t.entity_type_text), name))
            t.register(project, domain, name, version)
Пример #16
0
def fast_register_tasks_only(
    project: str,
    domain: str,
    pkgs: _List[str],
    test: bool,
    version: str,
    source_dir: _os.PathLike,
    dest_dir: _os.PathLike = None,
):
    if test:
        click.echo("Test switch enabled, not doing anything...")

    if not version:
        digest = _compute_digest(source_dir)
    else:
        digest = version
    remote_package_path = _upload_package(
        source_dir, digest, _sdk_config.FAST_REGISTRATION_DIR.get())

    click.echo(
        "Running task only fast registration for {}, {}, {} with version {} and code dir {}"
        .format(project, domain, pkgs, digest, source_dir))

    # Discover all tasks by loading the module
    for m, k, t in iterate_registerable_entities_in_order(
            pkgs, include_entities={_task.SdkTask}):
        name = _utils.fqdn(m.__name__, k, entity_type=t.resource_type)

        if test:
            click.echo("Would fast register task {:20} {}".format(
                "{}:".format(t.entity_type_text), name))
        else:
            click.echo("Fast registering task {:20} {}".format(
                "{}:".format(t.entity_type_text), name))
            if isinstance(t, _sdk_runnable_task.SdkRunnableTask):
                t.fast_register(project, domain, name, digest,
                                remote_package_path, dest_dir)
            else:
                t.register(project, domain, name, digest)
Пример #17
0
def serialize_all(
    pkgs: List[str] = None,
    local_source_root: str = None,
    folder: str = None,
    mode: SerializationMode = None,
    image: str = None,
    config_path: str = None,
    flytekit_virtualenv_root: str = None,
):
    """
    In order to register, we have to comply with Admin's endpoints. Those endpoints take the following objects. These
    flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
    flyteidl.admin.workflow_pb2.WorkflowSpec
    flyteidl.admin.task_pb2.TaskSpec

    However, if we were to merely call .to_flyte_idl() on all the discovered entities, what we would get are:
    flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
    flyteidl.core.workflow_pb2.WorkflowTemplate
    flyteidl.core.tasks_pb2.TaskTemplate

    For Workflows and Tasks therefore, there is special logic in the serialize function that translates these objects.

    :param list[Text] pkgs:
    :param Text folder:

    :return:
    """

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    env = {
        _internal_config.CONFIGURATION_PATH.env_var:
        config_path
        if config_path else _internal_config.CONFIGURATION_PATH.get(),
        _internal_config.IMAGE.env_var:
        image,
    }

    serialization_settings = flyte_context.SerializationSettings(
        project=_PROJECT_PLACEHOLDER,
        domain=_DOMAIN_PLACEHOLDER,
        version=_VERSION_PLACEHOLDER,
        image_config=flyte_context.get_image_config(img_name=image),
        env=env,
        flytekit_virtualenv_root=flytekit_virtualenv_root,
        entrypoint_settings=flyte_context.EntrypointSettings(
            path=_os.path.join(flytekit_virtualenv_root,
                               _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)),
    )
    with flyte_context.FlyteContext.current_context(
    ).new_serialization_settings(
            serialization_settings=serialization_settings) as ctx:
        loaded_entities = []
        for m, k, o in iterate_registerable_entities_in_order(
                pkgs, local_source_root=local_source_root):
            name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
            _logging.debug(
                "Found module {}\n   K: {} Instantiated in {}".format(
                    m, k, o._instantiated_in))
            o._id = _identifier.Identifier(o.resource_type,
                                           _PROJECT_PLACEHOLDER,
                                           _DOMAIN_PLACEHOLDER, name,
                                           _VERSION_PLACEHOLDER)
            loaded_entities.append(o)
            ctx.serialization_settings.add_instance_var(
                InstanceVar(module=m, name=k, o=o))

        click.echo(
            f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows"
        )

        mode = mode if mode else SerializationMode.DEFAULT
        # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan
        #  object, which gets added to the FlyteEntities.entities list, which we're iterating over.
        for entity in flyte_context.FlyteEntities.entities.copy():
            # TODO: Add a reachable check. Since these entities are always added by the constructor, weird things can
            #  happen. If someone creates a workflow inside a workflow, we don't actually want the inner workflow to be
            #  registered. Or do we? Certainly, we don't want inner tasks to be registered because we don't know how
            #  to reach them, but perhaps workflows should be okay to take into account generated workflows.
            #  Also a user may import dir_b.workflows from dir_a.workflows but workflow packages might only
            #  specify dir_a

            if isinstance(entity, PythonTask) or isinstance(
                    entity, Workflow) or isinstance(entity, LaunchPlan):
                if isinstance(entity, PythonTask):
                    if mode == SerializationMode.DEFAULT:
                        serializable = get_serializable(
                            ctx.serialization_settings, entity)
                    elif mode == SerializationMode.FAST:
                        serializable = get_serializable(
                            ctx.serialization_settings, entity, fast=True)
                    else:
                        raise AssertionError(
                            f"Unrecognized serialization mode: {mode}")
                else:
                    serializable = get_serializable(ctx.serialization_settings,
                                                    entity)
                loaded_entities.append(serializable)

                if isinstance(entity, Workflow):
                    lp = LaunchPlan.get_default_launch_plan(ctx, entity)
                    launch_plan = get_serializable(ctx.serialization_settings,
                                                   lp)
                    loaded_entities.append(launch_plan)

        zero_padded_length = _determine_text_chars(len(loaded_entities))
        for i, entity in enumerate(loaded_entities):
            if entity.has_registered:
                _logging.info(
                    f"Skipping entity {entity.id} because already registered")
                continue
            serialized = entity.serialize()
            fname_index = str(i).zfill(zero_padded_length)
            fname = "{}_{}_{}.pb".format(fname_index, entity.id.name,
                                         entity.id.resource_type)
            click.echo(
                f"  Writing type: {entity.id.resource_type_name()}, {entity.id.name} to\n    {fname}"
            )
            if folder:
                fname = _os.path.join(folder, fname)
            _write_proto_to_file(serialized, fname)

        click.secho(
            f"Successfully serialized {len(loaded_entities)} flyte objects",
            fg="green")
Пример #18
0
def serialize_all(
    pkgs: List[str] = None,
    local_source_root: str = None,
    folder: str = None,
    mode: SerializationMode = None,
    image: str = None,
    config_path: str = None,
    flytekit_virtualenv_root: str = None,
    python_interpreter: str = None,
):
    """
    This function will write to the folder specified the following protobuf types ::
        flyteidl.admin.launch_plan_pb2.LaunchPlan
        flyteidl.admin.workflow_pb2.WorkflowSpec
        flyteidl.admin.task_pb2.TaskSpec

    These can be inspected by calling (in the launch plan case) ::
        flyte-cli parse-proto -f filename.pb -p flyteidl.admin.launch_plan_pb2.LaunchPlan

    See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the
    entity type.
    :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization.
    :param local_source_root: Where to start looking for the code.
    :param folder: Where to write the output protobuf files
    :param mode: Regular vs fast
    :param image: The fully qualified and versioned default image to use
    :param config_path: Path to the config file, if any, to be used during serialization
    :param flytekit_virtualenv_root: The full path of the virtual env in the container.
    """

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    env = {
        _internal_config.CONFIGURATION_PATH.env_var:
        config_path
        if config_path else _internal_config.CONFIGURATION_PATH.get(),
        _internal_config.IMAGE.env_var:
        image,
    }

    if not (mode == SerializationMode.DEFAULT
            or mode == SerializationMode.FAST):
        raise AssertionError(f"Unrecognized serialization mode: {mode}")
    fast_serialization_settings = flyte_context.FastSerializationSettings(
        enabled=mode == SerializationMode.FAST,
        # TODO: if we want to move the destination dir as a serialization argument, we should initialize it here
    )
    serialization_settings = flyte_context.SerializationSettings(
        project=_PROJECT_PLACEHOLDER,
        domain=_DOMAIN_PLACEHOLDER,
        version=_VERSION_PLACEHOLDER,
        image_config=flyte_context.get_image_config(img_name=image),
        env=env,
        flytekit_virtualenv_root=flytekit_virtualenv_root,
        python_interpreter=python_interpreter,
        entrypoint_settings=flyte_context.EntrypointSettings(
            path=_os.path.join(flytekit_virtualenv_root,
                               _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)),
        fast_serialization_settings=fast_serialization_settings,
    )
    ctx = flyte_context.FlyteContextManager.current_context(
    ).with_serialization_settings(serialization_settings)
    with flyte_context.FlyteContextManager.with_context(ctx) as ctx:
        old_style_entities = []
        # This first for loop is for legacy API entities - SdkTask, SdkWorkflow, etc. The _get_entity_to_module
        # function that this iterate calls only works on legacy objects
        for m, k, o in iterate_registerable_entities_in_order(
                pkgs, local_source_root=local_source_root):
            name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
            _logging.debug(
                "Found module {}\n   K: {} Instantiated in {}".format(
                    m, k, o._instantiated_in))
            o._id = _identifier.Identifier(o.resource_type,
                                           _PROJECT_PLACEHOLDER,
                                           _DOMAIN_PLACEHOLDER, name,
                                           _VERSION_PLACEHOLDER)
            old_style_entities.append(o)

        serialized_old_style_entities = []
        for entity in old_style_entities:
            if entity.has_registered:
                _logging.info(
                    f"Skipping entity {entity.id} because already registered")
                continue
            serialized_old_style_entities.append(entity.serialize())

        click.echo(
            f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows"
        )

        new_api_model_values = get_registrable_entities(ctx)

        loaded_entities = serialized_old_style_entities + new_api_model_values
        if folder is None:
            folder = "."
        persist_registrable_entities(loaded_entities, folder)

        click.secho(
            f"Successfully serialized {len(loaded_entities)} flyte objects",
            fg="green")
Пример #19
0
def serialize_all(
    pkgs: List[str] = None,
    local_source_root: str = None,
    folder: str = None,
    mode: SerializationMode = None,
    image: str = None,
    config_path: str = None,
    flytekit_virtualenv_root: str = None,
    python_interpreter: str = None,
):
    """
    This function will write to the folder specified the following protobuf types ::
        flyteidl.admin.launch_plan_pb2.LaunchPlan
        flyteidl.admin.workflow_pb2.WorkflowSpec
        flyteidl.admin.task_pb2.TaskSpec

    These can be inspected by calling (in the launch plan case) ::
        flyte-cli parse-proto -f filename.pb -p flyteidl.admin.launch_plan_pb2.LaunchPlan

    See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the
    entity type.
    :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization.
    :param local_source_root: Where to start looking for the code.
    :param folder: Where to write the output protobuf files
    :param mode: Regular vs fast
    :param image: The fully qualified and versioned default image to use
    :param config_path: Path to the config file, if any, to be used during serialization
    :param flytekit_virtualenv_root: The full path of the virtual env in the container.
    """

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    env = {
        _internal_config.CONFIGURATION_PATH.env_var:
        config_path
        if config_path else _internal_config.CONFIGURATION_PATH.get(),
        _internal_config.IMAGE.env_var:
        image,
    }

    serialization_settings = flyte_context.SerializationSettings(
        project=_PROJECT_PLACEHOLDER,
        domain=_DOMAIN_PLACEHOLDER,
        version=_VERSION_PLACEHOLDER,
        image_config=flyte_context.get_image_config(img_name=image),
        env=env,
        flytekit_virtualenv_root=flytekit_virtualenv_root,
        python_interpreter=python_interpreter,
        entrypoint_settings=flyte_context.EntrypointSettings(
            path=_os.path.join(flytekit_virtualenv_root,
                               _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)),
    )
    ctx = flyte_context.FlyteContextManager.current_context(
    ).with_serialization_settings(serialization_settings)
    with flyte_context.FlyteContextManager.with_context(ctx) as ctx:
        old_style_entities = []
        # This first for loop is for legacy API entities - SdkTask, SdkWorkflow, etc. The _get_entity_to_module
        # function that this iterate calls only works on legacy objects
        for m, k, o in iterate_registerable_entities_in_order(
                pkgs, local_source_root=local_source_root):
            name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
            _logging.debug(
                "Found module {}\n   K: {} Instantiated in {}".format(
                    m, k, o._instantiated_in))
            o._id = _identifier.Identifier(o.resource_type,
                                           _PROJECT_PLACEHOLDER,
                                           _DOMAIN_PLACEHOLDER, name,
                                           _VERSION_PLACEHOLDER)
            old_style_entities.append(o)

        serialized_old_style_entities = []
        for entity in old_style_entities:
            if entity.has_registered:
                _logging.info(
                    f"Skipping entity {entity.id} because already registered")
                continue
            serialized_old_style_entities.append(entity.serialize())

        click.echo(
            f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows"
        )

        mode = mode if mode else SerializationMode.DEFAULT

        new_api_serializable_entities = OrderedDict()
        # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan
        #  object, which gets added to the FlyteEntities.entities list, which we're iterating over.
        for entity in flyte_context.FlyteEntities.entities.copy():
            # TODO: Add a reachable check. Since these entities are always added by the constructor, weird things can
            #  happen. If someone creates a workflow inside a workflow, we don't actually want the inner workflow to be
            #  registered. Or do we? Certainly, we don't want inner tasks to be registered because we don't know how
            #  to reach them, but perhaps workflows should be okay to take into account generated workflows.
            #  Also a user may import dir_b.workflows from dir_a.workflows but workflow packages might only
            #  specify dir_a
            if isinstance(entity, PythonTask) or isinstance(
                    entity, WorkflowBase) or isinstance(entity, LaunchPlan):
                if isinstance(entity, PythonTask):
                    if mode == SerializationMode.DEFAULT:
                        get_serializable(new_api_serializable_entities,
                                         ctx.serialization_settings, entity)
                    elif mode == SerializationMode.FAST:
                        get_serializable(new_api_serializable_entities,
                                         ctx.serialization_settings,
                                         entity,
                                         fast=True)
                    else:
                        raise AssertionError(
                            f"Unrecognized serialization mode: {mode}")
                else:
                    get_serializable(new_api_serializable_entities,
                                     ctx.serialization_settings, entity)

                if isinstance(entity, WorkflowBase):
                    lp = LaunchPlan.get_default_launch_plan(ctx, entity)
                    get_serializable(new_api_serializable_entities,
                                     ctx.serialization_settings, lp)

        new_api_model_values = list(new_api_serializable_entities.values())
        new_api_model_values = list(
            filter(_should_register_with_admin, new_api_model_values))
        new_api_model_values = [v.to_flyte_idl() for v in new_api_model_values]

        loaded_entities = serialized_old_style_entities + new_api_model_values
        zero_padded_length = _determine_text_chars(len(loaded_entities))
        for i, entity in enumerate(loaded_entities):
            fname_index = str(i).zfill(zero_padded_length)
            if isinstance(entity, _idl_admin_TaskSpec):
                fname = "{}_{}_1.pb".format(fname_index,
                                            entity.template.id.name)
            elif isinstance(entity, _idl_admin_WorkflowSpec):
                fname = "{}_{}_2.pb".format(fname_index,
                                            entity.template.id.name)
            elif isinstance(entity, _idl_admin_LaunchPlan):
                fname = "{}_{}_3.pb".format(fname_index, entity.id.name)
            else:
                raise Exception(f"Bad format {type(entity)}")
            click.echo(f"  Writing to file: {fname}")
            if folder:
                fname = _os.path.join(folder, fname)
            _write_proto_to_file(entity, fname)

        click.secho(
            f"Successfully serialized {len(loaded_entities)} flyte objects",
            fg="green")