def create_model(self): input_ids = BertModelTest.ids_tensor( [self.batch_size, self.seq_length], self.vocab_size) input_mask = None if self.use_input_mask: input_mask = BertModelTest.ids_tensor( [self.batch_size, self.seq_length], vocab_size=2) token_type_ids = None if self.use_token_type_ids: token_type_ids = BertModelTest.ids_tensor( [self.batch_size, self.seq_length], self.type_vocab_size) config = modeling.BertConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, intermediate_size=self.intermediate_size, hidden_act=self.hidden_act, hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, initializer_range=self.initializer_range) model = modeling.BertModel(config=config, is_training=self.is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids, scope=self.scope) outputs = { "embedding_output": model.get_embedding_output(), "sequence_output": model.get_sequence_output(), "pooled_output": model.get_pooled_output(), "all_encoder_layers": model.get_all_encoder_layers(), } return outputs
def get_sent_reps_masks_normal_loop(sent_index, input_sent_reps_doc, input_mask_doc_level, masked_lm_loss_doc, masked_lm_example_loss_doc, masked_lm_weights_doc, dual_encoder_config, is_training, train_mode, input_ids, input_mask, masked_lm_positions, masked_lm_ids, masked_lm_weights, use_one_hot_embeddings, debugging=False): """Get the sentence encodings, mask ids and masked word LM loss. Args: sent_index: The index of the current looped sentence. input_sent_reps_doc: The representations of all sentences in the doc learned by BERT. input_mask_doc_level: The document level input masks, which indicates whether a sentence is a real sentence or a padded sentence. masked_lm_loss_doc: The sum of all the masked word LM loss. masked_lm_example_loss_doc: The per example masked word LM loss. masked_lm_weights_doc: the weights of the maksed LM words. If the position is corresponding to a real masked word, it is 1.0; It is a padded mask, the weight is 0. dual_encoder_config: The config of the dual encoder. is_training: Whether it is in the training mode. train_mode: string. The train mode which can be finetune, joint_train, or pretrain. input_ids: The ids of the input tokens. input_mask: The mask of the input tokens. masked_lm_positions: The positions of the masked words in the language model training. masked_lm_ids: The ids of the masked words in LM model training. masked_lm_weights: The weights of the masked words in LM model training. use_one_hot_embeddings: Whether use one hot embedding. It should be true for the runs on TPUs. debugging: bool. Whether it is in the debugging mode. Returns: A list of tensors on the learned sentence representations and the masked word LM loss. """ # Collect token information for the current sentence. bert_config = modeling.BertConfig.from_json_file( dual_encoder_config.encoder_config.bert_config_file) max_sent_length_by_word = dual_encoder_config.encoder_config.max_sent_length_by_word sent_bert_trainable = dual_encoder_config.encoder_config.sent_bert_trainable max_predictions_per_seq = dual_encoder_config.encoder_config.max_predictions_per_seq sent_start = sent_index * max_sent_length_by_word input_ids_cur_sent = tf.slice(input_ids, [0, sent_start], [-1, max_sent_length_by_word]) # Output shape: [batch, max_sent_length_by_word]. input_mask_cur_sent = tf.slice(input_mask, [0, sent_start], [-1, max_sent_length_by_word]) # Output Shape: [batch]. input_mask_cur_sent_max = tf.reduce_max(input_mask_cur_sent, 1) # Output Shape: [loop_sent_number_per_doc, batch]. input_mask_doc_level.append(input_mask_cur_sent_max) if debugging: input_ids_cur_sent = tf.Print( input_ids_cur_sent, [input_ids_cur_sent, input_mask_cur_sent], message="input_ids_cur_sent in get_sent_reps_masks_lm_loss", summarize=20) model = modeling.BertModel( config=bert_config, is_training=is_training, input_ids=input_ids_cur_sent, input_mask=input_mask_cur_sent, use_one_hot_embeddings=use_one_hot_embeddings, sent_bert_trainable=sent_bert_trainable) with tf.variable_scope("seq_rep_from_bert_sent_dense", reuse=tf.AUTO_REUSE): normalized_siamese_input_tensor = get_seq_rep_from_bert(model) input_sent_reps_doc.append(normalized_siamese_input_tensor) if (train_mode == constants.TRAIN_MODE_PRETRAIN or train_mode == constants.TRAIN_MODE_JOINT_TRAIN): # Collect masked token information for the current sentence. sent_mask_lm_token_start = sent_index * max_predictions_per_seq # Output shape: [batch, max_predictions_per_seq]. masked_lm_positions_cur_sent = tf.slice(masked_lm_positions, [0, sent_mask_lm_token_start], [-1, max_predictions_per_seq]) masked_lm_ids_cur_sent = tf.slice(masked_lm_ids, [0, sent_mask_lm_token_start], [-1, max_predictions_per_seq]) masked_lm_weights_cur_sent = tf.slice(masked_lm_weights, [0, sent_mask_lm_token_start], [-1, max_predictions_per_seq]) # Since in the processed data of smith model, the masked lm positions are # global indices started from the 1st token of the whole sequence, we need # to transform this global position to a local position for the current # sentence. The position index is started from 0. # Local_index = global_index mod max_sent_length_by_word. masked_lm_positions_cur_sent = tf.mod(masked_lm_positions_cur_sent, max_sent_length_by_word) # Shape of masked_lm_loss_cur_sent [1]. # Shape of masked_lm_example_loss_cur_sent is [batch, # max_predictions_per_seq]. (masked_lm_loss_cur_sent, masked_lm_example_loss_cur_sent, _) = get_masked_lm_output(bert_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions_cur_sent, masked_lm_ids_cur_sent, masked_lm_weights_cur_sent) # Output Shape: [1]. masked_lm_loss_doc += masked_lm_loss_cur_sent # Output Shape: [loop_sent_number_per_doc, batch * max_predictions_per_seq]. masked_lm_example_loss_doc.append(masked_lm_example_loss_cur_sent) # Output Shape: [loop_sent_number_per_doc, batch, max_predictions_per_seq]. masked_lm_weights_doc.append(masked_lm_weights_cur_sent) return (input_sent_reps_doc, input_mask_doc_level, masked_lm_loss_doc, masked_lm_example_loss_doc, masked_lm_weights_doc)