Example #1
0
def cal_dpsnr_dssim(raw_frame, cmp_frame, enhanced_t):
    dpsnr = utils.cal_psnr(torch.squeeze(enhanced_t, 0).detach().cpu().numpy(), raw_frame, data_range=1.0) -\
        utils.cal_psnr(cmp_frame, raw_frame, data_range=1.0)
    #dpsnr = compare_psnr(torch.squeeze(enhanced_t).detach().cpu().numpy(), np.squeeze(raw_frame), data_range=1) -\
    #    compare_psnr(np.squeeze(cmp_frame), np.squeeze(raw_frame), data_range=1)
    #dssim = compare_ssim(torch.squeeze(enhanced_t).detach().cpu().numpy(), np.squeeze(raw_frame), data_range=1) -\
    #    compare_ssim(np.squeeze(cmp_frame), np.squeeze(raw_frame), data_range=1)
    return dpsnr  #, dssim
Example #2
0
    def test(self, sess, test_files, ckpt_dir, save_dir):
        """
        Test MAP denoising

        Parameters
        ----------
        sess - Tensorflow session
        test_files - list of filenames of images to test
        ckpt_dir - checkpoint directory containing the pretrained model
        save_dir - directory into which the noisy and estimate images will be saved

        Returns
        -------

        """
        # init variables
        tf.initialize_all_variables().run()
        assert len(test_files) != 0, 'No testing data!'
        load_model_status, _ = self.load(sess, ckpt_dir)
        assert load_model_status == True, '[!] Load weights FAILED...'
        print(" [*] Load weights SUCCESS...")
        psnr_sum = 0
        print("[*] " + 'noise variance: ' + str(self.stddev**2) +
              " start testing...")
        for idx in range(len(test_files)):
            if self.channels == 1:
                clean_image = load_images(test_files[idx]).astype(
                    np.float32) / 255.0
            else:
                clean_image = load_images_rgb(test_files[idx]).astype(
                    np.float32) / 255.0

            output_clean_image, noisy_image = sess.run([self.Dv, self.v],
                                                       feed_dict={
                                                           self.v_ph:
                                                           clean_image,
                                                           self.is_training_ph:
                                                           False
                                                       })
            groundtruth = np.clip(255 * clean_image, 0, 255).astype('uint8')
            noisyimage = np.clip(255 * noisy_image, 0, 255).astype('uint8')
            outputimage = np.clip(255 * output_clean_image, 0,
                                  255).astype('uint8')
            # calculate PSNR
            psnr = cal_psnr(groundtruth, outputimage)
            print("img%d PSNR: %.2f" % (idx, psnr))
            psnr_sum += psnr
            save_images(path.join(save_dir, 'noisy%d.png' % idx), noisyimage)
            save_images(path.join(save_dir, 'denoised%d.png' % idx),
                        outputimage)
        avg_psnr = psnr_sum / len(test_files)
        print("--- Average PSNR %.2f ---" % avg_psnr)
Example #3
0
def validation(img, name, save_imgs=False, save_dir=None):
    kernel_generation_net.eval()
    downsampler_net.eval()
    upscale_net.eval()

    kernels, offsets_h, offsets_v = kernel_generation_net(img)
    downscaled_img = downsampler_net(img, kernels, offsets_h, offsets_v,
                                     OFFSET_UNIT)
    downscaled_img = torch.clamp(downscaled_img, 0, 1)
    downscaled_img = torch.round(downscaled_img * 255)

    reconstructed_img = upscale_net(downscaled_img / 255.0)

    img = img * 255
    img = img.data.cpu().numpy().transpose(0, 2, 3, 1)
    img = np.uint8(img)

    reconstructed_img = torch.clamp(reconstructed_img, 0, 1) * 255
    reconstructed_img = reconstructed_img.data.cpu().numpy().transpose(
        0, 2, 3, 1)
    reconstructed_img = np.uint8(reconstructed_img)

    downscaled_img = downscaled_img.data.cpu().numpy().transpose(0, 2, 3, 1)
    downscaled_img = np.uint8(downscaled_img)

    orig_img = img[0, ...].squeeze()
    downscaled_img = downscaled_img[0, ...].squeeze()
    recon_img = reconstructed_img[0, ...].squeeze()

    if save_imgs and save_dir:
        img = Image.fromarray(orig_img)
        img.save(os.path.join(save_dir, name + '_orig.png'))

        img = Image.fromarray(downscaled_img)
        img.save(os.path.join(save_dir, name + '_down.png'))

        img = Image.fromarray(recon_img)
        img.save(os.path.join(save_dir, name + '_recon.png'))

    psnr = utils.cal_psnr(orig_img[SCALE:-SCALE, SCALE:-SCALE, ...],
                          recon_img[SCALE:-SCALE, SCALE:-SCALE, ...],
                          benchmark=BENCHMARK)

    orig_img_y = rgb2ycbcr(orig_img)[:, :, 0]
    recon_img_y = rgb2ycbcr(recon_img)[:, :, 0]
    orig_img_y = orig_img_y[SCALE:-SCALE, SCALE:-SCALE, ...]
    recon_img_y = recon_img_y[SCALE:-SCALE, SCALE:-SCALE, ...]

    ssim = utils.calc_ssim(recon_img_y, orig_img_y)

    return psnr, ssim
Example #4
0
    def evaluate(self, sess, iter_num, test_data, sample_dir, summary_writer):
        """
        Evaluate denoising

        Parameters
        ----------
        sess - Tensorfow session
        iter_num - Iteration number
        test_data - list of array of different size, 4-D, pixel value range is 0-255
        sample_dir - evalutation dataset folder name (found in ./data)
        summary_writer - Tensorflow SummaryWriter

        Returns
        -------

        """
        # assert test_data value range is 0-255
        print("[*] Evaluating...")
        psnr_sum = 0
        for idx in range(len(test_data)):
            clean_image = test_data[idx].astype(np.float32) / 255.0
            output_clean_image, noisy_image, psnr_summary = sess.run(
                [self.Dv, self.v, self.summary_psnr],
                feed_dict={
                    self.v_ph: clean_image,
                    self.is_training_ph: False
                })
            summary_writer.add_summary(psnr_summary, iter_num)
            groundtruth = np.clip(test_data[idx], 0, 255).astype('uint8')
            noisyimage = np.clip(255 * noisy_image, 0, 255).astype('uint8')
            outputimage = np.clip(255 * output_clean_image, 0,
                                  255).astype('uint8')
            # calculate PSNR
            psnr = cal_psnr(groundtruth, outputimage)
            # print("img%d PSNR: %.2f" % (idx + 1, psnr))
            psnr_sum += psnr
            save_images(
                path.join(sample_dir, 'test%d_%d.png' % (idx + 1, iter_num)),
                groundtruth, noisyimage, outputimage)
        avg_psnr = psnr_sum / len(test_data)

        print("--- Test ---- Average PSNR %.2f ---" % avg_psnr)
Example #5
0
    def evaluate(self, iter_num, evaln_data, evalc_data, sample_dir,
                 summary_merged, summary_writer):
        """
        -i- evaln_data : list, of 4D array of different size.
            Each array is a noisy image for evaluation, value range 0-255.
        -i- evalc_data : list, of 4D array of different size.
            Each array is a clean image for evaluation, value range 0-255.
        """
        # assert eval_data value range is 0-255
        print("[*] Evaluating...")
        psnr_sum = 0
        for idx in range(len(evaln_data)):
            noisy_image = evaln_data[idx].astype(np.float32) / 255.0
            clean_image = evalc_data[idx].astype(np.float32) / 255.0
            output_image, psnr_summary = self.sess.run(
                [self.Y, summary_merged],
                feed_dict={
                    self.X: noisy_image,
                    self.Y_: clean_image,
                    self.is_training: False
                })
            summary_writer.add_summary(psnr_summary, iter_num)

            groundtruth = np.clip(evalc_data[idx], 0, 255).astype('uint8')
            noisy_img = np.clip(evaln_data[idx], 0, 255).astype('uint8')
            output_img = np.clip(255 * output_image, 0, 255).astype('uint8')

            # calculate PSNR
            psnr = cal_psnr(groundtruth, output_img)
            print("img%d PSNR: %.2f" % (idx + 1, psnr))
            psnr_sum += psnr
            filename = 'test%d_%d.png' % (idx + 1, iter_num)
            filename = os.path.join(sample_dir, filename)
            save_images(filename, groundtruth, noisy_img, output_img)
        avg_psnr = psnr_sum / len(evaln_data)

        print("--- Test ---- Average PSNR %.2f ---" % avg_psnr)
def experiment( data, label, n_sample=200, n_test=40, n_imgrow=300, n_imgcol=300, shuffle_button=3,\
               in_button=3, window_len=7):

    if (shuffle_button == 1):
        # 1.shuffle the whole dataset
        order = nr.permutation(n_sample)
        print("[*] shuffle the whole dataset")
    elif (shuffle_button == 2):
        # 2.do not shuffle
        order = range(n_sample)
        print("[*] do not shuffle")
    elif (shuffle_button == 3):
        # 3. shuffle the training and validation only
        order = np.concatenate((nr.permutation(n_sample - n_test),
                                range(n_sample - n_test, n_sample)),
                               axis=0)
        print("[*] shuffle the training and validation only")
    else:
        print("[*] shuffle button not confirmed")

    shuffledata = data[order, :, :]
    shufflelabel = label[order, :, :]
    # split input data and test data
    in_data = shuffledata[0:(n_sample - n_test)]
    in_label = shufflelabel[0:(n_sample - n_test)]
    t_data = shuffledata[(n_sample - n_test):n_sample, :]
    t_label = shufflelabel[(n_sample - n_test):n_sample, :]

    train_data = np.zeros([len(in_data), n_imgrow, n_imgcol, in_button])
    train_label = np.zeros([len(in_label), n_imgrow, n_imgcol, in_button])

    t0 = time.time()

    for i in range(len(in_data)):
        # generate input data
        if (in_button == 1):
            train_data[i] = in_data[i].reshape(
                [n_imgrow, n_imgcol, 1])  # using the noisy image as input
            train_label[i] = (in_data[i] - in_label[i]).reshape(
                [n_imgrow, n_imgcol,
                 1])  # using the clean image as output (label)
        elif (in_button == 2):
            in_hardthr = in_data[i]
            train_data_ch1 = mean_filter(
                in_hardthr, kernelsize=window_len)  # channel 1 is for filtered
            train_data_ch2 = in_hardthr - train_data_ch1  # channel 2 is for residue
            train_data[i] = np.stack([train_data_ch1, train_data_ch2], axis=2)

            lab_hardthr = in_data[i] - in_label[i]  # channel 1 is for original
            train_label_ch1 = mean_filter(
                lab_hardthr,
                kernelsize=window_len)  # channel 1 is for filtered
            train_label_ch2 = lab_hardthr - train_label_ch1  # channel 2 is for residue
            train_label[i] = np.stack([train_label_ch1, train_label_ch2],
                                      axis=2)

        elif (in_button == 3):
            in_hardthr = in_data[i]
            train_data_ch1 = mean_filter(
                mean_filter(in_hardthr),
                kernelsize=window_len)  # channel 1 is for filtered twice
            train_data_ch2 = mean_filter(
                in_hardthr, kernelsize=window_len
            ) - train_data_ch1  # residue 1 (once - twice filtered)
            train_data_ch3 = in_hardthr - mean_filter(
                in_hardthr,
                kernelsize=window_len)  # residue 2 (original - once filtered)
            train_data[i] = np.stack(
                [train_data_ch1, train_data_ch2, train_data_ch3], axis=2)

            lab_hardthr = in_data[i] - in_label[i]
            train_label_ch1 = mean_filter(
                mean_filter(lab_hardthr),
                kernelsize=window_len)  # channel 1 is for filtered twice
            train_label_ch2 = mean_filter(
                lab_hardthr, kernelsize=window_len
            ) - train_label_ch1  # residue 1 (once - twice filtered)
            train_label_ch3 = lab_hardthr - mean_filter(
                lab_hardthr,
                kernelsize=window_len)  # residue 2 (original - once filtered)
            train_label[i] = np.stack(
                [train_label_ch1, train_label_ch2, train_label_ch3], axis=2)

    print("[*] train data ready")
    t1 = time.time()
    print("Total time running: %s seconds" % (str(t1 - t0)))

    t0 = time.time()

    test_data = np.zeros([len(t_data), n_imgrow, n_imgcol, in_button])
    test_label = np.zeros([len(t_label), n_imgrow, n_imgcol, in_button])
    for i in range(len(t_data)):
        # generate input data
        if (in_button == 1):
            test_data[i] = t_data[i].reshape(
                [n_imgrow, n_imgcol, 1])  # using the noisy image as input
            test_label[i] = (t_data[i] - t_label[i]).reshape(
                [n_imgrow, n_imgcol,
                 1])  # using the clean image as output (label)
        elif (in_button == 2):
            t_hardthr = t_data[i]
            test_data_ch1 = mean_filter(
                t_hardthr, kernelsize=window_len)  # channel 1 is for filtered
            test_data_ch2 = t_hardthr - test_data_ch1  # channel 2 is for residue
            test_data[i] = np.stack([test_data_ch1, test_data_ch2], axis=2)

            tlab_hardthr = t_data[i] - t_label[i]  # channel 1 is for original
            test_label_ch1 = mean_filter(
                tlab_hardthr,
                kernelsize=window_len)  # channel 1 is for filtered
            test_label_ch2 = tlab_hardthr - test_label_ch1  # channel 2 is for residue
            test_label[i] = np.stack([test_label_ch1, test_label_ch2], axis=2)

        elif (in_button == 3):
            t_hardthr = t_data[i]
            test_data_ch1 = mean_filter(
                mean_filter(t_hardthr),
                kernelsize=window_len)  # channel 1 is for filtered twice
            test_data_ch2 = mean_filter(
                t_hardthr, kernelsize=window_len
            ) - test_data_ch1  # residue 1 (once - twice filtered)
            test_data_ch3 = t_hardthr - mean_filter(
                t_hardthr,
                kernelsize=window_len)  # residue 2 (original - once filtered)
            test_data[i] = np.stack(
                [test_data_ch1, test_data_ch2, test_data_ch3], axis=2)

            tlab_hardthr = t_data[i] - t_label[i]
            test_label_ch1 = mean_filter(
                mean_filter(tlab_hardthr),
                kernelsize=window_len)  # channel 1 is for filtered twice
            test_label_ch2 = mean_filter(
                tlab_hardthr, kernelsize=window_len
            ) - test_label_ch1  # residue 1 (once - twice filtered)
            test_label_ch3 = tlab_hardthr - mean_filter(
                tlab_hardthr,
                kernelsize=window_len)  # residue 2 (original - once filtered)
            test_label[i] = np.stack(
                [test_label_ch1, test_label_ch2, test_label_ch3], axis=2)

    print("[*] test data ready")
    t1 = time.time()
    print("Total time running: %s seconds" % (str(t1 - t0)))

    CNNclass = FRCNN_model(image_size=[n_imgrow, n_imgcol],
                           in_channel=in_button)

    model = CNNclass.build_model()

    model, hist = CNNclass.train_model(model, train_data, train_label)
    denoised = CNNclass.test_model(model, test_data)

    output = open('log.txt', 'w+')
    output.write(hist.history['loss'])
    output.close

    # calculate the PSNR of this experiment
    ori_psnr = np.zeros([n_test, 1])
    dnd_psnr = np.zeros([n_test, 1])

    #    if os.path.exists('./tobedown'):
    #        os.removedirs('./tobedown')

    for i in range(n_test):
        noisy_img = t_data[i]
        denoised_img = (denoised[i, :, :, :].sum(axis=2)).reshape(
            [n_imgrow, n_imgcol])
        real_img = (t_data[i] - t_label[i]).reshape([n_imgrow, n_imgcol])
        ori_psnr[i] = cal_psnr(real_img, noisy_img)
        dnd_psnr[i] = cal_psnr(real_img, denoised_img)
        '''
        print("the {0:d}th test image : ".format(i))
        print("---> original PSNR is {0:.4f} dB".format(cal_psnr(real_img,noisy_img)))
        print("---> denoised PSNR is {0:.4f} dB".format(cal_psnr(real_img,denoised_img)))
        print("the different frequency PSNR of {0:d}th test image : ".format(i))
        if (in_button == 1):
            print("---- no frequency segmentation ----")
        elif (in_button == 2):
            print("----> original PSNR of smooth is {0:.4f} dB".format(cal_psnr(test_data[i,:,:,0],test_label[i,:,:,0])))
            print("----> original PSNR of residue is {0:.4f} dB".format(cal_psnr(test_data[i,:,:,1],test_label[i,:,:,1])))
            print("----> denoised PSNR of smooth is {0:.4f} dB".format(cal_psnr(denoised[i,:,:,0],test_label[i,:,:,0])))
            print("----> denoised PSNR of residue is {0:.4f} dB".format(cal_psnr(denoised[i,:,:,1],test_label[i,:,:,1])))
        elif (in_button == 3):
            print("----> original PSNR of smooth is {0:.4f} dB".format(cal_psnr(test_data[i,:,:,0],test_label[i,:,:,0])))
            print("----> original PSNR of residue1 is {0:.4f} dB".format(cal_psnr(test_data[i,:,:,1],test_label[i,:,:,1])))
            print("----> original PSNR of residue2 is {0:.4f} dB".format(cal_psnr(test_data[i,:,:,2],test_label[i,:,:,2])))
            print("----> denoised PSNR of smooth is {0:.4f} dB".format(cal_psnr(denoised[i,:,:,0],test_label[i,:,:,0])))
            print("----> denoised PSNR of residue1 is {0:.4f} dB".format(cal_psnr(denoised[i,:,:,1],test_label[i,:,:,1])))
            print("----> denoised PSNR of residue2 is {0:.4f} dB".format(cal_psnr(denoised[i,:,:,2],test_label[i,:,:,2])))       
        '''
    print("---> original PSNR is %.4f dB" % np.mean(ori_psnr))
    print("---> denoised PSNR is %.4f dB" % np.mean(dnd_psnr))

    output = open('log.txt', 'w+')
    output.write("---> original PSNR is %.4f dB\n" % np.mean(ori_psnr))
    output.write("---> denoised PSNR is %.4f dB\n" % np.mean(dnd_psnr))
    output.close

    # save experiment results
    savingpath = './tobedown/tobedown_in' + str(in_button) + '_winlen' + str(
        window_len)
    if not os.path.exists(savingpath):
        os.makedirs(savingpath)
    postfix = str(in_button) + '_' + str(window_len)
    sio.savemat(os.path.join(savingpath,'denoised'+ postfix+ '.mat'), \
                {'denoised'+ postfix: denoised.sum(axis=3).reshape([n_test,n_imgrow,n_imgcol])})
    sio.savemat(os.path.join(savingpath,'denoised_ch1'+ postfix+ '.mat'), \
                {'denoised1'+ postfix: denoised[:,:,:,0].reshape([n_test,n_imgrow,n_imgcol])})
    if (in_button > 1):
        sio.savemat(os.path.join(savingpath,'denoised_ch2'+ postfix+ '.mat'), \
                {'denoised2'+ postfix: denoised[:,:,:,1].reshape([n_test,n_imgrow,n_imgcol])})
    if (in_button > 2):
        sio.savemat(os.path.join(savingpath,'denoised_ch3'+ postfix+ '.mat'), \
                {'denoised3'+ postfix: denoised[:,:,:,2].reshape([n_test,n_imgrow,n_imgcol])})

    sio.savemat(os.path.join(savingpath,'noisy'+ postfix+ '.mat'), \
                {'noisy'+ postfix: t_data})
    sio.savemat(os.path.join(savingpath,'real'+ postfix+ '.mat'), \
                {'real'+ postfix: (t_data-t_label).reshape([n_test,n_imgrow,n_imgcol])})
    sio.savemat(os.path.join(savingpath,'ori_psnr'+ postfix+ '.mat'), \
                {'ori_psnr'+ postfix: ori_psnr})
    sio.savemat(os.path.join(savingpath,'dnd_psnr'+ postfix+ '.mat'), \
                {'dnd_psnr'+ postfix: dnd_psnr})

    return ori_psnr, dnd_psnr
Example #7
0
def val_loop(stsr, val_loader, val_dataset, epoch):
    ### validation
    avg_PSNR_TS = 0
    avg_PSNR_ST = 0
    avg_PSNR_MERGE = 0
    avg_PSNR_RESIDUAL = 0
    avg_PSNR_HR = 0
    avg_PSNR_LR = 0

    avg_SSIM_TS = 0
    avg_SSIM_ST = 0
    avg_SSIM_MERGE = 0
    avg_SSIM_RESIDUAL = 0
    avg_SSIM_HR = 0
    avg_SSIM_LR = 0

    stsr.eval()

    n = 0
    with torch.no_grad():
        # for vid, val_data in enumerate(tqdm(val_loader)):
        for vid, val_data in enumerate(val_loader):
            """
            TEST CODE
            """
            if args.forward_MsMt:
                HR = val_data['HR'].to(device)
                LR = torch.stack([nn_down(HR[:, 0]), nn_down(HR[:, 1]), nn_down(HR[:, 2])], dim=1)
                LR = LR.clamp(0, 1).detach()
                GT = HR[:, 1]
                I_L_2, I_H_1, I_H_3, I_TS_2, I_ST_2, I_F_2, mask_1, mask_2, I_R_basic, I_R_2 = stsr(LR[:, 0], LR[:, 2])
            else:
                ST = val_data['ST'].to(device)
                TS = val_data['TS'].to(device)
                GT = val_data['GT'].to(device)
                I_L_2, I_H_1, I_H_3, I_TS_2, I_ST_2, I_F_2, mask_1, mask_2, I_R_basic, I_R_2 = stsr(ST, TS)
            
            B, C, H, W = GT.size()


            for b_id in range(B):

                avg_PSNR_TS += utils.cal_psnr(I_TS_2[b_id], HR[b_id, 1]).item()
                avg_PSNR_ST += utils.cal_psnr(I_ST_2[b_id], HR[b_id, 1]).item()
                avg_PSNR_MERGE += utils.cal_psnr(I_F_2[b_id], HR[b_id, 1]).item()
                avg_PSNR_RESIDUAL += utils.cal_psnr(I_R_2[b_id], HR[b_id, 1]).item()
                avg_PSNR_HR += utils.cal_psnr(I_H_1[b_id], HR[b_id, 0]).item()+utils.cal_psnr(I_H_3[b_id], HR[b_id, 2]).item()
                avg_PSNR_LR += utils.cal_psnr(I_L_2[b_id], LR[b_id, 1]).item()

                avg_SSIM_TS += utils.cal_ssim(I_TS_2[b_id], HR[b_id, 1])
                avg_SSIM_ST += utils.cal_ssim(I_ST_2[b_id], HR[b_id, 1])
                avg_SSIM_MERGE += utils.cal_ssim(I_F_2[b_id], HR[b_id, 1])
                avg_SSIM_RESIDUAL += utils.cal_ssim(I_R_2[b_id], HR[b_id, 1])
                avg_SSIM_HR += utils.cal_ssim(I_H_1[b_id], HR[b_id, 0])+utils.cal_ssim(I_H_3[b_id], HR[b_id, 2])
                avg_SSIM_LR += utils.cal_ssim(I_L_2[b_id], LR[b_id, 1])


    f = open(os.path.join(save_dir, 'vimeo_record.txt'), 'w')
    print('PSNR_HR: {}'.format(avg_PSNR_HR/len(val_dataset)/2), file=f)
    print('PSNR_LR: {}'.format(avg_PSNR_LR/len(val_dataset)), file=f)
    print('PSNR_TS: {}'.format(avg_PSNR_TS/len(val_dataset)), file=f)
    print('PSNR_ST: {}'.format(avg_PSNR_ST/len(val_dataset)), file=f)
    print('PSNR_MERGE: {}'.format(avg_PSNR_MERGE/len(val_dataset)), file=f)
    print('PSNR_REFINE: {}'.format(avg_PSNR_RESIDUAL/len(val_dataset)), file=f)

    print('SSIM_HR: {}'.format(avg_SSIM_HR/len(val_dataset)/2), file=f)
    print('SSIM_LR: {}'.format(avg_SSIM_LR/len(val_dataset)), file=f)
    print('SSIM_TS: {}'.format(avg_SSIM_TS/len(val_dataset)), file=f)
    print('SSIM_ST: {}'.format(avg_SSIM_ST/len(val_dataset)), file=f)
    print('SSIM_MERGE: {}'.format(avg_SSIM_MERGE/len(val_dataset)), file=f)
    print('SSIM_REFINE: {}'.format(avg_SSIM_RESIDUAL/len(val_dataset)), file=f)
    f.close()
    avg_err, avg_psnr = 0, 0
    acc_rec = 0
    acc_f_diff = 0

    start_time = time.time()

    for z, data in enumerate(tqdm(trainloader)):
        ori_v = torch.autograd.Variable(data['ori'],
                                        requires_grad=False).cuda()
        de_v = torch.autograd.Variable(data['de'], requires_grad=False).cuda()
        residual = ori_v - de_v
        reconstruction, features = featExNets(residual)

        if epoch == 0:
            ori_psnr += utils.cal_psnr(ori_v.cpu().data.numpy(),
                                       de_v.cpu().data.numpy(),
                                       data_range=1.0).item()

        # epoch 0 to 4 we use real residual patterns to train upSamplingNets and refineNets
        # epoch 5 to - we use approximated residual patterns to train upSamplingNets and refineNets

        if epoch >= 5:
            c = 1  # weight for loss
            pick = []

            pre_kmmodel = utils.load_obj(opt.logging_root +
                                         '/kmeans/kmmodel_%d' % (epoch - 1))
            pre_centerPatch = utils.load_obj(opt.logging_root +
                                             '/kmeans/centerPatch_%d' %
                                             (epoch - 1))
Example #9
0
def test(model):

    with torch.no_grad():

        raw_path = os.path.join(dir_test, "RAISE_raw_" + suffix_data_path)

        dpsnr_sum_5QP = 0.0
        #dssim_sum_5QP = 0.0

        for QPorQF in order_QPorQF:  # test order, not the output order

            cmp_path = os.path.join(
                dir_test,
                "RAISE_" + tab + str(QPorQF) + "_" + suffix_data_path)

            dpsnr_ave = 0.0
            #dssim_ave = 0.0
            time_total = 0.0
            nfs_test_final = nfs_test_used

            for ite_frame in range(nfs_test_used):

                raw_frame = utils.y_import(
                    raw_path,
                    height_frame,
                    width_frame,
                    nfs=1,
                    startfrm=ite_frame).astype(
                        np.float32)[:, start_height:start_height + height_test,
                                    start_width:start_width + width_test] / 255
                cmp_frame = utils.y_import(
                    cmp_path,
                    height_frame,
                    width_frame,
                    nfs=1,
                    startfrm=ite_frame).astype(
                        np.float32)[:, start_height:start_height + height_test,
                                    start_width:start_width + width_test] / 255
                if isplane(raw_frame
                           ):  # plain frame => no need to enhance => invalid
                    nfs_test_final -= 1
                    continue

                cmp_t, raw_t = torch.from_numpy(cmp_frame).to(
                    dev), torch.from_numpy(raw_frame).to(
                        dev)  # turn them to tensors and move to GPU
                cmp_t = cmp_t.view(
                    1, 1, height_test, width_test
                )  # batch_size * height * width => batch_size * channel * height * width

                start_time = time.time()

                enh_1 = model(cmp_t,
                              1)  # enhanced img from the shallowest output
                enh_2 = model(cmp_t, 2)
                enh_3 = model(cmp_t, 3)
                enh_4 = model(cmp_t, 4)
                enh_5 = model(cmp_t, 5)  # enhanced img from the deepest output

                if QPorQF == order_QPorQF[0]:
                    enhanced_cmp_t = enh_1
                elif QPorQF == order_QPorQF[1]:
                    enhanced_cmp_t = enh_2
                elif QPorQF == order_QPorQF[2]:
                    enhanced_cmp_t = enh_3
                elif QPorQF == order_QPorQF[3]:
                    enhanced_cmp_t = enh_4
                elif QPorQF == order_QPorQF[4]:
                    enhanced_cmp_t = enh_5

                time_total += time.time() - start_time

                if opt_output:  # save frame as png
                    func_output(ite_frame, QPorQF, cmp_t, out=0)
                    func_output(ite_frame, QPorQF, enh_1, out=1)
                    func_output(ite_frame, QPorQF, enh_2, out=2)
                    func_output(ite_frame, QPorQF, enh_3, out=3)
                    func_output(ite_frame, QPorQF, enh_4, out=4)
                    func_output(ite_frame, QPorQF, enh_5, out=5)

                # cal dpsnr and dssim
                #dpsnr, dssim = cal_dpsnr_dssim(raw_frame, cmp_frame, enhanced_cmp_t)
                dpsnr = cal_dpsnr_dssim(raw_frame, cmp_frame, enhanced_cmp_t)
                #print("\rframe %4d|%4d - dpsnr %.3f - dssim %3d (x1e-4) - %s %2d    " % (ite_frame + 1, nfs_test_used, dpsnr, dssim * 1e4, tab, QPorQF), end="", flush=True)
                print("\rframe %4d|%4d - dpsnr %.3f - %s %2d    " %
                      (ite_frame + 1, nfs_test_used, dpsnr, tab, QPorQF),
                      end="",
                      flush=True)

                dpsnr_ave += dpsnr
                #dssim_ave += dssim

                # cal dpsnr and dssim for all outputs
                #dp1, ds1 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_1)
                #dp2, ds2 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_2)
                #dp3, ds3 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_3)
                #dp4, ds4 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_4)
                #dp5, ds5 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_5)
                dp1 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_1)
                dp2 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_2)
                dp3 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_3)
                dp4 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_4)
                dp5 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_5)


                fp_each.write("frame %d - %s %2d - ori psnr: %.3f - dpsnr from o1 to o5: %.3f, %.3f, %.3f, %.3f, %.3f\n" %\
                    (ite_frame, tab, QPorQF, utils.cal_psnr(cmp_frame, raw_frame, data_range=1.0), dp1, dp2, dp3, dp4, dp5))
                #fp_each.write("frame %d - %s %2d - ori ssim: %.3f - dssim from o1 to o5: %.3f, %.3f, %.3f, %.3f, %.3f\n" %\
                #    (ite_frame, tab, QPorQF, compare_ssim(np.squeeze(cmp_frame), np.squeeze(raw_frame), data_range=1), ds1, ds2, ds3, ds4, ds5))
                fp_each.flush()

            dpsnr_ave = dpsnr_ave / nfs_test_final
            #dssim_ave = dssim_ave / nfs_test_final
            fps = nfs_test_final / time_total

            #print("\r=== dpsnr: %.3f - dssim %3d (x1e-4) - %s %2d - fps %.1f ===          " % (dpsnr_ave, dssim * 1e4, tab, QPorQF, fps), flush=True)
            print(
                "\r=== dpsnr: {:.3f} - {:s} {:2d} - fps {:.1f} (no early-exit) ==="
                .format(dpsnr_ave, tab, QPorQF, fps) + 10 * " ",
                flush=True)
            #fp_ave.write("=== dpsnr: %.3f - dssim %3d (x1e-4) - %s %2d - fps %.1f ===\n" % (dpsnr_ave, dssim * 1e4, tab, QPorQF, fps))
            fp_ave.write(
                "=== dpsnr: %.3f - %s %2d - fps %.1f (no early-exit) ===\n" %
                (dpsnr_ave, tab, QPorQF, fps))

            fp_ave.flush()

            dpsnr_sum_5QP += dpsnr_ave
            #dssim_sum_5QP += dssim_ave

        #print("=== dpsnr: %.3f - dssim: % 3d (x1e-4) ===" % (dpsnr_sum_5QP / 5, dssim_sum_5QP / 5 * 1e4), flush=True)
        #fp_ave.write("=== dpsnr: %.3f - dssim: % 3d (x1e-4) ===\n" % (dpsnr_sum_5QP / 5, dssim_sum_5QP / 5 * 1e4))
        print("=== dpsnr: %.3f ===" % (dpsnr_sum_5QP / 5, ), flush=True)
        fp_ave.write("=== dpsnr: %.3f ===\n" % (dpsnr_sum_5QP / 5))
        fp_ave.flush()
                                         '/kmeans/centerPatch_%d' % i)

        ori_psnr, ori_ssim = 0, 0
        avg_err, avg_psnr, avg_ssim = 0, 0, 0
        avg_f_diff = 0
        start_time = time.time()

        for _, data in enumerate(testloader):
            ori_v = torch.autograd.Variable(data['ori'],
                                            requires_grad=False).cuda()
            de_v = torch.autograd.Variable(data['de'],
                                           requires_grad=False).cuda()
            residual = ori_v - de_v

            ori_psnr += utils.cal_psnr(ori_v.cpu().data.numpy(),
                                       de_v.cpu().data.numpy(),
                                       data_range=1.0).item() / len(testset)
            ori_ssim += utils.cal_ssim(
                ori_v.squeeze().cpu().data.numpy().transpose(1, 2, 0),
                de_v.squeeze().cpu().data.numpy().transpose(1, 2, 0),
                data_range=1.0,
                multichannel=True).item() / len(testset)

            _, features = featExNets(residual)

            pick = []
            patchResFeat = features.squeeze().permute(
                1, 2, 0).contiguous().view(
                    -1,
                    features.size()[1]).cpu().detach().data.numpy()
            prediction = pre_kmmodel.predict(patchResFeat.astype(np.float64))
Example #11
0
def val_loop(stsr, val_loader, val_dataset, epoch):
    ### validation
    avg_PSNR_TS = 0
    avg_PSNR_ST = 0
    avg_PSNR_MERGE = 0
    avg_PSNR_RESIDUAL = 0
    avg_PSNR_HR = 0
    avg_PSNR_LR = 0

    stsr.eval()

    with torch.no_grad():
        for vid, val_data in enumerate(tqdm(val_loader)):
            """
            TEST CODE
            """
            if args.train_MsMt:
                HR = val_data['HR'].to(device)
                LR = torch.stack(
                    [nn_down(HR[:, 0]),
                     nn_down(HR[:, 1]),
                     nn_down(HR[:, 2])],
                    dim=1)
                LR = LR.clamp(0, 1).detach()
                GT = HR[:, 1]
                I_L_2, I_H_1, I_H_3, I_TS_2, I_ST_2, I_F_2, mask_1, mask_2, I_R_basic, I_R_2 = stsr(
                    LR[:, 0], LR[:, 2])
            else:
                ST = val_data['ST'].to(device)
                TS = val_data['TS'].to(device)
                GT = val_data['GT'].to(device)
                I_L_2, I_H_1, I_H_3, I_TS_2, I_ST_2, I_F_2, mask_1, mask_2, I_R_basic, I_R_2 = stsr(
                    ST, TS)

            B, C, H, W = GT.size()

            for b_id in range(B):
                avg_PSNR_MERGE += utils.cal_psnr(I_R_basic[b_id],
                                                 GT[b_id]).item()
                avg_PSNR_RESIDUAL += utils.cal_psnr(I_R_2[b_id],
                                                    GT[b_id]).item()
                avg_PSNR_TS += utils.cal_psnr(I_TS_2[b_id], GT[b_id]).item()
                avg_PSNR_ST += utils.cal_psnr(I_ST_2[b_id], GT[b_id]).item()
                if args.train_MsMt:
                    avg_PSNR_HR += utils.cal_psnr(
                        I_H_1[b_id], HR[b_id, 0]).item() + utils.cal_psnr(
                            I_H_3[b_id], HR[b_id, 2]).item()
                    avg_PSNR_LR += utils.cal_psnr(I_L_2[b_id], LR[b_id,
                                                                  1]).item()

    log = {
        'PSNR_TS': avg_PSNR_TS / len(val_dataset),
        'PSNR_ST': avg_PSNR_ST / len(val_dataset),
        'PSNR_MERGE': avg_PSNR_MERGE / len(val_dataset),
        'PSNR_RESIDUAL': avg_PSNR_RESIDUAL / len(val_dataset)
    }
    if args.train_MsMt:
        log['PSNR_HR'] = avg_PSNR_HR / len(val_dataset) / 2.
        log['PSNR_LR'] = avg_PSNR_LR / len(val_dataset)
    print(log)

    return avg_PSNR_RESIDUAL / len(val_dataset)
Example #12
0
    def train(self):
        print("Begin training...")
        whole_time = 0 
        init_global_step = self.model.global_step.eval(self.sess)
        for _ in range(init_global_step, self.args.num_iter):
            start_time = time.time()

            # For Track 2, 3, and 4: multi-scale => multi-degradation
            if self.args.degrade:
                self.data.scale_list = self.data.degra_list

            # randomly select scale in scale_list
            idx_scale = np.random.choice(len(self.data.scale_list))
            scale = self.data.scale_list[idx_scale]

            # get batch data and scale
            train_in_imgs, train_tar_imgs = self.data.get_batch(batch_size=self.args.num_batch, idx_scale=idx_scale)

            # train the network  
            feed_dict = {self.model.input: train_in_imgs, self.model.target: train_tar_imgs, self.model.flag_scale: scale}
            _, loss, lr, output, global_step = self.sess.run([self.model.train_op, self.model.loss, self.model.learning_rate, self.model.output, self.model.global_step], \
                                                            feed_dict=feed_dict)

            # check the duration of each iteration
            end_time = time.time()
            duration = end_time - start_time
            whole_time += duration
            mean_duration = whole_time / (global_step - init_global_step)   

            ############################################## print loss and duratin of training  ################################################
            if global_step % self.args.print_freq == 0:
                print('Loss: %d, Duration: %d / %d (%.3f sec/batch)' % (loss, global_step, self.args.num_iter, mean_duration))
            

            ############################################## log the loss, PSNR, and lr of training ##############################################
            if global_step % self.args.log_freq == 0:
                # calculate PSNR
                psnr = 0 
                for (out_img, tar_img) in zip(output, train_tar_imgs):
                    psnr += cal_psnr(out_img, tar_img, scale) / self.args.num_batch

                # write summary
                summaries_dict = {}
                summaries_dict['loss'] = loss
                summaries_dict['PSNR'] = psnr
                summaries_dict['lr'] = lr
                # summaries_dict['input'] = np.array(train_in_imgs)[:3]
                # summaries_dict['output'] = np.clip(np.round(output), 0.0, 255.0)[:3]
                # summaries_dict['target'] = np.array(train_tar_imgs)[:3]            
                self.logger.write(summaries_dict, global_step, is_train=True, idx_scale=idx_scale) 


            ######################################################## save the trained model ########################################################
            if global_step % self.args.save_freq == 0: 
                # save the trained model
                self.model.save(self.sess)  


            ############################################## log the PSNR of validation ############################################## 
            if global_step % self.args.valid_freq == 0: 
                # validation for all scale used
                for idx_scale, scale in enumerate(self.data.scale_list):
                    if self.args.degrade:
                        scale = 4

                    valid_in_imgs = self.data.dataset[idx_scale][-self.args.num_valid:]
                    if self.args.is_degrade:
                        valid_tar_imgs = self.data.dataset[idx_scale+len(self.data.scale_list)][-self.args.num_valid:]
                    else:
                        valid_tar_imgs = self.data.dataset[-1][-self.args.num_valid:]

                    # inference validation images & calculate PSNR
                    psnr = 0
                    for (in_img, tar_img) in zip(valid_in_imgs, valid_tar_imgs):
                        tar_img = mod_crop(tar_img, scale)
                        out_img = chop_forward(in_img, self.sess, self.model, scale=scale, shave=10)
                        psnr += cal_psnr(out_img, tar_img, scale) / self.args.num_valid

                    # write summary
                    summaries_dict = {}
                    summaries_dict['PSNR'] = psnr 
                    self.logger.write(summaries_dict, global_step, is_train=False, idx_scale=idx_scale) 
        print("Training is done!")