def crf_binary_score(tag_indices, sequence_lengths, transition_params): ''' Computes the binary scores of tag sequences. Args: tag_indices: A [batch_size, max_seq_len] matrix of tag indices. sequence_lengths: A [batch_size] vector of true sequence lengths. transition_params: A [num_tags, num_tags] matrix of binary potentials. Returns: binary_scores: A [batch_size] vector of binary scores. ''' # Get shape information. num_tags = transition_params.get_shape()[0] num_transitions = tf.shape(tag_indices)[1] - 1 # Truncate by one on each side of the sequence to get the start and end # indices of each transition. start_tag_indices = tf.slice(tag_indices, [0, 0], [-1, num_transitions]) end_tag_indices = tf.slice(tag_indices, [0, 1], [-1, num_transitions]) # Encode the indices in a flattened representation. flattened_transition_indices = \ start_tag_indices * num_tags + end_tag_indices flattened_transition_params = tf.reshape(transition_params, [-1]) # Get the binary scores based on the flattened representation. binary_scores = tf.gather(flattened_transition_params, flattened_transition_indices) masks = tf.sequence_mask(sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=tf.float32) truncated_masks = tf.slice(masks, [0, 1], [-1, -1]) binary_scores = tf.reduce_sum(binary_scores * truncated_masks, 1) return binary_scores
def crf_unary_score(tag_indices, sequence_lengths, inputs): ''' Computes the unary scores of tag sequences. Args: tag_indices: A [batch_size, max_seq_len] matrix of tag indices. sequence_lengths: A [batch_size] vector of true sequence lengths. inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. Returns: unary_scores: A [batch_size] vector of unary scores. ''' batch_size = tf.shape(inputs)[0] max_seq_len = tf.shape(inputs)[1] num_tags = tf.shape(inputs)[2] flattened_inputs = tf.reshape(inputs, [-1]) offsets = tf.expand_dims(tf.range(batch_size) * max_seq_len * num_tags, 1) offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0) # Use int32 or int64 based on tag_indices' dtype. if tag_indices.dtype == tf.int64: offsets = tf.cast(offsets, tf.int64) flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1]) unary_scores = tf.reshape( tf.gather(flattened_inputs, flattened_tag_indices), [batch_size, max_seq_len]) masks = tf.sequence_mask(sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=tf.float32) unary_scores = tf.reduce_sum(unary_scores * masks, 1) return unary_scores
def mask(inputs, key_masks=None, type=None): '''Masks paddings on keys or queries to inputs inputs: 3d tensor. (h*N, T_q, T_k) key_masks: 3d tensor. (N, 1, T_k) type: string. 'key' | 'future' e.g., >> inputs = tf.zeros([2, 2, 3], dtype=tf.float32) >> key_masks = tf.constant([[0., 0., 1.], [0., 1., 1.]]) >> mask(inputs, key_masks=key_masks, type='key') array([[[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09], [ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09]], [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09], [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]], [[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09], [ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09]], [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09], [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]]], dtype=float32) ''' padding_num = -2 ** 32 + 1 if type in ('k', 'key', 'keys'): key_masks = tf.to_float(key_masks) key_masks = tf.tile( key_masks, [tf.shape(inputs)[0] // tf.shape(key_masks)[0], 1]) # (h*N, seqlen) key_masks = tf.expand_dims(key_masks, 1) # (h*N, 1, seqlen) outputs = inputs + key_masks * padding_num # elif type in ('q', 'query', 'queries'): # # Generate masks # masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1)) # (N, T_q) # masks = tf.expand_dims(masks, -1) # (N, T_q, 1) # masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]]) # (N, T_q, T_k) # # # Apply masks to inputs # outputs = inputs*masks elif type in ('f', 'future', 'right'): diag_vals = tf.ones_like(inputs[0, :, :]) # (T_q, T_k) tril = tf.linalg.LinearOperatorLowerTriangular( diag_vals).to_dense() # (T_q, T_k) future_masks = tf.tile( tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1]) # (N, T_q, T_k) paddings = tf.ones_like(future_masks) * padding_num outputs = tf.where(tf.equal(future_masks, 0), paddings, inputs) else: print('Check if you entered type correctly!') return outputs
def rel_attn_core(q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, r_r_bias, r_s_bias, attn_mask, dropatt, is_training, scale): '''Core relative positional attention operations.''' # content based attention score ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h) # position based attention score bd = tf.einsum('ibnd,jbnd->ijbn', q_head + r_r_bias, k_head_r) bd = rel_shift(bd, klen=tf.shape(ac)[1]) # segment based attention score if seg_mat is None: ef = 0 else: ef = tf.einsum('ibnd,snd->ibns', q_head + r_s_bias, seg_embed) ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef) # merge attention scores and perform masking attn_score = (ac + bd + ef) * scale if attn_mask is not None: # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask attn_score = attn_score - 1e30 * attn_mask # attention probability attn_prob = tf.nn.softmax(attn_score, 1) attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training) # attention output attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h) return attn_vec
def __init__(self, is_training, input_tensor, input_mask, label_ids, label_size=2, sample_weight=None, scope='cls/sequence', name='', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) batch_size = tf.shape(input_tensor)[0] seq_length = input_tensor.shape.as_list()[-2] hidden_size = input_tensor.shape.as_list()[-1] with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) output_layer = tf.reshape(output_layer, [-1, hidden_size]) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [-1, seq_length, label_size]) self.preds[name] = tf.argmax(logits, axis=-1) self.probs[name] = tf.nn.softmax(logits, axis=-1, name='probs') log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=label_size, dtype=tf.float32) per_token_losses = -tf.reduce_mean(one_hot_labels * log_probs, axis=-1) input_mask = tf.concat([ tf.zeros((batch_size, 1), dtype=tf.float32), tf.cast(input_mask[:, 2:], dtype=tf.float32), tf.zeros((batch_size, 1), dtype=tf.float32) ], axis=-1) per_token_losses *= input_mask per_example_loss = tf.reduce_mean(per_token_losses, axis=-1) if sample_weight is not None: per_example_loss *= tf.cast(sample_weight, dtype=tf.float32) self.losses[name] = per_example_loss self.total_loss = tf.reduce_mean(per_example_loss)
def rel_shift(x, klen=-1): '''perform relative shift to form the relative attention score.''' x_size = tf.shape(x) x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]]) x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1]) x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]]) x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1]) return x
def _single_seq_fn(): batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0] example_inds = tf.reshape( tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) sequence_scores = tf.gather_nd( tf.squeeze(inputs, [1]), tf.concat([example_inds, tag_indices], axis=1)) sequence_scores = tf.where(tf.less_equal(sequence_lengths, 0), tf.zeros_like(sequence_scores), sequence_scores) return sequence_scores
def summarize_sequence(summary_type, hidden, d_model, n_head, d_head, dropout, dropatt, input_mask, is_training, initializer, scope=None, reuse=None, use_proj=True): ''' Different classification tasks may not may not share the same parameters to summarize the sequence features. If shared, one can keep the `scope` to the default value `None`. Otherwise, one should specify a different `scope` for each task. ''' with tf.variable_scope(scope, 'sequnece_summary', reuse=reuse): if summary_type == 'last': summary = hidden[-1] elif summary_type == 'first': summary = hidden[0] elif summary_type == 'mean': summary = tf.reduce_mean(hidden, axis=0) elif summary_type == 'attn': bsz = tf.shape(hidden)[1] summary_bias = tf.get_variable('summary_bias', [d_model], dtype=hidden.dtype, initializer=initializer) summary_bias = tf.tile(summary_bias[None, None], [1, bsz, 1]) if input_mask is not None: input_mask = input_mask[None, :, :, None] summary = multihead_attn( summary_bias, hidden, hidden, input_mask, d_model, n_head, d_head, dropout, dropatt, is_training, initializer, residual=False) summary = summary[0] else: raise ValueError('Unsupported summary type %s' % summary_type) # use another projection as in BERT if use_proj: summary = tf.layers.dense( summary, d_model, activation=tf.tanh, kernel_initializer=initializer, name='summary') # dropout summary = tf.layers.dropout( summary, dropout, training=is_training, name='dropout') return summary
def positional_encoding(inputs, maxlen, masking=True, scope='positional_encoding'): '''Sinusoidal Positional_Encoding. See 3.5 inputs: 3d tensor. (N, T, E) maxlen: scalar. Must be >= T masking: Boolean. If True, padding positions are set to zeros. scope: Optional scope for `variable_scope`. returns 3d tensor that has the same shape as inputs. ''' E = inputs.get_shape().as_list()[-1] # static N, T = tf.shape(inputs)[0], tf.shape(inputs)[1] # dynamic with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): # position indices position_ind = tf.tile(tf.expand_dims(tf.range(T), 0), [N, 1]) # (N, T) # First part of the PE function: sin and cos argument position_enc = np.array([ [pos / np.power(10000, (i-i%2)/E) for i in range(E)] for pos in range(maxlen)]) # Second part, apply the cosine to even columns and sin to odds. position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1 position_enc = tf.convert_to_tensor( position_enc, tf.float32) # (maxlen, E) # lookup outputs = tf.nn.embedding_lookup(position_enc, position_ind) # masks if masking: outputs = tf.where(tf.equal(inputs, 0), inputs, outputs) return tf.to_float(outputs)
def crf_log_norm(inputs, sequence_lengths, transition_params): ''' Computes the normalization for a CRF. Args: inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials to use as input to the CRF layer. sequence_lengths: A [batch_size] vector of true sequence lengths. transition_params: A [num_tags, num_tags] transition matrix. Returns: log_norm: A [batch_size] vector of normalizers for a CRF. ''' # Split up the first and rest of the inputs in preparation for the forward # algorithm. first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1]) first_input = tf.squeeze(first_input, [1]) # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp # over the 'initial state' (the unary potentials). def _single_seq_fn(): log_norm = tf.reduce_logsumexp(first_input, [1]) # Mask `log_norm` of the sequences with length <= zero. log_norm = tf.where(tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), log_norm) return log_norm def _multi_seq_fn(): '''Forward computation of alpha values.''' rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1]) # Compute the alpha values in the forward algorithm in order to get the # partition function. forward_cell = CrfForwardRnnCell(transition_params) # Sequence length is not allowed to be less than zero. sequence_lengths_less_one = tf.maximum( tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1) _, alphas = rnn.dynamic_rnn(cell=forward_cell, inputs=rest_of_input, sequence_length=sequence_lengths_less_one, initial_state=first_input, dtype=tf.float32) log_norm = tf.reduce_logsumexp(alphas, [1]) # Mask `log_norm` of the sequences with length <= zero. log_norm = tf.where(tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), log_norm) return log_norm return smart.smart_cond(pred=tf.equal( util.get_shape_list(inputs)[1] or tf.shape(inputs)[1], 1), true_fn=_single_seq_fn, false_fn=_multi_seq_fn)
def crf_sequence_score(inputs, tag_indices, sequence_lengths, transition_params): ''' Computes the unnormalized score for a tag sequence. Args: inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials to use as input to the CRF layer. tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we compute the unnormalized score. sequence_lengths: A [batch_size] vector of true sequence lengths. transition_params: A [num_tags, num_tags] transition matrix. Returns: sequence_scores: A [batch_size] vector of unnormalized sequence scores. ''' # If max_seq_len is 1, we skip the score calculation and simply gather the # unary potentials of the single tag. def _single_seq_fn(): batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0] example_inds = tf.reshape( tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) sequence_scores = tf.gather_nd( tf.squeeze(inputs, [1]), tf.concat([example_inds, tag_indices], axis=1)) sequence_scores = tf.where(tf.less_equal(sequence_lengths, 0), tf.zeros_like(sequence_scores), sequence_scores) return sequence_scores def _multi_seq_fn(): # Compute the scores of the given tag sequence. unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) binary_scores = crf_binary_score(tag_indices, sequence_lengths, transition_params) sequence_scores = unary_scores + binary_scores return sequence_scores return smart.smart_cond(pred=tf.equal( util.get_shape_list(inputs)[1] or tf.shape(inputs)[1], 1), true_fn=_single_seq_fn, false_fn=_multi_seq_fn)
def get_shape_list(tensor, expected_rank=None, name=None): '''Returns a list of the shape of tensor, preferring static dimensions.''' if name is None: name = tensor.name if expected_rank is not None: assert_rank(tensor, expected_rank, name) shape = tensor.shape.as_list() non_static_indexes = [] for (index, dim) in enumerate(shape): if dim is None: non_static_indexes.append(index) if not non_static_indexes: return shape dyn_shape = tf.shape(tensor) for index in non_static_indexes: shape[index] = dyn_shape[index] return shape
def _forward(input_ids, past=None): batch, sequence = shape_list(input_ids) if tilda_embeddings is None: wte = tf.get_variable( 'word_embeddings', [hparams.n_vocab, hparams.n_embed], initializer=tf.random_normal_initializer(stddev=0.02)) else: wte = tilda_embeddings wpe = tf.get_variable( 'wpe', [hparams.n_ctx, hparams.n_embed], initializer=tf.random_normal_initializer(stddev=0.01)) past_length = 0 if past is None else tf.shape(past)[-2] h = (tf.gather(wte, input_ids) + tf.gather(wpe, positions_for(input_ids, past_length))) # stacked transformer layers presents = [] pasts = tf.unstack(past, axis=1) if past is not None else \ [None] * hparams.n_layer assert len(pasts) == hparams.n_layer for layer, past in enumerate(pasts): h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) presents.append(present) present = tf.stack(presents, axis=1) h = norm(h, 'ln_f') # Language model loss. Do tokens <n predict token n? h_flat = tf.reshape(h, [batch * sequence, hparams.n_embed]) logits = tf.matmul(h_flat, wte, transpose_b=True) logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab]) return logits, present
def positions_for(tokens, past_length): batch_size = tf.shape(tokens)[0] nsteps = tf.shape(tokens)[1] return expand_tile(past_length + tf.range(nsteps), batch_size)
def shape_list(x): '''Deal with dynamic shape in tensorflow cleanly.''' static = x.shape.as_list() dynamic = tf.shape(x) return [dynamic[i] if s is None else s for i, s in enumerate(static)]
def _local_perm(inputs, targets, is_masked, perm_size, seq_len): ''' Sample a permutation of the factorization order, and create an attention mask accordingly. Args: inputs: int64 Tensor in shape [seq_len], input ids. targets: int64 Tensor in shape [seq_len], target ids. is_masked: bool Tensor in shape [seq_len]. True means being selected for partial prediction. perm_size: the length of longest permutation. Could be set to be reuse_len. Should not be larger than reuse_len or there will be data leaks. seq_len: int, sequence length. ''' batch_size = tf.shape(inputs)[0] # Generate permutation indices index = tf.range(seq_len, dtype=tf.int64) index = tf.reshape(index, [-1, perm_size]) index = tf.transpose(index) index = tf.random_shuffle(index) index = tf.transpose(index) index = tf.reshape(index, [1, -1]) index = tf.tile(index, [batch_size, 1]) # `perm_mask` and `target_mask` # non-functional tokens non_func_tokens = tf.logical_not( tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID))) non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens) masked_or_func_tokens = tf.logical_not(non_mask_tokens) # Set the permutation indices of non-masked (& non-funcional) tokens to the # smallest index (-1): # (1) they can be seen by all other positions # (2) they cannot see masked positions, so there won't be information leak smallest_index = -tf.ones([batch_size, seq_len], dtype=tf.int64) rev_index = tf.where(non_mask_tokens, smallest_index, index) # Create `target_mask`: non-funcional and maksed tokens # 1: use mask as input and have loss # 0: use token (or [SEP], [CLS]) as input and do not have loss target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens) target_mask = tf.cast(target_tokens, tf.float32) # Create `perm_mask` # `target_tokens` cannot see themselves self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1) # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens) # 0: can attend if i > j or j is non-masked perm_mask = tf.logical_and( self_rev_index[:, :, None] <= rev_index[:, None, :], tf.expand_dims(masked_or_func_tokens, axis=-1)) # new target: [next token] for LM and [curr token] (self) for PLM new_targets = tf.concat([inputs[:, 0:1], targets[:, :-1]], axis=1) # construct inputs_k inputs_k = inputs # construct inputs_q inputs_q = target_mask return perm_mask, new_targets, target_mask, inputs_k, inputs_q
def _expand_features(module, split_placeholders): inputs = split_placeholders['input'] target = split_placeholders['target'] is_masked = tf.cast(split_placeholders['is_masked'], tf.bool) batch_size = tf.shape(inputs)[0] non_reuse_len = module.max_seq_length - module.reuse_seq_length assert (module.perm_size <= module.reuse_seq_length and module.perm_size <= non_reuse_len) (perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0) = \ _local_perm( inputs[:, :module.reuse_seq_length], target[:, :module.reuse_seq_length], is_masked[:, :module.reuse_seq_length], module.perm_size, module.reuse_seq_length) (perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1) = \ _local_perm( inputs[:, module.reuse_seq_length:], target[:, module.reuse_seq_length:], is_masked[:, module.reuse_seq_length:], module.perm_size, non_reuse_len) perm_mask_0 = tf.concat([ tf.cast(perm_mask_0, dtype=tf.float32), tf.ones([batch_size, module.reuse_seq_length, non_reuse_len]) ], axis=2) perm_mask_1 = tf.concat([ tf.zeros([batch_size, non_reuse_len, module.reuse_seq_length]), tf.cast(perm_mask_1, dtype=tf.float32) ], axis=2) perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=1) target = tf.concat([target_0, target_1], axis=1) target_mask = tf.concat([target_mask_0, target_mask_1], axis=1) input_k = tf.concat([input_k_0, input_k_1], axis=1) input_q = tf.concat([input_q_0, input_q_1], axis=1) if module._num_predict is not None: #TODO(geying): convert tensors from 1-D to 2-D indices = tf.range(module.max_seq_length, dtype=tf.int64) indices = tf.reshape(indices, [-1, module.max_seq_length]) indices = tf.tile(indices, [batch_size, 1]) bool_target_mask = tf.cast(target_mask, tf.bool) indices = tf.boolean_mask(indices, bool_target_mask) ##### extra padding due to CLS/SEP introduced after prepro actual_num_predict = tf.shape(indices)[1] pad_len = module._num_predict - actual_num_predict ##### target_mapping target_mapping = tf.one_hot(indices, module.max_seq_length, dtype=tf.float32) paddings = tf.zeros([pad_len, module.max_seq_length], dtype=target_mapping.dtype) target_mapping = tf.concat([target_mapping, paddings], axis=0) split_placeholders['target_mapping'] = tf.reshape( target_mapping, [-1, module._num_predict, module.max_seq_length]) ##### target target = tf.boolean_mask(target, bool_target_mask) paddings = tf.zeros([pad_len], dtype=target.dtype) target = tf.concat([target, paddings], axis=0) split_placeholders['target'] = tf.reshape(target, [-1, module._num_predict]) ##### target mask target_mask = tf.concat([ tf.ones([batch_size, actual_num_predict], dtype=tf.float32), tf.zeros([batch_size, pad_len], dtype=tf.float32) ], axis=1) split_placeholders['target_mask'] = tf.reshape( target_mask, [-1, module._num_predict]) else: split_placeholders['target'] = tf.reshape(target, [-1, module.max_seq_length]) split_placeholders['target_mask'] = tf.reshape( target_mask, [-1, module.max_seq_length]) # reshape back to fixed shape split_placeholders['perm_mask'] = tf.reshape( perm_mask, [-1, module.max_seq_length, module.max_seq_length]) split_placeholders['input_k'] = tf.reshape(input_k, [-1, module.max_seq_length]) split_placeholders['input_q'] = tf.reshape(input_q, [-1, module.max_seq_length]) return split_placeholders
def transformer_xl(inp_k, n_token, n_layer, d_model, n_head, d_head, d_inner, dropout, dropatt, attn_type, bi_data, initializer, is_training, mem_len=None, inp_q=None, mems=None, same_length=False, clamp_len=-1, untie_r=False, use_tpu=True, input_mask=None, perm_mask=None, seg_id=None, reuse_len=None, ff_activation='relu', target_mapping=None, use_bfloat16=False, scope='transformer', tilda_embeddings=None, **kwargs): ''' Defines a Transformer-XL computation graph with additional support for XLNet. Args: inp_k: int32 Tensor in shape [len, bsz], the input token IDs. seg_id: int32 Tensor in shape [len, bsz], the input segment IDs. input_mask: float32 Tensor in shape [len, bsz], the input mask. 0 for real tokens and 1 for padding. mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory from previous batches. The length of the list equals n_layer. If None, no memory is used. perm_mask: float32 Tensor in shape [len, len, bsz]. If perm_mask[i, j, k] = 0, i attend to j in batch k; if perm_mask[i, j, k] = 1, i does not attend to j in batch k. If None, each position attends to all the others. target_mapping: float32 Tensor in shape [num_predict, len, bsz]. If target_mapping[i, j, k] = 1, the i-th predict in batch k is on the j-th token. Only used during pretraining for partial prediction. Set to None during finetuning. inp_q: float32 Tensor in shape [len, bsz]. 1 for tokens with losses and 0 for tokens without losses. Only used during pretraining for two-stream attention. Set to None during finetuning. n_layer: int, the number of layers. d_model: int, the hidden size. n_head: int, the number of attention heads. d_head: int, the dimension size of each attention head. d_inner: int, the hidden size in feed-forward layers. ff_activation: str, 'relu' or 'gelu'. untie_r: bool, whether to untie the biases in attention. n_token: int, the vocab size. is_training: bool, whether in training mode. use_tpu: bool, whether TPUs are used. use_bfloat16: bool, use bfloat16 instead of float32. dropout: float, dropout rate. dropatt: float, dropout rate on attention probabilities. init: str, the initialization scheme, either 'normal' or 'uniform'. init_range: float, initialize the parameters with a uniform distribution in [-init_range, init_range]. Only effective when init='uniform'. init_std: float, initialize the parameters with a normal distribution with mean 0 and stddev init_std. Only effective when init='normal'. mem_len: int, the number of tokens to cache. reuse_len: int, the number of tokens in the currect batch to be cached and reused in the future. bi_data: bool, whether to use bidirectional input pipeline. Usually set to True during pretraining and False during finetuning. clamp_len: int, clamp all relative distances larger than clamp_len. -1 means no clamping. same_length: bool, whether to use the same attention length for each token. summary_type: str, 'last', 'first', 'mean', or 'attn'. The method to pool the input to get a vector representation. initializer: A tf initializer. scope: scope name for the computation graph. ''' tf_float = tf.bfloat16 if use_bfloat16 else tf.float32 new_mems = [] with tf.variable_scope(scope): if untie_r: r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) bsz = tf.shape(inp_k)[1] qlen = tf.shape(inp_k)[0] mlen = tf.shape(mems[0])[0] if mems is not None else 0 klen = mlen + qlen ##### Attention mask # causal attention mask if attn_type == 'uni': attn_mask = _create_mask(qlen, mlen, tf_float, same_length) attn_mask = attn_mask[:, :, None, None] elif attn_type == 'bi': attn_mask = None else: raise ValueError('Unsupported attention type: %s' % attn_type) # data mask: input mask & perm mask if input_mask is not None and perm_mask is not None: data_mask = input_mask[None] + perm_mask elif input_mask is not None and perm_mask is None: data_mask = input_mask[None] elif input_mask is None and perm_mask is not None: data_mask = perm_mask else: data_mask = None if data_mask is not None: # all mems can be attended to mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz], dtype=tf_float) data_mask = tf.cast(data_mask, dtype=tf.float32) data_mask = tf.concat([mems_mask, data_mask], 1) if attn_mask is None: attn_mask = data_mask[:, :, :, None] else: attn_mask += data_mask[:, :, :, None] if attn_mask is not None: attn_mask = tf.cast(attn_mask > 0, dtype=tf_float) if attn_mask is not None: non_tgt_mask = -tf.eye(qlen, dtype=tf_float) non_tgt_mask = tf.concat( [tf.zeros([qlen, mlen], dtype=tf_float), non_tgt_mask], axis=-1) non_tgt_mask = tf.cast( (attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=tf_float) else: non_tgt_mask = None ##### Word embedding word_emb_k, lookup_table = embedding_lookup( x=inp_k, n_token=n_token, d_embed=d_model, initializer=initializer, use_tpu=use_tpu, dtype=tf_float, scope='word_embedding', tilda_embeddings=tilda_embeddings) if inp_q is not None: with tf.variable_scope('mask_emb'): mask_emb = tf.get_variable('mask_emb', [1, 1, d_model], dtype=tf_float) if target_mapping is not None: word_emb_q = tf.tile(mask_emb, [tf.shape(target_mapping)[0], bsz, 1]) else: inp_q_ext = inp_q[:, :, None] word_emb_q = \ inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k output_h = tf.layers.dropout(word_emb_k, dropout, training=is_training) if inp_q is not None: output_g = tf.layers.dropout(word_emb_q, dropout, training=is_training) ##### Segment embedding if seg_id is not None: if untie_r: r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: # default case (tie) r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head], dtype=tf_float, initializer=initializer) # Convert `seg_id` to one-hot `seg_mat` mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32) cat_ids = tf.concat([mem_pad, seg_id], 0) # `1` indicates not in the same segment [qlen x klen x bsz] seg_mat = tf.cast( tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])), tf.int32) seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float) else: seg_mat = None ##### Positional encoding pos_emb = relative_positional_encoding(qlen, klen, d_model, clamp_len, attn_type, bi_data, bsz=bsz, dtype=tf_float) pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training) ##### Attention layers if mems is None: mems = [None] * n_layer for i in range(n_layer): # cache new mems new_mems.append(_cache_mem(output_h, mems[i], mem_len, reuse_len)) # segment bias if seg_id is None: r_s_bias_i = None seg_embed_i = None else: r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i] seg_embed_i = seg_embed[i] with tf.variable_scope('layer_{}'.format(i)): if inp_q is not None: output_h, output_g = two_stream_rel_attn( h=output_h, g=output_g, r=pos_emb, r_w_bias=r_w_bias if not untie_r else r_w_bias[i], r_r_bias=r_r_bias if not untie_r else r_r_bias[i], seg_mat=seg_mat, r_s_bias=r_s_bias_i, seg_embed=seg_embed_i, attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask, mems=mems[i], target_mapping=target_mapping, d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer) reuse = True else: reuse = False output_h = rel_multihead_attn( h=output_h, r=pos_emb, r_w_bias=r_w_bias if not untie_r else r_w_bias[i], r_r_bias=r_r_bias if not untie_r else r_r_bias[i], seg_mat=seg_mat, r_s_bias=r_s_bias_i, seg_embed=seg_embed_i, attn_mask=non_tgt_mask, mems=mems[i], d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, reuse=reuse) if inp_q is not None: output_g = positionwise_ffn(inp=output_g, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training) output_h = positionwise_ffn(inp=output_h, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training, reuse=reuse) if inp_q is not None: output = tf.layers.dropout(output_g, dropout, training=is_training) else: output = tf.layers.dropout(output_h, dropout, training=is_training) return output, new_mems, lookup_table