Exemplo n.º 1
0
def main(_):
    # Set up the inputs.
    # How much shall we pad the input images? We'll pad enough so that
    # (a) when we render output images we won't lose stuff at the edges
    # due to cropping, and (b) we can find a multiple of 16 size without
    # cropping into the original images.
    max_multiple = 0
    if FLAGS.render:
        render_list = [float(x) for x in FLAGS.render_multiples.split(',')]
        max_multiple = max(abs(float(m)) for m in render_list)
    pady = int(max_multiple * abs(FLAGS.yshift) + 8)
    padx = int(max_multiple * abs(FLAGS.xshift) + 8)

    print 'Padding inputs: padx=%d, pady=%d (max_multiple=%d)' % (padx, pady,
                                                                  max_multiple)
    inputs, original_width, original_height = get_inputs(padx, pady)

    # MPI code requires images of known size. So we run the input part of the
    # graph now to find the size, which we can then set on the inputs.
    with tf.Session() as sess:
        dimensions, original_width, original_height = sess.run(
            [tf.shape(inputs['ref_image']), original_width, original_height])
    batch = 1
    channels = 3
    assert dimensions[0] == batch
    mpi_height = dimensions[1]
    mpi_width = dimensions[2]
    assert dimensions[3] == channels

    print 'Original size: width=%d, height=%d' % (original_width,
                                                  original_height)
    print '     MPI size: width=%d, height=%d' % (mpi_width, mpi_height)

    inputs['ref_image'].set_shape([batch, mpi_height, mpi_width, channels])
    inputs['src_images'].set_shape(
        [batch, mpi_height, mpi_width, channels * 2])

    # Build the MPI.
    model = MPI()
    psv_planes = model.inv_depths(FLAGS.min_depth, FLAGS.max_depth,
                                  FLAGS.num_psv_planes)
    mpi_planes = model.inv_depths(FLAGS.min_depth, FLAGS.max_depth,
                                  FLAGS.num_mpi_planes)
    outputs = model.infer_mpi(inputs['src_images'], inputs['ref_image'],
                              inputs['ref_pose'], inputs['src_poses'],
                              inputs['intrinsics'], FLAGS.which_color_pred,
                              FLAGS.num_mpi_planes, psv_planes,
                              FLAGS.test_outputs)

    saver = tf.train.Saver([var for var in tf.model_variables()])
    ckpt_dir = os.path.join(FLAGS.model_root, FLAGS.model_name)
    ckpt_file = tf.train.latest_checkpoint(ckpt_dir)
    sv = tf.train.Supervisor(logdir=ckpt_dir, saver=None)
    config = tf.ConfigProto()

    config.gpu_options.allow_growth = True
    print 'Inferring MPI...'
    with sv.managed_session(config=config) as sess:
        saver.restore(sess, ckpt_file)
        ins, outs = sess.run([inputs, outputs])

    # Render output images separately so as not to run out of memory.
    tf.reset_default_graph()
    renders = {}
    if FLAGS.render:
        print 'Rendering new views...'
        for index, multiple in enumerate(render_list):
            m = float(multiple)
            print '    offset: %s' % multiple
            pose = build_matrix([[1.0, 0.0, 0.0, -m * FLAGS.xoffset],
                                 [0.0, 1.0, 0.0, -m * FLAGS.yoffset],
                                 [0.0, 0.0, 1.0, -m * FLAGS.zoffset],
                                 [0.0, 0.0, 0.0, 1.0]])[tf.newaxis, ...]
            image = model.deprocess_image(
                model.mpi_render_view(tf.constant(outs['rgba_layers']),
                                      pose, mpi_planes,
                                      tf.constant(ins['intrinsics'])))[0]
            unshifted = shift_image(image, m * FLAGS.xshift, m * FLAGS.yshift)
            cropped = crop_to_size(unshifted, original_width, original_height)

            with tf.Session() as sess:
                renders[multiple] = (index, sess.run(cropped))

    output_dir = FLAGS.output_dir
    if not tf.gfile.IsDirectory(output_dir):
        tf.gfile.MakeDirs(output_dir)

    print 'Saving results to %s' % output_dir

    # Write results to disk.
    for name, (index, image) in renders.items():
        write_image(output_dir + '/render_%02d_%s.png' % (index, name), image)

    if 'intrinsics' in FLAGS.test_outputs:
        with open(output_dir + '/intrinsics.txt', 'w') as fh:
            write_intrinsics(fh, ins['intrinsics'][0])

    if 'src_images' in FLAGS.test_outputs:
        for i in range(2):
            write_image(output_dir + '/src_image_%d.png' % i,
                        ins['src_images'][0, :, :, i * 3:(i + 1) * 3] * 255.0)
            if 'poses' in FLAGS.test_outputs:
                write_pose(output_dir + '/src_pose_%d.txt' % i,
                           ins['src_poses'][0, i])

    if 'fgbg' in FLAGS.test_outputs:
        write_image(output_dir + '/foreground_color.png', outs['fg_image'][0])
        write_image(output_dir + '/background_color.png', outs['bg_image'][0])

    if 'blend_weights' in FLAGS.test_outputs:
        for i in range(FLAGS.num_mpi_planes):
            weight_img = outs['blend_weights'][0, :, :, i] * 255.0
            write_image(output_dir + '/foreground_weight_plane_%.3d.png' % i,
                        weight_img)

    if 'psv' in FLAGS.test_outputs:
        for j in range(FLAGS.num_psv_planes):
            plane_img = (outs['psv'][0, :, :, j * 3:(j + 1) * 3] +
                         1.) / 2. * 255
            write_image(output_dir + '/psv_plane_%.3d.png' % j, plane_img)

    if 'rgba_layers' in FLAGS.test_outputs:
        for i in range(FLAGS.num_mpi_planes):
            alpha_img = outs['rgba_layers'][0, :, :, i, 3] * 255.0
            rgb_img = (outs['rgba_layers'][0, :, :, i, :3] + 1.) / 2. * 255
            write_image(output_dir + '/mpi_alpha_%.2d.png' % i, alpha_img)
            write_image(output_dir + '/mpi_rgb_%.2d.png' % i, rgb_img)

    with open(output_dir + '/README', 'w') as fh:
        fh.write(
            'This directory was generated by mpi_from_images. Command-line:\n\n'
        )
        fh.write('%s \\\n' % sys.argv[0])
        for arg in sys.argv[1:-1]:
            fh.write('  %s \\\n' % arg)
        fh.write('  %s\n' % sys.argv[-1])

    print 'Done.'
Exemplo n.º 2
0
def main(_):
    assert FLAGS.batch_size == 1, 'Currently, batch_size must be 1 when testing.'

    tf.logging.set_verbosity(tf.logging.INFO)
    tf.reset_default_graph()
    tf.set_random_seed(FLAGS.random_seed)

    # Set up data loader.
    data_loader = SequenceDataLoader(FLAGS.cameras_glob, FLAGS.image_dir,
                                     False, FLAGS.num_source,
                                     FLAGS.shuffle_seq_length,
                                     FLAGS.random_seed)

    inputs = data_loader.sample_batch()
    model = MPI()
    psv_planes = model.inv_depths(FLAGS.min_depth, FLAGS.max_depth,
                                  FLAGS.num_psv_planes)
    mpi_planes = model.inv_depths(FLAGS.min_depth, FLAGS.max_depth,
                                  FLAGS.num_mpi_planes)
    outputs = model.infer_mpi(inputs['src_images'], inputs['ref_image'],
                              inputs['ref_pose'], inputs['src_poses'],
                              inputs['intrinsics'], FLAGS.which_color_pred,
                              FLAGS.num_mpi_planes, psv_planes,
                              FLAGS.test_outputs)

    if 'tgt_image' in FLAGS.test_outputs:
        rel_pose = tf.matmul(inputs['tgt_pose'],
                             tf.matrix_inverse(inputs['ref_pose']))
        outputs['output_image'] = model.mpi_render_view(
            outputs['rgba_layers'], rel_pose, mpi_planes, inputs['intrinsics'])
        outputs['output_image'] = model.deprocess_image(
            outputs['output_image'])

    saver = tf.train.Saver([var for var in tf.model_variables()])
    ckpt_dir = os.path.join(FLAGS.model_root, FLAGS.model_name)
    ckpt_file = tf.train.latest_checkpoint(ckpt_dir)
    sv = tf.train.Supervisor(logdir=ckpt_dir, saver=None)
    config = tf.ConfigProto()

    with sv.managed_session(config=config) as sess:
        saver.restore(sess, ckpt_file)
        for run in range(FLAGS.num_runs):
            tf.logging.info('Progress: %d/%d' % (run, FLAGS.num_runs))
            ins, outs = sess.run([inputs, outputs])

            # Output directory name: [scene]_[1st src file]_[2nd src file]_[tgt file].
            dirname = ins['ref_name'][0].split('/')[0]
            for i in range(FLAGS.num_source):
                dirname += '_%s' % (os.path.basename(
                    ins['src_timestamps'][0][i]).split('.')[0].split('_')[-1])
            dirname += '_%s' % (os.path.basename(
                ins['tgt_timestamp'][0]).split('.')[0].split('_')[-1])

            output_dir = os.path.join(FLAGS.output_root, FLAGS.model_name,
                                      FLAGS.data_split, dirname)
            if not tf.gfile.IsDirectory(output_dir):
                tf.gfile.MakeDirs(output_dir)

            # Write results to disk.
            if 'intrinsics' in FLAGS.test_outputs:
                with open(output_dir + '/intrinsics.txt', 'w') as fh:
                    write_intrinsics(fh, ins['intrinsics'][0])

            if 'src_images' in FLAGS.test_outputs:
                for i in range(FLAGS.num_source):
                    timestamp = ins['src_timestamps'][0][i]
                    write_image(
                        output_dir + '/src_image_%d_%s.png' % (i, timestamp),
                        ins['src_images'][0, :, :, i * 3:(i + 1) * 3] * 255.0)
                    if 'poses' in FLAGS.test_outputs:
                        write_pose(output_dir + '/src_pose_%d.txt' % i,
                                   ins['src_poses'][0, i])

            if 'tgt_image' in FLAGS.test_outputs:
                timestamp = ins['tgt_timestamp'][0]
                write_image(output_dir + '/tgt_image_%s.png' % timestamp,
                            ins['tgt_image'][0] * 255.0)
                write_image(output_dir + '/output_image_%s.png' % timestamp,
                            outs['output_image'][0])
                if 'poses' in FLAGS.test_outputs:
                    write_pose(output_dir + '/tgt_pose.txt',
                               ins['tgt_pose'][0])

            if 'fgbg' in FLAGS.test_outputs:
                write_image(output_dir + '/foreground_color.png',
                            outs['fg_image'][0])
                write_image(output_dir + '/background_color.png',
                            outs['bg_image'][0])

            if 'blend_weights' in FLAGS.test_outputs:
                for i in range(FLAGS.num_mpi_planes):
                    weight_img = outs['blend_weights'][0, :, :, i] * 255.0
                    write_image(
                        output_dir + '/foreground_weight_plane_%.3d.png' % i,
                        weight_img)

            if 'ref_image' in FLAGS.test_outputs:
                fname = os.path.basename(ins['ref_name'][0])
                write_image(output_dir + '/ref_image_%s' % fname,
                            ins['ref_image'][0])
                write_pose(output_dir + '/ref_pose.txt', ins['ref_pose'][0])

            if 'psv' in FLAGS.test_outputs:
                for j in range(FLAGS.num_psv_planes):
                    plane_img = (outs['psv'][0, :, :, j * 3:(j + 1) * 3] +
                                 1.) / 2. * 255
                    write_image(output_dir + '/psv_plane_%.3d.png' % j,
                                plane_img)

            if 'rgba_layers' in FLAGS.test_outputs:
                for i in range(FLAGS.num_mpi_planes):
                    alpha_img = outs['rgba_layers'][0, :, :, i, 3] * 255.0
                    rgb_img = (outs['rgba_layers'][0, :, :, i, :3] +
                               1.) / 2. * 255
                    write_image(output_dir + '/mpi_alpha_%.2d.png' % i,
                                alpha_img)
                    write_image(output_dir + '/mpi_rgb_%.2d.png' % i, rgb_img)