Ejemplo n.º 1
0
def init_model(model_path):
    model = BERTForMultiLabelSequenceClassification(config)
    model.load_state_dict(
        torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    model.to(config.device)

    return model
Ejemplo n.º 2
0
def convert(model_path):
    config = Config('data')

    input_ids = torch.ones((1, 32), dtype=torch.int32).cuda()
    input_mask = torch.ones((1, 32), dtype=torch.int32).cuda()
    segment_ids = torch.ones((1, 32), dtype=torch.int32).cuda()

    model = BERTForMultiLabelSequenceClassification(config, config.num_classes) 
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval().cuda()

    model_trt = torch2trt(model, [input_ids, input_mask, segment_ids])

    torch.save(model_trt.state_dict(), 'bert_trt.pth')
    print('转化成功...')
Ejemplo n.º 3
0
def convert(model_path, to_onnx_path):
    config = Config('.')
    opset_version = 11
    use_gpu = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_gpu else "cpu")
    print(device)
    # print(device)
    # device = torch.device('cpu')

    inputs = {
        'input_ids': torch.ones([1, 32], dtype=torch.long).to(device),
        'token_type_ids': torch.ones([1, 32], dtype=torch.long).to(device),
        'attention_mask': torch.ones([1, 32], dtype=torch.long).to(device)
    }
    # print(len(tuple(inputs.values())))

    model = BERTForMultiLabelSequenceClassification(config, config.num_classes)
    model.load_state_dict(
        torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    model.to(device)

    with torch.no_grad():
        symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
        logits_name = {0: 'batch_size', 1: 'num_class'}
        torch.onnx.export(
            model,
            args=tuple(inputs.values()),
            f=to_onnx_path,
            verbose=True,
            opset_version=opset_version,
            do_constant_folding=True,
            input_names=['input_ids', 'segment_ids', 'input_mask'],
            output_names=['logits'],
            dynamic_axes={
                'input_ids': symbolic_names,
                'segment_ids': symbolic_names,
                'input_mask': symbolic_names,
                'logits': logits_name
            })

    print('model exported at', to_onnx_path)
Ejemplo n.º 4
0
class Predict:
    def __init__(self,
                 config,
                 model_path,
                 label_path,
                 bert_path='chinese-bert-wwm',
                 max_seq_length=32):
        self.config = config
        self.model_path = model_path
        self.label_path = label_path
        self.bert_path = bert_path

        self.model = BERTForMultiLabelSequenceClassification(
            self.config, self.config.num_classes)
        self.model.load_state_dict(
            torch.load(self.model_path, map_location=torch.device('cpu')))
        self.model.half()
        self.model.eval()
        self.model.to(self.config.device)

        self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
        self.max_seq_length = max_seq_length

        self.processor = TextProcessor()
        self.labels = self.processor.get_labels(self.label_path)
        self.label2id = {label: id_ for id_, label in enumerate(self.labels)}
        self.id2label = {id_: label for id_, label in enumerate(self.labels)}

    def run(self, record):
        '''
        预测小类标签
        '''
        text_a, text_b = record[0], record[1]
        example = self.processor._create_single_example(text_a, text_b)
        feature = convert_single_example(example, self.max_seq_length,
                                         self.tokenizer)

        input_ids = torch.tensor(feature.input_ids,
                                 dtype=torch.long).unsqueeze(0).to(
                                     self.config.device)
        segment_ids = torch.tensor(feature.segment_ids,
                                   dtype=torch.long).unsqueeze(0).to(
                                       self.config.device)
        input_mask = torch.tensor(feature.input_mask,
                                  dtype=torch.long).unsqueeze(0).to(
                                      self.config.device)
        # print(input_ids)
        # print(segment_ids)
        # print(input_mask)

        logits = self.model(input_ids, segment_ids, input_mask).detach()
        # print(logits)
        prob = logits.sigmoid()[:, 1].tolist()  #[0.123]
        # prob = torch.sigmoid(logits)

        # return prob[0].cpu().tolist()[0]
        return prob[0]

    def collect_badcase(self, data_path):
        badcase = []
        cnt = 0
        with open(data_path, 'r', encoding='utf-8') as reader:
            for record in reader:
                print(f'第{cnt+1}条记录...')
                cnt += 1
                text_a, text_b, label = record.strip().split('\t')
                pre = self.run([text_a, text_b])
                if pre > 0.5:
                    pre_label = '1'
                else:
                    pre_label = '0'
                if pre_label != label:
                    badcase.append('\t'.join(
                        [text_a, text_b, label, pre_label,
                         str(pre)]))

        return badcase

    def evaluate(self, data_path):
        '''在全部的数据集上对模型进行测试
        '''
        labels = []
        pres = []

        cnt = 0
        with open(data_path, 'r', encoding='utf-8') as reader:
            for record in reader:
                print(f'第{cnt+1}条记录...')
                cnt += 1
                text_a, text_b, label = record.strip().split('\t')
                pre = self.run([text_a, text_b])
                labels.append(int(label))
                pres.append(pre)

        fpr, tpr, th = roc_curve(labels, pres, pos_label=1)
        auc_score = auc(fpr, tpr)
        return auc_score, pres

    def inference(self, data_path, to_path):
        pres = []
        cnt = 0
        with open(data_path, 'r', encoding='utf-8') as reader:
            for record in reader:
                print(f'第{cnt+1}条记录...')
                cnt += 1
                text_a, text_b = record.strip().split('\t')
                pre = self.run([text_a, text_b])
                pres.append(pre)

        with open(to_path, 'w', encoding='utf-8') as writer:
            for pre in pres:
                writer.write(str(pre) + '\n')