def init_model(model_path): model = BERTForMultiLabelSequenceClassification(config) model.load_state_dict( torch.load(model_path, map_location=torch.device('cpu'))) model.eval() model.to(config.device) return model
def convert(model_path, to_onnx_path): config = Config('.') opset_version = 11 use_gpu = torch.cuda.is_available() device = torch.device("cuda:0" if use_gpu else "cpu") print(device) # print(device) # device = torch.device('cpu') inputs = { 'input_ids': torch.ones([1, 32], dtype=torch.long).to(device), 'token_type_ids': torch.ones([1, 32], dtype=torch.long).to(device), 'attention_mask': torch.ones([1, 32], dtype=torch.long).to(device) } # print(len(tuple(inputs.values()))) model = BERTForMultiLabelSequenceClassification(config, config.num_classes) model.load_state_dict( torch.load(model_path, map_location=torch.device('cpu'))) model.eval() model.to(device) with torch.no_grad(): symbolic_names = {0: 'batch_size', 1: 'max_seq_len'} logits_name = {0: 'batch_size', 1: 'num_class'} torch.onnx.export( model, args=tuple(inputs.values()), f=to_onnx_path, verbose=True, opset_version=opset_version, do_constant_folding=True, input_names=['input_ids', 'segment_ids', 'input_mask'], output_names=['logits'], dynamic_axes={ 'input_ids': symbolic_names, 'segment_ids': symbolic_names, 'input_mask': symbolic_names, 'logits': logits_name }) print('model exported at', to_onnx_path)
class Predict: def __init__(self, config, model_path, label_path, bert_path='chinese-bert-wwm', max_seq_length=32): self.config = config self.model_path = model_path self.label_path = label_path self.bert_path = bert_path self.model = BERTForMultiLabelSequenceClassification( self.config, self.config.num_classes) self.model.load_state_dict( torch.load(self.model_path, map_location=torch.device('cpu'))) self.model.half() self.model.eval() self.model.to(self.config.device) self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) self.max_seq_length = max_seq_length self.processor = TextProcessor() self.labels = self.processor.get_labels(self.label_path) self.label2id = {label: id_ for id_, label in enumerate(self.labels)} self.id2label = {id_: label for id_, label in enumerate(self.labels)} def run(self, record): ''' 预测小类标签 ''' text_a, text_b = record[0], record[1] example = self.processor._create_single_example(text_a, text_b) feature = convert_single_example(example, self.max_seq_length, self.tokenizer) input_ids = torch.tensor(feature.input_ids, dtype=torch.long).unsqueeze(0).to( self.config.device) segment_ids = torch.tensor(feature.segment_ids, dtype=torch.long).unsqueeze(0).to( self.config.device) input_mask = torch.tensor(feature.input_mask, dtype=torch.long).unsqueeze(0).to( self.config.device) # print(input_ids) # print(segment_ids) # print(input_mask) logits = self.model(input_ids, segment_ids, input_mask).detach() # print(logits) prob = logits.sigmoid()[:, 1].tolist() #[0.123] # prob = torch.sigmoid(logits) # return prob[0].cpu().tolist()[0] return prob[0] def collect_badcase(self, data_path): badcase = [] cnt = 0 with open(data_path, 'r', encoding='utf-8') as reader: for record in reader: print(f'第{cnt+1}条记录...') cnt += 1 text_a, text_b, label = record.strip().split('\t') pre = self.run([text_a, text_b]) if pre > 0.5: pre_label = '1' else: pre_label = '0' if pre_label != label: badcase.append('\t'.join( [text_a, text_b, label, pre_label, str(pre)])) return badcase def evaluate(self, data_path): '''在全部的数据集上对模型进行测试 ''' labels = [] pres = [] cnt = 0 with open(data_path, 'r', encoding='utf-8') as reader: for record in reader: print(f'第{cnt+1}条记录...') cnt += 1 text_a, text_b, label = record.strip().split('\t') pre = self.run([text_a, text_b]) labels.append(int(label)) pres.append(pre) fpr, tpr, th = roc_curve(labels, pres, pos_label=1) auc_score = auc(fpr, tpr) return auc_score, pres def inference(self, data_path, to_path): pres = [] cnt = 0 with open(data_path, 'r', encoding='utf-8') as reader: for record in reader: print(f'第{cnt+1}条记录...') cnt += 1 text_a, text_b = record.strip().split('\t') pre = self.run([text_a, text_b]) pres.append(pre) with open(to_path, 'w', encoding='utf-8') as writer: for pre in pres: writer.write(str(pre) + '\n')