def fit(self, model, data): def get_params(): return (p for p in model.parameters() if p.requires_grad) model.train() log = Logger() n_epoch = self.config.num_epochs optimizer = optim.Adam(get_params(), lr=self.config.lr) for epoch in range(n_epoch): if epoch < self.config.kl_start: kl_w = 0 else: kl_w = self.config.kl_w word_acc, topo_acc, assm_acc, steo_acc, all_kl = 0, 0, 0, 0, 0 with tqdm.tqdm(data) as train_dataloader: train_dataloader.set_description('Train (epoch #{})'.format(epoch)) for it, batch in enumerate(train_dataloader): model.zero_grad() loss, kl_div, wacc, tacc, sacc, dacc = model(batch, kl_w) loss.backward() optimizer.step() word_acc += wacc topo_acc += tacc assm_acc += sacc steo_acc += dacc all_kl += kl_div postfix = {'kl': all_kl / (it + 1), 'word': word_acc / (it + 1) * 100, 'topo': topo_acc / (it + 1) * 100, 'assm': assm_acc / (it + 1) * 100, 'steo': steo_acc / (it + 1) * 100} train_dataloader.set_postfix(postfix) log.append(postfix) log.save(self.config.log_file) if epoch % self.config.save_frequency == 0: model.to('cpu') torch.save(model.state_dict(), self.config.model_save[:-3]+'_{0:03d}.pt'.format(epoch)) model.to(device)
def fit(self, model, data): def get_params(): return (p for p in model.parameters() if p.requires_grad) if isinstance(data, tuple): train_dataloader = data[0] val_dataloader = data[1] else: train_dataloader = data val_dataloader = None num_epochs = self.config.num_epochs device = torch.device(self.config.device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(get_params(), lr=self.config.lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, self.config.step_size, self.config.gamma) elog = Logger() for epoch in range(num_epochs): scheduler.step() model.train() train_dataloader = tqdm.tqdm(train_dataloader) train_dataloader.set_description('Train (epoch #{})'.format(epoch)) loss = self._pass_data(model, train_dataloader, criterion, optimizer) elog.append({'loss': loss}) if val_dataloader is not None: val_dataloader = tqdm.tqdm(val_dataloader) val_dataloader.set_description( 'Validation (epoch #{})'.format(epoch)) self._pass_data(model, val_dataloader, criterion) if epoch % self.config.save_frequency == 0: model.to('cpu') torch.save( model.state_dict(), self.config.model_save[:-3] + '_{0:03d}.pt'.format(epoch)) model.to(device) elog.save(self.config.log_file) torch.save(model.state_dict(), self.config.model_save)
def fit(self, model, data): def get_params(): return (p for p in model.vae.parameters() if p.requires_grad) model.train() n_epoch = self._n_epoch() kl_annealer = KLAnnealer(n_epoch, self.config) optimizer = optim.Adam(get_params(), lr=self.config.lr_start) lr_annealer = CosineAnnealingLRWithRestart(optimizer, self.config) device = torch.device(self.config.device) n_last = self.config.n_last elog, ilog = Logger(), Logger() for epoch in range(n_epoch): # Epoch start kl_weight = kl_annealer(epoch) # Iters T = tqdm.tqdm(data) for i, x in enumerate(T): # Forward kl_loss, recon_loss = model(x) loss = kl_weight * kl_loss + recon_loss # Backward optimizer.zero_grad() loss.backward() clip_grad_norm_(get_params(), self.config.grad_clipping) optimizer.step() # Log lr = optimizer.param_groups[0]['lr'] ilog.append({ 'epoch': epoch, 'kl_loss': kl_loss.item(), 'recon_loss': recon_loss.item(), 'loss': loss.item(), 'kl_weight': kl_weight, 'lr': lr }) # Update T kl_loss_value = np.mean(ilog['kl_loss'][-n_last:]) recon_loss_value = np.mean(ilog['recon_loss'][-n_last:]) loss_value = np.mean(ilog['loss'][-n_last:]) postfix = [ f'loss={loss_value:.5f}', f'(kl={kl_loss_value:.5f}', f'recon={recon_loss_value:.5f})', f'klw={kl_weight:.5f} lr={lr:.5f}' ] T.set_postfix_str(' '.join(postfix)) T.set_description(f'Train (epoch #{epoch})') T.refresh() # Log elog.append({ **{k: v for k, v in ilog[-1].items() if 'loss' not in k}, 'kl_loss': kl_loss_value, 'recon_loss': recon_loss_value, 'loss': loss_value }) # Save model at each epoch if epoch % self.config.save_frequency == 0: model.to('cpu') torch.save( model.state_dict(), self.config.model_save[:-3] + '_{0:03d}.pt'.format(epoch)) model.to(device) elog.save(self.config.log_file) # Epoch end lr_annealer.step() torch.save(model.state_dict(), self.config.model_save) return elog, ilog