Exemple #1
0
 def __init__(self,
              batch_size=16,
              smids='train',
              mid=None,
              items=['im', 'pc'],
              split_file='data/all_exp.json',
              nepochs=None):
     vox_dir = SHAPENET_VOX[32]
     im_dir = SHAPENET_IM
     self.coord = tf.train.Coordinator()
     self.dset = ShapeNet(im_dir=im_dir,
                          vox_dir=vox_dir,
                          split_file=split_file)
     mids = self.dset.get_smids(smids)
     if mid is not None:
         mids[0][1] = mid
         mids = mids[0:1]
     # print(mids[0])
     # print(mids.shape)
     print("Mids", len(mids))
     self.items = items
     self.batch_size = batch_size
     self.dset.init_queue(
         mids,
         1,  # maybe this is the number of images per mid per batch
         self.items,
         self.coord,
         qsize=8,
         nthreads=8,
         nepochs=nepochs)
Exemple #2
0
def load_samples():
    im_dir = SHAPENET_IM
    path = os.path.dirname(os.path.realpath(__file__))
    path = os.path.join(path, "data")
    split_file = os.path.join(path, "splits.json")
    dset = ShapeNet(im_dir=im_dir, split_file=split_file, rng_seed=1)

    # smids = 'test' mids='02691156'
    x = dset.get_smids('test')
    x = np.array(x)
    print(x[0][0])  # sid
    print(x[0][1])  # mid
    print(x[1][0])
    print(x[1][1])
    print(x.shape)

    return
    train = np.array(x['train'])
    print(train.shape)
    val = np.array(x['val'])
    print(val.shape)
    test = np.array(x['test'])
    print(test.shape)

    print(val)
Exemple #3
0
def train(net):
    net.gt_depth = tf.placeholder(tf.float32, net.depth_tensor_shape)
    net.pred_depth = net.depth_out
    out_shape = tf_static_shape(net.pred_depth)
    net.depth_loss = loss_l1(
        net.pred_depth, repeat_tensor(net.gt_depth, out_shape[1], rep_dim=1))

    _t_dbg = Timer()

    # Add optimizer
    global_step = tf.Variable(0, trainable=False, name='global_step')
    decay_lr = tf.train.exponential_decay(args.lr,
                                          global_step,
                                          args.decay_steps,
                                          args.decay_rate,
                                          staircase=True)
    lr_sum = tf.summary.scalar('lr', decay_lr)
    optim = tf.train.AdamOptimizer(decay_lr).minimize(net.depth_loss,
                                                      global_step)
    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver()

    # Add summaries for training
    net.loss_sum = tf.summary.scalar('loss', net.depth_loss)
    net.im_sum = image_sum(net.ims, net.batch_size, net.im_batch)
    net.depth_gt_sum = depth_sum(net.gt_depth, net.batch_size, net.im_batch,
                                 'depth_gt')
    net.depth_pred_sum = depth_sum(net.pred_depth[:, -1, ...], net.batch_size,
                                   net.im_batch,
                                   'depth_pred_{:d}'.format(net.im_batch))
    merged_ims = tf.summary.merge(
        [net.im_sum, net.depth_gt_sum, net.depth_pred_sum])
    merged_scalars = tf.summary.merge([net.loss_sum, lr_sum])

    # Initialize dataset
    coord = tf.train.Coordinator()
    dset = ShapeNet(im_dir=im_dir,
                    split_file=args.split_file,
                    rng_seed=0,
                    custom_db=args.custom_training)
    mids = dset.get_smids('train')
    logger.info('Training with %d models', len(mids))
    items = ['im', 'K', 'R', 'depth']
    dset.init_queue(mids,
                    net.im_batch,
                    items,
                    coord,
                    qsize=64,
                    nthreads=args.prefetch_threads)

    _t_dbg = Timer()
    iters = 0
    # Training loop
    pbar = tqdm(desc='Training Depth-LSM', total=args.niters)
    with tf.Session(config=get_session_config()) as sess:
        sum_writer = tf.summary.FileWriter(log_dir, sess.graph)
        if args.ckpt is not None:
            logger.info('Restoring from %s', args.ckpt)
            saver.restore(sess, args.ckpt)
        else:
            sess.run(init_op)
        try:
            while True:
                iters += 1
                _t_dbg.tic()
                batch_data = dset.next_batch(items, net.batch_size)
                logging.debug('Data read time - %.3fs', _t_dbg.toc())
                feed_dict = {
                    net.ims: batch_data['im'],
                    net.K: batch_data['K'],
                    net.Rcam: batch_data['R'],
                    net.gt_depth: batch_data['depth']
                }
                if args.run_trace and (iters % args.sum_iters == 0
                                       or iters == 1 or iters == args.niters):
                    run_options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()

                    step_, _, merged_scalars_ = sess.run(
                        [global_step, optim, merged_scalars],
                        feed_dict=feed_dict,
                        options=run_options,
                        run_metadata=run_metadata)
                    sum_writer.add_run_metadata(run_metadata, 'step%d' % step_)
                else:
                    step_, _, merged_scalars_ = sess.run(
                        [global_step, optim, merged_scalars],
                        feed_dict=feed_dict)

                logging.debug('Net time - %.3fs', _t_dbg.toc())

                sum_writer.add_summary(merged_scalars_, step_)
                if iters % args.sum_iters == 0 or iters == 1 or iters == args.niters:
                    image_sum_, step_ = sess.run([merged_ims, global_step],
                                                 feed_dict=feed_dict)
                    sum_writer.add_summary(image_sum_, step_)

                if iters % args.ckpt_iters == 0 or iters == args.niters:
                    save_f = saver.save(sess,
                                        osp.join(log_dir, 'mvnet'),
                                        global_step=global_step)
                    logger.info(' Model checkpoint - {:s} '.format(save_f))

                pbar.update(1)
                if iters >= args.niters:
                    break
        except Exception, e:
            logging.error(repr(e))
            dset.close_queue(e)
        finally:
Exemple #4
0
def validate(args, checkpoint):
    net = MVNet(vmin=-0.5,
                vmax=0.5,
                vox_bs=args.val_batch_size,
                im_bs=args.val_im_batch,
                grid_size=args.nvox,
                im_h=args.im_h,
                im_w=args.im_w,
                mode="TEST",
                norm=args.norm)

    im_dir = SHAPENET_IM
    vox_dir = SHAPENET_VOX[args.nvox]

    # Setup network
    net = model_vlsm(net, im_nets[args.im_net], grid_nets[args.grid_net],
                     conv_rnns[args.rnn])
    sess = tf.Session(config=get_session_config())
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint)
    coord = tf.train.Coordinator()

    # Init IoU
    iou = init_iou(net.im_batch, args.eval_thresh)

    # Init dataset
    dset = ShapeNet(im_dir=im_dir,
                    split_file=args.val_split_file,
                    vox_dir=vox_dir,
                    rng_seed=1)
    mids = dset.get_smids('val')
    logging.info('Testing %d models', len(mids))
    items = ['shape_id', 'model_id', 'im', 'K', 'R', 'vol']
    dset.init_queue(mids,
                    args.val_im_batch,
                    items,
                    coord,
                    nepochs=1,
                    qsize=32,
                    nthreads=args.prefetch_threads)

    # Testing loop
    pbar = tqdm(desc='Validating', total=len(mids))
    deq_mids, deq_sids = [], []
    try:
        while not coord.should_stop():
            batch_data = dset.next_batch(items, net.batch_size)
            if batch_data is None:
                continue
            deq_sids.append(batch_data['shape_id'])
            deq_mids.append(batch_data['model_id'])
            num_batch_items = batch_data['K'].shape[0]
            batch_data = pad_batch(batch_data, args.val_batch_size)

            feed_dict = {net.K: batch_data['K'], net.Rcam: batch_data['R']}
            feed_dict[net.ims] = batch_data['im']

            pred = sess.run(net.prob_vox, feed_dict=feed_dict)
            batch_iou = eval_seq_iou(pred[:num_batch_items],
                                     batch_data['vol'][:num_batch_items],
                                     args.val_im_batch,
                                     thresh=args.eval_thresh)

            # Update iou dict
            iou = update_iou(batch_iou, iou)
            pbar.update(num_batch_items)
    except Exception, e:
        logger.error(repr(e))
        dset.close_queue(e)
Exemple #5
0
class DataSource():
    def __init__(self,
                 batch_size=16,
                 smids='train',
                 mid=None,
                 items=['im', 'pc'],
                 split_file='data/all_exp.json',
                 nepochs=None):
        vox_dir = SHAPENET_VOX[32]
        im_dir = SHAPENET_IM
        self.coord = tf.train.Coordinator()
        self.dset = ShapeNet(im_dir=im_dir,
                             vox_dir=vox_dir,
                             split_file=split_file)
        mids = self.dset.get_smids(smids)
        if mid is not None:
            mids[0][1] = mid
            mids = mids[0:1]
        # print(mids[0])
        # print(mids.shape)
        print("Mids", len(mids))
        self.items = items
        self.batch_size = batch_size
        self.dset.init_queue(
            mids,
            1,  # maybe this is the number of images per mid per batch
            self.items,
            self.coord,
            qsize=8,
            nthreads=8,
            nepochs=nepochs)

    def close(self):
        self.dset.close_queue()
        self.coord.join()

    def next_batch(self):
        try:
            batch_data = self.dset.next_batch(self.items, self.batch_size)
            if batch_data is None:
                return None

            ret = []
            if 'im' in self.items:
                img = batch_data['im']
                s = img.shape
                img = np.reshape(img, (-1, s[2], s[3], s[4]))
                ret.append(img)

            if 'im_128' in self.items:
                img = batch_data['im_128']
                s = img.shape
                img = np.reshape(img, (-1, s[2], s[3], s[4]))
                ret.append(img)

            if 'im_64' in self.items:
                img = batch_data['im_64']
                s = img.shape
                img = np.reshape(img, (-1, s[2], s[3], s[4]))
                ret.append(img)

            if 'gray' in self.items:
                img = batch_data['gray']
                s = img.shape
                img = np.reshape(img, (-1, s[2], s[3], 1))
                ret.append(img)

            if 'gray_128' in self.items:
                img = batch_data['gray_128']
                s = img.shape
                img = np.reshape(img, (-1, s[2], s[3], 1))
                ret.append(img)

            if 'pc' in self.items:
                pc = batch_data['pc']
                s = pc.shape
                pc = np.reshape(pc, (-1, s[2], s[3]))
                ret.append(pc)

            if 'elev' in self.items:
                elev = batch_data['elev']
                s = elev.shape
                elev = np.reshape(elev, (-1, s[1] * s[2]))
                ret.append(elev)

            if 'azim' in self.items:
                azim = batch_data['azim']
                s = azim.shape
                azim = np.reshape(azim, (-1, s[1] * s[2]))
                ret.append(azim)

            if 'view' in self.items:
                # vox = batch_data['vol']
                # s = vox.shape
                # vox = np.reshape(vox, (-1, s[1], s[2], s[3]))
                # return img, pc, vox
                view = batch_data['view']
                s = view.shape
                view = np.reshape(view, (-1, s[1] * s[2]))
                ret.append(view)

            return tuple(ret)
        except Exception as e:
            self.dset.close_queue(e)
            print("Exception: ", e)
        # finally:
        #print("finally")
        # self.coord.join()

    def save_vox(self, vox, path="model.mat"):
        dic = {}
        dic["vox"] = vox
        io.savemat(path, dic)

    def save_img(self, img, option=None, path="image.png"):
        if option is not None:
            plt.imsave(path, img, cmap=option)
        else:
            plt.imsave(path, img)
Exemple #6
0
def validate(args, checkpoint):
    net = MVNet(vmin=-0.5,
                vmax=0.5,
                vox_bs=args.val_batch_size,
                im_bs=args.val_im_batch,
                grid_size=args.nvox,
                im_h=args.im_h,
                im_w=args.im_w,
                mode="TEST",
                norm=args.norm)

    im_dir = SHAPENET_IM

    # Setup network
    net = model_dlsm(net,
                     im_nets[args.im_net],
                     grid_nets[args.grid_net],
                     conv_rnns[args.rnn],
                     im_skip=args.im_skip,
                     ray_samples=args.ray_samples,
                     sepup=args.sepup,
                     proj_x=args.proj_x,
                     proj_last=True)
    sess = tf.Session(config=get_session_config())
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint)
    coord = tf.train.Coordinator()

    # Init dataset
    dset = ShapeNet(im_dir=im_dir, split_file=args.val_split_file, rng_seed=1)
    mids = dset.get_smids('val')
    logging.info('Validating %d models', len(mids))
    items = ['shape_id', 'model_id', 'im', 'K', 'R', 'depth']
    dset.init_queue(mids,
                    args.val_im_batch,
                    items,
                    coord,
                    nepochs=1,
                    qsize=32,
                    nthreads=args.prefetch_threads)

    # Init stats
    l1_err = []

    # Testing loop
    pbar = tqdm(desc='Validating', total=len(mids))
    deq_mids, deq_sids = [], []
    try:
        while not coord.should_stop():
            batch_data = dset.next_batch(items, net.batch_size)
            if batch_data is None:
                continue
            deq_sids.append(batch_data['shape_id'])
            deq_mids.append(batch_data['model_id'])
            num_batch_items = batch_data['K'].shape[0]
            batch_data = pad_batch(batch_data, args.val_batch_size)
            feed_dict = {
                net.K: batch_data['K'],
                net.Rcam: batch_data['R'],
                net.ims: batch_data['im']
            }
            pred = sess.run(net.depth_out, feed_dict=feed_dict)
            batch_err = eval_l1_err(pred[:num_batch_items],
                                    batch_data['depth'][:num_batch_items])

            l1_err.extend(batch_err)
            pbar.update(num_batch_items)
    except Exception, e:
        logger.error(repr(e))
        dset.close_queue(e)
net = MVNet(vmin=-0.5, vmax=0.5, vox_bs=bs,
    im_bs=ims_per_model, grid_size=args.nvox,
    im_h=args.im_h, im_w=args.im_w,
    norm=args.norm, mode="TEST")

net = model_vlsm(net, im_nets[args.im_net], grid_nets[args.grid_net], conv_rnns[args.rnn])
vars_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='MVNet')
sess = tf.InteractiveSession(config=get_session_config())
saver = tf.train.Saver(var_list=vars_restore)
saver.restore(sess, os.path.join(log_dir, ckpt))
#print net.im_batch
#sys.exit()

from shapenet import ShapeNet
# Read data
dset = ShapeNet(im_dir=im_dir, split_file=os.path.join(SAMPLE_DIR, 'splits_sample.json'), rng_seed=1)
test_mids = dset.get_smids('test')
print test_mids[0]

# Run the last three cells to run on different inputs
rand_sid, rand_mid = random.choice(test_mids) # Select model to test
rand_views = np.random.choice(dset.num_renders, size=(net.im_batch, ), replace=False) # Select views of model to test
#rand_views = range(5)
rand_sid = '03001627'
#rand_mid = '41d9bd662687cf503ca22f17e86bab24'
rand_mid = '53180e91cd6651ab76e29c9c43bc7aa'

# Load images and cameras
ims = dset.load_func['im'](rand_sid, rand_mid, rand_views)
ims = np.expand_dims(ims, 0)
R = dset.load_func['R'](rand_sid, rand_mid, rand_views)