Example #1
0
  def test_example_proto_coder_bad_default_value(self):
    input_schema = dataset_schema.from_feature_spec({
        'scalar_feature_2': tf.FixedLenFeature(shape=[2], dtype=tf.float32,
                                               default_value=[1.0]),
    })
    with self.assertRaisesRegexp(ValueError,
                                 'got default value with incorrect shape'):
      example_proto_coder.ExampleProtoCoder(input_schema)

    input_schema = dataset_schema.from_feature_spec({
        'scalar_feature_2': tf.FixedLenFeature(shape=[], dtype=tf.float32,
                                               default_value=[0.0]),
    })
    with self.assertRaisesRegexp(ValueError,
                                 'got default value with incorrect shape'):
      example_proto_coder.ExampleProtoCoder(input_schema)

    input_schema = dataset_schema.from_feature_spec({
        '2d_vector_feature':
            tf.FixedLenFeature(
                shape=[2, 3],
                dtype=tf.float32,
                default_value=[[1.0, 1.0], [1.0]]),
    })
    with self.assertRaisesRegexp(ValueError,
                                 'got default value with incorrect shape'):
      example_proto_coder.ExampleProtoCoder(input_schema)
Example #2
0
  def testMalformedSparseFeatures(self):
    tensors = {
        'a': tf.sparse_placeholder(tf.int64),
    }

    # Invalid indices.
    schema = dataset_schema.from_feature_spec({
        'a': tf.SparseFeature('idx', 'val', tf.float32, 10)
    })
    instances = [{'a': ([-1, 2], [1.0, 2.0])}]
    with self.assertRaisesRegexp(
        ValueError, 'has index .* out of range'):
      impl_helper.make_feed_dict(tensors, schema, instances)

    instances = [{'a': ([11, 1], [1.0, 2.0])}]
    with self.assertRaisesRegexp(
        ValueError, 'has index .* out of range'):
      impl_helper.make_feed_dict(tensors, schema, instances)

    # Indices and values of different lengths.
    schema = dataset_schema.from_feature_spec({
        'a': tf.SparseFeature('idx', 'val', tf.float32, 10)
    })
    instances = [{'a': ([1, 2], [1])}]
    with self.assertRaisesRegexp(
        ValueError, 'indices and values of different lengths'):
      impl_helper.make_feed_dict(tensors, schema, instances)

    # Tuple of the wrong length.
    instances = [{'a': ([1], [2], [3])}]
    with self.assertRaisesRegexp(
        ValueError, 'too many values to unpack'):
      impl_helper.make_feed_dict(tensors, schema, instances)
Example #3
0
  def testMakeOutputDictErrorSparse(self):
    schema = dataset_schema.from_feature_spec({
        'a': tf.VarLenFeature(tf.string)
    })

    # SparseTensor that cannot be represented as VarLenFeature.
    fetches = {
        'a': tf.SparseTensorValue(indices=np.array([(0, 2), (0, 4), (0, 8)]),
                                  values=np.array([10.0, 20.0, 30.0]),
                                  dense_shape=(1, 20))
    }
    with self.assertRaisesRegexp(
        ValueError, 'cannot be decoded by ListColumnRepresentation'):
      impl_helper.to_instance_dicts(schema, fetches)

    # SparseTensor of invalid rank.
    fetches = {
        'a': tf.SparseTensorValue(
            indices=np.array([(0, 0, 1), (0, 0, 2), (0, 0, 3)]),
            values=np.array([10.0, 20.0, 30.0]),
            dense_shape=(1, 10, 10))
    }
    with self.assertRaisesRegexp(
        ValueError, 'cannot be decoded by ListColumnRepresentation'):
      impl_helper.to_instance_dicts(schema, fetches)

    # SparseTensor with indices that are out of order.
    fetches = {
        'a': tf.SparseTensorValue(indices=np.array([(0, 2), (2, 4), (1, 8)]),
                                  values=np.array([10.0, 20.0, 30.0]),
                                  dense_shape=(3, 20))
    }
    with self.assertRaisesRegexp(
        ValueError, 'Encountered out-of-order sparse index'):
      impl_helper.to_instance_dicts(schema, fetches)

    # SparseTensors with different batch dimension sizes.
    schema = dataset_schema.from_feature_spec({
        'a': tf.VarLenFeature(tf.string),
        'b': tf.VarLenFeature(tf.string)
    })
    fetches = {
        'a': tf.SparseTensorValue(indices=np.array([(0, 0)]),
                                  values=np.array([10.0]),
                                  dense_shape=(1, 20)),
        'b': tf.SparseTensorValue(indices=np.array([(0, 0)]),
                                  values=np.array([10.0]),
                                  dense_shape=(2, 20))
    }
    with self.assertRaisesRegexp(
        ValueError,
        r'Inconsistent batch sizes: "\w" had batch dimension \d, "\w" had batch'
        r' dimension \d'):
      impl_helper.to_instance_dicts(schema, fetches)
Example #4
0
 def test_decode_error(self, feature_spec, ascii_proto, error_msg,
                       error_type=ValueError, **kwargs):
   schema = dataset_schema.from_feature_spec(feature_spec)
   coder = example_proto_coder.ExampleProtoCoder(schema, **kwargs)
   serialized_proto = _ascii_to_binary(ascii_proto)
   with self.assertRaisesRegexp(error_type, error_msg):
     coder.decode(serialized_proto)
def make_csv_coder(schema):
    """Return a coder for tf.transform to read csv files."""
    raw_feature_spec = get_raw_feature_spec(schema)
    parsing_schema = dataset_schema.from_feature_spec(raw_feature_spec)
    return tft_coders.CsvCoder(_CSV_COLUMNS_NAMES,
                               parsing_schema,
                               delimiter='|')
Example #6
0
def build_serving_input_fn(tf_transform_beam, params):
    """Creates an input function reading from raw data.
    Args:
    Returns:
        The serving input function.
    """
    TRANSFORM_INPUT_SCHEMA = dataset_schema.from_feature_spec({
        'text':
        tf.FixedLenFeature(shape=[], dtype=tf.string),
        "chars":
        tf.FixedLenFeature(
            shape=[int(params.sentence_len),
                   int(params.word_len)],
            dtype=tf.int64),
        'sentence_length':
        tf.FixedLenFeature(shape=[], dtype=tf.int64),
        'chars_in_word':
        tf.FixedLenFeature(shape=[int(params.sentence_len)], dtype=tf.int64),
    })
    raw_feature_spec = TRANSFORM_INPUT_SCHEMA.as_feature_spec()

    def serving_input_fn():
        """
            Receiver function that converts raw features into transformed features
            :return: ServingInputReceiver
            """
        raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
            raw_feature_spec)
        raw_features, _, _ = raw_input_fn()
        transformed_features = tf_transform_beam.transform_raw_features(
            raw_features)
        return tf.estimator.export.ServingInputReceiver(
            transformed_features, raw_features)

    return serving_input_fn
Example #7
0
def tfrecord_schema():
    return dataset_schema.from_feature_spec({
        'label':
        tf.FixedLenFeature([], dtype=tf.int64),
        'image':
        tf.FixedLenFeature([], dtype=tf.string)
    })
def make_input_schema(mode=tf.contrib.learn.ModeKeys.TRAIN):
  """Input schema definition.

  Args:
    mode: tf.contrib.learn.ModeKeys specifying if the schema is being used for
      train/eval or prediction.
  Returns:
    A `Schema` object.
  """

  result = {}
  result[LABEL_COLUMN] = tf.FixedLenFeature(shape=[], dtype=tf.int64)
  result[DISPLAY_ID_COLUMN] = tf.FixedLenFeature(shape=[], dtype=tf.float32)
  #result[AD_ID_COLUMN] = tf.VarLenFeature(dtype=tf.float32)
  result[IS_LEAK_COLUMN] = tf.FixedLenFeature(shape=[], dtype=tf.int64)

  for name in BOOL_COLUMNS:
    result[name] = tf.VarLenFeature(dtype=tf.int64)
  #TODO: Create dummy features that indicates whether any of the numeric features is null 
  #(currently default 0 value might introduce noise)
  for name in FLOAT_COLUMNS_LOG_BIN_TRANSFORM+FLOAT_COLUMNS_SIMPLE_BIN_TRANSFORM:
    result[name] = tf.FixedLenFeature(shape=[], dtype=tf.float32, default_value=0.0)  
  for name in INT_COLUMNS:
    result[name] = tf.FixedLenFeature(shape=[], dtype=tf.float32, default_value=0.0)
  for name in CATEGORICAL_COLUMNS:
    result[name] = tf.VarLenFeature(dtype=tf.float32)
  for multi_category in DOC_CATEGORICAL_MULTIVALUED_COLUMNS:
    for category in DOC_CATEGORICAL_MULTIVALUED_COLUMNS[multi_category]:
      result[category] = tf.VarLenFeature(dtype=tf.float32)

  return dataset_schema.from_feature_spec(result)
Example #9
0
  def testMakeOutputDictVarLen(self):
    # Specifically test the empty ndarray optimization codepaths.
    schema = dataset_schema.from_feature_spec({
        'a': tf.VarLenFeature(tf.int64),
        'b': tf.VarLenFeature(tf.float32),
        'c': tf.VarLenFeature(tf.string),
    })

    fetches = {
        'a': tf.SparseTensorValue(
            indices=np.array([(0, 0), (2, 0)]),
            values=np.array([0, 1], np.int64),
            dense_shape=(4, 1)),
        'b': tf.SparseTensorValue(
            indices=np.array([(0, 0), (2, 0)]),
            values=np.array([0.5, 1.5], np.float32),
            dense_shape=(4, 1)),
        'c': tf.SparseTensorValue(
            indices=np.array([(0, 0), (2, 0)]),
            values=np.array(['hello', 'goodbye'], np.object),
            dense_shape=(4, 1)),
    }

    instance_dicts = impl_helper.to_instance_dicts(schema, fetches)
    self.assertEqual(4, len(instance_dicts))
    self.assertEqual(instance_dicts[1]['a'].dtype, np.int64)
    self.assertEqual(instance_dicts[3]['a'].dtype, np.int64)
    self.assertEqual(instance_dicts[1]['b'].dtype, np.float32)
    self.assertEqual(instance_dicts[3]['b'].dtype, np.float32)
    self.assertEqual(instance_dicts[1]['c'].dtype, np.object)
    self.assertEqual(instance_dicts[3]['c'].dtype, np.object)
Example #10
0
  def test_preprocessing_fn(self):
    schema_file = os.path.join(self._testdata_path, 'schema_gen/schema.pbtxt')
    schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema())
    feature_spec = taxi_utils._get_raw_feature_spec(schema)
    working_dir = self.get_temp_dir()
    transform_output_path = os.path.join(working_dir, 'transform_output')
    transformed_examples_path = os.path.join(
        working_dir, 'transformed_examples')

    # Run very simplified version of executor logic.
    # TODO(kestert): Replace with tft_unit.assertAnalyzeAndTransformResults.
    # Generate legacy `DatasetMetadata` object.  Future version of Transform
    # will accept the `Schema` proto directly.
    legacy_metadata = dataset_metadata.DatasetMetadata(
        dataset_schema.from_feature_spec(feature_spec))
    decoder = tft.coders.ExampleProtoCoder(legacy_metadata.schema)
    with beam.Pipeline() as p:
      with tft_beam.Context(temp_dir=os.path.join(working_dir, 'tmp')):
        examples = (
            p
            | 'ReadTrainData' >> beam.io.ReadFromTFRecord(
                os.path.join(self._testdata_path, 'csv_example_gen/train/*'),
                coder=beam.coders.BytesCoder(),
                # TODO(b/114938612): Eventually remove this override.
                validate=False)
            | 'DecodeTrainData' >> beam.Map(decoder.decode))
        (transformed_examples, transformed_metadata), transform_fn = (
            (examples, legacy_metadata)
            | 'AnalyzeAndTransform' >> tft_beam.AnalyzeAndTransformDataset(
                taxi_utils.preprocessing_fn))

        # WriteTransformFn writes transform_fn and metadata to subdirectories
        # tensorflow_transform.SAVED_MODEL_DIR and
        # tensorflow_transform.TRANSFORMED_METADATA_DIR respectively.
        # pylint: disable=expression-not-assigned
        (transform_fn
         | 'WriteTransformFn' >> tft_beam.WriteTransformFn(
             transform_output_path))

        encoder = tft.coders.ExampleProtoCoder(transformed_metadata.schema)
        (transformed_examples
         | 'EncodeTrainData' >> beam.Map(encoder.encode)
         | 'WriteTrainData' >> beam.io.WriteToTFRecord(
             os.path.join(transformed_examples_path,
                          'train/transformed_exmaples.gz'),
             coder=beam.coders.BytesCoder()))
        # pylint: enable=expression-not-assigned

    # Verify the output matches golden output.
    # NOTE: we don't verify that transformed examples match golden output.
    expected_transformed_schema = io_utils.parse_pbtxt_file(
        os.path.join(
            self._testdata_path,
            'transform/transform_output/transformed_metadata/schema.pbtxt'),
        schema_pb2.Schema())
    transformed_schema = io_utils.parse_pbtxt_file(
        os.path.join(transform_output_path,
                     'transformed_metadata/schema.pbtxt'),
        schema_pb2.Schema())
    self.assertEqual(transformed_schema, expected_transformed_schema)
 def test_make_feed_list(self, feature_spec, instances, feed_dict):
     schema = dataset_schema.from_feature_spec(feature_spec)
     feature_names = list(feature_spec.keys())
     expected_feed_list = [feed_dict[key] for key in feature_names]
     np.testing.assert_equal(
         impl_helper.make_feed_list(feature_names, schema, instances),
         expected_feed_list)
Example #12
0
  def test_example_proto_coder_default_value(self):
    input_schema = dataset_schema.from_feature_spec({
        'scalar_feature_3':
            tf.FixedLenFeature(shape=[], dtype=tf.float32, default_value=1.0),
        'scalar_feature_4':
            tf.FixedLenFeature(shape=[], dtype=tf.float32, default_value=0.0),
        '1d_vector_feature':
            tf.FixedLenFeature(
                shape=[1], dtype=tf.float32, default_value=[2.0]),
        '2d_vector_feature':
            tf.FixedLenFeature(
                shape=[2, 2],
                dtype=tf.float32,
                default_value=[[1.0, 2.0], [3.0, 4.0]]),
    })
    coder = example_proto_coder.ExampleProtoCoder(input_schema)

    # Python types.
    example_proto_text = """
    features {
    }
    """
    example = tf.train.Example()
    text_format.Merge(example_proto_text, example)
    data = example.SerializeToString()

    # Assert the data is decoded into the expected format.
    expected_decoded = {
        'scalar_feature_3': 1.0,
        'scalar_feature_4': 0.0,
        '1d_vector_feature': [2.0],
        '2d_vector_feature': [[1.0, 2.0], [3.0, 4.0]],
    }
    decoded = coder.decode(data)
    np.testing.assert_equal(expected_decoded, decoded)
Example #13
0
    def test_decode_errors(self):
        input_schema = dataset_schema.from_feature_spec({
            'b':
            tf.FixedLenFeature(shape=[], dtype=tf.float32),
            'a':
            tf.FixedLenFeature(shape=[], dtype=tf.string),
        })
        coder = csv_coder.CsvCoder(column_names=['a', 'b'],
                                   schema=input_schema)

        # Test bad csv.
        with self.assertRaisesRegexp(
                csv_coder.DecodeError,
                '\'int\' object has no attribute \'encode\': 123'):
            coder.decode(123)

        # Test extra column.
        with self.assertRaisesRegexp(
                csv_coder.DecodeError,
                'Columns do not match specified csv headers'):
            coder.decode('1,2,')

        # Test missing column.
        with self.assertRaisesRegexp(
                csv_coder.DecodeError,
                'Columns do not match specified csv headers'):
            coder.decode('a_value')

        # Test empty row.
        with self.assertRaisesRegexp(
                csv_coder.DecodeError,
                'Columns do not match specified csv headers'):
            coder.decode('')
Example #14
0
def make_input_schema(mode=tf.contrib.learn.ModeKeys.TRAIN):
    """Input schema definition.

  Args:
    mode: tf.contrib.learn.ModeKeys specifying if the schema is being used for
      train/eval or prediction.
  Returns:
    A `Schema` object.
  """
    result = ({} if mode == tf.contrib.learn.ModeKeys.INFER else {
        'score': tf.FixedLenFeature(shape=[], dtype=tf.float32)
    })
    result.update({
        'subreddit':
        tf.FixedLenFeature(shape=[], dtype=tf.string),
        'author':
        tf.FixedLenFeature(shape=[], dtype=tf.string),
        'comment_body':
        tf.FixedLenFeature(shape=[], dtype=tf.string, default_value=''),
        'comment_parent_body':
        tf.FixedLenFeature(shape=[], dtype=tf.string, default_value=''),
        'toplevel':
        tf.FixedLenFeature(shape=[], dtype=tf.int64),
    })
    return dataset_schema.from_feature_spec(result)
Example #15
0
def serve_input_fn(tf_transform_beam=None):
    SENTENCE_MAX_LENGTH = 50
    raw_feature_spec = dataset_schema.from_feature_spec({
        "text":
        tf.FixedLenFeature(shape=[], dtype=tf.string),
        "head":
        tf.FixedLenFeature(shape=[], dtype=tf.string),
        "taill":
        tf.FixedLenFeature(shape=[], dtype=tf.string),
        "distance_to_head":
        tf.FixedLenFeature(shape=[SENTENCE_MAX_LENGTH], dtype=tf.int64),
        "distance_to_tail":
        tf.FixedLenFeature(shape=[SENTENCE_MAX_LENGTH], dtype=tf.int64),
        "sentence_length":
        tf.FixedLenFeature(shape=[], dtype=tf.int64),
        "relation":
        tf.FixedLenFeature(shape=[], dtype=tf.int64),
    })

    def build_serving_input_fn(tf_transform_beam=None):
        """
        Receiver function that converts raw features into transformed features
        :return: ServingInputReceiver
        """
        raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
            raw_feature_spec.as_feature_spec())
        raw_features, _, _ = raw_input_fn()
        transformed_features = tf_transform_beam.transform_raw_features(
            raw_features)
        return tf.estimator.export.ServingInputReceiver(
            transformed_features, raw_features)

    return build_serving_input_fn(tf_transform_beam)
 def test_decode_non_serialized(self, feature_spec, ascii_proto, instance,
                                **kwargs):
   schema = dataset_schema.from_feature_spec(feature_spec)
   coder = example_proto_coder.ExampleProtoCoder(
       schema, serialized=False, **kwargs)
   proto = _ascii_to_example(ascii_proto)
   np.testing.assert_equal(coder.decode(proto), instance)
Example #17
0
  def test_example_proto_coder_error(self):
    input_schema = dataset_schema.from_feature_spec({
        '2d_vector_feature': tf.FixedLenFeature(shape=[2, 2], dtype=tf.int64),
    })
    coder = example_proto_coder.ExampleProtoCoder(input_schema)

    example_decoded_value = {
        '2d_vector_feature': [1, 2, 3]
    }
    example_proto_text = """
    features {
      feature { key: "1d_vector_feature"
                value { int64_list { value: [ 1, 2, 3 ] } } }
    }
    """
    example = tf.train.Example()
    text_format.Merge(example_proto_text, example)

    # Ensure that we raise an exception for trying to encode invalid data.
    with self.assertRaisesRegexp(ValueError, 'got wrong number of values'):
      _ = coder.encode(example_decoded_value)

    # Ensure that we raise an exception for trying to parse invalid data.
    with self.assertRaisesRegexp(ValueError, 'got wrong number of values'):
      _ = coder.decode(example.SerializeToString())
Example #18
0
def main(_):
    # Define schema.
    raw_metadata = dataset_metadata.DatasetMetadata(
        dataset_schema.from_feature_spec({
            'text':
            tf.FixedLenFeature([], tf.string),
            'language_code':
            tf.FixedLenFeature([], tf.string),
        }))

    # Add in padding tokens.
    reserved_tokens = FLAGS.reserved_tokens
    if FLAGS.num_pad_tokens:
        padded_tokens = ['<pad>']
        padded_tokens += [
            '<pad%d>' % i for i in range(1, FLAGS.num_pad_tokens)
        ]
        reserved_tokens = padded_tokens + reserved_tokens

    params = learner.Params(FLAGS.upper_thresh, FLAGS.lower_thresh,
                            FLAGS.num_iterations, FLAGS.max_input_tokens,
                            FLAGS.max_token_length, FLAGS.max_unique_chars,
                            FLAGS.vocab_size, FLAGS.slack_ratio,
                            FLAGS.include_joiner_token, FLAGS.joiner,
                            reserved_tokens)

    generate_vocab(FLAGS.data_file, FLAGS.vocab_file, FLAGS.metrics_file,
                   raw_metadata, params)
Example #19
0
    def __init__(self):
        self.FEATURES = []
        # vocabulary features
        self.FEATURES.append('share_id')
        # self.FEATURES.append('time')
        # float features
        self.FEATURES.append('close_b0')
        self.FEATURES.append('close_b1')
        self.FEATURES.append('close_b2')
        self.FEATURES.append('close_b3')
        self.FEATURES.append('close_b4')
        self.FEATURES.append('close_b5')
        self.FEATURES.append('close_b6')
        self.FEATURES.append('close_b7')
        self.FEATURES.append('close_b8')
        self.FEATURES.append('close_b9')
        self.FEATURES.append('close_b10')
        self.FEATURES.append('close_b11')
        self.FEATURES.append('close_b12')
        self.FEATURES.append('close_b13')
        self.FEATURES.append('close_b14')
        self.FEATURES.append('close_b15')
        self.FEATURES.append('close_b16')
        self.FEATURES.append('close_b17')
        self.FEATURES.append('close_b18')
        self.FEATURES.append('close_b19')
        self.FEATURES.append('close_b20')

        self.TARGETS = []
        self.TARGETS.append(Target.ROR_20_DAYS_BOOL)

        # store the features by data type
        self.number_features = []
        self.vocabulary_features = []
        for key in self.FEATURES:
            if feature_extractor_definition[key][3] == FORMAT_NUMBER:
                self.number_features.append(key)
            elif feature_extractor_definition[key][3] == FORMAT_VOCABULARY:
                self.vocabulary_features.append(key)
            else:
                raise Exception("unsupported feature types in TFT")

        # use all features in feature_column_defination
        features_spec = {}
        for key in self.FEATURES + self.TARGETS:
            if key in feature_extractor_definition:
                feature_def = feature_extractor_definition[key]
                if feature_def[2] == "tf.FixedLenFeature":
                    features_spec[key] = tf.FixedLenFeature([], feature_def[4])
                else:
                    log.error("unsupported key : " + key)
            else:
                log.error(
                    "doesn't exist key {} in feature_extractor_definition, but in features or targets"
                    .format(key))

        self.features_metadata = dataset_metadata.DatasetMetadata(
            dataset_schema.from_feature_spec(features_spec))
        self.feature_spec = features_spec
 def test_to_instance_dicts_error(self,
                                  feature_spec,
                                  feed_dict,
                                  error_msg,
                                  error_type=ValueError):
     schema = dataset_schema.from_feature_spec(feature_spec)
     with self.assertRaisesRegexp(error_type, error_msg):
         impl_helper.to_instance_dicts(schema, feed_dict)
Example #21
0
def get_metadata() -> dataset_metadata.DatasetMetadata:
    return dataset_metadata.DatasetMetadata(
        dataset_schema.from_feature_spec({  #  
            'cash':
            tf.io.FixedLenFeature([], tf.int64),
            'year_norm':
            tf.io.FixedLenFeature([], tf.float32),
            'start_time_norm_midnight':
            tf.io.FixedLenFeature([], tf.float32),
            'start_time_norm_noon':
            tf.io.FixedLenFeature([], tf.float32),
            'pickup_lat_std':
            tf.io.FixedLenFeature([], tf.float32),
            'pickup_long_std':
            tf.io.FixedLenFeature([], tf.float32),
            'pickup_lat_centered':
            tf.io.FixedLenFeature([], tf.float32),
            'pickup_long_centered':
            tf.io.FixedLenFeature([], tf.float32),
            'day_of_week_MONDAY':
            tf.io.FixedLenFeature([], tf.float32),
            'day_of_week_TUESDAY':
            tf.io.FixedLenFeature([], tf.float32),
            'day_of_week_WEDNESDAY':
            tf.io.FixedLenFeature([], tf.float32),
            'day_of_week_THURSDAY':
            tf.io.FixedLenFeature([], tf.float32),
            'day_of_week_FRIDAY':
            tf.io.FixedLenFeature([], tf.float32),
            'day_of_week_SATURDAY':
            tf.io.FixedLenFeature([], tf.float32),
            'day_of_week_SUNDAY':
            tf.io.FixedLenFeature([], tf.float32),
            'month_JANUARY':
            tf.io.FixedLenFeature([], tf.float32),
            'month_FEBRUARY':
            tf.io.FixedLenFeature([], tf.float32),
            'month_MARCH':
            tf.io.FixedLenFeature([], tf.float32),
            'month_APRIL':
            tf.io.FixedLenFeature([], tf.float32),
            'month_MAY':
            tf.io.FixedLenFeature([], tf.float32),
            'month_JUNE':
            tf.io.FixedLenFeature([], tf.float32),
            'month_JULY':
            tf.io.FixedLenFeature([], tf.float32),
            'month_AUGUST':
            tf.io.FixedLenFeature([], tf.float32),
            'month_SEPTEMBER':
            tf.io.FixedLenFeature([], tf.float32),
            'month_OCTOBER':
            tf.io.FixedLenFeature([], tf.float32),
            'month_NOVEMBER':
            tf.io.FixedLenFeature([], tf.float32),
            'month_DECEMBER':
            tf.io.FixedLenFeature([], tf.float32)
        }))
Example #22
0
 def test_constructor_error(self,
                            columns,
                            feature_spec,
                            error_msg,
                            error_type=ValueError,
                            **kwargs):
   schema = dataset_schema.from_feature_spec(feature_spec)
   with self.assertRaisesRegexp(error_type, error_msg):
     csv_coder.CsvCoder(columns, schema, **kwargs)
Example #23
0
 def test_make_feed_dict(self, feature_spec, instances, feed_dict):
   tensors = tf.parse_example(tf.placeholder(tf.string, [None]), feature_spec)
   schema = dataset_schema.from_feature_spec(feature_spec)
   # feed_dict contains feature names as keys, replace these with the
   # actual tensors.
   feed_dict = {tensors[key]: value for key, value in feed_dict.items()}
   np.testing.assert_equal(
       impl_helper.make_feed_dict(tensors, schema, instances),
       feed_dict)
def make_spec(output_dir, batch_size=None):
    fixed_shape = [batch_size, 1] if batch_size is not None else []
    spec = {}
    spec[LABEL_COLUMN] = tf.FixedLenFeature(shape=fixed_shape,
                                            dtype=tf.int64,
                                            default_value=None)
    spec[DISPLAY_ID_COLUMN] = tf.FixedLenFeature(shape=fixed_shape,
                                                 dtype=tf.int64,
                                                 default_value=None)
    spec[IS_LEAK_COLUMN] = tf.FixedLenFeature(shape=fixed_shape,
                                              dtype=tf.int64,
                                              default_value=None)
    spec[DISPLAY_ID_AND_IS_LEAK_ENCODED_COLUMN] = tf.FixedLenFeature(
        shape=fixed_shape, dtype=tf.int64, default_value=None)

    for name in BOOL_COLUMNS:
        spec[name] = tf.FixedLenFeature(shape=fixed_shape,
                                        dtype=tf.int64,
                                        default_value=None)
    for name in FLOAT_COLUMNS_LOG_BIN_TRANSFORM + FLOAT_COLUMNS_SIMPLE_BIN_TRANSFORM:
        spec[name] = tf.FixedLenFeature(shape=fixed_shape,
                                        dtype=tf.float32,
                                        default_value=None)
    for name in FLOAT_COLUMNS_SIMPLE_BIN_TRANSFORM:
        spec[name + '_binned'] = tf.FixedLenFeature(shape=fixed_shape,
                                                    dtype=tf.int64,
                                                    default_value=None)
    for name in FLOAT_COLUMNS_LOG_BIN_TRANSFORM:
        spec[name + '_binned'] = tf.FixedLenFeature(shape=fixed_shape,
                                                    dtype=tf.int64,
                                                    default_value=None)
        spec[name + '_log_01scaled'] = tf.FixedLenFeature(shape=fixed_shape,
                                                          dtype=tf.float32,
                                                          default_value=None)
    for name in INT_COLUMNS:
        spec[name + '_log_int'] = tf.FixedLenFeature(shape=fixed_shape,
                                                     dtype=tf.int64,
                                                     default_value=None)
        spec[name + '_log_01scaled'] = tf.FixedLenFeature(shape=fixed_shape,
                                                          dtype=tf.float32,
                                                          default_value=None)
    for name in BOOL_COLUMNS + CATEGORICAL_COLUMNS:
        spec[name] = tf.FixedLenFeature(shape=fixed_shape,
                                        dtype=tf.int64,
                                        default_value=None)

    for multi_category in DOC_CATEGORICAL_MULTIVALUED_COLUMNS:
        #spec[multi_category] = tf.VarLenFeature(dtype=tf.int64)
        shape = fixed_shape[:-1] + [
            len(DOC_CATEGORICAL_MULTIVALUED_COLUMNS[multi_category])
        ]
        spec[multi_category] = tf.FixedLenFeature(shape=shape, dtype=tf.int64)

    metadata = dataset_metadata.DatasetMetadata(
        dataset_schema.from_feature_spec(spec))

    metadata_io.write_metadata(metadata, output_dir)
 def test_make_feed_dict_error(self,
                               feature_spec,
                               instances,
                               error_msg,
                               error_type=ValueError):
     tensors = tf.parse_example(tf.placeholder(tf.string, [None]),
                                feature_spec)
     schema = dataset_schema.from_feature_spec(feature_spec)
     with self.assertRaisesRegexp(error_type, error_msg):
         impl_helper.make_feed_dict(tensors, schema, instances)
Example #26
0
def main(_):
  # Generate schema of input data.
  raw_metadata = dataset_metadata.DatasetMetadata(
      dataset_schema.from_feature_spec({
          'text': tf.FixedLenFeature([], tf.string),
          'language_code': tf.FixedLenFeature([], tf.string),
      }))

  pipeline = word_count(FLAGS.input_path, FLAGS.output_path, raw_metadata)
  pipeline.run().wait_until_finish()
Example #27
0
def _remove_columns_from_metadata(metadata, excluded_columns):
  """Remove columns from metadata without mutating original metadata."""
  feature_spec = metadata.schema.as_feature_spec()
  domains = metadata.schema.domains()
  new_feature_spec = {name: spec for name, spec in feature_spec.items()
                      if name not in excluded_columns}
  new_domains = {name: spec for name, spec in domains.items()
                 if name not in excluded_columns}
  return dataset_metadata.DatasetMetadata(
      dataset_schema.from_feature_spec(new_feature_spec, new_domains))
 def test_encode_error(self,
                       feature_spec,
                       instance,
                       error_msg,
                       error_type=ValueError,
                       **kwargs):
     schema = dataset_schema.from_feature_spec(feature_spec)
     coder = example_proto_coder.ExampleProtoCoder(schema, **kwargs)
     with self.assertRaisesRegexp(error_type, error_msg):
         coder.encode(instance)
    def test_example_proto_coder_bad_default_value(self):
        input_schema = dataset_schema.from_feature_spec({
            'scalar_feature_2':
            tf.FixedLenFeature(shape=[2],
                               dtype=tf.float32,
                               default_value=[1.0, 2.0]),
        })
        with self.assertRaisesRegexp(
                ValueError, 'only scalar default values are supported'):
            example_proto_coder.ExampleProtoCoder(input_schema)

        input_schema = dataset_schema.from_feature_spec({
            'scalar_feature_2':
            tf.FixedLenFeature(shape=[], dtype=tf.float32,
                               default_value=[1.0]),
        })
        with self.assertRaisesRegexp(
                ValueError, 'only scalar default values are supported'):
            example_proto_coder.ExampleProtoCoder(input_schema)
Example #30
0
 def test_encode_error(self,
                       columns,
                       feature_spec,
                       instance,
                       error_msg,
                       error_type=ValueError,
                       **kwargs):
   schema = dataset_schema.from_feature_spec(feature_spec)
   coder = csv_coder.CsvCoder(columns, schema, **kwargs)
   with self.assertRaisesRegexp(error_type, error_msg):
     coder.encode(instance)
Example #31
0
  def test_sequence_feature_not_supported(self):
    feature_spec = {
        # FixedLenSequenceFeatures
        'fixed_seq_bool':
            tf.FixedLenSequenceFeature(shape=[10], dtype=tf.bool),
        'fixed_seq_bool_allow_missing':
            tf.FixedLenSequenceFeature(
                shape=[5], dtype=tf.bool, allow_missing=True),
        'fixed_seq_int':
            tf.FixedLenSequenceFeature(shape=[5], dtype=tf.int64),
        'fixed_seq_float':
            tf.FixedLenSequenceFeature(shape=[5], dtype=tf.float32),
        'fixed_seq_string':
            tf.FixedLenSequenceFeature(shape=[5], dtype=tf.string),
    }

    with self.assertRaisesRegexp(ValueError,
                                 'DatasetSchema does not support '
                                 'FixedLenSequenceFeature yet.'):
      sch.from_feature_spec(feature_spec)
Example #32
0
def make_input_schema(mode=tf.contrib.learn.ModeKeys.TRAIN):
  """Input schema definition.

  Args:
    mode: tf.contrib.learn.ModeKeys specifying if the schema is being used for
      train/eval or prediction.
  Returns:
    A `Schema` object.
  """
  result = ({} if mode == tf.contrib.learn.ModeKeys.INFER
            else {'clicked': tf.FixedLenFeature(shape=[], dtype=tf.int64)})
  for name in INTEGER_COLUMN_NAMES:
    result[name] = tf.FixedLenFeature(
        shape=[], dtype=tf.int64, default_value=-1)
  for name in CATEGORICAL_COLUMN_NAMES:
    result[name] = tf.FixedLenFeature(shape=[], dtype=tf.string,
                                      default_value='')

  return dataset_schema.from_feature_spec(result)
Example #33
0
  def _ReadSchema(self, data_format,
                  schema_path):
    """Returns a TFT schema for the input data.

    Args:
      data_format: name of the input data format.
      schema_path: path to schema file.

    Returns:
      A schema representing the provided set of columns.
    """

    if self._ShouldDecodeAsRawExample(data_format):
      return _RAW_EXAMPLE_SCHEMA
    schema = self._GetSchema(schema_path)
    # TODO(b/77351671): Remove this conversion to tf.Transform's internal
    # schema format.
    feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec
    return dataset_schema.from_feature_spec(feature_spec)
Example #34
0
def _make_schema(columns, types, default_values):
  """Input schema definition.

  Args:
    columns: column names for fields appearing in input.
    types: column types for fields appearing in input.
    default_values: default values for fields appearing in input.
  Returns:
    feature_set dictionary of string to *Feature.
  """
  result = {}
  assert len(columns) == len(types)
  assert len(columns) == len(default_values)
  for c, t, v in zip(columns, types, default_values):
    if isinstance(t, list):
      result[c] = tf.VarLenFeature(dtype=t[0])
    else:
      result[c] = tf.FixedLenFeature(shape=[], dtype=t, default_value=v)
  return dataset_schema.from_feature_spec(result)
Example #35
0
def make_input_schema(mode=tf.contrib.learn.ModeKeys.TRAIN):
  """Input schema definition.

  Args:
    mode: tf.contrib.learn.ModeKeys specifying if the schema is being used for
      train/eval or prediction.
  Returns:
    A `Schema` object.
  """
  result = ({} if mode == tf.contrib.learn.ModeKeys.INFER else {
      'score': tf.FixedLenFeature(shape=[], dtype=tf.float32)
  })
  result.update({
      'subreddit': tf.FixedLenFeature(shape=[], dtype=tf.string),
      'author': tf.FixedLenFeature(shape=[], dtype=tf.string),
      'comment_body': tf.FixedLenFeature(shape=[], dtype=tf.string,
                                         default_value=''),
      'comment_parent_body': tf.FixedLenFeature(shape=[], dtype=tf.string,
                                                default_value=''),
      'toplevel': tf.FixedLenFeature(shape=[], dtype=tf.int64),
  })
  return dataset_schema.from_feature_spec(result)
def transform_data(input_handle,
                   outfile_prefix,
                   working_dir,
                   setup_file, ts1, ts2,
                   project=None,
                   max_rows=None,
                   mode=None,
                   stage=None,
                   preprocessing_fn=None):
  """The main tf.transform method which analyzes and transforms data.

  Args:
    input_handle: BigQuery table name to process specified as
      DATASET.TABLE or path to csv file with input data.
    outfile_prefix: Filename prefix for emitted transformed examples
    working_dir: Directory in which transformed examples and transform
      function will be emitted.
    max_rows: Number of rows to query from BigQuery
    pipeline_args: additional DataflowRunner or DirectRunner args passed to the
      beam pipeline.
  """

  def def_preprocessing_fn(inputs):
    """tf.transform's callback function for preprocessing inputs.

    Args:
      inputs: map from feature keys to raw not-yet-transformed features.

    Returns:
      Map from string feature key to transformed feature operations.
    """
    outputs = {}
    for key in taxi.DENSE_FLOAT_FEATURE_KEYS:
      # Preserve this feature as a dense float, setting nan's to the mean.
      outputs[taxi.transformed_name(key)] = transform.scale_to_z_score(
          _fill_in_missing(inputs[key]))

    for key in taxi.VOCAB_FEATURE_KEYS:
      # Build a vocabulary for this feature.
      outputs[
          taxi.transformed_name(key)] = transform.compute_and_apply_vocabulary(
              _fill_in_missing(inputs[key]),
              top_k=taxi.VOCAB_SIZE,
              num_oov_buckets=taxi.OOV_SIZE)

    for key in taxi.BUCKET_FEATURE_KEYS:
      outputs[taxi.transformed_name(key)] = transform.bucketize(
          _fill_in_missing(inputs[key]), taxi.FEATURE_BUCKET_COUNT)

    for key in taxi.CATEGORICAL_FEATURE_KEYS:
      outputs[taxi.transformed_name(key)] = _fill_in_missing(inputs[key])

    # Was this passenger a big tipper?
    taxi_fare = _fill_in_missing(inputs[taxi.FARE_KEY])
    tips = _fill_in_missing(inputs[taxi.LABEL_KEY])
    outputs[taxi.transformed_name(taxi.LABEL_KEY)] = tf.where(
        tf.is_nan(taxi_fare),
        tf.cast(tf.zeros_like(taxi_fare), tf.int64),
        # Test if the tip was > 20% of the fare.
        tf.cast(
            tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))),
            tf.int64))

    return outputs

  preprocessing_fn = preprocessing_fn or def_preprocessing_fn

  print('ts1 %s, ts2 %s' % (ts1,ts2))

  schema = taxi.read_schema('./schema.pbtxt')
  raw_feature_spec = taxi.get_raw_feature_spec(schema)
  raw_schema = dataset_schema.from_feature_spec(raw_feature_spec)
  raw_data_metadata = dataset_metadata.DatasetMetadata(raw_schema)

  transform_dir = None

  temp_dir = os.path.join(working_dir, 'tmp')
  if stage is None:
    stage = 'train'

  if mode == 'local':
    options = {
      'project': project}
    pipeline_options = beam.pipeline.PipelineOptions(flags=[], **options)
    runner = 'DirectRunner'
  elif mode == 'cloud':
    options = {
      'job_name': 'tft-' + stage + '-' + str(uuid.uuid4()),
      'temp_location': temp_dir,
      'project': project,
      'save_main_session': True,
      'setup_file': setup_file
    }
    pipeline_options = beam.pipeline.PipelineOptions(flags=[], **options)
    runner = 'DataFlowRunner'
  else:
    raise ValueError("Invalid mode %s." % mode)

  with beam.Pipeline(runner, options=pipeline_options) as pipeline:
    with beam_impl.Context(temp_dir=temp_dir):
      csv_coder = taxi.make_csv_coder(schema)
      # temp tft bug workaround
      mcsv_coder = make_mcsv_coder(schema)
      if 'csv' in input_handle.lower():
      # if input_handle.lower().endswith('csv'):
        raw_data = (
            pipeline
            | 'ReadFromText' >> beam.io.ReadFromText(
                input_handle, skip_header_lines=1)
            | 'ParseCSV' >> beam.Map(csv_coder.decode))
      else:
        query = make_sql(input_handle, ts1, ts2, stage, max_rows=max_rows, for_eval=False)
        raw_data1 = (
            pipeline
            | 'ReadBigQuery' >> beam.io.Read(
                beam.io.BigQuerySource(query=query, use_standard_sql=True)))
        raw_data = (
            raw_data1
            | 'CleanData' >> beam.Map(
                lambda x: (taxi.clean_raw_data_dict(x, raw_feature_spec))))

      if transform_dir is None:
        transform_fn = (
            (raw_data, raw_data_metadata)
            | ('Analyze' >> beam_impl.AnalyzeDataset(preprocessing_fn)))

        _ = (
            transform_fn
            | ('WriteTransformFn' >>
               transform_fn_io.WriteTransformFn(working_dir)))
      else:
        transform_fn = pipeline | transform_fn_io.ReadTransformFn(transform_dir)

      # Shuffling the data before materialization will improve Training
      # effectiveness downstream.
      shuffled_data = raw_data | 'RandomizeData' >> beam.transforms.Reshuffle()

      (transformed_data, transformed_metadata) = (
          ((shuffled_data, raw_data_metadata), transform_fn)
          | 'Transform' >> beam_impl.TransformDataset())

      if 'csv' not in input_handle.lower():  # if querying BQ
        _ = (
            raw_data
            | beam.Map(mcsv_coder.encode)
            | beam.io.WriteToText(os.path.join(working_dir, '{}.csv'.format(stage)), num_shards=1)
            )

      coder = example_proto_coder.ExampleProtoCoder(transformed_metadata.schema)
      _ = (
          transformed_data
          | 'SerializeExamples' >> beam.Map(coder.encode)
          | 'WriteExamples' >> beam.io.WriteToTFRecord(
              os.path.join(working_dir, outfile_prefix), file_name_suffix='.gz')
      )
Example #37
0
from tensorflow.core.example import example_pb2
from tensorflow_metadata.proto.v0 import schema_pb2
from tensorflow_metadata.proto.v0 import statistics_pb2
# pylint: enable=g-direct-tensorflow-import
from tfx.components.base import base_executor
from tfx.components.transform import common
from tfx.components.transform import labels
from tfx.components.transform import messages
from tfx.utils import io_utils
from tfx.utils import types


RAW_EXAMPLE_KEY = 'raw_example'

# Schema to use if the input data should be decoded as raw example.
_RAW_EXAMPLE_SCHEMA = dataset_schema.from_feature_spec(
    {RAW_EXAMPLE_KEY: tf.FixedLenFeature([], tf.string)})

# TODO(b/123519698): Simplify the code by removing the key structure.
_TRANSFORM_INTERNAL_FEATURE_FOR_KEY = '__TFT_PASS_KEY__'

# Default file name prefix for transformed_examples.
_DEFAULT_TRANSFORMED_EXAMPLES_PREFIX = 'transformed_examples'

# Temporary path inside transform_output used for tft.beam
# TODO(b/125451545): Provide a safe temp path from base executor instead.
_TEMP_DIR_IN_TRANSFORM_OUTPUT = '.temp_path'


# TODO(b/122478841): Move it to a common place that is shared across components.
class _Status(object):
  """Status that reports success or error status of an execution."""
Example #38
0
def make_proto_coder(schema):
  raw_feature_spec = get_raw_feature_spec(schema)
  raw_schema = dataset_schema.from_feature_spec(raw_feature_spec)
  return tft_coders.ExampleProtoCoder(raw_schema)
Example #39
0
def make_csv_coder(schema):
  """Return a coder for tf.transform to read csv files."""
  raw_feature_spec = get_raw_feature_spec(schema)
  parsing_schema = dataset_schema.from_feature_spec(raw_feature_spec)
  return tft_coders.CsvCoder(CSV_COLUMN_NAMES, parsing_schema)
Example #40
0
  def _RunBeamImpl(self, inputs,
                   outputs, preprocessing_fn,
                   input_dataset_metadata,
                   raw_examples_data_format, transform_output_path,
                   compute_statistics,
                   materialize_output_paths):
    """Perform data preprocessing with FlumeC++ runner.

    Args:
      inputs: A dictionary of labelled input values.
      outputs: A dictionary of labelled output values.
      preprocessing_fn: The tf.Transform preprocessing_fn.
      input_dataset_metadata: A DatasetMetadata object for the input data.
      raw_examples_data_format: A string describing the raw data format.
      transform_output_path: An absolute path to write the output to.
      compute_statistics: A bool indicating whether or not compute statistics.
      materialize_output_paths: Paths to materialized outputs.

    Raises:
      RuntimeError: If reset() is not being invoked between two run().
      ValueError: If the schema is empty.

    Returns:
      Status of the execution.
    """
    raw_examples_file_format = common.GetSoleValue(
        inputs, labels.EXAMPLES_FILE_FORMAT_LABEL, strict=False)
    analyze_and_transform_data_paths = common.GetValues(
        inputs, labels.ANALYZE_AND_TRANSFORM_DATA_PATHS_LABEL)
    transform_only_data_paths = common.GetValues(
        inputs, labels.TRANSFORM_ONLY_DATA_PATHS_LABEL)
    stats_use_tfdv = common.GetSoleValue(inputs,
                                         labels.TFT_STATISTICS_USE_TFDV_LABEL)
    per_set_stats_output_paths = common.GetValues(
        outputs, labels.PER_SET_STATS_OUTPUT_PATHS_LABEL)
    temp_path = common.GetSoleValue(outputs, labels.TEMP_OUTPUT_LABEL)

    tf.logging.info('Analyze and transform data patterns: %s',
                    list(enumerate(analyze_and_transform_data_paths)))
    tf.logging.info('Transform data patterns: %s',
                    list(enumerate(transform_only_data_paths)))
    tf.logging.info('Transform materialization output paths: %s',
                    list(enumerate(materialize_output_paths)))
    tf.logging.info('Transform output path: %s', transform_output_path)

    feature_spec = input_dataset_metadata.schema.as_feature_spec()
    try:
      analyze_input_columns = tft.get_analyze_input_columns(
          preprocessing_fn, feature_spec)
      transform_input_columns = (
          tft.get_transform_input_columns(preprocessing_fn, feature_spec))
    except AttributeError:
      # If using TFT 1.12, fall back to assuming all features are used.
      analyze_input_columns = feature_spec.keys()
      transform_input_columns = feature_spec.keys()
    # Use the same dataset (same columns) for AnalyzeDataset and computing
    # pre-transform stats so that the data will only be read once for these
    # two operations.
    if compute_statistics:
      analyze_input_columns = list(
          set(list(analyze_input_columns) + list(transform_input_columns)))
    analyze_input_dataset_metadata = copy.deepcopy(input_dataset_metadata)
    transform_input_dataset_metadata = copy.deepcopy(input_dataset_metadata)
    if input_dataset_metadata.schema is not _RAW_EXAMPLE_SCHEMA:
      analyze_input_dataset_metadata.schema = dataset_schema.from_feature_spec(
          {feature: feature_spec[feature] for feature in analyze_input_columns})
      transform_input_dataset_metadata.schema = (
          dataset_schema.from_feature_spec({
              feature: feature_spec[feature]
              for feature in transform_input_columns
          }))

    can_process_jointly = not bool(per_set_stats_output_paths or
                                   materialize_output_paths)
    analyze_data_list = self._MakeDatasetList(
        analyze_and_transform_data_paths, raw_examples_file_format,
        raw_examples_data_format, analyze_input_dataset_metadata,
        can_process_jointly)
    transform_data_list = self._MakeDatasetList(
        list(analyze_and_transform_data_paths) +
        list(transform_only_data_paths), raw_examples_file_format,
        raw_examples_data_format, transform_input_dataset_metadata,
        can_process_jointly)

    desired_batch_size = self._GetDesiredBatchSize(raw_examples_data_format)

    with self._CreatePipeline(outputs) as p:
      with tft_beam.Context(
          temp_dir=temp_path,
          desired_batch_size=desired_batch_size,
          passthrough_keys={_TRANSFORM_INTERNAL_FEATURE_FOR_KEY},
          use_deep_copy_optimization=True):
        # pylint: disable=expression-not-assigned
        # pylint: disable=no-value-for-parameter

        analyze_decode_fn = (
            self._GetDecodeFunction(raw_examples_data_format,
                                    analyze_input_dataset_metadata.schema))

        for (idx, dataset) in enumerate(analyze_data_list):
          dataset.encoded = (
              p | 'ReadAnalysisDataset[{}]'.format(idx) >>
              self._ReadExamples(dataset))
          dataset.decoded = (
              dataset.encoded
              | 'DecodeAnalysisDataset[{}]'.format(idx) >>
              self._DecodeInputs(analyze_decode_fn))

        input_analysis_data = (
            [dataset.decoded for dataset in analyze_data_list]
            | 'FlattenAnalysisDatasets' >> beam.Flatten())
        transform_fn = (
            (input_analysis_data, input_dataset_metadata)
            | 'AnalyzeDataset' >> tft_beam.AnalyzeDataset(preprocessing_fn))
        # Write the raw/input metadata.
        (input_dataset_metadata
         | 'WriteMetadata' >> tft_beam.WriteMetadata(
             os.path.join(transform_output_path,
                          tft.TFTransformOutput.RAW_METADATA_DIR), p))

        # WriteTransformFn writes transform_fn and metadata to subdirectories
        # tensorflow_transform.SAVED_MODEL_DIR and
        # tensorflow_transform.TRANSFORMED_METADATA_DIR respectively.
        (transform_fn |
         'WriteTransformFn' >> tft_beam.WriteTransformFn(transform_output_path))

        if compute_statistics or materialize_output_paths:
          # Do not compute pre-transform stats if the input format is raw proto,
          # as StatsGen would treat any input as tf.Example.
          if (compute_statistics and
              not self._IsDataFormatProto(raw_examples_data_format)):
            # Aggregated feature stats before transformation.
            pre_transform_feature_stats_path = os.path.join(
                transform_output_path,
                tft.TFTransformOutput.PRE_TRANSFORM_FEATURE_STATS_PATH)

            # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in
            # schema. Currently input dataset schema only contains dtypes,
            # and other metadata is dropped due to roundtrip to tensors.
            schema_proto = schema_utils.schema_from_feature_spec(
                analyze_input_dataset_metadata.schema.as_feature_spec())
            ([
                dataset.decoded if stats_use_tfdv else dataset.encoded
                for dataset in analyze_data_list
            ]
             | 'FlattenPreTransformAnalysisDatasets' >> beam.Flatten()
             | 'GenerateAggregatePreTransformAnalysisStats' >>
             self._GenerateStats(
                 pre_transform_feature_stats_path,
                 schema_proto,
                 use_deep_copy_optimization=True,
                 use_tfdv=stats_use_tfdv))

          transform_decode_fn = (
              self._GetDecodeFunction(raw_examples_data_format,
                                      transform_input_dataset_metadata.schema))
          # transform_data_list is a superset of analyze_data_list, we pay the
          # cost to read the same dataset (analyze_data_list) again here to
          # prevent certain beam runner from doing large temp materialization.
          for (idx, dataset) in enumerate(transform_data_list):
            dataset.encoded = (
                p
                | 'ReadTransformDataset[{}]'.format(idx) >>
                self._ReadExamples(dataset))
            dataset.decoded = (
                dataset.encoded
                | 'DecodeTransformDataset[{}]'.format(idx) >>
                self._DecodeInputs(transform_decode_fn))
            (dataset.transformed,
             metadata) = (((dataset.decoded, transform_input_dataset_metadata),
                           transform_fn)
                          | 'TransformDataset[{}]'.format(idx) >>
                          tft_beam.TransformDataset())

            if materialize_output_paths or not stats_use_tfdv:
              dataset.transformed_and_encoded = (
                  dataset.transformed
                  | 'EncodeTransformedDataset[{}]'.format(idx) >> beam.ParDo(
                      self._EncodeAsExamples(), metadata))

          if compute_statistics:
            # Aggregated feature stats after transformation.
            _, metadata = transform_fn
            post_transform_feature_stats_path = os.path.join(
                transform_output_path,
                tft.TFTransformOutput.POST_TRANSFORM_FEATURE_STATS_PATH)

            # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in
            # schema. Currently input dataset schema only contains dtypes,
            # and other metadata is dropped due to roundtrip to tensors.
            transformed_schema_proto = schema_utils.schema_from_feature_spec(
                metadata.schema.as_feature_spec())

            ([(dataset.transformed
               if stats_use_tfdv else dataset.transformed_and_encoded)
              for dataset in transform_data_list]
             | 'FlattenPostTransformAnalysisDatasets' >> beam.Flatten()
             | 'GenerateAggregatePostTransformAnalysisStats' >>
             self._GenerateStats(
                 post_transform_feature_stats_path,
                 transformed_schema_proto,
                 use_tfdv=stats_use_tfdv))

            if per_set_stats_output_paths:
              assert len(transform_data_list) == len(per_set_stats_output_paths)
              # TODO(b/67632871): Remove duplicate stats gen compute that is
              # done both on a flattened view of the data, and on each span
              # below.
              bundles = zip(transform_data_list, per_set_stats_output_paths)
              for (idx, (dataset, output_path)) in enumerate(bundles):
                if stats_use_tfdv:
                  data = dataset.transformed
                else:
                  data = dataset.transformed_and_encoded
                (data
                 | 'GeneratePostTransformStats[{}]'.format(idx) >>
                 self._GenerateStats(
                     output_path,
                     transformed_schema_proto,
                     use_tfdv=stats_use_tfdv))

          if materialize_output_paths:
            assert len(transform_data_list) == len(materialize_output_paths)
            bundles = zip(transform_data_list, materialize_output_paths)
            for (idx, (dataset, output_path)) in enumerate(bundles):
              (dataset.transformed_and_encoded
               | 'Materialize[{}]'.format(idx) >> self._WriteExamples(
                   raw_examples_file_format, output_path))

    return _Status.OK()