示例#1
0
def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str, raw_output_data_prefix: str):
    """
    Entrypoint for all PythonTask extensions
    """
    _click.echo("Running native-typed task")
    cloud_provider = _platform_config.CLOUD_PROVIDER.get()
    log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get()
    _logging.getLogger().setLevel(log_level)

    ctx = FlyteContext.current_context()

    # Create directories
    user_workspace_dir = ctx.file_access.local_access.get_random_directory()
    _click.echo(f"Using user directory {user_workspace_dir}")
    pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True)
    from flytekit import __version__ as _api_version

    execution_parameters = ExecutionParameters(
        execution_id=_identifier.WorkflowExecutionIdentifier(
            project=_internal_config.EXECUTION_PROJECT.get(),
            domain=_internal_config.EXECUTION_DOMAIN.get(),
            name=_internal_config.EXECUTION_NAME.get(),
        ),
        execution_date=_datetime.datetime.utcnow(),
        stats=_get_stats(
            # Stats metric path will be:
            # registration_project.registration_domain.app.module.task_name.user_stats
            # and it will be tagged with execution-level values for project/domain/wf/lp
            "{}.{}.{}.user_stats".format(
                _internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(),
                _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(),
                _internal_config.TASK_NAME.get() or _internal_config.NAME.get(),
            ),
            tags={
                "exec_project": _internal_config.EXECUTION_PROJECT.get(),
                "exec_domain": _internal_config.EXECUTION_DOMAIN.get(),
                "exec_workflow": _internal_config.EXECUTION_WORKFLOW.get(),
                "exec_launchplan": _internal_config.EXECUTION_LAUNCHPLAN.get(),
                "api_version": _api_version,
            },
        ),
        logging=_logging,
        tmp_dir=user_workspace_dir,
    )

    if cloud_provider == _constants.CloudProvider.AWS:
        file_access = _data_proxy.FileAccessProvider(
            local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_s3proxy.AwsS3Proxy(raw_output_data_prefix),
        )
    elif cloud_provider == _constants.CloudProvider.GCP:
        file_access = _data_proxy.FileAccessProvider(
            local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_gcs_proxy.GCSProxy(raw_output_data_prefix),
        )
    elif cloud_provider == _constants.CloudProvider.LOCAL:
        # A fake remote using the local disk will automatically be created
        file_access = _data_proxy.FileAccessProvider(local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get())
    else:
        raise Exception(f"Bad cloud provider {cloud_provider}")

    with ctx.new_file_access_context(file_access_provider=file_access) as ctx:
        # TODO: This is copied from serialize, which means there's a similarity here I'm not seeing.
        env = {
            _internal_config.CONFIGURATION_PATH.env_var: _internal_config.CONFIGURATION_PATH.get(),
            _internal_config.IMAGE.env_var: _internal_config.IMAGE.get(),
        }

        serialization_settings = SerializationSettings(
            project=_internal_config.TASK_PROJECT.get(),
            domain=_internal_config.TASK_DOMAIN.get(),
            version=_internal_config.TASK_VERSION.get(),
            image_config=get_image_config(),
            env=env,
        )

        # The reason we need this is because of dynamic tasks. Even if we move compilation all to Admin,
        # if a dynamic task calls some task, t1, we have to write to the DJ Spec the correct task
        # identifier for t1.
        with ctx.new_serialization_settings(serialization_settings=serialization_settings) as ctx:
            # Because execution states do not look up the context chain, it has to be made last
            with ctx.new_execution_context(
                mode=ExecutionState.Mode.TASK_EXECUTION, execution_params=execution_parameters
            ) as ctx:
                _dispatch_execute(ctx, task_def, inputs, output_prefix)
示例#2
0
    def execute(self, inputs, context=None):
        """
        Just execute the task and write the outputs to where they belong
        :param flytekit.models.literals.LiteralMap inputs:
        :param dict[Text, Text] context:
        :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
        """

        with _common_utils.AutoDeletingTempDir("engine_dir") as temp_dir:
            with _common_utils.AutoDeletingTempDir("task_dir") as task_dir:
                with _data_proxy.LocalWorkingDirectoryContext(task_dir):
                    with _data_proxy.RemoteDataContext():
                        output_file_dict = dict()

                        # This sets the logging level for user code and is the only place an sdk setting gets
                        # used at runtime.  Optionally, Propeller can set an internal config setting which
                        # takes precedence.
                        log_level = _internal_config.LOGGING_LEVEL.get(
                        ) or _sdk_config.LOGGING_LEVEL.get()
                        _logging.getLogger().setLevel(log_level)

                        try:
                            output_file_dict = self.sdk_task.execute(
                                _common_engine.EngineContext(
                                    execution_id=_identifier.
                                    WorkflowExecutionIdentifier(
                                        project=_internal_config.
                                        EXECUTION_PROJECT.get(),
                                        domain=_internal_config.
                                        EXECUTION_DOMAIN.get(),
                                        name=_internal_config.EXECUTION_NAME.
                                        get()),
                                    execution_date=_datetime.utcnow(),
                                    stats=_get_stats(
                                        # Stats metric path will be:
                                        # registration_project.registration_domain.app.module.task_name.user_stats
                                        # and it will be tagged with execution-level values for project/domain/wf/lp
                                        "{}.{}.{}.user_stats".format(
                                            _internal_config.TASK_PROJECT.get(
                                            )
                                            or _internal_config.PROJECT.get(),
                                            _internal_config.TASK_DOMAIN.get()
                                            or _internal_config.DOMAIN.get(),
                                            _internal_config.TASK_NAME.get()
                                            or _internal_config.NAME.get()),
                                        tags={
                                            'exec_project':
                                            _internal_config.EXECUTION_PROJECT.
                                            get(),
                                            'exec_domain':
                                            _internal_config.EXECUTION_DOMAIN.
                                            get(),
                                            'exec_workflow':
                                            _internal_config.
                                            EXECUTION_WORKFLOW.get(),
                                            'exec_launchplan':
                                            _internal_config.
                                            EXECUTION_LAUNCHPLAN.get(),
                                            'api_version':
                                            _api_version
                                        }),
                                    logging=_logging,
                                    tmp_dir=task_dir),
                                inputs)
                        except _exception_scopes.FlyteScopedException as e:
                            _logging.error(
                                "!!! Begin Error Captured by Flyte !!!")
                            output_file_dict[
                                _constants.
                                ERROR_FILE_NAME] = _error_models.ErrorDocument(
                                    _error_models.ContainerError(
                                        e.error_code, e.verbose_message,
                                        e.kind))
                            _logging.error(e.verbose_message)
                            _logging.error(
                                "!!! End Error Captured by Flyte !!!")
                        except Exception:
                            _logging.error(
                                "!!! Begin Unknown System Error Captured by Flyte !!!"
                            )
                            exc_str = _traceback.format_exc()
                            output_file_dict[
                                _constants.
                                ERROR_FILE_NAME] = _error_models.ErrorDocument(
                                    _error_models.ContainerError(
                                        "SYSTEM:Unknown", exc_str,
                                        _error_models.ContainerError.Kind.
                                        RECOVERABLE))
                            _logging.error(exc_str)
                            _logging.error(
                                "!!! End Error Captured by Flyte !!!")
                        finally:
                            for k, v in _six.iteritems(output_file_dict):
                                _common_utils.write_proto_to_file(
                                    v.to_flyte_idl(),
                                    _os.path.join(temp_dir.name, k))
                            _data_proxy.Data.put_data(temp_dir.name,
                                                      context['output_prefix'],
                                                      is_multipart=True)
示例#3
0
def setup_execution(
    raw_output_data_prefix: str,
    checkpoint_path: Optional[str] = None,
    prev_checkpoint: Optional[str] = None,
    dynamic_addl_distro: Optional[str] = None,
    dynamic_dest_dir: Optional[str] = None,
):
    """

    :param raw_output_data_prefix:
    :param checkpoint_path:
    :param prev_checkpoint:
    :param dynamic_addl_distro: Works in concert with the other dynamic arg. If present, indicates that if a dynamic
      task were to run, it should set fast serialize to true and use these values in FastSerializationSettings
    :param dynamic_dest_dir: See above.
    :return:
    """
    exe_project = get_one_of("FLYTE_INTERNAL_EXECUTION_PROJECT", "_F_PRJ")
    exe_domain = get_one_of("FLYTE_INTERNAL_EXECUTION_DOMAIN", "_F_DM")
    exe_name = get_one_of("FLYTE_INTERNAL_EXECUTION_ID", "_F_NM")
    exe_wf = get_one_of("FLYTE_INTERNAL_EXECUTION_WORKFLOW", "_F_WF")
    exe_lp = get_one_of("FLYTE_INTERNAL_EXECUTION_LAUNCHPLAN", "_F_LP")

    tk_project = get_one_of("FLYTE_INTERNAL_TASK_PROJECT", "_F_TK_PRJ")
    tk_domain = get_one_of("FLYTE_INTERNAL_TASK_DOMAIN", "_F_TK_DM")
    tk_name = get_one_of("FLYTE_INTERNAL_TASK_NAME", "_F_TK_NM")
    tk_version = get_one_of("FLYTE_INTERNAL_TASK_VERSION", "_F_TK_V")

    compressed_serialization_settings = os.environ.get(SERIALIZED_CONTEXT_ENV_VAR, "")

    ctx = FlyteContextManager.current_context()
    # Create directories
    user_workspace_dir = ctx.file_access.get_random_local_directory()
    logger.info(f"Using user directory {user_workspace_dir}")
    pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True)
    from flytekit import __version__ as _api_version

    checkpointer = None
    if checkpoint_path is not None:
        checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint)
        logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}")

    execution_parameters = ExecutionParameters(
        execution_id=_identifier.WorkflowExecutionIdentifier(
            project=exe_project,
            domain=exe_domain,
            name=exe_name,
        ),
        execution_date=_datetime.datetime.utcnow(),
        stats=_get_stats(
            cfg=StatsConfig.auto(),
            # Stats metric path will be:
            # registration_project.registration_domain.app.module.task_name.user_stats
            # and it will be tagged with execution-level values for project/domain/wf/lp
            prefix=f"{tk_project}.{tk_domain}.{tk_name}.user_stats",
            tags={
                "exec_project": exe_project,
                "exec_domain": exe_domain,
                "exec_workflow": exe_wf,
                "exec_launchplan": exe_lp,
                "api_version": _api_version,
            },
        ),
        logging=user_space_logger,
        tmp_dir=user_workspace_dir,
        raw_output_prefix=raw_output_data_prefix,
        checkpoint=checkpointer,
        task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version),
    )

    try:
        file_access = FileAccessProvider(
            local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"),
            raw_output_prefix=raw_output_data_prefix,
        )
    except TypeError:  # would be thrown from DataPersistencePlugins.find_plugin
        logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}")
        raise

    es = ctx.new_execution_state().with_params(
        mode=ExecutionState.Mode.TASK_EXECUTION,
        user_space_params=execution_parameters,
    )
    cb = ctx.new_builder().with_file_access(file_access).with_execution_state(es)

    if compressed_serialization_settings:
        ss = SerializationSettings.from_transport(compressed_serialization_settings)
        ssb = ss.new_builder()
        ssb.project = exe_project
        ssb.domain = exe_domain
        ssb.version = tk_version
        if dynamic_addl_distro:
            ssb.fast_serialization_settings = FastSerializationSettings(
                enabled=True,
                destination_dir=dynamic_dest_dir,
                distribution_location=dynamic_addl_distro,
            )
        cb = cb.with_serialization_settings(ssb.build())

    with FlyteContextManager.with_context(cb) as ctx:
        yield ctx