示例#1
0
    def apply_grad(self, grad, var):
        """See base class."""
        if grad is None:
            tf.logging.warning("Gradient is None for variable %s" % var.name)
            return []
        grad = mtf.to_float(grad)

        assignments = []

        m = mtf.get_variable(var.mesh,
                             var.name + "/adam_m",
                             var.shape,
                             initializer=tf.zeros_initializer(),
                             trainable=False)

        v = mtf.get_variable(var.mesh,
                             var.name + "/adam_v",
                             var.shape,
                             initializer=tf.zeros_initializer(),
                             trainable=False)

        # Standard Adam update.
        next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad
        next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad)

        update = next_m / (mtf.sqrt(next_v) + self.epsilon)

        # Just adding the square of the weights to the loss function is *not*
        # the correct way of using L2 regularization/weight decay with Adam,
        # since that will interact with the m and v parameters in strange ways.
        #
        # Instead we want ot decay the weights in a manner that doesn't interact
        # with the m/v parameters. This is equivalent to adding the square
        # of the weights to the loss with plain (non-momentum) SGD.
        if self._do_use_weight_decay(var.name):
            update += self.weight_decay_rate * var.value

        update_with_lr = self.learning_rate * update

        var_update = mtf.assign_sub(var, update_with_lr)

        assignments.extend(
            [var_update,
             mtf.assign(m, next_m),
             mtf.assign(v, next_v)])
        return assignments
示例#2
0
  def apply_grad(self, grad, var):
    if grad is None:
      tf.logging.warning("Gradient is None for variable %s" % var.name)
      return []

    updates = []
    v = mtf.get_variable(
        var.mesh, var.name + "_momentum_v", var.shape,
        dtype=var.dtype, initializer=tf.zeros_initializer(), trainable=False)

    with tf.variable_scope(var.name + "/sgd_momentum"):
      updates.append(mtf.assign(v, grad * self.lr + v * self.momentum))
      updates.append(mtf.assign_sub(var, v))

    return updates
示例#3
0
    def apply_grad(self, grad, var):
        if grad is None:
            tf.logging.warning("Gradient is None for variable %s" % var)
            return []
        # create slots
        grad = mtf.to_float(grad)
        factored_dims = self._factored_dims(var.shape)
        if factored_dims:
            d0, d1 = factored_dims
            vr_shape = var.shape - d0
            vc_shape = var.shape - d1
            vr = mtf.get_variable(var.mesh,
                                  var.name + "_slot_vr",
                                  vr_shape,
                                  initializer=tf.zeros_initializer(),
                                  trainable=False)
            vc = mtf.get_variable(var.mesh,
                                  var.name + "_slot_vc",
                                  vc_shape,
                                  initializer=tf.zeros_initializer(),
                                  trainable=False)
        else:
            v = mtf.get_variable(var.mesh,
                                 var.name + "_slot_v",
                                 var.shape,
                                 initializer=tf.zeros_initializer(),
                                 trainable=False)
        if self._beta1:
            m = mtf.get_variable(var.mesh,
                                 var.name + "_slot_m",
                                 var.shape,
                                 initializer=tf.zeros_initializer(),
                                 trainable=False)

        with tf.variable_scope(var.name + "/adafactor"):
            grad_squared = mtf.square(grad) + self._epsilon1
            decay_rate = self._decay_rate
            old_val = mtf.to_float(var.value)
            if self._multiply_by_parameter_scale:
                update_scale = self._parameter_scale(
                    old_val) * self._learning_rate
            else:
                update_scale = self._learning_rate
            mixing_rate = 1.0 - decay_rate
            updates = []
            if factored_dims:
                grad_squared_row_mean = mtf.reduce_mean(grad_squared,
                                                        output_shape=vr_shape)
                grad_squared_col_mean = mtf.reduce_mean(grad_squared,
                                                        output_shape=vc_shape)
                new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate
                new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate
                vr_update = mtf.assign(vr, new_vr)
                vc_update = mtf.assign(vc, new_vc)
                updates.extend([vr_update, vc_update])
                long_term_mean = mtf.reduce_mean(new_vr, reduced_dim=d1)
                r_factor = mtf.rsqrt(new_vr / long_term_mean)
                c_factor = mtf.rsqrt(new_vc)
                x = grad * r_factor * c_factor
            else:
                new_v = v * decay_rate + grad_squared * mixing_rate
                v_update = mtf.assign(v, new_v)
                updates.append(v_update)
                x = grad * mtf.rsqrt(new_v)
            if self._clipping_threshold is not None:
                clipping_denom = mtf.maximum(
                    1.0,
                    reduce_rms(x) / self._clipping_threshold)
                x /= clipping_denom
            subtrahend = x * update_scale
            if self._beta1:
                new_m = (m * tf.constant(self._beta1) +
                         subtrahend * tf.constant(1.0 - self._beta1))
                subtrahend = new_m
                updates.append(mtf.assign(m, new_m))
            # It is critical to use assign_sub instead of mtf.assign(var - subtrahend)
            #  for the case of bfloat16 activations, so as to avoid repeatedly
            #  rounding the slice value, which results in poor quality.
            var_update = mtf.assign_sub(var, subtrahend)
            updates.append(var_update)
            return updates
示例#4
0
文件: layers.py 项目: qixiuai/mesh
def batch_norm(x, is_training, momentum, epsilon=1e-9,
               init_zero=False, name=None):
  """Batch normalization.

  Args:
    x: a mtf.Tensor whose shape contains [batch_dim, ..., dim]
    is_training: a boolean, whether mode is training.
    momentum: a floating point number, specifying batch norm decay value.
    epsilon: a floating point number.
    init_zero: a boolean, whether to initialize scale with 0's or 1's.
    name: a string. variable scope.

  Returns:
    a mtf.Tensor with same shape as x.
  """
  with tf.variable_scope(name, default_name="batch_norm", values=[x]):
    if init_zero:
      gamma_initializer = tf.zeros_initializer()
    else:
      gamma_initializer = tf.ones_initializer()

    norm_dim = x.shape.dims[0:3]
    reduced_shape = x.shape - norm_dim

    scale = mtf.get_variable(
        x.mesh,
        "batch_norm_scale",
        reduced_shape,
        initializer=gamma_initializer,
        activation_dtype=x.dtype)
    bias = mtf.get_variable(
        x.mesh,
        "batch_norm_bias",
        reduced_shape,
        initializer=tf.zeros_initializer(),
        activation_dtype=x.dtype)

    moving_mean = mtf.get_variable(
        x.mesh, "moving_mean", reduced_shape,
        initializer=tf.random_normal_initializer(stddev=1.0),
        activation_dtype=x.dtype,
        trainable=False)
    moving_variance = mtf.get_variable(
        x.mesh, "moving_variance",
        reduced_shape, initializer=tf.ones_initializer(),
        activation_dtype=x.dtype,
        trainable=False)

    # At training time, calculate mean and variance and normalize across batch
    # dim.
    if is_training:
      mean = mtf.reduce_mean(x, output_shape=reduced_shape)
      variance = mtf.reduce_mean(
          mtf.square(x - mean), output_shape=reduced_shape)

      norm_x = (x - mean) * mtf.rsqrt(variance + epsilon)

      # Update running mean and running variance.
      moving_mean = mtf.assign(
          moving_mean, momentum * moving_mean + (1-momentum) * mean)
      moving_variance = mtf.assign(
          moving_variance,
          momentum * moving_variance + (1 - momentum) * variance)
    else:
      # At eval and test time, use the running mean and variance.
      norm_x = (x - moving_mean) * mtf.rsqrt(moving_variance + epsilon)

    return (norm_x * scale) + bias