args = parser.parse_args() args.n_gpu = torch.cuda.device_count() args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.device = device tokenizer = BertTokenizer.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=False) config = BertConfig.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path) model = BertForQuestionAnswering(config) model_state_dict = args.state_dict model.load_state_dict(torch.load(model_state_dict)) model.to(args.device) model.eval() input_file = args.predict_file def handle_file(input_file, context, question): with open(input_file, "r") as reader: orig_data = json.load(reader) orig_data["data"][0]['paragraphs'][0]['context'] = context for i in range(len(question)): orig_data["data"][0]['paragraphs'][0]['qas'][i][ 'question'] = question[i] with open(input_file, "w") as writer: writer.write(json.dumps(orig_data, indent=4) + "\n")
class QA_BERT_SingleSpan(): def __init__(self, model_state_dict) -> None: no_cuda = True self.device = torch.device( "cuda" if torch.cuda.is_available() and not no_cuda else "cpu") self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese', do_lower_case=False) config = BertConfig.from_pretrained('bert-base-chinese') self.model = BertForQuestionAnswering(config) self.model.load_state_dict( torch.load(model_state_dict, map_location='cpu')) self.model.to(self.device) self.model.eval() # TODO def predict_old(self, query, search_size, max_query_length, max_answer_length, max_seq_length, n_best_size, doc_stride, verbose_logging, es_index, null_score_diff_threshold, prefix=""): dataset, examples, features, recall_scores = build_dataset_example_feature( query, self.tokenizer, max_query_length, max_seq_length, doc_stride, search_size, es_index) eval_dataloader = DataLoader(dataset, sampler=None, batch_size=16) all_results = [] for batch in eval_dataloader: batch = tuple(t.to(self.device) for t in batch) with torch.no_grad(): inputs = { 'input_ids': batch[0], 'attention_mask': batch[1], 'token_type_ids': batch[2] } example_indices = batch[3] outputs = self.model(**inputs) for i, example_index in enumerate(example_indices): eval_feature = features[example_index.item()] unique_id = int(eval_feature.unique_id) result = RawResult(unique_id=unique_id, start_logits=to_list(outputs[0][i]), end_logits=to_list(outputs[1][i])) all_results.append(result) all_predictions = write_predictions(examples, features, all_results, n_best_size, max_answer_length, True, None, None, None, verbose_logging, False, null_score_diff_threshold) return all_predictions, recall_scores def predict(self, doc=None, query=None): """ function params doc : 输入文档 query : 输入查询话术 returns """ if doc != None: pass # TODO # 插入文档 else: dataset, examples, features, _ = build_dataset_example_feature_by_context_query( query, self.tokenizer, 16, 384, 128) eval_dataloader = DataLoader(dataset, sampler=None, batch_size=16) all_results = [] for batch in eval_dataloader: self.model.eval() batch = tuple(t.to(self.device) for t in batch) with torch.no_grad(): inputs = { 'input_ids': batch[0], 'attention_mask': batch[1], 'token_type_ids': batch[2] } example_indices = batch[3] outputs = self.model(**inputs) for i, example_index in enumerate(example_indices): eval_feature = features[example_index.item()] unique_id = int(eval_feature.unique_id) result = RawResult(unique_id=unique_id, start_logits=to_list(outputs[0][i]), end_logits=to_list(outputs[1][i])) all_results.append(result) return write_predictions(examples, features, all_results, 3, 24, True, None, None, None, False, False, 0.0)