def pre_execution( self, input_dict: Dict[Text, types.Channel], output_dict: Dict[Text, types.Channel], exec_properties: Dict[Text, Any], driver_args: data_types.DriverArgs, pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> data_types.ExecutionDecision: output_artifacts = { IMPORT_RESULT_KEY: self._import_artifacts( source_uri=exec_properties[SOURCE_URI_KEY], destination_channel=output_dict[IMPORT_RESULT_KEY], reimport=exec_properties[REIMPORT_OPTION_KEY], split_names=exec_properties[SPLIT_KEY]) } output_dict[IMPORT_RESULT_KEY] = channel_utils.as_channel( output_artifacts[IMPORT_RESULT_KEY]) return data_types.ExecutionDecision( input_dict={}, output_dict=output_artifacts, exec_properties={}, execution_id=self._register_execution( exec_properties={}, pipeline_info=pipeline_info, component_info=component_info), use_cached_results=False)
def pre_execution( self, input_dict: Dict[Text, types.Channel], output_dict: Dict[Text, types.Channel], exec_properties: Dict[Text, Any], driver_args: data_types.DriverArgs, pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> data_types.ExecutionDecision: resolver_class = exec_properties[RESOLVER_CLASS] if exec_properties[RESOLVER_CONFIGS]: resolver = resolver_class(**exec_properties[RESOLVER_CONFIGS]) else: resolver = resolver_class() resolve_result = resolver.resolve( metadata_handler=self._metadata_handler, source_channels=input_dict.copy()) if not resolve_result.has_complete_result: raise RuntimeError('Cannot resolve all artifacts as needed.') return data_types.ExecutionDecision( input_dict={}, output_dict=resolve_result.per_key_resolve_result, exec_properties=exec_properties, execution_id=self._register_execution( exec_properties={}, pipeline_info=pipeline_info, component_info=component_info), use_cached_results=True)
def test_new_execution(self, mock_metadata_class, mock_driver_class, mock_executor_class, mock_get_logger): self._setup_mocks(mock_metadata_class, mock_driver_class, mock_executor_class, mock_get_logger) adapter, input_dict, output_dict, exec_properties, driver_args = self._setup_adapter_and_args( ) self.mock_task_instance.xcom_pull.side_effect = [self.input_one_json] self.mock_driver.prepare_execution.return_value = data_types.ExecutionDecision( input_dict, output_dict, exec_properties, execution_id=12345) check_result = adapter.check_cache_and_maybe_prepare_execution( 'cached_branch', 'uncached_branch', ti=self.mock_task_instance) mock_driver_class.assert_called_with( metadata_handler=self.mock_metadata) self.mock_driver.prepare_execution.called_with(input_dict, output_dict, exec_properties, driver_args) self.mock_task_instance.xcom_pull.assert_called_with( dag_id='input_one_component_id', key='input_one_key') calls = [ mock.call(key='_exec_inputs', value=types.jsonify_tfx_type_dict(input_dict)), mock.call(key='_exec_outputs', value=types.jsonify_tfx_type_dict(output_dict)), mock.call(key='_exec_properties', value=json.dumps(exec_properties)), mock.call(key='_execution_id', value=12345) ] self.mock_task_instance.xcom_push.assert_has_calls(calls) self.assertEqual(check_result, 'uncached_branch')
def test_cached_execution(self, mock_metadata_class, mock_driver_class, mock_executor_class, mock_get_logger): self._setup_mocks(mock_metadata_class, mock_driver_class, mock_executor_class, mock_get_logger) adapter, input_dict, output_dict, exec_properties, driver_args = self._setup_adapter_and_args( ) self.mock_task_instance.xcom_pull.side_effect = [self.input_one_json] self.mock_driver.prepare_execution.return_value = data_types.ExecutionDecision( input_dict, output_dict, exec_properties) check_result = adapter.check_cache_and_maybe_prepare_execution( 'cached_branch', 'uncached_branch', ti=self.mock_task_instance) mock_get_logger.assert_called_with(self._logger_config) mock_driver_class.assert_called_with( metadata_handler=self.mock_metadata) self.mock_driver.prepare_execution.called_with(input_dict, output_dict, exec_properties, driver_args) self.mock_task_instance.xcom_pull.assert_called_with( dag_id='input_one_component_id', key='input_one_key') self.mock_task_instance.xcom_push.assert_called_with( key='output_one_key', value=self.output_one_json) self.assertEqual(check_result, 'cached_branch')
def pre_execution( self, input_dict: Dict[Text, types.Channel], output_dict: Dict[Text, types.Channel], exec_properties: Dict[Text, Any], driver_args: data_types.DriverArgs, pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> data_types.ExecutionDecision: input_artifacts = channel_utils.unwrap_channel_dict(input_dict) output_artifacts = channel_utils.unwrap_channel_dict(output_dict) # Generating missing output artifact URIs for name, artifacts in output_artifacts.items(): for idx, artifact in enumerate(artifacts): if not artifact.uri: suffix = str(idx + 1) if idx > 0 else '' artifact.uri = os.path.join( pipeline_info.pipeline_root, 'artifacts', name + suffix, 'data', ) fileio.makedirs(os.path.dirname(artifact.uri)) return data_types.ExecutionDecision(input_artifacts, output_artifacts, exec_properties, 123, False)
def _default_caching_handling( self, input_dict: Dict[Text, List[types.TfxArtifact]], output_dict: Dict[Text, List[types.TfxArtifact]], exec_properties: Dict[Text, Any], driver_args: data_types.DriverArgs, ) -> data_types.ExecutionDecision: """Check cache for desired and applicable identical execution.""" enable_cache = driver_args.enable_cache base_output_dir = driver_args.base_output_dir worker_name = driver_args.worker_name # If caching is enabled, try to get previous execution results and directly # use as output. if enable_cache: output_result = self._get_output_from_previous_run( input_dict, output_dict, exec_properties, driver_args) if output_result: tf.logging.info('Found cache from previous run.') return data_types.ExecutionDecision( input_dict=input_dict, output_dict=output_result, exec_properties=exec_properties) # Previous run is not available, prepare execution. # Registers execution in metadata. execution_id = self._metadata_handler.prepare_execution( worker_name, exec_properties) tf.logging.info('Preparing new execution.') # Checks inputs exist and have valid states and locks them to avoid GC half # way self._verify_artifacts(input_dict) # Updates output. for name, output_list in output_dict.items(): for artifact in output_list: artifact.uri = self._generate_output_uri(artifact, base_output_dir, name, execution_id) return data_types.ExecutionDecision( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, execution_id=execution_id)
def _run_driver( self, input_dict: Dict[Text, List[types.TfxArtifact]], output_dict: Dict[Text, List[types.TfxArtifact]], exec_properties: Dict[Text, Any]) -> data_types.ExecutionDecision: """Prepare inputs, outputs and execution properties for actual execution.""" # TODO(jyzhao): support driver after go/tfx-oss-artifact-passing. tf.logging.info('Run driver for %s', self._name) # Return a fake result that makes sure execution_decision.execution_needed # is true to always trigger the executor. return data_types.ExecutionDecision(input_dict, output_dict, exec_properties, 1)
def pre_execution( self, input_dict: Dict[str, types.BaseChannel], output_dict: Dict[str, types.Channel], exec_properties: Dict[str, Any], driver_args: data_types.DriverArgs, pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> data_types.ExecutionDecision: # Registers contexts and execution contexts = self._metadata_handler.register_pipeline_contexts_if_not_exists( pipeline_info) execution = self._metadata_handler.register_execution( exec_properties=exec_properties, pipeline_info=pipeline_info, component_info=component_info, contexts=contexts) # Gets resolved artifacts. resolver_class = exec_properties[RESOLVER_STRATEGY_CLASS] if exec_properties[RESOLVER_CONFIG]: resolver = resolver_class(**exec_properties[RESOLVER_CONFIG]) else: resolver = resolver_class() input_artifacts = self._build_input_dict(pipeline_info, input_dict) output_artifacts = resolver.resolve_artifacts( store=self._metadata_handler.store, input_dict=input_artifacts, ) if output_artifacts is None: # No inputs available. Still driver needs an ExecutionDecision, so use a # dummy dict with no artifacts. output_artifacts = {key: [] for key in input_artifacts} # TODO(b/148828122): This is a temporary workaround for interactive mode. for k, c in output_dict.items(): output_dict[k] = types.Channel(type=c.type).set_artifacts( output_artifacts[k]) # Updates execution to reflect artifact resolution results and mark # as cached. self._metadata_handler.update_execution( execution=execution, component_info=component_info, output_artifacts=output_artifacts, execution_state=metadata.EXECUTION_STATE_CACHED, contexts=contexts) return data_types.ExecutionDecision(input_dict={}, output_dict=output_artifacts, exec_properties=exec_properties, execution_id=execution.id, use_cached_results=True)
def pre_execution( self, input_dict: Dict[Text, channel.Channel], output_dict: Dict[Text, channel.Channel], exec_properties: Dict[Text, Any], driver_args: data_types.DriverArgs, pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> data_types.ExecutionDecision: input_artifacts = channel.unwrap_channel_dict(input_dict) output_artifacts = channel.unwrap_channel_dict(output_dict) tf.gfile.MakeDirs(pipeline_info.pipeline_root) types.get_single_instance(output_artifacts['output']).uri = os.path.join( pipeline_info.pipeline_root, 'output') return data_types.ExecutionDecision(input_artifacts, output_artifacts, exec_properties, 123, False)
def pre_execution( self, input_dict: Dict[Text, types.Channel], output_dict: Dict[Text, types.Channel], exec_properties: Dict[Text, Any], driver_args: data_types.DriverArgs, pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> data_types.ExecutionDecision: # Registers contexts and execution contexts = self._metadata_handler.register_pipeline_contexts_if_not_exists( pipeline_info) execution = self._metadata_handler.register_execution( exec_properties=exec_properties, pipeline_info=pipeline_info, component_info=component_info, contexts=contexts) # Gets resolved artifacts. resolver_class = exec_properties[RESOLVER_CLASS] if exec_properties[RESOLVER_CONFIGS]: resolver = resolver_class(**exec_properties[RESOLVER_CONFIGS]) else: resolver = resolver_class() resolve_result = resolver.resolve( pipeline_info=pipeline_info, metadata_handler=self._metadata_handler, source_channels=input_dict.copy()) # TODO(b/148828122): This is a temporary walkaround for interactive mode. for k, c in output_dict.items(): output_dict[k] = types.Channel( type=c.type, artifacts=resolve_result.per_key_resolve_result[k]) # Updates execution to reflect artifact resolution results and mark # as cached. self._metadata_handler.update_execution( execution=execution, component_info=component_info, output_artifacts=resolve_result.per_key_resolve_result, execution_state=metadata.EXECUTION_STATE_CACHED, contexts=contexts) return data_types.ExecutionDecision( input_dict={}, output_dict=resolve_result.per_key_resolve_result, exec_properties=exec_properties, execution_id=execution.id, use_cached_results=True)
def pre_execution( self, input_dict: Dict[Text, types.Channel], output_dict: Dict[Text, types.Channel], exec_properties: Dict[Text, Any], driver_args: data_types.DriverArgs, pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> data_types.ExecutionDecision: # Registers contexts and execution. contexts = self._metadata_handler.register_pipeline_contexts_if_not_exists( pipeline_info) execution = self._metadata_handler.register_execution( exec_properties=exec_properties, pipeline_info=pipeline_info, component_info=component_info, contexts=contexts) # Create imported artifacts. output_artifacts = { IMPORT_RESULT_KEY: [ self._prepare_artifact( uri=exec_properties[SOURCE_URI_KEY], properties=exec_properties[PROPERTIES_KEY], custom_properties=exec_properties[CUSTOM_PROPERTIES_KEY], destination_channel=output_dict[IMPORT_RESULT_KEY], reimport=exec_properties[REIMPORT_OPTION_KEY]) ] } # Update execution with imported artifacts. self._metadata_handler.update_execution( execution=execution, component_info=component_info, output_artifacts=output_artifacts, execution_state=metadata.EXECUTION_STATE_CACHED, contexts=contexts) output_dict[IMPORT_RESULT_KEY] = channel_utils.as_channel( output_artifacts[IMPORT_RESULT_KEY]) return data_types.ExecutionDecision( input_dict={}, output_dict=output_artifacts, exec_properties=exec_properties, execution_id=execution.id, use_cached_results=False)
def pre_execution( self, input_dict: Dict[Text, types.Channel], output_dict: Dict[Text, types.Channel], exec_properties: Dict[Text, Any], driver_args: data_types.DriverArgs, pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> data_types.ExecutionDecision: """Handle pre-execution logic. There are four steps: 1. Fetches input artifacts from metadata and checks whether uri exists. 2. Registers execution. 3. Decides whether a new execution is needed. 4a. If (3), prepare output artifacts. 4b. If not (3), fetch cached output artifacts. Args: input_dict: key -> Channel for inputs. output_dict: key -> Channel for outputs. Uris of the outputs are not assigned. exec_properties: Dict of other execution properties. driver_args: An instance of data_types.DriverArgs class. pipeline_info: An instance of data_types.PipelineInfo, holding pipeline related properties including pipeline_name, pipeline_root and run_id component_info: An instance of data_types.ComponentInfo, holding component related properties including component_type and component_id. Returns: data_types.ExecutionDecision object. Raises: RuntimeError: if any input as an empty uri. """ # Step 1. Fetch inputs from metadata. input_artifacts = self.resolve_input_artifacts(input_dict, exec_properties, driver_args, pipeline_info) self.verify_input_artifacts(artifacts_dict=input_artifacts) absl.logging.debug('Resolved input artifacts are: %s', input_artifacts) # Step 2. Register execution in metadata. execution_id = self._register_execution( exec_properties=exec_properties, pipeline_info=pipeline_info, component_info=component_info) output_artifacts = {} use_cached_results = False if driver_args.enable_cache: # TODO(b/136031301): Combine Step 3 and Step 4b after finish migration to # go/tfx-oss-artifact-passing. # Step 3. Decide whether a new execution is needed. cached_execution_id = self._metadata_handler.previous_execution( input_artifacts=input_artifacts, exec_properties=exec_properties, pipeline_info=pipeline_info, component_info=component_info) if cached_execution_id: absl.logging.debug('Found cached_execution: %s', cached_execution_id) # Step 4b. New execution not needed. Fetch cached output artifacts. try: output_artifacts = self._fetch_cached_artifacts( output_dict=output_dict, cached_execution_id=cached_execution_id) absl.logging.debug('Cached output artifacts are: %s', output_artifacts) use_cached_results = True except RuntimeError: absl.logging.warning( 'Error when trying to get cached output artifacts') use_cached_results = False if not use_cached_results: absl.logging.debug('Cached results not found, move on to new execution') # Step 4a. New execution is needed. Prepare output artifacts. output_artifacts = self._prepare_output_artifacts( output_dict=output_dict, execution_id=execution_id, pipeline_info=pipeline_info, component_info=component_info) absl.logging.debug( 'Output artifacts skeleton for the upcoming execution are: %s', output_artifacts) exec_properties = self.resolve_exec_properties(exec_properties, pipeline_info, component_info) absl.logging.debug( 'Execution properties for the upcoming execution are: %s', exec_properties) return data_types.ExecutionDecision(input_artifacts, output_artifacts, exec_properties, execution_id, use_cached_results)
def pre_execution( self, input_dict: Dict[Text, types.Channel], output_dict: Dict[Text, types.Channel], exec_properties: Dict[Text, Any], driver_args: data_types.DriverArgs, pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> data_types.ExecutionDecision: """Handle pre-execution logic. There are four steps: 1. Fetches input artifacts from metadata and checks whether uri exists. 2. Registers execution. 3. Decides whether a new execution is needed. 4a. If (3), prepare output artifacts. 4b. If not (3), fetch cached output artifacts. Args: input_dict: key -> Channel for inputs. output_dict: key -> Channel for outputs. Uris of the outputs are not assigned. exec_properties: Dict of other execution properties. driver_args: An instance of data_types.DriverArgs class. pipeline_info: An instance of data_types.PipelineInfo, holding pipeline related properties including pipeline_name, pipeline_root and run_id component_info: An instance of data_types.ComponentInfo, holding component related properties including component_type and component_id. Returns: data_types.ExecutionDecision object. Raises: RuntimeError: if any input as an empty uri. """ # Step 1. Fetch inputs from metadata. exec_properties = self.resolve_exec_properties(exec_properties, pipeline_info, component_info) input_artifacts = self.resolve_input_artifacts(input_dict, exec_properties, driver_args, pipeline_info) self.verify_input_artifacts(artifacts_dict=input_artifacts) absl.logging.debug('Resolved input artifacts are: %s', input_artifacts) # Step 2. Register execution in metadata. contexts = self._metadata_handler.register_pipeline_contexts_if_not_exists( pipeline_info) execution = self._metadata_handler.register_execution( input_artifacts=input_artifacts, exec_properties=exec_properties, pipeline_info=pipeline_info, component_info=component_info, contexts=contexts) use_cached_results = False output_artifacts = None if driver_args.enable_cache: # Step 3. Decide whether a new execution is needed. output_artifacts = self._metadata_handler.get_cached_outputs( input_artifacts=input_artifacts, exec_properties=exec_properties, pipeline_info=pipeline_info, component_info=component_info) if output_artifacts is not None: # If cache should be used, updates execution to reflect that. Note that # with this update, publisher should / will be skipped. self._metadata_handler.update_execution( execution=execution, component_info=component_info, output_artifacts=output_artifacts, execution_state=metadata.EXECUTION_STATE_CACHED, contexts=contexts) use_cached_results = True else: absl.logging.debug( 'Cached results not found, move on to new execution') # Step 4a. New execution is needed. Prepare output artifacts. output_artifacts = self._prepare_output_artifacts( input_artifacts=input_artifacts, output_dict=output_dict, exec_properties=exec_properties, execution_id=execution.id, pipeline_info=pipeline_info, component_info=component_info) absl.logging.debug( 'Output artifacts skeleton for the upcoming execution are: %s', output_artifacts) # Updates the execution to reflect refreshed output artifacts and # execution properties. self._metadata_handler.update_execution( execution=execution, component_info=component_info, output_artifacts=output_artifacts, exec_properties=exec_properties, contexts=contexts) absl.logging.debug( 'Execution properties for the upcoming execution are: %s', exec_properties) return data_types.ExecutionDecision(input_artifacts, output_artifacts, exec_properties, execution.id, use_cached_results)
def _default_caching_handling( self, input_dict: Dict[Text, List[types.TfxArtifact]], output_dict: Dict[Text, List[types.TfxArtifact]], exec_properties: Dict[Text, Any], driver_args: data_types.DriverArgs, ) -> data_types.ExecutionDecision: """Check cache for desired and applicable identical execution.""" enable_cache = driver_args.enable_cache base_output_dir = driver_args.base_output_dir worker_name = driver_args.worker_name # If caching is enabled, try to get previous execution results and directly # use as output. if enable_cache: output_result = self._get_output_from_previous_run( input_dict, output_dict, exec_properties, driver_args) if output_result: tf.logging.info('Found cache from previous run.') return data_types.ExecutionDecision( input_dict=input_dict, output_dict=output_result, exec_properties=exec_properties) # Previous run is not available, prepare execution. # Registers execution in metadata. execution_id = self._metadata_handler.prepare_execution( worker_name, exec_properties) tf.logging.info('Preparing new execution.') # Checks inputs exist and have valid states and locks them to avoid GC half # way self._verify_inputs(input_dict) # Updates output. max_input_span = 0 for input_list in input_dict.values(): for single_input in input_list: max_input_span = max(max_input_span, single_input.span) # TODO(ruoyu): This location is dangerous because this function is not # guaranteed to be called on custom driver. for output_name, output_list in output_dict.items(): for output_artifact in output_list: # Updates outputs uri based on execution id and optional split. # Last empty string forces this be to a directory. output_artifact.uri = os.path.join(base_output_dir, output_name, str(execution_id), output_artifact.split, '') if tf.gfile.Exists(output_artifact.uri): msg = 'Output artifact uri {} already exists'.format( output_artifact.uri) tf.logging.error(msg) raise RuntimeError(msg) else: # TODO(zhitaoli): Consider refactoring this out into something # which can handle permission bits. tf.logging.info('Creating output artifact uri %s as directory', output_artifact.uri) tf.gfile.MakeDirs(output_artifact.uri) # Defaults to make the output span the max of input span. output_artifact.span = max_input_span return data_types.ExecutionDecision( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, execution_id=execution_id)