예제 #1
0
  def testForceReadMode(self):
    tmpdir = self.snapshot_dir

    # We write a copy of the snapshot first.
    dataset = dataset_ops.Dataset.range(10)
    dataset = dataset.apply(
        snapshot.snapshot(
            tmpdir, mode="write", snapshot_name="my_custom_snapshot"))
    self.assertDatasetProduces(dataset, list(range(10)))

    # We move the run to a new name.
    shutil.move(
        os.path.join(tmpdir, "custom-my_custom_snapshot"),
        os.path.join(tmpdir, "custom-my_custom_snapshot_2"))

    # Even though the snapshot.metadata is pointing to the old run that no
    # longer exists after we moved, we force it to read from the run we specify.
    dataset = dataset_ops.Dataset.range(10)
    dataset = dataset.apply(
        snapshot.snapshot(
            tmpdir, mode="read", snapshot_name="my_custom_snapshot_2"))
    self.assertDatasetProduces(dataset, list(range(10)))

    # We should still have one snapshot and one run.
    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
예제 #2
0
    def testReadSnapshotDatasetCustomReaderFn(self):
        self.createTFRecords()
        filenames = self._test_filenames
        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 100)
        ]

        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(
            snapshot.snapshot(
                self._snapshot_dir,
                reader_func=(
                    lambda ds: ds.interleave(  # pylint:disable=g-long-lambda
                        lambda x: x,
                        cycle_length=4,
                        num_parallel_calls=4))))
        self.assertDatasetProduces(dataset, expected)
        self.assertSnapshotDirectoryContains(
            self._snapshot_dir,
            num_fingerprints=1,
            num_runs_per_fingerprint=1,
            num_snapshot_shards_per_run=multiprocessing.cpu_count())

        self.removeTFRecords()
        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.snapshot(
                self._snapshot_dir,
                reader_func=(
                    lambda ds: ds.interleave(  # pylint:disable=g-long-lambda
                        lambda x: x,
                        cycle_length=4,
                        num_parallel_calls=4))))
        self.assertDatasetProducesSet(dataset2, expected)
예제 #3
0
  def testReadShuffledSnapshotWithSeedAfterWrite(self):
    self.setUpTFRecord(num_files=10, num_records=50)
    filenames = self.test_filenames

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in range(0, 50)
    ]

    tmpdir = self.snapshot_dir
    dataset = core_readers._TFRecordDataset(filenames)
    dataset = dataset.apply(snapshot.snapshot(tmpdir, shard_size_bytes=10))
    self.assertDatasetProduces(dataset, expected)

    # remove the original files and try to read the data back only from snapshot
    self.removeTFRecords()

    dataset2 = core_readers._TFRecordDataset(filenames)
    dataset2 = dataset2.apply(
        snapshot.snapshot(tmpdir, shard_size_bytes=10, shuffle_on_read=True,
                          shuffle_seed=123456))
    next2 = self.getNext(dataset2)

    dataset3 = core_readers._TFRecordDataset(filenames)
    dataset3 = dataset3.apply(
        snapshot.snapshot(tmpdir, shard_size_bytes=10,
                          shuffle_on_read=True, shuffle_seed=123456))
    next3 = self.getNext(dataset3)

    # make sure that the items are read back in the same order for both datasets
    for _ in range(500):
      res2 = self.evaluate(next2())
      res3 = self.evaluate(next3())
      self.assertEqual(res2, res3)
예제 #4
0
  def testReadShuffledSnapshotAfterWrite(self):
    self.setUpTFRecord(num_files=10, num_records=50)
    filenames = self.test_filenames

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in range(0, 50)
    ]

    tmpdir = self.makeSnapshotDirectory()
    dataset = core_readers._TFRecordDataset(filenames)
    dataset = dataset.apply(snapshot.snapshot(tmpdir, shard_size_bytes=10))
    self.assertDatasetProduces(dataset, expected)

    # remove the original files and try to read the data back only from snapshot
    self.removeTFRecords()

    dataset2 = core_readers._TFRecordDataset(filenames)
    dataset2 = dataset2.apply(snapshot.snapshot(tmpdir, shuffle_on_read=True))
    next2 = self.getNext(dataset2)

    res1 = self.evaluate(next2())
    res2 = self.evaluate(next2())
    res3 = self.evaluate(next2())
    res4 = self.evaluate(next2())
    res5 = self.evaluate(next2())

    # make sure that we don't read the file back in the same order.
    self.assertNotEqual([res1, res2, res3, res4, res5], expected[0:5])

    # make sure all the elements are still there
    dataset3 = core_readers._TFRecordDataset(filenames)
    dataset3 = dataset3.apply(snapshot.snapshot(tmpdir, shuffle_on_read=True))
    self.assertDatasetProduces(dataset3, expected, assert_items_equal=True)
예제 #5
0
  def testReadSnapshotParallelAfterWrite(self, compression):
    self.setUpTFRecord(10, 4000)
    filenames = self.test_filenames

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in range(0, 4000)
    ]

    tmpdir = self.snapshot_dir
    dataset = core_readers._TFRecordDataset(filenames)
    dataset = dataset.apply(
        snapshot.snapshot(
            tmpdir,
            shard_size_bytes=1024 * 1024,
            num_reader_threads=2,
            reader_buffer_size=10,
            compression=compression))
    self.assertDatasetProduces(dataset, expected, assert_items_equal=True)

    # remove the original files and try to read the data back only from
    # snapshot.
    self.removeTFRecords()

    dataset2 = core_readers._TFRecordDataset(filenames)
    dataset2 = dataset2.apply(
        snapshot.snapshot(
            tmpdir,
            shard_size_bytes=1024 * 1024,
            num_reader_threads=2,
            reader_buffer_size=10,
            compression=compression))
    self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)
예제 #6
0
  def testReadSnapshotBackAfterMultiThreadedWrite(
      self, compression, threads, size):
    self.setUpTFRecord()
    filenames = self.test_filenames

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in range(0, 10)
    ]

    tmpdir = self.snapshot_dir
    dataset = core_readers._TFRecordDataset(filenames)
    dataset = dataset.apply(
        snapshot.snapshot(
            tmpdir,
            compression=compression,
            num_writer_threads=threads,
            writer_buffer_size=size))
    self.assertDatasetProduces(dataset, expected)

    # remove the original files and try to read the data back only from
    # snapshot
    self.removeTFRecords()

    dataset2 = core_readers._TFRecordDataset(filenames)
    dataset2 = dataset2.apply(
        snapshot.snapshot(tmpdir, compression=compression))
    self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)
예제 #7
0
  def testExpiredSnapshotRewrite(self):
    tmpdir = self.snapshot_dir

    dataset1 = dataset_ops.Dataset.range(1000)
    dataset1 = dataset1.apply(
        snapshot.snapshot(tmpdir, pending_snapshot_expiry_seconds=1))
    next1 = self.getNext(dataset1)

    # Don't finish reading dataset1, so it is never finalized
    for _ in range(500):
      self.evaluate(next1())
    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

    time.sleep(2)

    # Creating dataset2 after we run through dataset1 due to eager mode, where
    # the snapshot state is determined immediately upon dataset creation. We
    # only want to determine the snapshot state for dataset2 after the first
    # snapshot has expired.
    dataset2 = dataset_ops.Dataset.range(1000)
    dataset2 = dataset2.apply(
        snapshot.snapshot(tmpdir, pending_snapshot_expiry_seconds=1))
    next2 = self.getNext(dataset2)

    for _ in range(500):
      self.evaluate(next2())
    self.assertSnapshotDirectoryContains(tmpdir, 1, 2, 1)
예제 #8
0
  def testAdditionalOperationsAfterReadBack(self):
    self.setUpTFRecord()
    filenames = self.test_filenames

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in range(0, 10)
    ]

    tmpdir = self.snapshot_dir
    dataset = core_readers._TFRecordDataset(filenames)
    dataset = dataset.apply(snapshot.snapshot(tmpdir))
    self.assertDatasetProduces(dataset, expected)

    # remove the original files and try to read the data back only from snapshot
    self.removeTFRecords()

    dataset2 = core_readers._TFRecordDataset(filenames)
    dataset2 = dataset2.apply(snapshot.snapshot(tmpdir))
    self.assertDatasetProduces(dataset2, expected)

    expected_after = [
        b"cord %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in range(0, 10)
    ]

    dataset3 = core_readers._TFRecordDataset(filenames)
    dataset3 = dataset3.apply(snapshot.snapshot(tmpdir))
    dataset3 = dataset3.map(lambda x: string_ops.substr_v2(x, 2, 1000))
    self.assertDatasetProduces(dataset3, expected_after)
예제 #9
0
  def testAdditionalOperationsAfterReadBack(self):
    self.setUpTFRecord()
    filenames = self.test_filenames

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in range(0, 10)
    ]

    tmpdir = self.makeSnapshotDirectory()
    dataset = core_readers._TFRecordDataset(filenames)
    dataset = dataset.apply(snapshot.snapshot(tmpdir))
    self.assertDatasetProduces(dataset, expected)

    # remove the original files and try to read the data back only from snapshot
    self.removeTFRecords()

    dataset2 = core_readers._TFRecordDataset(filenames)
    dataset2 = dataset2.apply(snapshot.snapshot(tmpdir))
    self.assertDatasetProduces(dataset2, expected)

    expected_after = [
        b"cord %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in range(0, 10)
    ]

    dataset3 = core_readers._TFRecordDataset(filenames)
    dataset3 = dataset3.apply(snapshot.snapshot(tmpdir))
    dataset3 = dataset3.map(lambda x: string_ops.substr_v2(x, 2, 1000))
    self.assertDatasetProduces(dataset3, expected_after)
예제 #10
0
  def testWriteDifferentPipelinesInOneDirectory(self):
    tmpdir = self.makeSnapshotDirectory()

    dataset = dataset_ops.Dataset.range(1000)
    dataset = dataset.apply(snapshot.snapshot(tmpdir))
    self.assertDatasetProduces(dataset, list(range(1000)))

    dataset = dataset_ops.Dataset.range(1001)
    dataset = dataset.apply(snapshot.snapshot(tmpdir))
    self.assertDatasetProduces(dataset, list(range(1001)))

    self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)
예제 #11
0
    def testRoundtripEmptySnapshot(self):
        dataset = dataset_ops.Dataset.range(0)
        dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
        self.assertDatasetProduces(dataset, [])
        self.assertSnapshotDirectoryContains(self._snapshot_dir,
                                             num_fingerprints=1,
                                             num_runs_per_fingerprint=1,
                                             num_snapshot_shards_per_run=0)

        dataset2 = dataset_ops.Dataset.range(0)
        dataset2 = dataset.apply(snapshot.snapshot(self._snapshot_dir))
        self.assertDatasetProduces(dataset2, [])
예제 #12
0
  def testWriteDifferentPipelinesInOneDirectory(self):
    tmpdir = self.snapshot_dir

    dataset = dataset_ops.Dataset.range(1000)
    dataset = dataset.apply(snapshot.snapshot(tmpdir))
    self.assertDatasetProduces(dataset, list(range(1000)))

    dataset = dataset_ops.Dataset.range(1001)
    dataset = dataset.apply(snapshot.snapshot(tmpdir))
    self.assertDatasetProduces(dataset, list(range(1001)))

    self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)
예제 #13
0
    def testWriteSnapshotDatasetSameFingerprintMultipleCompleteRuns(self):
        dataset1 = dataset_ops.Dataset.range(1000)
        dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
        self.assertDatasetProduces(dataset1, list(range(1000)))
        dataset2 = dataset_ops.Dataset.range(1000)
        dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
        self.assertDatasetProduces(dataset2, list(range(1000)))

        self.assertSnapshotDirectoryContains(
            self._snapshot_dir,
            num_fingerprints=1,
            num_runs_per_fingerprint=1,
            num_snapshot_shards_per_run=multiprocessing.cpu_count())
예제 #14
0
 def testForceReadNonexistentSnapshot(self):
   tmpdir = self.snapshot_dir
   dataset = dataset_ops.Dataset.range(10)
   with self.assertRaises(errors.NotFoundError):
     dataset = dataset.apply(snapshot.snapshot(tmpdir, mode="read"))
     get_next = self.getNext(dataset)
     self.evaluate(get_next())
예제 #15
0
 def testSnapshotDatasetInvalidReaderFn(self):
   dataset = dataset_ops.Dataset.range(1000)
   with self.assertRaises(TypeError):
     dataset = dataset.apply(
         snapshot.snapshot(self._snapshot_dir, reader_func=lambda x: x + 1))
     next_fn = self.getNext(dataset)
     self.evaluate(next_fn())
예제 #16
0
 def testSnapshotDatasetInvalidShardFn(self):
     dataset = dataset_ops.Dataset.range(1000)
     with self.assertRaises(TypeError):
         dataset = dataset.apply(
             snapshot.snapshot(self._snapshot_dir,
                               shard_func=lambda _: "invalid_fn"))
         next_fn = self.getNext(dataset)
         self.evaluate(next_fn())
예제 #17
0
  def testWriteSnapshotMultiFileSuccessful(self):
    tmpdir = self.makeSnapshotDirectory()

    dataset = dataset_ops.Dataset.range(20000)
    dataset = dataset.apply(snapshot.snapshot(tmpdir))
    self.assertDatasetProduces(dataset, list(range(20000)))

    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 2)
예제 #18
0
  def testWriteSnapshotSimpleSuccessful(self, compression):
    tmpdir = self.snapshot_dir

    dataset = dataset_ops.Dataset.range(1000)
    dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression))
    self.assertDatasetProduces(dataset, list(range(1000)))

    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
예제 #19
0
  def testWriteSnapshotMultiFileSuccessful(self):
    tmpdir = self.makeSnapshotDirectory()

    dataset = dataset_ops.Dataset.range(20000)
    dataset = dataset.apply(snapshot.snapshot(tmpdir))
    self.assertDatasetProduces(dataset, list(range(20000)))

    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 2)
예제 #20
0
  def testGetNextCreatesDir(self):
    tmpdir = self.snapshot_dir

    # We create two iterators but call getNext on only one.
    dataset1 = dataset_ops.Dataset.range(1000)
    dataset1 = dataset1.apply(snapshot.snapshot(tmpdir))
    next1 = self.getNext(dataset1)

    dataset2 = dataset_ops.Dataset.range(1001)
    dataset2 = dataset2.apply(snapshot.snapshot(tmpdir))
    _ = self.getNext(dataset2)

    for _ in range(1000):
      self.evaluate(next1())

    # We check that only one directory is created.
    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
예제 #21
0
  def testForcePassthroughMode(self):
    tmpdir = self.snapshot_dir

    dataset = dataset_ops.Dataset.range(10)
    dataset = dataset.apply(snapshot.snapshot(tmpdir, mode="passthrough"))
    dataset = dataset.repeat(10)
    self.assertDatasetProduces(dataset, list(range(10)) * 10)

    self.assertSnapshotDirectoryContains(tmpdir, 0, 0, 0)
예제 #22
0
    def testReadSnapshotDatasetAutoWriteSnappyRead(self):
        self.createTFRecords()
        filenames = self._test_filenames
        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 100)
        ]

        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(
            snapshot.snapshot(self._snapshot_dir, compression="AUTO"))
        self.assertDatasetProduces(dataset, expected)

        self.removeTFRecords()
        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.snapshot(self._snapshot_dir, compression="SNAPPY"))
        self.assertDatasetProduces(dataset2, expected)
예제 #23
0
 def testForceReadNonexistentNamedSnapshot(self):
   tmpdir = self.makeSnapshotDirectory()
   dataset = dataset_ops.Dataset.range(10)
   with self.assertRaises(errors.NotFoundError):
     dataset = dataset.apply(
         snapshot.snapshot(
             tmpdir, mode="read", snapshot_name="my_nonexistent_snapshot"))
     get_next = self.getNext(dataset)
     self.evaluate(get_next())
예제 #24
0
  def testWriteSnapshotRepeatAfterwards(self):
    tmpdir = self.makeSnapshotDirectory()

    dataset = dataset_ops.Dataset.range(10)
    dataset = dataset.apply(snapshot.snapshot(tmpdir))
    dataset = dataset.repeat(10)
    self.assertDatasetProduces(dataset, list(range(10)) * 10)

    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
예제 #25
0
  def testWriteSnapshotMultipleSimultaneous(self):
    tmpdir = self.makeSnapshotDirectory()

    dataset1 = dataset_ops.Dataset.range(1000)
    dataset1 = dataset1.apply(snapshot.snapshot(tmpdir))
    next1 = self.getNext(dataset1)

    dataset2 = dataset_ops.Dataset.range(1000)
    dataset2 = dataset2.apply(snapshot.snapshot(tmpdir))
    next2 = self.getNext(dataset2)

    for _ in range(1000):
      self.evaluate(next1())
      self.evaluate(next2())

    # we check that only one copy of the metadata has been written, and the
    # one that lost the race would be in passthrough mode.
    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
예제 #26
0
  def testWriteSnapshotRepeatAfterwards(self, compression):
    tmpdir = self.snapshot_dir

    dataset = dataset_ops.Dataset.range(10)
    dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression))
    dataset = dataset.repeat(10)
    self.assertDatasetProduces(dataset, list(range(10)) * 10)

    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
예제 #27
0
  def testWriteSnapshotMultipleSimultaneous(self):
    tmpdir = self.snapshot_dir

    dataset1 = dataset_ops.Dataset.range(1000)
    dataset1 = dataset1.apply(snapshot.snapshot(tmpdir))
    next1 = self.getNext(dataset1)

    dataset2 = dataset_ops.Dataset.range(1000)
    dataset2 = dataset2.apply(snapshot.snapshot(tmpdir))
    next2 = self.getNext(dataset2)

    for i in range(0, 1000):
      self.assertEqual(i, self.evaluate(next1()))
      self.assertEqual(i, self.evaluate(next2()))

    # we check that only one copy of the metadata has been written, and the
    # one that lost the race would be in passthrough mode.
    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
예제 #28
0
  def testForceWriteMode(self):
    tmpdir = self.snapshot_dir

    dataset = dataset_ops.Dataset.range(10)
    dataset = dataset.apply(snapshot.snapshot(tmpdir, mode="write"))
    dataset = dataset.repeat(10)
    self.assertDatasetProduces(dataset, list(range(10)) * 10)

    # We will end up writing 10 different runs.
    self.assertSnapshotDirectoryContains(tmpdir, 1, 10, 1)
        def ds_fn():
            self._snapshot_dir = os.path.join(self.get_temp_dir(), "snapshot")
            if not os.path.exists(self._snapshot_dir):
                os.mkdir(self._snapshot_dir)

            dataset = dataset_ops.Dataset.range(100)
            dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
            if repeat:
                dataset = dataset.repeat(2)
            return dataset
예제 #30
0
 def testReadUsingFlatMap(self):
   dataset = dataset_ops.Dataset.range(1000)
   dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
   self.assertDatasetProduces(dataset, list(range(1000)))
   flat_map = dataset_ops.Dataset.from_tensors(dataset).flat_map(lambda x: x)
   self.assertDatasetProduces(flat_map, list(range(1000)))
   self.assertSnapshotDirectoryContains(
       self._snapshot_dir,
       num_fingerprints=1,
       num_runs_per_fingerprint=1,
       num_snapshot_shards_per_run=multiprocessing.cpu_count())
예제 #31
0
  def testSameFingerprintWithDifferentInitializationOrder(self):
    tmpdir = self.snapshot_dir

    dataset1 = dataset_ops.Dataset.range(0, 100)
    dataset2 = dataset_ops.Dataset.range(100, 200)
    dataset3 = dataset_ops.Dataset.range(200, 300)

    dataset = dataset1.concatenate(dataset2).concatenate(dataset3)
    dataset = dataset.apply(snapshot.snapshot(tmpdir))
    self.assertDatasetProduces(dataset, list(range(300)))

    dataset4 = dataset_ops.Dataset.range(200, 300)
    dataset5 = dataset_ops.Dataset.range(100, 200)
    dataset6 = dataset_ops.Dataset.range(0, 100)

    dataset = dataset6.concatenate(dataset5).concatenate(dataset4)
    dataset = dataset.apply(snapshot.snapshot(tmpdir))
    self.assertDatasetProduces(dataset, list(range(300)))

    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
예제 #32
0
    def testReadSnapshotBackAfterWrite(self):
        self.setUpTFRecord()
        filenames = self.test_filenames

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 10)
        ]

        tmpdir = self.makeSnapshotDirectory()
        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(snapshot.snapshot(tmpdir))
        self.assertDatasetProduces(dataset, expected)

        # remove the original files and try to read the data back only from snapshot
        self.removeTFRecords()

        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(snapshot.snapshot(tmpdir))
        self.assertDatasetProduces(dataset2, expected)
  def _createSimpleDataset(self, num_elems, tmp_dir=None):
    if not tmp_dir:
      tmp_dir = self._makeSnapshotDirectory()

    dataset = dataset_ops.Dataset.from_tensor_slices([1.0])
    dataset = dataset.map(
        lambda x: gen_array_ops.broadcast_to(x, [50, 50, 3]))
    dataset = dataset.repeat(num_elems)
    dataset = dataset.apply(snapshot.snapshot(tmp_dir))

    return dataset
예제 #34
0
  def testReadSnapshotBackAfterWrite(self):
    self.setUpTFRecord()
    filenames = self.test_filenames

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in range(0, 10)
    ]

    tmpdir = self.makeSnapshotDirectory()
    dataset = core_readers._TFRecordDataset(filenames)
    dataset = dataset.apply(snapshot.snapshot(tmpdir))
    self.assertDatasetProduces(dataset, expected)

    # remove the original files and try to read the data back only from snapshot
    self.removeTFRecords()

    dataset2 = core_readers._TFRecordDataset(filenames)
    dataset2 = dataset2.apply(snapshot.snapshot(tmpdir))
    self.assertDatasetProduces(dataset2, expected)
예제 #35
0
  def testSnapshotArgsCreateNewSnapshot(self):
    tmpdir = self.snapshot_dir

    dataset1 = dataset_ops.Dataset.range(1000)
    dataset1 = dataset1.apply(
        snapshot.snapshot(tmpdir, shard_size_bytes=10000))
    next1 = self.getNext(dataset1)

    for _ in range(1000):
      self.evaluate(next1())
    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

    # Create second snapshot with a different shard_size_bytes
    dataset2 = dataset_ops.Dataset.range(1000)
    dataset2 = dataset1.apply(
        snapshot.snapshot(tmpdir, shard_size_bytes=20000))
    next2 = self.getNext(dataset2)

    for _ in range(1000):
      self.evaluate(next2())
    self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)
예제 #36
0
 def testWriteSnapshotCustomShardFunction(self):
     dataset = dataset_ops.Dataset.range(1000)
     dataset = dataset.enumerate()
     dataset = dataset.apply(
         snapshot.snapshot(self._snapshot_dir,
                           shard_func=lambda i, _: i % 2))
     dataset = dataset.map(lambda _, elem: elem)
     self.assertDatasetProduces(dataset, list(range(1000)))
     self.assertSnapshotDirectoryContains(self._snapshot_dir,
                                          num_fingerprints=1,
                                          num_runs_per_fingerprint=1,
                                          num_snapshot_shards_per_run=2)