Exemplo n.º 1
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_horse_loader, train_zebra_loader = get_horse2zebra_loader('train', config.batch_size)
    val_horse_loader, val_zebra_loader = get_horse2zebra_loader('test', config.batch_size)
    total_batch = min(len(train_horse_loader), len(train_zebra_loader))

    # Image Pool #
    masked_fake_A_pool = ImageMaskPool(config.pool_size)
    masked_fake_B_pool = ImageMaskPool(config.pool_size)

    # Prepare Networks #
    Attn_A = Attention()
    Attn_B = Attention()
    G_A2B = Generator()
    G_B2A = Generator()
    D_A = Discriminator()
    D_B = Discriminator()

    networks = [Attn_A, Attn_B, G_A2B, G_B2A, D_A, D_B]
    for network in networks:
        network.to(device)

    # Loss Function #
    criterion_Adversarial = nn.MSELoss()
    criterion_Cycle = nn.L1Loss()

    # Optimizers #
    D_optim = torch.optim.Adam(chain(D_A.parameters(), D_B.parameters()), lr=config.lr, betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(chain(Attn_A.parameters(), Attn_B.parameters(), G_A2B.parameters(), G_B2A.parameters()), lr=config.lr, betas=(0.5, 0.999))

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_A_losses, D_B_losses = [], []
    G_A_losses, G_B_losses = [], []

    # Train #
    print("Training Unsupervised Attention-Guided GAN started with total epoch of {}.".format(config.num_epochs))

    for epoch in range(config.num_epochs):

        for i, (real_A, real_B) in enumerate(zip(train_horse_loader, train_zebra_loader)):

            # Data Preparation #
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            ###################
            # Train Generator #
            ###################

            set_requires_grad([D_A, D_B], requires_grad=False)

            # Adversarial Loss using real A #
            attn_A = Attn_A(real_A)
            fake_B = G_A2B(real_A)

            masked_fake_B = fake_B * attn_A + real_A * (1-attn_A)

            masked_fake_B *= attn_A
            prob_real_A = D_A(masked_fake_B)
            real_labels = torch.ones(prob_real_A.size()).to(device)

            G_loss_A = criterion_Adversarial(prob_real_A, real_labels)

            # Adversarial Loss using real B #
            attn_B = Attn_B(real_B)
            fake_A = G_B2A(real_B)

            masked_fake_A = fake_A * attn_B + real_B * (1-attn_B)

            masked_fake_A *= attn_B
            prob_real_B = D_B(masked_fake_A)
            real_labels = torch.ones(prob_real_B.size()).to(device)

            G_loss_B = criterion_Adversarial(prob_real_B, real_labels)

            # Cycle Consistency Loss using real A #
            attn_ABA = Attn_B(masked_fake_B)
            fake_ABA = G_B2A(masked_fake_B)
            masked_fake_ABA = fake_ABA * attn_ABA + masked_fake_B * (1 - attn_ABA)

            # Cycle Consistency Loss using real B #
            attn_BAB = Attn_A(masked_fake_A)
            fake_BAB = G_A2B(masked_fake_A)
            masked_fake_BAB = fake_BAB * attn_BAB + masked_fake_A * (1 - attn_BAB)

            # Cycle Consistency Loss #
            G_cycle_loss_A = config.lambda_cycle * criterion_Cycle(masked_fake_ABA, real_A)
            G_cycle_loss_B = config.lambda_cycle * criterion_Cycle(masked_fake_BAB, real_B)

            # Total Generator Loss #
            G_loss = G_loss_A + G_loss_B + G_cycle_loss_A + G_cycle_loss_B

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad([D_A, D_B], requires_grad=True)

            # Train Discriminator A using real A #
            prob_real_A = D_A(real_B)
            real_labels = torch.ones(prob_real_A.size()).to(device)
            D_loss_real_A = criterion_Adversarial(prob_real_A, real_labels)

            # Add Pooling #
            masked_fake_B, attn_A = masked_fake_B_pool.query(masked_fake_B, attn_A)
            masked_fake_B *= attn_A

            # Train Discriminator A using fake B #
            prob_fake_B = D_A(masked_fake_B.detach())
            fake_labels = torch.zeros(prob_fake_B.size()).to(device)
            D_loss_fake_A = criterion_Adversarial(prob_fake_B, fake_labels)

            D_loss_A = (D_loss_real_A + D_loss_fake_A).mean()

            # Train Discriminator B using real B #
            prob_real_B = D_B(real_A)
            real_labels = torch.ones(prob_real_B.size()).to(device)
            D_loss_real_B = criterion_Adversarial(prob_real_B, real_labels)

            # Add Pooling #
            masked_fake_A, attn_B = masked_fake_A_pool.query(masked_fake_A, attn_B)
            masked_fake_A *= attn_B

            # Train Discriminator B using fake A #
            prob_fake_A = D_B(masked_fake_A.detach())
            fake_labels = torch.zeros(prob_fake_A.size()).to(device)
            D_loss_fake_B = criterion_Adversarial(prob_fake_A, fake_labels)

            D_loss_B = (D_loss_real_B + D_loss_fake_B).mean()

            # Calculate Total Discriminator Loss #
            D_loss = D_loss_A + D_loss_B

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            # Add items to Lists #
            D_A_losses.append(D_loss_A.item())
            D_B_losses.append(D_loss_B.item())
            G_A_losses.append(G_loss_A.item())
            G_B_losses.append(G_loss_B.item())

            ####################
            # Print Statistics #
            ####################

            if (i+1) % config.print_every == 0:
                print("UAG-GAN | Epoch [{}/{}] | Iteration [{}/{}] | D A Losses {:.4f} | D B Losses {:.4f} | G A Losses {:.4f} | G B Losses {:.4f}".
                      format(epoch+1, config.num_epochs, i+1, total_batch, np.average(D_A_losses), np.average(D_B_losses), np.average(G_A_losses), np.average(G_B_losses)))

                # Save Sample Images #
                save_samples(val_horse_loader, val_zebra_loader, G_A2B, G_B2A, Attn_A, Attn_B, epoch, config.samples_path)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(G_A2B.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Generator_A2B_Epoch_{}.pkl'.format(epoch+1)))
            torch.save(G_B2A.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Generator_B2A_Epoch_{}.pkl'.format(epoch+1)))
            torch.save(Attn_A.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Attention_A_Epoch_{}.pkl'.format(epoch+1)))
            torch.save(Attn_B.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Attention_B_Epoch_{}.pkl'.format(epoch+1)))

    # Make a GIF file #
    make_gifs_train("UAG-GAN", config.samples_path)

    # Plot Losses #
    plot_losses(D_A_losses, D_B_losses, G_A_losses, G_B_losses, config.num_epochs, config.plots_path)

    print("Training finished.")
Exemplo n.º 2
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights, and Plots Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_loader_selfie, train_loader_anime = get_selfie2anime_loader(
        'train', config.batch_size)
    total_batch = max(len(train_loader_selfie), len(train_loader_anime))

    test_loader_selfie, test_loader_anime = get_selfie2anime_loader(
        'test', config.val_batch_size)

    # Prepare Networks #
    D_A = Discriminator(num_layers=7)
    D_B = Discriminator(num_layers=7)
    L_A = Discriminator(num_layers=5)
    L_B = Discriminator(num_layers=5)
    G_A2B = Generator(image_size=config.crop_size,
                      num_blocks=config.num_blocks)
    G_B2A = Generator(image_size=config.crop_size,
                      num_blocks=config.num_blocks)

    networks = [D_A, D_B, L_A, L_B, G_A2B, G_B2A]

    for network in networks:
        network.to(device)

    # Loss Function #
    Adversarial_loss = nn.MSELoss()
    Cycle_loss = nn.L1Loss()
    BCE_loss = nn.BCEWithLogitsLoss()

    # Optimizers #
    D_optim = torch.optim.Adam(chain(D_A.parameters(), D_B.parameters(),
                                     L_A.parameters(), L_B.parameters()),
                               lr=config.lr,
                               betas=(0.5, 0.999),
                               weight_decay=0.0001)
    G_optim = torch.optim.Adam(chain(G_A2B.parameters(), G_B2A.parameters()),
                               lr=config.lr,
                               betas=(0.5, 0.999),
                               weight_decay=0.0001)

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Rho Clipper to constraint the value of rho in AdaILN and ILN #
    Rho_Clipper = RhoClipper(0, 1)

    # Lists #
    D_losses = []
    G_losses = []

    # Train #
    print("Training U-GAT-IT started with total epoch of {}.".format(
        config.num_epochs))
    for epoch in range(config.num_epochs):

        for i, (selfie, anime) in enumerate(
                zip(train_loader_selfie, train_loader_anime)):

            # Data Preparation #
            real_A = selfie.to(device)
            real_B = anime.to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad([D_A, D_B, L_A, L_B], requires_grad=True)

            # Forward Data #
            fake_B, _, _ = G_A2B(real_A)
            fake_A, _, _ = G_B2A(real_B)

            G_real_A, G_real_A_cam, _ = D_A(real_A)
            L_real_A, L_real_A_cam, _ = L_A(real_A)
            G_real_B, G_real_B_cam, _ = D_B(real_B)
            L_real_B, L_real_B_cam, _ = L_B(real_B)

            G_fake_A, G_fake_A_cam, _ = D_A(fake_A)
            L_fake_A, L_fake_A_cam, _ = L_A(fake_A)
            G_fake_B, G_fake_B_cam, _ = D_B(fake_B)
            L_fake_B, L_fake_B_cam, _ = L_B(fake_B)

            # Adversarial Loss of Discriminator #
            real_labels = torch.ones(G_real_A.shape).to(device)
            D_ad_real_loss_GA = Adversarial_loss(G_real_A, real_labels)

            fake_labels = torch.zeros(G_fake_A.shape).to(device)
            D_ad_fake_loss_GA = Adversarial_loss(G_fake_A, fake_labels)

            D_ad_loss_GA = D_ad_real_loss_GA + D_ad_fake_loss_GA

            real_labels = torch.ones(G_real_A_cam.shape).to(device)
            D_ad_cam_real_loss_GA = Adversarial_loss(G_real_A_cam, real_labels)

            fake_labels = torch.zeros(G_fake_A_cam.shape).to(device)
            D_ad_cam_fake_loss_GA = Adversarial_loss(G_fake_A_cam, fake_labels)

            D_ad_cam_loss_GA = D_ad_cam_real_loss_GA + D_ad_cam_fake_loss_GA

            real_labels = torch.ones(G_real_B.shape).to(device)
            D_ad_real_loss_GB = Adversarial_loss(G_real_B, real_labels)

            fake_labels = torch.zeros(G_fake_B.shape).to(device)
            D_ad_fake_loss_GB = Adversarial_loss(G_fake_B, fake_labels)

            D_ad_loss_GB = D_ad_real_loss_GB + D_ad_fake_loss_GB

            real_labels = torch.ones(G_real_B_cam.shape).to(device)
            D_ad_cam_real_loss_GB = Adversarial_loss(G_real_B_cam, real_labels)

            fake_labels = torch.zeros(G_fake_B_cam.shape).to(device)
            D_ad_cam_fake_loss_GB = Adversarial_loss(G_fake_B_cam, fake_labels)

            D_ad_cam_loss_GB = D_ad_cam_real_loss_GB + D_ad_cam_fake_loss_GB

            # Adversarial Loss of L #
            real_labels = torch.ones(L_real_A.shape).to(device)
            D_ad_real_loss_LA = Adversarial_loss(L_real_A, real_labels)

            fake_labels = torch.zeros(L_fake_A.shape).to(device)
            D_ad_fake_loss_LA = Adversarial_loss(L_fake_A, fake_labels)

            D_ad_loss_LA = D_ad_real_loss_LA + D_ad_fake_loss_LA

            real_labels = torch.ones(L_real_A_cam.shape).to(device)
            D_ad_cam_real_loss_LA = Adversarial_loss(L_real_A_cam, real_labels)

            fake_labels = torch.zeros(L_fake_A_cam.shape).to(device)
            D_ad_cam_fake_loss_LA = Adversarial_loss(L_fake_A_cam, fake_labels)

            D_ad_cam_loss_LA = D_ad_cam_real_loss_LA + D_ad_cam_fake_loss_LA

            real_labels = torch.ones(L_real_B.shape).to(device)
            D_ad_real_loss_LB = Adversarial_loss(L_real_B, real_labels)

            fake_labels = torch.zeros(L_fake_B.shape).to(device)
            D_ad_fake_loss_LB = Adversarial_loss(L_fake_B, fake_labels)

            D_ad_loss_LB = D_ad_real_loss_LB + D_ad_fake_loss_LB

            real_labels = torch.ones(L_real_B_cam.shape).to(device)
            D_ad_cam_real_loss_LB = Adversarial_loss(L_real_B_cam, real_labels)

            fake_labels = torch.zeros(L_fake_B_cam.shape).to(device)
            D_ad_cam_fake_loss_LB = Adversarial_loss(L_fake_B_cam, fake_labels)

            D_ad_cam_loss_LB = D_ad_cam_real_loss_LB + D_ad_cam_fake_loss_LB

            # Calculate Each Discriminator Loss #
            D_loss_A = D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA
            D_loss_B = D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB

            # Calculate Total Discriminator Loss #
            D_loss = D_loss_A + D_loss_B

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            ###################
            # Train Generator #
            ###################

            set_requires_grad([D_A, D_B, L_A, L_B], requires_grad=False)

            # Forward Data #
            fake_B, fake_B_cam, _ = G_A2B(real_A)
            fake_A, fake_A_cam, _ = G_B2A(real_B)

            fake_ABA, _, _ = G_B2A(fake_B)
            fake_BAB, _, _ = G_A2B(fake_A)

            fake_A2A, fake_A2A_cam, _ = G_A2B(real_A)
            fake_B2B, fake_B2B_cam, _ = G_B2A(real_B)

            G_fake_A, G_fake_A_cam, _ = D_A(fake_A)
            L_fake_A, L_fake_A_cam, _ = L_A(fake_A)
            G_fake_B, G_fake_B_cam, _ = D_B(fake_B)
            L_fake_B, L_fake_B_cam, _ = L_B(fake_B)

            # Adversarial Loss of Generator #
            real_labels = torch.ones(G_fake_A.shape).to(device)
            G_adv_fake_loss_A = Adversarial_loss(G_fake_A, real_labels)

            real_labels = torch.ones(G_fake_A_cam.shape).to(device)
            G_adv_cam_fake_loss_A = Adversarial_loss(G_fake_A_cam, real_labels)

            G_adv_loss_A = G_adv_fake_loss_A + G_adv_cam_fake_loss_A

            real_labels = torch.ones(G_fake_B.shape).to(device)
            G_adv_fake_loss_B = Adversarial_loss(G_fake_B, real_labels)

            real_labels = torch.ones(G_fake_B_cam.shape).to(device)
            G_adv_cam_fake_loss_B = Adversarial_loss(G_fake_B_cam, real_labels)

            G_adv_loss_B = G_adv_fake_loss_B + G_adv_cam_fake_loss_B

            # Adversarial Loss of L #
            real_labels = torch.ones(L_fake_A.shape).to(device)
            L_adv_fake_loss_A = Adversarial_loss(L_fake_A, real_labels)

            real_labels = torch.ones(L_fake_A_cam.shape).to(device)
            L_adv_cam_fake_loss_A = Adversarial_loss(L_fake_A_cam, real_labels)

            L_adv_loss_A = L_adv_fake_loss_A + L_adv_cam_fake_loss_A

            real_labels = torch.ones(L_fake_B.shape).to(device)
            L_adv_fake_loss_B = Adversarial_loss(L_fake_B, real_labels)

            real_labels = torch.ones(L_fake_B_cam.shape).to(device)
            L_adv_cam_fake_loss_B = Adversarial_loss(L_fake_B_cam, real_labels)

            L_adv_loss_B = L_adv_fake_loss_B + L_adv_cam_fake_loss_B

            # Cycle Consistency Loss #
            G_recon_loss_A = Cycle_loss(fake_ABA, real_A)
            G_recon_loss_B = Cycle_loss(fake_BAB, real_B)

            G_identity_loss_A = Cycle_loss(fake_A2A, real_A)
            G_identity_loss_B = Cycle_loss(fake_B2B, real_B)

            G_cycle_loss_A = G_recon_loss_A + G_identity_loss_A
            G_cycle_loss_B = G_recon_loss_B + G_identity_loss_B

            # CAM Loss #
            real_labels = torch.ones(fake_A_cam.shape).to(device)
            G_cam_real_loss_A = BCE_loss(fake_A_cam, real_labels)

            fake_labels = torch.zeros(fake_A2A_cam.shape).to(device)
            G_cam_fake_loss_A = BCE_loss(fake_A2A_cam, fake_labels)

            G_cam_loss_A = G_cam_real_loss_A + G_cam_fake_loss_A

            real_labels = torch.ones(fake_B_cam.shape).to(device)
            G_cam_real_loss_B = BCE_loss(fake_B_cam, real_labels)

            fake_labels = torch.zeros(fake_B2B_cam.shape).to(device)
            G_cam_fake_loss_B = BCE_loss(fake_B2B_cam, fake_labels)

            G_cam_loss_B = G_cam_real_loss_B + G_cam_fake_loss_B

            # Calculate Each Generator Loss #
            G_loss_A = G_adv_loss_A + L_adv_loss_A + config.lambda_cycle * G_cycle_loss_A + config.lambda_cam * G_cam_loss_A
            G_loss_B = G_adv_loss_B + L_adv_loss_B + config.lambda_cycle * G_cycle_loss_B + config.lambda_cam * G_cam_loss_B

            # Calculate Total Generator Loss #
            G_loss = G_loss_A + G_loss_B

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            # Apply Rho Clipper to Generators #
            G_A2B.apply(Rho_Clipper)
            G_B2A.apply(Rho_Clipper)

            # Add items to Lists #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "U-GAT-IT | Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses), np.average(G_losses)))

                # Save Sample Images #
                save_samples(test_loader_selfie, G_A2B, epoch,
                             config.samples_path)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                D_A.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_D_A_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                D_B.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_D_B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                L_A.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_L_A_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                L_B.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_L_B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                G_A2B.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_G_A2B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                G_B2A.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_G_B2A_Epoch_{}.pkl'.format(epoch + 1)))

    # Plot Losses #
    plot_losses(D_losses, G_losses, config.num_epochs, config.plots_path)

    # Make a GIF file #
    make_gifs_train('U-GAT-IT', config.samples_path)

    print("Training finished.")
Exemplo n.º 3
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_horse_loader, train_zebra_loader = get_horse2zebra_loader(
        purpose='train', batch_size=config.batch_size)
    test_horse_loader, test_zebra_loader = get_horse2zebra_loader(
        purpose='test', batch_size=config.val_batch_size)
    total_batch = min(len(train_horse_loader), len(train_zebra_loader))

    # Prepare Networks #
    D_A = Discriminator()
    D_B = Discriminator()
    G_A2B = Generator()
    G_B2A = Generator()

    networks = [D_A, D_B, G_A2B, G_B2A]

    for network in networks:
        network.to(device)

    # Loss Function #
    criterion_Adversarial = nn.MSELoss()
    criterion_Cycle = nn.L1Loss()
    criterion_Identity = nn.L1Loss()

    # Optimizers #
    D_A_optim = torch.optim.Adam(D_A.parameters(),
                                 lr=config.lr,
                                 betas=(0.5, 0.999))
    D_B_optim = torch.optim.Adam(D_B.parameters(),
                                 lr=config.lr,
                                 betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(chain(G_A2B.parameters(), G_B2A.parameters()),
                               lr=config.lr,
                               betas=(0.5, 0.999))

    D_A_optim_scheduler = get_lr_scheduler(D_A_optim)
    D_B_optim_scheduler = get_lr_scheduler(D_B_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_losses_A, D_losses_B, G_losses = [], [], []

    # Training #
    print("Training CycleGAN started with total epoch of {}.".format(
        config.num_epochs))
    for epoch in range(config.num_epochs):
        for i, (horse,
                zebra) in enumerate(zip(train_horse_loader,
                                        train_zebra_loader)):

            # Data Preparation #
            real_A = horse.to(device)
            real_B = zebra.to(device)

            # Initialize Optimizers #
            G_optim.zero_grad()
            D_A_optim.zero_grad()
            D_B_optim.zero_grad()

            ###################
            # Train Generator #
            ###################

            set_requires_grad([D_A, D_B], requires_grad=False)

            # Adversarial Loss #
            fake_A = G_B2A(real_B)
            prob_fake_A = D_A(fake_A)
            real_labels = torch.ones(prob_fake_A.size()).to(device)
            G_mse_loss_B2A = criterion_Adversarial(prob_fake_A, real_labels)

            fake_B = G_A2B(real_A)
            prob_fake_B = D_B(fake_B)
            real_labels = torch.ones(prob_fake_B.size()).to(device)
            G_mse_loss_A2B = criterion_Adversarial(prob_fake_B, real_labels)

            # Identity Loss #
            identity_A = G_B2A(real_A)
            G_identity_loss_A = config.lambda_identity * criterion_Identity(
                identity_A, real_A)

            identity_B = G_A2B(real_B)
            G_identity_loss_B = config.lambda_identity * criterion_Identity(
                identity_B, real_B)

            # Cycle Loss #
            reconstructed_A = G_B2A(fake_B)
            G_cycle_loss_ABA = config.lambda_cycle * criterion_Cycle(
                reconstructed_A, real_A)

            reconstructed_B = G_A2B(fake_A)
            G_cycle_loss_BAB = config.lambda_cycle * criterion_Cycle(
                reconstructed_B, real_B)

            # Calculate Total Generator Loss #
            G_loss = G_mse_loss_B2A + G_mse_loss_A2B + G_identity_loss_A + G_identity_loss_B + G_cycle_loss_ABA + G_cycle_loss_BAB

            # Back Propagation and Update #
            G_loss.backward(retain_graph=True)
            G_optim.step()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad([D_A, D_B], requires_grad=True)

            ## Train Discriminator A ##
            # Real Loss #
            prob_real_A = D_A(real_A)
            real_labels = torch.ones(prob_real_A.size()).to(device)
            D_real_loss_A = criterion_Adversarial(prob_real_A, real_labels)

            # Fake Loss #
            prob_fake_A = D_A(fake_A.detach())
            fake_labels = torch.zeros(prob_fake_A.size()).to(device)
            D_fake_loss_A = criterion_Adversarial(prob_fake_A, fake_labels)

            # Calculate Total Discriminator A Loss #
            D_loss_A = config.lambda_identity * (D_real_loss_A +
                                                 D_fake_loss_A).mean()

            # Back propagation and Update #
            D_loss_A.backward(retain_graph=True)
            D_A_optim.step()

            ## Train Discriminator B ##
            # Real Loss #
            prob_real_B = D_B(real_B)
            real_labels = torch.ones(prob_real_B.size()).to(device)
            loss_real_B = criterion_Adversarial(prob_real_B, real_labels)

            # Fake Loss #
            prob_fake_B = D_B(fake_B.detach())
            fake_labels = torch.zeros(prob_fake_B.size()).to(device)
            loss_fake_B = criterion_Adversarial(prob_fake_B, fake_labels)

            # Calculate Total Discriminator B Loss #
            D_loss_B = config.lambda_identity * (loss_real_B +
                                                 loss_fake_B).mean()

            # Back propagation and Update #
            D_loss_B.backward(retain_graph=True)
            D_B_optim.step()

            # Add items to Lists #
            D_losses_A.append(D_loss_A.item())
            D_losses_B.append(D_loss_B.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "CycleGAN | Epoch [{}/{}] | Iterations [{}/{}] | D_A Loss {:.4f} | D_B Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses_A), np.average(D_losses_B),
                            np.average(G_losses)))

                # Save Sample Images #
                sample_images(test_horse_loader, test_zebra_loader, G_A2B,
                              G_B2A, epoch, config.samples_path)

        # Adjust Learning Rate #
        D_A_optim_scheduler.step()
        D_B_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G_A2B.state_dict(),
                os.path.join(
                    config.weights_path,
                    'CycleGAN_Generator_A2B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                G_B2A.state_dict(),
                os.path.join(
                    config.weights_path,
                    'CycleGAN_Generator_B2A_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train("CycleGAN", config.samples_path)

    # Plot Losses #
    plot_losses(D_losses_A, D_losses_B, G_losses, config.num_epochs,
                config.plots_path)

    print("Training finished.")
Exemplo n.º 4
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_loader = get_edges2shoes_loader('train', config.batch_size)
    val_loader = get_edges2shoes_loader('val', config.val_batch_size)
    total_batch = len(train_loader)

    # Prepare Networks #
    G_A2B = Generator().to(device)
    G_B2A = Generator().to(device)
    D_A = Discriminator().to(device)
    D_B = Discriminator().to(device)

    # Loss Function #
    criterion_Adversarial = nn.BCELoss()
    criterion_Recon = nn.MSELoss()
    criterion_Feature = nn.HingeEmbeddingLoss()

    # Optimizers #
    G_optim = torch.optim.Adam(chain(G_A2B.parameters(), G_B2A.parameters()),
                               config.lr,
                               betas=(0.5, 0.999),
                               weight_decay=0.00001)
    D_optim = torch.optim.Adam(chain(D_A.parameters(), D_B.parameters()),
                               config.lr,
                               betas=(0.5, 0.999),
                               weight_decay=0.00001)

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_losses, G_losses = [], []

    # Constants #
    iters = 0

    # Training #
    print("Training DiscoGAN started with total epoch of {}.".format(
        config.num_epochs))

    for epoch in range(config.num_epochs):

        for i, (real_A, real_B) in enumerate(train_loader):

            # Data Preparation #
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            # Initialize Models #
            G_A2B.zero_grad()
            G_B2A.zero_grad()
            D_A.zero_grad()
            D_B.zero_grad()

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            ################
            # Forward Data #
            ################

            fake_B = G_A2B(real_A)
            fake_A = G_B2A(real_B)

            prob_real_A, A_real_features = D_A(real_A)
            prob_fake_A, A_fake_features = D_A(fake_A)

            prob_real_B, B_real_features = D_B(real_B)
            prob_fake_B, B_fake_features = D_B(fake_B)

            #######################
            # Train Discriminator #
            #######################

            # Discriminator A #
            real_labels = Variable(torch.ones(prob_real_A.size()),
                                   requires_grad=False).to(device)
            D_real_loss_A = criterion_Adversarial(prob_real_A, real_labels)

            fake_labels = Variable(torch.zeros(prob_fake_A.size()),
                                   requires_grad=False).to(device)
            D_fake_loss_A = criterion_Adversarial(prob_fake_A, fake_labels)

            D_loss_A = (D_real_loss_A + D_fake_loss_A).mean()

            # Discriminator B #
            real_labels = Variable(torch.ones(prob_real_B.size()),
                                   requires_grad=False).to(device)
            D_real_loss_B = criterion_Adversarial(prob_real_B, real_labels)

            fake_labels = Variable(torch.zeros(prob_fake_B.size()),
                                   requires_grad=False).to(device)
            D_fake_loss_B = criterion_Adversarial(prob_fake_B, fake_labels)

            D_loss_B = (D_real_loss_B + D_fake_loss_B).mean()

            # Calculate Total Discriminator Loss #
            D_loss = D_loss_A + D_loss_B

            ###################
            # Train Generator #
            ###################

            # Adversarial Loss #
            real_labels = Variable(torch.ones(prob_real_A.size()),
                                   requires_grad=False).to(device)
            G_adv_loss_A = criterion_Adversarial(prob_fake_A, real_labels)

            real_labels = Variable(torch.ones(prob_real_B.size()),
                                   requires_grad=False).to(device)
            G_adv_loss_B = criterion_Adversarial(prob_fake_B, real_labels)

            # Feature Loss #
            G_feature_loss_A = feature_loss(criterion_Feature, A_real_features,
                                            A_fake_features)
            G_feature_loss_B = feature_loss(criterion_Feature, B_real_features,
                                            B_fake_features)

            # Reconstruction Loss #
            fake_ABA = G_B2A(fake_B)
            fake_BAB = G_A2B(fake_A)

            G_recon_loss_A = criterion_Recon(fake_ABA, real_A)
            G_recon_loss_B = criterion_Recon(fake_BAB, real_B)

            if iters < config.decay_gan_loss:
                rate = config.starting_rate
            else:
                print("Now the rate is changed to {}".format(
                    config.changed_rate))
                rate = config.changed_rate

            G_loss_A = (G_adv_loss_A * 0.1 + G_feature_loss_A * 0.9) * (
                1. - rate) + G_recon_loss_A * rate
            G_loss_B = (G_adv_loss_B * 0.1 + G_feature_loss_B * 0.9) * (
                1. - rate) + G_recon_loss_B * rate

            # Calculate Total Generator Loss #
            G_loss = G_loss_A + G_loss_B

            # Back Propagation and Update #
            if iters % config.num_train_gen == 0:
                D_loss.backward()
                D_optim.step()
            else:
                G_loss.backward()
                G_optim.step()

            # Add items to Lists #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "DiscoGAN | Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses), np.average(G_losses)))

                # Save Sample Images #
                sample_images(val_loader, G_A2B, G_B2A, epoch,
                              config.samples_path)

            iters += 1

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G_A2B.state_dict(),
                os.path.join(
                    config.weights_path,
                    'DiscoGAN_Generator_A2B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                G_B2A.state_dict(),
                os.path.join(
                    config.weights_path,
                    'DiscoGAN_Generator_B2A_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train('DiscoGAN', config.samples_path)

    # Plot Losses #
    plot_losses(D_losses, G_losses, config.num_epochs, config.plots_path)

    print("Training finished.")
Exemplo n.º 5
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_loader = get_facades_loader('train', config.batch_size)
    val_loader = get_facades_loader('val', config.val_batch_size)
    total_batch = len(train_loader)

    # Prepare Networks #
    D = Discriminator().to(device)
    G = Generator().to(device)

    # Criterion #
    criterion_Adversarial = nn.BCELoss()
    criterion_Pixelwise = nn.L1Loss()

    # Optimizers #
    D_optim = torch.optim.Adam(D.parameters(),
                               lr=config.lr,
                               betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(G.parameters(),
                               lr=config.lr,
                               betas=(0.5, 0.999))

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_losses, G_losses = [], []

    # Training #
    print("Training Pix2Pix started with total epoch of {}.".format(
        config.num_epochs))

    for epoch in range(config.num_epochs):

        for i, (real_A, real_B) in enumerate(train_loader):

            # Data Preparation #
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            ###################
            # Train Generator #
            ###################

            # Prevent Discriminator Update during Generator Update #
            set_requires_grad(D, requires_grad=False)

            # Adversarial Loss #
            fake_B = G(real_A)
            prob_fake = D(fake_B, real_A)
            real_labels = torch.ones(prob_fake.size()).to(device)
            G_loss_fake = criterion_Adversarial(prob_fake, real_labels)

            # Pixel-Wise Loss #
            G_loss_pixelwise = criterion_Pixelwise(fake_B, real_B)

            # Calculate Total Generator Loss #
            G_loss = G_loss_fake + config.l1_lambda * G_loss_pixelwise

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            #######################
            # Train Discriminator #
            #######################

            # Prevent Discriminator Update during Generator Update #
            set_requires_grad(D, requires_grad=True)

            # Adversarial Loss #
            prob_real = D(real_B, real_A)
            real_labels = torch.ones(prob_real.size()).to(device)
            D_real_loss = criterion_Adversarial(prob_real, real_labels)

            fake_B = G(real_A)
            prob_fake = D(fake_B.detach(), real_A)
            fake_labels = torch.zeros(prob_fake.size()).to(device)
            D_fake_loss = criterion_Adversarial(prob_fake, fake_labels)

            # Calculate Total Discriminator Loss #
            D_loss = torch.mean(D_real_loss + D_fake_loss)

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            # Add items to Lists #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "Pix2Pix | Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses), np.average(G_losses)))

                # Save Sample Images #
                sample_images(val_loader, G, epoch, config.samples_path)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G.state_dict(),
                os.path.join(
                    config.weights_path,
                    'Pix2Pix_Generator_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train('Pix2Pix', config.samples_path)

    # Plot Losses #
    plot_losses(D_losses, G_losses, config.plots_path)

    print("Training finished.")
Exemplo n.º 6
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_loader = get_edges2handbags_loader(purpose='train',
                                             batch_size=config.batch_size)
    val_loader = get_edges2handbags_loader(purpose='val',
                                           batch_size=config.batch_size)
    total_batch = len(train_loader)

    # Prepare Networks #
    D_cVAE = Discriminator()
    D_cLR = Discriminator()
    E = Encoder(config.z_dim)
    G = Generator(config.z_dim)

    networks = [D_cVAE, D_cLR, E, G]
    for network in networks:
        network.to(device)

    # Loss Function #
    criterion_Recon = nn.L1Loss()
    criterion_Adversarial = nn.MSELoss()

    # Optimizers #
    D_cVAE_optim = torch.optim.Adam(D_cVAE.parameters(),
                                    lr=config.lr,
                                    betas=(0.5, 0.999))
    D_cLR_optim = torch.optim.Adam(D_cLR.parameters(),
                                   lr=config.lr,
                                   betas=(0.5, 0.999))
    E_optim = torch.optim.Adam(E.parameters(),
                               lr=config.lr,
                               betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(G.parameters(),
                               lr=config.lr,
                               betas=(0.5, 0.999))

    D_cVAE_optim_scheduler = get_lr_scheduler(D_cVAE_optim)
    D_cLR_optim_scheduler = get_lr_scheduler(D_cLR_optim)
    E_optim_scheduler = get_lr_scheduler(E_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_losses, E_G_losses, G_losses = [], [], []

    # Fixed Noise #
    fixed_noise = torch.randn(config.test_size, config.num_images,
                              config.z_dim).to(device)

    # Training #
    print("Training BicycleGAN started total epoch of {}.".format(
        config.num_epochs))
    for epoch in range(config.num_epochs):
        for i, (sketch, target) in enumerate(train_loader):

            # Data Preparation #
            sketch = sketch.to(device)
            target = target.to(device)

            # Separate Data for D_cVAE-GAN and D_cLR-GAN #
            cVAE_data = {
                'sketch': sketch[0].unsqueeze(dim=0),
                'target': target[0].unsqueeze(dim=0)
            }
            cLR_data = {
                'sketch': sketch[1].unsqueeze(dim=0),
                'target': target[1].unsqueeze(dim=0)
            }

            # Initialize Optimizers #
            D_cVAE_optim.zero_grad()
            D_cLR_optim.zero_grad()
            E_optim.zero_grad()
            G_optim.zero_grad()

            # Train Discriminators #
            set_requires_grad([D_cVAE, D_cLR], requires_grad=True)

            ################################
            # Train Discriminator cVAE-GAN #
            ################################

            # Initialize Optimizers #
            D_cVAE_optim.zero_grad()
            D_cLR_optim.zero_grad()
            E_optim.zero_grad()
            G_optim.zero_grad()

            # Encode Latent Vector #
            mean, std = E(cVAE_data['target'])
            random_z = torch.randn(1, config.z_dim).to(device)
            encoded_z = mean + (random_z * std)

            # Generate Fake Image #
            fake_image_cVAE = G(cVAE_data['sketch'], encoded_z)

            # Forward to Discriminator cVAE-GAN #
            prob_real_D_cVAE_1, prob_real_D_cVAE_2 = D_cVAE(
                cVAE_data['target'])
            prob_fake_D_cVAE_1, prob_fake_D_cVAE_2 = D_cVAE(
                fake_image_cVAE.detach())

            # Adversarial Loss using cVAE_1 #
            real_labels = torch.ones(prob_real_D_cVAE_1.size()).to(device)
            D_cVAE_1_real_loss = criterion_Adversarial(prob_real_D_cVAE_1,
                                                       real_labels)

            fake_labels = torch.zeros(prob_fake_D_cVAE_1.size()).to(device)
            D_cVAE_1_fake_loss = criterion_Adversarial(prob_fake_D_cVAE_1,
                                                       fake_labels)

            D_cVAE_1_loss = D_cVAE_1_real_loss + D_cVAE_1_fake_loss

            # Adversarial Loss using cVAE_2 #
            real_labels = torch.ones(prob_real_D_cVAE_2.size()).to(device)
            D_cVAE_2_real_loss = criterion_Adversarial(prob_real_D_cVAE_2,
                                                       real_labels)

            fake_labels = torch.zeros(prob_fake_D_cVAE_2.size()).to(device)
            D_cVAE_2_fake_loss = criterion_Adversarial(prob_fake_D_cVAE_2,
                                                       fake_labels)

            D_cVAE_2_loss = D_cVAE_2_real_loss + D_cVAE_2_fake_loss

            ###########################
            # Train Discriminator cLR #
            ###########################

            # Initialize Optimizers #
            D_cVAE_optim.zero_grad()
            D_cLR_optim.zero_grad()
            E_optim.zero_grad()
            G_optim.zero_grad()

            # Generate Fake Image using Random Latent Vector #
            random_z = torch.randn(1, config.z_dim).to(device)
            fake_image_cLR = G(cLR_data['sketch'], random_z)

            # Forward to Discriminator cLR-GAN #
            prob_real_D_cLR_1, prob_real_D_cLR_2 = D_cLR(cLR_data['target'])
            prob_fake_D_cLR_1, prob_fake_D_cLR_2 = D_cLR(
                fake_image_cLR.detach())

            # Adversarial Loss using cLR-1 #
            real_labels = torch.ones(prob_real_D_cLR_1.size()).to(device)
            D_cLR_1_real_loss = criterion_Adversarial(prob_real_D_cLR_1,
                                                      real_labels)

            fake_labels = torch.zeros(prob_fake_D_cLR_1.size()).to(device)
            D_cLR_1_fake_loss = criterion_Adversarial(prob_fake_D_cLR_1,
                                                      fake_labels)

            D_cLR_1_loss = D_cLR_1_real_loss + D_cLR_1_fake_loss

            # Adversarial Loss using cLR-2 #
            real_labels = torch.ones(prob_real_D_cLR_2.size()).to(device)
            D_cLR_2_real_loss = criterion_Adversarial(prob_real_D_cLR_2,
                                                      real_labels)

            fake_labels = torch.zeros(prob_fake_D_cLR_2.size()).to(device)
            D_cLR_2_fake_loss = criterion_Adversarial(prob_fake_D_cLR_2,
                                                      fake_labels)

            D_cLR_2_loss = D_cLR_2_real_loss + D_cLR_2_fake_loss

            # Calculate Total Discriminator Loss #
            D_loss = D_cVAE_1_loss + D_cVAE_2_loss + D_cLR_1_loss + D_cLR_2_loss

            # Back Propagation and Update #
            D_loss.backward()
            D_cVAE_optim.step()
            D_cLR_optim.step()

            set_requires_grad([D_cVAE, D_cLR], requires_grad=False)

            ###############################
            # Train Encoder and Generator #
            ###############################

            # Initialize Optimizers #
            D_cVAE_optim.zero_grad()
            D_cLR_optim.zero_grad()
            E_optim.zero_grad()
            G_optim.zero_grad()

            # Encode Latent Vector #
            mean, std = E(cVAE_data['target'])
            random_z = torch.randn(1, config.z_dim).to(device)
            encoded_z = mean + (random_z * std)

            # Generate Fake Image #
            fake_image_cVAE = G(cVAE_data['sketch'], encoded_z)
            prob_fake_D_cVAE_1, prob_fake_D_cVAE_2 = D_cVAE(fake_image_cVAE)

            # Adversarial Loss using cVAE #
            real_labels = torch.ones(prob_fake_D_cVAE_1.size()).to(device)
            E_G_adv_cVAE_1_loss = criterion_Adversarial(
                prob_fake_D_cVAE_1, real_labels)

            real_labels = torch.ones(prob_fake_D_cVAE_2.size()).to(device)
            E_G_adv_cVAE_2_loss = criterion_Adversarial(
                prob_fake_D_cVAE_2, real_labels)

            E_G_adv_cVAE_loss = E_G_adv_cVAE_1_loss + E_G_adv_cVAE_2_loss

            # Generate Fake Image using Random Latent Vector #
            random_z = torch.randn(1, config.z_dim).to(device)
            fake_image_cLR = G(cLR_data['sketch'], random_z)
            prob_fake_D_cLR_1, prob_fake_D_cLR_2 = D_cLR(fake_image_cLR)

            # Adversarial Loss of cLR #
            real_labels = torch.ones(prob_fake_D_cLR_1.size()).to(device)
            E_G_adv_cLR_1_loss = criterion_Adversarial(prob_fake_D_cLR_1,
                                                       real_labels)

            real_labels = torch.ones(prob_fake_D_cLR_2.size()).to(device)
            E_G_adv_cLR_2_loss = criterion_Adversarial(prob_fake_D_cLR_2,
                                                       real_labels)

            E_G_adv_cLR_loss = E_G_adv_cLR_1_loss + E_G_adv_cLR_2_loss

            # KL Divergence with N ~ (0, 1) #
            E_KL_div_loss = config.lambda_KL * torch.sum(
                0.5 * (mean**2 + std - 2 * torch.log(std) - 1))

            # Reconstruction Loss #
            E_G_recon_loss = config.lambda_Image * criterion_Recon(
                fake_image_cVAE, cVAE_data['target'])

            # Total Encoder and Generator Loss ##
            E_G_loss = E_G_adv_cVAE_loss + E_G_adv_cLR_loss + E_KL_div_loss + E_G_recon_loss

            # Back Propagation and Update #
            E_G_loss.backward()
            E_optim.step()
            G_optim.step()

            ########################
            # Train Generator Only #
            ########################

            # Initialize Optimizers #
            D_cVAE_optim.zero_grad()
            D_cLR_optim.zero_grad()
            E_optim.zero_grad()
            G_optim.zero_grad()

            # Generate Fake Image using Random Latent Vector #
            random_z = torch.randn(1, config.z_dim).to(device)
            fake_image_cLR = G(cLR_data['sketch'], random_z)
            mean, std = E(fake_image_cLR)

            # Reconstruction Loss #
            G_recon_loss = criterion_Recon(mean, random_z)

            # Calculate Total Generator Loss #
            G_loss = config.lambda_Z * G_recon_loss

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            # Add items to Lists #
            D_losses.append(D_loss.item())
            E_G_losses.append(E_G_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "BicycleGAN | Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | E_G Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses), np.average(E_G_losses),
                            np.average(G_losses)))

                # Save Sample Images #
                sample_images(val_loader, G, fixed_noise, epoch,
                              config.num_images, config.samples_path)

        # Adjust Learning Rate #
        D_cVAE_optim_scheduler.step()
        D_cLR_optim_scheduler.step()
        E_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Models #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G.state_dict(),
                os.path.join(
                    config.weights_path,
                    'BicycleGAN_Generator_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train("BicycleGAN", config.samples_path)

    # Plot Losses #
    plot_losses(D_losses, E_G_losses, G_losses, config.num_epochs,
                config.plots_path)

    print("Training finished.")