def _restore_iterator(self): output_types = dtypes.string output_shapes = tensor_shape.scalar() iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) get_next = iterator.get_next() restore_op = gen_dataset_ops.restore_iterator( iterator._iterator_resource, self._iterator_checkpoint_path()) return restore_op, get_next
def _build_graph(start, stop, num_epochs, path): dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, path) return init_op, get_next, save_op, restore_op
def _build_graph(start, stop): iterator = dataset_ops.Dataset.range(start, stop).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() path = self._iterator_checkpoint_prefix() save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, path) return init_op, get_next, save_op, restore_op
def _build_reader_dataset_graph(): filenames = ["test"] # Does not exist but we don't care in this test. path = _iterator_checkpoint_prefix() iterator = readers.FixedLengthRecordDataset( filenames, 1, 0, 0).make_initializable_iterator() init_op = iterator.initializer get_next_op = iterator.get_next() save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, path) return init_op, get_next_op, save_op, restore_op
def _build_iterator_graph(self, num_epochs): filenames = self._createFiles() path = self._iterator_checkpoint_path() dataset = (readers.FixedLengthRecordDataset( filenames, self._record_bytes, self._header_bytes, self._footer_bytes) .repeat(num_epochs)) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next_op = iterator.get_next() save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, path) return init_op, get_next_op, save_op, restore_op
def testRestoreWithoutBuildingDatasetGraph(self): def _build_graph(start, stop, num_epochs, path): dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() save_op = gen_dataset_ops.save_iterator( iterator._iterator_resource, path) restore_op = gen_dataset_ops.restore_iterator( iterator._iterator_resource, path) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. start = 2 stop = 10 num_epochs = 5 break_point = 5 break_epoch = 3 path = self._iterator_checkpoint_prefix() with ops.Graph().as_default() as g: init_op, get_next, save_op, _ = _build_graph( start, stop, num_epochs, path) with self.test_session(graph=g) as sess: sess.run(variables.global_variables_initializer()) sess.run(init_op) for _ in range(break_epoch): for i in range(start, stop): self.assertEqual(i, sess.run(get_next)) for i in range(start, break_point): self.assertEqual(i, sess.run(get_next)) sess.run(save_op) with ops.Graph().as_default() as g: # Create an empty IteratorResource and restore the Iterator into it. output_types = dtypes.int64 output_shapes = tensor_shape.scalar() iterator = dataset_ops.Iterator.from_structure( output_types, output_shapes) restore_op = gen_dataset_ops.restore_iterator( iterator._iterator_resource, path) get_next = iterator.get_next() with self.test_session(graph=g) as sess: sess.run(restore_op) for i in range(break_point, stop): self.assertEqual(i, sess.run(get_next)) for _ in range(break_epoch + 1, num_epochs): for i in range(start, stop): self.assertEqual(i, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)
def testRestoreWithoutBuildingDatasetGraph(self): def _build_graph(start, stop, num_epochs, path): dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, path) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. start = 2 stop = 10 num_epochs = 5 break_point = 5 break_epoch = 3 path = self._iterator_checkpoint_prefix() with ops.Graph().as_default() as g: init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs, path) with self.test_session(graph=g) as sess: sess.run(variables.global_variables_initializer()) sess.run(init_op) for _ in range(break_epoch): for i in range(start, stop): self.assertEqual(i, sess.run(get_next)) for i in range(start, break_point): self.assertEqual(i, sess.run(get_next)) sess.run(save_op) with ops.Graph().as_default() as g: # Create an empty IteratorResource and restore the Iterator into it. output_types = dtypes.int64 output_shapes = tensor_shape.scalar() iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, path) get_next = iterator.get_next() with self.test_session(graph=g) as sess: sess.run(restore_op) for i in range(break_point, stop): self.assertEqual(i, sess.run(get_next)) for _ in range(break_epoch + 1, num_epochs): for i in range(start, stop): self.assertEqual(i, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)