Ejemplo n.º 1
0
    def __getitem__(self, idx):
        idx = idx % len(self.img_list)
        hr_path = osp.join(self.dataroot_hr, self.img_list[idx])
        base, ext = osp.splitext(self.img_list[idx])
        lr_basename = base + 'x{}'.format(self.scale) + ext
        lr_path = osp.join(self.dataroot_lr, lr_basename)
        hr = self.load_img(hr_path)
        lr = self.load_img(lr_path)

        if self.noise is not None:
            lr = self.add_noise(lr, self.noise['type'], self.noise['value'])
        if self.train_Y:
            lr = rgb2ycbcr(lr)[:, :, np.newaxis]
            hr = rgb2ycbcr(hr)[:, :, np.newaxis]

        data = {}
        if self.split == 'train':
            lr_patch, hr_patch = self.get_patch(lr, hr, self.ps, self.scale)
            lr, hr = self.augment(lr_patch, hr_patch, self.use_flip,
                                  self.use_rot)
        lr, hr = self.np2tensor(lr), self.np2tensor(hr)

        data['LR'] = lr
        data['HR'] = hr

        return data
Ejemplo n.º 2
0
    def __getitem__(self, index):
        input_ = self.input[index, :, :, :]
        target_ = self.target[index, :, :, :]

        subim_in, subim_tar = get_patch(input_, target_, self.patch_size,
                                        self.sr_factor)

        if not self.rgb:
            subim_in = utils.rgb2ycbcr(subim_in)
            subim_tar = utils.rgb2ycbcr(subim_tar)
            subim_in = np.expand_dims(subim_in[:, :, 0], 2)
            subim_tar = np.expand_dims(subim_tar[:, :, 0], 2)

        if self.input_up:
            subim_bic = imresize(subim_in, scalar_scale=self.sr_factor)
            subim_in = utils.np2tensor(subim_in, self.rgb_range)
            subim_tar = utils.np2tensor(subim_tar, self.rgb_range)
            subim_bic = utils.np2tensor(subim_bic, self.rgb_range)
            return {
                'input': subim_in,
                'target': subim_tar,
                'input_up': subim_bic
            }

        subim_in = utils.np2tensor(subim_in, self.rgb_range)
        subim_tar = utils.np2tensor(subim_tar, self.rgb_range)

        return {'input': subim_in, 'target': subim_tar}
Ejemplo n.º 3
0
    def comput_PSNR_SSIM(self, pred, gt, shave_border=0):

        if isinstance(pred, torch.Tensor):
            pred = tensor2np(pred, self.opt.rgb_range)
            pred = pred.astype(np.float32)

        if isinstance(gt, torch.Tensor):
            gt = tensor2np(gt, self.opt.rgb_range)
            gt = gt.astype(np.float32)

        height, width = pred.shape[:2]
        pred = pred[shave_border:height - shave_border,
                    shave_border:width - shave_border]
        gt = gt[shave_border:height - shave_border,
                shave_border:width - shave_border]

        if pred.shape[2] == 3 and gt.shape[2] == 3:
            pred_y = rgb2ycbcr(pred)[:, :, 0]
            gt_y = rgb2ycbcr(gt)[:, :, 0]
        elif pred.shape[2] == 1 and gt.shape[2] == 1:
            pred_y = pred[:, :, 0]
            gt_y = gt[:, :, 0]
        else:
            raise ValueError('Input or output channel is not 1 or 3!')

        psnr_ = calc_PSNR(pred_y, gt_y)
        ssim_ = calc_ssim(pred_y, gt_y)

        return psnr_, ssim_
Ejemplo n.º 4
0
    def gen_tfrecords(self, save_dir='DIV2K/tfrecords', tfrecord_num=10):
        file_num = tfrecord_num
        sample_num = 0
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        fs = []
        for i in range(file_num):
            fs.append(
                tf.python_io.TFRecordWriter(
                    os.path.join(save_dir, 'data%d.tfrecords' % i)))

        img_paths = sorted(glob.glob(os.path.join(self.data_dir, '*.png')))
        for img_path in img_paths:
            print('processing %s' % img_path)
            img = misc.imread(img_path)
            y = utils.img_to_uint8(utils.rgb2ycbcr(img)[:, :, 0])
            height, width = y.shape
            p = self.patch_size
            step = p // 3 * 2
            for h in range(0, height - p + 1, step):
                for w in range(0, width - p + 1, step):
                    gt = y[h:h + p, w:w + p]
                    assert gt.shape[:] == (p, p)
                    assert gt.dtype == np.uint8
                    example = tf.train.Example(features=tf.train.Features(
                        feature={'gt': _bytes_feature(gt.tostring())}))
                    fs[sample_num % file_num].write(
                        example.SerializeToString())
                    sample_num += 1
        print('example number: %d' % sample_num)
        np.savetxt(os.path.join(save_dir, 'sample_num.txt'),
                   np.asarray([sample_num]), '%d')
        for f in fs:
            f.close()
Ejemplo n.º 5
0
    def upscale(self, im_l, s):
        """
        % im_l: LR image, float np array in [0, 255]
        % im_h: HR image, float np array in [0, 255]
        """
        im_l = im_l/255.0
        if len(im_l.shape)==3 and im_l.shape[2]==3:
            im_l_ycbcr = utils.rgb2ycbcr(im_l)
        else:
            im_l_ycbcr = np.zeros([im_l.shape[0], im_l.shape[1], 3])
            im_l_ycbcr[:, :, 0] = im_l
            im_l_ycbcr[:, :, 1] = im_l
            im_l_ycbcr[:, :, 2] = im_l

        im_l_y = im_l_ycbcr[:, :, 0]*255 #[16 235]
        im_h_y = self.upscale_alg(im_l_y, s)

        # recover color
        if len(im_l.shape)==3:
            im_ycbcr = utils.imresize(im_l_ycbcr, s);
            im_ycbcr[:, :, 0] = im_h_y/255.0; #[16/255 235/255]
            im_h = utils.ycbcr2rgb(im_ycbcr)*255.0
        else:
            im_h = im_h_y

        im_h = np.clip(im_h, 0, 255)
        im_h_y = np.clip(im_h_y, 0, 255)
        return im_h,im_h_y
Ejemplo n.º 6
0
    def upscale(self, im_l, s):
        """
        % im_l: LR image, float np array in [0, 255]
        % im_h: HR image, float np array in [0, 255]
        """
        im_l = im_l/255.0
        if len(im_l.shape)==3 and im_l.shape[2]==3:
            im_l_ycbcr = utils.rgb2ycbcr(im_l)
        else:
            im_l_ycbcr = np.zeros([im_l.shape[0], im_l.shape[1], 3])
            im_l_ycbcr[:, :, 0] = im_l
            im_l_ycbcr[:, :, 1] = im_l
            im_l_ycbcr[:, :, 2] = im_l

        im_l_y = im_l_ycbcr[:, :, 0]*255 #[16 235]
        im_h_y = self.upscale_alg(im_l_y, s)

        # recover color
        #print 'recover color...'
        if len(im_l.shape)==3:
            im_ycbcr = utils.imresize(im_l_ycbcr, s);
            im_ycbcr[:, :, 0] = im_h_y/255.0; #[16/255 235/255]
            im_h = utils.ycbcr2rgb(im_ycbcr)*255.0
        else:
            im_h = im_h_y

        #print 'clip...'
        im_h = np.clip(im_h, 0, 255)
        im_h_y = np.clip(im_h_y, 0, 255)
        return im_h,im_h_y
Ejemplo n.º 7
0
    def gen_generator(self):
        img_paths = sorted(glob.glob(os.path.join(self.data_dir, '*.png')))
        one_img_patch_num = self.one_img_patch_num
        p = self.patch_size
        scale = self.scale
        for img_path in img_paths:
            img = misc.imread(img_path)
            height, width, _ = img.shape
            for i in range(one_img_patch_num):
                h = np.random.randint(height - p + 1)
                w = np.random.randint(width - p + 1)
                patch = img[h:h + p, w:w + p]
                gt = utils.rgb2ycbcr(patch)[:, :, 0]

                gt = np.float32(gt) / 255.0
                c1 = np.random.rand()  #0~1均匀分布
                c2 = np.random.rand()
                if c1 < 0.5:
                    gt = gt[::-1, :]
                if c2 < 0.5:
                    gt = gt[:, ::-1]

                lr = misc.imresize(gt, 1.0 / scale, 'bicubic', 'F')
                bic = misc.imresize(lr, scale / 1.0, 'bicubic', 'F')
                yield lr, bic, gt
Ejemplo n.º 8
0
    def __getitem__(self, index):
        path = self.test_path[index]
        images = {}

        images.update({'name': path['save']})
        input_ = cv2.imread(path['input'])
        input_ = cv2.cvtColor(input_, cv2.COLOR_BGR2RGB)

        target_ = cv2.imread(path['target'])
        target_ = utils.modcrop(target_, self.sr_factor)
        target_ = cv2.cvtColor(target_, cv2.COLOR_BGR2RGB)

        if not self.rgb:
            input_out = np.copy(input_)
            input_out = utils.np2tensor(input_out, self.rgb_range)
            # print(input_out)
            input_ = utils.rgb2ycbcr(input_)
            input_cbcr = input_[:, :, 1:]
            input_ = np.expand_dims(input_[:, :, 0], 2)
            input_cbcr = utils.np2tensor(input_cbcr, self.rgb_range)
            images.update({'input_cbcr': input_cbcr, 'input_rgb': input_out})

        if self.target_down:
            target_down = imresize(target_, scalar_scale=1 / self.sr_factor)
            target_down = utils.np2tensor(target_down, self.rgb_range)
            images.update({'target_down': target_down})

        input_ = utils.np2tensor(input_, self.rgb_range)
        target_ = utils.np2tensor(target_, self.rgb_range)

        images.update({'input': input_, 'target': target_})
        return images
Ejemplo n.º 9
0
    def __getitem__(self, index):
        path = self.train_path[index]
        images = {}
        if self.npy_reader:
            input_ = np.load(path['input'], allow_pickle=False)

            target_ = np.load(path['target'], allow_pickle=False)
            target_ = utils.modcrop(target_, self.sr_factor)
        else:
            input_ = cv2.imread(path['input'])
            input_ = cv2.cvtColor(input_, cv2.COLOR_BGR2RGB)

            target_ = cv2.imread(path['target'])
            target_ = utils.modcrop(target_, self.sr_factor)
            target_ = cv2.cvtColor(target_, cv2.COLOR_BGR2RGB)

        # for i in range(10):
        #     subim_in, subim_tar = get_patch(input_, target_, self.patch_size, self.sr_factor)
        # win_mean = ndimage.uniform_filter(subim_in[:, :, 0], (5, 5))
        # win_sqr_mean = ndimage.uniform_filter(subim_in[:, :, 0]**2, (5, 5))
        # win_var = win_sqr_mean - win_mean**2
        #
        # if np.sum(win_var) / (win_var.shape[0]*win_var.shape[1]) > 30:
        #     break

        subim_in, subim_tar = get_patch(input_, target_, self.patch_size,
                                        self.sr_factor)

        if not self.rgb:
            subim_in = utils.rgb2ycbcr(subim_in)
            subim_tar = utils.rgb2ycbcr(subim_tar)
            subim_in = np.expand_dims(subim_in[:, :, 0], 2)
            subim_tar = np.expand_dims(subim_tar[:, :, 0], 2)

        if self.target_down:
            subim_target_down = imresize(subim_tar,
                                         scalar_scale=1 / self.sr_factor)
            subim_target_down = utils.np2tensor(subim_target_down,
                                                self.rgb_range)
            images.update({'target_down': subim_target_down})

        subim_in = utils.np2tensor(subim_in, self.rgb_range)
        subim_tar = utils.np2tensor(subim_tar, self.rgb_range)
        images.update({'input': subim_in, 'target': subim_tar})
        return images
Ejemplo n.º 10
0
    def __getitem__(self, index):
        input_ = np.load(self.input_path[index])
        target_ = np.load(self.target_path[index])

        if not self.rgb:
            input_ = utils.rgb2ycbcr(input_)
            target_ = utils.rgb2ycbcr(target_)

        if self.input_up:
            input_bic = imresize(input_, scalar_scale=self.sr_factor).round()
            input_ = utils.np2tensor(input_, self.rgb_range)
            target_ = utils.np2tensor(target_, self.rgb_range)
            input_bic_ = utils.np2tensor(input_bic, self.rgb_range)
            return {'input': input_, 'target': target_, 'input_up': input_bic_}

        input_ = utils.np2tensor(input_, self.rgb_range)
        target_ = utils.np2tensor(target_, self.rgb_range)
        return {'input': input_, 'target': target_}
Ejemplo n.º 11
0
    def evaluate(self):
        opt = self.opt
        self.net.eval()

        psnr = 0
        for i, inputs in enumerate(self.test_loader):
            HR = inputs[0].to(self.dev)
            LR = inputs[1].to(self.dev)
            ORI_LR = LR.clone().detach()

            # match the resolution of (LR, HR) due to CutBlur
            if HR.size() != LR.size():
                scale = HR.size(2) // LR.size(2)
                LR = F.interpolate(LR, scale_factor=scale, mode="nearest")

            SR = self.net(LR)
            if isinstance(SR, (list, tuple)):
                SR = SR[-1]

            SR = SR.detach()
            # iter over batch
            for i in range(HR.size(0)):
                hr = HR[i].clamp(0,
                                 255).round().cpu().byte().permute(1, 2,
                                                                   0).numpy()
                sr = SR[i].clamp(0,
                                 255).round().cpu().byte().permute(1, 2,
                                                                   0).numpy()

                hr = hr[opt.crop:-opt.crop, opt.crop:-opt.crop, :]
                sr = sr[opt.crop:-opt.crop, opt.crop:-opt.crop, :]
                if opt.eval_y_only:
                    hr = utils.rgb2ycbcr(hr)
                    sr = utils.rgb2ycbcr(sr)
                psnr += utils.calculate_psnr(hr, sr)
            if opt.save_result:
                save_root = os.path.join(opt.save_root, opt.dataset)
                save_path = os.path.join(save_root, "{:04d}.png".format(i + 1))
                utils.save_batch_hr_lr(HR, SR, ORI_LR, save_path)
                # io.imsave(save_path, SR)

        self.net.train()

        return psnr / len(self.test_loader.dataset)
Ejemplo n.º 12
0
    def __getitem__(self, index):
        img_path = self.images_path[index]
        target_ = cv2.imread(img_path)
        input_ = imresize(target_, scalar_scale=1 / self.sr_factor)

        if not self.rgb:
            input_ = utils.rgb2ycbcr(input_)
            target_ = utils.rgb2ycbcr(target_)

        if self.input_up:
            input_bic = imresize(input_, scalar_scale=self.sr_factor).round()
            input_ = utils.np2tensor(input_, self.rgb_range)
            target_ = utils.np2tensor(target_, self.rgb_range)
            input_bic_ = utils.np2tensor(input_bic, self.rgb_range)
            return {'input': input_, 'target': target_, 'input_up': input_bic_}

        input_ = utils.np2tensor(input_, self.rgb_range)
        target_ = utils.np2tensor(target_, self.rgb_range)
        return {'input': input_, 'target': target_}
Ejemplo n.º 13
0
    def __getitem__(self, idx):
        hr_path = osp.join(self.dataroot_hr, self.img_list[idx])
        lr_path = osp.join(self.dataroot_lr, self.img_list[idx])
        hr = self.load_img(hr_path)
        lr = self.load_img(lr_path)

        if self.noise is not None:
            lr = self.add_noise(lr, self.noise['type'], self.noise['value'])
        if self.train_Y:
            lr = rgb2ycbcr(lr)
            hr = rgb2ycbcr(hr)

        data = {}
        if self.split == 'train':
            lr_patch, hr_patch = self.get_patch(lr, hr, self.ps, self.scale)
            lr, hr = self.augment(lr_patch, hr_patch, self.use_flip, self.use_rot)
        lr ,hr = self.np2tensor(lr), self.np2tensor(hr)
        
        data['LR'] = lr
        data['HR'] = hr
        return data
Ejemplo n.º 14
0
    def evaluate(self):
        opt = self.opt
        self.net.eval()

        if opt.save_result:
            save_root = os.path.join(opt.save_root, opt.dataset)
            os.makedirs(save_root, exist_ok=True)

        psnr = 0
        for i, inputs in enumerate(self.test_loader):
            HR = inputs[0].to(self.dev)
            LR = inputs[1].to(self.dev)

            # match the resolution of (LR, HR) due to CutBlur
            if HR.size() != LR.size():
                scale = HR.size(2) // LR.size(2)
                LR = F.interpolate(LR, scale_factor=scale, mode="nearest")

            SR = self.net(LR).detach()
            HR = HR[0].clamp(0, 255).round().cpu().byte().permute(1, 2,
                                                                  0).numpy()
            SR = SR[0].clamp(0, 255).round().cpu().byte().permute(1, 2,
                                                                  0).numpy()

            if opt.save_result:
                save_path = os.path.join(save_root, "{:04d}.png".format(i + 1))
                io.imsave(save_path, SR)

            HR = HR[opt.crop:-opt.crop, opt.crop:-opt.crop, :]
            SR = SR[opt.crop:-opt.crop, opt.crop:-opt.crop, :]
            if opt.eval_y_only:
                HR = utils.rgb2ycbcr(HR)
                SR = utils.rgb2ycbcr(SR)
            psnr += utils.calculate_psnr(HR, SR)

        self.net.train()

        return psnr / len(self.test_loader)
Ejemplo n.º 15
0
def evalimg(im_h_y, im_gt, shave=0):
    if len(im_gt.shape)==3:
        im_gt_ycbcr = utils.rgb2ycbcr(im_gt/255.0)*255.0
        im_gt_y = im_gt_ycbcr[:, :, 0]
    else:
        im_gt_y = im_gt

    diff = im_h_y.astype(np.uint8).astype(np.float32) - im_gt_y.astype(np.uint8).astype(np.float32)
    #diff = im_h_y - im_gt_y
    if shave>0:
        diff = utils.shave(diff, [shave, shave])
    res = {}
    res['rmse'] = np.sqrt((diff**2).mean())
    res['psnr'] = 20*np.log10(255.0/res['rmse'])
    return res
Ejemplo n.º 16
0
def evalimg(im_h_y, im_gt, shave=0):
    if len(im_gt.shape) == 3:
        im_gt_ycbcr = utils.rgb2ycbcr(im_gt / 255.0) * 255.0
        im_gt_y = im_gt_ycbcr[:, :, 0]
    else:
        im_gt_y = im_gt

    diff = im_h_y.astype(np.uint8).astype(np.float32) - im_gt_y.astype(
        np.uint8).astype(np.float32)
    #diff = im_h_y - im_gt_y
    if shave > 0:
        diff = utils.shave(diff, [shave, shave])
    res = {}
    res['rmse'] = np.sqrt((diff**2).mean())
    res['psnr'] = 20 * np.log10(255.0 / res['rmse'])
    return res
Ejemplo n.º 17
0
def evalimg(im_h_y, im_gt, shave=0):
    if len(im_gt.shape) == 3:
        im_gt_ycbcr = utils.rgb2ycbcr(im_gt / 255.0) * 255.0
        im_gt_y = im_gt_ycbcr[:, :, 0]
    else:
        im_gt_y = im_gt

    im_h_y_uint8 = np.rint(np.clip(im_h_y, 0, 255))
    im_gt_y_uint8 = np.rint(np.clip(im_gt_y, 0, 255))
    diff = im_h_y_uint8 - im_gt_y_uint8
    #diff = im_h_y - im_gt_y
    if shave > 0:
        diff = utils.shave(diff, [shave, shave])
    res = {}
    res['rmse'] = np.sqrt((diff**2).mean())
    res['psnr'] = 20 * np.log10(255.0 / res['rmse'])
    return res
Ejemplo n.º 18
0
    def slice_reconstruction(size, slice, ang_tar):
        # ---------------- Model -------------------- #
        global slice_y
        with sess.as_default():
            if FLAG_RGB:
                # slice_ycbcr = utils.rgb2ycbcr(slice)
                slice = np.transpose(slice, (1, 0, 2, 3))
                slice = np.expand_dims(slice, axis=0)

                slice_y = slice[:, :, :, :, 0:1]
                slice_cb = slice[:, :, :, :, 1:2]
                slice_cr = slice[:, :, :, :, 2:3]

                slice_y = sess.run(y_out, feed_dict={x: slice_y})
                slice_cb = sess.run(y_out, feed_dict={x: slice_cb})
                slice_cr = sess.run(y_out, feed_dict={x: slice_cr})

                slice_ycbcr = np.concatenate((slice_y, slice_cb, slice_cr),
                                             axis=-1)
                slice_ycbcr = np.transpose(slice_ycbcr[0, :, :, :, :],
                                           (1, 0, 2, 3))
                slice_ycbcr = tf.convert_to_tensor(slice_ycbcr)
                slice_ycbcr = tf.image.resize_bicubic(slice_ycbcr,
                                                      [ang_tar, size])
                slice = sess.run(slice_ycbcr)
                # slice = utils.ycbcr2rgb(slice_ycbcr)
            else:
                slice_ycbcr = utils.rgb2ycbcr(slice)
                slice_y = np.transpose(slice_ycbcr[:, :, :, 0:1], (1, 0, 2, 3))

                slice_ycbcr = tf.convert_to_tensor(slice_ycbcr)
                slice_ycbcr = tf.image.resize_bicubic(slice_ycbcr,
                                                      [ang_tar, size])
                slice_ycbcr = sess.run(slice_ycbcr)

                slice_y = np.expand_dims(slice_y, axis=0)
                slice_y = sess.run(y_out, feed_dict={x: slice_y})
                slice_y = tf.convert_to_tensor(
                    np.transpose(slice_y[0], (1, 0, 2, 3)))
                slice_y = tf.image.resize_bicubic(slice_y, [ang_tar, size])
                slice_ycbcr[:, :, :, 0:1] = sess.run(slice_y)
                slice = utils.ycbcr2rgb(slice_ycbcr)
            slice = np.minimum(np.maximum(slice, 0), 1)
        return slice
Ejemplo n.º 19
0
def evaluate(model, update_step, writer, bucket, engine):
    device = torch.device('cuda:0')
    model.eval()
    eval_paths = [os.path.join(args.eval_path, v) for v in ['Set14', 'Set5']]
    metrics_list = []

    unnorm = UnNormalize(0.5, 0.5)
    for eval_path in eval_paths:
        eval_name = os.path.basename(eval_path)
        HQ_path = os.path.join(eval_path, eval_name) + '.lmdb'
        LQ_path = os.path.join(eval_path, eval_name) + '_LQ.lmdb'
        LQ_r_path = os.path.join(eval_path, eval_name) + '_LQ_restored.lmdb'

        eval_set = ValDataset(HQ_path, LQ_path, LQ_r_path, args.scale)
        eval_loader = DataLoader(
            eval_set, batch_size=1, shuffle=False, num_workers=4)

        psrn_rgb = 0.0
        psrn_y = 0.0
        ssim_rgb = 0.0
        ssim_y = 0.0

        for i, data_dict in enumerate(eval_loader):
            img_HQ = data_dict['img_GT']
            img_LQ = data_dict['img_LQ'].to(device)
            img_LQ_r = data_dict['img_LQ_r']

            with torch.no_grad():
                # SR image range [-1, 1]
                img_SR = model(img_LQ)
                # SR image range [0, 1]
                img_SR = unnorm(img_SR)
            if i == 0:
                imgs = torch.cat([img_HQ, img_SR.detach().cpu(), img_LQ_r], dim=0)
                grid = vutils.make_grid(imgs, nrow=3, normalize=False)
                tmp_image = T.ToPILImage()(grid)
                tmp_image.save('images/tmp_image.png')
                upload_to_cloud(bucket, 'images/tmp_image.png',
                                'odesr01_04/image_progress/{}/gen_step_{}'.
                                format(eval_name, update_step * args.update_freq))
                if eval_name == 'Set5':
                    writer.add_image('Set5', grid, update_step)

            crop_size = args.scale
            img_HQ_rgb = img_HQ[0].permute(2, 1, 0).cpu(). \
                numpy()[crop_size:-crop_size, crop_size:-crop_size, :]
            img_SR_rgb = img_SR[0].permute(2, 1, 0).detach().cpu(). \
                numpy()[crop_size:-crop_size, crop_size:-crop_size, :]
            img_HQ_y = rgb2ycbcr(img_HQ_rgb)
            img_SR_y = rgb2ycbcr(img_SR_rgb)

            psrn_rgb += calculate_psnr(img_HQ_rgb * 255, img_SR_rgb * 255)
            psrn_y += calculate_psnr(img_HQ_y * 255, img_SR_y * 255)
            ssim_rgb += calculate_ssim(img_HQ_rgb * 255, img_SR_rgb * 255)
            ssim_y += calculate_ssim(img_HQ_y * 255, img_SR_y * 255)

        psrn_rgb = psrn_rgb / len(eval_loader.dataset)
        psrn_y = psrn_y / len(eval_loader.dataset)
        ssim_rgb = ssim_rgb / len(eval_loader.dataset)
        ssim_y = ssim_y / len(eval_loader.dataset)

        metrics_list.extend([psrn_rgb, psrn_y, ssim_rgb, ssim_y])

        if eval_name == 'Set5':
            writer.add_scalar('psrn_rgb', psrn_rgb, update_step)
            writer.add_scalar('psrn_y', psrn_y, update_step)
            writer.add_scalar('ssim_rgb', ssim_rgb, update_step)
            writer.add_scalar('ssim_y', ssim_y, update_step)

    query = '''
        INSERT INTO odesr01_04_val
            (set14_psnr_rgb, set14_psnr_y, set14_ssim_rgb, set14_ssim_y,
            set5_psnr_rgb, set5_psnr_y, set5_ssim_rgb, set5_ssim_y)
        VALUES (%f, %f, %f, %f, %f, %f, %f, %f)
    ''' % tuple(metrics_list)
    engine.execute(query)
    model.train()
Ejemplo n.º 20
0
if not MODEL == 'BICUBIC':
    saver = tf.train.Saver()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
if not MODEL == 'BICUBIC':
    saver.restore(sess, MODEL_CKPT_PATH)

fs = glob.glob(os.path.join(TEST_DIR, '*.bmp'))
psnrs = []
for f in fs:
    img = misc.imread(f)
    lr_img = misc.imresize(img, 1.0 / SCALE, 'bicubic')
    lr_y = utils.rgb2ycbcr(lr_img)[:, :, :1]
    lr_y = np.expand_dims(lr_y, 0).astype(np.float32) / 255.0
    start = time.clock()
    res_y = sess.run(res, feed_dict={lr: lr_y})
    end = time.clock()
    res_y = np.clip(res_y, 0, 1)[0] * 255.0
    bic_img = misc.imresize(lr_img, SCALE / 1.0, 'bicubic')

    bic_ycbcr = utils.rgb2ycbcr(bic_img)
    bic_ycbcr[:, :, :1] = res_y
    res_img = utils.img_to_uint8(utils.ycbcr2rgb(bic_ycbcr))
    img_name = f.split(os.sep)[-1]
    misc.imsave(os.path.join(OUTPUT_DIR, img_name), res_img)

    gt_y = utils.rgb2ycbcr(img)[:, :, :1]
    psnr = utils.psnr(res_y[SCALE:-SCALE, SCALE:-SCALE], gt_y[SCALE:-SCALE,