def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str, raw_output_data_prefix: str): """ Entrypoint for all PythonTask extensions """ _click.echo("Running native-typed task") cloud_provider = _platform_config.CLOUD_PROVIDER.get() log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get() _logging.getLogger().setLevel(log_level) ctx = FlyteContext.current_context() # Create directories user_workspace_dir = ctx.file_access.local_access.get_random_directory() _click.echo(f"Using user directory {user_workspace_dir}") pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True) from flytekit import __version__ as _api_version execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( project=_internal_config.EXECUTION_PROJECT.get(), domain=_internal_config.EXECUTION_DOMAIN.get(), name=_internal_config.EXECUTION_NAME.get(), ), execution_date=_datetime.datetime.utcnow(), stats=_get_stats( # Stats metric path will be: # registration_project.registration_domain.app.module.task_name.user_stats # and it will be tagged with execution-level values for project/domain/wf/lp "{}.{}.{}.user_stats".format( _internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(), _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(), _internal_config.TASK_NAME.get() or _internal_config.NAME.get(), ), tags={ "exec_project": _internal_config.EXECUTION_PROJECT.get(), "exec_domain": _internal_config.EXECUTION_DOMAIN.get(), "exec_workflow": _internal_config.EXECUTION_WORKFLOW.get(), "exec_launchplan": _internal_config.EXECUTION_LAUNCHPLAN.get(), "api_version": _api_version, }, ), logging=_logging, tmp_dir=user_workspace_dir, ) if cloud_provider == _constants.CloudProvider.AWS: file_access = _data_proxy.FileAccessProvider( local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_s3proxy.AwsS3Proxy(raw_output_data_prefix), ) elif cloud_provider == _constants.CloudProvider.GCP: file_access = _data_proxy.FileAccessProvider( local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_gcs_proxy.GCSProxy(raw_output_data_prefix), ) elif cloud_provider == _constants.CloudProvider.LOCAL: # A fake remote using the local disk will automatically be created file_access = _data_proxy.FileAccessProvider(local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get()) else: raise Exception(f"Bad cloud provider {cloud_provider}") with ctx.new_file_access_context(file_access_provider=file_access) as ctx: # TODO: This is copied from serialize, which means there's a similarity here I'm not seeing. env = { _internal_config.CONFIGURATION_PATH.env_var: _internal_config.CONFIGURATION_PATH.get(), _internal_config.IMAGE.env_var: _internal_config.IMAGE.get(), } serialization_settings = SerializationSettings( project=_internal_config.TASK_PROJECT.get(), domain=_internal_config.TASK_DOMAIN.get(), version=_internal_config.TASK_VERSION.get(), image_config=get_image_config(), env=env, ) # The reason we need this is because of dynamic tasks. Even if we move compilation all to Admin, # if a dynamic task calls some task, t1, we have to write to the DJ Spec the correct task # identifier for t1. with ctx.new_serialization_settings(serialization_settings=serialization_settings) as ctx: # Because execution states do not look up the context chain, it has to be made last with ctx.new_execution_context( mode=ExecutionState.Mode.TASK_EXECUTION, execution_params=execution_parameters ) as ctx: _dispatch_execute(ctx, task_def, inputs, output_prefix)
def execute(self, inputs, context=None): """ Just execute the task and write the outputs to where they belong :param flytekit.models.literals.LiteralMap inputs: :param dict[Text, Text] context: :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity] """ with _common_utils.AutoDeletingTempDir("engine_dir") as temp_dir: with _common_utils.AutoDeletingTempDir("task_dir") as task_dir: with _data_proxy.LocalWorkingDirectoryContext(task_dir): with _data_proxy.RemoteDataContext(): output_file_dict = dict() # This sets the logging level for user code and is the only place an sdk setting gets # used at runtime. Optionally, Propeller can set an internal config setting which # takes precedence. log_level = _internal_config.LOGGING_LEVEL.get( ) or _sdk_config.LOGGING_LEVEL.get() _logging.getLogger().setLevel(log_level) try: output_file_dict = self.sdk_task.execute( _common_engine.EngineContext( execution_id=_identifier. WorkflowExecutionIdentifier( project=_internal_config. EXECUTION_PROJECT.get(), domain=_internal_config. EXECUTION_DOMAIN.get(), name=_internal_config.EXECUTION_NAME. get()), execution_date=_datetime.utcnow(), stats=_get_stats( # Stats metric path will be: # registration_project.registration_domain.app.module.task_name.user_stats # and it will be tagged with execution-level values for project/domain/wf/lp "{}.{}.{}.user_stats".format( _internal_config.TASK_PROJECT.get( ) or _internal_config.PROJECT.get(), _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(), _internal_config.TASK_NAME.get() or _internal_config.NAME.get()), tags={ 'exec_project': _internal_config.EXECUTION_PROJECT. get(), 'exec_domain': _internal_config.EXECUTION_DOMAIN. get(), 'exec_workflow': _internal_config. EXECUTION_WORKFLOW.get(), 'exec_launchplan': _internal_config. EXECUTION_LAUNCHPLAN.get(), 'api_version': _api_version }), logging=_logging, tmp_dir=task_dir), inputs) except _exception_scopes.FlyteScopedException as e: _logging.error( "!!! Begin Error Captured by Flyte !!!") output_file_dict[ _constants. ERROR_FILE_NAME] = _error_models.ErrorDocument( _error_models.ContainerError( e.error_code, e.verbose_message, e.kind)) _logging.error(e.verbose_message) _logging.error( "!!! End Error Captured by Flyte !!!") except Exception: _logging.error( "!!! Begin Unknown System Error Captured by Flyte !!!" ) exc_str = _traceback.format_exc() output_file_dict[ _constants. ERROR_FILE_NAME] = _error_models.ErrorDocument( _error_models.ContainerError( "SYSTEM:Unknown", exc_str, _error_models.ContainerError.Kind. RECOVERABLE)) _logging.error(exc_str) _logging.error( "!!! End Error Captured by Flyte !!!") finally: for k, v in _six.iteritems(output_file_dict): _common_utils.write_proto_to_file( v.to_flyte_idl(), _os.path.join(temp_dir.name, k)) _data_proxy.Data.put_data(temp_dir.name, context['output_prefix'], is_multipart=True)
def setup_execution( raw_output_data_prefix: str, checkpoint_path: Optional[str] = None, prev_checkpoint: Optional[str] = None, dynamic_addl_distro: Optional[str] = None, dynamic_dest_dir: Optional[str] = None, ): """ :param raw_output_data_prefix: :param checkpoint_path: :param prev_checkpoint: :param dynamic_addl_distro: Works in concert with the other dynamic arg. If present, indicates that if a dynamic task were to run, it should set fast serialize to true and use these values in FastSerializationSettings :param dynamic_dest_dir: See above. :return: """ exe_project = get_one_of("FLYTE_INTERNAL_EXECUTION_PROJECT", "_F_PRJ") exe_domain = get_one_of("FLYTE_INTERNAL_EXECUTION_DOMAIN", "_F_DM") exe_name = get_one_of("FLYTE_INTERNAL_EXECUTION_ID", "_F_NM") exe_wf = get_one_of("FLYTE_INTERNAL_EXECUTION_WORKFLOW", "_F_WF") exe_lp = get_one_of("FLYTE_INTERNAL_EXECUTION_LAUNCHPLAN", "_F_LP") tk_project = get_one_of("FLYTE_INTERNAL_TASK_PROJECT", "_F_TK_PRJ") tk_domain = get_one_of("FLYTE_INTERNAL_TASK_DOMAIN", "_F_TK_DM") tk_name = get_one_of("FLYTE_INTERNAL_TASK_NAME", "_F_TK_NM") tk_version = get_one_of("FLYTE_INTERNAL_TASK_VERSION", "_F_TK_V") compressed_serialization_settings = os.environ.get(SERIALIZED_CONTEXT_ENV_VAR, "") ctx = FlyteContextManager.current_context() # Create directories user_workspace_dir = ctx.file_access.get_random_local_directory() logger.info(f"Using user directory {user_workspace_dir}") pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True) from flytekit import __version__ as _api_version checkpointer = None if checkpoint_path is not None: checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint) logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}") execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( project=exe_project, domain=exe_domain, name=exe_name, ), execution_date=_datetime.datetime.utcnow(), stats=_get_stats( cfg=StatsConfig.auto(), # Stats metric path will be: # registration_project.registration_domain.app.module.task_name.user_stats # and it will be tagged with execution-level values for project/domain/wf/lp prefix=f"{tk_project}.{tk_domain}.{tk_name}.user_stats", tags={ "exec_project": exe_project, "exec_domain": exe_domain, "exec_workflow": exe_wf, "exec_launchplan": exe_lp, "api_version": _api_version, }, ), logging=user_space_logger, tmp_dir=user_workspace_dir, raw_output_prefix=raw_output_data_prefix, checkpoint=checkpointer, task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version), ) try: file_access = FileAccessProvider( local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"), raw_output_prefix=raw_output_data_prefix, ) except TypeError: # would be thrown from DataPersistencePlugins.find_plugin logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}") raise es = ctx.new_execution_state().with_params( mode=ExecutionState.Mode.TASK_EXECUTION, user_space_params=execution_parameters, ) cb = ctx.new_builder().with_file_access(file_access).with_execution_state(es) if compressed_serialization_settings: ss = SerializationSettings.from_transport(compressed_serialization_settings) ssb = ss.new_builder() ssb.project = exe_project ssb.domain = exe_domain ssb.version = tk_version if dynamic_addl_distro: ssb.fast_serialization_settings = FastSerializationSettings( enabled=True, destination_dir=dynamic_dest_dir, distribution_location=dynamic_addl_distro, ) cb = cb.with_serialization_settings(ssb.build()) with FlyteContextManager.with_context(cb) as ctx: yield ctx