Exemple #1
0
        # args.decay_steps = 10000*3

    key = time.strftime("%Y-%m-%d_%H%M%S")
    init_logging(args.loglevel)
    logger = logging.getLogger('mview3d.' + __name__)

    if args.ckpt is None:
        log_dir = osp.join(args.logdir, key, 'train')
    else:
        log_dir = args.logdir

    mvnet = MVNet(vmin=-0.5,
                  vmax=0.5,
                  vox_bs=args.batch_size,
                  im_bs=args.im_batch,
                  grid_size=args.nvox,
                  im_h=args.im_h,
                  im_w=args.im_w,
                  norm=args.norm,
                  mode="TRAIN")

    # Define graph
    mvnet = model_dlsm(mvnet,
                       im_nets[args.im_net],
                       grid_nets[args.grid_net],
                       conv_rnns[args.rnn],
                       im_skip=args.im_skip,
                       ray_samples=args.ray_samples,
                       sepup=args.sepup,
                       proj_x=args.proj_x)
Exemple #2
0
def validate(args, checkpoint):
    net = MVNet(vmin=-0.5,
                vmax=0.5,
                vox_bs=args.val_batch_size,
                im_bs=args.val_im_batch,
                grid_size=args.nvox,
                im_h=args.im_h,
                im_w=args.im_w,
                mode="TEST",
                norm=args.norm)

    im_dir = SHAPENET_IM

    # Setup network
    net = model_dlsm(net,
                     im_nets[args.im_net],
                     grid_nets[args.grid_net],
                     conv_rnns[args.rnn],
                     im_skip=args.im_skip,
                     ray_samples=args.ray_samples,
                     sepup=args.sepup,
                     proj_x=args.proj_x,
                     proj_last=True)
    sess = tf.Session(config=get_session_config())
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint)
    coord = tf.train.Coordinator()

    # Init dataset
    dset = ShapeNet(im_dir=im_dir, split_file=args.val_split_file, rng_seed=1)
    mids = dset.get_smids('val')
    logging.info('Validating %d models', len(mids))
    items = ['shape_id', 'model_id', 'im', 'K', 'R', 'depth']
    dset.init_queue(mids,
                    args.val_im_batch,
                    items,
                    coord,
                    nepochs=1,
                    qsize=32,
                    nthreads=args.prefetch_threads)

    # Init stats
    l1_err = []

    # Testing loop
    pbar = tqdm(desc='Validating', total=len(mids))
    deq_mids, deq_sids = [], []
    try:
        while not coord.should_stop():
            batch_data = dset.next_batch(items, net.batch_size)
            if batch_data is None:
                continue
            deq_sids.append(batch_data['shape_id'])
            deq_mids.append(batch_data['model_id'])
            num_batch_items = batch_data['K'].shape[0]
            batch_data = pad_batch(batch_data, args.val_batch_size)
            feed_dict = {
                net.K: batch_data['K'],
                net.Rcam: batch_data['R'],
                net.ims: batch_data['im']
            }
            pred = sess.run(net.depth_out, feed_dict=feed_dict)
            batch_err = eval_l1_err(pred[:num_batch_items],
                                    batch_data['depth'][:num_batch_items])

            l1_err.extend(batch_err)
            pbar.update(num_batch_items)
    except Exception, e:
        logger.error(repr(e))
        dset.close_queue(e)
Exemple #3
0
def validate(args, checkpoint):
    net = MVNet(vmin=-0.5,
                vmax=0.5,
                vox_bs=args.val_batch_size,
                im_bs=args.val_im_batch,
                grid_size=args.nvox,
                im_h=args.im_h,
                im_w=args.im_w,
                mode="TEST",
                norm=args.norm)

    im_dir = SHAPENET_IM
    vox_dir = SHAPENET_VOX[args.nvox]

    # Setup network
    net = model_vlsm(net, im_nets[args.im_net], grid_nets[args.grid_net],
                     conv_rnns[args.rnn])
    sess = tf.Session(config=get_session_config())
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint)
    coord = tf.train.Coordinator()

    # Init IoU
    iou = init_iou(net.im_batch, args.eval_thresh)

    # Init dataset
    dset = ShapeNet(im_dir=im_dir,
                    split_file=args.val_split_file,
                    vox_dir=vox_dir,
                    rng_seed=1)
    mids = dset.get_smids('val')
    logging.info('Testing %d models', len(mids))
    items = ['shape_id', 'model_id', 'im', 'K', 'R', 'vol']
    dset.init_queue(mids,
                    args.val_im_batch,
                    items,
                    coord,
                    nepochs=1,
                    qsize=32,
                    nthreads=args.prefetch_threads)

    # Testing loop
    pbar = tqdm(desc='Validating', total=len(mids))
    deq_mids, deq_sids = [], []
    try:
        while not coord.should_stop():
            batch_data = dset.next_batch(items, net.batch_size)
            if batch_data is None:
                continue
            deq_sids.append(batch_data['shape_id'])
            deq_mids.append(batch_data['model_id'])
            num_batch_items = batch_data['K'].shape[0]
            batch_data = pad_batch(batch_data, args.val_batch_size)

            feed_dict = {net.K: batch_data['K'], net.Rcam: batch_data['R']}
            feed_dict[net.ims] = batch_data['im']

            pred = sess.run(net.prob_vox, feed_dict=feed_dict)
            batch_iou = eval_seq_iou(pred[:num_batch_items],
                                     batch_data['vol'][:num_batch_items],
                                     args.val_im_batch,
                                     thresh=args.eval_thresh)

            # Update iou dict
            iou = update_iou(batch_iou, iou)
            pbar.update(num_batch_items)
    except Exception, e:
        logger.error(repr(e))
        dset.close_queue(e)
log_dir = os.path.join('models_lsm_v1/vlsm-release/train')
with open(os.path.join(log_dir, 'args.json'), 'r') as f:
    args = json.load(f)
    args = Bunch(args)

# Set voxel resolution
voxel_resolution = 32
# Setup TF graph and initialize VLSM model
tf.reset_default_graph()

# Change the ims_per_model to run on different number of views
bs, ims_per_model = 1, 4

ckpt = 'mvnet-100000'
net = MVNet(vmin=-0.5, vmax=0.5, vox_bs=bs,
    im_bs=ims_per_model, grid_size=args.nvox,
    im_h=args.im_h, im_w=args.im_w,
    norm=args.norm, mode="TEST")

net = model_vlsm(net, im_nets[args.im_net], grid_nets[args.grid_net], conv_rnns[args.rnn])
vars_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='MVNet')
sess = tf.InteractiveSession(config=get_session_config())
saver = tf.train.Saver(var_list=vars_restore)
saver.restore(sess, os.path.join(log_dir, ckpt))
#print net.im_batch
#sys.exit()

from shapenet import ShapeNet
# Read data
dset = ShapeNet(im_dir=im_dir, split_file=os.path.join(SAMPLE_DIR, 'splits_sample.json'), rng_seed=1)
test_mids = dset.get_smids('test')
print test_mids[0]