コード例 #1
0
  def testMultipleIterators(self, reshuffle_each_iteration, buffer_size):
    range_limit = 5
    num_repeats = 2
    num_outputs = range_limit * num_repeats

    def ds_fn():
      # pylint: disable=cell-var-from-loop
      return self._build_shuffle_dataset(
          range_limit=range_limit,
          num_repeats=num_repeats,
          buffer_size=buffer_size,
          seed=None,  # Iterator seeds are generated non-deterministically.
          reshuffle_each_iteration=reshuffle_each_iteration)
      # pylint: enable=cell-var-from-loop

    with ops.Graph().as_default() as g:
      ds = ds_fn()
      iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()]
      get_next_ops = [it.get_next() for it in iterators]
      saveables = [
          contrib_iterator_ops.make_saveable_from_iterator(it)
          for it in iterators
      ]
      for saveable in saveables:
        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
      saver = saver_lib.Saver(allow_empty=True)
      with self.session(graph=g) as sess:
        self._save(sess, saver)
        expected = [self.evaluate(get_next_ops) for _ in range(num_outputs)]
        self._restore(saver, sess)
        actual = [self.evaluate(get_next_ops) for _ in range(num_outputs)]
        self.match(expected, actual)
コード例 #2
0
 def _build_input_pipeline(self, name, num_outputs):
   with ops.name_scope(name):
     ds = dataset_ops.Dataset.range(num_outputs).shuffle(
         10, reshuffle_each_iteration=False).prefetch(10)
     iterator = ds.make_initializable_iterator()
     saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
     ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
     return iterator.initializer, iterator.get_next()
 def _build_input_pipeline(self, name, num_outputs):
   with ops.name_scope(name):
     ds = dataset_ops.Dataset.range(num_outputs).shuffle(
         10, reshuffle_each_iteration=False).prefetch(10)
     iterator = ds.make_initializable_iterator()
     saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
     ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
     return iterator.initializer, iterator.get_next()
コード例 #4
0
 def _build_empty_graph(self, ds_fn, sparse_tensors=False):
   iterator = iterator_ops.Iterator.from_structure(
       self._get_output_types(ds_fn),
       output_shapes=self._get_output_shapes(ds_fn),
       output_classes=self._get_output_classes(ds_fn))
   saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
   ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
   if sparse_tensors:
     get_next = sparse_tensor.SparseTensor(*iterator.get_next())
   else:
     get_next = iterator.get_next()
   saver = saver_lib.Saver(allow_empty=True)
   return get_next, saver
コード例 #5
0
 def _build_empty_graph(self, ds_fn, sparse_tensors=False):
   iterator = iterator_ops.Iterator.from_structure(
       self._get_output_types(ds_fn),
       output_shapes=self._get_output_shapes(ds_fn),
       output_classes=self._get_output_classes(ds_fn))
   saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
   ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
   if sparse_tensors:
     get_next = sparse_tensor.SparseTensor(*iterator.get_next())
   else:
     get_next = iterator.get_next()
   saver = saver_lib.Saver(allow_empty=True)
   return get_next, saver
コード例 #6
0
  def _build_graph(self, ds_fn, sparse_tensors=False):
    iterator = dataset_ops.make_initializable_iterator(ds_fn())

    saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
    init_op = iterator.initializer
    if sparse_tensors:
      get_next = sparse_tensor.SparseTensor(*iterator.get_next())
    else:
      get_next = iterator.get_next()
    self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
                                         sparse_tensors)
    saver = saver_lib.Saver(allow_empty=True)
    return init_op, get_next, saver
コード例 #7
0
  def _build_graph(self, ds_fn, sparse_tensors=False):
    iterator = dataset_ops.make_initializable_iterator(ds_fn())

    saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
    init_op = iterator.initializer
    if sparse_tensors:
      get_next = sparse_tensor.SparseTensor(*iterator.get_next())
    else:
      get_next = iterator.get_next()
    self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
                                         sparse_tensors)
    saver = saver_lib.Saver(allow_empty=True)
    return init_op, get_next, saver
コード例 #8
0
def make_saveable_from_iterator(iterator):
    """Returns a SaveableObject for saving/restore iterator state using Saver.

  Args:
    iterator: Iterator.

  For example:

  ```python
  with tf.Graph().as_default():
    ds = tf.data.Dataset.range(10)
    iterator = ds.make_initializable_iterator()
    # Build the iterator SaveableObject.
    saveable_obj = tf.contrib.data.make_saveable_from_iterator(iterator)
    # Add the SaveableObject to the SAVEABLE_OBJECTS collection so
    # it can be automatically saved using Saver.
    tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
    saver = tf.train.Saver()

    while continue_training:
      ... Perform training ...
      if should_save_checkpoint:
        saver.save()
  ```

  Note: When restoring the iterator, the existing iterator state is completely
  discarded. This means that any changes you may have made to the Dataset
  graph will be discarded as well! This includes the new Dataset graph
  that you may have built during validation. So, while running validation,
  make sure to run the initializer for the validation input pipeline after
  restoring the checkpoint.

  Note: Not all iterators support checkpointing yet. Attempting to save the
  state of an unsupported iterator will throw an error.
  """
    return iterator_ops.make_saveable_from_iterator(iterator)
コード例 #9
0
  def testMultipleIterators(self):
    range_limit = 5
    num_repeats = 2
    num_outputs = range_limit * num_repeats
    buffer_sizes = [1, 3, 5, 8, 10]

    for reshuffle_each_iteration in [True, False]:
      for buffer_size in buffer_sizes:

        def ds_fn():
          # pylint: disable=cell-var-from-loop
          return self._build_shuffle_dataset(
              range_limit=range_limit,
              num_repeats=num_repeats,
              buffer_size=buffer_size,
              seed=None,  # Iterator seeds are generated non-deterministically.
              reshuffle_each_iteration=reshuffle_each_iteration)
          # pylint: enable=cell-var-from-loop

        with ops.Graph().as_default() as g:
          ds = ds_fn()
          iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()]
          get_next_ops = [it.get_next() for it in iterators]
          saveables = [
              contrib_iterator_ops.make_saveable_from_iterator(it)
              for it in iterators
          ]
          for saveable in saveables:
            ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
          saver = saver_lib.Saver(allow_empty=True)
          with self.session(graph=g) as sess:
            self._save(sess, saver)
            expected = [sess.run(get_next_ops) for _ in range(num_outputs)]
            self._restore(saver, sess)
            actual = [sess.run(get_next_ops) for _ in range(num_outputs)]
            self.match(expected, actual)
コード例 #10
0
def make_saveable_from_iterator(iterator):
  """Returns a SaveableObject for saving/restore iterator state using Saver.

  Args:
    iterator: Iterator.

  For example:

  ```python
  with tf.Graph().as_default():
    ds = tf.data.Dataset.range(10)
    iterator = ds.make_initializable_iterator()
    # Build the iterator SaveableObject.
    saveable_obj = tf.data.experimental.make_saveable_from_iterator(iterator)
    # Add the SaveableObject to the SAVEABLE_OBJECTS collection so
    # it can be automatically saved using Saver.
    tf.compat.v1.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
    saver = tf.compat.v1.train.Saver()

    while continue_training:
      ... Perform training ...
      if should_save_checkpoint:
        saver.save()
  ```

  Note: When restoring the iterator, the existing iterator state is completely
  discarded. This means that any changes you may have made to the Dataset
  graph will be discarded as well! This includes the new Dataset graph
  that you may have built during validation. So, while running validation,
  make sure to run the initializer for the validation input pipeline after
  restoring the checkpoint.

  Note: Not all iterators support checkpointing yet. Attempting to save the
  state of an unsupported iterator will throw an error.
  """
  return iterator_ops.make_saveable_from_iterator(iterator)