def attn(x, scope, n_state, *, past, hparams): assert x.shape.ndims == 3 # Should be [batch, sequence, features] assert n_state % hparams.n_head == 0 if past is not None: assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, # features], where 2 is [k, v] def split_heads(x): # From [batch, sequence, features] to [batch, heads, # sequence, features] return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3]) def merge_heads(x): # Reverse of split_heads return merge_states(tf.transpose(x, [0, 2, 1, 3])) def mask_attn_weights(w): # w has shape [batch, heads, dst_sequence, src_sequence], where # information flows from src to dst. _, _, nd, ns = shape_list(w) b = attention_mask(nd, ns, dtype=w.dtype) b = tf.reshape(b, [1, 1, nd, ns]) w = w * b - tf.cast(1e10, w.dtype) * (1 - b) return w def multihead_attn(q, k, v): # q, k, v have shape [batch, heads, sequence, features] w = tf.matmul(q, k, transpose_b=True) w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype)) w = mask_attn_weights(w) w = softmax(w) a = tf.matmul(w, v) return a with tf.variable_scope(scope): c = conv1d(x, 'c_attn', n_state * 3) q, k, v = map(split_heads, tf.split(c, 3, axis=2)) present = tf.stack([k, v], axis=1) if past is not None: pk, pv = tf.unstack(past, axis=1) k = tf.concat([pk, k], axis=-2) v = tf.concat([pv, v], axis=-2) a = multihead_attn(q, k, v) a = merge_heads(a) a = conv1d(a, 'c_proj', n_state) return a, present
def _forward(input_ids, past=None): batch, sequence = shape_list(input_ids) if tilda_embeddings is None: wte = tf.get_variable( 'word_embeddings', [hparams.n_vocab, hparams.n_embed], initializer=tf.random_normal_initializer(stddev=0.02)) else: wte = tilda_embeddings wpe = tf.get_variable( 'wpe', [hparams.n_ctx, hparams.n_embed], initializer=tf.random_normal_initializer(stddev=0.01)) past_length = 0 if past is None else tf.shape(past)[-2] h = (tf.gather(wte, input_ids) + tf.gather(wpe, positions_for(input_ids, past_length))) # stacked transformer layers presents = [] pasts = tf.unstack(past, axis=1) if past is not None else \ [None] * hparams.n_layer assert len(pasts) == hparams.n_layer for layer, past in enumerate(pasts): h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) presents.append(present) present = tf.stack(presents, axis=1) h = norm(h, 'ln_f') # Language model loss. Do tokens <n predict token n? h_flat = tf.reshape(h, [batch * sequence, hparams.n_embed]) logits = tf.matmul(h_flat, wte, transpose_b=True) logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab]) return logits, present
def transformer_model(input_tensor, attention_mask=None, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, intermediate_act_fn=util.gelu, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, initializer_range=0.02, do_return_all_layers=False): '''Multi-headed, multi-layer Transformer from 'Attention is All You Need'. This is almost an exact implementation of the original Transformer encoder. See the original paper: https://arxiv.org/abs/1706.03762 Also see: https://github.com/tensorflow/tensor2tensor/blob/master/ tensor2tensor/models/transformer.py Args: input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, seq_length], with 1 for positions that can be attended to and 0 in positions that should not be. hidden_size: int. Hidden size of the Transformer. num_hidden_layers: int. Number of layers (blocks) in the Transformer. num_attention_heads: int. Number of attention heads in the Transformer. intermediate_size: int. The size of the 'intermediate' (a.k.a., feed forward) layer. intermediate_act_fn: function. The non-linear activation function to apply to the output of the intermediate/feed-forward layer. hidden_dropout_prob: float. Dropout probability for the hidden layers. attention_probs_dropout_prob: float. Dropout probability of the attention probabilities. initializer_range: float. Range of the initializer (stddev of truncated normal). do_return_all_layers: Whether to also return all layers or just the final layer. Returns: float Tensor of shape [batch_size, seq_length, hidden_size], the final hidden layer of the Transformer. Raises: ValueError: A Tensor shape or parameter is invalid. ''' if hidden_size % num_attention_heads != 0: raise ValueError( 'The hidden size (%d) is not a multiple of the number of attention ' 'heads (%d)' % (hidden_size, num_attention_heads)) attention_head_size = int(hidden_size / num_attention_heads) input_shape = util.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape[0] seq_length = input_shape[1] input_width = input_shape[2] # The Transformer performs sum residuals on all layers so the input needs # to be the same as the hidden size. if input_width != hidden_size: raise ValueError( 'The width of the input tensor (%d) != hidden size (%d)' % (input_width, hidden_size)) # We keep the representation as a 2D tensor to avoid re-shaping it back and # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on # the GPU/CPU but may not be free on the TPU, so we want to minimize them to # help the optimizer. prev_output = util.reshape_to_matrix(input_tensor) attn_maps = [] all_layer_outputs = [] for layer_idx in range(num_hidden_layers): with tf.variable_scope('layer_%d' % layer_idx): with tf.variable_scope('attention'): attention_heads = [] with tf.variable_scope('self'): attention_head, probs = attention_layer( from_tensor=prev_output, to_tensor=prev_output, attention_mask=attention_mask, num_attention_heads=num_attention_heads, size_per_head=attention_head_size, attention_probs_dropout_prob= attention_probs_dropout_prob, initializer_range=initializer_range, do_return_2d_tensor=True, batch_size=batch_size, from_seq_length=seq_length, to_seq_length=seq_length) attention_heads.append(attention_head) attn_maps.append(probs) attention_output = None if len(attention_heads) == 1: attention_output = attention_heads[0] else: # In the case where we have other sequences, we just concatenate # them to the self-attention head before the projection. attention_output = tf.concat(attention_heads, axis=-1) # Run a linear projection of `hidden_size` then add a residual # with `layer_input`. with tf.variable_scope('output'): attention_output = tf.layers.dense( attention_output, hidden_size, kernel_initializer=util.create_initializer( initializer_range)) attention_output = util.dropout(attention_output, hidden_dropout_prob) attention_output = util.layer_norm(attention_output + prev_output) # The activation is only applied to the 'intermediate' hidden layer. with tf.variable_scope('intermediate'): intermediate_output = tf.layers.dense( attention_output, intermediate_size, activation=intermediate_act_fn, kernel_initializer=util.create_initializer( initializer_range)) # Down-project back to `hidden_size` then add the residual. with tf.variable_scope('output'): prev_output = tf.layers.dense( intermediate_output, hidden_size, kernel_initializer=util.create_initializer( initializer_range)) prev_output = util.dropout(prev_output, hidden_dropout_prob) prev_output = util.layer_norm(prev_output + attention_output) all_layer_outputs.append(prev_output) attn_maps = tf.stack(attn_maps, 0) if do_return_all_layers: return tf.stack([ util.reshape_from_matrix(layer, input_shape) for layer in all_layer_outputs ], 0), attn_maps else: return util.reshape_from_matrix(prev_output, input_shape), attn_maps