def inference(): # Inference Path # make_dirs(config.inference_path) # Prepare Data Loader # test_loader = get_celeba_loader('test', config.batch_size, config.selected_attrs) # Prepare Generator # G = Generator(num_classes=len(config.selected_attrs)).to(device) G.load_state_dict( torch.load( os.path.join( config.weights_path, 'StarGAN_Generator_Epoch_{}.pkl'.format(config.num_epochs)))) # Test # print("StarGAN | Generating Aligned CelebA Images started...") for i, (image, label) in enumerate(test_loader): # Prepare Data # image = image.to(device) fixed_labels = create_labels(label, selected_attrs=config.selected_attrs) # Generate Fake Images # x_fake_list = [image] for c_fixed in fixed_labels: x_fake_list.append(G(image, c_fixed)) x_concat = torch.cat(x_fake_list, dim=3) # Save Images # save_image(denorm(x_concat.data.cpu()), os.path.join( config.inference_path, 'StarGAN_Aligned_CelebA_Results_%04d.png' % (i + 1)), nrow=1, padding=0) make_gifs_test("StarGAN", config.inference_path)
def train(): # Fix Seed for Reproducibility # torch.manual_seed(9) if torch.cuda.is_available(): torch.cuda.manual_seed(9) # Samples, Weights and Results Path # paths = [config.samples_path, config.weights_path, config.plots_path] paths = [make_dirs(path) for path in paths] # Prepare Data Loader # celeba_loader = get_celeba_loader(path=config.celeba_path, batch_size=config.batch_size) total_batch = len(celeba_loader) # Prepare Networks # D = Discriminator().to(device) G = Generator().to(device) # Optimizer # D_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, D.parameters()), lr=config.D_lr, betas=(0.0, 0.9)) G_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, G.parameters()), lr=config.G_lr, betas=(0.0, 0.9)) D_optim_scheduler = get_lr_scheduler(D_optim) G_optim_scheduler = get_lr_scheduler(G_optim) # Fixed Noise # fixed_noise = torch.randn(config.batch_size, config.noise_dim, 1, 1).to(device) # Lists # D_losses, G_losses = [], [] # Train # print("Training has started with total epoch of {}.".format( config.num_epochs)) for epoch in range(config.num_epochs): for i, (images, labels) in enumerate(celeba_loader): # Data Preparation # images = images.to(device) noise = torch.randn(config.batch_size, config.noise_dim, 1, 1).to(device) # Initialize Optimizers # D_optim.zero_grad() G_optim.zero_grad() ####################### # Train Discriminator # ####################### # Hinge Loss using Real Image # prob_real = D(images)[0] D_real_loss = nn.ReLU()(1.0 - prob_real).mean() # Hinge Loss using Generated Image # fake_image = G(noise)[0] prob_fake = D(fake_image.detach())[0] D_fake_loss = nn.ReLU()(1.0 + prob_fake).mean() # Calculate Total Discriminator Loss # D_loss = D_real_loss + D_fake_loss # Back Propagation and Update # D_loss.backward() D_optim.step() ################### # Train Generator # ################### # Hinge Loss using Generated Image # fake_image = G(noise)[0] prob_fake = D(fake_image)[0] # Calculate Total Generator Loss # G_loss = -prob_fake.mean() # Back Propagation and Update # G_loss.backward() G_optim.step() # Contain Losses # D_losses.append(D_loss.item()) G_losses.append(G_loss.item()) #################### # Print Statistics # #################### if (i + 1) % config.print_every == 0: print( "Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}" .format(epoch + 1, config.num_epochs, i + 1, total_batch, np.mean(D_losses), np.mean(G_losses))) # Sample Images # sample_images(G, fixed_noise, epoch) # Adjust Learning Rate # D_optim_scheduler.step() G_optim_scheduler.step() # Save Model Weights # if (epoch + 1) % config.save_every == 0: torch.save( G.state_dict(), os.path.join(config.weights_path, 'Face_Generator_Epoch_{}.pkl'.format(epoch + 1))) # Make a GIF file # make_gifs_train("Face_Generation", config.samples_path) # Plot Losses # plot_losses(D_losses, G_losses, config.num_epochs, config.plots_path) print("Training finished.")
def train(): # Fix Seed for Reproducibility # torch.manual_seed(9) if torch.cuda.is_available(): torch.cuda.manual_seed(9) # Samples, Weights and Results Path # paths = [config.samples_path, config.weights_path, config.plots_path] paths = [make_dirs(path) for path in paths] # Prepare Data Loader # celeba_loader = get_celeba_loader(path=config.celeba_path, batch_size=config.batch_size) total_batch = len(celeba_loader) # Prepare Networks # D = Discriminator().to(device) G = Generator().to(device) # Loss Function # criterion = nn.MSELoss() # Optimizer # D_optim = torch.optim.Adam(D.parameters(), lr=config.D_lr, betas=(0.5, 0.999)) G_optim = torch.optim.Adam(G.parameters(), lr=config.G_lr, betas=(0.5, 0.999)) D_optim_scheduler = get_lr_scheduler(D_optim) G_optim_scheduler = get_lr_scheduler(G_optim) # Fixed Noise # fixed_noise = torch.randn(config.batch_size, config.noise_dim, 1, 1).to(device) # Lists # D_losses, G_losses = [], [] # Train # print("Training has started with total epoch of {}.".format( config.num_epochs)) for epoch in range(config.num_epochs): for i, (images, labels) in enumerate(celeba_loader): # Data Preparation # images = images.to(device) # Initialize Optimizers # G_optim.zero_grad() D_optim.zero_grad() ####################### # Train Discriminator # ####################### # Adversarial Loss using Real Image # _, prob_real = D(images) D_real_loss = criterion(prob_real, images) # Adversarial Loss using Fake Image # noise = torch.randn(config.batch_size, config.noise_dim, 1, 1).to(device) fake_images = G(noise) _, prob_fake = D(fake_images.detach()) D_fake_loss = criterion(prob_fake, fake_images) D_fake_loss = torch.clamp(config.margin - D_fake_loss, min=0) # Calculate Total Discriminator Loss # D_loss = D_real_loss if D_fake_loss.item() < config.margin: D_loss += D_fake_loss # Back Propagation and Update # D_loss.backward() D_optim.step() ################### # Train Generator # ################### # Adversarial Loss # fake_images = G(noise) encoded, prob_fake = D(fake_images) G_loss = criterion(prob_fake, fake_images) # Pulling Away Loss # G_pulling_away_loss = pulling_away(encoded) # Calculate Total Generator Loss # G_loss += config.lambda_pt * G_pulling_away_loss # Back Propagation and Update # G_loss.backward() G_optim.step() # Add items to Lists # D_losses.append(D_loss.item()) G_losses.append(G_loss.item()) #################### # Print Statistics # #################### if (i + 1) % config.print_every == 0: print( "Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}" .format(epoch + 1, config.num_epochs, i + 1, total_batch, np.mean(D_losses), np.mean(G_losses))) # Sample Images # sample_images(G, fixed_noise, epoch) # Adjust Learning Rate # D_optim_scheduler.step() G_optim_scheduler.step() # Save Model Weights # if (epoch + 1) % config.save_every == 0: torch.save( G.state_dict(), os.path.join(config.weights_path, 'Face_Generator_Epoch_{}.pkl'.format(epoch + 1))) # Make a GIF file # make_gifs_train("Face_Generation", config.samples_path) # Plot Losses # plot_losses(D_losses, G_losses, config.num_epochs, config.plots_path) print("Training finished.")
def train(): # Fix Seed for Reproducibility # torch.manual_seed(9) if torch.cuda.is_available(): torch.cuda.manual_seed(9) # Samples, Weights and Results Path # paths = [ config.samples_gen_path, config.samples_recon_path, config.weights_path, config.plots_path ] paths = [make_dirs(path) for path in paths] # Prepare Data Loader # celeba_loader = get_celeba_loader(path=config.celeba_path, batch_size=config.batch_size) total_batch = len(celeba_loader) # Prepare Networks # D = Discriminator().to(device) G = Generator().to(device) # Loss Function # criterion_bce = nn.BCEWithLogitsLoss() criterion_l2 = nn.MSELoss() # Optimizer # Disc_optim = torch.optim.Adam(D.parameters(), lr=config.D_lr, betas=(0.5, 0.999)) Enc_optim = torch.optim.Adam(G.encoder.parameters(), lr=config.G_lr, betas=(0.5, 0.999)) Dec_optim = torch.optim.Adam(G.decoder.parameters(), lr=config.G_lr, betas=(0.5, 0.999)) Disc_optim_scheduler = get_lr_scheduler(Disc_optim) Enc_optim_scheduler = get_lr_scheduler(Enc_optim) Dec_optim_scheduler = get_lr_scheduler(Dec_optim) # Labels # real_labels = torch.ones(config.batch_size, 1).to(device) fake_labels = torch.zeros(config.batch_size, 1).to(device) # Fixed Noise # fixed_noise = torch.randn(config.batch_size, config.noise_dim).to(device) # Lists # Disc_losses, Enc_losses, Dec_losses = list(), list(), list() # Train # print("Training has started with total epoch of {}.".format( config.num_epochs)) for epoch in range(config.num_epochs): for i, (images, labels) in enumerate(celeba_loader): # Data Preparation # images = images.to(device) # Initialize Optimizers # Disc_optim.zero_grad() Enc_optim.zero_grad() Dec_optim.zero_grad() ####################### # Train Discriminator # ####################### # Noise and Generate Fake Images # noise = torch.randn(images.size(0), config.noise_dim).to(device) fake_images, mu, log_var = G(images) # Discriminator GAN Loss # prob_real = D(images) Disc_real_loss = criterion_bce(prob_real, real_labels) prob_fake = D(fake_images) Disc_fake_loss = criterion_bce(prob_fake, fake_labels) prob_p_fake = D(G.decoder(noise)) Disc_fake_p_loss = criterion_bce(prob_p_fake, fake_labels) # Calculate Total Discriminator Loss # Disc_loss = Disc_real_loss + Disc_fake_loss + Disc_fake_p_loss # Back Propagation and Update # Disc_loss.backward() Disc_optim.step() ################# # Train Decoder # ################# # Generate Fake Images # fake_images, mu, log_var = G(images) # Decoder GAN Loss # prob_real = D(images) Dec_real_loss = criterion_bce(prob_real, real_labels) prob_fake = D(fake_images) Dec_fake_loss = criterion_bce(prob_fake, fake_labels) prob_p_fake = D(G.decoder(noise)) Dec_fake_p_loss = criterion_bce(prob_p_fake, fake_labels) # Calculate Total Decoder GAN Loss # Dec_gan_loss = -1 * (Dec_fake_loss + Dec_fake_p_loss + Dec_real_loss) # Decoder Reconstruction Loss # Dec_recon_loss = criterion_l2(D.feature(fake_images), D.feature(images)) # Calculate Total Decoder Loss # Dec_loss = Dec_gan_loss + config.gamma * Dec_recon_loss # Back Propagation and Update # Dec_loss.backward() Dec_optim.step() ################# # Train Encoder # ################# # Generate Fake Images # fake_images, mu, log_var = G(images) # Encoder Prior Loss # Enc_prior_loss = -0.5 * torch.mean(1 + log_var - torch.pow(mu, 2) - torch.exp(log_var)) # Encoder Reconstruction Loss # Enc_recon_loss = criterion_l2(D.feature(fake_images), D.feature(images)) # Calculate Total Encoder Loss # Enc_loss = Enc_prior_loss + config.beta * Enc_recon_loss # Back Propagation and Update # Enc_loss.backward() Enc_optim.step() # Add items to Lists # Disc_losses.append(Disc_loss.item()) Enc_losses.append(Enc_loss.item()) Dec_losses.append(Dec_loss.item()) #################### # Print Statistics # #################### if (i + 1) % config.print_every == 0: print( "Epochs [{}/{}] | Iterations [{}/{}] | Disc Loss {:.4f} | Enc Loss {:.4f} | Dec Loss {:.4f}" .format(epoch + 1, config.num_epochs, i + 1, total_batch, np.average(Disc_losses), np.average(Enc_losses), np.average(Dec_losses))) # Sample Images # sample_images(G, celeba_loader, fixed_noise, epoch) # Adjust Learning Rate # Disc_optim_scheduler.step() Enc_optim_scheduler.step() Dec_optim_scheduler.step() # Save Model Weights # if (epoch + 1) % config.save_every == 0: torch.save( G.state_dict(), os.path.join(config.weights_path, 'Face_Generator_Epoch_{}.pkl'.format(epoch + 1))) # Make a GIF file # make_gifs_train("Face_Generation", config.samples_gen_path) make_gifs_train("Face_Reconstruction", config.samples_recon_path) # Plot Losses # plot_losses(Disc_losses, Enc_losses, Dec_losses, config.num_epochs, config.plots_path) print("Training finished.")
def train(): # Fix Seed for Reproducibility # torch.manual_seed(9) if torch.cuda.is_available(): torch.cuda.manual_seed(9) # Samples, Weights and Results Path # paths = [config.samples_path, config.weights_path, config.plots_path] paths = [make_dirs(path) for path in paths] # Prepare Data Loader # celeba_loader = get_celeba_loader(path=config.celeba_path, batch_size=config.batch_size) total_batch = len(celeba_loader) # Prepare Networks # D = Discriminator().to(device) G = Generator().to(device) # Loss Function # criterion = nn.L1Loss().to(device) # Optimizers # D_optim = torch.optim.Adam(D.parameters(), lr=config.lr, betas=(0.5, 0.999)) G_optim = torch.optim.Adam(G.parameters(), lr=config.lr, betas=(0.5, 0.999)) D_optim_scheduler = get_lr_scheduler(D_optim) G_optim_scheduler = get_lr_scheduler(G_optim) # Lists # D_losses, G_losses = [], [] # Fixed Noise # fixed_noise = torch.randn(config.batch_size, config.noise_dim, 1, 1).to(device) # Constants # k_t = 0 lr_k = 0.001 gamma = 0.7 # Train # print("Training started with total epoch of {}.".format(config.num_epochs)) for epoch in range(config.num_epochs): for i, (images, labels) in enumerate(celeba_loader): # Data Preparation # images = images.to(device) noise = torch.randn(config.batch_size, config.noise_dim, 1, 1).to(device) # Initialize Optimizers # D_optim.zero_grad() G_optim.zero_grad() ####################### # Train Discriminator # ####################### # Adversarial Loss using Real Image # prob_real = D(images) D_real_loss = criterion(prob_real, images) # Adversarial Loss using Generated Image # fake_images = G(noise) prob_fake = D(fake_images.detach()) D_fake_loss = criterion(prob_fake, fake_images) # Calculate Total Discriminator Loss # D_loss = D_real_loss - k_t * D_fake_loss # Back Propagation and Update # D_loss.backward() D_optim.step() ################### # Train Generator # ################### # Adversarial Loss using Generated Image # fake_images = G(noise) prob_fake = D(fake_images) # Calculate Total Generator Loss # G_loss = criterion(prob_fake, fake_images) # Back Propagation and Update # G_loss.backward() G_optim.step() # Update Constants # balance = (gamma * D_real_loss - G_loss).item() k_t += lr_k * balance k_t = min(max(k_t, 0), 1) D_losses.append(D_loss.item()) G_losses.append(G_loss.item()) #################### # Print Statistics # #################### if (i + 1) % config.print_every == 0: print( "Epoch [{}/{}] | Iter [{}/{}] | D Loss {:.4f} | G Loss {:.4f}" .format(epoch + 1, config.num_epochs, i + 1, total_batch, np.average(D_losses), np.average(G_losses))) # Sample Images # sample_images(G, fixed_noise, epoch) # Adjust Learning Rate # D_optim_scheduler.step() G_optim_scheduler.step() # Save Model Weights # if (epoch + 1) % config.save_every == 0: torch.save( G.state_dict(), os.path.join(config.weights_path, 'Face_Generator_Epoch_{}.pkl'.format(epoch + 1))) # Make a GIF file # make_gifs_train("Face_Generation", config.samples_path) # Plot Losses # plot_losses(D_losses, G_losses, config.num_epochs, config.plots_path) print("Training finished.")
def train(): # Fix Seed for Reproducibility # torch.manual_seed(9) if torch.cuda.is_available(): torch.cuda.manual_seed(9) # Samples, Weights and Results Path # paths = [config.samples_path, config.weights_path, config.plots_path] paths = [make_dirs(path) for path in paths] # Prepare Data Loader # train_loader = get_celeba_loader('train', config.batch_size, config.selected_attrs) total_batch = len(train_loader) fixed_image, original_label = next(iter(train_loader)) fixed_image = fixed_image.to(device) fixed_labels_list = create_labels(original_label, config.selected_attrs) # Prepare Networks # D = Discriminator(num_classes=len(config.selected_attrs)).to(device) G = Generator(num_classes=len(config.selected_attrs)).to(device) # Optimizers # D_optim = torch.optim.Adam(D.parameters(), lr=config.lr, betas=(0.5, 0.999)) G_optim = torch.optim.Adam(G.parameters(), lr=config.lr, betas=(0.5, 0.999)) D_optim_scheduler = get_lr_scheduler(D_optim) G_optim_scheduler = get_lr_scheduler(G_optim) # Lists # D_losses, G_losses = [], [] # Train # print("Training StarGAN started with total epoch of {}.".format( config.num_epochs)) for epoch in range(config.num_epochs): for i, batch in enumerate(train_loader): # Data Preparation # real_image, label = next(iter(train_loader)) real_image = real_image.to(device) label = label.to(device) rand_idx = torch.randperm(label.size(0)) target_label = label[rand_idx].to(device) # Initialize Optimizers # D_optim.zero_grad() G_optim.zero_grad() ####################### # Train Discriminator # ####################### set_requires_grad(D, requires_grad=True) # Discriminiator Loss using Real Image # prob_real_src, prob_real_cls = D(real_image) D_real_loss = -torch.mean(prob_real_src) D_cls_loss = config.lambda_cls * criterion_CLS( prob_real_cls, label) # Discriminiator Loss using Generated Image # fake_image = G(real_image, target_label) prob_fake_src, prob_fake_cls = D(fake_image.detach()) D_fake_loss = torch.mean(prob_fake_src) # Discriminiator Loss using Wasserstein GAN Gradient Penalty # D_gp_loss = config.lambda_gp * get_gradient_penalty( real_image, fake_image, D) # Calculate Total Discriminator Loss # D_loss = D_real_loss + D_fake_loss + D_cls_loss + D_gp_loss # Back Propagation and Update # D_loss.backward() D_optim.step() # Add items to Lists # D_losses.append(D_loss.item()) ################### # Train Generator # ################### if (i + 1) % config.n_critics == 0: # Prevent Discriminator Update during Generator Update # set_requires_grad(D, requires_grad=False) # Initialize Optimizers # D_optim.zero_grad() G_optim.zero_grad() # Generator Loss using Fake Images # fake_image = G(real_image, target_label) prob_fake_src, prob_fake_cls = D(fake_image) G_fake_loss = -torch.mean(prob_fake_src) G_cls_loss = config.lambda_cls * criterion_CLS( prob_fake_cls, target_label) # Reconstruction Loss # recon_image = G(fake_image, label) G_recon_loss = config.lambda_recon * torch.mean( torch.abs(real_image - recon_image)) # Calculate Total Generator Loss # G_loss = G_fake_loss + G_recon_loss + G_cls_loss # Back Propagation and Update # G_loss.backward() G_optim.step() # Add items to Lists # G_losses.append(G_loss.item()) #################### # Print Statistics # #################### if (i + 1) % config.print_every == 0: print( "StarGAN | Epoch [{}/{}] | Iteration [{}/{}] | D Loss {:.4f} | G Loss {:.4f}" .format(epoch + 1, config.num_epochs, i + 1, total_batch, np.average(D_losses), np.average(G_losses))) # Save Sample Images # save_samples(fixed_image, fixed_labels_list, G, epoch, config.samples_path) # Adjust Learning Rate # D_optim_scheduler.step() G_optim_scheduler.step() # Save Model Weights # if (epoch + 1) % config.save_every == 0: torch.save( G.state_dict(), os.path.join( config.weights_path, 'StarGAN_Generator_Epoch_{}.pkl'.format(epoch + 1))) # Make a GIF file # make_gifs_train('StarGAN', config.samples_path) # Plot Losses # plot_losses(D_losses, G_losses, config.num_epochs, config.plots_path) print("Training Finished.")