예제 #1
0
    def testExpiredSnapshotRewrite(self):
        tmpdir = self.snapshot_dir

        dataset1 = dataset_ops.Dataset.range(1000)
        dataset1 = dataset1.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     pending_snapshot_expiry_seconds=1))
        next1 = self.getNext(dataset1)

        # Don't finish reading dataset1, so it is never finalized
        for _ in range(500):
            self.evaluate(next1())
        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

        time.sleep(2)

        # Creating dataset2 after we run through dataset1 due to eager mode, where
        # the snapshot state is determined immediately upon dataset creation. We
        # only want to determine the snapshot state for dataset2 after the first
        # snapshot has expired.
        dataset2 = dataset_ops.Dataset.range(1000)
        dataset2 = dataset2.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     pending_snapshot_expiry_seconds=1))
        next2 = self.getNext(dataset2)

        for _ in range(500):
            self.evaluate(next2())
        self.assertSnapshotDirectoryContains(tmpdir, 1, 2, 1)
예제 #2
0
    def testReadSnapshotParallelAfterWrite(self, compression):
        self.setUpTFRecord(10, 4000)
        filenames = self.test_filenames

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 4000)
        ]

        tmpdir = self.snapshot_dir
        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     shard_size_bytes=1024 * 1024,
                                     num_reader_threads=2,
                                     reader_buffer_size=10,
                                     compression=compression))
        self.assertDatasetProduces(dataset, expected, assert_items_equal=True)

        # remove the original files and try to read the data back only from
        # snapshot.
        self.removeTFRecords()

        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     shard_size_bytes=1024 * 1024,
                                     num_reader_threads=2,
                                     reader_buffer_size=10,
                                     compression=compression))
        self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)
예제 #3
0
    def testForceReadMode(self):
        tmpdir = self.snapshot_dir

        # We write a copy of the snapshot first.
        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     mode="write",
                                     snapshot_name="my_custom_snapshot"))
        self.assertDatasetProduces(dataset, list(range(10)))

        # We move the run to a new name.
        shutil.move(os.path.join(tmpdir, "custom-my_custom_snapshot"),
                    os.path.join(tmpdir, "custom-my_custom_snapshot_2"))

        # Even though the snapshot.metadata is pointing to the old run that no
        # longer exists after we moved, we force it to read from the run we specify.
        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     mode="read",
                                     snapshot_name="my_custom_snapshot_2"))
        self.assertDatasetProduces(dataset, list(range(10)))

        # We should still have one snapshot and one run.
        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
예제 #4
0
    def testReadSnapshotBackAfterMultiThreadedWrite(self, compression, threads,
                                                    size):
        self.setUpTFRecord()
        filenames = self.test_filenames

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 10)
        ]

        tmpdir = self.snapshot_dir
        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     compression=compression,
                                     num_writer_threads=threads,
                                     writer_buffer_size=size))
        self.assertDatasetProduces(dataset, expected)

        # remove the original files and try to read the data back only from
        # snapshot
        self.removeTFRecords()

        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.legacy_snapshot(tmpdir, compression=compression))
        self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)
예제 #5
0
    def testAdditionalOperationsAfterReadBack(self):
        self.setUpTFRecord()
        filenames = self.test_filenames

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 10)
        ]

        tmpdir = self.snapshot_dir
        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir))
        self.assertDatasetProduces(dataset, expected)

        # remove the original files and try to read the data back only from snapshot
        self.removeTFRecords()

        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir))
        self.assertDatasetProduces(dataset2, expected)

        expected_after = [
            b"cord %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 10)
        ]

        dataset3 = core_readers._TFRecordDataset(filenames)
        dataset3 = dataset3.apply(snapshot.legacy_snapshot(tmpdir))
        dataset3 = dataset3.map(lambda x: string_ops.substr_v2(x, 2, 1000))
        self.assertDatasetProduces(dataset3, expected_after)
예제 #6
0
    def testWriteDifferentPipelinesInOneDirectory(self):
        tmpdir = self.snapshot_dir

        dataset = dataset_ops.Dataset.range(1000)
        dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir))
        self.assertDatasetProduces(dataset, list(range(1000)))

        dataset = dataset_ops.Dataset.range(1001)
        dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir))
        self.assertDatasetProduces(dataset, list(range(1001)))

        self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)
예제 #7
0
 def testForceReadNonexistentSnapshot(self):
   tmpdir = self.snapshot_dir
   dataset = dataset_ops.Dataset.range(10)
   with self.assertRaises(errors.NotFoundError):
     dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir, mode="read"))
     get_next = self.getNext(dataset)
     self.evaluate(get_next())
예제 #8
0
    def testGetNextCreatesDir(self):
        tmpdir = self.snapshot_dir

        # We create two iterators but call getNext on only one.
        dataset1 = dataset_ops.Dataset.range(1000)
        dataset1 = dataset1.apply(snapshot.legacy_snapshot(tmpdir))
        next1 = self.getNext(dataset1)

        dataset2 = dataset_ops.Dataset.range(1001)
        dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir))
        _ = self.getNext(dataset2)

        for _ in range(1000):
            self.evaluate(next1())

        # We check that only one directory is created.
        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
예제 #9
0
    def testWriteSnapshotMultipleSimultaneous(self):
        tmpdir = self.snapshot_dir

        dataset1 = dataset_ops.Dataset.range(1000)
        dataset1 = dataset1.apply(snapshot.legacy_snapshot(tmpdir))
        next1 = self.getNext(dataset1)

        dataset2 = dataset_ops.Dataset.range(1000)
        dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir))
        next2 = self.getNext(dataset2)

        for i in range(0, 1000):
            self.assertEqual(i, self.evaluate(next1()))
            self.assertEqual(i, self.evaluate(next2()))

        # we check that only one copy of the metadata has been written, and the
        # one that lost the race would be in passthrough mode.
        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
예제 #10
0
    def testWriteSnapshotSimpleSuccessful(self, compression):
        tmpdir = self.snapshot_dir

        dataset = dataset_ops.Dataset.range(1000)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir, compression=compression))
        self.assertDatasetProduces(dataset, list(range(1000)))

        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
예제 #11
0
    def testWriteSnapshotRepeatAfterwards(self, compression):
        tmpdir = self.snapshot_dir

        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir, compression=compression))
        dataset = dataset.repeat(10)
        self.assertDatasetProduces(dataset, list(range(10)) * 10)

        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
예제 #12
0
    def testForceWriteMode(self):
        tmpdir = self.snapshot_dir

        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir, mode="write"))
        dataset = dataset.repeat(10)
        self.assertDatasetProduces(dataset, list(range(10)) * 10)

        # We will end up writing 10 different runs.
        self.assertSnapshotDirectoryContains(tmpdir, 1, 10, 1)
예제 #13
0
    def testForcePassthroughMode(self):
        tmpdir = self.snapshot_dir

        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir, mode="passthrough"))
        dataset = dataset.repeat(10)
        self.assertDatasetProduces(dataset, list(range(10)) * 10)

        self.assertSnapshotDirectoryContains(tmpdir, 0, 0, 0)
예제 #14
0
  def testReadShuffledSnapshotWithSeedAfterWrite(self):
    self.setUpTFRecord(num_files=10, num_records=50)
    filenames = self.test_filenames

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in range(0, 50)
    ]

    tmpdir = self.snapshot_dir
    dataset = core_readers._TFRecordDataset(filenames)
    dataset = dataset.apply(
        snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10))
    self.assertDatasetProduces(dataset, expected)

    # remove the original files and try to read the data back only from snapshot
    self.removeTFRecords()

    dataset2 = core_readers._TFRecordDataset(filenames)
    dataset2 = dataset2.apply(
        snapshot.legacy_snapshot(
            tmpdir,
            shard_size_bytes=10,
            shuffle_on_read=True,
            shuffle_seed=123456))
    next2 = self.getNext(dataset2)

    dataset3 = core_readers._TFRecordDataset(filenames)
    dataset3 = dataset3.apply(
        snapshot.legacy_snapshot(
            tmpdir,
            shard_size_bytes=10,
            shuffle_on_read=True,
            shuffle_seed=123456))
    next3 = self.getNext(dataset3)

    # make sure that the items are read back in the same order for both datasets
    for _ in range(500):
      res2 = self.evaluate(next2())
      res3 = self.evaluate(next3())
      self.assertEqual(res2, res3)
예제 #15
0
    def testSameFingerprintWithDifferentInitializationOrder(self):
        tmpdir = self.snapshot_dir

        dataset1 = dataset_ops.Dataset.range(0, 100)
        dataset2 = dataset_ops.Dataset.range(100, 200)
        dataset3 = dataset_ops.Dataset.range(200, 300)

        dataset = dataset1.concatenate(dataset2).concatenate(dataset3)
        dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir))
        self.assertDatasetProduces(dataset, list(range(300)))

        dataset4 = dataset_ops.Dataset.range(200, 300)
        dataset5 = dataset_ops.Dataset.range(100, 200)
        dataset6 = dataset_ops.Dataset.range(0, 100)

        dataset = dataset6.concatenate(dataset5).concatenate(dataset4)
        dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir))
        self.assertDatasetProduces(dataset, list(range(300)))

        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
예제 #16
0
    def testReadShuffledSnapshotAfterWrite(self):
        self.setUpTFRecord(num_files=10, num_records=50)
        filenames = self.test_filenames

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 50)
        ]

        tmpdir = self.snapshot_dir
        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir, shard_size_bytes=100))
        self.assertDatasetProduces(dataset, expected)

        # remove the original files and try to read the data back only from snapshot
        self.removeTFRecords()

        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     shard_size_bytes=100,
                                     shuffle_on_read=True))
        next2 = self.getNext(dataset2)

        res1 = self.evaluate(next2())
        res2 = self.evaluate(next2())
        res3 = self.evaluate(next2())
        res4 = self.evaluate(next2())
        res5 = self.evaluate(next2())

        # make sure that we don't read the file back in the same order.
        self.assertNotEqual([res1, res2, res3, res4, res5], expected[0:5])

        # make sure all the elements are still there
        dataset3 = core_readers._TFRecordDataset(filenames)
        dataset3 = dataset3.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     shard_size_bytes=100,
                                     shuffle_on_read=True))
        self.assertDatasetProduces(dataset3, expected, assert_items_equal=True)
예제 #17
0
    def testSnapshotArgsCreateNewSnapshot(self):
        tmpdir = self.snapshot_dir

        dataset1 = dataset_ops.Dataset.range(1000)
        dataset1 = dataset1.apply(
            snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10000))
        next1 = self.getNext(dataset1)

        for _ in range(1000):
            self.evaluate(next1())
        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

        # Create second snapshot with a different shard_size_bytes
        dataset2 = dataset_ops.Dataset.range(1000)
        dataset2 = dataset1.apply(
            snapshot.legacy_snapshot(tmpdir, shard_size_bytes=20000))
        next2 = self.getNext(dataset2)

        for _ in range(1000):
            self.evaluate(next2())
        self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)
예제 #18
0
  def testReadSnapshotBackAfterWrite(self, compression):
    self.setUpTFRecord()
    filenames = self._filenames

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in range(0, 10)
    ]

    tmpdir = self.snapshot_dir
    dataset = core_readers._TFRecordDataset(filenames)
    dataset = dataset.apply(
        snapshot.legacy_snapshot(tmpdir, compression=compression))
    self.assertDatasetProduces(dataset, expected)

    # remove the original files and try to read the data back only from snapshot
    self.removeTFRecords()

    dataset2 = core_readers._TFRecordDataset(filenames)
    dataset2 = dataset2.apply(
        snapshot.legacy_snapshot(tmpdir, compression=compression))
    self.assertDatasetProduces(dataset2, expected)
예제 #19
0
  def testSpecifySnapshotNameWriteAndRead(self):
    tmpdir = self.snapshot_dir

    dataset = dataset_ops.Dataset.range(10)
    dataset = dataset.apply(
        snapshot.legacy_snapshot(tmpdir, snapshot_name="my_custom_snapshot"))
    dataset = dataset.repeat(10)
    self.assertDatasetProduces(dataset, list(range(10)) * 10)

    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
    self.assertTrue(
        os.path.exists(os.path.join(tmpdir, "custom-my_custom_snapshot")))
    self.assertTrue(
        os.path.exists(
            os.path.join(tmpdir, "custom-my_custom_snapshot", "custom")))
예제 #20
0
    def _createSimpleDataset(self,
                             num_elements,
                             tmp_dir=None,
                             compression=snapshot.COMPRESSION_NONE):
        if not tmp_dir:
            tmp_dir = self._makeSnapshotDirectory()

        dataset = dataset_ops.Dataset.from_tensor_slices([1.0])
        dataset = dataset.map(
            lambda x: gen_array_ops.broadcast_to(x, [50, 50, 3]))
        dataset = dataset.repeat(num_elements)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmp_dir, compression=compression))

        return dataset
예제 #21
0
 def ds_fn():
   self.snapshot_dir = os.path.join(self.get_temp_dir(), "snapshot")
   if not os.path.exists(self.snapshot_dir):
     os.mkdir(self.snapshot_dir)
   dataset = dataset_ops.Dataset.range(1000)
   dataset = dataset.apply(
       snapshot.legacy_snapshot(
           self.snapshot_dir,
           num_writer_threads=num_threads,
           writer_buffer_size=2 * num_threads,
           num_reader_threads=num_threads,
           reader_buffer_size=2 * num_threads,
           pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds,
           shard_size_bytes=shard_size_bytes))
   if repeat:
     dataset = dataset.repeat(2)
   return dataset
예제 #22
0
  def testSpecifyShardSize(self, compression):
    tmpdir = self.snapshot_dir

    dataset = dataset_ops.Dataset.from_tensor_slices([1.0])
    dataset = dataset.map(lambda x: gen_array_ops.broadcast_to(x, [1024, 1024]))
    dataset = dataset.repeat(10)
    dataset = dataset.apply(
        snapshot.legacy_snapshot(
            tmpdir, shard_size_bytes=10 * 1024 * 1024, compression=compression))
    next_fn = self.getNext(dataset)

    for _ in range(10):
      self.evaluate(next_fn())

    num_files = 1
    if compression == snapshot.COMPRESSION_NONE:
      num_files = 3
    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, num_files)
예제 #23
0
  def testWriteSnapshotMixTypes(self, compression):
    tmpdir = self.snapshot_dir

    dataset = dataset_ops.Dataset.range(10)

    def map_fn(x):
      return (x, string_ops.as_string(x), string_ops.as_string(2 * x), 2 * x)

    dataset = dataset.map(map_fn)
    dataset = dataset.apply(
        snapshot.legacy_snapshot(tmpdir, compression=compression))
    dataset = dataset.repeat(10)

    expected = []
    for i in range(10):
      expected.append((i, str(i), str(2 * i), 2 * i))
    self.assertDatasetProduces(dataset, expected * 10)

    self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)