def __init__(self):

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        print("device: " + str(self.device))

        self.bert_model = load_bert(word2idx, model_name=model_name)

        #load_model_params(self.bert_model, model_path, keep_tokens=keep_tokens)

        load_recent_model(self.bert_model, recent_model_path)

        self.bert_model.to(self.device)
        print(self.bert_model)

        self.optim_parameters = list(self.bert_model.parameters())
        self.optimizer = torch.optim.Adam(self.optim_parameters,
                                          lr=lr,
                                          weight_decay=1e-3)

        dataset = BertDataset()
        self.dataloader = DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     collate_fn=collate_fn)
 def __init__(self):
     # 加载数据
     data_dir = "./corpus/对联"
     self.vocab_path = "./state_dict/roberta_wwm_vocab.txt"  # roberta模型字典的位置
     self.sents_src, self.sents_tgt = read_corpus(data_dir, self.vocab_path)
     self.model_name = "roberta"  # 选择模型名字
     self.model_path = "./state_dict/bert_duilian.model.epoch.0"  # roberta模型位置
     self.recent_model_path = ""  # 用于把已经训练好的模型继续训练
     self.model_save_path = "./bert_model.bin"
     self.batch_size = 16
     self.lr = 1e-5
     # 加载字典
     self.word2idx = load_chinese_base_vocab(self.vocab_path)
     # 判断是否有可用GPU
     self.device = torch.device(
         "cuda" if torch.cuda.is_available() else "cpu")
     print("device: " + str(self.device))
     # 定义模型
     self.bert_model = load_bert(self.vocab_path,
                                 model_name=self.model_name)
     ## 加载预训练的模型参数~
     load_recent_model(self.bert_model, self.model_path)
     # 将模型发送到计算设备(GPU或CPU)
     self.bert_model.to(self.device)
     # 声明需要优化的参数
     self.optim_parameters = list(self.bert_model.parameters())
     self.optimizer = torch.optim.Adam(self.optim_parameters,
                                       lr=self.lr,
                                       weight_decay=1e-3)
     # 声明自定义的数据加载器
     dataset = BertDataset(self.sents_src, self.sents_tgt, self.vocab_path)
     self.dataloader = DataLoader(dataset,
                                  batch_size=self.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_fn)
Example #3
0
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
from bert_seq2seq.utils import load_bert, load_model_params, load_recent_model

auto_title_model = "./state_dict/bert_auto_title_model2.bin"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if __name__ == "__main__":
    vocab_path = "./state_dict/roberta_wwm_vocab.txt"  # roberta模型字典的位置
    model_name = "roberta"  # 选择模型名字
    # model_path = "./state_dict/bert-base-chinese-pytorch_model.bin"  # roberta模型位
    # 加载字典
    word2idx, keep_tokens = load_chinese_base_vocab(vocab_path, simplfied=True)
    # 定义模型
    bert_model = load_bert(word2idx, model_name=model_name)
    bert_model.to(device)
    bert_model.eval()
    ## 加载训练的模型参数~
    load_recent_model(bert_model,
                      recent_model_path=auto_title_model,
                      device=device)

    test_data = [
        "针对央视3·15晚会曝光的电信行业乱象,工信部在公告中表示将严查央视3·15晚会曝光通信违规违法行为,工信部称已约谈三大运营商有关负责人并连夜责成三大运营商和所在省通信管理局进行调查依法依规严肃处理",
        "楚天都市报记者采访了解到,对于进口冷链食品,武汉已经采取史上最严措施,进行“红区”管理,严格执行证明查验制度,确保冷冻冷藏肉等冻品的安全。",
        "新华社受权于18日全文播发修改后的《中华人民共和国立法法》修改后的立法法分为“总则”“法律”“行政法规”“地方性法规自治条例和单行条例规章”“适用与备案审查”“附则”等6章共计105条"
    ]

    for text in test_data:
        with torch.no_grad():
            print(bert_model.generate(text, beam_size=3, device=device))
                    else:
                        res[each_entity] = [cur_text]
                    flag = each_entity
                elif flag == each_entity:
                    res[each_entity][-1] += text[index - 1]
            else:
                flag = 0
        print(res)


if __name__ == "__main__":
    vocab_path = "./state_dict/roberta_wwm_vocab.txt"  # roberta模型字典的位置
    model_name = "roberta"  # 选择模型名字
    # 加载字典
    word2idx = load_chinese_base_vocab(vocab_path, simplfied=False)
    tokenizer = Tokenizer(word2idx)
    # 定义模型
    bert_model = load_bert(word2idx,
                           model_name=model_name,
                           model_class="sequence_labeling_crf",
                           target_size=len(target))
    bert_model.to(device)
    bert_model.eval()
    ## 加载训练的模型参数~
    load_recent_model(bert_model, recent_model_path=model_path, device=device)
    test_data = [
        "日寇在京掠夺文物详情。", "以书结缘,把欧美,港台流行的食品类食谱汇集一堂。",
        "明天天津下雨,不知道杨永康主任还能不能来学校吃个饭。", "美国的华莱士,我和他谈笑风生", "看包公断案的戏"
    ]
    ner_print(bert_model, test_data, device=device)
    return objects


if __name__ == "__main__":

    # 定义模型
    bert_model = load_bert(word2idx,
                           model_class="relation_extrac",
                           model_name=model_name,
                           target_size=len(predicate2id))
    bert_model.eval()
    #   ## 加载预训练的模型参数~
    checkpoint = torch.load(relation_extrac_model, map_location="cpu")
    # print(checkpoint)
    load_recent_model(bert_model,
                      recent_model_path=relation_extrac_model,
                      device=device)
    text = [
        "查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部",
        "李治即位后,萧淑妃受宠,王皇后为了排挤萧淑妃,答应李治让身在感业寺的武则天续起头发,重新纳入后宫",
        "《星空黑夜传奇》是连载于起点中文网的网络小说,作者是啤酒的罪孽"
    ]
    for d in text:
        with torch.no_grad():
            token_ids_test, segment_ids = tokenizer.encode(d, max_length=256)
            token_ids_test = torch.tensor(token_ids_test,
                                          device=device).view(1, -1)
            # 先预测subject
            pred_subject = bert_model.predict_subject(token_ids_test,
                                                      device=device)
            pred_subject = pred_subject.squeeze(0)