Example #1
0
def test_serialization_images():
    @task(container_image="{{.image.xyz.fqn}}:{{.image.default.version}}")
    def t1(a: int) -> int:
        return a

    @task(container_image="{{.image.default.fqn}}:{{.image.default.version}}")
    def t2():
        pass

    @task
    def t3():
        pass

    @task(container_image="docker.io/org/myimage:latest")
    def t4():
        pass

    @task(container_image="docker.io/org/myimage:{{.image.default.version}}")
    def t5(a: int) -> int:
        return a

    os.environ["FLYTE_INTERNAL_IMAGE"] = "docker.io/default:version"
    set_flyte_config_file(
        os.path.join(os.path.dirname(os.path.realpath(__file__)),
                     "configs/images.config"))
    rs = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=get_image_config(),
    )
    t1_spec = get_serializable(OrderedDict(), rs, t1)
    assert t1_spec.template.container.image == "docker.io/xyz:version"
    t1_spec.to_flyte_idl()

    t2_spec = get_serializable(OrderedDict(), rs, t2)
    assert t2_spec.template.container.image == "docker.io/default:version"

    t3_spec = get_serializable(OrderedDict(), rs, t3)
    assert t3_spec.template.container.image == "docker.io/default:version"

    t4_spec = get_serializable(OrderedDict(), rs, t4)
    assert t4_spec.template.container.image == "docker.io/org/myimage:latest"

    t5_spec = get_serializable(OrderedDict(), rs, t5)
    assert t5_spec.template.container.image == "docker.io/org/myimage:version"
Example #2
0
def serialize_all(
    pkgs: List[str] = None,
    local_source_root: str = None,
    folder: str = None,
    mode: SerializationMode = None,
    image: str = None,
    config_path: str = None,
    flytekit_virtualenv_root: str = None,
):
    """
    In order to register, we have to comply with Admin's endpoints. Those endpoints take the following objects. These
    flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
    flyteidl.admin.workflow_pb2.WorkflowSpec
    flyteidl.admin.task_pb2.TaskSpec

    However, if we were to merely call .to_flyte_idl() on all the discovered entities, what we would get are:
    flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
    flyteidl.core.workflow_pb2.WorkflowTemplate
    flyteidl.core.tasks_pb2.TaskTemplate

    For Workflows and Tasks therefore, there is special logic in the serialize function that translates these objects.

    :param list[Text] pkgs:
    :param Text folder:

    :return:
    """

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    env = {
        _internal_config.CONFIGURATION_PATH.env_var:
        config_path
        if config_path else _internal_config.CONFIGURATION_PATH.get(),
        _internal_config.IMAGE.env_var:
        image,
    }

    serialization_settings = flyte_context.SerializationSettings(
        project=_PROJECT_PLACEHOLDER,
        domain=_DOMAIN_PLACEHOLDER,
        version=_VERSION_PLACEHOLDER,
        image_config=flyte_context.get_image_config(img_name=image),
        env=env,
        flytekit_virtualenv_root=flytekit_virtualenv_root,
        entrypoint_settings=flyte_context.EntrypointSettings(
            path=_os.path.join(flytekit_virtualenv_root,
                               _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)),
    )
    with flyte_context.FlyteContext.current_context(
    ).new_serialization_settings(
            serialization_settings=serialization_settings) as ctx:
        loaded_entities = []
        for m, k, o in iterate_registerable_entities_in_order(
                pkgs, local_source_root=local_source_root):
            name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
            _logging.debug(
                "Found module {}\n   K: {} Instantiated in {}".format(
                    m, k, o._instantiated_in))
            o._id = _identifier.Identifier(o.resource_type,
                                           _PROJECT_PLACEHOLDER,
                                           _DOMAIN_PLACEHOLDER, name,
                                           _VERSION_PLACEHOLDER)
            loaded_entities.append(o)
            ctx.serialization_settings.add_instance_var(
                InstanceVar(module=m, name=k, o=o))

        click.echo(
            f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows"
        )

        mode = mode if mode else SerializationMode.DEFAULT
        # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan
        #  object, which gets added to the FlyteEntities.entities list, which we're iterating over.
        for entity in flyte_context.FlyteEntities.entities.copy():
            # TODO: Add a reachable check. Since these entities are always added by the constructor, weird things can
            #  happen. If someone creates a workflow inside a workflow, we don't actually want the inner workflow to be
            #  registered. Or do we? Certainly, we don't want inner tasks to be registered because we don't know how
            #  to reach them, but perhaps workflows should be okay to take into account generated workflows.
            #  Also a user may import dir_b.workflows from dir_a.workflows but workflow packages might only
            #  specify dir_a

            if isinstance(entity, PythonTask) or isinstance(
                    entity, Workflow) or isinstance(entity, LaunchPlan):
                if isinstance(entity, PythonTask):
                    if mode == SerializationMode.DEFAULT:
                        serializable = get_serializable(
                            ctx.serialization_settings, entity)
                    elif mode == SerializationMode.FAST:
                        serializable = get_serializable(
                            ctx.serialization_settings, entity, fast=True)
                    else:
                        raise AssertionError(
                            f"Unrecognized serialization mode: {mode}")
                else:
                    serializable = get_serializable(ctx.serialization_settings,
                                                    entity)
                loaded_entities.append(serializable)

                if isinstance(entity, Workflow):
                    lp = LaunchPlan.get_default_launch_plan(ctx, entity)
                    launch_plan = get_serializable(ctx.serialization_settings,
                                                   lp)
                    loaded_entities.append(launch_plan)

        zero_padded_length = _determine_text_chars(len(loaded_entities))
        for i, entity in enumerate(loaded_entities):
            if entity.has_registered:
                _logging.info(
                    f"Skipping entity {entity.id} because already registered")
                continue
            serialized = entity.serialize()
            fname_index = str(i).zfill(zero_padded_length)
            fname = "{}_{}_{}.pb".format(fname_index, entity.id.name,
                                         entity.id.resource_type)
            click.echo(
                f"  Writing type: {entity.id.resource_type_name()}, {entity.id.name} to\n    {fname}"
            )
            if folder:
                fname = _os.path.join(folder, fname)
            _write_proto_to_file(serialized, fname)

        click.secho(
            f"Successfully serialized {len(loaded_entities)} flyte objects",
            fg="green")
Example #3
0
def serialize_all(
    pkgs: List[str] = None,
    local_source_root: str = None,
    folder: str = None,
    mode: SerializationMode = None,
    image: str = None,
    config_path: str = None,
    flytekit_virtualenv_root: str = None,
    python_interpreter: str = None,
):
    """
    This function will write to the folder specified the following protobuf types ::
        flyteidl.admin.launch_plan_pb2.LaunchPlan
        flyteidl.admin.workflow_pb2.WorkflowSpec
        flyteidl.admin.task_pb2.TaskSpec

    These can be inspected by calling (in the launch plan case) ::
        flyte-cli parse-proto -f filename.pb -p flyteidl.admin.launch_plan_pb2.LaunchPlan

    See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the
    entity type.
    :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization.
    :param local_source_root: Where to start looking for the code.
    :param folder: Where to write the output protobuf files
    :param mode: Regular vs fast
    :param image: The fully qualified and versioned default image to use
    :param config_path: Path to the config file, if any, to be used during serialization
    :param flytekit_virtualenv_root: The full path of the virtual env in the container.
    """

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    env = {
        _internal_config.CONFIGURATION_PATH.env_var:
        config_path
        if config_path else _internal_config.CONFIGURATION_PATH.get(),
        _internal_config.IMAGE.env_var:
        image,
    }

    if not (mode == SerializationMode.DEFAULT
            or mode == SerializationMode.FAST):
        raise AssertionError(f"Unrecognized serialization mode: {mode}")
    fast_serialization_settings = flyte_context.FastSerializationSettings(
        enabled=mode == SerializationMode.FAST,
        # TODO: if we want to move the destination dir as a serialization argument, we should initialize it here
    )
    serialization_settings = flyte_context.SerializationSettings(
        project=_PROJECT_PLACEHOLDER,
        domain=_DOMAIN_PLACEHOLDER,
        version=_VERSION_PLACEHOLDER,
        image_config=flyte_context.get_image_config(img_name=image),
        env=env,
        flytekit_virtualenv_root=flytekit_virtualenv_root,
        python_interpreter=python_interpreter,
        entrypoint_settings=flyte_context.EntrypointSettings(
            path=_os.path.join(flytekit_virtualenv_root,
                               _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)),
        fast_serialization_settings=fast_serialization_settings,
    )
    ctx = flyte_context.FlyteContextManager.current_context(
    ).with_serialization_settings(serialization_settings)
    with flyte_context.FlyteContextManager.with_context(ctx) as ctx:
        old_style_entities = []
        # This first for loop is for legacy API entities - SdkTask, SdkWorkflow, etc. The _get_entity_to_module
        # function that this iterate calls only works on legacy objects
        for m, k, o in iterate_registerable_entities_in_order(
                pkgs, local_source_root=local_source_root):
            name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
            _logging.debug(
                "Found module {}\n   K: {} Instantiated in {}".format(
                    m, k, o._instantiated_in))
            o._id = _identifier.Identifier(o.resource_type,
                                           _PROJECT_PLACEHOLDER,
                                           _DOMAIN_PLACEHOLDER, name,
                                           _VERSION_PLACEHOLDER)
            old_style_entities.append(o)

        serialized_old_style_entities = []
        for entity in old_style_entities:
            if entity.has_registered:
                _logging.info(
                    f"Skipping entity {entity.id} because already registered")
                continue
            serialized_old_style_entities.append(entity.serialize())

        click.echo(
            f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows"
        )

        new_api_model_values = get_registrable_entities(ctx)

        loaded_entities = serialized_old_style_entities + new_api_model_values
        if folder is None:
            folder = "."
        persist_registrable_entities(loaded_entities, folder)

        click.secho(
            f"Successfully serialized {len(loaded_entities)} flyte objects",
            fg="green")
Example #4
0
def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str, raw_output_data_prefix: str):
    """
    Entrypoint for all PythonTask extensions
    """
    _click.echo("Running native-typed task")
    cloud_provider = _platform_config.CLOUD_PROVIDER.get()
    log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get()
    _logging.getLogger().setLevel(log_level)

    ctx = FlyteContext.current_context()

    # Create directories
    user_workspace_dir = ctx.file_access.local_access.get_random_directory()
    _click.echo(f"Using user directory {user_workspace_dir}")
    pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True)
    from flytekit import __version__ as _api_version

    execution_parameters = ExecutionParameters(
        execution_id=_identifier.WorkflowExecutionIdentifier(
            project=_internal_config.EXECUTION_PROJECT.get(),
            domain=_internal_config.EXECUTION_DOMAIN.get(),
            name=_internal_config.EXECUTION_NAME.get(),
        ),
        execution_date=_datetime.datetime.utcnow(),
        stats=_get_stats(
            # Stats metric path will be:
            # registration_project.registration_domain.app.module.task_name.user_stats
            # and it will be tagged with execution-level values for project/domain/wf/lp
            "{}.{}.{}.user_stats".format(
                _internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(),
                _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(),
                _internal_config.TASK_NAME.get() or _internal_config.NAME.get(),
            ),
            tags={
                "exec_project": _internal_config.EXECUTION_PROJECT.get(),
                "exec_domain": _internal_config.EXECUTION_DOMAIN.get(),
                "exec_workflow": _internal_config.EXECUTION_WORKFLOW.get(),
                "exec_launchplan": _internal_config.EXECUTION_LAUNCHPLAN.get(),
                "api_version": _api_version,
            },
        ),
        logging=_logging,
        tmp_dir=user_workspace_dir,
    )

    if cloud_provider == _constants.CloudProvider.AWS:
        file_access = _data_proxy.FileAccessProvider(
            local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_s3proxy.AwsS3Proxy(raw_output_data_prefix),
        )
    elif cloud_provider == _constants.CloudProvider.GCP:
        file_access = _data_proxy.FileAccessProvider(
            local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_gcs_proxy.GCSProxy(raw_output_data_prefix),
        )
    elif cloud_provider == _constants.CloudProvider.LOCAL:
        # A fake remote using the local disk will automatically be created
        file_access = _data_proxy.FileAccessProvider(local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get())
    else:
        raise Exception(f"Bad cloud provider {cloud_provider}")

    with ctx.new_file_access_context(file_access_provider=file_access) as ctx:
        # TODO: This is copied from serialize, which means there's a similarity here I'm not seeing.
        env = {
            _internal_config.CONFIGURATION_PATH.env_var: _internal_config.CONFIGURATION_PATH.get(),
            _internal_config.IMAGE.env_var: _internal_config.IMAGE.get(),
        }

        serialization_settings = SerializationSettings(
            project=_internal_config.TASK_PROJECT.get(),
            domain=_internal_config.TASK_DOMAIN.get(),
            version=_internal_config.TASK_VERSION.get(),
            image_config=get_image_config(),
            env=env,
        )

        # The reason we need this is because of dynamic tasks. Even if we move compilation all to Admin,
        # if a dynamic task calls some task, t1, we have to write to the DJ Spec the correct task
        # identifier for t1.
        with ctx.new_serialization_settings(serialization_settings=serialization_settings) as ctx:
            # Because execution states do not look up the context chain, it has to be made last
            with ctx.new_execution_context(
                mode=ExecutionState.Mode.TASK_EXECUTION, execution_params=execution_parameters
            ) as ctx:
                _dispatch_execute(ctx, task_def, inputs, output_prefix)
Example #5
0
def serialize_all(
    pkgs: List[str] = None,
    local_source_root: str = None,
    folder: str = None,
    mode: SerializationMode = None,
    image: str = None,
    config_path: str = None,
    flytekit_virtualenv_root: str = None,
    python_interpreter: str = None,
):
    """
    This function will write to the folder specified the following protobuf types ::
        flyteidl.admin.launch_plan_pb2.LaunchPlan
        flyteidl.admin.workflow_pb2.WorkflowSpec
        flyteidl.admin.task_pb2.TaskSpec

    These can be inspected by calling (in the launch plan case) ::
        flyte-cli parse-proto -f filename.pb -p flyteidl.admin.launch_plan_pb2.LaunchPlan

    See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the
    entity type.
    :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization.
    :param local_source_root: Where to start looking for the code.
    :param folder: Where to write the output protobuf files
    :param mode: Regular vs fast
    :param image: The fully qualified and versioned default image to use
    :param config_path: Path to the config file, if any, to be used during serialization
    :param flytekit_virtualenv_root: The full path of the virtual env in the container.
    """

    # m = module (i.e. python file)
    # k = value of dir(m), type str
    # o = object (e.g. SdkWorkflow)
    env = {
        _internal_config.CONFIGURATION_PATH.env_var:
        config_path
        if config_path else _internal_config.CONFIGURATION_PATH.get(),
        _internal_config.IMAGE.env_var:
        image,
    }

    serialization_settings = flyte_context.SerializationSettings(
        project=_PROJECT_PLACEHOLDER,
        domain=_DOMAIN_PLACEHOLDER,
        version=_VERSION_PLACEHOLDER,
        image_config=flyte_context.get_image_config(img_name=image),
        env=env,
        flytekit_virtualenv_root=flytekit_virtualenv_root,
        python_interpreter=python_interpreter,
        entrypoint_settings=flyte_context.EntrypointSettings(
            path=_os.path.join(flytekit_virtualenv_root,
                               _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)),
    )
    ctx = flyte_context.FlyteContextManager.current_context(
    ).with_serialization_settings(serialization_settings)
    with flyte_context.FlyteContextManager.with_context(ctx) as ctx:
        old_style_entities = []
        # This first for loop is for legacy API entities - SdkTask, SdkWorkflow, etc. The _get_entity_to_module
        # function that this iterate calls only works on legacy objects
        for m, k, o in iterate_registerable_entities_in_order(
                pkgs, local_source_root=local_source_root):
            name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
            _logging.debug(
                "Found module {}\n   K: {} Instantiated in {}".format(
                    m, k, o._instantiated_in))
            o._id = _identifier.Identifier(o.resource_type,
                                           _PROJECT_PLACEHOLDER,
                                           _DOMAIN_PLACEHOLDER, name,
                                           _VERSION_PLACEHOLDER)
            old_style_entities.append(o)

        serialized_old_style_entities = []
        for entity in old_style_entities:
            if entity.has_registered:
                _logging.info(
                    f"Skipping entity {entity.id} because already registered")
                continue
            serialized_old_style_entities.append(entity.serialize())

        click.echo(
            f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows"
        )

        mode = mode if mode else SerializationMode.DEFAULT

        new_api_serializable_entities = OrderedDict()
        # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan
        #  object, which gets added to the FlyteEntities.entities list, which we're iterating over.
        for entity in flyte_context.FlyteEntities.entities.copy():
            # TODO: Add a reachable check. Since these entities are always added by the constructor, weird things can
            #  happen. If someone creates a workflow inside a workflow, we don't actually want the inner workflow to be
            #  registered. Or do we? Certainly, we don't want inner tasks to be registered because we don't know how
            #  to reach them, but perhaps workflows should be okay to take into account generated workflows.
            #  Also a user may import dir_b.workflows from dir_a.workflows but workflow packages might only
            #  specify dir_a
            if isinstance(entity, PythonTask) or isinstance(
                    entity, WorkflowBase) or isinstance(entity, LaunchPlan):
                if isinstance(entity, PythonTask):
                    if mode == SerializationMode.DEFAULT:
                        get_serializable(new_api_serializable_entities,
                                         ctx.serialization_settings, entity)
                    elif mode == SerializationMode.FAST:
                        get_serializable(new_api_serializable_entities,
                                         ctx.serialization_settings,
                                         entity,
                                         fast=True)
                    else:
                        raise AssertionError(
                            f"Unrecognized serialization mode: {mode}")
                else:
                    get_serializable(new_api_serializable_entities,
                                     ctx.serialization_settings, entity)

                if isinstance(entity, WorkflowBase):
                    lp = LaunchPlan.get_default_launch_plan(ctx, entity)
                    get_serializable(new_api_serializable_entities,
                                     ctx.serialization_settings, lp)

        new_api_model_values = list(new_api_serializable_entities.values())
        new_api_model_values = list(
            filter(_should_register_with_admin, new_api_model_values))
        new_api_model_values = [v.to_flyte_idl() for v in new_api_model_values]

        loaded_entities = serialized_old_style_entities + new_api_model_values
        zero_padded_length = _determine_text_chars(len(loaded_entities))
        for i, entity in enumerate(loaded_entities):
            fname_index = str(i).zfill(zero_padded_length)
            if isinstance(entity, _idl_admin_TaskSpec):
                fname = "{}_{}_1.pb".format(fname_index,
                                            entity.template.id.name)
            elif isinstance(entity, _idl_admin_WorkflowSpec):
                fname = "{}_{}_2.pb".format(fname_index,
                                            entity.template.id.name)
            elif isinstance(entity, _idl_admin_LaunchPlan):
                fname = "{}_{}_3.pb".format(fname_index, entity.id.name)
            else:
                raise Exception(f"Bad format {type(entity)}")
            click.echo(f"  Writing to file: {fname}")
            if folder:
                fname = _os.path.join(folder, fname)
            _write_proto_to_file(entity, fname)

        click.secho(
            f"Successfully serialized {len(loaded_entities)} flyte objects",
            fg="green")