def main(): parser = argparse.ArgumentParser() parser.add_argument('--num_images', default=25) parser.add_argument('--cuda', action='store_true') parser.add_argument('--model', required=True) opts = parser.parse_args() # GPU device = 'cpu' if opts.cuda: if torch.cuda.is_available(): device = 'cuda:0' else: print('CUDA not available!') # Load pretrained model netG = dcgan.Generator() netG.load_state_dict(torch.load(opts.model, map_location=device)) netG = netG.to(device) netG.eval() # Start generation noise = torch.randn(opts.num_images, NOISE_LENGTH, device=device) fake = netG(noise) vutils.save_image(fake, "samples/generate.png", 5, normalize=True, range=(-1, 1))
def test_output_shape(): netG = dcgan.Generator() noise = torch.randn(2, 100) fake = netG(noise) assert fake.shape == torch.Size([2, 3, 64, 64]) netD = dcgan.Critic() fake = torch.randn(2, 3, 64, 64) output = netD(fake) assert output.shape == torch.Size((2, ))
def main(): set_random_seed() device, gpu_ids = util.get_available_devices() # Arguments opt = args.get_setup_args() # Number of channels in the training images nc = opt.channels # Size of z latent vector (i.e. size of generator input) nz = opt.latent_dim # Size of feature maps in generator ngf = 64 def eval_fid(gen_images_path, eval_images_path): print("Calculating FID...") fid = fid_score.calculate_fid_given_paths( (gen_images_path, eval_images_path), opt.batch_size, device) return fid def evaluate(source_images_path, keep_images=True): dataset = datasets.ImageFolder(root=source_images_path, transform=transforms.Compose([ transforms.Resize( (opt.img_size, opt.img_size)), transforms.ToTensor() ])) dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) output_gen_images_path = os.path.join(opt.output_path, opt.version, opt.eval_mode) os.makedirs(output_gen_images_path, exist_ok=True) output_source_images_path = source_images_path + "_" + str( opt.img_size) source_images_available = True if (not os.path.exists(output_source_images_path)): os.makedirs(output_source_images_path) source_images_available = False images_done = 0 for _, data in enumerate(dataloader, 0): images, _ = data batch_size = images.size(0) noise = torch.randn((batch_size, nz, 1, 1)).to(device) gen_images = netG(noise) for i in range(images_done, images_done + batch_size): vutils.save_image(gen_images[i - images_done, :, :, :], "{}/{}.jpg".format(output_gen_images_path, i), normalize=True) if (not source_images_available): vutils.save_image(images[i - images_done, :, :, :], "{}/{}.jpg".format( output_source_images_path, i), normalize=True) images_done += batch_size fid = eval_fid(output_gen_images_path, output_source_images_path) if (not keep_images): print("Deleting images generated for validation...") rmtree(output_gen_images_path) return fid test_images_path = os.path.join(opt.data_path, "test") val_images_path = os.path.join(opt.data_path, "val") model_path = os.path.join(opt.output_path, opt.version, opt.model_file) netG = dcgan.Generator(nc, nz, ngf).to(device) if (opt.model_file.endswith(".pt")): netG.load_state_dict(torch.load(model_path)) elif (opt.model_file.endswith(".tar")): checkpoint = torch.load(model_path) netG.load_state_dict(checkpoint['g_state_dict']) netG.eval() if opt.eval_mode == "val": source_images_path = val_images_path else: source_images_path = test_images_path if opt.eval_mode == "val" or opt.eval_mode == "test": print("Evaluating model...") fid = evaluate(source_images_path) print("FID: {}".format(fid))
def train(dataloader, gradient_penalty, n_epochs, lr, n_critic=5, c=0.01, penalty_factor=10): def weights_init(m): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): nn.init.normal_(m.weight.data, 0.0, 0.02) elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)): nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight.data, 0.0, 0.02) nn.init.constant_(m.bias.data, 0) netG = dcgan.Generator().to(device) netG.apply(weights_init) optG = optim.Adam(netG.parameters(), lr, (0, 0.9)) if gradient_penalty else optim.RMSprop( netG.parameters(), lr) netD = dcgan.Critic(gradient_penalty).to(device) netD.apply(weights_init) optD = optim.Adam(netD.parameters(), lr, (0, 0.9)) if gradient_penalty else optim.RMSprop( netD.parameters(), lr) noise = torch.randn(dataloader.batch_size, NOISE_LENGTH, device=device) batches_done = 0 for epoch in range(n_epochs): # Train critic for i, (data, _) in enumerate(dataloader): data = data.to(device) real_batch_size = len(data) optD.zero_grad() noise.normal_() fake = netG(noise[:real_batch_size]) # Gradients from G are not used, so detach to avoid computing them # Maximize (3) -> minimize its inverse lossD = -(netD(data).mean() - netD(fake.detach()).mean()) if gradient_penalty: lossD += penalty_factor * compute_gradient_penalty( data, fake, netD) lossD.backward() optD.step() # Clamp the weights if not gradient_penalty: for p in netD.parameters(): p.data.clamp_(-c, c) # Train generator if i % n_critic == 0: optG.zero_grad() noise.normal_() # Minimize EM distance lossG = -netD(netG(noise[:real_batch_size])).mean() lossG.backward() optG.step() print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, n_epochs, batches_done % len(dataloader), len(dataloader), lossD.item(), lossG.item())) if batches_done % 100 == 0: fake = netG(NOISE_BASELINE) vutils.save_image(fake, "samples/%d.png" % batches_done, 5, normalize=True, range=(-1, 1)) batches_done += 1 # Save the model torch.save(netG.state_dict(), 'pretrain/netG_epoch_%d.pth' % epoch) torch.save(netD.state_dict(), 'pretrain/netD_epoch_%d.pth' % epoch)
transforms.ToTensor() #transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)) ]) ) dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, shuffle=True, num_workers=int(2)) ###### model ####### ngf = opt.ngf ndf = opt.ndf nz = opt.nz nc = 3 netG = dcgan.Generator(nc, ngf, nz, opt.fineSize) netD = dcgan.Discriminator(nc, ndf, opt.hidden_size, opt.fineSize) if(opt.cuda): netG.cuda() netD.cuda() #### setup optimizer ##### optimizerG = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) optimizerD = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) criterion = nn.L1Loss() ####### Variables ########## noise = torch.FloatTensor(opt.batchSize, opt.nz, 1, 1) real = torch.FloatTensor(opt.batchSize, nc, opt.fineSize, opt.fineSize) label = torch.FloatTensor(1)
def main(): set_random_seed() #cuda = True if torch.cuda.is_available() else False device, gpu_ids = util.get_available_devices() def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) # Arguments opt = args.get_setup_args() # Number of channels in the training images nc = opt.channels # Size of z latent vector (i.e. size of generator input) nz = opt.latent_dim # Size of feature maps in generator ngf = 64 # Size of feature maps in discriminator ndf = 64 num_classes = opt.num_classes train_images_path = os.path.join(opt.data_path, "train") val_images_path = os.path.join(opt.data_path, "val") output_model_path = os.path.join(opt.output_path, opt.version) output_train_images_path = os.path.join(opt.output_path, opt.version, "train") output_sample_images_path = os.path.join(opt.output_path, opt.version, "sample") os.makedirs(output_train_images_path, exist_ok=True) os.makedirs(output_sample_images_path, exist_ok=True) # Initialize BCELoss function criterion = nn.BCELoss() # Initialize generator and discriminator netG = dcgan.Generator(nc, nz, ngf).to(device) netG.apply(weights_init) netD = dcgan.Discriminator(nc, ndf).to(device) netD.apply(weights_init) # Create batch of latent vectors to visualize # the progress of the generator # sample_noise = torch.randn(64, nz, 1, 1, device=device) # Establish convention for real and fake labels during training real_label = 1 fake_label = 0 # Setup Adam optimizers for both G and D optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) train_set = datasets.ImageFolder(root=train_images_path, transform=transforms.Compose([ transforms.Resize( (opt.img_size, opt.img_size)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) dataloader = torch.utils.data.DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) # ---------- # Training # ---------- G_losses = [] D_losses = [] FIDs = [] val_epochs = [] def eval_fid(gen_images_path, eval_images_path): print("Calculating FID...") fid = fid_score.calculate_fid_given_paths( (gen_images_path, eval_images_path), opt.batch_size, device) return fid def validate(keep_images=True): val_set = datasets.ImageFolder(root=val_images_path, transform=transforms.Compose([ transforms.Resize( (opt.img_size, opt.img_size)), transforms.ToTensor() ])) val_loader = torch.utils.data.DataLoader(val_set, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) output_images_path = os.path.join(opt.output_path, opt.version, "val") os.makedirs(output_images_path, exist_ok=True) output_source_images_path = val_images_path + "_" + str(opt.img_size) source_images_available = True if (not os.path.exists(output_source_images_path)): os.makedirs(output_source_images_path) source_images_available = False images_done = 0 for _, data in enumerate(val_loader, 0): images, _ = data batch_size = images.size(0) noise = torch.randn((batch_size, nz, 1, 1)).to(device) gen_images = netG(noise) for i in range(images_done, images_done + batch_size): vutils.save_image(gen_images[i - images_done, :, :, :], "{}/{}.jpg".format(output_images_path, i), normalize=True) if (not source_images_available): vutils.save_image(images[i - images_done, :, :, :], "{}/{}.jpg".format( output_source_images_path, i), normalize=True) images_done += batch_size fid = eval_fid(output_images_path, output_source_images_path) if (not keep_images): print("Deleting images generated for validation...") rmtree(output_images_path) return fid def sample_images(num_images, batches_done): # Sample noise z = torch.randn((num_classes * num_images, nz, 1, 1)).to(device) sample_imgs = netG(z) vutils.save_image(sample_imgs.data, "{}/{}.png".format(output_sample_images_path, batches_done), nrow=num_images, padding=2, normalize=True) def save_loss_plot(path): plt.figure(figsize=(10, 5)) plt.title("Generator and Discriminator Loss During Training") plt.plot(G_losses, label="G") plt.plot(D_losses, label="D") plt.xlabel("iterations") plt.ylabel("Loss") plt.legend() plt.savefig(path) def save_fid_plot(FIDs, epochs, path): #N = len(FIDs) plt.figure(figsize=(10, 5)) plt.title("FID on Validation Set") plt.plot(epochs, FIDs) plt.xlabel("epochs") plt.ylabel("FID") #plt.xticks([i * 49 for i in range(1, N+1)]) plt.savefig(path) plt.close() print("Starting Training Loop...") # For each epoch for epoch in range(1, opt.num_epochs + 1): # For each batch in the dataloader for i, data in enumerate(dataloader, 0): ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### ## Train with all-real batch netD.zero_grad() # Format batch real_imgs = data[0].to(device) batch_size = real_imgs.size(0) label = torch.full((batch_size, ), real_label, device=device) # Forward pass real batch through D output = netD(real_imgs).view(-1) # Calculate loss on all-real batch errD_real = criterion(output, label) # Calculate gradients for D in backward pass errD_real.backward() D_x = output.mean().item() ## Train with all-fake batch # Generate batch of latent vectors noise = torch.randn(batch_size, nz, 1, 1, device=device) # Generate fake image batch with G fake = netG(noise) label.fill_(fake_label) # Classify all fake batch with D output = netD(fake.detach()).view(-1) # Calculate D's loss on the all-fake batch errD_fake = criterion(output, label) # Calculate the gradients for this batch errD_fake.backward() D_G_z1 = output.mean().item() # Add the gradients from the all-real and all-fake batches errD = errD_real + errD_fake # Update D optimizerD.step() ############################ # (2) Update G network: maximize log(D(G(z))) ########################### netG.zero_grad() label.fill_(real_label) # fake labels are real for generator cost # Since we just updated D, perform another forward pass of all-fake batch through D output = netD(fake).view(-1) # Calculate G's loss based on this output errG = criterion(output, label) # Calculate gradients for G errG.backward() D_G_z2 = output.mean().item() # Update G optimizerG.step() # Save Losses for plotting G_losses.append(errG.item()) D_losses.append(errD.item()) # Output training stats if i % opt.print_every == 0: print( "[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f] [G loss: %.4f] [D(x): %.4f] [D(G(z)): %.4f / %.4f]" % (epoch, opt.num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) batches_done = epoch * len(dataloader) + i if (batches_done % opt.sample_interval == 0) or ((epoch == opt.num_epochs - 1) and (i == len(dataloader) - 1)): # Put G in eval mode netG.eval() with torch.no_grad(): sample_images(opt.num_sample_images, batches_done) vutils.save_image(fake.data[:25], "{}/{}.png".format(output_train_images_path, batches_done), nrow=5, padding=2, normalize=True) # Put G back in train mode netG.train() # Save model checkpoint if (epoch != opt.num_epochs and epoch % opt.checkpoint_epochs == 0): print("Checkpoint at epoch {}".format(epoch)) print("Saving generator model...") torch.save( netG.state_dict(), os.path.join(output_model_path, "model_checkpoint_{}.pt".format(epoch))) print("Saving G & D loss plot...") save_loss_plot( os.path.join(opt.output_path, opt.version, "loss_plot_{}.png".format(epoch))) print("Validating model...") netG.eval() with torch.no_grad(): fid = validate(keep_images=False) print("Validation FID: {}".format(fid)) with open(os.path.join(opt.output_path, opt.version, "FIDs.txt"), "a") as f: f.write("Epoch: {}, FID: {}\n".format(epoch, fid)) FIDs.append(fid) val_epochs.append(epoch) print("Saving FID plot...") save_fid_plot( FIDs, val_epochs, os.path.join(opt.output_path, opt.version, "fid_plot_{}.png".format(epoch))) netG.train() print("Saving final generator model...") torch.save(netG.state_dict(), os.path.join(output_model_path, "model.pt")) print("Done!") print("Saving final G & D loss plot...") save_loss_plot(os.path.join(opt.output_path, opt.version, "loss_plot.png")) print("Done!") print("Validating final model...") netG.eval() with torch.no_grad(): fid = validate() print("Final Validation FID: {}".format(fid)) with open(os.path.join(opt.output_path, opt.version, "FIDs.txt"), "a") as f: f.write("Epoch: {}, FID: {}\n".format(epoch, fid)) FIDs.append(fid) val_epochs.append(epoch) print("Saving final FID plot...") save_fid_plot(FIDs, val_epochs, os.path.join(opt.output_path, opt.version, "fid_plot"))