Esempio n. 1
0
def get_model(point_cloud,
              is_training,
              num_class,
              global_pl,
              params,
              weight_decay=None,
              bn_decay=None,
              scname=''):
    ''' input: BxNxF
    Use https://arxiv.org/pdf/1902.08570 as baseline
    output:BxNx(cats*segms)  '''
    batch_size = point_cloud.get_shape()[0].value
    num_point = point_cloud.get_shape()[1].value
    input_image = tf.expand_dims(point_cloud, -2)

    k = params[0]
    #adj = tf_util.pairwise_distance(point_cloud[:,:,:3])
    adj = tf_util.pairwise_distanceR(point_cloud[:, :, :3])
    n_heads = params[1]
    nn_idx = tf_util.knn(adj, k=k)

    net, locals_transform, coefs = gap_block(k, n_heads, nn_idx, point_cloud,
                                             point_cloud,
                                             ('filter0', params[2]), bn_decay,
                                             weight_decay, is_training, scname)

    net = tf_util.conv2d(net,
                         params[3], [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn=True,
                         is_training=is_training,
                         scope='gapnet00',
                         bn_decay=bn_decay)
    net00 = net

    net = tf_util.conv2d(net,
                         params[4], [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=tf.nn.relu,
                         bn=True,
                         is_training=is_training,
                         scope='gapnet01' + scname,
                         bn_decay=bn_decay)
    net01 = net

    net = tf_util.conv2d(net,
                         params[5], [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=tf.nn.relu,
                         bn=True,
                         is_training=is_training,
                         scope='gapnet02' + scname,
                         bn_decay=bn_decay)

    net02 = net

    adj_matrix = tf_util.pairwise_distance(net)
    nn_idx = tf_util.knn(adj_matrix, k=k)
    adj_conv = nn_idx
    n_heads = params[6]

    net, locals_transform1, coefs2 = gap_block(k, n_heads, nn_idx, net,
                                               point_cloud,
                                               ('filter1', params[7]),
                                               bn_decay, weight_decay,
                                               is_training, scname)

    net = tf_util.conv2d(net,
                         params[8], [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn=True,
                         is_training=is_training,
                         scope='gapnet10',
                         bn_decay=bn_decay)
    net10 = net

    net = tf_util.conv2d(net,
                         params[9], [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=tf.nn.relu,
                         bn=True,
                         is_training=is_training,
                         scope='gapnet11' + scname,
                         bn_decay=bn_decay)
    net11 = net

    net = tf_util.conv2d(net,
                         params[10], [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=tf.nn.relu,
                         bn=True,
                         is_training=is_training,
                         scope='gapnet12' + scname,
                         bn_decay=bn_decay)

    net12 = net
    global_expand = tf.reshape(global_pl, [batch_size, 1, 1, -1])
    global_expand = tf.tile(global_expand, [1, num_point, 1, 1])
    global_expand = tf_util.conv2d(global_expand,
                                   16, [1, 1],
                                   padding='VALID',
                                   stride=[1, 1],
                                   bn=True,
                                   is_training=is_training,
                                   scope='global_expand' + scname,
                                   bn_decay=bn_decay)

    net = tf.concat([
        net00, net01, net02, net11, net12, global_expand, locals_transform,
        locals_transform1
    ],
                    axis=-1)

    net = tf_util.conv2d(net,
                         params[8], [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=tf.nn.relu,
                         bn=True,
                         is_training=is_training,
                         scope='agg' + scname,
                         bn_decay=bn_decay)
    #net_tot = net
    net = tf_util.avg_pool2d(net, [num_point, 1],
                             padding='VALID',
                             scope='avgpool' + scname)

    expand = tf.tile(net, [1, num_point, 1, 1])
    # net = tf.concat(axis=3, values=[expand,
    #                                 net_tot,
    #                             ])
    net = tf_util.conv2d(expand,
                         params[11], [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn_decay=bn_decay,
                         bn=True,
                         is_training=is_training,
                         scope='seg/conv2',
                         weight_decay=weight_decay,
                         is_dist=True)
    net = tf_util.dropout(net,
                          keep_prob=0.6,
                          is_training=is_training,
                          scope='seg/dp1')
    net = tf_util.conv2d(net,
                         params[11], [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn_decay=bn_decay,
                         bn=True,
                         is_training=is_training,
                         scope='seg/conv3',
                         weight_decay=weight_decay,
                         is_dist=True)
    net = tf_util.dropout(net,
                          keep_prob=0.6,
                          is_training=is_training,
                          scope='seg/dp2')
    net = tf_util.conv2d(net,
                         params[12], [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn_decay=bn_decay,
                         bn=True,
                         is_training=is_training,
                         scope='seg/conv4',
                         weight_decay=weight_decay,
                         is_dist=True)

    net = tf_util.conv2d(net,
                         num_class, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=None,
                         bn=False,
                         scope='seg/conv5',
                         weight_decay=weight_decay,
                         is_dist=True)

    net = tf.cond(is_training, lambda: net, lambda: tf.nn.softmax(net))

    net = tf.reshape(net, [batch_size, num_point, num_class])

    return net, coefs, coefs2, adj_conv
def get_model(point_cloud,
              is_training,
              num_class,
              weight_decay=None,
              bn_decay=None,
              scname=''):
    ''' input: BxNxF
    output:BxNx(cats*segms)  '''
    batch_size = point_cloud.get_shape()[0]
    num_point = point_cloud.get_shape()[1]
    num_feat = point_cloud.get_shape()[2]

    k = 10
    adj = tf_util.pairwise_distanceR(point_cloud[:, :, :3])
    n_heads = 1
    nn_idx = tf_util.knn(adj, k=k)

    net, locals_transform, coefs = gap_block(k, n_heads, nn_idx, point_cloud,
                                             point_cloud, ('filter0', 16),
                                             bn_decay, weight_decay,
                                             is_training, scname)

    net = tf_util.conv2d(net,
                         64, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=tf.nn.relu,
                         bn=True,
                         is_training=is_training,
                         scope='gapnet01' + scname,
                         bn_decay=bn_decay)
    net01 = net

    net = tf_util.conv2d(net,
                         128, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=tf.nn.relu,
                         bn=True,
                         is_training=is_training,
                         scope='gapnet02' + scname,
                         bn_decay=bn_decay)

    net02 = net
    adj_matrix = tf_util.pairwise_distance(net)
    nn_idx = tf_util.knn(adj_matrix, k=k)
    adj_conv = nn_idx
    n_heads = 1

    net, locals_transform1, coefs2 = gap_block(k, n_heads, nn_idx, net,
                                               point_cloud, ('filter1', 128),
                                               bn_decay, weight_decay,
                                               is_training, scname)

    net = tf_util.conv2d(net,
                         256, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=tf.nn.relu,
                         bn=True,
                         is_training=is_training,
                         scope='gapnet11' + scname,
                         bn_decay=bn_decay)
    net11 = net

    net = tf_util.conv2d(net,
                         256, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=tf.nn.relu,
                         bn=True,
                         is_training=is_training,
                         scope='gapnet12' + scname,
                         bn_decay=bn_decay)

    net12 = net

    net = tf.concat(
        [net01, net02, net11, net12, locals_transform, locals_transform1],
        axis=-1)

    net = tf_util.conv2d(net,
                         3, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=tf.nn.relu,
                         bn=True,
                         is_training=is_training,
                         scope='agg' + scname,
                         bn_decay=bn_decay)

    net = tf_util.avg_pool2d(net, [num_point, 1],
                             padding='VALID',
                             scope='avgpool' + scname)
    max_pool = net

    net = tf.reshape(net, [batch_size, -1])
    net = tf_util.fully_connected(net,
                                  256,
                                  bn=True,
                                  is_training=is_training,
                                  activation_fn=tf.nn.relu,
                                  scope='fc1' + scname,
                                  bn_decay=bn_decay)
    net = tf_util.fully_connected(net,
                                  128,
                                  bn=True,
                                  is_training=is_training,
                                  activation_fn=tf.nn.relu,
                                  scope='fc2' + scname,
                                  bn_decay=bn_decay)
    net = tf_util.fully_connected(net,
                                  num_class,
                                  activation_fn=None,
                                  scope='fc3' + scname)

    net = tf.squeeze(net)

    return net, max_pool
Esempio n. 3
0
def get_model(point_cloud,
              is_training,
              num_class,
              global_pl,
              weight_decay=None,
              bn_decay=None,
              scname=''):
    ''' input: BxNxF
    output:BxNx(cats*segms)  '''
    batch_size = point_cloud.get_shape()[0].value
    num_point = point_cloud.get_shape()[1].value
    num_feat = point_cloud.get_shape()[2].value

    k = 10
    adj = tf_util.pairwise_distanceR(point_cloud[:, :, :3])
    n_heads = 1
    nn_idx = tf_util.knn(adj, k=k)

    net, locals_transform, coefs = gap_block(k, n_heads, nn_idx, point_cloud,
                                             point_cloud, ('filter0', 16),
                                             bn_decay, weight_decay,
                                             is_training, scname)

    net = tf_util.conv2d(net,
                         64, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=tf.nn.relu,
                         bn=True,
                         is_training=is_training,
                         scope='gapnet01' + scname,
                         bn_decay=bn_decay)
    net01 = net

    net = tf_util.conv2d(net,
                         128, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=tf.nn.relu,
                         bn=True,
                         is_training=is_training,
                         scope='gapnet02' + scname,
                         bn_decay=bn_decay)

    net02 = net
    adj_matrix = tf_util.pairwise_distance(net)
    nn_idx = tf_util.knn(adj_matrix, k=k)
    adj_conv = nn_idx
    n_heads = 1

    net, locals_transform1, coefs2 = gap_block(k, n_heads, nn_idx, net,
                                               point_cloud, ('filter1', 128),
                                               bn_decay, weight_decay,
                                               is_training, scname)

    net = tf_util.conv2d(net,
                         256, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=tf.nn.relu,
                         bn=True,
                         is_training=is_training,
                         scope='gapnet11' + scname,
                         bn_decay=bn_decay)
    net11 = net

    net = tf_util.conv2d(net,
                         256, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=tf.nn.relu,
                         bn=True,
                         is_training=is_training,
                         scope='gapnet12' + scname,
                         bn_decay=bn_decay)

    net12 = net

    global_expand = tf.reshape(global_pl, [batch_size, 1, 1, -1])
    global_expand = tf.tile(global_expand, [1, num_point, 1, 1])
    global_expand = tf_util.conv2d(global_expand,
                                   16, [1, 1],
                                   padding='VALID',
                                   stride=[1, 1],
                                   bn=True,
                                   is_training=is_training,
                                   scope='global_expand' + scname,
                                   bn_decay=bn_decay)

    net = tf.concat([
        net01, net02, net11, net12, global_expand, locals_transform,
        locals_transform1
    ],
                    axis=-1)

    net = tf_util.conv2d(net,
                         2, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=tf.nn.relu,
                         bn=True,
                         is_training=is_training,
                         scope='agg' + scname,
                         bn_decay=bn_decay)

    net = tf_util.max_pool2d(net, [num_point, 1],
                             padding='VALID',
                             scope='avgpool' + scname)
    max_pool = net
    expand = tf.tile(net, [1, num_point, 1, 1])
    net = tf.concat(axis=3, values=[
        expand,
        net01,
        net11,
    ])
    net = tf_util.conv2d(net,
                         256, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn_decay=bn_decay,
                         bn=True,
                         is_training=is_training,
                         scope='seg/conv2',
                         weight_decay=weight_decay,
                         is_dist=True)
    net = tf_util.conv2d(net,
                         128, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         bn_decay=bn_decay,
                         bn=True,
                         is_training=is_training,
                         scope='seg/conv4',
                         weight_decay=weight_decay,
                         is_dist=True)

    net = tf_util.conv2d(net,
                         num_class, [1, 1],
                         padding='VALID',
                         stride=[1, 1],
                         activation_fn=None,
                         bn=False,
                         scope='seg/conv5',
                         weight_decay=weight_decay,
                         is_dist=True)

    net = tf.reshape(net, [batch_size, num_point, num_class])

    return net, max_pool
Esempio n. 4
0
def get_model(point_cloud,
              mask,
              is_training,
              num_class,
              weight_decay=None,
              bn_decay=None,
              scname=''):
    batch_size = point_cloud.get_shape()[0]

    k = 20
    adj, mask_matrix = tf_util.pairwise_distanceR(point_cloud[:, :, :3], mask)
    nn_idx = tf_util.knn(adj, k=k)

    edge_feature_0 = get_edge_feature(point_cloud, nn_idx=nn_idx, k=k)
    features_0 = GetLocalFeat(edge_feature_0,
                              scname + 'local0',
                              128,
                              is_training,
                              bn_decay=bn_decay)

    adj = tf_util.pairwise_distance(features_0, mask_matrix)
    nn_idx = tf_util.knn(adj, k=k)

    edge_feature_1 = get_edge_feature(features_0, nn_idx=nn_idx, k=k)
    features_1 = GetLocalFeat(edge_feature_1,
                              scname + 'local1',
                              64,
                              is_training,
                              bn_decay=bn_decay)

    self_att_1, attention1 = GetSelfAtt(tf.squeeze(features_1),
                                        mask,
                                        scname + 'att1',
                                        64,
                                        is_training,
                                        bn_decay=bn_decay)
    self_att_2, attention2 = GetSelfAtt(self_att_1,
                                        mask,
                                        scname + 'att2',
                                        64,
                                        is_training,
                                        bn_decay=bn_decay)
    self_att_3, attention3 = GetSelfAtt(self_att_2,
                                        mask,
                                        scname + 'att3',
                                        64,
                                        is_training,
                                        bn_decay=bn_decay)

    concat = tf.concat([
        self_att_1,
        self_att_2,
        self_att_3,
        tf.squeeze(features_1),
    ],
                       axis=-1)

    net = tf_util.conv1d(concat,
                         256,
                         1,
                         padding='VALID',
                         stride=1,
                         activation_fn=tf.nn.relu,
                         bn=True,
                         is_training=is_training,
                         scope=scname + 'concat',
                         bn_decay=bn_decay)

    net = tf.reduce_mean(net, axis=2, keep_dims=True)
    net = tf.reshape(net, [batch_size, -1])

    net = tf_util.fully_connected(net,
                                  128,
                                  bn=True,
                                  is_training=is_training,
                                  activation_fn=tf.nn.relu,
                                  scope=scname + 'fc1',
                                  bn_decay=bn_decay)
    net = tf_util.dropout(net,
                          keep_prob=0.5,
                          is_training=is_training,
                          scope=scname + 'dp1')
    net = tf_util.fully_connected(net,
                                  num_class,
                                  activation_fn=None,
                                  scope='fc3' + scname)

    return net, attention1, attention2, attention3