예제 #1
0
    def __getitem__(self, index):
        path = self.test_path[index]
        images = {}

        images.update({'name': path['save']})
        input_ = cv2.imread(path['input'])
        input_ = cv2.cvtColor(input_, cv2.COLOR_BGR2RGB)

        target_ = cv2.imread(path['target'])
        target_ = utils.modcrop(target_, self.sr_factor)
        target_ = cv2.cvtColor(target_, cv2.COLOR_BGR2RGB)

        if not self.rgb:
            input_out = np.copy(input_)
            input_out = utils.np2tensor(input_out, self.rgb_range)
            # print(input_out)
            input_ = utils.rgb2ycbcr(input_)
            input_cbcr = input_[:, :, 1:]
            input_ = np.expand_dims(input_[:, :, 0], 2)
            input_cbcr = utils.np2tensor(input_cbcr, self.rgb_range)
            images.update({'input_cbcr': input_cbcr, 'input_rgb': input_out})

        if self.target_down:
            target_down = imresize(target_, scalar_scale=1 / self.sr_factor)
            target_down = utils.np2tensor(target_down, self.rgb_range)
            images.update({'target_down': target_down})

        input_ = utils.np2tensor(input_, self.rgb_range)
        target_ = utils.np2tensor(target_, self.rgb_range)

        images.update({'input': input_, 'target': target_})
        return images
예제 #2
0
    def __getitem__(self, index):
        path = self.train_path[index]
        images = {}
        if self.npy_reader:
            input_ = np.load(path['input'], allow_pickle=False)

            target_ = np.load(path['target'], allow_pickle=False)
            target_ = utils.modcrop(target_, self.sr_factor)
        else:
            input_ = cv2.imread(path['input'])
            input_ = cv2.cvtColor(input_, cv2.COLOR_BGR2RGB)

            target_ = cv2.imread(path['target'])
            target_ = utils.modcrop(target_, self.sr_factor)
            target_ = cv2.cvtColor(target_, cv2.COLOR_BGR2RGB)

        # for i in range(10):
        #     subim_in, subim_tar = get_patch(input_, target_, self.patch_size, self.sr_factor)
        # win_mean = ndimage.uniform_filter(subim_in[:, :, 0], (5, 5))
        # win_sqr_mean = ndimage.uniform_filter(subim_in[:, :, 0]**2, (5, 5))
        # win_var = win_sqr_mean - win_mean**2
        #
        # if np.sum(win_var) / (win_var.shape[0]*win_var.shape[1]) > 30:
        #     break

        subim_in, subim_tar = get_patch(input_, target_, self.patch_size,
                                        self.sr_factor)

        if not self.rgb:
            subim_in = utils.rgb2ycbcr(subim_in)
            subim_tar = utils.rgb2ycbcr(subim_tar)
            subim_in = np.expand_dims(subim_in[:, :, 0], 2)
            subim_tar = np.expand_dims(subim_tar[:, :, 0], 2)

        if self.target_down:
            subim_target_down = imresize(subim_tar,
                                         scalar_scale=1 / self.sr_factor)
            subim_target_down = utils.np2tensor(subim_target_down,
                                                self.rgb_range)
            images.update({'target_down': subim_target_down})

        subim_in = utils.np2tensor(subim_in, self.rgb_range)
        subim_tar = utils.np2tensor(subim_tar, self.rgb_range)
        images.update({'input': subim_in, 'target': subim_tar})
        return images
예제 #3
0
    def load_images(self, images):
        """Given a list of file names, return a list of images"""
        out = []
        for image in images:
            img = cv2.cvtColor(cv2.imread(image), cv2.COLOR_BGR2RGB).astype(np.uint8)
            img = modcrop(img,2)
            img = dwt_shape(img)
            out.append(img)

        return out
예제 #4
0
def main():
    ## data
    print('Loading data...')
    test_hr_path = os.path.join('data/', dataset)
    if dataset == 'Set5':
        ext = '*.bmp'
    else:
        ext = '*.png'
    hr_paths = sorted(glob.glob(os.path.join(test_hr_path, ext)))

    ## model
    print('Loading model...')
    tensor_lr = tf.placeholder('float32', [1, None, None, 3], name='tensor_lr')
    tensor_b = tf.placeholder('float32', [1, None, None, 3], name='tensor_b')

    tensor_sr = IDN(tensor_lr, tensor_b, scale)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess, model_path)

    ## result
    save_path = os.path.join(saved_path, dataset + '/x' + str(scale))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    psnr_score = 0
    for i, _ in enumerate(hr_paths):
        print('processing image %d' % (i + 1))
        img_hr = utils.modcrop(misc.imread(hr_paths[i]), scale)
        img_lr = utils.downsample_fn(img_hr, scale=scale)
        img_b = utils.upsample_fn(img_lr, scale=scale)
        [lr, b] = utils.datatype([img_lr, img_b])
        lr = lr[np.newaxis, :, :, :]
        b = b[np.newaxis, :, :, :]
        [sr] = sess.run([tensor_sr], {tensor_lr: lr, tensor_b: b})
        sr = utils.quantize(np.squeeze(sr))
        img_sr = utils.shave(sr, scale)
        img_hr = utils.shave(img_hr, scale)
        if not rgb:
            img_pre = utils.quantize(sc.rgb2ycbcr(img_sr)[:, :, 0])
            img_label = utils.quantize(sc.rgb2ycbcr(img_hr)[:, :, 0])
        else:
            img_pre = img_sr
            img_label = img_hr
        psnr_score += utils.compute_psnr(img_pre, img_label)
        misc.imsave(os.path.join(save_path, os.path.basename(hr_paths[i])), sr)

    print('Average PSNR: %.4f' % (psnr_score / len(hr_paths)))
    print('Finish')
예제 #5
0
import cv2
import os
from matlab_imresize import imresize
from utils import modcrop, mkdir

origin_path = '/media/luo/data/data/super-resolution/VISTA/origin'
path = '/media/luo/data/data/super-resolution/VISTA'
#
Q_list = [5, 10, 20, 30, 40, 50]
sr_factor = [2, 3, 4]

for root, _, names in os.walk(origin_path):
    for name in names:
        target_name = os.path.join(origin_path, name)

        img_tar = cv2.imread(target_name)
        # img_tar = cv2.cvtColor(img_tar, cv2.cvtColor())
        for i in sr_factor:
            img_tar_ = modcrop(img_tar, i)
            sr_path = os.path.join(path, 'x{}'.format(i))
            img_down = imresize(img_tar_, scalar_scale=1 / i)
            for Q in Q_list:
                webp_path = os.path.join(sr_path, 'webp{}'.format(Q))
                if not os.path.exists(webp_path):
                    mkdir(webp_path)

                save_name = name[:-4] + '.webp'
                save = os.path.join(webp_path, save_name)
                cv2.imwrite(save, img_down, [cv2.IMWRITE_WEBP_QUALITY, Q])
예제 #6
0
            ssim_ = tf.image.ssim(output_crop_, target_crop_,
                                  max_val=1.0)  # ssim

            print("Computing PSNR/SSIM scores....")

            ssim_score = 0.0
            psnr_score = 0.0
            validation_images = os.listdir(test_data_dir + dataset)
            num_val_images = len(validation_images)

            for j in range(num_val_images):

                print("\rImage %d / %d" % (j + 1, num_val_images), end='')
                image = misc.imread(test_data_dir + dataset +
                                    validation_images[j])
                image = utils.modcrop(image, modulo=4)
                if len(image.shape) < 3:
                    image = image[..., np.newaxis]
                    image = np.concatenate([image] * 3, 2)

                image_bicubic = misc.imresize(image, 0.25, interp="bicubic")
                image_bicubic = misc.imresize(image_bicubic,
                                              4.0,
                                              interp="bicubic")

                image_bicubic = np.reshape(image_bicubic, [
                    1, image_bicubic.shape[0], image_bicubic.shape[1], 3
                ]) / 255
                image_target = np.reshape(
                    image, [1, image.shape[0], image.shape[1], 3]) / 255
예제 #7
0
 def load_images(self, images):
     """Given a list of file names, return a list of images"""
     out = []
     for image in images:
         out.append(modcrop(misc.imread(image, mode='RGB').astype(np.uint8),3))
     return out
예제 #8
0
    IMAGE_GT_FILE='./data/Set5/*.bmp'
    """

    MODEL_FILE = [
        './mdl/weights_srnet_x2_52.p', './mdl/weights_srnet_x2_310.p'
    ]
    UP_SCALE = 2
    SHAVE = 1  #set 1 to be consistant with SRCNN

    # load inputs
    im_gt = []
    files_gt = glob.glob(IMAGE_GT_FILE)
    for f in files_gt:
        #print 'loading', f
        im = np.array(Image.open(f))
        im = utils.modcrop(im, UP_SCALE).astype(np.float32)
        im_gt += [im]

    im_l = []
    if len(IMAGE_FILE) > 0:
        assert (len(im_gt) == 1)
        im_l = [np.array(Image.open(IMAGE_FILE)).astype(np.float32)]
    else:  #down scale from ground truth using Matlab
        try:
            from pymatbridge import Matlab
            mlab = Matlab()
            mlab.start()
            for im in im_gt:
                mlab.set_variable('a', im)
                mlab.set_variable('s', 1.0 / UP_SCALE)
                mlab.run_code('b=imresize(a, s);')
예제 #9
0
filelist = utils.get_list(filepath, ext=ext)
psnr_sr = np.zeros(len(filelist))
ssim_sr = np.zeros(len(filelist))

opt.is_train = False

model = networks.define_G(opt)
if isinstance(model, nn.DataParallel):
    model = model.module
model.load_state_dict(torch.load(opt.models), strict=True)
i = 0
for imname in filelist:
    if opt.isHR:
        im_gt = cv2.imread(opt.test_hr_folder +
                           imname.split('/')[-1])[:, :, [2, 1, 0]]
        im_gt = utils.modcrop(im_gt, opt.upscale_factor)
    im_l = cv2.imread(imname)[:, :, [2, 1, 0]]
    if len(im_l.shape) < 3:
        if opt.isHR:
            im_gt = im_gt[..., np.newaxis]
            im_gt = np.concatenate([im_gt] * 3, 2)
        im_l = im_l[..., np.newaxis]
        im_l = np.concatenate([im_l] * 3, 2)

    if im_l.shape[2] > 3:
        if opt.isHR:
            im_gt = im_gt[..., 0:3]
        im_l = im_l[..., 0:3]

    im_input = im_l / 255.0
    im_input = np.transpose(im_input, (2, 0, 1))
예제 #10
0
def makeh5patches(args):

    print('\n----------------------------------------')
    print('Command line arguements')
    print('----------------------------------------')
    for i in args.__dict__:
        print((i), ':', args.__dict__[i])
    print('----------------------------------------')

    # reading all the image paths for given patients
    all_dir_paths = sorted(glob.glob(args.input_folder +
                                     '/*/'))  #/*/-> to enter sub-folders
    all_input_paths, all_target_paths = [], []

    # allocating arrays for input/target min/max
    pre_norm_in_min, pre_norm_in_max = [], []
    pre_norm_tar_min, pre_norm_tar_max = [], []
    post_norm_in_min, post_norm_in_max = [], []
    post_norm_tar_min, post_norm_tar_max = [], []

    random_ind = None
    for dir_paths in all_dir_paths:
        if args.random_N:
            random_ind = utils.get_sorted_random_ind(
                os.path.join(dir_paths, args.input_gen_folder),
                args.N_rand_imgs)

        in_paths = utils.getimages4rmdir(
            os.path.join(dir_paths, args.input_gen_folder), random_ind)
        target_paths = utils.getimages4rmdir(
            os.path.join(dir_paths, args.target_gen_folder), random_ind)

        all_input_paths.extend(in_paths)
        all_target_paths.extend(target_paths)

    print('\nTraining input image paths:')
    print(np.asarray(all_input_paths))
    print('\n\nTraining target image paths:')
    print(np.asarray(all_target_paths))
    #declaring null array for input & label to append later
    sub_input_of_all_inputs = np.empty(
        [0, args.input_size, args.input_size, 1])
    sub_label_of_all_labels = np.empty(
        [0, args.label_size, args.label_size, 1])

    #declaring path to save sanity check results
    sanity_chk_path = 'sanity_check/' + (
        (args.input_folder).split('/'))[-1] + '/norm_' + str(
            args.normalization_type) + '_patch_size_' + str(args.patch_size)
    if not os.path.isdir(sanity_chk_path): os.makedirs(sanity_chk_path)

    # if the input is to be blurred and noise is to be added
    # get the label of the indices that is to be blurred and noised
    if args.blurr_n_noise: seed = utils.bn_seed(len(all_input_paths), 0.4, 0.4)
    else:
        sN = len(all_input_paths)
        seed = [None] * sN

    for i in range(len(all_input_paths)):

        input_image = gf.pydicom_imread(all_input_paths[i])
        target_image = gf.pydicom_imread(all_target_paths[i])
        #input_image = input_image[33:455]
        #target_image = target_image[33:455]
        #gf.plot2dlayers(input_image)
        #sys.exit()
        if (input_image.shape != target_image.shape):
            print("MISMATCH in image size for \
				input: ", all_input_paths[i].split('/'), "& output: ",
                  all_target_paths[i].split('/')[-1])
            print("Exiting the program")
            sys.exit()
        if (i == 0):
            print('\nHere target images from training dataset is of type-', target_image.dtype,\
               '. And is assigned as-', (target_image.astype('float32')).dtype,\
                'before network training')
            print("\nFirst image pair (target : input) in the raw stack (i.e. before patching)"\
               " are of shapes {} : {}".format(target_image.shape, input_image.shape))

        target_image = target_image.astype('float32')
        input_image = input_image.astype('float32')
        if (args.air_threshold):
            target_image_un = target_image  #used to for air thresholding

        pre_norm_in_min.append(np.min(input_image))
        pre_norm_in_max.append(np.max(input_image))
        pre_norm_tar_min.append(np.min(target_image))
        pre_norm_tar_max.append(np.max(target_image))

        # sp 	  = input_image.shape
        # if len(sp) == 3:
        # image = image[:, :, 0]

        # ------------------
        # Data normalization
        # ------------------
        input_image, target_image = utils.img_pair_normalization(
            input_image, target_image, args.normalization_type)
        post_norm_in_min.append(np.min(input_image))
        post_norm_in_max.append(np.max(input_image))
        post_norm_tar_min.append(np.min(target_image))
        post_norm_tar_max.append(np.max(target_image))

        # -----------------
        # Data Augmentation
        # -----------------
        if args.ds_augment:
            # need to change image into uint type before augmentation
            # if Pil augmentation is used
            # image = (gf.normalize_data_ab(0, 255, image)).astype(np.uint8)
            # else no need
            input_aug_images = utils.downsample_4r_augmentation(input_image)
            target_aug_images = utils.downsample_4r_augmentation(target_image)
            if (args.air_threshold):
                target_un_aug_images = utils.downsample_4r_augmentation(
                    target_image_un)
            if (i == 0):
                print("\nDownscale based data augmentation is PERFORMED")
                print("Also, each input-target image pair is downscaled by", \
                  len(input_aug_images)-1,"different scaling factors due to downscale based augmentation")
        else:
            h, w = input_image.shape
            input_aug_images = np.reshape(input_image, (1, h, w))
            target_aug_images = np.reshape(target_image, (1, h, w))
            if (args.air_threshold):
                target_un_aug_images = np.reshape(target_image_un, (1, h, w))
            if (i == 0):
                print("\nDownscale based data augmentation is NoT PERFORMED")

        # print(len(aug_images))
        # Now working on each augmented images
        for p in range(len(input_aug_images)):

            #adding noise and downscaling the input images as instructed
            label_ = utils.modcrop(target_aug_images[p], args.scale)
            input_ = utils.modcrop(input_aug_images[p], args.scale)
            if (args.air_threshold): un_label_ = target_un_aug_images[p]

            if args.scale == 1: input_ = input_
            else: input_ = utils.interpolation_lr(input_, args.scale)

            if args.blurr_n_noise:
                cinput_ = utils.add_blurr_n_noise(input_, seed[i])
            else:
                cinput_ = input_
            # print('seed=', seed[i])
            # gf.plot2dlayers(cinput_, title='input')
            # gf.plot2dlayers(label_, title='target')

            sub_input, sub_label = utils.overlap_based_sub_images(
                args, cinput_, label_)

            if (args.air_threshold):
                _, sub_label_un = utils.overlap_based_sub_images(
                    args, cinput_, un_label_)
                sub_input, sub_label = utils.air_thresholding(
                    args, sub_input, sub_label, sub_label_un)

            if args.rot_augment:
                add_rot_input, add_rot_label = utils.rotation_based_augmentation(
                    args, sub_input, sub_label)
            else:
                add_rot_input, add_rot_label = sub_input, sub_label
            sub_input_of_all_inputs = np.append(sub_input_of_all_inputs,
                                                add_rot_input,
                                                axis=0)
            sub_label_of_all_labels = np.append(sub_label_of_all_labels,
                                                add_rot_label,
                                                axis=0)
        #gf.multi2dplots(4, 8, sub_input_of_all_inputs[0:66, :, :, 0], 0, passed_fig_att = {"colorbar": False, "figsize":[4*2, 4*2]})
        #gf.multi2dplots(4, 8, sub_label_of_all_labels[0:66, :, :, 0], 0, passed_fig_att = {"colorbar": False, "figsize":[4*2, 4*2]})
        #sys.exit()
    # --------------------------
    # Shuffling the patches
    # --------------------------
    if args.shuffle_patches:
        Npatches = len(sub_input_of_all_inputs)
        shuffled_Npatches_arr = np.arange(Npatches)
        np.random.shuffle(shuffled_Npatches_arr)
        sub_input_of_all_inputs = sub_input_of_all_inputs[
            shuffled_Npatches_arr, :, :, :]
        sub_label_of_all_labels = sub_label_of_all_labels[
            shuffled_Npatches_arr, :, :, :]

    # -----------------------------------------------------
    # Sanity check
    # making patch plot of random patches for sanity check
    #------------------------------------------------------
    if args.sanity_plot_check:
        window = 12
        lr_N = len(sub_input_of_all_inputs)
        rand_num = random.sample(range(lr_N - window), 5)
        #print(sub_input_of_all_inputs.shape)
        #print(rand_num)
        #sys.exit()
        for k in range(len(rand_num)):
            s_ind = rand_num[k]
            e_ind = s_ind + window
            lr_out_path = os.path.join(sanity_chk_path +
                                       '/lr_input_sub_img_rand_' +
                                       str(rand_num[k]) + '.png')
            hr_out_path = os.path.join(sanity_chk_path +
                                       '/hr_input_sub_img_rand_' +
                                       str(rand_num[k]) + '.png')
            gf.multi2dplots(3,
                            4,
                            sub_input_of_all_inputs[s_ind:e_ind, :, :, 0],
                            0,
                            passed_fig_att={
                                "colorbar": False,
                                "figsize": [4, 4],
                                "out_path": lr_out_path
                            })
            gf.multi2dplots(3,
                            4,
                            sub_label_of_all_labels[s_ind:e_ind, :, :, 0],
                            0,
                            passed_fig_att={
                                "colorbar": False,
                                "figsize": [4 * args.scale, 4 * args.scale],
                                "out_path": hr_out_path
                            })

    # data format based on API used for network training
    # torch reads tensor as [batch_size, channels, height, width]
    # tensorflow reads tensor as [batch_size, height, width, channels]
    if args.tensor_format == 'torch':
        sub_input_of_all_inputs = np.transpose(sub_input_of_all_inputs,
                                               (0, 3, 1, 2))
        sub_label_of_all_labels = np.transpose(sub_label_of_all_labels,
                                               (0, 3, 1, 2))
    elif args.tensor_format == 'tf':
        sub_input_of_all_inputs = sub_input_of_all_inputs
        sub_label_of_all_labels = sub_label_of_all_labels

    # --------------------
    # creating h5 file
    #---------------------
    output_folder = os.path.split(args.output_fname)[0]
    if not os.path.isdir(output_folder): os.makedirs(output_folder)
    hf = h5py.File(args.output_fname, mode='w')
    hf.create_dataset('input', data=sub_input_of_all_inputs)
    hf.create_dataset('target', data=sub_label_of_all_labels)
    hf.close()
    print("\nshape of the overall input  subimages: {}".format(
        sub_input_of_all_inputs.shape))
    print("shape of the overall target subimages: {}".format(
        sub_label_of_all_labels.shape))
    print("\nFinally, due to data normalization based on:",
          args.normalization_type)
    print("input image range changes from (%.4f, %.4f) to (%.4f, %.4f)" %
          (min(pre_norm_in_min), max(pre_norm_in_max), min(post_norm_in_min),
           max(post_norm_in_max)))
    print("target image range changes from (%.4f, %.4f) to (%.4f, %.4f)" %
          (min(pre_norm_tar_min), max(pre_norm_tar_max),
           min(post_norm_tar_min), max(post_norm_tar_max)))
    print('final sum of input, target is:', np.sum(sub_input_of_all_inputs),
          np.sum(sub_label_of_all_labels))
예제 #11
0
    # input with ground truth images only (Matlab required)
    IMAGE_FILE=''
    IMAGE_GT_FILE='./data/Set5/*.bmp'
    """

    MODEL_FILE=['./mdl/weights_srnet_x2_52.p', './mdl/weights_srnet_x2_310.p']
    UP_SCALE=2
    SHAVE=1 #set 1 to be consistant with SRCNN

    # load inputs
    im_gt = []
    files_gt = glob.glob(IMAGE_GT_FILE)
    for f in files_gt:
        #print 'loading', f
        im = np.array(Image.open(f))
        im = utils.modcrop(im, UP_SCALE).astype(np.float32)
        im_gt += [im]

    im_l = []
    if len(IMAGE_FILE)>0:
        assert(len(im_gt)==1)
        im_l = [np.array(Image.open(IMAGE_FILE)).astype(np.float32)]
    else: #down scale from ground truth using Matlab
        try:
            from pymatbridge import Matlab
            mlab = Matlab()
            mlab.start()
            for im in im_gt:
                mlab.set_variable('a', im)
                mlab.set_variable('s', 1.0/UP_SCALE)
                mlab.run_code('b=imresize(a, s);')
예제 #12
0
    def train(self, config):
        if config.is_train:
            input_setup(self.sess, config)
        else:
            nx, ny = input_setup(self.sess, config)

        if config.is_train:
            data_dir = os.path.join('./{}'.format(config.checkpoint_dir),
                                    "train.h5")
        else:
            data_dir = os.path.join('./{}'.format(config.checkpoint_dir),
                                    "test.h5")

        train_data, train_label = read_data(data_dir)

        # Stochastic gradient descent with the standard backpropagation
        self.train_op = tf.train.GradientDescentOptimizer(
            config.learning_rate).minimize(self.loss)

        tf.initialize_all_variables().run()
        #For tf 0.12.1
        #tf.global_variables_initializer()

        counter = 0
        start_time = time.time()

        if self.load(self.checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        if config.is_train:
            print("Training...")

            for ep in xrange(config.epoch):
                # Run by batch images
                batch_idxs = len(train_data) // config.batch_size
                for idx in xrange(0, batch_idxs):
                    batch_images = train_data[idx *
                                              config.batch_size:(idx + 1) *
                                              config.batch_size]
                    batch_labels = train_label[idx *
                                               config.batch_size:(idx + 1) *
                                               config.batch_size]

                    counter += 1
                    _, err = self.sess.run([self.train_op, self.loss],
                                           feed_dict={
                                               self.images: batch_images,
                                               self.labels: batch_labels
                                           })

                    if counter % 10 == 0:
                        print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
                              % ((ep + 1), counter, time.time() - start_time, err))

                    if counter % 500 == 0:
                        self.save(config.checkpoint_dir, counter)

        else:
            print("Testing...")

            print "Train data shape", train_data.shape
            print "Train label shape", train_label.shape

            result = self.pred.eval({
                self.images: train_data,
                self.labels: train_label
            })

            print "Result shape", result.shape
            print "nx ny", nx, ny

            image = merge(result, [nx, ny])
            original_image = merge(train_label, [nx, ny])
            interpolation = down_upscale(modcrop(original_image, config.scale),
                                         scale=config.scale)

            imsave(
                original_image,
                os.path.join(os.getcwd(), config.sample_dir, "original.bmp"),
                config.is_RGB)
            imsave(
                interpolation,
                os.path.join(os.getcwd(), config.sample_dir,
                             "interpolation.bmp"), config.is_RGB)
            imsave(image,
                   os.path.join(os.getcwd(), config.sample_dir, "srcnn.bmp"),
                   config.is_RGB)
예제 #13
0
import math
import cv2
import glob
import os


def psnr(target, ref, scale):
    #assume RGB image
    target_data = np.array(target, dtype=np.float64)
    ref_data = np.array(ref, dtype=np.float64)
    diff = ref_data - target_data
    diff = diff.flatten('C')
    rmse = math.sqrt(np.mean(diff**2.))
    return 20 * math.log10(255.0 / rmse)


if __name__ == "__main__":
    data_HR = glob.glob(os.path.join('./Test/Set5', "*.bmp"))
    print(data_HR)
    data_LR = glob.glob('./result/result.png')
    print(data_LR)
    hr = modcrop(cv2.imread(data_HR[0]))
    lr = cv2.imread(data_LR[0])
    lr = modcrop(
        cv2.resize(lr,
                   None,
                   fx=1.0 / 3,
                   fy=1.0 / 3,
                   interpolation=cv2.INTER_CUBIC))
    print(psnr(lr, hr, scale=3))
예제 #14
0
model.load_weights(weights_name)
# model.summary()

# evaluate each image in dataset directory
image_list = []
psnr_list = []
ssim_list = []
file_dir = os.listdir(test_dataset)

for file in file_dir:
    # read image and prepare input
    image_name = file
    image_list.append(image_name)
    img = imread(os.path.join(test_dataset, image_name), mode='YCbCr')
    x = np.array(img[:, :, 0])
    x = modcrop(x, scale)
    x_lr = imresize(x, 1.0 / scale, 'bicubic') / 255.0
    x_bic = imresize(x_lr, 1.0 * scale, 'bicubic') / 255.0
    x = x / 255.0
    # Wavelet transform
    cA, (cH, cV, cD) = pywt.dwt2(x_bic, 'haar')
    input_data = np.array([[cA, cH, cV, cD]])
    input_data = input_data.transpose([0, 2, 3, 1])

    # predict by pretrained model
    result = model.predict(input_data, batch_size=1, verbose=1)
    result = np.squeeze(result)
    # inverse Wavelet transform
    rA, rH, rV, rD = result[:, :, 0], result[:, :, 1], result[:, :,
                                                              2], result[:, :,
                                                                         3]