def testMultipleIterators(self, reshuffle_each_iteration, buffer_size): range_limit = 5 num_repeats = 2 num_outputs = range_limit * num_repeats def ds_fn(): # pylint: disable=cell-var-from-loop return self._build_shuffle_dataset( range_limit=range_limit, num_repeats=num_repeats, buffer_size=buffer_size, seed=None, # Iterator seeds are generated non-deterministically. reshuffle_each_iteration=reshuffle_each_iteration) # pylint: enable=cell-var-from-loop with ops.Graph().as_default() as g: ds = ds_fn() iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()] get_next_ops = [it.get_next() for it in iterators] saveables = [ contrib_iterator_ops.make_saveable_from_iterator(it) for it in iterators ] for saveable in saveables: ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) saver = saver_lib.Saver(allow_empty=True) with self.session(graph=g) as sess: self._save(sess, saver) expected = [self.evaluate(get_next_ops) for _ in range(num_outputs)] self._restore(saver, sess) actual = [self.evaluate(get_next_ops) for _ in range(num_outputs)] self.match(expected, actual)
def _build_input_pipeline(self, name, num_outputs): with ops.name_scope(name): ds = dataset_ops.Dataset.range(num_outputs).shuffle( 10, reshuffle_each_iteration=False).prefetch(10) iterator = ds.make_initializable_iterator() saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) return iterator.initializer, iterator.get_next()
def _build_input_pipeline(self, name, num_outputs): with ops.name_scope(name): ds = dataset_ops.Dataset.range(num_outputs).shuffle( 10, reshuffle_each_iteration=False).prefetch(10) iterator = ds.make_initializable_iterator() saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) return iterator.initializer, iterator.get_next()
def _build_empty_graph(self, ds_fn, sparse_tensors=False): iterator = iterator_ops.Iterator.from_structure( self._get_output_types(ds_fn), output_shapes=self._get_output_shapes(ds_fn), output_classes=self._get_output_classes(ds_fn)) saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) if sparse_tensors: get_next = sparse_tensor.SparseTensor(*iterator.get_next()) else: get_next = iterator.get_next() saver = saver_lib.Saver(allow_empty=True) return get_next, saver
def _build_empty_graph(self, ds_fn, sparse_tensors=False): iterator = iterator_ops.Iterator.from_structure( self._get_output_types(ds_fn), output_shapes=self._get_output_shapes(ds_fn), output_classes=self._get_output_classes(ds_fn)) saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) if sparse_tensors: get_next = sparse_tensor.SparseTensor(*iterator.get_next()) else: get_next = iterator.get_next() saver = saver_lib.Saver(allow_empty=True) return get_next, saver
def _build_graph(self, ds_fn, sparse_tensors=False): iterator = dataset_ops.make_initializable_iterator(ds_fn()) saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) init_op = iterator.initializer if sparse_tensors: get_next = sparse_tensor.SparseTensor(*iterator.get_next()) else: get_next = iterator.get_next() self._add_iterator_ops_to_collection(init_op, get_next, ds_fn, sparse_tensors) saver = saver_lib.Saver(allow_empty=True) return init_op, get_next, saver
def _build_graph(self, ds_fn, sparse_tensors=False): iterator = dataset_ops.make_initializable_iterator(ds_fn()) saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) init_op = iterator.initializer if sparse_tensors: get_next = sparse_tensor.SparseTensor(*iterator.get_next()) else: get_next = iterator.get_next() self._add_iterator_ops_to_collection(init_op, get_next, ds_fn, sparse_tensors) saver = saver_lib.Saver(allow_empty=True) return init_op, get_next, saver
def make_saveable_from_iterator(iterator): """Returns a SaveableObject for saving/restore iterator state using Saver. Args: iterator: Iterator. For example: ```python with tf.Graph().as_default(): ds = tf.data.Dataset.range(10) iterator = ds.make_initializable_iterator() # Build the iterator SaveableObject. saveable_obj = tf.contrib.data.make_saveable_from_iterator(iterator) # Add the SaveableObject to the SAVEABLE_OBJECTS collection so # it can be automatically saved using Saver. tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) saver = tf.train.Saver() while continue_training: ... Perform training ... if should_save_checkpoint: saver.save() ``` Note: When restoring the iterator, the existing iterator state is completely discarded. This means that any changes you may have made to the Dataset graph will be discarded as well! This includes the new Dataset graph that you may have built during validation. So, while running validation, make sure to run the initializer for the validation input pipeline after restoring the checkpoint. Note: Not all iterators support checkpointing yet. Attempting to save the state of an unsupported iterator will throw an error. """ return iterator_ops.make_saveable_from_iterator(iterator)
def testMultipleIterators(self): range_limit = 5 num_repeats = 2 num_outputs = range_limit * num_repeats buffer_sizes = [1, 3, 5, 8, 10] for reshuffle_each_iteration in [True, False]: for buffer_size in buffer_sizes: def ds_fn(): # pylint: disable=cell-var-from-loop return self._build_shuffle_dataset( range_limit=range_limit, num_repeats=num_repeats, buffer_size=buffer_size, seed=None, # Iterator seeds are generated non-deterministically. reshuffle_each_iteration=reshuffle_each_iteration) # pylint: enable=cell-var-from-loop with ops.Graph().as_default() as g: ds = ds_fn() iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()] get_next_ops = [it.get_next() for it in iterators] saveables = [ contrib_iterator_ops.make_saveable_from_iterator(it) for it in iterators ] for saveable in saveables: ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) saver = saver_lib.Saver(allow_empty=True) with self.session(graph=g) as sess: self._save(sess, saver) expected = [sess.run(get_next_ops) for _ in range(num_outputs)] self._restore(saver, sess) actual = [sess.run(get_next_ops) for _ in range(num_outputs)] self.match(expected, actual)
def make_saveable_from_iterator(iterator): """Returns a SaveableObject for saving/restore iterator state using Saver. Args: iterator: Iterator. For example: ```python with tf.Graph().as_default(): ds = tf.data.Dataset.range(10) iterator = ds.make_initializable_iterator() # Build the iterator SaveableObject. saveable_obj = tf.data.experimental.make_saveable_from_iterator(iterator) # Add the SaveableObject to the SAVEABLE_OBJECTS collection so # it can be automatically saved using Saver. tf.compat.v1.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) saver = tf.compat.v1.train.Saver() while continue_training: ... Perform training ... if should_save_checkpoint: saver.save() ``` Note: When restoring the iterator, the existing iterator state is completely discarded. This means that any changes you may have made to the Dataset graph will be discarded as well! This includes the new Dataset graph that you may have built during validation. So, while running validation, make sure to run the initializer for the validation input pipeline after restoring the checkpoint. Note: Not all iterators support checkpointing yet. Attempting to save the state of an unsupported iterator will throw an error. """ return iterator_ops.make_saveable_from_iterator(iterator)