Example #1
0
def main():

    # load real images info or generate real images info
    inception_model_score = generative_model_score.GenerativeModelScore()
    inception_model_score.lazy_mode(True)

    import torchvision
    from torch.autograd import Variable
    from torchvision import transforms
    import tqdm
    import os

    batch_size = 64
    epochs = 1000
    img_size = 32
    save_image_interval = 5
    loss_calculation_interval = 10
    latent_dim = 10
    n_iter = 3

    wandb.login()
    wandb.init(project="AAE",
               config={
                   "batch_size": batch_size,
                   "epochs": epochs,
                   "img_size": img_size,
                   "save_image_interval": save_image_interval,
                   "loss_calculation_interval": loss_calculation_interval,
                   "latent_dim": latent_dim,
                   "n_iter": n_iter,
               })
    config = wandb.config

    train_loader, validation_loader, test_loader = get_celebA_dataset(
        batch_size, img_size)
    # train_loader = get_cifar1_dataset(batch_size)

    image_shape = [3, img_size, img_size]

    import hashlib
    real_images_info_file_name = hashlib.md5(
        str(train_loader.dataset).encode()).hexdigest() + '.pickle'

    if os.path.exists('./inception_model_info/' + real_images_info_file_name):
        print("Using generated real image info.")
        print(train_loader.dataset)
        inception_model_score.load_real_images_info('./inception_model_info/' +
                                                    real_images_info_file_name)
    else:
        inception_model_score.model_to('cuda')

        #put real image
        for each_batch in train_loader:
            X_train_batch = each_batch[0]
            inception_model_score.put_real(X_train_batch)

        #generate real images info
        inception_model_score.lazy_forward(batch_size=64,
                                           device='cuda',
                                           real_forward=True)
        inception_model_score.calculate_real_image_statistics()
        #save real images info for next experiments
        inception_model_score.save_real_images_info('./inception_model_info/' +
                                                    real_images_info_file_name)
        #offload inception_model
        inception_model_score.model_to('cpu')

    encoder = Encoder(latent_dim, image_shape).cuda()
    decoder = Decoder(latent_dim, image_shape).cuda()
    discriminator = Discriminator(latent_dim).cuda()
    ae_optimizer = torch.optim.Adam(itertools.chain(encoder.parameters(),
                                                    decoder.parameters()),
                                    lr=1e-4)
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
    g_optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-4)

    r_losses = []
    d_losses = []
    g_losses = []
    precisions = []
    recalls = []
    fids = []
    inception_scores_real = []
    inception_scores_fake = []

    for i in range(0, epochs):
        batch_count = 0

        for each_batch in tqdm.tqdm(train_loader):
            batch_count += 1
            X_train_batch = Variable(each_batch[0]).cuda()
            r_loss = update_autoencoder(ae_optimizer, X_train_batch, encoder,
                                        decoder)

            for iter_ in range(n_iter):
                d_loss = update_discriminator(d_optimizer, X_train_batch,
                                              encoder, discriminator,
                                              latent_dim)

            g_loss = update_generator(g_optimizer, X_train_batch, encoder,
                                      discriminator)

            sampled_images = sample_image(encoder, decoder,
                                          X_train_batch).detach().cpu()

            if i % loss_calculation_interval == 0:
                inception_model_score.put_fake(sampled_images)

        if i % save_image_interval == 0:
            image = save_images(n_row=10,
                                epoch=i,
                                latent_dim=latent_dim,
                                model=decoder)
            wandb.log({'image': wandb.Image(image, caption='%s_epochs' % i)},
                      step=i)

        if i % loss_calculation_interval == 0:
            #offload all GAN model to cpu and onload inception model to gpu
            encoder = encoder.to('cpu')
            decoder = decoder.to('cpu')
            discriminator = discriminator.to('cpu')
            inception_model_score.model_to('cuda')

            #generate fake images info
            inception_model_score.lazy_forward(batch_size=64,
                                               device='cuda',
                                               fake_forward=True)
            inception_model_score.calculate_fake_image_statistics()
            metrics = inception_model_score.calculate_generative_score()

            #onload all GAN model to gpu and offload inception model to cpu
            inception_model_score.model_to('cpu')
            encoder = encoder.to('cuda')
            decoder = decoder.to('cuda')
            discriminator = discriminator.to('cuda')

            precision, recall, fid, inception_score_real, inception_score_fake, density, coverage = \
                metrics['precision'], metrics['recall'], metrics['fid'], metrics['real_is'], metrics['fake_is'], metrics['density'], metrics['coverage']

            wandb.log(
                {
                    "precision": precision,
                    "recall": recall,
                    "fid": fid,
                    "inception_score_real": inception_score_real,
                    "inception_score_fake": inception_score_fake,
                    "density": density,
                    "coverage": coverage
                },
                step=i)

            r_losses.append(r_loss)
            d_losses.append(d_loss)
            g_losses.append(g_loss)
            precisions.append(precision)
            recalls.append(recall)
            fids.append(fid)
            inception_scores_real.append(inception_score_real)
            inception_scores_fake.append(inception_score_fake)
            save_scores_and_print(i + 1, epochs, r_loss, d_loss, g_loss,
                                  precision, recall, fid, inception_score_real,
                                  inception_score_fake)

        inception_model_score.clear_fake()
    save_losses(epochs, loss_calculation_interval, r_losses, d_losses,
                g_losses)
    wandb.finish()
Example #2
0
import argparse
import distutils
import numpy as np
import pandas as pd
import torch
import generative_model_score
inception_model_score = generative_model_score.GenerativeModelScore()
import matplotlib
import matplotlib.pyplot as plt
import wandb
from torch.autograd import Variable
import tqdm
import os
import hashlib
from PIL import Image
import data_helper
import prior_factory
import model
from torchvision.utils import save_image
import seaborn as sns
import time
matplotlib.use('Agg')
start_time = None


def load_inception_model(train_loader, dataset, image_size, environment):
    global device
    if environment == 'nuri':
        icm_path = './inception_model_info/'
    else:
        icm_path = '../../../inception_model_info/'