Exemple #1
0
def main(_):
    print("Super Resolution (tensorflow version:%s)" % tf.__version__)
    print("%s\n" % util.get_now_date())

    if FLAGS.model_name is "":
        model_name = "model_F%d_D%d_LR%f" % (FLAGS.feature_num, FLAGS.inference_depth, FLAGS.initial_lr)
    else:
        model_name = "model_%s" % FLAGS.model_name
    model = sr.SuperResolution(FLAGS, model_name=model_name)
    
    test_filename = util.get_files_in_directory(FLAGS.data_dir + "/" + FLAGS.file + "/")
    test_filenames = [test_filename[i:i+5] for i in range(0, len(test_filename), 5)]
    label_filenames = util.get_files_in_directory(FLAGS.data_dir + "/" + FLAGS.label_file + "/")

    FLAGS.load_model = True

    model.build_lap_graph()
    model.build_merge_graph()
    model.build_inference_graph()
    model.build_optimizer()
    model.init_all_variables(load_initial_data=FLAGS.load_model)

    for i in range(len(label_filenames)):
        mse_bic, mse = model.do_super_resolution(label_filenames[i], test_filenames[i], FLAGS.output_dir)
        psnr_bic = util.get_psnr(mse_bic)
        psnr = util.get_psnr(mse)
        path, name = os.path.split(label_filenames[i])
        print("%s MSE:%f, PSNR_bil:%f, PSNR:%f\n" % (name, mse, psnr_bic, psnr))
Exemple #2
0
    def evaluate(self):

        summary_str, mse = self.sess.run(
            [self.summary_op, self.mse],
            feed_dict={
                self.x: self.test.input.image,
                self.y: self.test.true.image,
                self.loss_alpha_input: self.loss_alpha
            })

        self.summary_writer.add_summary(summary_str, self.step)
        self.summary_writer.flush()

        if self.min_validation_mse < 0 or self.min_validation_mse > mse:
            self.min_validation_epoch = self.epochs_completed
            self.min_validation_mse = mse
        else:
            if self.epochs_completed > self.min_validation_epoch + self.lr_decay_epoch:
                self.min_validation_epoch = self.epochs_completed
                self.min_validation_mse = mse
                self.lr *= self.lr_decay

        psnr = util.get_psnr(mse, max_value=self.max_value)
        self.psnr_graph_epoch.append(self.epochs_completed)
        self.psnr_graph_value.append(psnr)

        return mse
	def train_batch(self, log_mse=False):

		_, mse = self.sess.run([self.train_step, self.mse], feed_dict={self.x: self.batch_input_images,
		                                                               self.y: self.batch_true_images,
		                                                               self.lr_input: self.lr,
		                                                               self.loss_alpha_input: self.loss_alpha})
		self.step += 1
		self.training_psnr = util.get_psnr(mse, max_value=self.max_value)
	def print_status(self, mse):

		psnr = util.get_psnr(mse, max_value=self.max_value)
		if self.step == 0:
			print("Initial MSE:%f PSNR:%f" % (mse, psnr))
		else:
			processing_time = (time.time() - self.start_time) / self.step
			print("%s Step:%d MSE:%f PSNR:%f (%f)" % (util.get_now_date(), self.step, mse, psnr, self.training_psnr))
			print("Epoch:%d LR:%f α:%f (%2.2fsec/step)" % (self.epochs_completed, self.lr, self.loss_alpha, processing_time))
    def do_super_resolution_for_test(self, file_path, output_folder):

        filename, extension = os.path.splitext(file_path)
        true_image = util.set_image_alignment(util.load_image(file_path),
                                              self.scale)
        util.save_image("output/" + file_path, true_image)

        input_image = util.load_input_image(file_path,
                                            channels=self.channels,
                                            scale=self.scale,
                                            alignment=self.scale,
                                            convert_ycbcr=True,
                                            jpeg_mode=self.jpeg_mode)
        util.save_image("output/" + filename + "_input" + extension,
                        input_image)
        input_color_image = util.load_input_image(file_path,
                                                  channels=3,
                                                  scale=self.scale,
                                                  alignment=self.scale,
                                                  convert_ycbcr=False,
                                                  jpeg_mode=self.jpeg_mode)
        util.save_image("output/" + filename + "_input_c" + extension,
                        input_color_image)

        if len(true_image.shape
               ) >= 3 and true_image.shape[2] == 3 and self.channels == 1:
            true_image = util.convert_rgb_to_y(true_image,
                                               jpeg_mode=self.jpeg_mode)
            util.save_image("output/" + filename + "_true" + extension,
                            true_image)
            input_ycbcr_image = util.load_input_image(file_path,
                                                      channels=3,
                                                      scale=self.scale,
                                                      alignment=self.scale,
                                                      convert_ycbcr=True,
                                                      jpeg_mode=self.jpeg_mode)
            output_image = self.do(input_image)
            output_color_image = util.convert_y_and_cbcr_to_rgb(
                output_image,
                input_ycbcr_image[:, :, 1:3],
                jpeg_mode=self.jpeg_mode)
            util.save_image("output/" + filename + "_result_c" + extension,
                            output_color_image)
        else:
            # for monochro or rgb image
            output_image = self.do(input_image)

        mse = util.compute_mse(true_image,
                               output_image,
                               border_size=self.scale)

        util.save_image("output/" + filename + "_result" + extension,
                        output_image)
        print("MSE:%f PSNR:%f" % (mse, util.get_psnr(mse)))
        return mse
Exemple #6
0
def main(_):

    #print("Super Resolution (tensorflow version:%s)" % tf.__version__)
    print("%s\n" % util.get_now_date())  #开始时间

    if FLAGS.model_name is "":  #保存模型名字
        model_name = "model_F%d_D%d_LR%f" % (
            FLAGS.feature_num, FLAGS.inference_depth, FLAGS.initial_lr)
    else:
        model_name = "model_%s" % FLAGS.model_name

    model = sr.SuperResolution(FLAGS, model_name=model_name)  #建立DRCN模型

    test_filenames = util.build_test_filenames(FLAGS.data_dir, FLAGS.dataset,
                                               FLAGS.scale)  #获取测试文件
    if FLAGS.is_training:
        if FLAGS.dataset == "test":
            training_filenames = util.build_test_filenames(
                FLAGS.data_dir, FLAGS.dataset, FLAGS.scale)
        else:
            training_filenames = util.get_files_in_directory(
                FLAGS.data_dir + "/" + FLAGS.training_set + "/")  #获得训练文件

        print("Loading and building cache images...")
        model.load_datasets(
            FLAGS.cache_dir, training_filenames, test_filenames,
            FLAGS.batch_size,
            FLAGS.stride_size)  #模型用的训练数据、测试数据、缓存图像数据目录、隔多少个数据取一次mini-batch
    else:
        FLAGS.load_model = True

    model.build_embedding_graph()  # embedding层
    model.build_inference_graph()  # inference层
    model.build_reconstruction_graph()  # reconstruction层
    model.build_optimizer()  #创建损失函数及优化器
    model.init_all_variables(load_initial_data=FLAGS.load_model)  #初始化变量

    if FLAGS.is_training:
        train(training_filenames, test_filenames, model)

    psnr = 0
    total_mse = 0
    for filename in test_filenames:
        mse = model.do_super_resolution_for_test(
            filename, FLAGS.output_dir)  # 对测试图片进行超分辨
        total_mse += mse
        psnr += util.get_psnr(mse)  #

    print("\n--- summary --- %s" % util.get_now_date())  #结束时间
    model.print_steps_completed()  #输出完成训练所需的总epoch,steps,time等
    util.print_num_of_total_parameters()  #输出一共有多少参数
    print("Final MSE:%f, PSNR:%f" %
          (total_mse / len(test_filenames), psnr / len(test_filenames)))
Exemple #7
0
def main(_):

    print("Super Resolution (tensorflow version:%s)" % tf.__version__)
    print("%s\n" % util.get_now_date())

    if FLAGS.model_name is "":
        model_name = "model_F%d_D%d_LR%f" % (
            FLAGS.feature_num, FLAGS.inference_depth, FLAGS.initial_lr)
    else:
        model_name = "model_%s" % FLAGS.model_name
    model = sr.SuperResolution(FLAGS, model_name=model_name)

    test_filenames = util.build_test_filenames(FLAGS.data_dir, FLAGS.dataset,
                                               FLAGS.scale)
    if FLAGS.is_training:
        if FLAGS.dataset == "test":
            training_filenames = util.build_test_filenames(
                FLAGS.data_dir, FLAGS.dataset, FLAGS.scale)
        else:
            training_filenames = util.get_files_in_directory(
                FLAGS.data_dir + "/" + FLAGS.training_set + "/")

        print("Loading and building cache images...")
        model.load_datasets(FLAGS.cache_dir, training_filenames,
                            test_filenames, FLAGS.batch_size,
                            FLAGS.stride_size)
    else:
        FLAGS.load_model = True

    model.build_embedding_graph()
    model.build_inference_graph()
    model.build_reconstruction_graph()
    model.build_optimizer()
    model.init_all_variables(load_initial_data=FLAGS.load_model)

    if FLAGS.is_training:
        train(training_filenames, test_filenames, model)

    psnr = 0
    total_mse = 0
    for filename in test_filenames:
        mse = model.do_super_resolution_for_test(filename, FLAGS.output_dir)
        total_mse += mse
        psnr += util.get_psnr(mse)

    print("\n--- summary --- %s" % util.get_now_date())
    model.print_steps_completed()
    util.print_num_of_total_parameters()
    print("Final MSE:%f, PSNR:%f" %
          (total_mse / len(test_filenames), psnr / len(test_filenames)))
    def print_status(self, step, mse):

        processing_time = (time.time() - self.start_time) / step

        if self.trained > 0:
            training_psnr = self.training_psnr / self.trained
        else:
            training_psnr = 0

        print("%s Step:%d MSE:%f PSNR:%f (%f)" %
              (util.get_now_date(), step, mse,
               util.get_psnr(mse, max_value=self.max_value), training_psnr))
        print(
            "Epoch:%d LR:%f α:%f (%2.2fsec/step)" %
            (self.epochs_completed, self.lr, self.loss_alpha, processing_time))
Exemple #9
0
    def do_super_resolution_for_test(self, i, label_file_path, file_path, output_folder="output", output=True):

        true_image = util.set_image_alignment(util.load_image(label_file_path), self.scale)
        output_folder = output_folder + "/"
        filename, extension = os.path.splitext(label_file_path)

        input_y_image = []
        if len(true_image.shape) >= 3 and true_image.shape[2] == 3 and self.channels == 1:
            for j in range (i*5, i*5+5):
                input_image = util.load_input_image(file_path[i], channels=1, scale=self.scale, alignment=self.scale)
                input_y_image.append(input_image)         # convert_ycbcr:True->False
            input_y_image = np.dstack((input_y_image[0], input_y_image[1],
                                       input_y_image[2], input_y_image[3], input_y_image[4]))
            true_ycbcr_image = util.convert_rgb_to_ycbcr(true_image, jpeg_mode=self.jpeg_mode)

            output_y_image = self.do(input_y_image, true_ycbcr_image[:, :, 0:1])
            mse = util.compute_mse(true_ycbcr_image[:, :, 0:1], output_y_image, border_size=self.scale)

            if output:
                output_color_image = util.convert_y_and_cbcr_to_rgb(output_y_image, true_ycbcr_image[:, :, 1:3],
        		                                                    jpeg_mode=self.jpeg_mode)
                loss_image = util.get_loss_image(true_ycbcr_image[:, :, 0:1], output_y_image, border_size=self.scale)

                util.save_image(output_folder + label_file_path, true_image)
                util.save_image(output_folder + filename + "_input" + extension, input_y_image)
                util.save_image(output_folder + filename + "_true_y" + extension, true_ycbcr_image[:, :, 0:1])
                util.save_image(output_folder + filename + "_result" + extension, output_y_image)
                util.save_image(output_folder + filename + "_result_c" + extension, output_color_image)
                util.save_image(output_folder + filename + "_loss" + extension, loss_image)
        else:
            for j in range (i*5, i*5+5):
                input_image = util.load_input_image(file_path[i], channels=1, scale=self.scale, alignment=self.scale)
                input_y_image.append(util.build_input_image(input_image, channels=self.channels, scale=self.scale, alignment=self.scale,
        	                                                convert_ycbcr=False, jpeg_mode=self.jpeg_mode))         # convert_ycbcr:True->False
            input_y_image = np.dstack((input_y_image[0], input_y_image[1],
                                       input_y_image[2], input_y_image[3], input_y_image[4]))
            output_image = self.do(input_y_image, true_image)
            mse = util.compute_mse(true_image, output_image, border_size=self.scale)

            if output:
                util.save_image(output_folder + label_file_path, true_image)
                util.save_image(output_folder + filename + "_result" + extension, output_image)

        print("MSE:%f PSNR:%f" % (mse, util.get_psnr(mse)))
        return mse
	def do_super_resolution_for_test(self, file_path, output_folder="output", output=True):

		filename, extension = os.path.splitext(file_path)
		output_folder = output_folder + "/"
		true_image = util.set_image_alignment(util.load_image(file_path), self.scale)

		if len(true_image.shape) >= 3 and true_image.shape[2] == 3 and self.channels == 1:
			input_y_image = util.build_input_image(true_image, channels=self.channels, scale=self.scale, alignment=self.scale,
			                                       convert_ycbcr=True, jpeg_mode=self.jpeg_mode)
			true_ycbcr_image = util.convert_rgb_to_ycbcr(true_image, jpeg_mode=self.jpeg_mode)

			output_y_image = self.do(input_y_image)
			mse = util.compute_mse(true_ycbcr_image[:, :, 0:1], output_y_image, border_size=self.scale)

			if output:
				output_color_image = util.convert_y_and_cbcr_to_rgb(output_y_image, true_ycbcr_image[:, :, 1:3],
				                                                    jpeg_mode=self.jpeg_mode)
				loss_image = util.get_loss_image(true_ycbcr_image[:, :, 0:1], output_y_image, border_size=self.scale)

				util.save_image(output_folder + file_path, true_image)
				util.save_image(output_folder + filename + "_input" + extension, input_y_image)
				util.save_image(output_folder + filename + "_true_y" + extension, true_ycbcr_image[:, :, 0:1])
				util.save_image(output_folder + filename + "_result" + extension, output_y_image)
				util.save_image(output_folder + filename + "_result_c" + extension, output_color_image)
				util.save_image(output_folder + filename + "_loss" + extension, loss_image)
		else:
			input_image = util.load_input_image(file_path, channels=1, scale=self.scale, alignment=self.scale)
			output_image = self.do(input_image)
			mse = util.compute_mse(true_image, output_image, border_size=self.scale)

			if output:
				util.save_image(output_folder + file_path, true_image)
				util.save_image(output_folder + filename + "_result" + extension, output_image)

		print("MSE:%f PSNR:%f" % (mse, util.get_psnr(mse)))
		return mse
Exemple #11
0
    def do_super_resolution_for_test(self, file_path, output_folder="output", output=True):

        filename, extension = os.path.splitext(file_path)
        output_folder = output_folder + "/"
        true_image = util.set_image_alignment(util.load_image(file_path), self.scale) #读取文件夹中图片并根据scale图片对准

        if len(true_image.shape) >= 3 and true_image.shape[2] == 3 and self.channels == 1: #图片预处理 
            input_y_image = util.build_input_image(true_image, channels=self.channels, scale=self.scale, alignment=self.scale,
                                                   convert_ycbcr=True, jpeg_mode=self.jpeg_mode) #从true_image中创建网络输入图片测试LR图片
            true_ycbcr_image = util.convert_rgb_to_ycbcr(true_image, jpeg_mode=self.jpeg_mode) # 将true_image从RGB格式转为YCBCR格式HR图片

            output_y_image = self.do(input_y_image) #输入测试LR图片到网络中得到输出
            mse = util.compute_mse(true_ycbcr_image[:, :, 0:1], output_y_image, border_size=self.scale)#计算HR图与输出图的mse

            if output: 
                output_color_image = util.convert_y_and_cbcr_to_rgb(output_y_image, true_ycbcr_image[:, :, 1:3],
                                                                    jpeg_mode=self.jpeg_mode) #把输出图片YCBCR再转为RGB
                loss_image = util.get_loss_image(true_ycbcr_image[:, :, 0:1], output_y_image, border_size=self.scale)#将HR图和LR输出图相减

                util.save_image(output_folder + file_path, true_image)
                util.save_image(output_folder + filename + "_input" + extension, input_y_image)#保存测试LR图片
                util.save_image(output_folder + filename + "_true_y" + extension, true_ycbcr_image[:, :, 0:1])#保存测试HR图片
                util.save_image(output_folder + filename + "_result" + extension, output_y_image)#保存测试输出图片
                util.save_image(output_folder + filename + "_result_c" + extension, output_color_image)#保存输出图片转为RGB图
                util.save_image(output_folder + filename + "_loss" + extension, loss_image)#保存HR图与LR输出图的差值
        else:
            input_image = util.load_input_image(file_path, channels=1, scale=self.scale, alignment=self.scale)
            output_image = self.do(input_image)
            mse = util.compute_mse(true_image, output_image, border_size=self.scale)

            if output:
                util.save_image(output_folder + file_path, true_image)
                util.save_image(output_folder + filename + "_result" + extension, output_image)

        print("MSE:%f PSNR:%f" % (mse, util.get_psnr(mse)))
        return mse