def multihead_encdec_attention_incremental(query_antecedent, q_var, o_var, k, v, mask, name="multihead_attention"): """Incremental attention over encoder (one decode step). In order to use only one variable containing the four weight matrices packed together, we insist that the query and memory antecedents have the same dimensionality (io_channels) and that the keys and values have the same dimensionality (kv_channels). memory_dims is a subset of query_dims Args: query_antecedent: a mtf.Tensor with shape query_dims + [io_channels] q_var: a mtf.Tensor with shape [heads, io_channels, kv_channels] o_var: a mtf.Tensor with shape [heads, io_channels, kv_channels] k: memory_dims + [heads, memory_length, kv_channels] v: memory_dims + [heads, memory_length, kv_channels] mask: mask Tensor (see attention_mask()) name: an optional string. Returns: A mtf.Tensor with shape [batch, qlen, io_channels] """ heads, _, kv_channels = k.shape.dims[-3:] query_dims = query_antecedent.shape.dims[:-1] with tf.variable_scope(name, default_name="multihead_attention"): q = mtf.einsum([query_antecedent, q_var], mtf.TensorShape(query_dims + [heads, kv_channels])) o = dot_product_attention(q, k, v, mask) return mtf.einsum([o, o_var], query_antecedent.shape)
def dense_relu_dense(x, hidden_channels, dropout=0.0, dropout_broadcast_dims=None, name=None): """Hidden layer with ReLU activation followed by linear projection. The output has the same number of channels as the input. Args: x: a mtf.Tensor hidden_channels: a mtf.Dimension - channels in the hidden layer dropout: an optional float dropout_broadcast_dims: an optional list of mtf.Dimension name: an optional string Returns: a mtf.Tensor with the same shape as x. """ with tf.variable_scope(name, default_name="dense_relu_dense"): io_channels = x.shape.dims[-1] stddev = (hidden_channels.size * io_channels.size) ** -0.25 io = mtf.Dimension("io", 2) w = mtf.get_variable( x.mesh, "kernel", mtf.Shape([io, io_channels, hidden_channels]), initializer=tf.random_normal_initializer(stddev=stddev), activation_dtype=x.dtype) wi, wo = mtf.unstack(w, io) h = mtf.relu(mtf.einsum([x, wi])) if dropout != 0.0: h = mtf.dropout(h, 1.0 - dropout, noise_shape=h.shape - dropout_broadcast_dims) return mtf.einsum([h, wo])
def multihead_attention(query_antecedent, memory_antecedent, mask, kv_channels, heads, dropout=0.0, dropout_broadcast_dims=None, name="multihead_attention"): """Multihead scaled-dot-product attention with input/output transformations. In order to use only one variable containing the four weight matrices packed together, we insist that the query and memory antecedents have the same dimensionality (io_channels) and that the keys and values have the same dimensionality (kv_channels). Args: query_antecedent: a mtf.Tensor with shape [batch, query_length, io_channels] memory_antecedent: a mtf.Tensor with shape [batch, memory_length, io_channels] (optional) mask: mask Tensor (see attention_mask()) kv_channels: a mtf.Dimension (the size of the key and value vectors) heads: a mtf.Dimension (the number of heads) dropout: a floating point value dropout_broadcast_dims: an optional list of mtf.Dimension name: an optional string. Returns: A mtf.Tensor with shape [batch, query_length, io_channels] Raises: ValueError: if the dimensions do not match. """ batch, query_length, io_channels = query_antecedent.shape.dims with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent, memory_antecedent]): q_var, k_var, v_var, o_var = multihead_attention_vars( query_antecedent.mesh, heads, io_channels, kv_channels, query_antecedent.dtype) if memory_antecedent is None: memory_antecedent = rename_length_to_memory_length( query_antecedent, query_length.name) memory_batch, memory_length, memory_channels = memory_antecedent.shape.dims if memory_batch != batch: raise ValueError("memory batch must equal query batch") if memory_channels != io_channels: raise ValueError("memory channels must equal query channels") q = mtf.einsum([query_antecedent, q_var], mtf.TensorShape( [batch, heads, query_length, kv_channels])) k = mtf.einsum([memory_antecedent, k_var], mtf.TensorShape( [batch, heads, memory_length, kv_channels])) v = mtf.einsum([memory_antecedent, v_var], mtf.TensorShape( [batch, heads, memory_length, kv_channels])) o = dot_product_attention(q, k, v, mask, dropout, dropout_broadcast_dims) return mtf.einsum([o, o_var], mtf.TensorShape([batch, query_length, io_channels]))
def multihead_self_attention_incremental(query_antecedent, prev_k, prev_v, step_num, name="multihead_attention"): """Incremental self-attention (one decode step). In order to use only one variable containing the four weight matrices packed together, we insist that the query and memory antecedents have the same dimensionality (io_channels) and that the keys and values have the same dimensionality (kv_channels). Args: query_antecedent: a mtf.Tensor with shape [batch..., io_channels] prev_k: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels] prev_v: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels] step_num: mtf Scalar with dtype tf.int32 name: an optional string. Returns: y: A mtf.Tensor with shape [batch..., io_channels] new_k: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels] new_v: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels] Raises: ValueError: if the dimensions do not match. """ batch_dims = query_antecedent.shape.dims[:-1] io_channels = query_antecedent.shape.dims[-1] heads, memory_length, kv_channels = prev_k.shape.dims[-3:] with tf.variable_scope(name, default_name="multihead_attention"): q_var, k_var, v_var, o_var = multihead_attention_vars( query_antecedent.mesh, heads, io_channels, kv_channels, query_antecedent.dtype) memory_antecedent = query_antecedent q = mtf.einsum( [query_antecedent, q_var], mtf.Shape(batch_dims + [heads, kv_channels])) k = mtf.einsum( [memory_antecedent, k_var], mtf.Shape(batch_dims + [heads, kv_channels])) v = mtf.einsum( [memory_antecedent, v_var], mtf.Shape(batch_dims + [heads, kv_channels])) k = prev_k + mtf.multiply( k, mtf.one_hot(step_num, memory_length), output_shape=prev_k.shape) v = prev_v + mtf.multiply( v, mtf.one_hot(step_num, memory_length), output_shape=prev_v.shape) mask = mtf.to_float(mtf.greater(mtf.range( query_antecedent.mesh, memory_length, dtype=tf.int32), step_num) ) * -1e9 o = dot_product_attention(q, k, v, mask) y = mtf.einsum([o, o_var], query_antecedent.shape) return y, k, v
def first_block_attention(): """Compute attention for the first block.""" first_q = mtf.slice(q, 0, 1, num_blocks.name) first_k = mtf.slice(k, 0, 1, num_blocks.name) first_v = mtf.slice(v, 0, 1, num_blocks.name) block = first_q.shape.dims[2] first_logits = mtf.einsum( [first_q, first_k], mtf.TensorShape([batch, heads, block, blength, mlength])) weights = mtf.softmax(first_logits, mlength) first_output = mtf.einsum( [weights, first_v], mtf.TensorShape([batch, heads, block, blength, kv_channels])) return first_output
def dense(x, output_dim, reduced_dims=None, expert_dims=None, use_bias=True, activation=None, name=None): """Dense layer doing (kernel*x + bias) computation. Args: x: a mtf.Tensor of shape [..., reduced_dims]. output_dim: a mtf.Dimension reduced_dims: an optional list of mtf.Dimensions of x to be reduced. If omitted, we reduce the last dimension. expert_dims: an optional list of mtf.Dimension which represent different experts. Different experts get different weights. use_bias: a boolean, whether to add bias. activation: an optional function from mtf.Tensor to mtf.Tensor name: a string. variable scope. Returns: a mtf.Tensor of shape [..., output_dim]. """ if expert_dims is None: expert_dims = [] if reduced_dims is None: reduced_dims = x.shape.dims[-1:] w_shape = mtf.Shape(expert_dims + reduced_dims + [output_dim]) output_shape = mtf.Shape( [d for d in x.shape.dims if d not in reduced_dims] + [output_dim]) with tf.variable_scope(name, default_name="dense"): stddev = mtf.list_product(d.size for d in reduced_dims)**-0.5 w = mtf.get_variable( x.mesh, "kernel", w_shape, initializer=tf.random_normal_initializer(stddev=stddev), activation_dtype=x.dtype) y = mtf.einsum([x, w], output_shape) if use_bias: b = mtf.get_variable(x.mesh, "bias", mtf.Shape(expert_dims + [output_dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype) y += b if activation is not None: y = activation(y) return y
def dot_product_attention(q, k, v, mask, dropout=0.0, dropout_broadcast_dims=None): """Dot-product attention. Args: q: Tensor with shape [...., length_q, depth_k]. Typically leading dimensions are [batch, heads]. k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must match with q. v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must match with q. mask: mask Tensor (see attention_mask()) dropout: a float. dropout_broadcast_dims: an optional list of mtf.Dimension Returns: Tensor with shape [..., length_q, depth_v]. """ length_kv = k.shape.dims[-2] logits_shape = mtf.TensorShape(q.shape.dims[:-1] + [length_kv]) logits = mtf.einsum([q, k], logits_shape) if mask is not None: logits += mask weights = mtf.softmax(logits, length_kv) if dropout != 0.0: weights = mtf.dropout(weights, 1.0 - dropout, noise_shape=weights.shape - dropout_broadcast_dims) depth_v = v.shape.dims[-1] outputs_shape = mtf.TensorShape(q.shape.dims[:-1] + [depth_v]) outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def moe_v0(inputs, hidden_dim, output_dim, experts_dim, loss_coef=1e-3, overhead=1.0): """Local mixture of experts that works well on TPU. See https://arxiv.org/abs/1701.06538 There are num_experts expert networks, each containing a relu-activated hidden layer of size hidden_size, followed by an output projection. The number of parameters is thus: num_experts * (input_size * hidden_size + hidden_size * output_size) The input is 3d: [batch, length, depth], consisting of the representations of all positions in a batch of sequences. Each position of each sequence is sent to 0-2 experts. The expert choices and the combination weights are determined by a learned gating function. This function returns a small auxiliary loss that should be added to the training loss of the model. This loss helps to balance expert usage. Without the loss, it is very likely that a few experts will be trained and the rest will starve. Several hacks are necessary to get around current TPU limitations: - To ensure static shapes, we enforce (by truncation/padding) that each sequence send the same number of elements to each expert. It would make more sense to enforce this equality over the entire batch, as opposed to on individual sequences. This would allow more freedom for individual sequences to be unbalanced. Unfortunately, that would slow down our hacked-up gather-by-matmul implementation. TODO(noam): There is no real reason for a single sequence to be the unit of equal allocation. Reshaping the inputs would allow us to pick a different unit of equal allocation. TODO(noam): Factor this code better. We want to be able to substitute different code for the experts themselves. We also want to integrate this gating/dispatching logic into multi-device mixtures-of-experts. Args: inputs: a mtf.Tensor with shape [batch_dim, length_dim, input_dim] hidden_dim: a mtf.Dimension output_dim: a mtf.Dimension experts_dim: a mtf.Dimension loss_coef: a float scalar overhead: multiplicative factor of how much spare capacity to assign Returns: outputs: a Tensor with shape [batch_dim, length_dim, output_dim] loss: a mtf scalar """ batch_dim, length_dim, input_dim = inputs.shape.dims # Each sequence sends expert_capacity positions to each expert. expert_capacity = min( length_dim.size, int((length_dim.size * 2 * overhead) / experts_dim.size)) expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity) experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size) batch_dim_unsplit = mtf.Dimension("batch_unsplit", batch_dim.size) # This is the learned gating function. # shape = [batch_dim, length_dim, experts_dim_unsplit] gates = mtf.softmax(dense(inputs, experts_dim_unsplit), experts_dim_unsplit) assignment_shape = mtf.TensorShape( [batch_dim, length_dim, experts_dim_unsplit, expert_capacity_dim]) backward_assignment = mtf.slicewise(functools.partial( _truncated_top_2_gating, expert_capacity=expert_capacity), [gates], output_shape=assignment_shape, splittable_dims=[batch_dim], name="backward_assignment") forward_assignment = mtf.cast(mtf.cast(backward_assignment, tf.bool), inputs.dtype) # put num_experts dimension first to make split easier in alltoall expert_inputs = mtf.einsum([inputs, forward_assignment], mtf.TensorShape([ experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim ])) expert_inputs = mtf.reshape( expert_inputs, mtf.TensorShape( [experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim])) # Now feed the expert inputs through the experts. h = dense(expert_inputs, hidden_dim, expert_dims=[experts_dim], activation=mtf.relu, name="x0") expert_output = dense(h, output_dim, expert_dims=[experts_dim], name="x1") expert_output = mtf.reshape( expert_output, mtf.TensorShape( [experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim])) output = mtf.einsum([expert_output, backward_assignment], mtf.TensorShape([batch_dim, length_dim, output_dim])) importance = mtf.reduce_sum(backward_assignment, output_shape=mtf.TensorShape( [batch_dim, experts_dim_unsplit])) loss = cv_squared(importance) * loss_coef return output, loss
def masked_local_attention_1d(query_antecedent, memory_antecedent, kv_channels, heads, block_length=128, name=None): """Attention to the source position and a neighborhood to the left of it. The sequence is divided into blocks of length block_size. Attention for a given query position can only see memory positions less than or equal to the query position, in the corresponding block and the previous block. Args: query_antecedent: a mtf.Tensor with shape [batch, query_length, io_channels] memory_antecedent: a mtf.Tensor with shape [batch, memory_length, io_channels] (optional). Currently, memory_length must have the same size as query_length, but a different name. kv_channels: a mtf.Dimension (the size of the key and value vectors) heads: a mtf.Dimension (the number of heads) block_length: an integer, representing receptive fields for attention. name: an optional string. Returns: a Tensor of shape [batch, query_length, io_channels] Raises: ValueError: if channels or depth don't match. """ with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent, memory_antecedent]): batch, query_length, io_channels = query_antecedent.shape.dims q_var, k_var, v_var, o_var = multihead_attention_vars( query_antecedent.mesh, heads, io_channels, kv_channels, query_antecedent.dtype) if memory_antecedent is None: memory_antecedent = rename_length_to_memory_length( query_antecedent, query_length.name) memory_batch, memory_length, memory_channels = memory_antecedent.shape.dims if memory_batch != batch: raise ValueError("memory batch must equal query batch") if memory_channels != io_channels: raise ValueError("memory channels must equal query channels") # Get query q, keys k and values v. q = mtf.einsum([query_antecedent, q_var], mtf.TensorShape( [batch, heads, query_length, kv_channels])) k = mtf.einsum([memory_antecedent, k_var], mtf.TensorShape( [batch, heads, memory_length, kv_channels])) v = mtf.einsum([memory_antecedent, v_var], mtf.TensorShape( [batch, heads, memory_length, kv_channels])) # Let's assume for now we don't have padding and the block length equally # divides the memory length. block_length = (query_length.size if query_length.size < block_length * 2 else block_length) blength = mtf.Dimension("block_length", block_length) mlength = mtf.Dimension("mem_block_length", block_length) num_blocks = mtf.Dimension("num_blocks", query_length.size // block_length) q = mtf.reshape( q, mtf.TensorShape([batch, heads, num_blocks, blength, kv_channels])) k = mtf.reshape( k, mtf.TensorShape([batch, heads, num_blocks, mlength, kv_channels])) v = mtf.reshape( v, mtf.TensorShape([batch, heads, num_blocks, mlength, kv_channels])) # compute attention for the first query block. def first_block_attention(): """Compute attention for the first block.""" first_q = mtf.slice(q, 0, 1, num_blocks.name) first_k = mtf.slice(k, 0, 1, num_blocks.name) first_v = mtf.slice(v, 0, 1, num_blocks.name) block = first_q.shape.dims[2] first_logits = mtf.einsum( [first_q, first_k], mtf.TensorShape([batch, heads, block, blength, mlength])) weights = mtf.softmax(first_logits, mlength) first_output = mtf.einsum( [weights, first_v], mtf.TensorShape([batch, heads, block, blength, kv_channels])) return first_output # Attention for first block, since query_length = key_length. first_output = first_block_attention() # Concatenate two adjacent blocks to compute the overlapping memory block. def local(x): """Helper function to get memory blocks.""" prev_block = mtf.slice(x, 0, num_blocks.size - 1, num_blocks.name) cur_block = mtf.slice(x, 1, num_blocks.size - 1, num_blocks.name) local_block = mtf.concat([prev_block, cur_block], mlength.name) return local_block local_k = local(k) local_v = local(v) mblocks = local_k.shape.dims[2] mlength = local_k.shape.dims[3] # Calculate the causal mask to avoid peeking into the future. We compute # this once and reuse it for all blocks since the block_size is known. mask = attention_bias_local_block(query_antecedent.mesh, blength, mlength) # Remove the first block from q since we already computed that. tail_q = mtf.slice(q, 1, num_blocks.size - 1, num_blocks.name) # Compatibility between q and k for rest of the blocks. # Shape [batch, heads, num_blocks - 1, block_length, local_length] attention = mtf.einsum([tail_q, local_k], mtf.TensorShape( [batch, heads, mblocks, blength, mlength])) attention += mask attention = mtf.softmax(attention, mlength) # Run attention for rest of the blocks. # Shape [batch, heads, num_blocks-1, block_length, kv_channels] output = mtf.einsum([attention, local_v], mtf.TensorShape( [batch, heads, mblocks, blength, kv_channels])) # Now concatenate the first and rest of the blocks. final_output = mtf.concat([first_output, output], num_blocks.name) final_output = mtf.reshape( final_output, mtf.TensorShape([batch, heads, query_length, kv_channels])) return mtf.einsum([final_output, o_var], mtf.TensorShape([batch, query_length, io_channels]))
def local_self_attention_spatial_blocks(query_antecedent, kv_channels, heads, memory_w_dim=None, mask_right=False, name=None): """Attention to the source position and a neighborhood to the left or right. The sequence is divided into blocks of length block_size. Attention for a given query position can only see memory positions less than or equal to the query position, in the corresponding block and the previous block. Args: query_antecedent: a mtf.Tensor with shape [batch, num_h_blocks, num_w_blocks, h_dim, w_dim, io_channels] must have the same size as query_length, but a different name. kv_channels: a mtf.Dimension (the size of the key and value vectors) heads: a mtf.Dimension (the number of heads) memory_w_dim: mtf Dimension, for the memory width block. mask_right: bool, flag specifying whether we mask out attention to the right for the decoder. name: an optional string. Returns: a Tensor of shape [batch, num_h_blocks, num_w_blocks, h_dim, w_dim, io_channels] Raises: ValueError: if channels or depth don't match. """ with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent]): w_dim, io_channels = query_antecedent.shape.dims[-2:] batch, num_w_blocks = query_antecedent.shape.dims[:2] q_var, k_var, v_var, o_var = multihead_attention_vars( query_antecedent.mesh, heads, io_channels, kv_channels, query_antecedent.dtype) # Rename dimensions for the memory height and width. memory_antecedent = mtf.rename_dimension(query_antecedent, w_dim.name, memory_w_dim.name) # Call einsum over the query and memory to get query q, keys k and values v. q = mtf.einsum([query_antecedent, q_var], mtf.Shape( [batch, heads, num_w_blocks, w_dim, kv_channels])) k = mtf.einsum([memory_antecedent, k_var], mtf.Shape( [batch, heads, num_w_blocks, w_dim, kv_channels])) v = mtf.einsum([memory_antecedent, v_var], mtf.Shape( [batch, heads, num_w_blocks, w_dim, kv_channels])) # Halo exchange for memory blocks. if memory_w_dim is not None: k, v = local_1d_halo_exchange(k, v, num_w_blocks, w_dim, memory_w_dim, mask_right) # Calculate the causal mask to avoid peeking into the future. We compute # this once and reuse it for all blocks since the block_size is known. mask = None if mask_right: mask = attention_bias_local_block(query_antecedent.mesh, w_dim, memory_w_dim) output = dot_product_attention(q, k, v, mask=mask) return mtf.einsum([output, o_var], mtf.Shape([batch, num_w_blocks, w_dim, io_channels]))
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if self.has_input: inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad( inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length( inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = ( mtf_layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.num_encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num in xrange(hparams.num_decoder_layers): with tf.variable_scope("decoder/layer_%d/encdec_attention" % layer_num): q_var, k_var, v_var, o_var = mtf_layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.activation_dtype) k = mtf.einsum( [encoder_output, k_var], mtf.Shape( [self.batch_dim, self.heads_dim, self.memory_length_dim, self.kv_dim])) v = mtf.einsum( [encoder_output, v_var], mtf.Shape( [self.batch_dim, self.heads_dim, self.memory_length_dim, self.kv_dim])) encdec_tensors.append((q_var, o_var, k, v)) partial_targets = None else: encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) if hparams.beam_size == 1: ids_shape = mtf.Shape([self.batch_dim, self.length_dim]) kv_shape = mtf.Shape([self.batch_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape([self.batch_dim, beam_dim, self.length_dim]) kv_shape = mtf.Shape([self.batch_dim, beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_kv_states = ( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * (2 * hparams.num_decoder_layers)) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" self_attention_k = states[:hparams.num_decoder_layers] self_attention_v = states[hparams.num_decoder_layers:] ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_self_attention_k, new_self_attention_v = ( self._decoder_layer_stack_incremental( x, step_num, encdec_tensors, self_attention_k, self_attention_v, encdec_attention_mask=encoder_attention_mask)) logits = mtf.matmul(x, softmax_var) return logits, new_self_attention_k + new_self_attention_v if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf_beam_search.greedy_decode( logits_fn, initial_ids, temperature=temperature, initial_states=initial_kv_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if self.has_input: input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf_beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_kv_states, decode_length=decode_length, use_tpu=hparams.use_tpu) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
def transformer_moe_layer_v1(inputs, output_dim, hparams, train): """Local mixture of experts that works well on TPU. Adapted from the paper https://arxiv.org/abs/1701.06538 Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_num_experts: number of experts hparams.moe_hidden_size: size of hidden layer in each expert hparams.moe_group_size: size of each "group" for gating purposes hparams.moe_capacity_factor_train: a float hparams.moe_capacity_factor_eval: a float hparams.moe_gating: a string + all hyperparmeters used by _top_2_gating() The number of parameters in the gating network is: (input_dim.size * hparams.num_experts) + The number of parameters in the experts themselves is: (hparams.num_experts * (input_dim.size + output_dim.size) * hparams.moe_hidden_size) The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting of the representations of all positions in a batch of sequences. Each position of each sequence is sent to 0-2 experts. The expert choices and the combination weights are determined by a learned gating function. This function returns a small auxiliary loss that should be added to the training loss of the model. This loss helps to balance expert usage. Without the loss, it is very likely that a few experts will be trained and the rest will starve. Several hacks are necessary to get around current TPU limitations: - To ensure static shapes, we enforce (by truncation/padding) that each sequence send the same number of elements to each expert. It would make more sense to enforce this equality over the entire batch, but due to our hacked-up gather-by-matmul implementation, we need to divide the batch into "groups". For each group, the same number of elements are sent to each expert. TODO(noam): Factor this code better. We want to be able to substitute different code for the experts themselves. Args: inputs: a mtf.Tensor with shape [<batch_dims...>, length_dim, input_dim] output_dim: a mtf.Dimension (for Transformer, this is input_dim) hparams: model hyperparameters train: a boolean Returns: outputs: a Tensor with shape [<batch_dims...>, length_dim, output_dim] loss: a mtf scalar Raises: ValueError: on unrecognized hparams.moe_gating """ orig_inputs = inputs input_dim = inputs.shape.dims[-1] hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size) experts_dim = mtf.Dimension("experts", hparams.moe_num_experts) group_size_dim = mtf.Dimension("group", hparams.moe_group_size) batch_dim = mtf.Dimension( orig_inputs.shape[0].name, orig_inputs.shape.size // (group_size_dim.size * input_dim.size)) inputs = mtf.reshape(inputs, [batch_dim, group_size_dim, input_dim]) # Each sequence sends expert_capacity positions to each expert. capacity_factor = (hparams.moe_capacity_factor_train if train else hparams.moe_capacity_factor_eval) expert_capacity = min( group_size_dim.size, int((group_size_dim.size * capacity_factor) / experts_dim.size)) expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity) experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size) batch_dim_unsplit = mtf.Dimension("batch_unsplit", batch_dim.size) if hparams.moe_gating == "top_2": dispatch_tensor, combine_tensor, loss = _top_2_gating( inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train) else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) # put num_experts dimension first to make split easier in alltoall expert_inputs = mtf.einsum([inputs, dispatch_tensor], mtf.Shape([ experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim ])) expert_inputs = mtf.reshape( expert_inputs, mtf.Shape( [experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim])) # Now feed the expert inputs through the experts. h = mtf_layers.dense(expert_inputs, hidden_dim, expert_dims=[experts_dim], activation=mtf.relu, use_bias=False, name="x0") expert_output = mtf_layers.dense(h, output_dim, expert_dims=[experts_dim], use_bias=False, name="x1") expert_output = mtf.reshape( expert_output, mtf.Shape( [experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim])) output = mtf.einsum([expert_output, combine_tensor], mtf.Shape([batch_dim, group_size_dim, output_dim])) output = mtf.reshape(output, orig_inputs.shape.dims[:-1] + [output_dim]) return output, loss * hparams.moe_loss_coef
def transformer_moe_layer_v2(inputs, output_dim, hparams, train): """2-level mixture of experts. Adapted from the paper https://arxiv.org/abs/1701.06538 Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_num_experts: number of experts hparams.moe_hidden_size: size of hidden layer in each expert hparams.moe_group_size: size of each "group" for gating purposes hparams.moe_capacity_factor_train: a float hparams.moe_capacity_factor_eval: a float hparams.moe_capacity_factor_second_level: a float hparams.moe_gating: a string + all hyperparmeters used by _top_2_gating() One set of params for experts in first level and different of hparams per expert in the second level. The number of parameters in the gating network is: (input_dim.size * (hparams.num_experts) + (moe_hidden_size * hparams.num_experts) * hparams.num_experts The number of parameters in the experts themselves is: (hparams.num_experts * (input_dim.size + output_dim.size) * hparams.moe_hidden_size) The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting of the representations of all positions in a batch of sequences. Each position of each sequence is sent to 0-3 experts. The expert choices and the combination weights are determined by a learned gating function. This function returns a small auxiliary loss that should be added to the training loss of the model. This loss helps to balance expert usage. Without the loss, it is very likely that a few experts will be trained and the rest will starve. Several hacks are necessary to get around current TPU limitations: - To ensure static shapes, we enforce (by truncation/padding) that each sequence send the same number of elements to each expert. It would make more sense to enforce this equality over the entire batch, but due to our hacked-up gather-by-matmul implementation, we need to divide the batch into "groups". For each group, the same number of elements are sent to each expert. TODO(noam): Factor this code better. We want to be able to substitute different code for the experts themselves. Dimensions cheat sheet: a, b: batch size l: original sequence length m: input depth n: output depth g, h: number of groups s, t: group size x, y: number of experts c, d: expert capacity input: [a0, b1, l, m] input: [a0, g1, s, m] dispatch_tensor_x: [a0, g1, s, x, c] expert_input: [a0, g1, x, c, m] alltoall: [a0, g, x1, c, m] alltoall: [a0, g, x1, c, m] transpose: [x1, a0, g, c, m] reshape: [x1, h0, s, m] assignment2: [x1, h0, t, y, d] expert_input2: [x1, h0, y, d, m] alltoall: [x1, h, y0, d, m] ... reverse of that gating params 0: [m, x] gating params 1: [x1, m, y] expert params: [x1, y0, m, hidden] [x1, y0, hidden, n] Args: inputs: a mtf.Tensor with shape [a, b, l, m] output_dim: a mtf.Dimension (for Transformer, this is input_dim) hparams: model hyperparameters train: a boolean Returns: outputs: a Tensor with shape [a, b, l, n] loss: a mtf scalar Raises: ValueError: on unrecognized hparams.moe_gating """ insert_outer_batch_dim = (len(inputs.shape.dims) == 3) if insert_outer_batch_dim: inputs = mtf.reshape(inputs, [mtf.Dimension("outer_batch", 1)] + inputs.shape.dims) assert len(hparams.moe_num_experts) == 2 a0, b1, l, m = inputs.shape.dims hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size) x1 = mtf.Dimension("expert_x", hparams.moe_num_experts[0]) y0 = mtf.Dimension("expert_y", hparams.moe_num_experts[1]) x = mtf.Dimension("expert_x_unsplit", hparams.moe_num_experts[0]) y = mtf.Dimension("expert_y_unsplit", hparams.moe_num_experts[1]) n = output_dim # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups (g.size) is a multiple of the mesh dimension # over which those groups are split. num_groups, group_size = _split_into_groups( b1.size * l.size, hparams.moe_group_size, _tensor_dim_to_mesh_dim_size(hparams, b1)) g1 = mtf.Dimension(b1.name, num_groups) g = mtf.Dimension(b1.name + "_unsplit", g1.size) s = mtf.Dimension("group_size_x", group_size) # Each sequence sends (at most?) expert_capacity positions to each expert. # Static expert_capacity dimension is needed for expert batch sizes capacity_factor = (hparams.moe_capacity_factor_train if train else hparams.moe_capacity_factor_eval) expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size)) c = mtf.Dimension("expert_capacity_x", expert_capacity) # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups (h.size) is a multiple of the mesh dimension # over which those groups are split. num_groups, group_size = _split_into_groups( a0.size * g.size * c.size, hparams.moe_group_size, _tensor_dim_to_mesh_dim_size(hparams, a0)) t = mtf.Dimension("group_size_y", group_size) h0 = mtf.Dimension(a0.name, num_groups) h = mtf.Dimension(a0.name + "_unsplit", h0.size) expert_capacity = min( t.size, int((t.size * hparams.moe_capacity_factor_second_level) / y.size)) d = mtf.Dimension("expert_capacity_y", expert_capacity) # First level of expert routing # Reshape the inner batch size to a multiple of group_dim g1 and # group_size_dim s. inputs = mtf.reshape(inputs, [a0, g1, s, m]) # Get the assignments for the first level. # dispatch_tensor_x has shape [a0, g1, s, x, c] if hparams.moe_gating == "top_2": dispatch_tensor_x, combine_tensor_x, loss_outer = _top_2_gating( inputs=inputs, outer_expert_dims=None, experts_dim=x, expert_capacity_dim=c, hparams=hparams, train=train) else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) # Now create expert_inputs based on the assignments. # put num_experts dimension first to make split easier in alltoall expert_inputs_x = mtf.einsum([inputs, dispatch_tensor_x], [x, a0, g1, c, m]) # we construct an "importance" Tensor for the inputs to the second-level # gating. The importance of an input is 1.0 if it represents the # first-choice expert-group and 0.5 if it represents the second-choice expert # group. This is used by the second-level gating. importance = mtf.reduce_sum(combine_tensor_x, output_shape=[x, a0, g1, c]) importance = 0.5 * (mtf.to_float(mtf.greater(importance, 0.5)) + mtf.to_float(mtf.greater(importance, 0.0))) # First level, all to all. Here we change the split dimension from g1 to x1. expert_inputs_x = mtf.reshape(expert_inputs_x, mtf.Shape([x1, a0, g, c, m])) importance = mtf.reshape(importance, [x1, a0, g, c]) # Second level of expert routing # Reshape the expert_inputs outer batch dim to be a multiple of group_dim h0 # and group_size_dim t. inputs_y = mtf.reshape(expert_inputs_x, [x1, h0, t, m]) importance = mtf.reshape(importance, [x1, h0, t]) # Get the assignments for the second level. # dispatch_tensor_y has shape [x1, h0, t, y, d] if hparams.moe_gating == "top_2": dispatch_tensor_y, combine_tensor_y, loss_inner = _top_2_gating( inputs=inputs_y, outer_expert_dims=[x1], experts_dim=y, expert_capacity_dim=d, hparams=hparams, train=train, importance=importance) else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) # Now create expert_inputs based on the assignments. # put num_experts dimension first to make split easier in alltoall expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y], [y, x1, h0, d, m]) # Second level, all to all. Here we change the split dimension from h0 to y0. expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape([y0, x1, h, d, m])) # Now feed the expert inputs through the experts. hidden_output = mtf_layers.dense(expert_inputs_y, hidden_dim, expert_dims=[y0, x1], activation=mtf.relu, use_bias=False, name="expert0") expert_output = mtf_layers.dense(hidden_output, output_dim, expert_dims=[y0, x1], use_bias=False, name="expert1") # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done) # expert_output has shape [y0, x1, h, d, n] # alltoall expert_output = mtf.reshape(expert_output, mtf.Shape([y, x1, h0, d, n])) # combine results from inner level output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n]) # Reshape the combined tensor from inner level to now contain outer_batch_dim # a0 and group_dim g output = mtf.reshape(output_y, [x1, a0, g, c, n]) # alltoall from expert_dim x to group_dim g1 expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n])) # combine results from outer level output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n]) # Reshape the combined tensor to now contain inner_batch_dim # b1 and the original sequence length output = mtf.reshape(output_x, [a0, b1, l, n]) if insert_outer_batch_dim: output = mtf.reshape(output, [b1, l, n]) return output, (loss_outer + loss_inner) * hparams.moe_loss_coef