Esempio n. 1
0
def train(net):
    net.gt_depth = tf.placeholder(tf.float32, net.depth_tensor_shape)
    net.pred_depth = net.depth_out
    out_shape = tf_static_shape(net.pred_depth)
    net.depth_loss = loss_l1(
        net.pred_depth, repeat_tensor(net.gt_depth, out_shape[1], rep_dim=1))

    _t_dbg = Timer()

    # Add optimizer
    global_step = tf.Variable(0, trainable=False, name='global_step')
    decay_lr = tf.train.exponential_decay(args.lr,
                                          global_step,
                                          args.decay_steps,
                                          args.decay_rate,
                                          staircase=True)
    lr_sum = tf.summary.scalar('lr', decay_lr)
    optim = tf.train.AdamOptimizer(decay_lr).minimize(net.depth_loss,
                                                      global_step)
    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver()

    # Add summaries for training
    net.loss_sum = tf.summary.scalar('loss', net.depth_loss)
    net.im_sum = image_sum(net.ims, net.batch_size, net.im_batch)
    net.depth_gt_sum = depth_sum(net.gt_depth, net.batch_size, net.im_batch,
                                 'depth_gt')
    net.depth_pred_sum = depth_sum(net.pred_depth[:, -1, ...], net.batch_size,
                                   net.im_batch,
                                   'depth_pred_{:d}'.format(net.im_batch))
    merged_ims = tf.summary.merge(
        [net.im_sum, net.depth_gt_sum, net.depth_pred_sum])
    merged_scalars = tf.summary.merge([net.loss_sum, lr_sum])

    # Initialize dataset
    coord = tf.train.Coordinator()
    dset = ShapeNet(im_dir=im_dir,
                    split_file=args.split_file,
                    rng_seed=0,
                    custom_db=args.custom_training)
    mids = dset.get_smids('train')
    logger.info('Training with %d models', len(mids))
    items = ['im', 'K', 'R', 'depth']
    dset.init_queue(mids,
                    net.im_batch,
                    items,
                    coord,
                    qsize=64,
                    nthreads=args.prefetch_threads)

    _t_dbg = Timer()
    iters = 0
    # Training loop
    pbar = tqdm(desc='Training Depth-LSM', total=args.niters)
    with tf.Session(config=get_session_config()) as sess:
        sum_writer = tf.summary.FileWriter(log_dir, sess.graph)
        if args.ckpt is not None:
            logger.info('Restoring from %s', args.ckpt)
            saver.restore(sess, args.ckpt)
        else:
            sess.run(init_op)
        try:
            while True:
                iters += 1
                _t_dbg.tic()
                batch_data = dset.next_batch(items, net.batch_size)
                logging.debug('Data read time - %.3fs', _t_dbg.toc())
                feed_dict = {
                    net.ims: batch_data['im'],
                    net.K: batch_data['K'],
                    net.Rcam: batch_data['R'],
                    net.gt_depth: batch_data['depth']
                }
                if args.run_trace and (iters % args.sum_iters == 0
                                       or iters == 1 or iters == args.niters):
                    run_options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()

                    step_, _, merged_scalars_ = sess.run(
                        [global_step, optim, merged_scalars],
                        feed_dict=feed_dict,
                        options=run_options,
                        run_metadata=run_metadata)
                    sum_writer.add_run_metadata(run_metadata, 'step%d' % step_)
                else:
                    step_, _, merged_scalars_ = sess.run(
                        [global_step, optim, merged_scalars],
                        feed_dict=feed_dict)

                logging.debug('Net time - %.3fs', _t_dbg.toc())

                sum_writer.add_summary(merged_scalars_, step_)
                if iters % args.sum_iters == 0 or iters == 1 or iters == args.niters:
                    image_sum_, step_ = sess.run([merged_ims, global_step],
                                                 feed_dict=feed_dict)
                    sum_writer.add_summary(image_sum_, step_)

                if iters % args.ckpt_iters == 0 or iters == args.niters:
                    save_f = saver.save(sess,
                                        osp.join(log_dir, 'mvnet'),
                                        global_step=global_step)
                    logger.info(' Model checkpoint - {:s} '.format(save_f))

                pbar.update(1)
                if iters >= args.niters:
                    break
        except Exception, e:
            logging.error(repr(e))
            dset.close_queue(e)
        finally:
Esempio n. 2
0
def main(config_file):
    """ Train text recognition network
    """
    # Parse configs
    FLAGS = Flags(config_file).get()

    # Set directory, seed, logger
    model_dir = create_model_dir(FLAGS.model_dir)
    logger = get_logger(model_dir, 'train')
    best_model_dir = os.path.join(model_dir, 'best_models')
    set_seed(FLAGS.seed)

    # Print configs
    flag_strs = [
        '{}:\t{}'.format(name, value)
        for name, value in FLAGS._asdict().items()
    ]
    log_formatted(logger, '[+] Model configurations', *flag_strs)

    # Print system environments
    num_gpus = count_available_gpus()
    num_cpus = os.cpu_count()
    mem_size = virtual_memory().available // (1024**3)
    log_formatted(logger, '[+] System environments',
                  'The number of gpus : {}'.format(num_gpus),
                  'The number of cpus : {}'.format(num_cpus),
                  'Memory Size : {}G'.format(mem_size))

    # Get optimizer and network
    global_step = tf.train.get_or_create_global_step()
    optimizer, learning_rate = get_optimizer(FLAGS.train.optimizer,
                                             global_step)
    out_charset = load_charset(FLAGS.charset)
    net = get_network(FLAGS, out_charset)
    is_ctc = (net.loss_fn == 'ctc_loss')

    # Multi tower for multi-gpu training
    tower_grads = []
    tower_extra_update_ops = []
    tower_preds = []
    tower_gts = []
    tower_losses = []
    batch_size = FLAGS.train.batch_size
    tower_batch_size = batch_size // num_gpus

    val_tower_outputs = []
    eval_tower_outputs = []

    for gpu_indx in range(num_gpus):

        # Train tower
        print('[+] Build Train tower GPU:%d' % gpu_indx)
        input_device = '/gpu:%d' % gpu_indx

        tower_batch_size = tower_batch_size \
            if gpu_indx < num_gpus-1 \
            else batch_size - tower_batch_size * (num_gpus-1)

        train_loader = DatasetLodaer(
            dataset_paths=FLAGS.train.dataset_paths,
            dataset_portions=FLAGS.train.dataset_portions,
            batch_size=tower_batch_size,
            label_maxlen=FLAGS.label_maxlen,
            out_charset=out_charset,
            preprocess_image=net.preprocess_image,
            is_train=True,
            is_ctc=is_ctc,
            shuffle_and_repeat=True,
            concat_batch=True,
            input_device=input_device,
            num_cpus=num_cpus,
            num_gpus=num_gpus,
            worker_index=gpu_indx,
            use_rgb=FLAGS.use_rgb,
            seed=FLAGS.seed,
            name='train')

        tower_output = single_tower(net,
                                    gpu_indx,
                                    train_loader,
                                    out_charset,
                                    optimizer,
                                    name='train',
                                    is_train=True)
        tower_grads.append([x for x in tower_output.grads if x[0] is not None])
        tower_extra_update_ops.append(tower_output.extra_update_ops)
        tower_preds.append(tower_output.prediction)
        tower_gts.append(tower_output.text)
        tower_losses.append(tower_output.loss)

        # Print network structure
        if gpu_indx == 0:
            param_stats = tf.profiler.profile(tf.get_default_graph())
            logger.info('total_params: %d\n' % param_stats.total_parameters)

        # Valid tower
        print('[+] Build Valid tower GPU:%d' % gpu_indx)
        valid_loader = DatasetLodaer(dataset_paths=FLAGS.valid.dataset_paths,
                                     dataset_portions=None,
                                     batch_size=FLAGS.valid.batch_size //
                                     num_gpus,
                                     label_maxlen=FLAGS.label_maxlen,
                                     out_charset=out_charset,
                                     preprocess_image=net.preprocess_image,
                                     is_train=False,
                                     is_ctc=is_ctc,
                                     shuffle_and_repeat=False,
                                     concat_batch=False,
                                     input_device=input_device,
                                     num_cpus=num_cpus,
                                     num_gpus=num_gpus,
                                     worker_index=gpu_indx,
                                     use_rgb=FLAGS.use_rgb,
                                     seed=FLAGS.seed,
                                     name='valid')

        val_tower_output = single_tower(net,
                                        gpu_indx,
                                        valid_loader,
                                        out_charset,
                                        optimizer=None,
                                        name='valid',
                                        is_train=False)

        val_tower_outputs.append(
            (val_tower_output.loss, val_tower_output.prediction,
             val_tower_output.text, val_tower_output.filename,
             val_tower_output.dataset))

    # Aggregate gradients
    losses = tf.reduce_mean(tower_losses)
    grads = _average_gradients(tower_grads)

    with tf.control_dependencies(tower_extra_update_ops[-1]):
        if FLAGS.train.optimizer.grad_clip_norm is not None:
            grads, global_norm = _clip_gradients(
                grads, FLAGS.train.optimizer.grad_clip_norm)
            tf.summary.scalar('global_norm', global_norm)

        train_op = optimizer.apply_gradients(grads, global_step=global_step)

    # Define config, scaffold
    saver = tf.train.Saver()
    sess_config = get_session_config()
    scaffold = get_scaffold(saver, FLAGS.train.tune_from, 'train')
    restore_model = get_init_trained()

    # Define validation saver, summary writer
    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
    val_summary_op = tf.summary.merge(
        [s for s in summaries if 'valid' in s.name])
    val_summary_writer = {
        dataset_name:
        tf.summary.FileWriter(os.path.join(model_dir, 'valid', dataset_name))
        for dataset_name in valid_loader.dataset_names
    }
    val_summary_writer['total_valid'] = tf.summary.FileWriter(
        os.path.join(model_dir, 'valid', 'total_valid'))
    val_saver = tf.train.Saver(max_to_keep=len(valid_loader.dataset_names) + 1)
    best_val_err_rates = {}
    best_steps = {}

    # Training
    print('[+] Make Session...')

    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=model_dir,
            scaffold=scaffold,
            config=sess_config,
            save_checkpoint_steps=FLAGS.train.save_steps,
            save_checkpoint_secs=None,
            save_summaries_steps=FLAGS.train.summary_steps,
            save_summaries_secs=None,
    ) as sess:

        log_formatted(logger, 'Training started!')
        _step = 0
        train_t = 0
        start_t = time.time()

        while _step < FLAGS.train.max_num_steps \
                and not sess.should_stop():

            # Train step
            step_t = time.time()
            [step_loss, _, _step, preds, gts, lr] = sess.run([
                losses, train_op, global_step, tower_preds[0], tower_gts[0],
                learning_rate
            ])
            train_t += time.time() - step_t

            # Summary
            if _step % FLAGS.valid.steps == 0:

                # Train summary
                train_err = 0.

                for i, (p, g) in enumerate(zip(preds, gts)):
                    s = get_string(p, out_charset, is_ctc=is_ctc)
                    g = g.decode('utf8').replace(DELIMITER, '')

                    s = adjust_string(s, FLAGS.train.lowercase,
                                      FLAGS.train.alphanumeric)
                    g = adjust_string(g, FLAGS.train.lowercase,
                                      FLAGS.train.alphanumeric)
                    e = int(s != g)

                    train_err += e

                    if FLAGS.train.verbose and i < 5:
                        print('TRAIN :\t{}\t{}\t{}'.format(s, g, not bool(e)))

                train_err_rate = \
                    train_err / len(gts)

                # Valid summary
                val_cnts, val_errs, val_err_rates, _ = \
                    validate(sess,
                             _step,
                             val_tower_outputs,
                             out_charset,
                             is_ctc,
                             val_summary_op,
                             val_summary_writer,
                             val_saver,
                             best_val_err_rates,
                             best_steps,
                             best_model_dir,
                             FLAGS.valid.lowercase,
                             FLAGS.valid.alphanumeric)

                # Logging
                log_strings = ['', '-' * 28 + ' VALID_DETAIL ' + '-' * 28, '']

                for dataset in sorted(val_err_rates.keys()):
                    if dataset == 'total_valid':
                        continue

                    cnt = val_cnts[dataset]
                    err = val_errs[dataset]
                    err_rate = val_err_rates[dataset]
                    best_step = best_steps[dataset]

                    s = '%s : %.2f%%(%d/%d)\tBEST_STEP : %d' % \
                        (dataset, (1.-err_rate)*100, cnt-err, cnt, best_step)

                    log_strings.append(s)

                elapsed_t = float(time.time() - start_t) / 60
                remain_t = (elapsed_t / (_step+1)) * \
                    (FLAGS.train.max_num_steps - _step - 1)
                log_formatted(
                    logger, 'STEP : %d\tTRAIN_LOSS : %f' % (_step, step_loss),
                    'ELAPSED : %.2f min\tREMAIN : %.2f min\t'
                    'STEP_TIME: %.1f sec' %
                    (elapsed_t, remain_t, float(train_t) / (_step + 1)),
                    'TRAIN_SEQ_ERR : %f\tVALID_SEQ_ERR : %f' %
                    (train_err_rate, val_err_rates['total_valid']),
                    'BEST_STEP : %d\tBEST_VALID_SEQ_ERR : %f' %
                    (best_steps['total_valid'],
                     best_val_err_rates['total_valid']), *log_strings)

        log_formatted(logger, 'Training is completed!')
Esempio n. 3
0
def validate(args, checkpoint):
    net = MVNet(vmin=-0.5,
                vmax=0.5,
                vox_bs=args.val_batch_size,
                im_bs=args.val_im_batch,
                grid_size=args.nvox,
                im_h=args.im_h,
                im_w=args.im_w,
                mode="TEST",
                norm=args.norm)

    im_dir = SHAPENET_IM
    vox_dir = SHAPENET_VOX[args.nvox]

    # Setup network
    net = model_vlsm(net, im_nets[args.im_net], grid_nets[args.grid_net],
                     conv_rnns[args.rnn])
    sess = tf.Session(config=get_session_config())
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint)
    coord = tf.train.Coordinator()

    # Init IoU
    iou = init_iou(net.im_batch, args.eval_thresh)

    # Init dataset
    dset = ShapeNet(im_dir=im_dir,
                    split_file=args.val_split_file,
                    vox_dir=vox_dir,
                    rng_seed=1)
    mids = dset.get_smids('val')
    logging.info('Testing %d models', len(mids))
    items = ['shape_id', 'model_id', 'im', 'K', 'R', 'vol']
    dset.init_queue(mids,
                    args.val_im_batch,
                    items,
                    coord,
                    nepochs=1,
                    qsize=32,
                    nthreads=args.prefetch_threads)

    # Testing loop
    pbar = tqdm(desc='Validating', total=len(mids))
    deq_mids, deq_sids = [], []
    try:
        while not coord.should_stop():
            batch_data = dset.next_batch(items, net.batch_size)
            if batch_data is None:
                continue
            deq_sids.append(batch_data['shape_id'])
            deq_mids.append(batch_data['model_id'])
            num_batch_items = batch_data['K'].shape[0]
            batch_data = pad_batch(batch_data, args.val_batch_size)

            feed_dict = {net.K: batch_data['K'], net.Rcam: batch_data['R']}
            feed_dict[net.ims] = batch_data['im']

            pred = sess.run(net.prob_vox, feed_dict=feed_dict)
            batch_iou = eval_seq_iou(pred[:num_batch_items],
                                     batch_data['vol'][:num_batch_items],
                                     args.val_im_batch,
                                     thresh=args.eval_thresh)

            # Update iou dict
            iou = update_iou(batch_iou, iou)
            pbar.update(num_batch_items)
    except Exception, e:
        logger.error(repr(e))
        dset.close_queue(e)
Esempio n. 4
0
    return net


if __name__ == '__main__':
    key = time.strftime("%Y-%m-%d_%H%M%S")
    args = parse_args()
    init_logging(args.loglevel)
    logger = logging.getLogger('silhonet.' + __name__)
    logger.setLevel(logging.DEBUG)

    # Set visible GPUs for training
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus

    # Create session
    sess = tf.Session(config=get_session_config())

    if 'train' in args.mode:
        args.run_mode = 'train'
    elif 'test' in args.mode:
        args.run_mode = 'test'
    if 'seg' in args.mode:
        args.net_mode = 'seg'
    elif 'quat' in args.mode:
        args.net_mode = 'quat'

    if args.run_mode == 'train':
        if args.ckpt is None:
            args.log_dir = osp.join(args.logdir, key, 'train')
        else:
            args.log_dir = args.logdir
Esempio n. 5
0
def validate(args, checkpoint):
    net = MVNet(vmin=-0.5,
                vmax=0.5,
                vox_bs=args.val_batch_size,
                im_bs=args.val_im_batch,
                grid_size=args.nvox,
                im_h=args.im_h,
                im_w=args.im_w,
                mode="TEST",
                norm=args.norm)

    im_dir = SHAPENET_IM

    # Setup network
    net = model_dlsm(net,
                     im_nets[args.im_net],
                     grid_nets[args.grid_net],
                     conv_rnns[args.rnn],
                     im_skip=args.im_skip,
                     ray_samples=args.ray_samples,
                     sepup=args.sepup,
                     proj_x=args.proj_x,
                     proj_last=True)
    sess = tf.Session(config=get_session_config())
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint)
    coord = tf.train.Coordinator()

    # Init dataset
    dset = ShapeNet(im_dir=im_dir, split_file=args.val_split_file, rng_seed=1)
    mids = dset.get_smids('val')
    logging.info('Validating %d models', len(mids))
    items = ['shape_id', 'model_id', 'im', 'K', 'R', 'depth']
    dset.init_queue(mids,
                    args.val_im_batch,
                    items,
                    coord,
                    nepochs=1,
                    qsize=32,
                    nthreads=args.prefetch_threads)

    # Init stats
    l1_err = []

    # Testing loop
    pbar = tqdm(desc='Validating', total=len(mids))
    deq_mids, deq_sids = [], []
    try:
        while not coord.should_stop():
            batch_data = dset.next_batch(items, net.batch_size)
            if batch_data is None:
                continue
            deq_sids.append(batch_data['shape_id'])
            deq_mids.append(batch_data['model_id'])
            num_batch_items = batch_data['K'].shape[0]
            batch_data = pad_batch(batch_data, args.val_batch_size)
            feed_dict = {
                net.K: batch_data['K'],
                net.Rcam: batch_data['R'],
                net.ims: batch_data['im']
            }
            pred = sess.run(net.depth_out, feed_dict=feed_dict)
            batch_err = eval_l1_err(pred[:num_batch_items],
                                    batch_data['depth'][:num_batch_items])

            l1_err.extend(batch_err)
            pbar.update(num_batch_items)
    except Exception, e:
        logger.error(repr(e))
        dset.close_queue(e)
Esempio n. 6
0
def main(config_file=None):
    """ Run evaluation.
    """
    # Parse Config
    print('[+] Model configurations')
    FLAGS = Flags(config_file).get()
    for name, value in FLAGS._asdict().items():
        print('{}:\t{}'.format(name, value))
    print('\n')

    # System environments
    num_gpus = count_available_gpus()
    num_cpus = os.cpu_count()
    mem_size = virtual_memory().available // (1024**3)
    out_charset = load_charset(FLAGS.charset)
    print('[+] System environments')
    print('The number of gpus : {}'.format(num_gpus))
    print('The number of cpus : {}'.format(num_cpus))
    print('Memory Size : {}G'.format(mem_size))
    print('The number of characters : {}\n'.format(len(out_charset)))

    # Make results dir
    res_dir = os.path.join(FLAGS.eval.model_path)
    os.makedirs(res_dir, exist_ok=True)

    # Get network
    net = get_network(FLAGS, out_charset)
    is_ctc = (net.loss_fn == 'ctc_loss')

    # Define Graph
    eval_tower_outputs = []
    global_step = tf.train.get_or_create_global_step()

    for gpu_indx in range(num_gpus):
        # Get eval dataset
        input_device = '/gpu:%d' % gpu_indx
        print('[+] Build Eval tower GPU:%d' % gpu_indx)

        eval_loader = DatasetLodaer(dataset_paths=FLAGS.eval.dataset_paths,
                                    dataset_portions=None,
                                    batch_size=FLAGS.eval.batch_size,
                                    label_maxlen=FLAGS.label_maxlen,
                                    out_charset=out_charset,
                                    preprocess_image=net.preprocess_image,
                                    is_train=False,
                                    is_ctc=is_ctc,
                                    shuffle_and_repeat=False,
                                    concat_batch=False,
                                    input_device=input_device,
                                    num_cpus=num_cpus,
                                    num_gpus=num_gpus,
                                    worker_index=gpu_indx,
                                    use_rgb=FLAGS.use_rgb,
                                    seed=FLAGS.seed,
                                    name='eval')

        eval_tower_output = single_tower(net,
                                         gpu_indx,
                                         eval_loader,
                                         out_charset,
                                         optimizer=None,
                                         name='eval',
                                         is_train=False)

        eval_tower_outputs.append(
            (eval_tower_output.loss, eval_tower_output.prediction,
             eval_tower_output.text, eval_tower_output.filename,
             eval_tower_output.dataset))

    # Summary
    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
    summary_op = tf.summary.merge([s for s in summaries])
    summary_writer = {
        dataset_name:
        tf.summary.FileWriter(os.path.join(res_dir, dataset_name))
        for dataset_name in eval_loader.dataset_names
    }
    summary_writer['total_valid'] = tf.summary.FileWriter(
        os.path.join(res_dir, 'total_eval'))

    # Define config, scaffold, hooks
    saver = tf.train.Saver()
    sess_config = get_session_config()
    restore_model = get_init_trained()
    scaffold = get_scaffold(saver, None, 'eval')

    # Testing
    with tf.train.MonitoredTrainingSession(scaffold=scaffold,
                                           config=sess_config) as sess:

        # Restore and init.
        restore_model(sess, FLAGS.eval.model_path)
        _step = sess.run(global_step)
        infet_t = 0

        # Run test
        start_t = time.time()
        eval_cnts, eval_errs, eval_err_rates, eval_preds = \
            validate(sess,
                     _step,
                     eval_tower_outputs,
                     out_charset,
                     is_ctc,
                     summary_op,
                     summary_writer,
                     lowercase=FLAGS.eval.lowercase,
                     alphanumeric=FLAGS.eval.alphanumeric)
        infer_t = time.time() - start_t

    # Log
    total_total = 0

    for dataset, result in eval_preds.items():
        res_file = open(os.path.join(res_dir, '{}.txt'.format(dataset)), 'w')
        total = eval_cnts[dataset]
        correct = total - eval_errs[dataset]
        acc = 1. - eval_err_rates[dataset]
        total_total += total

        for f, s, g in result:
            f = f.decode('utf8')

            if FLAGS.eval.verbose:
                print('FILE : ' + f)
                print('PRED : ' + s)
                print('ANSW : ' + g)
                print('=' * 50)

            res_file.write('{}\t{}\n'.format(f, s))

        res_s = 'DATASET : %s\tCORRECT : %d\tTOTAL : %d\tACC : %f' % \
                (dataset, correct, total, acc)
        print(res_s)
        res_file.write(res_s)
        res_file.close()

    eval_loader.flush_tmpfile()
    print('INFER TIME(PER IMAGE) : %f s' % (float(infer_t) / total_total))
voxel_resolution = 32
# Setup TF graph and initialize VLSM model
tf.reset_default_graph()

# Change the ims_per_model to run on different number of views
bs, ims_per_model = 1, 4

ckpt = 'mvnet-100000'
net = MVNet(vmin=-0.5, vmax=0.5, vox_bs=bs,
    im_bs=ims_per_model, grid_size=args.nvox,
    im_h=args.im_h, im_w=args.im_w,
    norm=args.norm, mode="TEST")

net = model_vlsm(net, im_nets[args.im_net], grid_nets[args.grid_net], conv_rnns[args.rnn])
vars_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='MVNet')
sess = tf.InteractiveSession(config=get_session_config())
saver = tf.train.Saver(var_list=vars_restore)
saver.restore(sess, os.path.join(log_dir, ckpt))
#print net.im_batch
#sys.exit()

from shapenet import ShapeNet
# Read data
dset = ShapeNet(im_dir=im_dir, split_file=os.path.join(SAMPLE_DIR, 'splits_sample.json'), rng_seed=1)
test_mids = dset.get_smids('test')
print test_mids[0]

# Run the last three cells to run on different inputs
rand_sid, rand_mid = random.choice(test_mids) # Select model to test
rand_views = np.random.choice(dset.num_renders, size=(net.im_batch, ), replace=False) # Select views of model to test
#rand_views = range(5)