def test_get_prediction_loss_cosine(self): input_tensor_1 = tf.constant( [[0.5, 0.7, 0.8, 0.9, 0.1, 0.1], [0.1, 0.3, 0.3, 0.3, 0.1, 0.1]], dtype=tf.float32) input_tensor_2 = tf.constant( [[0.1, 0.2, 0.2, 0.2, 0.2, 0.1], [0.1, 0.4, 0.4, 0.4, 0.1, 0.1]], dtype=tf.float32) labels = tf.constant([0, 1.0], dtype=tf.float32) neg_to_pos_example_ratio = 1.0 similarity_score_amplifier = 6.0 loss, per_example_loss, similarities = \ loss_fns.get_prediction_loss_cosine( input_tensor_1=input_tensor_1, input_tensor_2=input_tensor_2, labels=labels, similarity_score_amplifier=similarity_score_amplifier, neg_to_pos_example_ratio=neg_to_pos_example_ratio) with tf.Session() as sess: sess.run([tf.global_variables_initializer()]) loss_numpy = sess.run(loss) per_example_loss_numpy = sess.run(per_example_loss) similarities_numpy = sess.run(similarities) self.assertEqual(loss_numpy.shape, ()) self.assertDTypeEqual(loss_numpy, np.float32) self.assertEqual(per_example_loss_numpy.shape, (2, )) self.assertDTypeEqual(per_example_loss_numpy, np.float32) self.assertEqual(similarities_numpy.shape, (2, )) self.assertDTypeEqual(similarities_numpy, np.float32)
def build_smith_dual_encoder(dual_encoder_config, train_mode, is_training, input_ids_1, input_mask_1, masked_lm_positions_1, masked_lm_ids_1, masked_lm_weights_1, input_ids_2, input_mask_2, masked_lm_positions_2, masked_lm_ids_2, masked_lm_weights_2, use_one_hot_embeddings, documents_match_labels, debugging=False): """Build the dual encoder SMITH model. Args: dual_encoder_config: the configuration file for the dual encoder model. train_mode: string. The train mode of the current. It can be finetune, pretrain or joint_train. is_training: bool. Whether it in training mode. input_ids_1: int Tensor with shape [batch, max_seq_length]. The input ids of input examples of text 1. input_mask_1: int Tensor with shape [batch, max_seq_length]. The input masks of input examples of text 1. masked_lm_positions_1: int Tensor with shape [batch, max_predictions_per_seq]. The input masked LM prediction positions of input examples of text 1. This can be useful to compute the masked word prediction LM loss. masked_lm_ids_1: int Tensor with shape [batch, max_predictions_per_seq]. The input masked LM prediction ids of input examples of text 1. It is the ground truth in the masked word LM prediction task. This can be useful to compute the masked word prediction LM loss. masked_lm_weights_1: float Tensor with shape [batch, max_predictions_per_seq]. The input masked LM prediction weights of input examples of text 1. input_ids_2: int Tensor with shape [batch, max_seq_length]. The input ids of input examples of text 2. input_mask_2: int Tensor with shape [batch, max_seq_length]. The input masks of input examples of text 2. masked_lm_positions_2: int Tensor with shape [batch, max_predictions_per_seq]. The input masked LM prediction positions of input examples of text 2. This can be useful to compute the masked word prediction LM loss. masked_lm_ids_2: int Tensor with shape [batch, max_predictions_per_seq]. The input masked LM prediction ids of input examples of text 2. It is the ground truth in the masked word LM prediction task. This can be useful to compute the masked word prediction LM loss. masked_lm_weights_2: float Tensor with shape [batch, max_predictions_per_seq]. The input masked LM prediction weights of input examples of text 2. use_one_hot_embeddings: bool. Whether use one hot embeddings. documents_match_labels: float Tensor with shape [batch]. The ground truth labels for the input examples. debugging: bool. Whether it is in the debugging mode. Returns: The masked LM loss, per example LM loss, masked sentence LM loss, per example masked sentence LM loss, sequence representations, text matching loss, per example text matching loss, text matching logits, text matching probabilities and text matching log probabilities. Raises: ValueError: if the doc_rep_combine_mode in dual_encoder_config is invalid. """ bert_config = modeling.BertConfig.from_json_file( dual_encoder_config.encoder_config.bert_config_file) doc_bert_config = modeling.BertConfig.from_json_file( dual_encoder_config.encoder_config.doc_bert_config_file) (input_sent_reps_doc_1_unmask, input_mask_doc_level_1_tensor, input_sent_reps_doc_2_unmask, input_mask_doc_level_2_tensor, masked_lm_loss_doc_1, masked_lm_loss_doc_2, masked_lm_example_loss_doc_1, masked_lm_example_loss_doc_2, masked_lm_weights_doc_1, masked_lm_weights_doc_2) = layers.learn_sent_reps_normal_loop( dual_encoder_config, is_training, train_mode, input_ids_1, input_mask_1, masked_lm_positions_1, masked_lm_ids_1, masked_lm_weights_1, input_ids_2, input_mask_2, masked_lm_positions_2, masked_lm_ids_2, masked_lm_weights_2, use_one_hot_embeddings) if debugging: input_mask_doc_level_1_tensor = tf.Print( input_mask_doc_level_1_tensor, [input_mask_doc_level_1_tensor, input_mask_doc_level_2_tensor], message="input_mask_doc_level_1_tensor in build_smith_dual_encoder", summarize=30) if dual_encoder_config.encoder_config.use_masked_sentence_lm_loss: batch_size_static = ( dual_encoder_config.train_eval_config.train_batch_size if is_training else dual_encoder_config.train_eval_config.eval_batch_size) # Generates the sentence masked document represenations. with tf.variable_scope("mask_sent_in_doc", reuse=tf.AUTO_REUSE): # Randomly initialize a masked sentence vector and reuse it. # We also need to return the masked sentence position index to get the # ground truth labels for the masked positions. The shape of # sent_mask_embedding is [hidden]. sent_mask_embedding = tf.get_variable( name="sentence_mask_embedding", shape=[bert_config.hidden_size], initializer=tf.truncated_normal_initializer( stddev=bert_config.initializer_range)) # Output Shape: [batch, loop_sent_number_per_doc, hidden]. (input_sent_reps_doc_1_masked, masked_sent_index_1, masked_sent_weight_1) = layers.get_doc_rep_with_masked_sent( input_sent_reps_doc=input_sent_reps_doc_1_unmask, sent_mask_embedding=sent_mask_embedding, input_mask_doc_level=input_mask_doc_level_1_tensor, batch_size_static=batch_size_static, max_masked_sent_per_doc=dual_encoder_config.encoder_config .max_masked_sent_per_doc, loop_sent_number_per_doc=dual_encoder_config.encoder_config .loop_sent_number_per_doc) (input_sent_reps_doc_2_masked, masked_sent_index_2, masked_sent_weight_2) = layers.get_doc_rep_with_masked_sent( input_sent_reps_doc=input_sent_reps_doc_2_unmask, sent_mask_embedding=sent_mask_embedding, input_mask_doc_level=input_mask_doc_level_2_tensor, batch_size_static=batch_size_static, max_masked_sent_per_doc=dual_encoder_config.encoder_config .max_masked_sent_per_doc, loop_sent_number_per_doc=dual_encoder_config.encoder_config .loop_sent_number_per_doc) # Learn the document representations based on masked sentence embeddings. # Note that the variables in the DocBert model are not within the # "mask_sent_in_doc" variable scope. model_doc_1 = modeling.DocBertModel( config=doc_bert_config, is_training=is_training, input_reps=input_sent_reps_doc_1_masked, input_mask=input_mask_doc_level_1_tensor) model_doc_2 = modeling.DocBertModel( config=doc_bert_config, is_training=is_training, input_reps=input_sent_reps_doc_2_masked, input_mask=input_mask_doc_level_2_tensor) # Shape of masked_sent_lm_loss_1 [1]. # Shape of masked_sent_lm_example_loss_1 is [batch * # max_predictions_per_seq]. (masked_sent_lm_loss_1, masked_sent_per_example_loss_1, _) = layers.get_masked_sent_lm_output(doc_bert_config, model_doc_1.get_sequence_output(), input_sent_reps_doc_1_unmask, masked_sent_index_1, masked_sent_weight_1) (masked_sent_lm_loss_2, masked_sent_per_example_loss_2, _) = layers.get_masked_sent_lm_output(doc_bert_config, model_doc_2.get_sequence_output(), input_sent_reps_doc_2_unmask, masked_sent_index_2, masked_sent_weight_2) else: # Learn the document representations based on unmasked sentence embeddings. model_doc_1 = modeling.DocBertModel( config=doc_bert_config, is_training=is_training, input_reps=input_sent_reps_doc_1_unmask, input_mask=input_mask_doc_level_1_tensor) model_doc_2 = modeling.DocBertModel( config=doc_bert_config, is_training=is_training, input_reps=input_sent_reps_doc_2_unmask, input_mask=input_mask_doc_level_2_tensor) masked_sent_lm_loss_1 = 0 masked_sent_lm_loss_2 = 0 masked_sent_per_example_loss_1 = tf.zeros(1) masked_sent_per_example_loss_2 = tf.zeros(1) masked_sent_weight_1 = tf.zeros(1) masked_sent_weight_2 = tf.zeros(1) with tf.variable_scope("seq_rep_from_bert_doc_dense", reuse=tf.AUTO_REUSE): normalized_doc_rep_1 = layers.get_seq_rep_from_bert(model_doc_1) normalized_doc_rep_2 = layers.get_seq_rep_from_bert(model_doc_2) # We also dump the contextualized sentence embedding output by document # level Transformer model. These representations maybe useful for sentence # level tasks. output_sent_reps_doc_1 = model_doc_1.get_sequence_output() output_sent_reps_doc_2 = model_doc_2.get_sequence_output() # Here we support multiple modes to generate the final document # representations based on the word/sentence/document level representations # 1. normal: only use the document level representation as the final document # representations. # 2. sum_concat: firstly compute the sum of all sentence level repsentations. # Then concatenate the sum vector with the document level representations. # 3. mean_concat: firstly compute the mean of all sentence level # repsentations. Then concatenate the mean vector with the document level # representations. # 4. attention: firstly compute the weighted sum of sentence level # representations with attention mechanism, then concatenate the weighted sum # vector with the document level representations. # The document level mask is to indicate whether each sentence is # a real sentence (1) or a paded sentence (0). The shape of # input_mask_doc_level_1_tensor is [batch, max_doc_length_by_sentence]. The # shape of input_sent_reps_doc_1_unmask is # [batch, max_doc_length_by_sentence, hidden]. final_doc_rep_combine_mode = dual_encoder_config.encoder_config.doc_rep_combine_mode if final_doc_rep_combine_mode == constants.DOC_COMBINE_NORMAL: final_doc_rep_1 = normalized_doc_rep_1 final_doc_rep_2 = normalized_doc_rep_2 elif final_doc_rep_combine_mode == constants.DOC_COMBINE_SUM_CONCAT: # Output Shape: [batch, 2*hidden]. final_doc_rep_1 = tf.concat( [tf.reduce_sum(input_sent_reps_doc_1_unmask, 1), normalized_doc_rep_1], axis=1) final_doc_rep_2 = tf.concat( [tf.reduce_sum(input_sent_reps_doc_2_unmask, 1), normalized_doc_rep_2], axis=1) elif final_doc_rep_combine_mode == constants.DOC_COMBINE_MEAN_CONCAT: final_doc_rep_1 = tf.concat( [tf.reduce_mean(input_sent_reps_doc_1_unmask, 1), normalized_doc_rep_1], axis=1) final_doc_rep_2 = tf.concat( [tf.reduce_mean(input_sent_reps_doc_2_unmask, 1), normalized_doc_rep_2], axis=1) elif final_doc_rep_combine_mode == constants.DOC_COMBINE_ATTENTION: final_doc_rep_1 = tf.concat([ layers.get_attention_weighted_sum( input_sent_reps_doc_1_unmask, bert_config, is_training, dual_encoder_config.encoder_config.doc_rep_combine_attention_size), normalized_doc_rep_1 ], axis=1) final_doc_rep_2 = tf.concat([ layers.get_attention_weighted_sum( input_sent_reps_doc_2_unmask, bert_config, is_training, dual_encoder_config.encoder_config.doc_rep_combine_attention_size), normalized_doc_rep_2 ], axis=1) else: raise ValueError("Only normal, sum_concat, mean_concat and attention are" " supported: %s" % final_doc_rep_combine_mode) (siamese_loss, siamese_example_loss, siamese_logits) = loss_fns.get_prediction_loss_cosine( input_tensor_1=final_doc_rep_1, input_tensor_2=final_doc_rep_2, labels=documents_match_labels, similarity_score_amplifier=dual_encoder_config.loss_config .similarity_score_amplifier, neg_to_pos_example_ratio=dual_encoder_config.train_eval_config .neg_to_pos_example_ratio) # The shape of masked_lm_loss_doc is [1]. # The shape of masked_lm_example_loss_doc is [batch * max_predictions_per_seq, # max_doc_length_by_sentence]. return (masked_lm_loss_doc_1, masked_lm_loss_doc_2, masked_lm_example_loss_doc_1, masked_lm_example_loss_doc_2, masked_lm_weights_doc_1, masked_lm_weights_doc_2, masked_sent_lm_loss_1, masked_sent_lm_loss_2, masked_sent_per_example_loss_1, masked_sent_per_example_loss_2, masked_sent_weight_1, masked_sent_weight_2, final_doc_rep_1, final_doc_rep_2, input_sent_reps_doc_1_unmask, input_sent_reps_doc_2_unmask, output_sent_reps_doc_1, output_sent_reps_doc_2, siamese_loss, siamese_example_loss, siamese_logits)