示例#1
0
with tf.Session() as session:

    limage = tf.placeholder(tf.float32, [None, None, None, num_channels],
                            name='limage')
    rimage = tf.placeholder(tf.float32, [None, None, None, num_channels],
                            name='rimage')
    targets = tf.placeholder(tf.float32, [None, FLAGS.disp_range],
                             name='targets')

    snet = nf.create(limage, rimage, targets, FLAGS.net_type)

    lmap = tf.placeholder(tf.float32, [None, None, None, 64], name='lmap')
    rmap = tf.placeholder(tf.float32, [None, None, None, 64], name='rmap')

    map_prod = nf.map_inner_product(lmap, rmap)

    saver = tf.train.Saver()
    saver.restore(session, tf.train.latest_checkpoint(FLAGS.model_dir))

    for i in range(FLAGS.start_id, FLAGS.start_id + FLAGS.num_imgs):
        file_id = file_ids[i]

        if FLAGS.data_version == 'kitti2015':
            linput = misc.imread(
                ('%s/image_2/%06d_10.png') % (FLAGS.data_root, file_id))
            rinput = misc.imread(
                ('%s/image_3/%06d_10.png') % (FLAGS.data_root, file_id))

        elif FLAGS.data_version == 'kitti2012':
            linput = misc.imread(
def main(_):
    print(FLAGS.util_root)
    np.random.seed(123)

    file_ids = np.fromfile(os.path.join(FLAGS.util_root, 'myPerm.bin'), '<f4')

    if FLAGS.data_version == 'kitti2015':
        num_channels = 3
    elif FLAGS.data_version == 'kitti2012':
        num_channels = 1

    scale_factor = 255 / (FLAGS.disp_range - 1)

    if not os.path.exists(FLAGS.out_dir):
        os.makedirs(FLAGS.out_dir)

    with tf.Session() as session:

        limage = tf.placeholder(tf.float32, [None, None, None, num_channels],
                                name='limage')
        rimage = tf.placeholder(tf.float32, [None, None, None, num_channels],
                                name='rimage')
        targets = tf.placeholder(tf.float32, [None, FLAGS.disp_range],
                                 name='targets')

        snet = nf.create(limage, rimage, targets, FLAGS.net_type)

        lmap = tf.placeholder(tf.float32, [None, None, None, 64], name='lmap')
        rmap = tf.placeholder(tf.float32, [None, None, None, 64], name='rmap')

        map_prod = nf.map_inner_product(lmap, rmap)

        saver = tf.train.Saver()
        saver.restore(session, tf.train.latest_checkpoint(FLAGS.model_dir))

        for i in range(FLAGS.start_id, FLAGS.start_id + FLAGS.num_imgs):
            file_id = file_ids[i]

            if FLAGS.data_version == 'kitti2015':
                linput = misc.imread(
                    ('%s/image_2/%06d_10.png') % (FLAGS.data_root, file_id))
                rinput = misc.imread(
                    ('%s/image_3/%06d_10.png') % (FLAGS.data_root, file_id))

            elif FLAGS.data_version == 'kitti2012':
                linput = misc.imread(
                    ('%s/image_0/%06d_10.png') % (FLAGS.data_root, file_id))
                rinput = misc.imread(
                    ('%s/image_1/%06d_10.png') % (FLAGS.data_root, file_id))

            linput = (linput - linput.mean()) / linput.std()
            rinput = (rinput - rinput.mean()) / rinput.std()

            linput = linput.reshape(1, linput.shape[0], linput.shape[1],
                                    num_channels)
            rinput = rinput.reshape(1, rinput.shape[0], rinput.shape[1],
                                    num_channels)

            test_dict = {
                limage: linput,
                rimage: rinput,
                snet['is_training']: False
            }
            limage_map, rimage_map = session.run(
                [snet['lbranch'], snet['rbranch']], feed_dict=test_dict)

            map_width = limage_map.shape[2]
            unary_vol = np.zeros(
                (limage_map.shape[1], limage_map.shape[2], FLAGS.disp_range))

            for loc in range(FLAGS.disp_range):
                x_off = -loc
                l = limage_map[:, :, max(0, -x_off):map_width, :]
                r = rimage_map[:, :, 0:min(map_width, map_width + x_off), :]
                res = session.run(map_prod, feed_dict={lmap: l, rmap: r})

                unary_vol[:, max(0, -x_off):map_width, loc] = res[0, :, :]

            print('Image %s processed.' % (i + 1))
            pred = np.argmax(unary_vol, axis=2) * scale_factor

            misc.imsave('%s/disp_map_%06d_10.png' % (FLAGS.out_dir, file_id),
                        pred)