def Reset(self, sess=None): if tf.executing_eagerly_outside_functions(): self._iterator = { key: iter(ds) for key, ds in self._dataset.items() } else: sess.run([it.initializer for it in self._iterator.values()]) super().Reset(sess)
def _InitIterator(self): if self.host_id in self._dataset: return with py_utils.GlobalStepContext(None): # Hide global_step tensor from being captured by dataset function. ds = self.GetDataset() options = tf.data.Options() options.experimental_deterministic = bool(self.cluster.in_unit_test) ds = ds.with_options(options) self._dataset[self.host_id] = ds if tf.executing_eagerly_outside_functions(): it = iter(ds) else: it = tf.data.make_initializable_iterator(ds) self._iterator[self.host_id] = it
def _InitIterator(self): """Override of the root's _InitIterator to support dataset repeat.""" if self.host_id in self._dataset: return p = self.params self._repeat_steps = getattr(self._input_generator.params, 'repeat_steps', None) self._repeat_with_sentinel = getattr(self._input_generator.params, 'repeat_with_sentinel', None) with py_utils.GlobalStepContext(None): # Hide global_step tensor from being captured by dataset function. ds = self.GetDataset() if self._repeat_steps: tf.logging.info('Repeating dataset every %d steps.', self._repeat_steps) ds = ds.take(self._repeat_steps).repeat() elif self._repeat_with_sentinel: tf.logging.info('Attaching sentinel to end of dataset and repeat.') # Dataset should contain batches of type NestedMap. sentinel_batch = ds.element_spec.Transform( lambda x: tf.zeros(x.shape, dtype=x.dtype)) # Fill the dummy sentinel batch's sentinel_key tensor with sentinel_value. sentinel_batch[p.sentinel_key] = tf.fill( sentinel_batch[p.sentinel_key].shape, p.sentinel_value) tf.logging.info('attaching sentinel %r', sentinel_batch[p.sentinel_key]) tf.logging.info('sentinel type %r', sentinel_batch[p.sentinel_key].dtype) ds = ds.concatenate( tf.data.Dataset.from_tensors(sentinel_batch)).repeat() options = tf.data.Options() options.experimental_deterministic = bool(self.cluster.in_unit_test) ds = ds.with_options(options) self._dataset[self.host_id] = ds if tf.executing_eagerly_outside_functions(): it = iter(ds) else: it = tf.data.make_initializable_iterator(ds) self._iterator[self.host_id] = it
def Initialize(self, sess=None): if not tf.executing_eagerly_outside_functions(): sess.run([it.initializer for it in self._iterator.values()]) super().Initialize(sess)