def _train_epoch(self, data_loader, epoch): self.model.train() # iterate over len(data)/batch_size z_all = [] xhat_plot, x_plot = [], [] for i, (img, meta) in enumerate(data_loader): self.num_steps += 1 self.opt.zero_grad() img = img.to(self.device) xhat, z, mu, logvar = self.model(img) loss = self._loss(img, xhat, mu, logvar, train=True, ep=epoch) loss.backward() self.opt.step() self._report_train(i) z_all.append(mu.data.cpu().numpy()) if i == len(data_loader) - 1: xhat_plot = xhat.data.cpu().numpy() x_plot = img.data.cpu().numpy() z_all = np.concatenate(z_all) z_all = z_all[np.random.choice(z_all.shape[0], 1000, replace=False), :] if epoch % 1 == 0: wall = plot_recon_wall(xhat_plot, x_plot, epoch=epoch) self.wb.log({'Train_Recon': self.wb.Image(wall)}, step=self.num_steps) if epoch % 1 == 0: latent_plot = plot_latent_space(z_all, y=None) self.wb.log({'Latent_space': self.wb.Image(latent_plot)}, step=self.num_steps)
def _train_epoch(self, data_loader, epoch): """Training loop for a given epoch. Triningo goes over batches, images and latent space plots are logged to W&B Parameters ---------- data_loader : pytorch object data loader object with training items epoch : int epoch number Returns ------- """ # switch model to training mode self.model.train() # iterate over len(data)/batch_size mu_all = [] xhat_plot, x_plot = [], [] for i, (img, phy) in enumerate(data_loader): self.num_steps += 1 self.opt.zero_grad() img = img.to(self.device) phy = phy.to(self.device) xhat, z, mu, logvar = self.model(img, phy=phy) # calculate loss value loss = self._loss(img, xhat, mu, logvar, train=True, ep=epoch) # calculate the gradients loss.backward() # perform optimization step accordig to the gradients self.opt.step() self._report_train(i) # aux variables for latter plots mu_all.append(mu.data.cpu().numpy()) if i == len(data_loader) - 2: xhat_plot = xhat.data.cpu().numpy() x_plot = img.data.cpu().numpy() mu_all = np.concatenate(mu_all) mu_all = mu_all[ np.random.choice(mu_all.shape[0], 5000, replace=False), :] # plot reconstructed images ever 2 epochs if epoch % 2 == 0: wall = plot_recon_wall(xhat_plot, x_plot, epoch=epoch) self.wb.log({'Train_Recon': self.wb.Image(wall)}, step=self.num_steps) if epoch % 2 == 0: latent_plot = plot_latent_space(mu_all, y=None) self.wb.log({'Latent_space': self.wb.Image(latent_plot)}, step=self.num_steps)
def _test_epoch(self, test_loader, epoch): """Testing loop for a given epoch. Triningo goes over batches, images and latent space plots are logged to W&B logger Parameters ---------- data_loader : pytorch object data loader object with training items epoch : int epoch number Returns ------- """ # swich model to evaluation mode, this make it deterministic self.model.eval() with torch.no_grad(): xhat_plot, x_plot = [], [] for i, (img, phy) in enumerate(test_loader): # send data to current device img = img.to(self.device) phy = phy.to(self.device) xhat, z, mu, logvar = self.model(img, phy=phy) # calculate loss value loss = self._loss(img, xhat, mu, logvar, train=False, ep=epoch) # aux variables for plots if i == len(test_loader) - 2: xhat_plot = xhat.data.cpu().numpy() x_plot = img.data.cpu().numpy() self._report_test(epoch) # plot reconstructed images ever 2 epochs if epoch % 2 == 0: wall = plot_recon_wall(xhat_plot, x_plot, epoch=epoch) self.wb.log({'Test_Recon': self.wb.Image(wall)}, step=self.num_steps) return loss
def _test_epoch(self, test_loader, epoch): self.model.eval() with torch.no_grad(): xhat_plot, x_plot = [], [] for i, (img, meta) in enumerate(test_loader): img = img.to(self.device) xhat, z, mu, logvar = self.model(img) loss = self._loss(img, xhat, mu, logvar, train=False, ep=epoch) if i == len(test_loader) - 1: xhat_plot = xhat.data.cpu().numpy() x_plot = img.data.cpu().numpy() self._report_test(epoch) # generate data with G for visualization and seve to tensorboard if epoch % 2 == 0: wall = plot_recon_wall(xhat_plot, x_plot, epoch=epoch) self.wb.log({'Test_Recon': self.wb.Image(wall)}, step=self.num_steps) return loss