def testForceReadMode(self): tmpdir = self.snapshot_dir # We write a copy of the snapshot first. dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply( snapshot.snapshot( tmpdir, mode="write", snapshot_name="my_custom_snapshot")) self.assertDatasetProduces(dataset, list(range(10))) # We move the run to a new name. shutil.move( os.path.join(tmpdir, "custom-my_custom_snapshot"), os.path.join(tmpdir, "custom-my_custom_snapshot_2")) # Even though the snapshot.metadata is pointing to the old run that no # longer exists after we moved, we force it to read from the run we specify. dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply( snapshot.snapshot( tmpdir, mode="read", snapshot_name="my_custom_snapshot_2")) self.assertDatasetProduces(dataset, list(range(10))) # We should still have one snapshot and one run. self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
def testReadSnapshotDatasetCustomReaderFn(self): self.createTFRecords() filenames = self._test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 100) ] dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.snapshot( self._snapshot_dir, reader_func=( lambda ds: ds.interleave( # pylint:disable=g-long-lambda lambda x: x, cycle_length=4, num_parallel_calls=4)))) self.assertDatasetProduces(dataset, expected) self.assertSnapshotDirectoryContains( self._snapshot_dir, num_fingerprints=1, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=multiprocessing.cpu_count()) self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.snapshot( self._snapshot_dir, reader_func=( lambda ds: ds.interleave( # pylint:disable=g-long-lambda lambda x: x, cycle_length=4, num_parallel_calls=4)))) self.assertDatasetProducesSet(dataset2, expected)
def testReadShuffledSnapshotWithSeedAfterWrite(self): self.setUpTFRecord(num_files=10, num_records=50) filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 50) ] tmpdir = self.snapshot_dir dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply(snapshot.snapshot(tmpdir, shard_size_bytes=10)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.snapshot(tmpdir, shard_size_bytes=10, shuffle_on_read=True, shuffle_seed=123456)) next2 = self.getNext(dataset2) dataset3 = core_readers._TFRecordDataset(filenames) dataset3 = dataset3.apply( snapshot.snapshot(tmpdir, shard_size_bytes=10, shuffle_on_read=True, shuffle_seed=123456)) next3 = self.getNext(dataset3) # make sure that the items are read back in the same order for both datasets for _ in range(500): res2 = self.evaluate(next2()) res3 = self.evaluate(next3()) self.assertEqual(res2, res3)
def testReadShuffledSnapshotAfterWrite(self): self.setUpTFRecord(num_files=10, num_records=50) filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 50) ] tmpdir = self.makeSnapshotDirectory() dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply(snapshot.snapshot(tmpdir, shard_size_bytes=10)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply(snapshot.snapshot(tmpdir, shuffle_on_read=True)) next2 = self.getNext(dataset2) res1 = self.evaluate(next2()) res2 = self.evaluate(next2()) res3 = self.evaluate(next2()) res4 = self.evaluate(next2()) res5 = self.evaluate(next2()) # make sure that we don't read the file back in the same order. self.assertNotEqual([res1, res2, res3, res4, res5], expected[0:5]) # make sure all the elements are still there dataset3 = core_readers._TFRecordDataset(filenames) dataset3 = dataset3.apply(snapshot.snapshot(tmpdir, shuffle_on_read=True)) self.assertDatasetProduces(dataset3, expected, assert_items_equal=True)
def testReadSnapshotParallelAfterWrite(self, compression): self.setUpTFRecord(10, 4000) filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 4000) ] tmpdir = self.snapshot_dir dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.snapshot( tmpdir, shard_size_bytes=1024 * 1024, num_reader_threads=2, reader_buffer_size=10, compression=compression)) self.assertDatasetProduces(dataset, expected, assert_items_equal=True) # remove the original files and try to read the data back only from # snapshot. self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.snapshot( tmpdir, shard_size_bytes=1024 * 1024, num_reader_threads=2, reader_buffer_size=10, compression=compression)) self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)
def testReadSnapshotBackAfterMultiThreadedWrite( self, compression, threads, size): self.setUpTFRecord() filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] tmpdir = self.snapshot_dir dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.snapshot( tmpdir, compression=compression, num_writer_threads=threads, writer_buffer_size=size)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from # snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)
def testExpiredSnapshotRewrite(self): tmpdir = self.snapshot_dir dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply( snapshot.snapshot(tmpdir, pending_snapshot_expiry_seconds=1)) next1 = self.getNext(dataset1) # Don't finish reading dataset1, so it is never finalized for _ in range(500): self.evaluate(next1()) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) time.sleep(2) # Creating dataset2 after we run through dataset1 due to eager mode, where # the snapshot state is determined immediately upon dataset creation. We # only want to determine the snapshot state for dataset2 after the first # snapshot has expired. dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset2.apply( snapshot.snapshot(tmpdir, pending_snapshot_expiry_seconds=1)) next2 = self.getNext(dataset2) for _ in range(500): self.evaluate(next2()) self.assertSnapshotDirectoryContains(tmpdir, 1, 2, 1)
def testAdditionalOperationsAfterReadBack(self): self.setUpTFRecord() filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] tmpdir = self.snapshot_dir dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset2, expected) expected_after = [ b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] dataset3 = core_readers._TFRecordDataset(filenames) dataset3 = dataset3.apply(snapshot.snapshot(tmpdir)) dataset3 = dataset3.map(lambda x: string_ops.substr_v2(x, 2, 1000)) self.assertDatasetProduces(dataset3, expected_after)
def testAdditionalOperationsAfterReadBack(self): self.setUpTFRecord() filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] tmpdir = self.makeSnapshotDirectory() dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset2, expected) expected_after = [ b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] dataset3 = core_readers._TFRecordDataset(filenames) dataset3 = dataset3.apply(snapshot.snapshot(tmpdir)) dataset3 = dataset3.map(lambda x: string_ops.substr_v2(x, 2, 1000)) self.assertDatasetProduces(dataset3, expected_after)
def testWriteDifferentPipelinesInOneDirectory(self): tmpdir = self.makeSnapshotDirectory() dataset = dataset_ops.Dataset.range(1000) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(1000))) dataset = dataset_ops.Dataset.range(1001) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(1001))) self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)
def testRoundtripEmptySnapshot(self): dataset = dataset_ops.Dataset.range(0) dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir)) self.assertDatasetProduces(dataset, []) self.assertSnapshotDirectoryContains(self._snapshot_dir, num_fingerprints=1, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=0) dataset2 = dataset_ops.Dataset.range(0) dataset2 = dataset.apply(snapshot.snapshot(self._snapshot_dir)) self.assertDatasetProduces(dataset2, [])
def testWriteDifferentPipelinesInOneDirectory(self): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(1000) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(1000))) dataset = dataset_ops.Dataset.range(1001) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(1001))) self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)
def testWriteSnapshotDatasetSameFingerprintMultipleCompleteRuns(self): dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir)) self.assertDatasetProduces(dataset1, list(range(1000))) dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir)) self.assertDatasetProduces(dataset2, list(range(1000))) self.assertSnapshotDirectoryContains( self._snapshot_dir, num_fingerprints=1, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=multiprocessing.cpu_count())
def testForceReadNonexistentSnapshot(self): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) with self.assertRaises(errors.NotFoundError): dataset = dataset.apply(snapshot.snapshot(tmpdir, mode="read")) get_next = self.getNext(dataset) self.evaluate(get_next())
def testSnapshotDatasetInvalidReaderFn(self): dataset = dataset_ops.Dataset.range(1000) with self.assertRaises(TypeError): dataset = dataset.apply( snapshot.snapshot(self._snapshot_dir, reader_func=lambda x: x + 1)) next_fn = self.getNext(dataset) self.evaluate(next_fn())
def testSnapshotDatasetInvalidShardFn(self): dataset = dataset_ops.Dataset.range(1000) with self.assertRaises(TypeError): dataset = dataset.apply( snapshot.snapshot(self._snapshot_dir, shard_func=lambda _: "invalid_fn")) next_fn = self.getNext(dataset) self.evaluate(next_fn())
def testWriteSnapshotMultiFileSuccessful(self): tmpdir = self.makeSnapshotDirectory() dataset = dataset_ops.Dataset.range(20000) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(20000))) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 2)
def testWriteSnapshotSimpleSuccessful(self, compression): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(1000) dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression)) self.assertDatasetProduces(dataset, list(range(1000))) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
def testGetNextCreatesDir(self): tmpdir = self.snapshot_dir # We create two iterators but call getNext on only one. dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply(snapshot.snapshot(tmpdir)) next1 = self.getNext(dataset1) dataset2 = dataset_ops.Dataset.range(1001) dataset2 = dataset2.apply(snapshot.snapshot(tmpdir)) _ = self.getNext(dataset2) for _ in range(1000): self.evaluate(next1()) # We check that only one directory is created. self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
def testForcePassthroughMode(self): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply(snapshot.snapshot(tmpdir, mode="passthrough")) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) self.assertSnapshotDirectoryContains(tmpdir, 0, 0, 0)
def testReadSnapshotDatasetAutoWriteSnappyRead(self): self.createTFRecords() filenames = self._test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 100) ] dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply( snapshot.snapshot(self._snapshot_dir, compression="AUTO")) self.assertDatasetProduces(dataset, expected) self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply( snapshot.snapshot(self._snapshot_dir, compression="SNAPPY")) self.assertDatasetProduces(dataset2, expected)
def testForceReadNonexistentNamedSnapshot(self): tmpdir = self.makeSnapshotDirectory() dataset = dataset_ops.Dataset.range(10) with self.assertRaises(errors.NotFoundError): dataset = dataset.apply( snapshot.snapshot( tmpdir, mode="read", snapshot_name="my_nonexistent_snapshot")) get_next = self.getNext(dataset) self.evaluate(get_next())
def testWriteSnapshotRepeatAfterwards(self): tmpdir = self.makeSnapshotDirectory() dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply(snapshot.snapshot(tmpdir)) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
def testWriteSnapshotMultipleSimultaneous(self): tmpdir = self.makeSnapshotDirectory() dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply(snapshot.snapshot(tmpdir)) next1 = self.getNext(dataset1) dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset2.apply(snapshot.snapshot(tmpdir)) next2 = self.getNext(dataset2) for _ in range(1000): self.evaluate(next1()) self.evaluate(next2()) # we check that only one copy of the metadata has been written, and the # one that lost the race would be in passthrough mode. self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
def testWriteSnapshotRepeatAfterwards(self, compression): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression)) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
def testWriteSnapshotMultipleSimultaneous(self): tmpdir = self.snapshot_dir dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply(snapshot.snapshot(tmpdir)) next1 = self.getNext(dataset1) dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset2.apply(snapshot.snapshot(tmpdir)) next2 = self.getNext(dataset2) for i in range(0, 1000): self.assertEqual(i, self.evaluate(next1())) self.assertEqual(i, self.evaluate(next2())) # we check that only one copy of the metadata has been written, and the # one that lost the race would be in passthrough mode. self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
def testForceWriteMode(self): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply(snapshot.snapshot(tmpdir, mode="write")) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) # We will end up writing 10 different runs. self.assertSnapshotDirectoryContains(tmpdir, 1, 10, 1)
def ds_fn(): self._snapshot_dir = os.path.join(self.get_temp_dir(), "snapshot") if not os.path.exists(self._snapshot_dir): os.mkdir(self._snapshot_dir) dataset = dataset_ops.Dataset.range(100) dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir)) if repeat: dataset = dataset.repeat(2) return dataset
def testReadUsingFlatMap(self): dataset = dataset_ops.Dataset.range(1000) dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir)) self.assertDatasetProduces(dataset, list(range(1000))) flat_map = dataset_ops.Dataset.from_tensors(dataset).flat_map(lambda x: x) self.assertDatasetProduces(flat_map, list(range(1000))) self.assertSnapshotDirectoryContains( self._snapshot_dir, num_fingerprints=1, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=multiprocessing.cpu_count())
def testSameFingerprintWithDifferentInitializationOrder(self): tmpdir = self.snapshot_dir dataset1 = dataset_ops.Dataset.range(0, 100) dataset2 = dataset_ops.Dataset.range(100, 200) dataset3 = dataset_ops.Dataset.range(200, 300) dataset = dataset1.concatenate(dataset2).concatenate(dataset3) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(300))) dataset4 = dataset_ops.Dataset.range(200, 300) dataset5 = dataset_ops.Dataset.range(100, 200) dataset6 = dataset_ops.Dataset.range(0, 100) dataset = dataset6.concatenate(dataset5).concatenate(dataset4) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, list(range(300))) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
def testReadSnapshotBackAfterWrite(self): self.setUpTFRecord() filenames = self.test_filenames expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in range(0, 10) for r in range(0, 10) ] tmpdir = self.makeSnapshotDirectory() dataset = core_readers._TFRecordDataset(filenames) dataset = dataset.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset, expected) # remove the original files and try to read the data back only from snapshot self.removeTFRecords() dataset2 = core_readers._TFRecordDataset(filenames) dataset2 = dataset2.apply(snapshot.snapshot(tmpdir)) self.assertDatasetProduces(dataset2, expected)
def _createSimpleDataset(self, num_elems, tmp_dir=None): if not tmp_dir: tmp_dir = self._makeSnapshotDirectory() dataset = dataset_ops.Dataset.from_tensor_slices([1.0]) dataset = dataset.map( lambda x: gen_array_ops.broadcast_to(x, [50, 50, 3])) dataset = dataset.repeat(num_elems) dataset = dataset.apply(snapshot.snapshot(tmp_dir)) return dataset
def testSnapshotArgsCreateNewSnapshot(self): tmpdir = self.snapshot_dir dataset1 = dataset_ops.Dataset.range(1000) dataset1 = dataset1.apply( snapshot.snapshot(tmpdir, shard_size_bytes=10000)) next1 = self.getNext(dataset1) for _ in range(1000): self.evaluate(next1()) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) # Create second snapshot with a different shard_size_bytes dataset2 = dataset_ops.Dataset.range(1000) dataset2 = dataset1.apply( snapshot.snapshot(tmpdir, shard_size_bytes=20000)) next2 = self.getNext(dataset2) for _ in range(1000): self.evaluate(next2()) self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)
def testWriteSnapshotCustomShardFunction(self): dataset = dataset_ops.Dataset.range(1000) dataset = dataset.enumerate() dataset = dataset.apply( snapshot.snapshot(self._snapshot_dir, shard_func=lambda i, _: i % 2)) dataset = dataset.map(lambda _, elem: elem) self.assertDatasetProduces(dataset, list(range(1000))) self.assertSnapshotDirectoryContains(self._snapshot_dir, num_fingerprints=1, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=2)