Пример #1
0
    def train_srcnn(self, iteration):
        # images = low resolution, labels = high resolution
        sess = self.sess
        #load data
        train_label_list = sorted(glob.glob('./dataset/training/gray/*.*'))

        num_image = len(train_label_list)

        sr_model = SRCNN(channel_length=self.c_length, image=self.x)
        v1, v2, prediction = sr_model.build_model()

        with tf.name_scope("mse_loss"):
            loss = tf.reduce_mean(tf.square(self.y - prediction))
        '''
        train_op1 = tf.train.GradientDescentOptimizer(learning_rate=1e-4).minimize(loss, var_list=v1)
        train_op2 = tf.train.GradientDescentOptimizer(learning_rate=1e-5).minimize(loss, var_list=v2)
        train_op = tf.group(train_op1, train_op2)
        '''
        train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)

        batch_size = 3
        num_batch = int(num_image / batch_size)

        init = tf.global_variables_initializer()
        sess.run(init)

        saver = tf.train.Saver(max_to_keep=1)
        if self.pre_trained:
            saver.restore(sess, self.save_path)

        for i in range(iteration):
            total_mse_loss = 0
            for j in range(num_batch):
                for k in range(2, 5):
                    train_image_list = sorted(
                        glob.glob('./dataset/training/X{}/*.*'.format(k)))
                    batch_image, batch_label = preprocess.load_data(
                        train_image_list, train_label_list, j * batch_size,
                        min((j + 1) * batch_size, num_image), self.patch_size,
                        self.num_patch_per_image)
                    mse_loss, _ = sess.run([loss, train_op],
                                           feed_dict={
                                               self.x: batch_image,
                                               self.y: batch_label
                                           })
                    total_mse_loss += mse_loss / (num_batch * 3)
                    # print(mse_loss)

            print('In', i + 1, 'epoch, current loss is',
                  '{:.5f}'.format(total_mse_loss))
            saver.save(sess, save_path=self.save_path)

        print('Train completed')
Пример #2
0
    def test(self, mode, inference):
        # images = low resolution, labels = high resolution
        sess = self.sess

        # for training a particular image(one image)
        test_label_list = sorted(glob.glob('./dataset/test/gray/*.*'))

        num_image = len(test_label_list)

        assert mode == 'SRCNN' or mode == 'VDSR'
        if mode == 'SRCNN':
            sr_model = SRCNN(channel_length=self.c_length, image=self.x)
            _, _, prediction = sr_model.build_model()
        elif mode == 'VDSR':
            sr_model = VDSR(channel_length=self.c_length, image=self.x)
            prediction, residual, _ = sr_model.build_model()

        with tf.name_scope("PSNR"):
            psnr = 10 * tf.log(255 * 255 * tf.reciprocal(
                tf.reduce_mean(tf.square(self.y - prediction)))) / tf.log(
                    tf.constant(10, dtype='float32'))

        init = tf.global_variables_initializer()
        sess.run(init)

        saver = tf.train.Saver()

        saver.restore(sess, self.save_path)

        for j in range(2, 5):
            avg_psnr = 0
            for i in range(num_image):
                test_image_list = sorted(
                    glob.glob('./dataset/test/X{}/*.*'.format(j)))
                test_image = np.array(Image.open(test_image_list[i]))
                test_image = test_image[np.newaxis, :, :, np.newaxis]
                test_label = np.array(Image.open(test_label_list[i]))
                h = test_label.shape[0]
                w = test_label.shape[1]
                h -= h % j
                w -= w % j
                test_label = test_label[np.newaxis, 0:h, 0:w, np.newaxis]
                # print(test_image.shape, test_label.shape)

                final_psnr = sess.run(psnr,
                                      feed_dict={
                                          self.x: test_image,
                                          self.y: test_label
                                      })

                print('X{} : Test PSNR is '.format(j), final_psnr)
                avg_psnr += final_psnr

                if inference:
                    pred = sess.run(prediction,
                                    feed_dict={
                                        self.x: test_image,
                                        self.y: test_label
                                    })
                    pred = np.squeeze(pred).astype(dtype='uint8')
                    pred_image = Image.fromarray(pred)
                    filename = './restored_{0}/{3}/{1}_X{2}.png'.format(
                        mode, i, j, self.date)
                    pred_image.save(filename)
                    if mode == 'VDSR':
                        res = sess.run(residual,
                                       feed_dict={
                                           self.x: test_image,
                                           self.y: test_label
                                       })
                        res = np.squeeze(res).astype(dtype='uint8')
                        res_image = Image.fromarray(res)
                        filename = './restored_{0}/{3}/{1}_X{2}_res.png'.format(
                            mode, i, j, self.date)
                        res_image.save(filename)

            print('X{} : Avg PSNR is '.format(j), avg_psnr / 5)