def create_model(self):
            input_ids = BertModelTest.ids_tensor(
                [self.batch_size, self.seq_length], self.vocab_size)

            input_mask = None
            if self.use_input_mask:
                input_mask = BertModelTest.ids_tensor(
                    [self.batch_size, self.seq_length], vocab_size=2)

            token_type_ids = None
            if self.use_token_type_ids:
                token_type_ids = BertModelTest.ids_tensor(
                    [self.batch_size, self.seq_length], self.type_vocab_size)

            config = modeling.BertConfig(
                vocab_size=self.vocab_size,
                hidden_size=self.hidden_size,
                num_hidden_layers=self.num_hidden_layers,
                num_attention_heads=self.num_attention_heads,
                intermediate_size=self.intermediate_size,
                hidden_act=self.hidden_act,
                hidden_dropout_prob=self.hidden_dropout_prob,
                attention_probs_dropout_prob=self.attention_probs_dropout_prob,
                max_position_embeddings=self.max_position_embeddings,
                type_vocab_size=self.type_vocab_size,
                initializer_range=self.initializer_range)

            model = modeling.BertModel(config=config,
                                       is_training=self.is_training,
                                       input_ids=input_ids,
                                       input_mask=input_mask,
                                       token_type_ids=token_type_ids,
                                       scope=self.scope)

            outputs = {
                "embedding_output": model.get_embedding_output(),
                "sequence_output": model.get_sequence_output(),
                "pooled_output": model.get_pooled_output(),
                "all_encoder_layers": model.get_all_encoder_layers(),
            }
            return outputs
예제 #2
0
def get_sent_reps_masks_normal_loop(sent_index,
                                    input_sent_reps_doc,
                                    input_mask_doc_level,
                                    masked_lm_loss_doc,
                                    masked_lm_example_loss_doc,
                                    masked_lm_weights_doc,
                                    dual_encoder_config,
                                    is_training,
                                    train_mode,
                                    input_ids,
                                    input_mask,
                                    masked_lm_positions,
                                    masked_lm_ids,
                                    masked_lm_weights,
                                    use_one_hot_embeddings,
                                    debugging=False):
  """Get the sentence encodings, mask ids and masked word LM loss.

  Args:
      sent_index: The index of the current looped sentence.
      input_sent_reps_doc: The representations of all sentences in the doc
        learned by BERT.
      input_mask_doc_level: The document level input masks, which indicates
        whether a sentence is a real sentence or a padded sentence.
      masked_lm_loss_doc: The sum of all the masked word LM loss.
      masked_lm_example_loss_doc: The per example masked word LM loss.
      masked_lm_weights_doc: the weights of the maksed LM words. If the position
        is corresponding to a real masked word, it is 1.0; It is a padded mask,
        the weight is 0.
      dual_encoder_config: The config of the dual encoder.
      is_training: Whether it is in the training mode.
      train_mode: string. The train mode which can be finetune, joint_train, or
        pretrain.
      input_ids: The ids of the input tokens.
      input_mask: The mask of the input tokens.
      masked_lm_positions: The positions of the masked words in the language
        model training.
      masked_lm_ids: The ids of the masked words in LM model training.
      masked_lm_weights: The weights of the masked words in LM model training.
      use_one_hot_embeddings: Whether use one hot embedding. It should be true
        for the runs on TPUs.
      debugging: bool. Whether it is in the debugging mode.

  Returns:
    A list of tensors on the learned sentence representations and the masked
    word LM loss.
  """
  # Collect token information for the current sentence.
  bert_config = modeling.BertConfig.from_json_file(
      dual_encoder_config.encoder_config.bert_config_file)
  max_sent_length_by_word = dual_encoder_config.encoder_config.max_sent_length_by_word
  sent_bert_trainable = dual_encoder_config.encoder_config.sent_bert_trainable
  max_predictions_per_seq = dual_encoder_config.encoder_config.max_predictions_per_seq
  sent_start = sent_index * max_sent_length_by_word
  input_ids_cur_sent = tf.slice(input_ids, [0, sent_start],
                                [-1, max_sent_length_by_word])
  # Output shape: [batch, max_sent_length_by_word].
  input_mask_cur_sent = tf.slice(input_mask, [0, sent_start],
                                 [-1, max_sent_length_by_word])
  # Output Shape:  [batch].
  input_mask_cur_sent_max = tf.reduce_max(input_mask_cur_sent, 1)
  # Output Shape:  [loop_sent_number_per_doc, batch].
  input_mask_doc_level.append(input_mask_cur_sent_max)
  if debugging:
    input_ids_cur_sent = tf.Print(
        input_ids_cur_sent, [input_ids_cur_sent, input_mask_cur_sent],
        message="input_ids_cur_sent in get_sent_reps_masks_lm_loss",
        summarize=20)
  model = modeling.BertModel(
      config=bert_config,
      is_training=is_training,
      input_ids=input_ids_cur_sent,
      input_mask=input_mask_cur_sent,
      use_one_hot_embeddings=use_one_hot_embeddings,
      sent_bert_trainable=sent_bert_trainable)
  with tf.variable_scope("seq_rep_from_bert_sent_dense", reuse=tf.AUTO_REUSE):
    normalized_siamese_input_tensor = get_seq_rep_from_bert(model)
  input_sent_reps_doc.append(normalized_siamese_input_tensor)

  if (train_mode == constants.TRAIN_MODE_PRETRAIN or
      train_mode == constants.TRAIN_MODE_JOINT_TRAIN):
    # Collect masked token information for the current sentence.
    sent_mask_lm_token_start = sent_index * max_predictions_per_seq
    # Output shape: [batch, max_predictions_per_seq].
    masked_lm_positions_cur_sent = tf.slice(masked_lm_positions,
                                            [0, sent_mask_lm_token_start],
                                            [-1, max_predictions_per_seq])
    masked_lm_ids_cur_sent = tf.slice(masked_lm_ids,
                                      [0, sent_mask_lm_token_start],
                                      [-1, max_predictions_per_seq])
    masked_lm_weights_cur_sent = tf.slice(masked_lm_weights,
                                          [0, sent_mask_lm_token_start],
                                          [-1, max_predictions_per_seq])
    # Since in the processed data of smith model, the masked lm positions are
    # global indices started from the 1st token of the whole sequence, we need
    # to transform this global position to a local position for the current
    # sentence. The position index is started from 0.
    # Local_index = global_index mod max_sent_length_by_word.
    masked_lm_positions_cur_sent = tf.mod(masked_lm_positions_cur_sent,
                                          max_sent_length_by_word)
    # Shape of masked_lm_loss_cur_sent [1].
    # Shape of masked_lm_example_loss_cur_sent is [batch,
    # max_predictions_per_seq].
    (masked_lm_loss_cur_sent, masked_lm_example_loss_cur_sent,
     _) = get_masked_lm_output(bert_config, model.get_sequence_output(),
                               model.get_embedding_table(),
                               masked_lm_positions_cur_sent,
                               masked_lm_ids_cur_sent,
                               masked_lm_weights_cur_sent)
    # Output Shape: [1].
    masked_lm_loss_doc += masked_lm_loss_cur_sent
    # Output Shape: [loop_sent_number_per_doc, batch * max_predictions_per_seq].
    masked_lm_example_loss_doc.append(masked_lm_example_loss_cur_sent)
    # Output Shape: [loop_sent_number_per_doc, batch, max_predictions_per_seq].
    masked_lm_weights_doc.append(masked_lm_weights_cur_sent)
  return (input_sent_reps_doc, input_mask_doc_level, masked_lm_loss_doc,
          masked_lm_example_loss_doc, masked_lm_weights_doc)