Esempio n. 1
0
    def backward_Cycle_B(self):

        self.loss_Cycle_B = loss.cycle_consistency_loss(
            self.realB,
            self.cycleB,
            method='L1',
            loss_weight_config=self.loss_weight_config)

        self.loss_Cycle_B.backward(retain_graph=True)
Esempio n. 2
0
def training_loop(dataloader_X, dataloader_Y, test_dataloader_X, test_dataloader_Y,
                  n_epochs=1000):

    print_every=10

    # keep track of losses over time
    losses = []

    # Get some fixed data from domains X and Y for sampling. These are images that are held
    # constant throughout training, that allow us to inspect the model's performance.
    # make sure to scale to a range -1 to 1

    fixed_Y = next(iter(dataloader_Y))[0] ## shape: [batchsize, channels, height, width]
    fixed_X = next(iter(dataloader_X))[0]  ## shape: [batchsize, channels, height, width]
    global_step=0
    for epoch in range(pretrain_epoch, config['hyperparams']['epochs']+1):
      print("inside")
      epochG_loss = 0
      runningG_loss = 0
      runningDX_loss = 0
      runningDY_loss = 0
      mbps = 0 #mini batches per epoch

      for batch_id, (x, _) in tqdm(enumerate(dataloader_X), total=len(dataloader_X)):
        #  with torch.no_grad():
           mbps += 1
           global_step = global_step+1
           y, _ = next(iter(dataloader_Y))
           images_X = x # make sure to scale to a range -1 to 1
           images_Y = y
           del y
           # move images to GPU if available (otherwise stay on CPU)
           images_X = images_X.to(device)
           images_Y = images_Y.to(device)
          #  print("start:  ",convert_size(torch.cuda.memory_allocated(device=device)))

           d_x_optimizer.zero_grad()
           out_x = D_X(images_X)
           D_X_real_loss = loss.real_mse_loss(out_x)
           fake_X = G_YtoX(images_Y)
           out_x = D_X(fake_X)
           D_X_fake_loss = loss.fake_mse_loss(out_x)
           d_x_loss = D_X_real_loss + D_X_fake_loss
           d_x_loss.backward()
           d_x_optimizer.step()
           d_x_loss.detach(); out_x.detach(); D_X_fake_loss.detach();
           runningDX_loss += d_x_loss
           del D_X_fake_loss, D_X_real_loss, out_x, fake_X
           torch.cuda.empty_cache()

          #  print("end: DX block  and start DY", convert_size(torch.cuda.memory_allocated(device=device)))

           d_y_optimizer.zero_grad()
           out_y = D_Y(images_Y)
           D_Y_real_loss = loss.real_mse_loss(out_y)
           fake_Y = G_XtoY(images_X)
           out_y = D_Y(fake_Y)
           D_Y_fake_loss = loss.fake_mse_loss(out_y)
           d_y_loss = D_Y_real_loss + D_Y_fake_loss
           d_y_loss.backward()
           d_y_optimizer.step()
           d_y_loss.detach()
           runningDY_loss += d_y_loss
           del D_Y_fake_loss, D_Y_real_loss, out_y, fake_Y
           torch.cuda.empty_cache()
          #  print("End: DY ",convert_size(torch.cuda.memory_allocated(device=device)))


           g_optimizer.zero_grad()
           fake_Y = G_XtoY(images_X)
           out_y = D_Y(fake_Y)
           g_XtoY_loss = loss.real_mse_loss(out_y)
           reconstructed_X = G_YtoX(fake_Y)

           reconstructed_x_loss = loss.cycle_consistency_loss(images_X, reconstructed_X,
                                                              lambda_weight= config['hyperparams']['lambda_weight'])

           featuresY = loss_network(images_Y)
           featuresFakeY = loss_network(fake_Y)

           # print("\nFake Y: ", fake_Y.shape, "   imagesY: ", images_Y.shape,"\n",featuresY[1].data.shape, "   ", featuresFakeY[1].data.shape)
           # exit()
           CONTENT_WEIGHT = config['hyperparams']['Content_Weight']
           contentloss = CONTENT_WEIGHT * mse_loss(featuresY[1].data, featuresFakeY[1].data)
           del featuresY, featuresFakeY; torch.cuda.empty_cache()

           IDENTITY_WEIGHT = config['hyperparams']['Identity_Weight']

           downsample = nn.Upsample(scale_factor=0.25, mode='bicubic', align_corners=True)
           identity_loss = IDENTITY_WEIGHT * mse_loss(downsample(fake_Y), images_X )

           TOTAL_VARIATION_WEIGHT = config['hyperparams']['TotalVariation_Weight']
           tvloss = TOTAL_VARIATION_WEIGHT * loss.tv_loss(fake_Y, 0.25)

           g_total_loss = g_XtoY_loss + reconstructed_x_loss + identity_loss + tvloss + contentloss
          #  tvloss + content_loss_Y + identity_loss
           g_total_loss.backward()
           g_optimizer.step()
           del out_y, fake_Y, g_XtoY_loss, reconstructed_x_loss, reconstructed_X
          #  , tvloss content_loss_Y, identity_loss
          #  print("end: ", convert_size(torch.cuda.memory_allocated(device=device)))

           runningG_loss += g_total_loss
           writer.add_scalar('D/Y', d_y_loss.item(), global_step=global_step)
           writer.add_scalar('D/X', d_y_loss.item(), global_step=global_step)
           writer.add_scalar('G/TV', tvloss.item(), global_step=global_step)
           writer.add_scalar('G/Identity', identity_loss.item(), global_step=global_step)
           writer.add_scalar('G/Content', contentloss.item(), global_step=global_step)

           bs=config["exp_params"]["batch_size"]
           if config["logging_params"]["log"] and mbps % config["logging_params"]["log_interval"] == 0:
              with torch.no_grad():
                   G_XtoY.eval()
                   y=G_XtoY(fixed_X.to(device))
                   bs, c, h,w = y.size()
                   # x = F.interpolate(fixed_X[0:8, :,:,:], size=h)
                   concat = torch.cat([nn.Upsample(scale_factor=4, mode='bicubic', align_corners=True)(fixed_X.cuda()), y[0:8,:,:,:].cuda()], dim=0)
                   print("Saved image!")
                   writer.add_image(tag=str(epoch)+'/'+str(epoch)+'/'+str(mbps), img_tensor=vutils.make_grid(concat.to('cpu'), normalize=False,
                                                      pad_value=1, nrow=8), global_step=global_step)
                   G_XtoY.train()
           # print('Mini-batch no: {}, at epoch [{:3d}/{:3d}] | d_X_loss: {:6.4f} | d_Y_loss: {:6.4f}| g_total_loss: {:6.4f}'
           #          .format(mbps, epoch, n_epochs,  d_x_loss.item() , d_y_loss.item() , g_total_loss.item() ))
           # print(' TV-loss: ', tvloss.item(), '  content loss:', contentloss.item(),
           #                              '  identity loss:', identity_loss.item() )

      fid = calc_fid(G_XtoY.eval())
      G_XtoY.train()
      writer.add_scalar('Epoch/FID', fid, global_step=epoch)

      losses.append((runningDX_loss/mbps, runningDY_loss/mbps, runningG_loss/mbps))
      writer.add_scalar('Epoch/G', runningG_loss/mbps, global_step=epoch )
      writer.add_scalar('Epoch/D_X', runningDX_loss/mbps, global_step=epoch )
      writer.add_scalar('Epoch/D_Y', runningDY_loss/mbps, global_step=epoch )
      print('Epoch [{:5d}/{:5d}] | d_X_loss: {:6.4f} | d_Y_loss: {:6.4f} | g_total_loss: {:6.4f}'.format(epoch, n_epochs,
                                                        runningDX_loss/mbps ,  runningDY_loss/mbps,  runningG_loss/mbps ))


    return losses
Esempio n. 3
0
def training_loop(dataloader_X, dataloader_Y, #test_dataloader_X, test_dataloader_Y,
                  n_epochs=1000,
                  G_XtoY=None, G_YtoX=None, D_X=None, D_Y=None, lr=0.0002, beta1=0.5, beta2=0.999, save_path = './'):

    #If save folder does not exist, create it
    if not os.path.isdir(save_path):
        os.mkdir(save_path)
    g_params = list(G_XtoY.parameters()) + list(G_YtoX.parameters())  # Get generator parameters

    # Create optimizers for the generators and discriminators
    g_optimizer = optim.Adam(g_params, lr, [beta1, beta2])
    d_x_optimizer = optim.Adam(D_X.parameters(), lr, [beta1, beta2])
    d_y_optimizer = optim.Adam(D_Y.parameters(), lr, [beta1, beta2])

    print_every = 5

    # keep track of losses over time
    losses = []

    # test_iter_X = iter(test_dataloader_X)
    # test_iter_Y = iter(test_dataloader_Y)

    # Get some fixed data from domains X and Y for sampling. These are images that are held
    # constant throughout training, that allow us to inspect the model's performance.
    # fixed_X = test_iter_X.next()[0]
    # fixed_Y = test_iter_Y.next()[0]
    # fixed_X = scale(fixed_X)  # make sure to scale to a range -1 to 1
    # fixed_Y = scale(fixed_Y)

    # batches per epoch
    iter_X = iter(dataloader_X)
    iter_Y = iter(dataloader_Y)
    batches_per_epoch = min(len(iter_X), len(iter_Y))

    for epoch in range(1, n_epochs + 1):

        # Reset iterators for each epoch
        if epoch % batches_per_epoch == 0:
            iter_X = iter(dataloader_X)
            iter_Y = iter(dataloader_Y)

        images_X = iter_X.next()
        real_images = images_X
        # images_X = scale(images_X)  # make sure to scale to a range -1 to 1

        images_Y = iter_Y.next()
        real_images = images_X
        # images_Y = scale(images_Y)

        # move images to GPU if available (otherwise stay on CPU)
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        images_X = images_X.to(device)
        images_Y = images_Y.to(device)

        # ============================================
        #            TRAIN THE DISCRIMINATORS
        # ============================================

        ##   First: D_X, real and fake loss components   ##

        # Train with real images
        d_x_optimizer.zero_grad()

        # 1. Compute the discriminator losses on real images
        out_x = D_X(images_X)
        D_X_real_loss = loss.real_mse_loss(out_x)

        # Train with fake images

        # 2. Generate fake images that look like domain X based on real images in domain Y
        fake_X = G_YtoX(images_Y)

        # 3. Compute the fake loss for D_X
        out_x = D_X(fake_X)
        D_X_fake_loss = loss.fake_mse_loss(out_x)

        # 4. Compute the total loss and perform backprop
        d_x_loss = D_X_real_loss + D_X_fake_loss
        d_x_loss.backward()
        d_x_optimizer.step()

        ##   Second: D_Y, real and fake loss components   ##

        # Train with real images
        d_y_optimizer.zero_grad()

        # 1. Compute the discriminator losses on real images
        out_y = D_Y(images_Y)
        D_Y_real_loss = loss.real_mse_loss(out_y)

        # Train with fake images

        # 2. Generate fake images that look like domain Y based on real images in domain X
        fake_Y = G_XtoY(images_X)

        # 3. Compute the fake loss for D_Y
        out_y = D_Y(fake_Y)
        D_Y_fake_loss = loss.fake_mse_loss(out_y)

        # 4. Compute the total loss and perform backprop
        d_y_loss = D_Y_real_loss + D_Y_fake_loss
        d_y_loss.backward()
        d_y_optimizer.step()

        # =========================================
        #            TRAIN THE GENERATORS
        # =========================================

        ##    First: generate fake X images and reconstructed Y images    ##
        g_optimizer.zero_grad()

        # 1. Generate fake images that look like domain X based on real images in domain Y
        fake_X = G_YtoX(images_Y)

        # 2. Compute the generator loss based on domain X
        out_x = D_X(fake_X)
        g_YtoX_loss = loss.real_mse_loss(out_x)

        # 3. Create a reconstructed y
        # 4. Compute the cycle consistency loss (the reconstruction loss)
        reconstructed_Y = G_XtoY(fake_X)
        # print(images_Y.shape) #[8, 3, 215, 215]
        # print(reconstructed_Y.shape) #[8, 3, 208, 208]
        reconstructed_y_loss = loss.cycle_consistency_loss(images_Y, reconstructed_Y, lambda_weight=10)

        ##    Second: generate fake Y images and reconstructed X images    ##

        # 1. Generate fake images that look like domain Y based on real images in domain X
        fake_Y = G_XtoY(images_X)

        # 2. Compute the generator loss based on domain Y
        out_y = D_Y(fake_Y)
        g_XtoY_loss = loss.real_mse_loss(out_y)

        # 3. Create a reconstructed x
        # 4. Compute the cycle consistency loss (the reconstruction loss)
        reconstructed_X = G_YtoX(fake_Y)
        reconstructed_x_loss = loss.cycle_consistency_loss(images_X, reconstructed_X, lambda_weight=10)

        # 5. Add up all generator and reconstructed losses and perform backprop
        g_total_loss = g_YtoX_loss + g_XtoY_loss + reconstructed_y_loss + reconstructed_x_loss
        g_total_loss.backward()
        g_optimizer.step()

        # Print the log info
        if epoch % 50== 0:
            filename = save_path + str(epoch)
            # append real and fake discriminator losses and the generator loss
            losses.append((d_x_loss.item(), d_y_loss.item(), g_total_loss.item()))
            print('Epoch [{:5d}/{:5d}] | d_X_loss: {:6.4f} | d_Y_loss: {:6.4f} | g_total_loss: {:6.4f}'.format(
                epoch, n_epochs, d_x_loss.item(), d_y_loss.item(), g_total_loss.item()))

            # generate image
            X_fake = G_XtoY(images_X)
            X_fake = unnormalize_images(X_fake)
            images_X = unnormalize_images(images_X)
            grid_image_real = torchvision.utils.make_grid(images_X.cpu())
            grid_image_fake = torchvision.utils.make_grid(X_fake.cpu())
            grid_image = torch.cat((grid_image_real, grid_image_fake), 1)
            saveim = np.transpose(grid_image.data.numpy().astype(np.uint8), (1, 2, 0))
            # plt.figure(figsize=(20, 10))
            # plt.imshow(saveim)
            # plt.savefig(filename + '_' + 'XtoY.jpg')
            path = filename + '_' + 'XtoY.jpg'
            imageio.imwrite(path, saveim)
            print('Saved {}'.format(path))

            Y_fake = G_YtoX(images_Y)
            Y_fake = unnormalize_images(Y_fake)
            images_Y = unnormalize_images(images_Y)
            grid_image_real = torchvision.utils.make_grid(images_Y.cpu())
            grid_image_fake = torchvision.utils.make_grid(Y_fake.cpu())
            grid_image = torch.cat((grid_image_real, grid_image_fake), 1)
            saveim = np.transpose(grid_image.data.numpy().astype(np.uint8), (1, 2, 0))
            # plt.figure(figsize=(20, 10))
            # plt.imshow(saveim)
            # plt.savefig(filename + '_' + 'YtoX.jpg')
            path = filename + '_' + 'YtoX.jpg'
            imageio.imwrite(path, saveim)
            print('Saved {}'.format(path))

        # sample_every = 1
        # # Save the generated samples
        # if epoch % sample_every == 0:
        #     G_YtoX.eval()  # set generators to eval mode for sample generation
        #     G_XtoY.eval()
        #     helper.save_samples(epoch, images_Y, images_X, G_YtoX, G_XtoY, batch_size=8)
        #     G_YtoX.train()
        #     G_XtoY.train()

        # uncomment these lines, if you want to save your model
    #         checkpoint_every=1000
    #         # Save the model parameters
    #         if epoch % checkpoint_every == 0:
    #             checkpoint(epoch, G_XtoY, G_YtoX, D_X, D_Y)

    return losses, G_XtoY, G_YtoX, D_X, D_Y
Esempio n. 4
0
            ##    First: generate fake X images and reconstructed Y images    ##
            g_optimizer.zero_grad()

            # 1. Generate fake images that look like domain X based on real images in domain Y
            fake_X = G_YtoX(images_Y)

            # 2. Compute the generator loss based on domain X
            out_x = D_X(fake_X)
            g_YtoX_loss = loss.real_mse_loss(out_x)

            # 3. Create a reconstructed y
            # 4. Compute the cycle consistency loss (the reconstruction loss)
            reconstructed_Y = G_XtoY(fake_X)
            # print(images_Y.shape) #[8, 3, 215, 215]
            # print(reconstructed_Y.shape) #[8, 3, 208, 208]
            reconstructed_y_loss = loss.cycle_consistency_loss(
                images_Y, reconstructed_Y, lambda_weight=Y_lambda_weight)

            ##    Second: generate fake Y images and reconstructed X images    ##

            # 1. Generate fake images that look like domain Y based on real images in domain X
            fake_Y = G_XtoY(images_X)

            # 2. Compute the generator loss based on domain Y
            out_y = D_Y(fake_Y)
            g_XtoY_loss = loss.real_mse_loss(out_y)

            # 3. Create a reconstructed x
            # 4. Compute the cycle consistency loss (the reconstruction loss)
            reconstructed_X = G_YtoX(fake_Y)
            reconstructed_x_loss = loss.cycle_consistency_loss(
                images_X, reconstructed_X, lambda_weight=X_lambda_weight)