Example #1
0
 def __init__(self, args):
     super(LotteryVAE, self).__init__(args)
     LotteryAE.__init__(self, args)
     VAE.__init__(self, args)
     self.pruning = args.pruning
     self.mu.unprunable = True
     self.log_var.unprunable = True
Example #2
0
def train_vae(vae:VAE, dataset, dataloader_workers=8, lr=5e-5, kld_coef=0.1, noised_coef=0.2, epochs=400,
              batches_per_epoch=1000, batch_size=6, log_images_every=10):
    wandb.init(project="CST-GAN-2021")
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=dataloader_workers, pin_memory=True)
    optim = Adam(vae.parameters(), lr)
    test_images = torch.stack([dataset[10094], dataset[1282], dataset[25954], dataset[25513], dataset[7007]]).cuda()

    for i in range(epochs):
        data_iter = iter(dataloader)
        for k in tqdm(range(batches_per_epoch)):
            get_batch_time = 0.
            load_batch_time = 0.

            tmp = time.time()
            x = next(data_iter)
            if x is None:
                data_iter = iter(dataloader)
                x = next(data_iter)
            get_batch_time += time.time() - tmp
            tmp = time.time()
            x = x.cuda()
            load_batch_time += time.time() - tmp
            z, kld = vae(x)
            x_rec = vae.decode(z)
            #x_rec, x_rec_noised = compute_with_noised(z, vae)

            #noised_losses = [((x - xr)**2).sum((-1, -2, -3)) for xr in x_rec_noised]
            #rec_noised_loss = torch.stack(noised_losses).mean()

            kld_loss = kld.mean()
            rec_loss = ((x - x_rec)**2).sum((-1, -2, -3)).mean()

            scale_factor = x.size(-1) * x.size(-2)
            #loss = (rec_loss + kld_coef * kld_loss + noised_coef * rec_noised_loss) / scale_factor
            loss = (rec_loss + kld_coef * kld_loss) / scale_factor

            for param in vae.parameters():
                param.grad = None
            optim.zero_grad()
            loss.backward()
            #clip_grad_norm(vae.parameters(), 100)
            optim.step()

            wandb.log({"kld_loss": kld_loss.detach().cpu().item(),
                       "rec_loss": rec_loss.detach().cpu().item(),
                       #"rec_noised_loss": rec_noised_loss.detach().cpu().item(),
                       "loss": loss.detach().cpu().item(),
                       "get_batch_time": get_batch_time,
                       "load_batch_time": load_batch_time,
                       "step": i*batches_per_epoch + k + 1}, step=i*batches_per_epoch + k + 1)

        if i % log_images_every == 0:
            wandb.log(build_images_to_log(vae, test_images), step=(i+1) * batches_per_epoch)
Example #3
0
def build_images_to_log(vae:VAE, source):
    log_dict = {"orig_imgs": [wandb.Image(img.transpose(0, 2).numpy()) for img in source.cpu()]}
    with torch.no_grad():
        z = vae.encode(source)
        # reconstructed images
        x_rec = vae.decode(z)
        log_dict["rec_imgs"] = [wandb.Image(img.transpose(0, 2).numpy()) for img in x_rec.cpu()]
        # with randomized z
        for i in range(len(z)):
            new_z = [zz for zz in z]
            imgs = []
            for _ in range(4):
                new_z[i] = torch.randn_like(new_z[i])
                x_rec = vae.decode(new_z)
                imgs.extend([wandb.Image(img.transpose(0, 2).numpy()) for img in x_rec.cpu()])
            log_dict[f"noised_{i+1}_imgs"] = imgs
        # generated
        new_z = [torch.randn_like(zz) for zz in z]
        x_rec = vae.decode(new_z)
        log_dict["gen_imgs"] = [wandb.Image(img.transpose(0, 2).numpy()) for img in x_rec.cpu()]

    return log_dict
Example #4
0
from trainers.vae import train_vae
from models.vae.vae import VAE
from datasets.images_datasets import RawImagesDataset
from utils.pytorch import init_torch

if __name__ == "__main__":
    init_torch()
    vae = VAE(skip_features_start=2,
              num_skip_downsamplings=5,
              sc_encoder=True,
              sc_features=8,
              sc_encoder_prod_ups=True,
              dec_product_skip=True).cuda()
    dataset = RawImagesDataset(
        "/home/argentumwalker/Projects/#DATA/deviantart")
    train_vae(vae,
              dataset,
              dataloader_workers=6,
              batch_size=4,
              lr=5e-5,
              epochs=500,
              kld_coef=0.1)
Example #5
0
 def __init__(self, encoder, decoder, synth, args, upsampler=None):
     DDSSynth.__init__(self, encoder, decoder, synth, args, upsampler)
     VAE.__init__(self)
Example #6
0
     raise Exception('Unknown reconstruction loss ' + args.rec_loss)
 # Construct encoder and decoder
 encoder, decoder = construct_encoder_decoder(args.input_size,
                                              args.encoder_dims,
                                              args.latent_dims,
                                              channels=args.channels,
                                              n_layers=args.n_layers,
                                              hidden_size=args.n_hidden,
                                              n_mlp=args.n_layers // 2,
                                              type_mod=args.layers,
                                              args=args)
 # Construct specific type of AE
 if (args.model == 'ae'):
     model = AE(encoder, decoder, args.encoder_dims, args.latent_dims)
 elif (args.model == 'vae'):
     model = VAE(encoder, decoder, args.input_size, args.encoder_dims,
                 args.latent_dims)
 elif (args.model == 'wae'):
     model = WAE(encoder, decoder, args.input_size, args.encoder_dims,
                 args.latent_dims)
 elif (args.model == 'vae_flow'):
     # Construct the normalizing flow
     flow, blocks = construct_flow(args.latent_dims,
                                   flow_type=args.flow,
                                   flow_length=args.flow_length,
                                   amortization='input')
     # Construct full VAE with given flow
     model = VAEFlow(encoder, decoder, flow, args.input_size,
                     args.encoder_dims, args.latent_dims)
 # Construct specific regressor
 regression_model = construct_regressor(args.latent_dims,
                                        args.output_size,