Пример #1
0
def plot_image(epoch, generator, dataloader, dim=(1, 3), figsize=(15, 5)):
    for i, imgs in tqdm(enumerate(dataloader)):
        imgs_hr = Variable(imgs["hr"].type(Tensor))
        imgs_lr = Variable(imgs["lr"].type(Tensor))
        gen_hr = generator(imgs_lr)

        #denormalize input
        imgs_lr = denormalize(imgs_lr)

        # Scaling output
        gen_hr = gen_hr.clamp(0, 1)

        for j in range(imgs_lr.shape[0]):
            batches_done = i * len(dataloader) + j
            psnr_val = psnr_fn(gen_hr[[j]], imgs_hr[[j]]).mean().item()
            ssim_val = ssim_fn(gen_hr[[j]], imgs_hr[[j]]).mean().item()

            file_name = os.path.join(
                args.output_path, 'generated_image_' + str(batches_done) +
                "_" + str(epoch) + "_PSNR : " + str(round(psnr_val, 2)) +
                " SSIM : " + str(round(ssim_val, 2)) + '.png')

            lr_image = F.interpolate(imgs_lr[[j]],
                                     (imgs_hr.shape[2], imgs_hr.shape[3]),
                                     mode='nearest')[0]
            hr_image = imgs_hr[j]
            gen_image = gen_hr[j]

            concat_image = torch.cat((lr_image, gen_image, hr_image), 2)

            save_image(concat_image, file_name)
Пример #2
0
def plot_image(generators, dataloader, model_names):
    for i, imgs in tqdm(enumerate(dataloader)):
        imgs_hr = Variable(imgs["hr"].type(Tensor))
        imgs_lr = Variable(imgs["lr"].type(Tensor))

        for j in range(imgs_lr.shape[0]):
            batches_done = i * len(dataloader) + j

            file_name = os.path.join(
                args.output_path,
                'generated_image_11_' + str(batches_done) + '.png')
            csv_file_name = os.path.join(
                args.output_path,
                'generated_image_11_' + str(batches_done) + '.csv')

            imgs_lr_1 = imgs_lr.clone()
            lr_image = F.interpolate(denormalize(imgs_lr_1)[[j]],
                                     (imgs_hr.shape[2], imgs_hr.shape[3]),
                                     mode='nearest')[0]
            if args.debug:
                print("LR Shape", lr_image.shape)
            final_image = lr_image.unsqueeze(0)

            list_psnr = []
            list_ssim = []

            for generator in generators:
                gen_hr = generator(imgs_lr)
                gen_hr = gen_hr.clamp(0, 1)
                psnr_val = psnr_fn(gen_hr[[j]], imgs_hr[[j]]).mean().item()
                ssim_val = ssim_fn(gen_hr[[j]], imgs_hr[[j]]).mean().item()
                final_image = torch.cat((final_image, gen_hr[[j]]), 0)
                list_psnr.append(psnr_val)
                list_ssim.append(ssim_val)

            final_image = torch.cat((final_image, imgs_hr[[j]]), 0)

            if args.debug:
                print("Final Shape", final_image.shape)

            grid_image = torchvision.utils.make_grid(final_image, nrow=2)

            if args.debug:
                print("Grid Shape", grid_image.shape)

            save_image(grid_image, file_name)

            for i in range(len(model_names)):
                write_to_csv_file(csv_file_name, [
                    args.dataset_name, model_names[i], list_psnr[i],
                    list_ssim[i]
                ])
Пример #3
0
def test(generator , dataloader):
	psnr_val = 0
	ssim_val = 0
	for i,imgs in tqdm(enumerate(dataloader)):
		imgs_hr = Variable(imgs['hr'].type(Tensor))	
		imgs_lr = Variable(imgs['lr'].type(Tensor))
		gen_hr =  generator(imgs_lr)

		# Scaling output
		gen_hr = gen_hr.clamp(0,1)
		
		psnr_val += psnr_fn(gen_hr , imgs_hr).mean().item()
		ssim_val += ssim_fn(gen_hr , imgs_hr).mean().item()
	psnr_val /= len(dataloader)
	ssim_val /= len(dataloader)
	write_to_csv_file( os.path.join(args.output_path ,'test_log.csv' ) , [args.dataset_name , args.model_name , args.model_path , psnr_val , ssim_val])
Пример #4
0
def train():
    print("GPU : ", torch.cuda.is_available())

    generator, optimizer_G, scheduler_G = model.get_generator(args)
    generator.to(device)

    discriminator, optimizer_D = model.get_discriminator(args)
    discriminator.to(device)

    if args.method == 'M3':
        discriminator2, optimizer_D2 = model.get_discrminator2(args)
        discriminator2.to(device)

    start_epoch = 0

    if args.resume_training:
        if args.debug:
            print("Resuming Training")
        checkpoint = torch.load(args.checkpoint_path)
        start_epoch = checkpoint['epoch']
        generator.load_state_dict(checkpoint['gen_state_dict'])
        optimizer_G.load_state_dict(checkpoint['gen_optimizer_dict'])
        scheduler_G.load_state_dict(checkpoint['gen_scheduler_dict'])
        discriminator.load_state_dict(checkpoint['dis_state_dict'])
        optimizer_D.load_state_dict(checkpoint['dis_optimizer_dict'])

    feature_extractor = VGGFeatureExtractor().to(device)

    # Set feature extractor to inference mode
    feature_extractor.eval()

    # Losses
    bce_loss = torch.nn.BCEWithLogitsLoss().to(device)
    l1_loss = torch.nn.L1Loss().to(device)
    l2_loss = torch.nn.MSELoss().to(device)

    # equal to negative of hypervolume of input losses
    hv_loss = HVLoss().to(device)

    # 1 - ssim(sr , hr)
    ssim_loss = SSIMLoss().to(device)

    dataloader = data.dataloader(args)
    test_dataloader = data.dataloader(args, train=False)

    batch_count = len(dataloader)

    generator.train()

    loss_len = 6
    if args.method == 'M4' or args.method == 'M7':
        loss_len = 7
    elif args.method == 'M6':
        loss_len = 9

    losses_log = np.zeros(loss_len + 1)

    for epoch in range(start_epoch, start_epoch + args.epochs):
        # print("*"*15 , "Epoch :" , epoch , "*"*15)
        losses_gen = np.zeros(loss_len)
        for i, imgs in tqdm(enumerate(dataloader)):
            batches_done = epoch * len(dataloader) + i

            # Configure model input
            imgs_hr = imgs["hr"].to(device)
            imgs_lr = imgs["lr"].to(device)

            # ------------------
            #  Train Generators
            # ------------------

            # optimize generator
            # discriminator.eval()
            for p in discriminator.parameters():
                p.requires_grad = False

            if args.method == 'M3':
                for p in discriminator2.parameters():
                    p.requires_grad = False

            optimizer_G.zero_grad()

            gen_hr = generator(imgs_lr)

            # Scaling/Clipping output
            # gen_hr = gen_hr.clamp(0,1)

            if batches_done < args.warmup_batches:
                # Measure pixel-wise loss against ground truth
                if args.warmup_loss == "L1":
                    loss_pixel = l1_loss(gen_hr, imgs_hr)
                elif args.warmup_loss == "L2":
                    loss_pixel = l2_loss(gen_hr, imgs_hr)
                # Warm-up (pixel-wise loss only)
                loss_pixel.backward()
                optimizer_G.step()
                if args.debug:
                    print("[Epoch %d/%d] [Batch %d/%d] [G pixel: %f]" %
                          (epoch, args.epochs, i, len(dataloader),
                           loss_pixel.item()))
                continue

            # Extract validity predictions from discriminator
            pred_real = discriminator(imgs_hr).detach()
            pred_fake = discriminator(gen_hr)

            # Adversarial ground truths
            valid = torch.ones_like(pred_real)
            fake = torch.zeros_like(pred_real)

            if args.gan == 'RAGAN':
                # Adversarial loss (relativistic average GAN)
                loss_GAN = bce_loss(
                    pred_fake - pred_real.mean(0, keepdim=True), valid)
            elif args.gan == "VGAN":
                # Adversarial loss (vanilla GAN)
                loss_GAN = bce_loss(pred_fake, valid)

            if args.method == 'M3':
                # Extract validity predictions from discriminator
                pred_real2 = discriminator2(imgs_hr).detach()
                pred_fake2 = discriminator2(gen_hr)

                valid2 = torch.ones_like(pred_real2)
                fake2 = torch.zeros_like(pred_real2)
                if args.gan == 'RAGAN':
                    # Adversarial loss (relativistic average GAN)
                    loss_GAN2 = bce_loss(
                        pred_fake2 - pred_real2.mean(0, keepdim=True), valid2)
                elif args.gan == "VGAN":
                    # Adversarial loss (vanilla GAN)
                    loss_GAN2 = bce_loss(pred_fake2, valid2)

            # Content loss
            gen_features = feature_extractor(gen_hr)
            real_features = feature_extractor(imgs_hr).detach()
            if args.vgg_criterion == 'L1':
                loss_content = l1_loss(gen_features, real_features)
            elif args.vgg_criterion == 'L2':
                loss_content = l2_loss(gen_features, real_features)

            # For vgg hv loss ?? max-value
            # max_value = (1.1 * torch.max(torch.max(gen_features) , torch.max(real_features))).detach()
            # print(max_value , end = "\n\n")
            # loss_vgg_hv_psnr = hv_loss(1 - (psnr_fn(gen_features , real_features , max_value=max_value)/30) , 1 - ssim_fn_val(gen_features , real_features , max_value))

            psnr_val = psnr_fn(gen_hr, imgs_hr.detach())
            ssim_val = ssim_fn(gen_hr, imgs_hr.detach())

            # Total generator loss
            if args.method == 'M4':
                loss_hv_psnr = hv_loss(1 - (psnr_val / args.max_psnr),
                                       1 - ssim_val)
                loss_G = (loss_content * args.weight_vgg) + (
                    loss_hv_psnr * args.weight_hv) + (args.weight_gan *
                                                      loss_GAN)
            elif args.method == 'M1':
                loss_G = (loss_content * args.weight_vgg) + (args.weight_gan *
                                                             loss_GAN)
            elif args.method == 'M5':
                psnr_loss = (1 - (psnr_val / args.max_psnr)).mean()
                ssim_loss = (1 - ssim_val).mean()
                loss_G = (loss_content * args.weight_vgg) + (
                    args.weight_gan * loss_GAN) + (args.weight_pslinear *
                                                   (ssim_loss + psnr_loss))
            elif args.method == 'M6':
                real_features_normalized = normalize_VGG_features(
                    real_features)
                gen_features_normalized = normalize_VGG_features(gen_features)
                psnr_vgg_val = psnr_fn(gen_features_normalized,
                                       real_features_normalized)
                ssim_vgg_val = ssim_fn_vgg(gen_features_normalized,
                                           real_features_normalized)
                loss_vgg_hv = hv_loss(1 - (psnr_vgg_val / args.max_psnr),
                                      1 - ssim_vgg_val)
                loss_G = (args.weight_vgg_hv *
                          loss_vgg_hv) + (args.weight_gan * loss_GAN)
            elif args.method == 'M7':
                loss_hv_psnr = hv_loss(1 - (psnr_val / args.max_psnr),
                                       1 - ssim_val)
                if (epoch - start_epoch) < args.loss_mem:
                    loss_G = (loss_content * args.weight_vgg) + (
                        loss_hv_psnr * args.weight_hv) + (args.weight_gan *
                                                          loss_GAN)
                else:
                    weight_vgg = (1 / losses_log[-args.loss_mem:, 1].mean()
                                  ) * args.mem_vgg_weight
                    weight_bce = (1 / losses_log[-args.loss_mem:, 2].mean()
                                  ) * args.mem_bce_weight
                    weight_hv = (1 / losses_log[-args.loss_mem:,
                                                3].mean()) * args.mem_hv_weight
                    loss_G = (loss_content * weight_vgg) + (
                        loss_hv_psnr * weight_hv) + (loss_GAN * weight_bce)
            elif args.method == "M2":
                loss_G = hv_loss(loss_GAN * args.weight_gan,
                                 loss_content * args.weight_vgg)
            elif args.method == 'M3':
                loss_G = (args.weight_vgg * loss_content) + (
                    args.weight_hv * hv_loss(loss_GAN * args.weight_gan,
                                             loss_GAN2 * args.weight_gan))

            if args.include_l1:
                loss_G += (args.weight_l1 * l1_loss(gen_hr, imgs_hr))

            loss_G.backward()
            optimizer_G.step()
            scheduler_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # optimize discriminator
            # discriminator.train()
            for p in discriminator.parameters():
                p.requires_grad = True

            if args.method == 'M3':
                for p in discriminator2.parameters():
                    p.requires_grad = True

            if args.method == 'M3':
                pred_real2 = discriminator2(imgs_hr)
                pred_fake2 = discriminator2(gen_hr.detach())
                valid2 = torch.ones_like(pred_real2)
                fake2 = torch.zeros_like(pred_real2)
                if args.gan == "RAGAN":
                    # Adversarial loss for real and fake images (relativistic average GAN)
                    loss_real2 = bce_loss(
                        pred_real2 - pred_fake2.mean(0, keepdim=True), valid2)
                    loss_fake2 = bce_loss(
                        pred_fake2 - pred_real2.mean(0, keepdim=True), fake2)
                elif args.gan == "VGAN":
                    # Adversarial loss for real and fake images (vanilla GAN)
                    loss_real2 = bce_loss(pred_real2, valid2)
                    loss_fake2 = bce_loss(pred_fake2, fake2)

                optimizer_D2.zero_grad()
                loss_D2 = (loss_real2 + loss_fake2) / 2
                loss_D2.backward()
                optimizer_D2.step()

            optimizer_D.zero_grad()

            pred_real = discriminator(imgs_hr)
            pred_fake = discriminator(gen_hr.detach())

            if args.gan == "RAGAN":
                # Adversarial loss for real and fake images (relativistic average GAN)
                loss_real = bce_loss(
                    pred_real - pred_fake.mean(0, keepdim=True), valid)
                loss_fake = bce_loss(
                    pred_fake - pred_real.mean(0, keepdim=True), fake)
            elif args.gan == "VGAN":
                # Adversarial loss for real and fake images (vanilla GAN)
                loss_real = bce_loss(pred_real, valid)
                loss_fake = bce_loss(pred_fake, fake)

            # Total loss
            loss_D = (loss_real + loss_fake) / 2

            if args.method == 'M7':
                if (epoch - start_epoch) >= args.loss_mem:
                    weight_dis = (1 / losses_log[-args.loss_mem:, 5].mean()
                                  ) * args.mem_bce_weight
                    loss_D = loss_D / weight_dis

            loss_D.backward()
            optimizer_D.step()

            if args.method == "M4" or args.method == "M7":
                losses_gen += np.array([
                    loss_content.item(),
                    loss_GAN.item(),
                    loss_hv_psnr.item(),
                    loss_G.item(),
                    loss_D.item(),
                    psnr_val.mean().item(),
                    ssim_val.mean().item(),
                ])
            elif args.method == "M6":
                losses_gen += np.array([
                    loss_content.item(),
                    loss_GAN.item(),
                    psnr_vgg_val.mean().item(),
                    ssim_vgg_val.mean().item(),
                    loss_vgg_hv.item(),
                    loss_G.item(),
                    loss_D.item(),
                    psnr_val.mean().item(),
                    ssim_val.mean().item(),
                ])
            else:
                losses_gen += np.array([
                    loss_content.item(),
                    loss_GAN.item(),
                    loss_G.item(),
                    loss_D.item(),
                    psnr_val.mean().item(),
                    ssim_val.mean().item(),
                ])

        losses_gen /= batch_count
        losses_gen = list(losses_gen)
        losses_gen.insert(0, epoch)

        write_to_csv_file(os.path.join(args.output_path, 'train_log.csv'),
                          losses_gen)

        if (losses_log == np.zeros(loss_len + 1)).sum() == loss_len + 1:
            losses_log = np.expand_dims(np.array(losses_gen), 0)
        else:
            losses_log = np.vstack((losses_log, losses_gen))

        if epoch % args.print_every == 0:
            print('Epoch', epoch, 'Loss GAN :', losses_gen)

        if epoch % args.plot_every == 0:
            plot_image(epoch, generator, test_dataloader)

        if epoch % args.test_every == 0:
            test(epoch, generator, test_dataloader)

        if epoch % args.save_model_every == 0:
            checkpoint = {
                'epoch': epoch + 1,
                'gen_state_dict': generator.state_dict(),
                'gen_optimizer_dict': optimizer_G.state_dict(),
                'gen_scheduler_dict': scheduler_G.state_dict(),
                'dis_state_dict': discriminator.state_dict(),
                'dis_optimizer_dict': optimizer_D.state_dict(),
            }
            os.makedirs(os.path.join(args.output_path, 'saved_model'),
                        exist_ok=True)
            torch.save(
                checkpoint,
                os.path.join(args.output_path, 'saved_model',
                             'checkpoint_' + str(epoch) + ".pth"))