def layer_norm(x, dim, epsilon=1e-6, name="layer_prepostprocess"): """Layer normalization over dimension dim. Args: x: a mtf.Tensor whose shape contains dim. dim: a mtf.Dimension epsilon: a floating point number name: a string. variable scope. Returns: a mtf.Tensor with same shape as x. """ with tf.variable_scope(name + "/layer_norm"): scale = mtf.get_variable(x.mesh, "layer_norm_scale", mtf.Shape([dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) bias = mtf.get_variable(x.mesh, "layer_norm_bias", mtf.Shape([dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype) reduced_shape = x.shape - dim mean = mtf.reduce_mean(x, output_shape=reduced_shape) variance = mtf.reduce_mean(mtf.square(x - mean), output_shape=reduced_shape) norm_x = (x - mean) * mtf.rsqrt(variance + epsilon) return norm_x * scale + bias
def batch_norm(x, is_training, momentum, epsilon=1e-9, name=None): """Batch normalization. Args: x: a mtf.Tensor whose shape contains [batch_dim, ..., dim] is_training: a boolean, whether mode is training. momentum: a floating point number, specifying batch norm decay value. epsilon: a floating point number. name: a string. variable scope. Returns: a mtf.Tensor with same shape as x. """ with tf.variable_scope(name, default_name="batch_norm", values=[x]): batch_dim = x.shape.dims[0] reduced_shape = x.shape - batch_dim scale = mtf.get_variable(x.mesh, "batch_norm_scale", mtf.Shape([batch_dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) bias = mtf.get_variable(x.mesh, "batch_norm_bias", mtf.Shape([batch_dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype) moving_mean = mtf.get_variable( x.mesh, "moving_mean", reduced_shape, initializer=tf.random_normal_initializer(stddev=1.0), activation_dtype=x.dtype, trainable=False) moving_variance = mtf.get_variable(x.mesh, "moving_variance", reduced_shape, initializer=tf.ones_initializer(), activation_dtype=x.dtype, trainable=False) # At training time, calculate mean and variance and normalize across batch # dim. if is_training: mean = mtf.reduce_mean(x, output_shape=reduced_shape) variance = mtf.reduce_mean(mtf.square(x - mean), output_shape=reduced_shape) norm_x = (x - mean) * mtf.rsqrt(variance + epsilon) # Update running mean and running variance. moving_mean = mtf.assign( moving_mean, momentum * moving_mean + (1 - momentum) * mean) moving_variance = mtf.assign( moving_variance, momentum * moving_variance + (1 - momentum) * variance) else: # At eval and test time, use the running mean and variance. norm_x = (x - moving_mean) * mtf.rsqrt(moving_variance + epsilon) return norm_x * scale + bias
def dense(x, output_dim, reduced_dims=None, expert_dims=None, use_bias=True, activation=None, master_dtype=tf.float32, slice_dtype=tf.float32, name=None): """Dense layer doing (kernel*x + bias) computation. Args: x: a mtf.Tensor of shape [..., reduced_dims]. output_dim: a mtf.Dimension reduced_dims: an optional list of mtf.Dimensions of x to be reduced. If omitted, we reduce the last dimension. expert_dims: an optional list of mtf.Dimension which represent different experts. Different experts get different weights. use_bias: a boolean, whether to add bias. activation: an optional function from mtf.Tensor to mtf.Tensor master_dtype: a tf.dtype slice_dtype: a tf.dtype name: a string. variable scope. Returns: a mtf.Tensor of shape [..., output_dim]. """ if expert_dims is None: expert_dims = [] if reduced_dims is None: reduced_dims = x.shape.dims[-1:] w_shape = mtf.Shape(expert_dims + reduced_dims + [output_dim]) output_shape = mtf.Shape( [d for d in x.shape.dims if d not in reduced_dims] + [output_dim]) with tf.variable_scope(name, default_name="dense"): stddev = mtf.list_product(d.size for d in reduced_dims)**-0.5 w = mtf.get_variable( x.mesh, "kernel", w_shape, initializer=tf.random_normal_initializer(stddev=stddev), master_dtype=master_dtype, slice_dtype=slice_dtype, activation_dtype=x.dtype) y = mtf.einsum([x, w], output_shape) if use_bias: b = mtf.get_variable(x.mesh, "bias", mtf.Shape(expert_dims + [output_dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype) y += b if activation is not None: y = activation(y) return y
def multihead_attention_vars(mesh, heads, io_channels, kv_channels, master_dtype, slice_dtype, activation_dtype): """Create Parameters for Multihead Attention. Args: mesh: a Mesh heads: a Dimension io_channels: a Dimension kv_channels: a Dimension master_dtype: a tf.dtype slice_dtype: a tf.dtype activation_dtype: a tf.dtype Returns: q_var: a Tensor with shape [heads, io_channels, kv_channels] k_var: a Tensor with shape [heads, io_channels, kv_channels] v_var: a Tensor with shape [heads, io_channels, kv_channels] o_var: a Tensor with shape [heads, io_channels, kv_channels] """ qkvo = mtf.Dimension("qkvo", 4) qk_stddev = (io_channels.size**-0.5) * (kv_channels.size**-0.25) v_stddev = io_channels.size**-0.5 o_stddev = (io_channels.size * heads.size)**-0.5 def qkvo_initializer(shape, dtype=None, partition_info=None, verify_shape=None): del partition_info, verify_shape return tf.random_normal(shape, dtype=dtype) * tf.reshape( tf.cast([qk_stddev, qk_stddev, v_stddev, o_stddev], dtype or tf.float32), [4, 1, 1, 1]) var = mtf.get_variable(mesh, "qkvo", mtf.Shape([qkvo, heads, io_channels, kv_channels]), initializer=qkvo_initializer, master_dtype=master_dtype, slice_dtype=slice_dtype, activation_dtype=activation_dtype) q_var, k_var, v_var, o_var = mtf.unstack(var, qkvo) return q_var, k_var, v_var, o_var
def dense_relu_dense(x, hidden_channels, dropout=0.0, dropout_broadcast_dims=None, name=None): """Hidden layer with ReLU activation followed by linear projection. The output has the same number of channels as the input. Args: x: a mtf.Tensor hidden_channels: a mtf.Dimension - channels in the hidden layer dropout: an optional float dropout_broadcast_dims: an optional list of mtf.Dimension name: an optional string Returns: a mtf.Tensor with the same shape as x. """ with tf.variable_scope(name, default_name="dense_relu_dense"): io_channels = x.shape.dims[-1] stddev = (hidden_channels.size * io_channels.size)**-0.25 io = mtf.Dimension("io", 2) w = mtf.get_variable( x.mesh, "kernel", mtf.Shape([io, io_channels, hidden_channels]), initializer=tf.random_normal_initializer(stddev=stddev), activation_dtype=x.dtype) wi, wo = mtf.unstack(w, io) h = mtf.relu(mtf.einsum([x, wi])) if dropout != 0.0: h = mtf.dropout(h, 1.0 - dropout, noise_shape=h.shape - dropout_broadcast_dims) return mtf.einsum([h, wo])
def apply_grad(self, grad, var): if grad is None: tf.logging.warning("Gradient is None for variable %s" % var) return [] # create slots grad = mtf.to_float(grad) factored_dims = self._factored_dims(var.shape) if factored_dims: d0, d1 = factored_dims vr_shape = var.shape - d0 vc_shape = var.shape - d1 vr = mtf.get_variable(var.mesh, var.name + "_slot_vr", vr_shape, initializer=tf.zeros_initializer(), trainable=False) vc = mtf.get_variable(var.mesh, var.name + "_slot_vc", vc_shape, initializer=tf.zeros_initializer(), trainable=False) else: v = mtf.get_variable(var.mesh, var.name + "_slot_v", var.shape, initializer=tf.zeros_initializer(), trainable=False) if self._beta1: m = mtf.get_variable(var.mesh, var.name + "_slot_m", var.shape, initializer=tf.zeros_initializer(), trainable=False) with tf.variable_scope(var.name + "/adafactor"): grad_squared = mtf.square(grad) + self._epsilon1 decay_rate = self._decay_rate old_val = mtf.to_float(var.value) if self._multiply_by_parameter_scale: update_scale = self._parameter_scale( old_val) * self._learning_rate else: update_scale = self._learning_rate mixing_rate = 1.0 - decay_rate updates = [] if factored_dims: grad_squared_row_mean = mtf.reduce_mean(grad_squared, output_shape=vr_shape) grad_squared_col_mean = mtf.reduce_mean(grad_squared, output_shape=vc_shape) new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate vr_update = mtf.assign(vr, new_vr) vc_update = mtf.assign(vc, new_vc) updates.extend([vr_update, vc_update]) long_term_mean = mtf.reduce_mean(new_vr, reduced_dim=d1) r_factor = mtf.rsqrt(new_vr / long_term_mean) c_factor = mtf.rsqrt(new_vc) x = grad * r_factor * c_factor else: new_v = v * decay_rate + grad_squared * mixing_rate v_update = mtf.assign(v, new_v) updates.append(v_update) x = grad * mtf.rsqrt(new_v) if self._clipping_threshold is not None: clipping_denom = mtf.maximum( 1.0, reduce_rms(x) / self._clipping_threshold) x /= clipping_denom subtrahend = x * update_scale if self._beta1: new_m = (m * tf.constant(self._beta1) + subtrahend * tf.constant(1.0 - self._beta1)) subtrahend = new_m updates.append(mtf.assign(m, new_m)) var_update = mtf.assign_sub(var, subtrahend) updates.append(var_update) return updates