def _skip_unet(d_f, im_f): ''' im_f: bs x im_bs x ... ; d_f: bs x t x im_bs ...''' with tf.variable_scope('Skip'): d_shape = tf_static_shape(d_f) im_shape = tf_static_shape(im_f) im_f = uncollapse_dims(im_f, net.batch_size, net.im_batch) im_rep = repeat_tensor(im_f, d_shape[0] / im_shape[0], rep_dim=1) im_rep = tf.reshape(im_rep, d_shape[:-1] + [im_shape[-1]]) return tf.concat([im_rep, d_f], axis=-1)
def model_silhouette(net, im_net=im_vgg16, seg_net=seg_stack, threshold_mask=False): ''' Silhouette prediction model ''' with tf.variable_scope('MVNet'): # Compute model image features net.mdl_ims_stack = collapse_dims(net.mdl_ims) num_its = int(tf_static_shape(net.mdl_ims_stack)[0] / net.batch_size) for v in range(0, num_its): im_feats = im_net( net, net.mdl_ims_stack[v * net.batch_size:(1 + v) * net.batch_size, ...]) if v == 0: net.mdl_im_feats = im_feats else: net.mdl_im_feats = tf.concat([net.mdl_im_feats, im_feats], 0) # Compute input image features net.input_im_feats = im_net(net, net.input_ims) # Extract ROIs from input image feature maps and resize to fixed size net.roi_im_feats = tf.image.crop_and_resize( net.input_im_feats, net.rois, tf.range(tf_static_shape(net.rois)[0]), net.roi_shape[1:], method='bilinear') # Use convolution layers to resize each mdl image feature map to 64x64 net.mdl_resized_feats = mdl_resize_net_64(net, net.mdl_im_feats) # Compute segmentation maps net.seg_net = {} net.pred_seg_map = seg_net(net, net.mdl_resized_feats, net.roi_im_feats, scope='im_seg_net') net.pred_seg_map = tf.squeeze(net.pred_seg_map, [-1]) net.prob_seg_map = tf.nn.sigmoid(net.pred_seg_map) if threshold_mask: net.prob_seg_map = tf.cast((net.prob_seg_map > 0.3), tf.float32) net.pred_occ_seg_map = seg_net(net, net.mdl_resized_feats, net.roi_im_feats, scope='im_occ_seg_net') net.pred_occ_seg_map = tf.squeeze(net.pred_occ_seg_map, [-1]) net.prob_occ_seg_map = tf.nn.sigmoid(net.pred_occ_seg_map) if threshold_mask: net.prob_occ_seg_map = tf.cast((net.prob_occ_seg_map > 0.3), tf.float32) return net
def model_quat(net, quat_net=quat_res): ''' Quaternion prediction model ''' with tf.variable_scope('Quaternion_TM2SM'): if net.training: silhouette = net.gt_segment else: silhouette = tf.stop_gradient(net.prob_seg_map) net.pred_quat = quat_net(net, silhouette) # extract predictions from class output vectors for v in range(0, net.batch_size): dim = int(tf_static_shape(net.pred_quat)[-1] / net.num_classes) cls_id = net.class_id[v] quat = net.pred_quat[v, cls_id * dim:(cls_id + 1) * dim] quat = tf.expand_dims(quat, axis=0) if v == 0: quat_stack = quat else: quat_stack = tf.concat([quat_stack, quat], 0) net.pred_quat = quat_stack net.pred_quat = tf.nn.l2_normalize(net.pred_quat, dim=-1) return net
def model_dlsm(net, im_net=im_unet, grid_net=grid_unet32, rnn=convgru, ray_samples=64, proj_x=4, sepup=False, im_skip=True, proj_last=False): '''Depth LSTM model ''' with tf.variable_scope('MVNet'): # Setup placeholders for im, depth, extrinsic and intrinsic matrices net.ims = tf.placeholder(tf.float32, net.im_tensor_shape, name='ims') net.K = tf.placeholder(tf.float32, net.K_tensor_shape, name='K') net.Rcam = tf.placeholder(tf.float32, net.R_tensor_shape, name='R') # Compute image features net.im_feats = im_net(net, collapse_dims(net.ims)) # Unproject feature grid net.cost_grid = proj_splat(net, net.im_feats, net.K, net.Rcam) # Combine grids with LSTM/GRU net.pool_grid, _ = rnn(net.cost_grid) # Grid network net.pool_grid = collapse_dims(net.pool_grid) net.pred_vox = grid_net(net, net.pool_grid) net.proj_vox = uncollapse_dims(net.grid_net['deconv3'], net.batch_size, net.im_batch) # Projection proj_vox_in = (net.proj_vox if not proj_last else net.proj_vox[:, -1:, ...]) net.ray_slices, z_samples = proj_slice(net, proj_vox_in, net.K, net.Rcam, proj_size=net.im_h / proj_x, samples=ray_samples) bs, im_bs, ks, im_sz1, im_sz2, fdim, _ = tf_static_shape( net.ray_slices) net.depth_in = tf.reshape( net.ray_slices, [bs * im_bs * ks, im_sz1, im_sz2, fdim * ray_samples]) # Depth network if proj_x == 4: if not sepup: net.depth_out = depth_net_x4(net, net.depth_in, im_skip) else: net.depth_out = depth_net_x4_sepup(net, net.depth_in, im_skip) elif proj_x == 8: if not sepup: net.depth_out = depth_net_x8(net, net.depth_in, im_skip) else: net.depth_out = depth_net_x8_sepup(net, net.depth_in, im_skip) else: logger = logging.getLogger('mview3d.' + __name__) logger.error( 'Unsupported subsample ratio for projection. Use {4, 8}') net.depth_out = tf.reshape(net.depth_out, [bs, im_bs, ks, net.im_h, net.im_w, 1]) return net
def train(net): net.gt_depth = tf.placeholder(tf.float32, net.depth_tensor_shape) net.pred_depth = net.depth_out out_shape = tf_static_shape(net.pred_depth) net.depth_loss = loss_l1( net.pred_depth, repeat_tensor(net.gt_depth, out_shape[1], rep_dim=1)) _t_dbg = Timer() # Add optimizer global_step = tf.Variable(0, trainable=False, name='global_step') decay_lr = tf.train.exponential_decay(args.lr, global_step, args.decay_steps, args.decay_rate, staircase=True) lr_sum = tf.summary.scalar('lr', decay_lr) optim = tf.train.AdamOptimizer(decay_lr).minimize(net.depth_loss, global_step) init_op = tf.global_variables_initializer() saver = tf.train.Saver() # Add summaries for training net.loss_sum = tf.summary.scalar('loss', net.depth_loss) net.im_sum = image_sum(net.ims, net.batch_size, net.im_batch) net.depth_gt_sum = depth_sum(net.gt_depth, net.batch_size, net.im_batch, 'depth_gt') net.depth_pred_sum = depth_sum(net.pred_depth[:, -1, ...], net.batch_size, net.im_batch, 'depth_pred_{:d}'.format(net.im_batch)) merged_ims = tf.summary.merge( [net.im_sum, net.depth_gt_sum, net.depth_pred_sum]) merged_scalars = tf.summary.merge([net.loss_sum, lr_sum]) # Initialize dataset coord = tf.train.Coordinator() dset = ShapeNet(im_dir=im_dir, split_file=args.split_file, rng_seed=0, custom_db=args.custom_training) mids = dset.get_smids('train') logger.info('Training with %d models', len(mids)) items = ['im', 'K', 'R', 'depth'] dset.init_queue(mids, net.im_batch, items, coord, qsize=64, nthreads=args.prefetch_threads) _t_dbg = Timer() iters = 0 # Training loop pbar = tqdm(desc='Training Depth-LSM', total=args.niters) with tf.Session(config=get_session_config()) as sess: sum_writer = tf.summary.FileWriter(log_dir, sess.graph) if args.ckpt is not None: logger.info('Restoring from %s', args.ckpt) saver.restore(sess, args.ckpt) else: sess.run(init_op) try: while True: iters += 1 _t_dbg.tic() batch_data = dset.next_batch(items, net.batch_size) logging.debug('Data read time - %.3fs', _t_dbg.toc()) feed_dict = { net.ims: batch_data['im'], net.K: batch_data['K'], net.Rcam: batch_data['R'], net.gt_depth: batch_data['depth'] } if args.run_trace and (iters % args.sum_iters == 0 or iters == 1 or iters == args.niters): run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() step_, _, merged_scalars_ = sess.run( [global_step, optim, merged_scalars], feed_dict=feed_dict, options=run_options, run_metadata=run_metadata) sum_writer.add_run_metadata(run_metadata, 'step%d' % step_) else: step_, _, merged_scalars_ = sess.run( [global_step, optim, merged_scalars], feed_dict=feed_dict) logging.debug('Net time - %.3fs', _t_dbg.toc()) sum_writer.add_summary(merged_scalars_, step_) if iters % args.sum_iters == 0 or iters == 1 or iters == args.niters: image_sum_, step_ = sess.run([merged_ims, global_step], feed_dict=feed_dict) sum_writer.add_summary(image_sum_, step_) if iters % args.ckpt_iters == 0 or iters == args.niters: save_f = saver.save(sess, osp.join(log_dir, 'mvnet'), global_step=global_step) logger.info(' Model checkpoint - {:s} '.format(save_f)) pbar.update(1) if iters >= args.niters: break except Exception, e: logging.error(repr(e)) dset.close_queue(e) finally:
def im_vgg16(net, ims): net.im_net = {} bs, h, w, ch = tf_static_shape(ims) with tf.variable_scope('ImNet_UNet', reuse=tf.AUTO_REUSE): #VGG16 layers conv1_1 = conv2d('conv1_1', ims, 3, 64, mode=net.mode, act=None) net.im_net['conv1_1'] = conv1_1 conv1_2 = conv2d('conv1_2', conv1_1, 3, 64, mode=net.mode) net.im_net['conv1_2'] = conv1_2 pool1 = tf.layers.max_pooling2d(conv1_2, 2, 2, padding='same', name='pool1') conv2_1 = conv2d('conv2_1', pool1, 3, 128, mode=net.mode) net.im_net['conv2_1'] = conv2_1 conv2_2 = conv2d('conv2_2', conv2_1, 3, 128, mode=net.mode) net.im_net['conv2_2'] = conv2_2 pool2 = tf.layers.max_pooling2d(conv2_2, 2, 2, padding='same', name='pool2') net.im_net['pool2'] = pool2 conv3_1 = conv2d('conv3_1', pool2, 3, 256, mode=net.mode) net.im_net['conv3_1'] = conv3_1 conv3_2 = conv2d('conv3_2', conv3_1, 3, 256, mode=net.mode) net.im_net['conv3_2'] = conv3_2 conv3_3 = conv2d('conv3_3', conv3_2, 3, 256, mode=net.mode) net.im_net['conv3_3'] = conv3_3 pool3 = tf.layers.max_pooling2d(conv3_3, 2, 2, padding='same', name='pool3') net.im_net['pool3'] = pool3 conv4_1 = conv2d('conv4_1', pool3, 3, 512, mode=net.mode) net.im_net['conv4_1'] = conv4_1 conv4_2 = conv2d('conv4_2', conv4_1, 3, 512, mode=net.mode) net.im_net['conv4_2'] = conv4_2 conv4_3 = conv2d('conv4_3', conv4_2, 3, 512, mode=net.mode) net.im_net['conv4_3'] = conv4_3 pool4 = tf.layers.max_pooling2d(conv4_3, 2, 2, padding='same', name='pool4') net.im_net['pool4'] = pool4 conv5_1 = conv2d('conv5_1', pool4, 3, 512, mode=net.mode) net.im_net['conv5_1'] = conv5_1 conv5_2 = conv2d('conv5_2', conv5_1, 3, 512, mode=net.mode) net.im_net['conv5_2'] = conv5_2 conv5_3 = conv2d('conv5_3', conv5_2, 3, 512, mode=net.mode) net.im_net['conv5_3'] = conv5_3 #Deconv layers feat_conv5 = conv2d('feat_conv5', conv5_3, 1, 64, norm=net.norm, mode=net.mode) net.im_net['feat_conv5'] = feat_conv5 upfeat_conv5 = deconv_pcnn(feat_conv5, 4, 4, 64, 2, 2, name='upfeat_conv5', trainable=False) # upfeat_conv5 = deconv2d('upfeat_conv5', conv5_3, 4, 64, stride=2, padding="SAME", norm=net.norm, mode=net.mode) net.im_net['upfeat_conv5'] = upfeat_conv5 feat_conv4 = conv2d('feat_conv4', conv4_3, 1, 64, norm=net.norm, mode=net.mode) net.im_net['feat_conv4'] = feat_conv4 add_feat = tf.add_n([upfeat_conv5, feat_conv4], name='add_feat') add_feat = dropout(add_feat, net.keep_prob) net.im_net['add_feat'] = add_feat upfeat = deconv_pcnn(add_feat, 16, 16, 64, 8, 8, name='upfeat', trainable=False) # upfeat = deconv2d('upfeat', add_feat, 16, 64, stride=8, padding="SAME", norm=net.norm, mode=net.mode) net.im_net['upfeat'] = upfeat return upfeat