def main(args): # Load the data (DataLoader object) path_monnet = args.Monet_Path path_pictures = args.Pictures_Path save_path = args.Save_Path batch_size = args.batch_size n_epochs = args.epochs device = args.device dataset = get_data_loader(path_monnet, path_pictures, batch_size) # Create Generators and Discriminators and put them on GPU/TPU generator_AB = Generator().to(device) generator_BA = Generator().to(device) discriminator_A = Discriminator().to(device) discriminator_B = Discriminator().to(device) generator_AB.apply(weights_init_normal) generator_BA.apply(weights_init_normal) discriminator_A.apply(weights_init_normal) discriminator_B.apply(weights_init_normal) # Set optimizers G_optimizer = torch.optim.Adam(itertools.chain(generator_AB.parameters(), generator_BA.parameters()), lr=2e-4) D_optimizer = torch.optim.Adam(itertools.chain( discriminator_A.parameters(), discriminator_B.parameters()), lr=2e-4) # Set trainer trainer = Trainer( generator_ab=generator_AB, generator_ba=generator_BA, discriminator_a=discriminator_A, discriminator_b=discriminator_B, generator_optimizer=G_optimizer, discriminator_optimizer=D_optimizer, n_epochs=n_epochs, dataloader=dataset, device=device, ) # Launch Training trainer.train() # Save the model and the loss during training # Save logs trainer.log.save(os.path.join(save_path, 'save_loss.txt')) # Save the model torch.save(generator_AB.state_dict(), os.path.join(save_path, 'generator_AB.pt')) torch.save(generator_BA.state_dict(), os.path.join(save_path, 'generator_BA.pt')) torch.save(discriminator_A.state_dict(), os.path.join(save_path, 'discriminator_A.pt')) torch.save(discriminator_B.state_dict(), os.path.join(save_path, 'discriminator_B.pt'))
class GANSolver(object): def __init__(self, config, data_loader): self.generator = None self.discriminator = None self.g_optimizer = None self.d_optimizer = None self.cuda = torch.cuda.is_available() self.device = torch.device('cuda' if self.cuda else 'cpu') self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.z_dim = config.z_dim self.beta1 = config.beta1 self.beta2 = config.beta2 self.image_size = config.image_size self.data_loader = data_loader self.num_epochs = config.num_epochs self.batch_size = config.batch_size self.sample_size = config.sample_size self.lr = config.lr self.log_step = config.log_step self.sample_step = config.sample_step self.sample_path = config.sample_path self.model_path = config.model_path self.build_model() def build_model(self): """Build generator and discriminator.""" self.generator = Generator(z_dim=self.z_dim, image_size=self.image_size, conv_dim=self.g_conv_dim)\ .to(self.device) self.discriminator = Discriminator(image_size=self.image_size, conv_dim=self.d_conv_dim).to( self.device) self.g_optimizer = optim.Adam(self.generator.parameters(), self.lr, [self.beta1, self.beta2]) self.d_optimizer = optim.Adam(self.discriminator.parameters(), self.lr, [self.beta1, self.beta2]) if self.cuda: cudnn.benchmark = True def to_data(self, x): """Convert variable to tensor.""" if self.cuda: x = x.cpu() return x.data def reset_grad(self): """Zero the gradient buffers.""" self.d_optimizer.zero_grad() self.g_optimizer.zero_grad() @staticmethod def de_normalize(x): """Convert range (-1, 1) to (0, 1)""" out = (x + 1) / 2 return out.clamp(0, 1) @staticmethod def least_square_loss(output, target): return torch.mean((output - target)**2) def fixed_noise(self): return self.torch.randn(self.batch_size, self.z_dim, device=self.device) def save_model(self, epoch): g_path = os.path.join(self.model_path, 'GAN-generator-%d.pkl' % (epoch + 1)) d_path = os.path.join(self.model_path, 'GAN-discriminator-%d.pkl' % (epoch + 1)) torch.save(self.generator.state_dict(), g_path) torch.save(self.discriminator.state_dict(), d_path) def save_fakes(self, step, epoch): if (step + 1) % self.sample_step == 0: fake_images = self.generator(self.fixed_noise()) torchvision.utils.save_image( self.de_normalize(fake_images.data), os.path.join( self.sample_path, 'GAN-fake_samples-%d-%d.png' % (epoch + 1, step + 1))) def train(self): """Train generator and discriminator.""" total_step = len(self.data_loader) for epoch in range(self.num_epochs): print("===> Epoch [%d/%d]" % (epoch + 1, self.num_epochs)) for i, images in enumerate(self.data_loader): # ===================== Train D ===================== # images = images.to(self.device) batch_size = images.size(0) noise = torch.randn(batch_size, self.z_dim, device=self.device) # Train D to recognize real images as real. outputs = self.discriminator(images) real_loss = self.least_square_loss( outputs, 1 ) # L2 loss instead of Binary cross entropy loss (this is optional for stable training) # Train D to recognize fake images as fake. fake_images = self.generator(noise) outputs = self.discriminator(fake_images) fake_loss = self.least_square_loss(outputs, 0) # Backpropagation + optimize self.reset_grad() d_loss = real_loss + fake_loss d_loss.backward() self.d_optimizer.step() # ===================== Train G =====================# noise = torch.randn(batch_size, self.z_dim, device=self.device) # Train G so that D recognizes G(z) as real. fake_images = self.generator(noise) outputs = self.discriminator(fake_images) g_loss = self.least_square_loss(outputs, 1) # Backpropagation + optimize self.reset_grad() g_loss.backward() self.g_optimizer.step() # print the log info via progress bar progress_bar( i, total_step, 'd_real_loss: %.4f | d_fake_loss: %.4f | g_loss: %.4f' % (real_loss.item(), fake_loss.item(), g_loss.item())) # save the sampled images self.save_fakes(step=i, epoch=epoch) # save the model parameters for each epoch self.save_model(epoch=epoch) def sample(self): # Load trained parameters g_path = os.path.join(self.model_path, 'generator-%d.pkl' % self.num_epochs) d_path = os.path.join(self.model_path, 'discriminator-%d.pkl' % self.num_epochs) self.generator.load_state_dict(torch.load(g_path)) self.discriminator.load_state_dict(torch.load(d_path)) self.generator.eval() self.discriminator.eval() # Sample the images noise = torch.randn(self.sample_size, self.z_dim, device=self.device) with torch.no_grad(): fake_images = self.generator(noise) sample_path = os.path.join(self.sample_path, 'fake_samples-final.png') torchvision.utils.save_image(self.de_normalize(fake_images.data), sample_path, nrow=12) print("Saved sampled images to '%s'" % sample_path)
############################ # (2) Update G network: maximize log(D(G(z))) ########################### netG.zero_grad() label.fill_(real_label) # fake labels are real for generator cost output = netD(fake) errG = criterion(output, label) errG.backward() D_G_z2 = output.mean().item() optimizerG.step() print( '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, niter, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) # save the output if i % 100 == 0: print('saving the output') vutils.save_image(real_cpu, 'output/real_samples.png', normalize=True) fake = netG(fixed_noise) vutils.save_image(fake.detach(), 'output/fake_samples_epoch_%03d.png' % (epoch), normalize=True) # Check pointing for every epoch torch.save(netG.state_dict(), 'weights/netG_epoch_%d.pth' % (epoch)) torch.save(netD.state_dict(), 'weights/netD_epoch_%d.pth' % (epoch))