def test_get_node_execution_outputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_node_execution_data = MagicMock( return_value=_execution_models.NodeExecutionGetDataResponse( execution_data_locations[0], execution_data_locations[1] ) ) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock( return_value=identifier.NodeExecutionIdentifier( "node-a", identifier.WorkflowExecutionIdentifier( "project", "domain", "name", ) ) ) inputs = engine.FlyteNodeExecution(m).get_outputs() assert len(inputs.literals) == 1 assert inputs.literals['b'].scalar.primitive.integer == 2 mock_client.get_node_execution_data.assert_called_once_with( identifier.NodeExecutionIdentifier( "node-a", identifier.WorkflowExecutionIdentifier( "project", "domain", "name", ) ) )
def test_get_task_execution_inputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_task_execution_data = MagicMock( return_value=_execution_models.TaskExecutionGetDataResponse( execution_data_locations[0], execution_data_locations[1])) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock(return_value=identifier.TaskExecutionIdentifier( identifier.Identifier(identifier.ResourceType.TASK, 'project', 'domain', 'task-name', 'version'), identifier.NodeExecutionIdentifier( "node-a", identifier.WorkflowExecutionIdentifier( "project", "domain", "name", )), 0)) inputs = engine.FlyteTaskExecution(m).get_inputs() assert len(inputs.literals) == 1 assert inputs.literals['a'].scalar.primitive.integer == 1 mock_client.get_task_execution_data.assert_called_once_with( identifier.TaskExecutionIdentifier( identifier.Identifier(identifier.ResourceType.TASK, 'project', 'domain', 'task-name', 'version'), identifier.NodeExecutionIdentifier( "node-a", identifier.WorkflowExecutionIdentifier( "project", "domain", "name", )), 0))
def test_get_full_node_execution_outputs(mock_client_factory): mock_client = MagicMock() mock_client.get_node_execution_data = MagicMock( return_value=_execution_models.NodeExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP) ) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock( return_value=identifier.NodeExecutionIdentifier( "node-a", identifier.WorkflowExecutionIdentifier( "project", "domain", "name", ), ) ) outputs = engine.FlyteNodeExecution(m).get_outputs() assert len(outputs.literals) == 1 assert outputs.literals["b"].scalar.primitive.integer == 2 mock_client.get_node_execution_data.assert_called_once_with( identifier.NodeExecutionIdentifier( "node-a", identifier.WorkflowExecutionIdentifier( "project", "domain", "name", ), ) )
def test_get_full_task_execution_inputs(mock_client_factory): mock_client = MagicMock() mock_client.get_task_execution_data = MagicMock( return_value=_execution_models.TaskExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP) ) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock( return_value=identifier.TaskExecutionIdentifier( identifier.Identifier( identifier.ResourceType.TASK, "project", "domain", "task-name", "version", ), identifier.NodeExecutionIdentifier( "node-a", identifier.WorkflowExecutionIdentifier( "project", "domain", "name", ), ), 0, ) ) inputs = engine.FlyteTaskExecution(m).get_inputs() assert len(inputs.literals) == 1 assert inputs.literals["a"].scalar.primitive.integer == 1 mock_client.get_task_execution_data.assert_called_once_with( identifier.TaskExecutionIdentifier( identifier.Identifier( identifier.ResourceType.TASK, "project", "domain", "task-name", "version", ), identifier.NodeExecutionIdentifier( "node-a", identifier.WorkflowExecutionIdentifier( "project", "domain", "name", ), ), 0, ) )
def test_execution_notification_soft_overrides(mock_client_factory): mock_client = MagicMock() mock_client.create_execution = MagicMock( return_value=identifier.WorkflowExecutionIdentifier('xp', 'xd', 'xn')) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock(return_value=identifier.Identifier( identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version")) notification = _common_models.Notification( [0, 1, 2], email=_common_models.EmailNotification(["*****@*****.**"])) engine.FlyteLaunchPlan(m).execute('xp', 'xd', 'xn', literals.LiteralMap({}), notification_overrides=[notification]) mock_client.create_execution.assert_called_once_with( 'xp', 'xd', 'xn', _execution_models.ExecutionSpec( identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version"), literals.LiteralMap({}), _execution_models.ExecutionMetadata( _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, 'sdk', 0), notifications=_execution_models.NotificationList([notification]), ))
def initialize(): """ Re-initializes the context and erases the entire context """ # This is supplied so that tasks that rely on Flyte provided param functionality do not fail when run locally default_execution_id = _identifier.WorkflowExecutionIdentifier( project="local", domain="local", name="local") # Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users # are already acquainted with default_user_space_params = ExecutionParameters( execution_id=str( _SdkWorkflowExecutionIdentifier.promote_from_model( default_execution_id)), execution_date=_datetime.datetime.utcnow(), stats=_mock_stats.MockStats(), logging=_logging, tmp_dir=os.path.join(_sdk_config.LOCAL_SANDBOX.get(), "user_space"), ) default_context = FlyteContext( file_access=_data_proxy.default_local_file_access_provider) default_context = default_context.with_execution_state( default_context.new_execution_state().with_params( user_space_params=default_user_space_params)).build() default_context.set_stackframe( s=FlyteContextManager.get_origin_stackframe()) FlyteContextManager._OBJS = [default_context]
def launch(self, project, domain, name=None, inputs=None, notification_overrides=None, label_overrides=None, annotation_overrides=None, auth_role=None): """ Executes the task as a single task execution and returns the identifier. :param Text project: :param Text domain: :param Text name: :param flytekit.models.literals.LiteralMap inputs: The inputs to pass :param list[flytekit.models.common.Notification] notification_overrides: If specified, override the notifications. :param flytekit.models.common.Labels label_overrides: :param flytekit.models.common.Annotations annotation_overrides: :param flytekit.models.common.AuthRole auth_role: :rtype: flytekit.models.execution.Execution """ disable_all = (notification_overrides == []) if disable_all: notification_overrides = None else: notification_overrides = _execution_models.NotificationList( notification_overrides or [] ) disable_all = None if not auth_role: assumable_iam_role = _auth_config.ASSUMABLE_IAM_ROLE.get() kubernetes_service_account = _auth_config.KUBERNETES_SERVICE_ACCOUNT.get() if not (assumable_iam_role or kubernetes_service_account): _logging.warning("Using deprecated `role` from config. " "Please update your config to use `assumable_iam_role` instead") assumable_iam_role = _sdk_config.ROLE.get() auth_role = _common_models.AuthRole(assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account) try: # TODO(katrogan): Add handling to register the underlying task if it's not already. client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client exec_id = client.create_execution( project, domain, name, _execution_models.ExecutionSpec( self.sdk_task.id, _execution_models.ExecutionMetadata( _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, 'sdk', # TODO: get principle 0 # TODO: Detect nesting ), notifications=notification_overrides, disable_all=disable_all, labels=label_overrides, annotations=annotation_overrides, auth_role=auth_role, ), inputs, ) except _user_exceptions.FlyteEntityAlreadyExistsException: exec_id = _identifier.WorkflowExecutionIdentifier(project, domain, name) return client.get_execution(exec_id)
def from_python_std(cls, string): """ Parses a string in the correct format into an identifier :param Text string: :rtype: TaskExecutionIdentifier """ segments = string.split(":") if len(segments) != 10: raise _user_exceptions.FlyteValueException( string, "The provided string was not in a parseable format. The string for an identifier must be in the format" " te:exec_project:exec_domain:exec_name:node_id:task_project:task_domain:task_name:task_version:retry.", ) resource_type, ep, ed, en, node_id, tp, td, tn, tv, retry = segments if resource_type != "te": raise _user_exceptions.FlyteValueException( resource_type, "The provided string could not be parsed. The first element of an execution identifier must be 'ex'.", ) return cls( task_id=Identifier(_core_identifier.ResourceType.TASK, tp, td, tn, tv), node_execution_id=_core_identifier.NodeExecutionIdentifier( node_id=node_id, execution_id=_core_identifier.WorkflowExecutionIdentifier( ep, ed, en), ), retry_attempt=int(retry), )
def test_task_node_metadata(): task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") wf_exec_id = identifier.WorkflowExecutionIdentifier( "project", "domain", "name") node_exec_id = identifier.NodeExecutionIdentifier( "node_id", wf_exec_id, ) te_id = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3) ds_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "t1", "abcdef") tag = catalog.CatalogArtifactTag("my-artifact-id", "some name") catalog_metadata = catalog.CatalogMetadata(dataset_id=ds_id, artifact_tag=tag, source_task_execution=te_id) obj = node_execution_models.TaskNodeMetadata(cache_status=0, catalog_key=catalog_metadata) assert obj.cache_status == 0 assert obj.catalog_key == catalog_metadata obj2 = node_execution_models.TaskNodeMetadata.from_flyte_idl( obj.to_flyte_idl()) assert obj2 == obj
def test_execution_notification_overrides(mock_client_factory): mock_client = MagicMock() mock_client.create_execution = MagicMock( return_value=identifier.WorkflowExecutionIdentifier('xp', 'xd', 'xn')) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock(return_value=identifier.Identifier( identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version")) engine.FlyteLaunchPlan(m).execute('xp', 'xd', 'xn', literals.LiteralMap({}), notification_overrides=[]) mock_client.create_execution.assert_called_once_with( 'xp', 'xd', 'xn', _execution_models.ExecutionSpec( identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version"), literals.LiteralMap({}), _execution_models.ExecutionMetadata( _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, 'sdk', 0), disable_all=True, ))
def launch( self, project, domain, name, inputs, notification_overrides=None, label_overrides=None, annotation_overrides=None, ): """ Creates a workflow execution using parameters specified in the launch plan. :param Text project: :param Text domain: :param Text name: :param flytekit.models.literals.LiteralMap inputs: :param list[flytekit.models.common.Notification] notification_overrides: If specified, override the notifications. :param flytekit.models.common.Labels label_overrides: :param flytekit.models.common.Annotations annotation_overrides: :rtype: flytekit.models.execution.Execution """ disable_all = notification_overrides == [] if disable_all: notification_overrides = None else: notification_overrides = _execution_models.NotificationList( notification_overrides or []) disable_all = None try: client = _FlyteClientManager( _platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client exec_id = client.create_execution( project, domain, name, _execution_models.ExecutionSpec( self.sdk_launch_plan.id, _execution_models.ExecutionMetadata( _execution_models.ExecutionMetadata.ExecutionMode. MANUAL, "sdk", # TODO: get principle 0, # TODO: Detect nesting ), notifications=notification_overrides, disable_all=disable_all, labels=label_overrides, annotations=annotation_overrides, ), inputs, ) except _user_exceptions.FlyteEntityAlreadyExistsException: exec_id = _identifier.WorkflowExecutionIdentifier( project, domain, name) return client.get_execution(exec_id)
def test_workflow_execution_identifier(): identifier = _identifier.WorkflowExecutionIdentifier( "project", "domain", "name") assert identifier == _identifier.WorkflowExecutionIdentifier.from_urn( "ex:project:domain:name") assert identifier == _identifier.WorkflowExecutionIdentifier.promote_from_model( _core_identifier.WorkflowExecutionIdentifier("project", "domain", "name")) assert identifier.__str__() == "ex:project:domain:name"
def test_node_execution_identifier(): wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name") obj = identifier.NodeExecutionIdentifier("node_id", wf_exec_id) assert obj.node_id == "node_id" assert obj.execution_id == wf_exec_id obj2 = identifier.NodeExecutionIdentifier.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj assert obj2.node_id == "node_id" assert obj2.execution_id == wf_exec_id
def test_workflow_node_metadata(): wf_exec_id = identifier.WorkflowExecutionIdentifier( "project", "domain", "name") obj = node_execution_models.WorkflowNodeMetadata(execution_id=wf_exec_id) assert obj.execution_id is wf_exec_id obj2 = node_execution_models.WorkflowNodeMetadata.from_flyte_idl( obj.to_flyte_idl()) assert obj == obj2
def test_workflow_execution_identifier(): obj = identifier.WorkflowExecutionIdentifier("project", "domain", "name") assert obj.project == "project" assert obj.domain == "domain" assert obj.name == "name" obj2 = identifier.WorkflowExecutionIdentifier.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj assert obj2.project == "project" assert obj2.domain == "domain" assert obj2.name == "name"
def test_get_execution_inputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_execution_data = MagicMock( return_value=_execution_models.WorkflowExecutionGetDataResponse( execution_data_locations[0], execution_data_locations[1], _EMPTY_LITERAL_MAP, _EMPTY_LITERAL_MAP)) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock( return_value=identifier.WorkflowExecutionIdentifier( "project", "domain", "name", )) inputs = engine.FlyteWorkflowExecution(m).get_inputs() assert len(inputs.literals) == 1 assert inputs.literals["a"].scalar.primitive.integer == 1 mock_client.get_execution_data.assert_called_once_with( identifier.WorkflowExecutionIdentifier("project", "domain", "name"))
def test_exec_params(): ep = ExecutionParameters( execution_id=id_models.WorkflowExecutionIdentifier("p", "d", "n"), task_id=id_models.Identifier(id_models.ResourceType.TASK, "local", "local", "local", "local"), execution_date=datetime.utcnow(), stats=mock_stats.MockStats(), logging=None, tmp_dir="/tmp", raw_output_prefix="", decks=[], ) assert ep.task_id.name == "local"
def test_task_execution_identifier(): task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name") node_exec_id = identifier.NodeExecutionIdentifier("node_id", wf_exec_id,) obj = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3) assert obj.retry_attempt == 3 assert obj.task_id == task_id assert obj.node_execution_id == node_exec_id obj2 = identifier.TaskExecutionIdentifier.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj assert obj2.retry_attempt == 3 assert obj2.task_id == task_id assert obj2.node_execution_id == node_exec_id
def test_task_execution_identifier(): task_id = _identifier.Identifier(_core_identifier.ResourceType.TASK, "project", "domain", "name", "version") node_execution_id = _core_identifier.NodeExecutionIdentifier( node_id="n0", execution_id=_core_identifier.WorkflowExecutionIdentifier( "project", "domain", "name")) identifier = _identifier.TaskExecutionIdentifier( task_id=task_id, node_execution_id=node_execution_id, retry_attempt=0, ) assert identifier == _identifier.TaskExecutionIdentifier.from_urn( "te:project:domain:name:n0:project:domain:name:version:0") assert identifier == _identifier.TaskExecutionIdentifier.promote_from_model( _core_identifier.TaskExecutionIdentifier(task_id, node_execution_id, 0)) assert identifier.__str__( ) == "te:project:domain:name:n0:project:domain:name:version:0"
def test_execution_annotation_overrides(mock_client_factory): mock_client = MagicMock() mock_client.create_execution = MagicMock( return_value=identifier.WorkflowExecutionIdentifier("xp", "xd", "xn")) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock(return_value=identifier.Identifier( identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version")) annotations = _common_models.Annotations({"my": "annotation"}) engine.FlyteLaunchPlan(m).launch( "xp", "xd", "xn", literals.LiteralMap({}), notification_overrides=[], annotation_overrides=annotations, ) mock_client.create_execution.assert_called_once_with( "xp", "xd", "xn", _execution_models.ExecutionSpec( identifier.Identifier( identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version", ), _execution_models.ExecutionMetadata( _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), disable_all=True, annotations=annotations, ), literals.LiteralMap({}), )
def initialize(): """ Re-initializes the context and erases the entire context """ # This is supplied so that tasks that rely on Flyte provided param functionality do not fail when run locally default_execution_id = _identifier.WorkflowExecutionIdentifier( project="local", domain="local", name="local") cfg = Config.auto() # Ensure a local directory is available for users to work with. user_space_path = os.path.join(cfg.local_sandbox_path, "user_space") pathlib.Path(user_space_path).mkdir(parents=True, exist_ok=True) # Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users # are already acquainted with default_context = FlyteContext( file_access=default_local_file_access_provider) default_user_space_params = ExecutionParameters( execution_id=WorkflowExecutionIdentifier.promote_from_model( default_execution_id), task_id=_identifier.Identifier(_identifier.ResourceType.TASK, "local", "local", "local", "local"), execution_date=_datetime.datetime.utcnow(), stats=mock_stats.MockStats(), logging=user_space_logger, tmp_dir=user_space_path, raw_output_prefix=default_context.file_access._raw_output_prefix, decks=[], ) default_context = default_context.with_execution_state( default_context.new_execution_state().with_params( user_space_params=default_user_space_params)).build() default_context.set_stackframe( s=FlyteContextManager.get_origin_stackframe()) flyte_context_Var.set([default_context])
def setup_execution( raw_output_data_prefix: str, checkpoint_path: Optional[str] = None, prev_checkpoint: Optional[str] = None, dynamic_addl_distro: Optional[str] = None, dynamic_dest_dir: Optional[str] = None, ): """ :param raw_output_data_prefix: :param checkpoint_path: :param prev_checkpoint: :param dynamic_addl_distro: Works in concert with the other dynamic arg. If present, indicates that if a dynamic task were to run, it should set fast serialize to true and use these values in FastSerializationSettings :param dynamic_dest_dir: See above. :return: """ exe_project = get_one_of("FLYTE_INTERNAL_EXECUTION_PROJECT", "_F_PRJ") exe_domain = get_one_of("FLYTE_INTERNAL_EXECUTION_DOMAIN", "_F_DM") exe_name = get_one_of("FLYTE_INTERNAL_EXECUTION_ID", "_F_NM") exe_wf = get_one_of("FLYTE_INTERNAL_EXECUTION_WORKFLOW", "_F_WF") exe_lp = get_one_of("FLYTE_INTERNAL_EXECUTION_LAUNCHPLAN", "_F_LP") tk_project = get_one_of("FLYTE_INTERNAL_TASK_PROJECT", "_F_TK_PRJ") tk_domain = get_one_of("FLYTE_INTERNAL_TASK_DOMAIN", "_F_TK_DM") tk_name = get_one_of("FLYTE_INTERNAL_TASK_NAME", "_F_TK_NM") tk_version = get_one_of("FLYTE_INTERNAL_TASK_VERSION", "_F_TK_V") compressed_serialization_settings = os.environ.get(SERIALIZED_CONTEXT_ENV_VAR, "") ctx = FlyteContextManager.current_context() # Create directories user_workspace_dir = ctx.file_access.get_random_local_directory() logger.info(f"Using user directory {user_workspace_dir}") pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True) from flytekit import __version__ as _api_version checkpointer = None if checkpoint_path is not None: checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint) logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}") execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( project=exe_project, domain=exe_domain, name=exe_name, ), execution_date=_datetime.datetime.utcnow(), stats=_get_stats( cfg=StatsConfig.auto(), # Stats metric path will be: # registration_project.registration_domain.app.module.task_name.user_stats # and it will be tagged with execution-level values for project/domain/wf/lp prefix=f"{tk_project}.{tk_domain}.{tk_name}.user_stats", tags={ "exec_project": exe_project, "exec_domain": exe_domain, "exec_workflow": exe_wf, "exec_launchplan": exe_lp, "api_version": _api_version, }, ), logging=user_space_logger, tmp_dir=user_workspace_dir, raw_output_prefix=raw_output_data_prefix, checkpoint=checkpointer, task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version), ) try: file_access = FileAccessProvider( local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"), raw_output_prefix=raw_output_data_prefix, ) except TypeError: # would be thrown from DataPersistencePlugins.find_plugin logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}") raise es = ctx.new_execution_state().with_params( mode=ExecutionState.Mode.TASK_EXECUTION, user_space_params=execution_parameters, ) cb = ctx.new_builder().with_file_access(file_access).with_execution_state(es) if compressed_serialization_settings: ss = SerializationSettings.from_transport(compressed_serialization_settings) ssb = ss.new_builder() ssb.project = exe_project ssb.domain = exe_domain ssb.version = tk_version if dynamic_addl_distro: ssb.fast_serialization_settings = FastSerializationSettings( enabled=True, destination_dir=dynamic_dest_dir, distribution_location=dynamic_addl_distro, ) cb = cb.with_serialization_settings(ssb.build()) with FlyteContextManager.with_context(cb) as ctx: yield ctx
def execute(self, inputs, context=None): """ Just execute the task and write the outputs to where they belong :param flytekit.models.literals.LiteralMap inputs: :param dict[Text, Text] context: :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity] """ with _common_utils.AutoDeletingTempDir("engine_dir") as temp_dir: with _common_utils.AutoDeletingTempDir("task_dir") as task_dir: with _data_proxy.LocalWorkingDirectoryContext(task_dir): with _data_proxy.RemoteDataContext(): output_file_dict = dict() # This sets the logging level for user code and is the only place an sdk setting gets # used at runtime. Optionally, Propeller can set an internal config setting which # takes precedence. log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get() _logging.getLogger().setLevel(log_level) try: output_file_dict = self.sdk_task.execute( _common_engine.EngineContext( execution_id=_identifier.WorkflowExecutionIdentifier( project=_internal_config.EXECUTION_PROJECT.get(), domain=_internal_config.EXECUTION_DOMAIN.get(), name=_internal_config.EXECUTION_NAME.get() ), execution_date=_datetime.utcnow(), stats=_get_stats( # Stats metric path will be: # registration_project.registration_domain.app.module.task_name.user_stats # and it will be tagged with execution-level values for project/domain/wf/lp "{}.{}.{}.user_stats".format( _internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(), _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(), _internal_config.TASK_NAME.get() or _internal_config.NAME.get() ), tags={ 'exec_project': _internal_config.EXECUTION_PROJECT.get(), 'exec_domain': _internal_config.EXECUTION_DOMAIN.get(), 'exec_workflow': _internal_config.EXECUTION_WORKFLOW.get(), 'exec_launchplan': _internal_config.EXECUTION_LAUNCHPLAN.get(), 'api_version': _api_version } ), logging=_logging, tmp_dir=task_dir ), inputs ) except _exception_scopes.FlyteScopedException as e: _logging.error("!!! Begin Error Captured by Flyte !!!") output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument( _error_models.ContainerError( e.error_code, e.verbose_message, e.kind ) ) _logging.error(e.verbose_message) _logging.error("!!! End Error Captured by Flyte !!!") except Exception: _logging.error("!!! Begin Unknown System Error Captured by Flyte !!!") exc_str = _traceback.format_exc() output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument( _error_models.ContainerError( "SYSTEM:Unknown", exc_str, _error_models.ContainerError.Kind.RECOVERABLE ) ) _logging.error(exc_str) _logging.error("!!! End Error Captured by Flyte !!!") finally: for k, v in _six.iteritems(output_file_dict): _common_utils.write_proto_to_file( v.to_flyte_idl(), _os.path.join(temp_dir.name, k) ) _data_proxy.Data.put_data(temp_dir.name, context['output_prefix'], is_multipart=True)
if self._flyte_client is not None: return self._flyte_client elif self._parent is not None: return self._parent.flyte_client else: raise Exception("No flyte_client initialized") # Hack... we'll think of something better in the future class FlyteEntities(object): entities = [] # This is supplied so that tasks that rely on Flyte provided param functionality do not fail when run locally default_execution_id = _identifier.WorkflowExecutionIdentifier(project="local", domain="local", name="local") # Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users # are already acquainted with default_user_space_params = ExecutionParameters( execution_id=str( _SdkWorkflowExecutionIdentifier.promote_from_model( default_execution_id)), execution_date=_datetime.datetime.utcnow(), stats=_mock_stats.MockStats(), logging=_logging, tmp_dir=os.path.join(_sdk_config.LOCAL_SANDBOX.get(), "user_space"), ) default_context = FlyteContext( user_space_params=default_user_space_params, file_access=_data_proxy.default_local_file_access_provider)
def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str, raw_output_data_prefix: str): """ Entrypoint for all PythonTask extensions """ _click.echo("Running native-typed task") cloud_provider = _platform_config.CLOUD_PROVIDER.get() log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get() _logging.getLogger().setLevel(log_level) ctx = FlyteContext.current_context() # Create directories user_workspace_dir = ctx.file_access.local_access.get_random_directory() _click.echo(f"Using user directory {user_workspace_dir}") pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True) from flytekit import __version__ as _api_version execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( project=_internal_config.EXECUTION_PROJECT.get(), domain=_internal_config.EXECUTION_DOMAIN.get(), name=_internal_config.EXECUTION_NAME.get(), ), execution_date=_datetime.datetime.utcnow(), stats=_get_stats( # Stats metric path will be: # registration_project.registration_domain.app.module.task_name.user_stats # and it will be tagged with execution-level values for project/domain/wf/lp "{}.{}.{}.user_stats".format( _internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(), _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(), _internal_config.TASK_NAME.get() or _internal_config.NAME.get(), ), tags={ "exec_project": _internal_config.EXECUTION_PROJECT.get(), "exec_domain": _internal_config.EXECUTION_DOMAIN.get(), "exec_workflow": _internal_config.EXECUTION_WORKFLOW.get(), "exec_launchplan": _internal_config.EXECUTION_LAUNCHPLAN.get(), "api_version": _api_version, }, ), logging=_logging, tmp_dir=user_workspace_dir, ) if cloud_provider == _constants.CloudProvider.AWS: file_access = _data_proxy.FileAccessProvider( local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_s3proxy.AwsS3Proxy(raw_output_data_prefix), ) elif cloud_provider == _constants.CloudProvider.GCP: file_access = _data_proxy.FileAccessProvider( local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_gcs_proxy.GCSProxy(raw_output_data_prefix), ) elif cloud_provider == _constants.CloudProvider.LOCAL: # A fake remote using the local disk will automatically be created file_access = _data_proxy.FileAccessProvider(local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get()) else: raise Exception(f"Bad cloud provider {cloud_provider}") with ctx.new_file_access_context(file_access_provider=file_access) as ctx: # TODO: This is copied from serialize, which means there's a similarity here I'm not seeing. env = { _internal_config.CONFIGURATION_PATH.env_var: _internal_config.CONFIGURATION_PATH.get(), _internal_config.IMAGE.env_var: _internal_config.IMAGE.get(), } serialization_settings = SerializationSettings( project=_internal_config.TASK_PROJECT.get(), domain=_internal_config.TASK_DOMAIN.get(), version=_internal_config.TASK_VERSION.get(), image_config=get_image_config(), env=env, ) # The reason we need this is because of dynamic tasks. Even if we move compilation all to Admin, # if a dynamic task calls some task, t1, we have to write to the DJ Spec the correct task # identifier for t1. with ctx.new_serialization_settings(serialization_settings=serialization_settings) as ctx: # Because execution states do not look up the context chain, it has to be made last with ctx.new_execution_context( mode=ExecutionState.Mode.TASK_EXECUTION, execution_params=execution_parameters ) as ctx: _dispatch_execute(ctx, task_def, inputs, output_prefix)