def multihead_attention_vars( mesh, heads, io_channels, kv_channels, activation_dtype): """Create Parameters for Multihead Attention. Args: mesh: a Mesh heads: a Dimension io_channels: a Dimension kv_channels: a Dimension activation_dtype: a tf.dtype Returns: q_var: a Tensor with shape [heads, io_channels, kv_channels] k_var: a Tensor with shape [heads, io_channels, kv_channels] v_var: a Tensor with shape [heads, io_channels, kv_channels] o_var: a Tensor with shape [heads, io_channels, kv_channels] """ qkvo = mtf.Dimension("qkvo", 4) qk_stddev = (io_channels.size ** -0.5) * (kv_channels.size ** -0.25) v_stddev = io_channels.size ** -0.5 o_stddev = (io_channels.size * heads.size) ** -0.5 def qkvo_initializer(shape, dtype=None, partition_info=None, verify_shape=None): del partition_info, verify_shape return tf.random_normal(shape, dtype=dtype) * tf.reshape( [qk_stddev, qk_stddev, v_stddev, o_stddev], [4, 1, 1, 1]) var = mtf.get_variable( mesh, "qkvo", mtf.Shape([qkvo, heads, io_channels, kv_channels]), initializer=qkvo_initializer, activation_dtype=activation_dtype) q_var, k_var, v_var, o_var = mtf.unstack(var, qkvo) return q_var, k_var, v_var, o_var
def dense_relu_dense(x, hidden_channels, dropout=0.0, dropout_broadcast_dims=None, name=None): """Hidden layer with ReLU activation followed by linear projection. The output has the same number of channels as the input. Args: x: a mtf.Tensor hidden_channels: a mtf.Dimension - channels in the hidden layer dropout: an optional float dropout_broadcast_dims: an optional list of mtf.Dimension name: an optional string Returns: a mtf.Tensor with the same shape as x. """ with tf.variable_scope(name, default_name="dense_relu_dense"): io_channels = x.shape.dims[-1] stddev = (hidden_channels.size * io_channels.size) ** -0.25 io = mtf.Dimension("io", 2) w = mtf.get_variable( x.mesh, "kernel", mtf.Shape([io, io_channels, hidden_channels]), initializer=tf.random_normal_initializer(stddev=stddev), activation_dtype=x.dtype) wi, wo = mtf.unstack(w, io) h = mtf.relu(mtf.einsum([x, wi])) if dropout != 0.0: h = mtf.dropout(h, 1.0 - dropout, noise_shape=h.shape - dropout_broadcast_dims) return mtf.einsum([h, wo])
def _decoder_layer_stack_incremental(self, x, step_num, encdec_tensors, self_attention_k, self_attention_v, encdec_attention_mask=None): """Decoder layer stack during inference. We are processing only one position at a time. The self-attention keys and values have already been computed for previous positions. In addition to the decoder output, we need to produce the updated self-attention keys and values. If there is an encoder, then additional Tensors are supplied in encdec_tensors, which give us the keys and values for encoder-decoder attention as well as the weight matrices q_var and o_var. Args: x: a mtf.Tensor with shape [batch_dim, model_dim] step_num: an mtf integer Scalar encdec_tensors: an optional list of num_layers tuples, each of the form (q_var, o_var, k, v) self_attention_k: an optional list of num_layers Tensors each with shape [batch, heads, memory_length, kv_channels] self_attention_v: an optional list of num_layers Tensors each with shape [batch, heads, memory_length, kv_channels] encdec_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, encoder_length_dim] containing values 0 or -inf. Returns: y: a mtf.Tensor with shape [batch_dim, model_dim] new_self_attention_k: a list of num_layers mtf.Tensors, with the same shapes as the elements of self_attention_k new_self_attention_v: a list of num_layers mtf.Tensors, with the same shapes as the elements of self_attention_v Raises: ValueError: if hparams make no sense """ hparams = self._hparams num_layers = hparams.num_decoder_layers num_layer_norms = num_layers * (2 if encdec_tensors is None else 3) + 1 layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms) layer_norm_combined_var = mtf.get_variable( x.mesh, "layer_norm_scale", mtf.Shape([layer_norms_dim, self.model_dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim) def normalize(x): scale = layer_norm_vars.pop(0) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim) return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale new_self_attention_k = [] new_self_attention_v = [] for layer in range(num_layers): with tf.variable_scope("layer_%d" % layer): # Self attention layer y, new_k, new_v = mtf_layers.multihead_self_attention_incremental( normalize(x), prev_k=self_attention_k[layer], prev_v=self_attention_v[layer], step_num=step_num, name="self_attention") new_self_attention_k.append(new_k) new_self_attention_v.append(new_v) x += y if encdec_tensors is not None: # Encoder-Decoder attention layer q_var, o_var, k, v = encdec_tensors[layer] x += mtf_layers.multihead_encdec_attention_incremental( normalize(x), q_var, o_var, k, v, encdec_attention_mask, name="encdec_attention") # ffn layer x += self._feedforward_layer(normalize(x), hparams) x = normalize(x) assert not layer_norm_vars return x, new_self_attention_k, new_self_attention_v
def _layer_stack(self, x, num_layers, encoder_output=None, self_attention_mask=None, encdec_attention_mask=None, losses=None): """Encoder or decoder stack. Args: x: a mtf.Tensor with shape [batch_dim, length_dim, model_dim] num_layers: an integer encoder_output: an optional mtf.Tensor with shape [batch_dim, encoder_length_dim, model_dim] self_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, memory_length_dim] containing values 0 or -inf. encdec_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, encoder_length_dim] containing values 0 or -inf. losses: a list to be appended-to Returns: a mtf.Tensor with shape [batch_dim, length_dim, model_dim] Raises: ValueError: if hparams make no sense """ hparams = self._hparams def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape([self.batch_dim, self.model_dim])) num_layer_norms = num_layers * (2 if encoder_output is None else 3) + 1 layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms) layer_norm_combined_var = mtf.get_variable( x.mesh, "layer_norm_scale", mtf.Shape([layer_norms_dim, self.model_dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim) def normalize(x): scale = layer_norm_vars.pop(0) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim) return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale for layer in range(num_layers): with tf.variable_scope("layer_%d" % layer): # Self attention layer x += layer_prepostprocess_dropout( mtf_layers.multihead_attention( normalize(x), None, self_attention_mask, self.kv_dim, self.heads_dim, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], name="self_attention")) if encoder_output is not None: # Encoder-Decoder attention layer x += layer_prepostprocess_dropout( mtf_layers.multihead_attention( normalize(x), encoder_output, encdec_attention_mask, self.kv_dim, self.heads_dim, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], name="encdec_attention")) # ffn layer x += layer_prepostprocess_dropout( self._feedforward_layer(normalize(x), losses=losses)) x = layer_prepostprocess_dropout(normalize(x)) assert not layer_norm_vars return x