示例#1
0
def train():
    # Inspired by https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
    # Lists to keep track of progress
    num_epochs = args.max_epoch
    #img_list = []
    iters = 0
    print("Starting Training Loop...")
    sys.stdout.flush()
    for epoch in range(num_epochs):
        start = time.time()
        G_losses = []
        D_losses = []
        for i, data in enumerate(train_loader, 0):
            global_i = len(train_loader) * epoch + i
            ############################
            # D network
            ###########################
            ## Train with all-real batch
            #with torch.autograd.detect_anomaly():
            optimizerD.zero_grad()
            real_image, mask, masked_image, loss_mask = data[0].to(device), data[1].to(device), data[2].to(device), data[3].to(device)
            jitter_real = torch.empty_like(real_image, device=device).uniform_(-0.05 * (0.99 ** epoch), 0.05 * (0.99 ** epoch))
            jitter_fake = torch.empty_like(real_image, device=device).uniform_(-0.05 * (0.99 ** epoch), 0.05 * (0.99 ** epoch))
            #real_preds, real_feats = netD(real_image, mask)
            real_preds, real_feats = netD(torch.clamp(real_image + jitter_real, -1, 1), mask)
            ## Train with all-fake batch
            # noise = torch.randn(b_size, nz, 1, 1, device=device)

            _, skips = netS1(masked_image)
            embed, _ = netS1(real_image, False)
            #embed, _ = netS2(real_image, False)
            style_code, mu, sigma = netM(embed)
            fake = netG(style_code, mask, skips)
            fake_preds, fake_feats = netD(torch.clamp(fake.detach() + jitter_fake, -1, 1), mask)
            #fake_preds, fake_feats = netD(fake.detach(), mask)
            errD = 0.0
            for fp, rp in zip(fake_preds, real_preds):
                errD += losses.hinge_loss_discriminator(fp, rp)
            errD.backward()
            optimizerD.step()

            # dump train metrics to tensorboard
            if writer is not None:
                writer.add_scalar(f"loss_D", errD.item(), global_i)
            ############################
            # G network
            ###########################
            optimizerM.zero_grad()
            optimizerS1.zero_grad()
            #optimizerS2.zero_grad()
            optimizerG.zero_grad()
            #l1 = losses.masked_l1(fake, masked_image, loss_mask) * args.cycle_lambda
            #l1.backward(retain_graph=True)
            dkl = losses.KL_divergence(mu, sigma) * args.kl_lambda
            dkl.backward(retain_graph=True)
            errG_p = 0.0
            for ff, rf in zip(fake_vgg_f, real_vgg_f):
                errG_p += losses.perceptual_loss(ff, rf.detach(), args.fm_lambda)
            errG_p.backward(retain_graph=True)
            fake_preds, fake_feats = netD(fake, mask)
            errG_hinge = 0.0
            for fp in fake_preds:
                errG_hinge += losses.hinge_loss_generator(fp)
            errG_hinge.backward(retain_graph=True)
            errG_fm = 0.0
            for ff, rf in zip(fake_feats, real_feats):
                errG_fm += losses.perceptual_loss(ff, rf.detach(), args.fm_lambda)
            errG_fm.backward()
            errG = errG_hinge.item() + errG_fm.item() + errG_p.item() #+ l1.item()

            if args.G_orth > 0.0:
                losses.ortho(netG, args.G_orth,
                            blacklist=[])
            optimizerG.step()
            optimizerS1.step()
            #optimizerS2.step()
            optimizerM.step()
            if writer is not None:
                writer.add_scalar(f"loss_G", errG, global_i)
            # Output training stats
            if i % 500 == 499:
               print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\t'
                    % (epoch, num_epochs, i, len(train_loader), errD.item(), errG))
               sys.stdout.flush()
               with torch.no_grad():
                    netG.eval()
                    netS1.eval()
                    #netS2.eval()
                    netM.eval()
                    _, test_skips = netS1(fixed_test_images)
                    test_embed, _ = netS1(fixed_test_real_images, False)
                    #test_embed, _ = netS2(fixed_test_real_images, False)
                    style_code, _, _ = netM(test_embed)
                    test_generated = netG(style_code, fixed_test_masks, test_skips).detach().cpu()
                    netG.train()
                    netS1.train()
                    #netS2.train()
                    netM.train()
               tim = vutils.save_image(test_generated.data[:16], '%s/%d.png' % (test_save_dir, epoch),
                                            normalize=True, save=False)
               writer.add_image('generated', tim, global_i, dataformats='HWC')
            G_losses.append(errG)
            D_losses.append(errD)

        # Check how the generator is doing by saving G's output on fixed_noise
        end = time.time()
        hours, rem = divmod(end - start, 3600)
        minutes, seconds = divmod(rem, 60)
        if epoch % args.save_frequency == 0:
            with torch.no_grad():
                netG.eval()
                _, test_skips = netS1(fixed_test_images)
                test_embed, _ = netS1(fixed_test_real_images, False)
                #test_embed, _ = netS2(fixed_test_real_images, False)
                style_code, _, _ = netM(test_embed)
                test_generated = netG(style_code, fixed_test_masks, test_skips).detach().cpu()
                netG.train()
                netS1.train()
                #netS2.train()
                netM.train()
            # img_list.append(fake.data.numpy())

            print("Epoch %d - Elapsed time: {:0>2}:{:0>2}:{:05.2f}".format(epoch, int(hours), int(minutes), seconds))
            sys.stdout.flush()
            _ = vutils.save_image(test_generated.data[:16], '%s/%d.png' % (test_save_dir, epoch), normalize=True)
            # writer.add_image('generated', tim, epoch, dataformats='HWC')
            torch.save(netG.state_dict(), os.path.join(args.root_path, 'NetG' + best_model_path))
            torch.save(netD.state_dict(), os.path.join(args.root_path, 'NetD' + best_model_path))
            # plt.imsave(os.path.join(
            # './{}/'.format(test_save_dir) + 'img{}.png'.format(datetime.now().strftime("%d.%m.%Y-%H:%M:%S"))),
            # ((img_list[-1][0] + 1) / 2.0).transpose([1, 2, 0]), cmap='gray', interpolation="none")
            if writer is not None:
               writer.add_scalar(f"loss_G_epoch", np.sum(G_losses) / len(train_loader), epoch)
               writer.add_scalar(f"loss_D_epoch", np.sum(D_losses) / len(train_loader), epoch)
            iters += 1
示例#2
0
transform = Compose([Resize(resize_height, resize_width),
                    HorizontalFlip(p=0.5),
                    ToTensorV2()])


train_dataset = 
print('Loading Dataset...')
sys.stdout.flush()
train_loader = data_utils.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)

test_batch = next(iter(train_loader))
fixed_test_images = test_batch[2].to(device)
fixed_test_masks = test_batch[1].to(device)
fixed_test_real_images = test_batch[0].to(device)

_ = vutils.save_image(fixed_test_real_images.cpu().data[:16], '!test.png', normalize=True)
_ = vutils.save_image(fixed_test_images.cpu().data[:16], '!test_noise.png', normalize=True)

##MODE
netD = MultiscaleDiscriminator(args.mask_channels + 3).to(device)
netD.apply(weights_init)

netG = GauGANUnetStylizationGenerator(args.mask_channels, args.encoder_latent_dim, 2, args.unet_ch, device).to(device)
netG.apply(weights_init)

netS1 = StyleEncoder(args.encoder_latent_dim, args.unet_ch, 2).to(device)
netS1.apply(weights_init)

#netS2 = StyleEncoder(args.encoder_latent_dim, args.unet_ch, 2, need_skips=False).to(device)
#netS2.apply(weights_init)
                                    transform=transform,
                                    return_masked_image=True)
#val_dataset = DeepFashion2Dataset(os.path.join(args.data_root, 'validation'),  transform=transform, return_masked_image= True )

print('Loading Dataset...')
sys.stdout.flush()
train_loader = data_utils.DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True)
#val_loader = data_utils.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True)

test_batch = next(iter(train_loader))
fixed_test_images = test_batch[2].to(device)
fixed_test_masks = test_batch[1].to(device)

_ = vutils.save_image(test_batch[0].data[:16], '!test.png', normalize=True)

##MODE
netD = MultiscaleDiscriminator(args.mask_channels + 3).to(device)
netD.apply(weights_init)

netG = GauGANUnetGenerator(args.mask_channels, args.encoder_latent_dim, 2,
                           args.unet_ch).to(device)
netG.apply(weights_init)

netE = UnetEncoder(args.encoder_latent_dim, args.unet_ch, 2).to(device)
netE.apply(weights_init)

vgg = Vgg19Full().to(device)
vgg.eval()
示例#4
0
def RL_restore(blurry_tensor,
               initial_output,
               kernels,
               masks,
               n_iters,
               GPU,
               SAVE_INTERMIDIATE=True,
               method='basic',
               saturation_threshold=0.7,
               reg_factor=1e-3,
               gamma_correction_factor=1.0,
               isDebug=False):

    epsilon = 1e-6

    deblug_folder = './RL_deblug'
    if not os.path.exists(deblug_folder):
        os.makedirs(deblug_folder)

    blurry_tensor_ph = blurry_tensor**gamma_correction_factor  # to photons space
    output = initial_output
    kernels_flipped = torch.flip(kernels, dims=(2, 3))
    for it in range(n_iters):
        output_ph = output**gamma_correction_factor
        output_reblurred_ph = forward_reblur(output_ph,
                                             kernels,
                                             masks,
                                             GPU,
                                             size='same',
                                             padding_mode='reflect',
                                             manage_saturated_pixels=False,
                                             max_value=1)
        #output_reblurred = output_reblurred_ph**(1.0/gamma_correction_factor)  # to pixels space
        if method == 'basic':
            relative_blur = torch.div(blurry_tensor_ph,
                                      output_reblurred_ph + epsilon)

        elif method == 'function':
            R = apply_saturation_function(output_reblurred_ph, max_value=1)
            R_prima = apply_saturation_function(output_reblurred_ph,
                                                max_value=1,
                                                get_derivative=True)
            relative_blur = torch.div(blurry_tensor_ph * R_prima,
                                      R + epsilon) + 1 - R_prima
        elif method == 'masked':
            mask = blurry_tensor < saturation_threshold
            relative_blur = torch.div(
                blurry_tensor_ph * mask,
                output_reblurred_ph + epsilon) + 1 - mask.float()

        error_estimate = forward_reblur(relative_blur,
                                        kernels_flipped,
                                        masks,
                                        GPU,
                                        size='same',
                                        padding_mode='reflect',
                                        manage_saturated_pixels=False,
                                        max_value=1)

        output_ph = output_ph * error_estimate
        J_reg_grad = reg_factor * normalised_gradient_divergence(output)
        output = output_ph**(1.0 /
                             gamma_correction_factor) * (1.0 /
                                                         (1 - J_reg_grad))

        if isDebug:
            # compute ||K*I -B||
            # reblur_loss = model.reblurLoss(2*output_reblurred,  model.real_A[:,:, K//2:-K//2+1, K//2:-K//2+1]) * opt.lambda_reblur
            reblur_loss = torch.mean(
                (output_reblurred_ph**(1.0 / gamma_correction_factor) -
                 blurry_tensor)**2)
            PSNR_reblur = 10 * np.log10(1 / reblur_loss.item())
            print('PSNR_reblur', PSNR_reblur)
            if (((SAVE_INTERMIDIATE and it % np.max([1, n_iters // 10]) == 0)
                 or it == (n_iters - 1))):

                if it == (n_iters - 1):
                    filename = os.path.join(
                        deblug_folder, 'iter_%06i_restored.png' % (n_iters))
                else:
                    filename = os.path.join(deblug_folder,
                                            'iter_%06i.png' % it)

                save_image(tensor2im(output[0].detach().clamp(0, 1) - 0.5),
                           filename)

                print(it, 'PSNR_reblur: ', PSNR_reblur.item())

    return output
示例#5
0
            blurry_image = gray2rgb(blurry_image)
        new_shape = (int(args.resize_factor * M), int(args.resize_factor * N),
                     C)
        blurry_image = resize(blurry_image, new_shape).astype(np.float32)

    initial_image = blurry_image.copy()

    blurry_tensor = transforms.ToTensor()(blurry_image)
    blurry_tensor = blurry_tensor[None, :, :, :]
    blurry_tensor = blurry_tensor.cuda(args.gpu_id)

    initial_restoration_tensor = transforms.ToTensor()(initial_image)
    initial_restoration_tensor = initial_restoration_tensor[None, :, :, :]
    initial_restoration_tensor = initial_restoration_tensor.cuda(args.gpu_id)

    save_image(tensor2im(initial_restoration_tensor[0] - 0.5),
               os.path.join(args.output_folder, img_name + '.png'))

    with torch.no_grad():
        blurry_tensor_to_compute_kernels = blurry_tensor**args.gamma_factor - 0.5
        kernels, masks = two_heads(blurry_tensor_to_compute_kernels)
        save_kernels_grid(
            blurry_tensor[0], kernels[0], masks[0],
            os.path.join(args.output_folder, img_name + '_kernels' + '.png'))

    output = initial_restoration_tensor

    with torch.no_grad():

        if args.saturation_method == 'combined':
            output = combined_RL_restore(
                blurry_tensor,