Ejemplo n.º 1
0
 def _skip_unet(d_f, im_f):
     ''' im_f: bs x im_bs x ... ; d_f: bs x t x im_bs ...'''
     with tf.variable_scope('Skip'):
         d_shape = tf_static_shape(d_f)
         im_shape = tf_static_shape(im_f)
         im_f = uncollapse_dims(im_f, net.batch_size, net.im_batch)
         im_rep = repeat_tensor(im_f, d_shape[0] / im_shape[0], rep_dim=1)
         im_rep = tf.reshape(im_rep, d_shape[:-1] + [im_shape[-1]])
         return tf.concat([im_rep, d_f], axis=-1)
Ejemplo n.º 2
0
def model_silhouette(net,
                     im_net=im_vgg16,
                     seg_net=seg_stack,
                     threshold_mask=False):
    ''' Silhouette prediction model '''
    with tf.variable_scope('MVNet'):
        # Compute model image features
        net.mdl_ims_stack = collapse_dims(net.mdl_ims)
        num_its = int(tf_static_shape(net.mdl_ims_stack)[0] / net.batch_size)
        for v in range(0, num_its):
            im_feats = im_net(
                net,
                net.mdl_ims_stack[v * net.batch_size:(1 + v) * net.batch_size,
                                  ...])
            if v == 0:
                net.mdl_im_feats = im_feats
            else:
                net.mdl_im_feats = tf.concat([net.mdl_im_feats, im_feats], 0)
        # Compute input image features
        net.input_im_feats = im_net(net, net.input_ims)
        # Extract ROIs from input image feature maps and resize to fixed size
        net.roi_im_feats = tf.image.crop_and_resize(
            net.input_im_feats,
            net.rois,
            tf.range(tf_static_shape(net.rois)[0]),
            net.roi_shape[1:],
            method='bilinear')
        # Use convolution layers to resize each mdl image feature map to 64x64
        net.mdl_resized_feats = mdl_resize_net_64(net, net.mdl_im_feats)
        # Compute segmentation maps
        net.seg_net = {}
        net.pred_seg_map = seg_net(net,
                                   net.mdl_resized_feats,
                                   net.roi_im_feats,
                                   scope='im_seg_net')
        net.pred_seg_map = tf.squeeze(net.pred_seg_map, [-1])
        net.prob_seg_map = tf.nn.sigmoid(net.pred_seg_map)
        if threshold_mask:
            net.prob_seg_map = tf.cast((net.prob_seg_map > 0.3), tf.float32)
        net.pred_occ_seg_map = seg_net(net,
                                       net.mdl_resized_feats,
                                       net.roi_im_feats,
                                       scope='im_occ_seg_net')
        net.pred_occ_seg_map = tf.squeeze(net.pred_occ_seg_map, [-1])
        net.prob_occ_seg_map = tf.nn.sigmoid(net.pred_occ_seg_map)
        if threshold_mask:
            net.prob_occ_seg_map = tf.cast((net.prob_occ_seg_map > 0.3),
                                           tf.float32)
        return net
Ejemplo n.º 3
0
def model_quat(net, quat_net=quat_res):
    ''' Quaternion prediction model '''
    with tf.variable_scope('Quaternion_TM2SM'):
        if net.training:
            silhouette = net.gt_segment
        else:
            silhouette = tf.stop_gradient(net.prob_seg_map)
        net.pred_quat = quat_net(net, silhouette)
        # extract predictions from class output vectors
        for v in range(0, net.batch_size):
            dim = int(tf_static_shape(net.pred_quat)[-1] / net.num_classes)
            cls_id = net.class_id[v]
            quat = net.pred_quat[v, cls_id * dim:(cls_id + 1) * dim]
            quat = tf.expand_dims(quat, axis=0)
            if v == 0:
                quat_stack = quat
            else:
                quat_stack = tf.concat([quat_stack, quat], 0)
        net.pred_quat = quat_stack
        net.pred_quat = tf.nn.l2_normalize(net.pred_quat, dim=-1)
        return net
Ejemplo n.º 4
0
def model_dlsm(net,
               im_net=im_unet,
               grid_net=grid_unet32,
               rnn=convgru,
               ray_samples=64,
               proj_x=4,
               sepup=False,
               im_skip=True,
               proj_last=False):
    '''Depth LSTM model '''

    with tf.variable_scope('MVNet'):
        # Setup placeholders for im, depth, extrinsic and intrinsic matrices
        net.ims = tf.placeholder(tf.float32, net.im_tensor_shape, name='ims')
        net.K = tf.placeholder(tf.float32, net.K_tensor_shape, name='K')
        net.Rcam = tf.placeholder(tf.float32, net.R_tensor_shape, name='R')

        # Compute image features
        net.im_feats = im_net(net, collapse_dims(net.ims))

        # Unproject feature grid
        net.cost_grid = proj_splat(net, net.im_feats, net.K, net.Rcam)

        # Combine grids with LSTM/GRU
        net.pool_grid, _ = rnn(net.cost_grid)

        # Grid network
        net.pool_grid = collapse_dims(net.pool_grid)
        net.pred_vox = grid_net(net, net.pool_grid)
        net.proj_vox = uncollapse_dims(net.grid_net['deconv3'], net.batch_size,
                                       net.im_batch)

        # Projection
        proj_vox_in = (net.proj_vox if not proj_last else net.proj_vox[:, -1:,
                                                                       ...])
        net.ray_slices, z_samples = proj_slice(net,
                                               proj_vox_in,
                                               net.K,
                                               net.Rcam,
                                               proj_size=net.im_h / proj_x,
                                               samples=ray_samples)

        bs, im_bs, ks, im_sz1, im_sz2, fdim, _ = tf_static_shape(
            net.ray_slices)
        net.depth_in = tf.reshape(
            net.ray_slices,
            [bs * im_bs * ks, im_sz1, im_sz2, fdim * ray_samples])
        # Depth network
        if proj_x == 4:
            if not sepup:
                net.depth_out = depth_net_x4(net, net.depth_in, im_skip)
            else:
                net.depth_out = depth_net_x4_sepup(net, net.depth_in, im_skip)
        elif proj_x == 8:
            if not sepup:
                net.depth_out = depth_net_x8(net, net.depth_in, im_skip)
            else:
                net.depth_out = depth_net_x8_sepup(net, net.depth_in, im_skip)
        else:
            logger = logging.getLogger('mview3d.' + __name__)
            logger.error(
                'Unsupported subsample ratio for projection. Use {4, 8}')

        net.depth_out = tf.reshape(net.depth_out,
                                   [bs, im_bs, ks, net.im_h, net.im_w, 1])
        return net
Ejemplo n.º 5
0
def train(net):
    net.gt_depth = tf.placeholder(tf.float32, net.depth_tensor_shape)
    net.pred_depth = net.depth_out
    out_shape = tf_static_shape(net.pred_depth)
    net.depth_loss = loss_l1(
        net.pred_depth, repeat_tensor(net.gt_depth, out_shape[1], rep_dim=1))

    _t_dbg = Timer()

    # Add optimizer
    global_step = tf.Variable(0, trainable=False, name='global_step')
    decay_lr = tf.train.exponential_decay(args.lr,
                                          global_step,
                                          args.decay_steps,
                                          args.decay_rate,
                                          staircase=True)
    lr_sum = tf.summary.scalar('lr', decay_lr)
    optim = tf.train.AdamOptimizer(decay_lr).minimize(net.depth_loss,
                                                      global_step)
    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver()

    # Add summaries for training
    net.loss_sum = tf.summary.scalar('loss', net.depth_loss)
    net.im_sum = image_sum(net.ims, net.batch_size, net.im_batch)
    net.depth_gt_sum = depth_sum(net.gt_depth, net.batch_size, net.im_batch,
                                 'depth_gt')
    net.depth_pred_sum = depth_sum(net.pred_depth[:, -1, ...], net.batch_size,
                                   net.im_batch,
                                   'depth_pred_{:d}'.format(net.im_batch))
    merged_ims = tf.summary.merge(
        [net.im_sum, net.depth_gt_sum, net.depth_pred_sum])
    merged_scalars = tf.summary.merge([net.loss_sum, lr_sum])

    # Initialize dataset
    coord = tf.train.Coordinator()
    dset = ShapeNet(im_dir=im_dir,
                    split_file=args.split_file,
                    rng_seed=0,
                    custom_db=args.custom_training)
    mids = dset.get_smids('train')
    logger.info('Training with %d models', len(mids))
    items = ['im', 'K', 'R', 'depth']
    dset.init_queue(mids,
                    net.im_batch,
                    items,
                    coord,
                    qsize=64,
                    nthreads=args.prefetch_threads)

    _t_dbg = Timer()
    iters = 0
    # Training loop
    pbar = tqdm(desc='Training Depth-LSM', total=args.niters)
    with tf.Session(config=get_session_config()) as sess:
        sum_writer = tf.summary.FileWriter(log_dir, sess.graph)
        if args.ckpt is not None:
            logger.info('Restoring from %s', args.ckpt)
            saver.restore(sess, args.ckpt)
        else:
            sess.run(init_op)
        try:
            while True:
                iters += 1
                _t_dbg.tic()
                batch_data = dset.next_batch(items, net.batch_size)
                logging.debug('Data read time - %.3fs', _t_dbg.toc())
                feed_dict = {
                    net.ims: batch_data['im'],
                    net.K: batch_data['K'],
                    net.Rcam: batch_data['R'],
                    net.gt_depth: batch_data['depth']
                }
                if args.run_trace and (iters % args.sum_iters == 0
                                       or iters == 1 or iters == args.niters):
                    run_options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()

                    step_, _, merged_scalars_ = sess.run(
                        [global_step, optim, merged_scalars],
                        feed_dict=feed_dict,
                        options=run_options,
                        run_metadata=run_metadata)
                    sum_writer.add_run_metadata(run_metadata, 'step%d' % step_)
                else:
                    step_, _, merged_scalars_ = sess.run(
                        [global_step, optim, merged_scalars],
                        feed_dict=feed_dict)

                logging.debug('Net time - %.3fs', _t_dbg.toc())

                sum_writer.add_summary(merged_scalars_, step_)
                if iters % args.sum_iters == 0 or iters == 1 or iters == args.niters:
                    image_sum_, step_ = sess.run([merged_ims, global_step],
                                                 feed_dict=feed_dict)
                    sum_writer.add_summary(image_sum_, step_)

                if iters % args.ckpt_iters == 0 or iters == args.niters:
                    save_f = saver.save(sess,
                                        osp.join(log_dir, 'mvnet'),
                                        global_step=global_step)
                    logger.info(' Model checkpoint - {:s} '.format(save_f))

                pbar.update(1)
                if iters >= args.niters:
                    break
        except Exception, e:
            logging.error(repr(e))
            dset.close_queue(e)
        finally:
Ejemplo n.º 6
0
def im_vgg16(net, ims):
    net.im_net = {}
    bs, h, w, ch = tf_static_shape(ims)
    with tf.variable_scope('ImNet_UNet', reuse=tf.AUTO_REUSE):
        #VGG16 layers
        conv1_1 = conv2d('conv1_1', ims, 3, 64, mode=net.mode, act=None)
        net.im_net['conv1_1'] = conv1_1
        conv1_2 = conv2d('conv1_2', conv1_1, 3, 64, mode=net.mode)
        net.im_net['conv1_2'] = conv1_2
        pool1 = tf.layers.max_pooling2d(conv1_2,
                                        2,
                                        2,
                                        padding='same',
                                        name='pool1')
        conv2_1 = conv2d('conv2_1', pool1, 3, 128, mode=net.mode)
        net.im_net['conv2_1'] = conv2_1
        conv2_2 = conv2d('conv2_2', conv2_1, 3, 128, mode=net.mode)
        net.im_net['conv2_2'] = conv2_2
        pool2 = tf.layers.max_pooling2d(conv2_2,
                                        2,
                                        2,
                                        padding='same',
                                        name='pool2')
        net.im_net['pool2'] = pool2
        conv3_1 = conv2d('conv3_1', pool2, 3, 256, mode=net.mode)
        net.im_net['conv3_1'] = conv3_1
        conv3_2 = conv2d('conv3_2', conv3_1, 3, 256, mode=net.mode)
        net.im_net['conv3_2'] = conv3_2
        conv3_3 = conv2d('conv3_3', conv3_2, 3, 256, mode=net.mode)
        net.im_net['conv3_3'] = conv3_3
        pool3 = tf.layers.max_pooling2d(conv3_3,
                                        2,
                                        2,
                                        padding='same',
                                        name='pool3')
        net.im_net['pool3'] = pool3
        conv4_1 = conv2d('conv4_1', pool3, 3, 512, mode=net.mode)
        net.im_net['conv4_1'] = conv4_1
        conv4_2 = conv2d('conv4_2', conv4_1, 3, 512, mode=net.mode)
        net.im_net['conv4_2'] = conv4_2
        conv4_3 = conv2d('conv4_3', conv4_2, 3, 512, mode=net.mode)
        net.im_net['conv4_3'] = conv4_3
        pool4 = tf.layers.max_pooling2d(conv4_3,
                                        2,
                                        2,
                                        padding='same',
                                        name='pool4')
        net.im_net['pool4'] = pool4
        conv5_1 = conv2d('conv5_1', pool4, 3, 512, mode=net.mode)
        net.im_net['conv5_1'] = conv5_1
        conv5_2 = conv2d('conv5_2', conv5_1, 3, 512, mode=net.mode)
        net.im_net['conv5_2'] = conv5_2
        conv5_3 = conv2d('conv5_3', conv5_2, 3, 512, mode=net.mode)
        net.im_net['conv5_3'] = conv5_3
        #Deconv layers
        feat_conv5 = conv2d('feat_conv5',
                            conv5_3,
                            1,
                            64,
                            norm=net.norm,
                            mode=net.mode)
        net.im_net['feat_conv5'] = feat_conv5
        upfeat_conv5 = deconv_pcnn(feat_conv5,
                                   4,
                                   4,
                                   64,
                                   2,
                                   2,
                                   name='upfeat_conv5',
                                   trainable=False)
        # upfeat_conv5 = deconv2d('upfeat_conv5', conv5_3, 4, 64, stride=2, padding="SAME", norm=net.norm, mode=net.mode)
        net.im_net['upfeat_conv5'] = upfeat_conv5
        feat_conv4 = conv2d('feat_conv4',
                            conv4_3,
                            1,
                            64,
                            norm=net.norm,
                            mode=net.mode)
        net.im_net['feat_conv4'] = feat_conv4
        add_feat = tf.add_n([upfeat_conv5, feat_conv4], name='add_feat')
        add_feat = dropout(add_feat, net.keep_prob)
        net.im_net['add_feat'] = add_feat
        upfeat = deconv_pcnn(add_feat,
                             16,
                             16,
                             64,
                             8,
                             8,
                             name='upfeat',
                             trainable=False)
        # upfeat = deconv2d('upfeat', add_feat, 16, 64, stride=8, padding="SAME", norm=net.norm, mode=net.mode)
        net.im_net['upfeat'] = upfeat
    return upfeat