def train(): # Fix Seed for Reproducibility # torch.manual_seed(9) if torch.cuda.is_available(): torch.cuda.manual_seed(9) # Samples, Weights and Results Path # paths = [config.samples_path, config.weights_path, config.plots_path] paths = [make_dirs(path) for path in paths] # Prepare Data Loader # train_horse_loader, train_zebra_loader = get_horse2zebra_loader('train', config.batch_size) val_horse_loader, val_zebra_loader = get_horse2zebra_loader('test', config.batch_size) total_batch = min(len(train_horse_loader), len(train_zebra_loader)) # Image Pool # masked_fake_A_pool = ImageMaskPool(config.pool_size) masked_fake_B_pool = ImageMaskPool(config.pool_size) # Prepare Networks # Attn_A = Attention() Attn_B = Attention() G_A2B = Generator() G_B2A = Generator() D_A = Discriminator() D_B = Discriminator() networks = [Attn_A, Attn_B, G_A2B, G_B2A, D_A, D_B] for network in networks: network.to(device) # Loss Function # criterion_Adversarial = nn.MSELoss() criterion_Cycle = nn.L1Loss() # Optimizers # D_optim = torch.optim.Adam(chain(D_A.parameters(), D_B.parameters()), lr=config.lr, betas=(0.5, 0.999)) G_optim = torch.optim.Adam(chain(Attn_A.parameters(), Attn_B.parameters(), G_A2B.parameters(), G_B2A.parameters()), lr=config.lr, betas=(0.5, 0.999)) D_optim_scheduler = get_lr_scheduler(D_optim) G_optim_scheduler = get_lr_scheduler(G_optim) # Lists # D_A_losses, D_B_losses = [], [] G_A_losses, G_B_losses = [], [] # Train # print("Training Unsupervised Attention-Guided GAN started with total epoch of {}.".format(config.num_epochs)) for epoch in range(config.num_epochs): for i, (real_A, real_B) in enumerate(zip(train_horse_loader, train_zebra_loader)): # Data Preparation # real_A = real_A.to(device) real_B = real_B.to(device) # Initialize Optimizers # D_optim.zero_grad() G_optim.zero_grad() ################### # Train Generator # ################### set_requires_grad([D_A, D_B], requires_grad=False) # Adversarial Loss using real A # attn_A = Attn_A(real_A) fake_B = G_A2B(real_A) masked_fake_B = fake_B * attn_A + real_A * (1-attn_A) masked_fake_B *= attn_A prob_real_A = D_A(masked_fake_B) real_labels = torch.ones(prob_real_A.size()).to(device) G_loss_A = criterion_Adversarial(prob_real_A, real_labels) # Adversarial Loss using real B # attn_B = Attn_B(real_B) fake_A = G_B2A(real_B) masked_fake_A = fake_A * attn_B + real_B * (1-attn_B) masked_fake_A *= attn_B prob_real_B = D_B(masked_fake_A) real_labels = torch.ones(prob_real_B.size()).to(device) G_loss_B = criterion_Adversarial(prob_real_B, real_labels) # Cycle Consistency Loss using real A # attn_ABA = Attn_B(masked_fake_B) fake_ABA = G_B2A(masked_fake_B) masked_fake_ABA = fake_ABA * attn_ABA + masked_fake_B * (1 - attn_ABA) # Cycle Consistency Loss using real B # attn_BAB = Attn_A(masked_fake_A) fake_BAB = G_A2B(masked_fake_A) masked_fake_BAB = fake_BAB * attn_BAB + masked_fake_A * (1 - attn_BAB) # Cycle Consistency Loss # G_cycle_loss_A = config.lambda_cycle * criterion_Cycle(masked_fake_ABA, real_A) G_cycle_loss_B = config.lambda_cycle * criterion_Cycle(masked_fake_BAB, real_B) # Total Generator Loss # G_loss = G_loss_A + G_loss_B + G_cycle_loss_A + G_cycle_loss_B # Back Propagation and Update # G_loss.backward() G_optim.step() ####################### # Train Discriminator # ####################### set_requires_grad([D_A, D_B], requires_grad=True) # Train Discriminator A using real A # prob_real_A = D_A(real_B) real_labels = torch.ones(prob_real_A.size()).to(device) D_loss_real_A = criterion_Adversarial(prob_real_A, real_labels) # Add Pooling # masked_fake_B, attn_A = masked_fake_B_pool.query(masked_fake_B, attn_A) masked_fake_B *= attn_A # Train Discriminator A using fake B # prob_fake_B = D_A(masked_fake_B.detach()) fake_labels = torch.zeros(prob_fake_B.size()).to(device) D_loss_fake_A = criterion_Adversarial(prob_fake_B, fake_labels) D_loss_A = (D_loss_real_A + D_loss_fake_A).mean() # Train Discriminator B using real B # prob_real_B = D_B(real_A) real_labels = torch.ones(prob_real_B.size()).to(device) D_loss_real_B = criterion_Adversarial(prob_real_B, real_labels) # Add Pooling # masked_fake_A, attn_B = masked_fake_A_pool.query(masked_fake_A, attn_B) masked_fake_A *= attn_B # Train Discriminator B using fake A # prob_fake_A = D_B(masked_fake_A.detach()) fake_labels = torch.zeros(prob_fake_A.size()).to(device) D_loss_fake_B = criterion_Adversarial(prob_fake_A, fake_labels) D_loss_B = (D_loss_real_B + D_loss_fake_B).mean() # Calculate Total Discriminator Loss # D_loss = D_loss_A + D_loss_B # Back Propagation and Update # D_loss.backward() D_optim.step() # Add items to Lists # D_A_losses.append(D_loss_A.item()) D_B_losses.append(D_loss_B.item()) G_A_losses.append(G_loss_A.item()) G_B_losses.append(G_loss_B.item()) #################### # Print Statistics # #################### if (i+1) % config.print_every == 0: print("UAG-GAN | Epoch [{}/{}] | Iteration [{}/{}] | D A Losses {:.4f} | D B Losses {:.4f} | G A Losses {:.4f} | G B Losses {:.4f}". format(epoch+1, config.num_epochs, i+1, total_batch, np.average(D_A_losses), np.average(D_B_losses), np.average(G_A_losses), np.average(G_B_losses))) # Save Sample Images # save_samples(val_horse_loader, val_zebra_loader, G_A2B, G_B2A, Attn_A, Attn_B, epoch, config.samples_path) # Adjust Learning Rate # D_optim_scheduler.step() G_optim_scheduler.step() # Save Model Weights # if (epoch + 1) % config.save_every == 0: torch.save(G_A2B.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Generator_A2B_Epoch_{}.pkl'.format(epoch+1))) torch.save(G_B2A.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Generator_B2A_Epoch_{}.pkl'.format(epoch+1))) torch.save(Attn_A.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Attention_A_Epoch_{}.pkl'.format(epoch+1))) torch.save(Attn_B.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Attention_B_Epoch_{}.pkl'.format(epoch+1))) # Make a GIF file # make_gifs_train("UAG-GAN", config.samples_path) # Plot Losses # plot_losses(D_A_losses, D_B_losses, G_A_losses, G_B_losses, config.num_epochs, config.plots_path) print("Training finished.")
def inference(): # Inference Path # paths = [config.inference_path_H2Z, config.inference_path_Z2H] paths = [make_dirs(path) for path in paths] # Prepare Data Loader # test_horse_loader, test_zebra_loader = get_horse2zebra_loader( 'test', config.val_batch_size) # Prepare Attention and Generator # Attn_A = Attention().to(device) Attn_B = Attention().to(device) G_A2B = Generator().to(device) G_B2A = Generator().to(device) Attn_A.load_state_dict( torch.load( os.path.join( config.weights_path, 'UAG-GAN_Attention_A_Epoch_{}.pkl'.format(config.num_epochs)))) Attn_B.load_state_dict( torch.load( os.path.join( config.weights_path, 'UAG-GAN_Attention_B_Epoch_{}.pkl'.format(config.num_epochs)))) G_A2B.load_state_dict( torch.load( os.path.join( config.weights_path, 'UAG-GAN_Generator_A2B_Epoch_{}.pkl'.format( config.num_epochs)))) G_B2A.load_state_dict( torch.load( os.path.join( config.weights_path, 'UAG-GAN_Generator_B2A_Epoch_{}.pkl'.format( config.num_epochs)))) # Test # print("UAG-GAN | Generating Horse2Zebra images started...") for i, (horse, zebra) in enumerate(zip(test_horse_loader, test_zebra_loader)): # Prepare Data # real_A = horse.to(device) real_B = zebra.to(device) # Generate Attention Images # attn_A = Attn_A(real_A.detach()) attn_A = attn_A.repeat(1, 3, 1, 1) attn_A = 2 * attn_A - 1 attn_B = Attn_B(real_B.detach()) attn_B = attn_B.repeat(1, 3, 1, 1) attn_B = 2 * attn_B - 1 # Generated Fake Images # fake_B = G_A2B(real_A.detach()) fake_A = G_B2A(real_B.detach()) # Save Images (Horse -> Zebra) # result = torch.cat((real_A, attn_A, fake_B), dim=0) save_image( denorm(result.data), os.path.join(config.inference_path_H2Z, 'UAG-GAN_Horse2Zebra_Results_%03d.png' % (i + 1))) # Save Images (Zebra -> Horse) # result = torch.cat((real_B, attn_B, fake_A), dim=0) save_image( denorm(result.data), os.path.join(config.inference_path_Z2H, 'UAG-GAN_Zebra2Horse_Results_%03d.png' % (i + 1))) # Make a GIF file # make_gifs_test("UAG-GAN", "Horse2Zebra", config.inference_path_H2Z) make_gifs_test("UAG-GAN", "Zebra2Horse", config.inference_path_Z2H)
def train(): # Fix Seed for Reproducibility # torch.manual_seed(9) if torch.cuda.is_available(): torch.cuda.manual_seed(9) # Samples, Weights and Results Path # paths = [config.samples_path, config.weights_path, config.plots_path] paths = [make_dirs(path) for path in paths] # Prepare Data Loader # train_horse_loader, train_zebra_loader = get_horse2zebra_loader( purpose='train', batch_size=config.batch_size) test_horse_loader, test_zebra_loader = get_horse2zebra_loader( purpose='test', batch_size=config.val_batch_size) total_batch = min(len(train_horse_loader), len(train_zebra_loader)) # Prepare Networks # D_A = Discriminator() D_B = Discriminator() G_A2B = Generator() G_B2A = Generator() networks = [D_A, D_B, G_A2B, G_B2A] for network in networks: network.to(device) # Loss Function # criterion_Adversarial = nn.MSELoss() criterion_Cycle = nn.L1Loss() criterion_Identity = nn.L1Loss() # Optimizers # D_A_optim = torch.optim.Adam(D_A.parameters(), lr=config.lr, betas=(0.5, 0.999)) D_B_optim = torch.optim.Adam(D_B.parameters(), lr=config.lr, betas=(0.5, 0.999)) G_optim = torch.optim.Adam(chain(G_A2B.parameters(), G_B2A.parameters()), lr=config.lr, betas=(0.5, 0.999)) D_A_optim_scheduler = get_lr_scheduler(D_A_optim) D_B_optim_scheduler = get_lr_scheduler(D_B_optim) G_optim_scheduler = get_lr_scheduler(G_optim) # Lists # D_losses_A, D_losses_B, G_losses = [], [], [] # Training # print("Training CycleGAN started with total epoch of {}.".format( config.num_epochs)) for epoch in range(config.num_epochs): for i, (horse, zebra) in enumerate(zip(train_horse_loader, train_zebra_loader)): # Data Preparation # real_A = horse.to(device) real_B = zebra.to(device) # Initialize Optimizers # G_optim.zero_grad() D_A_optim.zero_grad() D_B_optim.zero_grad() ################### # Train Generator # ################### set_requires_grad([D_A, D_B], requires_grad=False) # Adversarial Loss # fake_A = G_B2A(real_B) prob_fake_A = D_A(fake_A) real_labels = torch.ones(prob_fake_A.size()).to(device) G_mse_loss_B2A = criterion_Adversarial(prob_fake_A, real_labels) fake_B = G_A2B(real_A) prob_fake_B = D_B(fake_B) real_labels = torch.ones(prob_fake_B.size()).to(device) G_mse_loss_A2B = criterion_Adversarial(prob_fake_B, real_labels) # Identity Loss # identity_A = G_B2A(real_A) G_identity_loss_A = config.lambda_identity * criterion_Identity( identity_A, real_A) identity_B = G_A2B(real_B) G_identity_loss_B = config.lambda_identity * criterion_Identity( identity_B, real_B) # Cycle Loss # reconstructed_A = G_B2A(fake_B) G_cycle_loss_ABA = config.lambda_cycle * criterion_Cycle( reconstructed_A, real_A) reconstructed_B = G_A2B(fake_A) G_cycle_loss_BAB = config.lambda_cycle * criterion_Cycle( reconstructed_B, real_B) # Calculate Total Generator Loss # G_loss = G_mse_loss_B2A + G_mse_loss_A2B + G_identity_loss_A + G_identity_loss_B + G_cycle_loss_ABA + G_cycle_loss_BAB # Back Propagation and Update # G_loss.backward(retain_graph=True) G_optim.step() ####################### # Train Discriminator # ####################### set_requires_grad([D_A, D_B], requires_grad=True) ## Train Discriminator A ## # Real Loss # prob_real_A = D_A(real_A) real_labels = torch.ones(prob_real_A.size()).to(device) D_real_loss_A = criterion_Adversarial(prob_real_A, real_labels) # Fake Loss # prob_fake_A = D_A(fake_A.detach()) fake_labels = torch.zeros(prob_fake_A.size()).to(device) D_fake_loss_A = criterion_Adversarial(prob_fake_A, fake_labels) # Calculate Total Discriminator A Loss # D_loss_A = config.lambda_identity * (D_real_loss_A + D_fake_loss_A).mean() # Back propagation and Update # D_loss_A.backward(retain_graph=True) D_A_optim.step() ## Train Discriminator B ## # Real Loss # prob_real_B = D_B(real_B) real_labels = torch.ones(prob_real_B.size()).to(device) loss_real_B = criterion_Adversarial(prob_real_B, real_labels) # Fake Loss # prob_fake_B = D_B(fake_B.detach()) fake_labels = torch.zeros(prob_fake_B.size()).to(device) loss_fake_B = criterion_Adversarial(prob_fake_B, fake_labels) # Calculate Total Discriminator B Loss # D_loss_B = config.lambda_identity * (loss_real_B + loss_fake_B).mean() # Back propagation and Update # D_loss_B.backward(retain_graph=True) D_B_optim.step() # Add items to Lists # D_losses_A.append(D_loss_A.item()) D_losses_B.append(D_loss_B.item()) G_losses.append(G_loss.item()) #################### # Print Statistics # #################### if (i + 1) % config.print_every == 0: print( "CycleGAN | Epoch [{}/{}] | Iterations [{}/{}] | D_A Loss {:.4f} | D_B Loss {:.4f} | G Loss {:.4f}" .format(epoch + 1, config.num_epochs, i + 1, total_batch, np.average(D_losses_A), np.average(D_losses_B), np.average(G_losses))) # Save Sample Images # sample_images(test_horse_loader, test_zebra_loader, G_A2B, G_B2A, epoch, config.samples_path) # Adjust Learning Rate # D_A_optim_scheduler.step() D_B_optim_scheduler.step() G_optim_scheduler.step() # Save Model Weights # if (epoch + 1) % config.save_every == 0: torch.save( G_A2B.state_dict(), os.path.join( config.weights_path, 'CycleGAN_Generator_A2B_Epoch_{}.pkl'.format(epoch + 1))) torch.save( G_B2A.state_dict(), os.path.join( config.weights_path, 'CycleGAN_Generator_B2A_Epoch_{}.pkl'.format(epoch + 1))) # Make a GIF file # make_gifs_train("CycleGAN", config.samples_path) # Plot Losses # plot_losses(D_losses_A, D_losses_B, G_losses, config.num_epochs, config.plots_path) print("Training finished.")
def inference(): # Inference Path # paths = [config.inference_path_H2Z, config.inference_path_Z2H] paths = [make_dirs(path) for path in paths] # Prepare Data Loader # test_horse_loader, test_zebra_loader = get_horse2zebra_loader( 'test', config.val_batch_size) # Prepare Generator # G_A2B = Generator().to(device) G_B2A = Generator().to(device) G_A2B.load_state_dict( torch.load( os.path.join( config.weights_path, 'CycleGAN_Generator_A2B_Epoch_{}.pkl'.format( config.num_epochs)))) G_B2A.load_state_dict( torch.load( os.path.join( config.weights_path, 'CycleGAN_Generator_B2A_Epoch_{}.pkl'.format( config.num_epochs)))) # Test # print("CycleGAN | Generating Horse2Zebra images started...") for i, (horse, zebra) in enumerate(zip(test_horse_loader, test_zebra_loader)): # Prepare Data # real_A = horse.to(device) real_B = zebra.to(device) # Generate Fake Images # fake_B = G_A2B(real_A) fake_A = G_B2A(real_B) # Generated Reconstructed Images # fake_ABA = G_B2A(fake_B) fake_BAB = G_A2B(fake_A) # Save Images (Horse -> Zebra) # result = torch.cat((real_A, fake_B, fake_ABA), dim=0) save_image(denorm(result.data), os.path.join( config.inference_path_H2Z, 'CycleGAN_Horse2Zebra_Results_%03d.png' % (i + 1)), nrow=3, normalize=True) # Save Images (Zebra -> Horse) # result = torch.cat((real_B, fake_A, fake_BAB), dim=0) save_image(denorm(result.data), os.path.join( config.inference_path_Z2H, 'CycleGAN_Zebra2Horse_Results_%03d.png' % (i + 1)), nrow=3, normalize=True) # Make a GIF file # make_gifs_test("CycleGAN", "Horse2Zebra", config.inference_path_H2Z) make_gifs_test("CycleGAN", "Zebra2Horse", config.inference_path_Z2H)