예제 #1
0
    def _train_epoch(self, data_loader, epoch):
        self.model.train()
        # iterate over len(data)/batch_size
        z_all = []
        xhat_plot, x_plot = [], []
        for i, (img, meta) in enumerate(data_loader):
            self.num_steps += 1
            self.opt.zero_grad()
            img = img.to(self.device)

            xhat, z, mu, logvar = self.model(img)

            loss = self._loss(img, xhat, mu, logvar, train=True, ep=epoch)
            loss.backward()
            self.opt.step()

            self._report_train(i)
            z_all.append(mu.data.cpu().numpy())
            if i == len(data_loader) - 1:
                xhat_plot = xhat.data.cpu().numpy()
                x_plot = img.data.cpu().numpy()

        z_all = np.concatenate(z_all)
        z_all = z_all[np.random.choice(z_all.shape[0], 1000, replace=False), :]

        if epoch % 1 == 0:
            wall = plot_recon_wall(xhat_plot, x_plot, epoch=epoch)
            self.wb.log({'Train_Recon': self.wb.Image(wall)},
                        step=self.num_steps)

        if epoch % 1 == 0:
            latent_plot = plot_latent_space(z_all, y=None)
            self.wb.log({'Latent_space': self.wb.Image(latent_plot)},
                        step=self.num_steps)
예제 #2
0
    def _train_epoch(self, data_loader, epoch):
        """Training loop for a given epoch. Triningo goes over
        batches, images and latent space plots are logged to
        W&B
        Parameters
        ----------
        data_loader : pytorch object
            data loader object with training items
        epoch       : int
            epoch number
        Returns
        -------
        """
        # switch model to training mode
        self.model.train()
        # iterate over len(data)/batch_size
        mu_all = []
        xhat_plot, x_plot = [], []
        for i, (img, phy) in enumerate(data_loader):
            self.num_steps += 1
            self.opt.zero_grad()
            img = img.to(self.device)
            phy = phy.to(self.device)

            xhat, z, mu, logvar = self.model(img, phy=phy)
            # calculate loss value
            loss = self._loss(img, xhat, mu, logvar, train=True, ep=epoch)
            # calculate the gradients
            loss.backward()
            # perform optimization step accordig to the gradients
            self.opt.step()

            self._report_train(i)
            # aux variables for latter plots
            mu_all.append(mu.data.cpu().numpy())
            if i == len(data_loader) - 2:
                xhat_plot = xhat.data.cpu().numpy()
                x_plot = img.data.cpu().numpy()

        mu_all = np.concatenate(mu_all)
        mu_all = mu_all[
            np.random.choice(mu_all.shape[0], 5000, replace=False), :]

        # plot reconstructed images ever 2 epochs
        if epoch % 2 == 0:
            wall = plot_recon_wall(xhat_plot, x_plot, epoch=epoch)
            self.wb.log({'Train_Recon': self.wb.Image(wall)},
                        step=self.num_steps)

        if epoch % 2 == 0:
            latent_plot = plot_latent_space(mu_all, y=None)
            self.wb.log({'Latent_space': self.wb.Image(latent_plot)},
                        step=self.num_steps)
예제 #3
0
    def _test_epoch(self, test_loader, epoch):
        """Testing loop for a given epoch. Triningo goes over
        batches, images and latent space plots are logged to
        W&B logger
        Parameters
        ----------
        data_loader : pytorch object
            data loader object with training items
        epoch       : int
            epoch number
        Returns
        -------
        """
        # swich model to evaluation mode, this make it deterministic
        self.model.eval()
        with torch.no_grad():
            xhat_plot, x_plot = [], []

            for i, (img, phy) in enumerate(test_loader):
                # send data to current device
                img = img.to(self.device)
                phy = phy.to(self.device)
                xhat, z, mu, logvar = self.model(img, phy=phy)
                # calculate loss value
                loss = self._loss(img, xhat, mu, logvar, train=False, ep=epoch)

                # aux variables for plots
                if i == len(test_loader) - 2:
                    xhat_plot = xhat.data.cpu().numpy()
                    x_plot = img.data.cpu().numpy()

        self._report_test(epoch)

        # plot reconstructed images ever 2 epochs
        if epoch % 2 == 0:
            wall = plot_recon_wall(xhat_plot, x_plot, epoch=epoch)
            self.wb.log({'Test_Recon': self.wb.Image(wall)},
                        step=self.num_steps)

        return loss
예제 #4
0
    def _test_epoch(self, test_loader, epoch):
        self.model.eval()
        with torch.no_grad():
            xhat_plot, x_plot = [], []

            for i, (img, meta) in enumerate(test_loader):
                img = img.to(self.device)
                xhat, z, mu, logvar = self.model(img)
                loss = self._loss(img, xhat, mu, logvar, train=False, ep=epoch)

                if i == len(test_loader) - 1:
                    xhat_plot = xhat.data.cpu().numpy()
                    x_plot = img.data.cpu().numpy()

        self._report_test(epoch)

        # generate data with G for visualization and seve to tensorboard
        if epoch % 2 == 0:
            wall = plot_recon_wall(xhat_plot, x_plot, epoch=epoch)
            self.wb.log({'Test_Recon': self.wb.Image(wall)},
                        step=self.num_steps)

        return loss