def __init__(self): # 加载数据 data_dir = "data" self.vocab_path = "corpus/bert-base-chinese-vocab.txt" self.sents_src, self.sents_tgt = read_corpus(data_dir, self.vocab_path) self.model_name = "bert" self.model_path = "weights/bert-base-chinese-pytorch_model.bin" self.recent_model_path = "" self.model_save_path = "./bert_model.bin" self.batch_size = 24 self.lr = 1e-5 self.word2idx = load_chinese_base_vocab(self.vocab_path) self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") print("device: " + str(self.device)) self.bert_model = load_bert(self.vocab_path, model_name=self.model_name) load_model_params(self.bert_model, self.model_path) self.bert_model.to(self.device) self.optim_parameters = list(self.bert_model.parameters()) self.optimizer = torch.optim.Adam(self.optim_parameters, lr=self.lr, weight_decay=1e-3) dataset = BertDataset(self.sents_src, self.sents_tgt, self.vocab_path) self.dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn)
def __init__(self, vocab_path, predicate_num, model_name="roberta"): super(BertRelationExtrac, self).__init__() self.word2ix = load_chinese_base_vocab(vocab_path) self.predicate_num = predicate_num config = "" if model_name == "roberta": from bert.seq2seq.model.roberta_model import BertModel, BertConfig, BertPredictionHeadTransform, BertLayerNorm config = BertConfig(len(self.word2ix)) self.bert = BertModel(config) self.layer_norm = BertLayerNorm(config.hidden_size) self.layer_norm_cond = BertLayerNorm(config.hidden_size, conditional=True) elif model_name == "bert": from bert.seq2seq.model.bert_model import BertConfig, BertModel, BertPredictionHeadTransform, BertLayerNorm config = BertConfig(len(self.word2ix)) self.bert = BertModel(config) self.layer_norm = BertLayerNorm(config.hidden_size) self.layer_norm_cond = BertLayerNorm(config.hidden_size, conditional=True) else: raise Exception("model_name_err") self.subject_pred = nn.Linear(config.hidden_size, 2) self.activation = nn.Sigmoid() self.object_pred = nn.Linear(config.hidden_size, 2 * self.predicate_num)
def __init__(self, sents_src, sents_tgt, vocab_path): # 一般init函数是加载所有数据 super(BertDataset, self).__init__() self.sents_src = sents_src self.sents_tgt = sents_tgt self.word2idx = load_chinese_base_vocab(vocab_path) self.idx2word = {k: v for v, k in self.word2idx.items()} self.tokenizer = Tokenizer(self.word2idx)
def read_corpus(dir_path, vocab_path): sents_src = [] sents_tgt = [] word2idx = load_chinese_base_vocab(vocab_path) tokenizer = Tokenizer(word2idx) files = os.listdir(dir_path) for file1 in files: if not os.path.isdir(file1): file_path = dir_path + "/" + file1 print(file_path) if file_path[-3:] != "csv": continue df = pd.read_csv(file_path) for index, row in df.iterrows(): if type(row[0]) is not str or type(row[3]) is not str: continue if len(row[0]) > 8 or len(row[0]) < 2: # 过滤掉题目长度过长和过短的诗句 continue if len(row[0].split(" ")) > 1: # 说明题目里面存在空格,只要空格前面的数据 row[0] = row[0].split(" ")[0] encode_text = tokenizer.encode(row[3])[0] if word2idx["[UNK]"] in encode_text: continue if len(row[3]) == 24 and (row[3][5] == "," or row[3][5] == "。"): # 五言绝句 sents_src.append(row[0] + "##" + "五言绝句") sents_tgt.append(row[3]) elif len(row[3]) == 32 and (row[3][7] == "," or row[3][7] == "。"): # 七言绝句 sents_src.append(row[0] + "##" + "七言绝句") sents_tgt.append(row[3]) elif len(row[3]) == 48 and (row[3][5] == "," or row[3][5] == "。"): # 五言律诗 sents_src.append(row[0] + "##" + "五言律诗") sents_tgt.append(row[3]) elif len(row[3]) == 64 and (row[3][7] == "," or row[3][7] == "。"): # 七言律诗 sents_src.append(row[0] + "##" + "七言律诗") sents_tgt.append(row[3]) print("诗句共: " + str(len(sents_src)) + "篇") return sents_src, sents_tgt
def __init__(self, vocab_path, target_size, model_name="roberta"): super(BertClsClassifier, self).__init__() self.word2ix = load_chinese_base_vocab(vocab_path) self.tokenizer = Tokenizer(self.word2ix) self.target_size = target_size config = "" if model_name == "roberta": from bert.seq2seq.model.roberta_model import BertModel, BertConfig config = BertConfig(len(self.word2ix)) self.bert = BertModel(config) elif model_name == "bert": from bert.seq2seq.model.bert_model import BertConfig, BertModel config = BertConfig(len(self.word2ix)) self.bert = BertModel(config) else : raise Exception("model_name_err") self.final_dense = nn.Linear(config.hidden_size, self.target_size)
def __init__(self, vocab_path, target_size, model_name="roberta"): super(BertSeqLabelingCRF, self).__init__() self.word2ix = load_chinese_base_vocab(vocab_path) self.target_size = target_size config = "" if model_name == "roberta": from bert.seq2seq.model.roberta_model import BertModel, BertConfig, BertPredictionHeadTransform config = BertConfig(len(self.word2ix)) self.bert = BertModel(config) self.transform = BertPredictionHeadTransform(config) elif model_name == "bert": from bert.seq2seq.model.bert_model import BertConfig, BertModel, BertPredictionHeadTransform config = BertConfig(len(self.word2ix)) self.bert = BertModel(config) self.transform = BertPredictionHeadTransform(config) else : raise Exception("model_name_err") self.final_dense = nn.Linear(config.hidden_size, self.target_size) self.crf_layer = CRFLayer(self.target_size)
def __init__(self, vocab_path, model_name="roberta"): super(Seq2SeqModel, self).__init__() self.word2ix = load_chinese_base_vocab(vocab_path) self.tokenizer = Tokenizer(self.word2ix) config = "" if model_name == "roberta": from bert.seq2seq.model.roberta_model import BertModel, BertConfig, BertLMPredictionHead config = BertConfig(len(self.word2ix)) self.bert = BertModel(config) self.decoder = BertLMPredictionHead( config, self.bert.embeddings.word_embeddings.weight) elif model_name == "bert": from bert.seq2seq.model.bert_model import BertConfig, BertModel, BertLMPredictionHead config = BertConfig(len(self.word2ix)) self.bert = BertModel(config) self.decoder = BertLMPredictionHead( config, self.bert.embeddings.word_embeddings.weight) else: raise Exception("model_name_err") self.hidden_dim = config.hidden_size self.vocab_size = config.vocab_size