示例#1
0
    def expand(self, pipeline):
        # TODO(b/147620802): Consider making this (and other parameters)
        # configurable to test more variants (e.g. with and without deep-copy
        # optimisation, with and without cache, etc).
        with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
            converter = tft.coders.ExampleProtoCoder(self._tf_metadata_schema,
                                                     serialized=False)
            raw_data = (
                pipeline
                |
                "ReadDataset" >> beam.Create(self._dataset.read_raw_dataset())
                | "Decode" >> beam.Map(converter.decode))
            transform_fn, output_metadata = (
                (raw_data, self._transform_input_dataset_metadata)
                | "AnalyzeDataset" >> tft_beam.AnalyzeDataset(
                    self._preprocessing_fn))

            if self._generate_dataset:
                _ = transform_fn | "CopySavedModel" >> _CopySavedModel(
                    dest_path=self._dataset.tft_saved_model_path())

            (transformed_dataset, transformed_metadata) = (
                ((raw_data, self._transform_input_dataset_metadata),
                 (transform_fn, output_metadata))
                | "TransformDataset" >> tft_beam.TransformDataset())
            return transformed_dataset, transformed_metadata
示例#2
0
def _main(argv=None):
    logging.getLogger().setLevel(logging.INFO)

    parser = argparse.ArgumentParser()
    parser.add_argument('--raw_examples_path', required=True)
    parser.add_argument('--raw_examples_schema_path', required=True)
    parser.add_argument('--preprocessing_module_path', required=True)
    parser.add_argument('--transform_fn_dir', required=True)
    known_args, pipeline_args = parser.parse_known_args(argv)

    raw_examples_schema = load_schema(known_args.raw_examples_schema_path)
    raw_examples_coder = tft.coders.ExampleProtoCoder(raw_examples_schema)
    raw_examples_metadata = dataset_metadata.DatasetMetadata(
        raw_examples_schema)

    tft_preprocessing = load_module_from_file_path(
        'tft_preprocessing', known_args.preprocessing_module_path)
    preprocessing_fn = tft_preprocessing.preprocessing_fn

    pipeline_options = PipelineOptions(pipeline_args)
    pipeline_options.view_as(SetupOptions).save_main_session = True

    with beam.Pipeline(options=pipeline_options) as pipeline:
        with tft_beam.Context(temp_dir=get_beam_temp_dir(pipeline_options)):
            raw_examples = pipeline | 'ReadRawExamples' >> beam.io.ReadFromTFRecord(
                known_args.raw_examples_path, coder=raw_examples_coder)
            raw_examples_dataset = (raw_examples, raw_examples_metadata)
            transform_fn = raw_examples_dataset | tft_beam.AnalyzeDataset(
                preprocessing_fn)
            transform_fn | 'WriteTransformFn' >> tft_beam.WriteTransformFn(
                known_args.transform_fn_dir)
示例#3
0
  def expand(self, pipeline):
    # TODO(b/147620802): Consider making this (and other parameters)
    # configurable to test more variants (e.g. with and without deep-copy
    # optimisation, with and without cache, etc).
    with tft_beam.Context(
        temp_dir=tempfile.mkdtemp(),
        force_tf_compat_v1=self._force_tf_compat_v1):
      raw_data = (
          pipeline
          | "ReadDataset" >> beam.Create(
              self._dataset.read_raw_dataset(
                  deserialize=False, limit=self._max_num_examples))
          | "Decode" >> self._tfxio.BeamSource())
      transform_fn, output_metadata = (
          (raw_data, self._tfxio.TensorAdapterConfig())
          | "AnalyzeDataset" >> tft_beam.AnalyzeDataset(self._preprocessing_fn))

      if self._generate_dataset:
        _ = transform_fn | "CopySavedModel" >> _CopySavedModel(
            dest_path=self._dataset.tft_saved_model_path(
                self._force_tf_compat_v1))

      (transformed_dataset, transformed_metadata) = (
          ((raw_data, self._tfxio.TensorAdapterConfig()),
           (transform_fn, output_metadata))
          | "TransformDataset" >> tft_beam.TransformDataset())
      return transformed_dataset, transformed_metadata
示例#4
0
def transform_train_and_eval(pipeline, train_data, eval_data, data_source,
                             transform_dir, output_dir, schema):
  """Analyzes and transforms data.

  Args:
    pipeline: Beam Pipeline instance.
    train_data: Training CSV data.
    eval_data: Evaluation CSV data.
    data_source: Input data source - path to CSV file or BigQuery table. Expects
      either `csv` or `bigquery`.
    transform_dir: Directory to write transformed output. If this directory
      exists beam loads `transform_fn` instead of computing it again.
    output_dir: Directory to write transformed output.
    schema: A text-serialized TensorFlow metadata schema for the input data.
  """
  train_raw_data = (
      pipeline | 'ReadTrainData' >> ReadData(train_data, data_source, schema,
                                             tf.estimator.ModeKeys.TRAIN))
  eval_raw_data = (
      pipeline | 'ReadEvalData' >> ReadData(eval_data, data_source, schema,
                                            tf.estimator.ModeKeys.EVAL))
  schema = utils.make_dataset_schema(schema, mode=tf.estimator.ModeKeys.TRAIN)
  input_metadata = dataset_metadata.DatasetMetadata(schema)
  logger.info('Creating new transform model.')
  transform_fn = ((train_raw_data, input_metadata)
                  | ('Analyze' >> tft_beam.AnalyzeDataset(preprocessing_fn)))

  (transform_fn
   | ('WriteTransformFn' >> tft_beam.WriteTransformFn(transform_dir)))

  (train_raw_data
   | 'TransformAndWriteTraining' >> transform_and_write(
       input_metadata, output_dir, transform_fn, _TRAIN_PREFIX))
  (eval_raw_data
   | 'TransformAndWriteEval' >> transform_and_write(input_metadata, output_dir,
                                                    transform_fn, _EVAL_PREFIX))
示例#5
0
def transform_data(input_handle,
                   outfile_prefix,
                   working_dir,
                   schema_file,
                   transform_dir=None,
                   max_rows=None,
                   pipeline_args=None):
    """The main tf.transform method which analyzes and transforms data.

  Args:
    input_handle: BigQuery table name to process specified as DATASET.TABLE or
      path to csv file with input data.
    outfile_prefix: Filename prefix for emitted transformed examples
    working_dir: Directory in which transformed examples and transform function
      will be emitted.
    schema_file: An file path that contains a text-serialized TensorFlow
      metadata schema of the input data.
    transform_dir: Directory in which the transform output is located. If
      provided, this will load the transform_fn from disk instead of computing
      it over the data. Hint: this is useful for transforming eval data.
    max_rows: Number of rows to query from BigQuery
    pipeline_args: additional DataflowRunner or DirectRunner args passed to the
      beam pipeline.
  """
    def preprocessing_fn(inputs):
        """tf.transform's callback function for preprocessing inputs.

    Args:
      inputs: map from feature keys to raw not-yet-transformed features.

    Returns:
      Map from string feature key to transformed feature operations.
    """
        outputs = {}
        for key in taxi.DENSE_FLOAT_FEATURE_KEYS:
            # Preserve this feature as a dense float, setting nan's to the mean.
            outputs[taxi.transformed_name(key)] = transform.scale_to_z_score(
                _fill_in_missing(inputs[key]))

        for key in taxi.VOCAB_FEATURE_KEYS:
            # Build a vocabulary for this feature.
            outputs[taxi.transformed_name(
                key)] = transform.compute_and_apply_vocabulary(
                    _fill_in_missing(inputs[key]),
                    top_k=taxi.VOCAB_SIZE,
                    num_oov_buckets=taxi.OOV_SIZE)

        for key in taxi.BUCKET_FEATURE_KEYS:
            outputs[taxi.transformed_name(key)] = transform.bucketize(
                _fill_in_missing(inputs[key]), taxi.FEATURE_BUCKET_COUNT)

        for key in taxi.CATEGORICAL_FEATURE_KEYS:
            outputs[taxi.transformed_name(key)] = _fill_in_missing(inputs[key])

        # Was this passenger a big tipper?
        taxi_fare = _fill_in_missing(inputs[taxi.FARE_KEY])
        tips = _fill_in_missing(inputs[taxi.LABEL_KEY])
        outputs[taxi.transformed_name(taxi.LABEL_KEY)] = tf.where(
            tf.is_nan(taxi_fare),
            tf.cast(tf.zeros_like(taxi_fare), tf.int64),
            # Test if the tip was > 20% of the fare.
            tf.cast(tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))),
                    tf.int64))

        return outputs

    schema = taxi.read_schema(schema_file)
    raw_feature_spec = taxi.get_raw_feature_spec(schema)
    raw_schema = dataset_schema.from_feature_spec(raw_feature_spec)
    raw_data_metadata = dataset_metadata.DatasetMetadata(raw_schema)

    with beam.Pipeline(argv=pipeline_args) as pipeline:
        with tft_beam.Context(temp_dir=working_dir):
            if input_handle.lower().endswith('csv'):
                csv_coder = taxi.make_csv_coder(schema)
                raw_data = (pipeline
                            | 'ReadFromText' >> beam.io.ReadFromText(
                                input_handle, skip_header_lines=1))
                decode_transform = beam.Map(csv_coder.decode)
            else:
                query = taxi.make_sql(input_handle, max_rows, for_eval=False)
                raw_data = (pipeline
                            | 'ReadBigQuery' >> beam.io.Read(
                                beam.io.BigQuerySource(query=query,
                                                       use_standard_sql=True)))
                decode_transform = beam.Map(taxi.clean_raw_data_dict,
                                            raw_feature_spec=raw_feature_spec)

            if transform_dir is None:
                decoded_data = raw_data | 'DecodeForAnalyze' >> decode_transform
                transform_fn = (
                    (decoded_data, raw_data_metadata) |
                    ('Analyze' >> tft_beam.AnalyzeDataset(preprocessing_fn)))

                _ = (transform_fn
                     | ('WriteTransformFn' >>
                        tft_beam.WriteTransformFn(working_dir)))
            else:
                transform_fn = pipeline | tft_beam.ReadTransformFn(
                    transform_dir)

            # Shuffling the data before materialization will improve Training
            # effectiveness downstream. Here we shuffle the raw_data (as opposed to
            # decoded data) since it has a compact representation.
            shuffled_data = raw_data | 'RandomizeData' >> beam.transforms.Reshuffle(
            )

            decoded_data = shuffled_data | 'DecodeForTransform' >> decode_transform
            (transformed_data, transformed_metadata) = (
                ((decoded_data, raw_data_metadata), transform_fn)
                | 'Transform' >> tft_beam.TransformDataset())

            coder = example_proto_coder.ExampleProtoCoder(
                transformed_metadata.schema)
            _ = (transformed_data
                 | 'SerializeExamples' >> beam.Map(coder.encode)
                 | 'WriteExamples' >> beam.io.WriteToTFRecord(
                     os.path.join(working_dir, outfile_prefix),
                     file_name_suffix='.gz'))
示例#6
0
  def Do(self, input_dict: Dict[Text, List[types.Artifact]],
         output_dict: Dict[Text, List[types.Artifact]],
         exec_properties: Dict[Text, Any]) -> None:
    """Get human review result on a model through Slack channel.

    Args:
      input_dict: Input dict from input key to a list of artifacts, including:
        - model_export: exported model from trainer.
        - model_blessing: model blessing path from model_validator.
      output_dict: Output dict from key to a list of artifacts, including:
        - slack_blessing: model blessing result.
      exec_properties: A dict of execution properties, including:
        - slack_token: Token used to setup connection with slack server.
        - slack_channel_id: The id of the Slack channel to send and receive
          messages.
        - timeout_sec: How long do we wait for response, in seconds.

    Returns:
      None

    Raises:
      TimeoutError:
        When there is no decision made within timeout_sec.
      ConnectionError:
        When connection to slack server cannot be established.

    """
    self._log_startup(input_dict, output_dict, exec_properties)
    transform_graph_uri = artifact_utils.get_single_uri(
        input_dict[TRANSFORM_GRAPH_KEY])
    temp_path = os.path.join(transform_graph_uri, _TEMP_DIR_IN_TRANSFORM_OUTPUT)
    # transformed_schema_file = os.path.join(
    #   transform_graph_uri,
    #   tft.TFTransformOutput.TRANSFORMED_METADATA_DIR,
    #   'schema.pbtxt'
    # )
    # transformed_schema_proto = io_utils.parse_pbtxt_file(
    #   transformed_schema_file,
    #   schema_pb2.Schema()
    # )
    transformed_train_output = artifact_utils.get_split_uri(
      output_dict[TRANSFORMED_EXAMPLES_KEY], 'train')
    transformed_eval_output = artifact_utils.get_split_uri(
      output_dict[TRANSFORMED_EXAMPLES_KEY], 'eval')

    tf_transform_output = tft.TFTransformOutput(transform_graph_uri)
    # transform_output_dataset_metadata = dataset_metadata.DatasetMetadata(
    #   schema=transformed_schema_proto
    # )

    # transform_fn = (tf_transform_output.transform_raw_features, transform_output_dataset_metadata)
    # feature_spec = schema_utils.schema_as_feature_spec(schema_proto).feature_spec
    schema_file = io_utils.get_only_uri_in_dir(
        artifact_utils.get_single_uri(input_dict[SCHEMA_KEY]))
    schema_proto = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema())
    transform_input_dataset_metadata = dataset_metadata.DatasetMetadata(
      schema_proto
    )

    train_data_uri = artifact_utils.get_split_uri(
      input_dict[EXAMPLES_KEY],
      'train'
    )
    eval_data_uri = artifact_utils.get_split_uri(
      input_dict[EXAMPLES_KEY],
      'eval'
    )
    analyze_data_paths = [io_utils.all_files_pattern(train_data_uri)]
    transform_data_paths = [
      io_utils.all_files_pattern(train_data_uri),
      io_utils.all_files_pattern(eval_data_uri),
    ]
    materialize_output_paths = [
      os.path.join(transformed_train_output, _DEFAULT_TRANSFORMED_EXAMPLES_PREFIX),
      os.path.join(transformed_eval_output, _DEFAULT_TRANSFORMED_EXAMPLES_PREFIX)
    ]
    transform_data_list = self._MakeDatasetList(
      transform_data_paths,
      materialize_output_paths
    )
    analyze_data_list = self._MakeDatasetList(
      analyze_data_paths,
    )

    with self._make_beam_pipeline() as pipeline:
      with tft_beam.Context(temp_dir=temp_path):
        # NOTE: Unclear if there is a difference between input_dataset_metadata
        # and transform_input_dataset_metadata. Look at Transform executor.
        decode_fn = tft.coders.ExampleProtoCoder(schema_proto, serialized=True).decode

        input_analysis_data = {}
        for dataset in analyze_data_list:
          infix = 'AnalysisIndex{}'.format(dataset.index)
          dataset.serialized = (
            pipeline
            | 'ReadDataset[{}]'.format(infix) >> self._ReadExamples(
                dataset, transform_input_dataset_metadata))
          dataset.decoded = (
            dataset.serialized
            | 'Decode[{}]'.format(infix)
            >> self._DecodeInputs(decode_fn))
          input_analysis_data[dataset.dataset_key] = dataset.decoded

        if not hasattr(tft_beam.analyzer_cache, 'DatasetKey'):
          input_analysis_data = (
              [
                  dataset for dataset in input_analysis_data.values()
                  if dataset is not None
              ]
              | 'FlattenAnalysisDatasetsBecauseItIsRequired' >>
              beam.Flatten(pipeline=pipeline))

        transform_fn = (
            (input_analysis_data, transform_input_dataset_metadata)
            | 'Analyze' >> tft_beam.AnalyzeDataset(
                tf_transform_output.transform_raw_features, pipeline=pipeline))

        for dataset in transform_data_list:
          infix = 'TransformIndex{}'.format(dataset.index)
          dataset.serialized = (
            pipeline
            | 'ReadDataset[{}]'.format(infix) >> self._ReadExamples(
                dataset, transform_input_dataset_metadata))

          dataset.decoded = (
            dataset.serialized
            | 'Decode[{}]'.format(infix)
            >> self._DecodeInputs(decode_fn))

          dataset.transformed, metadata = (
              ((dataset.decoded, transform_input_dataset_metadata), transform_fn)
              | 'Transform[{}]'.format(infix) >> tft_beam.TransformDataset())

          dataset.transformed_and_serialized = (
              dataset.transformed
              | 'EncodeAndSerialize[{}]'.format(infix)
              >> beam.ParDo(self._EncodeAsSerializedExamples(), _GetSchemaProto(metadata)))

          _ = (
            dataset.transformed_and_serialized
            | 'Materialize[{}]'.format(infix) >> self._WriteExamples(dataset.materialize_output_path))
示例#7
0
    def _RunBeamImpl(self, inputs, outputs, preprocessing_fn,
                     input_dataset_metadata, raw_examples_data_format,
                     transform_output_path, compute_statistics,
                     materialize_output_paths):
        """Perform data preprocessing with FlumeC++ runner.

    Args:
      inputs: A dictionary of labelled input values.
      outputs: A dictionary of labelled output values.
      preprocessing_fn: The tf.Transform preprocessing_fn.
      input_dataset_metadata: A DatasetMetadata object for the input data.
      raw_examples_data_format: A string describing the raw data format.
      transform_output_path: An absolute path to write the output to.
      compute_statistics: A bool indicating whether or not compute statistics.
      materialize_output_paths: Paths to materialized outputs.

    Raises:
      RuntimeError: If reset() is not being invoked between two run().
      ValueError: If the schema is empty.

    Returns:
      Status of the execution.
    """
        raw_examples_file_format = common.GetSoleValue(
            inputs, labels.EXAMPLES_FILE_FORMAT_LABEL, strict=False)
        analyze_and_transform_data_paths = common.GetValues(
            inputs, labels.ANALYZE_AND_TRANSFORM_DATA_PATHS_LABEL)
        transform_only_data_paths = common.GetValues(
            inputs, labels.TRANSFORM_ONLY_DATA_PATHS_LABEL)
        stats_use_tfdv = common.GetSoleValue(
            inputs, labels.TFT_STATISTICS_USE_TFDV_LABEL)
        per_set_stats_output_paths = common.GetValues(
            outputs, labels.PER_SET_STATS_OUTPUT_PATHS_LABEL)
        temp_path = common.GetSoleValue(outputs, labels.TEMP_OUTPUT_LABEL)

        tf.logging.info('Analyze and transform data patterns: %s',
                        list(enumerate(analyze_and_transform_data_paths)))
        tf.logging.info('Transform data patterns: %s',
                        list(enumerate(transform_only_data_paths)))
        tf.logging.info('Transform materialization output paths: %s',
                        list(enumerate(materialize_output_paths)))
        tf.logging.info('Transform output path: %s', transform_output_path)

        feature_spec = input_dataset_metadata.schema.as_feature_spec()
        try:
            analyze_input_columns = tft.get_analyze_input_columns(
                preprocessing_fn, feature_spec)
            transform_input_columns = (tft.get_transform_input_columns(
                preprocessing_fn, feature_spec))
        except AttributeError:
            # If using TFT 1.12, fall back to assuming all features are used.
            analyze_input_columns = feature_spec.keys()
            transform_input_columns = feature_spec.keys()
        # Use the same dataset (same columns) for AnalyzeDataset and computing
        # pre-transform stats so that the data will only be read once for these
        # two operations.
        if compute_statistics:
            analyze_input_columns = list(
                set(
                    list(analyze_input_columns) +
                    list(transform_input_columns)))
        analyze_input_dataset_metadata = copy.deepcopy(input_dataset_metadata)
        transform_input_dataset_metadata = copy.deepcopy(
            input_dataset_metadata)
        if input_dataset_metadata.schema is not _RAW_EXAMPLE_SCHEMA:
            analyze_input_dataset_metadata.schema = dataset_schema.from_feature_spec(
                {
                    feature: feature_spec[feature]
                    for feature in analyze_input_columns
                })
            transform_input_dataset_metadata.schema = (
                dataset_schema.from_feature_spec({
                    feature: feature_spec[feature]
                    for feature in transform_input_columns
                }))

        can_process_jointly = not bool(per_set_stats_output_paths
                                       or materialize_output_paths)
        analyze_data_list = self._MakeDatasetList(
            analyze_and_transform_data_paths, raw_examples_file_format,
            raw_examples_data_format, analyze_input_dataset_metadata,
            can_process_jointly)
        transform_data_list = self._MakeDatasetList(
            list(analyze_and_transform_data_paths) +
            list(transform_only_data_paths), raw_examples_file_format,
            raw_examples_data_format, transform_input_dataset_metadata,
            can_process_jointly)

        desired_batch_size = self._GetDesiredBatchSize(
            raw_examples_data_format)

        with self._CreatePipeline(outputs) as p:
            with tft_beam.Context(
                    temp_dir=temp_path,
                    desired_batch_size=desired_batch_size,
                    passthrough_keys={_TRANSFORM_INTERNAL_FEATURE_FOR_KEY},
                    use_deep_copy_optimization=True):
                # pylint: disable=expression-not-assigned
                # pylint: disable=no-value-for-parameter

                analyze_decode_fn = (self._GetDecodeFunction(
                    raw_examples_data_format,
                    analyze_input_dataset_metadata.schema))

                for (idx, dataset) in enumerate(analyze_data_list):
                    dataset.encoded = (p
                                       | 'ReadAnalysisDataset[{}]'.format(idx)
                                       >> self._ReadExamples(dataset))
                    dataset.decoded = (
                        dataset.encoded
                        | 'DecodeAnalysisDataset[{}]'.format(idx) >>
                        self._DecodeInputs(analyze_decode_fn))

                input_analysis_data = (
                    [dataset.decoded for dataset in analyze_data_list]
                    | 'FlattenAnalysisDatasets' >> beam.Flatten())
                transform_fn = ((input_analysis_data, input_dataset_metadata)
                                | 'AnalyzeDataset' >>
                                tft_beam.AnalyzeDataset(preprocessing_fn))
                # Write the raw/input metadata.
                (input_dataset_metadata
                 | 'WriteMetadata' >> tft_beam.WriteMetadata(
                     os.path.join(transform_output_path,
                                  tft.TFTransformOutput.RAW_METADATA_DIR), p))

                # WriteTransformFn writes transform_fn and metadata to subdirectories
                # tensorflow_transform.SAVED_MODEL_DIR and
                # tensorflow_transform.TRANSFORMED_METADATA_DIR respectively.
                (transform_fn | 'WriteTransformFn' >>
                 tft_beam.WriteTransformFn(transform_output_path))

                if compute_statistics or materialize_output_paths:
                    # Do not compute pre-transform stats if the input format is raw proto,
                    # as StatsGen would treat any input as tf.Example.
                    if (compute_statistics and not self._IsDataFormatProto(
                            raw_examples_data_format)):
                        # Aggregated feature stats before transformation.
                        pre_transform_feature_stats_path = os.path.join(
                            transform_output_path, tft.TFTransformOutput.
                            PRE_TRANSFORM_FEATURE_STATS_PATH)

                        # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in
                        # schema. Currently input dataset schema only contains dtypes,
                        # and other metadata is dropped due to roundtrip to tensors.
                        schema_proto = schema_utils.schema_from_feature_spec(
                            analyze_input_dataset_metadata.schema.
                            as_feature_spec())
                        ([
                            dataset.decoded
                            if stats_use_tfdv else dataset.encoded
                            for dataset in analyze_data_list
                        ]
                         | 'FlattenPreTransformAnalysisDatasets' >>
                         beam.Flatten()
                         | 'GenerateAggregatePreTransformAnalysisStats' >>
                         self._GenerateStats(pre_transform_feature_stats_path,
                                             schema_proto,
                                             use_deep_copy_optimization=True,
                                             use_tfdv=stats_use_tfdv))

                    transform_decode_fn = (self._GetDecodeFunction(
                        raw_examples_data_format,
                        transform_input_dataset_metadata.schema))
                    # transform_data_list is a superset of analyze_data_list, we pay the
                    # cost to read the same dataset (analyze_data_list) again here to
                    # prevent certain beam runner from doing large temp materialization.
                    for (idx, dataset) in enumerate(transform_data_list):
                        dataset.encoded = (
                            p
                            | 'ReadTransformDataset[{}]'.format(idx) >>
                            self._ReadExamples(dataset))
                        dataset.decoded = (
                            dataset.encoded
                            | 'DecodeTransformDataset[{}]'.format(idx) >>
                            self._DecodeInputs(transform_decode_fn))
                        (dataset.transformed, metadata) = (
                            ((dataset.decoded,
                              transform_input_dataset_metadata), transform_fn)
                            | 'TransformDataset[{}]'.format(idx) >>
                            tft_beam.TransformDataset())

                        if materialize_output_paths or not stats_use_tfdv:
                            dataset.transformed_and_encoded = (
                                dataset.transformed
                                | 'EncodeTransformedDataset[{}]'.format(idx) >>
                                beam.ParDo(self._EncodeAsExamples(), metadata))

                    if compute_statistics:
                        # Aggregated feature stats after transformation.
                        _, metadata = transform_fn
                        post_transform_feature_stats_path = os.path.join(
                            transform_output_path, tft.TFTransformOutput.
                            POST_TRANSFORM_FEATURE_STATS_PATH)

                        # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in
                        # schema. Currently input dataset schema only contains dtypes,
                        # and other metadata is dropped due to roundtrip to tensors.
                        transformed_schema_proto = schema_utils.schema_from_feature_spec(
                            metadata.schema.as_feature_spec())

                        ([(dataset.transformed if stats_use_tfdv else
                           dataset.transformed_and_encoded)
                          for dataset in transform_data_list]
                         | 'FlattenPostTransformAnalysisDatasets' >>
                         beam.Flatten()
                         | 'GenerateAggregatePostTransformAnalysisStats' >>
                         self._GenerateStats(post_transform_feature_stats_path,
                                             transformed_schema_proto,
                                             use_tfdv=stats_use_tfdv))

                        if per_set_stats_output_paths:
                            assert len(transform_data_list) == len(
                                per_set_stats_output_paths)
                            # TODO(b/67632871): Remove duplicate stats gen compute that is
                            # done both on a flattened view of the data, and on each span
                            # below.
                            bundles = zip(transform_data_list,
                                          per_set_stats_output_paths)
                            for (idx, (dataset,
                                       output_path)) in enumerate(bundles):
                                if stats_use_tfdv:
                                    data = dataset.transformed
                                else:
                                    data = dataset.transformed_and_encoded
                                (data
                                 | 'GeneratePostTransformStats[{}]'.format(idx)
                                 >> self._GenerateStats(
                                     output_path,
                                     transformed_schema_proto,
                                     use_tfdv=stats_use_tfdv))

                    if materialize_output_paths:
                        assert len(transform_data_list) == len(
                            materialize_output_paths)
                        bundles = zip(transform_data_list,
                                      materialize_output_paths)
                        for (idx, (dataset,
                                   output_path)) in enumerate(bundles):
                            (dataset.transformed_and_encoded
                             | 'Materialize[{}]'.format(idx) >>
                             self._WriteExamples(raw_examples_file_format,
                                                 output_path))

        return _Status.OK()
示例#8
0
def transform_data(input_handle,
                   outfile_prefix,
                   working_dir,
                   schema_file,
                   transform_dir=None,
                   max_rows=None,
                   pipeline_args=None):
  """The main tf.transform method which analyzes and transforms data.

  Args:
    input_handle: BigQuery table name to process specified as DATASET.TABLE or
      path to csv file with input data.
    outfile_prefix: Filename prefix for emitted transformed examples
    working_dir: Directory in which transformed examples and transform function
      will be emitted.
    schema_file: An file path that contains a text-serialized TensorFlow
      metadata schema of the input data.
    transform_dir: Directory in which the transform output is located. If
      provided, this will load the transform_fn from disk instead of computing
      it over the data. Hint: this is useful for transforming eval data.
    max_rows: Number of rows to query from BigQuery
    pipeline_args: additional DataflowRunner or DirectRunner args passed to the
      beam pipeline.
  """

  def transform_ngrams(input, ngram_range):
    """ helper function to transform ngrams and print output. """
    # this print statement causes output to concat itself!
    # input = tf.Print(input, [input], "raw input:", first_n=-1, summarize=100)

    transformed = transform.ngrams(
      tf.string_split(input, delimiter=" "),
      ngram_range=ngram_range,
      separator=' ')

    # SparseTensor basically cannot be printed because it's made up of 3
    # tensors. We can use this trick to print the values column, but without the index
    # it's not too meaningful.
    #
    # values = tf.Print(transformed.values, [transformed.values], "ngram output:")
    # transformed = tf.SparseTensor(
    #       indices=transformed.indices,
    #       values=values,
    #       dense_shape=transformed.dense_shape)
    return transformed

  def preprocessing_fn(inputs):
    """tf.transform's callback function for preprocessing inputs.
    https://cloud.google.com/solutions/machine-learning/data-preprocessing-for-ml-with-tf-transform-pt2

    Args:
      inputs: map from feature keys to raw not-yet-transformed features.

    Returns:
      Map from string feature key to transformed feature operations.
    """
    outputs = {}
    for key in taxi.DENSE_FLOAT_FEATURE_KEYS:
      print('processing key', key)
      print('input:', inputs[key])
      # Preserve this feature as a dense float, setting nan's to the mean.
      outputs[taxi.transformed_name(key)] = transform.scale_to_z_score(
          _fill_in_missing(inputs[key]))

    for key in taxi.VOCAB_FEATURE_KEYS:
      # Build a vocabulary for this feature.
      outputs[
          taxi.transformed_name(key)] = transform.compute_and_apply_vocabulary(
              _fill_in_missing(inputs[key]),
              top_k=taxi.VOCAB_SIZE,
              num_oov_buckets=taxi.OOV_SIZE)

    # for key in taxi.FEATURE_NGRAM:
    #   # Extract nggrams and build a vocab.
    #   outputs[
    #       taxi.transformed_name(key)] = transform.compute_and_apply_vocabulary(
    #           transform.ngrams(
    #             tf.string_split(_fill_in_missing(inputs[key])),
    #             ngram_range=taxi.NGRAM_RANGE,
    #             separator=' '),
    #           top_k=512,
    #           num_oov_buckets=taxi.OOV_SIZE)

    for key in taxi.FEATURE_NGRAM:
      # Extract nggrams and build a vocab.
      outputs[
          taxi.transformed_name(key)] = transform.compute_and_apply_vocabulary(
            transform_ngrams(_fill_in_missing(inputs[key]), taxi.NGRAM_RANGE),
            top_k=taxi.VOCAB_SIZE,
            num_oov_buckets=taxi.OOV_SIZE)

    for key in taxi.BUCKET_FEATURE_KEYS:
      outputs[taxi.transformed_name(key)] = transform.bucketize(
          _fill_in_missing(inputs[key]), taxi.FEATURE_BUCKET_COUNT)

    for key in taxi.CATEGORICAL_FEATURE_KEYS:
      outputs[taxi.transformed_name(key)] = _fill_in_missing(inputs[key])

    # Was this passenger a big tipper?
    taxi_fare = _fill_in_missing(inputs[taxi.FARE_KEY])
    tips = _fill_in_missing(inputs[taxi.LABEL_KEY])
    outputs[taxi.transformed_name(taxi.LABEL_KEY)] = tf.where(
        tf.is_nan(taxi_fare),
        tf.cast(tf.zeros_like(taxi_fare), tf.int64),
        # Test if the tip was > 20% of the fare.
        tf.cast(
            tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))),
            tf.int64))

    return outputs

  schema = taxi.read_schema(schema_file)
  raw_feature_spec = taxi.get_raw_feature_spec(schema)
  raw_schema = dataset_schema.from_feature_spec(raw_feature_spec)
  raw_data_metadata = dataset_metadata.DatasetMetadata(raw_schema)

  with beam.Pipeline(argv=pipeline_args) as pipeline:
    with tft_beam.Context(temp_dir=working_dir):
      if input_handle.lower().endswith('csv'):
        csv_coder = taxi.make_csv_coder(schema, input_handle.lower())
        raw_data = (
            pipeline
            | 'ReadFromText' >> beam.io.ReadFromText(
                input_handle, skip_header_lines=1))
        decode_transform = beam.Map(csv_coder.decode)
      else:
        query = taxi.make_sql(input_handle, max_rows, for_eval=False)
        raw_data = (
            pipeline
            | 'ReadBigQuery' >> beam.io.Read(
                beam.io.BigQuerySource(query=query, use_standard_sql=True)))
        decode_transform = beam.Map(
            taxi.clean_raw_data_dict, raw_feature_spec=raw_feature_spec)

      if transform_dir is None:
        decoded_data = raw_data | 'DecodeForAnalyze' >> decode_transform
        transform_fn = (
            (decoded_data, raw_data_metadata) |
            ('Analyze' >> tft_beam.AnalyzeDataset(preprocessing_fn)))

        _ = (
            transform_fn
            | ('WriteTransformFn' >>
               tft_beam.WriteTransformFn(working_dir)))
      else:
        transform_fn = pipeline | tft_beam.ReadTransformFn(transform_dir)

      # Shuffling the data before materialization will improve Training
      # effectiveness downstream. Here we shuffle the raw_data (as opposed to
      # decoded data) since it has a compact representation.
      shuffled_data = raw_data | 'RandomizeData' >> beam.transforms.Reshuffle()

      decoded_data = shuffled_data | 'DecodeForTransform' >> decode_transform
      (transformed_data, transformed_metadata) = (
          ((decoded_data, raw_data_metadata), transform_fn)
          | 'Transform' >> tft_beam.TransformDataset())

      coder = example_proto_coder.ExampleProtoCoder(transformed_metadata.schema)
      _ = (
          transformed_data
          | 'SerializeExamples' >> beam.Map(coder.encode)
          | 'WriteExamples' >> beam.io.WriteToTFRecord(
              os.path.join(working_dir, outfile_prefix), file_name_suffix='.gz')
      )
def compute_gci_vocab(input_handle,
                      working_dir,
                      schema_file,
                      transform_dir=None,
                      max_rows=None,
                      pipeline_args=None):
    """The main tf.transform method which analyzes and transforms data.
    Args:
      input_handle: BigQuery table name to process specified as DATASET.TABLE or
        path to csv file with input data.
      outfile_prefix: Filename prefix for emitted transformed examples
      working_dir: Directory in which transformed examples and transform function
        will be emitted.
      schema_file: An file path that contains a text-serialized TensorFlow
        metadata schema of the input data.
      transform_dir: Directory in which the transform output is located. If
        provided, this will load the transform_fn from disk instead of computing
        it over the data. Hint: this is useful for transforming eval data.
      max_rows: Number of rows to query from BigQuery
      pipeline_args: additional DataflowRunner or DirectRunner args passed to the
        beam pipeline.
    """
    def preprocessing_fn(inputs):
        """tf.transform's callback function for preprocessing inputs.
        Args:
          inputs: map from feature keys to raw not-yet-transformed features.
        Returns:
          Map from string feature key to transformed feature operations.
        """
        outputs = {}
        DENSE_FLOAT_FEATURE_KEYS = []
        VOCAB_FEATURE_KEYS = []
        _CSV_COLUMNS_NAMES, _CSV_COLUMN_DEFAULTS, _CSV_COLUMN_types, _UNUSED = setcolumn_list_original(
        )
        for i in range(len(_CSV_COLUMNS_NAMES)):
            if _CSV_COLUMN_types[i] is tf.string:
                VOCAB_FEATURE_KEYS.append(_CSV_COLUMNS_NAMES[i])

        outputs['gci'] = tf.expand_dims(_fill_in_missing(inputs['gci']), 1)
        for key in VOCAB_FEATURE_KEYS:
            if key in _UNUSED:
                continue
            if 'gci' in key:
                appendlist = tf.expand_dims(_fill_in_missing(inputs[key]), 1)
                outputs['gci'] = tf.concat([appendlist, outputs['gci']], 0)
        transform.vocabulary(outputs['gci'], vocab_filename='gci')
        transform.vocabulary(inputs['LAT_LON_10'], vocab_filename='label')
        return outputs

    schema = read_schema(schema_file)
    raw_feature_spec = get_raw_feature_spec(schema)
    raw_schema = dataset_schema.from_feature_spec(raw_feature_spec)
    raw_data_metadata = dataset_metadata.DatasetMetadata(raw_schema)

    with beam.Pipeline(argv=pipeline_args) as pipeline:
        with tft_beam.Context(temp_dir=working_dir):
            csv_coder = make_csv_coder(schema)
            raw_data = (pipeline
                        | 'ReadFromText' >> beam.io.ReadFromText(
                            input_handle, skip_header_lines=1))
            decode_transform = beam.Map(csv_coder.decode)

            decoded_data = raw_data | 'DecodeForAnalyze' >> decode_transform
            transform_fn = (
                (decoded_data, raw_data_metadata) |
                ('Analyze' >> tft_beam.AnalyzeDataset(preprocessing_fn)))

            _ = (
                transform_fn
                |
                ('WriteTransformFn' >> tft_beam.WriteTransformFn(working_dir)))
示例#10
0
    def test_non_frequency_vocabulary_merge(self):
        """This test compares vocabularies produced with and without cache."""

        mi_vocab_name = 'mutual_information_vocab'
        adjusted_mi_vocab_name = 'adjusted_mutual_information_vocab'
        weighted_frequency_vocab_name = 'weighted_frequency_vocab'

        def preprocessing_fn(inputs):
            _ = tft.vocabulary(inputs['s'],
                               labels=inputs['label'],
                               store_frequency=True,
                               vocab_filename=mi_vocab_name,
                               min_diff_from_avg=0.1,
                               use_adjusted_mutual_info=False)

            _ = tft.vocabulary(inputs['s'],
                               labels=inputs['label'],
                               store_frequency=True,
                               vocab_filename=adjusted_mi_vocab_name,
                               min_diff_from_avg=1.0,
                               use_adjusted_mutual_info=True)

            _ = tft.vocabulary(inputs['s'],
                               weights=inputs['weight'],
                               store_frequency=True,
                               vocab_filename=weighted_frequency_vocab_name,
                               use_adjusted_mutual_info=False)
            return inputs

        span_0_key = 'span-0'
        span_1_key = 'span-1'

        input_data = [
            dict(s='a', weight=1, label=1),
            dict(s='a', weight=0.5, label=1),
            dict(s='b', weight=0.75, label=1),
            dict(s='b', weight=1, label=0),
        ]
        input_metadata = dataset_metadata.DatasetMetadata(
            schema_utils.schema_from_feature_spec({
                's':
                tf.io.FixedLenFeature([], tf.string),
                'label':
                tf.io.FixedLenFeature([], tf.int64),
                'weight':
                tf.io.FixedLenFeature([], tf.float32),
            }))
        input_data_dict = {
            span_0_key: input_data,
            span_1_key: input_data,
        }

        with _TestPipeline() as p:
            flat_data = p | 'CreateInputData' >> beam.Create(
                list(itertools.chain(*input_data_dict.values())))

            # wrap each value in input_data_dict as a pcoll.
            input_data_pcoll_dict = {}
            for a, b in six.iteritems(input_data_dict):
                input_data_pcoll_dict[a] = p | a >> beam.Create(b)

            transform_fn_with_cache, output_cache = (
                (flat_data, input_data_pcoll_dict, {}, input_metadata)
                | tft_beam.AnalyzeDatasetWithCache(preprocessing_fn))
            transform_fn_with_cache_dir = os.path.join(
                self.base_test_dir, 'transform_fn_with_cache')
            _ = transform_fn_with_cache | tft_beam.WriteTransformFn(
                transform_fn_with_cache_dir)

            expected_accumulators = {
                b'__v0__VocabularyAccumulate[vocabulary]-<GhZ\xac\xb8\xa9\x8c\xce\x1c\xb2-ck\xca\xe8\xec\t%\x8f':
                [
                    b'["a", [2, [0.0, 1.0], [0.0, 0.0], 1.0]]',
                    b'["b", [2, [0.5, 0.5], [0.0, 0.0], 1.0]]',
                    b'["global_y_count_sentinel", [4, [0.25, 0.75], [0.0, 0.0], '
                    b'1.0]]'
                ],
                b'__v0__VocabularyAccumulate[vocabulary_1]-\xa6\xae\nd\xe3\xd1\x9f\xa0\xe2\xb4\x05j\xa5\xfd\x8c\xfaeN\xd1\x1f':
                [
                    b'["a", [2, [0.0, 1.0], [0.0, 0.0], 1.0]]',
                    b'["b", [2, [0.5, 0.5], [0.0, 0.0], 1.0]]',
                    b'["global_y_count_sentinel", [4, [0.25, 0.75], [0.0, 0.0], '
                    b'1.0]]'
                ],
                b"__v0__VocabularyAccumulate[vocabulary_2]-\x97\x1c>\x851\x94'\xdc\xdf\xfd\xcc\x86\xb7\xb8\xe1\xe8*\x89B\t":
                [b'["a", 1.5]', b'["b", 1.75]'],
            }
            spans = [span_0_key, span_1_key]
            self.assertCountEqual(output_cache.keys(), spans)
            for span in spans:
                self.assertCountEqual(output_cache[span].keys(),
                                      expected_accumulators.keys())
                for idx, (key, value) in enumerate(
                        six.iteritems(expected_accumulators)):
                    beam_test_util.assert_that(
                        output_cache[span][key],
                        beam_test_util.equal_to(value),
                        label='AssertCache[{}][{}]'.format(span, idx))

        # 4 from analysis on each of the input spans.
        self.assertEqual(_get_counter_value(p.metrics, 'num_instances'), 8)
        self.assertEqual(
            _get_counter_value(p.metrics, 'cache_entries_decoded'), 0)
        self.assertEqual(
            _get_counter_value(p.metrics, 'cache_entries_encoded'), 6)
        self.assertEqual(_get_counter_value(p.metrics, 'saved_models_created'),
                         2)

        with _TestPipeline() as p:
            flat_data = p | 'CreateInputData' >> beam.Create(input_data * 2)

            transform_fn_no_cache = ((flat_data, input_metadata)
                                     |
                                     tft_beam.AnalyzeDataset(preprocessing_fn))

            transform_fn_no_cache_dir = os.path.join(self.base_test_dir,
                                                     'transform_fn_no_cache')
            _ = transform_fn_no_cache | tft_beam.WriteTransformFn(
                transform_fn_no_cache_dir)

        # 4 from analysis on each of the input spans.
        self.assertEqual(_get_counter_value(p.metrics, 'num_instances'), 8)
        self.assertEqual(
            _get_counter_value(p.metrics, 'cache_entries_decoded'), 0)
        self.assertEqual(
            _get_counter_value(p.metrics, 'cache_entries_encoded'), 0)
        self.assertEqual(_get_counter_value(p.metrics, 'saved_models_created'),
                         2)

        tft_output_cache = tft.TFTransformOutput(transform_fn_with_cache_dir)
        tft_output_no_cache = tft.TFTransformOutput(transform_fn_no_cache_dir)

        for vocab_filename in (mi_vocab_name, adjusted_mi_vocab_name,
                               weighted_frequency_vocab_name):
            cache_path = tft_output_cache.vocabulary_file_by_name(
                vocab_filename)
            no_cache_path = tft_output_no_cache.vocabulary_file_by_name(
                vocab_filename)
            with tf.io.gfile.GFile(cache_path, 'rb') as f1, tf.io.gfile.GFile(
                    no_cache_path, 'rb') as f2:
                self.assertEqual(
                    f1.readlines(), f2.readlines(),
                    'vocab with cache != vocab without cache for: {}'.format(
                        vocab_filename))
示例#11
0
    'y': 2,
    's': 'world'
}, {
    'x': 3,
    'y': 3,
    's': 'hello'
}]

transformed_dataset, transform_fn = (
    (raw_data, raw_data_metadata)
    | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))

# NOTE: AnalyzeAndTransformDataset is the amalgamation of two tft_beam functions:
#  transformed_data, transform_fn = (my_data | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
# same as:
# a = tft_beam.AnalyzeDataset(preprocessing_fn)
# transform_fn = a.expand(my_data)   # my_data is a dataset, applies preprocessing_fn, returns a transform_fn objA
#       transform_fn is a pure function that is applied to every row of incoming dataset
#       at this point, tf.Transform analyzers (like tft.mean() have already been computed and are constants,
#       so transform_fn has constants for the mean of column x, the min and max of column y, i
#       and the vocabulary used to map the strings to integers
# all aggregation of data happens in AnalyzeDataset
# tranform_fun represented as a Tensorflow graph, so can be embedded into serving graph
transform_fn = my_data | tft_beam.AnalyzeDataset(preprocessing_fn)
# t = tft_beam.TransformDataset()   # instantiate this class
# transformed_data = t.expand( (my_data, transform_fn) )    # takes in a 2-tuple, outputs "dataset"
transformed_data = (my_data, transform_fn) | tft_beam.TransformDataset()
# where:
#  my_data is a "dataset": a typ
transformed_data, transformed_metadata = transformed_dataset
示例#12
0
def transform_data(bq_table,
                   step,
                   schema_file,
                   working_dir,
                   outfile_prefix,
                   max_rows=None,
                   transform_dir=None,
                   pipeline_args=None):
    # todo : documentation
    """

    :param project:
    :param dataset:
    :param table:
    :param step:
    :param negative_sampling_ratio:
    :param train_cut:
    :param test_tenth:
    :param schema_file:
    :param working_dir:
    :param outfile_prefix:
    :param transform_dir:
    :param pipeline_args:
    :return:
    """

    def preprocessing_fn(inputs):
        """tf.transform's callback function for preprocessing inputs.

        Args:
          inputs: map from feature keys to raw not-yet-transformed features.

        Returns:
          Map from string feature key to transformed feature operations.
        """
        outputs = {}
        for key in my_metadata.NUMERIC_FEATURE_KEYS:
            # Preserve this feature as a dense float, setting nan's to the mean.
            outputs[my_metadata.transformed_name(key)] = transform.scale_to_z_score(_fill_in_missing(inputs[key]))

        for key in my_metadata.VOCAB_FEATURE_KEYS:
            # Build a vocabulary for this feature.
            outputs[my_metadata.transformed_name(key)] = transform.compute_and_apply_vocabulary(
                _fill_in_missing(inputs[key]),
                vocab_filename=my_metadata.transformed_name(key),
                num_oov_buckets=my_metadata.OOV_SIZE,
                top_k=my_metadata.VOCAB_SIZE
            )

        for key, hash_buckets in my_metadata.HASH_STRING_FEATURE_KEYS.items():
            outputs[my_metadata.transformed_name(key)] = transform.hash_strings(
                _fill_in_missing(inputs[key]),
                hash_buckets=hash_buckets
            )

        for key, nb_buckets in my_metadata.TO_BE_BUCKETIZED_FEATURE.items():
            outputs[my_metadata.transformed_name(key +'_bucketized')] = transform.bucketize(
                _fill_in_missing(inputs[key]), nb_buckets)


        # Was this passenger a big tipper?
        taxi_fare = _fill_in_missing(inputs[my_metadata.FARE_KEY])
        tips = _fill_in_missing(inputs[my_metadata.LABEL_KEY])
        outputs[my_metadata.transformed_name(my_metadata.LABEL_KEY)] = tf.where(
            tf.is_nan(taxi_fare),
            tf.cast(tf.zeros_like(taxi_fare), tf.int64),
            # Test if the tip was > 20% of the fare.
            tf.cast(
                tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))),
                tf.int64))

        return outputs

    schema = my_metadata.read_schema(schema_file)
    raw_feature_spec = my_metadata.get_raw_feature_spec(schema)
    raw_schema = dataset_schema.from_feature_spec(raw_feature_spec)
    raw_data_metadata = dataset_metadata.DatasetMetadata(raw_schema)

    with beam.Pipeline(argv=pipeline_args) as pipeline:
        with tft_beam.Context(temp_dir=working_dir):
            query = sql_queries.get_train_test_sql_query(bq_table, step, max_rows)
            raw_data = (
                    pipeline
                    | 'ReadBigQuery' >> beam.io.Read(
                beam.io.BigQuerySource(query=query, use_standard_sql=True))
                    | 'CleanData' >> beam.Map(
                my_metadata.clean_raw_data_dict, raw_feature_spec=raw_feature_spec))

            if transform_dir is None:
                transform_fn = (
                        (raw_data, raw_data_metadata)
                        | ('Analyze' >> tft_beam.AnalyzeDataset(preprocessing_fn)))

                _ = (
                        transform_fn
                        | ('WriteTransformFn' >>
                           tft_beam.WriteTransformFn(working_dir)))
            else:
                transform_fn = pipeline | tft_beam.ReadTransformFn(transform_dir)

            # Shuffling the data before materialization will improve Training
            # effectiveness downstream.
            shuffled_data = raw_data | 'RandomizeData' >> beam.transforms.Reshuffle()

            (transformed_data, transformed_metadata) = (
                    ((shuffled_data, raw_data_metadata), transform_fn)
                    | 'Transform' >> tft_beam.TransformDataset())

            coder = example_proto_coder.ExampleProtoCoder(transformed_metadata.schema)
            _ = (
                    transformed_data
                    | 'SerializeExamples' >> beam.Map(coder.encode)
                    | 'WriteExamples' >> beam.io.WriteToTFRecord(
                os.path.join(working_dir, outfile_prefix))
            )