Example #1
0
def evaluate(calculate_lr_img_list,
             calculate_hr_img_list,
             pb_path,
             save_path,
             save=False):
    calculate_hr_imgs = [
        scipy.misc.imread(p, mode='RGB') for p in calculate_hr_img_list
    ]
    calculate_lr_imgs = [
        scipy.misc.imread(p, mode='RGB') for p in calculate_lr_img_list
    ]

    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with open(pb_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tf.import_graph_def(output_graph_def, name="")
        with tf.Session() as sess:
            y_image = sess.graph.get_tensor_by_name("input_image_evaluate_y:0")
            pbpr_image = sess.graph.get_tensor_by_name(
                "input_image_evaluate_pbpr:0")
            output_tensor = sess.graph.get_tensor_by_name(
                'test_sr_evaluator_i1_b0_g/target:0')
            sess.run(tf.global_variables_initializer())
            metrics = []
            for index, calculate_lr_img in enumerate(calculate_lr_imgs):
                calculate_hr_img = calculate_hr_imgs[index]
                size = calculate_lr_img.shape
                ypbpr = sc.rgb2ypbpr(calculate_lr_img / 255.0)
                x_scale = scipy.misc.imresize(
                    calculate_lr_img, [size[0] * scale, size[1] * scale],
                    interp='bicubic',
                    mode=None)
                y, pbpr = ypbpr[..., 0], sc.rgb2ypbpr(x_scale / 255)[..., 1:]
                y = np.expand_dims(y, -1)
                paras = {y_image: [y], pbpr_image: [pbpr]}

                out = sess.run(output_tensor, paras)
                out = out[0]

                out = out * 255
                out = np.clip(out, 0, 255)
                out = out.astype(np.uint8)

                if save:
                    exists_or_mkdir(save_path)
                    im = scipy.misc.toimage(out, high=255, low=0)
                    im.save(save_path + os.sep +
                            calculate_hr_img_list[index].split(
                                os.sep)[-1].replace('HR', 'SR'))
                    # imsave(save_path + os.sep + calculate_hr_img_list[index].split(os.sep)[-1].replace('HR', 'SR'), out)
                out_ycbcr = sc.rgb2ycbcr(out)
                hr_ycbcr = sc.rgb2ycbcr(calculate_hr_img)
                metrics.append(
                    calculate_metrics([out_ycbcr[:, :, 0:1]],
                                      [hr_ycbcr[:, :, 0:1]]))
            avg_psnr = sum([m[0] for m in metrics]) / len(metrics)
            avg_ssim = sum([m[1] for m in metrics]) / len(metrics)

            return avg_psnr, avg_ssim
Example #2
0
    def __getitem__(self, index):

        image_path = self.folder + self.pathes[index] + '/images/'
        image_id = os.listdir(image_path)[0]
        image_path = image_path + image_id

        ground_truth = img_as_float(imread(image_path)[..., :3])

        psf = make_kernel(kernels=self.kernels)
        blurred = blur(ground_truth.copy(), psf=psf)

        if self.noise:
            blurred = blurred + np.random.normal(scale=0.008,
                                                 size=blurred.shape)
            blurred = blurred.clip(0, 1)

        psf = torch.FloatTensor(psf)[None, ...]

        blurred_ycbcr = rgb2ycbcr(blurred)
        blurred_ycbcr[..., 0] = normalize_img(blurred_ycbcr[..., 0])

        gt_ycbcr = rgb2ycbcr(ground_truth)
        gt_ycbcr[..., 0] = normalize_img(gt_ycbcr[..., 0])

        if self.crop:
            i, j = self.get_crop_indexes(blurred_ycbcr)
            blurred_ycbcr = blurred_ycbcr[i:i + 256, j:j + 256, :]
            gt_ycbcr = gt_ycbcr[i:i + 256, j:j + 256, :]

        blurred_ycbcr = torch.FloatTensor(blurred_ycbcr)[None, ...]
        gt_ycbcr = torch.FloatTensor(gt_ycbcr)[None, ...]

        return blurred_ycbcr, gt_ycbcr, psf
Example #3
0
def validation(img_new, name, save_imgs=False, save_dir=None):
    upscale_net.eval()

    img = torch.clamp(img_new, 0, 1)
    img = torch.round(img * 255)
    reconstructed_img = upscale_net(img / 255.0)

    img = img_new * 255
    img = img.data.cpu().numpy().transpose(0, 2, 3, 1)
    img = np.uint8(img)

    reconstructed_img = torch.clamp(reconstructed_img, 0, 1) * 255
    reconstructed_img = reconstructed_img.data.cpu().numpy().transpose(
        0, 2, 3, 1)
    reconstructed_img = np.uint8(reconstructed_img)

    orig_img = img[0, ...].squeeze()
    recon_img = reconstructed_img[0, ...].squeeze()

    if save_imgs and save_dir:
        img = Image.fromarray(orig_img)
        img.save(os.path.join(save_dir, name + '_orig.png'))

        img = Image.fromarray(recon_img)
        img.save(os.path.join(save_dir, name + '_recon.png'))

    orig_img_y = rgb2ycbcr(orig_img)[:, :, 0]
    recon_img_y = rgb2ycbcr(recon_img)[:, :, 0]
    orig_img_y = orig_img_y[SCALE:-SCALE, SCALE:-SCALE, ...]
    recon_img_y = recon_img_y[SCALE:-SCALE, SCALE:-SCALE, ...]
Example #4
0
def test(epoch):
    from skimage.measure import compare_psnr
    from skimage import color
    avg_psnr = 0
    for batch in eval_data_loader:
        # input, target = batch[1], batch[2]
        input, target = batch[0], batch[1]
        with torch.no_grad():
            input = Variable(input)
            target = Variable(target[0])
        input = input.cuda(gpu_lists[0])
        target = target.cuda(gpu_lists[0])

        with torch.no_grad():
            prediction = model(input)
        target = color.rgb2ycbcr(target.cpu().data.numpy().transpose(1, 2,
                                                                     0))[...,
                                                                         0]
        prediction = color.rgb2ycbcr(prediction.cpu().data.squeeze().clamp(
            0, 1).numpy().transpose(1, 2, 0))[..., 0]
        psnr = compare_psnr(target, prediction, 255.)
        avg_psnr += psnr
    writer.add_scalar('avg_psnr',
                      avg_psnr / len(eval_data_loader),
                      global_step=epoch)
    print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(eval_data_loader)))
Example #5
0
def single_channel(img, method):
    """Convert image to single channel

    :param img: PIL image
    :param method: single channel method
    :return: single channel iamge
    """
    img = np.asarray(img)

    if method == 'cb':
        ycbcr = color.rgb2ycbcr(img)
        gray = ycbcr[:, :, 1]

    elif method == 'cr':
        ycbcr = color.rgb2ycbcr(img)
        cr = ycbcr[:, :, 2]
        gray = np.max(cr) - cr

    elif method == 'cbcr':
        ycbcr = color.rgb2ycbcr(img)
        cb = ycbcr[:, :, 1]
        cr = ycbcr[:, :, 2]
        gray = cr + (np.max(cb) - cb)

    elif method == 'sat':
        hsv = color.rgb2hsv(img)
        gray = hsv[:, :, 1]

    else:
        hsv = color.rgb2hsv(img)
        gray = hsv[:, :, 2]

    return gray
Example #6
0
def PreProcess(img, params):
    img_ds = ImgDownSample(img, params.blur_kernel, params.sr_scale)
    # nrow = np.shape(img_ds)[0]
    # ncol = np.shape(img_ds)[1]
    nchl = np.shape(img_ds)[2]
    if nchl == 3:
        img_hr = rgb2ycbcr(img)
        img_lr = rgb2ycbcr(img_ds)
        img_hr = float(img_hr[:, :, 0])
        img_lr = float(img_lr[:, :, 0])
    else:
        img_hr = float(img)
        img_lr = float(img)

    nchannels_feat = len(params.lr_filters)
    img_rs = resize(img_lr, params.lr_feat_scale)
    nr = np.shape(img_rs)[0]
    nc = np.shape(img_rs)[1]
    feat_lr = np.zeros((nr, nc, nchannels_feat))

    for i in range(len(nchannels_feat)):
        feat_lr[:, :, i] = signal.convolve2d(img_rs, params.lr_filters[i],
                                             "same")

    return img, img_ds, img_hr, img_lr, feat_lr
Example #7
0
def valid():
    model.eval()
    avg_psnr, avg_ssim = 0, 0
    for i, batch in enumerate(testing_data_loader):
        lr_tensor, hr_tensor = batch[0], batch[1]
        if args.cuda:
            lr_tensor = lr_tensor.to(device)
            hr_tensor = hr_tensor.to(device)

        with torch.no_grad():
            pre = model(lr_tensor)

        sr_img = utils.tensor2np(pre.detach()[0])
        gt_img = utils.tensor2np(hr_tensor.detach()[0])
        crop_size = args.scale
        cropped_sr_img = utils.shave(sr_img, crop_size)
        cropped_gt_img = utils.shave(gt_img, crop_size)
        if args.isY is True:
            im_label = utils.quantize(sc.rgb2ycbcr(cropped_gt_img)[:, :, 0])
            im_pre = utils.quantize(sc.rgb2ycbcr(cropped_sr_img)[:, :, 0])
        else:
            im_label = cropped_gt_img
            im_pre = cropped_sr_img

        psnr = utils.compute_psnr(im_pre, im_label)
        ssim = utils.compute_ssim(im_pre, im_label)

        avg_psnr += psnr
        avg_ssim += ssim
        print(
            f" Valid {i}/{len(testing_data_loader)} with PSNR = {psnr} and SSIM = {ssim}"
        )
    print("===> Valid. psnr: {:.4f}, ssim: {:.4f}".format(
        avg_psnr / len(testing_data_loader),
        avg_ssim / len(testing_data_loader)))
Example #8
0
    def __getitem__(self, i):
        if (self.lazy_load):
            GT = np.array(
                Image.open(
                    os.path.join(self.GT_path,
                                 self.GT_img[i %
                                             self.real_size])).convert('RGB'))
        else:
            if (i % self.real_size == 0):
                random.shuffle(self.GT_img)
            GT = self.GT_img[i % self.real_size].astype(np.float32)
            #GT = self.GT_img[np.random.randint(self.real_size)].astype(np.float32)

        if (self.transform is not None):
            for tr in self.transform:
                GT = tr(GT)

        LR = self.LR_transform(GT)
        if (not self.rgb):
            GT = GT.astype(np.float32) / 255.
            LR = LR.astype(np.float32) / 255.
            GT = rgb2ycbcr(GT)[:, :, :1]
            LR = rgb2ycbcr(LR)[:, :, :1]

        img_item = {}
        img_item['GT'] = GT.transpose(2, 0, 1).astype(np.float32) / 255.
        img_item['LR'] = LR.transpose(2, 0, 1).astype(np.float32) / 255.

        return img_item
Example #9
0
def eval_psnr_and_ssim(im1, im2, scale):
    im1_t = np.atleast_3d(img_as_float(im1))
    im2_t = np.atleast_3d(img_as_float(im2))

    if im1_t.shape[2] == 1 or im2_t.shape[2] == 1:
        im1_t = im1_t[..., 0]
        im2_t = im2_t[..., 0]

    else:
        im1_t = rgb2ycbcr(im1_t)[:, :, 0:1] / 255.0
        im2_t = rgb2ycbcr(im2_t)[:, :, 0:1] / 255.0

    if scale > 1:
        im1_t = mod_crop(im1_t, scale)
        im2_t = mod_crop(im2_t, scale)

        # NOTE conventionally, crop scale+6 pixels (EDSR, VDSR etc)
        im1_t = crop_boundaries(im1_t, int(scale))
        im2_t = crop_boundaries(im2_t, int(scale))

    psnr_val = compare_psnr(im1_t, im2_t)
    ssim_val = compare_ssim(im1_t,
                            im2_t,
                            win_size=11,
                            gaussian_weights=True,
                            multichannel=True,
                            data_range=1.0,
                            K1=0.01,
                            K2=0.03,
                            sigma=1.5)

    return psnr_val, ssim_val
Example #10
0
    def valid(self):
        scale = self.scale
        data_files = glob.glob(
            os.path.join(
                self.data_file_path + os.sep + 'image_SRF_{}'.format(scale),
                '*_HR{}'.format(self.args.image_form)))

        metrics_1, metrics_2 = [], []
        for data in range(len(data_files)):
            img = Image.open(data_files[data])
            if img.mode != 'RGB':
                continue
            (width, height) = img.size
            (width_in, height_in) = width // scale, height // scale
            (width_lb, height_lb) = width_in * scale, height_in * scale
            label_ = img.resize((width_lb, height_lb),
                                Image.ANTIALIAS)  # as label
            input_ = img.resize((width_in, height_in),
                                Image.ANTIALIAS)  # as input
            valid_ = input_.resize((width_lb, height_lb),
                                   Image.ANTIALIAS)  # as bicubic

            label_ = np.array(list(label_.getdata())).astype(
                np.float32).reshape([height_lb, width_lb, -1]) / 255
            input_ = np.array(list(input_.getdata())).astype(
                np.float32).reshape([height_in, width_in, -1]) / 255
            valid_ = np.array(list(valid_.getdata())).astype(
                np.float32).reshape([height_lb, width_lb, -1]) / 255

            feed_input = input_[np.newaxis, :]
            click = time.time()
            feed_dict = {self.net.input_: feed_input}
            output = self.sess.run(self.net.output, feed_dict=feed_dict)[0]
            print('Process image with shape: {:d} x {:d}, take time: {:.3f}s'.
                  format(height_in, width_in,
                         time.time() - click))
            if self.args.save:
                array_image_save(
                    output * 255,
                    os.path.join(self.out_dir,
                                 '{}_{}.png'.format(self.args.dataset, data)))
            output = np.clip((output * 255), 0, 255).astype(np.uint8)
            valid_ = np.clip((valid_ * 255), 0, 255).astype(np.uint8)
            label_ = np.clip((label_ * 255), 0, 255).astype(np.uint8)

            output_ycbcr = sc.rgb2ycbcr(output)
            hr_ycbcr = sc.rgb2ycbcr(label_)
            valid_ycbcr = sc.rgb2ycbcr(valid_)
            metrics_1.append(
                calculate_metrics([output_ycbcr[:, :, 0:1]],
                                  [hr_ycbcr[:, :, 0:1]]))
            metrics_2.append(
                calculate_metrics([valid_ycbcr[:, :, 0:1]],
                                  [hr_ycbcr[:, :, 0:1]]))
            img.close()
            avg_psnr1 = sum(m[0] for m in metrics_1) / len(metrics_1)
            avg_ssim1 = sum(m[1] for m in metrics_1) / len(metrics_1)
            avg_psnr2 = sum(m[0] for m in metrics_2) / len(metrics_2)
            avg_ssim2 = sum(m[1] for m in metrics_2) / len(metrics_2)
        return [avg_psnr1, avg_ssim1], [avg_psnr2, avg_ssim2]
def gen_hdf5(label_path, input_path, outfile_name):
    count = 0
    scale = 2
    #stride = 14
    stride = 36
    size_input = 41
    size_label = 41
    data = np.zeros((5000000,1, size_input,size_input), np.float)
    label = np.zeros((5000000,1, size_label,size_label), np.float)

    ### load info
    for f in os.popen('ls ' + label_path + "/*"):
        label_name = f.strip()
        img_name=label_name.split('/')[-1]
        input_name = input_path + "/" + img_name
#        img_input = io.imread(input_name)
        try:
            img_org = io.imread(label_name)
            img_input = io.imread(input_name.strip())
            img_input = img_input[:,:,:3]
            #img_input = io.imread(label_name)
        except :
            print "err read:" + input_name
            continue 
        sizeinfo = img_org.shape
        if len(sizeinfo) != 3 or len(img_input.shape) != 3:
            print "err size:" + img_name
            continue
        [h,w] = img_org.shape[:2]
        img_label = color.rgb2ycbcr(img_org).astype(np.float)/255.0
        img_input = cv2.resize(img_input, (w,h))
        img_input = color.rgb2ycbcr(img_input).astype(np.float)/255.0
        im_label = img_label[:,:,0]
        im_input = img_input[:,:,0]
 
        #for x in range(1, h - size_input - 1, stride):
        #    for y in range(1, w - size_input - 1, stride):
        for x in range(size_input+1, h - size_input - size_input -1 , stride):
            for y in range(size_input+1, w - size_input - size_input -1, stride):
                submit_input = im_input[x:x+size_input, y:y+size_input]
                submit_label = im_label[x:x+size_input, y:y+size_input]
                data[count,0,:,:]= submit_input
                label[count,0,:,:] = submit_label
                count = count + 1
    print ("count %s" % count)
    print "END LOOP IMAGE"
    order = random.sample(range(count),count)

    data = data[order,:,:,:]
    label = label[order,:,:,:]
   
    ## write
    chunksz = 64
    setname, ext = outfile_name.split('.')
    h5_filename = '{}.h5'.format(setname)
    last_read = 0
    with h5py.File(h5_filename, 'w') as h:
        h.create_dataset('data', data=data, chunks=(64,1,size_input,size_input))
        h.create_dataset('label', data=label, chunks=(64,1,size_label, size_label))
Example #12
0
def test(args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = mydata(GT_path=args.GT_path,
                     LR_path=args.LR_path,
                     in_memory=False,
                     transform=None)
    loader = DataLoader(dataset,
                        batch_size=1,
                        shuffle=False,
                        num_workers=args.num_workers)

    generator = Generator(img_feat=3,
                          n_feats=64,
                          kernel_size=3,
                          num_block=args.res_num)
    generator.load_state_dict(
        torch.load(args.generator_path, map_location=device))
    generator = generator.to(device)
    generator.eval()

    f = open('./result.txt', 'w')
    psnr_list = []

    with torch.no_grad():
        for i, te_data in enumerate(loader):
            gt = te_data['GT'].to(device)
            lr = te_data['LR'].to(device)

            bs, c, h, w = lr.size()
            gt = gt[:, :, :h * args.scale, :w * args.scale]

            output, _ = generator(lr)

            output = output[0].cpu().numpy()
            output = np.clip(output, -1.0, 1.0)
            gt = gt[0].cpu().numpy()

            output = (output + 1.0) / 2.0
            gt = (gt + 1.0) / 2.0

            output = output.transpose(1, 2, 0)
            gt = gt.transpose(1, 2, 0)

            y_output = rgb2ycbcr(output)[args.scale:-args.scale,
                                         args.scale:-args.scale, :1]
            y_gt = rgb2ycbcr(gt)[args.scale:-args.scale,
                                 args.scale:-args.scale, :1]

            psnr = compare_psnr(y_output / 255.0, y_gt / 255.0, data_range=1.0)
            psnr_list.append(psnr)
            f.write('psnr : %04f \n' % psnr)

            result = Image.fromarray((output * 255.0).astype(np.uint8))
            result.save('output/res_%04d.png' % i)

        f.write('avg psnr : %04f' % np.mean(psnr_list))
Example #13
0
def set_channel(img_in, img_tar, n_channel):
    (h, w, c) = img_tar.shape
    if n_channel == 1 and c == 3:
        img_in = np.expand_dims(sc.rgb2ycbcr(img_in)[:, :, 0], 2)
        img_tar = np.expand_dims(sc.rgb2ycbcr(img_tar)[:, :, 0], 2)
    elif n_channel == 3 and c == 1:
        img_in = np.concatenate([img_in] * n_channel, 2)
        img_tar = np.concatenate([img_tar] * n_channel, 2)

    return img_in, img_tar
Example #14
0
def setChannel(imgIn, imgTar, nChannel):
    (h, w, c) = imgTar.shape
    if nChannel == 1 and c == 3:
        imgIn = np.expand_dims(sc.rgb2ycbcr(imgIn)[:, :, 0], 2)
        imgTar = np.expand_dims(sc.rgb2ycbcr(imgTar)[:, :, 0], 2)
    elif nChannel == 3 and c == 1:
        imgIn = np.concatenate([imgIn] * nChannel, 2)
        imgTar = np.concatenate([imgTar] * nChannel, 2)

    return imgIn, imgTar
Example #15
0
def validation(img, name, save_imgs=False, save_dir=None):
    kernel_generation_net.eval()
    downsampler_net.eval()
    upscale_net.eval()

    kernels, offsets_h, offsets_v = kernel_generation_net(img)
    downscaled_img = downsampler_net(img, kernels, offsets_h, offsets_v,
                                     OFFSET_UNIT)
    downscaled_img = torch.clamp(downscaled_img, 0, 1)
    downscaled_img = torch.round(downscaled_img * 255)

    reconstructed_img = upscale_net(downscaled_img / 255.0)
    # reconstructed_img = upscale_net(reconstructed_img)

    img = img * 255
    img = img.data.cpu().numpy().transpose(0, 2, 3, 1)
    img = np.uint8(img)

    reconstructed_img = torch.clamp(reconstructed_img, 0, 1) * 255
    reconstructed_img = reconstructed_img.data.cpu().numpy().transpose(
        0, 2, 3, 1)
    reconstructed_img = np.uint8(reconstructed_img)

    downscaled_img = downscaled_img.data.cpu().numpy().transpose(0, 2, 3, 1)
    downscaled_img = np.uint8(downscaled_img)

    orig_img = img[0, ...].squeeze()
    downscaled_img = downscaled_img[0, ...].squeeze()
    recon_img = reconstructed_img[0, ...].squeeze()

    if save_imgs and save_dir:
        img = Image.fromarray(orig_img)
        img.save(os.path.join(save_dir, name + "_orig.png"))

        img = Image.fromarray(downscaled_img)
        img.save(os.path.join(save_dir, name + "_down.png"))

        img = Image.fromarray(recon_img)
        img.save(os.path.join(save_dir, name + "_recon.png"))

    # psnr = utils.cal_psnr(
    #     orig_img[SCALE:-SCALE, SCALE:-SCALE, ...],
    #     recon_img[SCALE:-SCALE, SCALE:-SCALE, ...],
    #     benchmark=BENCHMARK,
    # )

    orig_img_y = rgb2ycbcr(orig_img)[:, :, 0]
    recon_img_y = rgb2ycbcr(recon_img)[:, :, 0]
    orig_img_y = orig_img_y[SCALE:-SCALE, SCALE:-SCALE, ...]
    recon_img_y = recon_img_y[SCALE:-SCALE, SCALE:-SCALE, ...]

    # ssim = utils.calc_ssim(recon_img_y, orig_img_y)

    # return
    pass
Example #16
0
def set_channel1(img, n_channels=1):
    if img.ndim == 2:
        img = np.expand_dims(img, axis=2)
    c = img.shape[2]
    if n_channels == 1 and c == 3:
        cb = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 1], 2)
        cr = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 2], 2)
        img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
    elif n_channels == 3 and c == 1:
        img = np.concatenate([img] * n_channels, 2)
    return img, cb, cr
Example #17
0
 def test_yuv(self):
     rgb = np.array([[[1.0, 1.0, 1.0]]])
     assert_array_almost_equal(rgb2yuv(rgb), np.array([[[1, 0, 0]]]))
     assert_array_almost_equal(rgb2yiq(rgb), np.array([[[1, 0, 0]]]))
     assert_array_almost_equal(rgb2ypbpr(rgb), np.array([[[1, 0, 0]]]))
     assert_array_almost_equal(rgb2ycbcr(rgb), np.array([[[235, 128, 128]]]))
     rgb = np.array([[[0.0, 1.0, 0.0]]])
     assert_array_almost_equal(rgb2yuv(rgb), np.array([[[0.587, -0.28886916, -0.51496512]]]))
     assert_array_almost_equal(rgb2yiq(rgb), np.array([[[0.587, -0.27455667, -0.52273617]]]))
     assert_array_almost_equal(rgb2ypbpr(rgb), np.array([[[0.587, -0.331264, -0.418688]]]))
     assert_array_almost_equal(rgb2ycbcr(rgb), np.array([[[144.553, 53.797, 34.214]]]))
Example #18
0
 def test_yuv(self):
     rgb = np.array([[[1.0, 1.0, 1.0]]])
     assert_array_almost_equal(rgb2yuv(rgb), np.array([[[1, 0, 0]]]))
     assert_array_almost_equal(rgb2yiq(rgb), np.array([[[1, 0, 0]]]))
     assert_array_almost_equal(rgb2ypbpr(rgb), np.array([[[1, 0, 0]]]))
     assert_array_almost_equal(rgb2ycbcr(rgb), np.array([[[235, 128, 128]]]))
     rgb = np.array([[[0.0, 1.0, 0.0]]])
     assert_array_almost_equal(rgb2yuv(rgb), np.array([[[0.587, -0.28886916, -0.51496512]]]))
     assert_array_almost_equal(rgb2yiq(rgb), np.array([[[0.587, -0.27455667, -0.52273617]]]))
     assert_array_almost_equal(rgb2ypbpr(rgb), np.array([[[0.587, -0.331264, -0.418688]]]))
     assert_array_almost_equal(rgb2ycbcr(rgb), np.array([[[144.553,   53.797,   34.214]]]))
Example #19
0
def calc_test_ssim(imGT, imSR, scale):
    if len(imGT.shape) > 2 and imGT.shape[2] > 1:
        imGT = sc.rgb2ycbcr(imGT)[..., 0]
    if len(imSR.shape) > 2 and imSR.shape[2] > 1:
        imSR = sc.rgb2ycbcr(imSR)[..., 0]

    imGT = shave(imGT, [scale, scale])
    imSR = shave(imSR, [scale, scale])

    cur_ssim = ssim(imGT, imSR)

    return cur_ssim
Example #20
0
def FolderTo4DLF(path, ext, length):
    path_str = path + '/*.' + ext
    log.info('-' * 40)
    log.info('Loading %s files from %s' % (ext, path))
    img_data = io.ImageCollection(path_str)
    if len(img_data) == 0:
        raise IOError('No .%s file in this folder' % ext)
    # print(len(img_data))
    # print img_data[3].shape
    N = int(sqrt(len(img_data)))
    if not (N**2 == len(img_data)):
        raise ValueError('This folder does not have n^2 images!')
    if len(img_data[0].shape) == 3:
        [height, width, channel] = img_data[0].shape
        print(channel, ' Channels, RGB input.')
    elif len(img_data[0].shape) == 2:
        [height, width] = img_data[0].shape
        print('1 Channels, Grayscale input.')
        channel = 1
    else:
        raise ValueError('Not 1 or 3 channels')

    lf_shape = (N, N, height, width, 3)
    log.info('Initial LF shape: ' + str(lf_shape))
    border = int((N - length) / 2)
    if border < 0:
        raise ValueError('Border {0} < 0'.format(border))
    out_lf_shape = (length, length, height, width, 3)
    log.info('Output LF shape: ' + str(out_lf_shape))
    lf = np.zeros(out_lf_shape).astype(np.float32)
    # save_path = './DATA/train/001/Coll/'
    for i in range(border, N - border, 1):
        for j in range(border, N - border, 1):
            indx = j + i * N
            if channel == 3:
                im = color.rgb2ycbcr(np.uint8(img_data[indx]))
            else:
                gray_img = np.uint8(img_data[indx])
                rgb_im = np.stack([gray_img, gray_img, gray_img], axis=2)
                im = color.rgb2ycbcr(rgb_im)

            lf[i - border, j - border, :, :, 0] = im[:, :, 0] / 255.0
            lf[i - border, j - border, :, :, 1:3] = im[:, :, 1:3]
            # io.imsave(save_path+str(indx)+'.png',img_data[indx])
    log.info('LF Range:')
    log.info('Channel 1 [%.2f %.2f]' %
             (lf[:, :, 0, :, :].max(), lf[:, :, 0, :, :].min()))
    log.info('Channel 2 [%.2f %.2f]' %
             (lf[:, :, 1, :, :].max(), lf[:, :, 1, :, :].min()))
    log.info('Channel 3 [%.2f %.2f]' %
             (lf[:, :, 2, :, :].max(), lf[:, :, 2, :, :].min()))
    log.info('--------------------')
    return lf
Example #21
0
def predict(images, session=None, network=None, targets=None, border=0):
    session_passed = session is not None

    if not session_passed:
        session = tf.Session()

    if network is None:
        network = load_model(session)

    predictions = []

    if targets is not None:
        psnr = []

    for i in range(len(images)):
        image = images[i]

        if len(image.shape) == 3:
            image_ycbcr = color.rgb2ycbcr(image)
            image_y = image_ycbcr[:, :, 0]
        else:
            image_y = image.copy()

        image_y = image_y.astype(np.float) / 255
        reshaped_image_y = np.array([np.expand_dims(image_y, axis=2)])
        prediction = network.output.eval(feed_dict={network.input: reshaped_image_y}, session=session)[0]
        prediction *= 255

        if targets is not None:
            if len(targets[i].shape) == 3:
                target_y = color.rgb2ycbcr(targets[i])[:, :, 0]
            else:
                target_y = targets[i].copy()

            psnr.append(utils.psnr(prediction[border:-border, border:-border, 0],
                                   target_y[border:-border, border:-border], maximum=255.0))

        if len(image.shape) == 3:
            prediction = color.ycbcr2rgb(np.concatenate((prediction, image_ycbcr[:, :, 1:3]), axis=2)) * 255
        else:
            prediction = prediction[:, :, 0]

        prediction = np.clip(prediction, 0, 255).astype(np.uint8)
        predictions.append(prediction)

    if not session_passed:
        session.close()

    if targets is not None:
        return predictions, psnr
    else:
        return predictions
Example #22
0
    def __getitem__(self, index):
        img_H = io.imread(self.label_root + str(index + 1) + '.png')
        img_L = io.imread(self.data_root + str(index + 1) + '.png')

        img_H_ycbcr = color.rgb2ycbcr(img_H)
        img_L_ycbcr = color.rgb2ycbcr(img_L)
        img_H_y = img_H_ycbcr[:, :, 0] / 255
        img_L_y = img_L_ycbcr[:, :, 0] / 255

        label = torch.FloatTensor(img_H_y).unsqueeze(0)
        LR_image = torch.FloatTensor(img_L_y).unsqueeze(0)

        return LR_image, label
Example #23
0
def main():
    ## data
    print('Loading data...')
    test_hr_path = os.path.join('data/', dataset)
    if dataset == 'Set5':
        ext = '*.bmp'
    else:
        ext = '*.png'
    hr_paths = sorted(glob.glob(os.path.join(test_hr_path, ext)))

    ## model
    print('Loading model...')
    tensor_lr = tf.placeholder('float32', [1, None, None, 3], name='tensor_lr')
    tensor_b = tf.placeholder('float32', [1, None, None, 3], name='tensor_b')

    tensor_sr = IDN(tensor_lr, tensor_b, scale)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess, model_path)

    ## result
    save_path = os.path.join(saved_path, dataset + '/x' + str(scale))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    psnr_score = 0
    for i, _ in enumerate(hr_paths):
        print('processing image %d' % (i + 1))
        img_hr = utils.modcrop(misc.imread(hr_paths[i]), scale)
        img_lr = utils.downsample_fn(img_hr, scale=scale)
        img_b = utils.upsample_fn(img_lr, scale=scale)
        [lr, b] = utils.datatype([img_lr, img_b])
        lr = lr[np.newaxis, :, :, :]
        b = b[np.newaxis, :, :, :]
        [sr] = sess.run([tensor_sr], {tensor_lr: lr, tensor_b: b})
        sr = utils.quantize(np.squeeze(sr))
        img_sr = utils.shave(sr, scale)
        img_hr = utils.shave(img_hr, scale)
        if not rgb:
            img_pre = utils.quantize(sc.rgb2ycbcr(img_sr)[:, :, 0])
            img_label = utils.quantize(sc.rgb2ycbcr(img_hr)[:, :, 0])
        else:
            img_pre = img_sr
            img_label = img_hr
        psnr_score += utils.compute_psnr(img_pre, img_label)
        misc.imsave(os.path.join(save_path, os.path.basename(hr_paths[i])), sr)

    print('Average PSNR: %.4f' % (psnr_score / len(hr_paths)))
    print('Finish')
Example #24
0
def val(net1, epoch, i):
    with torch.no_grad():
        psnr_ac = 0
        ssim_ac = 0
        for j in range(5):
            # RGB array:[3, H, W]
            label = io.imread('./data/test_data/bicubic/X%d/'%opt.scaling_factor + 'img_00' + str(j + 1) + '_SRF_%d_HR'%opt.scaling_factor + '.png')
            test = io.imread('./data/test_data/bicubic/X%d/'%opt.scaling_factor + 'img_00' + str(j + 1) + '_SRF_%d_LR'%opt.scaling_factor + '.png')

            label_ycbcr = color.rgb2ycbcr(label)
            test_ycbcr = color.rgb2ycbcr(test)
            label_y = label_ycbcr[:, :, 0] / 255
            test_y = test_ycbcr[:, :, 0] / 255

            label_cb = label_ycbcr[:, :, 1]
            label_cr = label_ycbcr[:, :, 2]

            label = torch.FloatTensor(label_y).unsqueeze(0).unsqueeze(0).cuda()
            test = torch.FloatTensor(test_y).unsqueeze(0).unsqueeze(0).cuda()

            output = net1(test)
            output = torch.clamp(output, 0.0, 1.0)
            loss = (output*255 - label*255).pow(2).sum() / (output.shape[2]*output.shape[3])
            psnr = 10*np.log10(255*255 / loss.item())

            output = output.squeeze(0).squeeze(0).cpu()
            label = label.squeeze(0).squeeze(0).cpu()

            output_array = np.array(output * 255).astype(np.float32)
            label_array = np.array(label * 255).astype(np.float32)
            ssim = measure.compare_ssim(output_array, label_array, data_range=255)

            psnr_ac += psnr
            ssim_ac += ssim

        # every 500 batches save test output
        if i%500 == 0:
            # synthesize SR image
            SR_image = np.zeros([*label_array.shape, 3])
            SR_image[:, :, 0] = output_array
            SR_image[:, :, 1] = label_cb
            SR_image[:, :, 2] = label_cr
            # SR_image = SR_image.astype(np.uint8)
            save_index = str(int(epoch*(opt.num_data/opt.batch_size/500) + (i+1)/500))
            SR_image = color.ycbcr2rgb(SR_image)*255
            SR_image = np.clip(SR_image, a_min=0., a_max=255.)
            SR_image = SR_image.astype(np.uint8)
            io.imsave('./data/test_data/bicubic/X%d/test_output/'%opt.scaling_factor + save_index + '.png', SR_image)

    return loss, psnr_ac/5, ssim_ac/5
Example #25
0
def compute_y_psnr(img_gt_rgb, img_out_rgb):
    # images must be in range [-1, 1] float or double
    peak = 255
    img_gt_rgb = np.squeeze(img_gt_rgb)
    img_out_rgb = np.squeeze(img_out_rgb)
    img_gt_rgb = np.clip((img_gt_rgb + 1.) / 2. * 255., 0, 255).round()
    img_out_rgb = np.clip((img_out_rgb + 1.) / 2. * 255., 0, 255).round()

    img_gt_yuv = color.rgb2ycbcr(img_gt_rgb.astype('uint8'))
    img_out_yuv = color.rgb2ycbcr(img_out_rgb.astype('uint8'))
    img_gt_yuv = np.clip(img_gt_yuv[:, :, 0], 0, 255).round()
    img_out_yuv = np.clip(img_out_yuv[:, :, 0], 0, 255).round()
    psnr = compute_psnr(img_gt_yuv, img_out_yuv, peak)
    return psnr
Example #26
0
def compute_psnr(im1, im2, shave_border):
    if len(im1.shape) == 3:
        im1 = rgb2ycbcr(im1)[:, :, 0]
    if len(im2.shape) == 3:
        im2 = rgb2ycbcr(im2)[:, :, 0]

    im1 = shave(im1, shave_border)
    im2 = shave(im2, shave_border)
    im1 = np.clip(im1.astype(np.float32), 0.0, 1.0)
    im2 = np.clip(im2.astype(np.float32), 0.0, 1.0)
    psnr = compare_psnr(im1, im2)
    ssim = compare_ssim(im1, im2)

    return psnr, ssim
Example #27
0
 def prepare_image(self, batches):
     for batch in batches:
         lo_res = []
         hi_res = []
         for lo_res_img, hi_res_img in zip(batch[0], batch[1]):
             lo_res_ycbcr = color.rgb2ycbcr(lo_res_img.astype('uint8'))
             hi_res_ycbcr = color.rgb2ycbcr(hi_res_img.astype('uint8'))
             lo_res.append(lo_res_ycbcr[..., 0])
             hi_res.append(hi_res_ycbcr[..., 0])
         lo_res = np.expand_dims(np.asarray(lo_res),
                                 1).astype('float32') / 255.
         hi_res = np.expand_dims(np.asarray(hi_res),
                                 1).astype('float32') / 255.
         yield lo_res, hi_res
Example #28
0
def calc_test_psnr(imGT, imSR, scale):
    if len(imGT.shape) > 2 and imGT.shape[2] > 1:
        imGT = sc.rgb2ycbcr(imGT)[..., 0]
    if len(imSR.shape) > 2 and imSR.shape[2] > 1:
        imSR = sc.rgb2ycbcr(imSR)[..., 0]

    imGT = shave(imGT, [scale, scale])
    imSR = shave(imSR, [scale, scale])

    imGT = imGT / 255.0
    imSR = imSR / 255.0
    cur_psnr = psnr(imGT, imSR)

    return cur_psnr
Example #29
0
File: main.py Project: qynan/DASAA
def test(epoch):
    from skimage.measure import compare_psnr
    from skimage import color
    avg_psnr = 0
    avg_mse = 0
    avg_ssim = 0
    model.eval()

    for batch in eval_data_loader:
        input, target = batch[0], batch[1]

        with torch.no_grad():
            input = Variable(input)
            target = Variable(target)

        input = input.cuda()
        target = target.cuda()
        mse = nn.MSELoss().cuda()

        with torch.no_grad():

            prediction = model(input)

            mse_loss = mse(prediction, target)

            ssim = pytorch_ssim.ssim(target, prediction)
            avg_ssim += ssim

            target = color.rgb2ycbcr(
                target.cpu().data.squeeze().numpy().transpose(1, 2, 0))[..., 0]
            prediction = color.rgb2ycbcr(prediction.cpu().data.squeeze().clamp(
                0, 1).numpy().transpose(1, 2, 0))[..., 0]
            psnr = compare_psnr(target, prediction, 255.)
            avg_psnr += psnr
            avg_mse += mse_loss

    writer.add_scalar('avg_ssim',
                      avg_ssim / len(eval_data_loader),
                      global_step=epoch)
    writer.add_scalar('avg_psnr',
                      avg_psnr / len(eval_data_loader),
                      global_step=epoch)
    writer.add_scalar('avg_mse',
                      avg_mse / len(eval_data_loader),
                      global_step=epoch)

    print("===> Avg. SSIM: {:.4f} dB".format(avg_ssim / len(eval_data_loader)))
    print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(eval_data_loader)))
    print("===> Avg. MSE: {:.4f}".format(avg_mse / len(eval_data_loader)))
Example #30
0
 def test_yuv_roundtrip(self):
     img_rgb = img_as_float(self.img_rgb)[::16, ::16]
     assert_array_almost_equal(yuv2rgb(rgb2yuv(img_rgb)), img_rgb)
     assert_array_almost_equal(yiq2rgb(rgb2yiq(img_rgb)), img_rgb)
     assert_array_almost_equal(ypbpr2rgb(rgb2ypbpr(img_rgb)), img_rgb)
     assert_array_almost_equal(ycbcr2rgb(rgb2ycbcr(img_rgb)), img_rgb)
     assert_array_almost_equal(ydbdr2rgb(rgb2ydbdr(img_rgb)), img_rgb)
 def luminance(self, image):
     # Get luminance
     lum = rgb2ycbcr(image)[:, :, 0]
     # Crop off 4 border pixels
     lum = lum[4:lum.shape[0] - 4, 4:lum.shape[1] - 4]
     # lum = lum.astype(np.float64)
     return lum
 def test_yuv_roundtrip(self):
     img_rgb = img_as_float(self.img_rgb)[::16, ::16]
     assert_array_almost_equal(yuv2rgb(rgb2yuv(img_rgb)), img_rgb)
     assert_array_almost_equal(yiq2rgb(rgb2yiq(img_rgb)), img_rgb)
     assert_array_almost_equal(ypbpr2rgb(rgb2ypbpr(img_rgb)), img_rgb)
     assert_array_almost_equal(ycbcr2rgb(rgb2ycbcr(img_rgb)), img_rgb)
     assert_array_almost_equal(ydbdr2rgb(rgb2ydbdr(img_rgb)), img_rgb)
Example #33
0
def rgb2ycbcr():
    name="rgb2ycbcr"
    inputs = [ ( 255, 0, 0 ),
               ( 0, 255, 0 ),
               ( 0, 0, 255 ),
               ( 12, 56, 43 ) ]
    for i in xrange( 124 ):
        inputs.append( ( random.randint( 0, 255 ), random.randint( 0, 255 ), random.randint( 0, 255 ) ) )
    with open("input", "wt") as fi, open("expected", "wt" ) as fe:
        counter = 0
        for i in inputs:
            pixel_rgb = np.array( [[i]], dtype=np.uint8 )
            pixel_ycbcr = color.rgb2ycbcr( pixel_rgb )
            t = (name, counter,) + i
            fi.write( '%s[%d]="echo %d,%d,%d | image-color-calc convert --from rgb,ub --to ycbcr,ub"\n' % t )
            t = (name, counter,) + i + tuple( pixel_ycbcr[0][0].round().tolist() ) + (name,counter,)
            fe.write( '%s[%d]/output="%d,%d,%d,%d,%d,%d"\n%s[%d]/status=0\n' % t )
            counter = counter + 1
Example #34
0
def denoise_wavelet(image, sigma=None, wavelet='db1', mode='soft',
                    wavelet_levels=None, multichannel=False,
                    convert2ycbcr=False, method='BayesShrink'):
    """Perform wavelet denoising on an image.

    Parameters
    ----------
    image : ndarray ([M[, N[, ...P]][, C]) of ints, uints or floats
        Input data to be denoised. `image` can be of any numeric type,
        but it is cast into an ndarray of floats for the computation
        of the denoised image.
    sigma : float or list, optional
        The noise standard deviation used when computing the wavelet detail
        coefficient threshold(s). When None (default), the noise standard
        deviation is estimated via the method in [2]_.
    wavelet : string, optional
        The type of wavelet to perform and can be any of the options
        ``pywt.wavelist`` outputs. The default is `'db1'`. For example,
        ``wavelet`` can be any of ``{'db2', 'haar', 'sym9'}`` and many more.
    mode : {'soft', 'hard'}, optional
        An optional argument to choose the type of denoising performed. It
        noted that choosing soft thresholding given additive noise finds the
        best approximation of the original image.
    wavelet_levels : int or None, optional
        The number of wavelet decomposition levels to use.  The default is
        three less than the maximum number of possible decomposition levels.
    multichannel : bool, optional
        Apply wavelet denoising separately for each channel (where channels
        correspond to the final axis of the array).
    convert2ycbcr : bool, optional
        If True and multichannel True, do the wavelet denoising in the YCbCr
        colorspace instead of the RGB color space. This typically results in
        better performance for RGB images.
    method : {'BayesShrink', 'VisuShrink'}, optional
        Thresholding method to be used. The currently supported methods are
        "BayesShrink" [1]_ and "VisuShrink" [2]_. Defaults to "BayesShrink".

    Returns
    -------
    out : ndarray
        Denoised image.

    Notes
    -----
    The wavelet domain is a sparse representation of the image, and can be
    thought of similarly to the frequency domain of the Fourier transform.
    Sparse representations have most values zero or near-zero and truly random
    noise is (usually) represented by many small values in the wavelet domain.
    Setting all values below some threshold to 0 reduces the noise in the
    image, but larger thresholds also decrease the detail present in the image.

    If the input is 3D, this function performs wavelet denoising on each color
    plane separately. The output image is clipped between either [-1, 1] and
    [0, 1] depending on the input image range.

    When YCbCr conversion is done, every color channel is scaled between 0
    and 1, and `sigma` values are applied to these scaled color channels.

    Many wavelet coefficient thresholding approaches have been proposed. By
    default, ``denoise_wavelet`` applies BayesShrink, which is an adaptive
    thresholding method that computes separate thresholds for each wavelet
    sub-band as described in [1]_.

    If ``method == "VisuShrink"``, a single "universal threshold" is applied to
    all wavelet detail coefficients as described in [2]_. This threshold
    is designed to remove all Gaussian noise at a given ``sigma`` with high
    probability, but tends to produce images that appear overly smooth.

    Although any of the wavelets from ``PyWavelets`` can be selected, the
    thresholding methods assume an orthogonal wavelet transform and may not
    choose the threshold appropriately for biorthogonal wavelets. Orthogonal
    wavelets are desirable because white noise in the input remains white noise
    in the subbands. Biorthogonal wavelets lead to colored noise in the
    subbands. Additionally, the orthogonal wavelets in PyWavelets are
    orthonormal so that noise variance in the subbands remains identical to the
    noise variance of the input. Example orthogonal wavelets are the Daubechies
    (e.g. 'db2') or symmlet (e.g. 'sym2') families.

    References
    ----------
    .. [1] Chang, S. Grace, Bin Yu, and Martin Vetterli. "Adaptive wavelet
           thresholding for image denoising and compression." Image Processing,
           IEEE Transactions on 9.9 (2000): 1532-1546.
           :DOI:`10.1109/83.862633`
    .. [2] D. L. Donoho and I. M. Johnstone. "Ideal spatial adaptation
           by wavelet shrinkage." Biometrika 81.3 (1994): 425-455.
           :DOI:`10.1093/biomet/81.3.425`

    Examples
    --------
    >>> from skimage import color, data
    >>> img = img_as_float(data.astronaut())
    >>> img = color.rgb2gray(img)
    >>> img += 0.1 * np.random.randn(*img.shape)
    >>> img = np.clip(img, 0, 1)
    >>> denoised_img = denoise_wavelet(img, sigma=0.1)

    """
    if method not in ["BayesShrink", "VisuShrink"]:
        raise ValueError(
            ('Invalid method: {}. The currently supported methods are '
             '"BayesShrink" and "VisuShrink"').format(method))

    image = img_as_float(image)

    if multichannel:
        if isinstance(sigma, numbers.Number) or sigma is None:
            sigma = [sigma] * image.shape[-1]

    if multichannel:
        if convert2ycbcr:
            out = color.rgb2ycbcr(image)
            for i in range(3):
                # renormalizing this color channel to live in [0, 1]
                min, max = out[..., i].min(), out[..., i].max()
                channel = out[..., i] - min
                channel /= max - min
                out[..., i] = denoise_wavelet(channel, wavelet=wavelet,
                                              method=method, sigma=sigma[i],
                                              mode=mode,
                                              wavelet_levels=wavelet_levels)

                out[..., i] = out[..., i] * (max - min)
                out[..., i] += min
            out = color.ycbcr2rgb(out)
        else:
            out = np.empty_like(image)
            for c in range(image.shape[-1]):
                out[..., c] = _wavelet_threshold(image[..., c],
                                                 wavelet=wavelet,
                                                 method=method,
                                                 sigma=sigma[c], mode=mode,
                                                 wavelet_levels=wavelet_levels)
    else:
        out = _wavelet_threshold(image, wavelet=wavelet, method=method,
                                 sigma=sigma, mode=mode,
                                 wavelet_levels=wavelet_levels)

    clip_range = (-1, 1) if image.min() < 0 else (0, 1)
    return np.clip(out, *clip_range)
Example #35
0
def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft',
                    wavelet_levels=None, multichannel=False,
                    convert2ycbcr=False):
    """Perform wavelet denoising on an image.

    Parameters
    ----------
    img : ndarray ([M[, N[, ...P]][, C]) of ints, uints or floats
        Input data to be denoised. `img` can be of any numeric type,
        but it is cast into an ndarray of floats for the computation
        of the denoised image.
    sigma : float or list, optional
        The noise standard deviation used when computing the threshold
        adaptively as described in [1]_ for each color channel. When None
        (default), the noise standard deviation is estimated via the method in
        [2]_.
    wavelet : string, optional
        The type of wavelet to perform and can be any of the options
        ``pywt.wavelist`` outputs. The default is `'db1'`. For example,
        ``wavelet`` can be any of ``{'db2', 'haar', 'sym9'}`` and many more.
    mode : {'soft', 'hard'}, optional
        An optional argument to choose the type of denoising performed. It
        noted that choosing soft thresholding given additive noise finds the
        best approximation of the original image.
    wavelet_levels : int or None, optional
        The number of wavelet decomposition levels to use.  The default is
        three less than the maximum number of possible decomposition levels.
    multichannel : bool, optional
        Apply wavelet denoising separately for each channel (where channels
        correspond to the final axis of the array).
    convert2ycbcr : bool, optional
        If True and multichannel True, do the wavelet denoising in the YCbCr
        colorspace instead of the RGB color space. This typically results in
        better performance for RGB images.

    Returns
    -------
    out : ndarray
        Denoised image.

    Notes
    -----
    The wavelet domain is a sparse representation of the image, and can be
    thought of similarly to the frequency domain of the Fourier transform.
    Sparse representations have most values zero or near-zero and truly random
    noise is (usually) represented by many small values in the wavelet domain.
    Setting all values below some threshold to 0 reduces the noise in the
    image, but larger thresholds also decrease the detail present in the image.

    If the input is 3D, this function performs wavelet denoising on each color
    plane separately. The output image is clipped between either [-1, 1] and
    [0, 1] depending on the input image range.

    When YCbCr conversion is done, every color channel is scaled between 0
    and 1, and `sigma` values are applied to these scaled color channels.

    References
    ----------
    .. [1] Chang, S. Grace, Bin Yu, and Martin Vetterli. "Adaptive wavelet
           thresholding for image denoising and compression." Image Processing,
           IEEE Transactions on 9.9 (2000): 1532-1546.
           DOI: 10.1109/83.862633
    .. [2] D. L. Donoho and I. M. Johnstone. "Ideal spatial adaptation
           by wavelet shrinkage." Biometrika 81.3 (1994): 425-455.
           DOI: 10.1093/biomet/81.3.425

    Examples
    --------
    >>> from skimage import color, data
    >>> img = img_as_float(data.astronaut())
    >>> img = color.rgb2gray(img)
    >>> img += 0.1 * np.random.randn(*img.shape)
    >>> img = np.clip(img, 0, 1)
    >>> denoised_img = denoise_wavelet(img, sigma=0.1)

    """
    img = img_as_float(img)

    if multichannel:
        if isinstance(sigma, numbers.Number) or sigma is None:
            sigma = [sigma] * img.shape[-1]

    if multichannel:
        if convert2ycbcr:
            out = color.rgb2ycbcr(img)
            for i in range(3):
                # renormalizing this color channel to live in [0, 1]
                min, max = out[..., i].min(), out[..., i].max()
                channel = out[..., i] - min
                channel /= max - min
                out[..., i] = denoise_wavelet(channel, sigma=sigma[i],
                                              wavelet=wavelet, mode=mode)

                out[..., i] = out[..., i] * (max - min)
                out[..., i] += min
            out = color.ycbcr2rgb(out)
        else:
            out = np.empty_like(img)
            for c in range(img.shape[-1]):
                out[..., c] = _wavelet_threshold(img[..., c], wavelet=wavelet,
                                                 mode=mode, sigma=sigma[c],
                                                 wavelet_levels=wavelet_levels)

    else:
        out = _wavelet_threshold(img, wavelet=wavelet, mode=mode,
                                 sigma=sigma,
                                 wavelet_levels=wavelet_levels)

    clip_range = (-1, 1) if img.min() < 0 else (0, 1)
    return np.clip(out, *clip_range)