def test_get_prediction_loss_cosine(self):
        input_tensor_1 = tf.constant(
            [[0.5, 0.7, 0.8, 0.9, 0.1, 0.1], [0.1, 0.3, 0.3, 0.3, 0.1, 0.1]],
            dtype=tf.float32)
        input_tensor_2 = tf.constant(
            [[0.1, 0.2, 0.2, 0.2, 0.2, 0.1], [0.1, 0.4, 0.4, 0.4, 0.1, 0.1]],
            dtype=tf.float32)
        labels = tf.constant([0, 1.0], dtype=tf.float32)
        neg_to_pos_example_ratio = 1.0
        similarity_score_amplifier = 6.0
        loss, per_example_loss, similarities = \
            loss_fns.get_prediction_loss_cosine(
                input_tensor_1=input_tensor_1,
                input_tensor_2=input_tensor_2,
                labels=labels,
                similarity_score_amplifier=similarity_score_amplifier,
                neg_to_pos_example_ratio=neg_to_pos_example_ratio)
        with tf.Session() as sess:
            sess.run([tf.global_variables_initializer()])
            loss_numpy = sess.run(loss)
            per_example_loss_numpy = sess.run(per_example_loss)
            similarities_numpy = sess.run(similarities)
            self.assertEqual(loss_numpy.shape, ())
            self.assertDTypeEqual(loss_numpy, np.float32)

            self.assertEqual(per_example_loss_numpy.shape, (2, ))
            self.assertDTypeEqual(per_example_loss_numpy, np.float32)

            self.assertEqual(similarities_numpy.shape, (2, ))
            self.assertDTypeEqual(similarities_numpy, np.float32)
예제 #2
0
def build_smith_dual_encoder(dual_encoder_config,
                             train_mode,
                             is_training,
                             input_ids_1,
                             input_mask_1,
                             masked_lm_positions_1,
                             masked_lm_ids_1,
                             masked_lm_weights_1,
                             input_ids_2,
                             input_mask_2,
                             masked_lm_positions_2,
                             masked_lm_ids_2,
                             masked_lm_weights_2,
                             use_one_hot_embeddings,
                             documents_match_labels,
                             debugging=False):
  """Build the dual encoder SMITH model.

  Args:
    dual_encoder_config: the configuration file for the dual encoder model.
    train_mode: string. The train mode of the current. It can be finetune,
      pretrain or joint_train.
    is_training: bool. Whether it in training mode.
    input_ids_1: int Tensor with shape [batch, max_seq_length]. The input ids of
      input examples of text 1.
    input_mask_1: int Tensor with shape [batch, max_seq_length]. The input masks
      of input examples of text 1.
    masked_lm_positions_1: int Tensor with shape [batch,
      max_predictions_per_seq]. The input masked LM prediction positions of
      input examples of text 1. This can be useful to compute the masked word
      prediction LM loss.
    masked_lm_ids_1: int Tensor with shape [batch, max_predictions_per_seq]. The
      input masked LM prediction ids of input examples of text 1. It is the
      ground truth in the masked word LM prediction task. This can be useful to
      compute the masked word prediction LM loss.
    masked_lm_weights_1: float Tensor with shape [batch,
      max_predictions_per_seq]. The input masked LM prediction weights of input
      examples of text 1.
    input_ids_2: int Tensor with shape [batch, max_seq_length]. The input ids of
      input examples of text 2.
    input_mask_2: int Tensor with shape [batch, max_seq_length]. The input masks
      of input examples of text 2.
    masked_lm_positions_2: int Tensor with shape [batch,
      max_predictions_per_seq]. The input masked LM prediction positions of
      input examples of text 2. This can be useful to compute the masked word
      prediction LM loss.
    masked_lm_ids_2: int Tensor with shape [batch, max_predictions_per_seq]. The
      input masked LM prediction ids of input examples of text 2. It is the
      ground truth in the masked word LM prediction task. This can be useful to
      compute the masked word prediction LM loss.
    masked_lm_weights_2: float Tensor with shape [batch,
      max_predictions_per_seq]. The input masked LM prediction weights of input
      examples of text 2.
    use_one_hot_embeddings: bool. Whether use one hot embeddings.
    documents_match_labels: float Tensor with shape [batch]. The ground truth
      labels for the input examples.
    debugging: bool. Whether it is in the debugging mode.

  Returns:
    The masked LM loss, per example LM loss, masked sentence LM loss, per
    example masked sentence LM loss, sequence representations, text matching
    loss, per example text matching loss, text matching logits, text matching
    probabilities and text matching log probabilities.

  Raises:
    ValueError: if the doc_rep_combine_mode in dual_encoder_config is invalid.
  """
  bert_config = modeling.BertConfig.from_json_file(
      dual_encoder_config.encoder_config.bert_config_file)
  doc_bert_config = modeling.BertConfig.from_json_file(
      dual_encoder_config.encoder_config.doc_bert_config_file)
  (input_sent_reps_doc_1_unmask, input_mask_doc_level_1_tensor,
   input_sent_reps_doc_2_unmask, input_mask_doc_level_2_tensor,
   masked_lm_loss_doc_1, masked_lm_loss_doc_2, masked_lm_example_loss_doc_1,
   masked_lm_example_loss_doc_2, masked_lm_weights_doc_1,
   masked_lm_weights_doc_2) = layers.learn_sent_reps_normal_loop(
       dual_encoder_config, is_training, train_mode, input_ids_1, input_mask_1,
       masked_lm_positions_1, masked_lm_ids_1, masked_lm_weights_1, input_ids_2,
       input_mask_2, masked_lm_positions_2, masked_lm_ids_2,
       masked_lm_weights_2, use_one_hot_embeddings)
  if debugging:
    input_mask_doc_level_1_tensor = tf.Print(
        input_mask_doc_level_1_tensor,
        [input_mask_doc_level_1_tensor, input_mask_doc_level_2_tensor],
        message="input_mask_doc_level_1_tensor in build_smith_dual_encoder",
        summarize=30)

  if dual_encoder_config.encoder_config.use_masked_sentence_lm_loss:
    batch_size_static = (
        dual_encoder_config.train_eval_config.train_batch_size if is_training
        else dual_encoder_config.train_eval_config.eval_batch_size)
    # Generates the sentence masked document represenations.
    with tf.variable_scope("mask_sent_in_doc", reuse=tf.AUTO_REUSE):
      # Randomly initialize a masked sentence vector and reuse it.
      # We also need to return the masked sentence position index to get the
      # ground truth labels for the masked positions. The shape of
      # sent_mask_embedding is [hidden].
      sent_mask_embedding = tf.get_variable(
          name="sentence_mask_embedding",
          shape=[bert_config.hidden_size],
          initializer=tf.truncated_normal_initializer(
              stddev=bert_config.initializer_range))
      # Output Shape: [batch, loop_sent_number_per_doc, hidden].
      (input_sent_reps_doc_1_masked, masked_sent_index_1,
       masked_sent_weight_1) = layers.get_doc_rep_with_masked_sent(
           input_sent_reps_doc=input_sent_reps_doc_1_unmask,
           sent_mask_embedding=sent_mask_embedding,
           input_mask_doc_level=input_mask_doc_level_1_tensor,
           batch_size_static=batch_size_static,
           max_masked_sent_per_doc=dual_encoder_config.encoder_config
           .max_masked_sent_per_doc,
           loop_sent_number_per_doc=dual_encoder_config.encoder_config
           .loop_sent_number_per_doc)
      (input_sent_reps_doc_2_masked, masked_sent_index_2,
       masked_sent_weight_2) = layers.get_doc_rep_with_masked_sent(
           input_sent_reps_doc=input_sent_reps_doc_2_unmask,
           sent_mask_embedding=sent_mask_embedding,
           input_mask_doc_level=input_mask_doc_level_2_tensor,
           batch_size_static=batch_size_static,
           max_masked_sent_per_doc=dual_encoder_config.encoder_config
           .max_masked_sent_per_doc,
           loop_sent_number_per_doc=dual_encoder_config.encoder_config
           .loop_sent_number_per_doc)
    # Learn the document representations based on masked sentence embeddings.
    # Note that the variables in the DocBert model are not within the
    # "mask_sent_in_doc" variable scope.
    model_doc_1 = modeling.DocBertModel(
        config=doc_bert_config,
        is_training=is_training,
        input_reps=input_sent_reps_doc_1_masked,
        input_mask=input_mask_doc_level_1_tensor)
    model_doc_2 = modeling.DocBertModel(
        config=doc_bert_config,
        is_training=is_training,
        input_reps=input_sent_reps_doc_2_masked,
        input_mask=input_mask_doc_level_2_tensor)
    # Shape of masked_sent_lm_loss_1 [1].
    # Shape of masked_sent_lm_example_loss_1 is [batch *
    # max_predictions_per_seq].
    (masked_sent_lm_loss_1, masked_sent_per_example_loss_1,
     _) = layers.get_masked_sent_lm_output(doc_bert_config,
                                           model_doc_1.get_sequence_output(),
                                           input_sent_reps_doc_1_unmask,
                                           masked_sent_index_1,
                                           masked_sent_weight_1)
    (masked_sent_lm_loss_2, masked_sent_per_example_loss_2,
     _) = layers.get_masked_sent_lm_output(doc_bert_config,
                                           model_doc_2.get_sequence_output(),
                                           input_sent_reps_doc_2_unmask,
                                           masked_sent_index_2,
                                           masked_sent_weight_2)
  else:
    # Learn the document representations based on unmasked sentence embeddings.
    model_doc_1 = modeling.DocBertModel(
        config=doc_bert_config,
        is_training=is_training,
        input_reps=input_sent_reps_doc_1_unmask,
        input_mask=input_mask_doc_level_1_tensor)
    model_doc_2 = modeling.DocBertModel(
        config=doc_bert_config,
        is_training=is_training,
        input_reps=input_sent_reps_doc_2_unmask,
        input_mask=input_mask_doc_level_2_tensor)
    masked_sent_lm_loss_1 = 0
    masked_sent_lm_loss_2 = 0
    masked_sent_per_example_loss_1 = tf.zeros(1)
    masked_sent_per_example_loss_2 = tf.zeros(1)
    masked_sent_weight_1 = tf.zeros(1)
    masked_sent_weight_2 = tf.zeros(1)

  with tf.variable_scope("seq_rep_from_bert_doc_dense", reuse=tf.AUTO_REUSE):
    normalized_doc_rep_1 = layers.get_seq_rep_from_bert(model_doc_1)
    normalized_doc_rep_2 = layers.get_seq_rep_from_bert(model_doc_2)

    # We also dump the contextualized sentence embedding output by document
    # level Transformer model. These representations maybe useful for sentence
    # level tasks.
    output_sent_reps_doc_1 = model_doc_1.get_sequence_output()
    output_sent_reps_doc_2 = model_doc_2.get_sequence_output()

  # Here we support multiple modes to generate the final document
  # representations based on the word/sentence/document level representations
  # 1. normal: only use the document level representation as the final document
  # representations.
  # 2. sum_concat: firstly compute the sum of all sentence level repsentations.
  # Then concatenate the sum vector with the document level representations.
  # 3. mean_concat: firstly compute the mean of all sentence level
  # repsentations. Then concatenate the mean vector with the document level
  # representations.
  # 4. attention: firstly compute the weighted sum of sentence level
  # representations with attention mechanism, then concatenate the weighted sum
  # vector with the document level representations.
  # The document level mask is to indicate whether each sentence is
  # a real sentence (1) or a paded sentence (0). The shape of
  # input_mask_doc_level_1_tensor is [batch, max_doc_length_by_sentence]. The
  # shape of input_sent_reps_doc_1_unmask is
  # [batch, max_doc_length_by_sentence, hidden].
  final_doc_rep_combine_mode = dual_encoder_config.encoder_config.doc_rep_combine_mode
  if final_doc_rep_combine_mode == constants.DOC_COMBINE_NORMAL:
    final_doc_rep_1 = normalized_doc_rep_1
    final_doc_rep_2 = normalized_doc_rep_2
  elif final_doc_rep_combine_mode == constants.DOC_COMBINE_SUM_CONCAT:
    # Output Shape: [batch, 2*hidden].
    final_doc_rep_1 = tf.concat(
        [tf.reduce_sum(input_sent_reps_doc_1_unmask, 1), normalized_doc_rep_1],
        axis=1)
    final_doc_rep_2 = tf.concat(
        [tf.reduce_sum(input_sent_reps_doc_2_unmask, 1), normalized_doc_rep_2],
        axis=1)
  elif final_doc_rep_combine_mode == constants.DOC_COMBINE_MEAN_CONCAT:
    final_doc_rep_1 = tf.concat(
        [tf.reduce_mean(input_sent_reps_doc_1_unmask, 1), normalized_doc_rep_1],
        axis=1)
    final_doc_rep_2 = tf.concat(
        [tf.reduce_mean(input_sent_reps_doc_2_unmask, 1), normalized_doc_rep_2],
        axis=1)
  elif final_doc_rep_combine_mode == constants.DOC_COMBINE_ATTENTION:
    final_doc_rep_1 = tf.concat([
        layers.get_attention_weighted_sum(
            input_sent_reps_doc_1_unmask, bert_config, is_training,
            dual_encoder_config.encoder_config.doc_rep_combine_attention_size),
        normalized_doc_rep_1
    ],
                                axis=1)
    final_doc_rep_2 = tf.concat([
        layers.get_attention_weighted_sum(
            input_sent_reps_doc_2_unmask, bert_config, is_training,
            dual_encoder_config.encoder_config.doc_rep_combine_attention_size),
        normalized_doc_rep_2
    ],
                                axis=1)
  else:
    raise ValueError("Only normal, sum_concat, mean_concat and attention are"
                     " supported: %s" % final_doc_rep_combine_mode)
  (siamese_loss, siamese_example_loss,
   siamese_logits) = loss_fns.get_prediction_loss_cosine(
       input_tensor_1=final_doc_rep_1,
       input_tensor_2=final_doc_rep_2,
       labels=documents_match_labels,
       similarity_score_amplifier=dual_encoder_config.loss_config
       .similarity_score_amplifier,
       neg_to_pos_example_ratio=dual_encoder_config.train_eval_config
       .neg_to_pos_example_ratio)

  # The shape of masked_lm_loss_doc is [1].
  # The shape of masked_lm_example_loss_doc is [batch * max_predictions_per_seq,
  # max_doc_length_by_sentence].
  return (masked_lm_loss_doc_1, masked_lm_loss_doc_2,
          masked_lm_example_loss_doc_1, masked_lm_example_loss_doc_2,
          masked_lm_weights_doc_1, masked_lm_weights_doc_2,
          masked_sent_lm_loss_1, masked_sent_lm_loss_2,
          masked_sent_per_example_loss_1, masked_sent_per_example_loss_2,
          masked_sent_weight_1, masked_sent_weight_2, final_doc_rep_1,
          final_doc_rep_2, input_sent_reps_doc_1_unmask,
          input_sent_reps_doc_2_unmask, output_sent_reps_doc_1,
          output_sent_reps_doc_2, siamese_loss, siamese_example_loss,
          siamese_logits)