Ejemplo n.º 1
0
 def testBasic(self, compression):
     dataset = dataset_ops.Dataset.range(42)
     io.save(dataset, self._test_dir, compression=compression)
     dataset2 = io.load(self._test_dir,
                        dataset.element_spec,
                        compression=compression)
     self.assertDatasetProduces(dataset2, range(42))
Ejemplo n.º 2
0
 def testCustomShardFunction(self):
   dataset = dataset_ops.Dataset.range(42)
   io.save(dataset, self._test_dir, shard_func=lambda x: x // 21)
   dataset2 = io.load(self._test_dir, dataset.element_spec)
   expected = []
   for i in range(21):
     expected.extend([i, i + 21])
   self.assertDatasetProduces(dataset2, expected)
Ejemplo n.º 3
0
 def testCustomReaderFunction(self):
     dataset = dataset_ops.Dataset.range(42)
     io.save(dataset, self._test_dir, shard_func=lambda x: x % 7)
     dataset2 = io.load(self._test_dir,
                        dataset.element_spec,
                        reader_func=lambda x: x.flat_map(lambda y: y))
     expected = []
     for i in range(7):
         expected.extend(range(i, 42, 7))
     self.assertDatasetProduces(dataset2, expected)
Ejemplo n.º 4
0
 def testElementSpecOptional(self):
   range_dataset = dataset_ops.Dataset.range(42)
   dict_dataset = dataset_ops.Dataset.from_tensor_slices({"a": [1, 2],
                                                          "b": [3, 4]})
   tuple_dataset = dataset_ops.Dataset.from_tensor_slices(([1, 2], [3, 4]))
   dataset = dataset_ops.Dataset.zip((range_dataset, dict_dataset,
                                      tuple_dataset))
   io.save(dataset, self._test_dir)
   dataset_loaded = io.load(self._test_dir)
   self.assertDatasetsEqual(dataset, dataset_loaded)
Ejemplo n.º 5
0
  def testSaveInsideFunction(self, compression):

    dataset = dataset_ops.Dataset.range(42)

    @def_function.function
    def save_fn():
      io.save(dataset, self._test_dir, compression=compression)

    save_fn()
    dataset = io.load(
        self._test_dir, dataset.element_spec, compression=compression)
    self.assertDatasetProduces(dataset, range(42))
Ejemplo n.º 6
0
 def testRepeatAndPrefetch(self):
   """This test reproduces github.com/tensorflow/tensorflow/issues/49165."""
   dataset1 = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
   io.save(dataset1, self._test_dir)
   dataset = io.load(self._test_dir)
   dataset = dataset.shuffle(buffer_size=16)
   dataset = dataset.batch(16)
   dataset = dataset.repeat()
   dataset = dataset.prefetch(1)
   next_element = self.getNext(dataset)
   for _ in range(30):
     self.evaluate(next_element())
Ejemplo n.º 7
0
    def testLoadCheckpointIteratorMultipleBreaks(self):
        dataset = dataset_ops.Dataset.range(3)
        io.save(dataset, self._save_dir)
        loaded_dataset = io.load(self._save_dir)
        iterator = iter(loaded_dataset)
        get_next = iterator.get_next

        checkpoint = trackable_utils.Checkpoint(iterator=iterator)
        for i in range(len(dataset)):
            save_path = checkpoint.save(self._checkpoint_prefix)
            self.assertAllEqual(i, get_next())
            checkpoint.restore(save_path).run_restore_ops()
            self.assertAllEqual(i, get_next())
        with self.assertRaises(errors.OutOfRangeError):
            get_next()
Ejemplo n.º 8
0
    def testLoadCheckpointFullyUsedIterator(self):
        dataset = dataset_ops.Dataset.range(3)
        io.save(dataset, self._save_dir)
        loaded_dataset = io.load(self._save_dir)
        iterator = iter(loaded_dataset)
        get_next = iterator.get_next

        checkpoint = trackable_utils.Checkpoint(iterator=iterator)
        self.assertAllEqual(0, get_next())
        self.assertAllEqual(1, get_next())
        self.assertAllEqual(2, get_next())
        save_path = checkpoint.save(self._checkpoint_prefix)
        checkpoint.restore(save_path).run_restore_ops()
        with self.assertRaises(errors.OutOfRangeError):
            get_next()
Ejemplo n.º 9
0
 def testCardinality(self):
   dataset = dataset_ops.Dataset.range(42)
   io.save(dataset, self._test_dir)
   dataset2 = io.load(self._test_dir, dataset.element_spec)
   self.assertEqual(self.evaluate(dataset2.cardinality()), 42)
Ejemplo n.º 10
0
 def _build_ds(self):
   return io.load(self._save_dir)
Ejemplo n.º 11
0
 def testElementSpecRequired(self):
   dataset = dataset_ops.Dataset.range(42)
   io.save(dataset, self._test_dir)
   with self.assertRaises(ValueError):
     _ = io.load(self._test_dir)