def _forward(dilated_ids, dilated_mask): logits = self._bert_forward( bert_config, dilated_ids, dilated_mask, batch_size, dilated_seq_length, tilda_embeddings=tilda_embeddings) output_ids = tf.argmax(logits, axis=-1) output_ids = tf.cast(output_ids, dtype=tf.int32) # special padding (using `spad` token) equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32) equal_zero = tf.reduce_sum(equal_zero, axis=-1) right_pad = spad_id * tf.sequence_mask( equal_zero, dilated_seq_length, dtype=tf.int32) paded = tf.concat([output_ids, right_pad], axis=-1) # extract ids of length `max_seq_length` flattened_padded = tf.reshape(paded, [-1]) is_valid = tf.cast(tf.greater(flattened_padded, 0), dtype=tf.int32) flattened_valid = tf.boolean_mask(flattened_padded, is_valid) valid = tf.reshape(flattened_valid, [batch_size, dilated_seq_length]) cutted_valid = valid[:, :max_seq_length] # replace `spad` token with `pad` non_spad_mask = tf.cast(tf.not_equal( cutted_valid, spad_id), dtype=tf.int32) output_ids = cutted_valid * non_spad_mask output_length = tf.reduce_sum(non_spad_mask, axis=-1) # dilate reshaped_ids = tf.reshape(output_ids, [batch_size, max_seq_length, 1]) reshaped_mask = tf.reshape( tf.sequence_mask(output_length, max_seq_length, dtype=tf.int32), [batch_size, max_seq_length, 1]) concat_ids = tf.concat( [reshaped_ids, tf.zeros_like(reshaped_ids)], axis=-1) concat_mask = tf.concat([ reshaped_mask, tf.zeros_like(reshaped_mask, dtype=tf.int32) ], axis=-1) dilated_ids = tf.reshape(concat_ids, [batch_size, max_seq_length * 2]) dilated_mask = tf.reshape(concat_mask, [batch_size, max_seq_length * 2]) return dilated_ids, dilated_mask
def mask(inputs, key_masks=None, type=None): '''Masks paddings on keys or queries to inputs inputs: 3d tensor. (h*N, T_q, T_k) key_masks: 3d tensor. (N, 1, T_k) type: string. 'key' | 'future' e.g., >> inputs = tf.zeros([2, 2, 3], dtype=tf.float32) >> key_masks = tf.constant([[0., 0., 1.], [0., 1., 1.]]) >> mask(inputs, key_masks=key_masks, type='key') array([[[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09], [ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09]], [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09], [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]], [[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09], [ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09]], [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09], [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]]], dtype=float32) ''' padding_num = -2 ** 32 + 1 if type in ('k', 'key', 'keys'): key_masks = tf.to_float(key_masks) key_masks = tf.tile( key_masks, [tf.shape(inputs)[0] // tf.shape(key_masks)[0], 1]) # (h*N, seqlen) key_masks = tf.expand_dims(key_masks, 1) # (h*N, 1, seqlen) outputs = inputs + key_masks * padding_num # elif type in ('q', 'query', 'queries'): # # Generate masks # masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1)) # (N, T_q) # masks = tf.expand_dims(masks, -1) # (N, T_q, 1) # masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]]) # (N, T_q, T_k) # # # Apply masks to inputs # outputs = inputs*masks elif type in ('f', 'future', 'right'): diag_vals = tf.ones_like(inputs[0, :, :]) # (T_q, T_k) tril = tf.linalg.LinearOperatorLowerTriangular( diag_vals).to_dense() # (T_q, T_k) future_masks = tf.tile( tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1]) # (N, T_q, T_k) paddings = tf.ones_like(future_masks) * padding_num outputs = tf.where(tf.equal(future_masks, 0), paddings, inputs) else: print('Check if you entered type correctly!') return outputs
def crf_log_norm(inputs, sequence_lengths, transition_params): ''' Computes the normalization for a CRF. Args: inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials to use as input to the CRF layer. sequence_lengths: A [batch_size] vector of true sequence lengths. transition_params: A [num_tags, num_tags] transition matrix. Returns: log_norm: A [batch_size] vector of normalizers for a CRF. ''' # Split up the first and rest of the inputs in preparation for the forward # algorithm. first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1]) first_input = tf.squeeze(first_input, [1]) # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp # over the 'initial state' (the unary potentials). def _single_seq_fn(): log_norm = tf.reduce_logsumexp(first_input, [1]) # Mask `log_norm` of the sequences with length <= zero. log_norm = tf.where(tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), log_norm) return log_norm def _multi_seq_fn(): '''Forward computation of alpha values.''' rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1]) # Compute the alpha values in the forward algorithm in order to get the # partition function. forward_cell = CrfForwardRnnCell(transition_params) # Sequence length is not allowed to be less than zero. sequence_lengths_less_one = tf.maximum( tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1) _, alphas = rnn.dynamic_rnn(cell=forward_cell, inputs=rest_of_input, sequence_length=sequence_lengths_less_one, initial_state=first_input, dtype=tf.float32) log_norm = tf.reduce_logsumexp(alphas, [1]) # Mask `log_norm` of the sequences with length <= zero. log_norm = tf.where(tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), log_norm) return log_norm return smart.smart_cond(pred=tf.equal( util.get_shape_list(inputs)[1] or tf.shape(inputs)[1], 1), true_fn=_single_seq_fn, false_fn=_multi_seq_fn)
def _forward(dilated_ids, dilated_mask): logits = self._bert_forward( bert_config, dilated_ids, dilated_mask, batch_size, dilated_seq_length, tilda_embeddings=tilda_embeddings) output_ids = tf.argmax(logits, axis=-1) output_ids = tf.cast(output_ids, dtype=tf.int32) equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32) equal_zero = tf.reduce_sum(equal_zero, axis=-1) right_pad = spad_id * tf.sequence_mask( equal_zero, dilated_seq_length, dtype=tf.int32) paded = tf.concat([output_ids, right_pad], axis=-1) flattened_padded = tf.reshape(paded, [-1]) is_valid = tf.cast(tf.greater(flattened_padded, 0), dtype=tf.int32) flattened_valid = tf.boolean_mask(flattened_padded, is_valid) valid = tf.reshape(flattened_valid, [batch_size, dilated_seq_length]) cutted_valid = valid[:, :max_seq_length] nonpad_mask = tf.cast(tf.not_equal(cutted_valid, spad_id), dtype=tf.int32) output_ids = cutted_valid * nonpad_mask reshaped = tf.reshape(output_ids, [batch_size, max_seq_length, 1]) concatenated = tf.concat( [reshaped, tf.zeros_like(reshaped)], axis=-1) dilated_ids = tf.reshape(concatenated, [batch_size, max_seq_length * 2]) input_mask = tf.reduce_sum(nonpad_mask, axis=-1) dilated_mask = tf.sequence_mask(input_mask, dilated_seq_length, dtype=tf.int32) return dilated_ids, dilated_mask
def _get_fake_data(self, inputs, mlm_logits): '''Sample from the generator to create corrupted input.''' inputs = unmask(inputs) disallow = tf.one_hot( inputs.masked_lm_ids, depth=self.bert_config.vocab_size, dtype=tf.float32) if self.config.disallow_correct else None sampled_tokens = tf.stop_gradient(sample_from_softmax( mlm_logits / self.config.temperature, disallow=disallow)) sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32) updated_input_ids, masked = scatter_update( inputs.input_ids, sampled_tokids, inputs.masked_lm_positions) labels = masked * (1 - tf.cast( tf.equal(updated_input_ids, inputs.input_ids), tf.int32)) updated_inputs = get_updated_inputs( inputs, input_ids=updated_input_ids) FakedData = collections.namedtuple('FakedData', [ 'inputs', 'is_fake_tokens', 'sampled_tokens']) return FakedData(inputs=updated_inputs, is_fake_tokens=labels, sampled_tokens=sampled_tokens)
def crf_sequence_score(inputs, tag_indices, sequence_lengths, transition_params): ''' Computes the unnormalized score for a tag sequence. Args: inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials to use as input to the CRF layer. tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we compute the unnormalized score. sequence_lengths: A [batch_size] vector of true sequence lengths. transition_params: A [num_tags, num_tags] transition matrix. Returns: sequence_scores: A [batch_size] vector of unnormalized sequence scores. ''' # If max_seq_len is 1, we skip the score calculation and simply gather the # unary potentials of the single tag. def _single_seq_fn(): batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0] example_inds = tf.reshape( tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) sequence_scores = tf.gather_nd( tf.squeeze(inputs, [1]), tf.concat([example_inds, tag_indices], axis=1)) sequence_scores = tf.where(tf.less_equal(sequence_lengths, 0), tf.zeros_like(sequence_scores), sequence_scores) return sequence_scores def _multi_seq_fn(): # Compute the scores of the given tag sequence. unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) binary_scores = crf_binary_score(tag_indices, sequence_lengths, transition_params) sequence_scores = unary_scores + binary_scores return sequence_scores return smart.smart_cond(pred=tf.equal( util.get_shape_list(inputs)[1] or tf.shape(inputs)[1], 1), true_fn=_single_seq_fn, false_fn=_multi_seq_fn)
def positional_encoding(inputs, maxlen, masking=True, scope='positional_encoding'): '''Sinusoidal Positional_Encoding. See 3.5 inputs: 3d tensor. (N, T, E) maxlen: scalar. Must be >= T masking: Boolean. If True, padding positions are set to zeros. scope: Optional scope for `variable_scope`. returns 3d tensor that has the same shape as inputs. ''' E = inputs.get_shape().as_list()[-1] # static N, T = tf.shape(inputs)[0], tf.shape(inputs)[1] # dynamic with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): # position indices position_ind = tf.tile(tf.expand_dims(tf.range(T), 0), [N, 1]) # (N, T) # First part of the PE function: sin and cos argument position_enc = np.array([ [pos / np.power(10000, (i-i%2)/E) for i in range(E)] for pos in range(maxlen)]) # Second part, apply the cosine to even columns and sin to odds. position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1 position_enc = tf.convert_to_tensor( position_enc, tf.float32) # (maxlen, E) # lookup outputs = tf.nn.embedding_lookup(position_enc, position_ind) # masks if masking: outputs = tf.where(tf.equal(inputs, 0), inputs, outputs) return tf.to_float(outputs)
def _local_perm(inputs, targets, is_masked, perm_size, seq_len): ''' Sample a permutation of the factorization order, and create an attention mask accordingly. Args: inputs: int64 Tensor in shape [seq_len], input ids. targets: int64 Tensor in shape [seq_len], target ids. is_masked: bool Tensor in shape [seq_len]. True means being selected for partial prediction. perm_size: the length of longest permutation. Could be set to be reuse_len. Should not be larger than reuse_len or there will be data leaks. seq_len: int, sequence length. ''' batch_size = tf.shape(inputs)[0] # Generate permutation indices index = tf.range(seq_len, dtype=tf.int64) index = tf.reshape(index, [-1, perm_size]) index = tf.transpose(index) index = tf.random_shuffle(index) index = tf.transpose(index) index = tf.reshape(index, [1, -1]) index = tf.tile(index, [batch_size, 1]) # `perm_mask` and `target_mask` # non-functional tokens non_func_tokens = tf.logical_not( tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID))) non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens) masked_or_func_tokens = tf.logical_not(non_mask_tokens) # Set the permutation indices of non-masked (& non-funcional) tokens to the # smallest index (-1): # (1) they can be seen by all other positions # (2) they cannot see masked positions, so there won't be information leak smallest_index = -tf.ones([batch_size, seq_len], dtype=tf.int64) rev_index = tf.where(non_mask_tokens, smallest_index, index) # Create `target_mask`: non-funcional and maksed tokens # 1: use mask as input and have loss # 0: use token (or [SEP], [CLS]) as input and do not have loss target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens) target_mask = tf.cast(target_tokens, tf.float32) # Create `perm_mask` # `target_tokens` cannot see themselves self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1) # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens) # 0: can attend if i > j or j is non-masked perm_mask = tf.logical_and( self_rev_index[:, :, None] <= rev_index[:, None, :], tf.expand_dims(masked_or_func_tokens, axis=-1)) # new target: [next token] for LM and [curr token] (self) for PLM new_targets = tf.concat([inputs[:, 0:1], targets[:, :-1]], axis=1) # construct inputs_k inputs_k = inputs # construct inputs_q inputs_q = target_mask return perm_mask, new_targets, target_mask, inputs_k, inputs_q
def transformer_xl(inp_k, n_token, n_layer, d_model, n_head, d_head, d_inner, dropout, dropatt, attn_type, bi_data, initializer, is_training, mem_len=None, inp_q=None, mems=None, same_length=False, clamp_len=-1, untie_r=False, use_tpu=True, input_mask=None, perm_mask=None, seg_id=None, reuse_len=None, ff_activation='relu', target_mapping=None, use_bfloat16=False, scope='transformer', tilda_embeddings=None, **kwargs): ''' Defines a Transformer-XL computation graph with additional support for XLNet. Args: inp_k: int32 Tensor in shape [len, bsz], the input token IDs. seg_id: int32 Tensor in shape [len, bsz], the input segment IDs. input_mask: float32 Tensor in shape [len, bsz], the input mask. 0 for real tokens and 1 for padding. mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory from previous batches. The length of the list equals n_layer. If None, no memory is used. perm_mask: float32 Tensor in shape [len, len, bsz]. If perm_mask[i, j, k] = 0, i attend to j in batch k; if perm_mask[i, j, k] = 1, i does not attend to j in batch k. If None, each position attends to all the others. target_mapping: float32 Tensor in shape [num_predict, len, bsz]. If target_mapping[i, j, k] = 1, the i-th predict in batch k is on the j-th token. Only used during pretraining for partial prediction. Set to None during finetuning. inp_q: float32 Tensor in shape [len, bsz]. 1 for tokens with losses and 0 for tokens without losses. Only used during pretraining for two-stream attention. Set to None during finetuning. n_layer: int, the number of layers. d_model: int, the hidden size. n_head: int, the number of attention heads. d_head: int, the dimension size of each attention head. d_inner: int, the hidden size in feed-forward layers. ff_activation: str, 'relu' or 'gelu'. untie_r: bool, whether to untie the biases in attention. n_token: int, the vocab size. is_training: bool, whether in training mode. use_tpu: bool, whether TPUs are used. use_bfloat16: bool, use bfloat16 instead of float32. dropout: float, dropout rate. dropatt: float, dropout rate on attention probabilities. init: str, the initialization scheme, either 'normal' or 'uniform'. init_range: float, initialize the parameters with a uniform distribution in [-init_range, init_range]. Only effective when init='uniform'. init_std: float, initialize the parameters with a normal distribution with mean 0 and stddev init_std. Only effective when init='normal'. mem_len: int, the number of tokens to cache. reuse_len: int, the number of tokens in the currect batch to be cached and reused in the future. bi_data: bool, whether to use bidirectional input pipeline. Usually set to True during pretraining and False during finetuning. clamp_len: int, clamp all relative distances larger than clamp_len. -1 means no clamping. same_length: bool, whether to use the same attention length for each token. summary_type: str, 'last', 'first', 'mean', or 'attn'. The method to pool the input to get a vector representation. initializer: A tf initializer. scope: scope name for the computation graph. ''' tf_float = tf.bfloat16 if use_bfloat16 else tf.float32 new_mems = [] with tf.variable_scope(scope): if untie_r: r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) bsz = tf.shape(inp_k)[1] qlen = tf.shape(inp_k)[0] mlen = tf.shape(mems[0])[0] if mems is not None else 0 klen = mlen + qlen ##### Attention mask # causal attention mask if attn_type == 'uni': attn_mask = _create_mask(qlen, mlen, tf_float, same_length) attn_mask = attn_mask[:, :, None, None] elif attn_type == 'bi': attn_mask = None else: raise ValueError('Unsupported attention type: %s' % attn_type) # data mask: input mask & perm mask if input_mask is not None and perm_mask is not None: data_mask = input_mask[None] + perm_mask elif input_mask is not None and perm_mask is None: data_mask = input_mask[None] elif input_mask is None and perm_mask is not None: data_mask = perm_mask else: data_mask = None if data_mask is not None: # all mems can be attended to mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz], dtype=tf_float) data_mask = tf.cast(data_mask, dtype=tf.float32) data_mask = tf.concat([mems_mask, data_mask], 1) if attn_mask is None: attn_mask = data_mask[:, :, :, None] else: attn_mask += data_mask[:, :, :, None] if attn_mask is not None: attn_mask = tf.cast(attn_mask > 0, dtype=tf_float) if attn_mask is not None: non_tgt_mask = -tf.eye(qlen, dtype=tf_float) non_tgt_mask = tf.concat( [tf.zeros([qlen, mlen], dtype=tf_float), non_tgt_mask], axis=-1) non_tgt_mask = tf.cast( (attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=tf_float) else: non_tgt_mask = None ##### Word embedding word_emb_k, lookup_table = embedding_lookup( x=inp_k, n_token=n_token, d_embed=d_model, initializer=initializer, use_tpu=use_tpu, dtype=tf_float, scope='word_embedding', tilda_embeddings=tilda_embeddings) if inp_q is not None: with tf.variable_scope('mask_emb'): mask_emb = tf.get_variable('mask_emb', [1, 1, d_model], dtype=tf_float) if target_mapping is not None: word_emb_q = tf.tile(mask_emb, [tf.shape(target_mapping)[0], bsz, 1]) else: inp_q_ext = inp_q[:, :, None] word_emb_q = \ inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k output_h = tf.layers.dropout(word_emb_k, dropout, training=is_training) if inp_q is not None: output_g = tf.layers.dropout(word_emb_q, dropout, training=is_training) ##### Segment embedding if seg_id is not None: if untie_r: r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: # default case (tie) r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head], dtype=tf_float, initializer=initializer) # Convert `seg_id` to one-hot `seg_mat` mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32) cat_ids = tf.concat([mem_pad, seg_id], 0) # `1` indicates not in the same segment [qlen x klen x bsz] seg_mat = tf.cast( tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])), tf.int32) seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float) else: seg_mat = None ##### Positional encoding pos_emb = relative_positional_encoding(qlen, klen, d_model, clamp_len, attn_type, bi_data, bsz=bsz, dtype=tf_float) pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training) ##### Attention layers if mems is None: mems = [None] * n_layer for i in range(n_layer): # cache new mems new_mems.append(_cache_mem(output_h, mems[i], mem_len, reuse_len)) # segment bias if seg_id is None: r_s_bias_i = None seg_embed_i = None else: r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i] seg_embed_i = seg_embed[i] with tf.variable_scope('layer_{}'.format(i)): if inp_q is not None: output_h, output_g = two_stream_rel_attn( h=output_h, g=output_g, r=pos_emb, r_w_bias=r_w_bias if not untie_r else r_w_bias[i], r_r_bias=r_r_bias if not untie_r else r_r_bias[i], seg_mat=seg_mat, r_s_bias=r_s_bias_i, seg_embed=seg_embed_i, attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask, mems=mems[i], target_mapping=target_mapping, d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer) reuse = True else: reuse = False output_h = rel_multihead_attn( h=output_h, r=pos_emb, r_w_bias=r_w_bias if not untie_r else r_w_bias[i], r_r_bias=r_r_bias if not untie_r else r_r_bias[i], seg_mat=seg_mat, r_s_bias=r_s_bias_i, seg_embed=seg_embed_i, attn_mask=non_tgt_mask, mems=mems[i], d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, reuse=reuse) if inp_q is not None: output_g = positionwise_ffn(inp=output_g, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training) output_h = positionwise_ffn(inp=output_h, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training, reuse=reuse) if inp_q is not None: output = tf.layers.dropout(output_g, dropout, training=is_training) else: output = tf.layers.dropout(output_h, dropout, training=is_training) return output, new_mems, lookup_table