# Check the output_dir is given if FLAGS.output_dir is None: raise ValueError('The output directory is needed') # Check the output directory to save the checkpoint if not os.path.exists(FLAGS.output_dir): os.mkdir(FLAGS.output_dir) # Check the summary directory to save the event if not os.path.exists(FLAGS.summary_dir): os.mkdir(FLAGS.summary_dir) if __name__ == '__main__': # import itertools # dataset = loadLRHR(FLAGS) frvsr_gpu_data_loader(FLAGS, 1) # iterator = dataset.make_initializable_iterator() # s_inputs, s_targets = iterator.get_next() # for lr_images, hr_images in dataset.take(-1): # print(lr_images.get_shape().as_list()) # plt.figure(figsize=(12,12)) # tf.InteractiveSession() # # lr_data, hr_data = sess.run([x,y]) # # print(1, lr_data.shape, hr_data.shape) # iterator = dataset.make_one_shot_iterator() # # for lr_images_list, hr_images_list in iterator: # lr_images_list, hr_images_list = iterator.get_next() # for i in range(FLAGS.RNN_N):
filelist = [ 'main.py', 'lib/Teco.py', 'lib/frvsr.py', 'lib/dataloader.py', 'lib/ops.py' ] for filename in filelist: shutil.copyfile('./' + filename, FLAGS.summary_dir + filename.replace("/", "_")) for key, value_obj in tf.flags.FLAGS.__flags.items(): print(key, ': ', value_obj.value) print(FLAGS.input_video_dir) useValidat = tf.placeholder_with_default(tf.constant(False, dtype=tf.bool), shape=()) rdata = frvsr_gpu_data_loader(FLAGS, useValidat) # Data = collections.namedtuple('Data', 'paths_HR, s_inputs, s_targets, image_count, steps_per_epoch') print('tData count = %d, steps per epoch %d' % (rdata.image_count, rdata.steps_per_epoch)) if (FLAGS.ratio > 0): Net = TecoGAN(rdata.s_inputs, rdata.s_targets, FLAGS) else: Net = FRVSR(rdata.s_inputs, rdata.s_targets, FLAGS) # Network = collections.namedtuple('Network', 'gen_output, train, learning_rate, update_list, ' # 'update_list_name, update_list_avg, image_summary') # Add scalar summary tf.summary.scalar('learning_rate', Net.learning_rate) train_summary = [] for key, value in zip(Net.update_list_name, Net.update_list_avg): # 'map_loss, scale_loss, FrameA_loss, FrameA_loss,...'