def _GetPreprocessingFn(self, inputs: Mapping[Text, Any], unused_outputs: Mapping[Text, Any]) -> Any: """Returns a user defined preprocessing_fn. Args: inputs: A dictionary of labelled input values. unused_outputs: A dictionary of labelled output values. Returns: User defined function. Raises: ValueError: When neither or both of MODULE_FILE and PREPROCESSING_FN are present in inputs. """ has_module_file = bool( common.GetSoleValue(inputs, labels.MODULE_FILE, strict=False)) has_preprocessing_fn = bool( common.GetSoleValue(inputs, labels.PREPROCESSING_FN, strict=False)) if has_module_file == has_preprocessing_fn: raise ValueError( 'Neither or both of MODULE_FILE and PREPROCESSING_FN have been ' 'supplied in inputs.') if has_module_file: return import_utils.import_func_from_source( common.GetSoleValue(inputs, labels.MODULE_FILE), 'preprocessing_fn') preprocessing_fn_path_split = common.GetSoleValue( inputs, labels.PREPROCESSING_FN).split('.') return import_utils.import_func_from_module( '.'.join(preprocessing_fn_path_split[0:-1]), preprocessing_fn_path_split[-1])
def _GetPreprocessingFn(self, inputs, unused_outputs): """Returns a user defined preprocessing_fn. Args: inputs: A dictionary of labelled input values. unused_outputs: A dictionary of labelled output values. Returns: User defined function. """ return io_utils.import_func( common.GetSoleValue(inputs, labels.PREPROCESSING_FN), 'preprocessing_fn')
def _GetPreprocessingFn(self, inputs: Mapping[Text, Any], unused_outputs: Mapping[Text, Any]) -> Any: """Returns a user defined preprocessing_fn. Args: inputs: A dictionary of labelled input values. unused_outputs: A dictionary of labelled output values. Returns: User defined function. """ return import_utils.import_func_from_source( common.GetSoleValue(inputs, labels.PREPROCESSING_FN), 'preprocessing_fn')
def _RunBeamImpl(self, inputs: Mapping[Text, Any], outputs: Mapping[Text, Any], preprocessing_fn: Any, input_dataset_metadata: dataset_metadata.DatasetMetadata, raw_examples_data_format: Text, transform_output_path: Text, compute_statistics: bool, materialize_output_paths: Sequence[Text]) -> _Status: """Perform data preprocessing with FlumeC++ runner. Args: inputs: A dictionary of labelled input values. outputs: A dictionary of labelled output values. preprocessing_fn: The tf.Transform preprocessing_fn. input_dataset_metadata: A DatasetMetadata object for the input data. raw_examples_data_format: A string describing the raw data format. transform_output_path: An absolute path to write the output to. compute_statistics: A bool indicating whether or not compute statistics. materialize_output_paths: Paths to materialized outputs. Raises: RuntimeError: If reset() is not being invoked between two run(). ValueError: If the schema is empty. Returns: Status of the execution. """ raw_examples_file_format = common.GetSoleValue( inputs, labels.EXAMPLES_FILE_FORMAT_LABEL, strict=False) analyze_and_transform_data_paths = common.GetValues( inputs, labels.ANALYZE_AND_TRANSFORM_DATA_PATHS_LABEL) transform_only_data_paths = common.GetValues( inputs, labels.TRANSFORM_ONLY_DATA_PATHS_LABEL) stats_use_tfdv = common.GetSoleValue(inputs, labels.TFT_STATISTICS_USE_TFDV_LABEL) per_set_stats_output_paths = common.GetValues( outputs, labels.PER_SET_STATS_OUTPUT_PATHS_LABEL) temp_path = common.GetSoleValue(outputs, labels.TEMP_OUTPUT_LABEL) input_cache_dir = common.GetSoleValue( inputs, labels.CACHE_INPUT_PATH_LABEL, strict=False) output_cache_dir = common.GetSoleValue( outputs, labels.CACHE_OUTPUT_PATH_LABEL, strict=False) tf.logging.info('Analyze and transform data patterns: %s', list(enumerate(analyze_and_transform_data_paths))) tf.logging.info('Transform data patterns: %s', list(enumerate(transform_only_data_paths))) tf.logging.info('Transform materialization output paths: %s', list(enumerate(materialize_output_paths))) tf.logging.info('Transform output path: %s', transform_output_path) feature_spec = schema_utils.schema_as_feature_spec( _GetSchemaProto(input_dataset_metadata)).feature_spec try: analyze_input_columns = tft.get_analyze_input_columns( preprocessing_fn, feature_spec) transform_input_columns = ( tft.get_transform_input_columns(preprocessing_fn, feature_spec)) except AttributeError: # If using TFT 1.12, fall back to assuming all features are used. analyze_input_columns = feature_spec.keys() transform_input_columns = feature_spec.keys() # Use the same dataset (same columns) for AnalyzeDataset and computing # pre-transform stats so that the data will only be read once for these # two operations. if compute_statistics: analyze_input_columns = list( set(list(analyze_input_columns) + list(transform_input_columns))) if input_dataset_metadata.schema is _RAW_EXAMPLE_SCHEMA: analyze_input_dataset_metadata = input_dataset_metadata transform_input_dataset_metadata = input_dataset_metadata else: analyze_input_dataset_metadata = dataset_metadata.DatasetMetadata( dataset_schema.from_feature_spec( {feature: feature_spec[feature] for feature in analyze_input_columns})) transform_input_dataset_metadata = dataset_metadata.DatasetMetadata( dataset_schema.from_feature_spec( {feature: feature_spec[feature] for feature in transform_input_columns})) can_process_jointly = not bool(per_set_stats_output_paths or materialize_output_paths or output_cache_dir) analyze_data_list = self._MakeDatasetList( analyze_and_transform_data_paths, raw_examples_file_format, raw_examples_data_format, analyze_input_dataset_metadata, can_process_jointly) transform_data_list = self._MakeDatasetList( list(analyze_and_transform_data_paths) + list(transform_only_data_paths), raw_examples_file_format, raw_examples_data_format, transform_input_dataset_metadata, can_process_jointly) desired_batch_size = self._GetDesiredBatchSize(raw_examples_data_format) with self._CreatePipeline(outputs) as p: with tft_beam.Context( temp_dir=temp_path, desired_batch_size=desired_batch_size, passthrough_keys={_TRANSFORM_INTERNAL_FEATURE_FOR_KEY}, use_deep_copy_optimization=True): # pylint: disable=expression-not-assigned # pylint: disable=no-value-for-parameter _ = ( p | self._IncrementColumnUsageCounter( len(feature_spec.keys()), len(analyze_input_columns), len(transform_input_columns))) (new_analyze_data_dict, input_cache, flat_data_required) = ( p | self._OptimizeRun(input_cache_dir, output_cache_dir, analyze_data_list, feature_spec, preprocessing_fn, self._GetCacheSource())) # Removing unneeded datasets if they won't be needed for # materialization. This means that these datasets won't be included in # the statistics computation or profiling either. if not materialize_output_paths: analyze_data_list = [ d for d in new_analyze_data_dict.values() if d is not None ] analyze_decode_fn = ( self._GetDecodeFunction(raw_examples_data_format, analyze_input_dataset_metadata.schema)) for (idx, dataset) in enumerate(analyze_data_list): dataset.encoded = ( p | 'ReadAnalysisDataset[{}]'.format(idx) >> self._ReadExamples(dataset)) dataset.decoded = ( dataset.encoded | 'DecodeAnalysisDataset[{}]'.format(idx) >> self._DecodeInputs(analyze_decode_fn)) input_analysis_data = {} for key, dataset in six.iteritems(new_analyze_data_dict): if dataset is None: input_analysis_data[key] = None else: input_analysis_data[key] = dataset.decoded if flat_data_required: flat_input_analysis_data = ( [dataset.decoded for dataset in analyze_data_list] | 'FlattenAnalysisDatasets' >> beam.Flatten(pipeline=p)) else: flat_input_analysis_data = None if input_cache: tf.logging.info('Analyzing data with cache.') transform_fn, cache_output = ( (flat_input_analysis_data, input_analysis_data, input_cache, input_dataset_metadata) | 'AnalyzeDataset' >> tft_beam.AnalyzeDatasetWithCache( preprocessing_fn, pipeline=p)) # Write the raw/input metadata. (input_dataset_metadata | 'WriteMetadata' >> tft_beam.WriteMetadata( os.path.join(transform_output_path, tft.TFTransformOutput.RAW_METADATA_DIR), p)) # WriteTransformFn writes transform_fn and metadata to subdirectories # tensorflow_transform.SAVED_MODEL_DIR and # tensorflow_transform.TRANSFORMED_METADATA_DIR respectively. (transform_fn | 'WriteTransformFn' >> tft_beam.WriteTransformFn(transform_output_path)) if output_cache_dir is not None and cache_output is not None: # TODO(b/37788560): Possibly make this part of the beam graph. tf.io.gfile.makedirs(output_cache_dir) tf.logging.info('Using existing cache in: %s', input_cache_dir) if input_cache_dir is not None: # Only copy cache that is relevant to this iteration. This is # assuming that this pipeline operates on rolling ranges, so those # cache entries may also be relevant for future iterations. for span_cache_dir in input_analysis_data: full_span_cache_dir = os.path.join(input_cache_dir, span_cache_dir) if tf.io.gfile.isdir(full_span_cache_dir): self._CopyCache(full_span_cache_dir, os.path.join(output_cache_dir, span_cache_dir)) (cache_output | 'WriteCache' >> analyzer_cache.WriteAnalysisCacheToFS( p, output_cache_dir, sink=self._GetCacheSink())) if compute_statistics or materialize_output_paths: # Do not compute pre-transform stats if the input format is raw proto, # as StatsGen would treat any input as tf.Example. if (compute_statistics and not self._IsDataFormatProto(raw_examples_data_format)): # Aggregated feature stats before transformation. pre_transform_feature_stats_path = os.path.join( transform_output_path, tft.TFTransformOutput.PRE_TRANSFORM_FEATURE_STATS_PATH) schema_proto = _GetSchemaProto(analyze_input_dataset_metadata) ([ dataset.decoded if stats_use_tfdv else dataset.encoded for dataset in analyze_data_list ] | 'FlattenPreTransformAnalysisDatasets' >> beam.Flatten(pipeline=p) | 'GenerateAggregatePreTransformAnalysisStats' >> self._GenerateStats( pre_transform_feature_stats_path, schema_proto, use_deep_copy_optimization=True, use_tfdv=stats_use_tfdv)) transform_decode_fn = ( self._GetDecodeFunction(raw_examples_data_format, transform_input_dataset_metadata.schema)) # transform_data_list is a superset of analyze_data_list, we pay the # cost to read the same dataset (analyze_data_list) again here to # prevent certain beam runner from doing large temp materialization. for (idx, dataset) in enumerate(transform_data_list): dataset.encoded = ( p | 'ReadTransformDataset[{}]'.format(idx) >> self._ReadExamples(dataset)) dataset.decoded = ( dataset.encoded | 'DecodeTransformDataset[{}]'.format(idx) >> self._DecodeInputs(transform_decode_fn)) (dataset.transformed, metadata) = (((dataset.decoded, transform_input_dataset_metadata), transform_fn) | 'TransformDataset[{}]'.format(idx) >> tft_beam.TransformDataset()) if materialize_output_paths or not stats_use_tfdv: dataset.transformed_and_encoded = ( dataset.transformed | 'EncodeTransformedDataset[{}]'.format(idx) >> beam.ParDo( self._EncodeAsExamples(), metadata)) if compute_statistics: # Aggregated feature stats after transformation. _, metadata = transform_fn post_transform_feature_stats_path = os.path.join( transform_output_path, tft.TFTransformOutput.POST_TRANSFORM_FEATURE_STATS_PATH) # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in # schema. Currently input dataset schema only contains dtypes, # and other metadata is dropped due to roundtrip to tensors. transformed_schema_proto = _GetSchemaProto(metadata) ([(dataset.transformed if stats_use_tfdv else dataset.transformed_and_encoded) for dataset in transform_data_list] | 'FlattenPostTransformAnalysisDatasets' >> beam.Flatten() | 'GenerateAggregatePostTransformAnalysisStats' >> self._GenerateStats( post_transform_feature_stats_path, transformed_schema_proto, use_tfdv=stats_use_tfdv)) if per_set_stats_output_paths: assert len(transform_data_list) == len(per_set_stats_output_paths) # TODO(b/67632871): Remove duplicate stats gen compute that is # done both on a flattened view of the data, and on each span # below. bundles = zip(transform_data_list, per_set_stats_output_paths) for (idx, (dataset, output_path)) in enumerate(bundles): if stats_use_tfdv: data = dataset.transformed else: data = dataset.transformed_and_encoded (data | 'GeneratePostTransformStats[{}]'.format(idx) >> self._GenerateStats( output_path, transformed_schema_proto, use_tfdv=stats_use_tfdv)) if materialize_output_paths: assert len(transform_data_list) == len(materialize_output_paths) bundles = zip(transform_data_list, materialize_output_paths) for (idx, (dataset, output_path)) in enumerate(bundles): (dataset.transformed_and_encoded | 'Materialize[{}]'.format(idx) >> self._WriteExamples( raw_examples_file_format, output_path)) return _Status.OK()
def Transform(self, inputs: Mapping[Text, Any], outputs: Mapping[Text, Any], status_file: Text) -> None: """Executes on request. This is the implementation part of transform executor. This is intended for using or extending the executor without artifact dependency. Args: inputs: A dictionary of labelled input values, including: - labels.COMPUTE_STATISTICS_LABEL: Whether compute statistics. - labels.SCHEMA_PATH_LABEL: Path to schema file. - labels.EXAMPLES_FILE_FORMAT_LABEL: Example file format, optional. - labels.EXAMPLES_DATA_FORMAT_LABEL: Example data format. - labels.ANALYZE_AND_TRANSFORM_DATA_PATHS_LABEL: Paths or path patterns to analyze and transform data. - labels.TRANSFORM_DATA_PATHS_LABEL: Paths or path patterns to transform only data. - labels.TFT_STATISTICS_USE_TFDV_LABEL: Whether use tfdv to compute statistics. - labels.PREPROCESSING_FN: Path to a Python module that contains the preprocessing_fn, optional. outputs: A dictionary of labelled output values, including: - labels.PER_SET_STATS_OUTPUT_PATHS_LABEL: Paths to statistics output, optional. - labels.TRANSFORM_METADATA_OUTPUT_PATH_LABEL: A path to TFTransformOutput output. - labels.TRANSFORM_MATERIALIZE_OUTPUT_PATHS_LABEL: Paths to transform materialization. - labels.TEMP_OUTPUT_LABEL: A path to temporary directory. status_file: Where the status should be written (not yet implemented) """ del status_file # unused compute_statistics = common.GetSoleValue(inputs, labels.COMPUTE_STATISTICS_LABEL) transform_output_path = common.GetSoleValue( outputs, labels.TRANSFORM_METADATA_OUTPUT_PATH_LABEL) raw_examples_data_format = common.GetSoleValue( inputs, labels.EXAMPLES_DATA_FORMAT_LABEL) schema = common.GetSoleValue(inputs, labels.SCHEMA_PATH_LABEL) input_dataset_metadata = self._ReadMetadata(raw_examples_data_format, schema) tf.logging.info('Inputs to executor.Transform function: {}'.format(inputs)) tf.logging.info( 'Outputs to executor.Transform function: {}'.format(outputs)) feature_spec = schema_utils.schema_as_feature_spec( _GetSchemaProto(input_dataset_metadata)).feature_spec # NOTE: We disallow an empty schema, which we detect by testing the # number of columns. While in principal an empty schema is valid, in # practice this is a sign of a user error, and this is a convenient # place to catch that error. if (not feature_spec and not self._ShouldDecodeAsRawExample(raw_examples_data_format)): raise ValueError(messages.SCHEMA_EMPTY) preprocessing_fn = self._GetPreprocessingFn(inputs, outputs) materialize_output_paths = common.GetValues( outputs, labels.TRANSFORM_MATERIALIZE_OUTPUT_PATHS_LABEL) # Inspecting the preprocessing_fn even if we know we need a full pass in # order to fail faster if it fails. try: analyze_input_columns = tft.get_analyze_input_columns( preprocessing_fn, feature_spec) except AttributeError: # If using TFT 1.12, fall back to assuming all features are used. analyze_input_columns = feature_spec.keys() if not compute_statistics and not materialize_output_paths: if analyze_input_columns: tf.logging.warning( 'Not using the in-place Transform because the following features ' 'require analyzing: {}'.format( tuple(c for c in analyze_input_columns))) else: tf.logging.warning( 'Using the in-place Transform since compute_statistics=False, ' 'it does not materialize transformed data, and the configured ' 'preprocessing_fn appears to not require analyzing the data.') self._RunInPlaceImpl(preprocessing_fn, input_dataset_metadata, transform_output_path) # TODO(b/122478841): Writes status to status file. return self._RunBeamImpl(inputs, outputs, preprocessing_fn, input_dataset_metadata, raw_examples_data_format, transform_output_path, compute_statistics, materialize_output_paths)
def _RunBeamImpl(self, inputs, outputs, preprocessing_fn, input_dataset_metadata, raw_examples_data_format, transform_output_path, compute_statistics, materialize_output_paths): """Perform data preprocessing with FlumeC++ runner. Args: inputs: A dictionary of labelled input values. outputs: A dictionary of labelled output values. preprocessing_fn: The tf.Transform preprocessing_fn. input_dataset_metadata: A DatasetMetadata object for the input data. raw_examples_data_format: A string describing the raw data format. transform_output_path: An absolute path to write the output to. compute_statistics: A bool indicating whether or not compute statistics. materialize_output_paths: Paths to materialized outputs. Raises: RuntimeError: If reset() is not being invoked between two run(). ValueError: If the schema is empty. Returns: Status of the execution. """ raw_examples_file_format = common.GetSoleValue( inputs, labels.EXAMPLES_FILE_FORMAT_LABEL, strict=False) analyze_and_transform_data_paths = common.GetValues( inputs, labels.ANALYZE_AND_TRANSFORM_DATA_PATHS_LABEL) transform_only_data_paths = common.GetValues( inputs, labels.TRANSFORM_ONLY_DATA_PATHS_LABEL) stats_use_tfdv = common.GetSoleValue( inputs, labels.TFT_STATISTICS_USE_TFDV_LABEL) per_set_stats_output_paths = common.GetValues( outputs, labels.PER_SET_STATS_OUTPUT_PATHS_LABEL) temp_path = common.GetSoleValue(outputs, labels.TEMP_OUTPUT_LABEL) tf.logging.info('Analyze and transform data patterns: %s', list(enumerate(analyze_and_transform_data_paths))) tf.logging.info('Transform data patterns: %s', list(enumerate(transform_only_data_paths))) tf.logging.info('Transform materialization output paths: %s', list(enumerate(materialize_output_paths))) tf.logging.info('Transform output path: %s', transform_output_path) feature_spec = input_dataset_metadata.schema.as_feature_spec() try: analyze_input_columns = tft.get_analyze_input_columns( preprocessing_fn, feature_spec) transform_input_columns = (tft.get_transform_input_columns( preprocessing_fn, feature_spec)) except AttributeError: # If using TFT 1.12, fall back to assuming all features are used. analyze_input_columns = feature_spec.keys() transform_input_columns = feature_spec.keys() # Use the same dataset (same columns) for AnalyzeDataset and computing # pre-transform stats so that the data will only be read once for these # two operations. if compute_statistics: analyze_input_columns = list( set( list(analyze_input_columns) + list(transform_input_columns))) analyze_input_dataset_metadata = copy.deepcopy(input_dataset_metadata) transform_input_dataset_metadata = copy.deepcopy( input_dataset_metadata) if input_dataset_metadata.schema is not _RAW_EXAMPLE_SCHEMA: analyze_input_dataset_metadata.schema = dataset_schema.from_feature_spec( { feature: feature_spec[feature] for feature in analyze_input_columns }) transform_input_dataset_metadata.schema = ( dataset_schema.from_feature_spec({ feature: feature_spec[feature] for feature in transform_input_columns })) can_process_jointly = not bool(per_set_stats_output_paths or materialize_output_paths) analyze_data_list = self._MakeDatasetList( analyze_and_transform_data_paths, raw_examples_file_format, raw_examples_data_format, analyze_input_dataset_metadata, can_process_jointly) transform_data_list = self._MakeDatasetList( list(analyze_and_transform_data_paths) + list(transform_only_data_paths), raw_examples_file_format, raw_examples_data_format, transform_input_dataset_metadata, can_process_jointly) desired_batch_size = self._GetDesiredBatchSize( raw_examples_data_format) with self._CreatePipeline(outputs) as p: with tft_beam.Context( temp_dir=temp_path, desired_batch_size=desired_batch_size, passthrough_keys={_TRANSFORM_INTERNAL_FEATURE_FOR_KEY}, use_deep_copy_optimization=True): # pylint: disable=expression-not-assigned # pylint: disable=no-value-for-parameter analyze_decode_fn = (self._GetDecodeFunction( raw_examples_data_format, analyze_input_dataset_metadata.schema)) for (idx, dataset) in enumerate(analyze_data_list): dataset.encoded = (p | 'ReadAnalysisDataset[{}]'.format(idx) >> self._ReadExamples(dataset)) dataset.decoded = ( dataset.encoded | 'DecodeAnalysisDataset[{}]'.format(idx) >> self._DecodeInputs(analyze_decode_fn)) input_analysis_data = ( [dataset.decoded for dataset in analyze_data_list] | 'FlattenAnalysisDatasets' >> beam.Flatten()) transform_fn = ((input_analysis_data, input_dataset_metadata) | 'AnalyzeDataset' >> tft_beam.AnalyzeDataset(preprocessing_fn)) # Write the raw/input metadata. (input_dataset_metadata | 'WriteMetadata' >> tft_beam.WriteMetadata( os.path.join(transform_output_path, tft.TFTransformOutput.RAW_METADATA_DIR), p)) # WriteTransformFn writes transform_fn and metadata to subdirectories # tensorflow_transform.SAVED_MODEL_DIR and # tensorflow_transform.TRANSFORMED_METADATA_DIR respectively. (transform_fn | 'WriteTransformFn' >> tft_beam.WriteTransformFn(transform_output_path)) if compute_statistics or materialize_output_paths: # Do not compute pre-transform stats if the input format is raw proto, # as StatsGen would treat any input as tf.Example. if (compute_statistics and not self._IsDataFormatProto( raw_examples_data_format)): # Aggregated feature stats before transformation. pre_transform_feature_stats_path = os.path.join( transform_output_path, tft.TFTransformOutput. PRE_TRANSFORM_FEATURE_STATS_PATH) # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in # schema. Currently input dataset schema only contains dtypes, # and other metadata is dropped due to roundtrip to tensors. schema_proto = schema_utils.schema_from_feature_spec( analyze_input_dataset_metadata.schema. as_feature_spec()) ([ dataset.decoded if stats_use_tfdv else dataset.encoded for dataset in analyze_data_list ] | 'FlattenPreTransformAnalysisDatasets' >> beam.Flatten() | 'GenerateAggregatePreTransformAnalysisStats' >> self._GenerateStats(pre_transform_feature_stats_path, schema_proto, use_deep_copy_optimization=True, use_tfdv=stats_use_tfdv)) transform_decode_fn = (self._GetDecodeFunction( raw_examples_data_format, transform_input_dataset_metadata.schema)) # transform_data_list is a superset of analyze_data_list, we pay the # cost to read the same dataset (analyze_data_list) again here to # prevent certain beam runner from doing large temp materialization. for (idx, dataset) in enumerate(transform_data_list): dataset.encoded = ( p | 'ReadTransformDataset[{}]'.format(idx) >> self._ReadExamples(dataset)) dataset.decoded = ( dataset.encoded | 'DecodeTransformDataset[{}]'.format(idx) >> self._DecodeInputs(transform_decode_fn)) (dataset.transformed, metadata) = ( ((dataset.decoded, transform_input_dataset_metadata), transform_fn) | 'TransformDataset[{}]'.format(idx) >> tft_beam.TransformDataset()) if materialize_output_paths or not stats_use_tfdv: dataset.transformed_and_encoded = ( dataset.transformed | 'EncodeTransformedDataset[{}]'.format(idx) >> beam.ParDo(self._EncodeAsExamples(), metadata)) if compute_statistics: # Aggregated feature stats after transformation. _, metadata = transform_fn post_transform_feature_stats_path = os.path.join( transform_output_path, tft.TFTransformOutput. POST_TRANSFORM_FEATURE_STATS_PATH) # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in # schema. Currently input dataset schema only contains dtypes, # and other metadata is dropped due to roundtrip to tensors. transformed_schema_proto = schema_utils.schema_from_feature_spec( metadata.schema.as_feature_spec()) ([(dataset.transformed if stats_use_tfdv else dataset.transformed_and_encoded) for dataset in transform_data_list] | 'FlattenPostTransformAnalysisDatasets' >> beam.Flatten() | 'GenerateAggregatePostTransformAnalysisStats' >> self._GenerateStats(post_transform_feature_stats_path, transformed_schema_proto, use_tfdv=stats_use_tfdv)) if per_set_stats_output_paths: assert len(transform_data_list) == len( per_set_stats_output_paths) # TODO(b/67632871): Remove duplicate stats gen compute that is # done both on a flattened view of the data, and on each span # below. bundles = zip(transform_data_list, per_set_stats_output_paths) for (idx, (dataset, output_path)) in enumerate(bundles): if stats_use_tfdv: data = dataset.transformed else: data = dataset.transformed_and_encoded (data | 'GeneratePostTransformStats[{}]'.format(idx) >> self._GenerateStats( output_path, transformed_schema_proto, use_tfdv=stats_use_tfdv)) if materialize_output_paths: assert len(transform_data_list) == len( materialize_output_paths) bundles = zip(transform_data_list, materialize_output_paths) for (idx, (dataset, output_path)) in enumerate(bundles): (dataset.transformed_and_encoded | 'Materialize[{}]'.format(idx) >> self._WriteExamples(raw_examples_file_format, output_path)) return _Status.OK()