Beispiel #1
0
class Predict(object):
    def __init__(self,
                 model_path=config.root_path + '/model/saved_dict/bert.ckpt',
                 bert_path=config.root_path + '/model/bert-wwm/',
                 is_cuda=config.is_cuda):
        self.model_path = model_path
        self.tokenizer = BertTokenizer.from_pretrained(bert_path)
        self.is_cuda = is_cuda
        config.bert_path = config.root_path + '/model/bert/'
        config.hidden_size = 768
        self.model = Model(config).to(config.device)
        checkpoint = torch.load(self.model_path)
        self.model.load_state_dict(checkpoint, strict=False)
        self.model.eval()

    def process_data(self, text, is_cuda=config.is_cuda):
        def padding(indice, max_length, pad_idx=0):
            """
            pad 函数
            注意 token type id 右侧pad是添加1而不是0,1表示属于句子B
            """
            pad_indice = [
                item + [pad_idx] * max(0, max_length - len(item))
                for item in indice
            ]
            return torch.tensor(pad_indice)

        text_dict = self.tokenizer.encode_plus(
            text,  # Sentence to encode.
            add_special_tokens=True,  # Add '[CLS]' and '[SEP]'
            max_length=config.max_length,  # Pad & truncate all sentences.
            ad_to_max_length=True,
            return_attention_mask=True,  # Construct attn. masks.
            #                                                    return_tensors='pt',     # Return pytorch tensors.
        )

        input_ids, attention_mask, token_type_ids = text_dict[
            'input_ids'], text_dict['attention_mask'], text_dict[
                'token_type_ids']

        token_ids_padded = padding([input_ids], config.max_length)
        token_type_ids_padded = padding([token_type_ids], config.max_length)
        attention_mask_padded = padding([attention_mask], config.max_length)
        return token_ids_padded, token_type_ids_padded, attention_mask_padded

    def predict(self, text):
        token_ids_padded, token_type_ids_padded, attention_mask_padded = self.process_data(
            text)
        if self.is_cuda:
            token_ids_padded = token_ids_padded.to(torch.device('cuda'))
            token_type_ids_padded = token_type_ids_padded.to(
                torch.device('cuda'))
            attention_mask_padded = attention_mask_padded.to(
                torch.device('cuda'))
        outputs = self.model(
            (token_ids_padded, attention_mask_padded, token_type_ids_padded))
        label = torch.max(outputs.data, 1)[1].cpu().numpy()[0]
        score = outputs.data[0][torch.max(
            outputs.data, 1)[1].cpu().numpy()[0]].cpu().numpy().tolist()
        return label, score
Beispiel #2
0
 def __init__(self,
              model_path=config.root_path + '/model/saved_dict/bert.ckpt',
              bert_path=config.root_path + '/model/bert-wwm/',
              is_cuda=config.is_cuda):
     self.model_path = model_path
     self.tokenizer = BertTokenizer.from_pretrained(bert_path)
     self.is_cuda = is_cuda
     config.bert_path = config.root_path + '/model/bert/'
     config.hidden_size = 768
     self.model = Model(config).to(config.device)
     checkpoint = torch.load(self.model_path)
     self.model.load_state_dict(checkpoint, strict=False)
     self.model.eval()
Beispiel #3
0
        end_time = time.time()
        print('Elasped time: {}'.format(end_time - start_time))

    preds = np.argmax(preds, axis=1)
    eval_acc = (preds == out_label_ids).mean()

    return eval_acc


if __name__ == '__main__':
    # Set seed
    set_seed(config)

    config.bert_path = config.root_path + '/model/bert/'
    config.hidden_size = 768
    model = Model(config).to(config.device)

    checkpoint = torch.load(config.root_path + '/model/saved_dict/bert.ckpt')
    model.load_state_dict(checkpoint, strict=False)
    tokenizer = BertTokenizer.from_pretrained(
        config.root_path + '/model/bert', do_lower_case=config.do_lower_case)
    model.to(config.device)
    print('finish model load')
    if config.visualize > -1:
        start_pos = config.visualize
        end_pos = start_pos + 1
    else:
        start_pos = config.start_pos
        end_pos = config.end_pos
    print('load data')
    test_dataset = MyDataset(config.test_file,