def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model): """Masked language modeling softmax layer.""" masked_lm_weights = inputs.masked_lm_weights with tf.variable_scope("generator_predictions"): if self._config.uniform_generator: logits = tf.zeros(self._bert_config.vocab_size) logits_tiled = tf.zeros( modeling.get_shape_list(inputs.masked_lm_ids) + [self._bert_config.vocab_size]) logits_tiled += tf.reshape( logits, [1, 1, self._bert_config.vocab_size]) logits = logits_tiled else: relevant_hidden = pretrain_helpers.gather_positions( model.get_sequence_output(), inputs.masked_lm_positions) hidden = tf.layers.dense( relevant_hidden, units=modeling.get_shape_list( model.get_embedding_table())[-1], activation=modeling.get_activation( self._bert_config.hidden_act), kernel_initializer=modeling.create_initializer( self._bert_config.initializer_range)) hidden = modeling.layer_norm(hidden) output_bias = tf.get_variable( "output_bias", shape=[self._bert_config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul(hidden, model.get_embedding_table(), transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) oh_labels = tf.one_hot(inputs.masked_lm_ids, depth=self._bert_config.vocab_size, dtype=tf.float32) probs = tf.nn.softmax(logits) log_probs = tf.nn.log_softmax(logits) label_log_probs = -tf.reduce_sum(log_probs * oh_labels, axis=-1) numerator = tf.reduce_sum(inputs.masked_lm_weights * label_log_probs) denominator = tf.reduce_sum(masked_lm_weights) + 1e-6 loss = numerator / denominator preds = tf.argmax(log_probs, axis=-1, output_type=tf.int32) MLMOutput = collections.namedtuple( "MLMOutput", ["logits", "probs", "loss", "per_example_loss", "preds"]) return MLMOutput(logits=logits, probs=probs, per_example_loss=label_log_probs, loss=loss, preds=preds)
def scatter_update(sequence, updates, positions): """Scatter-update a sequence. Args: sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor updates: A tensor of size batch_size*seq_len(*depth) positions: A [batch_size, n_positions] tensor Returns: A tuple of two tensors. First is a [batch_size, seq_len] or [batch_size, seq_len, depth] tensor of "sequence" with elements at "positions" replaced by the values at "updates." Updates to index 0 are ignored. If there are duplicated positions the update is only applied once. Second is a [batch_size, seq_len] mask tensor of which inputs were updated. """ shape = modeling.get_shape_list(sequence, expected_rank=[2, 3]) depth_dimension = (len(shape) == 3) if depth_dimension: B, L, D = shape else: B, L = shape D = 1 sequence = tf.expand_dims(sequence, -1) N = modeling.get_shape_list(positions)[1] shift = tf.expand_dims(L * tf.range(B), -1) flat_positions = tf.reshape(positions + shift, [-1, 1]) flat_updates = tf.reshape(updates, [-1, D]) updates = tf.scatter_nd(flat_positions, flat_updates, [B * L, D]) updates = tf.reshape(updates, [B, L, D]) flat_updates_mask = tf.ones([B * N], tf.int32) updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask, [B * L]) updates_mask = tf.reshape(updates_mask, [B, L]) not_first_token = tf.concat( [tf.zeros((B, 1), tf.int32), tf.ones((B, L - 1), tf.int32)], -1) updates_mask *= not_first_token updates_mask_3d = tf.expand_dims(updates_mask, -1) # account for duplicate positions if sequence.dtype == tf.float32: updates_mask_3d = tf.cast(updates_mask_3d, tf.float32) updates /= tf.maximum(1.0, updates_mask_3d) else: assert sequence.dtype == tf.int32 updates = tf.math.floordiv(updates, tf.maximum(1, updates_mask_3d)) updates_mask = tf.minimum(updates_mask, 1) updates_mask_3d = tf.minimum(updates_mask_3d, 1) updated_sequence = (((1 - updates_mask_3d) * sequence) + (updates_mask_3d * updates)) if not depth_dimension: updated_sequence = tf.squeeze(updated_sequence, -1) return updated_sequence, updates_mask
def _get_entropy_output(self, inputs: pretrain_data.Inputs, model): """Masked language modeling softmax layer.""" with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): hidden = tf.layers.dense( model.get_sequence_output(), units=modeling.get_shape_list(model.get_embedding_table())[-1], activation=modeling.get_activation( self._bert_config.hidden_act), kernel_initializer=modeling.create_initializer( self._bert_config.initializer_range)) hidden = modeling.layer_norm(hidden) output_bias = tf.get_variable("output_bias", shape=[self._bert_config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul(hidden, model.get_embedding_table(), transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probs = tf.nn.softmax(logits) log_probs = tf.nn.log_softmax(logits) entropy = -tf.reduce_sum(log_probs * probs, axis=[2]) EntropyOutput = collections.namedtuple( "EntropyOutput", ["logits", "probs", "log_probs", "entropy"]) return EntropyOutput(logits=logits, probs=probs, log_probs=log_probs, entropy=entropy)
def gather_positions(sequence, positions): """Gathers the vectors at the specific positions over a minibatch. Args: sequence: A [batch_size, seq_length] or [batch_size, seq_length, depth] tensor of values positions: A [batch_size, n_positions] tensor of indices Returns: A [batch_size, n_positions] or [batch_size, n_positions, depth] tensor of the values at the indices """ shape = modeling.get_shape_list(sequence, expected_rank=[2, 3]) depth_dimension = (len(shape) == 3) if depth_dimension: B, L, D = shape else: B, L = shape D = 1 sequence = tf.expand_dims(sequence, -1) position_shift = tf.expand_dims(L * tf.range(B), -1) flat_positions = tf.reshape(positions + position_shift, [-1]) flat_sequence = tf.reshape(sequence, [B * L, D]) gathered = tf.gather(flat_sequence, flat_positions) if depth_dimension: return tf.reshape(gathered, [B, -1, D]) else: return tf.reshape(gathered, [B, -1])
def sample_from_softmax(logits, disallow=None): if disallow is not None: logits -= 1000.0 * disallow uniform_noise = tf.random.uniform( modeling.get_shape_list(logits), minval=0, maxval=1) gumbel_noise = -tf.log(-tf.log(uniform_noise + 1e-9) + 1e-9) return tf.one_hot(tf.argmax(tf.nn.softmax(logits + gumbel_noise), -1, output_type=tf.int32), logits.shape[-1])
def sample_from_top_k(logits, temperature, disallow=None, straight_through=False, k=20): print(logits, '===========') logits_shape = modeling.get_shape_list(logits, expected_rank=[2, 3]) depth_dimension = (len(logits_shape) == 3) if depth_dimension: reshape_logits = tf.reshape(logits, [-1, logits_shape[-1]]) else: reshape_logits = logits print(reshape_logits, '======') reshape_logits_shape = modeling.get_shape_list(reshape_logits, expected_rank=[2]) batch = reshape_logits_shape[0] values, _ = tf.nn.top_k(reshape_logits, k=k) min_values = values[:, -1, tf.newaxis] reshape_topk_logits = tf.where( reshape_logits < min_values, tf.ones_like(reshape_logits, dtype=logits.dtype) * -1e10, reshape_logits, ) topk_logits = tf.reshape(reshape_topk_logits, logits_shape) if disallow is not None: topk_logits -= 1e10 * disallow uniform_noise = tf.random.uniform(modeling.get_shape_list(topk_logits), minval=0, maxval=1) gumbel_noise = -tf.log(-tf.log(uniform_noise + 1e-9) + 1e-9) gumbel_logits = (topk_logits + gumbel_noise) / temperature gumbel_probs = tf.nn.softmax(gumbel_logits) hard_token_ids = tf.one_hot( tf.argmax(gumbel_probs, axis=-1, output_type=tf.int32), topk_logits.shape[-1]) if straight_through: gumbel_dense = tf.stop_gradient(hard_token_ids - gumbel_probs) + gumbel_probs else: gumbel_dense = gumbel_probs return gumbel_dense
def __init__(self, config: configure_finetuning.FinetuningConfig, tasks, is_training, features, num_train_steps): # Create a shared transformer encoder bert_config = training_utils.get_bert_config(config) self.bert_config = bert_config if config.debug: bert_config.num_hidden_layers = 3 bert_config.hidden_size = 144 bert_config.intermediate_size = 144 * 4 bert_config.num_attention_heads = 4 # multi-choice mrc if any([isinstance(x, qa_tasks.MQATask) for x in tasks]): seq_len = config.max_seq_length assert seq_len <= bert_config.max_position_embeddings bs, total_len = modeling.get_shape_list(features["input_ids"], expected_rank=2) to_shape = [ bs * config.max_options_num * config.evidences_top_k, seq_len ] bert_model = modeling.BertModel( bert_config=bert_config, is_training=is_training, input_ids=tf.reshape(features["input_ids"], to_shape), input_mask=tf.reshape(features["input_mask"], to_shape), token_type_ids=tf.reshape(features["segment_ids"], to_shape), use_one_hot_embeddings=config.use_tpu, embedding_size=config.embedding_size) else: assert config.max_seq_length <= bert_config.max_position_embeddings bert_model = modeling.BertModel( bert_config=bert_config, is_training=is_training, input_ids=features["input_ids"], input_mask=features["input_mask"], token_type_ids=features["segment_ids"], use_one_hot_embeddings=config.use_tpu, embedding_size=config.embedding_size) percent_done = ( tf.cast(tf.train.get_or_create_global_step(), tf.float32) / tf.cast(num_train_steps, tf.float32)) # Add specific tasks self.outputs = {"task_id": features["task_id"]} losses = [] for task in tasks: with tf.variable_scope("task_specific/" + task.name): task_losses, task_outputs = task.get_prediction_module( bert_model, features, is_training, percent_done) losses.append(task_losses) self.outputs[task.name] = task_outputs self.loss = tf.reduce_sum( tf.stack(losses, -1) * tf.one_hot(features["task_id"], len(config.task_names)))
def _get_masked_lm_output(inputs, model): """Masked language modeling softmax layer.""" with tf.variable_scope("generator_predictions"): logits = tf.zeros(21228) logits_tiled = tf.zeros( modeling.get_shape_list(inputs.masked_lm_ids) + [21228]) logits_tiled += tf.reshape(logits, [1, 1, 21228]) logits = logits_tiled return get_softmax_output(logits, inputs.masked_lm_ids, inputs.masked_lm_weights, 21228)
def get_prediction_module(self, bert_model, features, is_training, percent_done): final_hidden = bert_model.get_pooled_output() # bs * options_num * top_k, hidden_dim final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=2) # bs, options_num * top_k * seq_len input_ids_shape = modeling.get_shape_list(features["input_ids"], expected_rank=2) batch_size = input_ids_shape[0] hidden_dim = final_hidden_shape[1] final_hidden_reshape = tf.reshape(final_hidden, [batch_size, self.config.max_options_num, self.config.evidences_top_k * hidden_dim]) # def single(hidden, mask, y): # logits = tf.squeeze(tf.layers.dense(hidden, 1), -1) # mask = mask[:self.config.max_options_num] # y = y[:self.config.max_options_num] # logits_masked = logits + 1e8 * (mask - 1) # loss = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits_masked) # return logits_masked, loss # # def combination_single(hidden, mask, y): # logits = tf.squeeze(tf.layers.dense(hidden, 1), -1) # todo: share or not ? # logits = tf.layers.dense(logits, 2 ** self.config.max_options_num) # logits_masked = logits + 1e8 * (mask - 1) # loss = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits_masked) # return logits_masked, loss # logits, loss = tf.cond() logits = tf.squeeze(tf.layers.dense(final_hidden_reshape, 1), -1) loss1 = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.to_float(features[self.name + "_answer_ids_raw"]), logits=logits) loss1 = tf.reduce_mean(loss1, axis=-1) logits = tf.layers.dense(logits, 2 ** self.config.max_options_num) logits_masked = logits + 1e8 * tf.to_float(features[self.name + "_answer_mask"] - 1) loss = tf.nn.softmax_cross_entropy_with_logits(labels=features[self.name + "_answer_ids"], logits=logits) loss = loss * 0.2 + loss1 * 0.8 return loss, dict( loss=loss, logits=logits_masked, eid=features[self.name + "_eid"], )
def get_token_logits(input_reprs, embedding_table, bert_config): hidden = tf.layers.dense( input_reprs, units=modeling.get_shape_list(embedding_table)[-1], activation=modeling.get_activation(bert_config.hidden_act), kernel_initializer=modeling.create_initializer( bert_config.initializer_range)) hidden = modeling.layer_norm(hidden) output_bias = tf.get_variable("output_bias", shape=[bert_config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul(hidden, embedding_table, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) return logits
def gather_indexes(sequence_tensor, positions): """Gathers the vectors at the specific positions over a minibatch.""" sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) batch_size = sequence_shape[0] seq_length = sequence_shape[1] width = sequence_shape[2] flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) flat_positions = tf.reshape(positions + flat_offsets, [-1]) flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) return output_tensor
def _build_teacher(self, states, inputs: pretrain_data.Inputs, is_training, name="teacher", reuse=False, **kwargs): """Build teacher network to estimate token score.""" input_shape = get_shape_list(states, expected_rank=3) prev_output = states hidden_size = self._teacher_config.hidden_size num_hidden_layers = self._teacher_config.num_hidden_layers if is_training: hidden_dropout_prob = self._teacher_config.hidden_dropout_prob else: hidden_dropout_prob = 0.0 with tf.variable_scope("teacher", reuse=reuse): for layer_idx in range(num_hidden_layers): with tf.variable_scope("layer_%d" % layer_idx): layer_input = prev_output layer_output = tf.layers.dense( layer_input, hidden_size, activation=get_activation("gelu"), kernel_initializer=create_initializer( self._teacher_config.initializer_range)) layer_output = dropout(layer_output, hidden_dropout_prob) layer_output = layer_norm(layer_output) prev_output = layer_output sequence_output = prev_output with tf.variable_scope("bernoulli"): with tf.variable_scope("transform"): logits = tf.layers.dense( sequence_output, units=1, kernel_initializer=create_initializer( self._teacher_config.initializer_range)) action_probs = tf.nn.sigmoid(logits) action_probs = tf.squeeze(action_probs) TeacherOutput = collections.namedtuple("TeacherOutput", ["action_probs"]) return TeacherOutput(action_probs=action_probs)
def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model): """Masked language modeling softmax layer.""" with tf.variable_scope("generator_predictions"): if self._config.uniform_generator: logits = tf.zeros(self._bert_config.vocab_size) logits_tiled = tf.zeros( modeling.get_shape_list(inputs.masked_lm_ids) + [self._bert_config.vocab_size]) logits_tiled += tf.reshape( logits, [1, 1, self._bert_config.vocab_size]) logits = logits_tiled else: relevant_reprs = pretrain_helpers.gather_positions( model.get_sequence_output(), inputs.masked_lm_positions) logits = get_token_logits(relevant_reprs, model.get_embedding_table(), self._bert_config) return get_softmax_output(logits, inputs.masked_lm_ids, inputs.masked_lm_weights, self._bert_config.vocab_size)
def _get_autoencoder_output(self, inputs: pretrain_data.Inputs, model): """Auto-Encoder softmax layer.""" with tf.variable_scope("autoencoder_predictions"): relevant_hidden = model.get_sequence_output() hidden = tf.layers.dense( relevant_hidden, units=modeling.get_shape_list(model.get_embedding_table())[-1], activation=modeling.get_activation( self._bert_config.hidden_act), kernel_initializer=modeling.create_initializer( self._bert_config.initializer_range)) hidden = modeling.layer_norm(hidden) output_bias = tf.get_variable("output_bias", shape=[self._bert_config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul(hidden, model.get_embedding_table(), transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) oh_labels = tf.one_hot(inputs.input_ids, depth=self._bert_config.vocab_size, dtype=tf.float32) probs = tf.nn.softmax(logits) log_probs = tf.nn.log_softmax(logits) label_log_probs = -tf.reduce_sum(log_probs * oh_labels, axis=-1) numerator = tf.reduce_sum(inputs.input_mask * label_log_probs) denominator = tf.reduce_sum(inputs.input_mask) + 1e-6 loss = numerator / denominator preds = tf.argmax(log_probs, axis=-1, output_type=tf.int32) AEOutput = collections.namedtuple( "AEOutput", ["logits", "probs", "loss", "per_example_loss", "preds"]) return AEOutput(logits=logits, probs=probs, per_example_loss=label_log_probs, loss=loss, preds=preds)
def sample_from_softmax(logits, temperature, disallow=None, straight_through=False): if disallow is not None: logits -= 1000.0 * disallow uniform_noise = tf.random.uniform(modeling.get_shape_list(logits), minval=0, maxval=1) gumbel_noise = -tf.log(-tf.log(uniform_noise + 1e-9) + 1e-9) gumbel_logits = (logits + gumbel_noise) / temperature gumbel_probs = tf.nn.softmax(gumbel_logits) hard_token_ids = tf.one_hot( tf.argmax(gumbel_probs, axis=-1, output_type=tf.int32), logits.shape[-1]) if straight_through: gumbel_dense = tf.stop_gradient(hard_token_ids - gumbel_probs) + gumbel_probs else: gumbel_dense = gumbel_probs return gumbel_dense
def _sample_masking_subset(self, inputs: pretrain_data.Inputs, action_probs): #calculate shifted action_probs input_mask = inputs.input_mask segment_ids = inputs.segment_ids input_ids = inputs.input_ids shape = modeling.get_shape_list(input_ids, expected_rank=2) batch_size = shape[0] max_seq_len = shape[1] def _remove_special_token(elems): action_prob = tf.cast(elems[0], tf.float32) segment = tf.cast(elems[1], tf.int32) input = tf.cast(elems[2], tf.int32) mask = tf.cast(elems[3], tf.int32) seq_len = tf.reduce_sum(mask) seg1_len = seq_len - tf.reduce_sum(segment) seq1_idx = tf.range(start=1, limit=seg1_len - 1, dtype=tf.int32) seq2_limit = tf.math.maximum(seg1_len, seq_len - 1) seq2_idx = tf.range(start=seg1_len, limit=seq2_limit, dtype=tf.int32) mask_idx = tf.range(start=seq_len, limit=max_seq_len, dtype=tf.int32) index_tensor = tf.concat([seq1_idx, seq2_idx, mask_idx], axis=0) seq1_prob = action_prob[1:seg1_len - 1] seq2_prob = action_prob[seg1_len:seq2_limit] mask_prob = tf.ones_like(mask_idx, dtype=tf.float32) * 1e-20 cleaned_action_prob = tf.concat([seq1_prob, seq2_prob, mask_prob], axis=0) cleaned_mask = tf.concat([ mask[1:seg1_len - 1], mask[seg1_len:seq_len - 1], mask[seq_len:max_seq_len] ], axis=0) cleaned_input = tf.concat([ input[1:seg1_len - 1], input[seg1_len:seq_len - 1], input[seq_len:max_seq_len] ], axis=0) cleaned_action_prob = cleaned_action_prob[0:max_seq_len - 3] index_tensor = index_tensor[0:max_seq_len - 3] cleaned_input = cleaned_input[0:max_seq_len - 3] cleaned_mask = cleaned_mask[0:max_seq_len - 3] return (cleaned_action_prob, index_tensor, cleaned_input, cleaned_mask) # Remove CLS and SEP action probs elems = tf.stack([ action_probs, tf.cast(segment_ids, tf.float32), tf.cast(input_ids, tf.float32), tf.cast(input_mask, tf.float32) ], 1) cleaned_action_probs, index_tensors, cleaned_inputs, cleaned_input_mask = tf.map_fn( _remove_special_token, elems, dtype=(tf.float32, tf.int32, tf.int32, tf.int32), parallel_iterations=1) logZ, log_prob = self._calculate_partition_table( cleaned_input_mask, cleaned_action_probs, self._config.max_predictions_per_seq) samples, log_q = self._sampling_a_subset( logZ, log_prob, self._config.max_predictions_per_seq) # Collect masked_lm_ids and masked_lm_positions zero_values = tf.zeros_like(index_tensors, tf.int32) selected_position = tf.where(tf.equal(samples, 1), index_tensors, zero_values) masked_lm_positions, _ = tf.nn.top_k( selected_position, self._config.max_predictions_per_seq, sorted=False) # Get the ids of the masked-out tokens shift = tf.expand_dims(max_seq_len * tf.range(batch_size), -1) flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1]) masked_lm_ids = tf.gather_nd(tf.reshape(input_ids, [-1]), flat_positions) masked_lm_ids = tf.reshape(masked_lm_ids, [batch_size, -1]) # Update the input ids replaced_prob = tf.random.uniform( [batch_size, self._config.max_predictions_per_seq]) replace_with_mask_positions = masked_lm_positions * tf.cast( tf.less(replaced_prob, 0.85), tf.int32) inputs_ids, _ = scatter_update( inputs.input_ids, tf.fill([batch_size, self._config.max_predictions_per_seq], self._vocab["[MASK]"]), replace_with_mask_positions) # Replace with random tokens replace_with_random_positions = masked_lm_positions * tf.cast( tf.greater(replaced_prob, 0.925), tf.int32) random_tokens = tf.random.uniform( [batch_size, self._config.max_predictions_per_seq], minval=0, maxval=len(self._vocab), dtype=tf.int32) inputs_ids, _ = scatter_update(inputs_ids, random_tokens, replace_with_random_positions) masked_lm_weights = tf.ones_like(masked_lm_ids, tf.float32) inv_vocab = self._inv_vocab # Apply mask on input if self._config.debug: def pretty_print(inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, tag_ids): debug_inputs = Inputs(input_ids=inputs_ids, input_mask=None, segment_ids=None, masked_lm_positions=masked_lm_positions, masked_lm_ids=masked_lm_ids, masked_lm_weights=masked_lm_weights, tag_ids=tag_ids) pretrain_data.print_tokens(debug_inputs, inv_vocab) ## TODO: save to the mask choice return inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights mask_shape = masked_lm_ids.get_shape() inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights = \ tf.py_func(pretty_print, [inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, inputs.tag_ids], (tf.int32, tf.int32, tf.int32, tf.float32)) inputs_ids.set_shape(inputs.input_ids.get_shape()) masked_lm_ids.set_shape(mask_shape) masked_lm_positions.set_shape(mask_shape) masked_lm_weights.set_shape(mask_shape) masked_input = pretrain_data.get_updated_inputs( inputs, input_ids=tf.stop_gradient(input_ids), masked_lm_positions=tf.stop_gradient(masked_lm_positions), masked_lm_ids=tf.stop_gradient(masked_lm_ids), masked_lm_weights=tf.stop_gradient(masked_lm_weights), tag_ids=inputs.tag_ids) return log_q, masked_input
def mask(config: configure_pretraining.PretrainingConfig, inputs: pretrain_data.Inputs, mask_prob, proposal_distribution=1.0, disallow_from_mask=None, already_masked=None): """Implementation of dynamic masking. The optional arguments aren't needed for BERT/ELECTRA and are from early experiments in "strategically" masking out tokens instead of uniformly at random. Args: config: configure_pretraining.PretrainingConfig inputs: pretrain_data.Inputs containing input input_ids/input_mask mask_prob: percent of tokens to mask proposal_distribution: for non-uniform masking can be a [B, L] tensor of scores for masking each position. disallow_from_mask: a boolean tensor of [B, L] of positions that should not be masked out already_masked: a boolean tensor of [B, N] of already masked-out tokens for multiple rounds of masking Returns: a pretrain_data.Inputs with masking added """ # Get the batch size, sequence length, and max masked-out tokens N = config.max_predictions_per_seq B, L = modeling.get_shape_list(inputs.input_ids) # Find indices where masking out a token is allowed tokenizer = tokenization.FullTokenizer( config.vocab_file, do_lower_case=config.do_lower_case) vocab = tokenizer.vocab inv_vocab = tokenizer.inv_vocab candidates_mask = _get_candidates_mask(inputs, vocab, disallow_from_mask) # Set the number of tokens to mask out per example num_tokens = tf.cast(tf.reduce_sum(inputs.input_mask, -1), tf.float32) num_to_predict = tf.maximum(1, tf.minimum( N, tf.cast(tf.round(num_tokens * mask_prob), tf.int32))) masked_lm_weights = tf.cast(tf.sequence_mask(num_to_predict, N), tf.float32) if already_masked is not None: masked_lm_weights *= (1 - already_masked) # Get a probability of masking each position in the sequence candidate_mask_float = tf.cast(candidates_mask, tf.float32) if config.masking_strategy == RAND_STRATEGY or config.masking_strategy == MIX_ADV_STRATEGY: sample_prob = (proposal_distribution * candidate_mask_float) elif config.masking_strategy == POS_STRATEGY: unfavor_pos_mask = _get_unfavor_pos_mask(inputs) unfavor_pos_mask_float = tf.cast(unfavor_pos_mask, tf.float32) prefer_pos_mask_float = 1 - unfavor_pos_mask_float # prefered pos have 80% propabiblity, not preferred ones have 20% probability # proposal_distribution = prefer_pos_mask_float proposal_distribution = 0.95 * prefer_pos_mask_float + 0.05 sample_prob = (proposal_distribution * candidate_mask_float) elif config.masking_strategy == ENTROPY_STRATEGY: sample_prob = (proposal_distribution * candidate_mask_float) elif config.masking_strategy == MIX_POS_STRATEGY: rand_sample_prob = (proposal_distribution * candidate_mask_float) unfavor_pos_mask = _get_unfavor_pos_mask(inputs) unfavor_pos_mask_float = tf.cast(unfavor_pos_mask, tf.float32) prefer_pos_mask_float = 1 - unfavor_pos_mask_float # prefered pos have 80% propabiblity, not preferred ones have 20% probability # proposal_distribution = prefer_pos_mask_float proposal_distribution = 0.95 * prefer_pos_mask_float + 0.05 pos_sample_prob = (proposal_distribution * candidate_mask_float) strategy_prob = tf.random.uniform([B]) strategy_prob = tf.expand_dims(tf.cast(tf.greater(strategy_prob,0.5), tf.float32),1) strategy_prob = tf.tile(strategy_prob, [1,L]) sample_prob = rand_sample_prob * strategy_prob + pos_sample_prob * (1 - strategy_prob) elif config.masking_strategy == MIX_ENTROPY_STRATEGY: rand_sample_prob = (proposal_distribution * candidate_mask_float) entropy_sample_prob = (proposal_distribution * candidate_mask_float) strategy_prob = tf.random.uniform([B]) strategy_prob = tf.expand_dims(tf.cast(tf.greater(strategy_prob,0.5), tf.float32),1) strategy_prob = tf.tile(strategy_prob, [1, L]) sample_prob = rand_sample_prob * strategy_prob + entropy_sample_prob * (1 - strategy_prob) else: raise ValueError("{} strategy is not supported".format(config.masking_strategy)) sample_prob /= tf.reduce_sum(sample_prob, axis=-1, keepdims=True) # Sample the positions to mask out sample_prob = tf.stop_gradient(sample_prob) sample_logits = tf.log(sample_prob) masked_lm_positions = tf.random.categorical( sample_logits, N, dtype=tf.int32) masked_lm_positions *= tf.cast(masked_lm_weights, tf.int32) # Get the ids of the masked-out tokens shift = tf.expand_dims(L * tf.range(B), -1) flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1]) masked_lm_ids = tf.gather_nd(tf.reshape(inputs.input_ids, [-1]), flat_positions) masked_lm_ids = tf.reshape(masked_lm_ids, [B, -1]) masked_lm_ids *= tf.cast(masked_lm_weights, tf.int32) # Update the input ids replace_prob = tf.random.uniform([B, N]) replace_with_mask_positions = masked_lm_positions * tf.cast( tf.less(replace_prob, 0.85), tf.int32) inputs_ids, _ = scatter_update( inputs.input_ids, tf.fill([B, N], vocab["[MASK]"]), replace_with_mask_positions) # Replace with random tokens replace_with_random_positions = masked_lm_positions * tf.cast( tf.greater(replace_prob, 0.925), tf.int32) random_tokens = tf.random.uniform([B,N], minval=0, maxval=len(vocab), dtype=tf.int32) inputs_ids, _ = scatter_update( inputs_ids, random_tokens, replace_with_random_positions) if config.debug: def pretty_print(inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, tag_ids): debug_inputs = Inputs( input_ids=inputs_ids, input_mask=None, segment_ids=None, masked_lm_positions=masked_lm_positions, masked_lm_ids=masked_lm_ids, masked_lm_weights=masked_lm_weights, tag_ids = tag_ids) pretrain_data.print_tokens(debug_inputs, inv_vocab) ## TODO: save to the mask choice return inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights mask_shape = masked_lm_ids.get_shape() inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights = \ tf.py_func(pretty_print,[inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, inputs.tag_ids], (tf.int32, tf.int32, tf.int32, tf.float32)) inputs_ids.set_shape(inputs.input_ids.get_shape()) masked_lm_ids.set_shape(mask_shape) masked_lm_positions.set_shape(mask_shape) masked_lm_weights.set_shape(mask_shape) return pretrain_data.get_updated_inputs( inputs, input_ids=tf.stop_gradient(inputs_ids), masked_lm_positions=masked_lm_positions, masked_lm_ids=masked_lm_ids, masked_lm_weights=masked_lm_weights, tag_ids = inputs.tag_ids )
def mask(config: configure_pretraining.PretrainingConfig, inputs: pretrain_data.Inputs, mask_prob, proposal_distribution=1.0, disallow_from_mask=None, already_masked=None): """Implementation of dynamic masking. The optional arguments aren't needed for BERT/ELECTRA and are from early experiments in "strategically" masking out tokens instead of uniformly at random. Args: config: configure_pretraining.PretrainingConfig inputs: pretrain_data.Inputs containing input input_ids/input_mask mask_prob: percent of tokens to mask proposal_distribution: for non-uniform masking can be a [B, L] tensor of scores for masking each position. disallow_from_mask: a boolean tensor of [B, L] of positions that should not be masked out already_masked: a boolean tensor of [B, N] of already masked-out tokens for multiple rounds of masking Returns: a pretrain_data.Inputs with masking added """ # Get the batch size, sequence length, and max masked-out tokens N = config.max_predictions_per_seq B, L = modeling.get_shape_list(inputs.input_ids) # Find indices where masking out a token is allowed vocab = tokenization.FullTokenizer( config.vocab_file, do_lower_case=config.do_lower_case).vocab candidates_mask = _get_candidates_mask(inputs, vocab, disallow_from_mask) # Set the number of tokens to mask out per example num_tokens = tf.cast(tf.reduce_sum(inputs.input_mask, -1), tf.float32) num_to_predict = tf.maximum(1, tf.minimum( N, tf.cast(tf.round(num_tokens * mask_prob), tf.int32))) masked_lm_weights = tf.cast(tf.sequence_mask(num_to_predict, N), tf.float32) if already_masked is not None: masked_lm_weights *= (1 - already_masked) # Get a probability of masking each position in the sequence candidate_mask_float = tf.cast(candidates_mask, tf.float32) sample_prob = (proposal_distribution * candidate_mask_float) sample_prob /= tf.reduce_sum(sample_prob, axis=-1, keepdims=True) # Sample the positions to mask out sample_prob = tf.stop_gradient(sample_prob) sample_logits = tf.log(sample_prob) masked_lm_positions = tf.random.categorical( sample_logits, N, dtype=tf.int32) masked_lm_positions *= tf.cast(masked_lm_weights, tf.int32) # Get the ids of the masked-out tokens shift = tf.expand_dims(L * tf.range(B), -1) flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1]) masked_lm_ids = tf.gather_nd(tf.reshape(inputs.input_ids, [-1]), flat_positions) masked_lm_ids = tf.reshape(masked_lm_ids, [B, -1]) masked_lm_ids *= tf.cast(masked_lm_weights, tf.int32) # Update the input ids replace_with_mask_positions = masked_lm_positions * tf.cast( tf.less(tf.random.uniform([B, N]), 0.85), tf.int32) inputs_ids, _ = scatter_update( inputs.input_ids, tf.fill([B, N], vocab["[MASK]"]), replace_with_mask_positions) return pretrain_data.get_updated_inputs( inputs, input_ids=tf.stop_gradient(inputs_ids), masked_lm_positions=masked_lm_positions, masked_lm_ids=masked_lm_ids, masked_lm_weights=masked_lm_weights )
def get_prediction_module(self, bert_model, features, is_training, percent_done): final_hidden = bert_model.get_sequence_output() final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3) batch_size = final_hidden_shape[0] seq_length = final_hidden_shape[1] answer_mask = tf.cast(features["input_mask"], tf.float32) answer_mask *= tf.cast(features["segment_ids"], tf.float32) answer_mask += tf.one_hot(0, seq_length) start_logits = tf.squeeze(tf.layers.dense(final_hidden, 1), -1) start_top_log_probs = tf.zeros([batch_size, self.config.beam_size]) start_top_index = tf.zeros([batch_size, self.config.beam_size], tf.int32) end_top_log_probs = tf.zeros( [batch_size, self.config.beam_size, self.config.beam_size]) end_top_index = tf.zeros( [batch_size, self.config.beam_size, self.config.beam_size], tf.int32) if self.config.joint_prediction: start_logits += 1000.0 * (answer_mask - 1) start_log_probs = tf.nn.log_softmax(start_logits) start_top_log_probs, start_top_index = tf.nn.top_k( start_log_probs, k=self.config.beam_size) if not is_training: # batch, beam, length, hidden end_features = tf.tile(tf.expand_dims(final_hidden, 1), [1, self.config.beam_size, 1, 1]) # batch, beam, length start_index = tf.one_hot(start_top_index, depth=seq_length, axis=-1, dtype=tf.float32) # batch, beam, hidden start_features = tf.reduce_sum( tf.expand_dims(final_hidden, 1) * tf.expand_dims(start_index, -1), axis=-2) # batch, beam, length, hidden start_features = tf.tile(tf.expand_dims(start_features, 2), [1, 1, seq_length, 1]) else: start_index = tf.one_hot(features[self.name + "_start_positions"], depth=seq_length, axis=-1, dtype=tf.float32) start_features = tf.reduce_sum( tf.expand_dims(start_index, -1) * final_hidden, axis=1) start_features = tf.tile(tf.expand_dims(start_features, 1), [1, seq_length, 1]) end_features = final_hidden final_repr = tf.concat([start_features, end_features], -1) final_repr = tf.layers.dense(final_repr, 512, activation=modeling.gelu, name="qa_hidden") # batch, beam, length (batch, length when training) end_logits = tf.squeeze(tf.layers.dense(final_repr, 1), -1, name="qa_logits") if is_training: end_logits += 1000.0 * (answer_mask - 1) else: end_logits += tf.expand_dims(1000.0 * (answer_mask - 1), 1) if not is_training: end_log_probs = tf.nn.log_softmax(end_logits) end_top_log_probs, end_top_index = tf.nn.top_k( end_log_probs, k=self.config.beam_size) end_logits = tf.zeros([batch_size, seq_length]) else: end_logits = tf.squeeze(tf.layers.dense(final_hidden, 1), -1) start_logits += 1000.0 * (answer_mask - 1) end_logits += 1000.0 * (answer_mask - 1) def compute_loss(logits, positions): one_hot_positions = tf.one_hot(positions, depth=seq_length, dtype=tf.float32) log_probs = tf.nn.log_softmax(logits, axis=-1) loss = -tf.reduce_sum(one_hot_positions * log_probs, axis=-1) return loss start_positions = features[self.name + "_start_positions"] end_positions = features[self.name + "_end_positions"] start_loss = compute_loss(start_logits, start_positions) end_loss = compute_loss(end_logits, end_positions) losses = (start_loss + end_loss) / 2.0 answerable_logit = tf.zeros([batch_size]) if self.config.answerable_classifier: final_repr = final_hidden[:, 0] if self.config.answerable_uses_start_logits: start_p = tf.nn.softmax(start_logits) start_feature = tf.reduce_sum(tf.expand_dims(start_p, -1) * final_hidden, axis=1) final_repr = tf.concat([final_repr, start_feature], -1) final_repr = tf.layers.dense(final_repr, 512, activation=modeling.gelu) answerable_logit = tf.squeeze(tf.layers.dense(final_repr, 1), -1) answerable_loss = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.cast(features[self.name + "_is_impossible"], tf.float32), logits=answerable_logit) losses += answerable_loss * self.config.answerable_weight return losses, dict( loss=losses, start_logits=start_logits, end_logits=end_logits, answerable_logit=answerable_logit, start_positions=features[self.name + "_start_positions"], end_positions=features[self.name + "_end_positions"], start_top_log_probs=start_top_log_probs, start_top_index=start_top_index, end_top_log_probs=end_top_log_probs, end_top_index=end_top_index, eid=features[self.name + "_eid"], )
def get_prediction_module(self, bert_model, features, is_training, percent_done): final_hidden = bert_model.get_sequence_output() # sgnet # dep_mask_x = features[self.name + "_dep_mask_x"] # dep_mask_y = features[self.name + "_dep_mask_y"] # dep_mask_len = features[self.name + "_dep_mask_len"] # # def fn(xyz): # x = xyz[0] # y = xyz[1] # length = xyz[2] # x = x[:length] # y = y[:length] # st = tf.SparseTensor(indices=tf.cast(tf.transpose([x, y]), tf.int64), # values=tf.ones_like(x, dtype=tf.float32), # dense_shape=[self.config.max_seq_length, self.config.max_seq_length]) # dt = tf.sparse_tensor_to_dense(st) # return dt # # dep_mask = tf.map_fn(fn, (dep_mask_x, dep_mask_y, dep_mask_len), dtype=tf.float32) # dep_mask = features["squad_dep_mask"] # dep_mask = tf.reshape(dep_mask, [-1, self.config.max_seq_length, self.config.max_seq_length]) # with tf.variable_scope("dependence"): # bert_config = bert_model.config # dep_att_output, _ = modeling.transformer_model( # input_tensor=final_hidden, # attention_mask=dep_mask, # hidden_size=bert_config.hidden_size, # num_hidden_layers=1, # num_attention_heads=bert_config.num_attention_heads, # intermediate_size=bert_config.intermediate_size, # intermediate_act_fn=modeling.get_activation(bert_config.hidden_act), # hidden_dropout_prob=bert_config.hidden_dropout_prob, # attention_probs_dropout_prob=bert_config.attention_probs_dropout_prob, # initializer_range=bert_config.initializer_range, # do_return_all_layers=False) # weight = tf.get_variable(name="weight", dtype=tf.float32, initializer=tf.zeros_initializer(), # shape=(), trainable=True) # weight = tf.sigmoid(weight) # final_hidden = weight * final_hidden + (1 - weight) * dep_att_output final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3) batch_size = final_hidden_shape[0] seq_length = final_hidden_shape[1] answer_mask = tf.cast(features["input_mask"], tf.float32) answer_mask *= tf.cast(features["segment_ids"], tf.float32) answer_mask += tf.one_hot(0, seq_length) start_logits = tf.squeeze(tf.layers.dense(final_hidden, 1), -1) start_top_log_probs = tf.zeros([batch_size, self.config.beam_size]) start_top_index = tf.zeros([batch_size, self.config.beam_size], tf.int32) end_top_log_probs = tf.zeros( [batch_size, self.config.beam_size, self.config.beam_size]) end_top_index = tf.zeros( [batch_size, self.config.beam_size, self.config.beam_size], tf.int32) if self.config.joint_prediction: start_logits += 1000.0 * (answer_mask - 1) start_log_probs = tf.nn.log_softmax(start_logits) start_top_log_probs, start_top_index = tf.nn.top_k( start_log_probs, k=self.config.beam_size) if not is_training: # batch, beam, length, hidden end_features = tf.tile(tf.expand_dims(final_hidden, 1), [1, self.config.beam_size, 1, 1]) # batch, beam, length start_index = tf.one_hot(start_top_index, depth=seq_length, axis=-1, dtype=tf.float32) # batch, beam, hidden start_features = tf.reduce_sum( tf.expand_dims(final_hidden, 1) * tf.expand_dims(start_index, -1), axis=-2) # batch, beam, length, hidden start_features = tf.tile(tf.expand_dims(start_features, 2), [1, 1, seq_length, 1]) else: start_index = tf.one_hot(features[self.name + "_start_positions"], depth=seq_length, axis=-1, dtype=tf.float32) start_features = tf.reduce_sum( tf.expand_dims(start_index, -1) * final_hidden, axis=1) start_features = tf.tile(tf.expand_dims(start_features, 1), [1, seq_length, 1]) end_features = final_hidden final_repr = tf.concat([start_features, end_features], -1) final_repr = tf.layers.dense(final_repr, 512, activation=modeling.gelu, name="qa_hidden") # batch, beam, length (batch, length when training) end_logits = tf.squeeze(tf.layers.dense(final_repr, 1), -1, name="qa_logits") if is_training: end_logits += 1000.0 * (answer_mask - 1) else: end_logits += tf.expand_dims(1000.0 * (answer_mask - 1), 1) if not is_training: end_log_probs = tf.nn.log_softmax(end_logits) end_top_log_probs, end_top_index = tf.nn.top_k( end_log_probs, k=self.config.beam_size) end_logits = tf.zeros([batch_size, seq_length]) else: end_logits = tf.squeeze(tf.layers.dense(final_hidden, 1), -1) start_logits += 1000.0 * (answer_mask - 1) end_logits += 1000.0 * (answer_mask - 1) def compute_loss(logits, positions): one_hot_positions = tf.one_hot(positions, depth=seq_length, dtype=tf.float32) log_probs = tf.nn.log_softmax(logits, axis=-1) loss = -tf.reduce_sum(one_hot_positions * log_probs, axis=-1) return loss start_positions = features[self.name + "_start_positions"] end_positions = features[self.name + "_end_positions"] start_loss = compute_loss(start_logits, start_positions) end_loss = compute_loss(end_logits, end_positions) losses = (start_loss + end_loss) / 2.0 # plausible answer loss plau_logits = tf.layers.dense(final_hidden, 2) plau_logits = tf.reshape(plau_logits, [batch_size, seq_length, 2]) plau_logits = tf.transpose(plau_logits, [2, 0, 1]) unstacked_logits = tf.unstack(plau_logits, axis=0) (plau_start_logits, plau_end_logits) = (unstacked_logits[0], unstacked_logits[1]) plau_start_logits += 1000.0 * (answer_mask - 1) plau_end_logits += 1000.0 * (answer_mask - 1) plau_start_positions = features[self.name + "_plau_answer_start"] plau_end_positions = features[self.name + "_plau_answer_end"] plau_start_loss = compute_loss(plau_start_logits, plau_start_positions) plau_end_loss = compute_loss(plau_end_logits, plau_end_positions) losses += (plau_start_loss + plau_end_loss) / 2.0 # def compute_loss_for_plau(start_logits, end_logits, start_positions, end_positions, start_positions_true, # alpha=1.0, beta=1.0): # start_probs = tf.nn.softmax(start_logits) # end_probs = tf.nn.softmax(end_logits) # log_neg_start_probs = tf.log(tf.clip_by_value(1 - start_probs, 1e-30, 1)) # log_neg_end_probs = tf.log(tf.clip_by_value(1 - end_probs, 1e-30, 1)) # start_positions_mask = tf.cast(tf.sequence_mask(start_positions, maxlen=seq_length), tf.float32) # end_positions_mask = tf.cast(tf.sequence_mask(end_positions + 1, maxlen=seq_length), tf.float32) # positions_mask = end_positions_mask - start_positions_mask # one_hot_positions = tf.one_hot( # start_positions_true, depth=seq_length, dtype=tf.float32) # positions_mask = positions_mask * (1 - one_hot_positions) # 忽略切出来的无答案 # # # mask_0 = tf.zeros([batch_size, 1]) # # mask_1 = tf.ones([batch_size, seq_length - 1]) # # zero_mask = tf.concat([mask_0, mask_1], axis=1) # # positions_mask = positions_mask * zero_mask # loss1 = - tf.reduce_sum(positions_mask * log_neg_start_probs, axis=-1) # loss1 = tf.reduce_mean(loss1) # loss2 = - tf.reduce_sum(positions_mask * log_neg_end_probs, axis=-1) # loss2 = tf.reduce_mean(loss2) # return (loss1 * alpha + loss2 * beta) * 0.5 # # plau_loss = compute_loss_for_plau(start_logits, end_logits, # features[self.name + "_plau_answer_start"], # features[self.name + "_plau_answer_end"], # features[self.name + "_start_positions"], 1.0, 1.0) # losses += plau_loss answerable_logit = tf.zeros([batch_size]) if self.config.answerable_classifier: final_repr = final_hidden[:, 0] if self.config.answerable_uses_start_logits: start_p = tf.nn.softmax(start_logits) start_feature = tf.reduce_sum(tf.expand_dims(start_p, -1) * final_hidden, axis=1) final_repr = tf.concat([final_repr, start_feature], -1) final_repr = tf.layers.dense(final_repr, 512, activation=modeling.gelu) answerable_logit = tf.squeeze(tf.layers.dense(final_repr, 1), -1) answerable_loss = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.cast(features[self.name + "_is_impossible"], tf.float32), logits=answerable_logit) losses += answerable_loss * self.config.answerable_weight return losses, dict( loss=losses, start_logits=start_logits, end_logits=end_logits, answerable_logit=answerable_logit, start_positions=features[self.name + "_start_positions"], end_positions=features[self.name + "_end_positions"], start_top_log_probs=start_top_log_probs, start_top_index=start_top_index, end_top_log_probs=end_top_log_probs, end_top_index=end_top_index, eid=features[self.name + "_eid"], )
def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model): """Masked language modeling softmax layer.""" masked_lm_weights = inputs.masked_lm_weights with tf.variable_scope("generator_predictions"): if self._config.uniform_generator or self._config.identity_generator or self._config.heuristic_generator: logits = tf.zeros(self._bert_config.vocab_size) logits_tiled = tf.zeros( modeling.get_shape_list(inputs.masked_lm_ids) + [self._bert_config.vocab_size]) logits_tiled += tf.reshape(logits, [1, 1, self._bert_config.vocab_size]) logits = logits_tiled else: relevant_hidden = pretrain_helpers.gather_positions( model.get_sequence_output(), inputs.masked_lm_positions) hidden = tf.layers.dense( relevant_hidden, units=modeling.get_shape_list(model.get_embedding_table())[-1], activation=modeling.get_activation(self._bert_config.hidden_act), kernel_initializer=modeling.create_initializer( self._bert_config.initializer_range)) hidden = modeling.layer_norm(hidden) output_bias = tf.get_variable( "output_bias", shape=[self._bert_config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul(hidden, model.get_embedding_table(), transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) oh_labels = tf.one_hot( inputs.masked_lm_ids, depth=self._bert_config.vocab_size, dtype=tf.float32) probs = tf.nn.softmax(logits) if self._config.identity_generator: identity_logits = tf.zeros(self._bert_config.vocab_size) identity_logits_tiled = tf.zeros( modeling.get_shape_list(inputs.masked_lm_ids) + [self._bert_config.vocab_size]) masked_identity_weights = tf.one_hot(inputs.masked_lm_ids, depth=self._bert_config.vocab_size, dtype=tf.float32) identity_logits_tiled += 25.0 * masked_identity_weights identity_logits_tiled += tf.reshape(identity_logits, [1, 1, self._bert_config.vocab_size]) identity_logits = identity_logits_tiled identity_probs = tf.nn.softmax(identity_logits) identity_weight = (self.global_step / tf.cast(self._config.num_train_steps, tf.float32)) * self._config.max_identity_weight probs = probs * (1 - identity_weight) + identity_probs * identity_weight logits = tf.math.log(probs) # softmax(log(probs)) = probs elif self._config.heuristic_generator: synonym_logits = tf.zeros(self._bert_config.vocab_size) synonym_logits_tiled = tf.zeros( modeling.get_shape_list(inputs.masked_lm_ids) + [self._bert_config.vocab_size]) masked_synonym_weights = tf.reduce_sum( tf.one_hot(inputs.masked_synonym_ids, depth=self._bert_config.vocab_size, dtype=tf.float32), -2) padded_synonym_mask = tf.concat([tf.zeros([1]), tf.ones([self._bert_config.vocab_size - 1])], 0) masked_synonym_weights *= tf.expand_dims(tf.expand_dims(padded_synonym_mask, 0), 0) synonym_logits_tiled += 25.0 * masked_synonym_weights synonym_logits_tiled += tf.reshape(synonym_logits, [1, 1, self._bert_config.vocab_size]) synonym_logits = synonym_logits_tiled synonym_probs = tf.nn.softmax(synonym_logits) if self._config.synonym_scheduler_type == 'linear': synonym_weight = (self.global_step / tf.cast(self._config.num_train_steps, tf.float32)) * self._config.max_synonym_weight probs = probs * (1 - synonym_weight) + synonym_probs * synonym_weight logits = tf.math.log(probs) # softmax(log(probs)) = probs log_probs = tf.nn.log_softmax(logits) label_log_probs = -tf.reduce_sum(log_probs * oh_labels, axis=-1) numerator = tf.reduce_sum(inputs.masked_lm_weights * label_log_probs) denominator = tf.reduce_sum(masked_lm_weights) + 1e-6 loss = numerator / denominator preds = tf.argmax(log_probs, axis=-1, output_type=tf.int32) MLMOutput = collections.namedtuple( "MLMOutput", ["logits", "probs", "loss", "per_example_loss", "preds"]) return MLMOutput( logits=logits, probs=probs, per_example_loss=label_log_probs, loss=loss, preds=preds)
def _sampling_a_subset(self, logZ, logp, max_predictions_per_seq): shape = modeling.get_shape_list(logp, expected_rank=2) seq_len = shape[1] def gather_z_indexes(sequence_tensor, positions): """Gathers the vectors at the specific positions over a minibatch.""" # set negative indices to zeros mask = tf.zeros_like(positions, dtype=tf.int32) masked_position = tf.reduce_max(tf.stack([positions, mask]), 0) index = tf.reshape( tf.cast(tf.where(tf.equal(mask, 0)), dtype=tf.int32), [-1]) flat_offsets = index * (max_predictions_per_seq + 1) flat_positions = masked_position + flat_offsets flat_sequence_tensor = tf.reshape(sequence_tensor, [-1]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) return output_tensor def sampling_loop_cond(j, subset, count, left, log_q): # j < N and left > 0 # we want to exclude last tokens, because it's always a special token [SEP] return tf.logical_or(tf.less(j, seq_len), tf.greater(tf.reduce_sum(left), 0)) def sampling_body(j, subset, count, left, log_q): # calculate log_q_yes and log_q_no logp_j = logp[:, j] log_Z_total = gather_z_indexes(logZ[:, j, :], left) # b log_Z_yes = gather_z_indexes(logZ[:, j + 1, :], left - 1) # b logit_yes = logp_j + log_Z_yes - log_Z_total logit_no = tf.log( tf.clip_by_value(1 - tf.exp(logit_yes), 1e-20, 1.0)) # draw 2 Gumbel noise and compute action by argmax logits = tf.transpose(tf.stack([logit_no, logit_yes]), [1, 0]) actions = gumbel.gumbel_softmax(logits) action_mask = tf.cast(tf.argmax(actions, 1), dtype=tf.int32) no_left_mask = tf.where(tf.greater(left, 0), tf.ones_like(left, dtype=tf.int32), tf.zeros_like(left, dtype=tf.int32)) output = action_mask * no_left_mask actions = tf.reduce_max(actions, 1) log_actions = tf.log(actions) # compute log_q_j and update count and subset count = count + output left = left - output log_q = log_q + log_actions subset = subset.write(j, output) return [tf.add(j, 1), subset, count, left, log_q] with tf.variable_scope("teacher/sampling"): # Batch sampling subset = tf.TensorArray(tf.int32, size=seq_len) count = tf.zeros_like(logp[:, 0], dtype=tf.dtypes.int32) left = tf.ones_like(logp[:, 0], dtype=tf.dtypes.int32) left = left * max_predictions_per_seq log_q = tf.zeros_like(count, dtype=tf.dtypes.float32) _, subset, count, left, log_q = tf.while_loop( sampling_loop_cond, sampling_body, [tf.constant(0), subset, count, left, log_q]) subset = subset.stack() # K x b x N subset = tf.transpose(subset, [1, 0]) return subset, log_q
def _calculate_partition_table(self, input_mask, action_prob, max_predictions_per_seq): shape = modeling.get_shape_list(action_prob, expected_rank=2) seq_len = shape[1] with tf.variable_scope("teacher/dp"): ''' Calculate DP table: aims to calculate logZ[0,K] # We add an extra row so that when we calculate log_q_yes, we don't have out of bound error # Z[b,N+1,k] = log 0 - we do not allow to choose anything # logZ size batch_size x N+1 x K+1 ''' initZ = tf.TensorArray(tf.float32, size=max_predictions_per_seq + 1) logZ_0 = tf.zeros_like(input_mask, dtype=tf.float32) # size b x N logZ_0 = tf.pad(logZ_0, [[0, 0], [0, 1]], "CONSTANT") # size b x N+1 initZ = initZ.write(tf.constant(0), logZ_0) # mask logp action_prob = tf.cast(input_mask, dtype=tf.float32) * action_prob action_prob = tf.clip_by_value(action_prob, 1e-20, 1.0) logp = tf.log(action_prob) logp_no = tf.log(1 - action_prob) accum_logp = tf.cumsum(logp, axis=1, reverse=True) def accum_cond(j, logZ_j, logb, loga): return tf.greater(j, -1) def accum_body(j, logZ_j, logb, loga): # logb: log_yes = logp[j] + logZ[j+1, k-1] -- already compute # loga: log_no = log(1-p[j]) + logZ[j+1, k] logb_j = tf.squeeze(logb[:, j]) log_one_minus_p_j = tf.squeeze(logp_no[:, j]) loga = loga + log_one_minus_p_j next_logZ_j = tf.math.reduce_logsumexp( tf.stack([loga, logb_j]), 0) logZ_j = logZ_j.write(j, next_logZ_j) return [tf.subtract(j, 1), logZ_j, logb, next_logZ_j] def dp_loop_cond(k, logZ, lastZ): return tf.less(k, max_predictions_per_seq + 1) def dp_body(k, logZ, lastZ): ''' case j < N-k + 1: logZ[j,k] = log_sum( log(1-pi(j)) + logZ[j+1,k], logp(j) + logZ[j+1,k-1]) case j = N-k + 1 logZ[j,k] = accum_logp[j] case j > N-k + 1 logZ[j,k] = 0 ''' # shift lastZ one step shifted_lastZ = tf.roll(lastZ[:, :-1], shift=1, axis=1) #logZ[j+1,k-1] log_yes = logp + shifted_lastZ # b x N logZ_j = tf.TensorArray(tf.float32, size=seq_len + 1) init_value = accum_logp[:, seq_len - k] logZ_j = logZ_j.write(seq_len - k, init_value) _, logZ_j, logb, loga = tf.while_loop( accum_cond, accum_body, [seq_len - k - 1, logZ_j, log_yes, init_value]) logZ_j = logZ_j.stack() # N x b logZ_j = tf.transpose(logZ_j, [1, 0]) # b x N logZ = logZ.write(k, logZ_j) return [tf.add(k, 1), logZ, logZ_j] k = tf.constant(1) _, logZ, lastZ = tf.while_loop(dp_loop_cond, dp_body, [k, initZ, logZ_0], shape_invariants=[ k.get_shape(), tf.TensorShape([]), tf.TensorShape([None, None]) ]) logZ = logZ.stack() # N x b x N logZ = tf.transpose(logZ, [1, 2, 0]) return logZ, logp
def _get_richer_data(self, fake_data): inputs_tf = fake_data.inputs.input_ids labels_tf = fake_data.is_fake_tokens lens_tf = tf.reduce_sum(fake_data.inputs.input_mask, 1) #retrieve the basic config V = self._bert_config.vocab_size #sub: 10%, del + ins: 5% N = int(self._config.max_predictions_per_seq * self._config.rich_prob) B, L = modeling.get_shape_list(inputs_tf) nlms = 0 bilm = None if self._config.use_bilm: with open(self._config.bilm_file, 'rb') as f: bilm = tf.constant(np.load(f), tf.int32) _, nlms = modeling.get_shape_list(bilm) #make multiple partitions for edit op splits_list = [] for i in range(B): one = tf.random.uniform([N * 4], 1, lens_tf[i], tf.int32) one, _ = tf.unique(one) one = tf.cond(tf.less(tf.shape(one)[0], N * 2 + 1), lambda: tf.expand_dims(tf.range(1, N * 2 + 2), 0), lambda: tf.sort(tf.reshape(one[: N * 2 + 1], [1, N * 2 + 1]))) splits_list.append(one[:, 2::2]) splits_tf = tf.concat(splits_list, 0) splits_up = tf.concat([splits_tf, tf.expand_dims(tf.constant([L] * B, tf.int32), 1)], 1) splits_lo = tf.concat([tf.expand_dims(tf.constant([0] * B, tf.int32), 1), splits_tf], 1) size_splits = splits_up - splits_lo #update the inputs and labels giving random insertion and deletion new_labels_list = [] new_inputs_list = [] for i in range(B): inputs_splits = tf.split(inputs_tf[i, :], size_splits[i, :]) labels_splits = tf.split(labels_tf[i, :], size_splits[i, :]) one_inputs = [] one_labels = [] size_split = len(inputs_splits) inputs_end = inputs_splits[-1] labels_end = labels_splits[-1] for j in range(size_split-1): inputs = inputs_splits[j] labels = labels_splits[j] #label 1 for substistution rand_op = random.randint(2, self._config.num_preds - 1) if rand_op == 2: #label 2 for insertion if bilm is None: #noise insert_tok = tf.random.uniform([1], 1, V, tf.int32) else: #2-gram prediction insert_tok = tf.expand_dims(bilm[inputs[-1], random.randint(0, nlms-1)], 0) is_end_valid = tf.less_equal(2, tf.shape(inputs_end)[0]) inputs = tf.cond(is_end_valid, lambda: tf.concat([inputs, insert_tok], 0), lambda: inputs) labels = tf.cond(is_end_valid, lambda: tf.concat([labels, tf.constant([2])], 0), lambda: labels) inputs_end = tf.cond(is_end_valid, lambda: inputs_end[:-1], lambda: inputs_end) labels_end = tf.cond(is_end_valid, lambda: labels_end[:-1], lambda: labels_end) elif rand_op == 3: #label 3 for deletion labels = tf.concat([labels[:-2], tf.constant([3])], 0) inputs = inputs[:-1] inputs_end = tf.concat([inputs_end, tf.constant([0])], 0) labels_end = tf.concat([labels_end, tf.constant([0])], 0) elif rand_op == 4: #label 4 for swapping labels = tf.concat([labels[:-1], tf.constant([4])], 0) inputs = tf.concat([inputs[:-2], [inputs[-1]], [inputs[-2]]], 0) one_labels.append(labels) one_inputs.append(inputs) one_inputs.append(inputs_end) one_labels.append(labels_end) one_inputs_tf = tf.concat(one_inputs, 0) one_labels_tf = tf.concat(one_labels, 0) one_inputs_tf = tf.cond(tf.less(lens_tf[i], N * 2 + 1), lambda: inputs_tf[i, :], lambda: one_inputs_tf) one_labels_tf = tf.cond(tf.less(lens_tf[i], N * 2 + 1), lambda: labels_tf[i, :], lambda: one_labels_tf) new_inputs_list.append(tf.expand_dims(one_inputs_tf, 0)) new_labels_list.append(tf.expand_dims(one_labels_tf, 0)) new_inputs_tf = tf.concat(new_inputs_list, 0) new_labels_tf = tf.concat(new_labels_list, 0) new_input_mask = tf.cast(tf.not_equal(new_inputs_tf, 0), tf.int32) updated_inputs = pretrain_data.get_updated_inputs( fake_data.inputs, input_ids=new_inputs_tf, input_mask=new_input_mask) RicherData = collections.namedtuple("RicherData", [ "inputs", "is_fake_tokens", "sampled_tokens"]) return RicherData(inputs=updated_inputs, is_fake_tokens=new_labels_tf, sampled_tokens=fake_data.sampled_tokens)
def __init__(self, config: configure_pretraining.PretrainingConfig, features, is_training): # Set up model config self._config = config self._bert_config = training_utils.get_bert_config(config) self._teacher_config = training_utils.get_teacher_config(config) embedding_size = (self._bert_config.hidden_size if config.embedding_size is None else config.embedding_size) tokenizer = tokenization.FullTokenizer( config.vocab_file, do_lower_case=config.do_lower_case) self._vocab = tokenizer.vocab self._inv_vocab = tokenizer.inv_vocab # Mask the input inputs = pretrain_data.features_to_inputs(features) old_model = self._build_transformer(inputs, is_training, embedding_size=embedding_size) input_states = old_model.get_sequence_output() input_states = tf.stop_gradient(input_states) teacher_output = self._build_teacher(input_states, inputs, is_training, embedding_size=embedding_size) # calculate the proposal distribution action_prob = teacher_output.action_probs #pi(x_i) coin_toss = tf.random.uniform([]) log_q, masked_inputs = self._sample_masking_subset(inputs, action_prob) if config.masking_strategy == pretrain_helpers.MIX_ADV_STRATEGY: random_masked_input = pretrain_helpers.mask( config, pretrain_data.features_to_inputs(features), config.mask_prob) B, L = modeling.get_shape_list(inputs.input_ids) N = config.max_predictions_per_seq strategy_prob = tf.random.uniform([B]) strategy_prob = tf.expand_dims( tf.cast(tf.greater(strategy_prob, 0.5), tf.int32), 1) l_strategy_prob = tf.tile(strategy_prob, [1, L]) n_strategy_prob = tf.tile(strategy_prob, [1, N]) mix_input_ids = masked_inputs.input_ids * l_strategy_prob + random_masked_input.input_ids * ( 1 - l_strategy_prob) mix_masked_lm_positions = masked_inputs.masked_lm_positions * n_strategy_prob + random_masked_input.masked_lm_positions * ( 1 - n_strategy_prob) mix_masked_lm_ids = masked_inputs.masked_lm_ids * n_strategy_prob + random_masked_input.masked_lm_ids * ( 1 - n_strategy_prob) n_strategy_prob = tf.cast(n_strategy_prob, tf.float32) mix_masked_lm_weights = masked_inputs.masked_lm_weights * n_strategy_prob + random_masked_input.masked_lm_weights * ( 1 - n_strategy_prob) mix_masked_inputs = pretrain_data.get_updated_inputs( inputs, input_ids=tf.stop_gradient(mix_input_ids), masked_lm_positions=mix_masked_lm_positions, masked_lm_ids=mix_masked_lm_ids, masked_lm_weights=mix_masked_lm_weights, tag_ids=inputs.tag_ids) masked_inputs = mix_masked_inputs # BERT model model = self._build_transformer(masked_inputs, is_training, reuse=tf.AUTO_REUSE, embedding_size=embedding_size) mlm_output = self._get_masked_lm_output(masked_inputs, model) self.total_loss = mlm_output.loss # Teacher reward is the -log p(x_S|x;B) reward = tf.stop_gradient( tf.reduce_mean(mlm_output.per_example_loss, 1)) self._baseline = tf.reduce_mean(reward, -1) self._std = tf.math.reduce_std(reward, -1) # Calculate teacher loss def compute_teacher_loss(log_q, reward, baseline, std): advantage = tf.abs((reward - baseline) / std) advantage = tf.stop_gradient(advantage) log_q = tf.Print(log_q, [log_q], "log_q: ") teacher_loss = tf.reduce_mean(-log_q * advantage) return teacher_loss teacher_loss = tf.cond( coin_toss < 0.1, lambda: compute_teacher_loss( log_q, reward, self._baseline, self._std), lambda: tf.constant(0.0)) self.total_loss = mlm_output.loss + teacher_loss self.teacher_loss = teacher_loss self.mlm_loss = mlm_output.loss # Evaluation` eval_fn_inputs = { "input_ids": masked_inputs.input_ids, "masked_lm_preds": mlm_output.preds, "mlm_loss": mlm_output.per_example_loss, "masked_lm_ids": masked_inputs.masked_lm_ids, "masked_lm_weights": masked_inputs.masked_lm_weights, "input_mask": masked_inputs.input_mask } eval_fn_keys = eval_fn_inputs.keys() eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys] """Computes the loss and accuracy of the model.""" d = {k: arg for k, arg in zip(eval_fn_keys, eval_fn_values)} metrics = dict() metrics["masked_lm_accuracy"] = tf.metrics.accuracy( labels=tf.reshape(d["masked_lm_ids"], [-1]), predictions=tf.reshape(d["masked_lm_preds"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) metrics["masked_lm_loss"] = tf.metrics.mean( values=tf.reshape(d["mlm_loss"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) self.eval_metrics = metrics