class Trainer: def __init__(self, config): self.config = config self.step = 0 self.vocab = Vocab(config.vocab_file, config.vocab_size) self.train_data = CNNDMDataset('train', config.data_path, config, self.vocab) self.validate_data = CNNDMDataset('val', config.data_path, config, self.vocab) # self.model = Model(config).to(device) # self.optimizer = None self.setup(config) def setup(self, config): model = Model(config) checkpoint = None if config.train_from != '': logging('Train from %s' % config.train_from) checkpoint = torch.load(config.train_from, map_location='cpu') model.load_state_dict(checkpoint['model']) self.step = checkpoint['step'] self.model = model.to(device) self.optimizer = Adagrad(model.parameters(), lr=config.learning_rate, initial_accumulator_value=config.initial_acc) if checkpoint is not None: self.optimizer.load_state_dict(checkpoint['optimizer']) def train_one(self, batch): config = self.config enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \ get_input_from_batch(batch, config, device) dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ get_output_from_batch(batch, device) encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder( enc_batch, enc_lens) s_t_1 = self.model.reduce_state(encoder_hidden) step_losses = [] for di in range(max_dec_len): y_t_1 = dec_batch[:, di] # Teacher forcing final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder( y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di) target = target_batch[:, di] gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() step_loss = -torch.log(gold_probs + config.eps) if config.is_coverage: step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) step_loss = step_loss + config.cov_loss_wt * step_coverage_loss coverage = next_coverage step_mask = dec_padding_mask[:, di] step_loss = step_loss * step_mask step_losses.append(step_loss) sum_losses = torch.sum(torch.stack(step_losses, 1), 1) batch_avg_loss = sum_losses / dec_lens_var loss = torch.mean(batch_avg_loss) return loss def train(self): config = self.config train_loader = DataLoader(self.train_data, batch_size=config.batch_size, shuffle=True, collate_fn=Collate()) running_avg_loss = 0 self.model.train() for e in range(config.train_epoch): for batch in train_loader: self.step += 1 self.optimizer.zero_grad() loss = self.train_one(batch) loss.backward() clip_grad_norm_(self.model.parameters(), config.max_grad_norm) self.optimizer.step() #print(loss.item()) running_avg_loss = calc_running_avg_loss( loss.item(), running_avg_loss) if self.step % config.report_every == 0: logging("Step %d Train loss %.3f" % (self.step, running_avg_loss)) if self.step % config.validate_every == 0: self.validate() if self.step % config.save_every == 0: self.save(self.step) if self.step % config.test_every == 0: pass @torch.no_grad() def validate(self): self.model.eval() validate_loader = DataLoader(self.validate_data, batch_size=self.config.batch_size, shuffle=False, collate_fn=Collate()) losses = [] for batch in validate_loader: loss = self.train_one(batch) losses.append(loss.item()) self.model.train() ave_loss = sum(losses) / len(losses) logging('Validate loss : %f' % ave_loss) def save(self, step): state = { 'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'step': step } save_path = os.path.join(self.config.model_path, 'model_s%d.pt' % step) logging('Saving model step %d to %s...' % (step, save_path)) torch.save(state, save_path)
class Trainer: def __init__(self, config): self.config = config self.device = config['device'] self.step = 0 if os.path.exists('../vocab.pt'): self.vocab = torch.load('../vocab.pt') else: self.vocab = Vocab(config['vocab_file'], config['vocab_size']) torch.save(self.vocab, '../vocab.pt') self.train_data = CNNDMDataset('train', config['data_path'], config, self.vocab) self.validate_data = CNNDMDataset('val', config['data_path'], config, self.vocab) self.setup(config) def setup(self, config): self.model = Model(config).to(config['device']) self.optimizer = Adagrad(self.model.parameters(), lr=config['learning_rate'], initial_accumulator_value=0.1) # self.optimizer = Adam(self.model.parameters(),lr = config['learning_rate'],betas = config['betas']) checkpoint = None if config[ 'train_from'] != '': # Counter在两次mostCommon间, 相同频率的元素可能以不同的次序输出...! logging('Train from %s' % config['train_from']) checkpoint = torch.load(config['train_from'], map_location='cpu') self.model.load_state_dict(checkpoint['model']) self.step = checkpoint['step'] self.vocab = checkpoint['vocab'] self.optimizer.load_state_dict(checkpoint['optimizer']) # print('State dict parameters:') # for n in model.state_dict().keys(): # print(n) #self.optimizer = Adam(self.model.parameters(),lr = config['learning_rate'],betas = config['betas']) def train_one(self, batch): """ coverage not implemented """ config = self.config enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros = \ get_input_from_batch(batch, config, self.device) dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ get_output_from_batch(batch, self.device) pred = self.model(enc_batch, dec_batch, enc_padding_mask, dec_padding_mask, enc_batch_extend_vocab, extra_zeros) # >>>>>>>> DEBUG Session <<<<<<<<< # print("ENC\n") # print(enc_batch) # print("DEC\n") # print(dec_batch) # print("TGT\n") # print(target_batch) # print("ENCP\n") # print(enc_padding_mask) # print("DECP\n") # print(dec_padding_mask) # encs = [self.vocab.id2word(int(v)) for v in enc_batch[:, 0]] # decs = [self.vocab.id2word(int(v)) for v in dec_batch[:, 0]] # print(' '.join(encs)) # print(' '.join(decs)) #print(pred.max(dim=-1)[1][:,0]) # #loss = self.model.nll_loss(pred, target_batch, dec_lens_var) loss = self.model.label_smoothing_loss(pred, target_batch) return loss def train(self): config = self.config train_loader = DataLoader(self.train_data, batch_size=config['batch_size'], shuffle=True, collate_fn=Collate()) running_avg_loss = 0 self.model.train() for _ in range(config['train_epoch']): for batch in train_loader: self.step += 1 loss = self.train_one(batch) running_avg_loss = calc_running_avg_loss( loss.item(), running_avg_loss) loss.div(float(config['gradient_accum'])).backward() if self.step % config[ 'gradient_accum'] == 0: # gradient accumulation clip_grad_norm_(self.model.parameters(), config['max_grad_norm']) self.optimizer.step() self.optimizer.zero_grad() if self.step % config['report_every'] == 0: logging("Step %d Train loss %.3f" % (self.step, running_avg_loss)) if self.step % config['save_every'] == 0: self.save() if self.step % config['validate_every'] == 0: self.validate() @torch.no_grad() def validate(self): self.model.eval() validate_loader = DataLoader(self.validate_data, batch_size=self.config['batch_size'], shuffle=False, collate_fn=Collate()) losses = [] for batch in tqdm(validate_loader): loss = self.train_one(batch) losses.append(loss.item()) self.model.train() ave_loss = sum(losses) / len(losses) logging('Validate loss : %f' % ave_loss) def save(self): state = { 'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'step': self.step, 'vocab': self.vocab } save_path = os.path.join(self.config['model_path'], 'model_s%d.pt' % self.step) logging('Saving model step %d to %s...' % (self.step, save_path)) torch.save(state, save_path)
class Trainer(): def __init__(self, model, args, train_dataset, eval_dataset, test_dataset, vocab, is_train=True): self.model = model #.to(args.device) self.args = args self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.test_dataset = test_dataset self.is_train = is_train self.vocab = vocab self.params = list(model.encoder.parameters()) + \ list(model.decoder.parameters()) + list(model.reduce_state.parameters()) initial_lr = args.lr_coverage if args.is_coverage else args.lr self.optimizer = Adagrad( self.params, lr=initial_lr, initial_accumulator_value=args.adagrad_init_acc) def get_train_dataloader(self): if self.train_dataset is None: raise ValueError('Trainer: training requires a train_dataset.') return BucketIterator(dataset=self.train_dataset, batch_size=self.args.batch_size, device=self.args.device, sort_key=lambda x: len(x.source), sort_within_batch=True) def get_eval_dataloader(self): if self.eval_dataset is None: raise ValueError('Trainer: eval requires a eval_dataset.') return BucketIterator(dataset=self.eval_dataset, batch_size=self.args.batch_size, device=self.args.device, sort_key=lambda x: len(x.source), sort_within_batch=True) def get_test_dataloader(self): if self.test_dataset is None: raise ValueError('Trainer: testing requires a test_dataset.') return BucketIterator(dataset=self.test_dataset, batch_size=self.args.batch_size, device=self.args.device, sort_key=lambda x: len(x.source), sort_within_batch=True) def get_mask(self, batch): # print('each batch', batch[0].size()) maxlen = batch[0].size()[1] max_enc_seq_len = batch[1] mask = torch.arange(maxlen).to(self.args.device) mask = mask[None, :] < max_enc_seq_len[:, None] # print(batch.source[0]*mask) return mask def get_extra_features(self, batch): unk_index = self.vocab.stoi[UNKNOWN_TOKEN] batch = batch.cpu().detach().numpy() batch_size = batch.shape[0] max_art_oovs = max([Counter(sample)[unk_index] for sample in batch]) extra_zeros = None enc_batch_extend_vocab = np.full_like( batch, fill_value=self.vocab.stoi[PAD_TOKEN]) max_art_oovs = 0 for i, sample_index in enumerate(batch): oov_word_count = len(self.vocab) for j, word_index in enumerate(sample_index): if word_index == unk_index: enc_batch_extend_vocab[i, j] = oov_word_count oov_word_count += 1 max_art_oovs = max(max_art_oovs, oov_word_count) max_art_oovs -= len(self.vocab) enc_batch_extend_vocab = Variable( torch.from_numpy(enc_batch_extend_vocab).long()) extra_zeros = Variable(torch.zeros((batch_size, max_art_oovs))) return extra_zeros, enc_batch_extend_vocab, max_art_oovs def save_model(self, running_avg_loss, iter, model_dir): if not os.path.exists(model_dir): os.makedirs(model_dir) state = { 'iter': iter, 'encoder_state_dict': self.model.encoder.state_dict(), 'decoder_state_dict': self.model.decoder.state_dict(), 'reduce_state_dict': self.model.reduce_state.state_dict(), 'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_loss, 'vocab': self.vocab } model_save_path = os.path.join( model_dir, 'model_%d_%d' % (iter, int(time.time()))) torch.save(state, model_save_path) def evaluate(self, eval_dataset=None, iter=0, is_test=False): if is_test: eval_iter = self.get_test_dataloader() else: eval_iter = self.get_eval_dataloader() self.model.eval() running_avg_loss = 0 with torch.no_grad(): for i, batch in tqdm(enumerate(eval_iter), total=len(eval_iter)): # print(batch.source[0].size()) # exit() batch_size = batch.batch_size # encoder part enc_padding_mask = self.get_mask(batch.source) enc_batch = batch.source[0] enc_lens = batch.source[1] encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder( enc_batch, enc_lens) s_t_1 = self.model.reduce_state(encoder_hidden) coverage = Variable(torch.zeros(batch.source[0].size())).to( self.args.device) c_t_1 = Variable( torch.zeros( (batch_size, 2 * self.args.hidden_dim))).to(self.args.device) extra_zeros, enc_batch_extend_vocab, max_art_oovs = self.get_extra_features( batch.source[0]) extra_zeros = extra_zeros.to(self.args.device) enc_batch_extend_vocab = enc_batch_extend_vocab.to( self.args.device) # decoder part dec_batch = batch.target[0][:, :-1] # print(dec_batch.size()) target_batch = batch.target[0][:, 0:] dec_lens_var = batch.target[1] dec_padding_mask = self.get_mask(batch.target) max_dec_len = max(dec_lens_var) step_losses = [] for di in range(min(max_dec_len, self.args.max_dec_steps) - 1): y_t_1 = dec_batch[:, di] # Teacher forcing final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder( y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di) target = target_batch[:, di] gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() step_loss = -torch.log(gold_probs + self.args.eps) if self.args.is_coverage: step_coverage_loss = torch.sum( torch.min(attn_dist, coverage), 1) step_loss = step_loss + self.args.cov_loss_wt * step_coverage_loss coverage = next_coverage step_mask = dec_padding_mask[:, di] step_loss = step_loss * step_mask step_losses.append(step_loss) sum_losses = torch.sum(torch.stack(step_losses, 1), 1) batch_avg_loss = sum_losses / dec_lens_var loss = torch.mean(batch_avg_loss) norm = clip_grad_norm_(self.model.encoder.parameters(), self.args.max_grad_norm) clip_grad_norm_(self.model.decoder.parameters(), self.args.max_grad_norm) clip_grad_norm_(self.model.reduce_state.parameters(), self.args.max_grad_norm) self.optimizer.step() # running_avg_loss = loss if running_avg_loss == 0 else running_avg_loss * decay + (1 - decay) * loss # running_avg_loss = min(running_avg_loss, 12) name = 'Test' if is_test else 'Evaluation' calc_running_avg_loss(loss.item(), running_avg_loss, summary_writer, iter, name) # iter += 1 # def predict(self, source_sentence): def train(self, model_path=None): train_iter = self.get_train_dataloader() iter, running_avg_loss = 0, 0 start = time.time() for epoch in range(self.args.epoches): print(f"Epoch: {epoch+1}") self.model.train() for i, batch in tqdm(enumerate(train_iter), total=len(train_iter)): # print(batch.source[0].size()) # exit() batch_size = batch.batch_size # encoder part enc_padding_mask = self.get_mask(batch.source) enc_batch = batch.source[0] enc_lens = batch.source[1] encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder( enc_batch, enc_lens) s_t_1 = self.model.reduce_state(encoder_hidden) coverage = Variable(torch.zeros(batch.source[0].size())).to( self.args.device) c_t_1 = Variable( torch.zeros( (batch_size, 2 * self.args.hidden_dim))).to(self.args.device) extra_zeros, enc_batch_extend_vocab, max_art_oovs = self.get_extra_features( batch.source[0]) extra_zeros = extra_zeros.to(self.args.device) enc_batch_extend_vocab = enc_batch_extend_vocab.to( self.args.device) # decoder part dec_batch = batch.target[0][:, :-1] # print(dec_batch.size()) target_batch = batch.target[0][:, 0:] dec_lens_var = batch.target[1] dec_padding_mask = self.get_mask(batch.target) max_dec_len = max(dec_lens_var) step_losses = [] for di in range(min(max_dec_len, self.args.max_dec_steps) - 1): y_t_1 = dec_batch[:, di] # Teacher forcing final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder( y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di) target = target_batch[:, di] gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() step_loss = -torch.log(gold_probs + self.args.eps) if self.args.is_coverage: step_coverage_loss = torch.sum( torch.min(attn_dist, coverage), 1) step_loss = step_loss + self.args.cov_loss_wt * step_coverage_loss coverage = next_coverage step_mask = dec_padding_mask[:, di] step_loss = step_loss * step_mask step_losses.append(step_loss) sum_losses = torch.sum(torch.stack(step_losses, 1), 1) batch_avg_loss = sum_losses / dec_lens_var loss = torch.mean(batch_avg_loss) loss.backward() norm = clip_grad_norm_(self.model.encoder.parameters(), self.args.max_grad_norm) clip_grad_norm_(self.model.decoder.parameters(), self.args.max_grad_norm) clip_grad_norm_(self.model.reduce_state.parameters(), self.args.max_grad_norm) self.optimizer.step() running_avg_loss = calc_running_avg_loss( loss.item(), running_avg_loss, summary_writer, iter, 'Train') iter += 1 if iter % self.args.flush: # print('flush') summary_writer.flush() # print_interval = 10 # if iter % print_interval == 0: # print(f'steps {iter}, batch number: {i} with {time.time() - start} seconds, loss: {loss}') # start = time.time() # if iter % 300 == 0: self.save_model(running_avg_loss, iter, model_dir) self.evaluate(self.eval_dataset, epoch) self.evaluate(self.test_dataset, epoch, True)