def layer_norm(x, dim, epsilon=1e-6, name="layer_prepostprocess"): """Layer normalization over dimension dim. Args: x: a mtf.Tensor whose shape contains dim. dim: a mtf.Dimension epsilon: a floating point number name: a string. variable scope. Returns: a mtf.Tensor with same shape as x. """ with tf.variable_scope(name + "/layer_norm"): scale = mtf.get_variable(x.mesh, "layer_norm_scale", mtf.Shape([dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) bias = mtf.get_variable(x.mesh, "layer_norm_bias", mtf.Shape([dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype) reduced_shape = x.shape - dim mean = mtf.reduce_mean(x, output_shape=reduced_shape) variance = mtf.reduce_mean(mtf.square(x - mean), output_shape=reduced_shape) norm_x = (x - mean) * mtf.rsqrt(variance + epsilon) return norm_x * scale + bias
def batch_norm(x, is_training, momentum, epsilon=1e-9, name=None): """Batch normalization. Args: x: a mtf.Tensor whose shape contains [batch_dim, ..., dim] is_training: a boolean, whether mode is training. momentum: a floating point number, specifying batch norm decay value. epsilon: a floating point number. name: a string. variable scope. Returns: a mtf.Tensor with same shape as x. """ with tf.variable_scope(name, default_name="batch_norm", values=[x]): batch_dim = x.shape.dims[0] reduced_shape = x.shape - batch_dim scale = mtf.get_variable(x.mesh, "batch_norm_scale", mtf.Shape([batch_dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) bias = mtf.get_variable(x.mesh, "batch_norm_bias", mtf.Shape([batch_dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype) moving_mean = mtf.get_variable( x.mesh, "moving_mean", reduced_shape, initializer=tf.random_normal_initializer(stddev=1.0), activation_dtype=x.dtype, trainable=False) moving_variance = mtf.get_variable(x.mesh, "moving_variance", reduced_shape, initializer=tf.ones_initializer(), activation_dtype=x.dtype, trainable=False) # At training time, calculate mean and variance and normalize across batch # dim. if is_training: mean = mtf.reduce_mean(x, output_shape=reduced_shape) variance = mtf.reduce_mean(mtf.square(x - mean), output_shape=reduced_shape) norm_x = (x - mean) * mtf.rsqrt(variance + epsilon) # Update running mean and running variance. moving_mean = mtf.assign( moving_mean, momentum * moving_mean + (1 - momentum) * mean) moving_variance = mtf.assign( moving_variance, momentum * moving_variance + (1 - momentum) * variance) else: # At eval and test time, use the running mean and variance. norm_x = (x - moving_mean) * mtf.rsqrt(moving_variance + epsilon) return norm_x * scale + bias
def compress_mean(x, dim, compression_factor): """Compress by taking group means. Args: x: a Tensor dim: a dimension in x.shape compression_factor: an integer Returns: a Tensor """ dims = x.shape.dims pos = dims.index(dim) compressed_dim = mtf.Dimension(dim.name, dim.size // compression_factor) compression_factor_dim = mtf.Dimension("compression_factor", compression_factor) new_shape = (dims[:pos] + [compressed_dim, compression_factor_dim] + dims[pos + 1:]) x = mtf.reshape(x, new_shape) x = mtf.reduce_mean(x, reduced_dim=compression_factor_dim) return x
def reduce_rms(x): return mtf.sqrt(mtf.reduce_mean(mtf.square(x)))
def apply_grad(self, grad, var): if grad is None: tf.logging.warning("Gradient is None for variable %s" % var) return [] # create slots grad = mtf.to_float(grad) factored_dims = self._factored_dims(var.shape) if factored_dims: d0, d1 = factored_dims vr_shape = var.shape - d0 vc_shape = var.shape - d1 vr = mtf.get_variable(var.mesh, var.name + "_slot_vr", vr_shape, initializer=tf.zeros_initializer(), trainable=False) vc = mtf.get_variable(var.mesh, var.name + "_slot_vc", vc_shape, initializer=tf.zeros_initializer(), trainable=False) else: v = mtf.get_variable(var.mesh, var.name + "_slot_v", var.shape, initializer=tf.zeros_initializer(), trainable=False) if self._beta1: m = mtf.get_variable(var.mesh, var.name + "_slot_m", var.shape, initializer=tf.zeros_initializer(), trainable=False) with tf.variable_scope(var.name + "/adafactor"): grad_squared = mtf.square(grad) + self._epsilon1 decay_rate = self._decay_rate old_val = mtf.to_float(var.value) if self._multiply_by_parameter_scale: update_scale = self._parameter_scale( old_val) * self._learning_rate else: update_scale = self._learning_rate mixing_rate = 1.0 - decay_rate updates = [] if factored_dims: grad_squared_row_mean = mtf.reduce_mean(grad_squared, output_shape=vr_shape) grad_squared_col_mean = mtf.reduce_mean(grad_squared, output_shape=vc_shape) new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate vr_update = mtf.assign(vr, new_vr) vc_update = mtf.assign(vc, new_vc) updates.extend([vr_update, vc_update]) long_term_mean = mtf.reduce_mean(new_vr, reduced_dim=d1) r_factor = mtf.rsqrt(new_vr / long_term_mean) c_factor = mtf.rsqrt(new_vc) x = grad * r_factor * c_factor else: new_v = v * decay_rate + grad_squared * mixing_rate v_update = mtf.assign(v, new_v) updates.append(v_update) x = grad * mtf.rsqrt(new_v) if self._clipping_threshold is not None: clipping_denom = mtf.maximum( 1.0, reduce_rms(x) / self._clipping_threshold) x /= clipping_denom subtrahend = x * update_scale if self._beta1: new_m = (m * tf.constant(self._beta1) + subtrahend * tf.constant(1.0 - self._beta1)) subtrahend = new_m updates.append(mtf.assign(m, new_m)) var_update = mtf.assign_sub(var, subtrahend) updates.append(var_update) return updates