def __init__(self):
     # 加载数据
     data_dir = "data"
     self.vocab_path = "corpus/bert-base-chinese-vocab.txt"
     self.sents_src, self.sents_tgt = read_corpus(data_dir, self.vocab_path)
     self.model_name = "bert"
     self.model_path = "weights/bert-base-chinese-pytorch_model.bin"
     self.recent_model_path = ""
     self.model_save_path = "./bert_model.bin"
     self.batch_size = 24
     self.lr = 1e-5
     self.word2idx = load_chinese_base_vocab(self.vocab_path)
     self.device = torch.device(
         "cuda" if torch.cuda.is_available() else "cpu")
     print("device: " + str(self.device))
     self.bert_model = load_bert(self.vocab_path,
                                 model_name=self.model_name)
     load_model_params(self.bert_model, self.model_path)
     self.bert_model.to(self.device)
     self.optim_parameters = list(self.bert_model.parameters())
     self.optimizer = torch.optim.Adam(self.optim_parameters,
                                       lr=self.lr,
                                       weight_decay=1e-3)
     dataset = BertDataset(self.sents_src, self.sents_tgt, self.vocab_path)
     self.dataloader = DataLoader(dataset,
                                  batch_size=self.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_fn)
    def __init__(self, vocab_path, predicate_num, model_name="roberta"):
        super(BertRelationExtrac, self).__init__()
        self.word2ix = load_chinese_base_vocab(vocab_path)
        self.predicate_num = predicate_num
        config = ""
        if model_name == "roberta":
            from bert.seq2seq.model.roberta_model import BertModel, BertConfig, BertPredictionHeadTransform, BertLayerNorm
            config = BertConfig(len(self.word2ix))
            self.bert = BertModel(config)
            self.layer_norm = BertLayerNorm(config.hidden_size)
            self.layer_norm_cond = BertLayerNorm(config.hidden_size,
                                                 conditional=True)
        elif model_name == "bert":
            from bert.seq2seq.model.bert_model import BertConfig, BertModel, BertPredictionHeadTransform, BertLayerNorm
            config = BertConfig(len(self.word2ix))
            self.bert = BertModel(config)
            self.layer_norm = BertLayerNorm(config.hidden_size)
            self.layer_norm_cond = BertLayerNorm(config.hidden_size,
                                                 conditional=True)
        else:
            raise Exception("model_name_err")

        self.subject_pred = nn.Linear(config.hidden_size, 2)
        self.activation = nn.Sigmoid()
        self.object_pred = nn.Linear(config.hidden_size,
                                     2 * self.predicate_num)
 def __init__(self, sents_src, sents_tgt, vocab_path):
     # 一般init函数是加载所有数据
     super(BertDataset, self).__init__()
     self.sents_src = sents_src
     self.sents_tgt = sents_tgt
     self.word2idx = load_chinese_base_vocab(vocab_path)
     self.idx2word = {k: v for v, k in self.word2idx.items()}
     self.tokenizer = Tokenizer(self.word2idx)
def read_corpus(dir_path, vocab_path):
    sents_src = []
    sents_tgt = []
    word2idx = load_chinese_base_vocab(vocab_path)
    tokenizer = Tokenizer(word2idx)
    files = os.listdir(dir_path)
    for file1 in files:
        if not os.path.isdir(file1):
            file_path = dir_path + "/" + file1
            print(file_path)
            if file_path[-3:] != "csv":
                continue
            df = pd.read_csv(file_path)
            for index, row in df.iterrows():
                if type(row[0]) is not str or type(row[3]) is not str:
                    continue
                if len(row[0]) > 8 or len(row[0]) < 2:
                    # 过滤掉题目长度过长和过短的诗句
                    continue
                if len(row[0].split(" ")) > 1:
                    # 说明题目里面存在空格,只要空格前面的数据
                    row[0] = row[0].split(" ")[0]
                encode_text = tokenizer.encode(row[3])[0]
                if word2idx["[UNK]"] in encode_text:
                    continue
                if len(row[3]) == 24 and (row[3][5] == ","
                                          or row[3][5] == "。"):
                    # 五言绝句
                    sents_src.append(row[0] + "##" + "五言绝句")
                    sents_tgt.append(row[3])
                elif len(row[3]) == 32 and (row[3][7] == ","
                                            or row[3][7] == "。"):
                    # 七言绝句
                    sents_src.append(row[0] + "##" + "七言绝句")
                    sents_tgt.append(row[3])
                elif len(row[3]) == 48 and (row[3][5] == ","
                                            or row[3][5] == "。"):
                    # 五言律诗
                    sents_src.append(row[0] + "##" + "五言律诗")
                    sents_tgt.append(row[3])
                elif len(row[3]) == 64 and (row[3][7] == ","
                                            or row[3][7] == "。"):
                    # 七言律诗
                    sents_src.append(row[0] + "##" + "七言律诗")
                    sents_tgt.append(row[3])

    print("诗句共: " + str(len(sents_src)) + "篇")
    return sents_src, sents_tgt
 def __init__(self, vocab_path, target_size, model_name="roberta"):
     super(BertClsClassifier, self).__init__()
     self.word2ix = load_chinese_base_vocab(vocab_path)
     self.tokenizer = Tokenizer(self.word2ix)
     self.target_size = target_size
     config = ""
     if model_name == "roberta":
         from bert.seq2seq.model.roberta_model import BertModel, BertConfig
         config = BertConfig(len(self.word2ix))
         self.bert = BertModel(config)
     elif model_name == "bert":
         from bert.seq2seq.model.bert_model import BertConfig, BertModel
         config = BertConfig(len(self.word2ix))
         self.bert = BertModel(config)
     else :
         raise Exception("model_name_err")
         
     self.final_dense = nn.Linear(config.hidden_size, self.target_size)
 def __init__(self, vocab_path, target_size, model_name="roberta"):
     super(BertSeqLabelingCRF, self).__init__()
     self.word2ix = load_chinese_base_vocab(vocab_path)
     self.target_size = target_size
     config = ""
     if model_name == "roberta":
         from bert.seq2seq.model.roberta_model import BertModel, BertConfig, BertPredictionHeadTransform
         config = BertConfig(len(self.word2ix))
         self.bert = BertModel(config)
         self.transform = BertPredictionHeadTransform(config)
     elif model_name == "bert":
         from bert.seq2seq.model.bert_model import BertConfig, BertModel, BertPredictionHeadTransform
         config = BertConfig(len(self.word2ix))
         self.bert = BertModel(config)
         self.transform = BertPredictionHeadTransform(config)
     else :
         raise Exception("model_name_err")
     
     self.final_dense = nn.Linear(config.hidden_size, self.target_size)
     self.crf_layer = CRFLayer(self.target_size)
    def __init__(self, vocab_path, model_name="roberta"):
        super(Seq2SeqModel, self).__init__()
        self.word2ix = load_chinese_base_vocab(vocab_path)
        self.tokenizer = Tokenizer(self.word2ix)
        config = ""
        if model_name == "roberta":
            from bert.seq2seq.model.roberta_model import BertModel, BertConfig, BertLMPredictionHead
            config = BertConfig(len(self.word2ix))
            self.bert = BertModel(config)
            self.decoder = BertLMPredictionHead(
                config, self.bert.embeddings.word_embeddings.weight)
        elif model_name == "bert":
            from bert.seq2seq.model.bert_model import BertConfig, BertModel, BertLMPredictionHead
            config = BertConfig(len(self.word2ix))
            self.bert = BertModel(config)
            self.decoder = BertLMPredictionHead(
                config, self.bert.embeddings.word_embeddings.weight)
        else:
            raise Exception("model_name_err")

        self.hidden_dim = config.hidden_size
        self.vocab_size = config.vocab_size