def test_no_data_needed(self): span_0_key = 'span-0' span_1_key = 'span-1' def preprocessing_fn(inputs): return {k: tf.identity(v) for k, v in six.iteritems(inputs)} input_metadata = dataset_metadata.DatasetMetadata( schema_utils.schema_from_feature_spec({ 'x': tf.io.FixedLenFeature([], tf.float32), })) input_data_dict = { span_0_key: None, span_1_key: None, } with _TestPipeline() as p: flat_data = None cache_dict = { span_0_key: {}, span_1_key: {}, } _, output_cache = ( (flat_data, input_data_dict, cache_dict, input_metadata) | 'Analyze' >> tft_beam.AnalyzeDatasetWithCache( preprocessing_fn, pipeline=p)) self.assertFalse(output_cache)
def _RunBeamImpl(self, inputs: Mapping[Text, Any], outputs: Mapping[Text, Any], preprocessing_fn: Any, input_dataset_metadata: dataset_metadata.DatasetMetadata, raw_examples_data_format: Text, transform_output_path: Text, compute_statistics: bool, materialize_output_paths: Sequence[Text]) -> _Status: """Perform data preprocessing with FlumeC++ runner. Args: inputs: A dictionary of labelled input values. outputs: A dictionary of labelled output values. preprocessing_fn: The tf.Transform preprocessing_fn. input_dataset_metadata: A DatasetMetadata object for the input data. raw_examples_data_format: A string describing the raw data format. transform_output_path: An absolute path to write the output to. compute_statistics: A bool indicating whether or not compute statistics. materialize_output_paths: Paths to materialized outputs. Raises: RuntimeError: If reset() is not being invoked between two run(). ValueError: If the schema is empty. Returns: Status of the execution. """ raw_examples_file_format = common.GetSoleValue( inputs, labels.EXAMPLES_FILE_FORMAT_LABEL, strict=False) analyze_and_transform_data_paths = common.GetValues( inputs, labels.ANALYZE_AND_TRANSFORM_DATA_PATHS_LABEL) transform_only_data_paths = common.GetValues( inputs, labels.TRANSFORM_ONLY_DATA_PATHS_LABEL) stats_use_tfdv = common.GetSoleValue(inputs, labels.TFT_STATISTICS_USE_TFDV_LABEL) per_set_stats_output_paths = common.GetValues( outputs, labels.PER_SET_STATS_OUTPUT_PATHS_LABEL) temp_path = common.GetSoleValue(outputs, labels.TEMP_OUTPUT_LABEL) input_cache_dir = common.GetSoleValue( inputs, labels.CACHE_INPUT_PATH_LABEL, strict=False) output_cache_dir = common.GetSoleValue( outputs, labels.CACHE_OUTPUT_PATH_LABEL, strict=False) tf.logging.info('Analyze and transform data patterns: %s', list(enumerate(analyze_and_transform_data_paths))) tf.logging.info('Transform data patterns: %s', list(enumerate(transform_only_data_paths))) tf.logging.info('Transform materialization output paths: %s', list(enumerate(materialize_output_paths))) tf.logging.info('Transform output path: %s', transform_output_path) feature_spec = schema_utils.schema_as_feature_spec( _GetSchemaProto(input_dataset_metadata)).feature_spec try: analyze_input_columns = tft.get_analyze_input_columns( preprocessing_fn, feature_spec) transform_input_columns = ( tft.get_transform_input_columns(preprocessing_fn, feature_spec)) except AttributeError: # If using TFT 1.12, fall back to assuming all features are used. analyze_input_columns = feature_spec.keys() transform_input_columns = feature_spec.keys() # Use the same dataset (same columns) for AnalyzeDataset and computing # pre-transform stats so that the data will only be read once for these # two operations. if compute_statistics: analyze_input_columns = list( set(list(analyze_input_columns) + list(transform_input_columns))) if input_dataset_metadata.schema is _RAW_EXAMPLE_SCHEMA: analyze_input_dataset_metadata = input_dataset_metadata transform_input_dataset_metadata = input_dataset_metadata else: analyze_input_dataset_metadata = dataset_metadata.DatasetMetadata( dataset_schema.from_feature_spec( {feature: feature_spec[feature] for feature in analyze_input_columns})) transform_input_dataset_metadata = dataset_metadata.DatasetMetadata( dataset_schema.from_feature_spec( {feature: feature_spec[feature] for feature in transform_input_columns})) can_process_jointly = not bool(per_set_stats_output_paths or materialize_output_paths or output_cache_dir) analyze_data_list = self._MakeDatasetList( analyze_and_transform_data_paths, raw_examples_file_format, raw_examples_data_format, analyze_input_dataset_metadata, can_process_jointly) transform_data_list = self._MakeDatasetList( list(analyze_and_transform_data_paths) + list(transform_only_data_paths), raw_examples_file_format, raw_examples_data_format, transform_input_dataset_metadata, can_process_jointly) desired_batch_size = self._GetDesiredBatchSize(raw_examples_data_format) with self._CreatePipeline(outputs) as p: with tft_beam.Context( temp_dir=temp_path, desired_batch_size=desired_batch_size, passthrough_keys={_TRANSFORM_INTERNAL_FEATURE_FOR_KEY}, use_deep_copy_optimization=True): # pylint: disable=expression-not-assigned # pylint: disable=no-value-for-parameter _ = ( p | self._IncrementColumnUsageCounter( len(feature_spec.keys()), len(analyze_input_columns), len(transform_input_columns))) (new_analyze_data_dict, input_cache, flat_data_required) = ( p | self._OptimizeRun(input_cache_dir, output_cache_dir, analyze_data_list, feature_spec, preprocessing_fn, self._GetCacheSource())) # Removing unneeded datasets if they won't be needed for # materialization. This means that these datasets won't be included in # the statistics computation or profiling either. if not materialize_output_paths: analyze_data_list = [ d for d in new_analyze_data_dict.values() if d is not None ] analyze_decode_fn = ( self._GetDecodeFunction(raw_examples_data_format, analyze_input_dataset_metadata.schema)) for (idx, dataset) in enumerate(analyze_data_list): dataset.encoded = ( p | 'ReadAnalysisDataset[{}]'.format(idx) >> self._ReadExamples(dataset)) dataset.decoded = ( dataset.encoded | 'DecodeAnalysisDataset[{}]'.format(idx) >> self._DecodeInputs(analyze_decode_fn)) input_analysis_data = {} for key, dataset in six.iteritems(new_analyze_data_dict): if dataset is None: input_analysis_data[key] = None else: input_analysis_data[key] = dataset.decoded if flat_data_required: flat_input_analysis_data = ( [dataset.decoded for dataset in analyze_data_list] | 'FlattenAnalysisDatasets' >> beam.Flatten(pipeline=p)) else: flat_input_analysis_data = None if input_cache: tf.logging.info('Analyzing data with cache.') transform_fn, cache_output = ( (flat_input_analysis_data, input_analysis_data, input_cache, input_dataset_metadata) | 'AnalyzeDataset' >> tft_beam.AnalyzeDatasetWithCache( preprocessing_fn, pipeline=p)) # Write the raw/input metadata. (input_dataset_metadata | 'WriteMetadata' >> tft_beam.WriteMetadata( os.path.join(transform_output_path, tft.TFTransformOutput.RAW_METADATA_DIR), p)) # WriteTransformFn writes transform_fn and metadata to subdirectories # tensorflow_transform.SAVED_MODEL_DIR and # tensorflow_transform.TRANSFORMED_METADATA_DIR respectively. (transform_fn | 'WriteTransformFn' >> tft_beam.WriteTransformFn(transform_output_path)) if output_cache_dir is not None and cache_output is not None: # TODO(b/37788560): Possibly make this part of the beam graph. tf.io.gfile.makedirs(output_cache_dir) tf.logging.info('Using existing cache in: %s', input_cache_dir) if input_cache_dir is not None: # Only copy cache that is relevant to this iteration. This is # assuming that this pipeline operates on rolling ranges, so those # cache entries may also be relevant for future iterations. for span_cache_dir in input_analysis_data: full_span_cache_dir = os.path.join(input_cache_dir, span_cache_dir) if tf.io.gfile.isdir(full_span_cache_dir): self._CopyCache(full_span_cache_dir, os.path.join(output_cache_dir, span_cache_dir)) (cache_output | 'WriteCache' >> analyzer_cache.WriteAnalysisCacheToFS( p, output_cache_dir, sink=self._GetCacheSink())) if compute_statistics or materialize_output_paths: # Do not compute pre-transform stats if the input format is raw proto, # as StatsGen would treat any input as tf.Example. if (compute_statistics and not self._IsDataFormatProto(raw_examples_data_format)): # Aggregated feature stats before transformation. pre_transform_feature_stats_path = os.path.join( transform_output_path, tft.TFTransformOutput.PRE_TRANSFORM_FEATURE_STATS_PATH) schema_proto = _GetSchemaProto(analyze_input_dataset_metadata) ([ dataset.decoded if stats_use_tfdv else dataset.encoded for dataset in analyze_data_list ] | 'FlattenPreTransformAnalysisDatasets' >> beam.Flatten(pipeline=p) | 'GenerateAggregatePreTransformAnalysisStats' >> self._GenerateStats( pre_transform_feature_stats_path, schema_proto, use_deep_copy_optimization=True, use_tfdv=stats_use_tfdv)) transform_decode_fn = ( self._GetDecodeFunction(raw_examples_data_format, transform_input_dataset_metadata.schema)) # transform_data_list is a superset of analyze_data_list, we pay the # cost to read the same dataset (analyze_data_list) again here to # prevent certain beam runner from doing large temp materialization. for (idx, dataset) in enumerate(transform_data_list): dataset.encoded = ( p | 'ReadTransformDataset[{}]'.format(idx) >> self._ReadExamples(dataset)) dataset.decoded = ( dataset.encoded | 'DecodeTransformDataset[{}]'.format(idx) >> self._DecodeInputs(transform_decode_fn)) (dataset.transformed, metadata) = (((dataset.decoded, transform_input_dataset_metadata), transform_fn) | 'TransformDataset[{}]'.format(idx) >> tft_beam.TransformDataset()) if materialize_output_paths or not stats_use_tfdv: dataset.transformed_and_encoded = ( dataset.transformed | 'EncodeTransformedDataset[{}]'.format(idx) >> beam.ParDo( self._EncodeAsExamples(), metadata)) if compute_statistics: # Aggregated feature stats after transformation. _, metadata = transform_fn post_transform_feature_stats_path = os.path.join( transform_output_path, tft.TFTransformOutput.POST_TRANSFORM_FEATURE_STATS_PATH) # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in # schema. Currently input dataset schema only contains dtypes, # and other metadata is dropped due to roundtrip to tensors. transformed_schema_proto = _GetSchemaProto(metadata) ([(dataset.transformed if stats_use_tfdv else dataset.transformed_and_encoded) for dataset in transform_data_list] | 'FlattenPostTransformAnalysisDatasets' >> beam.Flatten() | 'GenerateAggregatePostTransformAnalysisStats' >> self._GenerateStats( post_transform_feature_stats_path, transformed_schema_proto, use_tfdv=stats_use_tfdv)) if per_set_stats_output_paths: assert len(transform_data_list) == len(per_set_stats_output_paths) # TODO(b/67632871): Remove duplicate stats gen compute that is # done both on a flattened view of the data, and on each span # below. bundles = zip(transform_data_list, per_set_stats_output_paths) for (idx, (dataset, output_path)) in enumerate(bundles): if stats_use_tfdv: data = dataset.transformed else: data = dataset.transformed_and_encoded (data | 'GeneratePostTransformStats[{}]'.format(idx) >> self._GenerateStats( output_path, transformed_schema_proto, use_tfdv=stats_use_tfdv)) if materialize_output_paths: assert len(transform_data_list) == len(materialize_output_paths) bundles = zip(transform_data_list, materialize_output_paths) for (idx, (dataset, output_path)) in enumerate(bundles): (dataset.transformed_and_encoded | 'Materialize[{}]'.format(idx) >> self._WriteExamples( raw_examples_file_format, output_path)) return _Status.OK()
def test_non_frequency_vocabulary_merge(self): """This test compares vocabularies produced with and without cache.""" mi_vocab_name = 'mutual_information_vocab' adjusted_mi_vocab_name = 'adjusted_mutual_information_vocab' weighted_frequency_vocab_name = 'weighted_frequency_vocab' def preprocessing_fn(inputs): _ = tft.vocabulary(inputs['s'], labels=inputs['label'], store_frequency=True, vocab_filename=mi_vocab_name, min_diff_from_avg=0.1, use_adjusted_mutual_info=False) _ = tft.vocabulary(inputs['s'], labels=inputs['label'], store_frequency=True, vocab_filename=adjusted_mi_vocab_name, min_diff_from_avg=1.0, use_adjusted_mutual_info=True) _ = tft.vocabulary(inputs['s'], weights=inputs['weight'], store_frequency=True, vocab_filename=weighted_frequency_vocab_name, use_adjusted_mutual_info=False) return inputs span_0_key = 'span-0' span_1_key = 'span-1' input_data = [ dict(s='a', weight=1, label=1), dict(s='a', weight=0.5, label=1), dict(s='b', weight=0.75, label=1), dict(s='b', weight=1, label=0), ] input_metadata = dataset_metadata.DatasetMetadata( schema_utils.schema_from_feature_spec({ 's': tf.io.FixedLenFeature([], tf.string), 'label': tf.io.FixedLenFeature([], tf.int64), 'weight': tf.io.FixedLenFeature([], tf.float32), })) input_data_dict = { span_0_key: input_data, span_1_key: input_data, } with _TestPipeline() as p: flat_data = p | 'CreateInputData' >> beam.Create( list(itertools.chain(*input_data_dict.values()))) # wrap each value in input_data_dict as a pcoll. input_data_pcoll_dict = {} for a, b in six.iteritems(input_data_dict): input_data_pcoll_dict[a] = p | a >> beam.Create(b) transform_fn_with_cache, output_cache = ( (flat_data, input_data_pcoll_dict, {}, input_metadata) | tft_beam.AnalyzeDatasetWithCache(preprocessing_fn)) transform_fn_with_cache_dir = os.path.join( self.base_test_dir, 'transform_fn_with_cache') _ = transform_fn_with_cache | tft_beam.WriteTransformFn( transform_fn_with_cache_dir) expected_accumulators = { b'__v0__VocabularyAccumulate[vocabulary]-<GhZ\xac\xb8\xa9\x8c\xce\x1c\xb2-ck\xca\xe8\xec\t%\x8f': [ b'["a", [2, [0.0, 1.0], [0.0, 0.0], 1.0]]', b'["b", [2, [0.5, 0.5], [0.0, 0.0], 1.0]]', b'["global_y_count_sentinel", [4, [0.25, 0.75], [0.0, 0.0], ' b'1.0]]' ], b'__v0__VocabularyAccumulate[vocabulary_1]-\xa6\xae\nd\xe3\xd1\x9f\xa0\xe2\xb4\x05j\xa5\xfd\x8c\xfaeN\xd1\x1f': [ b'["a", [2, [0.0, 1.0], [0.0, 0.0], 1.0]]', b'["b", [2, [0.5, 0.5], [0.0, 0.0], 1.0]]', b'["global_y_count_sentinel", [4, [0.25, 0.75], [0.0, 0.0], ' b'1.0]]' ], b"__v0__VocabularyAccumulate[vocabulary_2]-\x97\x1c>\x851\x94'\xdc\xdf\xfd\xcc\x86\xb7\xb8\xe1\xe8*\x89B\t": [b'["a", 1.5]', b'["b", 1.75]'], } spans = [span_0_key, span_1_key] self.assertCountEqual(output_cache.keys(), spans) for span in spans: self.assertCountEqual(output_cache[span].keys(), expected_accumulators.keys()) for idx, (key, value) in enumerate( six.iteritems(expected_accumulators)): beam_test_util.assert_that( output_cache[span][key], beam_test_util.equal_to(value), label='AssertCache[{}][{}]'.format(span, idx)) # 4 from analysis on each of the input spans. self.assertEqual(_get_counter_value(p.metrics, 'num_instances'), 8) self.assertEqual( _get_counter_value(p.metrics, 'cache_entries_decoded'), 0) self.assertEqual( _get_counter_value(p.metrics, 'cache_entries_encoded'), 6) self.assertEqual(_get_counter_value(p.metrics, 'saved_models_created'), 2) with _TestPipeline() as p: flat_data = p | 'CreateInputData' >> beam.Create(input_data * 2) transform_fn_no_cache = ((flat_data, input_metadata) | tft_beam.AnalyzeDataset(preprocessing_fn)) transform_fn_no_cache_dir = os.path.join(self.base_test_dir, 'transform_fn_no_cache') _ = transform_fn_no_cache | tft_beam.WriteTransformFn( transform_fn_no_cache_dir) # 4 from analysis on each of the input spans. self.assertEqual(_get_counter_value(p.metrics, 'num_instances'), 8) self.assertEqual( _get_counter_value(p.metrics, 'cache_entries_decoded'), 0) self.assertEqual( _get_counter_value(p.metrics, 'cache_entries_encoded'), 0) self.assertEqual(_get_counter_value(p.metrics, 'saved_models_created'), 2) tft_output_cache = tft.TFTransformOutput(transform_fn_with_cache_dir) tft_output_no_cache = tft.TFTransformOutput(transform_fn_no_cache_dir) for vocab_filename in (mi_vocab_name, adjusted_mi_vocab_name, weighted_frequency_vocab_name): cache_path = tft_output_cache.vocabulary_file_by_name( vocab_filename) no_cache_path = tft_output_no_cache.vocabulary_file_by_name( vocab_filename) with tf.io.gfile.GFile(cache_path, 'rb') as f1, tf.io.gfile.GFile( no_cache_path, 'rb') as f2: self.assertEqual( f1.readlines(), f2.readlines(), 'vocab with cache != vocab without cache for: {}'.format( vocab_filename))
def test_caching_vocab_for_integer_categorical(self): span_0_key = 'span-0' span_1_key = 'span-1' def preprocessing_fn(inputs): return { 'x_vocab': tft.compute_and_apply_vocabulary(inputs['x'], frequency_threshold=2) } input_metadata = dataset_metadata.DatasetMetadata( schema_utils.schema_from_feature_spec({ 'x': tf.FixedLenFeature([], tf.int64), })) input_data_dict = { span_0_key: [{ 'x': -2, }, { 'x': -4, }, { 'x': -1, }, { 'x': 4, }], span_1_key: [{ 'x': -2, }, { 'x': -1, }, { 'x': 6, }, { 'x': 7, }], } expected_transformed_data = [{ 'x_vocab': 0, }, { 'x_vocab': 1, }, { 'x_vocab': -1, }, { 'x_vocab': -1, }] with _TestPipeline() as p: flat_data = p | 'CreateInputData' >> beam.Create( list(itertools.chain(*input_data_dict.values()))) cache_dict = { span_0_key: { b'__v0__VocabularyAccumulate[compute_and_apply_vocabulary/vocabulary]-\x05e\xfe4\x03H.P\xb5\xcb\xd22\xe3\x16\x15\xf8\xf5\xe38\xd9': p | 'CreateB' >> beam.Create( [b'[-2, 2]', b'[-4, 1]', b'[-1, 1]', b'[4, 1]']), }, span_1_key: {}, } transform_fn, cache_output = ( (flat_data, input_data_dict, cache_dict, input_metadata) | 'Analyze' >> tft_beam.AnalyzeDatasetWithCache(preprocessing_fn)) dot_string = nodes.get_dot_graph( [analysis_graph_builder._ANALYSIS_GRAPH]).to_string() self.WriteRenderedDotFile(dot_string) self.assertNotIn(span_0_key, cache_output) _ = cache_output | 'WriteCache' >> analyzer_cache.WriteAnalysisCacheToFS( p, self._cache_dir) transformed_dataset = ( ((input_data_dict[span_1_key], input_metadata), transform_fn) | 'Transform' >> tft_beam.TransformDataset()) transformed_data, _ = transformed_dataset beam_test_util.assert_that( transformed_data, beam_test_util.equal_to(expected_transformed_data), label='first') # 4 from analysis since 1 span was completely cached, and 4 from transform. self.assertEqual(_get_counter_value(p.metrics, 'num_instances'), 8) self.assertEqual( _get_counter_value(p.metrics, 'cache_entries_decoded'), 1) self.assertEqual( _get_counter_value(p.metrics, 'cache_entries_encoded'), 1) self.assertEqual(_get_counter_value(p.metrics, 'saved_models_created'), 2)
def test_single_phase_run_twice(self): span_0_key = 'span-0' span_1_key = 'span-1' def preprocessing_fn(inputs): _ = tft.vocabulary(inputs['s'], vocab_filename='vocab1') _ = tft.bucketize(inputs['x'], 2, name='bucketize') return { 'x_min': tft.min(inputs['x'], name='x') + tf.zeros_like(inputs['x']), 'x_mean': tft.mean(inputs['x'], name='x') + tf.zeros_like(inputs['x']), 'y_min': tft.min(inputs['y'], name='y') + tf.zeros_like(inputs['y']), 'y_mean': tft.mean(inputs['y'], name='y') + tf.zeros_like(inputs['y']), 's_integerized': tft.compute_and_apply_vocabulary( inputs['s'], labels=inputs['label'], use_adjusted_mutual_info=True), } input_metadata = dataset_metadata.DatasetMetadata( schema_utils.schema_from_feature_spec({ 'x': tf.io.FixedLenFeature([], tf.float32), 'y': tf.io.FixedLenFeature([], tf.float32), 's': tf.io.FixedLenFeature([], tf.string), 'label': tf.io.FixedLenFeature([], tf.int64), })) input_data_dict = { span_0_key: [{ 'x': -2, 'y': 1, 's': 'a', 'label': 0, }, { 'x': 4, 'y': -4, 's': 'a', 'label': 1, }, { 'x': 5, 'y': 11, 's': 'a', 'label': 1, }, { 'x': 1, 'y': -4, 's': u'ȟᎥ𝒋ǩľḿꞑȯ𝘱𝑞𝗋𝘴'.encode('utf-8'), 'label': 1, }], span_1_key: [{ 'x': 12, 'y': 1, 's': u'ȟᎥ𝒋ǩľḿꞑȯ𝘱𝑞𝗋𝘴'.encode('utf-8'), 'label': 0 }, { 'x': 10, 'y': 1, 's': 'c', 'label': 1 }], } expected_vocabulary_contents = np.array( [b'a', u'ȟᎥ𝒋ǩľḿꞑȯ𝘱𝑞𝗋𝘴'.encode('utf-8'), b'c'], dtype=object) with _TestPipeline() as p: flat_data = p | 'CreateInputData' >> beam.Create( list(itertools.chain(*input_data_dict.values()))) # wrap each value in input_data_dict as a pcoll. input_data_pcoll_dict = {} for a, b in six.iteritems(input_data_dict): input_data_pcoll_dict[a] = p | a >> beam.Create(b) transform_fn_1, cache_output = ( (flat_data, input_data_pcoll_dict, {}, input_metadata) | 'Analyze' >> tft_beam.AnalyzeDatasetWithCache(preprocessing_fn)) _ = (cache_output | 'WriteCache' >> analyzer_cache.WriteAnalysisCacheToFS( p, self._cache_dir)) transformed_dataset = (((input_data_pcoll_dict[span_1_key], input_metadata), transform_fn_1) | 'Transform' >> tft_beam.TransformDataset()) del input_data_pcoll_dict transformed_data, unused_transformed_metadata = transformed_dataset expected_transformed_data = [ { 'x_mean': 5.0, 'x_min': -2.0, 'y_mean': 1.0, 'y_min': -4.0, 's_integerized': 0, }, { 'x_mean': 5.0, 'x_min': -2.0, 'y_mean': 1.0, 'y_min': -4.0, 's_integerized': 2, }, ] beam_test_util.assert_that( transformed_data, beam_test_util.equal_to(expected_transformed_data), label='first') transform_fn_dir = os.path.join(self.base_test_dir, 'transform_fn_1') _ = transform_fn_1 | tft_beam.WriteTransformFn(transform_fn_dir) for key in input_data_dict: self.assertIn(key, cache_output) self.assertEqual(7, len(cache_output[key])) tf_transform_output = tft.TFTransformOutput(transform_fn_dir) vocab1_path = tf_transform_output.vocabulary_file_by_name('vocab1') self.AssertVocabularyContents(vocab1_path, expected_vocabulary_contents) # 4 from analyzing 2 spans, and 2 from transform. self.assertEqual(_get_counter_value(p.metrics, 'num_instances'), 8) self.assertEqual( _get_counter_value(p.metrics, 'cache_entries_decoded'), 0) self.assertEqual( _get_counter_value(p.metrics, 'cache_entries_encoded'), 14) self.assertEqual(_get_counter_value(p.metrics, 'saved_models_created'), 2) with _TestPipeline() as p: flat_data = p | 'CreateInputData' >> beam.Create( list(itertools.chain(*input_data_dict.values()))) # wrap each value in input_data_dict as a pcoll. input_data_pcoll_dict = {} for a, b in six.iteritems(input_data_dict): input_data_pcoll_dict[a] = p | a >> beam.Create(b) input_cache = p | analyzer_cache.ReadAnalysisCacheFromFS( self._cache_dir, list(input_data_dict.keys())) transform_fn_2, second_output_cache = ( (flat_data, input_data_pcoll_dict, input_cache, input_metadata) | 'AnalyzeAgain' >> (tft_beam.AnalyzeDatasetWithCache(preprocessing_fn))) _ = (second_output_cache | 'WriteCache' >> analyzer_cache.WriteAnalysisCacheToFS( p, self._cache_dir)) dot_string = nodes.get_dot_graph( [analysis_graph_builder._ANALYSIS_GRAPH]).to_string() self.WriteRenderedDotFile(dot_string) transformed_dataset = ( ((input_data_dict[span_1_key], input_metadata), transform_fn_2) | 'TransformAgain' >> tft_beam.TransformDataset()) transformed_data, unused_transformed_metadata = transformed_dataset beam_test_util.assert_that( transformed_data, beam_test_util.equal_to(expected_transformed_data), label='second') transform_fn_dir = os.path.join(self.base_test_dir, 'transform_fn_2') _ = transform_fn_2 | tft_beam.WriteTransformFn(transform_fn_dir) tf_transform_output = tft.TFTransformOutput(transform_fn_dir) vocab1_path = tf_transform_output.vocabulary_file_by_name('vocab1') self.AssertVocabularyContents(vocab1_path, expected_vocabulary_contents) self.assertFalse(second_output_cache) # Only 2 from transform. self.assertEqual(_get_counter_value(p.metrics, 'num_instances'), 2) self.assertEqual( _get_counter_value(p.metrics, 'cache_entries_decoded'), 14) self.assertEqual( _get_counter_value(p.metrics, 'cache_entries_encoded'), 0) # The root CreateSavedModel is optimized away because the data doesn't get # processed at all (only cache). self.assertEqual(_get_counter_value(p.metrics, 'saved_models_created'), 1)
def test_single_phase_mixed_analyzer_run_once(self): span_0_key = 'span-0' span_1_key = 'span-1' def preprocessing_fn(inputs): integerized_s = tft.compute_and_apply_vocabulary(inputs['s']) _ = tft.bucketize(inputs['x'], 2, name='bucketize') return { 'integerized_s': integerized_s, 'x_min': tft.min(inputs['x'], name='x') + tf.zeros_like(inputs['x']), 'x_mean': tft.mean(inputs['x'], name='x') + tf.zeros_like(inputs['x']), 'y_min': tft.min(inputs['y'], name='y') + tf.zeros_like(inputs['y']), 'y_mean': tft.mean(inputs['y'], name='y') + tf.zeros_like(inputs['y']), } # Run AnalyzeAndTransform on some input data and compare with expected # output. input_data = [{'x': 12, 'y': 1, 's': 'd'}, {'x': 10, 'y': 1, 's': 'c'}] input_metadata = dataset_metadata.DatasetMetadata( schema_utils.schema_from_feature_spec({ 'x': tf.io.FixedLenFeature([], tf.float32), 'y': tf.io.FixedLenFeature([], tf.float32), 's': tf.io.FixedLenFeature([], tf.string), })) input_data_dict = { span_0_key: [{ 'x': -2, 'y': 1, 's': 'b', }, { 'x': 4, 'y': -4, 's': 'b', }], span_1_key: input_data, } with _TestPipeline() as p: flat_data = p | 'CreateInputData' >> beam.Create( list(itertools.chain(*input_data_dict.values()))) cache_dict = { span_0_key: { b'__v0__CacheableCombineAccumulate[x_1/mean_and_var]-.\xc4t>ZBv\xea\xa5SU\xf4\x065\xc6\x1c\x81W\xf9\x1b': p | 'CreateA' >> beam.Create([b'[2.0, 1.0, 9.0, 0.0]']), b'__v0__CacheableCombineAccumulate[x/x]-\x95\xc5w\x88\x85\x8b5V\xc9\x00\xe0\x0f\x03\x1a\xdaL\x9d\xd5\xb3\xe3': p | 'CreateB' >> beam.Create([b'[2.0, 4.0]']), b'__v0__CacheableCombineAccumulate[y_1/mean_and_var]-E^\xb7VZ\xeew4rm\xab\xa3\xa4k|J\x80ck\x16': p | 'CreateC' >> beam.Create([b'[2.0, -1.5, 6.25, 0.0]']), b'__v0__CacheableCombineAccumulate[y/y]-\xdf\x1ey\x03\x1c\x96\xd5' b' e\x9bJ\xa1\xd2\xfc\x9c\x03\x0fM \xdb': p | 'CreateD' >> beam.Create([b'[4.0, 1.0]']), }, span_1_key: {}, } transform_fn, cache_output = ( (flat_data, input_data_dict, cache_dict, input_metadata) | 'Analyze' >> tft_beam.AnalyzeDatasetWithCache(preprocessing_fn)) _ = (cache_output | 'WriteCache' >> analyzer_cache.WriteAnalysisCacheToFS(p, self._cache_dir)) transformed_dataset = ( ((input_data_dict[span_1_key], input_metadata), transform_fn) | 'Transform' >> tft_beam.TransformDataset()) dot_string = nodes.get_dot_graph( [analysis_graph_builder._ANALYSIS_GRAPH]).to_string() self.WriteRenderedDotFile(dot_string) # The output cache should not have entries for the cache that is present # in the input cache. self.assertEqual(len(cache_output[span_0_key]), len(cache_output[span_1_key]) - 4) transformed_data, unused_transformed_metadata = transformed_dataset expected_transformed = [ { 'x_mean': 6.0, 'x_min': -2.0, 'y_mean': -0.25, 'y_min': -4.0, 'integerized_s': 1, }, { 'x_mean': 6.0, 'x_min': -2.0, 'y_mean': -0.25, 'y_min': -4.0, 'integerized_s': 2, }, ] beam_test_util.assert_that( transformed_data, beam_test_util.equal_to(expected_transformed)) transform_fn_dir = os.path.join(self.base_test_dir, 'transform_fn') _ = transform_fn | tft_beam.WriteTransformFn(transform_fn_dir) # 4 from analyzing 2 spans, and 2 from transform. self.assertEqual(_get_counter_value(p.metrics, 'num_instances'), 6) self.assertEqual( _get_counter_value(p.metrics, 'cache_entries_decoded'), 4) self.assertEqual( _get_counter_value(p.metrics, 'cache_entries_encoded'), 8) self.assertEqual(_get_counter_value(p.metrics, 'saved_models_created'), 2)