示例#1
0
 def create_and_check_dpr_reader(
     self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     model = DPRReader(config=config)
     model.to(torch_device)
     model.eval()
     result = model(input_ids, attention_mask=input_mask,)
     self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
     self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
     self.parent.assertListEqual(list(result["relevance_logits"].size()), [self.batch_size])
 def load_dpr_model(self):
     model = DPRReader(
         DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
     print("Loading DPR reader from {}".format(self.src_file))
     saved_state = load_states_from_checkpoint(self.src_file)
     state_dict = {}
     for key, value in saved_state.model_dict.items():
         if key.startswith(
                 "encoder.") and not key.startswith("encoder.encode_proj"):
             key = "encoder.bert_model." + key[len("encoder."):]
         state_dict[key] = value
     model.span_predictor.load_state_dict(state_dict)
     return model
    def create_and_check_reader(
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = DPRReader(config=config)
        model.to(torch_device)
        model.eval()
        result = model(
            input_ids,
            attention_mask=input_mask,
        )

        self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
        self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
        self.parent.assertEqual(result.relevance_logits.shape, (self.batch_size,))
 def load_dpr_model(self):
     model = DPRReader(
         DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
     print("Loading DPR reader from {}".format(self.src_file))
     saved_state = load_states_from_checkpoint(self.src_file)
     # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3
     state_dict = {
         "encoder.bert_model.embeddings.position_ids":
         model.span_predictor.encoder.bert_model.embeddings.position_ids
     }
     for key, value in saved_state.model_dict.items():
         if key.startswith(
                 "encoder.") and not key.startswith("encoder.encode_proj"):
             key = "encoder.bert_model." + key[len("encoder."):]
         state_dict[key] = value
     model.span_predictor.load_state_dict(state_dict)
     return model