Ejemplo n.º 1
0
    def upscale(self, im_l, s):
        """
        % im_l: LR image, float np array in [0, 255]
        % im_h: HR image, float np array in [0, 255]
        """
        im_l = im_l/255.0
        if len(im_l.shape)==3 and im_l.shape[2]==3:
            im_l_ycbcr = utils.rgb2ycbcr(im_l)
        else:
            im_l_ycbcr = np.zeros([im_l.shape[0], im_l.shape[1], 3])
            im_l_ycbcr[:, :, 0] = im_l
            im_l_ycbcr[:, :, 1] = im_l
            im_l_ycbcr[:, :, 2] = im_l

        im_l_y = im_l_ycbcr[:, :, 0]*255 #[16 235]
        im_h_y = self.upscale_alg(im_l_y, s)

        # recover color
        #print 'recover color...'
        if len(im_l.shape)==3:
            im_ycbcr = utils.imresize(im_l_ycbcr, s);
            im_ycbcr[:, :, 0] = im_h_y/255.0; #[16/255 235/255]
            im_h = utils.ycbcr2rgb(im_ycbcr)*255.0
        else:
            im_h = im_h_y

        #print 'clip...'
        im_h = np.clip(im_h, 0, 255)
        im_h_y = np.clip(im_h_y, 0, 255)
        return im_h,im_h_y
Ejemplo n.º 2
0
    def upscale(self, im_l, s):
        """
        % im_l: LR image, float np array in [0, 255]
        % im_h: HR image, float np array in [0, 255]
        """
        im_l = im_l/255.0
        if len(im_l.shape)==3 and im_l.shape[2]==3:
            im_l_ycbcr = utils.rgb2ycbcr(im_l)
        else:
            im_l_ycbcr = np.zeros([im_l.shape[0], im_l.shape[1], 3])
            im_l_ycbcr[:, :, 0] = im_l
            im_l_ycbcr[:, :, 1] = im_l
            im_l_ycbcr[:, :, 2] = im_l

        im_l_y = im_l_ycbcr[:, :, 0]*255 #[16 235]
        im_h_y = self.upscale_alg(im_l_y, s)

        # recover color
        if len(im_l.shape)==3:
            im_ycbcr = utils.imresize(im_l_ycbcr, s);
            im_ycbcr[:, :, 0] = im_h_y/255.0; #[16/255 235/255]
            im_h = utils.ycbcr2rgb(im_ycbcr)*255.0
        else:
            im_h = im_h_y

        im_h = np.clip(im_h, 0, 255)
        im_h_y = np.clip(im_h_y, 0, 255)
        return im_h,im_h_y
Ejemplo n.º 3
0
    def eval(self):
        with torch.no_grad():
            output = self.model(self.eval_input_y, self.eval_input_bicu_y)
            output = quantize(output, self.opt.rgb_range)

        self.output = self.eval_input_bicu.data.clone()
        self.output[:, 0, :, :] = output
        self.output = self.output[0].cpu().permute(1, 2, 0).numpy()
        self.output = ycbcr2rgb(self.output)
        self.output = torch.from_numpy(self.output).permute(2, 0, 1)

        self.eval_input = self.eval_input[0].cpu().permute(1, 2, 0).numpy()
        self.eval_input = ycbcr2rgb(self.eval_input)
        self.eval_input = torch.from_numpy(self.eval_input).permute(2, 0, 1)

        self.eval_target = self.eval_target[0].cpu().permute(1, 2, 0).numpy()
        self.eval_target = ycbcr2rgb(self.eval_target)
        self.eval_target = torch.from_numpy(self.eval_target).permute(2, 0, 1)

        return {
            'input': self.eval_input,
            'output': self.output,
            'target': self.eval_target
        }
Ejemplo n.º 4
0
    def slice_reconstruction(size, slice, ang_tar):
        # ---------------- Model -------------------- #
        global slice_y
        with sess.as_default():
            if FLAG_RGB:
                # slice_ycbcr = utils.rgb2ycbcr(slice)
                slice = np.transpose(slice, (1, 0, 2, 3))
                slice = np.expand_dims(slice, axis=0)

                slice_y = slice[:, :, :, :, 0:1]
                slice_cb = slice[:, :, :, :, 1:2]
                slice_cr = slice[:, :, :, :, 2:3]

                slice_y = sess.run(y_out, feed_dict={x: slice_y})
                slice_cb = sess.run(y_out, feed_dict={x: slice_cb})
                slice_cr = sess.run(y_out, feed_dict={x: slice_cr})

                slice_ycbcr = np.concatenate((slice_y, slice_cb, slice_cr),
                                             axis=-1)
                slice_ycbcr = np.transpose(slice_ycbcr[0, :, :, :, :],
                                           (1, 0, 2, 3))
                slice_ycbcr = tf.convert_to_tensor(slice_ycbcr)
                slice_ycbcr = tf.image.resize_bicubic(slice_ycbcr,
                                                      [ang_tar, size])
                slice = sess.run(slice_ycbcr)
                # slice = utils.ycbcr2rgb(slice_ycbcr)
            else:
                slice_ycbcr = utils.rgb2ycbcr(slice)
                slice_y = np.transpose(slice_ycbcr[:, :, :, 0:1], (1, 0, 2, 3))

                slice_ycbcr = tf.convert_to_tensor(slice_ycbcr)
                slice_ycbcr = tf.image.resize_bicubic(slice_ycbcr,
                                                      [ang_tar, size])
                slice_ycbcr = sess.run(slice_ycbcr)

                slice_y = np.expand_dims(slice_y, axis=0)
                slice_y = sess.run(y_out, feed_dict={x: slice_y})
                slice_y = tf.convert_to_tensor(
                    np.transpose(slice_y[0], (1, 0, 2, 3)))
                slice_y = tf.image.resize_bicubic(slice_y, [ang_tar, size])
                slice_ycbcr[:, :, :, 0:1] = sess.run(slice_y)
                slice = utils.ycbcr2rgb(slice_ycbcr)
            slice = np.minimum(np.maximum(slice, 0), 1)
        return slice
Ejemplo n.º 5
0
sess.run(tf.global_variables_initializer())
if not MODEL == 'BICUBIC':
    saver.restore(sess, MODEL_CKPT_PATH)

fs = glob.glob(os.path.join(TEST_DIR, '*.bmp'))
psnrs = []
for f in fs:
    img = misc.imread(f)
    lr_img = misc.imresize(img, 1.0 / SCALE, 'bicubic')
    lr_y = utils.rgb2ycbcr(lr_img)[:, :, :1]
    lr_y = np.expand_dims(lr_y, 0).astype(np.float32) / 255.0
    start = time.clock()
    res_y = sess.run(res, feed_dict={lr: lr_y})
    end = time.clock()
    res_y = np.clip(res_y, 0, 1)[0] * 255.0
    bic_img = misc.imresize(lr_img, SCALE / 1.0, 'bicubic')

    bic_ycbcr = utils.rgb2ycbcr(bic_img)
    bic_ycbcr[:, :, :1] = res_y
    res_img = utils.img_to_uint8(utils.ycbcr2rgb(bic_ycbcr))
    img_name = f.split(os.sep)[-1]
    misc.imsave(os.path.join(OUTPUT_DIR, img_name), res_img)

    gt_y = utils.rgb2ycbcr(img)[:, :, :1]
    psnr = utils.psnr(res_y[SCALE:-SCALE, SCALE:-SCALE], gt_y[SCALE:-SCALE,
                                                              SCALE:-SCALE])
    psnrs.append(psnr)
    print(img_name, 'PSNR:', psnr, 'time:', end - start)

print('AVG PSNR:', np.mean(psnrs))