def test_do_with_cache(self): # First run that creates cache. self._exec_properties['module_file'] = self._module_file metrics = self._run_pipeline_get_metrics() # The test data has 10036 instances in the train dataset, and 4964 instances # in the eval dataset. Since the analysis dataset (train) is read twice when # no input cache is present (once for analysis and once for transform), the # expected value of the num_instances counter is: 10036 * 2 + 4964 = 25036. self.assertMetricsCounterEqual(metrics, 'num_instances', 24909) self._verify_transform_outputs(store_cache=True) # Second run from cache. self._output_data_dir = self._get_output_data_dir('2nd_run') analyzer_cache_artifact = standard_artifacts.TransformCache() analyzer_cache_artifact.uri = self._updated_analyzer_cache_artifact.uri self._make_base_do_params(self._source_data_dir, self._output_data_dir) self._input_dict[executor.ANALYZER_CACHE_KEY] = [ analyzer_cache_artifact ] self._exec_properties['module_file'] = self._module_file metrics = self._run_pipeline_get_metrics() # Since input cache should now cover all analysis (train) paths, the train # and eval sets are each read exactly once for transform. Thus, the # expected value of the num_instances counter is: 10036 + 4964 = 15000. self.assertMetricsCounterEqual(metrics, 'num_instances', 15000) self._verify_transform_outputs(store_cache=True)
def test_construct_with_cache_disabled_but_input_cache(self): with self.assertRaises(ValueError): _ = component.Transform(examples=self.examples, schema=self.schema, preprocessing_fn='my_preprocessing_fn', disable_analyzer_cache=True, analyzer_cache=channel_utils.as_channel( [standard_artifacts.TransformCache()]))
def _make_base_do_params(self, source_data_dir, output_data_dir): # Create input dict. example1 = standard_artifacts.Examples() example1.uri = self._ARTIFACT1_URI example1.split_names = artifact_utils.encode_split_names( ['train', 'eval']) example2 = copy.deepcopy(example1) example2.uri = self._ARTIFACT2_URI self._example_artifacts = [example1, example2] schema_artifact = standard_artifacts.Schema() schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen') self._input_dict = { standard_component_specs.EXAMPLES_KEY: self._example_artifacts[:1], standard_component_specs.SCHEMA_KEY: [schema_artifact], } # Create output dict. self._transformed_output = standard_artifacts.TransformGraph() self._transformed_output.uri = os.path.join(output_data_dir, 'transformed_graph') transformed1 = standard_artifacts.Examples() transformed1.uri = os.path.join(output_data_dir, 'transformed_examples', '0') transformed2 = standard_artifacts.Examples() transformed2.uri = os.path.join(output_data_dir, 'transformed_examples', '1') self._transformed_example_artifacts = [transformed1, transformed2] temp_path_output = _TempPath() temp_path_output.uri = tempfile.mkdtemp() self._updated_analyzer_cache_artifact = standard_artifacts.TransformCache( ) self._updated_analyzer_cache_artifact.uri = os.path.join( self._output_data_dir, 'CACHE') self._output_dict = { standard_component_specs.TRANSFORM_GRAPH_KEY: [self._transformed_output], standard_component_specs.TRANSFORMED_EXAMPLES_KEY: self._transformed_example_artifacts[:1], executor.TEMP_PATH_KEY: [temp_path_output], standard_component_specs.UPDATED_ANALYZER_CACHE_KEY: [self._updated_analyzer_cache_artifact], } # Create exec properties skeleton. self._exec_properties = {}
def test_do_with_cache(self, provide_first_input_cache): # First run that creates cache. self._exec_properties[ standard_component_specs.MODULE_FILE_KEY] = self._module_file if provide_first_input_cache: self._input_dict[standard_component_specs.ANALYZER_CACHE_KEY] = [] metrics = self._run_pipeline_get_metrics() # The test data has 9909 instances in the train dataset, and 5091 instances # in the eval dataset. Since the analysis dataset (train) is read twice when # no input cache is present (once for analysis and once for transform), the # expected value of the num_instances counter is: 9909 * 2 + 5091 = 24909. self.assertMetricsCounterEqual(metrics, 'num_instances', 24909, ['tfx.Transform']) # Additionally we have 24909 instances due to generating statistics. self.assertMetricsCounterEqual(metrics, 'num_instances', 24909, ['tfx.DataValidation']) self._verify_transform_outputs(store_cache=True) # Second run from cache. self._output_data_dir = self._get_output_data_dir('2nd_run') analyzer_cache_artifact = standard_artifacts.TransformCache() analyzer_cache_artifact.uri = self._updated_analyzer_cache_artifact.uri self._make_base_do_params(self._SOURCE_DATA_DIR, self._output_data_dir) self._input_dict[standard_component_specs.ANALYZER_CACHE_KEY] = [ analyzer_cache_artifact ] self._exec_properties[ standard_component_specs.MODULE_FILE_KEY] = self._module_file metrics = self._run_pipeline_get_metrics() # Since input cache should now cover all analysis (train) paths, the train # and eval sets are each read exactly once for transform. Thus, the # expected value of the num_instances counter is: 9909 + 5091 = 15000. self.assertMetricsCounterEqual(metrics, 'num_instances', 15000, ['tfx.Transform']) # Additionally we have 24909 instances due to generating statistics. self.assertMetricsCounterEqual(metrics, 'num_instances', 24909, ['tfx.DataValidation']) self._verify_transform_outputs(store_cache=True)
def testGetCachePathEntry(self): # Empty case. self.assertEmpty( executor_utils.GetCachePathEntry( standard_component_specs.ANALYZER_CACHE_KEY, {})) cache_artifact = standard_artifacts.TransformCache() cache_artifact.uri = '/dummy' # input result = executor_utils.GetCachePathEntry( standard_component_specs.ANALYZER_CACHE_KEY, {standard_component_specs.ANALYZER_CACHE_KEY: [cache_artifact]}) self.assertEqual({labels.CACHE_INPUT_PATH_LABEL: '/dummy'}, result) # output result = executor_utils.GetCachePathEntry( standard_component_specs.UPDATED_ANALYZER_CACHE_KEY, {standard_component_specs.UPDATED_ANALYZER_CACHE_KEY: [cache_artifact]}) self.assertEqual({labels.CACHE_OUTPUT_PATH_LABEL: '/dummy'}, result)
def _make_base_do_params(self, source_data_dir, output_data_dir): # Create input dict. examples = standard_artifacts.Examples() examples.uri = os.path.join(source_data_dir, 'csv_example_gen') examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) schema_artifact = standard_artifacts.Schema() schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen') self._input_dict = { executor.EXAMPLES_KEY: [examples], executor.SCHEMA_KEY: [schema_artifact], } # Create output dict. self._transformed_output = standard_artifacts.TransformGraph() self._transformed_output.uri = os.path.join(output_data_dir, 'transformed_graph') self._transformed_examples = standard_artifacts.Examples() self._transformed_examples.uri = os.path.join(output_data_dir, 'transformed_examples') temp_path_output = _TempPath() temp_path_output.uri = tempfile.mkdtemp() self._updated_analyzer_cache_artifact = standard_artifacts.TransformCache( ) self._updated_analyzer_cache_artifact.uri = os.path.join( self._output_data_dir, 'CACHE') self._output_dict = { executor.TRANSFORM_GRAPH_KEY: [self._transformed_output], executor.TRANSFORMED_EXAMPLES_KEY: [self._transformed_examples], executor.TEMP_PATH_KEY: [temp_path_output], executor.UPDATED_ANALYZER_CACHE_KEY: [self._updated_analyzer_cache_artifact], } # Create exec properties skeleton. self._exec_properties = {}
def __init__( self, examples: types.Channel = None, schema: types.Channel = None, module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None, preprocessing_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, transform_graph: Optional[types.Channel] = None, transformed_examples: Optional[types.Channel] = None, input_data: Optional[types.Channel] = None, analyzer_cache: Optional[types.Channel] = None, instance_name: Optional[Text] = None, materialize: bool = True, disable_analyzer_cache: bool = False, custom_config: Optional[Dict[Text, Any]] = None): """Construct a Transform component. Args: examples: A Channel of type `standard_artifacts.Examples` (required). This should contain the two splits 'train' and 'eval'. schema: A Channel of type `standard_artifacts.Schema`. This should contain a single schema artifact. module_file: The file path to a python module file, from which the 'preprocessing_fn' function will be loaded. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. The function needs to have the following signature: ``` def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]: ... ``` where the values of input and returned Dict are either tf.Tensor or tf.SparseTensor. If additional inputs are needed for preprocessing_fn, they can be passed in custom_config: ``` def preprocessing_fn(inputs: Dict[Text, Any], custom_config: Dict[Text, Any]) -> Dict[Text, Any]: ... ``` preprocessing_fn: The path to python function that implements a 'preprocessing_fn'. See 'module_file' for expected signature of the function. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. transform_graph: Optional output 'TransformPath' channel for output of 'tf.Transform', which includes an exported Tensorflow graph suitable for both training and serving; transformed_examples: Optional output 'ExamplesPath' channel for materialized transformed examples, which includes both 'train' and 'eval' splits. input_data: Backwards compatibility alias for the 'examples' argument. analyzer_cache: Optional input 'TransformCache' channel containing cached information from previous Transform runs. When provided, Transform will try use the cached calculation if possible. instance_name: Optional unique instance name. Necessary iff multiple transform components are declared in the same pipeline. materialize: If True, write transformed examples as an output. If False, `transformed_examples` must not be provided. disable_analyzer_cache: If False, Transform will use input cache if provided and write cache output. If True, `analyzer_cache` must not be provided. custom_config: A dict which contains additional parameters that will be passed to preprocessing_fn. Raises: ValueError: When both or neither of 'module_file' and 'preprocessing_fn' is supplied. """ if input_data: absl.logging.warning( 'The "input_data" argument to the Transform component has ' 'been renamed to "examples" and is deprecated. Please update your ' 'usage as support for this argument will be removed soon.') examples = input_data if bool(module_file) == bool(preprocessing_fn): raise ValueError( "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied." ) transform_graph = transform_graph or types.Channel( type=standard_artifacts.TransformGraph, artifacts=[standard_artifacts.TransformGraph()]) if materialize and transformed_examples is None: transformed_examples = types.Channel( type=standard_artifacts.Examples, # TODO(b/161548528): remove the hardcode artifact. artifacts=[standard_artifacts.Examples()], matching_channel_name='examples') elif not materialize and transformed_examples is not None: raise ValueError( 'Must not specify transformed_examples when materialize is False.') if disable_analyzer_cache: updated_analyzer_cache = None if analyzer_cache: raise ValueError( '`analyzer_cache` is set when disable_analyzer_cache is True.') else: updated_analyzer_cache = types.Channel( type=standard_artifacts.TransformCache, artifacts=[standard_artifacts.TransformCache()]) spec = TransformSpec( examples=examples, schema=schema, module_file=module_file, preprocessing_fn=preprocessing_fn, transform_graph=transform_graph, transformed_examples=transformed_examples, analyzer_cache=analyzer_cache, updated_analyzer_cache=updated_analyzer_cache, custom_config=json.dumps(custom_config)) super(Transform, self).__init__(spec=spec, instance_name=instance_name)