def validate(args, checkpoint): net = MVNet(vmin=-0.5, vmax=0.5, vox_bs=args.val_batch_size, im_bs=args.val_im_batch, grid_size=args.nvox, im_h=args.im_h, im_w=args.im_w, mode="TEST", norm=args.norm) im_dir = SHAPENET_IM vox_dir = SHAPENET_VOX[args.nvox] # Setup network net = model_vlsm(net, im_nets[args.im_net], grid_nets[args.grid_net], conv_rnns[args.rnn]) sess = tf.Session(config=get_session_config()) saver = tf.train.Saver() saver.restore(sess, checkpoint) coord = tf.train.Coordinator() # Init IoU iou = init_iou(net.im_batch, args.eval_thresh) # Init dataset dset = ShapeNet(im_dir=im_dir, split_file=args.val_split_file, vox_dir=vox_dir, rng_seed=1) mids = dset.get_smids('val') logging.info('Testing %d models', len(mids)) items = ['shape_id', 'model_id', 'im', 'K', 'R', 'vol'] dset.init_queue(mids, args.val_im_batch, items, coord, nepochs=1, qsize=32, nthreads=args.prefetch_threads) # Testing loop pbar = tqdm(desc='Validating', total=len(mids)) deq_mids, deq_sids = [], [] try: while not coord.should_stop(): batch_data = dset.next_batch(items, net.batch_size) if batch_data is None: continue deq_sids.append(batch_data['shape_id']) deq_mids.append(batch_data['model_id']) num_batch_items = batch_data['K'].shape[0] batch_data = pad_batch(batch_data, args.val_batch_size) feed_dict = {net.K: batch_data['K'], net.Rcam: batch_data['R']} feed_dict[net.ims] = batch_data['im'] pred = sess.run(net.prob_vox, feed_dict=feed_dict) batch_iou = eval_seq_iou(pred[:num_batch_items], batch_data['vol'][:num_batch_items], args.val_im_batch, thresh=args.eval_thresh) # Update iou dict iou = update_iou(batch_iou, iou) pbar.update(num_batch_items) except Exception, e: logger.error(repr(e)) dset.close_queue(e)
if args.ckpt is None: log_dir = osp.join(args.logdir, key, 'train') else: log_dir = args.logdir # Initialize network parameters mvnet = MVNet(vmin=-0.5, vmax=0.5, vox_bs=args.batch_size, im_bs=args.im_batch, grid_size=args.nvox, im_h=args.im_h, im_w=args.im_w, mode="TRAIN", norm=args.norm) # Define graph mvnet = model_vlsm(mvnet, im_nets[args.im_net], grid_nets[args.grid_net], conv_rnns[args.rnn]) # Set things up mkdir_p(log_dir) write_args(args, osp.join(log_dir, 'args.json')) logger.info('Logging to {:s}'.format(log_dir)) logger.info('\nUsing args:') pprint(vars(args)) mvnet.print_net() # Train loop train(mvnet)