def __init__(self): # 一般init函数是加载所有数据 super(DreamDataset, self).__init__() # 读原始数据 self.sents_src, self.sents_tgt = read_corpus(Config.dream_train_corpus_path) self.word2idx = load_bert_vocab() self.idx2word = {k: v for v, k in self.word2idx.items()} self.tokenizer = Tokenizer(self.word2idx)
def __init__(self, config: BertConfig): super(Seq2SeqModel, self).__init__() # 获取配置信息 self.hidden_dim = config.hidden_size self.vocab_size = config.vocab_size # encoder and decoder self.bert = BertModel(config) self.decoder = BertLMPredictionHead( config, self.bert.embeddings.word_embeddings.weight) # 加载字典和分词器 self.word2ix = load_bert_vocab() self.tokenizer = Tokenizer(self.word2ix)
def train(): # 加载数据集 dataset = DreamDataset() dataloader = DataLoader(dataset, batch_size=Config.batch_size, shuffle=True, collate_fn=collate_fn) # 实例化模型 word2idx = load_bert_vocab() bertconfig = BertConfig(vocab_size=len(word2idx)) bert_model = Seq2SeqModel(config=bertconfig) # 加载预训练模型 load_model(bert_model, Config.pretrain_model_path) bert_model.to(Config.device) # 声明需要优化的参数 并定义相关优化器 optim_parameters = list(bert_model.parameters()) optimizer = torch.optim.Adam(optim_parameters, lr=Config.learning_rate, weight_decay=1e-3) step = 0 for epoch in range(Config.EPOCH): total_loss = 0 i = 0 for token_ids, token_type_ids, target_ids in dataloader: start_time = time.time() step += 1 i += 1 token_ids = token_ids.to(Config.device) token_type_ids = token_type_ids.to(Config.device) target_ids = target_ids.to(Config.device) # 因为传入了target标签,因此会计算loss并且返回 predictions, loss = bert_model(token_ids, token_type_ids, labels=target_ids, device=Config.device) # 1. 清空之前梯度 optimizer.zero_grad() # 2. 反向传播 loss.backward() # 3. 梯度更新 optimizer.step() time_str = datetime.datetime.now().isoformat() log_str = 'time:{}, epoch:{}, step:{}, loss:{:8f}, spend_time:{:6f}'.format( time_str, epoch, step, loss, time.time() - start_time) rainbow(log_str) # print('epoch:{}, step:{}, loss:{:6f}, spend_time:{}'.format(epoch, step, loss, time.time() - start_time)) # 为计算当前epoch的平均loss total_loss += loss.item() if step % 30 == 0: torch.save(bert_model.state_dict(), './bert_dream.bin') print("当前epoch:{}, 平均损失:{}".format(epoch, total_loss / i)) if epoch % 10 == 0: save_path = "./data/" + "pytorch_bert_gen_epoch{}.bin".format( str(epoch)) torch.save(bert_model.state_dict(), save_path) print("{} saved!".format(save_path))
""" @file : interface.py @author: xiaolu @time : 2020-03-25 """ import torch from seq2seq_bert import Seq2SeqModel from bert_model import BertConfig from tokenizer import load_bert_vocab if __name__ == '__main__': word2idx = load_bert_vocab() config = BertConfig(len(word2idx)) bert_seq2seq = Seq2SeqModel(config) # 加载模型 checkpoint = torch.load('./bert_dream.bin', map_location=torch.device("cpu")) bert_seq2seq.load_state_dict(checkpoint) bert_seq2seq.eval() test_data = [ "梦见大街上人群涌动、拥拥而行的景象", "梦见司机将你送到目的地", "梦见别人把弓箭作为礼物送给自己", "梦见中奖了", "梦见大富豪" ] for text in test_data: print(bert_seq2seq.generate(text, beam_size=3))