def test(self):
        sml_x = tf.placeholder(tf.float32,
                               [None, self.idims[0], self.idims[1], 3])
        odims = tf.placeholder(tf.int32, [2])
        gener_x = generator(sml_x, odims, is_training=False, reuse=False)
        init = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init)
            saver = tf.train.Saver()
            try:
                saver.restore(sess,
                              '/'.join(['models', self.model, self.model]))
            except Exception as e:
                print('Model could not be restored. Exiting.\nError: ' + e)
                exit()
            makedirs(self.out_path)
            print('Performing super resolution ...')
            for idx in range(0, self.dataset_size, self.batch_size):
                start, end = idx, min(idx + self.batch_size, self.dataset_size)
                batch = range(start, end)
                batch_big = self.dataset[batch] / 255.0
                batch_sml = np.array([
                    imresize(img, size=(self.idims[0], self.idims[1], 3))
                    for img in batch_big
                ])
                superres_imgs = sess.run(gener_x,
                                         feed_dict={
                                             sml_x: batch_sml,
                                             odims: self.odims
                                         })
                superres_imgs = np.array(superres_imgs * 255.0, dtype=np.uint8)
                nearest = np.array([
                    imresize(
                        img, size=superres_imgs.shape[1:], interp='nearest')
                    for img in batch_sml
                ],
                                   dtype=np.uint8)
                bilinear = np.array([
                    imresize(
                        img, size=superres_imgs.shape[1:], interp='bilinear')
                    for img in batch_sml
                ],
                                    dtype=np.uint8)
                bicubic = np.array([
                    imresize(
                        img, size=superres_imgs.shape[1:], interp='bicubic')
                    for img in batch_sml
                ],
                                   dtype=np.uint8)
                lanczos = np.array([
                    imresize(
                        img, size=superres_imgs.shape[1:], interp='lanczos')
                    for img in batch_sml
                ],
                                   dtype=np.uint8)
                original = np.array([
                    imresize(img, size=(self.odims[0], self.odims[1], 3))
                    for img in batch_big
                ],
                                    dtype=np.uint8)
                images = np.concatenate((nearest, bilinear, bicubic, lanczos,
                                         superres_imgs, original), 2)

                def display(im_data):
                    dpi = 80
                    height, width, depth = im_data.shape
                    figsize = width / float(dpi), height / float(dpi)
                    fig = plt.figure(figsize=figsize)
                    ax = fig.add_axes([0, 0, 1, 1])
                    ax.axis('off')
                    ax.imshow(im_data, cmap='gray')
                    plt.show()

                for (i, og, nn, bl, bc, la, sr,
                     image) in zip(range(100), original, nearest, bilinear,
                                   bicubic, lanczos, superres_imgs, images):
                    nn, _ = compare_ssim(og, og, nn)
                    bl, _ = compare_ssim(og, og, bl)
                    bc, _ = compare_ssim(og, og, bc)
                    la, _ = compare_ssim(og, og, la)
                    sr, _ = compare_ssim(og, og, sr)
                    # display(image)
                    plt.subplot(111)
                    title = 'Nearest               Bilinear               Bicubic               Lanczos               SRGAN               Original'.format(
                        nn, bl, bc, la, sr)
                    plt.title(title)
                    title = '{0:.4f}                      {1:.4f}                      {2:.4f}                      {3:.4f}                      {4:.4f}                      {5:.4f}'.format(
                        nn, bl, bc, la, sr, 1.000)
                    plt.xlabel(title)
                    plt.xticks([])
                    plt.yticks([])

                    plt.imshow(image)
                    plt.show()
                    # imsave('%s/%d.png' % (self.out_path, start+i), image)
                print('%d/%d saved successfully.' % (end, self.dataset_size))
Exemple #2
0
    def test(self):
        sml_x = tf.placeholder(tf.float32,
                               [None, self.idims[0], self.idims[1], 3])
        odims = tf.placeholder(tf.int32, [2])
        gener_x = generator(sml_x, odims, is_training=False, reuse=False)
        init = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init)
            saver = tf.train.Saver()
            try:
                saver.restore(sess,
                              '/'.join(['models', self.model, self.model]))
            except Exception as e:
                print('Model could not be restored. Exiting.\nError: ' + e)
                exit()
            succ, total = 0, 0
            avg_1, avg_2 = 0, 0

            print('Performing super resolution ...')
            for idx in range(0, self.dataset_size, self.batch_size):
                start, end = idx, min(idx + self.batch_size, self.dataset_size)
                batch = range(start, end)
                batch_big = self.dataset[batch] / 255.0
                batch_sml = np.array([
                    imresize(img, size=(self.idims[0], self.idims[1], 3))
                    for img in batch_big
                ])
                superres_imgs = sess.run(gener_x,
                                         feed_dict={
                                             sml_x: batch_sml,
                                             odims: self.odims
                                         })
                interpolated_imgs = np.array([
                    imresize(img, size=superres_imgs.shape[1:]) / 255.0
                    for img in batch_sml
                ])

                for i in range(len(batch_sml)):
                    original = np.array(imresize(batch_big[i],
                                                 size=(self.odims[0],
                                                       self.odims[1])),
                                        dtype=np.uint8)
                    superres = np.array(superres_imgs[i] * 255.0,
                                        dtype=np.uint8)
                    interpolated = np.array(imresize(batch_sml[i],
                                                     size=(self.odims[0],
                                                           self.odims[1]),
                                                     interp=args.interpol),
                                            dtype=np.uint8)
                    ssim_1, ssim_2 = compare_ssim(original, superres,
                                                  interpolated)
                    if ssim_1 <= ssim_2:
                        succ += 1
                    total += 1
                    avg_1 += ssim_1
                    avg_2 += ssim_2
                print('%d/%d completed.' % (end, self.dataset_size))
                print('Average SSIM: {0:.4f}, {1:.4f}'.format(
                    avg_1 / total, avg_2 / total))
            print('{}/{} images have better SSIM'.format(succ, total))
            print('Average SSIM: {0:.4f}, {1:.4f}'.format(
                avg_1 / total, avg_2 / total))