Esempio n. 1
0
def test_ssim_single_ch_identical():
    """Test torch implementation of the SSIM (structural similarity) metric for grayscale picture

    Check that comparing identical pictures returns 1
    """
    ssim = SSIM(data_range=255, channels=1)
    images = _load_test_images()
    for image in images:
        image = color.rgb2gray(image)
        # add dimensions for the channel and the batch dimension
        im_tensor = torch.Tensor(image).unsqueeze(0).unsqueeze(0)

        result = ssim(im_tensor, im_tensor).numpy()
        np.testing.assert_allclose(result, 1., rtol=RTOL)
Esempio n. 2
0
def test_ssim_single_channel():
    """Test torch implementation of the SSIM (structural similarity) metric for grayscale picture"""
    ssim = SSIM(data_range=255, channels=1)
    images = _load_test_images()
    for image in images:
        image = color.rgb2gray(image)
        image_c = _corrupt_image(image)
        # add dimensions for the channel and the batch dimension
        im_tensor = torch.Tensor(image).unsqueeze(0).unsqueeze(0)
        im_tensor_c = torch.Tensor(image_c).unsqueeze(0).unsqueeze(0)

        result = ssim(im_tensor, im_tensor_c).numpy()
        desired = _groundtruth_ssim(image, image_c, multichannel=False)
        np.testing.assert_allclose(result, desired, rtol=RTOL)
Esempio n. 3
0
def test_ssim_multi_ch_identical():
    """Test torch implementation of the SSIM (structural similarity) metric for color picture

    Check that comparing identical pictures returns 1
    """
    ssim = SSIM(data_range=255, channels=4)
    images = _load_test_images()
    for image in images:
        # the transpose is necessary to get the structure NCHW instead of NHWC
        im_tensor = torch.Tensor(image).transpose(2,
                                                  1).transpose(1,
                                                               0).unsqueeze(0)

        result = ssim(im_tensor, im_tensor).numpy()
        np.testing.assert_allclose(result, 1., rtol=RTOL)
Esempio n. 4
0
def test_ssim_multi_channel():
    """Test torch implementation of the SSIM (structural similarity) metric for color picture"""
    ssim = SSIM(data_range=255, channels=4)
    images = _load_test_images()
    for image in images:
        image_c = _corrupt_image(image)
        # the transpose is necessary to get the structure NCHW instead of NHWC
        im_tensor = torch.Tensor(image).transpose(2,
                                                  1).transpose(1,
                                                               0).unsqueeze(0)
        im_tensor_c = torch.Tensor(image_c).transpose(2, 1).transpose(
            1, 0).unsqueeze(0)

        result = ssim(im_tensor, im_tensor_c).numpy()
        desired = _groundtruth_ssim(image, image_c, multichannel=True)
        np.testing.assert_allclose(result, desired, rtol=RTOL)
Esempio n. 5
0
    def validate(self, val_batch, current_step):
        avg_psnr = 0.0
        avg_ssim = 0.0
        idx = 0
        for _, val_data in enumerate(val_batch):
            idx += 1
            img_name = os.path.splitext(
                os.path.basename(val_data['LR_path'][0]))[0]
            img_dir = os.path.join(
                self.opt['path']['checkpoints']['val_image_dir'], img_name)
            util.mkdir(img_dir)

            self.val_lr = val_data['LR'].to(self.device)
            self.val_hr = val_data['HR'].to(self.device)

            self.G.eval()
            with torch.no_grad():
                self.val_sr = self.G(self.val_lr)
            self.G.train()

            val_LR = self.val_lr.detach()[0].float().cpu()
            val_SR = self.val_sr.detach()[0].float().cpu()
            val_HR = self.val_hr.detach()[0].float().cpu()

            sr_img = util.tensor2img(val_SR)  # uint8
            gt_img = util.tensor2img(val_HR)  # uint8

            # Save SR images for reference
            save_img_path = os.path.join(
                img_dir, '{:s}_{:d}.png'.format(img_name, current_step))
            cv2.imwrite(save_img_path, sr_img)

            # calculate PSNR
            crop_size = 4
            gt_img = gt_img / 255.
            sr_img = sr_img / 255.
            cropped_sr_img = sr_img[crop_size:-crop_size,
                                    crop_size:-crop_size, :]
            cropped_gt_img = gt_img[crop_size:-crop_size,
                                    crop_size:-crop_size, :]
            avg_psnr += PSNR(cropped_sr_img * 255, cropped_gt_img * 255)
            avg_ssim += SSIM(cropped_sr_img * 255, cropped_gt_img * 255)

        avg_psnr = avg_psnr / idx
        avg_ssim = avg_ssim / idx
        return avg_psnr, avg_ssim
Esempio n. 6
0
def test(args, model, test_dataloader):

    PSNR_total = []
    SSIM_total = []
    #model.eval()
    print('=====> test sr begin!')
    with torch.no_grad():
        for i, data in enumerate(test_dataloader):
            #torch.Size([1, 3, 320, 320])
            img_ref = data['image_center']
            img_oth = data['image_others']
            #img_oth = torch.squeeze(img_oth)
            img_adv_cen = data['img_adv_cen']
            img_adv_ref = data['img_adv_ref']
            img_oth = img_oth.squeeze(0)
            img_adv_ref = img_adv_ref.squeeze(0)
            img_ref = img_ref.expand(args.batch_size, -1, -1, -1)
            img_adv_cen = img_adv_cen.expand(args.batch_size, -1, -1, -1)

            image_others = (img_oth.cuda())[:, :, :args.im_crop_H, :args.
                                            im_crop_W].clone().float()

            #print(img_ref.shape)
            #image_ref = img_ref.expand(args.batch_size-1, -1, -1, -1)
            #image_ref = (image_ref.cuda())[:, :, :args.im_crop_H, :args.im_crop_W].clone().float()
            image_ref = (img_ref.cuda())[:, :, :args.im_crop_H, :args.
                                         im_crop_W].clone().float()
            lr_image_ref = nn.functional.avg_pool2d(image_ref,
                                                    kernel_size=args.scale)
            lr_image_others = nn.functional.avg_pool2d(image_others,
                                                       kernel_size=args.scale)
            image_adv_cen = img_adv_cen.cuda().clone().float()
            image_adv_ref = img_adv_ref.cuda().clone().float()
            '''
            hr_val = model.net_sr(lr_image_ref)
            hr_ref = model.net_sr(lr_image_others)
            #flows_ref_to_other = model.net_flow(image_ref, image_others)
            flows_ref_to_other = model.net_flow(hr_val, hr_ref)
            #flows_other_to_ref = model.net_flow(image_others, image_ref)
            #flow_12_1 = flows_ref_to_other[0]*20.0
            #flow_12_2 = flows_ref_to_other[1]*10.0
            #flow_12_3 = flows_ref_to_other[2]*5.0
            #flow_12_4 = flows_ref_to_other[3]*2.5
            #SR_conv1, SR_conv2, SR_conv3, SR_conv4 = model.net_enc(hr_val)
            #HR2_conv1, HR2_conv2, HR2_conv3, HR2_conv4 = model.net_enc(hr_ref)

            #warp_21_conv1 = model.Backward_warper(HR2_conv1, flow_12_1)
            #warp_21_conv2 = model.Backward_warper(HR2_conv2, flow_12_2)
            #warp_21_conv3 = model.Backward_warper(HR2_conv3, flow_12_3)
            #warp_21_conv4 = model.Backward_warper(HR2_conv4, flow_12_4)

            #hr_val = model.net_dec(SR_conv1, SR_conv2, SR_conv3, SR_conv4, warp_21_conv1,warp_21_conv2, warp_21_conv3,warp_21_conv4)
            #hr_val = model.net_G1(hr_val, flows_ref_to_other, model.Backward_warper, image_others)
            hr_val = model.net_G1(hr_val, flows_ref_to_other, model.Backward_warper, hr_ref)
            #print(hr_val.min(), hr_val.max())
            '''
            #hr_val = model.net_sr(lr_image_ref) + model.upsample_4(lr_image_ref)
            #hr_ref = model.net_sr(lr_image_others) + model.upsample_4(lr_image_others)
            #flows_ref_to_other = model.net_flow(hr_val, hr_ref)
            #hr_val = model.net_G1(hr_val, flows_ref_to_other, model.Backward_warper, hr_ref)

            #noise  = torch.randn(args.batch_size, args.n_colors, args.im_crop_H, args.im_crop_W).cuda() * 1e-4
            #hr_val = model.net_sr(image_ref)
            #hr_val = model.net_sr(image_ref) + model.upsample_4(image_ref)
            #hr_val = model.net_sr(image_ref)
            #hr_val = model.net_G1(hr_val)
            #hr_val = model.net_G2(image_adv_cen)
            #res    = model.net_G(hr_val)
            #res    = model.net_G1(hr_val)
            #hr_val = hr_val + res

            #hr_other_imgs = self.net_sr(lr_other_imgs)
            #hr_val = model.net_sr(lr_image_ref)
            #noise  = torch.randn(args.batch_size, args.n_colors, args.im_crop_H, args.im_crop_W).cuda() * 0.0001
            #hr_val = hr_val + model.net_G(hr_val)
            #hr_val = model.net_G1(hr_val)

            lr_feature_head = model.net_Feature_Head(lr_image_ref)
            lr_content_feature = model.net_Feature_extractor(lr_feature_head)
            lr_content_output = lr_feature_head + lr_content_feature
            hr_val = model.net_Upscalar(lr_content_output)
            hr_val = model.net_G1(hr_val)

            hr_val_numpy = hr_val.cpu()[0].permute(1, 2, 0).numpy()
            hr_val_numpy[hr_val_numpy > 1] = 1
            hr_val_numpy[hr_val_numpy < -1] = -1

            img_sr = skimage.img_as_ubyte(hr_val_numpy)
            skimage.io.imsave(
                os.path.join(args.result_dir, 'tempo', 'SR_{}.png'.format(i)),
                img_sr)
            #skimage.io.imsave(os.path.join(args.result_dir, 'tempo', 'SR_{}.png'.format(i)), hr_val_numpy)

            if args.have_gt:
                PSNR_value = PSNR(hr_val.data, image_ref)
                SSIM_value = SSIM(hr_val.data, image_ref)
                PSNR_total.append(PSNR_value)
                SSIM_total.append(SSIM_value)

                print('PSNR: {} for patch {}'.format(PSNR_value, i))
                print('SSIM: {} for patch {}'.format(SSIM_value, i))
                print('Average PSNR: {} for {} patches'.format(
                    sum(PSNR_total) / len(PSNR_total), i))
                print('Average SSIM: {} for {} patches'.format(
                    sum(SSIM_total) / len(SSIM_total), i))

            if args.save_result:
                os.makedirs(os.path.join(args.result_dir, 'HR'), exist_ok=True)
                os.makedirs(os.path.join(args.result_dir, 'LR'), exist_ok=True)
                os.makedirs(os.path.join(args.result_dir, 'REF'),
                            exist_ok=True)
                os.makedirs(os.path.join(args.result_dir, 'ADV_CEN'),
                            exist_ok=True)
                os.makedirs(os.path.join(args.result_dir, 'ADV_REF'),
                            exist_ok=True)

                #img_gt = skimage.img_as_float(torch.squeeze(img_ref).permute(1,2,0).numpy())
                img_gt = skimage.img_as_ubyte(
                    torch.squeeze(img_ref).permute(1, 2, 0).numpy())
                skimage.io.imsave(
                    os.path.join(args.result_dir, 'HR', '{}.png'.format(i)),
                    img_gt)
                skimage.io.imsave(
                    os.path.join(args.result_dir, 'HR', '{}.png'.format(i)),
                    img_gt)

                img_lr = skimage.img_as_ubyte(lr_image_ref.cpu()[0].permute(
                    1, 2, 0).numpy())
                skimage.io.imsave(
                    os.path.join(args.result_dir, 'LR', '{}.png'.format(i)),
                    img_lr)
                skimage.io.imsave(
                    os.path.join(args.result_dir, 'LR', '{}.png'.format(i)),
                    img_lr)

                img_adv_center = skimage.img_as_ubyte(
                    image_adv_cen.cpu()[0].permute(1, 2, 0).numpy())
                skimage.io.imsave(
                    os.path.join(args.result_dir, 'ADV_CEN',
                                 '{}.png'.format(i)), img_adv_center)

                for j in range(args.batch_size):
                    os.makedirs(os.path.join(args.result_dir, 'ADV_REF',
                                             '{}'.format(j)),
                                exist_ok=True)
                    img_adv_reference = skimage.img_as_ubyte(
                        image_adv_ref.cpu()[j].permute(1, 2, 0).numpy())
                    skimage.io.imsave(
                        os.path.join(args.result_dir, 'ADV_REF',
                                     '{}'.format(j), '{}.png'.format(i)),
                        img_adv_reference)

                    os.makedirs(os.path.join(args.result_dir, 'REF',
                                             '{}'.format(j)),
                                exist_ok=True)
                    img_reference = skimage.img_as_ubyte(
                        image_others.cpu()[j].permute(1, 2, 0).numpy())
                    skimage.io.imsave(
                        os.path.join(args.result_dir, 'REF', '{}'.format(j),
                                     '{}.png'.format(i)), img_reference)
Esempio n. 7
0
def test_lr(args, model, test_dataloader):

    #model.eval()
    print('=====> test existing lr begin!')
    PSNR_total = []
    SSIM_total = []

    fake_total = []
    real_total = []
    Loss_function = GANLoss()

    with torch.no_grad():
        for i, data in enumerate(test_dataloader):

            img_lr = data['lr_image']
            img_lr = img_lr.expand(args.batch_size, -1, -1, -1)
            img_lr = img_lr.cuda().clone().float()

            #hr_val = model.net_sr(img_lr)
            #flows_ref_to_other = model.net_flow(self.hr_img_ref_gt, self.hr_img_oth_gt)
            #flows_other_to_ref = model.net_flow(self.hr_img_oth_gt, self.hr_img_ref_gt)
            #flow_12_1 = self.flows_ref_to_other[0]*20.0
            #flow_12_2 = self.flows_ref_to_other[1]*10.0
            #flow_12_3 = self.flows_ref_to_other[2]*5.0
            #flow_12_4 = self.flows_ref_to_other[3]*2.5
            #SR_conv1, SR_conv2, SR_conv3, SR_conv4 = self.net_enc(self.sr_img_ref)
            #HR2_conv1, HR2_conv2, HR2_conv3, HR2_conv4 = self.net_enc(self.hr_img_oth_gt)

            #warp_21_conv1 = self.Backward_warper(HR2_conv1, flow_12_1)
            #warp_21_conv2 = self.Backward_warper(HR2_conv2, flow_12_2)
            #warp_21_conv3 = self.Backward_warper(HR2_conv3, flow_12_3)
            #warp_21_conv4 = self.Backward_warper(HR2_conv4, flow_12_4)

            #sythsis_output = self.net_dec(SR_conv1, SR_conv2, SR_conv3, SR_conv4, warp_21_conv1,warp_21_conv2, warp_21_conv3,warp_21_conv4)

            #lr_feature_head    = model.net_Feature_Head(img_lr)
            #lr_content_feature = model.net_Feature_extractor(lr_feature_head)
            #lr_content_output = lr_feature_head + lr_content_feature
            #hr_val = model.net_Upscalar(lr_content_output)

            hr_val = model.net_sr(img_lr) + model.upsample_4(img_lr)
            #hr_val = model.upsample_4(img_lr)
            #hr_val = model.net_sr(img_lr)
            #noise  = torch.randn(args.batch_size, args.n_colors, args.im_crop_H, args.im_crop_W).cuda() * 1e-4
            #hr_val = hr_val + model.net_G1(hr_val)
            hr_val = model.net_G1(hr_val)
            #hr_val = model.net_G1(hr_val)
            #hr_val = model.net_G2(hr_val)
            #m = nn.Upsample(size=[args.im_crop_H*3, args.im_crop_W*3],mode='bilinear',align_corners=True)
            #hr_val = m(hr_val)
            hr_val_numpy = hr_val.cpu()[0].permute(1, 2, 0).numpy()
            hr_val_numpy[hr_val_numpy > 1] = 1
            hr_val_numpy[hr_val_numpy < -1] = -1

            img_sr = skimage.img_as_ubyte(hr_val_numpy)
            skimage.io.imsave(
                os.path.join(args.result_dir, 'SR', 'SR_{}.png'.format(i)),
                img_sr)
            #skimage.io.imsave(os.path.join(args.result_dir, 'SR_{}.png'.format(i)), img_sr)

            #dx_hr_img_fake, dy_hr_img_fake, dxy_hr_img_fake = model.gradient_fn(hr_val)
            #hr_img_fake = torch.cat([dx_hr_img_fake, dy_hr_img_fake, dxy_hr_img_fake], dim=0)
            #fake = model.net_D(hr_img_fake)
            #fake = Loss_function(fake, target_is_real=False)

            #print('fake: {} for patch {}'.format(fake, i))
            #fake_total.append(fake)
            #print('Average fake: {} for {} patches'.format(sum(fake_total) / len(fake_total), i))

            if args.have_gt:

                img_hr = data['hr_image']
                img_hr = img_hr.expand(args.batch_size, -1, -1, -1)
                img_hr = img_hr.cuda().clone().float()

                #dx_hr_img_real, dy_hr_img_real, dxy_hr_img_real = model.gradient_fn(img_hr)
                #hr_img_real = torch.cat([dx_hr_img_real, dy_hr_img_real, dxy_hr_img_real], dim=0)
                #real = model.net_D(hr_img_real)
                #real = Loss_function(real, target_is_real=True)

                #print('real: {} for patch {}'.format(real, i))
                #real_total.append(real)
                #print('Average real: {} for {} patches'.format(sum(real_total) / len(real_total), i))

                PSNR_value = PSNR(hr_val.data, img_hr)
                SSIM_value = SSIM(hr_val.data, img_hr)

                PSNR_total.append(PSNR_value)
                SSIM_total.append(SSIM_value)

                print('PSNR: {} for patch {}'.format(PSNR_value, i))
                print('SSIM: {} for patch {}'.format(SSIM_value, i))
                print('Average PSNR: {} for {} patches'.format(
                    sum(PSNR_total) / len(PSNR_total), i))
                print('Average SSIM: {} for {} patches'.format(
                    sum(SSIM_total) / len(SSIM_total), i))