Exemple #1
0
 def testBasic(self, compression):
     dataset = dataset_ops.Dataset.range(42)
     io.save(dataset, self._test_dir, compression=compression)
     dataset2 = io.load(self._test_dir,
                        dataset.element_spec,
                        compression=compression)
     self.assertDatasetProduces(dataset2, range(42))
Exemple #2
0
 def testCustomShardFunction(self):
   dataset = dataset_ops.Dataset.range(42)
   io.save(dataset, self._test_dir, shard_func=lambda x: x // 21)
   dataset2 = io.load(self._test_dir, dataset.element_spec)
   expected = []
   for i in range(21):
     expected.extend([i, i + 21])
   self.assertDatasetProduces(dataset2, expected)
Exemple #3
0
 def testSaveCheckpointingAPIIncorrectArgs(self):
   dataset = dataset_ops.Dataset.range(42)
   checkpoint_args = {
       "directory": self._checkpoint_prefix,
       "incorrect_arg": "incorrect_arg"
   }
   with self.assertRaises(TypeError):
     io.save(dataset, self._save_dir, checkpoint_args=checkpoint_args)
Exemple #4
0
 def testSaveCheckpointingAPI(self):
   dataset = dataset_ops.Dataset.range(40)
   checkpoint_args = {"directory": self._checkpoint_prefix, "max_to_keep": 50}
   io.save(dataset, self._save_dir, checkpoint_args=checkpoint_args)
   num_checkpoint_files = len(list(os.listdir(self._checkpoint_prefix)))
   # By default, we checkpoint every increment. Each checkpoint writes a
   # file containing the data and a file containing the index. There is
   # also an overall checkpoint file. Thus, we expect (2 * 40) + 1 files.
   self.assertEqual(81, num_checkpoint_files)
Exemple #5
0
 def testCustomReaderFunction(self):
     dataset = dataset_ops.Dataset.range(42)
     io.save(dataset, self._test_dir, shard_func=lambda x: x % 7)
     dataset2 = io.load(self._test_dir,
                        dataset.element_spec,
                        reader_func=lambda x: x.flat_map(lambda y: y))
     expected = []
     for i in range(7):
         expected.extend(range(i, 42, 7))
     self.assertDatasetProduces(dataset2, expected)
Exemple #6
0
 def testElementSpecOptional(self):
   range_dataset = dataset_ops.Dataset.range(42)
   dict_dataset = dataset_ops.Dataset.from_tensor_slices({"a": [1, 2],
                                                          "b": [3, 4]})
   tuple_dataset = dataset_ops.Dataset.from_tensor_slices(([1, 2], [3, 4]))
   dataset = dataset_ops.Dataset.zip((range_dataset, dict_dataset,
                                      tuple_dataset))
   io.save(dataset, self._test_dir)
   dataset_loaded = io.load(self._test_dir)
   self.assertDatasetsEqual(dataset, dataset_loaded)
Exemple #7
0
 def testRepeatAndPrefetch(self):
   """This test reproduces github.com/tensorflow/tensorflow/issues/49165."""
   dataset1 = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
   io.save(dataset1, self._test_dir)
   dataset = io.load(self._test_dir)
   dataset = dataset.shuffle(buffer_size=16)
   dataset = dataset.batch(16)
   dataset = dataset.repeat()
   dataset = dataset.prefetch(1)
   next_element = self.getNext(dataset)
   for _ in range(30):
     self.evaluate(next_element())
Exemple #8
0
 def testSaveCheckpointingAPICustomCheckpointInterval(self):
   dataset = dataset_ops.Dataset.range(40)
   step_counter = variables.Variable(0, trainable=False)
   checkpoint_args = {
       "checkpoint_interval": 5,
       "step_counter": step_counter,
       "directory": self._checkpoint_prefix,
       "max_to_keep": 10,
   }
   io.save(dataset, self._save_dir, checkpoint_args=checkpoint_args)
   num_checkpoint_files = len(list(os.listdir(self._checkpoint_prefix)))
   # We expect (2 * 8) + 1 files.
   self.assertEqual(17, num_checkpoint_files)
Exemple #9
0
    def testLoadCheckpointIteratorMultipleBreaks(self):
        dataset = dataset_ops.Dataset.range(3)
        io.save(dataset, self._save_dir)
        loaded_dataset = io.load(self._save_dir)
        iterator = iter(loaded_dataset)
        get_next = iterator.get_next

        checkpoint = trackable_utils.Checkpoint(iterator=iterator)
        for i in range(len(dataset)):
            save_path = checkpoint.save(self._checkpoint_prefix)
            self.assertAllEqual(i, get_next())
            checkpoint.restore(save_path).run_restore_ops()
            self.assertAllEqual(i, get_next())
        with self.assertRaises(errors.OutOfRangeError):
            get_next()
Exemple #10
0
    def testLoadCheckpointFullyUsedIterator(self):
        dataset = dataset_ops.Dataset.range(3)
        io.save(dataset, self._save_dir)
        loaded_dataset = io.load(self._save_dir)
        iterator = iter(loaded_dataset)
        get_next = iterator.get_next

        checkpoint = trackable_utils.Checkpoint(iterator=iterator)
        self.assertAllEqual(0, get_next())
        self.assertAllEqual(1, get_next())
        self.assertAllEqual(2, get_next())
        save_path = checkpoint.save(self._checkpoint_prefix)
        checkpoint.restore(save_path).run_restore_ops()
        with self.assertRaises(errors.OutOfRangeError):
            get_next()
Exemple #11
0
 def testCardinality(self):
   dataset = dataset_ops.Dataset.range(42)
   io.save(dataset, self._test_dir)
   dataset2 = io.load(self._test_dir, dataset.element_spec)
   self.assertEqual(self.evaluate(dataset2.cardinality()), 42)
Exemple #12
0
 def save_fn():
     io.save(dataset, self._test_dir, compression=compression)
Exemple #13
0
 def test(self, verify_fn):
   dataset = dataset_ops.Dataset.range(42)
   io.save(dataset, self._save_dir)
   verify_fn(self, self._build_ds, num_outputs=42)
Exemple #14
0
 def testElementSpecRequired(self):
   dataset = dataset_ops.Dataset.range(42)
   io.save(dataset, self._test_dir)
   with self.assertRaises(ValueError):
     _ = io.load(self._test_dir)
Exemple #15
0
 def test(self, verify_fn):
   self.skipTest(
       "TODO(jsimsa): Re-enable once snapshot reader supports serialization.")
   dataset = dataset_ops.Dataset.range(42)
   io.save(dataset, self._save_dir)
   verify_fn(self, self._build_ds, num_outputs=42)