Ejemplo n.º 1
0
def get_model(point_cloud: tf.Tensor, features: tf.Tensor, is_training: tf.Variable, num_class: int, bn_decay=None) -> \
        [tf.Tensor, tf.Tensor]:
    """
    Return a PointNet++ model using additional features as input for the first layer

    :param point_cloud: Input points for the model (BxNx3)
    :param features: The features for each point (BxNxk)
    :param is_training: Flag whether or not the parameters should be trained or not
    :param num_class: Number of classes (e.g. 21 for ScanNet)
    :param bn_decay: BatchNorm decay
    :return: predictions for each point (B x N x num_class)
    """
    end_points = {}
    l0_xyz = point_cloud
    l0_points = features
    end_points['l0_xyz'] = l0_xyz

    # Layer 1
    l1_xyz, l1_points, l1_indices = pointnet_sa_module(l0_xyz,
                                                       l0_points,
                                                       npoint=1024,
                                                       radius=0.1,
                                                       nsample=32,
                                                       mlp=[32, 32, 64],
                                                       mlp2=None,
                                                       group_all=False,
                                                       is_training=is_training,
                                                       bn_decay=bn_decay,
                                                       scope='layer1')
    l2_xyz, l2_points, l2_indices = pointnet_sa_module(l1_xyz,
                                                       l1_points,
                                                       npoint=256,
                                                       radius=0.2,
                                                       nsample=32,
                                                       mlp=[64, 64, 128],
                                                       mlp2=None,
                                                       group_all=False,
                                                       is_training=is_training,
                                                       bn_decay=bn_decay,
                                                       scope='layer2')
    l3_xyz, l3_points, l3_indices = pointnet_sa_module(l2_xyz,
                                                       l2_points,
                                                       npoint=64,
                                                       radius=0.4,
                                                       nsample=32,
                                                       mlp=[128, 128, 256],
                                                       mlp2=None,
                                                       group_all=False,
                                                       is_training=is_training,
                                                       bn_decay=bn_decay,
                                                       scope='layer3')
    l4_xyz, l4_points, l4_indices = pointnet_sa_module(l3_xyz,
                                                       l3_points,
                                                       npoint=16,
                                                       radius=0.8,
                                                       nsample=32,
                                                       mlp=[256, 256, 512],
                                                       mlp2=None,
                                                       group_all=False,
                                                       is_training=is_training,
                                                       bn_decay=bn_decay,
                                                       scope='layer4')

    # Feature Propagation layers
    l3_points = pointnet_fp_module(l3_xyz,
                                   l4_xyz,
                                   l3_points,
                                   l4_points, [256, 256],
                                   is_training,
                                   bn_decay,
                                   scope='fa_layer1')
    l2_points = pointnet_fp_module(l2_xyz,
                                   l3_xyz,
                                   l2_points,
                                   l3_points, [256, 256],
                                   is_training,
                                   bn_decay,
                                   scope='fa_layer2')
    l1_points = pointnet_fp_module(l1_xyz,
                                   l2_xyz,
                                   l1_points,
                                   l2_points, [256, 128],
                                   is_training,
                                   bn_decay,
                                   scope='fa_layer3')
    l0_points = pointnet_fp_module(l0_xyz,
                                   l1_xyz,
                                   l0_points,
                                   l1_points, [128, 128, 128],
                                   is_training,
                                   bn_decay,
                                   scope='fa_layer4')

    # Full connected layers
    net = tf_util.conv1d(l0_points,
                         128,
                         1,
                         padding='VALID',
                         bn=True,
                         is_training=is_training,
                         scope='fc1',
                         bn_decay=bn_decay)
    end_points['feats'] = net
    net = tf_util.dropout(net,
                          keep_prob=0.5,
                          is_training=is_training,
                          scope='dp1')
    net = tf_util.conv1d(net,
                         num_class,
                         1,
                         padding='VALID',
                         activation_fn=None,
                         scope='fc2')

    return net, end_points
Ejemplo n.º 2
0
    def call(self, inputs, **kwargs):
        l0_xyz = inputs
        l0_points = None
        l1_xyz, l1_points, l1_indices = self.l1([l0_xyz, tf.zeros([0])])
        print("l1_points shape: ", l1_points.shape)
        l2_xyz, l2_points, l2_indices = self.l2([l1_xyz, l1_points])
        print("l2_points shape: ", l2_points.shape)
        l3_xyz, l3_points, l3_indices = self.l3([l2_xyz, l2_points])
        print("l3_points shape: ", l3_points.shape)
        l4_xyz, l4_points, l4_indices = self.l4([l3_xyz, l3_points])
        print("l4_points shape: ", l4_points.shape)

        # Feature Propagation layers
        l3_points = pointnet_fp_module(l3_xyz,
                                       l4_xyz,
                                       l3_points,
                                       l4_points, [256, 256],
                                       self.is_training,
                                       self.bn_decay,
                                       scope='fa_layer1')
        l2_points = pointnet_fp_module(l2_xyz,
                                       l3_xyz,
                                       l2_points,
                                       l3_points, [256, 256],
                                       self.is_training,
                                       self.bn_decay,
                                       scope='fa_layer2')
        l1_points = pointnet_fp_module(l1_xyz,
                                       l2_xyz,
                                       l1_points,
                                       l2_points, [256, 128],
                                       self.is_training,
                                       self.bn_decay,
                                       scope='fa_layer3')
        l0_points = pointnet_fp_module(l0_xyz,
                                       l1_xyz,
                                       l0_points,
                                       l1_points, [128, 128, 128],
                                       self.is_training,
                                       self.bn_decay,
                                       scope='fa_layer4')

        # FC layers
        net = tf_util.conv1d(l0_points,
                             128,
                             1,
                             padding='VALID',
                             bn=True,
                             is_training=self.is_training,
                             scope='fc1',
                             bn_decay=self.bn_decay)
        net = tf_util.dropout(net,
                              keep_prob=0.5,
                              is_training=self.is_training,
                              scope='dp1')
        out = tf_util.conv1d(net,
                             self.num_class,
                             1,
                             padding='VALID',
                             activation_fn=None,
                             scope='fc2')

        return out
def get_model(point_cloud, is_training, num_class, bn_decay=None):
    """ Semantic segmentation PointNet, input is BxNx3, output Bxnum_class """
    batch_size = point_cloud.get_shape()[0].value
    num_point = point_cloud.get_shape()[1].value
    end_points = {}
    l0_xyz = point_cloud
    l0_points = None
    end_points['l0_xyz'] = l0_xyz

    # Layer 1
    l1_xyz, l1_points, l1_indices = pointnet_sa_module(l0_xyz,
                                                       l0_points,
                                                       npoint=1024,
                                                       radius=0.1,
                                                       nsample=32,
                                                       mlp=[32, 32, 64],
                                                       mlp2=None,
                                                       group_all=False,
                                                       is_training=is_training,
                                                       bn_decay=bn_decay,
                                                       scope='layer1')
    l2_xyz, l2_points, l2_indices = pointnet_sa_module(l1_xyz,
                                                       l1_points,
                                                       npoint=256,
                                                       radius=0.2,
                                                       nsample=32,
                                                       mlp=[64, 64, 128],
                                                       mlp2=None,
                                                       group_all=False,
                                                       is_training=is_training,
                                                       bn_decay=bn_decay,
                                                       scope='layer2')
    l3_xyz, l3_points, l3_indices = pointnet_sa_module(l2_xyz,
                                                       l2_points,
                                                       npoint=64,
                                                       radius=0.4,
                                                       nsample=32,
                                                       mlp=[128, 128, 256],
                                                       mlp2=None,
                                                       group_all=False,
                                                       is_training=is_training,
                                                       bn_decay=bn_decay,
                                                       scope='layer3')
    l4_xyz, l4_points, l4_indices = pointnet_sa_module(l3_xyz,
                                                       l3_points,
                                                       npoint=16,
                                                       radius=0.8,
                                                       nsample=32,
                                                       mlp=[256, 256, 512],
                                                       mlp2=None,
                                                       group_all=False,
                                                       is_training=is_training,
                                                       bn_decay=bn_decay,
                                                       scope='layer4')

    # Feature Propagation layers
    l3_points = pointnet_fp_module(l3_xyz,
                                   l4_xyz,
                                   l3_points,
                                   l4_points, [256, 256],
                                   is_training,
                                   bn_decay,
                                   scope='fa_layer1')
    l2_points = pointnet_fp_module(l2_xyz,
                                   l3_xyz,
                                   l2_points,
                                   l3_points, [256, 256],
                                   is_training,
                                   bn_decay,
                                   scope='fa_layer2')
    l1_points = pointnet_fp_module(l1_xyz,
                                   l2_xyz,
                                   l1_points,
                                   l2_points, [256, 128],
                                   is_training,
                                   bn_decay,
                                   scope='fa_layer3')
    l0_points = pointnet_fp_module(l0_xyz,
                                   l1_xyz,
                                   l0_points,
                                   l1_points, [128, 128, 128],
                                   is_training,
                                   bn_decay,
                                   scope='fa_layer4')

    # FC layers
    net = tf_util.conv1d(l0_points,
                         128,
                         1,
                         padding='VALID',
                         bn=True,
                         is_training=is_training,
                         scope='fc1',
                         bn_decay=bn_decay)
    end_points['feats'] = net
    net = tf_util.dropout(net,
                          keep_prob=0.5,
                          is_training=is_training,
                          scope='dp1')
    net = tf_util.conv1d(net,
                         num_class,
                         1,
                         padding='VALID',
                         activation_fn=None,
                         scope='fc2')

    return net, end_points