Exemple #1
0
 def testTFRecordDatasetConstructorErrorsTensorInput(self):
     with self.assertRaisesRegex(TypeError,
                                 "filenames.*must be.*Tensor.*string"):
         readers.TFRecordDataset([1, 2, 3])
     with self.assertRaisesRegex(TypeError,
                                 "filenames.*must be.*Tensor.*string"):
         readers.TFRecordDataset(constant_op.constant([1, 2, 3]))
     # convert_to_tensor raises different errors in graph and eager
     with self.assertRaises(Exception):
         readers.TFRecordDataset(object())
 def testConstructorErrorsTensorInput(self):
     with self.assertRaisesRegex(
             TypeError,
             "The `filenames` argument must contain `tf.string` elements. Got "
             "`tf.int32` elements."):
         readers.TFRecordDataset([1, 2, 3])
     with self.assertRaisesRegex(
             TypeError,
             "The `filenames` argument must contain `tf.string` elements. Got "
             "`tf.int32` elements."):
         readers.TFRecordDataset(constant_op.constant([1, 2, 3]))
     # convert_to_tensor raises different errors in graph and eager
     with self.assertRaises(Exception):
         readers.TFRecordDataset(object())
    def make_dataset(self,
                     num_epochs,
                     batch_size=1,
                     compression_type=None,
                     buffer_size=None):
        filenames = self._createFiles()
        if compression_type == "ZLIB":
            zlib_files = []
            for i, fn in enumerate(filenames):
                with open(fn, "rb") as f:
                    cdata = zlib.compress(f.read())
                    zfn = os.path.join(self.get_temp_dir(),
                                       "tfrecord_%s.z" % i)
                    with open(zfn, "wb") as f:
                        f.write(cdata)
                    zlib_files.append(zfn)
            filenames = zlib_files

        elif compression_type == "GZIP":
            gzip_files = []
            for i, fn in enumerate(self._filenames):
                with open(fn, "rb") as f:
                    gzfn = os.path.join(self.get_temp_dir(),
                                        "tfrecord_%s.gz" % i)
                    with gzip.GzipFile(gzfn, "wb") as gzf:
                        gzf.write(f.read())
                    gzip_files.append(gzfn)
            filenames = gzip_files

        return readers.TFRecordDataset(
            filenames, compression_type,
            buffer_size=buffer_size).repeat(num_epochs).batch(batch_size)
  def testName(self):
    files = [self._filenames[0]]

    expected_output = [self._record(0, i) for i in range(self._num_records)]
    ds = readers.TFRecordDataset(files, name="tf_record_dataset")
    self.assertDatasetProduces(
        ds, expected_output=expected_output, assert_items_equal=True)
  def testPathlib(self):
    files = [pathlib.Path(self._filenames[0])]

    expected_output = [self._record(0, i) for i in range(self._num_records)]
    ds = readers.TFRecordDataset(files)
    self.assertDatasetProduces(
        ds, expected_output=expected_output, assert_items_equal=True)
Exemple #6
0
 def testReadFromDatasetOfFiles(self):
     files = dataset_ops.Dataset.from_tensor_slices(self.test_filenames)
     expected_output = []
     for j in range(self._num_files):
         expected_output.extend(
             [self._record(j, i) for i in range(self._num_records)])
     dataset = readers.TFRecordDataset(files)
     self.assertDatasetProduces(dataset, expected_output=expected_output)
    def testDirectFilenameTFRecordReaderPipeline(self):
        dataset = core_readers.TFRecordDataset(self._filenames)
        dataset = distribute._AutoShardDataset(dataset, 5, 0)

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in (0, 5) for r in range(0, 10)
        ]
        self.assertDatasetProduces(dataset, expected)
Exemple #8
0
    def testZip(self):
        dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
        dataset2 = readers.TextLineDataset(self._createTextFiles())
        dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f))
        self._verifySimpleShardingOutput(dataset, record_fn)
Exemple #9
0
 def testReadWithBuffer(self):
     one_mebibyte = 2**20
     dataset = readers.TFRecordDataset(self.test_filenames,
                                       buffer_size=one_mebibyte)
     expected_output = []
     for j in range(self._num_files):
         expected_output.extend(
             [self._record(j, i) for i in range(self._num_records)])
     self.assertDatasetProduces(dataset, expected_output=expected_output)
Exemple #10
0
 def testReadTenEpochsFromDatasetOfFilesInParallel(self):
   files = dataset_ops.Dataset.from_tensor_slices(
       self.test_filenames).repeat(10)
   expected_output = []
   for j in range(self._num_files):
     expected_output.extend(
         [self._record(j, i) for i in range(self._num_records)])
   dataset = readers.TFRecordDataset(files, num_parallel_reads=4)
   self.assertDatasetProduces(
       dataset, expected_output=expected_output * 10, assert_items_equal=True)
  def setUp(self):
    super(TFRecordWriterTest, self).setUp()
    self._num_records = 7
    self.filename = array_ops.placeholder(dtypes.string, shape=[])
    self.compression_type = array_ops.placeholder_with_default("", shape=[])

    input_dataset = readers.TFRecordDataset([self.filename],
                                            self.compression_type)
    self.writer = writers.TFRecordWriter(
        self._outputFilename(), self.compression_type).write(input_dataset)
 def testReadWithBuffer(self):
   one_mebibyte = 2**20
   d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte)
   iterator = d.make_one_shot_iterator()
   with self.test_session() as sess:
     for j in range(self._num_files):
       for i in range(self._num_records):
         self.assertAllEqual(self._record(j, i), sess.run(iterator.get_next()))
     with self.assertRaises(errors.OutOfRangeError):
       sess.run(iterator.get_next())
 def testReadFromDatasetOfFiles(self):
   files = dataset_ops.Dataset.from_tensor_slices(self.test_filenames)
   d = readers.TFRecordDataset(files)
   iterator = d.make_one_shot_iterator()
   next_element = iterator.get_next()
   with self.test_session() as sess:
     for j in range(self._num_files):
       for i in range(self._num_records):
         self.assertAllEqual(self._record(j, i), sess.run(next_element))
     with self.assertRaises(errors.OutOfRangeError):
       sess.run(next_element)
Exemple #14
0
    def _dataset_factory(self,
                         filenames,
                         compression_type="",
                         num_epochs=1,
                         batch_size=None):

        repeat_dataset = readers.TFRecordDataset(
            filenames, compression_type).repeat(num_epochs)
        if batch_size:
            return repeat_dataset.batch(batch_size)
        return repeat_dataset
Exemple #15
0
    def __init__(self, filenames, compression_type=None, buffer_size=None):
        """Creates a `TFRecordDataset`.

    Args:
      filenames: A `tf.string` tensor containing one or more filenames.
      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
        bytes in the read buffer. 0 means no buffering.
    """
        dataset = readers.TFRecordDataset(filenames, compression_type,
                                          buffer_size)
        super(TFRecordDataset, self).__init__(dataset)
Exemple #16
0
    def testAsFunctionFromReader(self):
        with ops.device("CPU"):
            file_path = os.path.join(
                self.get_temp_dir(),
                "{}.tfrecord.gz".format("tf_record_asset"))
            with tf_record.TFRecordWriter(file_path, "GZIP") as f:
                for v in ["a", "aa", "aaa"]:
                    f.write(str(v))
            original_dataset = readers.TFRecordDataset([file_path],
                                                       compression_type="GZIP")
            fn = original_dataset._trace_variant_creation()
            variant = fn()

            revived_dataset = dataset_ops._VariantDataset(
                variant, original_dataset.element_spec)
            self.assertDatasetProduces(revived_dataset, ["a", "aa", "aaa"])
Exemple #17
0
 def testReadWithEquivalentDataset(self):
     features = {
         "file": parsing_ops.FixedLenFeature([], dtypes.int64),
         "record": parsing_ops.FixedLenFeature([], dtypes.int64),
     }
     dataset = (core_readers.TFRecordDataset(
         self._filenames).map(lambda x: parsing_ops.parse_single_example(
             x, features)).repeat(10).batch(2))
     next_element = self.getNext(dataset)
     for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
             range(self._num_files), 2, 10):
         actual_batch = self.evaluate(next_element())
         self.assertAllEqual(file_batch, actual_batch["file"])
         self.assertAllEqual(record_batch, actual_batch["record"])
     with self.assertRaises(errors.OutOfRangeError):
         self.evaluate(next_element())
Exemple #18
0
 def testReadTenEpochsFromDatasetOfFilesInParallel(self):
     files = dataset_ops.Dataset.from_tensor_slices(
         self.test_filenames).repeat(10)
     d = readers.TFRecordDataset(files, num_parallel_reads=4)
     iterator = d.make_one_shot_iterator()
     next_element = iterator.get_next()
     expected = []
     actual = []
     with self.test_session() as sess:
         for _ in range(10):
             for j in range(self._num_files):
                 for i in range(self._num_records):
                     expected.append(self._record(j, i))
                     actual.append(sess.run(next_element))
         with self.assertRaises(errors.OutOfRangeError):
             sess.run(next_element)
         self.assertEqual(sorted(expected), sorted(actual))
Exemple #19
0
  def testConcat(self):
    dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
    dataset2 = readers.TextLineDataset(self._createTextFiles())

    dataset = dataset1.concatenate(dataset2)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    next_element_fn = self._getNext(dataset)
    for f in range(self._shard_index, self._num_files, self._num_shards):
      for r in range(self._num_records):
        self.assertAllEqual(
            self._record(r, f), self.evaluate(next_element_fn()))
    for f in range(self._shard_index, self._num_files, self._num_shards):
      for r in range(self._num_records):
        self.assertAllEqual(
            self._text_line(r, f), self.evaluate(next_element_fn()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(next_element_fn())
  def setUp(self):
    super(TFRecordDatasetTestBase, self).setUp()
    self._num_files = 2
    self._num_records = 7

    self.test_filenames = self._createFiles()

    self.filenames = array_ops.placeholder(dtypes.string, shape=[None])
    self.num_epochs = array_ops.placeholder_with_default(
        constant_op.constant(1, dtypes.int64), shape=[])
    self.compression_type = array_ops.placeholder_with_default("", shape=[])
    self.batch_size = array_ops.placeholder(dtypes.int64, shape=[])

    repeat_dataset = core_readers.TFRecordDataset(
        self.filenames, self.compression_type).repeat(self.num_epochs)
    batch_dataset = repeat_dataset.batch(self.batch_size)

    iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
    self.init_op = iterator.make_initializer(repeat_dataset)
    self.init_batch_op = iterator.make_initializer(batch_dataset)
    self.get_next = iterator.get_next()
    def testReadWithEquivalentDataset(self):
        features = {
            "file": parsing_ops.FixedLenFeature([], dtypes.int64),
            "record": parsing_ops.FixedLenFeature([], dtypes.int64),
        }
        dataset = (core_readers.TFRecordDataset(self.test_filenames).map(
            lambda x: parsing_ops.parse_single_example(x, features)).repeat(
                10).batch(2))
        iterator = dataset.make_initializable_iterator()
        init_op = iterator.initializer
        next_element = iterator.get_next()

        with self.test_session() as sess:
            sess.run(init_op)
            for file_batch, _, _, _, record_batch in self._next_expected_batch(
                    range(self._num_files), 2, 10):
                actual_batch = sess.run(next_element)
                self.assertAllEqual(file_batch, actual_batch["file"])
                self.assertAllEqual(record_batch, actual_batch["record"])
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)
Exemple #22
0
    def testConcat(self):
        dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
        dataset2 = readers.TextLineDataset(self._createTextFiles())
        dataset = dataset1.concatenate(dataset2)
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()
        with self.cached_session() as sess:
            for f in range(self._shard_index, self._num_files,
                           self._num_shards):
                for r in range(self._num_records):
                    self.assertAllEqual(self._record(r, f),
                                        sess.run(next_element))
            for f in range(self._shard_index, self._num_files,
                           self._num_shards):
                for r in range(self._num_records):
                    self.assertAllEqual(self._text_line(r, f),
                                        sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)
    def testShard(self):
        filename = self._createFile()
        dataset = readers.TFRecordDataset([filename])

        def reduce_func(key, dataset):
            shard_filename = string_ops.string_join(
                [filename, string_ops.as_string(key)])
            writer = writers.TFRecordWriter(shard_filename)
            writer.write(dataset.map(lambda _, x: x))
            return dataset_ops.Dataset.from_tensors(shard_filename)

        dataset = dataset.enumerate()
        dataset = dataset.apply(
            grouping.group_by_window(lambda i, _: i % 2, reduce_func,
                                     dtypes.int64.max))

        get_next = self.getNext(dataset)
        for i in range(2):
            shard_filename = (filename + str(i)).encode()
            self.assertEqual(self.evaluate(get_next()), shard_filename)
            for j, r in enumerate(
                    tf_record.tf_record_iterator(shard_filename)):
                self.assertAllEqual(self._record(i + 2 * j), r)
    def testTFRecordDatasetIgnoreError(self):
        filenames = []
        for i in range(5):
            fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
            filenames.append(fn)
            writer = python_io.TFRecordWriter(fn)
            for _ in range(10):
                writer.write(b"record")
            writer.close()
            # Append corrupted data
            with open(fn, "a") as f:
                f.write("corrupted data")

        dataset = readers.TFRecordDataset(filenames).apply(
            error_ops.ignore_errors())
        get_next = self.getNext(dataset)

        # All of the files are present.
        for _ in filenames:
            for _ in range(10):
                self.assertEqual(b"record", self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())
Exemple #25
0
def _TFRecordDataset(filename: Text) -> dataset_ops.Dataset:
    buffer_size = 8 * 1024 * 1024  # 8 MiB per file
    dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size)
    return dataset
Exemple #26
0
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):

  def testAsSerializedGraph(self):
    dataset = dataset_ops.Dataset.range(10)
    graph = graph_pb2.GraphDef().FromString(
        self.evaluate(dataset._as_serialized_graph()))
    self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))

  @staticmethod
  def make_apply_fn(dataset):

    def apply_fn(dataset):

      def _apply_fn(dataset):
        return dataset.cache()

      return dataset.apply(_apply_fn)

    return apply_fn

  @staticmethod
  def make_gen():

    def gen():
      yield 42

    return gen

  @staticmethod
  def make_interleave_fn(dataset, num_parallel_calls=None):

    def interleave_fn(dataset):
      return dataset.interleave(
          lambda x: dataset_ops.Dataset.range(0),
          cycle_length=2,
          num_parallel_calls=num_parallel_calls)

    return interleave_fn

  @parameterized.named_parameters(
      ("FixedLengthRecord",
       lambda: readers.FixedLengthRecordDataset("", 42)),
      ("FromGenerator",
       lambda: dataset_ops.Dataset.from_generator(
           DatasetTest.make_gen(), dtypes.int32),
       1),
      ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
      ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
      ("Range", lambda: dataset_ops.Dataset.range(10)),
      ("TextLine", lambda: readers.TextLineDataset("")),
      ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
  )
  def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
    self.assertEqual(num_inputs, len(dataset_fn()._inputs()))

  def testDatasetComplexSourceInputs(self):
    dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
        sparse_tensor.SparseTensor(
            indices=np.array([[0, 0], [1, 0], [2, 0]]),
            values=np.array([0, 0, 0]),
            dense_shape=np.array([3, 1])))
    self.assertEqual(0, len(dataset_fn._inputs()))

  @parameterized.named_parameters(
      ("Batch",
       lambda x: x.batch(10),
       lambda: dataset_ops.Dataset.range(0)),
      ("Cache",
       lambda x: x.cache(),
       lambda: dataset_ops.Dataset.range(0)),
      ("Filter",
       lambda x: x.filter(lambda x: True),
       lambda: dataset_ops.Dataset.range(0)),
      ("FlatMap",
       lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
       lambda: dataset_ops.Dataset.range(0)),
      ("Map",
       lambda x: x.map(lambda x: x),
       lambda: dataset_ops.Dataset.range(0)),
      ("PaddedBatch",
       lambda x: x.padded_batch(10, []),
       lambda: dataset_ops.Dataset.range(0)),
      ("ParallelMap",
       lambda x: x.map(lambda x: x, num_parallel_calls=2),
       lambda: dataset_ops.Dataset.range(0)),
      ("Repeat",
       lambda x: x.repeat(),
       lambda: dataset_ops.Dataset.range(0)),
      ("Shuffle",
       lambda x: x.shuffle(10),
       lambda: dataset_ops.Dataset.range(0)),
      ("Skip",
       lambda x: x.skip(1),
       lambda: dataset_ops.Dataset.range(0)),
      ("Take",
       lambda x: x.take(1),
       lambda: dataset_ops.Dataset.range(0)),
      ("Window",
       lambda x: x.window(10),
       lambda: dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
    input_dataset = input_dataset_fn()
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  def testUnaryTransformationInputsApply(self):
    input_dataset = dataset_ops.Dataset.range(0)
    dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  @parameterized.named_parameters(
      ("ParallelInterleave",
       [lambda: dataset_ops.Dataset.range(0), 2],
       lambda: dataset_ops.Dataset.range(0)),
      ("Interleave",
       [lambda: dataset_ops.Dataset.range(0), None],
       lambda: dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputsWithInterleaveFn(
      self, interleave_fn_args, input_dataset_fn):
    input_dataset = input_dataset_fn()
    dataset_fn = self.make_interleave_fn(*interleave_fn_args)
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  @parameterized.named_parameters(
      ("Concatenate", lambda x, y: x.concatenate(y),
       lambda: dataset_ops.Dataset.range(0),
       lambda: dataset_ops.Dataset.range(1)))
  def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
    input1 = input1_fn()
    input2 = input2_fn()
    self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())

  @parameterized.named_parameters(
      ("ZipOne",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0))),
      ("ZipNest",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0),
                (dataset_ops.Dataset.range(1),
                 dataset_ops.Dataset.range(2)))),
      ("ZipTuple",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0),
                dataset_ops.Dataset.range(1))),
  )
  def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
    input_datasets = input_datasets_fn()
    self.assertEqual(
        nest.flatten(input_datasets),
        dataset_fn(input_datasets)._inputs())

  def testCollectInputs(self):
    ds1 = dataset_ops.Dataset.range(0)
    ds2 = ds1.concatenate(ds1)
    ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

    inputs = []
    queue = [ds3]
    while queue:
      ds = queue[0]
      queue = queue[1:]
      queue.extend(ds._inputs())
      inputs.append(ds)

    self.assertEqual(5, inputs.count(ds1))
    self.assertEqual(2, inputs.count(ds2))
    self.assertEqual(1, inputs.count(ds3))

  # TODO(b/119882922): use-after-free bug in eager mode.
  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0),
       structure.TensorStructure(dtypes.float32, [])),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
          dense_shape=[1]),
       structure.SparseTensorStructure(dtypes.int32, [1])),
      ("Nest", lambda: {
          "a": constant_op.constant(37.0),
          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
       structure.NestedStructure({
           "a": structure.TensorStructure(dtypes.float32, []),
           "b": (structure.TensorStructure(dtypes.string, [1]),
                 structure.TensorStructure(dtypes.string, []))})),
      ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
          constant_op.constant([1, 2, 3])),
       dataset_ops.DatasetStructure(
           structure.TensorStructure(dtypes.int32, []))),
      ("Optional", lambda: optional_ops.Optional.from_value(37.0),
       optional_ops.OptionalStructure(
           structure.TensorStructure(dtypes.float32, []))),
  )
  def testSkipEagerDatasetStructure(self, tf_value_fn,
                                    expected_element_structure):
    dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn())
    dataset_structure = structure.Structure.from_value(dataset)
    self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)

    # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
    # the element structure.
    self.assertTrue(expected_element_structure.is_compatible_with(
        dataset_structure._element_structure))
    self.assertTrue(dataset_structure._element_structure.is_compatible_with(
        expected_element_structure))

    self.assertEqual([dtypes.variant], dataset_structure._flat_types)
    self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes)

    # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
    # and _to_tensor_list().
    round_trip_dataset = dataset_structure._from_tensor_list(
        dataset_structure._to_tensor_list(dataset))

    value = tf_value_fn()

    if isinstance(value, dataset_ops.Dataset):
      self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
    elif isinstance(value, optional_ops.Optional):
      self.assertDatasetProduces(
          round_trip_dataset.map(lambda opt: opt.get_value()),
          [self.evaluate(value.get_value())],
          requires_initialization=True)
    else:
      self.assertDatasetProduces(
          round_trip_dataset, [self.evaluate(tf_value_fn())],
          requires_initialization=True)

  @test_util.run_deprecated_v1
  def testSkipEagerSameGraphErrorOneShot(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      dataset = dataset.batch(2)
      with test.mock.patch.object(logging, "warning") as mock_log:
        _ = dataset.make_one_shot_iterator()
        self.assertRegexpMatches(
            str(mock_log.call_args), "Please ensure that all datasets in the "
            "pipeline are created in the same graph as the iterator.")

  @test_util.run_deprecated_v1
  def testSkipEagerSameGraphErrorOneShotSimple(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      with test.mock.patch.object(logging, "warning") as mock_log:
        _ = dataset.make_one_shot_iterator()
        self.assertRegexpMatches(
            str(mock_log.call_args), "Please ensure that all datasets in the "
            "pipeline are created in the same graph as the iterator.")

  @test_util.run_deprecated_v1
  def testSkipEagerSameGraphErrorInitializable(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      dataset = dataset.batch(2)
      with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
        _ = dataset.make_initializable_iterator()
Exemple #27
0
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):

  def testAsSerializedGraph(self):
    dataset = dataset_ops.Dataset.range(10)
    graph = graph_pb2.GraphDef().FromString(
        self.evaluate(dataset._as_serialized_graph()))
    self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))

  def testAsFunctionWithMap(self):
    if not context.executing_eagerly():
      self.skipTest("Only works executing eagerly")
    with ops.device("CPU"):
      original_dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
      fn = original_dataset._trace_variant_creation()
      variant = fn()

      revived_dataset = _RevivedDataset(
          variant, original_dataset._element_structure)
      self.assertDatasetProduces(revived_dataset, range(0, 10, 2))

  def testAsFunctionWithMapInFlatMap(self):
    if not context.executing_eagerly():
      self.skipTest("Only works executing eagerly")
    with ops.device("CPU"):
      original_dataset = dataset_ops.Dataset.range(5).flat_map(
          lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2))
      fn = original_dataset._trace_variant_creation()
      variant = fn()

      revived_dataset = _RevivedDataset(
          variant, original_dataset._element_structure)
      self.assertDatasetProduces(revived_dataset, list(original_dataset))

  @staticmethod
  def make_apply_fn(dataset):

    def apply_fn(dataset):

      def _apply_fn(dataset):
        return dataset.cache()

      return dataset.apply(_apply_fn)

    return apply_fn

  @staticmethod
  def make_gen():

    def gen():
      yield 42

    return gen

  @staticmethod
  def make_interleave_fn(dataset, num_parallel_calls=None):

    def interleave_fn(dataset):
      return dataset.interleave(
          lambda x: dataset_ops.Dataset.range(0),
          cycle_length=2,
          num_parallel_calls=num_parallel_calls)

    return interleave_fn

  @parameterized.named_parameters(
      ("FixedLengthRecord",
       lambda: readers.FixedLengthRecordDataset("", 42)),
      ("FromGenerator",
       lambda: dataset_ops.Dataset.from_generator(
           DatasetTest.make_gen(), dtypes.int32),
       1),
      ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
      ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
      ("Range", lambda: dataset_ops.Dataset.range(10)),
      ("TextLine", lambda: readers.TextLineDataset("")),
      ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
  )
  def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
    self.assertLen(dataset_fn()._inputs(), num_inputs)

  @test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
  def testDatasetComplexSourceInputs(self):
    dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
        sparse_tensor.SparseTensor(
            indices=np.array([[0, 0], [1, 0], [2, 0]]),
            values=np.array([0, 0, 0]),
            dense_shape=np.array([3, 1])))
    self.assertEmpty(dataset_fn._inputs())

  @parameterized.named_parameters(
      ("Batch",
       lambda x: x.batch(10),
       lambda: dataset_ops.Dataset.range(0)),
      ("Cache",
       lambda x: x.cache(),
       lambda: dataset_ops.Dataset.range(0)),
      ("Filter",
       lambda x: x.filter(lambda x: True),
       lambda: dataset_ops.Dataset.range(0)),
      ("FlatMap",
       lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
       lambda: dataset_ops.Dataset.range(0)),
      ("Map",
       lambda x: x.map(lambda x: x),
       lambda: dataset_ops.Dataset.range(0)),
      ("PaddedBatch",
       lambda x: x.padded_batch(10, []),
       lambda: dataset_ops.Dataset.range(0)),
      ("ParallelMap",
       lambda x: x.map(lambda x: x, num_parallel_calls=2),
       lambda: dataset_ops.Dataset.range(0)),
      ("Repeat",
       lambda x: x.repeat(),
       lambda: dataset_ops.Dataset.range(0)),
      ("Shuffle",
       lambda x: x.shuffle(10),
       lambda: dataset_ops.Dataset.range(0)),
      ("Skip",
       lambda x: x.skip(1),
       lambda: dataset_ops.Dataset.range(0)),
      ("Take",
       lambda x: x.take(1),
       lambda: dataset_ops.Dataset.range(0)),
      ("Window",
       lambda x: x.window(10),
       lambda: dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
    input_dataset = input_dataset_fn()
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  def testUnaryTransformationInputsApply(self):
    input_dataset = dataset_ops.Dataset.range(0)
    dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  @parameterized.named_parameters(
      ("ParallelInterleave",
       [lambda: dataset_ops.Dataset.range(0), 2],
       lambda: dataset_ops.Dataset.range(0)),
      ("Interleave",
       [lambda: dataset_ops.Dataset.range(0), None],
       lambda: dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputsWithInterleaveFn(
      self, interleave_fn_args, input_dataset_fn):
    input_dataset = input_dataset_fn()
    dataset_fn = self.make_interleave_fn(*interleave_fn_args)
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  def testNoWarnings(self):
    with test.mock.patch.object(warnings, "warn") as mock_log:
      dataset_fn = self.make_interleave_fn(dataset_ops.Dataset.range(10))
      dataset_fn(dataset_ops.Dataset.range(10))
      self.assertEmpty(mock_log.call_args_list)

  @parameterized.named_parameters(
      ("Concatenate", lambda x, y: x.concatenate(y),
       lambda: dataset_ops.Dataset.range(0),
       lambda: dataset_ops.Dataset.range(1)))
  def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
    input1 = input1_fn()
    input2 = input2_fn()
    self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())

  @parameterized.named_parameters(
      ("ZipOne",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0))),
      ("ZipNest",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0),
                (dataset_ops.Dataset.range(1),
                 dataset_ops.Dataset.range(2)))),
      ("ZipTuple",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0),
                dataset_ops.Dataset.range(1))),
  )
  def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
    input_datasets = input_datasets_fn()
    self.assertEqual(
        nest.flatten(input_datasets),
        dataset_fn(input_datasets)._inputs())

  def testFunctions(self):
    dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
    self.assertLen(dataset._functions(), 1)

  def testCollectInputs(self):
    ds1 = dataset_ops.Dataset.range(0)
    ds2 = ds1.concatenate(ds1)
    ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

    inputs = []
    queue = [ds3]
    while queue:
      ds = queue[0]
      queue = queue[1:]
      queue.extend(ds._inputs())
      inputs.append(ds)

    self.assertEqual(5, inputs.count(ds1))
    self.assertEqual(2, inputs.count(ds2))
    self.assertEqual(1, inputs.count(ds3))

  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0),
       structure.TensorStructure(dtypes.float32, [])),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
          dense_shape=[1]),
       structure.SparseTensorStructure(dtypes.int32, [1])),
      ("Nest", lambda: {
          "a": constant_op.constant(37.0),
          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
       structure.NestedStructure({
           "a": structure.TensorStructure(dtypes.float32, []),
           "b": (structure.TensorStructure(dtypes.string, [1]),
                 structure.TensorStructure(dtypes.string, []))})),
      ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
          constant_op.constant([1, 2, 3])),
       dataset_ops.DatasetStructure(
           structure.TensorStructure(dtypes.int32, []))),
      ("Optional", lambda: optional_ops.Optional.from_value(37.0),
       optional_ops.OptionalStructure(
           structure.TensorStructure(dtypes.float32, []))),
  )
  def testDatasetStructure(self, tf_value_fn, expected_element_structure):
    dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn())
    dataset_structure = structure.Structure.from_value(dataset)
    self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)

    # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
    # the element structure.
    self.assertTrue(expected_element_structure.is_compatible_with(
        dataset_structure._element_structure))
    self.assertTrue(dataset_structure._element_structure.is_compatible_with(
        expected_element_structure))

    self.assertEqual([dtypes.variant], dataset_structure._flat_types)
    self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes)

    # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
    # and _to_tensor_list().
    round_trip_dataset = dataset_structure._from_tensor_list(
        dataset_structure._to_tensor_list(dataset))

    value = tf_value_fn()

    if isinstance(value, dataset_ops.Dataset):
      self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
    elif isinstance(value, optional_ops.Optional):
      self.assertDatasetProduces(
          round_trip_dataset.map(lambda opt: opt.get_value()),
          [self.evaluate(value.get_value())],
          requires_initialization=True)
    else:
      self.assertDatasetProduces(
          round_trip_dataset, [self.evaluate(tf_value_fn())],
          requires_initialization=True)

  @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
  def testSkipEagerSameGraphErrorOneShot(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
        dataset = dataset.batch(2)

  @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
  def testSkipEagerSameGraphErrorOneShotSimple(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      with test.mock.patch.object(logging, "warning") as mock_log:
        _ = dataset_ops.make_one_shot_iterator(dataset)
        self.assertRegexpMatches(
            str(mock_log.call_args), "Please ensure that all datasets in the "
            "pipeline are created in the same graph as the iterator.")

  @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
  def testSkipEagerSameGraphErrorInitializable(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
        dataset = dataset.batch(2)

  @parameterized.named_parameters(
      ("Async", context.ASYNC),
      ("Sync", context.SYNC),
  )
  def testDatasetEagerIteration(self, execution_mode):
    with context.eager_mode(), context.execution_mode(execution_mode):
      val = 0
      dataset = dataset_ops.Dataset.range(10)
      for foo in dataset:
        self.assertEqual(val, foo.numpy())
        val += 1

  def testDatasetAsFunctionArgument(self):

    @def_function.function
    def _uses_dataset(d):
      accumulator = array_ops.zeros([], dtype=dtypes.int64)
      for value in d:
        accumulator += value
      return accumulator

    with ops.device("CPU"):
      first_dataset = dataset_ops.Dataset.range(10)
      self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset)))
      second_dataset = dataset_ops.Dataset.range(11)
      self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset)))
      first_concrete = _uses_dataset.get_concrete_function(first_dataset)
      # The dataset should not be a captured input
      self.assertEmpty(first_concrete.graph.captures)
      # The two datasets have the same structure and so should re-use a trace.
      self.assertIs(first_concrete,
                    _uses_dataset.get_concrete_function(second_dataset))
      # With a different structure we should use a different trace.
      self.assertIsNot(
          first_concrete,
          _uses_dataset.get_concrete_function(
              dataset_ops.Dataset.zip((first_dataset, second_dataset))))

  def testLimitedRetracing(self):
    trace_count = [0]

    @def_function.function
    def f(ds):
      trace_count[0] += 1
      counter = np.int64(0)
      for elem in ds:
        counter += elem
      return counter

    dataset = dataset_ops.Dataset.range(5)
    dataset2 = dataset_ops.Dataset.range(10)

    for _ in range(10):
      self.assertEqual(self.evaluate(f(dataset)), 10)
      self.assertEqual(self.evaluate(f(dataset2)), 45)
      self.assertEqual(trace_count[0], 1)
Exemple #28
0
 def testTFRecordInputs(self):
     dataset = readers.TFRecordDataset("")
     self.checkNumInputs(dataset, 1)
Exemple #29
0
def _TFRecordDataset(filename):
  buffer_size = 8 * 1024 * 1024  # 8 MiB per file
  dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size)
  return dataset
Exemple #30
0
def make_tf_record_dataset(file_pattern,
                           batch_size,
                           parser_fn=None,
                           num_epochs=None,
                           shuffle=True,
                           shuffle_buffer_size=None,
                           shuffle_seed=None,
                           prefetch_buffer_size=optimization.AUTOTUNE,
                           num_parallel_reads=None,
                           num_parallel_parser_calls=None,
                           drop_final_batch=False):
  """Reads and optionally parses TFRecord files into a dataset.

  Provides common functionality such as batching, optional parsing, shuffling,
  and performant defaults.

  Args:
    file_pattern: List of files or patterns of TFRecord file paths.
      See `tf.io.gfile.glob` for pattern rules.
    batch_size: An int representing the number of records to combine
      in a single batch.
    parser_fn: (Optional.) A function accepting string input to parse
      and process the record contents. This function must map records
      to components of a fixed shape, so they may be batched. By
      default, uses the record contents unmodified.
    num_epochs: (Optional.) An int specifying the number of times this
      dataset is repeated.  If None (the default), cycles through the
      dataset forever.
    shuffle: (Optional.) A bool that indicates whether the input
      should be shuffled. Defaults to `True`.
    shuffle_buffer_size: (Optional.) Buffer size to use for
      shuffling. A large buffer size ensures better shuffling, but
      increases memory usage and startup time.
    shuffle_seed: (Optional.) Randomization seed to use for shuffling.
    prefetch_buffer_size: (Optional.) An int specifying the number of
      feature batches to prefetch for performance improvement.
      Defaults to auto-tune. Set to 0 to disable prefetching.
    num_parallel_reads: (Optional.) Number of threads used to read
      records from files. By default or if set to a value >1, the
      results will be interleaved.
    num_parallel_parser_calls: (Optional.) Number of parallel
      records to parse in parallel. Defaults to an automatic selection.
    drop_final_batch: (Optional.) Whether the last batch should be
      dropped in case its size is smaller than `batch_size`; the
      default behavior is not to drop the smaller batch.

  Returns:
    A dataset, where each element matches the output of `parser_fn`
    except it will have an additional leading `batch-size` dimension,
    or a `batch_size`-length 1-D tensor of strings if `parser_fn` is
    unspecified.
  """
  files = dataset_ops.Dataset.list_files(
      file_pattern, shuffle=shuffle, seed=shuffle_seed)

  if num_parallel_reads is None:
    # Note: We considered auto-tuning this value, but there is a concern
    # that this affects the mixing of records from different files, which
    # could affect training convergence/accuracy, so we are defaulting to
    # a constant for now.
    num_parallel_reads = 24
  dataset = core_readers.TFRecordDataset(
      files, num_parallel_reads=num_parallel_reads)

  if shuffle_buffer_size is None:
    # TODO(josh11b): Auto-tune this value when not specified
    shuffle_buffer_size = 10000
  dataset = _maybe_shuffle_and_repeat(
      dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)

  # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to
  # improve the shape inference, because it makes the batch dimension static.
  # It is safe to do this because in that case we are repeating the input
  # indefinitely, and all batches will be full-sized.
  drop_final_batch = drop_final_batch or num_epochs is None

  if parser_fn is None:
    dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
  else:
    # TODO(josh11b): if num_parallel_parser_calls is None, use some function
    # of num cores instead of map_and_batch's default behavior of one batch.
    dataset = dataset.apply(batching.map_and_batch(
        parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls,
        drop_remainder=drop_final_batch))

  if prefetch_buffer_size == 0:
    return dataset
  else:
    return dataset.prefetch(buffer_size=prefetch_buffer_size)