class Train(object): def __init__(self, opt): self.vocab = Vocab(config.vocab_path, config.vocab_size) self.batcher = Batcher(config.train_data_path, self.vocab, mode='train', batch_size=config.batch_size, single_pass=False) self.opt = opt self.start_id = self.vocab.word2id(data.START_DECODING) self.end_id = self.vocab.word2id(data.STOP_DECODING) self.pad_id = self.vocab.word2id(data.PAD_TOKEN) self.unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN) time.sleep(5) def save_model(self, iter): save_path = config.save_model_path + "%07d.tar" % iter torch.save( { 'iter': iter + 1, 'model_dict': self.model.state_dict(), 'training_dict': self.trainer.state_dict() }, save_path) def setup_train(self): self.model = Model() self.model = get_cuda(self.model) self.trainer = torch.optim.Adam(self.model.parameters(), lr=config.lr) start_iter = 0 if self.opt.load_model is not None: load_model_path = os.path.join(config.save_model_path, self.opt.load_model) checkpoint = torch.load(load_model_path) start_iter = checkpoint['iter'] self.model.load_state_dict(checkpoint['model_dict']) self.trainer.load_state_dict(checkpoint['trainer_dict']) print("load model at" + load_model_path) if self.opt.new_lr is not None: self.trainer = torch.optim.Adam(self.model.parameters(), lr=self.opt.new_lr) # for params in self.traine # .param_groups: # params['lr'] = self.opt.new_lr return start_iter def train_batch_MLE(self, enc_out, enc_hidden, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, batch): ''' 以0.25的概率使用生成token来作为输入,0.75的概率以ground-truth label作为输入。 输入: enc_out: encoder的每个time step的输出。 [batch_size, max_seq_len, 2 * hidden_dim] enc_hidden: encoder最后的单元的隐藏状态和记忆状态。 (h, c) [batch_size, hidden_dim] enc_padding_mask: 对encoder的输入区分padding部分和确切的输入部分。 因为输入的时候是按照最长的单元的长度来设定的,所以在形成batch的时候进行了padding操作。 [batch_size, max_seq_len]. 0代表填充,1代表没有填充。 ct_e: decoder的time step对encoder进行attention操作得到的向量。 [batch_size, 2 * hidden_dim]. 随着time step而不断的变化的。 extra_zeros:存储oovs。 [batch_size, max_art_oovs] enc_batch_extend_vocab: 输入的batch,并且里面的各个article的oov都使用了对应的temperatual oov id来表示。 [batch_size, max_seq_len] batch: 输入的batch, 类 Batch的对象。 ''' dec_batch, max_dec_len, dec_lens, target_batch = get_dec_data(batch) step_losses = [] h_t = (enc_hidden[0], enc_hidden[1]) x_t = get_cuda(torch.LongTensor(len(enc_out)).fill_(self.start_id)) prev_s = None sum_temporal_srcs = None for t in range(min(max_dec_len, config.max_dec_steps)): # 对于batch中的每个article,随机生成一个数字, # 从而得到对应的article是否使用ground-truth label。得到0/1 use_ground_truth = get_cuda( (torch.rand(len(enc_out)) > 0.25)).long() x_t = use_ground_truth * dec_batch[:, t] + (1 - use_ground_truth) * x_t # 这里我觉得有一点不太对, # 因为输入x_t 的最后一个维度并不是config.vocab_size # 原因: 这里并不需要x_t 最后一个维度是config.vocab_size, 嵌入层会自动的将整数转换为嵌入表示,也就是在后面增加一个维度,为 emb_dim x_t = self.model.embeds(x_t) final_dist, h_t, ct_e, sum_temporal_srcs, prev_s = self.model.decoder( x_t, h_t, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, sum_temporal_srcs, prev_s) target = target_batch[:, t] log_probs = torch.log(final_dist + config.eps) step_loss = F.nll_loss(log_probs, target, reduction='none', ignore_index=self.pad_id) step_losses.append(step_loss) # final_dist:[batch_size, config.vocab_size + batch.max_art_oovs] # 对得到的结果在第二个维度进行采样,将采样的数量设置为1. 返回的结果是采样的位置。 # x_t : [batch_size, 1] --> [batch_size] x_t = torch.multinomial(final_dist, 1).squeeze() is_oov = (x_t >= config.vocab_size).long() # x_t: [batch_size] x_t = (1 - is_oov) * x_t.detach() + (is_oov) * self.unk_id losses = torch.sum(torch.stack(step_losses, 1), 1) batch_avg_loss = losses / dec_lens mle_loss = torch.mean(batch_avg_loss) return mle_loss # 一步迭代进行的所有的步骤 def train_one_batch(self, batch): enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, \ extra_zeros, context = get_enc_data(batch) enc_batch = self.model.embeds(enc_batch) enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens) if self.opt.train_mle == 'yes': mle_loss = self.train_batch_MLE(enc_out, enc_hidden, enc_padding_mask, context, extra_zeros, enc_batch_extend_vocab, batch) else: mle_loss = get_cuda(torch.FloatTensor([0])) self.trainer.zero_grad() mle_loss.backward() self.trainer.step() return mle_loss.item() # 真正的train的迭代部分 def train_iters(self): iter = self.setup_train() count = mle_total = 0 while iter <= config.max_iterations: batch = self.batcher.next_batch() try: mle_loss = self.train_one_batch(batch) except KeyboardInterrupt: print("-------------Keyboard Interrupt------------") exit(0) mle_total += mle_loss mle_loss = 0 count += 1 iter += 1 if iter % 1000 == 0: mle_avg = mle_total / count print('iter:', iter, 'mle_loss:', "%.3f" % mle_avg) count = mle_total = 0 sys.stdout.flush() if iter % 2000 == 0: self.save_model(iter)