def main(_): tf.logging.set_verbosity(tf.logging.INFO) tf.set_random_seed(FLAGS.random_seed) FLAGS.checkpoint_dir += '/%s/' % FLAGS.experiment_name if not tf.gfile.IsDirectory(FLAGS.checkpoint_dir): tf.gfile.MakeDirs(FLAGS.checkpoint_dir) # Set up data loader #with tf.device('/device:GPU:0'): while True: data_loader = SequenceDataLoader(FLAGS.cameras_glob, FLAGS.image_dir, True, FLAGS.num_source, FLAGS.shuffle_seq_length, FLAGS.random_seed) #print('\033[2J data loaded') train_batch = data_loader.sample_batch() #print('\033[2J sample batch taken') model = MPI() #print('\033[2J mpi built') train_op = model.build_train_graph( train_batch, FLAGS.min_depth, FLAGS.max_depth, FLAGS.num_psv_planes, FLAGS.num_mpi_planes, FLAGS.which_color_pred, FLAGS.which_loss, FLAGS.learning_rate, FLAGS.beta1, FLAGS.vgg_model_file) #print('\033[2J trainop setup') model.train(train_op, FLAGS.checkpoint_dir, FLAGS.continue_train, FLAGS.summary_freq, FLAGS.save_latest_freq, FLAGS.max_steps) #model.train(train_op, FLAGS.checkpoint_dir, FLAGS.continue_train, # FLAGS.summary_freq, FLAGS.save_latest_freq, FLAGS.max_steps) #print('\033[2J model.train finished') break
def main(_): tf.logging.set_verbosity(tf.logging.INFO) tf.set_random_seed(FLAGS.random_seed) FLAGS.checkpoint_dir += '/%s/' % FLAGS.experiment_name if not tf.gfile.IsDirectory(FLAGS.checkpoint_dir): tf.gfile.MakeDirs(FLAGS.checkpoint_dir) # Set up data loader data_loader = SequenceDataLoader(FLAGS.cameras_glob, FLAGS.image_dir, True, FLAGS.num_source, FLAGS.shuffle_seq_length, FLAGS.random_seed) train_batch = data_loader.sample_batch() model = MPI() train_op = model.build_train_graph( train_batch, FLAGS.min_depth, FLAGS.max_depth, FLAGS.num_psv_planes, FLAGS.num_mpi_planes, FLAGS.which_color_pred, FLAGS.which_loss, FLAGS.learning_rate, FLAGS.beta1, FLAGS.vgg_model_file) model.train(train_op, FLAGS.checkpoint_dir, FLAGS.continue_train, FLAGS.summary_freq, FLAGS.save_latest_freq, FLAGS.max_steps)
def main(_): # Set up the inputs. # How much shall we pad the input images? We'll pad enough so that # (a) when we render output images we won't lose stuff at the edges # due to cropping, and (b) we can find a multiple of 16 size without # cropping into the original images. max_multiple = 0 if FLAGS.render: render_list = [float(x) for x in FLAGS.render_multiples.split(',')] max_multiple = max(abs(float(m)) for m in render_list) pady = int(max_multiple * abs(FLAGS.yshift) + 8) padx = int(max_multiple * abs(FLAGS.xshift) + 8) print 'Padding inputs: padx=%d, pady=%d (max_multiple=%d)' % (padx, pady, max_multiple) inputs, original_width, original_height = get_inputs(padx, pady) # MPI code requires images of known size. So we run the input part of the # graph now to find the size, which we can then set on the inputs. with tf.Session() as sess: dimensions, original_width, original_height = sess.run( [tf.shape(inputs['ref_image']), original_width, original_height]) batch = 1 channels = 3 assert dimensions[0] == batch mpi_height = dimensions[1] mpi_width = dimensions[2] assert dimensions[3] == channels print 'Original size: width=%d, height=%d' % (original_width, original_height) print ' MPI size: width=%d, height=%d' % (mpi_width, mpi_height) inputs['ref_image'].set_shape([batch, mpi_height, mpi_width, channels]) inputs['src_images'].set_shape( [batch, mpi_height, mpi_width, channels * 2]) # Build the MPI. model = MPI() psv_planes = model.inv_depths(FLAGS.min_depth, FLAGS.max_depth, FLAGS.num_psv_planes) mpi_planes = model.inv_depths(FLAGS.min_depth, FLAGS.max_depth, FLAGS.num_mpi_planes) outputs = model.infer_mpi(inputs['src_images'], inputs['ref_image'], inputs['ref_pose'], inputs['src_poses'], inputs['intrinsics'], FLAGS.which_color_pred, FLAGS.num_mpi_planes, psv_planes, FLAGS.test_outputs) saver = tf.train.Saver([var for var in tf.model_variables()]) ckpt_dir = os.path.join(FLAGS.model_root, FLAGS.model_name) ckpt_file = tf.train.latest_checkpoint(ckpt_dir) sv = tf.train.Supervisor(logdir=ckpt_dir, saver=None) config = tf.ConfigProto() config.gpu_options.allow_growth = True print 'Inferring MPI...' with sv.managed_session(config=config) as sess: saver.restore(sess, ckpt_file) ins, outs = sess.run([inputs, outputs]) # Render output images separately so as not to run out of memory. tf.reset_default_graph() renders = {} if FLAGS.render: print 'Rendering new views...' for index, multiple in enumerate(render_list): m = float(multiple) print ' offset: %s' % multiple pose = build_matrix([[1.0, 0.0, 0.0, -m * FLAGS.xoffset], [0.0, 1.0, 0.0, -m * FLAGS.yoffset], [0.0, 0.0, 1.0, -m * FLAGS.zoffset], [0.0, 0.0, 0.0, 1.0]])[tf.newaxis, ...] image = model.deprocess_image( model.mpi_render_view(tf.constant(outs['rgba_layers']), pose, mpi_planes, tf.constant(ins['intrinsics'])))[0] unshifted = shift_image(image, m * FLAGS.xshift, m * FLAGS.yshift) cropped = crop_to_size(unshifted, original_width, original_height) with tf.Session() as sess: renders[multiple] = (index, sess.run(cropped)) output_dir = FLAGS.output_dir if not tf.gfile.IsDirectory(output_dir): tf.gfile.MakeDirs(output_dir) print 'Saving results to %s' % output_dir # Write results to disk. for name, (index, image) in renders.items(): write_image(output_dir + '/render_%02d_%s.png' % (index, name), image) if 'intrinsics' in FLAGS.test_outputs: with open(output_dir + '/intrinsics.txt', 'w') as fh: write_intrinsics(fh, ins['intrinsics'][0]) if 'src_images' in FLAGS.test_outputs: for i in range(2): write_image(output_dir + '/src_image_%d.png' % i, ins['src_images'][0, :, :, i * 3:(i + 1) * 3] * 255.0) if 'poses' in FLAGS.test_outputs: write_pose(output_dir + '/src_pose_%d.txt' % i, ins['src_poses'][0, i]) if 'fgbg' in FLAGS.test_outputs: write_image(output_dir + '/foreground_color.png', outs['fg_image'][0]) write_image(output_dir + '/background_color.png', outs['bg_image'][0]) if 'blend_weights' in FLAGS.test_outputs: for i in range(FLAGS.num_mpi_planes): weight_img = outs['blend_weights'][0, :, :, i] * 255.0 write_image(output_dir + '/foreground_weight_plane_%.3d.png' % i, weight_img) if 'psv' in FLAGS.test_outputs: for j in range(FLAGS.num_psv_planes): plane_img = (outs['psv'][0, :, :, j * 3:(j + 1) * 3] + 1.) / 2. * 255 write_image(output_dir + '/psv_plane_%.3d.png' % j, plane_img) if 'rgba_layers' in FLAGS.test_outputs: for i in range(FLAGS.num_mpi_planes): alpha_img = outs['rgba_layers'][0, :, :, i, 3] * 255.0 rgb_img = (outs['rgba_layers'][0, :, :, i, :3] + 1.) / 2. * 255 write_image(output_dir + '/mpi_alpha_%.2d.png' % i, alpha_img) write_image(output_dir + '/mpi_rgb_%.2d.png' % i, rgb_img) with open(output_dir + '/README', 'w') as fh: fh.write( 'This directory was generated by mpi_from_images. Command-line:\n\n' ) fh.write('%s \\\n' % sys.argv[0]) for arg in sys.argv[1:-1]: fh.write(' %s \\\n' % arg) fh.write(' %s\n' % sys.argv[-1]) print 'Done.'
def main(_): assert FLAGS.batch_size == 1, 'Currently, batch_size must be 1 when testing.' tf.logging.set_verbosity(tf.logging.INFO) tf.reset_default_graph() tf.set_random_seed(FLAGS.random_seed) # Set up data loader. data_loader = SequenceDataLoader(FLAGS.cameras_glob, FLAGS.image_dir, False, FLAGS.num_source, FLAGS.shuffle_seq_length, FLAGS.random_seed) inputs = data_loader.sample_batch() model = MPI() psv_planes = model.inv_depths(FLAGS.min_depth, FLAGS.max_depth, FLAGS.num_psv_planes) mpi_planes = model.inv_depths(FLAGS.min_depth, FLAGS.max_depth, FLAGS.num_mpi_planes) outputs = model.infer_mpi(inputs['src_images'], inputs['ref_image'], inputs['ref_pose'], inputs['src_poses'], inputs['intrinsics'], FLAGS.which_color_pred, FLAGS.num_mpi_planes, psv_planes, FLAGS.test_outputs) if 'tgt_image' in FLAGS.test_outputs: rel_pose = tf.matmul(inputs['tgt_pose'], tf.matrix_inverse(inputs['ref_pose'])) outputs['output_image'] = model.mpi_render_view( outputs['rgba_layers'], rel_pose, mpi_planes, inputs['intrinsics']) outputs['output_image'] = model.deprocess_image( outputs['output_image']) saver = tf.train.Saver([var for var in tf.model_variables()]) ckpt_dir = os.path.join(FLAGS.model_root, FLAGS.model_name) ckpt_file = tf.train.latest_checkpoint(ckpt_dir) sv = tf.train.Supervisor(logdir=ckpt_dir, saver=None) config = tf.ConfigProto() with sv.managed_session(config=config) as sess: saver.restore(sess, ckpt_file) for run in range(FLAGS.num_runs): tf.logging.info('Progress: %d/%d' % (run, FLAGS.num_runs)) ins, outs = sess.run([inputs, outputs]) # Output directory name: [scene]_[1st src file]_[2nd src file]_[tgt file]. dirname = ins['ref_name'][0].split('/')[0] for i in range(FLAGS.num_source): dirname += '_%s' % (os.path.basename( ins['src_timestamps'][0][i]).split('.')[0].split('_')[-1]) dirname += '_%s' % (os.path.basename( ins['tgt_timestamp'][0]).split('.')[0].split('_')[-1]) output_dir = os.path.join(FLAGS.output_root, FLAGS.model_name, FLAGS.data_split, dirname) if not tf.gfile.IsDirectory(output_dir): tf.gfile.MakeDirs(output_dir) # Write results to disk. if 'intrinsics' in FLAGS.test_outputs: with open(output_dir + '/intrinsics.txt', 'w') as fh: write_intrinsics(fh, ins['intrinsics'][0]) if 'src_images' in FLAGS.test_outputs: for i in range(FLAGS.num_source): timestamp = ins['src_timestamps'][0][i] write_image( output_dir + '/src_image_%d_%s.png' % (i, timestamp), ins['src_images'][0, :, :, i * 3:(i + 1) * 3] * 255.0) if 'poses' in FLAGS.test_outputs: write_pose(output_dir + '/src_pose_%d.txt' % i, ins['src_poses'][0, i]) if 'tgt_image' in FLAGS.test_outputs: timestamp = ins['tgt_timestamp'][0] write_image(output_dir + '/tgt_image_%s.png' % timestamp, ins['tgt_image'][0] * 255.0) write_image(output_dir + '/output_image_%s.png' % timestamp, outs['output_image'][0]) if 'poses' in FLAGS.test_outputs: write_pose(output_dir + '/tgt_pose.txt', ins['tgt_pose'][0]) if 'fgbg' in FLAGS.test_outputs: write_image(output_dir + '/foreground_color.png', outs['fg_image'][0]) write_image(output_dir + '/background_color.png', outs['bg_image'][0]) if 'blend_weights' in FLAGS.test_outputs: for i in range(FLAGS.num_mpi_planes): weight_img = outs['blend_weights'][0, :, :, i] * 255.0 write_image( output_dir + '/foreground_weight_plane_%.3d.png' % i, weight_img) if 'ref_image' in FLAGS.test_outputs: fname = os.path.basename(ins['ref_name'][0]) write_image(output_dir + '/ref_image_%s' % fname, ins['ref_image'][0]) write_pose(output_dir + '/ref_pose.txt', ins['ref_pose'][0]) if 'psv' in FLAGS.test_outputs: for j in range(FLAGS.num_psv_planes): plane_img = (outs['psv'][0, :, :, j * 3:(j + 1) * 3] + 1.) / 2. * 255 write_image(output_dir + '/psv_plane_%.3d.png' % j, plane_img) if 'rgba_layers' in FLAGS.test_outputs: for i in range(FLAGS.num_mpi_planes): alpha_img = outs['rgba_layers'][0, :, :, i, 3] * 255.0 rgb_img = (outs['rgba_layers'][0, :, :, i, :3] + 1.) / 2. * 255 write_image(output_dir + '/mpi_alpha_%.2d.png' % i, alpha_img) write_image(output_dir + '/mpi_rgb_%.2d.png' % i, rgb_img)