def testBasic(self, compression): dataset = dataset_ops.Dataset.range(42) io.save(dataset, self._test_dir, compression=compression) dataset2 = io.load(self._test_dir, dataset.element_spec, compression=compression) self.assertDatasetProduces(dataset2, range(42))
def testCustomShardFunction(self): dataset = dataset_ops.Dataset.range(42) io.save(dataset, self._test_dir, shard_func=lambda x: x // 21) dataset2 = io.load(self._test_dir, dataset.element_spec) expected = [] for i in range(21): expected.extend([i, i + 21]) self.assertDatasetProduces(dataset2, expected)
def testCustomReaderFunction(self): dataset = dataset_ops.Dataset.range(42) io.save(dataset, self._test_dir, shard_func=lambda x: x % 7) dataset2 = io.load(self._test_dir, dataset.element_spec, reader_func=lambda x: x.flat_map(lambda y: y)) expected = [] for i in range(7): expected.extend(range(i, 42, 7)) self.assertDatasetProduces(dataset2, expected)
def testElementSpecOptional(self): range_dataset = dataset_ops.Dataset.range(42) dict_dataset = dataset_ops.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]}) tuple_dataset = dataset_ops.Dataset.from_tensor_slices(([1, 2], [3, 4])) dataset = dataset_ops.Dataset.zip((range_dataset, dict_dataset, tuple_dataset)) io.save(dataset, self._test_dir) dataset_loaded = io.load(self._test_dir) self.assertDatasetsEqual(dataset, dataset_loaded)
def testSaveInsideFunction(self, compression): dataset = dataset_ops.Dataset.range(42) @def_function.function def save_fn(): io.save(dataset, self._test_dir, compression=compression) save_fn() dataset = io.load( self._test_dir, dataset.element_spec, compression=compression) self.assertDatasetProduces(dataset, range(42))
def testRepeatAndPrefetch(self): """This test reproduces github.com/tensorflow/tensorflow/issues/49165.""" dataset1 = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32)) io.save(dataset1, self._test_dir) dataset = io.load(self._test_dir) dataset = dataset.shuffle(buffer_size=16) dataset = dataset.batch(16) dataset = dataset.repeat() dataset = dataset.prefetch(1) next_element = self.getNext(dataset) for _ in range(30): self.evaluate(next_element())
def testLoadCheckpointIteratorMultipleBreaks(self): dataset = dataset_ops.Dataset.range(3) io.save(dataset, self._save_dir) loaded_dataset = io.load(self._save_dir) iterator = iter(loaded_dataset) get_next = iterator.get_next checkpoint = trackable_utils.Checkpoint(iterator=iterator) for i in range(len(dataset)): save_path = checkpoint.save(self._checkpoint_prefix) self.assertAllEqual(i, get_next()) checkpoint.restore(save_path).run_restore_ops() self.assertAllEqual(i, get_next()) with self.assertRaises(errors.OutOfRangeError): get_next()
def testLoadCheckpointFullyUsedIterator(self): dataset = dataset_ops.Dataset.range(3) io.save(dataset, self._save_dir) loaded_dataset = io.load(self._save_dir) iterator = iter(loaded_dataset) get_next = iterator.get_next checkpoint = trackable_utils.Checkpoint(iterator=iterator) self.assertAllEqual(0, get_next()) self.assertAllEqual(1, get_next()) self.assertAllEqual(2, get_next()) save_path = checkpoint.save(self._checkpoint_prefix) checkpoint.restore(save_path).run_restore_ops() with self.assertRaises(errors.OutOfRangeError): get_next()
def testCardinality(self): dataset = dataset_ops.Dataset.range(42) io.save(dataset, self._test_dir) dataset2 = io.load(self._test_dir, dataset.element_spec) self.assertEqual(self.evaluate(dataset2.cardinality()), 42)
def _build_ds(self): return io.load(self._save_dir)
def testElementSpecRequired(self): dataset = dataset_ops.Dataset.range(42) io.save(dataset, self._test_dir) with self.assertRaises(ValueError): _ = io.load(self._test_dir)