Exemplo n.º 1
0
def main(_):
    t0 = time.time()

    pp.pprint(FLAGS.build_model)

    if not FLAGS.build_model:
        FLAGS.test_img = validate(FLAGS.test_img)
        print("Image path = ", FLAGS.test_img)
        if not os.path.isfile(FLAGS.test_img):
            print("File does not exist ", FLAGS.test_img)
            sys.exit()

    create_required_directories(FLAGS)

    with tf.compat.v1.Session() as sess:
        srcnn = SRCNN(sess,
                      image_size=FLAGS.image_size,
                      label_size=FLAGS.label_size,
                      batch_size=FLAGS.batch_size,
                      c_dim=FLAGS.c_dim,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir)

        if FLAGS.build_model:
            srcnn.train(FLAGS)
        else:
            srcnn.test(FLAGS)
    print("\n\nTime taken %4.2f\n\n" % (time.time() - t0))
Exemplo n.º 2
0
def main(_):
    """3.print configurations"""
    print('tf version:', tf.__version__)
    print('tf setup:')
    for k, v in FLAGS.flag_values_dict().items():
        print(k, v)
    FLAGS.TB_dir += '_' + str(FLAGS.c_dim)
    """4.check/create folders"""
    print("check dirs...")
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.TB_dir):
        os.makedirs(FLAGS.TB_dir)
    """5.begin tf session"""
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        print("building model...")
        """6.init srcnn model"""
        srcnn = SRCNN(sess, FLAGS)
        """7.start to train/test"""
        if (FLAGS.is_train):
            srcnn.train()
        elif FLAGS.patch_test:
            srcnn.test()
        else:
            srcnn.test_whole_img()
Exemplo n.º 3
0
def main(_):
    """3.print configurations"""
    print('tf version:',tf.__version__)
    print('tf setup:')
    #os.makedirs(FLAGS.checkpoint_dir)
    """5.begin tf session"""
    with tf.Session() as sess:
        """6.init srcnn model"""
        srcnn = SRCNN(sess, FLAGS)
        """7.start to train/test"""
        if(FLAGS.is_train):
            srcnn.train()
        else:
            srcnn.test()
Exemplo n.º 4
0
def main(_):
    """3.print configurations"""
    print('tf version:', tf.__version__)
    print('tf setup:')
    for k, v in FLAGS.flag_values_dict().items():
        print(k, v)
    """4.check/create folders"""
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    """5.begin tf session"""
    with tf.Session() as sess:
        """6.init srcnn model"""
        srcnn = SRCNN(sess, FLAGS)
        """7.start to train/test"""
        if (FLAGS.is_train):
            srcnn.train()
        else:
            srcnn.test()
Exemplo n.º 5
0
def main(_):
  pp.pprint(flags.FLAGS.__flags)

  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)

  with tf.Session() as sess:
    srcnn = SRCNN(sess, 
                  image_size=FLAGS.image_size, 
                  label_size=FLAGS.label_size, 
                  batch_size=FLAGS.batch_size,
                  c_dim=FLAGS.c_dim, 
                  checkpoint_dir=FLAGS.checkpoint_dir,
                  sample_dir=FLAGS.sample_dir)

    if not FLAGS.is_train:
        srcnn.test(FLAGS)
    else:
        srcnn.train(FLAGS)
Exemplo n.º 6
0
def test():
    print("process the image to h5file.....")
    test_dir = flags.test_dir
    test_h5_dir = flags.test_h5_dir
    stride = flags.test_stride
    if not os.path.exists(test_h5_dir):
        os.makedirs(test_h5_dir)

    test_set5 = os.path.join(test_dir, 'Set5')
    test_set14 = os.path.join(test_dir, 'Set14')
    path_set5 = os.path.join(test_h5_dir, 'Set5')
    path_set14 = os.path.join(test_h5_dir, 'Set14')
    data_helper.gen_input_image(test_set5, path_set5, stride)
    data_helper.gen_input_image(test_set14, path_set14, stride)

    print("initialize the model......")
    model_dir = flags.model_dir
    model = SRCNN(flags)
    model.build_graph()
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(model_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(model.sess, ckpt.model_checkpoint_path)
    else:
        print("model info didn't exist!")
        raise ValueError

    print("test in Set5......")
    test_h5_path = os.path.join(path_set5, "data.h5")
    data_set5, label_set5 = data_helper.load_data(test_h5_path)
    accu = model.test(data_set5, label_set5)
    print("the accuracy in Set5 is %.5f", accu)

    print("test in Set14......")
    test_h5_path = os.path.join(path_set14, "data.h5")
    data_set14, label_set14 = data_helper.load_data(test_h5_path)
    accu2 = model.test(data_set14, label_set14)
    print("the accuracy in Set14 is %.5f", accu2)
Exemplo n.º 7
0
        self.batch_size = 128
        self.result_dir = 'result'
        self.test_img = ''  # Do not change this.


arg = this_config()
print(
    "Hello TA!  We are group 7. Thank you for your work for us. Hope you have a happy day!"
)

with tf.Session() as sess:
    FLAGS = arg
    srcnn = SRCNN(sess,
                  image_size=FLAGS.image_size,
                  label_size=FLAGS.label_size,
                  c_dim=FLAGS.c_dim)
    srcnn.train(FLAGS)

    # Testing
    files = glob.glob(os.path.join(os.getcwd(), 'train_set', 'LR0', '*.jpg'))
    test_files = random.sample(files, len(files) // 5)

    FLAGS.is_train = False
    count = 1
    for f in test_files:
        FLAGS.test_img = f
        print('Saving ', count, '/', len(test_files), ': ', FLAGS.test_img,
              '\n')
        count += 1
        srcnn.test(FLAGS)