def compute_mask(self, context, memory_position): """Compute attention mask. Args: context: a transformer.Context memory_position: an int32 tensor containing memory_length dimension. Returns: a Tensor or None """ masks = [] min_relative_position = self.min_relative_position(context) max_relative_position = self.max_relative_position(context) if max_relative_position is not None or min_relative_position is not None: relative_position = memory_position - context.position if min_relative_position is not None: illegal = mtf.less(relative_position, min_relative_position) masks.append(mtf.cast(illegal, context.activation_dtype) * -1e9) if max_relative_position is not None: illegal = mtf.greater(relative_position, max_relative_position) masks.append(mtf.cast(illegal, context.activation_dtype) * -1e9) if (context.sequence_id is not None and isinstance(context.sequence_id, mtf.Tensor) and context.length_dim in context.sequence_id.shape): masks.append(mtf.cast( mtf.not_equal( context.sequence_id, self.rename_length_to_memory_length( context.sequence_id, context)), context.activation_dtype) * -1e9) return mtf.add_n(masks) if masks else None
def clip_by_global_norm(grads, clip_norm): """Clip the grads by global norm.""" global_norm = mtf.sqrt( mtf.add_n( [mtf.reduce_sum(mtf.square(t)) for t in grads if t is not None])) multiplier = clip_norm / mtf.maximum(global_norm, clip_norm) clipped_grads = [None if t is None else t * multiplier for t in grads] return clipped_grads, global_norm
def call_simple(self, inputs, targets, compute_loss, mode=tf.estimator.ModeKeys.TRAIN, variable_dtype=mtf.VariableDType(tf.float32), sequence_id=None, position=None, encoder_output=None, encoder_sequence_id=None, shared_params=None): """Compute logits based on inputs (all positions in parallel). This is called during training and evaluation. Args: inputs: an int32 Tensor with shape [<batch_dims>, length_dim] For training autoregressive models this should be equal to mtf.shift(targets, offset=1, dim=length_dim, wrap=False) targets: an optional int32 Tensor with shape [<batch_dims>, length_dim] compute_loss: a boolean mode: a tf.estimator.ModeKeys variable_dtype: a mtf.VariableDType sequence_id: an optional Tensor position: an optional Tensor encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor shared_params: an optional dictionary Returns: logits: a Tensor with shape [<batch_dims>, output_vocab_dim] loss: an optional Scalar (if compute_loss=True) """ context = Context(mesh=inputs.mesh, batch_dims=inputs.shape.dims[:-1], length_dim=inputs.shape.dims[-1], model_dim=self.model_dim, variable_dtype=variable_dtype, mode=mode, autoregressive=self.autoregressive, losses=[] if compute_loss else None, sequence_id=sequence_id, position=position, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape) with tf.variable_scope(self.name): logits = self._call_internal(context, inputs, targets) if compute_loss: loss = mtf.add_n(context.losses) else: loss = None return logits, loss
def talking_heads( self, context, inp, name, input_heads_dims, output_heads_dims, dynamic_projections_from=None): shared_dims = [d for d in input_heads_dims if d in output_heads_dims] reduced_dims = [d for d in input_heads_dims if d not in output_heads_dims] new_dims = [d for d in output_heads_dims if d not in input_heads_dims] if not (reduced_dims or new_dims): # Output dimensions are same as input dimensions. Return the input return inp elif dynamic_projections_from: # There are one or more dynamic talking-heads-projections with tf.variable_scope(name): # static projection - this is the same as the static projection in the # "else" case below. We create the weight matrix with get_variable # instead of calling mtf.layers.dense() so that we can fold the # static projection into one of the dynamic projections. static_p_initializer = mtf.layers.VarianceScalingInitializer()( reduced_dims, new_dims) static_p_shape = ( context.model.ensemble_dims + shared_dims + reduced_dims + new_dims) static_p = mtf.get_variable(inp.mesh, "kernel", static_p_shape, initializer=static_p_initializer, dtype=context.variable_dtype) ps = [] for i, dp_from in enumerate(dynamic_projections_from): kernel_initializer = mtf.layers.VarianceScalingInitializer( self.dynamic_projections_init_scale / mtf.Shape(reduced_dims).size) ps.append( mtf.layers.dense( dp_from, reduced_dims=[context.model.model_dim], new_dims=shared_dims + reduced_dims + new_dims, use_bias=False, activation=None, variable_dtype=context.variable_dtype, name="%s_dynamic_%d" % (name, i), expert_dims=context.model.ensemble_dims, kernel_initializer=kernel_initializer)) # Fold the static projection into one of the static projections. # Mathematically, we could add all the dynamic projections together # here, but it would create a very large tensor which contained # both the query-length and memory-length dimensions, and would # probably be slower in practice. ps[0] += static_p return mtf.add_n( [mtf.einsum([inp, p], reduced_dims=reduced_dims) for p in ps]) else: # No dynamic projections. Static talking-heads projection only return mtf.layers.dense( inp, reduced_dims=reduced_dims, new_dims=new_dims, use_bias=False, activation=None, variable_dtype=context.variable_dtype, name=name, expert_dims=context.model.ensemble_dims + shared_dims)
def call(self, context, x, losses=None): """Call the layer.""" wq, wk, wv, wo = mtf.layers.multihead_attention_params( context.mesh, self.heads_dim, context.model_dim, self.kv_dim, context.variable_dtype) memory_length = mtf.Dimension("memory_length", context.length_dim.size) q = mtf.einsum([x, wq], reduced_dims=[context.model_dim]) if context.mode == "incremental": m = x else: m = mtf.rename_dimension(x, context.length_dim.name, "memory_length") k = mtf.einsum([m, wk], reduced_dims=[context.model_dim]) v = mtf.einsum([m, wv], reduced_dims=[context.model_dim]) if context.mode == "incremental": old_k, old_v = context.get_states(2) one_hot = mtf.one_hot(context.position, memory_length, dtype=context.activation_dtype) inv_one_hot = 1.0 - one_hot k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot if context.mode == "incremental" or context.mode == "first_part": context.record_new_states([k, v]) masks = [] if context.autoregressive: masks.append( mtf.cast( mtf.less( context.position, mtf.range(context.mesh, memory_length, dtype=tf.int32)), context.activation_dtype) * -1e9) if (context.sequence_id is not None and isinstance(context.sequence_id, mtf.Tensor) and context.length_dim in context.sequence_id.shape): masks.append( mtf.cast( mtf.not_equal( context.sequence_id, mtf.layers.rename_length_to_memory_length( context.sequence_id)), context.activation_dtype) * -1e9) mask = mtf.add_n(masks) if masks else None o = mtf.layers.dot_product_attention_v2( q, k, v, memory_length, self.kv_dim, self.kv_dim, mask, self.dropout_rate if context.train else 0.0, [context.length_dim]) return mtf.einsum([o, wo], x.shape, reduced_dims=[self.heads_dim, self.kv_dim])
def compute_mask(self, context, memory_position): """Compute attention mask. Args: context: a transformer.Context memory_position: an int32 tensor containing memory_length dimension. Returns: a Tensor or None """ masks = [] min_relative_position = self.min_relative_position(context) max_relative_position = self.max_relative_position(context) if max_relative_position is not None or min_relative_position is not None: relative_position = memory_position - context.position if min_relative_position is not None: illegal = mtf.less(relative_position, min_relative_position) masks.append( mtf.cast(illegal, context.activation_dtype) * -1e9) if max_relative_position is not None: illegal = mtf.greater(relative_position, max_relative_position) masks.append( mtf.cast(illegal, context.activation_dtype) * -1e9) sequence_id = None # Subsequence id should only be set if we are in the decoder and have # multiple targets per input. This will allow each sub-target to only attend # to itself. if isinstance(context.subsequence_id, mtf.Tensor): sequence_id = context.subsequence_id elif isinstance(context.sequence_id, mtf.Tensor): sequence_id = context.sequence_id if (sequence_id is not None and context.length_dim in sequence_id.shape): masks.append( mtf.cast( mtf.not_equal( sequence_id, self.rename_length_to_memory_length( sequence_id, context)), context.activation_dtype) * -1e9) return mtf.add_n(masks) if masks else None
def get_extra_loss(self): return mtf.add_n(self._extra_losses)
def __init__(self, config, is_training, input_ids, input_mask=None, token_type_ids=None, scope=None, mesh_shape="", layout=""): self.config = copy.deepcopy(config) del config if not is_training: self.config.layer_output_dropout_prob = 0.0 self.config.attention_probs_dropout_prob = 0.0 self.config.feedforward_intermediate_dropout_prob = 0.0 input_shape = input_ids.shape assert input_shape.ndims == 2 self._seq_dim = input_shape.dims[1] self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size) self._extra_losses = [] mesh = input_ids.mesh if token_type_ids is None: token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32) with tf.variable_scope(scope, default_name="bert"): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. self.embedding_table = mtf.get_variable( mesh, "word_embeddings", mtf.Shape([self.vocab_dim, self.model_dim]), initializer=self.embedding_initializer) self.word_embedding_output = mtf.gather( self.embedding_table, input_ids, self.vocab_dim) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. self.embedding_output = self.word_embedding_output token_type_table = mtf.get_variable( mesh, "token_type_embeddings", mtf.Shape([self.token_type_vocab_dim, self.model_dim]), initializer=self.embedding_initializer) if token_type_ids is not None: self.embedding_output += mtf.gather( token_type_table, token_type_ids, self.token_type_vocab_dim) if self.config.position_signal == "embedding": full_position_table = mtf.get_variable( mesh, "position_embeddings", mtf.Shape([self.max_position_embeddings_dim, self.model_dim]), initializer=self.embedding_initializer) short_position_table = mtf.rename_dimension( mtf.slice(full_position_table, 0, self.seq_dim.size, self.max_position_embeddings_dim.name), self.max_position_embeddings_dim.name, self.seq_dim.name) self.embedding_output += short_position_table self.embedding_output = self.normalize(self.embedding_output) self.embedding_output = mtf.dropout( self.embedding_output, is_training, keep_prob=1.0 - self.config.layer_output_dropout_prob) with tf.variable_scope("encoder"): attention_biases = [] if input_mask: # [batch_dim, memory_seq_dim] attention_biases.append( (1.0 - mtf.to_float(mtf.replace_dimensions( input_mask, self.seq_dim, self.memory_seq_dim))) * -10000.0) if self.config.position_signal == "relative_attention_bias": buckets_dim = mtf.Dimension("buckets", 32) rp_bucket = _relative_position_bucket( mtf.range(mesh, self.memory_seq_dim, tf.int32) - mtf.range(mesh, self.seq_dim, tf.int32), num_buckets=buckets_dim.size) bias_var = mtf.get_variable( mesh, "relative_attention_bias", [self.num_heads_dim, buckets_dim], initializer=tf.zeros_initializer()) attention_biases.append(mtf.gather(bias_var, rp_bucket, buckets_dim)) attention_bias = mtf.add_n(attention_biases) prev_layer_output = self.embedding_output self.all_encoder_layers = [] for block_num in range(self.config.num_blocks): with tf.variable_scope("block_%d" % block_num): for layer_idx, layer_type in enumerate(self.config.block_layers): layer_name = layer_type count = self.config.block_layers[:layer_idx].count(layer_type) if count: layer_name += "_%d" % count with tf.variable_scope(layer_name): x = prev_layer_output if self.config.residual_structure == "direct": x = self.normalize(x) if layer_type == "attention": x = self.self_attention(x, attention_bias) elif layer_type == "feedforward": x = self.feedforward(x) elif layer_type == "moe": x = self.moe(x, layout, mesh_shape, input_mask, is_training) else: raise ValueError("unknown layer type " + layer_type) x = mtf.dropout( x, is_training, keep_prob=1.0 - self.config.layer_output_dropout_prob) layer_output = prev_layer_output + x if self.config.residual_structure == "original": layer_output = self.normalize(layer_output) prev_layer_output = layer_output self.all_encoder_layers.append(layer_output) self.sequence_output = prev_layer_output if self.config.residual_structure == "direct": self.sequence_output = self.normalize(self.sequence_output) # The "pooler" converts the encoded sequence tensor of shape # [batch_dim, seq_dim, hidden_size] to a tensor of shape # [batch_dim, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = mtf.gather(self.sequence_output, 0, self.seq_dim) self.pooled_output = mtf.layers.dense( first_token_tensor, reduced_dims=[self.model_dim], new_dims=[self.model_dim], activation=mtf.tanh, kernel_initializer=self.dense_initializer, use_bias=self.config.use_bias)
def compute_bias(self, context, memory_position, x): """Compute attention bias. Args: context: a transformer.Context memory_position: an int32 tensor containing memory_length dimension. x: a Tensor - the query antecedent - required for relative attention Returns: a Tensor or None """ min_relative_position = self.min_relative_position(context) max_relative_position = self.max_relative_position(context) # we can often cache the result of this function between similar layers can_cache = (self.relative_attention_type is None or self.relative_attention_type == "bias_shared") if can_cache: cache_key = ("self_attention_mask", min_relative_position, max_relative_position, self.relative_attention_type, self.num_heads) if cache_key in context.cache: return context.cache[cache_key] biases = [] relative_position = memory_position - context.position if min_relative_position is not None: visible = mtf.greater_equal(relative_position, min_relative_position) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if max_relative_position is not None: visible = mtf.less_equal(relative_position, max_relative_position) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if context.read_priority is not None: visible = mtf.greater_equal( context.read_priority, mtf.layers.rename_length_to_memory_length( context.write_priority)) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) sequence_id = None # Subsequence id should only be set if we are in the decoder and have # multiple targets per input. This will allow each sub-target to only attend # to itself. if isinstance(context.subsequence_id, mtf.Tensor): sequence_id = context.subsequence_id elif isinstance(context.sequence_id, mtf.Tensor): sequence_id = context.sequence_id if (sequence_id is not None and context.length_dim in sequence_id.shape): visible = mtf.equal( sequence_id, self.rename_length_to_memory_length(sequence_id, context)) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if self.relative_attention_type is not None: buckets_dim = mtf.Dimension("buckets", self.relative_attention_num_buckets) heads_dim = mtf.Dimension("heads", self.num_heads) bidirectional = not context.model.fully_autoregressive rp_bucket = _relative_position_bucket(relative_position, bidirectional=bidirectional, num_buckets=buckets_dim.size) if (self.relative_attention_type == "bias" or self.relative_attention_type == "bias_shared"): values = mtf.get_variable(context.mesh, "relative_attention_bias", [heads_dim, buckets_dim], dtype=context.variable_dtype) elif self.relative_attention_type == "contextual": values = layers.dense(x, [buckets_dim, heads_dim], variable_dtype=context.variable_dtype, name="relative_attention_contextual") else: raise ValueError( "unrecognized relative_attention_type \"%s\"" % self.relative_attention_type) biases.append(mtf.gather(values, rp_bucket, buckets_dim)) ret = mtf.add_n(biases) if biases else None if can_cache: context.cache[cache_key] = ret return ret
def call_simple(self, inputs, targets, compute_loss, attributes=None, mode=tf.estimator.ModeKeys.TRAIN, variable_dtype=mtf.VariableDType(tf.float32), sequence_id=None, subsequence_id=None, position=None, encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, shared_params=None, layer_outputs=None, encoder_layer_outputs=None, z=None): """Compute logits based on inputs (all positions in parallel). This is called during training and evaluation. Args: inputs: an int32 Tensor with shape [<batch_dims>, length_dim] For training autoregressive models this should be equal to mtf.shift(targets, offset=1, dim=length_dim, wrap=False) targets: an optional int32 Tensor with shape [<batch_dims>, length_dim] compute_loss: a boolean attributes: an (optional?) int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>]) mode: a tf.estimator.ModeKeys variable_dtype: a mtf.VariableDType sequence_id: an optional Tensor subsequence_id: an optional Tensor position: an optional Tensor encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor shared_params: an optional dictionary layer_outputs: an optional list to append Tensor layer activations to encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: logits: a Tensor with shape [<batch_dims>, output_vocab_dim] loss: an optional Scalar (if compute_loss=True) """ batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] length_range = mtf.range(inputs.mesh, length_dim, dtype=tf.int32) if not self.positional_embedding: # To make relative attention faster, we drop the information about the # position in the subsequence. The relative attention code then # assumes that the positions are given by index in the tensor, # which still leads to the correct computation of relative position. position = None if position is None: position_is_default = True position = length_range else: position_is_default = False if self.input_full_attention: # The inputs part of each sequence can fully attend within itself. full_attention_region = delimited_lm_inputs_mask(targets) # We can include one additional position to the right - the position # where the final EOS of the inputs is read and the first target token # is predicted. full_attention_region = mtf.logical_or( full_attention_region, mtf.shift(full_attention_region, offset=1, dim=length_dim, wrap=False)) # We set read_priority and write_priority to 0 in the full-attention # region and equal to the position elsewhere. read_priority = write_priority = length_range * mtf.cast( mtf.logical_not(full_attention_region), tf.int32) elif self.autoregressive: # Vanilla autoregressive model - each position can see previous positions. read_priority = write_priority = length_range else: read_priority = write_priority = None context = Context(model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode=mode, losses=[] if compute_loss else None, sequence_id=sequence_id, subsequence_id=subsequence_id, position=position, position_is_default=position_is_default, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, layer_outputs=layer_outputs, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name): logits = self._call_internal(context, inputs, targets, attributes, z=z) if compute_loss: loss = mtf.add_n(context.losses) else: loss = None return logits, loss