def variables_for_ema(self): p = self.params all_vars = set(tf.trainable_variables()) | set( tf.moving_average_variables()) if p.train.ema_decay_moving_vars: all_vars |= set(tf.get_collection('moving_vars')) all_vars &= set(self.vars.Flatten()) for var in all_vars: tf.logging.debug('variables_for_ema: %s', var.name) return all_vars
def ApplyExponentialMovingAverage(self, ema): """Wraps `self.train_op` with an op updating exponential moving average.""" # We need to apply EMA to trainable and moving average variable of this # Task, not just bprop vars, so that we create a shadow # '/ExponentialMovingAverage' variable for every trainable and moving # average variable. all_vars = set(tf.trainable_variables()) | set( tf.moving_average_variables()) all_vars &= set(self.vars.Flatten()) for var in all_vars: tf.logging.debug('ApplyExponentialMovingAverage: %s', var.name) with tf.control_dependencies([self._train_op ]), tf.name_scope('moving_average'): self._train_op = ema.apply(all_vars)
def ApplyExponentialMovingAverage(self, ema): """Wraps `self.train_op` with an op updating exponential moving average.""" if (self._create_variables_status != base_layer._CreateLayerVariablesStatus.COMPLETED): # pylint: disable=protected-access raise ValueError( 'ApplyExponentialMovingAverage called before InstantiateVariables!') # TODO(rpang): raise an exception if this is called in the eval mode. p = self.params # We need to apply EMA to trainable and moving average variable of this # Task, not just bprop vars, so that we create a shadow # '/ExponentialMovingAverage' variable for every trainable and moving # average variable. all_vars = set(tf.trainable_variables()) | set( tf.moving_average_variables()) if p.train.ema_decay_moving_vars: all_vars |= set(tf.get_collection('moving_vars')) all_vars &= set(self.vars.Flatten()) for var in all_vars: tf.logging.debug('ApplyExponentialMovingAverage: %s', var.name) with tf.name_scope('moving_average'): self._post_train_ops.append(ema.apply(all_vars))