Example #1
0
def layer_norm(x, dim, epsilon=1e-6, name="layer_prepostprocess"):
    """Layer normalization over dimension dim.

  Args:
    x: a mtf.Tensor whose shape contains dim.
    dim: a mtf.Dimension
    epsilon: a floating point number
    name: a string. variable scope.

  Returns:
    a mtf.Tensor with same shape as x.
  """
    with tf.variable_scope(name + "/layer_norm"):
        scale = mtf.get_variable(x.mesh,
                                 "layer_norm_scale",
                                 mtf.TensorShape([dim]),
                                 initializer=tf.ones_initializer(),
                                 activation_dtype=x.dtype)
        bias = mtf.get_variable(x.mesh,
                                "layer_norm_bias",
                                mtf.TensorShape([dim]),
                                initializer=tf.zeros_initializer(),
                                activation_dtype=x.dtype)
        reduced_shape = x.shape - dim
        mean = mtf.reduce_mean(x, output_shape=reduced_shape)
        variance = mtf.reduce_mean(mtf.square(x - mean),
                                   output_shape=reduced_shape)
        norm_x = (x - mean) * mtf.rsqrt(variance + epsilon)
        return norm_x * scale + bias
Example #2
0
def batch_norm(x, is_training, momentum, epsilon=1e-9, name=None):
    """Batch normalization.

  Args:
    x: a mtf.Tensor whose shape contains [batch_dim, ..., dim]
    is_training: a boolean, whether mode is training.
    momentum: a floating point number, specifying batch norm decay value.
    epsilon: a floating point number.
    name: a string. variable scope.

  Returns:
    a mtf.Tensor with same shape as x.
  """
    with tf.variable_scope(name, default_name="batch_norm", values=[x]):
        batch_dim = x.shape.dims[0]
        reduced_shape = x.shape - batch_dim
        scale = mtf.get_variable(x.mesh,
                                 "batch_norm_scale",
                                 mtf.Shape([batch_dim]),
                                 initializer=tf.ones_initializer(),
                                 activation_dtype=x.dtype)
        bias = mtf.get_variable(x.mesh,
                                "batch_norm_bias",
                                mtf.Shape([batch_dim]),
                                initializer=tf.zeros_initializer(),
                                activation_dtype=x.dtype)

        moving_mean = mtf.get_variable(
            x.mesh,
            "moving_mean",
            reduced_shape,
            initializer=tf.random_normal_initializer(stddev=1.0),
            activation_dtype=x.dtype,
            trainable=False)
        moving_variance = mtf.get_variable(x.mesh,
                                           "moving_variance",
                                           reduced_shape,
                                           initializer=tf.ones_initializer(),
                                           activation_dtype=x.dtype,
                                           trainable=False)

        # At training time, calculate mean and variance and normalize across batch
        # dim.
        if is_training:
            mean = mtf.reduce_mean(x, output_shape=reduced_shape)
            variance = mtf.reduce_mean(mtf.square(x - mean),
                                       output_shape=reduced_shape)
            norm_x = (x - mean) * mtf.rsqrt(variance + epsilon)

            # Update running mean and running variance.
            moving_mean = mtf.assign(
                moving_mean, momentum * moving_mean + (1 - momentum) * mean)
            moving_variance = mtf.assign(
                moving_variance,
                momentum * moving_variance + (1 - momentum) * variance)
        else:
            # At eval and test time, use the running mean and variance.
            norm_x = (x - moving_mean) * mtf.rsqrt(moving_variance + epsilon)
        return norm_x * scale + bias
Example #3
0
def cv_squared(x):
    """The squared coefficient of variation of a sample.

  Useful as a loss to encourage a positive distribution to be more uniform.
  Epsilons added for numerical stability.
  Returns 0 for an empty Tensor.

  Args:
    x: a mtf.Tensor

  Returns:
    a mtf Scalar
  """
    epsilon = 1e-10
    mean = mtf.reduce_mean(x)
    variance = mtf.reduce_mean(mtf.square(x - mean))
    return variance / (mtf.square(mean) + epsilon)
def toy_model(features, mesh):
  """A toy model implemented by mesh tensorlfow."""
  batch_dim = mtf.Dimension('batch', FLAGS.batch_size)
  hidden_dim = mtf.Dimension('hidden', FLAGS.hidden_size)
  io_dim = mtf.Dimension('io', FLAGS.io_size)

  x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim]))
  h = mtf_layers.dense(x, hidden_dim, name='layer1', use_bias=False)
  y = mtf_layers.dense(h, io_dim, name='layer2', use_bias=False)

  loss = mtf.reduce_sum(mtf.square(y - x))
  return y, loss
Example #5
0
    def mtf_model_fn(self, features, mesh):
        hparams = self._hparams
        # tf_x = tf.random_uniform([hparams.batch_size, hparams.io_size])
        tf_x = tf.matmul(
            tf.reshape(tf.lin_space(0., 1.0, hparams.batch_size),
                       [hparams.batch_size, 1]),
            tf.reshape(tf.lin_space(0., 1.0, hparams.io_size),
                       [1, hparams.io_size]))
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
        io_dim = mtf.Dimension("io", hparams.io_size)

        x = mtf.infeed_fully_replicated(mesh, tf_x,
                                        mtf.TensorShape([batch_dim, io_dim]))
        h = mtf_layers.dense(x, hidden_dim, name="layer1", use_bias=False)
        y = mtf_layers.dense(h, io_dim, name="layer2", use_bias=False)
        loss = mtf.reduce_sum(mtf.square(y - x))
        return None, loss
Example #6
0
def reduce_rms(x):
    return mtf.sqrt(mtf.reduce_mean(mtf.square(x)))
Example #7
0
    def apply_grad(self, grad, var):
        # create slots
        factored_dims = self._factored_dims(var.shape)
        if factored_dims:
            d0, d1 = factored_dims
            vr_shape = var.shape - d0
            vc_shape = var.shape - d1
            vr = mtf.get_variable(var.mesh,
                                  var.name + "_slot_vr",
                                  vr_shape,
                                  initializer=tf.zeros_initializer(),
                                  trainable=False)
            vc = mtf.get_variable(var.mesh,
                                  var.name + "_slot_vc",
                                  vc_shape,
                                  initializer=tf.zeros_initializer(),
                                  trainable=False)
        else:
            v = mtf.get_variable(var.mesh,
                                 var.name + "_slot_v",
                                 var.shape,
                                 initializer=tf.zeros_initializer(),
                                 trainable=False)
        if self._beta1:
            m = mtf.get_variable(var.mesh,
                                 var.name + "_slot_m",
                                 var.shape,
                                 iniitalizer=tf.zeros_initializer(),
                                 trainable=False)

        with tf.variable_scope(var.name + "/adafactor"):
            grad_squared = mtf.square(grad) + self._epsilon1
            decay_rate = self._decay_rate
            old_val = var.value
            if self._multiply_by_parameter_scale:
                update_scale = self._parameter_scale(
                    old_val) * self._learning_rate
            else:
                update_scale = self._learning_rate
            mixing_rate = 1.0 - decay_rate
            updates = []
            if factored_dims:
                grad_squared_row_mean = mtf.reduce_mean(grad_squared,
                                                        output_shape=vr_shape)
                grad_squared_col_mean = mtf.reduce_mean(grad_squared,
                                                        output_shape=vc_shape)
                new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate
                new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate
                vr_update = mtf.assign(vr, new_vr)
                vc_update = mtf.assign(vc, new_vc)
                updates.extend([vr_update, vc_update])
                long_term_mean = mtf.reduce_mean(new_vr, reduced_dim=d1)
                r_factor = mtf.rsqrt(new_vr / long_term_mean)
                c_factor = mtf.rsqrt(new_vc)
                x = grad * r_factor * c_factor
            else:
                new_v = v * decay_rate + grad_squared * mixing_rate
                v_update = mtf.assign(v, new_v)
                updates.append(v_update)
                x = grad * mtf.rsqrt(new_v)
            if self._clipping_threshold is not None:
                clipping_denom = mtf.maximum(
                    1.0,
                    reduce_rms(x) / self._clipping_threshold)
                x /= clipping_denom
            subtrahend = x * update_scale
            if self._beta1:
                new_m = self._beta1 * m.value + (1.0 -
                                                 self._beta1) * subtrahend
                subtrahend = new_m
                updates.append(mtf.assign(m, new_m))
            new_val = old_val - subtrahend
            var_update = mtf.assign(var, new_val)
            updates.append(var_update)
            return updates
Example #8
0
 def normalize(x):
   scale = layer_norm_vars.pop(0)
   variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
   return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale