def _compute_merge_qkv(self, antecedent): """Computes qkv all in one call using MoE layer.""" def _replace_d_model_dim(t): """Used to replace the `d_model` dim with `heads`.""" new_last_dim = mtf.Dimension(self.q_shape[-1].name, t.shape[-1].size) return mtf.reshape( t, new_shape=mtf.Shape(t.shape[:-1] + [new_last_dim])) if self.expert_computation == "qkv": # NOTE: This assumes querty and memory antecedent are the same qk = self.moe_layer.call(self.context, antecedent) # Split qk here since they went through experts-layers q, k = qk q = _replace_d_model_dim(q) k = _replace_d_model_dim(k) elif self.expert_computation == "q": q = self.moe_layer.call(self.context, antecedent) q = _replace_d_model_dim(q) # Compute key/value normally k = mtf.layers.us_einsum( [antecedent, self.wkv], reduced_dims=[self.memory_input_dim]) elif self.expert_computation == "kv": k = self.moe_layer.call(self.context, antecedent) k = _replace_d_model_dim(k) # Compute query normally q = mtf.layers.us_einsum( [antecedent, self.wq], reduced_dims=[self.query_input_dim]) else: raise ValueError("Invalid expert computation mode: {}".format( self.expert_computation)) # Scale query q *= self.key_dim.size ** -0.5 self._q = mtf.replace_dimensions(q, q.shape.dims[-1], self.q_dims) self._k = mtf.replace_dimensions(k, k.shape.dims[-1], self.k_dims)
def get_indices(self, keys: mtf.Tensor, query: mtf.Tensor) -> Tuple[mtf.Tensor, mtf.Tensor]: """Generate score and indices for the query.""" score_shape = mtf.Shape(query.shape.dims[:-1] + keys.shape.dims[2:3]) scores = mtf.einsum([query, keys], output_shape=score_shape) # [b, l, h, 2, n_keys] knn_dim = mtf.Dimension("knn", self.knn) scores, indices = mtf.top_k(scores, score_shape.dims[-1], knn_dim) # [b, l, h, 2, knn] # Computes the top cartesian products and their indices knn_square_dim = mtf.Dimension("knn_square_dim", self.knn**2) scores1, scores2 = mtf.unstack(scores, scores.shape.dims[-2]) scores2 = mtf.rename_dimension(scores2, "knn", "knn2") out_shape = mtf.Shape(scores1.shape.dims + scores2.shape.dims[-1:]) all_scores = mtf.add(scores1, scores2, output_shape=out_shape) all_scores = mtf.replace_dimensions(all_scores, out_shape[-2:], knn_square_dim) indices1, indices2 = mtf.unstack(indices, indices.shape.dims[-2]) indices1 = mtf.multiply(indices1, self.n_keys) indices2 = mtf.rename_dimension(indices2, "knn", "knn2") all_indices = mtf.add(indices1, indices2, output_shape=out_shape) all_indices = mtf.replace_dimensions(all_indices, out_shape[-2:], knn_square_dim) scores, best_indices = mtf.top_k(all_scores, all_scores.shape.dims[-1], knn_dim) return scores, mtf.gather(all_indices, best_indices, knn_square_dim)
def mdha_shared_qk(self, query_antecedent, context): """MDHA QK shared projection.""" ret = mtf.layers.us_einsum([query_antecedent, self.wq], reduced_dims=[self.query_input_dim]) with tf.variable_scope("qk_dconv"): len_dim = context.length_dim context.length_dim = ret.shape.dims[-2] ret = causal_depthwise_conv(ret, context=context, kernel_size=3) context.length_dim = len_dim q = mtf.layers.dense(ret, ret.shape.dims[-1:], use_bias=False, activation=None, variable_dtype=context.variable_dtype, reduced_dims=ret.shape.dims[-1:], name="q_solo_project", expert_dims=context.model.ensemble_dims) k = ret if self.combine_dims: q = mtf.replace_dimensions(q, q.shape.dims[-1], self.q_dims) k = mtf.replace_dimensions(k, k.shape.dims[-1], self.k_dims) if not self.fold_scaling_into_initializer: q *= self.key_dim.size**-0.5 return q, k
def _compute_merge_qkv(self, antecedent): """Computes qkv all in one call using MoE layer.""" # NOTE: This assumes querty and memory antecedent are the same qk = self.moe_layer.call(self.context, antecedent) # Split qk here since they went through experts-layers q, k = qk # Scale query q *= self.key_dim.size ** -0.5 self._q = mtf.replace_dimensions(q, q.shape.dims[-1], self.q_dims) self._k = mtf.replace_dimensions(k, k.shape.dims[-1], self.k_dims)
def call(self, context, x, losses=None): """Call the layer.""" memory_length = self.memory_length(context) q = self.compute_q(context, x) if context.mode == "incremental": m = x else: m = mtf.replace_dimensions(x, context.length_dim, memory_length) k = self.compute_k(context, m) v = self.compute_v(context, m) if context.mode == "incremental": one_hot = mtf.one_hot( context.position, memory_length, dtype=context.activation_dtype) inv_one_hot = 1.0 - one_hot old_k, old_v = context.get_states(2) k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot memory_position = mtf.range(context.mesh, memory_length, tf.int32) else: memory_position = self.rename_length_to_memory_length( context.position, context) if context.mode == "incremental" or context.mode == "first_part": context.record_new_states([k, v]) bias = self.compute_bias(context, memory_position, x, self.softmax_heads_dims, q) return self.attention_internal(context, x, m, q, k, v, memory_length, bias)
def call(self, context, x, losses=None): """Call the layer.""" params = self.make_params(context) q = params.compute_q(x) memory_length = self.memory_length(context) if context.mode == "incremental": m = x else: m = mtf.replace_dimensions(x, context.length_dim, memory_length) if self.shared_kv: kv = params.compute_kv(m) else: k = params.compute_k(m) v = params.compute_v(m) if context.mode == "incremental": one_hot = mtf.one_hot( context.position, memory_length, dtype=context.activation_dtype) inv_one_hot = 1.0 - one_hot if self.shared_kv: old_kv = context.get_states(1) kv = old_kv * inv_one_hot + kv * one_hot else: old_k, old_v = context.get_states(2) k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot memory_position = mtf.range(context.mesh, memory_length, tf.int32) else: memory_position = self.rename_length_to_memory_length( context.position, context) if context.mode == "incremental" or context.mode == "first_part": context.record_new_states([kv] if self.shared_kv else [k, v]) if self.shared_kv: k = kv v = kv if self.attention_func == "hybrid": o = attention.hybrid_attention( q, k, v, context, memory_length, self.kv_dim, self.kv_dim, self.compute_bias( context, memory_position, x, params.query_heads_dims), **self.attention_kwargs_from_context(context)) else: o = attention.attention( q, k, v, memory_length, self.kv_dim, self.kv_dim, self.compute_bias( context, memory_position, x, params.query_heads_dims), **self.attention_kwargs_from_context(context)) return params.compute_output(o, output_shape=x.shape)
def mdha_v(self, memory_antecedent, context): """MDHA V projection.""" ret = mtf.layers.us_einsum([memory_antecedent, self.wv], reduced_dims=[self.memory_input_dim]) with tf.variable_scope("v_dconv"): len_dim = context.length_dim context.length_dim = ret.shape.dims[-2] ret = causal_depthwise_conv(ret, context=context, kernel_size=3) context.length_dim = len_dim if self.combine_dims: ret = mtf.replace_dimensions(ret, ret.shape.dims[-1], self.v_dims) return ret
def mdha_q(self, query_antecedent, context): """MDHA Q projection.""" ret = mtf.layers.us_einsum([query_antecedent, self.wq], reduced_dims=[self.query_input_dim]) with tf.variable_scope("q_dconv"): len_dim = context.length_dim context.length_dim = ret.shape.dims[-2] ret = causal_depthwise_conv(ret, context=context, kernel_size=3) context.length_dim = len_dim if self.combine_dims: ret = mtf.replace_dimensions(ret, ret.shape.dims[-1], self.q_dims) if not self.fold_scaling_into_initializer: ret *= self.key_dim.size**-0.5 return ret
def compute_q(self, query_antecedent): """Compute query Tensor q. Args: query_antecedent: a Tensor with dimensions {query_input_dim} + other_dims Returns: a Tensor with dimensions query_heads_dims + {key_dim} + other_dims """ ret = mtf.einsum([query_antecedent, self.wq], reduced_dims=[self.query_input_dim]) if self.combine_dims: ret = mtf.replace_dimensions(ret, ret.shape.dims[-1], self.q_dims) return ret
def compute_kv(self, memory_antecedent): """Compute key/value Tensor kv. Args: memory_antecedent: a Tensor with dimensions {memory_input_dim} + other_dims Returns: a Tensor with dimensions memory_heads_dims + {key_dim} + other_dims """ if not self.shared_kv: raise ValueError("compute_kv can only be called with shared_kv") ret = mtf.einsum([memory_antecedent, self.wkv], reduced_dims=[self.memory_input_dim]) if self.combine_dims: ret = mtf.replace_dimensions(ret, ret.shape.dims[-1], self.k_dims) return ret
def compute_q(self, query_antecedent): """Compute query Tensor q. Args: query_antecedent: a Tensor with dimensions {query_input_dim} + other_dims Returns: a Tensor with dimensions query_heads_dims + {key_dim} + other_dims """ ret = mtf.layers.us_einsum( [query_antecedent, self.wq], reduced_dims=[self.query_input_dim]) if self.combine_dims: ret = mtf.replace_dimensions(ret, ret.shape.dims[-1], self.q_dims) if not self.fold_scaling_into_initializer: ret *= self.key_dim.size ** -0.5 return ret
def compute_v(self, memory_antecedent): """Compute value Tensor v. Args: memory_antecedent: a Tensor with dimensions {memory_input_dim} + other_dims Returns: a Tensor with dimensions memory_heads_dims + {value_dim} + other_dims """ if self.shared_kv: raise ValueError("compute_v cannot be called with shared_kv") ret = mtf.layers.us_einsum( [memory_antecedent, self.wv], reduced_dims=[self.memory_input_dim]) if self.combine_dims: ret = mtf.replace_dimensions(ret, ret.shape.dims[-1], self.v_dims) return ret
def compute_output(self, o, output_shape=None): """Compute output of multihead attention. Args: o: a Tensor with dimensions query_heads_dims + {value_dim} + other_dims output_shape: an optional Shape Returns: a Tensor with shape: {output_dim} + other_dims """ if self.combine_dims: o = mtf.transpose(o, o.shape - self.o_dims + self.o_dims) o = mtf.replace_dimensions(o, self.o_dims, self.wo.shape.dims[-2]) reduced_dims = [self.wo.shape.dims[-2]] else: reduced_dims = self.o_dims return mtf.einsum( [o, self.wo], output_shape=output_shape, reduced_dims=reduced_dims)
def call(self, context, x, losses=None): """Call the layer.""" memory_length = self.memory_length(context) q = self.compute_q(context, x) if context.mode == "incremental": m = x else: m = mtf.replace_dimensions(x, context.length_dim, memory_length) if context.mode == "incremental": one_hot = mtf.one_hot( context.position, memory_length, dtype=context.activation_dtype) inv_one_hot = 1.0 - one_hot old_m, = context.get_states(1) m = old_m * inv_one_hot + one_hot * m memory_position = mtf.range(context.mesh, memory_length, tf.int32) else: memory_position = self.rename_length_to_memory_length( context.position, context) if context.mode == "incremental" or context.mode == "first_part": context.record_new_states([m]) bias = self.compute_bias(context, memory_position, x, self.heads_dims, q) return self.attention_internal(context, q, m, memory_length, bias)
def self_attention(self, x, attention_bias): """Performs multi-headed self-attention with output projection. Args: x: output of previous layer attention_bias: optional float32 Tensor broadcastable to shape x.shape - self.model_dim + self.memory_seq_dim to be added to attention logits. This may used to mask out padding regions of the memory. Returns: float Tensor with the same shape as x """ queries = mtf.layers.dense( x, reduced_dims=[self.model_dim], new_dims=[self.num_heads_dim, self.size_per_head_dim], kernel_initializer=self.dense_initializer, name="query", use_bias=self.config.use_bias) keys = mtf.layers.dense( mtf.replace_dimensions(x, self.seq_dim, self.memory_seq_dim), reduced_dims=[self.model_dim], new_dims=[self.num_heads_dim, self.size_per_head_dim], kernel_initializer=self.dense_initializer, name="key", use_bias=self.config.use_bias) values = mtf.layers.dense( mtf.replace_dimensions(x, self.seq_dim, self.memory_seq_dim), reduced_dims=[self.model_dim], new_dims=[self.num_heads_dim, self.size_per_head_dim], kernel_initializer=self.dense_initializer, name="value", use_bias=self.config.use_bias) # Take the dot product between "query" and "key" to get the raw # attention scores. attention_scores = mtf.einsum( [queries, keys], reduced_dims=[self.size_per_head_dim]) attention_scores *= self.size_per_head_dim.size ** -0.5 if attention_bias is not None: attention_scores += attention_bias # Normalize the attention scores to probabilities. attention_probs = mtf.softmax(attention_scores, self.memory_seq_dim) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = mtf.dropout( attention_probs, is_training=(self.config.attention_probs_dropout_prob == 0.0), keep_prob=1.0 - self.config.attention_probs_dropout_prob) output = mtf.einsum([attention_probs, values], reduced_dims=[self.memory_seq_dim]) # linear transformation back to shape of query_antecedent output = mtf.layers.dense( output, reduced_dims=[self.num_heads_dim, self.size_per_head_dim], new_dims=[self.model_dim], kernel_initializer=self.dense_initializer, name="output", use_bias=self.config.use_bias) output = mtf.transpose(output, x.shape) return output
def __init__(self, config, is_training, input_ids, input_mask=None, token_type_ids=None, scope=None, mesh_shape="", layout=""): self.config = copy.deepcopy(config) del config if not is_training: self.config.layer_output_dropout_prob = 0.0 self.config.attention_probs_dropout_prob = 0.0 self.config.feedforward_intermediate_dropout_prob = 0.0 input_shape = input_ids.shape assert input_shape.ndims == 2 self._seq_dim = input_shape.dims[1] self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size) self._extra_losses = [] mesh = input_ids.mesh if token_type_ids is None: token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32) with tf.variable_scope(scope, default_name="bert"): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. self.embedding_table = mtf.get_variable( mesh, "word_embeddings", mtf.Shape([self.vocab_dim, self.model_dim]), initializer=self.embedding_initializer) self.word_embedding_output = mtf.gather( self.embedding_table, input_ids, self.vocab_dim) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. self.embedding_output = self.word_embedding_output token_type_table = mtf.get_variable( mesh, "token_type_embeddings", mtf.Shape([self.token_type_vocab_dim, self.model_dim]), initializer=self.embedding_initializer) if token_type_ids is not None: self.embedding_output += mtf.gather( token_type_table, token_type_ids, self.token_type_vocab_dim) if self.config.position_signal == "embedding": full_position_table = mtf.get_variable( mesh, "position_embeddings", mtf.Shape([self.max_position_embeddings_dim, self.model_dim]), initializer=self.embedding_initializer) short_position_table = mtf.rename_dimension( mtf.slice(full_position_table, 0, self.seq_dim.size, self.max_position_embeddings_dim.name), self.max_position_embeddings_dim.name, self.seq_dim.name) self.embedding_output += short_position_table self.embedding_output = self.normalize(self.embedding_output) self.embedding_output = mtf.dropout( self.embedding_output, is_training, keep_prob=1.0 - self.config.layer_output_dropout_prob) with tf.variable_scope("encoder"): attention_biases = [] if input_mask: # [batch_dim, memory_seq_dim] attention_biases.append( (1.0 - mtf.to_float(mtf.replace_dimensions( input_mask, self.seq_dim, self.memory_seq_dim))) * -10000.0) if self.config.position_signal == "relative_attention_bias": buckets_dim = mtf.Dimension("buckets", 32) rp_bucket = _relative_position_bucket( mtf.range(mesh, self.memory_seq_dim, tf.int32) - mtf.range(mesh, self.seq_dim, tf.int32), num_buckets=buckets_dim.size) bias_var = mtf.get_variable( mesh, "relative_attention_bias", [self.num_heads_dim, buckets_dim], initializer=tf.zeros_initializer()) attention_biases.append(mtf.gather(bias_var, rp_bucket, buckets_dim)) attention_bias = mtf.add_n(attention_biases) prev_layer_output = self.embedding_output self.all_encoder_layers = [] for block_num in range(self.config.num_blocks): with tf.variable_scope("block_%d" % block_num): for layer_idx, layer_type in enumerate(self.config.block_layers): layer_name = layer_type count = self.config.block_layers[:layer_idx].count(layer_type) if count: layer_name += "_%d" % count with tf.variable_scope(layer_name): x = prev_layer_output if self.config.residual_structure == "direct": x = self.normalize(x) if layer_type == "attention": x = self.self_attention(x, attention_bias) elif layer_type == "feedforward": x = self.feedforward(x) elif layer_type == "moe": x = self.moe(x, layout, mesh_shape, input_mask, is_training) else: raise ValueError("unknown layer type " + layer_type) x = mtf.dropout( x, is_training, keep_prob=1.0 - self.config.layer_output_dropout_prob) layer_output = prev_layer_output + x if self.config.residual_structure == "original": layer_output = self.normalize(layer_output) prev_layer_output = layer_output self.all_encoder_layers.append(layer_output) self.sequence_output = prev_layer_output if self.config.residual_structure == "direct": self.sequence_output = self.normalize(self.sequence_output) # The "pooler" converts the encoded sequence tensor of shape # [batch_dim, seq_dim, hidden_size] to a tensor of shape # [batch_dim, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = mtf.gather(self.sequence_output, 0, self.seq_dim) self.pooled_output = mtf.layers.dense( first_token_tensor, reduced_dims=[self.model_dim], new_dims=[self.model_dim], activation=mtf.tanh, kernel_initializer=self.dense_initializer, use_bias=self.config.use_bias)
def rename_length_to_memory_length(self, x, context): return mtf.replace_dimensions(x, context.length_dim, self.memory_length(context))
def _my_reshape(x): if x and resplittable_dim in x.shape.dims: return mtf.replace_dimensions( x, resplittable_dim, [new_dim_high, new_dim_low]) else: return x
def _reshape_memory(x): x = mtf.replace_dimensions( x, length_dim, [num_blocks, memory_block_length]) return (mtf.left_halo_exchange if fully_autoregressive else mtf.halo_exchange)( x, num_blocks, memory_block_length, radius)
def _reshape_query(x): return mtf.replace_dimensions( x, length_dim, [num_blocks, query_block_length])
def local_attention_1d(q, k, v, length_dim, key_dim, value_dim, fully_autoregressive=True, length_dim_num_splits=1, radius=128, sequence_id=1, write_priority=None, read_priority=None, attention_kwargs=None): """Attention to the a neighborood around the source. If fully_autoregressive, then query position p can only see memory positions in the range (p - radius, p]. If not fully_autoregressive, then query position p can only see memory positions in the range (p - window_size, p + radius]. In addition, if write_priority and read_priority are provided, then attention is limited to position pairs where read_priority[query position] >= write_priority[memory position] Args: q: a Tensor containing length_dim k: a Tensor containing length_dim v: an optional Tensor containing length_dim. If none then uses v=k. length_dim: a Dimension key_dim: a Dimension (the channels dimension of q and k) value_dim: a Dimension (the channels dimension of v) fully_autoregressive: a boolean length_dim_num_splits: an optional integer indicating how many ways the length dimension is split radius: an integer sequence_id: a Tensor or an integer write_priority: an optional Tensor containing length_dim read_priority: an optional Tensor containing length_dim attention_kwargs: optional keyword arguments for attention() Returns: a Tensor with the shape x.shape - key_dim + value_dim Raises: ValueError: if channels or depth don't match. """ # Choose a suitable block size. # We choose the greatest divisor of length_per_split less than or equal # to max(window_size, 128) length_per_split = length_dim.size // length_dim_num_splits block_length = max(radius, 128) while length_per_split % block_length != 0: block_length -= 1 query_block_length = mtf.Dimension("query_block_length", block_length) memory_block_length = mtf.Dimension("memory_block_length", block_length) # The num_blocks dimension gets the same name as the length dimension, # so it will be split in the same way. num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length) def _reshape_query(x): return mtf.replace_dimensions( x, length_dim, [num_blocks, query_block_length]) def _reshape_memory(x): x = mtf.replace_dimensions( x, length_dim, [num_blocks, memory_block_length]) return (mtf.left_halo_exchange if fully_autoregressive else mtf.halo_exchange)( x, num_blocks, memory_block_length, radius) q = _reshape_query(q) k = _reshape_memory(k) if v: v = _reshape_memory(v) else: v = k if sequence_id is None: sequence_id = 1 if (not isinstance(sequence_id, mtf.Tensor) or length_dim not in sequence_id.shape.dims): sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32) q_sequence_id = _reshape_query(sequence_id) m_sequence_id = _reshape_memory(sequence_id) pos = mtf.range(q.mesh, length_dim, dtype=tf.int32) q_pos = _reshape_query(pos) m_pos = _reshape_memory(pos) padded_memory_block_length = mtf.Dimension( "memory_block_length", (1 if fully_autoregressive else 2) * radius + block_length) relative_position = m_pos - q_pos visible = mtf.equal(q_sequence_id, m_sequence_id) visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius)) visible = mtf.logical_and(visible, mtf.less_equal( relative_position, 0 if fully_autoregressive else radius)) if read_priority is not None: write_priority = _reshape_memory(write_priority) read_priority = _reshape_query(read_priority) visible = mtf.logical_and( visible, mtf.greater_equal(read_priority, write_priority)) bias = visibility_mask_to_attention_bias(visible, q.dtype) o = attention(q, k, v, padded_memory_block_length, key_dim, value_dim, bias, **attention_kwargs) return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)
def combine_batch_dims(self, x): if len(self.batch_dims) <= 1: return x return mtf.replace_dimensions(x, self.batch_dims, mtf.combined_dimension(self.batch_dims))
def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None): # x :: [batch, seq, n_embd] x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh # n_state is the same as config["n_embd"], which is also the same as dim_embd. assert n_state.size % params["n_head"] == 0 dim_heads = mtf.Dimension("heads", params["n_head"]) num_mem_kv = params.get("num_mem_kv", 0) use_num_mem_kv = num_mem_kv > 0 with tf.variable_scope(scope): # Compute attention inputs dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"]) mtfparams = mtf.transformer.attention.attention_params_simple( x.mesh, io_dim=dim_embd, kv_dim=dim_kv, heads_dim=dim_heads, variable_dtype=variable_dtype ) q = mtfparams.compute_q(x) k = mtfparams.compute_k(x) v = mtfparams.compute_v(x) if is_incremental_inference(context): one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype) inv_one_hot = 1.0 - one_hot old_k, old_v = context.get_states(2) k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot if exists(context): context.record_new_states([k, v]) with tf.variable_scope("attention"): if attention_type == "local": # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights. radius = params.get("local_attention_radius", 256) if is_incremental_inference(context): q *= one_hot a = mtf_transformer.attention.local_attention_1d( q, k, v, length_dim=k.shape[1], key_dim=dim_kv, value_dim=dim_kv, radius=radius, length_dim_num_splits=1, fully_autoregressive=params["causal"], attention_kwargs={}, ) if is_incremental_inference(context): a = mtf.gather(a, context.position - 1, dim_seq) elif attention_type == "global": # TODO: pass in fake context # Broadcast mask bias across batch and heads if exists(bias): if not is_incremental_inference(context): broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]]) else: # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position bias = mtf.gather(bias, context.position - 1, dim_seq) broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]]) # memory key / values, from all-attention paper if use_num_mem_kv: k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh) k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim) v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim) attn_dropout_rate = params["attn_dropout"] if params["mode"] == "train" else 0 a = mtf_transformer.attention.attention( q, k, v, memory_length_dim=memory_length_dim, key_dim=dim_kv, value_dim=dim_kv, bias=broadcasted_bias, dropout_rate=attn_dropout_rate ) elif attention_type == "linear": linear_attn_fn = causal_linear_attention if params["causal"] else linear_attention a = linear_attn_fn(q, k, v) else: raise NotImplementedError("Unknown attention type {}!".format(attention_type)) with tf.variable_scope("compute_output"): a = mtfparams.compute_output(a, x_shape) with tf.variable_scope("compute_output_bias"): b = mtf.get_variable(x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) a += b if params["mode"] == "train" and params["res_dropout"] > 0: a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout") return a
def call(self, context, x, losses=None): """Call the layer.""" params = self.make_params(context) if self.share_qk_rep: q, k = params.mdha_shared_qk(x, context) else: q = params.mdha_q(x, context) memory_length = self.memory_length(context) if context.mode == "incremental": m = x else: if self.share_qk_rep: k = mtf.replace_dimensions(k, context.length_dim, memory_length) m = mtf.replace_dimensions(x, context.length_dim, memory_length) if self.shared_kv: kv = params.compute_kv(m) else: if not self.share_qk_rep: k = params.mdha_k(m, context) v = params.mdha_v(m, context) if context.mode == "incremental": one_hot = mtf.one_hot(context.position, memory_length, dtype=context.activation_dtype) inv_one_hot = 1.0 - one_hot if self.shared_kv: old_kv = context.get_states(1) kv = old_kv * inv_one_hot + kv * one_hot else: old_k, old_v = context.get_states(2) k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot memory_position = mtf.range(context.mesh, memory_length, tf.int32) else: memory_position = self.rename_length_to_memory_length( context.position, context) if context.mode == "incremental" or context.mode == "first_part": context.record_new_states([kv] if self.shared_kv else [k, v]) if self.shared_kv: k = kv v = kv o = self.attention_fn(q, k, v, context=context, memory_length_dim=memory_length, key_dim=self.kv_dim, value_dim=self.kv_dim, bias=self.compute_bias(context, memory_position, x, params.query_heads_dims, q), **self.attention_kwargs_from_context(context)) attention_output_shape = self.expected_attention_output_shape( x, params) attention_output = params.compute_output( o, output_shape=attention_output_shape) return self.layer_output_from_attention_output(context, attention_output, losses)
def attention(self, x, n_state, mask, attention_type="global", name="attn"): # x :: [batch, seq, n_embd] batch_dim, seq_dim, embd_dim = x_shape = x.shape assert n_state.size % self.n_heads == 0, "n_state must be divisible by n_heads" with tf.variable_scope(name): # Compute attention inputs mtfparams = mtf.transformer.attention.attention_params_simple( x.mesh, io_dim=self.dimensions["embed_dim"], kv_dim=self.dimensions["kv_dim"], heads_dim=self.dimensions["heads_dim"], variable_dtype=self.variable_dtype) q = mtfparams.compute_q(x) k = mtfparams.compute_k(x) v = mtfparams.compute_v(x) if self.is_incremental_inference: one_hot = mtf.one_hot(self.context.position - 1, seq_dim, dtype=self.variable_dtype.master_dtype) inv_one_hot = 1.0 - one_hot old_k, old_v = self.context.get_states(2) k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot if exists(self.context): self.context.record_new_states([k, v]) with tf.variable_scope("attention"): if attention_type == "local": # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights. radius = self.params.get("local_attention_radius", 256) if self.is_incremental_inference: q *= one_hot a = mtf_transformer.attention.local_attention_1d( q, k, v, length_dim=k.shape[1], key_dim=self.dimensions["kv_dim"], value_dim=self.dimensions["kv_dim"], radius=radius, length_dim_num_splits=1, fully_autoregressive=True, attention_kwargs={}, ) if self.is_incremental_inference: a = mtf.gather(a, self.context.position - 1, seq_dim) elif attention_type == "global": if exists(mask): if not self.is_incremental_inference: broadcasted_mask = mtf.broadcast( mask, [ batch_dim, self.dimensions["heads_dim"], mask.shape[-2], mask.shape[-1] ]) # TODO: not sure this is correct else: # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position mask = mtf.gather(mask, self.context.position - 1, seq_dim) broadcasted_mask = mtf.broadcast( mask, [ batch_dim, self.dimensions["heads_dim"], mask.shape[-1] ]) k = mtf.replace_dimensions( k, k.shape[1], self.dimensions["memory_len_dim"]) v = mtf.replace_dimensions( v, v.shape[1], self.dimensions["memory_len_dim"]) attn_dropout_rate = self.params.get( "attention_dropout", 0) if self.mode == "train" else 0 a = mtf_transformer.attention.attention( q, k, v, memory_length_dim=self.dimensions["memory_len_dim"], key_dim=self.dimensions["kv_dim"], value_dim=self.dimensions["kv_dim"], bias=broadcasted_mask, dropout_rate=attn_dropout_rate) else: raise NotImplementedError( "Unknown attention type {}!".format(attention_type)) with tf.variable_scope("compute_output"): a = mtfparams.compute_output(a, x_shape) with tf.variable_scope("compute_output_bias"): b = mtf.get_variable( x.mesh, "o_b", [embd_dim], initializer=tf.constant_initializer(0), master_dtype=self.variable_dtype.master_dtype, slice_dtype=self.variable_dtype.slice_dtype, activation_dtype=self.variable_dtype.activation_dtype) a += b residual_dropout = self.params.get("residual_dropout", 0) if self.mode == "train" and residual_dropout > 0: a = mtf.dropout(a, rate=residual_dropout, name="res_dropout") return a