def _restore_iterator(self):
   output_types = dtypes.string
   output_shapes = tensor_shape.scalar()
   iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
   get_next = iterator.get_next()
   restore_op = gen_dataset_ops.restore_iterator(
       iterator._iterator_resource, self._iterator_checkpoint_path())
   return restore_op, get_next
예제 #2
0
 def _build_graph(start, stop, num_epochs, path):
   dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
   iterator = dataset.make_initializable_iterator()
   init_op = iterator.initializer
   get_next = iterator.get_next()
   save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
   restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
                                                 path)
   return init_op, get_next, save_op, restore_op
예제 #3
0
 def _build_graph(start, stop):
   iterator = dataset_ops.Dataset.range(start,
                                        stop).make_initializable_iterator()
   init_op = iterator.initializer
   get_next = iterator.get_next()
   path = self._iterator_checkpoint_prefix()
   save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
   restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
                                                 path)
   return init_op, get_next, save_op, restore_op
예제 #4
0
 def _build_graph(start, stop):
   iterator = dataset_ops.Dataset.range(start,
                                        stop).make_initializable_iterator()
   init_op = iterator.initializer
   get_next = iterator.get_next()
   path = self._iterator_checkpoint_prefix()
   save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
   restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
                                                 path)
   return init_op, get_next, save_op, restore_op
예제 #5
0
 def _build_reader_dataset_graph():
   filenames = ["test"]  # Does not exist but we don't care in this test.
   path = _iterator_checkpoint_prefix()
   iterator = readers.FixedLengthRecordDataset(
       filenames, 1, 0, 0).make_initializable_iterator()
   init_op = iterator.initializer
   get_next_op = iterator.get_next()
   save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
   restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
                                                 path)
   return init_op, get_next_op, save_op, restore_op
예제 #6
0
 def _build_reader_dataset_graph():
   filenames = ["test"]  # Does not exist but we don't care in this test.
   path = _iterator_checkpoint_prefix()
   iterator = readers.FixedLengthRecordDataset(
       filenames, 1, 0, 0).make_initializable_iterator()
   init_op = iterator.initializer
   get_next_op = iterator.get_next()
   save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
   restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
                                                 path)
   return init_op, get_next_op, save_op, restore_op
 def _build_iterator_graph(self, num_epochs):
   filenames = self._createFiles()
   path = self._iterator_checkpoint_path()
   dataset = (readers.FixedLengthRecordDataset(
       filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
              .repeat(num_epochs))
   iterator = dataset.make_initializable_iterator()
   init_op = iterator.initializer
   get_next_op = iterator.get_next()
   save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
   restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
                                                 path)
   return init_op, get_next_op, save_op, restore_op
예제 #8
0
    def testRestoreWithoutBuildingDatasetGraph(self):
        def _build_graph(start, stop, num_epochs, path):
            dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
            iterator = dataset.make_initializable_iterator()
            init_op = iterator.initializer
            get_next = iterator.get_next()
            save_op = gen_dataset_ops.save_iterator(
                iterator._iterator_resource, path)
            restore_op = gen_dataset_ops.restore_iterator(
                iterator._iterator_resource, path)
            return init_op, get_next, save_op, restore_op

        # Saving and restoring in different sessions.
        start = 2
        stop = 10
        num_epochs = 5
        break_point = 5
        break_epoch = 3
        path = self._iterator_checkpoint_prefix()
        with ops.Graph().as_default() as g:
            init_op, get_next, save_op, _ = _build_graph(
                start, stop, num_epochs, path)
            with self.test_session(graph=g) as sess:
                sess.run(variables.global_variables_initializer())
                sess.run(init_op)
                for _ in range(break_epoch):
                    for i in range(start, stop):
                        self.assertEqual(i, sess.run(get_next))
                for i in range(start, break_point):
                    self.assertEqual(i, sess.run(get_next))
                sess.run(save_op)

        with ops.Graph().as_default() as g:
            # Create an empty IteratorResource and restore the Iterator into it.
            output_types = dtypes.int64
            output_shapes = tensor_shape.scalar()
            iterator = dataset_ops.Iterator.from_structure(
                output_types, output_shapes)
            restore_op = gen_dataset_ops.restore_iterator(
                iterator._iterator_resource, path)
            get_next = iterator.get_next()
            with self.test_session(graph=g) as sess:
                sess.run(restore_op)
                for i in range(break_point, stop):
                    self.assertEqual(i, sess.run(get_next))
                for _ in range(break_epoch + 1, num_epochs):
                    for i in range(start, stop):
                        self.assertEqual(i, sess.run(get_next))
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(get_next)
  def testRestoreWithoutBuildingDatasetGraph(self):

    def _build_graph(start, stop, num_epochs, path):
      dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
      iterator = dataset.make_initializable_iterator()
      init_op = iterator.initializer
      get_next = iterator.get_next()
      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
                                                    path)
      return init_op, get_next, save_op, restore_op

    # Saving and restoring in different sessions.
    start = 2
    stop = 10
    num_epochs = 5
    break_point = 5
    break_epoch = 3
    path = self._iterator_checkpoint_prefix()
    with ops.Graph().as_default() as g:
      init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs,
                                                   path)
      with self.test_session(graph=g) as sess:
        sess.run(variables.global_variables_initializer())
        sess.run(init_op)
        for _ in range(break_epoch):
          for i in range(start, stop):
            self.assertEqual(i, sess.run(get_next))
        for i in range(start, break_point):
          self.assertEqual(i, sess.run(get_next))
        sess.run(save_op)

    with ops.Graph().as_default() as g:
      # Create an empty IteratorResource and restore the Iterator into it.
      output_types = dtypes.int64
      output_shapes = tensor_shape.scalar()
      iterator = iterator_ops.Iterator.from_structure(output_types,
                                                      output_shapes)
      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
                                                    path)
      get_next = iterator.get_next()
      with self.test_session(graph=g) as sess:
        sess.run(restore_op)
        for i in range(break_point, stop):
          self.assertEqual(i, sess.run(get_next))
        for _ in range(break_epoch + 1, num_epochs):
          for i in range(start, stop):
            self.assertEqual(i, sess.run(get_next))
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(get_next)