def layer_norm(x, dim, epsilon=1e-6, name="layer_prepostprocess"): """Layer normalization over dimension dim. Args: x: a mtf.Tensor whose shape contains dim. dim: a mtf.Dimension epsilon: a floating point number name: a string. variable scope. Returns: a mtf.Tensor with same shape as x. """ with tf.variable_scope(name + "/layer_norm"): scale = mtf.get_variable( x.mesh, "layer_norm_scale", mtf.Shape([dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) bias = mtf.get_variable( x.mesh, "layer_norm_bias", mtf.Shape([dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype) reduced_shape = x.shape - dim mean = mtf.reduce_mean(x, output_shape=reduced_shape) variance = mtf.reduce_mean(mtf.square(x - mean), output_shape=reduced_shape) norm_x = (x - mean) * mtf.rsqrt(variance + epsilon) return norm_x * scale + bias
def dense(x, output_dim, reduced_dims=None, expert_dims=None, use_bias=True, activation=None, master_dtype=tf.float32, slice_dtype=tf.float32, variable_dtype=None, name=None): """Dense layer doing (kernel*x + bias) computation. Args: x: a mtf.Tensor of shape [..., reduced_dims]. output_dim: a mtf.Dimension reduced_dims: an optional list of mtf.Dimensions of x to be reduced. If omitted, we reduce the last dimension. expert_dims: an optional list of mtf.Dimension which represent different experts. Different experts get different weights. use_bias: a boolean, whether to add bias. activation: an optional function from mtf.Tensor to mtf.Tensor master_dtype: a tf.dtype (deprecated - use variable_dtype) slice_dtype: a tf.dtype (deprecated - use variable_dtype) variable_dtype: a mtf.VariableDType name: a string. variable scope. Returns: a mtf.Tensor of shape [..., output_dim]. """ if variable_dtype is None: variable_dtype = mtf.VariableDType(master_dtype, slice_dtype, x.dtype) if expert_dims is None: expert_dims = [] if reduced_dims is None: reduced_dims = x.shape.dims[-1:] w_shape = mtf.Shape(expert_dims + reduced_dims + [output_dim]) output_shape = mtf.Shape( [d for d in x.shape.dims if d not in reduced_dims] + [output_dim]) with tf.variable_scope(name, default_name="dense"): stddev = mtf.list_product(d.size for d in reduced_dims) ** -0.5 w = mtf.get_variable( x.mesh, "kernel", w_shape, initializer=tf.random_normal_initializer(stddev=stddev), dtype=variable_dtype) w = mtf.cast(w, x.dtype) y = mtf.einsum([x, w], output_shape) if use_bias: b = mtf.get_variable( x.mesh, "bias", mtf.Shape(expert_dims + [output_dim]), initializer=tf.zeros_initializer(), dtype=variable_dtype) y += b if activation is not None: y = activation(y) return y
def multihead_attention_params(mesh, heads, io_channels, kv_channels, variable_dtype, combine=False): """Create Parameters for Multihead Attention. If the combine flag is set to True, then we create only one variable which stacks together all of the parameters. Otherwise, we create four separate variables. Args: mesh: a Mesh heads: a Dimension io_channels: a Dimension kv_channels: a Dimension variable_dtype: a mtf.VariableDType combine: a boolean Returns: wq: a Tensor with shape [heads, io_channels, kv_channels] wk: a Tensor with shape [heads, io_channels, kv_channels] wv: a Tensor with shape [heads, io_channels, kv_channels] wo: a Tensor with shape [heads, io_channels, kv_channels] """ qkvo = mtf.Dimension("qkvo", 4) qk_stddev = (io_channels.size ** -0.5) * (kv_channels.size ** -0.25) v_stddev = io_channels.size ** -0.5 # TODO(noam): should be: o_stddev = (kv_channels.size * heads.size) ** -0.5 # verify that this still works and change it. o_stddev = (io_channels.size * heads.size) ** -0.5 if combine: def qkvo_initializer(shape, dtype=None, partition_info=None, verify_shape=None): del partition_info, verify_shape return tf.random_normal(shape, dtype=dtype) * tf.reshape( tf.cast([qk_stddev, qk_stddev, v_stddev, o_stddev], dtype or tf.float32), [4, 1, 1, 1]) var = mtf.get_variable( mesh, "qkvo", mtf.Shape([qkvo, heads, io_channels, kv_channels]), initializer=qkvo_initializer, dtype=variable_dtype) return mtf.unstack(var, qkvo) else: return [mtf.get_variable( mesh, name, mtf.Shape([heads, io_channels, kv_channels]), initializer=tf.random_normal_initializer(stddev=stddev), dtype=variable_dtype) for name, stddev in zip( ["q", "k", "v", "o"], [qk_stddev, qk_stddev, v_stddev, o_stddev])]
def apply_grad(self, grad, var): """See base class.""" if grad is None: tf.logging.warning("Gradient is None for variable %s" % var.name) return [] grad = mtf.to_float(grad) assignments = [] m = mtf.get_variable(var.mesh, var.name + "/adam_m", var.shape, initializer=tf.zeros_initializer(), trainable=False) v = mtf.get_variable(var.mesh, var.name + "/adam_v", var.shape, initializer=tf.zeros_initializer(), trainable=False) # Standard Adam update. next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad) update = next_m / (mtf.sqrt(next_v) + self.epsilon) # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want ot decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. if self._do_use_weight_decay(var.name): update += self.weight_decay_rate * var.value update_with_lr = self.learning_rate * update var_update = mtf.assign_sub(var, update_with_lr) assignments.extend( [var_update, mtf.assign(m, next_m), mtf.assign(v, next_v)]) return assignments
def apply_grad(self, grad, var): if grad is None: tf.logging.warning("Gradient is None for variable %s" % var.name) return [] updates = [] v = mtf.get_variable( var.mesh, var.name + "_momentum_v", var.shape, dtype=var.dtype, initializer=tf.zeros_initializer(), trainable=False) with tf.variable_scope(var.name + "/sgd_momentum"): updates.append(mtf.assign(v, grad * self.lr + v * self.momentum)) updates.append(mtf.assign_sub(var, v)) return updates
def conv2d(x, output_dim, filter_size=(3, 3), strides=(1, 1), padding="SAME", filter_initializer=None, variable_dtype=None, name=None): """2D Convolution. Args: x: a mtf.Tensor of format NHWC. output_dim: a mtf.Dimension, indicating the output channel dimension. filter_size: a list or tuple in format [filter_height, filter_width]. strides: a list or tuple in format [stride_height, stride_width]. padding: either "SAME" or "VALID". filter_initializer: the initializer for tf.get_variable. variable_dtype: a mtf.VariableDType name: a string used for tf.variable_scope. Returns: a mtf.Tensor. """ fh_dim = mtf.Dimension("fh", filter_size[0]) fw_dim = mtf.Dimension("fw", filter_size[1]) input_dim = x.shape[-1] with tf.variable_scope(name, default_name="conv2d"): if variable_dtype is None: variable_dtype = mtf.VariableDType(activation_dtype=x.dtype) conv_filter = mtf.get_variable(x.mesh, "kernel", [fh_dim, fw_dim, input_dim, output_dim], initializer=filter_initializer, dtype=variable_dtype) # Pad stride in batch and channel dimensions. strides = [1] + list(strides) + [1] return mtf.Conv2dOperation(x, conv_filter, strides, padding).outputs[0]
def apply_grad(self, grad, var): if grad is None: tf.logging.warning("Gradient is None for variable %s" % var) return [] # create slots grad = mtf.to_float(grad) factored_dims = self._factored_dims(var.shape) if factored_dims: d0, d1 = factored_dims vr_shape = var.shape - d0 vc_shape = var.shape - d1 vr = mtf.get_variable(var.mesh, var.name + "_slot_vr", vr_shape, initializer=tf.zeros_initializer(), trainable=False) vc = mtf.get_variable(var.mesh, var.name + "_slot_vc", vc_shape, initializer=tf.zeros_initializer(), trainable=False) else: v = mtf.get_variable(var.mesh, var.name + "_slot_v", var.shape, initializer=tf.zeros_initializer(), trainable=False) if self._beta1: m = mtf.get_variable(var.mesh, var.name + "_slot_m", var.shape, initializer=tf.zeros_initializer(), trainable=False) with tf.variable_scope(var.name + "/adafactor"): grad_squared = mtf.square(grad) + self._epsilon1 decay_rate = self._decay_rate old_val = mtf.to_float(var.value) if self._multiply_by_parameter_scale: update_scale = self._parameter_scale( old_val) * self._learning_rate else: update_scale = self._learning_rate mixing_rate = 1.0 - decay_rate updates = [] if factored_dims: grad_squared_row_mean = mtf.reduce_mean(grad_squared, output_shape=vr_shape) grad_squared_col_mean = mtf.reduce_mean(grad_squared, output_shape=vc_shape) new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate vr_update = mtf.assign(vr, new_vr) vc_update = mtf.assign(vc, new_vc) updates.extend([vr_update, vc_update]) long_term_mean = mtf.reduce_mean(new_vr, reduced_dim=d1) r_factor = mtf.rsqrt(new_vr / long_term_mean) c_factor = mtf.rsqrt(new_vc) x = grad * r_factor * c_factor else: new_v = v * decay_rate + grad_squared * mixing_rate v_update = mtf.assign(v, new_v) updates.append(v_update) x = grad * mtf.rsqrt(new_v) if self._clipping_threshold is not None: clipping_denom = mtf.maximum( 1.0, reduce_rms(x) / self._clipping_threshold) x /= clipping_denom subtrahend = x * update_scale if self._beta1: new_m = (m * tf.constant(self._beta1) + subtrahend * tf.constant(1.0 - self._beta1)) subtrahend = new_m updates.append(mtf.assign(m, new_m)) # It is critical to use assign_sub instead of mtf.assign(var - subtrahend) # for the case of bfloat16 activations, so as to avoid repeatedly # rounding the slice value, which results in poor quality. var_update = mtf.assign_sub(var, subtrahend) updates.append(var_update) return updates
def batch_norm(x, is_training, momentum, epsilon=1e-9, init_zero=False, name=None): """Batch normalization. Args: x: a mtf.Tensor whose shape contains [batch_dim, ..., dim] is_training: a boolean, whether mode is training. momentum: a floating point number, specifying batch norm decay value. epsilon: a floating point number. init_zero: a boolean, whether to initialize scale with 0's or 1's. name: a string. variable scope. Returns: a mtf.Tensor with same shape as x. """ with tf.variable_scope(name, default_name="batch_norm", values=[x]): if init_zero: gamma_initializer = tf.zeros_initializer() else: gamma_initializer = tf.ones_initializer() norm_dim = x.shape.dims[0:3] reduced_shape = x.shape - norm_dim scale = mtf.get_variable( x.mesh, "batch_norm_scale", reduced_shape, initializer=gamma_initializer, activation_dtype=x.dtype) bias = mtf.get_variable( x.mesh, "batch_norm_bias", reduced_shape, initializer=tf.zeros_initializer(), activation_dtype=x.dtype) moving_mean = mtf.get_variable( x.mesh, "moving_mean", reduced_shape, initializer=tf.random_normal_initializer(stddev=1.0), activation_dtype=x.dtype, trainable=False) moving_variance = mtf.get_variable( x.mesh, "moving_variance", reduced_shape, initializer=tf.ones_initializer(), activation_dtype=x.dtype, trainable=False) # At training time, calculate mean and variance and normalize across batch # dim. if is_training: mean = mtf.reduce_mean(x, output_shape=reduced_shape) variance = mtf.reduce_mean( mtf.square(x - mean), output_shape=reduced_shape) norm_x = (x - mean) * mtf.rsqrt(variance + epsilon) # Update running mean and running variance. moving_mean = mtf.assign( moving_mean, momentum * moving_mean + (1-momentum) * mean) moving_variance = mtf.assign( moving_variance, momentum * moving_variance + (1 - momentum) * variance) else: # At eval and test time, use the running mean and variance. norm_x = (x - moving_mean) * mtf.rsqrt(moving_variance + epsilon) return (norm_x * scale) + bias
def embedding_weights( mesh, vocab_dim, output_dim, variable_dtype, name="embedding"): return mtf.get_variable( mesh, name, mtf.Shape([vocab_dim, output_dim]), dtype=variable_dtype, initializer=tf.random_normal_initializer())