Esempio n. 1
0
  def _GetPreprocessingFn(self, inputs: Mapping[Text, Any],
                          unused_outputs: Mapping[Text, Any]) -> Any:
    """Returns a user defined preprocessing_fn.

    Args:
      inputs: A dictionary of labelled input values.
      unused_outputs: A dictionary of labelled output values.

    Returns:
      User defined function.

    Raises:
      ValueError: When neither or both of MODULE_FILE and PREPROCESSING_FN
        are present in inputs.
    """
    has_module_file = bool(
        common.GetSoleValue(inputs, labels.MODULE_FILE, strict=False))
    has_preprocessing_fn = bool(
        common.GetSoleValue(inputs, labels.PREPROCESSING_FN, strict=False))

    if has_module_file == has_preprocessing_fn:
      raise ValueError(
          'Neither or both of MODULE_FILE and PREPROCESSING_FN have been '
          'supplied in inputs.')

    if has_module_file:
      return import_utils.import_func_from_source(
          common.GetSoleValue(inputs, labels.MODULE_FILE), 'preprocessing_fn')

    preprocessing_fn_path_split = common.GetSoleValue(
        inputs, labels.PREPROCESSING_FN).split('.')
    return import_utils.import_func_from_module(
        '.'.join(preprocessing_fn_path_split[0:-1]),
        preprocessing_fn_path_split[-1])
Esempio n. 2
0
    def _GetPreprocessingFn(self, inputs, unused_outputs):
        """Returns a user defined preprocessing_fn.

    Args:
      inputs: A dictionary of labelled input values.
      unused_outputs: A dictionary of labelled output values.

    Returns:
      User defined function.
    """
        return io_utils.import_func(
            common.GetSoleValue(inputs, labels.PREPROCESSING_FN),
            'preprocessing_fn')
Esempio n. 3
0
  def _GetPreprocessingFn(self, inputs: Mapping[Text, Any],
                          unused_outputs: Mapping[Text, Any]) -> Any:
    """Returns a user defined preprocessing_fn.

    Args:
      inputs: A dictionary of labelled input values.
      unused_outputs: A dictionary of labelled output values.

    Returns:
      User defined function.
    """
    return import_utils.import_func_from_source(
        common.GetSoleValue(inputs, labels.PREPROCESSING_FN),
        'preprocessing_fn')
Esempio n. 4
0
  def _RunBeamImpl(self, inputs: Mapping[Text, Any],
                   outputs: Mapping[Text, Any], preprocessing_fn: Any,
                   input_dataset_metadata: dataset_metadata.DatasetMetadata,
                   raw_examples_data_format: Text, transform_output_path: Text,
                   compute_statistics: bool,
                   materialize_output_paths: Sequence[Text]) -> _Status:
    """Perform data preprocessing with FlumeC++ runner.

    Args:
      inputs: A dictionary of labelled input values.
      outputs: A dictionary of labelled output values.
      preprocessing_fn: The tf.Transform preprocessing_fn.
      input_dataset_metadata: A DatasetMetadata object for the input data.
      raw_examples_data_format: A string describing the raw data format.
      transform_output_path: An absolute path to write the output to.
      compute_statistics: A bool indicating whether or not compute statistics.
      materialize_output_paths: Paths to materialized outputs.

    Raises:
      RuntimeError: If reset() is not being invoked between two run().
      ValueError: If the schema is empty.

    Returns:
      Status of the execution.
    """
    raw_examples_file_format = common.GetSoleValue(
        inputs, labels.EXAMPLES_FILE_FORMAT_LABEL, strict=False)
    analyze_and_transform_data_paths = common.GetValues(
        inputs, labels.ANALYZE_AND_TRANSFORM_DATA_PATHS_LABEL)
    transform_only_data_paths = common.GetValues(
        inputs, labels.TRANSFORM_ONLY_DATA_PATHS_LABEL)
    stats_use_tfdv = common.GetSoleValue(inputs,
                                         labels.TFT_STATISTICS_USE_TFDV_LABEL)
    per_set_stats_output_paths = common.GetValues(
        outputs, labels.PER_SET_STATS_OUTPUT_PATHS_LABEL)
    temp_path = common.GetSoleValue(outputs, labels.TEMP_OUTPUT_LABEL)

    input_cache_dir = common.GetSoleValue(
        inputs, labels.CACHE_INPUT_PATH_LABEL, strict=False)
    output_cache_dir = common.GetSoleValue(
        outputs, labels.CACHE_OUTPUT_PATH_LABEL, strict=False)

    tf.logging.info('Analyze and transform data patterns: %s',
                    list(enumerate(analyze_and_transform_data_paths)))
    tf.logging.info('Transform data patterns: %s',
                    list(enumerate(transform_only_data_paths)))
    tf.logging.info('Transform materialization output paths: %s',
                    list(enumerate(materialize_output_paths)))
    tf.logging.info('Transform output path: %s', transform_output_path)

    feature_spec = schema_utils.schema_as_feature_spec(
        _GetSchemaProto(input_dataset_metadata)).feature_spec
    try:
      analyze_input_columns = tft.get_analyze_input_columns(
          preprocessing_fn, feature_spec)
      transform_input_columns = (
          tft.get_transform_input_columns(preprocessing_fn, feature_spec))
    except AttributeError:
      # If using TFT 1.12, fall back to assuming all features are used.
      analyze_input_columns = feature_spec.keys()
      transform_input_columns = feature_spec.keys()
    # Use the same dataset (same columns) for AnalyzeDataset and computing
    # pre-transform stats so that the data will only be read once for these
    # two operations.
    if compute_statistics:
      analyze_input_columns = list(
          set(list(analyze_input_columns) + list(transform_input_columns)))
    if input_dataset_metadata.schema is _RAW_EXAMPLE_SCHEMA:
      analyze_input_dataset_metadata = input_dataset_metadata
      transform_input_dataset_metadata = input_dataset_metadata
    else:
      analyze_input_dataset_metadata = dataset_metadata.DatasetMetadata(
          dataset_schema.from_feature_spec(
              {feature: feature_spec[feature]
               for feature in analyze_input_columns}))
      transform_input_dataset_metadata = dataset_metadata.DatasetMetadata(
          dataset_schema.from_feature_spec(
              {feature: feature_spec[feature]
               for feature in transform_input_columns}))

    can_process_jointly = not bool(per_set_stats_output_paths or
                                   materialize_output_paths or output_cache_dir)
    analyze_data_list = self._MakeDatasetList(
        analyze_and_transform_data_paths, raw_examples_file_format,
        raw_examples_data_format, analyze_input_dataset_metadata,
        can_process_jointly)
    transform_data_list = self._MakeDatasetList(
        list(analyze_and_transform_data_paths) +
        list(transform_only_data_paths), raw_examples_file_format,
        raw_examples_data_format, transform_input_dataset_metadata,
        can_process_jointly)

    desired_batch_size = self._GetDesiredBatchSize(raw_examples_data_format)

    with self._CreatePipeline(outputs) as p:
      with tft_beam.Context(
          temp_dir=temp_path,
          desired_batch_size=desired_batch_size,
          passthrough_keys={_TRANSFORM_INTERNAL_FEATURE_FOR_KEY},
          use_deep_copy_optimization=True):
        # pylint: disable=expression-not-assigned
        # pylint: disable=no-value-for-parameter

        _ = (
            p | self._IncrementColumnUsageCounter(
                len(feature_spec.keys()), len(analyze_input_columns),
                len(transform_input_columns)))

        (new_analyze_data_dict, input_cache, flat_data_required) = (
            p | self._OptimizeRun(input_cache_dir, output_cache_dir,
                                  analyze_data_list, feature_spec,
                                  preprocessing_fn, self._GetCacheSource()))
        # Removing unneeded datasets if they won't be needed for
        # materialization. This means that these datasets won't be included in
        # the statistics computation or profiling either.
        if not materialize_output_paths:
          analyze_data_list = [
              d for d in new_analyze_data_dict.values() if d is not None
          ]

        analyze_decode_fn = (
            self._GetDecodeFunction(raw_examples_data_format,
                                    analyze_input_dataset_metadata.schema))

        for (idx, dataset) in enumerate(analyze_data_list):
          dataset.encoded = (
              p | 'ReadAnalysisDataset[{}]'.format(idx) >>
              self._ReadExamples(dataset))
          dataset.decoded = (
              dataset.encoded
              | 'DecodeAnalysisDataset[{}]'.format(idx) >>
              self._DecodeInputs(analyze_decode_fn))

        input_analysis_data = {}
        for key, dataset in six.iteritems(new_analyze_data_dict):
          if dataset is None:
            input_analysis_data[key] = None
          else:
            input_analysis_data[key] = dataset.decoded

        if flat_data_required:
          flat_input_analysis_data = (
              [dataset.decoded for dataset in analyze_data_list]
              | 'FlattenAnalysisDatasets' >> beam.Flatten(pipeline=p))
        else:
          flat_input_analysis_data = None
        if input_cache:
          tf.logging.info('Analyzing data with cache.')
        transform_fn, cache_output = (
            (flat_input_analysis_data, input_analysis_data, input_cache,
             input_dataset_metadata)
            | 'AnalyzeDataset' >> tft_beam.AnalyzeDatasetWithCache(
                preprocessing_fn, pipeline=p))

        # Write the raw/input metadata.
        (input_dataset_metadata
         | 'WriteMetadata' >> tft_beam.WriteMetadata(
             os.path.join(transform_output_path,
                          tft.TFTransformOutput.RAW_METADATA_DIR), p))

        # WriteTransformFn writes transform_fn and metadata to subdirectories
        # tensorflow_transform.SAVED_MODEL_DIR and
        # tensorflow_transform.TRANSFORMED_METADATA_DIR respectively.
        (transform_fn |
         'WriteTransformFn' >> tft_beam.WriteTransformFn(transform_output_path))

        if output_cache_dir is not None and cache_output is not None:
          # TODO(b/37788560): Possibly make this part of the beam graph.
          tf.io.gfile.makedirs(output_cache_dir)
          tf.logging.info('Using existing cache in: %s', input_cache_dir)
          if input_cache_dir is not None:
            # Only copy cache that is relevant to this iteration. This is
            # assuming that this pipeline operates on rolling ranges, so those
            # cache entries may also be relevant for future iterations.
            for span_cache_dir in input_analysis_data:
              full_span_cache_dir = os.path.join(input_cache_dir,
                                                 span_cache_dir)
              if tf.io.gfile.isdir(full_span_cache_dir):
                self._CopyCache(full_span_cache_dir,
                                os.path.join(output_cache_dir, span_cache_dir))

          (cache_output
           | 'WriteCache' >> analyzer_cache.WriteAnalysisCacheToFS(
               p, output_cache_dir, sink=self._GetCacheSink()))

        if compute_statistics or materialize_output_paths:
          # Do not compute pre-transform stats if the input format is raw proto,
          # as StatsGen would treat any input as tf.Example.
          if (compute_statistics and
              not self._IsDataFormatProto(raw_examples_data_format)):
            # Aggregated feature stats before transformation.
            pre_transform_feature_stats_path = os.path.join(
                transform_output_path,
                tft.TFTransformOutput.PRE_TRANSFORM_FEATURE_STATS_PATH)

            schema_proto = _GetSchemaProto(analyze_input_dataset_metadata)
            ([
                dataset.decoded if stats_use_tfdv else dataset.encoded
                for dataset in analyze_data_list
            ]
             | 'FlattenPreTransformAnalysisDatasets' >> beam.Flatten(pipeline=p)
             | 'GenerateAggregatePreTransformAnalysisStats' >>
             self._GenerateStats(
                 pre_transform_feature_stats_path,
                 schema_proto,
                 use_deep_copy_optimization=True,
                 use_tfdv=stats_use_tfdv))

          transform_decode_fn = (
              self._GetDecodeFunction(raw_examples_data_format,
                                      transform_input_dataset_metadata.schema))
          # transform_data_list is a superset of analyze_data_list, we pay the
          # cost to read the same dataset (analyze_data_list) again here to
          # prevent certain beam runner from doing large temp materialization.
          for (idx, dataset) in enumerate(transform_data_list):
            dataset.encoded = (
                p
                | 'ReadTransformDataset[{}]'.format(idx) >>
                self._ReadExamples(dataset))
            dataset.decoded = (
                dataset.encoded
                | 'DecodeTransformDataset[{}]'.format(idx) >>
                self._DecodeInputs(transform_decode_fn))
            (dataset.transformed,
             metadata) = (((dataset.decoded, transform_input_dataset_metadata),
                           transform_fn)
                          | 'TransformDataset[{}]'.format(idx) >>
                          tft_beam.TransformDataset())

            if materialize_output_paths or not stats_use_tfdv:
              dataset.transformed_and_encoded = (
                  dataset.transformed
                  | 'EncodeTransformedDataset[{}]'.format(idx) >> beam.ParDo(
                      self._EncodeAsExamples(), metadata))

          if compute_statistics:
            # Aggregated feature stats after transformation.
            _, metadata = transform_fn
            post_transform_feature_stats_path = os.path.join(
                transform_output_path,
                tft.TFTransformOutput.POST_TRANSFORM_FEATURE_STATS_PATH)

            # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in
            # schema. Currently input dataset schema only contains dtypes,
            # and other metadata is dropped due to roundtrip to tensors.
            transformed_schema_proto = _GetSchemaProto(metadata)

            ([(dataset.transformed
               if stats_use_tfdv else dataset.transformed_and_encoded)
              for dataset in transform_data_list]
             | 'FlattenPostTransformAnalysisDatasets' >> beam.Flatten()
             | 'GenerateAggregatePostTransformAnalysisStats' >>
             self._GenerateStats(
                 post_transform_feature_stats_path,
                 transformed_schema_proto,
                 use_tfdv=stats_use_tfdv))

            if per_set_stats_output_paths:
              assert len(transform_data_list) == len(per_set_stats_output_paths)
              # TODO(b/67632871): Remove duplicate stats gen compute that is
              # done both on a flattened view of the data, and on each span
              # below.
              bundles = zip(transform_data_list, per_set_stats_output_paths)
              for (idx, (dataset, output_path)) in enumerate(bundles):
                if stats_use_tfdv:
                  data = dataset.transformed
                else:
                  data = dataset.transformed_and_encoded
                (data
                 | 'GeneratePostTransformStats[{}]'.format(idx) >>
                 self._GenerateStats(
                     output_path,
                     transformed_schema_proto,
                     use_tfdv=stats_use_tfdv))

          if materialize_output_paths:
            assert len(transform_data_list) == len(materialize_output_paths)
            bundles = zip(transform_data_list, materialize_output_paths)
            for (idx, (dataset, output_path)) in enumerate(bundles):
              (dataset.transformed_and_encoded
               | 'Materialize[{}]'.format(idx) >> self._WriteExamples(
                   raw_examples_file_format, output_path))

    return _Status.OK()
Esempio n. 5
0
  def Transform(self, inputs: Mapping[Text, Any], outputs: Mapping[Text, Any],
                status_file: Text) -> None:
    """Executes on request.

    This is the implementation part of transform executor. This is intended for
    using or extending the executor without artifact dependency.

    Args:
      inputs: A dictionary of labelled input values, including:
        - labels.COMPUTE_STATISTICS_LABEL: Whether compute statistics.
        - labels.SCHEMA_PATH_LABEL: Path to schema file.
        - labels.EXAMPLES_FILE_FORMAT_LABEL: Example file format, optional.
        - labels.EXAMPLES_DATA_FORMAT_LABEL: Example data format.
        - labels.ANALYZE_AND_TRANSFORM_DATA_PATHS_LABEL: Paths or path patterns
          to analyze and transform data.
        - labels.TRANSFORM_DATA_PATHS_LABEL: Paths or path patterns to transform
          only data.
        - labels.TFT_STATISTICS_USE_TFDV_LABEL: Whether use tfdv to compute
          statistics.
        - labels.PREPROCESSING_FN: Path to a Python module that contains the
          preprocessing_fn, optional.
      outputs: A dictionary of labelled output values, including:
        - labels.PER_SET_STATS_OUTPUT_PATHS_LABEL: Paths to statistics output,
          optional.
        - labels.TRANSFORM_METADATA_OUTPUT_PATH_LABEL: A path to
          TFTransformOutput output.
        - labels.TRANSFORM_MATERIALIZE_OUTPUT_PATHS_LABEL: Paths to transform
          materialization.
        - labels.TEMP_OUTPUT_LABEL: A path to temporary directory.
      status_file: Where the status should be written (not yet implemented)
    """

    del status_file  # unused
    compute_statistics = common.GetSoleValue(inputs,
                                             labels.COMPUTE_STATISTICS_LABEL)
    transform_output_path = common.GetSoleValue(
        outputs, labels.TRANSFORM_METADATA_OUTPUT_PATH_LABEL)
    raw_examples_data_format = common.GetSoleValue(
        inputs, labels.EXAMPLES_DATA_FORMAT_LABEL)
    schema = common.GetSoleValue(inputs, labels.SCHEMA_PATH_LABEL)
    input_dataset_metadata = self._ReadMetadata(raw_examples_data_format,
                                                schema)

    tf.logging.info('Inputs to executor.Transform function: {}'.format(inputs))
    tf.logging.info(
        'Outputs to executor.Transform function: {}'.format(outputs))

    feature_spec = schema_utils.schema_as_feature_spec(
        _GetSchemaProto(input_dataset_metadata)).feature_spec

    # NOTE: We disallow an empty schema, which we detect by testing the
    # number of columns.  While in principal an empty schema is valid, in
    # practice this is a sign of a user error, and this is a convenient
    # place to catch that error.
    if (not feature_spec and
        not self._ShouldDecodeAsRawExample(raw_examples_data_format)):
      raise ValueError(messages.SCHEMA_EMPTY)

    preprocessing_fn = self._GetPreprocessingFn(inputs, outputs)

    materialize_output_paths = common.GetValues(
        outputs, labels.TRANSFORM_MATERIALIZE_OUTPUT_PATHS_LABEL)

    # Inspecting the preprocessing_fn even if we know we need a full pass in
    # order to fail faster if it fails.
    try:
      analyze_input_columns = tft.get_analyze_input_columns(
          preprocessing_fn, feature_spec)
    except AttributeError:
      # If using TFT 1.12, fall back to assuming all features are used.
      analyze_input_columns = feature_spec.keys()

    if not compute_statistics and not materialize_output_paths:
      if analyze_input_columns:
        tf.logging.warning(
            'Not using the in-place Transform because the following features '
            'require analyzing: {}'.format(
                tuple(c for c in analyze_input_columns)))
      else:
        tf.logging.warning(
            'Using the in-place Transform since compute_statistics=False, '
            'it does not materialize transformed data, and the configured '
            'preprocessing_fn appears to not require analyzing the data.')
        self._RunInPlaceImpl(preprocessing_fn, input_dataset_metadata,
                             transform_output_path)
        # TODO(b/122478841): Writes status to status file.
        return
    self._RunBeamImpl(inputs, outputs, preprocessing_fn, input_dataset_metadata,
                      raw_examples_data_format, transform_output_path,
                      compute_statistics, materialize_output_paths)
Esempio n. 6
0
    def _RunBeamImpl(self, inputs, outputs, preprocessing_fn,
                     input_dataset_metadata, raw_examples_data_format,
                     transform_output_path, compute_statistics,
                     materialize_output_paths):
        """Perform data preprocessing with FlumeC++ runner.

    Args:
      inputs: A dictionary of labelled input values.
      outputs: A dictionary of labelled output values.
      preprocessing_fn: The tf.Transform preprocessing_fn.
      input_dataset_metadata: A DatasetMetadata object for the input data.
      raw_examples_data_format: A string describing the raw data format.
      transform_output_path: An absolute path to write the output to.
      compute_statistics: A bool indicating whether or not compute statistics.
      materialize_output_paths: Paths to materialized outputs.

    Raises:
      RuntimeError: If reset() is not being invoked between two run().
      ValueError: If the schema is empty.

    Returns:
      Status of the execution.
    """
        raw_examples_file_format = common.GetSoleValue(
            inputs, labels.EXAMPLES_FILE_FORMAT_LABEL, strict=False)
        analyze_and_transform_data_paths = common.GetValues(
            inputs, labels.ANALYZE_AND_TRANSFORM_DATA_PATHS_LABEL)
        transform_only_data_paths = common.GetValues(
            inputs, labels.TRANSFORM_ONLY_DATA_PATHS_LABEL)
        stats_use_tfdv = common.GetSoleValue(
            inputs, labels.TFT_STATISTICS_USE_TFDV_LABEL)
        per_set_stats_output_paths = common.GetValues(
            outputs, labels.PER_SET_STATS_OUTPUT_PATHS_LABEL)
        temp_path = common.GetSoleValue(outputs, labels.TEMP_OUTPUT_LABEL)

        tf.logging.info('Analyze and transform data patterns: %s',
                        list(enumerate(analyze_and_transform_data_paths)))
        tf.logging.info('Transform data patterns: %s',
                        list(enumerate(transform_only_data_paths)))
        tf.logging.info('Transform materialization output paths: %s',
                        list(enumerate(materialize_output_paths)))
        tf.logging.info('Transform output path: %s', transform_output_path)

        feature_spec = input_dataset_metadata.schema.as_feature_spec()
        try:
            analyze_input_columns = tft.get_analyze_input_columns(
                preprocessing_fn, feature_spec)
            transform_input_columns = (tft.get_transform_input_columns(
                preprocessing_fn, feature_spec))
        except AttributeError:
            # If using TFT 1.12, fall back to assuming all features are used.
            analyze_input_columns = feature_spec.keys()
            transform_input_columns = feature_spec.keys()
        # Use the same dataset (same columns) for AnalyzeDataset and computing
        # pre-transform stats so that the data will only be read once for these
        # two operations.
        if compute_statistics:
            analyze_input_columns = list(
                set(
                    list(analyze_input_columns) +
                    list(transform_input_columns)))
        analyze_input_dataset_metadata = copy.deepcopy(input_dataset_metadata)
        transform_input_dataset_metadata = copy.deepcopy(
            input_dataset_metadata)
        if input_dataset_metadata.schema is not _RAW_EXAMPLE_SCHEMA:
            analyze_input_dataset_metadata.schema = dataset_schema.from_feature_spec(
                {
                    feature: feature_spec[feature]
                    for feature in analyze_input_columns
                })
            transform_input_dataset_metadata.schema = (
                dataset_schema.from_feature_spec({
                    feature: feature_spec[feature]
                    for feature in transform_input_columns
                }))

        can_process_jointly = not bool(per_set_stats_output_paths
                                       or materialize_output_paths)
        analyze_data_list = self._MakeDatasetList(
            analyze_and_transform_data_paths, raw_examples_file_format,
            raw_examples_data_format, analyze_input_dataset_metadata,
            can_process_jointly)
        transform_data_list = self._MakeDatasetList(
            list(analyze_and_transform_data_paths) +
            list(transform_only_data_paths), raw_examples_file_format,
            raw_examples_data_format, transform_input_dataset_metadata,
            can_process_jointly)

        desired_batch_size = self._GetDesiredBatchSize(
            raw_examples_data_format)

        with self._CreatePipeline(outputs) as p:
            with tft_beam.Context(
                    temp_dir=temp_path,
                    desired_batch_size=desired_batch_size,
                    passthrough_keys={_TRANSFORM_INTERNAL_FEATURE_FOR_KEY},
                    use_deep_copy_optimization=True):
                # pylint: disable=expression-not-assigned
                # pylint: disable=no-value-for-parameter

                analyze_decode_fn = (self._GetDecodeFunction(
                    raw_examples_data_format,
                    analyze_input_dataset_metadata.schema))

                for (idx, dataset) in enumerate(analyze_data_list):
                    dataset.encoded = (p
                                       | 'ReadAnalysisDataset[{}]'.format(idx)
                                       >> self._ReadExamples(dataset))
                    dataset.decoded = (
                        dataset.encoded
                        | 'DecodeAnalysisDataset[{}]'.format(idx) >>
                        self._DecodeInputs(analyze_decode_fn))

                input_analysis_data = (
                    [dataset.decoded for dataset in analyze_data_list]
                    | 'FlattenAnalysisDatasets' >> beam.Flatten())
                transform_fn = ((input_analysis_data, input_dataset_metadata)
                                | 'AnalyzeDataset' >>
                                tft_beam.AnalyzeDataset(preprocessing_fn))
                # Write the raw/input metadata.
                (input_dataset_metadata
                 | 'WriteMetadata' >> tft_beam.WriteMetadata(
                     os.path.join(transform_output_path,
                                  tft.TFTransformOutput.RAW_METADATA_DIR), p))

                # WriteTransformFn writes transform_fn and metadata to subdirectories
                # tensorflow_transform.SAVED_MODEL_DIR and
                # tensorflow_transform.TRANSFORMED_METADATA_DIR respectively.
                (transform_fn | 'WriteTransformFn' >>
                 tft_beam.WriteTransformFn(transform_output_path))

                if compute_statistics or materialize_output_paths:
                    # Do not compute pre-transform stats if the input format is raw proto,
                    # as StatsGen would treat any input as tf.Example.
                    if (compute_statistics and not self._IsDataFormatProto(
                            raw_examples_data_format)):
                        # Aggregated feature stats before transformation.
                        pre_transform_feature_stats_path = os.path.join(
                            transform_output_path, tft.TFTransformOutput.
                            PRE_TRANSFORM_FEATURE_STATS_PATH)

                        # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in
                        # schema. Currently input dataset schema only contains dtypes,
                        # and other metadata is dropped due to roundtrip to tensors.
                        schema_proto = schema_utils.schema_from_feature_spec(
                            analyze_input_dataset_metadata.schema.
                            as_feature_spec())
                        ([
                            dataset.decoded
                            if stats_use_tfdv else dataset.encoded
                            for dataset in analyze_data_list
                        ]
                         | 'FlattenPreTransformAnalysisDatasets' >>
                         beam.Flatten()
                         | 'GenerateAggregatePreTransformAnalysisStats' >>
                         self._GenerateStats(pre_transform_feature_stats_path,
                                             schema_proto,
                                             use_deep_copy_optimization=True,
                                             use_tfdv=stats_use_tfdv))

                    transform_decode_fn = (self._GetDecodeFunction(
                        raw_examples_data_format,
                        transform_input_dataset_metadata.schema))
                    # transform_data_list is a superset of analyze_data_list, we pay the
                    # cost to read the same dataset (analyze_data_list) again here to
                    # prevent certain beam runner from doing large temp materialization.
                    for (idx, dataset) in enumerate(transform_data_list):
                        dataset.encoded = (
                            p
                            | 'ReadTransformDataset[{}]'.format(idx) >>
                            self._ReadExamples(dataset))
                        dataset.decoded = (
                            dataset.encoded
                            | 'DecodeTransformDataset[{}]'.format(idx) >>
                            self._DecodeInputs(transform_decode_fn))
                        (dataset.transformed, metadata) = (
                            ((dataset.decoded,
                              transform_input_dataset_metadata), transform_fn)
                            | 'TransformDataset[{}]'.format(idx) >>
                            tft_beam.TransformDataset())

                        if materialize_output_paths or not stats_use_tfdv:
                            dataset.transformed_and_encoded = (
                                dataset.transformed
                                | 'EncodeTransformedDataset[{}]'.format(idx) >>
                                beam.ParDo(self._EncodeAsExamples(), metadata))

                    if compute_statistics:
                        # Aggregated feature stats after transformation.
                        _, metadata = transform_fn
                        post_transform_feature_stats_path = os.path.join(
                            transform_output_path, tft.TFTransformOutput.
                            POST_TRANSFORM_FEATURE_STATS_PATH)

                        # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in
                        # schema. Currently input dataset schema only contains dtypes,
                        # and other metadata is dropped due to roundtrip to tensors.
                        transformed_schema_proto = schema_utils.schema_from_feature_spec(
                            metadata.schema.as_feature_spec())

                        ([(dataset.transformed if stats_use_tfdv else
                           dataset.transformed_and_encoded)
                          for dataset in transform_data_list]
                         | 'FlattenPostTransformAnalysisDatasets' >>
                         beam.Flatten()
                         | 'GenerateAggregatePostTransformAnalysisStats' >>
                         self._GenerateStats(post_transform_feature_stats_path,
                                             transformed_schema_proto,
                                             use_tfdv=stats_use_tfdv))

                        if per_set_stats_output_paths:
                            assert len(transform_data_list) == len(
                                per_set_stats_output_paths)
                            # TODO(b/67632871): Remove duplicate stats gen compute that is
                            # done both on a flattened view of the data, and on each span
                            # below.
                            bundles = zip(transform_data_list,
                                          per_set_stats_output_paths)
                            for (idx, (dataset,
                                       output_path)) in enumerate(bundles):
                                if stats_use_tfdv:
                                    data = dataset.transformed
                                else:
                                    data = dataset.transformed_and_encoded
                                (data
                                 | 'GeneratePostTransformStats[{}]'.format(idx)
                                 >> self._GenerateStats(
                                     output_path,
                                     transformed_schema_proto,
                                     use_tfdv=stats_use_tfdv))

                    if materialize_output_paths:
                        assert len(transform_data_list) == len(
                            materialize_output_paths)
                        bundles = zip(transform_data_list,
                                      materialize_output_paths)
                        for (idx, (dataset,
                                   output_path)) in enumerate(bundles):
                            (dataset.transformed_and_encoded
                             | 'Materialize[{}]'.format(idx) >>
                             self._WriteExamples(raw_examples_file_format,
                                                 output_path))

        return _Status.OK()