Exemplo n.º 1
0
 def Reset(self, sess=None):
     if tf.executing_eagerly_outside_functions():
         self._iterator = {
             key: iter(ds)
             for key, ds in self._dataset.items()
         }
     else:
         sess.run([it.initializer for it in self._iterator.values()])
     super().Reset(sess)
Exemplo n.º 2
0
  def _InitIterator(self):
    if self.host_id in self._dataset:
      return

    with py_utils.GlobalStepContext(None):
      # Hide global_step tensor from being captured by dataset function.
      ds = self.GetDataset()
    options = tf.data.Options()
    options.experimental_deterministic = bool(self.cluster.in_unit_test)
    ds = ds.with_options(options)
    self._dataset[self.host_id] = ds
    if tf.executing_eagerly_outside_functions():
      it = iter(ds)
    else:
      it = tf.data.make_initializable_iterator(ds)
    self._iterator[self.host_id] = it
Exemplo n.º 3
0
    def _InitIterator(self):
        """Override of the root's _InitIterator to support dataset repeat."""
        if self.host_id in self._dataset:
            return

        p = self.params
        self._repeat_steps = getattr(self._input_generator.params,
                                     'repeat_steps', None)
        self._repeat_with_sentinel = getattr(self._input_generator.params,
                                             'repeat_with_sentinel', None)

        with py_utils.GlobalStepContext(None):
            # Hide global_step tensor from being captured by dataset function.
            ds = self.GetDataset()
        if self._repeat_steps:
            tf.logging.info('Repeating dataset every %d steps.',
                            self._repeat_steps)
            ds = ds.take(self._repeat_steps).repeat()
        elif self._repeat_with_sentinel:
            tf.logging.info('Attaching sentinel to end of dataset and repeat.')
            # Dataset should contain batches of type NestedMap.
            sentinel_batch = ds.element_spec.Transform(
                lambda x: tf.zeros(x.shape, dtype=x.dtype))
            # Fill the dummy sentinel batch's sentinel_key tensor with sentinel_value.
            sentinel_batch[p.sentinel_key] = tf.fill(
                sentinel_batch[p.sentinel_key].shape, p.sentinel_value)
            tf.logging.info('attaching sentinel %r',
                            sentinel_batch[p.sentinel_key])
            tf.logging.info('sentinel type %r',
                            sentinel_batch[p.sentinel_key].dtype)
            ds = ds.concatenate(
                tf.data.Dataset.from_tensors(sentinel_batch)).repeat()
        options = tf.data.Options()
        options.experimental_deterministic = bool(self.cluster.in_unit_test)
        ds = ds.with_options(options)
        self._dataset[self.host_id] = ds
        if tf.executing_eagerly_outside_functions():
            it = iter(ds)
        else:
            it = tf.data.make_initializable_iterator(ds)
        self._iterator[self.host_id] = it
Exemplo n.º 4
0
 def Initialize(self, sess=None):
   if not tf.executing_eagerly_outside_functions():
     sess.run([it.initializer for it in self._iterator.values()])
   super().Initialize(sess)