Beispiel #1
0
# Check the output_dir is given
if FLAGS.output_dir is None:
    raise ValueError('The output directory is needed')
# Check the output directory to save the checkpoint
if not os.path.exists(FLAGS.output_dir):
    os.mkdir(FLAGS.output_dir)
# Check the summary directory to save the event
if not os.path.exists(FLAGS.summary_dir):
    os.mkdir(FLAGS.summary_dir)

if __name__ == '__main__':
    # import itertools

    # dataset = loadLRHR(FLAGS)
    frvsr_gpu_data_loader(FLAGS, 1)
    # iterator = dataset.make_initializable_iterator()
    # s_inputs, s_targets = iterator.get_next()

    # for lr_images, hr_images in dataset.take(-1):

    #     print(lr_images.get_shape().as_list())

    # plt.figure(figsize=(12,12))
    # tf.InteractiveSession()
    # #     lr_data, hr_data = sess.run([x,y])
    # #     print(1, lr_data.shape, hr_data.shape)
    # iterator = dataset.make_one_shot_iterator()
    # # for lr_images_list, hr_images_list in iterator:
    # lr_images_list, hr_images_list = iterator.get_next()
    # for i in range(FLAGS.RNN_N):
Beispiel #2
0
    filelist = [
        'main.py', 'lib/Teco.py', 'lib/frvsr.py', 'lib/dataloader.py',
        'lib/ops.py'
    ]
    for filename in filelist:
        shutil.copyfile('./' + filename,
                        FLAGS.summary_dir + filename.replace("/", "_"))

    for key, value_obj in tf.flags.FLAGS.__flags.items():
        print(key, ': ', value_obj.value)

    print(FLAGS.input_video_dir)

    useValidat = tf.placeholder_with_default(tf.constant(False, dtype=tf.bool),
                                             shape=())
    rdata = frvsr_gpu_data_loader(FLAGS, useValidat)
    # Data = collections.namedtuple('Data', 'paths_HR, s_inputs, s_targets, image_count, steps_per_epoch')
    print('tData count = %d, steps per epoch %d' %
          (rdata.image_count, rdata.steps_per_epoch))
    if (FLAGS.ratio > 0):
        Net = TecoGAN(rdata.s_inputs, rdata.s_targets, FLAGS)
    else:
        Net = FRVSR(rdata.s_inputs, rdata.s_targets, FLAGS)
    # Network = collections.namedtuple('Network', 'gen_output, train, learning_rate, update_list, '
    #                                     'update_list_name, update_list_avg, image_summary')

    # Add scalar summary
    tf.summary.scalar('learning_rate', Net.learning_rate)
    train_summary = []
    for key, value in zip(Net.update_list_name, Net.update_list_avg):
        # 'map_loss, scale_loss, FrameA_loss, FrameA_loss,...'