Example #1
0
 def variables_for_ema(self):
   p = self.params
   all_vars = set(tf.trainable_variables()) | set(
       tf.moving_average_variables())
   if p.train.ema_decay_moving_vars:
     all_vars |= set(tf.get_collection('moving_vars'))
   all_vars &= set(self.vars.Flatten())
   for var in all_vars:
     tf.logging.debug('variables_for_ema: %s', var.name)
   return all_vars
Example #2
0
 def ApplyExponentialMovingAverage(self, ema):
   """Wraps `self.train_op` with an op updating exponential moving average."""
   # We need to apply EMA to trainable and moving average variable of this
   # Task, not just bprop vars, so that we create a shadow
   # '/ExponentialMovingAverage' variable for every trainable and moving
   # average variable.
   all_vars = set(tf.trainable_variables()) | set(
       tf.moving_average_variables())
   all_vars &= set(self.vars.Flatten())
   for var in all_vars:
     tf.logging.debug('ApplyExponentialMovingAverage: %s', var.name)
   with tf.control_dependencies([self._train_op
                                ]), tf.name_scope('moving_average'):
     self._train_op = ema.apply(all_vars)
Example #3
0
 def ApplyExponentialMovingAverage(self, ema):
   """Wraps `self.train_op` with an op updating exponential moving average."""
   if (self._create_variables_status !=
       base_layer._CreateLayerVariablesStatus.COMPLETED):  # pylint: disable=protected-access
     raise ValueError(
         'ApplyExponentialMovingAverage called before InstantiateVariables!')
   # TODO(rpang): raise an exception if this is called in the eval mode.
   p = self.params
   # We need to apply EMA to trainable and moving average variable of this
   # Task, not just bprop vars, so that we create a shadow
   # '/ExponentialMovingAverage' variable for every trainable and moving
   # average variable.
   all_vars = set(tf.trainable_variables()) | set(
       tf.moving_average_variables())
   if p.train.ema_decay_moving_vars:
     all_vars |= set(tf.get_collection('moving_vars'))
   all_vars &= set(self.vars.Flatten())
   for var in all_vars:
     tf.logging.debug('ApplyExponentialMovingAverage: %s', var.name)
   with tf.name_scope('moving_average'):
     self._post_train_ops.append(ema.apply(all_vars))