Пример #1
0
def main(not_parsed_args):
    if len(not_parsed_args) > 1:
        print("Unknown args:%s" % not_parsed_args)
        exit()
    training_filenames = util.get_files_in_directory(
        "/media/data1/ww/sr_data/DIV2K_aug2/DIV2K_train_HR")
    # target_dir_x2 = "/media/data1/ww/sr_data/DIV2K_aug2/291_LR_bicubic_X2/291_LR_bicubic/X2"
    target_dir_x3 = "/media/data1/ww/sr_data/DIV2K_aug2/DIV2K_train_LR_bicubic_X3/DIV2K_train_LR_bicubic/X3"
    target_dir_x4 = "/media/data1/ww/sr_data/DIV2K_aug2/DIV2K_train_LR_bicubic_X4/DIV2K_train_LR_bicubic/X4"
    # util.make_dir(target_dir_x2)
    util.make_dir(target_dir_x3)
    util.make_dir(target_dir_x4)
    for file_path in training_filenames:
        org_image = util.load_image(file_path)
        filename = os.path.basename(file_path)
        filename, extension = os.path.splitext(filename)
        # new_filename_x2 = target_dir_x2 + '/' +filename + 'x{}'.format(2)
        new_filename_x3 = target_dir_x3 + '/' + filename + 'x{}'.format(3)
        new_filename_x4 = target_dir_x4 + '/' + filename + 'x{}'.format(4)

        # bicubic_image_x2 = util.resize_image_by_pil(org_image, 1 / 2)
        bicubic_image_x3 = util.resize_image_by_pil(org_image, 1 / 3)
        bicubic_image_x4 = util.resize_image_by_pil(org_image, 1 / 4)
        # util.save_image(new_filename_x2 + extension, bicubic_image_x2)
        util.save_image(new_filename_x3 + extension, bicubic_image_x3)
        util.save_image(new_filename_x4 + extension, bicubic_image_x4)
Пример #2
0
def test(model, test_data):
    test_filenames = util.get_files_in_directory(FLAGS.test_dir + test_data)
    total_psnr = total_ssim = total_mse = total_time = 0

    path_file = FLAGS.output_dir + "/" + test_data + ".txt"

    for filename in test_filenames:
        mse, psnr_predicted, ssim_predicted, spend_time = model.do_for_evaluate_with_output(
            filename, output_directory=FLAGS.output_dir, print_console=False)
        total_mse += mse
        # total_psnr += util.get_psnr(mse, max_value=FLAGS.max_value)
        total_psnr += psnr_predicted
        total_ssim += ssim_predicted
        total_time += spend_time
        create_str_to_txt(path_file, test_data + ":" + "\n")
        create_str_to_txt(
            path_file, "PSNR:" + str(psnr_predicted) + " , SSIM:" +
            str(ssim_predicted) + " , Time:" + str(spend_time) + "\n")

    logging.info("\n=== [%s] MSE:%f, PSNR:%f , SSIM:%f , Time:%f===" %
                 (test_data, total_mse / len(test_filenames), total_psnr /
                  len(test_filenames), total_ssim / len(test_filenames),
                  total_time / len(test_filenames)))
    create_str_to_txt(
        path_file, "================================================\n" +
        "PSNR:" + str(total_psnr / len(test_filenames)) + " , SSIM:" +
        str(total_ssim / len(test_filenames)) + " , Time:" +
        str(total_time / len(test_filenames)) +
        "\n================================================\n")
Пример #3
0
    def build_batch(self, data_dir):
        """ Build batch images and. """

        print("Building batch images for %s..." % self.batch_dir)
        filenames = util.get_files_in_directory(data_dir)
        images_count = 0

        util.make_dir(self.batch_dir)
        util.clean_dir(self.batch_dir)
        util.make_dir(self.batch_dir + "/" + INPUT_IMAGE_DIR)
        util.make_dir(self.batch_dir + "/" + INTERPOLATED_IMAGE_DIR)
        util.make_dir(self.batch_dir + "/" + TRUE_IMAGE_DIR)

        processed_images = 0
        for filename in filenames:
            output_window_size = self.batch_image_size * self.scale
            output_window_stride = self.stride * self.scale

            input_image, input_interpolated_image, true_image = \
                build_image_set(filename, channels=self.channels, resampling_method=self.resampling_method,
                                scale=self.scale, print_console=False)

            # split into batch images
            input_batch_images = util.get_split_images(input_image, self.batch_image_size, stride=self.stride)
            input_interpolated_batch_images = util.get_split_images(input_interpolated_image, output_window_size,
                                                                    stride=output_window_stride)

            if input_batch_images is None or input_interpolated_batch_images is None:
                # if the original image size * scale is less than batch image size
                continue
            input_count = input_batch_images.shape[0]

            true_batch_images = util.get_split_images(true_image, output_window_size, stride=output_window_stride)

            for i in range(input_count):
                self.save_input_batch_image(images_count, input_batch_images[i])
                self.save_interpolated_batch_image(images_count, input_interpolated_batch_images[i])
                self.save_true_batch_image(images_count, true_batch_images[i])
                images_count += 1
            processed_images += 1
            if processed_images % 10 == 0:
                print('.', end='', flush=True)

        print("Finished")
        self.count = images_count

        print("%d mini-batch images are built(saved)." % images_count)

        config = configparser.ConfigParser()
        config.add_section("batch")
        config.set("batch", "count", str(images_count))
        config.set("batch", "scale", str(self.scale))
        config.set("batch", "batch_image_size", str(self.batch_image_size))
        config.set("batch", "stride", str(self.stride))
        config.set("batch", "channels", str(self.channels))

        with open(self.batch_dir + "/batch_images.ini", "w") as configfile:
            config.write(configfile)
Пример #4
0
def train(model, flags, trial):
    test_filenames = util.get_files_in_directory(flags.data_dir + "/" +
                                                 flags.test_dataset)
    if len(test_filenames) <= 0:
        print("Can't load images from [%s]" %
              (flags.data_dir + "/" + flags.test_dataset))
        exit()

    model.init_all_variables()
    if flags.load_model_name != "":
        model.load_model(flags.load_model_name, output_log=True)

    model.init_train_step()
    model.init_epoch_index()
    model_updated = True
    min_mse = None

    mse, psnr = model.evaluate(test_filenames)
    model.print_status(mse, psnr, log=True)
    model.log_to_tensorboard(test_filenames[0], psnr, save_meta_data=True)

    while model.lr > flags.end_lr:

        model.build_input_batch()
        model.train_batch()

        if model.training_step * model.batch_num >= model.training_images:

            # one training epoch finished
            model.epochs_completed += 1
            mse, psnr = model.evaluate(test_filenames)
            model.print_status(mse, psnr, log=model_updated)
            model.log_to_tensorboard(test_filenames[0],
                                     psnr,
                                     save_meta_data=model_updated)

            # save if performance gets better
            if min_mse is None or min_mse > mse:
                min_mse = mse
                model.save_model(trial=trial, output_log=False)

            model_updated = model.update_epoch_and_lr()
            model.init_epoch_index()

    model.end_train_step()

    # save last generation anyway
    model.save_model(trial=trial, output_log=True)

    # outputs result
    test(model, flags.test_dataset)

    if FLAGS.do_benchmark:
        for test_data in ['set5', 'set14', 'bsd100']:
            if test_data != flags.test_dataset:
                test(model, test_data)

    return mse, psnr
Пример #5
0
def main(not_parsed_args):
    if len(not_parsed_args) > 1:
        print("Unknown args:%s" % not_parsed_args)
        exit()

    model = DCSCN.SuperResolution(FLAGS, model_name=FLAGS.model_name)

    model.train = model.load_dynamic_datasets(
        FLAGS.data_dir + "/" + FLAGS.dataset, FLAGS.batch_image_size,
        FLAGS.stride_size)
    model.test = model.load_datasets(
        FLAGS.data_dir + "/" + FLAGS.test_dataset,
        FLAGS.batch_dir + "/" + FLAGS.test_dataset, FLAGS.batch_image_size,
        FLAGS.stride_size)

    model.build_graph()
    model.build_optimizer()
    model.build_summary_saver()
    logging.info("\n" + str(sys.argv))
    logging.info("Test Data:" + FLAGS.test_dataset + " Training Data:" +
                 FLAGS.dataset)

    final_mse = final_psnr = 0
    test_filenames = util.get_files_in_directory(FLAGS.data_dir + "/" +
                                                 FLAGS.test_dataset)

    for i in range(FLAGS.tests):

        train(model, FLAGS, i)

        total_psnr = total_mse = 0
        for filename in test_filenames:
            mse = model.do_for_evaluate(filename,
                                        FLAGS.output_dir,
                                        output=i is (FLAGS.tests - 1),
                                        print_console=False)
            total_mse += mse
            total_psnr += util.get_psnr(mse, max_value=FLAGS.max_value)

        logging.info("\nTrial(%d) %s" % (i, util.get_now_date()))
        model.print_steps_completed(output_to_logging=True)
        logging.info("MSE:%f, PSNR:%f\n" % (total_mse / len(test_filenames),
                                            total_psnr / len(test_filenames)))

        final_mse += total_mse
        final_psnr += total_psnr

    logging.info("=== summary [%d] %s [%s] ===" %
                 (FLAGS.tests, model.name, util.get_now_date()))
    util.print_num_of_total_parameters(output_to_logging=True)
    n = len(test_filenames) * FLAGS.tests
    logging.info("\n=== Final Average [%s] MSE:%f, PSNR:%f ===" %
                 (FLAGS.test_dataset, final_mse / n, final_psnr / n))

    model.copy_log_to_archive("archive")
Пример #6
0
def test(model, test_data):
	test_filenames = util.get_files_in_directory(FLAGS.data_dir + "/" + test_data)
	total_psnr = total_mse = 0

	for filename in test_filenames:
		mse = model.do_for_evaluate_with_output(filename, output_directory=FLAGS.output_dir, print_console=False)
		total_mse += mse
		total_psnr += util.get_psnr(mse, max_value=FLAGS.max_value)

	logging.info("\n=== [%s] MSE:%f, PSNR:%f ===" % (
		test_data, total_mse / len(test_filenames), total_psnr / len(test_filenames)))
Пример #7
0
def evaluate_model(model, test_data):
    test_filenames = util.get_files_in_directory(FLAGS.data_dir + "/" + test_data)
    total_psnr = total_ssim = 0

    for filename in test_filenames:
        psnr, ssim = model.do_for_evaluate_with_output(filename, output_directory=FLAGS.output_dir, print_console=False)
        total_psnr += psnr
        total_ssim += ssim

    logging.info("Model Average [%s] PSNR:%f, SSIM:%f" % (
        test_data, total_psnr / len(test_filenames), total_ssim / len(test_filenames)))
Пример #8
0
def evaluate_bicubic(model, test_data):
    test_filenames = util.get_files_in_directory(FLAGS.data_dir + "/" + test_data)
    total_psnr = total_ssim = 0

    for filename in test_filenames:
        psnr, ssim = model.evaluate_bicubic(filename, print_console=False)
        total_psnr += psnr
        total_ssim += ssim

    logging.info("Bicubic Average [%s] PSNR:%f, SSIM:%f" % (
        test_data, total_psnr / len(test_filenames), total_ssim / len(test_filenames)))
Пример #9
0
def test(model, test_data):
    test_filenames = util.get_files_in_directory(FLAGS.test_dir + test_data)
    total_psnr = total_ssim = total_mse = 0

    for filename in test_filenames:
        mse, psnr_predicted, ssim_predicted = model.do_for_evaluate_with_output(filename,output_directory=FLAGS.output_dir,print_console=False)
        total_mse += mse
        # total_psnr += util.get_psnr(mse, max_value=FLAGS.max_value)
        total_psnr += psnr_predicted
        total_ssim += ssim_predicted

    logging.info("\n=== [%s] MSE:%f, PSNR:%f , SSIM:%f===" % (
        test_data, total_mse / len(test_filenames), total_psnr / len(test_filenames), total_ssim / len(test_filenames)))
Пример #10
0
def evaluate_model(model, test_data):
    test_filenames = util.get_files_in_directory(FLAGS.data_dir + "/" + test_data)
    total_psnr = total_ssim = total_time = 0

    for filename in test_filenames:
        psnr, ssim, elapsed_time = model.do_for_evaluate(filename, output_directory=FLAGS.output_dir,
                                                        print_console=False, save_output_images=FLAGS.save_results)
        total_psnr += psnr
        total_ssim += ssim
        total_time += elapsed_time

    logging.info("Model Average [%s] PSNR:%f, SSIM:%f, Elapsed Time:%f" % (
        test_data, total_psnr / len(test_filenames), total_ssim / len(test_filenames), total_time / len(test_filenames)))
Пример #11
0
	def build_batch(self, data_dir, batch_dir):
		""" load from input files. Then save batch images on file to reduce memory consumption. """

		print("Building batch images for %s..." % batch_dir)
		filenames = util.get_files_in_directory(data_dir)
		images_count = 0

		util.make_dir(batch_dir)
		util.clean_dir(batch_dir)
		util.make_dir(batch_dir + "/" + INPUT_IMAGE_DIR)
		util.make_dir(batch_dir + "/" + INTERPOLATED_IMAGE_DIR)
		util.make_dir(batch_dir + "/" + TRUE_IMAGE_DIR)

		for filename in filenames:
			output_window_size = self.batch_image_size * self.scale
			output_window_stride = self.stride * self.scale

			input_image, input_interpolated_image = self.input.load_input_image(filename, rescale=True,
			                                                                    resampling_method=self.resampling_method)
			test_image = self.true.load_test_image(filename)

			# split into batch images
			input_batch_images = util.get_split_images(input_image, self.batch_image_size, stride=self.stride)
			input_interpolated_batch_images = util.get_split_images(input_interpolated_image, output_window_size,
			                                                        stride=output_window_stride)
			if input_batch_images is None or input_interpolated_batch_images is None:
				continue
			input_count = input_batch_images.shape[0]

			test_batch_images = util.get_split_images(test_image, output_window_size, stride=output_window_stride)

			for i in range(input_count):
				save_input_batch_image(batch_dir, images_count, input_batch_images[i])
				save_interpolated_batch_image(batch_dir, images_count, input_interpolated_batch_images[i])
				save_true_batch_image(batch_dir, images_count, test_batch_images[i])
				images_count += 1

		print("%d mini-batch images are built(saved)." % images_count)

		config = configparser.ConfigParser()
		config.add_section("batch")
		config.set("batch", "count", str(images_count))
		config.set("batch", "scale", str(self.scale))
		config.set("batch", "batch_image_size", str(self.batch_image_size))
		config.set("batch", "stride", str(self.stride))
		config.set("batch", "channels", str(self.channels))
		config.set("batch", "jpeg_mode", str(self.jpeg_mode))
		config.set("batch", "max_value", str(self.max_value))

		with open(batch_dir + "/batch_images.ini", "w") as configfile:
			config.write(configfile)
Пример #12
0
def build_batch(data_dir, thread, threads):
    """ Build batch images and. """

    filenames = util.get_files_in_directory(data_dir)
    images_count = 0
    processed_images = int(len(filenames)/threads*(thread))
    for filename in filenames[int(len(filenames)/threads*thread):int(len(filenames)/threads*(thread+1))]:
        output_window_size = BatchDataSets.batch_image_size * BatchDataSets.scale
        output_window_stride = BatchDataSets.stride * BatchDataSets.scale

        input_image, input_interpolated_image, true_image = \
            build_image_set(filename, channels=BatchDataSets.channels, resampling_method=BatchDataSets.resampling_method,
                            scale=BatchDataSets.scale, print_console=False)

        # split into batch images
        input_batch_images = util.get_split_images(input_image, BatchDataSets.batch_image_size, stride=BatchDataSets.stride)
        input_interpolated_batch_images = util.get_split_images(input_interpolated_image, output_window_size,
                                                                stride=output_window_stride)

        if input_batch_images is None or input_interpolated_batch_images is None:
            # if the original image size * scale is less than batch image size
            continue
        input_count = input_batch_images.shape[0]

        true_batch_images = util.get_split_images(true_image, output_window_size, stride=output_window_stride)

        for i in range(input_count):
            BatchDataSets.save_input_batch_image(thread*1000000+images_count, input_batch_images[i])
            BatchDataSets.save_interpolated_batch_image(thread*1000000+images_count, input_interpolated_batch_images[i])
            BatchDataSets.save_true_batch_image(thread*1000000+images_count, true_batch_images[i])
            images_count += 1
        processed_images += 1
        if processed_images % 10 == 0:
            print('.', end='', flush=True)

    print("Finished")
    BatchDataSets.count = images_count

    print("%d mini-batch images are built(saved)." % images_count)

    config = configparser.ConfigParser()
    config.add_section("batch")
    config.set("batch", "count", str(images_count))
    config.set("batch", "scale", str(BatchDataSets.scale))
    config.set("batch", "batch_image_size", str(BatchDataSets.batch_image_size))
    config.set("batch", "stride", str(BatchDataSets.stride))
    config.set("batch", "channels", str(BatchDataSets.channels))

    with open(BatchDataSets.batch_dir + "/batch_images.ini", "w") as configfile:
        config.write(configfile)
Пример #13
0
def test(model, test_data):
    test_filenames = util.get_files_in_directory(FLAGS.data_dir + "/" +
                                                 test_data)
    total_psnr = total_mse = 0

    for filename in test_filenames:
        mse = model.do_for_evaluate(filename,
                                    output_directory=FLAGS.output_dir,
                                    output=FLAGS.save_results)
        total_mse += mse
        total_psnr += util.get_psnr(mse, max_value=FLAGS.max_value)

    logging.info("\n=== Average [%s] MSE:%f, PSNR:%f ===" %
                 (test_data, total_mse / len(test_filenames),
                  total_psnr / len(test_filenames)))
Пример #14
0
def main(not_parsed_args):
    if len(not_parsed_args) > 1:
        print("Unknown args:%s" % not_parsed_args)
        exit()

    print("Building x%d augmented data." % FLAGS.augment_level)

    training_filenames = util.get_files_in_directory(FLAGS.data_dir + "/" +
                                                     FLAGS.dataset + "/")
    target_dir = FLAGS.data_dir + "/" + FLAGS.dataset + ("_%d/" %
                                                         FLAGS.augment_level)
    util.make_dir(target_dir)

    for file_path in training_filenames:
        org_image = util.load_image(file_path)

        filename = os.path.basename(file_path)
        filename, extension = os.path.splitext(filename)

        new_filename = target_dir + filename
        util.save_image(new_filename + extension, org_image)

        if FLAGS.augment_level >= 2:
            ud_image = np.flipud(org_image)
            util.save_image(new_filename + "_v" + extension, ud_image)
        if FLAGS.augment_level >= 3:
            lr_image = np.fliplr(org_image)
            util.save_image(new_filename + "_h" + extension, lr_image)
        if FLAGS.augment_level >= 4:
            lr_image = np.fliplr(org_image)
            lrud_image = np.flipud(lr_image)
            util.save_image(new_filename + "_hv" + extension, lrud_image)

        if FLAGS.augment_level >= 5:
            rotated_image1 = np.rot90(org_image)
            util.save_image(new_filename + "_r1" + extension, rotated_image1)
        if FLAGS.augment_level >= 6:
            rotated_image2 = np.rot90(org_image, -1)
            util.save_image(new_filename + "_r2" + extension, rotated_image2)

        if FLAGS.augment_level >= 7:
            rotated_image1 = np.rot90(org_image)
            ud_image = np.flipud(rotated_image1)
            util.save_image(new_filename + "_r1_v" + extension, ud_image)
        if FLAGS.augment_level >= 8:
            rotated_image2 = np.rot90(org_image, -1)
            ud_image = np.flipud(rotated_image2)
            util.save_image(new_filename + "_r2_v" + extension, ud_image)
Пример #15
0
def train(model, flags, trial, load_model_name=""):

    test_filenames = util.get_files_in_directory(flags.data_dir + "/" +
                                                 flags.test_dataset)

    model.init_all_variables()
    if load_model_name != "":
        model.load_model(load_model_name, output_log=True)

    model.init_train_step()
    model.init_epoch_index()
    model_updated = True
    mse = 0

    while model.lr > flags.end_lr:

        model.build_input_batch()
        model.train_batch()

        if model.training_step * model.batch_num >= model.training_images:

            # training epoch finished
            model.epochs_completed += 1
            mse, psnr = model.evaluate(test_filenames)
            model.print_status(mse, psnr, log=model_updated)
            model.log_to_tensorboard(test_filenames[0],
                                     psnr,
                                     save_meta_data=model_updated)

            model_updated = model.update_epoch_and_lr()
            model.init_epoch_index()

    model.end_train_step()
    model.save_model(trial=trial, output_log=True)

    # outputs result
    test(model, flags.test_dataset)

    if FLAGS.do_benchmark:
        for test_data in ['set5', 'set14', 'bsd100']:
            if test_data != flags.test_dataset:
                test(model, test_data)

    return mse
Пример #16
0
def evaluate_model(model, test_data):
    test_filenames = util.get_files_in_directory(FLAGS.data_dir + "/" + test_data)
    total_psnr = total_ssim = total_time = 0

    for filename in test_filenames:
        start = time.time()
        if FLAGS.save_results:
            psnr, ssim = model.do_for_evaluate_with_output(filename, output_directory=FLAGS.output_dir,
                                                           print_console=False)
        else:
            psnr, ssim = model.do_for_evaluate(filename, print_console=False)
        end = time.time()
        elapsed_time = end - start
        total_psnr += psnr
        total_ssim += ssim
        total_time += elapsed_time

    logging.info("Model Average [%s] PSNR:%f, SSIM:%f, Time (s): %f" % (
        test_data, total_psnr / len(test_filenames), total_ssim / len(test_filenames), total_time / len(test_filenames)))
Пример #17
0
def main(not_parsed_args):
    if len(not_parsed_args) > 1:
        print("Unknown args:%s" % not_parsed_args)
        exit()

    print("Building Y channel data...")

    training_filenames = util.get_files_in_directory(FLAGS.data_dir + "/" +
                                                     FLAGS.dataset + "/")
    target_dir = FLAGS.data_dir + "/" + FLAGS.dataset + "_y/"
    util.make_dir(target_dir)

    for file_path in training_filenames:
        org_image = util.load_image(file_path)
        if org_image.shape[2] == 3:
            org_image = util.convert_rgb_to_y(org_image)

        filename = os.path.basename(file_path)
        filename, extension = os.path.splitext(filename)

        new_filename = target_dir + filename
        util.save_image(new_filename + ".bmp", org_image)
Пример #18
0
def test(model, test_data):

    test_filenames = util.get_files_in_directory(FLAGS.data_dir + "/" +
                                                 test_data)
    total_psnr = total_time = total_mse = 0
    i = 0
    for filename in test_filenames:
        mse, time1 = model.do_for_evaluate(filename,
                                           output_directory=FLAGS.output_dir,
                                           output=FLAGS.save_results)
        total_mse += mse
        if i != 0:
            total_time += time1
        total_psnr += util.get_psnr(mse, max_value=FLAGS.max_value)
        i = i + 1
    logging.info("\n=== Average [%s] MSE:%f, PSNR:%f ===" %
                 (test_data, total_mse / len(test_filenames),
                  total_psnr / len(test_filenames)))
    ave_time = total_time / (len(test_filenames) - 1)
    #print ("end_t",end_t,"start_t",start_t)

    print("total_time: %4.4f   ave_time: %4.4f   " % (total_time, ave_time))
Пример #19
0
def main(not_parsed_args):
    if len(not_parsed_args) > 1:
        print("Unknown args:%s" % not_parsed_args)
        exit()

    print("Building x%d augmented data." % FLAGS.augment_level)

    training_filenames = util.get_files_in_directory(
        "/media/data3/ww/sr_data/DIV2K_train_HR/")
    target_dir = "/media/data3/ww/sr_data/DIV2K_train_HR" + (
        "_%d/" % FLAGS.augment_level)
    util.make_dir(target_dir)

    writer = tf.python_io.TFRecordWriter("DIV2K_org.tfrecords")
    writer2 = tf.python_io.TFRecordWriter("DIV2K_aug.tfrecords")
    for file_path in training_filenames:
        org_image = util.load_image(file_path)
        org_raw = org_image.tobytes()  #convert image to bytes

        train_object = tf.train.Example(features=tf.train.Features(
            feature={
                'org_raw':
                tf.train.Feature(bytes_list=tf.train.BytesList(
                    value=[org_raw]))
            }))
        writer.write(train_object.SerializeToString())

        ud_image = np.flipud(org_image)
        ud_raw = ud_image.tobytes()  # convert image to bytes

        train_object2 = tf.train.Example(features=tf.train.Features(
            feature={
                'org_raw':
                tf.train.Feature(bytes_list=tf.train.BytesList(value=[ud_raw]))
            }))
        writer2.write(train_object2.SerializeToString())
    writer.close()
Пример #20
0
 def set_data_dir(self, data_dir):
     self.filenames = util.get_files_in_directory(data_dir)
     self.count = len(self.filenames)
     if self.count <= 0:
         logging.error("Data Directory is empty.")
         exit(-1)
Пример #21
0
def load_and_evaluate_tflite_graph(
    output_dir,
    data_dir,
    test_data,
    model_path=os.path.join(os.getcwd(),
                            'model_to_freeze/converted_model.tflite')):
    # https://stackoverflow.com/questions/50443411/how-to-load-a-tflite-model-in-script
    # https://www.tensorflow.org/lite/convert/python_api#tensorflow_lite_python_interpreter_
    output_directory = output_dir
    output_directory += "/" + "tflite" + "/"
    util.make_dir(output_directory)

    test_filepaths = util.get_files_in_directory(data_dir + "/" + test_data)
    total_psnr = total_ssim = total_time = 0

    # Load TFLite model and allocate tensors.
    interpreter = tf.lite.Interpreter(model_path=model_path)
    # interpreter = tf.contrib.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()

    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    for file_path in test_filepaths:
        # split filename from extension
        filename, extension = os.path.splitext(file_path)

        # prepare true image
        true_image = util.set_image_alignment(
            util.load_image(file_path, print_console=False), FLAGS.scale)

        # start the timer
        if true_image.shape[2] == 3 and FLAGS.channels == 1:
            # prepare input and ground truth images
            input_y_image = loader.build_input_image(true_image,
                                                     channels=FLAGS.channels,
                                                     scale=FLAGS.scale,
                                                     alignment=FLAGS.scale,
                                                     convert_ycbcr=True)
            input_bicubic_y_image = util.resize_image_by_pil(
                input_y_image,
                FLAGS.scale,
                resampling_method=FLAGS.resampling_method)
            true_ycbcr_image = util.convert_rgb_to_ycbcr(true_image)

            # pass inputs through the model (need to recast and reshape inputs)
            input_y_image_reshaped = input_y_image.astype('float32')
            input_y_image_reshaped = input_y_image_reshaped.reshape(
                1, input_y_image.shape[0], input_y_image.shape[1],
                FLAGS.channels)

            input_bicubic_y_image_reshaped = input_bicubic_y_image.astype(
                'float32')
            input_bicubic_y_image_reshaped = input_bicubic_y_image_reshaped.reshape(
                1, input_bicubic_y_image.shape[0],
                input_bicubic_y_image.shape[1], FLAGS.channels)

            interpreter.set_tensor(input_details[0]['index'],
                                   input_y_image_reshaped)  # pass x
            interpreter.set_tensor(input_details[1]['index'],
                                   input_bicubic_y_image_reshaped)  # pass x2

            start = time.time()
            interpreter.invoke()
            end = time.time()

            output_y_image = interpreter.get_tensor(
                output_details[0]['index'])  # get y
            # resize the output into an image
            output_y_image = output_y_image.reshape(output_y_image.shape[1],
                                                    output_y_image.shape[2],
                                                    FLAGS.channels)

            # calculate psnr and ssim for the output
            psnr, ssim = util.compute_psnr_and_ssim(
                true_ycbcr_image[:, :, 0:1],
                output_y_image,
                border_size=FLAGS.psnr_calc_border_size)

            # get the loss image
            loss_image = util.get_loss_image(
                true_ycbcr_image[:, :, 0:1],
                output_y_image,
                border_size=FLAGS.psnr_calc_border_size)

            # get output color image
            output_color_image = util.convert_y_and_cbcr_to_rgb(
                output_y_image, true_ycbcr_image[:, :, 1:3])

            # save all images
            util.save_image(output_directory + file_path, true_image)
            util.save_image(output_directory + filename + "_input" + extension,
                            input_y_image)
            util.save_image(
                output_directory + filename + "_input_bicubic" + extension,
                input_bicubic_y_image)
            util.save_image(
                output_directory + filename + "_true_y" + extension,
                true_ycbcr_image[:, :, 0:1])
            util.save_image(
                output_directory + filename + "_result" + extension,
                output_y_image)
            util.save_image(
                output_directory + filename + "_result_c" + extension,
                output_color_image)
            util.save_image(output_directory + filename + "_loss" + extension,
                            loss_image)
        elapsed_time = end - start
        total_psnr += psnr
        total_ssim += ssim
        total_time += elapsed_time
    testSize = len(test_filepaths)
    print("Model Average [%s] PSNR:%f, SSIM:%f, Elapsed Time:%f" %
          (test_data, total_psnr / testSize, total_ssim / testSize,
           total_time / testSize))