Пример #1
0
def ssim_skimage(_ndarr_input,
                 _ndarr_ref,
                 _multichannel=False,
                 _win_size=11,
                 _K1=0.01,
                 _K2=0.03,
                 _sigma=1.5,
                 _R=255.0):
    """
    :param _ndarr_input:  _ndarr_input's range: (0, 255), dtype: np.float32
    :param _ndarr_ref:    _ndarr_ref's range:   (0, 255), dtype: np.float32
    :param _multichannel: True of False
    :return:              SSIM value
    """
    if _ndarr_input.dtype == 'uint8' or _ndarr_ref.dtype == 'uint8':
        raise ValueError('The ndarray.dtype should not be uint8.')

    if (3 <= _ndarr_input.ndim
            and 2 <= _ndarr_input.shape[2]) and (3 <= _ndarr_ref.ndim
                                                 and 2 <= _ndarr_ref.shape[2]):
        _multichannel = True
    elif _ndarr_input.ndim == 2 and _ndarr_ref.ndim == 2:
        _multichannel = False

    return SSIM(_ndarr_input / _R,
                _ndarr_ref / _R,
                multichannel=_multichannel,
                win_size=_win_size,
                data_range=1.0,
                gaussian_weights=True,
                K1=_K1,
                K2=_K2,
                sigma=_sigma)
def test(model, testloader, save_prefix, learned, testing, device):
    model.eval()

    psnr = 0.
    ssim = 0.

    for i, (data, target) in enumerate(testloader):
        name = target.topk(1)[1]
        data = data.to(device)
        target = target.to(device)

        with torch.no_grad():
            reconstructed_x, _, _ = model(data, target)
        reconstructed_x = reconstructed_x.cpu().numpy()
        reconstructed_x = (reconstructed_x * 255).astype(np.uint8)
        original_x = (data.cpu().numpy() * 255).astype(np.uint8)

        for j in range(reconstructed_x.shape[0]):
            psnr += PSNR(original_x[j, 0], reconstructed_x[j, 0])
            ssim += SSIM(original_x[j, 0], reconstructed_x[j, 0])
            im = Image.fromarray(reconstructed_x[j, 0])
            im = im.convert("L")
            im.save(save_prefix + "{}_{}.png".format(
                str(name[j].item()), str(i * data.shape[0] + j)))

    psnr = psnr / len(testloader)
    ssim = ssim / len(testloader)
    print('[%d, %d] PSNR: %.4f, SSIM %.4f' % (learned, testing, psnr, ssim))
    return
Пример #3
0
def get_basic_meterics(img,img_gt):
    img = tensor2image(img)
    img_gt = tensor2image(img_gt)

    psnr = PSNR(img,img_gt)
    ssim = SSIM(img,img_gt,multichannel=True)
    return img,img_gt,psnr,ssim
Пример #4
0
 def get_images_and_metrics(self, inp, output, target) -> (float, float, np.ndarray):
     inp = self.tensor2im(inp)
     fake = self.tensor2im(output.data)
     real = self.tensor2im(target.data)
     psnr = PSNR(fake, real)
     ssim = SSIM(fake, real, multichannel=True)
     vis_img = np.hstack((inp, fake, real))
     return psnr, ssim, vis_img
Пример #5
0
def compute_measures(img_path, lab_path, scale, mode='rgb'):
    img = io.imread(img_path)
    lab = io.imread(lab_path)
    if mode == 'ycbcr':
        img = color.rgb2ycbcr(img)[:, :, 0]
        lab = color.rgb2ycbcr(lab)[:, :, 0]
    if scale > 1:
        img = imgcrop(img, scale)
        img = shave(img, scale)
        lab = imgcrop(lab, scale)
        lab = shave(lab, scale)

    if len(lab.shape) == 2:
        psnr = PSNR(img, lab, data_range=255)
        ssim = SSIM(img, lab, data_range=255, multichannel=False)
    else:
        psnr = PSNR(img, lab, data_range=255)
        ssim = SSIM(img, lab, data_range=255, multichannel=True)
    img_name = img_path.split('/')[-1]
    process_info = f'Processing {img_name} ...'

    return process_info, psnr, ssim
Пример #6
0
    def optimize_parameters(self, r_low, g1_low, g2_low, b_low, gt,
                            for_amplifier):

        #         num = random.randint(0, 191)
        #         self.num = (num//3)*3

        r_low = r_low.to(self.device)
        g1_low = g1_low.to(self.device)
        g2_low = g2_low.to(self.device)
        b_low = b_low.to(self.device)
        gt = gt.to(self.device)
        for_amplifier = for_amplifier.to(self.device)

        self.model.train()

        self.optimizer.zero_grad()

        pred_output, gamma = self.model(r_low, g1_low, g2_low, b_low,
                                        for_amplifier)

        # THIS IS THE UNPACK operation done using for loop so that readers can understand. The vectorised version of it which is faster can be found in the TESTING code.
        plot_out_GT = torch.zeros(1, 3, 512, 512,
                                  dtype=torch.float).to(self.device)
        plot_out_pred = torch.zeros(1, 3, 512, 512,
                                    dtype=torch.float).to(self.device)
        counttt = 0
        for ii in range(8):
            for jj in range(8):

                plot_out_GT[:, :, ii:opt['patch']:8,
                            jj:self.opt['patch']:8] = gt[:, counttt:counttt +
                                                         3, :, :]
                plot_out_pred[:, :, ii:opt['patch']:8, jj:self.
                              opt['patch']:8] = pred_output[:,
                                                            counttt:counttt +
                                                            3, :, :]

                counttt = counttt + 3

        plot_out_pred = self.relu(plot_out_pred)

        blurLoss = self.criterion(plot_out_GT, plot_out_pred)

        regularization_loss = 0
        for param in self.model.parameters():
            regularization_loss += torch.sum(torch.abs(param))

        if self.count < self.opt['switch']:
            self.loss = (blurLoss) + (1e-6 * regularization_loss) + (
                1000 * self.TVLoss(plot_out_pred)
            ) + (
                10 * self.blur_color_loss(plot_out_GT.detach(), plot_out_pred)
            ) + (3 * self.perceptual_loss(plot_out_GT.detach(), plot_out_pred))
            if self.count % self.opt['text_prnt_freq'] == 0:
                #                 print('TVLoss : {0: .6f}'.format(1000*self.TVLoss(pred_output)))
                #                 print('L1loss MAIN : {0: .4f}'.format(5*blurLoss))
                #                 print('ColorLoss : {0: .4f}'.format(10*self.blur_color_loss(imgs_op,pred_output)))
                #                 print('reg_loss : {0: .4f}'.format(1e-6*regularization_loss))
                #                 print('PerceptualLoss : {0: .4f}'.format(3*self.perceptual_loss(imgs_op,pred_output)))
                print('Count : {}\n'.format(self.count))
                print(gamma)
        else:
            self.loss = (blurLoss) + (1e-6 * regularization_loss) + (
                400 * self.TVLoss(plot_out_pred)) + (1 * self.blur_color_loss(
                    plot_out_GT.detach(), plot_out_pred)) + (
                        3 * self.perceptual_loss(plot_out_GT.detach(),
                                                 plot_out_pred))
            if self.count % self.opt['text_prnt_freq'] == 0:
                #                 print('TVLoss : {0: .6f}'.format(400*self.TVLoss(pred_output)))
                #                 print('L1loss MAIN : {0: .4f}'.format(3*blurLoss))
                #                 print('ColorLoss : {0: .4f}'.format(1*self.blur_color_loss(imgs_op,pred_output)))
                #                 print('reg_loss : {0: .4f}'.format(1e-6*regularization_loss))
                #                 print('PerceptualLoss : {0: .4f}'.format(3*self.perceptual_loss(imgs_op,pred_output)))
                print('Count : {}\n'.format(self.count))
                print(gamma)

        self.loss.backward()

        self.optimizer.step()

        self.count += 1
        if self.count % 10 == 0:
            print(self.count)

        if self.count % opt['fig_freq'] == 0:

            plot_out_pred = (np.clip(
                plot_out_pred[0].detach().cpu().numpy().transpose(1, 2, 0), 0,
                1) * 255).astype(np.uint8)
            plot_out_GT = (np.clip(
                plot_out_GT[0].detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
                           * 255).astype(np.uint8)
            #             print(gt.shape)
            #             print(np.dtype(pred_output))
            #             counttt = 0
            #             plot_out_GT = np.zeros((self.opt['patch'],self.opt['patch'],3),dtype=np.uint8)
            #             plot_out_pred = np.zeros((self.opt['patch'],self.opt['patch'],3),dtype=np.uint8)

            #             for ii in range(8):
            #                 for jj in range(8):

            #                     plot_out_GT[ii:opt['patch']:8,jj:self.opt['patch']:8,:] = gt[:,:,counttt:counttt+3]
            #                     plot_out_pred[ii:opt['patch']:8,jj:self.opt['patch']:8,:] = pred_output[:,:,counttt:counttt+3]

            #                     counttt=counttt+3

            print("PSNR: {0:.3f}, SSIM: {1:.3f}, RMSE:{2:.3f}".format(
                PSNR(plot_out_GT, plot_out_pred),
                SSIM(plot_out_GT, plot_out_pred, multichannel=True),
                NRMSE(plot_out_GT, plot_out_pred)))

            #             print('Input:')
            #             plt.figure(figsize=(self.opt['fig_size'], self.opt['fig_size']))

            #             plt.imshow(plot_out_ip)
            #             plt.show()

            #             print('Predicted Output:')

            imageio.imwrite('pred_{}.jpg'.format(self.count), plot_out_pred)
            imageio.imwrite('GT_{}.jpg'.format(self.count), plot_out_GT)

            print(gamma)

        if self.count in opt['save_freq']:
            torch.save(
                {
                    'model': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict()
                }, self.opt['dir_root'] + 'weights/' + self.opt['exp_name'] +
                '_{}'.format(self.count))
Пример #7
0
            y = torch.zeros(batch_size, args.channels, h, w)

            count = 0
            for a in idx1:
                for b in idx2:
                    y[:, :, a:a + args.block_size,
                      b:b + args.block_size] = temp[count, :, :, :, :]
                    count = count + 1

            image = y[:, :, 0:h - h_lack, 0:w - w_lack]
            image = torch.squeeze(image)

            diff = image.numpy() - ori_x.numpy()
            mse = np.mean(np.square(diff))
            psnr = 10 * np.log10(1 / mse)
            PSNR_total = PSNR_total + psnr

            ssim = SSIM(image.numpy(), ori_x.numpy(), data_range=1)
            SSIM_total = SSIM_total + ssim

            image = tensor2image(image)
            image.save("./dataset/result/image/({}).jpg".format(i))
            print(
                "=> process {} done! time: {:.3f}s, PSNR: {:.3f}, SSIM: {:.3f}."
                .format(i, end_time - start_time, psnr, ssim))

        print(
            "=> All the {} images done!, your AVG PSNR: {:.3f}, AVG SSIM: {:.3f}."
            .format(File_No, PSNR_total / File_No, SSIM_total / File_No))
Пример #8
0
    def optimize_parameters(self, low, gt, for_amplifier):

        #         num = random.randint(0, 191)
        #         self.num = (num//3)*3

        low = low.to(self.device)
        #        g1_low=g1_low.to(self.device)
        #        g2_low=g2_low.to(self.device)
        #        b_low=b_low.to(self.device)
        gt = gt.to(self.device)
        for_amplifier = for_amplifier.to(self.device)

        self.model.eval()

        with torch.no_grad():
            beg = time.time()
            pred, gamma = self.model(low, for_amplifier)
            end = time.time()
            #print('rlow {}'.format(r_low.size()))
            #print('pred {}'.format(pred.size()))

        plot_out_GT = gt  #.reshape(-1,8,3,356,532).permute(2,3,0,4,1).reshape(1,3,2848,4256)
        plot_out_pred = pred
        #        torch.zeros(1,3,2848,4256, dtype=torch.float).to(self.device)
        #        counttt=0
        #        for ii in range(8):
        #                for jj in range(8):

        #                    plot_out_GT[:,:,ii:2848:8,jj:4256:8] = gt[:,counttt:counttt+3,:,:]
        #                    plot_out_pred[:,:,ii:2848:8,jj:4256:8] = pred[:,counttt:counttt+3,:,:]
        #
        #                    counttt=counttt+3
        #

        #        plot_out_pred = self.relu(plot_out_pred)

        self.count += 1

        if True:

            plot_out_pred = (np.clip(
                plot_out_pred[0].detach().cpu().numpy().transpose(1, 2, 0), 0,
                1) * 255).astype(np.uint8)
            plot_out_GT = (np.clip(
                plot_out_GT[0].detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
                           * 255).astype(np.uint8)

            psnrr = PSNR(plot_out_GT, plot_out_pred)
            ssimm = SSIM(plot_out_GT, plot_out_pred, multichannel=True)

            print("PSNR: {0:.3f}, SSIM: {1:.3f}, RMSE:{2:.3f}".format(
                psnrr, ssimm, NRMSE(plot_out_GT, plot_out_pred)))

            self.psnr += psnrr
            self.ssim += ssimm

            print('Mean PSNR: {}'.format(self.psnr / self.count))
            print('Mean SSIM: {}'.format(self.ssim / self.count))

            #            imageio.imwrite('/media/mohit/data/mohit/chen_dark_cvpr_18_dataset/Sony/results/sir_amplifier/{}_IMG_PRED.jpg'.format(self.count), plot_out_pred)
            #           imageio.imwrite('/media/mohit/data/mohit/chen_dark_cvpr_18_dataset/Sony/results/sir_amplifier/{}_IMG_GT.jpg'.format(self.count), plot_out_GT)

            print(gamma)
Пример #9
0
    def run_test(self):
        test_HR_dir = '{}/{}/HR'.format(FLAGS.input_data_dir, FLAGS.dataset)
        psnrs, ssims, durations = [], [], []
        HR_names = sorted(os.listdir(test_HR_dir))

        with tf.Graph().as_default():
            image_pl = tf.placeholder(tf.float32, shape=(1, FLAGS.patch_size, FLAGS.patch_size, FLAGS.out_channel))
            sess = tf.Session()
            if FLAGS.DecMethod == 'RODec':
                Res, ROs = self.test_infer(image_pl)
                output = tf.concat([ROs, Res], axis=0)
                saver = tf.train.Saver()
                saver.restore(sess, FLAGS.RODec_checkpoint)
            else:
                ROs, ROs_sum = self.batch_svd(image_pl, self.num_decomp)
                Res = image_pl - ROs_sum
                output = tf.concat([ROs, Res], axis=0)

            for number in tqdm(range(10)):
                start_time = time.time()
                HR_path = os.path.join(test_HR_dir, HR_names[number])
                ob, gt = self.read_image(HR_path)
                input_ob = np.expand_dims(ob, 0)
                feed_dict = {image_pl : input_ob}
                ob_dec = sess.run(output, feed_dict=feed_dict)
                input_gt = np.expand_dims(gt, 0)
                feed_dict = {image_pl : input_gt}
                gt_dec = sess.run(output, feed_dict=feed_dict)
                duration = time.time() - start_time
                durations.append(duration)
                    
                #compute psnr and ssim
                psnr = []
                ssim = []
                for i in range(self.num_decomp + 1):
                    if self.out_channel == 1:
                        p = PSNR(ob_dec[i,:,:,0], gt_dec[i,:,:,0], data_range=1.0)
                        s = SSIM(ob_dec[i,:,:,0], gt_dec[i,:,:,0], data_range=1.0, multichannel=False)
                    elif self.out_channel == 3:
                        p = PSNR(ob_dec[i], gt_dec[i], data_range=1.0)
                        s = SSIM(ob_dec[i], gt_dec[i], data_range=1.0, multichannel=True) 
                    else:
                        raise ValueError('Invalid out channel, must be 1 or 3')
                    psnr.append(p)
                    ssim.append(s)
                psnrs.append(psnr)
                ssims.append(ssim)
            dur_mean = np.mean(np.array(durations))
            dur_std = np.std(np.array(durations))
            aver_psnr = np.mean(np.array(psnrs), axis=0)
            aver_ssim = np.mean(np.array(ssims), axis=0)
            print('-------------------------------------')
            print('Runtime -> mean: %0.1f  std: %0.2f' % (dur_mean, dur_std))
            print('-------------------------------------')
            for i in range(self.num_decomp):
                print('Average PSNR for X%d : %0.2f' % (i, aver_psnr[i]))
            print('Average PSNR for E%d : %0.2f' % (self.num_decomp, aver_psnr[self.num_decomp]))
            print('-------------------------------------')
            for i in range(self.num_decomp):
                print('Average SSIM for X%d : %0.4f' % (i, aver_ssim[i]))
            print('Average SSIM for E%d : %0.4f' % (self.num_decomp, aver_ssim[self.num_decomp]))