示例#1
0
 def train(self, n=None, **kwargs):
     self.model.train()
     
     logger = Logger()
     for i, (data, label) in enumerate(self.train_loader):
         start_time = perf_counter()
         data = data.to(self.model.device)
         re_x, mu, logvar = self.model(data)
         out = vae_loss(re_x, data, mu, logvar, 'BCE')
         loss = out['loss']
         self.optimizer.zero_grad()
         loss.backward()
         self.optimizer.step()
         
         logger('epoch', n)
         self.model.total_iter += 1
         logger('iteration', self.model.total_iter)
         logger('mini-batch', i)
         logger('train_loss', out['loss'].item())
         logger('reconstruction_loss', out['re_loss'].item())
         logger('KL_loss', out['KL_loss'].item())
         logger('num_seconds', round(perf_counter() - start_time, 1))
         if i == 0 or (i+1) % self.config['log.freq'] == 0:
             logger.dump(keys=None, index=-1, indent=0, border='-'*50)
     mean_loss = np.mean([logger.logs['train_loss']])
     print(f'====> Average loss: {mean_loss}')
     
     # Use decoder to sample images from standard Gaussian noise
     with torch.no_grad():  # fast, disable grad
         z = torch.randn(64, self.config['nn.z_dim']).to(self.model.device)
         re_x = self.model.decode(z).cpu()
         save_image(re_x.view(64, 1, 28, 28), f'{kwargs["logdir"]}/sample_{n}.png')
     return logger
示例#2
0
    def eval(self, n=None, **kwargs):
        self.model.eval()

        logger = Logger()
        for i, (data, label) in enumerate(self.test_loader):
            data = data.to(self.model.device)
            with torch.no_grad():
                re_x, mu, logvar = self.model(data)
                out = vae_loss(re_x, data, mu, logvar, 'BCE')
                logger('eval_loss', out['loss'].item())
        mean_loss = np.mean(logger.logs['eval_loss'])
        print(f'====> Test set loss: {mean_loss}')

        # Reconstruct some test images
        data, label = next(iter(self.test_loader))  # get a random batch
        data = data.to(self.model.device)
        n = min(data.size(0), 8)  # number of images
        D = data[:n]
        with torch.no_grad():
            re_x, _, _ = self.model(D)
        compare_img = torch.cat([D.cpu(), re_x.cpu().view(-1, 1, 28, 28)])
        save_image(compare_img,
                   f'{kwargs["logdir"]}/reconstruction_{n}.png',
                   nrow=n)

        return logger
示例#3
0
def test_model(model, test_loader, device):
    model.eval()
    for (x, _) in test_loader:
        x = x.to(device)
        output, mean, logvar = model(x)
        loss, _, _ = vae_loss(x, output, mean, logvar)

    return loss.item()
示例#4
0
def reconstruction_error(dataloader, model, beta):

    loss, num_batches = 0, 0
    for image_batch, _ in dataloader:

        with torch.no_grad():

            image_batch = image_batch.to(device)
            image_batch_recon, latent_mu, latent_logvar = model(image_batch)
            loss += vae_loss(image_batch_recon, image_batch, latent_mu,
                             latent_logvar, beta).item()
            num_batches += 1

    loss /= num_batches
    print('average reconstruction error: %f' % (loss))
    return loss
示例#5
0
def train_model(
    model,
    train_loader,
    val_loader,
    epochs,
    optimizer,
    device,
    image_dims,
    save_freq=5,
    data_cmap="Greys_r",
):
    print("Training model...")

    for i in range(epochs):
        model.train()
        tqdm_loader = tqdm(train_loader, desc="Epoch " + str(i))
        running_loss = 0
        total_images = 0
        for (x, _) in tqdm_loader:
            x = x.to(device)
            optimizer.zero_grad()
            output, mean, logvar = model(x)
            loss, _, _ = vae_loss(x, output, mean, logvar)
            loss.backward()
            optimizer.step()

            batch_size = x.shape[0]
            running_loss += loss.item() * batch_size
            total_images += batch_size

            tqdm_loader.set_postfix(
                {"training_loss": running_loss / total_images})

        model.eval()
        val_loss = test_model(model, val_loader, device)
        print("\tValidation loss: " + str(val_loss))

        if i % save_freq == 0:
            save_checkpoint(model, i)
            plot_reconstructions(model, next(iter(val_loader)), device, i,
                                 image_dims, data_cmap)
            plot_samples_from_prior(model, device, i, image_dims, data_cmap)
示例#6
0
optim = torch.optim.Adam(params=vae.parameters(), lr=lr, weight_decay=1e-5)

writer = SummaryWriter(f'runs/{run_id}')
PATH = f"models/{run_id}.pt"

for epoch in range(EPOCHS):
    num_batches = 0
    train_loss = 0
    
    for image_batch, _ in train_dataloader:
        
        image_batch = image_batch.to(device)
        if not(conv):
        	image_batch = image_batch.reshape(-1, 784)
        image_batch_recon, latent_mu, latent_logvar = vae(image_batch)
        
        loss = vae_loss(image_batch_recon, image_batch, latent_mu, latent_logvar, beta)
        
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        train_loss += loss.item()
        num_batches += 1

    train_loss /= num_batches
    writer.add_scalar('train/loss', train_loss, epoch + 1)
    print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, EPOCHS, train_loss))
    torch.save(vae, PATH)