예제 #1
0
def input_transform_net_edge_net(point_cloud, mask,is_training, bn_decay=None, K=4):
    """ Input (XYZ) Transform Net, input is BxNx3 gray image
        Return:
            Transformation matrix of size 3xK """
    batch_size = point_cloud.get_shape()[0].value
    num_point = point_cloud.get_shape()[1].value

    net, kernel,_,_ = edge_net.edge_unit(point_cloud, mask, 'max', scope='tconv1', bn=True, is_training=is_training,
                            bn_decay=bn_decay)
    net = tf_util.conv2d(net, 128, [1,1],
                         padding='VALID', stride=[1,1],
                         bn=True, is_training=is_training,
                         scope='tconv2', bn_decay=bn_decay)
    net = tf_util.conv2d(net, 1024, [1,1],
                         padding='VALID', stride=[1,1],
                         bn=True, is_training=is_training,
                         scope='tconv3', bn_decay=bn_decay)
    net = tf_util.max_pool2d(net, [num_point,1],
                             padding='VALID', scope='tmaxpool')

    net = tf.reshape(net, [batch_size, -1])
    net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training,
                                  scope='tfc1', bn_decay=bn_decay)
    net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training,
                                  scope='tfc2', bn_decay=bn_decay)

    with tf.variable_scope('transform_XYZ') as sc:
        if (K == 4):
            weights = tf.get_variable('weights', [256, 4 * 4],
                                          initializer=tf.constant_initializer(0.0),
                                          dtype=tf.float32)
            biases = tf.get_variable('biases', [4 * 4],
                                         initializer=tf.constant_initializer(0.0),
                                         dtype=tf.float32)
            biases += tf.constant(np.eye(K).flatten(), dtype=tf.float32)
            transform = tf.matmul(net, weights)
            transform = tf.nn.bias_add(transform, biases)

            transform = tf.reshape(transform, [batch_size, 4, 4])
    return transform
예제 #2
0
def get_model_ec(group_data, mask, is_training, bn_decay=None):
    # groupdata B N K*C
    batch_size = group_data.get_shape()[0].value
    num_point = group_data.get_shape()[1].value
    ec = econ.create_ec(group_data, mask)  # B N K ec_leghth
    ec_length = ec.get_shape()[3].value
    ec = tf.reshape(ec, [batch_size, num_point, -1])  # B N 9
    with tf.variable_scope('transform_net1_ec') as sc:
        transform = input_transform_net_edge_net(ec,
                                                 mask,
                                                 is_training,
                                                 bn_decay,
                                                 K=ec_length)

    ec_transformed = tf.matmul(tf.reshape(ec, [batch_size, -1, ec_length]),
                               transform)
    ec_transformed = tf.reshape(ec_transformed, [batch_size, num_point, -1])
    input_image = ec_transformed

    with tf.variable_scope('ec_net1') as sc:
        net, kernel, max_index_local_ec, masked_result = edge_net.edge_unit(
            input_image,
            mask,
            'max',
            config.neighbor_num,
            32,
            scope='conv1',
            bn=True,
            is_training=is_training,
            bn_decay=bn_decay)

    net = tf_util.conv2d(net,
                         64, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn=True,
                         is_training=is_training,
                         scope='conv1',
                         bn_decay=bn_decay)

    with tf.variable_scope('transform_net2') as sc:
        transform = feature_transform_net(net, is_training, bn_decay, K=64)

    net_transformed = tf.matmul(tf.squeeze(net), transform)
    net_transformed = tf.expand_dims(net_transformed, [2])

    net = tf_util.conv2d(net_transformed,
                         64, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn=True,
                         is_training=is_training,
                         scope='conv2',
                         bn_decay=bn_decay)
    net = tf_util.conv2d(net,
                         128, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn=True,
                         is_training=is_training,
                         scope='conv3',
                         bn_decay=bn_decay)
    net = tf_util.conv2d(net,
                         1024, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn=True,
                         is_training=is_training,
                         scope='conv4',
                         bn_decay=bn_decay)

    max_index_ec = tf.squeeze(tf.argmax(net,
                                        1))  # Symmetric function: max pooling
    net = tf_util.max_pool2d(net, [num_point, 1],
                             padding='VALID',
                             scope='maxpool')

    net = tf.reshape(net, [batch_size, -1])
    net = tf_util.fully_connected(net,
                                  512,
                                  bn=True,
                                  is_training=is_training,
                                  scope='fc1',
                                  bn_decay=bn_decay)
    net = tf_util.dropout(net,
                          keep_prob=0.7,
                          is_training=is_training,
                          scope='dp1')
    net = tf_util.fully_connected(net,
                                  256,
                                  bn=True,
                                  is_training=is_training,
                                  scope='fc2',
                                  bn_decay=bn_decay)
    net = tf_util.dropout(net,
                          keep_prob=0.7,
                          is_training=is_training,
                          scope='dp2')

    return net, transform, max_index_ec, max_index_local_ec, masked_result
예제 #3
0
def get_model_groupdata(group_data, mask, is_training, bn_decay=None):
    """ Classification PointNet, input is BxNx4, output Bx40 """
    batch_size = group_data.get_shape()[0].value  #32
    num_point = group_data.get_shape()[1].value  #1024

    with tf.variable_scope('transform_net1') as sc:
        transform = input_transform_net_edge_net(group_data,
                                                 mask,
                                                 is_training,
                                                 bn_decay,
                                                 K=4)

    group_data_transformed = tf.matmul(
        tf.reshape(group_data, [batch_size, -1, 4]), transform)
    group_data_transformed = tf.reshape(group_data_transformed,
                                        [batch_size, num_point, -1])  # B N K C
    #input_image = tf.expand_dims(group_data_transformed, -1)
    input_image = group_data_transformed
    with tf.variable_scope('edge_net1') as sc:
        net, kernel, max_index_local_neighbor, masked_result = edge_net.edge_unit(
            input_image,
            mask,
            'max',
            config.neighbor_num,
            32,
            scope='conv1',
            bn=True,
            is_training=is_training,
            bn_decay=bn_decay)

    net = tf_util.conv2d(net,
                         64, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn=True,
                         is_training=is_training,
                         scope='conv1',
                         bn_decay=bn_decay)

    with tf.variable_scope('transform_net2') as sc:
        transform = feature_transform_net(net, is_training, bn_decay, K=64)

    net_transformed = tf.matmul(tf.squeeze(net), transform)
    net_transformed = tf.expand_dims(net_transformed, [2])

    net = tf_util.conv2d(net_transformed,
                         64, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn=True,
                         is_training=is_training,
                         scope='conv2',
                         bn_decay=bn_decay)
    net = tf_util.conv2d(net,
                         128, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn=True,
                         is_training=is_training,
                         scope='conv3',
                         bn_decay=bn_decay)
    net = tf_util.conv2d(net,
                         1024, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn=True,
                         is_training=is_training,
                         scope='conv4',
                         bn_decay=bn_decay)

    # Symmetric function: max pooling
    max_index_neighbor = tf.squeeze(tf.argmax(net, 1))
    net = tf_util.max_pool2d(net, [num_point, 1],
                             padding='VALID',
                             scope='maxpool')

    net = tf.reshape(net, [batch_size, -1])
    net = tf_util.fully_connected(net,
                                  512,
                                  bn=True,
                                  is_training=is_training,
                                  scope='fc1',
                                  bn_decay=bn_decay)
    net = tf_util.dropout(net,
                          keep_prob=0.7,
                          is_training=is_training,
                          scope='dp1')
    net = tf_util.fully_connected(net,
                                  256,
                                  bn=True,
                                  is_training=is_training,
                                  scope='fc2',
                                  bn_decay=bn_decay)
    net = tf_util.dropout(net,
                          keep_prob=0.7,
                          is_training=is_training,
                          scope='dp2')

    return net, transform, max_index_neighbor, max_index_local_neighbor