예제 #1
0
def main(not_parsed_args):
    if len(not_parsed_args) > 1:
        print("Unknown args:%s" % not_parsed_args)
        exit()

    model = DCSCN.SuperResolution(FLAGS, model_name=FLAGS.model_name)
    model.build_graph()
    model.build_summary_saver()
    model.init_all_variables()

    if FLAGS.test_dataset == "all":
        test_list = ['set5', 'set14', 'bsd100']
    else:
        test_list = [FLAGS.test_dataset]

    # FLAGS.tests refer to the number of training sets to be used
    for i in range(FLAGS.tests):
        model.load_model(FLAGS.load_model_name,
                         trial=i,
                         output_log=True if FLAGS.tests > 1 else False)

        if FLAGS.compute_bicubic:
            for test_data in test_list:
                evaluate_bicubic(model, test_data)

        for test_data in test_list:
            evaluate_model(model, test_data)

    model.save_model(name="inference_model", output_log=True)
예제 #2
0
def main(not_parsed_args):
    if len(not_parsed_args) > 1:
        print("Unknown args:%s" % not_parsed_args)
        exit()

    model = DCSCN.SuperResolution(FLAGS, model_name=FLAGS.model_name)
    if (FLAGS.frozenInference):
        model.load_graph(FLAGS.frozen_graph_path)
        model.build_summary_saver(
            with_saver=False
        )  # no need because we are not saving any variables
    else:
        model.build_graph()
        model.build_summary_saver()
    model.init_all_variables()

    if FLAGS.test_dataset == "all":
        test_list = ['set5', 'set14', 'bsd100']
    else:
        test_list = [FLAGS.test_dataset]

    for i in range(FLAGS.tests):
        if (not FLAGS.frozenInference):
            model.load_model(FLAGS.load_model_name,
                             trial=i,
                             output_log=True if FLAGS.tests > 1 else False)

        if FLAGS.compute_bicubic:
            for test_data in test_list:
                print(test_data)
                evaluate_bicubic(model, test_data)

        for test_data in test_list:
            evaluate_model(model, test_data)
예제 #3
0
파일: evaluate.py 프로젝트: doctorwgd/DCDRN
def main(not_parsed_args):
    if len(not_parsed_args) > 1:
        print("Unknown args:%s" % not_parsed_args)
        exit()

    # modifying process/build options for faster processing
    if FLAGS.load_model_name == "":
        FLAGS.load_model_name = "default"
    FLAGS.save_loss = False
    FLAGS.save_weights = False
    FLAGS.save_images = False

    model = DCSCN.SuperResolution(FLAGS, model_name=FLAGS.model_name)
    model.build_graph()
    model.build_summary_saver()
    model.init_all_variables()

    logging.info("evaluate model performance")

    if FLAGS.test_dataset == "all":
        test_list = ['set5', 'set14', 'bsd100']
    else:
        test_list = [FLAGS.test_dataset]

    for i in range(FLAGS.tests):
        model.load_model(FLAGS.load_model_name, i,
                         True if FLAGS.tests > 1 else False)
        for test_data in test_list:
            test(model, test_data)
예제 #4
0
def main(_):
       model = DCSCN.SuperResolution(FLAGS, model_name=FLAGS.model_name)
       model.build_graph()


       model.init_all_variables()
       model.load_model()

       model.do_for_file(FLAGS.file, FLAGS.output_dir)
예제 #5
0
def load_model(flags, model_path):
    model = DCSCN.SuperResolution(FLAGS, model_name=model_path)
    model.build_graph()
    model.build_optimizer()
    model.build_summary_saver()

    model.init_all_variables()
    model.load_model(name=model_path)
    return model
예제 #6
0
def main(_):
    model = DCSCN.SuperResolution(FLAGS, model_name=FLAGS.model_name)
    model.build_graph()
    model.build_optimizer()
    model.build_summary_saver()

    model.init_all_variables()
    model.load_model()
    model.export_model()
예제 #7
0
def main(not_parsed_args):
    if len(not_parsed_args) > 1:
        print("Unknown args:%s" % not_parsed_args)
        exit()

    model = DCSCN.SuperResolution(FLAGS, model_name=FLAGS.model_name)

    model.train = model.load_dynamic_datasets(
        FLAGS.data_dir + "/" + FLAGS.dataset, FLAGS.batch_image_size,
        FLAGS.stride_size)
    model.test = model.load_datasets(
        FLAGS.data_dir + "/" + FLAGS.test_dataset,
        FLAGS.batch_dir + "/" + FLAGS.test_dataset, FLAGS.batch_image_size,
        FLAGS.stride_size)

    model.build_graph()
    model.build_optimizer()
    model.build_summary_saver()
    logging.info("\n" + str(sys.argv))
    logging.info("Test Data:" + FLAGS.test_dataset + " Training Data:" +
                 FLAGS.dataset)

    final_mse = final_psnr = 0
    test_filenames = util.get_files_in_directory(FLAGS.data_dir + "/" +
                                                 FLAGS.test_dataset)

    for i in range(FLAGS.tests):

        train(model, FLAGS, i)

        total_psnr = total_mse = 0
        for filename in test_filenames:
            mse = model.do_for_evaluate(filename,
                                        FLAGS.output_dir,
                                        output=i is (FLAGS.tests - 1),
                                        print_console=False)
            total_mse += mse
            total_psnr += util.get_psnr(mse, max_value=FLAGS.max_value)

        logging.info("\nTrial(%d) %s" % (i, util.get_now_date()))
        model.print_steps_completed(output_to_logging=True)
        logging.info("MSE:%f, PSNR:%f\n" % (total_mse / len(test_filenames),
                                            total_psnr / len(test_filenames)))

        final_mse += total_mse
        final_psnr += total_psnr

    logging.info("=== summary [%d] %s [%s] ===" %
                 (FLAGS.tests, model.name, util.get_now_date()))
    util.print_num_of_total_parameters(output_to_logging=True)
    n = len(test_filenames) * FLAGS.tests
    logging.info("\n=== Final Average [%s] MSE:%f, PSNR:%f ===" %
                 (FLAGS.test_dataset, final_mse / n, final_psnr / n))

    model.copy_log_to_archive("archive")
예제 #8
0
def main(_):
    model = DCSCN.SuperResolution(FLAGS, model_name=FLAGS.model_name)
    model.build_graph()
    model.build_optimizer()
    model.build_summary_saver()
    model.init_all_variables()
    model.load_model()

    assert (os.path.isdir(FLAGS.dir))
    filelist = os.listdir(FLAGS.dir)
    for file_name in filelist[:]:
        if file_name.endswith(".png"):
            model.do_for_file(FLAGS.dir + '/' + file_name, FLAGS.output_dir)
예제 #9
0
def main(not_parsed_args):
    if len(not_parsed_args) > 1:
        print("Unknown args:%s" % not_parsed_args)
        exit()

    model = DCSCN.SuperResolution(FLAGS,
                                  model_name=FLAGS.model_name,
                                  is_module_training=True)

    # if FLAGS.build_batch:
    #     model.load_datasets(FLAGS.data_dir + "/" + FLAGS.dataset, FLAGS.batch_dir + "/" + FLAGS.dataset,
    #                         FLAGS.batch_image_size, FLAGS.stride_size)
    # else:
    #     model.load_dynamic_datasets(FLAGS.data_dir + "/" + FLAGS.dataset, FLAGS.batch_image_size)
    if FLAGS.build_batch:
        # Not implemented for MISR
        logging.error("'build_batch' not implemented for MISR")
        raise NotImplementedError
    else:
        model.load_dynamic_datasets_misr(
            data_dir=FLAGS.data_dir,
            batch_image_size=FLAGS.batch_image_size,
            dataset_name=FLAGS.dataset)

    model.build_graph()
    model.build_optimizer()
    model.build_summary_saver()

    logging.info("\n" + str(sys.argv))
    logging.info("Test Data:" + FLAGS.test_dataset + " Training Data:" +
                 FLAGS.dataset)
    util.print_num_of_total_parameters(output_to_logging=True)

    total_psnr = total_ssim = 0

    for i in range(FLAGS.tests):
        psnr, ssim, cpsnr = train_misr(model, FLAGS, i)
        total_psnr += psnr
        total_ssim += ssim

        logging.info("\nTrial(%d) %s" % (i, util.get_now_date()))
        model.print_steps_completed(output_to_logging=True)
        logging.info("PSNR:%f, SSIM:%f\n" % (psnr, ssim))

    if FLAGS.tests > 1:
        logging.info("\n=== Final Average [%s] PSNR:%f, SSIM:%f ===" %
                     (FLAGS.test_dataset, total_psnr / FLAGS.tests,
                      total_ssim / FLAGS.tests))

    model.copy_log_to_archive("archive")
예제 #10
0
def main(not_parsed_args):
	if len(not_parsed_args) > 1:
		print("Unknown args:%s" % not_parsed_args)
		exit()

	model = DCSCN.SuperResolution(FLAGS, model_name=FLAGS.model_name)
	model.build_graph()
	model.build_summary_saver()
	model.init_all_variables()

	if FLAGS.test_dataset == "all":
		test_list = ['set5', 'set14', 'bsd100']
	else:
		test_list = [FLAGS.test_dataset]

	for i in range(FLAGS.tests):
		model.load_model(FLAGS.load_model_name, trial=i, output_log=True if FLAGS.tests > 1 else False)
		for test_data in test_list:
			test(model, test_data)
예제 #11
0
def main(not_parsed_args):
    if len(not_parsed_args) > 1:
        print("Unknown args:%s" % not_parsed_args)
        exit()

    model = DCSCN.SuperResolution(FLAGS, model_name=FLAGS.model_name)

    if FLAGS.build_batch:
        model.load_datasets(FLAGS.data_dir + "/" + FLAGS.dataset,
                            FLAGS.batch_dir + "/" + FLAGS.dataset,
                            FLAGS.batch_image_size, FLAGS.stride_size)
    else:
        model.load_dynamic_datasets(FLAGS.data_dir + "/" + FLAGS.dataset,
                                    FLAGS.batch_image_size)
    model.build_graph()
    model.build_optimizer()
    model.build_summary_saver()

    logging.info("\n" + str(sys.argv))
    logging.info("Test Data:" + FLAGS.test_dataset + " Training Data:" +
                 FLAGS.dataset)
    util.print_num_of_total_parameters(output_to_logging=True)

    total_psnr = total_mse = 0

    for i in range(FLAGS.tests):

        mse = train(model, FLAGS, i)
        psnr = util.get_psnr(mse, max_value=FLAGS.max_value)
        total_mse += mse
        total_psnr += psnr

        logging.info("\nTrial(%d) %s" % (i, util.get_now_date()))
        model.print_steps_completed(output_to_logging=True)
        logging.info("MSE:%f, PSNR:%f\n" % (mse, psnr))

    if FLAGS.tests > 1:
        logging.info("\n=== Final Average [%s] MSE:%f, PSNR:%f ===" %
                     (FLAGS.test_dataset, total_mse / FLAGS.tests,
                      total_psnr / FLAGS.tests))

    model.copy_log_to_archive("archive")
def main(not_parsed_args):
    if len(not_parsed_args) > 1:
        print("Unknown args:%s" % not_parsed_args)
        exit()

    model = DCSCN.SuperResolution(FLAGS, model_name=FLAGS.model_name)

    model.load_datasets("training", FLAGS.data_dir + "/" + FLAGS.dataset, FLAGS.batch_dir + "/" + FLAGS.dataset,
                        FLAGS.batch_image_size, FLAGS.stride_size)
    model.load_datasets("test", FLAGS.data_dir + "/" + FLAGS.test_dataset, FLAGS.batch_dir + "/" + FLAGS.test_dataset,
                        FLAGS.batch_image_size, FLAGS.stride_size)

    model.build_graph()
    model.build_optimizer()
    model.build_summary_saver()
    logging.info("\n" + str(sys.argv))
    logging.info("Test Data:" + FLAGS.test_dataset + " Training Data:" + FLAGS.dataset)

    final_psnr = final_ssim = 0
    test_filenames = util.get_files_in_directory(FLAGS.data_dir + "/" + FLAGS.test_dataset)

    for i in range(FLAGS.tests):

        train(model, FLAGS, i)

        total_psnr = total_ssim = 0
        for filename in test_filenames:
            psnr, ssim = model.do_for_evaluate(filename, FLAGS.output_dir, output=i is (FLAGS.tests - 1))
            total_psnr += psnr
            total_ssim += ssim

        logging.info("\nTrial(%d) %s" % (i, util.get_now_date()))
        model.print_steps_completed(output_to_logging=True)
        logging.info("PSNR:%f, SSIM:%f\n" % (total_psnr / len(test_filenames), total_ssim / len(test_filenames)))

        final_psnr += total_psnr
        final_ssim += total_ssim

    logging.info("=== summary [%d] %s [%s] ===" % (FLAGS.tests, model.name, util.get_now_date()))
    util.print_num_of_total_parameters(output_to_logging=True)
    n = len(test_filenames) * FLAGS.tests
    logging.info("\n=== Average [%s] PSNR:%f, SSIM:%f ===" % (FLAGS.test_dataset, final_psnr / n, final_ssim / n))
예제 #13
0
def main(not_parsed_args):
    if len(not_parsed_args) > 1:
        print("Unknown args:%s" % not_parsed_args)
        exit()

    model = DCSCN.SuperResolution(FLAGS, model_name=FLAGS.model_name)

    # script allows you to split training images into batches in advance.
    if FLAGS.build_batch:
        model.load_datasets(FLAGS.data_dir + "/" + FLAGS.dataset,
                            FLAGS.batch_dir + "/" + FLAGS.dataset,
                            FLAGS.batch_image_size, FLAGS.stride_size)
    else:
        model.load_dynamic_datasets(FLAGS.data_dir + "/" + FLAGS.dataset,
                                    FLAGS.batch_image_size)
    model.build_graph()
    model.build_optimizer()
    model.build_summary_saver()

    logging.info("\n" + str(sys.argv))
    logging.info("Test Data:" + FLAGS.test_dataset + " Training Data:" +
                 FLAGS.dataset)
    util.print_num_of_total_parameters(output_to_logging=True)

    total_psnr = total_ssim = 0

    for i in range(FLAGS.tests):
        psnr, ssim = train(model, FLAGS, i)
        total_psnr += psnr
        total_ssim += ssim

        logging.info("\nTrial(%d) %s" % (i, util.get_now_date()))
        model.print_steps_completed(output_to_logging=True)
        logging.info("PSNR:%f, SSIM:%f\n" % (psnr, ssim))

    if FLAGS.tests > 1:
        logging.info("\n=== Final Average [%s] PSNR:%f, SSIM:%f ===" %
                     (FLAGS.test_dataset, total_psnr / FLAGS.tests,
                      total_ssim / FLAGS.tests))

    model.copy_log_to_archive("archive")
예제 #14
0
def main(_):
    model = DCSCN.SuperResolution(FLAGS, model_name=FLAGS.model_name)
    model.build_graph()
    model.build_optimizer()
    model.build_summary_saver()

    model.init_all_variables()
    model.load_model()

    video=FLAGS.file
    lrclip=VideoFileClip(video).subclip(0,20)
    audioclip=lrclip.audio
    audioclip.write_audiofile('audio'+video[:-4]+'.mp3')
    hrclip=lrclip.fl_image(model.doframe)
    # hrclip.ipython_display()
    # lrclip.ipython_display()
    '''
        Uncomment the line if you run the code in jupyter notebook
    '''
    hroutput='hr'+video
    hrclip.write_videofile(hroutput, audio='audio'+video[:-4]+'.mp3', threads=8, progress_bar=False)
예제 #15
0
def main(not_parsed_args):
    if len(not_parsed_args) > 1:
        print("Unknown args:%s" % not_parsed_args)
        exit()

    model = DCSCN.SuperResolution(FLAGS, model_name=FLAGS.model_name)
    model.load_graph()
    model.build_summary_saver(with_saver=False) # no need because we are not saving any variables
    model.init_all_variables()

    if FLAGS.test_dataset == "all":
        test_list = ['set5', 'set14', 'bsd100']
    else:
        test_list = [FLAGS.test_dataset]

    # FLAGS.tests refer to the number of training sets to be used
    for i in range(FLAGS.tests):

        if FLAGS.compute_bicubic:
            for test_data in test_list:
                evaluate_bicubic(model, test_data)

        for test_data in test_list:
            evaluate_model(model, test_data)