예제 #1
0
파일: main.py 프로젝트: liuzhiit/DRCN_tf
def main(_):

    #print("Super Resolution (tensorflow version:%s)" % tf.__version__)
    print("%s\n" % util.get_now_date())  #开始时间

    if FLAGS.model_name is "":  #保存模型名字
        model_name = "model_F%d_D%d_LR%f" % (
            FLAGS.feature_num, FLAGS.inference_depth, FLAGS.initial_lr)
    else:
        model_name = "model_%s" % FLAGS.model_name

    model = sr.SuperResolution(FLAGS, model_name=model_name)  #建立DRCN模型

    test_filenames = util.build_test_filenames(FLAGS.data_dir, FLAGS.dataset,
                                               FLAGS.scale)  #获取测试文件
    if FLAGS.is_training:
        if FLAGS.dataset == "test":
            training_filenames = util.build_test_filenames(
                FLAGS.data_dir, FLAGS.dataset, FLAGS.scale)
        else:
            training_filenames = util.get_files_in_directory(
                FLAGS.data_dir + "/" + FLAGS.training_set + "/")  #获得训练文件

        print("Loading and building cache images...")
        model.load_datasets(
            FLAGS.cache_dir, training_filenames, test_filenames,
            FLAGS.batch_size,
            FLAGS.stride_size)  #模型用的训练数据、测试数据、缓存图像数据目录、隔多少个数据取一次mini-batch
    else:
        FLAGS.load_model = True

    model.build_embedding_graph()  # embedding层
    model.build_inference_graph()  # inference层
    model.build_reconstruction_graph()  # reconstruction层
    model.build_optimizer()  #创建损失函数及优化器
    model.init_all_variables(load_initial_data=FLAGS.load_model)  #初始化变量

    if FLAGS.is_training:
        train(training_filenames, test_filenames, model)

    psnr = 0
    total_mse = 0
    for filename in test_filenames:
        mse = model.do_super_resolution_for_test(
            filename, FLAGS.output_dir)  # 对测试图片进行超分辨
        total_mse += mse
        psnr += util.get_psnr(mse)  #

    print("\n--- summary --- %s" % util.get_now_date())  #结束时间
    model.print_steps_completed()  #输出完成训练所需的总epoch,steps,time等
    util.print_num_of_total_parameters()  #输出一共有多少参数
    print("Final MSE:%f, PSNR:%f" %
          (total_mse / len(test_filenames), psnr / len(test_filenames)))
예제 #2
0
def main(_):

    print("Super Resolution (tensorflow version:%s)" % tf.__version__)
    print("%s\n" % util.get_now_date())

    if FLAGS.model_name is "":
        model_name = "model_F%d_D%d_LR%f" % (
            FLAGS.feature_num, FLAGS.inference_depth, FLAGS.initial_lr)
    else:
        model_name = "model_%s" % FLAGS.model_name
    model = sr.SuperResolution(FLAGS, model_name=model_name)

    test_filenames = util.build_test_filenames(FLAGS.data_dir, FLAGS.dataset,
                                               FLAGS.scale)
    if FLAGS.is_training:
        if FLAGS.dataset == "test":
            training_filenames = util.build_test_filenames(
                FLAGS.data_dir, FLAGS.dataset, FLAGS.scale)
        else:
            training_filenames = util.get_files_in_directory(
                FLAGS.data_dir + "/" + FLAGS.training_set + "/")

        print("Loading and building cache images...")
        model.load_datasets(FLAGS.cache_dir, training_filenames,
                            test_filenames, FLAGS.batch_size,
                            FLAGS.stride_size)
    else:
        FLAGS.load_model = True

    model.build_embedding_graph()
    model.build_inference_graph()
    model.build_reconstruction_graph()
    model.build_optimizer()
    model.init_all_variables(load_initial_data=FLAGS.load_model)

    if FLAGS.is_training:
        train(training_filenames, test_filenames, model)

    psnr = 0
    total_mse = 0
    for filename in test_filenames:
        mse = model.do_super_resolution_for_test(filename, FLAGS.output_dir)
        total_mse += mse
        psnr += util.get_psnr(mse)

    print("\n--- summary --- %s" % util.get_now_date())
    model.print_steps_completed()
    util.print_num_of_total_parameters()
    print("Final MSE:%f, PSNR:%f" %
          (total_mse / len(test_filenames), psnr / len(test_filenames)))
예제 #3
0
def main(_):
    print("Super Resolution (tensorflow version:%s)" % tf.__version__)
    print("%s\n" % util.get_now_date())

    if FLAGS.model_name is "":
        model_name = "model_F%d_D%d_LR%f" % (FLAGS.feature_num, FLAGS.inference_depth, FLAGS.initial_lr)
    else:
        model_name = "model_%s" % FLAGS.model_name
    model = sr.SuperResolution(FLAGS, model_name=model_name)
    
    test_filename = util.get_files_in_directory(FLAGS.data_dir + "/" + FLAGS.file + "/")
    test_filenames = [test_filename[i:i+5] for i in range(0, len(test_filename), 5)]
    label_filenames = util.get_files_in_directory(FLAGS.data_dir + "/" + FLAGS.label_file + "/")

    FLAGS.load_model = True

    model.build_lap_graph()
    model.build_merge_graph()
    model.build_inference_graph()
    model.build_optimizer()
    model.init_all_variables(load_initial_data=FLAGS.load_model)

    for i in range(len(label_filenames)):
        mse_bic, mse = model.do_super_resolution(label_filenames[i], test_filenames[i], FLAGS.output_dir)
        psnr_bic = util.get_psnr(mse_bic)
        psnr = util.get_psnr(mse)
        path, name = os.path.split(label_filenames[i])
        print("%s MSE:%f, PSNR_bil:%f, PSNR:%f\n" % (name, mse, psnr_bic, psnr))
	def print_status(self, mse):

		psnr = util.get_psnr(mse, max_value=self.max_value)
		if self.step == 0:
			print("Initial MSE:%f PSNR:%f" % (mse, psnr))
		else:
			processing_time = (time.time() - self.start_time) / self.step
			print("%s Step:%d MSE:%f PSNR:%f (%f)" % (util.get_now_date(), self.step, mse, psnr, self.training_psnr))
			print("Epoch:%d LR:%f α:%f (%2.2fsec/step)" % (self.epochs_completed, self.lr, self.loss_alpha, processing_time))
예제 #5
0
    def print_status(self, step, mse):

        processing_time = (time.time() - self.start_time) / step

        if self.trained > 0:
            training_psnr = self.training_psnr / self.trained
        else:
            training_psnr = 0

        print("%s Step:%d MSE:%f PSNR:%f (%f)" %
              (util.get_now_date(), step, mse,
               util.get_psnr(mse, max_value=self.max_value), training_psnr))
        print(
            "Epoch:%d LR:%f α:%f (%2.2fsec/step)" %
            (self.epochs_completed, self.lr, self.loss_alpha, processing_time))
예제 #6
0
파일: test.py 프로젝트: liuzhiit/DRCN_tf
def main(_):
	print("Super Resolution (tensorflow version:%s)" % tf.__version__)
	print("%s\n" % util.get_now_date())

	if FLAGS.model_name is "":
		model_name = "model_F%d_D%d_LR%f" % (FLAGS.feature_num, FLAGS.inference_depth, FLAGS.initial_lr)
	else:
		model_name = "model_%s" % FLAGS.model_name
	model = sr.SuperResolution(FLAGS, model_name=model_name)

	test_filenames = [FLAGS.file]
	FLAGS.load_model = True

	model.build_embedding_graph()
	model.build_inference_graph()
	model.build_reconstruction_graph()
	model.build_optimizer()
	model.init_all_variables(load_initial_data=FLAGS.load_model)

	model.do_super_resolution(FLAGS.file, FLAGS.output_dir)