Exemple #1
0
 def create_and_check_bert_for_question_answering(
     self,
     config,
     input_ids,
     token_type_ids,
     input_mask,
     sequence_labels,
     token_labels,
     choice_labels,
 ):
     model = BertForQuestionAnswering(config=config)
     model.eval()
     loss, start_logits, end_logits = model(input_ids, token_type_ids,
                                            input_mask, sequence_labels,
                                            sequence_labels)
     result = {
         "loss": loss,
         "start_logits": start_logits,
         "end_logits": end_logits
     }
     self.parent.assertListEqual(list(result["start_logits"].size()),
                                 [self.batch_size, self.seq_length])
     self.parent.assertListEqual(list(result["end_logits"].size()),
                                 [self.batch_size, self.seq_length])
     self.check_loss_output(result)
args = parser.parse_args()
args.n_gpu = torch.cuda.device_count()
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
device = torch.device(
    "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.device = device
tokenizer = BertTokenizer.from_pretrained(
    args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
    do_lower_case=False)
config = BertConfig.from_pretrained(
    args.tokenizer_name if args.tokenizer_name else args.model_name_or_path)
model = BertForQuestionAnswering(config)
model_state_dict = args.state_dict
model.load_state_dict(torch.load(model_state_dict))
model.to(args.device)
model.eval()
input_file = args.predict_file


def handle_file(input_file, context, question):
    with open(input_file, "r") as reader:
        orig_data = json.load(reader)
        orig_data["data"][0]['paragraphs'][0]['context'] = context
        for i in range(len(question)):
            orig_data["data"][0]['paragraphs'][0]['qas'][i][
                'question'] = question[i]
    with open(input_file, "w") as writer:
        writer.write(json.dumps(orig_data, indent=4) + "\n")


def run():
Exemple #3
0
class QA_BERT_SingleSpan():
    def __init__(self, model_state_dict) -> None:
        no_cuda = True
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese',
                                                       do_lower_case=False)
        config = BertConfig.from_pretrained('bert-base-chinese')
        self.model = BertForQuestionAnswering(config)
        self.model.load_state_dict(
            torch.load(model_state_dict, map_location='cpu'))
        self.model.to(self.device)
        self.model.eval()  # TODO

    def predict_old(self,
                    query,
                    search_size,
                    max_query_length,
                    max_answer_length,
                    max_seq_length,
                    n_best_size,
                    doc_stride,
                    verbose_logging,
                    es_index,
                    null_score_diff_threshold,
                    prefix=""):
        dataset, examples, features, recall_scores = build_dataset_example_feature(
            query, self.tokenizer, max_query_length, max_seq_length,
            doc_stride, search_size, es_index)
        eval_dataloader = DataLoader(dataset, sampler=None, batch_size=16)
        all_results = []
        for batch in eval_dataloader:
            batch = tuple(t.to(self.device) for t in batch)
            with torch.no_grad():
                inputs = {
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    'token_type_ids': batch[2]
                }
                example_indices = batch[3]
                outputs = self.model(**inputs)
            for i, example_index in enumerate(example_indices):
                eval_feature = features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                result = RawResult(unique_id=unique_id,
                                   start_logits=to_list(outputs[0][i]),
                                   end_logits=to_list(outputs[1][i]))
                all_results.append(result)
        all_predictions = write_predictions(examples, features, all_results,
                                            n_best_size, max_answer_length,
                                            True, None, None, None,
                                            verbose_logging, False,
                                            null_score_diff_threshold)
        return all_predictions, recall_scores

    def predict(self, doc=None, query=None):
        """
        function

        params
            doc : 输入文档
            query : 输入查询话术
        returns

        """

        if doc != None:
            pass
            # TODO
            # 插入文档
        else:
            dataset, examples, features, _ = build_dataset_example_feature_by_context_query(
                query, self.tokenizer, 16, 384, 128)

        eval_dataloader = DataLoader(dataset, sampler=None, batch_size=16)
        all_results = []
        for batch in eval_dataloader:
            self.model.eval()
            batch = tuple(t.to(self.device) for t in batch)
            with torch.no_grad():
                inputs = {
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    'token_type_ids': batch[2]
                }
                example_indices = batch[3]
                outputs = self.model(**inputs)
            for i, example_index in enumerate(example_indices):
                eval_feature = features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                result = RawResult(unique_id=unique_id,
                                   start_logits=to_list(outputs[0][i]),
                                   end_logits=to_list(outputs[1][i]))
                all_results.append(result)
        return write_predictions(examples, features, all_results, 3, 24, True,
                                 None, None, None, False, False, 0.0)