def primer_norm(x, dim, epsilon=1e-6, name="layer_prepostprocess"): """Primer normalization over dimension `dim`. Args: x: a mtf.Tensor whose shape contains `dim`. dim: a mtf.Dimension. epsilon: a floating point number. name: a string used for tf.variable_scope. Returns: a mtf.Tensor with same shape as x. """ with tf.variable_scope(name + "/primer_norm"): scale = mtf.get_variable(x.mesh, "primer_norm_scale", mtf.Shape([dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) bias = mtf.get_variable(x.mesh, "primer_norm_bias", mtf.Shape([dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype) reduced_shape = x.shape - dim mean = mtf.reduce_mean(x, output_shape=reduced_shape) mean_centered_x = x - mean pseudo_variance = mtf.reduce_mean(x * mean_centered_x, output_shape=reduced_shape) norm_x = mean_centered_x * mtf.rsqrt(pseudo_variance + epsilon) return norm_x * scale + bias
def _layer_norm(self, context, x, name=None): with tf.variable_scope(name, default_name="layer_norm"): scale = mtf.get_variable( context.mesh, "scale", mtf.Shape([context.model_dim]), initializer=tf.ones_initializer(), dtype=context.variable_dtype) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=context.model_dim) return x * mtf.rsqrt(variance + self._norm_epsilon) * scale
def norm(x, axis=None, epsilon=1e-5): axis = default(axis, x.shape[-1]) u = mtf.reduce_mean(x, reduced_dim=axis) s = mtf.reduce_mean(mtf.square(x - u), reduced_dim=axis) u = mtf.broadcast(u, x.shape) s = mtf.broadcast(s, x.shape) return (x - u) * mtf.rsqrt(s + epsilon)
def layer_norm( x, dim: mtf.Dimension, epsilon: float = 1e-6, subtract_mean=True, use_scale=True, use_bias=True, name=None, ): """Layer normalization over dimension dim. Args: x: a mtf.Tensor whose shape contains dim. dim: a mtf.Dimension epsilon: a floating point number subtract_mean: a boolean use_scale: a boolean use_bias: a boolean name: a string used for tf.variable_scope. Returns: a mtf.Tensor with same shape as x. """ with tf.variable_scope(name, default_name="layer_norm"): if subtract_mean: x -= mtf.reduce_mean(x, reduced_dim=dim) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=dim) x *= mtf.rsqrt(variance + epsilon) if use_scale: x *= mtf.get_variable( x.mesh, "scale", mtf.Shape([dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype, ) if use_bias: x += mtf.get_variable( x.mesh, "bias", mtf.Shape([dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype, ) return x
def normalize(x): scale = layer_norm_vars.pop(0) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim) return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale
def norm(x, axis, epsilon=1e-8): x -= mtf.reduce_mean(x, reduced_dim=axis, name="norm_reduce_mean_u") s = mtf.reduce_mean(mtf.square(x), reduced_dim=axis, name="norm_reduce_mean_s") return x * mtf.rsqrt(s + epsilon)