コード例 #1
0
                                  shared_axis=shared_axis).build()
    generator.load_weights(model_path)

    generator2 = Network.Generator(data_format=data_format,
                                   axis=axis,
                                   shared_axis=shared_axis).build()
    generator2.load_weights(model_path2)

    predicted_images = generator.predict(lr_images)
    predicted_images2 = generator2.predict(lr_images)

    for index in range(batch_size):
        fig = plt.figure()
        ax = fig.add_subplot(2, 2, 1)
        ax.imshow(
            utils.deprocess_HR(predicted_images2[index]).astype(np.uint8))
        ax.axis("off")
        ax.set_title("SRGAN-VGG54_real_bs16_8epochs")

        ax = fig.add_subplot(2, 2, 2)
        ax.imshow(utils.deprocess_HR(hr_images[index]).astype(np.uint8))
        ax.axis("off")
        ax.set_title("Original")

        ax = fig.add_subplot(2, 2, 3)
        ax.imshow(utils.deprocess_LR(lr_images[index]).astype(np.uint8))
        ax.axis("off")
        ax.set_title("Low-res")

        ax = fig.add_subplot(2, 2, 4)
        ax.imshow(utils.deprocess_HR(predicted_images[index]).astype(np.uint8))
コード例 #2
0
    generator_srgan = Network.Generator(data_format=data_format,
                                        axis=axis,
                                        shared_axis=shared_axis).build()

    generator_mse.load_weights(model_path_mse)
    generator_srgan.load_weights(model_path_srgan)

    predicted_images_mse = generator_mse.predict(hr_images)
    predicted_images_srgan = generator_srgan.predict(hr_images)
    predicted_images_lowres = generator_srgan.predict(lr_images)

    for index in range(batch_size):
        fig = plt.figure()
        ax = fig.add_subplot(5, 1, 1)
        ax.imshow(
            utils.deprocess_HR(predicted_images_mse[index]).astype(np.uint8))
        ax.axis("off")
        ax.set_title("MSE")

        ax = fig.add_subplot(5, 1, 2)
        ax.imshow(utils.deprocess_LR(hr_images[index]).astype(np.uint8))
        ax.axis("off")
        ax.set_title("Original")

        ax = fig.add_subplot(5, 1, 3)
        ax.imshow(
            utils.deprocess_HR(predicted_images_srgan[index]).astype(np.uint8))
        ax.axis("off")
        ax.set_title("SRGAN")

        ax = fig.add_subplot(5, 1, 4)
コード例 #3
0
ファイル: images_gen.py プロジェクト: Tubbz-alt/eyeSRGAN
        _map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    train_ds = train_ds.batch(batch_size)
    train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)

    iterator = train_ds.__iter__()

    # model_path_srgan = 'E:\\TFM\\outputs\\Resultados_Test\\Sin_Textura_4\\checkpoints\\SRGAN-VGG54\\generator_best.h5'
    # model_path_srgan = 'saved_weights/SRGAN-VGG54_lrchanged_wm_bs20_10epochs_earlys4_lower/generator_best.h5'
    model_path_srgan = 'saved_weights/SRGAN-VGG54_lrchanged_wm_bs20_10epochs_earlys4_lower/generator_best.h5'
    # model_path_mse = 'saved_weights/SRResNet-MSE/best_weights.hdf5'
    # model_path_srgan = 'saved_weights/SRGAN-VGG54-20bs-2e-u18/generator_best.h5'

    generator_srgan = Network.Generator(data_format=data_format,
                                        axis=axis,
                                        shared_axis=shared_axis).build()

    generator_srgan.load_weights(model_path_srgan)

    for i, file in enumerate(list_files):
        print(i)
        hr_images, lr_images = next(iterator)

        predicted_images = generator_srgan.predict(lr_images)

        base_name = os.path.splitext(file)[0]
        output_path = f'{base_name}_sr.png'
        io.imsave(output_path,
                  utils.deprocess_HR(predicted_images[0]).astype(np.uint8))
        # io.imsave('./outputs/' + f'{i}_gen_sr.png', utils.deprocess_HR(predicted_images[0]).astype(np.uint8))
        # io.imsave('./outputs/' + f'{i}_orig.png', utils.deprocess_LR(crop_hr[0, :, :]).astype(np.uint8))
コード例 #4
0
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        list_mse = []
        list_psnr = []

        current_filepath = os.path.join(self.filepath,
                                        'epoch_{}'.format(epoch + 1))
        if os.path.isdir(current_filepath):
            rmtree(current_filepath)
        os.mkdir(current_filepath)

        start = time()
        for i, file_HR in enumerate(self.list_files_HR):
            img_hr = cv2.imread(file_HR, cv2.IMREAD_COLOR)
            img_name = file_HR.split('/')[-1].split('_')[0]

            if self.color_mode == 'bgr':
                img_hr = cv2.cvtColor(img_hr, cv2.COLOR_BGR2RGB)

            new_shape = img_hr.shape

            img_lr = cv2.resize(img_hr,
                                (new_shape[1] // self.downscale_factor,
                                 new_shape[0] // self.downscale_factor),
                                interpolation=cv2.INTER_LINEAR)

            img_cubic = cv2.resize(img_lr, (new_shape[1], new_shape[0]),
                                   interpolation=cv2.INTER_CUBIC)

            img_lr = preprocess_LR(img_lr)

            if self.data_format == 'channels_first':
                # (h,w,c) to (c,h,w)
                img_lr = np.transpose(img_lr, (2, 0, 1))

            img_sr = self.model.predict_on_batch(np.expand_dims(img_lr,
                                                                axis=0))[0]

            if self.data_format == 'channels_first':
                # (c,h,w) to (h,w,c)
                img_sr = np.transpose(img_sr, (1, 2, 0))

            # compute MSE (with images in range [0, 255])
            img_sr = deprocess_HR(img_sr)
            mse = np.mean(np.square(img_hr[self.slice] - img_sr[self.slice]))
            list_mse.append(mse)

            # compute PSNR
            if mse == 0.:
                psnr = 100
            else:
                psnr = 20 * np.log10(255. / np.sqrt(mse))
            list_psnr.append(psnr)

            if self.filepath is not None:
                global_image = np.zeros((new_shape[0], 3 * new_shape[1], 3),
                                        dtype=np.uint8)
                global_image[:, 0:new_shape[1], :] = img_cubic.astype(np.uint8)
                global_image[:, new_shape[1]:2 *
                             new_shape[1], :] = img_sr.astype(np.uint8)
                global_image[:, 2 * new_shape[1]:3 *
                             new_shape[1], :] = img_hr.astype(np.uint8)

                # add black padding
                global_image_ext = np.zeros(
                    (global_image.shape[0] + 50, global_image.shape[1], 3),
                    dtype=np.uint8)
                global_image_ext[0:global_image.shape[0], :, :] = global_image
                global_image_ext = cv2.putText(
                    img=np.copy(global_image_ext),
                    text="MSE = {:.3f} | PSNR = {:.3f}".format(mse, psnr),
                    org=(0, new_shape[0] + 50),
                    fontFace=1,
                    fontScale=2,
                    color=(255, 255, 255),
                    thickness=2)

                cv2.imwrite(os.path.join(current_filepath,
                                         '{}.png'.format(img_name)),
                            global_image_ext.astype(np.uint8),
                            params=[cv2.IMWRITE_PNG_COMPRESSION, 3])

        self.logs['mse'].append(np.mean(list_mse))
        self.logs['psnr'].append(np.mean(list_psnr))
        stop = time()

        if self.verbose > 0:
            print(
                '\nBSD100 Callback - Epoch %05d: MSE = %s  || PSNR = %s  in %05d s'
                % (epoch + 1, np.mean(list_mse), np.mean(list_psnr),
                   stop - start))
コード例 #5
0
def BSD100_evaluate(model,
                    step,
                    directory,
                    filepath=None,
                    verbose=1,
                    color_mode='rgb',
                    downscale_factor=4,
                    margin=5,
                    data_format='channels_last'):
    """
    :param model: model to evaluate on BSD100
    :param step: step id. Used if filepath is not None.
    :param directory: path to the BSD100 dataset directory
    :param filepath: directory where LR|SR|HR images will be saved. If None, it does not save images.
    :param downscale_factor: downscale factor to apply to LR images
    :param color_mode: RGB or BGR mode
    :param verbose: verbose level. 0 or 1
    :param data_format: order of dimensions for TensorFlow
    :param margin: remove margin pixels strips from each border
    """

    list_files_HR = glob(os.path.join(directory, "*_HR.png"))
    print("Found {} images ...".format(len(list_files_HR)))

    list_mse = []
    list_psnr = []

    slice = np.s_[margin:-margin, margin:-margin, :]

    current_filepath = os.path.join(filepath, 'epoch_{}'.format(step))
    if filepath is not None:
        if os.path.isdir(current_filepath):
            rmtree(current_filepath)
        os.mkdir(current_filepath)

    start = time()
    for i, file_HR in enumerate(list_files_HR):
        img_hr = cv2.imread(file_HR, cv2.IMREAD_COLOR)
        img_name = file_HR.split('/')[-1].split('_')[0]

        if color_mode == 'bgr':
            img_hr = cv2.cvtColor(img_hr, cv2.COLOR_BGR2RGB)

        new_shape = img_hr.shape

        img_lr = cv2.resize(img_hr, (new_shape[1] // downscale_factor,
                                     new_shape[0] // downscale_factor),
                            interpolation=cv2.INTER_LINEAR)

        img_cubic = cv2.resize(img_lr, (new_shape[1], new_shape[0]),
                               interpolation=cv2.INTER_CUBIC)

        img_lr = preprocess_LR(img_lr)

        if data_format == 'channels_first':
            # (h,w,c) to (c,h,w)
            img_lr = np.transpose(img_lr, (2, 0, 1))

        img_sr = model.predict_on_batch(np.expand_dims(img_lr, axis=0))[0]

        if data_format == 'channels_first':
            # (c,h,w) to (h,w,c)
            img_sr = np.transpose(img_sr, (1, 2, 0))

        # compute MSE (with images in range [0, 255])
        img_sr = deprocess_HR(img_sr)
        mse = np.mean(np.square(img_hr[slice] - img_sr[slice]))
        list_mse.append(mse)

        # compute PSNR
        if mse == 0.:
            psnr = 100
        else:
            psnr = 20 * np.log10(255. / np.sqrt(mse))
        list_psnr.append(psnr)

        if filepath is not None:
            global_image = np.zeros((new_shape[0], 3 * new_shape[1], 3),
                                    dtype=np.uint8)
            global_image[:, 0:new_shape[1], :] = img_cubic.astype(np.uint8)
            global_image[:, new_shape[1]:2 * new_shape[1], :] = img_sr.astype(
                np.uint8)
            global_image[:, 2 * new_shape[1]:3 *
                         new_shape[1], :] = img_hr.astype(np.uint8)

            # add black padding
            global_image_ext = np.zeros(
                (global_image.shape[0] + 50, global_image.shape[1], 3),
                dtype=np.uint8)
            global_image_ext[0:global_image.shape[0], :, :] = global_image
            global_image_ext = cv2.putText(
                img=np.copy(global_image_ext),
                text="MSE = {:.3f} | PSNR = {:.3f}".format(mse, psnr),
                org=(0, new_shape[0] + 50),
                fontFace=1,
                fontScale=2,
                color=(255, 255, 255),
                thickness=2)

            cv2.imwrite(os.path.join(current_filepath,
                                     '{}.png'.format(img_name)),
                        global_image_ext.astype(np.uint8),
                        params=[cv2.IMWRITE_PNG_COMPRESSION, 3])
    stop = time()

    if verbose > 0:
        print(
            '\nBSD100 Callback - Epoch %05d: MSE = %s  || PSNR = %s  in %05d s'
            % (step, np.mean(list_mse), np.mean(list_psnr), stop - start))
    return np.mean(list_mse), np.mean(list_psnr)
コード例 #6
0
ファイル: Train.py プロジェクト: Tubbz-alt/eyeSRGAN
    iterator = train_ds.__iter__()
    batch_LR, batch_HR = next(iterator)
    # batch_LR, batch_HR = batch_gen.next()

    print(batch_LR.numpy().shape) 
    print(batch_HR.numpy().shape)

    if data_format == 'channels_first':
        batch_HR = np.transpose(batch_HR, (0, 2, 3, 1))
        batch_LR = np.transpose(batch_LR, (0, 2, 3, 1))

    fig, axes = plt.subplots(4, 2, figsize=(7, 15))
    for i in range(4):
        axes[i, 0].imshow(utils.deprocess_LR(batch_LR.numpy()[i]).astype(np.uint8))
        axes[i, 1].imshow(utils.deprocess_HR(batch_HR.numpy()[i]).astype(np.uint8))

    common_optimizer = tf.keras.optimizers.Adam(lr=1e-5, beta_1=0.9)
    # common_optimizer = tf.keras.optimizers.RMSprop(lr=1e-4)

    epochs = 20
    steps_per_epoch = int(len(list_files) // batch_size)

    eval_freq = 3000
    info_freq = 100
    checkpoint_freq = 3000

    if os.path.isdir('./outputs/checkpoints/SRGAN-VGG54/'):
        shutil.rmtree('./outputs/checkpoints/SRGAN-VGG54/')
    os.makedirs('./outputs/checkpoints/SRGAN-VGG54/')