def main(): # load real images info or generate real images info inception_model_score = generative_model_score.GenerativeModelScore() inception_model_score.lazy_mode(True) import torchvision from torch.autograd import Variable from torchvision import transforms import tqdm import os batch_size = 64 epochs = 1000 img_size = 32 save_image_interval = 5 loss_calculation_interval = 10 latent_dim = 10 n_iter = 3 wandb.login() wandb.init(project="AAE", config={ "batch_size": batch_size, "epochs": epochs, "img_size": img_size, "save_image_interval": save_image_interval, "loss_calculation_interval": loss_calculation_interval, "latent_dim": latent_dim, "n_iter": n_iter, }) config = wandb.config train_loader, validation_loader, test_loader = get_celebA_dataset( batch_size, img_size) # train_loader = get_cifar1_dataset(batch_size) image_shape = [3, img_size, img_size] import hashlib real_images_info_file_name = hashlib.md5( str(train_loader.dataset).encode()).hexdigest() + '.pickle' if os.path.exists('./inception_model_info/' + real_images_info_file_name): print("Using generated real image info.") print(train_loader.dataset) inception_model_score.load_real_images_info('./inception_model_info/' + real_images_info_file_name) else: inception_model_score.model_to('cuda') #put real image for each_batch in train_loader: X_train_batch = each_batch[0] inception_model_score.put_real(X_train_batch) #generate real images info inception_model_score.lazy_forward(batch_size=64, device='cuda', real_forward=True) inception_model_score.calculate_real_image_statistics() #save real images info for next experiments inception_model_score.save_real_images_info('./inception_model_info/' + real_images_info_file_name) #offload inception_model inception_model_score.model_to('cpu') encoder = Encoder(latent_dim, image_shape).cuda() decoder = Decoder(latent_dim, image_shape).cuda() discriminator = Discriminator(latent_dim).cuda() ae_optimizer = torch.optim.Adam(itertools.chain(encoder.parameters(), decoder.parameters()), lr=1e-4) d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4) g_optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-4) r_losses = [] d_losses = [] g_losses = [] precisions = [] recalls = [] fids = [] inception_scores_real = [] inception_scores_fake = [] for i in range(0, epochs): batch_count = 0 for each_batch in tqdm.tqdm(train_loader): batch_count += 1 X_train_batch = Variable(each_batch[0]).cuda() r_loss = update_autoencoder(ae_optimizer, X_train_batch, encoder, decoder) for iter_ in range(n_iter): d_loss = update_discriminator(d_optimizer, X_train_batch, encoder, discriminator, latent_dim) g_loss = update_generator(g_optimizer, X_train_batch, encoder, discriminator) sampled_images = sample_image(encoder, decoder, X_train_batch).detach().cpu() if i % loss_calculation_interval == 0: inception_model_score.put_fake(sampled_images) if i % save_image_interval == 0: image = save_images(n_row=10, epoch=i, latent_dim=latent_dim, model=decoder) wandb.log({'image': wandb.Image(image, caption='%s_epochs' % i)}, step=i) if i % loss_calculation_interval == 0: #offload all GAN model to cpu and onload inception model to gpu encoder = encoder.to('cpu') decoder = decoder.to('cpu') discriminator = discriminator.to('cpu') inception_model_score.model_to('cuda') #generate fake images info inception_model_score.lazy_forward(batch_size=64, device='cuda', fake_forward=True) inception_model_score.calculate_fake_image_statistics() metrics = inception_model_score.calculate_generative_score() #onload all GAN model to gpu and offload inception model to cpu inception_model_score.model_to('cpu') encoder = encoder.to('cuda') decoder = decoder.to('cuda') discriminator = discriminator.to('cuda') precision, recall, fid, inception_score_real, inception_score_fake, density, coverage = \ metrics['precision'], metrics['recall'], metrics['fid'], metrics['real_is'], metrics['fake_is'], metrics['density'], metrics['coverage'] wandb.log( { "precision": precision, "recall": recall, "fid": fid, "inception_score_real": inception_score_real, "inception_score_fake": inception_score_fake, "density": density, "coverage": coverage }, step=i) r_losses.append(r_loss) d_losses.append(d_loss) g_losses.append(g_loss) precisions.append(precision) recalls.append(recall) fids.append(fid) inception_scores_real.append(inception_score_real) inception_scores_fake.append(inception_score_fake) save_scores_and_print(i + 1, epochs, r_loss, d_loss, g_loss, precision, recall, fid, inception_score_real, inception_score_fake) inception_model_score.clear_fake() save_losses(epochs, loss_calculation_interval, r_losses, d_losses, g_losses) wandb.finish()
import argparse import distutils import numpy as np import pandas as pd import torch import generative_model_score inception_model_score = generative_model_score.GenerativeModelScore() import matplotlib import matplotlib.pyplot as plt import wandb from torch.autograd import Variable import tqdm import os import hashlib from PIL import Image import data_helper import prior_factory import model from torchvision.utils import save_image import seaborn as sns import time matplotlib.use('Agg') start_time = None def load_inception_model(train_loader, dataset, image_size, environment): global device if environment == 'nuri': icm_path = './inception_model_info/' else: icm_path = '../../../inception_model_info/'