def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, label_ids, label_weights): """Get loss and log probs for the masked LM.""" # Output Shape: [batch * max_predictions_per_seq, hidden]. input_tensor = gather_indexes(input_tensor, positions) with tf.variable_scope("cls/word_predictions", reuse=tf.AUTO_REUSE): # We apply one more non-linear transformation before the output layer. # This matrix is not used after pre-training. with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=bert_config.hidden_size, activation=modeling.get_activation(bert_config.hidden_act), kernel_initializer=modeling.create_initializer( bert_config.initializer_range)) # Output Shape: [batch * max_predictions_per_seq, hidden]. input_tensor = modeling.layer_norm(input_tensor) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. output_bias = tf.get_variable("output_bias", shape=[bert_config.vocab_size], initializer=tf.zeros_initializer()) # Shape of input_tensor [batch * max_predictions_per_seq, embedding_size]. # Shape of output_weights (embed table) is [vocab_size, embedding_size]. # In the current Bert implementation: embedding_size = hidden. logits = tf.matmul(input_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) # Output Shape: [batch * max_predictions_per_seq, vocab_size]. log_probs = tf.nn.log_softmax(logits, axis=-1) # Output Shape: [batch * max_predictions_per_seq]. label_ids = tf.reshape(label_ids, [-1]) # Output Shape: [batch * max_predictions_per_seq]. label_weights = tf.reshape(label_weights, [-1]) # Output Shape: [batch * max_predictions_per_seq, vocab_size]. one_hot_labels = tf.one_hot(label_ids, depth=bert_config.vocab_size, dtype=tf.float32) # The `positions` tensor might be zero-padded (if the sequence is too # short to have the maximum number of predictions). The `label_weights` # tensor has a value of 1.0 for every real prediction and 0.0 for the # padding predictions. # Output Shape: [batch * max_predictions_per_seq]. per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # Output Shape: [1]. numerator = tf.reduce_sum(label_weights * per_example_loss) # Output Shape: [1]. denominator = tf.reduce_sum(label_weights) + 1e-5 # Output Shape: [1]. loss = numerator / denominator # Shape of loss [1]. # Shape of per_example_loss is [batch * max_predictions_per_seq]. return (loss, per_example_loss, log_probs)
def get_masked_sent_lm_output(bert_config, input_tensor, cur_sent_reps_doc_unmask, sent_masked_positions, sent_masked_weights, debugging=False): """Get the sentence level masked LM loss. Args: bert_config: BertConfig object. The configuration file for the document level BERT model. input_tensor: float Tensor. The contextualized representations of all sentences learned by the document level BERT model. The shape is [batch, loop_sent_number_per_doc, hidden]. This is the model prediction. cur_sent_reps_doc_unmask: float Tensor. The unmasked sentence representations of the current document. The shape is [batch, loop_sent_number_per_doc, hidden]. This is the source of the ground truth and negative examples in the masked sentence prediction. sent_masked_positions: int Tensor. The masked sentence positions in the current document. The shape is [batch, max_masked_sent_per_doc]. sent_masked_weights: float Tensor. The masked sentence weights in the current document. The shape is [batch, max_masked_sent_per_doc]. debugging: bool. Whether it is in the debugging mode. Returns: The masked sentence LM loss and the mask sentence LM loss per example. """ # The current method for masked sentence prediction: we approach this problem # as a multi-class classification problem similar to the masked word LM task. # For each masked sentence position, the sentence in the current position is # the positive example. The other co-masked sentences in the current document # and in the other documents of the same batch are the negative examples. We # compute the cross entropy loss over the sentence prediction task following # the implementation of the masked word LM loss in the BERT model. input_tensor_shape = modeling.get_shape_list(input_tensor) batch_size = input_tensor_shape[0] masked_position_shape = modeling.get_shape_list(sent_masked_positions) max_predictions_per_seq = masked_position_shape[1] # In the context of masked sentence prediction, the max_predictions_per_seq # is the same with max_masked_sent_per_doc. # Output Shape: [batch * max_predictions_per_seq, hidden]. # Input_tensor is the model prediction for each position. input_tensor = gather_indexes(input_tensor, sent_masked_positions) # Independent_sent_embeddings is the ground truth input sentence embeddings # for the document level BERT model. The output shape is [batch * # max_predictions_per_seq, hidden]. independent_sent_embeddings = gather_indexes(cur_sent_reps_doc_unmask, sent_masked_positions) with tf.variable_scope("cls/sent_predictions", reuse=tf.AUTO_REUSE): # We apply one more non-linear transformation before the output layer. # This matrix is not used after pre-training. with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=bert_config.hidden_size, activation=modeling.get_activation(bert_config.hidden_act), kernel_initializer=modeling.create_initializer( bert_config.initializer_range)) # Output Shape: [batch * max_predictions_per_seq, hidden]. input_tensor = modeling.layer_norm(input_tensor) # The output weights are the same as the input embeddings, but there is # an output-only bias for each predicted position. output_bias = tf.get_variable( "output_bias", shape=[batch_size * max_predictions_per_seq], initializer=tf.zeros_initializer()) # Shape of input_tensor [batch * max_predictions_per_seq, hidden]. # Shape of independent_sent_embeddings is [batch * max_predictions_per_seq, # hidden]. # Shape of logits: [batch * max_predictions_per_seq, # batch * max_predictions_per_seq]. logits = tf.matmul( input_tensor, independent_sent_embeddings, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) # Output Shape: [batch * max_predictions_per_seq, # batch * max_predictions_per_seq]. log_probs = tf.nn.log_softmax(logits, axis=-1) # Output Shape: [batch * max_predictions_per_seq]. # Double checked the setting of label_ids here. The label_ids # should be the label index in the "sentence vocabulary". Thus if batch=32, # max_predictions_per_seq = 2, then label ids should be like # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 63]. For the ground truth one hot # label matrix, only the values in the diagonal positions are 1. All the # other positions should be 0. label_ids = tf.range( 0, batch_size * max_predictions_per_seq, dtype=tf.int32) if debugging: label_ids = tf.Print( label_ids, [label_ids], message="label_ids in get_masked_sent_lm_output", summarize=30) # Output Shape: [batch * max_predictions_per_seq]. # The label_weights is the flatten vector based on sent_masked_weights, # where the weight is 1.0 for sampled real sentences and 0.0 for sampled # masked sentences. label_weights = tf.reshape(sent_masked_weights, [-1]) # Output Shape: [batch * max_predictions_per_seq, # batch * max_predictions_per_seq]. one_hot_labels = tf.one_hot( label_ids, depth=batch_size * max_predictions_per_seq, dtype=tf.float32) # Output Shape: [batch * max_predictions_per_seq]. per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # Output Shape: [1]. numerator = tf.reduce_sum(label_weights * per_example_loss) # Output Shape: [1]. denominator = tf.reduce_sum(label_weights) + 1e-5 # Output Shape: [1]. loss = numerator / denominator # Shape of loss [1]. # Shape of per_example_loss is [batch * max_predictions_per_seq]. return (loss, per_example_loss, log_probs)