def testBasic(self, compression): dataset = dataset_ops.Dataset.range(42) io.save(dataset, self._test_dir, compression=compression) dataset2 = io.load(self._test_dir, dataset.element_spec, compression=compression) self.assertDatasetProduces(dataset2, range(42))
def testCustomShardFunction(self): dataset = dataset_ops.Dataset.range(42) io.save(dataset, self._test_dir, shard_func=lambda x: x // 21) dataset2 = io.load(self._test_dir, dataset.element_spec) expected = [] for i in range(21): expected.extend([i, i + 21]) self.assertDatasetProduces(dataset2, expected)
def testSaveCheckpointingAPIIncorrectArgs(self): dataset = dataset_ops.Dataset.range(42) checkpoint_args = { "directory": self._checkpoint_prefix, "incorrect_arg": "incorrect_arg" } with self.assertRaises(TypeError): io.save(dataset, self._save_dir, checkpoint_args=checkpoint_args)
def testSaveCheckpointingAPI(self): dataset = dataset_ops.Dataset.range(40) checkpoint_args = {"directory": self._checkpoint_prefix, "max_to_keep": 50} io.save(dataset, self._save_dir, checkpoint_args=checkpoint_args) num_checkpoint_files = len(list(os.listdir(self._checkpoint_prefix))) # By default, we checkpoint every increment. Each checkpoint writes a # file containing the data and a file containing the index. There is # also an overall checkpoint file. Thus, we expect (2 * 40) + 1 files. self.assertEqual(81, num_checkpoint_files)
def testCustomReaderFunction(self): dataset = dataset_ops.Dataset.range(42) io.save(dataset, self._test_dir, shard_func=lambda x: x % 7) dataset2 = io.load(self._test_dir, dataset.element_spec, reader_func=lambda x: x.flat_map(lambda y: y)) expected = [] for i in range(7): expected.extend(range(i, 42, 7)) self.assertDatasetProduces(dataset2, expected)
def testElementSpecOptional(self): range_dataset = dataset_ops.Dataset.range(42) dict_dataset = dataset_ops.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]}) tuple_dataset = dataset_ops.Dataset.from_tensor_slices(([1, 2], [3, 4])) dataset = dataset_ops.Dataset.zip((range_dataset, dict_dataset, tuple_dataset)) io.save(dataset, self._test_dir) dataset_loaded = io.load(self._test_dir) self.assertDatasetsEqual(dataset, dataset_loaded)
def testRepeatAndPrefetch(self): """This test reproduces github.com/tensorflow/tensorflow/issues/49165.""" dataset1 = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32)) io.save(dataset1, self._test_dir) dataset = io.load(self._test_dir) dataset = dataset.shuffle(buffer_size=16) dataset = dataset.batch(16) dataset = dataset.repeat() dataset = dataset.prefetch(1) next_element = self.getNext(dataset) for _ in range(30): self.evaluate(next_element())
def testSaveCheckpointingAPICustomCheckpointInterval(self): dataset = dataset_ops.Dataset.range(40) step_counter = variables.Variable(0, trainable=False) checkpoint_args = { "checkpoint_interval": 5, "step_counter": step_counter, "directory": self._checkpoint_prefix, "max_to_keep": 10, } io.save(dataset, self._save_dir, checkpoint_args=checkpoint_args) num_checkpoint_files = len(list(os.listdir(self._checkpoint_prefix))) # We expect (2 * 8) + 1 files. self.assertEqual(17, num_checkpoint_files)
def testLoadCheckpointIteratorMultipleBreaks(self): dataset = dataset_ops.Dataset.range(3) io.save(dataset, self._save_dir) loaded_dataset = io.load(self._save_dir) iterator = iter(loaded_dataset) get_next = iterator.get_next checkpoint = trackable_utils.Checkpoint(iterator=iterator) for i in range(len(dataset)): save_path = checkpoint.save(self._checkpoint_prefix) self.assertAllEqual(i, get_next()) checkpoint.restore(save_path).run_restore_ops() self.assertAllEqual(i, get_next()) with self.assertRaises(errors.OutOfRangeError): get_next()
def testLoadCheckpointFullyUsedIterator(self): dataset = dataset_ops.Dataset.range(3) io.save(dataset, self._save_dir) loaded_dataset = io.load(self._save_dir) iterator = iter(loaded_dataset) get_next = iterator.get_next checkpoint = trackable_utils.Checkpoint(iterator=iterator) self.assertAllEqual(0, get_next()) self.assertAllEqual(1, get_next()) self.assertAllEqual(2, get_next()) save_path = checkpoint.save(self._checkpoint_prefix) checkpoint.restore(save_path).run_restore_ops() with self.assertRaises(errors.OutOfRangeError): get_next()
def testCardinality(self): dataset = dataset_ops.Dataset.range(42) io.save(dataset, self._test_dir) dataset2 = io.load(self._test_dir, dataset.element_spec) self.assertEqual(self.evaluate(dataset2.cardinality()), 42)
def save_fn(): io.save(dataset, self._test_dir, compression=compression)
def test(self, verify_fn): dataset = dataset_ops.Dataset.range(42) io.save(dataset, self._save_dir) verify_fn(self, self._build_ds, num_outputs=42)
def testElementSpecRequired(self): dataset = dataset_ops.Dataset.range(42) io.save(dataset, self._test_dir) with self.assertRaises(ValueError): _ = io.load(self._test_dir)
def test(self, verify_fn): self.skipTest( "TODO(jsimsa): Re-enable once snapshot reader supports serialization.") dataset = dataset_ops.Dataset.range(42) io.save(dataset, self._save_dir) verify_fn(self, self._build_ds, num_outputs=42)