def train_srcnn(self, iteration): # images = low resolution, labels = high resolution sess = self.sess #load data train_label_list = sorted(glob.glob('./dataset/training/gray/*.*')) num_image = len(train_label_list) sr_model = SRCNN(channel_length=self.c_length, image=self.x) v1, v2, prediction = sr_model.build_model() with tf.name_scope("mse_loss"): loss = tf.reduce_mean(tf.square(self.y - prediction)) ''' train_op1 = tf.train.GradientDescentOptimizer(learning_rate=1e-4).minimize(loss, var_list=v1) train_op2 = tf.train.GradientDescentOptimizer(learning_rate=1e-5).minimize(loss, var_list=v2) train_op = tf.group(train_op1, train_op2) ''' train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss) batch_size = 3 num_batch = int(num_image / batch_size) init = tf.global_variables_initializer() sess.run(init) saver = tf.train.Saver(max_to_keep=1) if self.pre_trained: saver.restore(sess, self.save_path) for i in range(iteration): total_mse_loss = 0 for j in range(num_batch): for k in range(2, 5): train_image_list = sorted( glob.glob('./dataset/training/X{}/*.*'.format(k))) batch_image, batch_label = preprocess.load_data( train_image_list, train_label_list, j * batch_size, min((j + 1) * batch_size, num_image), self.patch_size, self.num_patch_per_image) mse_loss, _ = sess.run([loss, train_op], feed_dict={ self.x: batch_image, self.y: batch_label }) total_mse_loss += mse_loss / (num_batch * 3) # print(mse_loss) print('In', i + 1, 'epoch, current loss is', '{:.5f}'.format(total_mse_loss)) saver.save(sess, save_path=self.save_path) print('Train completed')
def test(self, mode, inference): # images = low resolution, labels = high resolution sess = self.sess # for training a particular image(one image) test_label_list = sorted(glob.glob('./dataset/test/gray/*.*')) num_image = len(test_label_list) assert mode == 'SRCNN' or mode == 'VDSR' if mode == 'SRCNN': sr_model = SRCNN(channel_length=self.c_length, image=self.x) _, _, prediction = sr_model.build_model() elif mode == 'VDSR': sr_model = VDSR(channel_length=self.c_length, image=self.x) prediction, residual, _ = sr_model.build_model() with tf.name_scope("PSNR"): psnr = 10 * tf.log(255 * 255 * tf.reciprocal( tf.reduce_mean(tf.square(self.y - prediction)))) / tf.log( tf.constant(10, dtype='float32')) init = tf.global_variables_initializer() sess.run(init) saver = tf.train.Saver() saver.restore(sess, self.save_path) for j in range(2, 5): avg_psnr = 0 for i in range(num_image): test_image_list = sorted( glob.glob('./dataset/test/X{}/*.*'.format(j))) test_image = np.array(Image.open(test_image_list[i])) test_image = test_image[np.newaxis, :, :, np.newaxis] test_label = np.array(Image.open(test_label_list[i])) h = test_label.shape[0] w = test_label.shape[1] h -= h % j w -= w % j test_label = test_label[np.newaxis, 0:h, 0:w, np.newaxis] # print(test_image.shape, test_label.shape) final_psnr = sess.run(psnr, feed_dict={ self.x: test_image, self.y: test_label }) print('X{} : Test PSNR is '.format(j), final_psnr) avg_psnr += final_psnr if inference: pred = sess.run(prediction, feed_dict={ self.x: test_image, self.y: test_label }) pred = np.squeeze(pred).astype(dtype='uint8') pred_image = Image.fromarray(pred) filename = './restored_{0}/{3}/{1}_X{2}.png'.format( mode, i, j, self.date) pred_image.save(filename) if mode == 'VDSR': res = sess.run(residual, feed_dict={ self.x: test_image, self.y: test_label }) res = np.squeeze(res).astype(dtype='uint8') res_image = Image.fromarray(res) filename = './restored_{0}/{3}/{1}_X{2}_res.png'.format( mode, i, j, self.date) res_image.save(filename) print('X{} : Avg PSNR is '.format(j), avg_psnr / 5)