Example #1
0
def get_basic_meterics(img,img_gt):
    img = tensor2image(img)
    img_gt = tensor2image(img_gt)

    psnr = PSNR(img,img_gt)
    ssim = SSIM(img,img_gt,multichannel=True)
    return img,img_gt,psnr,ssim
def test(model, testloader, save_prefix, learned, testing, device):
    model.eval()

    psnr = 0.
    ssim = 0.

    for i, (data, target) in enumerate(testloader):
        name = target.topk(1)[1]
        data = data.to(device)
        target = target.to(device)

        with torch.no_grad():
            reconstructed_x, _, _ = model(data, target)
        reconstructed_x = reconstructed_x.cpu().numpy()
        reconstructed_x = (reconstructed_x * 255).astype(np.uint8)
        original_x = (data.cpu().numpy() * 255).astype(np.uint8)

        for j in range(reconstructed_x.shape[0]):
            psnr += PSNR(original_x[j, 0], reconstructed_x[j, 0])
            ssim += SSIM(original_x[j, 0], reconstructed_x[j, 0])
            im = Image.fromarray(reconstructed_x[j, 0])
            im = im.convert("L")
            im.save(save_prefix + "{}_{}.png".format(
                str(name[j].item()), str(i * data.shape[0] + j)))

    psnr = psnr / len(testloader)
    ssim = ssim / len(testloader)
    print('[%d, %d] PSNR: %.4f, SSIM %.4f' % (learned, testing, psnr, ssim))
    return
Example #3
0
def compute_measures(img_path, lab_path, scale, mode='rgb'):
    img = io.imread(img_path)
    lab = io.imread(lab_path)
    if mode == 'ycbcr':
        img = color.rgb2ycbcr(img)[:, :, 0]
        lab = color.rgb2ycbcr(lab)[:, :, 0]
    if scale > 1:
        img = imgcrop(img, scale)
        img = shave(img, scale)
        lab = imgcrop(lab, scale)
        lab = shave(lab, scale)

    if len(lab.shape) == 2:
        psnr = PSNR(img, lab, data_range=255)
        ssim = SSIM(img, lab, data_range=255, multichannel=False)
    else:
        psnr = PSNR(img, lab, data_range=255)
        ssim = SSIM(img, lab, data_range=255, multichannel=True)
    img_name = img_path.split('/')[-1]
    process_info = f'Processing {img_name} ...'

    return process_info, psnr, ssim
Example #4
0
def PSNRHeatMap(gt_file, test_file, window_size, psnrs):
    gt = cv2.VideoCapture(gt_file)
    test = cv2.VideoCapture(test_file)
    width = int(gt.get(cv2.CAP_PROP_FRAME_WIDTH))  # float `width`
    height = int(gt.get(cv2.CAP_PROP_FRAME_HEIGHT))  # float `height`
    print(width, height)

    outname = test_file.split('.mp4')[0]
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    outname = outname.split('/')[-1]
    print(f'{sys.argv[1]}/PSNR/{outname}_PSNR_CROP.mp4')
    out = cv2.VideoWriter(f'{sys.argv[1]}/PSNR/{outname}_PSNR_CROP.mp4',
                          fourcc, 5.0, (width, height))
    cnt = 0

    while True:
        ret1, gt_frame = gt.read()
        ret2, test_frame = test.read()

        if ret1 and ret2:
            pic = np.zeros((height, width))

            for i in range(0, height, window_size):
                for j in range(0, width, window_size):
                    psnr = PSNR(gt_frame[i:i + window_size, j:j + window_size],
                                test_frame[i:i + window_size,
                                           j:j + window_size],
                                data_range=255)
                    pic[i:i + window_size, j:j + window_size] = psnr
            pic = 255 - (pic) * 255 / 60
            pic = np.uint8(pic)
            psnr = round(psnrs[cnt], 2)
            pic = cv2.applyColorMap(pic, cv2.COLORMAP_JET)
            cv2.putText(pic, f"PSNR = {psnr}", (10, 20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
            out.write(pic)
            cnt += 1
            if cnt > 50:
                break
        else:
            break
    out.release()
Example #5
0
    def optimize_parameters(self, r_low, g1_low, g2_low, b_low, gt,
                            for_amplifier):

        #         num = random.randint(0, 191)
        #         self.num = (num//3)*3

        r_low = r_low.to(self.device)
        g1_low = g1_low.to(self.device)
        g2_low = g2_low.to(self.device)
        b_low = b_low.to(self.device)
        gt = gt.to(self.device)
        for_amplifier = for_amplifier.to(self.device)

        self.model.train()

        self.optimizer.zero_grad()

        pred_output, gamma = self.model(r_low, g1_low, g2_low, b_low,
                                        for_amplifier)

        # THIS IS THE UNPACK operation done using for loop so that readers can understand. The vectorised version of it which is faster can be found in the TESTING code.
        plot_out_GT = torch.zeros(1, 3, 512, 512,
                                  dtype=torch.float).to(self.device)
        plot_out_pred = torch.zeros(1, 3, 512, 512,
                                    dtype=torch.float).to(self.device)
        counttt = 0
        for ii in range(8):
            for jj in range(8):

                plot_out_GT[:, :, ii:opt['patch']:8,
                            jj:self.opt['patch']:8] = gt[:, counttt:counttt +
                                                         3, :, :]
                plot_out_pred[:, :, ii:opt['patch']:8, jj:self.
                              opt['patch']:8] = pred_output[:,
                                                            counttt:counttt +
                                                            3, :, :]

                counttt = counttt + 3

        plot_out_pred = self.relu(plot_out_pred)

        blurLoss = self.criterion(plot_out_GT, plot_out_pred)

        regularization_loss = 0
        for param in self.model.parameters():
            regularization_loss += torch.sum(torch.abs(param))

        if self.count < self.opt['switch']:
            self.loss = (blurLoss) + (1e-6 * regularization_loss) + (
                1000 * self.TVLoss(plot_out_pred)
            ) + (
                10 * self.blur_color_loss(plot_out_GT.detach(), plot_out_pred)
            ) + (3 * self.perceptual_loss(plot_out_GT.detach(), plot_out_pred))
            if self.count % self.opt['text_prnt_freq'] == 0:
                #                 print('TVLoss : {0: .6f}'.format(1000*self.TVLoss(pred_output)))
                #                 print('L1loss MAIN : {0: .4f}'.format(5*blurLoss))
                #                 print('ColorLoss : {0: .4f}'.format(10*self.blur_color_loss(imgs_op,pred_output)))
                #                 print('reg_loss : {0: .4f}'.format(1e-6*regularization_loss))
                #                 print('PerceptualLoss : {0: .4f}'.format(3*self.perceptual_loss(imgs_op,pred_output)))
                print('Count : {}\n'.format(self.count))
                print(gamma)
        else:
            self.loss = (blurLoss) + (1e-6 * regularization_loss) + (
                400 * self.TVLoss(plot_out_pred)) + (1 * self.blur_color_loss(
                    plot_out_GT.detach(), plot_out_pred)) + (
                        3 * self.perceptual_loss(plot_out_GT.detach(),
                                                 plot_out_pred))
            if self.count % self.opt['text_prnt_freq'] == 0:
                #                 print('TVLoss : {0: .6f}'.format(400*self.TVLoss(pred_output)))
                #                 print('L1loss MAIN : {0: .4f}'.format(3*blurLoss))
                #                 print('ColorLoss : {0: .4f}'.format(1*self.blur_color_loss(imgs_op,pred_output)))
                #                 print('reg_loss : {0: .4f}'.format(1e-6*regularization_loss))
                #                 print('PerceptualLoss : {0: .4f}'.format(3*self.perceptual_loss(imgs_op,pred_output)))
                print('Count : {}\n'.format(self.count))
                print(gamma)

        self.loss.backward()

        self.optimizer.step()

        self.count += 1
        if self.count % 10 == 0:
            print(self.count)

        if self.count % opt['fig_freq'] == 0:

            plot_out_pred = (np.clip(
                plot_out_pred[0].detach().cpu().numpy().transpose(1, 2, 0), 0,
                1) * 255).astype(np.uint8)
            plot_out_GT = (np.clip(
                plot_out_GT[0].detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
                           * 255).astype(np.uint8)
            #             print(gt.shape)
            #             print(np.dtype(pred_output))
            #             counttt = 0
            #             plot_out_GT = np.zeros((self.opt['patch'],self.opt['patch'],3),dtype=np.uint8)
            #             plot_out_pred = np.zeros((self.opt['patch'],self.opt['patch'],3),dtype=np.uint8)

            #             for ii in range(8):
            #                 for jj in range(8):

            #                     plot_out_GT[ii:opt['patch']:8,jj:self.opt['patch']:8,:] = gt[:,:,counttt:counttt+3]
            #                     plot_out_pred[ii:opt['patch']:8,jj:self.opt['patch']:8,:] = pred_output[:,:,counttt:counttt+3]

            #                     counttt=counttt+3

            print("PSNR: {0:.3f}, SSIM: {1:.3f}, RMSE:{2:.3f}".format(
                PSNR(plot_out_GT, plot_out_pred),
                SSIM(plot_out_GT, plot_out_pred, multichannel=True),
                NRMSE(plot_out_GT, plot_out_pred)))

            #             print('Input:')
            #             plt.figure(figsize=(self.opt['fig_size'], self.opt['fig_size']))

            #             plt.imshow(plot_out_ip)
            #             plt.show()

            #             print('Predicted Output:')

            imageio.imwrite('pred_{}.jpg'.format(self.count), plot_out_pred)
            imageio.imwrite('GT_{}.jpg'.format(self.count), plot_out_GT)

            print(gamma)

        if self.count in opt['save_freq']:
            torch.save(
                {
                    'model': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict()
                }, self.opt['dir_root'] + 'weights/' + self.opt['exp_name'] +
                '_{}'.format(self.count))
Example #6
0
    def optimize_parameters(self, low, gt, for_amplifier):

        #         num = random.randint(0, 191)
        #         self.num = (num//3)*3

        low = low.to(self.device)
        #        g1_low=g1_low.to(self.device)
        #        g2_low=g2_low.to(self.device)
        #        b_low=b_low.to(self.device)
        gt = gt.to(self.device)
        for_amplifier = for_amplifier.to(self.device)

        self.model.eval()

        with torch.no_grad():
            beg = time.time()
            pred, gamma = self.model(low, for_amplifier)
            end = time.time()
            #print('rlow {}'.format(r_low.size()))
            #print('pred {}'.format(pred.size()))

        plot_out_GT = gt  #.reshape(-1,8,3,356,532).permute(2,3,0,4,1).reshape(1,3,2848,4256)
        plot_out_pred = pred
        #        torch.zeros(1,3,2848,4256, dtype=torch.float).to(self.device)
        #        counttt=0
        #        for ii in range(8):
        #                for jj in range(8):

        #                    plot_out_GT[:,:,ii:2848:8,jj:4256:8] = gt[:,counttt:counttt+3,:,:]
        #                    plot_out_pred[:,:,ii:2848:8,jj:4256:8] = pred[:,counttt:counttt+3,:,:]
        #
        #                    counttt=counttt+3
        #

        #        plot_out_pred = self.relu(plot_out_pred)

        self.count += 1

        if True:

            plot_out_pred = (np.clip(
                plot_out_pred[0].detach().cpu().numpy().transpose(1, 2, 0), 0,
                1) * 255).astype(np.uint8)
            plot_out_GT = (np.clip(
                plot_out_GT[0].detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
                           * 255).astype(np.uint8)

            psnrr = PSNR(plot_out_GT, plot_out_pred)
            ssimm = SSIM(plot_out_GT, plot_out_pred, multichannel=True)

            print("PSNR: {0:.3f}, SSIM: {1:.3f}, RMSE:{2:.3f}".format(
                psnrr, ssimm, NRMSE(plot_out_GT, plot_out_pred)))

            self.psnr += psnrr
            self.ssim += ssimm

            print('Mean PSNR: {}'.format(self.psnr / self.count))
            print('Mean SSIM: {}'.format(self.ssim / self.count))

            #            imageio.imwrite('/media/mohit/data/mohit/chen_dark_cvpr_18_dataset/Sony/results/sir_amplifier/{}_IMG_PRED.jpg'.format(self.count), plot_out_pred)
            #           imageio.imwrite('/media/mohit/data/mohit/chen_dark_cvpr_18_dataset/Sony/results/sir_amplifier/{}_IMG_GT.jpg'.format(self.count), plot_out_GT)

            print(gamma)
Example #7
0
    def run_test(self):
        test_HR_dir = '{}/{}/HR'.format(FLAGS.input_data_dir, FLAGS.dataset)
        psnrs, ssims, durations = [], [], []
        HR_names = sorted(os.listdir(test_HR_dir))

        with tf.Graph().as_default():
            image_pl = tf.placeholder(tf.float32, shape=(1, FLAGS.patch_size, FLAGS.patch_size, FLAGS.out_channel))
            sess = tf.Session()
            if FLAGS.DecMethod == 'RODec':
                Res, ROs = self.test_infer(image_pl)
                output = tf.concat([ROs, Res], axis=0)
                saver = tf.train.Saver()
                saver.restore(sess, FLAGS.RODec_checkpoint)
            else:
                ROs, ROs_sum = self.batch_svd(image_pl, self.num_decomp)
                Res = image_pl - ROs_sum
                output = tf.concat([ROs, Res], axis=0)

            for number in tqdm(range(10)):
                start_time = time.time()
                HR_path = os.path.join(test_HR_dir, HR_names[number])
                ob, gt = self.read_image(HR_path)
                input_ob = np.expand_dims(ob, 0)
                feed_dict = {image_pl : input_ob}
                ob_dec = sess.run(output, feed_dict=feed_dict)
                input_gt = np.expand_dims(gt, 0)
                feed_dict = {image_pl : input_gt}
                gt_dec = sess.run(output, feed_dict=feed_dict)
                duration = time.time() - start_time
                durations.append(duration)
                    
                #compute psnr and ssim
                psnr = []
                ssim = []
                for i in range(self.num_decomp + 1):
                    if self.out_channel == 1:
                        p = PSNR(ob_dec[i,:,:,0], gt_dec[i,:,:,0], data_range=1.0)
                        s = SSIM(ob_dec[i,:,:,0], gt_dec[i,:,:,0], data_range=1.0, multichannel=False)
                    elif self.out_channel == 3:
                        p = PSNR(ob_dec[i], gt_dec[i], data_range=1.0)
                        s = SSIM(ob_dec[i], gt_dec[i], data_range=1.0, multichannel=True) 
                    else:
                        raise ValueError('Invalid out channel, must be 1 or 3')
                    psnr.append(p)
                    ssim.append(s)
                psnrs.append(psnr)
                ssims.append(ssim)
            dur_mean = np.mean(np.array(durations))
            dur_std = np.std(np.array(durations))
            aver_psnr = np.mean(np.array(psnrs), axis=0)
            aver_ssim = np.mean(np.array(ssims), axis=0)
            print('-------------------------------------')
            print('Runtime -> mean: %0.1f  std: %0.2f' % (dur_mean, dur_std))
            print('-------------------------------------')
            for i in range(self.num_decomp):
                print('Average PSNR for X%d : %0.2f' % (i, aver_psnr[i]))
            print('Average PSNR for E%d : %0.2f' % (self.num_decomp, aver_psnr[self.num_decomp]))
            print('-------------------------------------')
            for i in range(self.num_decomp):
                print('Average SSIM for X%d : %0.4f' % (i, aver_ssim[i]))
            print('Average SSIM for E%d : %0.4f' % (self.num_decomp, aver_ssim[self.num_decomp]))