コード例 #1
0
def get_sdf_3dcnn_binary(grid_idx, globalfeats, is_training, batch_size, num_point, bn, bn_decay, wd=None, FLAGS=None):
    globalfeats_expand = tf.reshape(globalfeats, [batch_size, 1, 1, 1, -1])
    print('globalfeats_expand', globalfeats_expand.get_shape())
    net2 = tf_util.conv3d_transpose(globalfeats_expand, 128, [2, 2, 2], stride=[2, 2, 2],
                                    bn_decay=bn_decay, bn=bn,
                                    is_training=is_training, weight_decay=wd, scope='3deconv1') # 2

    net2 = tf_util.conv3d_transpose(net2, 128, [3, 3, 3], stride=[2, 2, 2],bn_decay=bn_decay, bn=bn,
                                    is_training=is_training, weight_decay=wd, scope='3deconv2') # 4

    net2 = tf_util.conv3d_transpose(net2, 128, [3, 3, 3], stride=[2, 2, 2], bn_decay=bn_decay, bn=bn,
                                    is_training=is_training, weight_decay=wd, scope='3deconv3')  # 8

    net2 = tf_util.conv3d_transpose(net2, 64, [3, 3, 3], stride=[2, 2, 2], bn_decay=bn_decay, bn=bn,
                                    is_training=is_training, weight_decay=wd, scope='3deconv4')  # 16

    net2 = tf_util.conv3d_transpose(net2, 64, [3, 3, 3], stride=[2, 2, 2], bn_decay=bn_decay, bn=bn,
                                    is_training=is_training, weight_decay=wd, scope='3deconv5')  # 32

    net2 = tf_util.conv3d_transpose(net2, 32, [3, 3, 3], stride=[2, 2, 2], bn_decay=bn_decay, bn=bn,
                                    is_training=is_training, weight_decay=wd, padding='VALID', scope='3deconv6') # 65

    net2 = tf_util.conv3d(net2, 2, [1, 1, 1], stride=[1, 1, 1], bn_decay=bn_decay, bn=bn, activation_fn=None,
                                is_training=is_training, weight_decay=wd, padding='VALID', scope='3conv7_binary')
    res_plus = FLAGS.sdf_res+1
    full_inter = tf.reshape(net2, (batch_size, res_plus, res_plus, res_plus))

    print("3d cnn net2 shape:", full_inter.get_shape())

    pred = tf.reshape(full_inter, [batch_size, -1, 2])
    return pred
コード例 #2
0
def get_model(pointgrid, is_training):
    # Args:
    #     pointgrid: of size B x N x N x N x NUM_FEATURES
    #     is_training: boolean tensor
    # Returns:
    #     pred_cat: of size B x NUM_CATEGORY
    #     pred_seg: of size B x N x N x N x (K+1) x NUM_PART_SEG

    # Encoder
    batch_size = pointgrid.get_shape()[0].value
    conv1 = tf_util.conv3d(pointgrid,
                           64, [5, 5, 5],
                           scope='conv1',
                           activation_fn=leak_relu,
                           bn=True,
                           is_training=is_training)  # N
    conv2 = tf_util.conv3d(conv1,
                           64, [5, 5, 5],
                           scope='conv2',
                           activation_fn=leak_relu,
                           stride=[2, 2, 2],
                           bn=True,
                           is_training=is_training)  # N/2
    conv3 = tf_util.conv3d(conv2,
                           64, [5, 5, 5],
                           scope='conv3',
                           activation_fn=leak_relu,
                           bn=True,
                           is_training=is_training)  # N/2
    conv4 = tf_util.conv3d(conv3,
                           128, [3, 3, 3],
                           scope='conv4',
                           activation_fn=leak_relu,
                           stride=[2, 2, 2],
                           bn=True,
                           is_training=is_training)  # N/4
    conv5 = tf_util.conv3d(conv4,
                           128, [3, 3, 3],
                           scope='conv5',
                           activation_fn=leak_relu,
                           bn=True,
                           is_training=is_training)  # N/4
    conv6 = tf_util.conv3d(conv5,
                           256, [3, 3, 3],
                           scope='conv6',
                           activation_fn=leak_relu,
                           stride=[2, 2, 2],
                           bn=True,
                           is_training=is_training)  # N/8
    conv7 = tf_util.conv3d(conv6,
                           256, [3, 3, 3],
                           scope='conv7',
                           activation_fn=leak_relu,
                           bn=True,
                           is_training=is_training)  # N/8
    conv8 = tf_util.conv3d(conv7,
                           512, [3, 3, 3],
                           scope='conv8',
                           activation_fn=leak_relu,
                           stride=[2, 2, 2],
                           bn=True,
                           is_training=is_training)  # N/16
    conv9 = tf_util.conv3d(conv8,
                           512, [1, 1, 1],
                           scope='conv9',
                           activation_fn=leak_relu,
                           bn=True,
                           is_training=is_training)  # N/16

    # Classification Network
    conv9_flat = tf.reshape(conv9, [batch_size, -1])
    fc1 = tf_util.fully_connected(conv9_flat,
                                  512,
                                  activation_fn=leak_relu,
                                  bn=True,
                                  is_training=is_training,
                                  scope='fc1')
    do1 = tf_util.dropout(fc1,
                          keep_prob=0.7,
                          is_training=is_training,
                          scope='do1')
    fc2 = tf_util.fully_connected(do1,
                                  256,
                                  activation_fn=leak_relu,
                                  bn=True,
                                  is_training=is_training,
                                  scope='fc2')
    do2 = tf_util.dropout(fc2,
                          keep_prob=0.7,
                          is_training=is_training,
                          scope='do2')
    pred_cat = tf_util.fully_connected(do2,
                                       NUM_CATEGORY,
                                       activation_fn=None,
                                       bn=False,
                                       scope='pred_cat')

    # Segmentation Network
    cat_features = tf.tile(
        tf.reshape(tf.concat([fc2, pred_cat], axis=1),
                   [batch_size, 1, 1, 1, -1]), [1, N / 16, N / 16, N / 16, 1])
    conv9_cat = tf.concat([conv9, cat_features], axis=4)
    deconv1 = tf_util.conv3d_transpose(conv9_cat,
                                       256, [3, 3, 3],
                                       scope='deconv1',
                                       activation_fn=leak_relu,
                                       bn=True,
                                       is_training=is_training,
                                       stride=[2, 2, 2],
                                       padding='SAME')  # N/8
    conv7_deconv1 = tf.concat(axis=4, values=[conv7, deconv1])
    deconv2 = tf_util.conv3d(conv7_deconv1,
                             256, [3, 3, 3],
                             scope='deconv2',
                             activation_fn=leak_relu,
                             bn=True,
                             is_training=is_training)  # N/8
    deconv3 = tf_util.conv3d_transpose(deconv2,
                                       128, [3, 3, 3],
                                       scope='deconv3',
                                       activation_fn=leak_relu,
                                       bn=True,
                                       is_training=is_training,
                                       stride=[2, 2, 2],
                                       padding='SAME')  # N/4
    conv5_deconv3 = tf.concat(axis=4, values=[conv5, deconv3])
    deconv4 = tf_util.conv3d(conv5_deconv3,
                             128, [3, 3, 3],
                             scope='deconv4',
                             activation_fn=leak_relu,
                             bn=True,
                             is_training=is_training)  # N/4
    deconv5 = tf_util.conv3d_transpose(deconv4,
                                       64, [3, 3, 3],
                                       scope='deconv5',
                                       activation_fn=leak_relu,
                                       bn=True,
                                       is_training=is_training,
                                       stride=[2, 2, 2],
                                       padding='SAME')  # N/2
    conv3_deconv5 = tf.concat(axis=4, values=[conv3, deconv5])
    deconv6 = tf_util.conv3d(conv3_deconv5,
                             64, [5, 5, 5],
                             scope='deconv6',
                             activation_fn=leak_relu,
                             bn=True,
                             is_training=is_training)  # N/2
    deconv7 = tf_util.conv3d_transpose(deconv6,
                                       64, [5, 5, 5],
                                       scope='deconv7',
                                       activation_fn=leak_relu,
                                       bn=True,
                                       is_training=is_training,
                                       stride=[2, 2, 2],
                                       padding='SAME')  # N
    conv1_deconv7 = tf.concat(axis=4, values=[conv1, deconv7])
    deconv8 = tf_util.conv3d(conv1_deconv7,
                             64, [5, 5, 5],
                             scope='deconv8',
                             activation_fn=leak_relu,
                             bn=True,
                             is_training=is_training)  # N

    pred_seg = tf_util.conv3d(deconv8, (K + 1) * NUM_SEG_PART, [5, 5, 5],
                              scope='pred_seg',
                              activation_fn=None,
                              bn=False,
                              is_training=is_training)
    pred_seg = tf.reshape(pred_seg, [batch_size, N, N, N, K + 1, NUM_SEG_PART])

    return pred_cat, pred_seg
コード例 #3
0
ファイル: pointnet_cls.py プロジェクト: teddyz829/PointVox
def get_model(point_cloud, is_training, bn_decay=None):
    """ Classification PointNet, input is BxNx3, output Bx40 """
    batch_size = point_cloud.get_shape()[0].value
    num_point = point_cloud.get_shape()[1].value
    end_points = {}

    with tf.variable_scope('transform_net1') as sc:
        transform = input_transform_net(point_cloud,
                                        is_training,
                                        bn_decay,
                                        K=3)
    point_cloud_transformed = tf.matmul(point_cloud, transform)
    input_image = tf.expand_dims(point_cloud_transformed, -1)

    net = tf_util.conv2d(input_image,
                         64, [1, 3],
                         padding='VALID',
                         stride=[1, 1],
                         bn=True,
                         is_training=is_training,
                         scope='conv1',
                         bn_decay=bn_decay)
    net = tf_util.conv2d(net,
                         64, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn=True,
                         is_training=is_training,
                         scope='conv2',
                         bn_decay=bn_decay)

    with tf.variable_scope('transform_net2') as sc:
        transform = feature_transform_net(net, is_training, bn_decay, K=64)
    end_points['transform'] = transform
    net_transformed = tf.matmul(tf.squeeze(net, axis=[2]), transform)
    net_transformed = tf.expand_dims(net_transformed, [2])

    net = tf_util.conv2d(net_transformed,
                         64, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn=True,
                         is_training=is_training,
                         scope='conv3',
                         bn_decay=bn_decay)
    net = tf_util.conv2d(net,
                         128, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn=True,
                         is_training=is_training,
                         scope='conv4',
                         bn_decay=bn_decay)
    net = tf_util.conv2d(net,
                         1024, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn=True,
                         is_training=is_training,
                         scope='conv5',
                         bn_decay=bn_decay)

    # Symmetric function: max pooling
    net = tf_util.max_pool2d(net, [num_point, 1],
                             padding='VALID',
                             scope='maxpool')

    # net = tf.reshape(net, [batch_size, -1])
    net = tf.reshape(net, [batch_size, 1, 1, 1, -1])
    net = tf_util.conv3d_transpose(net,
                                   64, [4, 4, 4],
                                   padding='VALID',
                                   stride=[1, 1, 1],
                                   scope='deconv1',
                                   bn=True,
                                   is_training=is_training)
    net = tf_util.conv3d_transpose(net,
                                   32, [6, 6, 6],
                                   padding='VALID',
                                   stride=[2, 2, 2],
                                   scope='deconv2',
                                   bn=True,
                                   is_training=is_training)
    net = tf_util.conv3d_transpose(net,
                                   1, [8, 8, 8],
                                   padding='VALID',
                                   stride=[2, 2, 2],
                                   scope='deconv3',
                                   bn=True,
                                   is_training=is_training)
    net = tf.reshape(net, [batch_size, -1])
    # net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training,
    #                               scope='fc1', bn_decay=bn_decay)
    # net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training,
    #                       scope='dp1')
    # net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training,
    #                               scope='fc2', bn_decay=bn_decay)
    # net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training,
    #                       scope='dp2')
    # net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3')

    # return net, end_points
    return net, end_points