示例#1
0
    def testCaptureHashTable(self):
        # NOTE(mrry): We must use the V2 variants of `HashTable`
        # etc. because these produce a `tf.resource`-typed output that is
        # compatible with the in-graph function implementation.
        default_val = -1
        keys = constant_op.constant(["brain", "salad", "surgery"])
        values = constant_op.constant([0, 1, 2], dtypes.int64)
        table = lookup_ops.HashTable(
            lookup_ops.KeyValueTensorInitializer(keys, values), default_val)

        input_sentences = dataset_ops.Dataset.from_tensor_slices(
            ["brain brain tank salad surgery", "surgery brain"])

        iterator = dataset_ops.make_initializable_iterator(
            input_sentences.map(
                lambda x: string_ops.string_split([x]).values).map(
                    table.lookup))
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.cached_session() as sess:
            sess.run(table.initializer)
            sess.run(init_op)
            sess.run(get_next)
            sess.run(get_next)
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)
    def testCaptureHashTableInSharedIterator(self):
        worker, _ = test_util.create_local_cluster(1, 1)

        # NOTE(mrry): We must use the V2 variants of `HashTable`
        # etc. because these produce a `tf.resource`-typed output that is
        # compatible with the in-graph function implementation.
        default_val = -1
        keys = constant_op.constant(["brain", "salad", "surgery"])
        values = constant_op.constant([0, 1, 2], dtypes.int64)
        table = lookup_ops.HashTable(lookup_ops.KeyValueTensorInitializer(
            keys, values),
                                     default_val,
                                     shared_name="shared_table")

        input_sentences = dataset_ops.Dataset.from_tensor_slices(
            ["brain brain tank salad surgery", "surgery brain"])

        iterator = (input_sentences.map(
            lambda x: string_ops.string_split([x]).values).map(
                table.lookup).make_initializable_iterator(
                    shared_name="shared_iterator"))
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with session.Session(worker[0].target) as sess:
            sess.run(table.init)
            sess.run(init_op)
            self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next))

        with session.Session(worker[0].target) as sess:
            self.assertAllEqual([2, 0], sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)
示例#3
0
    def test_filter_input_subsample_vocab(self):
        """Tests input filtering based on vocab subsampling."""
        # The outputs are non-deterministic, so set random seed to help ensure
        # that the outputs remain constant for testing.
        random_seed.set_random_seed(42)

        input_tensor = tf.constant([
            # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57.
            b"the",
            b"answer",  # Not in vocab. (Always discarded)
            b"to",  # keep_prob = 0.75.
            b"life",  # keep_prob > 1. (Always kept)
            b"and",  # keep_prob = 0.48.
            b"universe"  # Below vocab threshold of 3. (Always discarded)
        ])
        keys = tf.constant([b"and", b"life", b"the", b"to", b"universe"])
        values = tf.constant([40, 8, 30, 20, 2], tf.dtypes.int64)
        vocab_freq_table = lookup_ops.HashTable(
            lookup_ops.KeyValueTensorInitializer(keys, values), -1)

        output = skip_gram_ops._filter_input(
            input_tensor=input_tensor,
            vocab_freq_table=vocab_freq_table,
            vocab_min_count=3,
            vocab_subsampling=0.05,
            corpus_size=tf.math.reduce_sum(values),
            seed=9)
        self.assertAllEqual([b"the", b"to", b"life", b"and"], output)
示例#4
0
def _create_table(vocab, num_oov=1):
  init = lookup_ops.KeyValueTensorInitializer(
      vocab,
      math_ops.range(
          array_ops.size(vocab, out_type=dtypes.int64), dtype=dtypes.int64),
      key_dtype=dtypes.string,
      value_dtype=dtypes.int64)
  return lookup_ops.StaticVocabularyTableV1(
      init, num_oov, lookup_key_dtype=dtypes.string)
示例#5
0
    def test_skip_gram_sample_errors(self):
        """Tests various errors raised by skip_gram_sample()."""
        input_tensor = constant_op.constant([b"the", b"quick", b"brown"])

        invalid_skips = (
            # min_skips and max_skips must be >= 0.
            (-1, 2),
            (1, -2),
            # min_skips must be <= max_skips.
            (2, 1))
        for min_skips, max_skips in invalid_skips:
            with self.assertRaises(errors.InvalidArgumentError):
                text.skip_gram_sample(input_tensor,
                                      min_skips=min_skips,
                                      max_skips=max_skips)

        # Eager tensor must be rank 1
        with self.assertRaises(errors.InvalidArgumentError):
            invalid_tensor = constant_op.constant([[b"the"], [b"quick"],
                                                   [b"brown"]])
            text.skip_gram_sample(invalid_tensor)

        # vocab_freq_table must be provided if vocab_min_count,
        # vocab_subsampling, or corpus_size is specified.
        dummy_input = constant_op.constant([""])
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=None,
                                  vocab_min_count=1)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=None,
                                  vocab_subsampling=1e-5)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=None,
                                  corpus_size=100)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=None,
                                  vocab_subsampling=1e-5,
                                  corpus_size=100)

        # vocab_subsampling and corpus_size must both be present or absent.
        dummy_table = lookup_ops.HashTable(
            lookup_ops.KeyValueTensorInitializer([b"foo"], [10]), -1)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=dummy_table,
                                  vocab_subsampling=None,
                                  corpus_size=100)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=dummy_table,
                                  vocab_subsampling=1e-5,
                                  corpus_size=None)
示例#6
0
 def get_graph_def():
   with ops.Graph().as_default() as g:
     x = constant_op.constant([2, 9], name="x")
     keys = constant_op.constant([1, 2], name="keys")
     values = constant_op.constant([3, 4], name="values")
     default = constant_op.constant(-1, name="default")
     table = lookup_ops.StaticHashTable(
         lookup_ops.KeyValueTensorInitializer(keys, values), default)
     _ = table.lookup(x)
   return g.as_graph_def()
示例#7
0
 def testDistributeLookupTable(self):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     keys_tensor = constant_op.constant([1, 2])
     vals_tensor = constant_op.constant([11, 12])
     table = lookup_ops.StaticHashTable(
         lookup_ops.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1)
     ds = dataset_ops.Dataset.range(3, output_type=dtypes.int32)
     ds = ds.map(table.lookup)
     ds = self.make_distributed_dataset(ds, cluster)
     self.assertDatasetProduces(ds, [-1, 11, 12])
示例#8
0
    def test_filter_input_filter_vocab(self):
        """
        Tests input filtering based on vocab frequency table and thresholds.
        """
        input_tensor = tf.constant(
            [b"the", b"answer", b"to", b"life", b"and", b"universe"])
        keys = tf.constant([b"and", b"life", b"the", b"to",
                                     b"universe"])
        values = tf.constant([0, 1, 2, 3, 4], tf.dtypes.int64)
        vocab_freq_table = lookup_ops.HashTable(
            lookup_ops.KeyValueTensorInitializer(keys, values), -1)

        # No vocab_freq_table specified - output should be the same as input
        no_table_output = skip_gram_ops._filter_input(
            input_tensor=input_tensor,
            vocab_freq_table=None,
            vocab_min_count=None,
            vocab_subsampling=None,
            corpus_size=None,
            seed=None)
        self.assertAllEqual(input_tensor, no_table_output)

        # vocab_freq_table specified, but no vocab_min_count - output should
        # have filtered out tokens not in the table (b"answer").
        table_output = skip_gram_ops._filter_input(
            input_tensor=input_tensor,
            vocab_freq_table=vocab_freq_table,
            vocab_min_count=None,
            vocab_subsampling=None,
            corpus_size=None,
            seed=None)
        self.assertAllEqual([b"the", b"to", b"life", b"and", b"universe"],
                            table_output)

        # vocab_freq_table and vocab_min_count specified - output should have
        # filtered out tokens whose frequencies are below the threshold
        # (b"and": 0, b"life": 1).
        threshold_output = skip_gram_ops._filter_input(
            input_tensor=input_tensor,
            vocab_freq_table=vocab_freq_table,
            vocab_min_count=2,
            vocab_subsampling=None,
            corpus_size=None,
            seed=None)
        self.assertAllEqual([b"the", b"to", b"universe"], threshold_output)
 def make_initializer(self, init_source, vals):
   if init_source == "textfile":
     file = os.path.join(self.get_temp_dir(), "text_file_initializer")
     with open(file, "w") as f:
       f.write("\n".join(str(v) for v in vals) + "\n")
     return lookup_ops.TextFileInitializer(
         filename=file,
         key_dtype=dtypes.int64,
         key_index=lookup_ops.TextFileIndex.LINE_NUMBER,
         value_dtype=dtypes.int64,
         value_index=lookup_ops.TextFileIndex.WHOLE_LINE)
   elif init_source == "keyvaluetensor":
     keys_tensor = constant_op.constant(
         list(range(len(vals))), dtype=dtypes.int64)
     vals_tensor = constant_op.constant(vals)
     return lookup_ops.KeyValueTensorInitializer(keys_tensor, vals_tensor)
   else:
     raise ValueError("Unrecognized init_source: " + init_source)
示例#10
0
def read_tf_vocab(input_file, UNK):
    """Read vocabulary and return a tf hashtable"""
    if input_file is None:
        return None

    keys, values = [], []
    fin = tf.io.gfile.GFile(input_file, 'r')
    for line in fin:
        word = split(strip(line))[0]
        keys.append(word)
        values.append(len(values))
    fin.close()
    UNK_ID = keys.index(UNK)

    initializer = lookup_ops.KeyValueTensorInitializer(tf.constant(keys),
                                                       tf.constant(values))
    vocab_table = lookup_ops.HashTable(initializer, UNK_ID)
    return initializer, vocab_table
示例#11
0
  def testStaticHashTableDatasetFnHostTrainingLoop(self, enable_packed_var):
    self._dataset_fn_tracing_count = 0
    strategy = get_tpu_strategy(enable_packed_var)

    with strategy.scope():
      vals = [0, 1, 2]
      keys_tensor = constant_op.constant(
          list(range(len(vals))), dtype=dtypes.int64)
      vals_tensor = constant_op.constant(vals)
      initializer = lookup_ops.KeyValueTensorInitializer(
          keys_tensor, vals_tensor)
      per_worker_table = lookup_ops.StaticHashTable(
          initializer, default_value=-1)

    @def_function.function
    def dataset_fn(input_context):
      tensor = constant_op.constant([0, 1, 3], dtype=dtypes.int64)
      global_batch_size = 2
      batch_size = input_context.get_per_replica_batch_size(global_batch_size)
      dataset = dataset_ops.Dataset.from_tensors(tensor).repeat().batch(
          batch_size, drop_remainder=True)
      dataset = dataset.shard(input_context.num_input_pipelines,
                              input_context.input_pipeline_id)
      dataset = dataset.prefetch(2)  # This prefetches 2 batches per device.
      dataset = dataset.map(per_worker_table.lookup)
      self._dataset_fn_tracing_count += 1
      return dataset

    dist_iterator = iter(
        strategy.experimental_distribute_datasets_from_function(dataset_fn))

    @def_function.function
    def step_fn(inputs):
      # inputs should be [0, 1, -1]
      return math_ops.reduce_sum(inputs)

    def train_steps(iterator, steps):

      for _ in math_ops.range(steps):
        strategy.run(step_fn, args=(next(iterator),))

    train_steps(dist_iterator, steps=5)
    self.assertEqual(self._dataset_fn_tracing_count, 1)
    def __init__(self, init_source, filepath):
      vals = [0, 1, 2]
      if init_source == "textfile":

        with open(filepath, "w") as f:
          f.write("\n".join(str(v) for v in vals) + "\n")

        self.initializer = lookup_ops.TextFileInitializer(
            filepath, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER,
            dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE)
      else:
        keys_tensor = constant_op.constant(
            list(range(len(vals))), dtype=dtypes.int64)
        vals_tensor = constant_op.constant(vals)
        self.initializer = lookup_ops.KeyValueTensorInitializer(
            keys_tensor, vals_tensor)

      self.table = lookup_ops.StaticHashTable(
          self.initializer, default_value=-2)
示例#13
0
    def testDetokenizeFailsForSparseVocab(self):
        vocab = ["a", "##b", "##c"]
        ids = [0, 10, 20]
        init = lookup_ops.KeyValueTensorInitializer(vocab,
                                                    ids,
                                                    key_dtype=dtypes.string,
                                                    value_dtype=dtypes.int64)
        table = lookup_ops.StaticVocabularyTableV1(
            init, num_oov_buckets=1, lookup_key_dtype=dtypes.string)
        self.evaluate(table.initializer)

        tokenizer = WordpieceTokenizer(table)
        words = ragged_factory_ops.constant([["abb", "abc"], ["abcbc"]])
        subwords_ids = tokenizer.tokenize(words)

        with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
                                    "detokenize.*?dense on the interval"):
            result = tokenizer.detokenize(subwords_ids)
            self.evaluate(result)
示例#14
0
def read_tf_vocab(input_file, UNK):
    """Read vocabulary and return a tf hashtable"""
    if input_file is None:
        return None

    keys, values = [], []
    if input_file.endswith('.gz'):
        f = tf.gfile.Open(input_file, 'r')
        fin = gzip.GzipFile(fileobj=f)
    else:
        fin = tf.gfile.Open(input_file, 'r')
    for line in fin:
        word = split(strip(line))[0]
        keys.append(word)
        values.append(len(values))
    fin.close()
    UNK_ID = keys.index(UNK)
    vocab_table = lookup_ops.HashTable(
        lookup_ops.KeyValueTensorInitializer(tf.constant(keys),
                                             tf.constant(values)), UNK_ID)
    return vocab_table
示例#15
0
 def testDistributeLookupTable(self, init_from_file):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     if init_from_file:
         file = os.path.join(self.get_temp_dir(), "distribute_lookup_table")
         with open(file, "w") as f:
             f.write("10\n11\n")
         initializer = lookup_ops.TextFileInitializer(
             file, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER,
             dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE)
     else:
         keys_tensor = constant_op.constant([0, 1], dtype=dtypes.int64)
         vals_tensor = constant_op.constant([10, 11])
         initializer = lookup_ops.KeyValueTensorInitializer(
             keys_tensor, vals_tensor)
     table = lookup_ops.StaticHashTable(initializer, -1)
     ds = dataset_ops.Dataset.range(3)
     ds = ds.map(table.lookup)
     ds = self.make_distributed_dataset(ds, cluster)
     self.evaluate(lookup_ops.tables_initializer())
     self.assertDatasetProduces(ds, [10, 11, -1],
                                requires_initialization=True)
示例#16
0
  def _make_model_with_tables(self):
    default_val = -1
    keys = constant_op.constant(["brain", "salad", "surgery"])
    values = constant_op.constant([0, 1, 2], dtypes.int64)
    table1_initializer = lookup_ops.KeyValueTensorInitializer(keys, values)
    table1 = lookup_ops.HashTable(table1_initializer, default_val)

    table2_file = self._make_asset("test\nfoo\nbrain\n")
    table2_initializer = lookup_ops.TextFileIdTableInitializer(table2_file)
    table2 = lookup_ops.HashTable(table2_initializer, default_val)

    def _make_lookup_function(table):
      signature = [tensor_spec.TensorSpec(None, dtypes.string)]
      return def_function.function(input_signature=signature)(
          lambda x: table.lookup(x))  # pylint: disable=unnecessary-lambda

    root = tracking.AutoTrackable()
    root.table1 = table1
    root.lookup1 = _make_lookup_function(table1)
    root.table2 = table2
    root.lookup2 = _make_lookup_function(table2)
    return root
示例#17
0
def read_tf_vocab_inverse(input_file, UNK):
    """Read vocabulary (token->id) and return a tf hashtable (id->token)"""
    if input_file is None:
        return None

    keys, values = [], []
    if input_file.endswith('.gz'):
        f = tf.io.gfile.GFile(input_file, 'r')
        fin = gzip.GzipFile(fileobj=f)
    else:
        fin = tf.io.gfile.GFile(input_file, 'r')
    for line in fin:
        word = split(strip(line))[0]

        keys.append(len(keys))
        values.append(word)
    fin.close()

    initializer = lookup_ops.KeyValueTensorInitializer(tf.constant(keys),
                                                       tf.constant(values))
    vocab_table = lookup_ops.HashTable(initializer, UNK)
    return initializer, vocab_table
示例#18
0
    def testLookupTableGraphSerialization(self, init_from_file):
        if init_from_file:
            file = os.path.join(self.get_temp_dir(),
                                "lookup_table_graph_serialize")
            with open(file, "w") as f:
                f.write("10\n11\n")
            initializer = lookup_ops.TextFileInitializer(
                file, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER,
                dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE)
        else:
            keys_tensor = constant_op.constant([0, 1], dtype=dtypes.int64)
            vals_tensor = constant_op.constant([10, 11])
            initializer = lookup_ops.KeyValueTensorInitializer(
                keys_tensor, vals_tensor)

        table = lookup_ops.StaticHashTable(initializer, -1)
        dataset = dataset_ops.Dataset.range(3)
        dataset = dataset.map(table.lookup)
        self.evaluate(lookup_ops.tables_initializer())
        round_tripped = self.graphRoundTrip(dataset)
        del table
        del dataset
        self.assertDatasetProduces(round_tripped, [10, 11, -1],
                                   requires_initialization=True)
 def collecting_function(x):
   _ = lookup_ops.HashTable(
       lookup_ops.KeyValueTensorInitializer([], []), 0.0, name="t1")
   return x
示例#20
0
 def keyValueTensorInitializer(self, vals):
   keys_tensor = constant_op.constant(
       list(range(len(vals))), dtype=dtypes.int64)
   vals_tensor = constant_op.constant(vals)
   return lookup_ops.KeyValueTensorInitializer(keys_tensor, vals_tensor)
示例#21
0
def create_infer_model(model_creator, hparams, scope=None, extra_args=None):
    """Create inference model."""
    graph = tf.Graph()
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file

    # REvo added
    tgt_table = codecs.open(src_vocab_file, 'r').readlines()
    tmp_ids = []
    tmp_words = []
    for i in range(len(tgt_table)):
        tmp_ids.append(i)
        tmp_words.append(tgt_table[i].strip())

    with graph.as_default(), tf.container(scope or "infer"):
        # Constant vocab table
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)

        # reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
        #     tgt_vocab_file, default_value=vocab_utils.UNK, name="reverse_table")
        # added
        vals = tf.constant(tmp_words, dtype=tf.string)
        keys = tf.constant(tmp_ids, dtype=tf.int64)
        reverse_tgt_vocab_table = lookup_ops.HashTable(
            lookup_ops.KeyValueTensorInitializer(keys, vals),
            "<unk>",
            name="reverse_table")
        #

        # debug
        print("SRC:", src_vocab_table)
        print("SRC type:", type(src_vocab_table))
        #
        src_placeholder = tf.placeholder(shape=[None],
                                         dtype=tf.string,
                                         name="src_place")
        # batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64, name="batch_place")
        batch_size_placeholder = tf.constant(1,
                                             dtype=tf.int64,
                                             name="batch_place")

        src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder)
        iterator = iterator_utils.get_infer_iterator(
            src_dataset,
            src_vocab_table,
            batch_size=batch_size_placeholder,
            eos=hparams.eos,
            src_max_len=hparams.src_max_len_infer)
        model = model_creator(
            hparams,
            iterator=iterator,
            mode=tf.contrib.learn.ModeKeys.INFER,
            source_vocab_table=src_vocab_table,
            target_vocab_table=tgt_vocab_table,
            reverse_target_vocab_table=reverse_tgt_vocab_table,
            scope=scope,
            extra_args=extra_args)

        # Debug
        # with tf.Session() as sess:
        #     # init
        #     sess.run(
        #         iterator.initializer,
        #         feed_dict={
        #             src_placeholder: iterator.infer_data,
        #             batch_size_placeholder: 64
        #         })
        #     value = sess.run(iterator.source)
        #     print ("value:", value)
        # sys.exit()

    return InferModel(graph=graph,
                      model=model,
                      src_placeholder=src_placeholder,
                      batch_size_placeholder=batch_size_placeholder,
                      iterator=iterator,
                      insert_op=(src_vocab_table.init, tgt_vocab_table.init,
                                 reverse_tgt_vocab_table.init))
示例#22
0
def build_dataset(file_pattern,
                  input_config,
                  batch_size,
                  include_labels=True,
                  reverse_time_series_prob=0,
                  shuffle_filenames=False,
                  shuffle_values_buffer=0,
                  repeat=1,
                  use_tpu=False):
    """Builds an input pipeline that reads a dataset from sharded TFRecord files.

    Args:
      file_pattern: File pattern matching input TFRecord files, e.g.
          "/tmp/train-?????-of-00100". May also be a comma-separated list of file
          patterns.
      input_config: ConfigDict containing feature and label specifications.
      batch_size: The number of examples per batch.
      include_labels: Whether to read labels from the input files.
      reverse_time_series_prob: If > 0, the time series features will be randomly
          reversed with this probability. Within a given example, either all time
          series features will be reversed, or none will be reversed.
      shuffle_filenames: Whether to shuffle the order of TFRecord files between
          epochs.
      shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
      repeat: The number of times to repeat the dataset. If None or -1 the dataset
          will repeat indefinitely.
      use_tpu: Whether to build the dataset for TPU.

    Raises:
      ValueError: If an input file pattern does not match any files, or if the
          label IDs in input_config.label_map are not contiguous integers starting
          at 0.

    Returns:
      A tf.data.Dataset object.
    """
    file_patterns = file_pattern.split(",")
    filenames = []
    for p in file_patterns:
        matches = tf.io.gfile.glob(p)
        if not matches:
            raise ValueError("Found no input files matching %s" % p)
        filenames.extend(matches)
    tf.compat.v1.logging.info("Building input pipeline from %d files matching patterns: %s",
                              len(filenames), file_patterns)

    if include_labels:
        # Ensure that the label ids are contiguous integers starting at 0.
        label_ids = set(input_config.label_map.values())
        if label_ids != set(range(len(label_ids))):
            raise ValueError(
                "Label IDs must be contiguous integers starting at 0. Got: %s" %
                label_ids)

        # Create a HashTable mapping label strings to integer ids.
        table_initializer = lookup_ops.KeyValueTensorInitializer(
            keys=list(input_config.label_map.keys()),
            values=list(input_config.label_map.values()),
            key_dtype=tf.string,
            value_dtype=tf.int32)
        label_to_id = lookup_ops.HashTable(
            table_initializer, default_value=-1)

    def _example_parser(serialized_example):
        """Parses a single tf.Example into feature and label tensors."""
        # Set specifications for parsing the features.
        data_fields = {
            feature_name: tf.io.FixedLenFeature([feature.length], tf.float32)
            for feature_name, feature in input_config.features.items()
        }
        if include_labels:
            data_fields[input_config.label_feature] = tf.io.FixedLenFeature([],
                                                                            tf.string)

        # Parse the features.
        parsed_features = tf.io.parse_single_example(
            serialized=serialized_example, features=data_fields)

        if reverse_time_series_prob > 0:
            # Randomly reverse time series features with probability
            # reverse_time_series_prob.
            should_reverse = tf.less(
                tf.random.uniform([], 0, 1),
                reverse_time_series_prob,
                name="should_reverse")

        # Reorganize outputs.
        output = {}
        for feature_name, value in parsed_features.items():
            if include_labels and feature_name == input_config.label_feature:
                label_id = label_to_id.lookup(value)
                # Ensure that the label_id is nonnegative to verify a successful hash
                # map lookup.
                assert_known_label = tf.Assert(
                    tf.greater_equal(label_id, tf.cast(0, dtype=tf.int32)),
                    ["Unknown label string:", value])
                with tf.control_dependencies([assert_known_label]):
                    label_id = tf.identity(label_id)

                # We use the plural name "labels" in the output due to batching.
                output["labels"] = label_id
            elif input_config.features[feature_name].is_time_series:
                # Possibly reverse.
                if reverse_time_series_prob > 0:
                    # pylint:disable=cell-var-from-loop
                    value = tf.cond(pred=should_reverse, true_fn=lambda: tf.reverse(value, axis=[0]),
                                    false_fn=lambda: tf.identity(value))
                    # pylint:enable=cell-var-from-loop
                if "time_series_features" not in output:
                    output["time_series_features"] = {}
                output["time_series_features"][feature_name] = value
            else:
                if "aux_features" not in output:
                    output["aux_features"] = {}
                output["aux_features"][feature_name] = value

        return output

    # Create a string dataset of filenames, and possibly shuffle.
    filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
    if len(filenames) > 1 and shuffle_filenames:
        filename_dataset = filename_dataset.shuffle(len(filenames))

    # Read serialized Example protos.
    dataset = filename_dataset.flat_map(tf.data.TFRecordDataset)

    # Possibly shuffle. Note that we shuffle before repeat(), so we only shuffle
    # elements among each "epoch" of data, and not across epochs of data.
    if shuffle_values_buffer > 0:
        dataset = dataset.shuffle(shuffle_values_buffer)

    # Repeat.
    if repeat != 1:
        dataset = dataset.repeat(repeat)

    # Map the parser over the dataset.
    dataset = dataset.map(_example_parser, num_parallel_calls=4)

    # Batch results by up to batch_size.
    dataset = dataset.batch(batch_size)
    if repeat == -1 or repeat is None:
        # The dataset repeats infinitely before batching, so each batch has the
        # maximum number of elements.
        dataset = set_batch_size(dataset, batch_size)
    elif use_tpu:
        # TPU requires all dimensions to be fixed. Since the dataset does not repeat
        # infinitely before batching, the final batch may have fewer than batch_size
        # elements. Therefore we pad to ensure that the final batch has batch_size
        # elements.
        dataset = pad_dataset_to_batch_size(dataset, batch_size)

    # Prefetch a few batches.
    dataset = dataset.prefetch(max(1, int(256 / batch_size)))

    return dataset