Esempio n. 1
0
    def testFixedLengthReader(self):
        dataset = readers.FixedLengthRecordDataset(
            self._createFixedLengthRecordFiles(), self._record_bytes)
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
Esempio n. 2
0
 def _build_reader_dataset_graph():
   filenames = ["test"]  # Does not exist but we don't care in this test.
   iterator = readers.FixedLengthRecordDataset(
       filenames, 1, 0, 0).make_initializable_iterator()
   init_op = iterator.initializer
   get_next_op = iterator.get_next()
   save_op = _save_op(iterator._iterator_resource)
   restore_op = _restore_op(iterator._iterator_resource)
   return init_op, get_next_op, save_op, restore_op
Esempio n. 3
0
    def testFixedLengthReaderWithFlatMap(self):
        dataset = dataset_ops.Dataset.from_tensor_slices(
            self._createFixedLengthRecordFiles())
        dataset = dataset.flat_map(
            lambda f: readers.FixedLengthRecordDataset(f, self._record_bytes))
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
Esempio n. 4
0
 def dataset_fn(filenames, num_epochs, batch_size=None):
     repeat_dataset = readers.FixedLengthRecordDataset(
         filenames,
         self._record_bytes,
         self._header_bytes,
         self._footer_bytes,
         compression_type=compression_type).repeat(num_epochs)
     if batch_size:
         return repeat_dataset.batch(batch_size)
     return repeat_dataset
 def _build_iterator_graph(self, num_epochs):
   filenames = self._createFiles()
   dataset = (readers.FixedLengthRecordDataset(
       filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
              .repeat(num_epochs))
   iterator = dataset.make_initializable_iterator()
   init_op = iterator.initializer
   get_next_op = iterator.get_next()
   save_op = self._save_op(iterator._iterator_resource)
   restore_op = self._restore_op(iterator._iterator_resource)
   return init_op, get_next_op, save_op, restore_op
Esempio n. 6
0
 def testFixedLengthRecordDatasetBuffering(self):
     test_filenames = self._createFiles()
     dataset = readers.FixedLengthRecordDataset(test_filenames,
                                                self._record_bytes,
                                                self._header_bytes,
                                                self._footer_bytes,
                                                buffer_size=10)
     expected_output = []
     for j in range(self._num_files):
         expected_output.extend(
             [self._record(j, i) for i in range(self._num_records)])
     self.assertDatasetProduces(dataset, expected_output=expected_output)
Esempio n. 7
0
 def _build_reader_dataset_graph():
     filenames = ["test"
                  ]  # Does not exist but we don't care in this test.
     path = _iterator_checkpoint_prefix()
     iterator = readers.FixedLengthRecordDataset(
         filenames, 1, 0, 0).make_initializable_iterator()
     init_op = iterator.initializer
     get_next_op = iterator.get_next()
     save_op = gen_dataset_ops.save_iterator(
         iterator._iterator_resource, path)
     restore_op = gen_dataset_ops.restore_iterator(
         iterator._iterator_resource, path)
     return init_op, get_next_op, save_op, restore_op
Esempio n. 8
0
 def testFixedLengthRecordDatasetParallelRead(self):
     test_filenames = self._createFiles()
     dataset = readers.FixedLengthRecordDataset(test_filenames,
                                                self._record_bytes,
                                                self._header_bytes,
                                                self._footer_bytes,
                                                buffer_size=10,
                                                num_parallel_reads=4)
     expected_output = []
     for j in range(self._num_files):
         expected_output.extend(
             [self._record(j, i) for i in range(self._num_records)])
     self.assertDatasetProduces(dataset,
                                expected_output=expected_output,
                                assert_items_equal=True)
Esempio n. 9
0
 def testFixedLengthRecordDatasetWrongSize(self):
     test_filenames = self._createFiles()
     dataset = readers.FixedLengthRecordDataset(
         test_filenames,
         self._record_bytes + 1,  # Incorrect record length.
         self._header_bytes,
         self._footer_bytes,
         buffer_size=10)
     self.assertDatasetProduces(
         dataset,
         expected_error=
         (errors.InvalidArgumentError,
          r"Excluding the header \(5 bytes\) and footer \(2 bytes\), input "
          r"file \".*fixed_length_record.0.txt\" has body length 21 bytes, "
          r"which is not an exact multiple of the record length \(4 bytes\)."
          ))
  def testFixedLengthRecordDatasetBuffering(self):
    test_filenames = self._createFiles()
    dataset = readers.FixedLengthRecordDataset(
        test_filenames,
        self._record_bytes,
        self._header_bytes,
        self._footer_bytes,
        buffer_size=10)
    iterator = dataset.make_one_shot_iterator()

    with self.test_session() as sess:
      for j in range(self._num_files):
        for i in range(self._num_records):
          self.assertEqual(self._record(j, i), sess.run(iterator.get_next()))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(iterator.get_next())
  def testFixedLengthRecordDatasetWrongSize(self):
    test_filenames = self._createFiles()
    dataset = readers.FixedLengthRecordDataset(
        test_filenames,
        self._record_bytes + 1,  # Incorrect record length.
        self._header_bytes,
        self._footer_bytes,
        buffer_size=10)
    iterator = dataset.make_one_shot_iterator()

    with self.test_session() as sess:
      with self.assertRaisesRegexp(
          errors.InvalidArgumentError,
          r"Excluding the header \(5 bytes\) and footer \(2 bytes\), input "
          r"file \".*fixed_length_record.0.txt\" has body length 21 bytes, "
          r"which is not an exact multiple of the record length \(4 bytes\)."):
        sess.run(iterator.get_next())
Esempio n. 12
0
    def __init__(self,
                 filenames,
                 record_bytes,
                 header_bytes=None,
                 footer_bytes=None,
                 buffer_size=None):
        """Creates a `FixedLengthRecordDataset`.

    Args:
      filenames: A `tf.string` tensor containing one or more filenames.
      record_bytes: A `tf.int64` scalar representing the number of bytes in
        each record.
      header_bytes: (Optional.) A `tf.int64` scalar representing the number of
        bytes to skip at the start of a file.
      footer_bytes: (Optional.) A `tf.int64` scalar representing the number of
        bytes to ignore at the end of a file.
      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
        bytes to buffer when reading.
    """
        dataset = readers.FixedLengthRecordDataset(filenames, record_bytes,
                                                   header_bytes, footer_bytes,
                                                   buffer_size)
        super(FixedLengthRecordDataset, self).__init__(dataset)
Esempio n. 13
0
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):

  def testAsSerializedGraph(self):
    dataset = dataset_ops.Dataset.range(10)
    graph = graph_pb2.GraphDef().FromString(
        self.evaluate(dataset._as_serialized_graph()))
    self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))

  @staticmethod
  def make_apply_fn(dataset):

    def apply_fn(dataset):

      def _apply_fn(dataset):
        return dataset.cache()

      return dataset.apply(_apply_fn)

    return apply_fn

  @staticmethod
  def make_gen():

    def gen():
      yield 42

    return gen

  @staticmethod
  def make_interleave_fn(dataset, num_parallel_calls=None):

    def interleave_fn(dataset):
      return dataset.interleave(
          lambda x: dataset_ops.Dataset.range(0),
          cycle_length=2,
          num_parallel_calls=num_parallel_calls)

    return interleave_fn

  @parameterized.named_parameters(
      ("FixedLengthRecord",
       lambda: readers.FixedLengthRecordDataset("", 42)),
      ("FromGenerator",
       lambda: dataset_ops.Dataset.from_generator(
           DatasetTest.make_gen(), dtypes.int32),
       1),
      ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
      ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
      ("Range", lambda: dataset_ops.Dataset.range(10)),
      ("TextLine", lambda: readers.TextLineDataset("")),
      ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
  )
  def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
    self.assertEqual(num_inputs, len(dataset_fn()._inputs()))

  def testDatasetComplexSourceInputs(self):
    dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
        sparse_tensor.SparseTensor(
            indices=np.array([[0, 0], [1, 0], [2, 0]]),
            values=np.array([0, 0, 0]),
            dense_shape=np.array([3, 1])))
    self.assertEqual(0, len(dataset_fn._inputs()))

  @parameterized.named_parameters(
      ("Batch",
       lambda x: x.batch(10),
       lambda: dataset_ops.Dataset.range(0)),
      ("Cache",
       lambda x: x.cache(),
       lambda: dataset_ops.Dataset.range(0)),
      ("Filter",
       lambda x: x.filter(lambda x: True),
       lambda: dataset_ops.Dataset.range(0)),
      ("FlatMap",
       lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
       lambda: dataset_ops.Dataset.range(0)),
      ("Map",
       lambda x: x.map(lambda x: x),
       lambda: dataset_ops.Dataset.range(0)),
      ("PaddedBatch",
       lambda x: x.padded_batch(10, []),
       lambda: dataset_ops.Dataset.range(0)),
      ("ParallelMap",
       lambda x: x.map(lambda x: x, num_parallel_calls=2),
       lambda: dataset_ops.Dataset.range(0)),
      ("Repeat",
       lambda x: x.repeat(),
       lambda: dataset_ops.Dataset.range(0)),
      ("Shuffle",
       lambda x: x.shuffle(10),
       lambda: dataset_ops.Dataset.range(0)),
      ("Skip",
       lambda x: x.skip(1),
       lambda: dataset_ops.Dataset.range(0)),
      ("Take",
       lambda x: x.take(1),
       lambda: dataset_ops.Dataset.range(0)),
      ("Window",
       lambda x: x.window(10),
       lambda: dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
    input_dataset = input_dataset_fn()
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  def testUnaryTransformationInputsApply(self):
    input_dataset = dataset_ops.Dataset.range(0)
    dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  @parameterized.named_parameters(
      ("ParallelInterleave",
       [lambda: dataset_ops.Dataset.range(0), 2],
       lambda: dataset_ops.Dataset.range(0)),
      ("Interleave",
       [lambda: dataset_ops.Dataset.range(0), None],
       lambda: dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputsWithInterleaveFn(
      self, interleave_fn_args, input_dataset_fn):
    input_dataset = input_dataset_fn()
    dataset_fn = self.make_interleave_fn(*interleave_fn_args)
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  @parameterized.named_parameters(
      ("Concatenate", lambda x, y: x.concatenate(y),
       lambda: dataset_ops.Dataset.range(0),
       lambda: dataset_ops.Dataset.range(1)))
  def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
    input1 = input1_fn()
    input2 = input2_fn()
    self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())

  @parameterized.named_parameters(
      ("ZipOne",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0))),
      ("ZipNest",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0),
                (dataset_ops.Dataset.range(1),
                 dataset_ops.Dataset.range(2)))),
      ("ZipTuple",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0),
                dataset_ops.Dataset.range(1))),
  )
  def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
    input_datasets = input_datasets_fn()
    self.assertEqual(
        nest.flatten(input_datasets),
        dataset_fn(input_datasets)._inputs())

  def testCollectInputs(self):
    ds1 = dataset_ops.Dataset.range(0)
    ds2 = ds1.concatenate(ds1)
    ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

    inputs = []
    queue = [ds3]
    while queue:
      ds = queue[0]
      queue = queue[1:]
      queue.extend(ds._inputs())
      inputs.append(ds)

    self.assertEqual(5, inputs.count(ds1))
    self.assertEqual(2, inputs.count(ds2))
    self.assertEqual(1, inputs.count(ds3))

  # TODO(b/119882922): use-after-free bug in eager mode.
  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0),
       structure.TensorStructure(dtypes.float32, [])),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
          dense_shape=[1]),
       structure.SparseTensorStructure(dtypes.int32, [1])),
      ("Nest", lambda: {
          "a": constant_op.constant(37.0),
          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
       structure.NestedStructure({
           "a": structure.TensorStructure(dtypes.float32, []),
           "b": (structure.TensorStructure(dtypes.string, [1]),
                 structure.TensorStructure(dtypes.string, []))})),
      ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
          constant_op.constant([1, 2, 3])),
       dataset_ops.DatasetStructure(
           structure.TensorStructure(dtypes.int32, []))),
      ("Optional", lambda: optional_ops.Optional.from_value(37.0),
       optional_ops.OptionalStructure(
           structure.TensorStructure(dtypes.float32, []))),
  )
  def testSkipEagerDatasetStructure(self, tf_value_fn,
                                    expected_element_structure):
    dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn())
    dataset_structure = structure.Structure.from_value(dataset)
    self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)

    # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
    # the element structure.
    self.assertTrue(expected_element_structure.is_compatible_with(
        dataset_structure._element_structure))
    self.assertTrue(dataset_structure._element_structure.is_compatible_with(
        expected_element_structure))

    self.assertEqual([dtypes.variant], dataset_structure._flat_types)
    self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes)

    # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
    # and _to_tensor_list().
    round_trip_dataset = dataset_structure._from_tensor_list(
        dataset_structure._to_tensor_list(dataset))

    value = tf_value_fn()

    if isinstance(value, dataset_ops.Dataset):
      self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
    elif isinstance(value, optional_ops.Optional):
      self.assertDatasetProduces(
          round_trip_dataset.map(lambda opt: opt.get_value()),
          [self.evaluate(value.get_value())],
          requires_initialization=True)
    else:
      self.assertDatasetProduces(
          round_trip_dataset, [self.evaluate(tf_value_fn())],
          requires_initialization=True)

  @test_util.run_deprecated_v1
  def testSkipEagerSameGraphErrorOneShot(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      dataset = dataset.batch(2)
      with test.mock.patch.object(logging, "warning") as mock_log:
        _ = dataset.make_one_shot_iterator()
        self.assertRegexpMatches(
            str(mock_log.call_args), "Please ensure that all datasets in the "
            "pipeline are created in the same graph as the iterator.")

  @test_util.run_deprecated_v1
  def testSkipEagerSameGraphErrorOneShotSimple(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      with test.mock.patch.object(logging, "warning") as mock_log:
        _ = dataset.make_one_shot_iterator()
        self.assertRegexpMatches(
            str(mock_log.call_args), "Please ensure that all datasets in the "
            "pipeline are created in the same graph as the iterator.")

  @test_util.run_deprecated_v1
  def testSkipEagerSameGraphErrorInitializable(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      dataset = dataset.batch(2)
      with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
        _ = dataset.make_initializable_iterator()
Esempio n. 14
0
class InputsTest(test_base.DatasetTestBase, parameterized.TestCase):

  @staticmethod
  def make_apply_fn(dataset):

    def apply_fn(dataset):

      def _apply_fn(dataset):
        return dataset.cache()

      return dataset.apply(_apply_fn)

    return apply_fn

  @staticmethod
  def make_gen():

    def gen():
      yield 42

    return gen

  @staticmethod
  def make_interleave_fn(dataset, num_parallel_calls=None):

    def interleave_fn(dataset):
      return dataset.interleave(
          lambda x: dataset_ops.Dataset.range(0),
          cycle_length=2,
          num_parallel_calls=num_parallel_calls)

    return interleave_fn

  @parameterized.named_parameters(
      ("FixedLengthRecord", readers.FixedLengthRecordDataset("", 42)),
      ("FromGenerator",
       dataset_ops.Dataset.from_generator(make_gen.__func__(), dtypes.int32),
       1),
      ("FromSparseTensorSlices",
       dataset_ops.Dataset.from_sparse_tensor_slices(
           sparse_tensor.SparseTensor(
               indices=np.array([[0, 0], [1, 0], [2, 0]]),
               values=np.array([0, 0, 0]),
               dense_shape=np.array([3, 1])))),
      ("FromTensors", dataset_ops.Dataset.from_tensors([42])),
      ("FromTensorSlices", dataset_ops.Dataset.from_tensors([42])),
      ("Range", dataset_ops.Dataset.range(10)),
      ("TextLine", readers.TextLineDataset("")),
      ("TFRecord", readers.TFRecordDataset(""), 1),
  )
  def testDatasetSourceInputs(self, dataset, num_inputs=0):
    self.assertEqual(num_inputs, len(dataset._inputs()))

  @parameterized.named_parameters(
      ("Apply", make_apply_fn.__func__(dataset_ops.Dataset.range(0)),
       dataset_ops.Dataset.range(0)),
      ("Batch", lambda x: x.batch(10), dataset_ops.Dataset.range(0)),
      ("Cache", lambda x: x.cache(), dataset_ops.Dataset.range(0)),
      ("Filter", lambda x: x.filter(lambda x: True),
       dataset_ops.Dataset.range(0)),
      ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
       dataset_ops.Dataset.range(0)),
      ("Interleave", make_interleave_fn.__func__(dataset_ops.Dataset.range(0)),
       dataset_ops.Dataset.range(0)),
      ("Map", lambda x: x.map(lambda x: x), dataset_ops.Dataset.range(0)),
      ("PaddedBatch", lambda x: x.padded_batch(10, []),
       dataset_ops.Dataset.range(0)),
      ("ParallelInterleave",
       make_interleave_fn.__func__(dataset_ops.Dataset.range(0), 2),
       dataset_ops.Dataset.range(0)),
      ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
       dataset_ops.Dataset.range(0)),
      ("Repeat", lambda x: x.repeat(), dataset_ops.Dataset.range(0)),
      ("Shuffle", lambda x: x.shuffle(10), dataset_ops.Dataset.range(0)),
      ("Skip", lambda x: x.skip(1), dataset_ops.Dataset.range(0)),
      ("Take", lambda x: x.take(1), dataset_ops.Dataset.range(0)),
      ("Window", lambda x: x.window(10), dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputs(self, dataset_fn, input_dataset):
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  @parameterized.named_parameters(
      ("Concatenate", lambda x, y: x.concatenate(y),
       dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1)))
  def testBinaryTransformationInputs(self, dataset_fn, input1, input2):
    self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())

  @parameterized.named_parameters(
      ("ZipOne", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0))),
      ("ZipNest", dataset_ops.Dataset.zip,
       (dataset_ops.Dataset.range(0),
        (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
      ("ZipTuple", dataset_ops.Dataset.zip,
       (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))))
  def testVariadicTransformationInputs(self, dataset_fn, input_datasets):
    self.assertEqual(
        nest.flatten(input_datasets),
        dataset_fn(input_datasets)._inputs())

  def testCollectInputs(self):
    ds1 = dataset_ops.Dataset.range(0)
    ds2 = ds1.concatenate(ds1)
    ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

    inputs = []
    queue = [ds3]
    while queue:
      ds = queue[0]
      queue = queue[1:]
      queue.extend(ds._inputs())
      inputs.append(ds)

    self.assertEqual(5, inputs.count(ds1))
    self.assertEqual(2, inputs.count(ds2))
    self.assertEqual(1, inputs.count(ds3))
Esempio n. 15
0
 def testFixedLengthRecordInputs(self):
     dataset = readers.FixedLengthRecordDataset("", 42)
     self.checkNumInputs(dataset, 0)
  def testFixedLengthRecordDataset(self):
    test_filenames = self._createFiles()
    filenames = array_ops.placeholder(dtypes.string, shape=[None])
    num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
    batch_size = array_ops.placeholder(dtypes.int64, shape=[])

    repeat_dataset = (readers.FixedLengthRecordDataset(
        filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
                      .repeat(num_epochs))
    batch_dataset = repeat_dataset.batch(batch_size)

    iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
    init_op = iterator.make_initializer(repeat_dataset)
    init_batch_op = iterator.make_initializer(batch_dataset)
    get_next = iterator.get_next()

    with self.test_session() as sess:
      # Basic test: read from file 0.
      sess.run(
          init_op, feed_dict={filenames: [test_filenames[0]],
                              num_epochs: 1})
      for i in range(self._num_records):
        self.assertEqual(self._record(0, i), sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

      # Basic test: read from file 1.
      sess.run(
          init_op, feed_dict={filenames: [test_filenames[1]],
                              num_epochs: 1})
      for i in range(self._num_records):
        self.assertEqual(self._record(1, i), sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

      # Basic test: read from both files.
      sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1})
      for j in range(self._num_files):
        for i in range(self._num_records):
          self.assertEqual(self._record(j, i), sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

      # Test repeated iteration through both files.
      sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 10})
      for _ in range(10):
        for j in range(self._num_files):
          for i in range(self._num_records):
            self.assertEqual(self._record(j, i), sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

      # Test batched and repeated iteration through both files.
      sess.run(
          init_batch_op,
          feed_dict={
              filenames: test_filenames,
              num_epochs: 10,
              batch_size: self._num_records
          })
      for _ in range(10):
        for j in range(self._num_files):
          self.assertAllEqual(
              [self._record(j, i) for i in range(self._num_records)],
              sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)
Esempio n. 17
0
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):

  def testAsSerializedGraph(self):
    dataset = dataset_ops.Dataset.range(10)
    graph = graph_pb2.GraphDef().FromString(
        self.evaluate(dataset._as_serialized_graph()))
    self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))

  def testAsFunctionWithMap(self):
    if not context.executing_eagerly():
      self.skipTest("Only works executing eagerly")
    with ops.device("CPU"):
      original_dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
      fn = original_dataset._trace_variant_creation()
      variant = fn()

      revived_dataset = _RevivedDataset(
          variant, original_dataset._element_structure)
      self.assertDatasetProduces(revived_dataset, range(0, 10, 2))

  def testAsFunctionWithMapInFlatMap(self):
    if not context.executing_eagerly():
      self.skipTest("Only works executing eagerly")
    with ops.device("CPU"):
      original_dataset = dataset_ops.Dataset.range(5).flat_map(
          lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2))
      fn = original_dataset._trace_variant_creation()
      variant = fn()

      revived_dataset = _RevivedDataset(
          variant, original_dataset._element_structure)
      self.assertDatasetProduces(revived_dataset, list(original_dataset))

  @staticmethod
  def make_apply_fn(dataset):

    def apply_fn(dataset):

      def _apply_fn(dataset):
        return dataset.cache()

      return dataset.apply(_apply_fn)

    return apply_fn

  @staticmethod
  def make_gen():

    def gen():
      yield 42

    return gen

  @staticmethod
  def make_interleave_fn(dataset, num_parallel_calls=None):

    def interleave_fn(dataset):
      return dataset.interleave(
          lambda x: dataset_ops.Dataset.range(0),
          cycle_length=2,
          num_parallel_calls=num_parallel_calls)

    return interleave_fn

  @parameterized.named_parameters(
      ("FixedLengthRecord",
       lambda: readers.FixedLengthRecordDataset("", 42)),
      ("FromGenerator",
       lambda: dataset_ops.Dataset.from_generator(
           DatasetTest.make_gen(), dtypes.int32),
       1),
      ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
      ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
      ("Range", lambda: dataset_ops.Dataset.range(10)),
      ("TextLine", lambda: readers.TextLineDataset("")),
      ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
  )
  def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
    self.assertLen(dataset_fn()._inputs(), num_inputs)

  @test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
  def testDatasetComplexSourceInputs(self):
    dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
        sparse_tensor.SparseTensor(
            indices=np.array([[0, 0], [1, 0], [2, 0]]),
            values=np.array([0, 0, 0]),
            dense_shape=np.array([3, 1])))
    self.assertEmpty(dataset_fn._inputs())

  @parameterized.named_parameters(
      ("Batch",
       lambda x: x.batch(10),
       lambda: dataset_ops.Dataset.range(0)),
      ("Cache",
       lambda x: x.cache(),
       lambda: dataset_ops.Dataset.range(0)),
      ("Filter",
       lambda x: x.filter(lambda x: True),
       lambda: dataset_ops.Dataset.range(0)),
      ("FlatMap",
       lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
       lambda: dataset_ops.Dataset.range(0)),
      ("Map",
       lambda x: x.map(lambda x: x),
       lambda: dataset_ops.Dataset.range(0)),
      ("PaddedBatch",
       lambda x: x.padded_batch(10, []),
       lambda: dataset_ops.Dataset.range(0)),
      ("ParallelMap",
       lambda x: x.map(lambda x: x, num_parallel_calls=2),
       lambda: dataset_ops.Dataset.range(0)),
      ("Repeat",
       lambda x: x.repeat(),
       lambda: dataset_ops.Dataset.range(0)),
      ("Shuffle",
       lambda x: x.shuffle(10),
       lambda: dataset_ops.Dataset.range(0)),
      ("Skip",
       lambda x: x.skip(1),
       lambda: dataset_ops.Dataset.range(0)),
      ("Take",
       lambda x: x.take(1),
       lambda: dataset_ops.Dataset.range(0)),
      ("Window",
       lambda x: x.window(10),
       lambda: dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
    input_dataset = input_dataset_fn()
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  def testUnaryTransformationInputsApply(self):
    input_dataset = dataset_ops.Dataset.range(0)
    dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  @parameterized.named_parameters(
      ("ParallelInterleave",
       [lambda: dataset_ops.Dataset.range(0), 2],
       lambda: dataset_ops.Dataset.range(0)),
      ("Interleave",
       [lambda: dataset_ops.Dataset.range(0), None],
       lambda: dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputsWithInterleaveFn(
      self, interleave_fn_args, input_dataset_fn):
    input_dataset = input_dataset_fn()
    dataset_fn = self.make_interleave_fn(*interleave_fn_args)
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  def testNoWarnings(self):
    with test.mock.patch.object(warnings, "warn") as mock_log:
      dataset_fn = self.make_interleave_fn(dataset_ops.Dataset.range(10))
      dataset_fn(dataset_ops.Dataset.range(10))
      self.assertEmpty(mock_log.call_args_list)

  @parameterized.named_parameters(
      ("Concatenate", lambda x, y: x.concatenate(y),
       lambda: dataset_ops.Dataset.range(0),
       lambda: dataset_ops.Dataset.range(1)))
  def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
    input1 = input1_fn()
    input2 = input2_fn()
    self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())

  @parameterized.named_parameters(
      ("ZipOne",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0))),
      ("ZipNest",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0),
                (dataset_ops.Dataset.range(1),
                 dataset_ops.Dataset.range(2)))),
      ("ZipTuple",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0),
                dataset_ops.Dataset.range(1))),
  )
  def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
    input_datasets = input_datasets_fn()
    self.assertEqual(
        nest.flatten(input_datasets),
        dataset_fn(input_datasets)._inputs())

  def testFunctions(self):
    dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
    self.assertLen(dataset._functions(), 1)

  def testCollectInputs(self):
    ds1 = dataset_ops.Dataset.range(0)
    ds2 = ds1.concatenate(ds1)
    ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

    inputs = []
    queue = [ds3]
    while queue:
      ds = queue[0]
      queue = queue[1:]
      queue.extend(ds._inputs())
      inputs.append(ds)

    self.assertEqual(5, inputs.count(ds1))
    self.assertEqual(2, inputs.count(ds2))
    self.assertEqual(1, inputs.count(ds3))

  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0),
       structure.TensorStructure(dtypes.float32, [])),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
          dense_shape=[1]),
       structure.SparseTensorStructure(dtypes.int32, [1])),
      ("Nest", lambda: {
          "a": constant_op.constant(37.0),
          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
       structure.NestedStructure({
           "a": structure.TensorStructure(dtypes.float32, []),
           "b": (structure.TensorStructure(dtypes.string, [1]),
                 structure.TensorStructure(dtypes.string, []))})),
      ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
          constant_op.constant([1, 2, 3])),
       dataset_ops.DatasetStructure(
           structure.TensorStructure(dtypes.int32, []))),
      ("Optional", lambda: optional_ops.Optional.from_value(37.0),
       optional_ops.OptionalStructure(
           structure.TensorStructure(dtypes.float32, []))),
  )
  def testDatasetStructure(self, tf_value_fn, expected_element_structure):
    dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn())
    dataset_structure = structure.Structure.from_value(dataset)
    self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)

    # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
    # the element structure.
    self.assertTrue(expected_element_structure.is_compatible_with(
        dataset_structure._element_structure))
    self.assertTrue(dataset_structure._element_structure.is_compatible_with(
        expected_element_structure))

    self.assertEqual([dtypes.variant], dataset_structure._flat_types)
    self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes)

    # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
    # and _to_tensor_list().
    round_trip_dataset = dataset_structure._from_tensor_list(
        dataset_structure._to_tensor_list(dataset))

    value = tf_value_fn()

    if isinstance(value, dataset_ops.Dataset):
      self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
    elif isinstance(value, optional_ops.Optional):
      self.assertDatasetProduces(
          round_trip_dataset.map(lambda opt: opt.get_value()),
          [self.evaluate(value.get_value())],
          requires_initialization=True)
    else:
      self.assertDatasetProduces(
          round_trip_dataset, [self.evaluate(tf_value_fn())],
          requires_initialization=True)

  @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
  def testSkipEagerSameGraphErrorOneShot(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
        dataset = dataset.batch(2)

  @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
  def testSkipEagerSameGraphErrorOneShotSimple(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      with test.mock.patch.object(logging, "warning") as mock_log:
        _ = dataset_ops.make_one_shot_iterator(dataset)
        self.assertRegexpMatches(
            str(mock_log.call_args), "Please ensure that all datasets in the "
            "pipeline are created in the same graph as the iterator.")

  @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
  def testSkipEagerSameGraphErrorInitializable(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
        dataset = dataset.batch(2)

  @parameterized.named_parameters(
      ("Async", context.ASYNC),
      ("Sync", context.SYNC),
  )
  def testDatasetEagerIteration(self, execution_mode):
    with context.eager_mode(), context.execution_mode(execution_mode):
      val = 0
      dataset = dataset_ops.Dataset.range(10)
      for foo in dataset:
        self.assertEqual(val, foo.numpy())
        val += 1

  def testDatasetAsFunctionArgument(self):

    @def_function.function
    def _uses_dataset(d):
      accumulator = array_ops.zeros([], dtype=dtypes.int64)
      for value in d:
        accumulator += value
      return accumulator

    with ops.device("CPU"):
      first_dataset = dataset_ops.Dataset.range(10)
      self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset)))
      second_dataset = dataset_ops.Dataset.range(11)
      self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset)))
      first_concrete = _uses_dataset.get_concrete_function(first_dataset)
      # The dataset should not be a captured input
      self.assertEmpty(first_concrete.graph.captures)
      # The two datasets have the same structure and so should re-use a trace.
      self.assertIs(first_concrete,
                    _uses_dataset.get_concrete_function(second_dataset))
      # With a different structure we should use a different trace.
      self.assertIsNot(
          first_concrete,
          _uses_dataset.get_concrete_function(
              dataset_ops.Dataset.zip((first_dataset, second_dataset))))

  def testLimitedRetracing(self):
    trace_count = [0]

    @def_function.function
    def f(ds):
      trace_count[0] += 1
      counter = np.int64(0)
      for elem in ds:
        counter += elem
      return counter

    dataset = dataset_ops.Dataset.range(5)
    dataset2 = dataset_ops.Dataset.range(10)

    for _ in range(10):
      self.assertEqual(self.evaluate(f(dataset)), 10)
      self.assertEqual(self.evaluate(f(dataset2)), 45)
      self.assertEqual(trace_count[0], 1)
Esempio n. 18
0
 def FixedLengthFile(filename):
     return readers.FixedLengthRecordDataset(filename, record_bytes)
Esempio n. 19
0
class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
    def testAsSerializedGraph(self):
        dataset = dataset_ops.Dataset.range(10)
        with self.cached_session() as sess:
            graph = graph_pb2.GraphDef().FromString(
                sess.run(dataset._as_serialized_graph()))
            self.assertTrue(
                any(node.op != "RangeDataset" for node in graph.node))

    @staticmethod
    def make_apply_fn(dataset):
        def apply_fn(dataset):
            def _apply_fn(dataset):
                return dataset.cache()

            return dataset.apply(_apply_fn)

        return apply_fn

    @staticmethod
    def make_gen():
        def gen():
            yield 42

        return gen

    @staticmethod
    def make_interleave_fn(dataset, num_parallel_calls=None):
        def interleave_fn(dataset):
            return dataset.interleave(lambda x: dataset_ops.Dataset.range(0),
                                      cycle_length=2,
                                      num_parallel_calls=num_parallel_calls)

        return interleave_fn

    @parameterized.named_parameters(
        ("FixedLengthRecord",
         lambda: readers.FixedLengthRecordDataset("", 42)),
        ("FromGenerator", lambda: dataset_ops.Dataset.from_generator(
            DatasetOpsTest.make_gen(), dtypes.int32), 1),
        ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("Range", lambda: dataset_ops.Dataset.range(10)),
        ("TextLine", lambda: readers.TextLineDataset("")),
        ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
    )
    def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
        self.assertEqual(num_inputs, len(dataset_fn()._inputs()))

    def testDatasetComplexSourceInputs(self):
        dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
            sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0],
                                                         [2, 0]]),
                                       values=np.array([0, 0, 0]),
                                       dense_shape=np.array([3, 1])))
        self.assertEqual(0, len(dataset_fn._inputs()))

    @parameterized.named_parameters(
        ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)),
        ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)),
        ("Filter", lambda x: x.filter(lambda x: True),
         lambda: dataset_ops.Dataset.range(0)),
        ("FlatMap",
         lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
         lambda: dataset_ops.Dataset.range(0)),
        ("Map", lambda x: x.map(lambda x: x),
         lambda: dataset_ops.Dataset.range(0)),
        ("PaddedBatch", lambda x: x.padded_batch(10, []),
         lambda: dataset_ops.Dataset.range(0)),
        ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
         lambda: dataset_ops.Dataset.range(0)),
        ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)),
        ("Shuffle", lambda x: x.shuffle(10),
         lambda: dataset_ops.Dataset.range(0)),
        ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)),
        ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)),
        ("Window", lambda x: x.window(10),
         lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
        input_dataset = input_dataset_fn()
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    def testUnaryTransformationInputsApply(self):
        input_dataset = dataset_ops.Dataset.range(0)
        dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2
                                ], lambda: dataset_ops.Dataset.range(0)),
        ("Interleave", [lambda: dataset_ops.Dataset.range(0), None
                        ], lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputsWithInterleaveFn(self, interleave_fn_args,
                                                      input_dataset_fn):
        input_dataset = input_dataset_fn()
        dataset_fn = self.make_interleave_fn(*interleave_fn_args)
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("Concatenate", lambda x, y: x.concatenate(y),
         lambda: dataset_ops.Dataset.range(0),
         lambda: dataset_ops.Dataset.range(1)))
    def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
        input1 = input1_fn()
        input2 = input2_fn()
        self.assertEqual([input1, input2],
                         dataset_fn(input1, input2)._inputs())

    @parameterized.named_parameters(
        ("ZipOne", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0))),
        ("ZipNest", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0),
          (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
        ("ZipTuple", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))),
    )
    def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
        input_datasets = input_datasets_fn()
        self.assertEqual(nest.flatten(input_datasets),
                         dataset_fn(input_datasets)._inputs())

    def testCollectInputs(self):
        ds1 = dataset_ops.Dataset.range(0)
        ds2 = ds1.concatenate(ds1)
        ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

        inputs = []
        queue = [ds3]
        while queue:
            ds = queue[0]
            queue = queue[1:]
            queue.extend(ds._inputs())
            inputs.append(ds)

        self.assertEqual(5, inputs.count(ds1))
        self.assertEqual(2, inputs.count(ds2))
        self.assertEqual(1, inputs.count(ds3))

    def testOptionsDefault(self):
        ds = dataset_ops.Dataset.range(0)
        self.assertEqual(dataset_ops.Options(), ds.options())

    def testOptionsOnce(self):
        options = dataset_ops.Options()
        ds = dataset_ops.Dataset.range(0).with_options(options).cache()
        self.assertEqual(options, ds.options())

    def testOptionsTwiceSame(self):
        options = dataset_ops.Options()
        options.experimental_autotune = True
        ds = dataset_ops.Dataset.range(0).with_options(options).with_options(
            options)
        self.assertEqual(options, ds.options())

    def testOptionsTwiceDifferent(self):
        options1 = dataset_ops.Options()
        options1.experimental_autotune = True
        options2 = dataset_ops.Options()
        options2.experimental_filter_fusion = False
        ds = dataset_ops.Dataset.range(0).with_options(options1).with_options(
            options2)
        self.assertTrue(ds.options().experimental_autotune)
        # Explicitly check that flag is False since assertFalse allows None
        self.assertIs(ds.options().experimental_filter_fusion, False)

    def testOptionsTwiceDifferentError(self):
        options1 = dataset_ops.Options()
        options1.experimental_autotune = True
        options2 = dataset_ops.Options()
        options2.experimental_autotune = False
        with self.assertRaisesRegexp(ValueError,
                                     "Cannot merge incompatible values"):
            dataset_ops.Dataset.range(0).with_options(options1).with_options(
                options2)

    def testOptionsMergeOptionsFromMultipleInputs(self):
        options1 = dataset_ops.Options()
        options1.experimental_autotune = True
        options2 = dataset_ops.Options()
        options2.experimental_filter_fusion = True
        ds = dataset_ops.Dataset.zip(
            (dataset_ops.Dataset.range(0).with_options(options1),
             dataset_ops.Dataset.range(0).with_options(options2)))
        self.assertTrue(ds.options().experimental_autotune)
        self.assertTrue(ds.options().experimental_filter_fusion)

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[1]), structure.SparseTensorStructure(
                dtypes.int32, [1])),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        },
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.TensorStructure(dtypes.string, [1]),
                   structure.TensorStructure(dtypes.string, []))
         })),
        ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
            constant_op.constant([1, 2, 3])),
         dataset_ops.DatasetStructure(
             structure.TensorStructure(dtypes.int32, []))),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalStructure(
             structure.TensorStructure(dtypes.float32, []))),
    )
    def testDatasetStructure(self, tf_value_fn, expected_element_structure):
        dataset = dataset_ops.Dataset.from_tensors(0).map(
            lambda _: tf_value_fn())
        dataset_structure = structure.Structure.from_value(dataset)
        self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)

        # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
        # the element structure.
        self.assertTrue(
            expected_element_structure.is_compatible_with(
                dataset_structure._element_structure))
        self.assertTrue(
            dataset_structure._element_structure.is_compatible_with(
                expected_element_structure))

        self.assertEqual([dtypes.variant], dataset_structure._flat_types)
        self.assertEqual([tensor_shape.scalar()],
                         dataset_structure._flat_shapes)

        # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_dataset = dataset_structure._from_tensor_list(
            dataset_structure._to_tensor_list(dataset))

        value = tf_value_fn()

        if isinstance(value, dataset_ops.Dataset):
            self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
        elif isinstance(value, optional_ops.Optional):
            self.assertDatasetProduces(
                round_trip_dataset.map(lambda opt: opt.get_value()),
                [self.evaluate(value.get_value())],
                requires_initialization=True)
        else:
            self.assertDatasetProduces(round_trip_dataset,
                                       [self.evaluate(tf_value_fn())],
                                       requires_initialization=True)
class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
    def testAsSerializedGraph(self):
        dataset = dataset_ops.Dataset.range(10)
        with self.cached_session() as sess:
            graph = graph_pb2.GraphDef().FromString(
                sess.run(dataset._as_serialized_graph()))
            self.assertTrue(
                any([node.op != "RangeDataset" for node in graph.node]))

    @staticmethod
    def make_apply_fn(dataset):
        def apply_fn(dataset):
            def _apply_fn(dataset):
                return dataset.cache()

            return dataset.apply(_apply_fn)

        return apply_fn

    @staticmethod
    def make_gen():
        def gen():
            yield 42

        return gen

    @staticmethod
    def make_interleave_fn(dataset, num_parallel_calls=None):
        def interleave_fn(dataset):
            return dataset.interleave(lambda x: dataset_ops.Dataset.range(0),
                                      cycle_length=2,
                                      num_parallel_calls=num_parallel_calls)

        return interleave_fn

    @parameterized.named_parameters(
        ("FixedLengthRecord", readers.FixedLengthRecordDataset("", 42)),
        ("FromGenerator",
         dataset_ops.Dataset.from_generator(make_gen.__func__(),
                                            dtypes.int32), 1),
        ("FromSparseTensorSlices",
         dataset_ops.Dataset.from_sparse_tensor_slices(
             sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0],
                                                          [2, 0]]),
                                        values=np.array([0, 0, 0]),
                                        dense_shape=np.array([3, 1])))),
        ("FromTensors", dataset_ops.Dataset.from_tensors([42])),
        ("FromTensorSlices", dataset_ops.Dataset.from_tensors([42])),
        ("Range", dataset_ops.Dataset.range(10)),
        ("TextLine", readers.TextLineDataset("")),
        ("TFRecord", readers.TFRecordDataset(""), 1),
    )
    def testDatasetSourceInputs(self, dataset, num_inputs=0):
        self.assertEqual(num_inputs, len(dataset._inputs()))

    @parameterized.named_parameters(
        ("Apply", make_apply_fn.__func__(
            dataset_ops.Dataset.range(0)), dataset_ops.Dataset.range(0)),
        ("Batch", lambda x: x.batch(10), dataset_ops.Dataset.range(0)),
        ("Cache", lambda x: x.cache(), dataset_ops.Dataset.range(0)),
        ("Filter", lambda x: x.filter(lambda x: True),
         dataset_ops.Dataset.range(0)),
        ("FlatMap",
         lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
         dataset_ops.Dataset.range(0)),
        ("Interleave", make_interleave_fn.__func__(
            dataset_ops.Dataset.range(0)), dataset_ops.Dataset.range(0)),
        ("Map", lambda x: x.map(lambda x: x), dataset_ops.Dataset.range(0)),
        ("PaddedBatch", lambda x: x.padded_batch(10, []),
         dataset_ops.Dataset.range(0)),
        ("ParallelInterleave",
         make_interleave_fn.__func__(dataset_ops.Dataset.range(0),
                                     2), dataset_ops.Dataset.range(0)),
        ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
         dataset_ops.Dataset.range(0)),
        ("Repeat", lambda x: x.repeat(), dataset_ops.Dataset.range(0)),
        ("Shuffle", lambda x: x.shuffle(10), dataset_ops.Dataset.range(0)),
        ("Skip", lambda x: x.skip(1), dataset_ops.Dataset.range(0)),
        ("Take", lambda x: x.take(1), dataset_ops.Dataset.range(0)),
        ("Window", lambda x: x.window(10), dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputs(self, dataset_fn, input_dataset):
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("Concatenate", lambda x, y: x.concatenate(y),
         dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1)))
    def testBinaryTransformationInputs(self, dataset_fn, input1, input2):
        self.assertEqual([input1, input2],
                         dataset_fn(input1, input2)._inputs())

    @parameterized.named_parameters(
        ("ZipOne", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0))),
        ("ZipNest", dataset_ops.Dataset.zip,
         (dataset_ops.Dataset.range(0),
          (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
        ("ZipTuple", dataset_ops.Dataset.zip,
         (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))))
    def testVariadicTransformationInputs(self, dataset_fn, input_datasets):
        self.assertEqual(nest.flatten(input_datasets),
                         dataset_fn(input_datasets)._inputs())

    def testCollectInputs(self):
        ds1 = dataset_ops.Dataset.range(0)
        ds2 = ds1.concatenate(ds1)
        ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

        inputs = []
        queue = [ds3]
        while queue:
            ds = queue[0]
            queue = queue[1:]
            queue.extend(ds._inputs())
            inputs.append(ds)

        self.assertEqual(5, inputs.count(ds1))
        self.assertEqual(2, inputs.count(ds2))
        self.assertEqual(1, inputs.count(ds3))

    def testOptionsDefault(self):
        ds = dataset_ops.Dataset.range(0)
        self.assertEqual(dataset_ops.Options(), ds.options())

    def testOptionsOnce(self):
        options = dataset_ops.Options()
        ds = dataset_ops.Dataset.range(0).with_options(options).cache()
        self.assertEqual(options, ds.options())

    def testOptionsTwiceSame(self):
        options = dataset_ops.Options()
        options.experimental_autotune = True
        ds = dataset_ops.Dataset.range(0).with_options(options).with_options(
            options)
        self.assertEqual(options, ds.options())

    def testOptionsTwiceDifferent(self):
        options1 = dataset_ops.Options()
        options1.experimental_autotune = True
        options2 = dataset_ops.Options()
        options2.experimental_filter_fusion = False
        ds = dataset_ops.Dataset.range(0).with_options(options1).with_options(
            options2)
        self.assertTrue(ds.options().experimental_autotune)
        self.assertFalse(ds.options().experimental_filter_fusion)

    def testOptionsTwiceDifferentError(self):
        options1 = dataset_ops.Options()
        options1.experimental_autotune = True
        options2 = dataset_ops.Options()
        options2.experimental_autotune = False
        with self.assertRaisesRegexp(
                ValueError, "Cannot merge incompatible values of option"):
            dataset_ops.Dataset.range(0).with_options(options1).with_options(
                options2)
 def _build_iterator_graph(self, num_epochs, compression_type=None):
     filenames = self._createFiles()
     return core_readers.FixedLengthRecordDataset(
         filenames, self._record_bytes, self._header_bytes,
         self._footer_bytes).repeat(num_epochs)