Example #1
0
    def apply_grad(self, grad, var):
        """See base class."""
        if grad is None:
            tf.logging.warning("Gradient is None for variable %s" % var.name)
            return []
        grad = mtf.to_float(grad)

        assignments = []

        m = mtf.get_variable(var.mesh,
                             var.name + "/adam_m",
                             var.shape,
                             initializer=tf.zeros_initializer(),
                             trainable=False)

        v = mtf.get_variable(var.mesh,
                             var.name + "/adam_v",
                             var.shape,
                             initializer=tf.zeros_initializer(),
                             trainable=False)

        # Standard Adam update.
        next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad
        next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad)

        update = next_m / (mtf.sqrt(next_v) + self.epsilon)

        # Just adding the square of the weights to the loss function is *not*
        # the correct way of using L2 regularization/weight decay with Adam,
        # since that will interact with the m and v parameters in strange ways.
        #
        # Instead we want ot decay the weights in a manner that doesn't interact
        # with the m/v parameters. This is equivalent to adding the square
        # of the weights to the loss with plain (non-momentum) SGD.
        if self._do_use_weight_decay(var.name):
            update += self.weight_decay_rate * var.value

        update_with_lr = self.learning_rate * update

        var_update = mtf.assign_sub(var, update_with_lr)

        assignments.extend(
            [var_update,
             mtf.assign(m, next_m),
             mtf.assign(v, next_v)])
        return assignments
Example #2
0
def reduce_rms(x):
    return mtf.sqrt(mtf.reduce_mean(mtf.square(x)))
Example #3
0
def reduce_rms(x, **kwargs):
  return mtf.sqrt(mtf.reduce_mean(mtf.square(x), **kwargs))