예제 #1
0
    def __init__(self,
                 in_channels,
                 stem_channels=(16, 32, 32),
                 local_channels=(128, 128),
                 seg_channels=(64, 64, 32),
                 dropout_prob_seg=0.2):
        super(PointNet, self).__init__()

        self.in_channels = in_channels

        # stem
        self.stem = Stem(in_channels, stem_channels, with_transform=False)
        self.mlp_local = SharedMLP(stem_channels[-1], local_channels)

        # part segmentation
        # Notice that the original repo concatenates global feature, one hot class embedding,
        # stem features and local features. However, the paper does not use last local feature.
        # Here, we follow the released repo.
        in_channels_seg = sum(stem_channels) + sum(
            local_channels) + local_channels[-1]
        self.mlp_seg = SharedMLP(in_channels_seg,
                                 seg_channels[:-1],
                                 dropout_prob=dropout_prob_seg)
        self.conv_seg = Conv1d(seg_channels[-2], seg_channels[-1], 1)

        self.reset_parameters()
예제 #2
0
class ConcatHead(nn.Module):
    def __init__(self, in_channels, dropout_prob=0.5):
        super(ConcatHead, self).__init__()
        self.mlp_local = SharedMLP(in_channels, (in_channels, ),
                                   dropout_prob=dropout_prob)
        self.conv1d = Conv1d(in_channels, 2, 1, relu=False, bn=False)
        #self.classifier = nn.Linear(in_channels, 2, bias=True)
        #self.mlp_local = SharedMLP(in_channels, (in_channels,2), dropout_prob=dropout_prob)
        self.reset_parameters()

    def forward(self, concat_feats):
        if isinstance(concat_feats, (list, tuple)):
            # concat_feature, (batch_size, in_channel, num_points)
            concat_feats = torch.cat(concat_feats, dim=1)
        # ins_logit, (batch_size, 2, num_points)
        ins_logit = self.mlp_local(concat_feats)
        ins_logit = self.conv1d(ins_logit)

        return ins_logit

    def reset_parameters(self):
        #xavier_uniform(self.classifier)
        self.mlp_local.reset_parameters(xavier_uniform)
        self.conv1d.reset_parameters(xavier_uniform)
        set_bn(self, momentum=0.01)
예제 #3
0
    def __init__(self,
                 in_channels,
                 mlp_channels,
                 num_centroids,
                 radius,
                 num_neighbours,
                 use_xyz):
        super(PointNetSAModule, self).__init__()

        self.in_channels = in_channels
        self.out_channels = mlp_channels[-1]
        self.num_centroids = num_centroids
        # self.num_neighbours = num_neighbours
        self.use_xyz = use_xyz

        if self.use_xyz:
            in_channels += 3
        self.mlp = SharedMLP(in_channels, mlp_channels, ndim=2, bn=True)

        if num_centroids <= 0:
            self.sampler = None
        else:
            self.sampler = FarthestPointSampler(num_centroids)

        if num_neighbours < 0:
            assert radius < 0.0
            self.grouper = None
        else:
            assert num_neighbours > 0 and radius > 0.0
            self.grouper = QueryGrouper(radius, num_neighbours)
예제 #4
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 stem_channels=(16, 32),
                 local_channels=(32, 32),
                 global_channels=(32, 32),
                 dropout_prob=0.5,
                 with_transform=False):
        super(PointNetCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.stem = Stem(in_channels,
                         stem_channels,
                         with_transform=with_transform)
        self.mlp_local = SharedMLP(stem_channels[-1], local_channels)
        self.mlp_global = MLP(local_channels[-1],
                              global_channels,
                              dropout_prob=dropout_prob)
        self.classifier = nn.Linear(global_channels[-1],
                                    out_channels,
                                    bias=True)
        self.classifier2 = nn.Linear(global_channels[-1], 2, bias=True)

        self.reset_parameters()
예제 #5
0
class PointnetFPModule(nn.Module):
    """PointNet feature propagation module"""

    def __init__(self,
                 in_channels,
                 mlp_channels,
                 num_neighbors):
        super(PointnetFPModule, self).__init__()

        self.in_channels = in_channels
        self.out_channels = mlp_channels[-1]

        self.mlp = SharedMLP(in_channels, mlp_channels, ndim=1, bn=True)
        if num_neighbors == 0:
            self.interpolator = None
        elif num_neighbors == 3:
            self.interpolator = FeatureInterpolator(num_neighbors)
        else:
            raise ValueError('Expected value 1 or 3, but {} given.'.format(num_neighbors))

    def forward(self, dense_xyz, sparse_xyz, dense_feature, sparse_feature):
        if self.interpolator is None:
            assert sparse_xyz.size(2) == 1 and sparse_feature.size(2) == 1
            sparse_feature_expand = sparse_feature.expand(-1, -1, dense_xyz.size(2))
            new_feature = torch.cat([sparse_feature_expand, dense_feature], dim=1)
        else:
            new_feature = self.interpolator(dense_xyz, sparse_xyz, dense_feature, sparse_feature)
        new_feature = self.mlp(new_feature)

        return new_feature

    def reset_parameters(self, init_fn=None):
        self.mlp.reset_parameters(init_fn)
예제 #6
0
    def __init__(self,
                 in_channels,
                 num_centroids=(128, 32, 0),
                 radius=(0.2, 0.4, -1.0),
                 num_neighbours=(64, 64, -1),
                 sa_channels=((16, 16, 32), (32, 32, 64), (128, 128, 256)),
                 fp_channels=((64, 64), (64, 32), (32, 32, 32)),
                 num_fp_neighbours=(0, 3, 3),
                 seg_channels=(32, ),
                 dropout_prob=0.5,
                 use_xyz=True):
        super(PointNet2SSG, self).__init__()

        self.in_channels = in_channels
        self.use_xyz = use_xyz

        # Sanity check
        num_sa_layers = len(num_centroids)
        num_fp_layers = len(fp_channels)
        assert len(radius) == num_sa_layers
        assert len(num_neighbours) == num_sa_layers
        assert len(sa_channels) == num_sa_layers
        assert num_sa_layers == num_fp_layers
        assert len(num_fp_neighbours) == num_fp_layers

        # Set Abstraction Layers
        feature_channels = in_channels - 3
        self.sa_modules = nn.ModuleList()
        for ind in range(num_sa_layers):
            sa_module = PointNetSAModule(in_channels=feature_channels,
                                         mlp_channels=sa_channels[ind],
                                         num_centroids=num_centroids[ind],
                                         radius=radius[ind],
                                         num_neighbours=num_neighbours[ind],
                                         use_xyz=use_xyz)
            self.sa_modules.append(sa_module)
            feature_channels = sa_channels[ind][-1]

        inter_channels = [in_channels if use_xyz else in_channels - 3]
        inter_channels.extend([x[-1] for x in sa_channels])

        # Feature Propagation Layers
        self.fp_modules = nn.ModuleList()
        feature_channels = inter_channels[-1]
        for ind in range(num_fp_layers):
            fp_module = PointnetFPModule(in_channels=feature_channels +
                                         inter_channels[-2 - ind],
                                         mlp_channels=fp_channels[ind],
                                         num_neighbors=num_fp_neighbours[ind])
            self.fp_modules.append(fp_module)
            feature_channels = fp_channels[ind][-1]

        # MLP
        self.mlp_seg = SharedMLP(feature_channels,
                                 seg_channels,
                                 ndim=1,
                                 dropout_prob=dropout_prob)

        self.reset_parameters()
예제 #7
0
 def __init__(self, in_channels, dropout_prob=0.5):
     super(ConcatHead, self).__init__()
     self.mlp_local = SharedMLP(in_channels, (in_channels, ),
                                dropout_prob=dropout_prob)
     self.conv1d = Conv1d(in_channels, 2, 1, relu=False, bn=False)
     #self.classifier = nn.Linear(in_channels, 2, bias=True)
     #self.mlp_local = SharedMLP(in_channels, (in_channels,2), dropout_prob=dropout_prob)
     self.reset_parameters()
예제 #8
0
class Stem(nn.Module):
    """Stem (main body or stalk). Extract features from raw point clouds"""
    def __init__(self,
                 in_channels,
                 stem_channels=(64, 128, 128),
                 with_transform=True):
        super(Stem, self).__init__()

        self.in_channels = in_channels
        self.out_channels = stem_channels[-1]
        self.with_transform = with_transform

        # feature stem
        self.mlp = SharedMLP(in_channels, stem_channels)
        self.mlp.reset_parameters(xavier_uniform)

        if self.with_transform:
            # input transform
            self.transform_input = TNet(in_channels, in_channels)
            # feature transform
            self.transform_feature = TNet(self.out_channels, self.out_channels)

    def forward(self, x):
        """PointNet Stem forward

        Args:
            x (torch.Tensor): (batch_size, in_channels, num_points)

        Returns:
            torch.Tensor: (batch_size, stem_channels[-1], num_points)
            dict (optional):
                trans_input: (batch_size, in_channels, in_channels)
                trans_feature: (batch_size, stem_channels[-1], stem_channels[-1])
                stem_features (list of torch.Tensor)

        """
        end_points = {}

        # input transform
        if self.with_transform:
            trans_input = self.transform_input(x)
            x = torch.bmm(trans_input, x)
            end_points['trans_input'] = trans_input

        # feature
        features = []
        for module in self.mlp:
            x = module(x)
            features.append(x)
        end_points['stem_features'] = features

        # feature transform
        if self.with_transform:
            trans_feature = self.transform_feature(x)
            x = torch.bmm(trans_feature, x)
            end_points['trans_feature'] = trans_feature

        return x, end_points
예제 #9
0
    def __init__(self,
                 in_channels,
                 num_classes,
                 num_seg_classes,
                 stem_channels=(64, 128, 128),
                 local_channels=(512, 2048),
                 cls_channels=(256, 256),
                 seg_channels=(256, 256, 128),
                 dropout_prob_cls=0.3,
                 dropout_prob_seg=0.2,
                 with_transform=True,
                 use_one_hot=True):
        super(PointNetPartSeg, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes
        self.num_seg_classes = num_seg_classes
        self.use_one_hot = use_one_hot

        # stem
        self.stem = Stem(in_channels,
                         stem_channels,
                         with_transform=with_transform)
        self.mlp_local = SharedMLP(stem_channels[-1], local_channels)

        # part segmentation
        # Notice that the original repo concatenates global feature, one hot class embedding,
        # stem features and local features. However, the paper does not use last local feature.
        # Here, we follow the released repo.
        in_channels_seg = sum(stem_channels) + sum(
            local_channels) + local_channels[-1]
        if self.use_one_hot:
            in_channels_seg += num_classes
        self.mlp_seg = SharedMLP(in_channels_seg,
                                 seg_channels[:-1],
                                 dropout_prob=dropout_prob_seg)
        self.conv_seg = Conv1d(seg_channels[-2], seg_channels[-1], 1)
        self.seg_logit = nn.Conv1d(seg_channels[-1],
                                   num_seg_classes,
                                   1,
                                   bias=True)

        # classification (optional)
        if len(cls_channels) > 0:
            # Notice that we apply dropout to each classification mlp.
            self.mlp_cls = MLP(local_channels[-1],
                               cls_channels,
                               dropout_prob=dropout_prob_cls)
            self.cls_logit = nn.Linear(cls_channels[-1],
                                       num_classes,
                                       bias=True)
        else:
            self.mlp_cls = None
            self.cls_logit = None

        self.reset_parameters()
예제 #10
0
class TNet(nn.Module):
    """Transformation Network for DGCNN

    Args:
        in_channels (int): the number of input channels
        out_channels (int): the number of output channels
        conv_channels (tuple of int): the numbers of channels of edge convolution layers
        local_channels (tuple of int): the numbers of channels in local mlp
        global_channels (tuple of int): the numbers of channels in global mlp
        k: the number of neareast neighbours for edge feature extractor

    """
    def __init__(self,
                 in_channels=3,
                 out_channels=3,
                 conv_channels=(64, 128),
                 local_channels=(1024, ),
                 global_channels=(512, 256),
                 k=20):
        super(TNet, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.edge_conv = EdgeConvBlock(in_channels, conv_channels, k)
        self.mlp_local = SharedMLP(conv_channels[-1], local_channels)
        self.mlp_global = MLP(local_channels[-1], global_channels)
        self.linear = nn.Linear(global_channels[-1],
                                self.in_channels * self.out_channels,
                                bias=True)

        self.reset_parameters()

    def forward(self, x):
        # input x: (batch_size, in_channels, num_points)
        x = self.edge_conv(x)  # (batch_size, edge_channels[-1], num_points)
        x = self.mlp_local(x)  # (batch_size, local_channels[-1], num_points)
        x, _ = torch.max(x, 2)
        x = self.mlp_global(x)
        x = self.linear(x)
        x = x.view(-1, self.out_channels, self.in_channels)
        I = torch.eye(self.out_channels, self.in_channels, device=x.device)
        x = x.add(I)  # broadcast first dimension
        return x

    def reset_parameters(self, init_fn=xavier_uniform):
        self.edge_conv.reset_parameters(init_fn)
        self.mlp_local.reset_parameters(init_fn)
        self.mlp_global.reset_parameters(init_fn)
        # set linear transform be 0
        nn.init.zeros_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)
예제 #11
0
    def __init__(self,
                 in_channels,
                 num_classes,
                 num_seg_classes,
                 edge_conv_channels=((64, 64), (64, 64), 64),
                 local_channels=(1024, ),
                 seg_channels=(256, 256, 128),
                 k=20,
                 dropout_prob=0.4,
                 with_transform=True):
        super(DGCNNPartSeg, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes
        self.num_seg_classes = num_seg_classes
        self.with_transform = with_transform

        # input transform
        if self.with_transform:
            self.transform_input = TNet(in_channels, in_channels, k=k)

        self.edge_convs = nn.ModuleList()
        inter_channels = []
        for conv_channels in edge_conv_channels:
            if isinstance(conv_channels, int):
                conv_channels = [conv_channels]
            else:
                assert isinstance(conv_channels, (tuple, list))
            self.edge_convs.append(EdgeConvBlock(in_channels, conv_channels,
                                                 k))
            inter_channels.append(conv_channels[-1])
            in_channels = conv_channels[-1]

        LABEL_CHANNELS = 64
        self.mlp_label = Conv1d(self.num_classes, LABEL_CHANNELS, 1)
        self.mlp_local = SharedMLP(sum(inter_channels), local_channels)

        mlp_seg_in_channels = sum(
            inter_channels) + local_channels[-1] + LABEL_CHANNELS
        self.mlp_seg = SharedMLP(mlp_seg_in_channels,
                                 seg_channels[:-1],
                                 dropout_prob=dropout_prob)
        self.conv_seg = Conv1d(seg_channels[-2], seg_channels[-1], 1)
        self.seg_logit = nn.Conv1d(seg_channels[-1],
                                   num_seg_classes,
                                   1,
                                   bias=True)

        self.reset_parameters()
예제 #12
0
    def __init__(self,
                 in_channels,
                 mlp_channels_list,
                 num_centroids,
                 radius_list,
                 num_neighbours_list,
                 use_xyz):
        super(PointNetSAModuleMSG, self).__init__()

        self.in_channels = in_channels
        self.out_channels = sum(mlp_channels[-1] for mlp_channels in mlp_channels_list)
        self.num_centroids = num_centroids
        self.use_xyz = use_xyz

        num_scales = len(mlp_channels_list)
        assert len(radius_list) == num_scales
        assert len(num_neighbours_list) == num_scales

        self.mlp = nn.ModuleList()
        if num_centroids == -1:
            self.sampler = None
        else:
            assert num_centroids > 0
            self.sampler = FarthestPointSampler(num_centroids)
        self.grouper = nn.ModuleList()

        if self.use_xyz:
            in_channels += 3
        for ind in range(num_scales):
            self.mlp.append(SharedMLP(in_channels, mlp_channels_list[ind], ndim=2, bn=True))
            self.grouper.append(QueryGrouper(radius_list[ind], num_neighbours_list[ind]))
예제 #13
0
    def __init__(self,
                 in_channels,
                 mlp_channels,
                 num_neighbors):
        super(PointnetFPModule, self).__init__()

        self.in_channels = in_channels
        self.out_channels = mlp_channels[-1]

        self.mlp = SharedMLP(in_channels, mlp_channels, ndim=1, bn=True)
        if num_neighbors == 0:
            self.interpolator = None
        elif num_neighbors == 3:
            self.interpolator = FeatureInterpolator(num_neighbors)
        else:
            raise ValueError('Expected value 1 or 3, but {} given.'.format(num_neighbors))
예제 #14
0
    def __init__(self,
                 in_channels,
                 stem_channels=(64, 64),
                 with_transform=True):
        super(Stem, self).__init__()

        self.in_channels = in_channels
        self.out_channels = stem_channels[-1]
        self.with_transform = with_transform

        # feature stem
        self.mlp = SharedMLP(in_channels, stem_channels)

        if self.with_transform:
            # input transform
            self.transform_input = TNet(in_channels, in_channels)
            # feature transform
            self.transform_feature = TNet(self.out_channels, self.out_channels)
예제 #15
0
class TNet(nn.Module):
    """Transformation Network. The structure is similar with PointNet"""
    def __init__(self,
                 in_channels=3,
                 out_channels=3,
                 local_channels=(64, 128, 1024),
                 global_channels=(512, 256)):
        super(TNet, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        # local features
        self.mlp_local = SharedMLP(in_channels, local_channels)
        # global features
        self.mlp_global = MLP(local_channels[-1], global_channels)
        # linear output
        self.linear = nn.Linear(global_channels[-1],
                                in_channels * out_channels,
                                bias=True)

        self.reset_parameters()

    def forward(self, x):
        # x: (batch_size, in_channels, num_points)
        x = self.mlp_local(x)  # (batch_size, local_channels[-1], num_points)
        x, _ = torch.max(x, 2)  # (batch_size, local_channels[-1])
        x = self.mlp_global(x)
        x = self.linear(x)
        x = x.view(-1, self.out_channels, self.in_channels)
        I = torch.eye(self.out_channels,
                      self.in_channels,
                      dtype=x.dtype,
                      device=x.device)
        x = x.add(I)  # broadcast add, (batch_size, out_channels, in_channels)
        return x

    def reset_parameters(self, init_fn=xavier_uniform):
        # Default initialization in original implementation
        self.mlp_local.reset_parameters(init_fn)
        self.mlp_global.reset_parameters(init_fn)
        # Initialize linear transform to 0
        nn.init.zeros_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)
예제 #16
0
    def __init__(self,
                 in_channels,
                 stem_channels=(64, 128, 128),
                 with_transform=True):
        super(Stem, self).__init__()

        self.in_channels = in_channels
        self.out_channels = stem_channels[-1]
        self.with_transform = with_transform

        # feature stem
        self.mlp = SharedMLP(in_channels, stem_channels)
        self.mlp.reset_parameters(xavier_uniform)

        if self.with_transform:
            # input transform
            self.transform_input = TNet(in_channels, in_channels)
            # feature transform
            self.transform_feature = TNet(self.out_channels, self.out_channels)
예제 #17
0
    def __init__(self,
                 in_channels=3,
                 out_channels=3,
                 local_channels=(64, 128, 1024),
                 global_channels=(512, 256)):
        super(TNet, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        # local features
        self.mlp_local = SharedMLP(in_channels, local_channels)
        # global features
        self.mlp_global = MLP(local_channels[-1], global_channels)
        # linear output
        self.linear = nn.Linear(global_channels[-1],
                                in_channels * out_channels,
                                bias=True)

        self.reset_parameters()
예제 #18
0
    def __init__(self,
                 in_channels=3,
                 out_channels=3,
                 conv_channels=(64, 128),
                 local_channels=(1024, ),
                 global_channels=(512, 256),
                 k=20):
        super(TNet, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.edge_conv = EdgeConvBlock(in_channels, conv_channels, k)
        self.mlp_local = SharedMLP(conv_channels[-1], local_channels)
        self.mlp_global = MLP(local_channels[-1], global_channels)
        self.linear = nn.Linear(global_channels[-1],
                                self.in_channels * self.out_channels,
                                bias=True)

        self.reset_parameters()
예제 #19
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 edge_conv_channels=(64, 64, 64, 128),
                 local_channels=(1024, ),
                 global_channels=(512, 256),
                 k=20,
                 dropout_prob=0.5,
                 with_transform=True):
        super(DGCNNCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.with_transform = with_transform

        # input transform
        if self.with_transform:
            self.transform_input = TNet(in_channels, in_channels, k=k)

        self.edge_convs = nn.ModuleList()
        inter_channels = []
        for conv_channels in edge_conv_channels:
            if isinstance(conv_channels, int):
                conv_channels = [conv_channels]
            else:
                assert isinstance(conv_channels, (tuple, list))
            self.edge_convs.append(EdgeConvBlock(in_channels, conv_channels,
                                                 k))
            inter_channels.append(conv_channels[-1])
            in_channels = conv_channels[-1]
        self.mlp_local = SharedMLP(sum(inter_channels), local_channels)
        self.mlp_global = MLP(local_channels[-1],
                              global_channels,
                              dropout_prob=dropout_prob)
        self.classifier = nn.Linear(global_channels[-1],
                                    self.out_channels,
                                    bias=True)

        self.reset_parameters()
예제 #20
0
class DGCNNCls(nn.Module):
    """DGCNN for classification

    Args:
        in_channels (int): the number of input channels
        out_channels (int): the number of output channels
        edge_conv_channels (tuple of int): the numbers of channels of edge convolution layers
        inter_channels (int): the number of channels of intermediate features
        global_channels (tuple of int): the numbers of channels in global mlp
        k (int): the number of neareast neighbours for edge feature extractor
        dropout_prob (float): the probability to dropout
        with_transform (bool): whether to use TNet to transform features.

    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 edge_conv_channels=(64, 64, 64, 128),
                 local_channels=(1024, ),
                 global_channels=(512, 256),
                 k=20,
                 dropout_prob=0.5,
                 with_transform=True):
        super(DGCNNCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.with_transform = with_transform

        # input transform
        if self.with_transform:
            self.transform_input = TNet(in_channels, in_channels, k=k)

        self.edge_convs = nn.ModuleList()
        inter_channels = []
        for conv_channels in edge_conv_channels:
            if isinstance(conv_channels, int):
                conv_channels = [conv_channels]
            else:
                assert isinstance(conv_channels, (tuple, list))
            self.edge_convs.append(EdgeConvBlock(in_channels, conv_channels,
                                                 k))
            inter_channels.append(conv_channels[-1])
            in_channels = conv_channels[-1]
        self.mlp_local = SharedMLP(sum(inter_channels), local_channels)
        self.mlp_global = MLP(local_channels[-1],
                              global_channels,
                              dropout_prob=dropout_prob)
        self.classifier = nn.Linear(global_channels[-1],
                                    self.out_channels,
                                    bias=True)

        self.reset_parameters()

    def forward(self, data_batch):
        x = data_batch['points']
        end_points = {}

        # input transform
        if self.with_transform:
            trans_input = self.transform_input(x)
            x = torch.bmm(trans_input, x)
            end_points['trans_input'] = trans_input

        # EdgeConv
        features = []
        for edge_conv in self.edge_convs:
            x = edge_conv(x)
            features.append(x)

        x = torch.cat(features, dim=1)

        x = self.mlp_local(x)
        x, max_indices = torch.max(x, 2)
        end_points['key_point_indices'] = max_indices
        x = self.mlp_global(x)
        x = self.classifier(x)
        preds = {
            'cls_logit': x,
        }
        preds.update(end_points)

        return preds

    def reset_parameters(self):
        for edge_conv in self.edge_convs:
            edge_conv.reset_parameters(xavier_uniform)
        self.mlp_local.reset_parameters(xavier_uniform)
        self.mlp_global.reset_parameters(xavier_uniform)
        xavier_uniform(self.classifier)
        set_bn(self, momentum=0.01)
예제 #21
0
class PointNet2SSGPartSeg(nn.Module):
    """PointNet++ part segmentation with single-scale grouping

    PointNetSA: PointNet Set Abstraction Layer
    PointNetFP: PointNet Feature Propagation Layer

    Args:
        in_channels (int): the number of input channels
        num_classes (int): the number of classification class
        num_seg_classes (int): the number of segmentation class
        num_centroids (tuple of int): the numbers of centroids to sample in each set abstraction module
        radius (tuple of float): a tuple of radius to query neighbours in each set abstraction module
        num_neighbours (tuple of int): the numbers of neighbours to query for each centroid
        sa_channels (tuple of tuple of int): the numbers of channels within each set abstraction module
        fp_channels (tuple of tuple of int): the numbers of channels for feature propagation (FP) module
        num_fp_neighbours (tuple of int): the numbers of nearest neighbor used in FP
        seg_channels (tuple of int): the numbers of channels in segmentation mlp
        dropout_prob (float): the probability to dropout input features
        use_xyz (bool): whether or not to use the xyz position of a points as a feature
        use_one_hot (bool): whehter to use one hot vector of class labels.

    References:
        https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_part_seg.py

    """

    def __init__(self,
                 in_channels,
                 num_classes,
                 num_seg_classes,
                 num_centroids=(512, 128, 0),
                 radius=(0.2, 0.4, -1.0),
                 num_neighbours=(64, 64, -1),
                 sa_channels=((64, 64, 128), (128, 128, 256), (256, 512, 1024)),
                 fp_channels=((256, 256), (256, 128), (128, 128, 128)),
                 num_fp_neighbours=(0, 3, 3),
                 seg_channels=(128,),
                 dropout_prob=0.5,
                 use_xyz=True,
                 use_one_hot=True):
        super(PointNet2SSGPartSeg, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes
        self.num_seg_classes = num_seg_classes
        self.use_xyz = use_xyz
        self.use_one_hot = use_one_hot

        # Sanity check
        num_sa_layers = len(num_centroids)
        num_fp_layers = len(fp_channels)
        assert len(radius) == num_sa_layers
        assert len(num_neighbours) == num_sa_layers
        assert len(sa_channels) == num_sa_layers
        assert num_sa_layers == num_fp_layers
        assert len(num_fp_neighbours) == num_fp_layers

        # Set Abstraction Layers
        feature_channels = in_channels - 3
        self.sa_modules = nn.ModuleList()
        for ind in range(num_sa_layers):
            sa_module = PointNetSAModule(in_channels=feature_channels,
                                         mlp_channels=sa_channels[ind],
                                         num_centroids=num_centroids[ind],
                                         radius=radius[ind],
                                         num_neighbours=num_neighbours[ind],
                                         use_xyz=use_xyz)
            self.sa_modules.append(sa_module)
            feature_channels = sa_channels[ind][-1]

        inter_channels = [in_channels if use_xyz else in_channels - 3]
        if self.use_one_hot:
            inter_channels[0] += num_classes  # concat with one-hot
        inter_channels.extend([x[-1] for x in sa_channels])

        # Feature Propagation Layers
        self.fp_modules = nn.ModuleList()
        feature_channels = inter_channels[-1]
        for ind in range(num_fp_layers):
            fp_module = PointnetFPModule(in_channels=feature_channels + inter_channels[-2 - ind],
                                         mlp_channels=fp_channels[ind],
                                         num_neighbors=num_fp_neighbours[ind])
            self.fp_modules.append(fp_module)
            feature_channels = fp_channels[ind][-1]

        # MLP
        self.mlp_seg = SharedMLP(feature_channels, seg_channels, ndim=1, dropout_prob=dropout_prob)
        self.seg_logit = nn.Conv1d(seg_channels[-1], num_seg_classes, 1, bias=True)

        self.reset_parameters()

    def forward(self, data_batch):
        points = data_batch['points']
        end_points = {}

        xyz = points.narrow(1, 0, 3)
        if points.size(1) > 3:
            feature = points.narrow(1, 3, points.size(1) - 3)
        else:
            feature = None

        # save intermediate results
        inter_xyz = [xyz]
        inter_feature = [points if self.use_xyz else feature]

        if self.use_one_hot:
            # one hot class label
            num_points = points.size(2)
            with torch.no_grad():
                cls_label = data_batch['cls_label']
                one_hot = cls_label.new_zeros(cls_label.size(0), self.num_classes)
                one_hot = one_hot.scatter(1, cls_label.unsqueeze(1), 1)  # (batch_size, num_classes)
                one_hot_expand = one_hot.unsqueeze(2).expand(-1, -1, num_points).float()
                inter_feature[0] = torch.cat((inter_feature[0], one_hot_expand), dim=1)

        # Set Abstraction Layers
        for sa_module in self.sa_modules:
            xyz, feature = sa_module(xyz, feature)
            inter_xyz.append(xyz)
            inter_feature.append(feature)

        # Feature Propagation Layers
        sparse_xyz = xyz
        sparse_feature = feature
        for fp_ind, fp_module in enumerate(self.fp_modules):
            dense_xyz = inter_xyz[-2 - fp_ind]
            dense_feature = inter_feature[-2 - fp_ind]
            fp_feature = fp_module(dense_xyz, sparse_xyz, dense_feature, sparse_feature)
            sparse_xyz = dense_xyz
            sparse_feature = fp_feature

        # MLP
        x = self.mlp_seg(sparse_feature)
        seg_logit = self.seg_logit(x)

        preds = {
            'seg_logit': seg_logit,
        }
        preds.update(end_points)

        return preds

    def reset_parameters(self):
        for sa_module in self.sa_modules:
            sa_module.reset_parameters(xavier_uniform)
        for fp_module in self.fp_modules:
            fp_module.reset_parameters(xavier_uniform)
        self.mlp_seg.reset_parameters(xavier_uniform)
        xavier_uniform(self.seg_logit)
        set_bn(self, momentum=0.01)
예제 #22
0
class PointNetCls(nn.Module):
    """PointNet classification model

    Args:
        in_channels (int): the number of input channels
        out_channels (int): the number of output channels
        stem_channels (tuple of int): the numbers of channels in stem feature extractor
        local_channels (tuple of int): the numbers of channels in local mlp
        global_channels (tuple of int): the numbers of channels in global mlp
        dropout_prob (float): the probability to dropout
        with_transform (bool): whether to use TNet to transform features.

    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 stem_channels=(16, 32),
                 local_channels=(32, 32),
                 global_channels=(32, 32),
                 dropout_prob=0.5,
                 with_transform=False):
        super(PointNetCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.stem = Stem(in_channels,
                         stem_channels,
                         with_transform=with_transform)
        self.mlp_local = SharedMLP(stem_channels[-1], local_channels)
        self.mlp_global = MLP(local_channels[-1],
                              global_channels,
                              dropout_prob=dropout_prob)
        self.classifier = nn.Linear(global_channels[-1],
                                    out_channels,
                                    bias=True)
        self.classifier2 = nn.Linear(global_channels[-1], 2, bias=True)

        self.reset_parameters()

    def forward(self, points):
        #x = data_batch['points']
        #x = data_batch['box_points']
        x = points

        # stem
        x, end_points = self.stem(x)
        # mlp for local features
        x = self.mlp_local(x)
        # max pool over points
        x, max_indices = torch.max(x, 2)
        end_points['key_point_indices'] = max_indices
        # mlp for global features
        x = self.mlp_global(x)

        x1 = self.classifier(x)
        x2 = self.classifier2(x)
        preds = {
            'cls_logit': x1,
            'node_logit': x2,
        }
        preds.update(end_points)

        return preds

    def reset_parameters(self):
        # default initialization in original implementation
        self.mlp_local.reset_parameters(xavier_uniform)
        self.mlp_global.reset_parameters(xavier_uniform)
        xavier_uniform(self.classifier)
        xavier_uniform(self.classifier2)
        # set batch normalization to 0.01 as default
        set_bn(self, momentum=0.01)
예제 #23
0
    def __init__(self,
                 in_channels,
                 num_classes,
                 num_seg_classes,
                 num_centroids=(512, 128, 0),
                 radius_list=((0.1, 0.2, 0.4), (0.4, 0.8), -1.0),
                 num_neighbours_list=((32, 64, 128), (64, 128), -1),
                 sa_channels_list=(
                     ((32, 32, 64), (64, 64, 128), (64, 96, 128)),
                     ((128, 128, 256), (128, 196, 256)),
                     (256, 512, 1024),
                 ),
                 fp_channels=((256, 256), (256, 128), (128, 128)),
                 num_fp_neighbours=(0, 3, 3),
                 seg_channels=(128, ),
                 dropout_prob=0.5,
                 use_xyz=True,
                 use_one_hot=True):
        super(PointNet2MSGPartSeg, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes
        self.num_seg_classes = num_seg_classes
        self.use_xyz = use_xyz
        self.use_one_hot = use_one_hot

        # sanity check
        num_sa_layers = len(num_centroids)
        num_fp_layers = len(fp_channels)
        assert len(radius_list) == num_sa_layers
        assert len(num_neighbours_list) == num_sa_layers
        assert len(sa_channels_list) == num_sa_layers
        assert num_sa_layers == num_fp_layers
        assert len(num_fp_neighbours) == num_fp_layers

        # Set Abstraction Layers
        feature_channels = in_channels - 3
        self.sa_modules = nn.ModuleList()
        for ind in range(num_sa_layers - 1):
            sa_module = PointNetSAModuleMSG(
                in_channels=feature_channels,
                mlp_channels_list=sa_channels_list[ind],
                num_centroids=num_centroids[ind],
                radius_list=radius_list[ind],
                num_neighbours_list=num_neighbours_list[ind],
                use_xyz=use_xyz)
            self.sa_modules.append(sa_module)
            feature_channels = sa_module.out_channels

        sa_module = PointNetSAModule(in_channels=feature_channels,
                                     mlp_channels=sa_channels_list[-1],
                                     num_centroids=num_centroids[-1],
                                     radius=radius_list[-1],
                                     num_neighbours=num_neighbours_list[-1],
                                     use_xyz=use_xyz)
        self.sa_modules.append(sa_module)

        inter_channels = [in_channels if use_xyz else in_channels - 3]
        if self.use_one_hot:
            inter_channels[0] += num_classes  # concat with one-hot
        inter_channels.extend(
            [sa_module.out_channels for sa_module in self.sa_modules])

        # Feature Propagation Layers
        self.fp_modules = nn.ModuleList()
        feature_channels = inter_channels[-1]
        for ind in range(num_fp_layers):
            fp_module = PointnetFPModule(in_channels=feature_channels +
                                         inter_channels[-2 - ind],
                                         mlp_channels=fp_channels[ind],
                                         num_neighbors=num_fp_neighbours[ind])
            self.fp_modules.append(fp_module)
            feature_channels = fp_channels[ind][-1]

        # MLP
        self.mlp_seg = SharedMLP(feature_channels,
                                 seg_channels,
                                 ndim=1,
                                 dropout_prob=dropout_prob)
        self.seg_logit = nn.Conv1d(seg_channels[-1],
                                   num_seg_classes,
                                   1,
                                   bias=True)

        self.reset_parameters()
예제 #24
0
class PointNet2MSGPartSeg(nn.Module):
    """ PointNet++ part segmentation with multi-scale grouping"""
    def __init__(self,
                 in_channels,
                 num_classes,
                 num_seg_classes,
                 num_centroids=(512, 128, 0),
                 radius_list=((0.1, 0.2, 0.4), (0.4, 0.8), -1.0),
                 num_neighbours_list=((32, 64, 128), (64, 128), -1),
                 sa_channels_list=(
                     ((32, 32, 64), (64, 64, 128), (64, 96, 128)),
                     ((128, 128, 256), (128, 196, 256)),
                     (256, 512, 1024),
                 ),
                 fp_channels=((256, 256), (256, 128), (128, 128)),
                 num_fp_neighbours=(0, 3, 3),
                 seg_channels=(128, ),
                 dropout_prob=0.5,
                 use_xyz=True,
                 use_one_hot=True):
        super(PointNet2MSGPartSeg, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes
        self.num_seg_classes = num_seg_classes
        self.use_xyz = use_xyz
        self.use_one_hot = use_one_hot

        # sanity check
        num_sa_layers = len(num_centroids)
        num_fp_layers = len(fp_channels)
        assert len(radius_list) == num_sa_layers
        assert len(num_neighbours_list) == num_sa_layers
        assert len(sa_channels_list) == num_sa_layers
        assert num_sa_layers == num_fp_layers
        assert len(num_fp_neighbours) == num_fp_layers

        # Set Abstraction Layers
        feature_channels = in_channels - 3
        self.sa_modules = nn.ModuleList()
        for ind in range(num_sa_layers - 1):
            sa_module = PointNetSAModuleMSG(
                in_channels=feature_channels,
                mlp_channels_list=sa_channels_list[ind],
                num_centroids=num_centroids[ind],
                radius_list=radius_list[ind],
                num_neighbours_list=num_neighbours_list[ind],
                use_xyz=use_xyz)
            self.sa_modules.append(sa_module)
            feature_channels = sa_module.out_channels

        sa_module = PointNetSAModule(in_channels=feature_channels,
                                     mlp_channels=sa_channels_list[-1],
                                     num_centroids=num_centroids[-1],
                                     radius=radius_list[-1],
                                     num_neighbours=num_neighbours_list[-1],
                                     use_xyz=use_xyz)
        self.sa_modules.append(sa_module)

        inter_channels = [in_channels if use_xyz else in_channels - 3]
        if self.use_one_hot:
            inter_channels[0] += num_classes  # concat with one-hot
        inter_channels.extend(
            [sa_module.out_channels for sa_module in self.sa_modules])

        # Feature Propagation Layers
        self.fp_modules = nn.ModuleList()
        feature_channels = inter_channels[-1]
        for ind in range(num_fp_layers):
            fp_module = PointnetFPModule(in_channels=feature_channels +
                                         inter_channels[-2 - ind],
                                         mlp_channels=fp_channels[ind],
                                         num_neighbors=num_fp_neighbours[ind])
            self.fp_modules.append(fp_module)
            feature_channels = fp_channels[ind][-1]

        # MLP
        self.mlp_seg = SharedMLP(feature_channels,
                                 seg_channels,
                                 ndim=1,
                                 dropout_prob=dropout_prob)
        self.seg_logit = nn.Conv1d(seg_channels[-1],
                                   num_seg_classes,
                                   1,
                                   bias=True)

        self.reset_parameters()

    def forward(self, data_batch):
        points = data_batch['points']
        end_points = {}

        xyz = points.narrow(1, 0, 3)
        if points.size(1) > 3:
            feature = points.narrow(1, 3, points.size(1) - 3)
        else:
            feature = None

        # save intermediate results
        inter_xyz = [xyz]
        inter_feature = [points if self.use_xyz else feature]

        if self.use_one_hot:
            # one hot class label
            num_points = points.size(2)
            with torch.no_grad():
                cls_label = data_batch['cls_label']
                one_hot = cls_label.new_zeros(cls_label.size(0),
                                              self.num_classes)
                one_hot = one_hot.scatter(1, cls_label.unsqueeze(1),
                                          1)  # (batch_size, num_classes)
                one_hot_expand = one_hot.unsqueeze(2).expand(
                    -1, -1, num_points).float()
                inter_feature[0] = torch.cat(
                    (inter_feature[0], one_hot_expand), dim=1)

        # Set Abstraction Layers
        for sa_module in self.sa_modules:
            xyz, feature = sa_module(xyz, feature)
            inter_xyz.append(xyz)
            inter_feature.append(feature)

        # Feature Propagation Layers
        sparse_xyz = xyz
        sparse_feature = feature
        for fp_ind, fp_module in enumerate(self.fp_modules):
            dense_xyz = inter_xyz[-2 - fp_ind]
            dense_feature = inter_feature[-2 - fp_ind]
            fp_feature = fp_module(dense_xyz, sparse_xyz, dense_feature,
                                   sparse_feature)
            sparse_xyz = dense_xyz
            sparse_feature = fp_feature

        # MLP
        x = self.mlp_seg(sparse_feature)
        seg_logit = self.seg_logit(x)

        preds = {'seg_logit': seg_logit}
        preds.update(end_points)

        return preds

    def reset_parameters(self):
        for sa_module in self.sa_modules:
            sa_module.reset_parameters(xavier_uniform)
        for fp_module in self.fp_modules:
            fp_module.reset_parameters(xavier_uniform)
        self.mlp_seg.reset_parameters(xavier_uniform)
        xavier_uniform(self.seg_logit)
        set_bn(self, momentum=0.01)
예제 #25
0
class PointNet2SSGPartSeg(nn.Module):
    """PointNet++ part segmentation with single-scale grouping

    PointNetSA: PointNet Set Abstraction Layer
    PointNetFP: PointNet Feature Propagation Layer

    Args:
        in_channels (int): the number of input channels
        num_centroids (tuple of int): the numbers of centroids to sample in each set abstraction module
        radius (tuple of float): a tuple of radius to query neighbours in each set abstraction module
        num_neighbours (tuple of int): the numbers of neighbours to query for each centroid
        sa_channels (tuple of tuple of int): the numbers of channels within each set abstraction module
        fp_channels (tuple of tuple of int): the numbers of channels for feature propagation (FP) module
        num_fp_neighbours (tuple of int): the numbers of nearest neighbor used in FP
        seg_channels (tuple of int): the numbers of channels in segmentation mlp
        dropout_prob (float): the probability to dropout input features
        use_xyz (bool): whether or not to use the xyz position of a points as a feature

    References:
        https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_part_seg.py

    """

    def __init__(self,
                 in_channels,
                 num_centroids=(128, 32, 0),
                 radius=(0.2, 0.4, -1.0),
                 num_neighbours=(64, 64, -1),
                 sa_channels=((16, 16, 32), (32, 32, 64), (128, 128, 256)),
                 fp_channels=((64, 64), (64, 32), (32, 32, 32)),
                 num_fp_neighbours=(0, 3, 3),
                 seg_channels=(32,),
                 dropout_prob=0.5,
                 use_xyz=True):
        super(PointNet2SSGPartSeg, self).__init__()

        self.in_channels = in_channels
        self.use_xyz = use_xyz

        # Sanity check
        num_sa_layers = len(num_centroids)
        num_fp_layers = len(fp_channels)
        assert len(radius) == num_sa_layers
        assert len(num_neighbours) == num_sa_layers
        assert len(sa_channels) == num_sa_layers
        assert num_sa_layers == num_fp_layers
        assert len(num_fp_neighbours) == num_fp_layers

        # Set Abstraction Layers
        feature_channels = in_channels - 3
        self.sa_modules = nn.ModuleList()
        for ind in range(num_sa_layers):
            sa_module = PointNetSAModule(in_channels=feature_channels,
                                         mlp_channels=sa_channels[ind],
                                         num_centroids=num_centroids[ind],
                                         radius=radius[ind],
                                         num_neighbours=num_neighbours[ind],
                                         use_xyz=use_xyz)
            self.sa_modules.append(sa_module)
            feature_channels = sa_channels[ind][-1]

        inter_channels = [in_channels if use_xyz else in_channels - 3]
        inter_channels.extend([x[-1] for x in sa_channels])

        # Feature Propagation Layers
        self.fp_modules = nn.ModuleList()
        feature_channels = inter_channels[-1]
        for ind in range(num_fp_layers):
            fp_module = PointnetFPModule(in_channels=feature_channels + inter_channels[-2 - ind],
                                         mlp_channels=fp_channels[ind],
                                         num_neighbors=num_fp_neighbours[ind])
            self.fp_modules.append(fp_module)
            feature_channels = fp_channels[ind][-1]

        # MLP
        self.mlp_seg = SharedMLP(feature_channels, seg_channels, ndim=1, dropout_prob=dropout_prob)

        self.reset_parameters()

    def extract_feats(self, points):

        xyz = points.narrow(1, 0, 3)
        if points.size(1) > 3:
            feature = points.narrow(1, 3, points.size(1) - 3)
        else:
            feature = None

        # save intermediate results
        inter_xyz = [xyz]
        inter_feature = [points if self.use_xyz else feature]

        # Set Abstraction Layers
        for sa_module in self.sa_modules:
            xyz, feature = sa_module(xyz, feature)
            inter_xyz.append(xyz)
            inter_feature.append(feature)

        # Feature Propagation Layers
        sparse_xyz = xyz
        sparse_feature = feature
        for fp_ind, fp_module in enumerate(self.fp_modules):
            dense_xyz = inter_xyz[-2 - fp_ind]
            dense_feature = inter_feature[-2 - fp_ind]
            fp_feature = fp_module(dense_xyz, sparse_xyz, dense_feature, sparse_feature)
            sparse_xyz = dense_xyz
            sparse_feature = fp_feature

        # MLP
        x = self.mlp_seg(sparse_feature)

        return x

    def forward(self, points):
        preds = {
            'feature': self.extract_feats(points),
        }

        return preds

    def reset_parameters(self):
        for sa_module in self.sa_modules:
            sa_module.reset_parameters(xavier_uniform)
        for fp_module in self.fp_modules:
            fp_module.reset_parameters(xavier_uniform)
        self.mlp_seg.reset_parameters(xavier_uniform)
        set_bn(self, momentum=0.01)
예제 #26
0
class PointNetSAModule(nn.Module):
    """PointNet set abstraction module"""

    def __init__(self,
                 in_channels,
                 mlp_channels,
                 num_centroids,
                 radius,
                 num_neighbours,
                 use_xyz):
        super(PointNetSAModule, self).__init__()

        self.in_channels = in_channels
        self.out_channels = mlp_channels[-1]
        self.num_centroids = num_centroids
        # self.num_neighbours = num_neighbours
        self.use_xyz = use_xyz

        if self.use_xyz:
            in_channels += 3
        self.mlp = SharedMLP(in_channels, mlp_channels, ndim=2, bn=True)

        if num_centroids <= 0:
            self.sampler = None
        else:
            self.sampler = FarthestPointSampler(num_centroids)

        if num_neighbours < 0:
            assert radius < 0.0
            self.grouper = None
        else:
            assert num_neighbours > 0 and radius > 0.0
            self.grouper = QueryGrouper(radius, num_neighbours)

    def forward(self, xyz, feature=None):
        """

        Args:
            xyz (torch.Tensor): (batch_size, 3, num_points)
                xyz coordinates of feature
            feature (torch.Tensor, optional): (batch_size, in_channels, num_points)

        Returns:
            new_xyz (torch.Tensor): (batch_size, 3, num_centroids)
            new_feature (torch.Tensor): (batch_size, out_channels, num_centroids)

        """

        if self.num_centroids == 0:
            # use the origin as the centroid
            new_xyz = xyz.new_zeros(xyz.size(0), 3, 1)  # (batch_size, 3, 1)
            assert self.grouper is None
            group_feature = feature.unsqueeze(2)  # (batch_size, in_channels, 1, num_points)
            group_xyz = xyz.unsqueeze(2)  # (batch_size, 3, 1, num_points)
            if self.use_xyz:
                group_feature = torch.cat([group_xyz, group_feature], dim=1)
        else:
            if self.num_centroids == -1:
                # use all points
                new_xyz = xyz
            else:
                # sample new points
                index = self.sampler(xyz)
                new_xyz = _F.gather_points(xyz, index)  # (batch_size, 3, num_centroids)

            # group_feature, (batch_size, in_channels, num_centroids, num_neighbours)
            group_feature, group_xyz = self.grouper(new_xyz, xyz, feature, use_xyz=self.use_xyz)

        new_feature = self.mlp(group_feature)
        new_feature, _ = torch.max(new_feature, 3)
        return new_xyz, new_feature

    def reset_parameters(self, init_fn=None):
        self.mlp.reset_parameters(init_fn)

    def extra_repr(self):
        return 'num_centroids={:d}, use_xyz={}'.format(self.num_centroids, self.use_xyz)
예제 #27
0
class PointNetPartSeg(nn.Module):
    """PointNet for part segmentation

     Args:
        in_channels (int): the number of input channels
        num_classes (int): the number of classification class
        num_seg_classes (int): the number of segmentation class
        stem_channels (tuple of int): the numbers of channels in stem feature extractor
        local_channels (tuple of int): the numbers of channels in local mlp
        cls_channels (tuple of int): the numbers of channels in classification mlp
        seg_channels (tuple of int): the numbers of channels in segmentation mlp
        dropout_prob_cls (float): the probability to dropout in classification mlp
        dropout_prob_seg (float): the probability to dropout in segmentation mlp
        with_transform (bool): whether to use TNet to transform features.
        use_one_hot (bool): whehter to use one hot vector of class labels.

    References:
        https://github.com/charlesq34/pointnet/blob/master/part_seg/pointnet_part_seg.py

    """
    def __init__(self,
                 in_channels,
                 num_classes,
                 num_seg_classes,
                 stem_channels=(64, 128, 128),
                 local_channels=(512, 2048),
                 cls_channels=(256, 256),
                 seg_channels=(256, 256, 128),
                 dropout_prob_cls=0.3,
                 dropout_prob_seg=0.2,
                 with_transform=True,
                 use_one_hot=True):
        super(PointNetPartSeg, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes
        self.num_seg_classes = num_seg_classes
        self.use_one_hot = use_one_hot

        # stem
        self.stem = Stem(in_channels,
                         stem_channels,
                         with_transform=with_transform)
        self.mlp_local = SharedMLP(stem_channels[-1], local_channels)

        # part segmentation
        # Notice that the original repo concatenates global feature, one hot class embedding,
        # stem features and local features. However, the paper does not use last local feature.
        # Here, we follow the released repo.
        in_channels_seg = sum(stem_channels) + sum(
            local_channels) + local_channels[-1]
        if self.use_one_hot:
            in_channels_seg += num_classes
        self.mlp_seg = SharedMLP(in_channels_seg,
                                 seg_channels[:-1],
                                 dropout_prob=dropout_prob_seg)
        self.conv_seg = Conv1d(seg_channels[-2], seg_channels[-1], 1)
        self.seg_logit = nn.Conv1d(seg_channels[-1],
                                   num_seg_classes,
                                   1,
                                   bias=True)

        # classification (optional)
        if len(cls_channels) > 0:
            # Notice that we apply dropout to each classification mlp.
            self.mlp_cls = MLP(local_channels[-1],
                               cls_channels,
                               dropout_prob=dropout_prob_cls)
            self.cls_logit = nn.Linear(cls_channels[-1],
                                       num_classes,
                                       bias=True)
        else:
            self.mlp_cls = None
            self.cls_logit = None

        self.reset_parameters()

    def forward(self, data_batch):
        x = data_batch['points']
        num_points = x.shape[2]
        end_points = {}

        # stem
        stem_feature, end_points_stem = self.stem(x)
        stem_features = end_points_stem.pop('stem_features')
        end_points.update(end_points_stem)

        # mlp for local features
        local_features = []
        x = stem_feature
        for ind, mlp in enumerate(self.mlp_local):
            x = mlp(x)
            local_features.append(x)

        # max pool over points
        global_feature, max_indices = torch.max(
            x, 2)  # (batch_size, local_channels[-1])
        # end_points['key_point_indices'] = max_indices

        # segmentation
        global_feature_expand = global_feature.unsqueeze(2).expand(
            -1, -1, num_points)
        seg_features = stem_features + local_features + [global_feature_expand]
        if self.use_one_hot:
            with torch.no_grad():
                cls_label = data_batch['cls_label']
                one_hot = cls_label.new_zeros(cls_label.size(0),
                                              self.num_classes)
                one_hot = one_hot.scatter(
                    1, cls_label.unsqueeze(1),
                    1).float()  # (batch_size, num_classes)
                one_hot_expand = one_hot.unsqueeze(2).expand(
                    -1, -1, num_points)
            seg_features.append(one_hot_expand)

        x = torch.cat(seg_features, dim=1)
        x = self.mlp_seg(x)
        x = self.conv_seg(x)
        seg_logit = self.seg_logit(x)

        preds = {'seg_logit': seg_logit}
        preds.update(end_points)

        # classification (optional)
        if self.cls_logit is not None:
            x = self.mlp_cls(global_feature)
            cls_logit = self.cls_logit(x)
            preds['cls_logit'] = cls_logit

        return preds

    def reset_parameters(self):
        # default initialization
        self.mlp_local.reset_parameters(xavier_uniform)
        self.mlp_seg.reset_parameters(xavier_uniform)
        self.conv_seg.reset_parameters(xavier_uniform)
        if self.cls_logit is not None:
            self.mlp_cls.reset_parameters(xavier_uniform)
            xavier_uniform(self.cls_logit)
        xavier_uniform(self.seg_logit)
        # set batch normalization to 0.01 as default
        set_bn(self, momentum=0.01)
예제 #28
0
class PointNetCls(nn.Module):
    """PointNet classification model

    Args:
        in_channels (int): the number of input channels
        out_channels (int): the number of output channels
        stem_channels (tuple of int): the numbers of channels in stem feature extractor
        local_channels (tuple of int): the numbers of channels in local mlp
        global_channels (tuple of int): the numbers of channels in global mlp
        dropout_prob (float): the probability to dropout
        with_transform (bool): whether to backboneuse TNet to transform features.

    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 stem_channels=(128, 128),
                 local_channels=(128, 256, 256),
                 global_channels=(128, 128),
                 dropout_prob=0.3,
                 with_transform=False):
        super(PointNetCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        p1 = 128
        stem_channels1 = ((p1, p1, int(2 * p1)), (int(2 * p1), int(2 * p1),
                                                  int(2 * p1)))
        p2 = int(p1 / 4)
        stem_channels2 = ((p2, p2, int(2 * p2)), (int(2 * p2), int(2 * p2),
                                                  int(2 * p2)))
        p3 = int(p1 / 8)
        stem_channels3 = ((p3, p3, int(2 * p3)), (int(2 * p3), int(2 * p3),
                                                  int(2 * p3)))
        global_channels1 = (p1 * 2, p1)
        global_channels2 = (p2 * 2, p2)
        global_channels3 = (p3 * 2, p3)

        self.p1 = PointNet2SSGCls(sa_channels=stem_channels1,
                                  global_channels=global_channels1)
        self.classifier1 = nn.Linear(global_channels[-1],
                                     out_channels,
                                     bias=True)

        self.p2 = PointNet2SSGCls(sa_channels=stem_channels1,
                                  global_channels=global_channels1)
        self.classifier2 = nn.Linear(global_channels[-1],
                                     out_channels,
                                     bias=True)

        self.p3 = PointNet2SSGCls(sa_channels=stem_channels2,
                                  global_channels=global_channels2)
        self.classifier3 = nn.Linear(int(global_channels[-1] / 4),
                                     int(global_channels[-1] / 4),
                                     bias=True)

        self.mlp_global22 = MLP(int(global_channels[-1] / 4) * 2,
                                tuple(int(e / 4) for e in global_channels),
                                dropout_prob=dropout_prob)
        self.classifier22 = nn.Linear(int(global_channels[-1] / 4),
                                      1,
                                      bias=True)

        self.p4 = PointNet2SSGCls(sa_channels=stem_channels1,
                                  global_channels=global_channels1)
        self.classifier4 = nn.Linear(global_channels[-1], 1, bias=True)

        self.stem8 = Stem(int(out_channels * 1) + 3,
                          stem_channels,
                          with_transform=with_transform)
        self.mlp_local8 = SharedMLP(stem_channels[-1], local_channels)
        self.mlp_global8 = MLP(local_channels[-1],
                               global_channels,
                               dropout_prob=dropout_prob)
        self.classifier8 = nn.Linear(global_channels[-1], 2, bias=True)

        self.stem9 = Stem(int(out_channels * 2) + 3,
                          stem_channels,
                          with_transform=with_transform)
        self.mlp_local9 = SharedMLP(stem_channels[-1], local_channels)
        self.mlp_global9 = MLP(local_channels[-1],
                               global_channels,
                               dropout_prob=dropout_prob)
        self.classifier9 = nn.Linear(global_channels[-1], 2, bias=True)

        self.sigmoid = nn.Sigmoid()
        self.reset_parameters()

    def forward(self, x, infer_type, y=None):

        if infer_type == 'backbone':
            x = self.p1(x)
            x = self.classifier1(x)
        elif infer_type == 'backbone2':
            x = self.p2(x)
            x = self.classifier2(x)
        elif infer_type == 'policy':
            x = self.p3(x)
            x = self.classifier3(x)
        elif infer_type == 'policy_head':
            x = self.mlp_global22(x)
            x = self.classifier22(x)
        elif infer_type == 'purity':
            x = self.p4(x)
            x = self.classifier4(x)
            x = self.sigmoid(x)
        elif infer_type == 'head':
            x, end_points = self.stem8(x)
            x = self.mlp_local8(x)
            x, max_indices = torch.max(x, 2)
            end_points['key_point_indices'] = max_indices
            x = self.mlp_global8(x)
            x = self.classifier8(x)
        elif infer_type == 'head2':
            x, end_points = self.stem9(x)
            x = self.mlp_local9(x)
            x, max_indices = torch.max(x, 2)
            end_points['key_point_indices'] = max_indices
            x = self.mlp_global9(x)
            x = self.classifier9(x)
        else:
            raise NameError

        return x

    def reset_parameters(self):
        # default initialization in original implementation
        self.p1.reset_parameters()
        self.p2.reset_parameters()
        self.p3.reset_parameters()
        self.p4.reset_parameters()
        xavier_uniform(self.classifier1)
        xavier_uniform(self.classifier2)
        xavier_uniform(self.classifier3)
        xavier_uniform(self.classifier22)
        xavier_uniform(self.classifier4)
        xavier_uniform(self.classifier8)
        xavier_uniform(self.classifier9)
        self.mlp_local8.reset_parameters(xavier_uniform)
        self.mlp_global8.reset_parameters(xavier_uniform)
        self.mlp_local9.reset_parameters(xavier_uniform)
        self.mlp_global9.reset_parameters(xavier_uniform)
        self.mlp_global22.reset_parameters(xavier_uniform)
        set_bn(self, momentum=0.01)
예제 #29
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 stem_channels=(128, 128),
                 local_channels=(128, 256, 256),
                 global_channels=(128, 128),
                 dropout_prob=0.3,
                 with_transform=False):
        super(PointNetCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        p1 = 128
        stem_channels1 = ((p1, p1, int(2 * p1)), (int(2 * p1), int(2 * p1),
                                                  int(2 * p1)))
        p2 = int(p1 / 4)
        stem_channels2 = ((p2, p2, int(2 * p2)), (int(2 * p2), int(2 * p2),
                                                  int(2 * p2)))
        p3 = int(p1 / 8)
        stem_channels3 = ((p3, p3, int(2 * p3)), (int(2 * p3), int(2 * p3),
                                                  int(2 * p3)))
        global_channels1 = (p1 * 2, p1)
        global_channels2 = (p2 * 2, p2)
        global_channels3 = (p3 * 2, p3)

        self.p1 = PointNet2SSGCls(sa_channels=stem_channels1,
                                  global_channels=global_channels1)
        self.classifier1 = nn.Linear(global_channels[-1],
                                     out_channels,
                                     bias=True)

        self.p2 = PointNet2SSGCls(sa_channels=stem_channels1,
                                  global_channels=global_channels1)
        self.classifier2 = nn.Linear(global_channels[-1],
                                     out_channels,
                                     bias=True)

        self.p3 = PointNet2SSGCls(sa_channels=stem_channels2,
                                  global_channels=global_channels2)
        self.classifier3 = nn.Linear(int(global_channels[-1] / 4),
                                     int(global_channels[-1] / 4),
                                     bias=True)

        self.mlp_global22 = MLP(int(global_channels[-1] / 4) * 2,
                                tuple(int(e / 4) for e in global_channels),
                                dropout_prob=dropout_prob)
        self.classifier22 = nn.Linear(int(global_channels[-1] / 4),
                                      1,
                                      bias=True)

        self.p4 = PointNet2SSGCls(sa_channels=stem_channels1,
                                  global_channels=global_channels1)
        self.classifier4 = nn.Linear(global_channels[-1], 1, bias=True)

        self.stem8 = Stem(int(out_channels * 1) + 3,
                          stem_channels,
                          with_transform=with_transform)
        self.mlp_local8 = SharedMLP(stem_channels[-1], local_channels)
        self.mlp_global8 = MLP(local_channels[-1],
                               global_channels,
                               dropout_prob=dropout_prob)
        self.classifier8 = nn.Linear(global_channels[-1], 2, bias=True)

        self.stem9 = Stem(int(out_channels * 2) + 3,
                          stem_channels,
                          with_transform=with_transform)
        self.mlp_local9 = SharedMLP(stem_channels[-1], local_channels)
        self.mlp_global9 = MLP(local_channels[-1],
                               global_channels,
                               dropout_prob=dropout_prob)
        self.classifier9 = nn.Linear(global_channels[-1], 2, bias=True)

        self.sigmoid = nn.Sigmoid()
        self.reset_parameters()
예제 #30
0
class DGCNNPartSeg(nn.Module):
    """DGCNN for part segmentation

    Args:
        in_channels (int): the number of input channels
        num_classes (int): the number of classification class
        num_seg_classes (int): the number of segmentation class
        edge_conv_channels (tuple of int): the numbers of channels of edge convolution layers
        local_channels (tuple of int): the number of channels of intermediate features
        seg_channels (tuple of int): the numbers of channels in segmentation mlp
        k (int): the number of neareast neighbours for edge feature extractor
        dropout_prob (float): the probability to dropout
        with_transform (bool): whether to use TNet to transform features.

    """
    def __init__(self,
                 in_channels,
                 num_classes,
                 num_seg_classes,
                 edge_conv_channels=((64, 64), (64, 64), 64),
                 local_channels=(1024, ),
                 seg_channels=(256, 256, 128),
                 k=20,
                 dropout_prob=0.4,
                 with_transform=True):
        super(DGCNNPartSeg, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes
        self.num_seg_classes = num_seg_classes
        self.with_transform = with_transform

        # input transform
        if self.with_transform:
            self.transform_input = TNet(in_channels, in_channels, k=k)

        self.edge_convs = nn.ModuleList()
        inter_channels = []
        for conv_channels in edge_conv_channels:
            if isinstance(conv_channels, int):
                conv_channels = [conv_channels]
            else:
                assert isinstance(conv_channels, (tuple, list))
            self.edge_convs.append(EdgeConvBlock(in_channels, conv_channels,
                                                 k))
            inter_channels.append(conv_channels[-1])
            in_channels = conv_channels[-1]

        LABEL_CHANNELS = 64
        self.mlp_label = Conv1d(self.num_classes, LABEL_CHANNELS, 1)
        self.mlp_local = SharedMLP(sum(inter_channels), local_channels)

        mlp_seg_in_channels = sum(
            inter_channels) + local_channels[-1] + LABEL_CHANNELS
        self.mlp_seg = SharedMLP(mlp_seg_in_channels,
                                 seg_channels[:-1],
                                 dropout_prob=dropout_prob)
        self.conv_seg = Conv1d(seg_channels[-2], seg_channels[-1], 1)
        self.seg_logit = nn.Conv1d(seg_channels[-1],
                                   num_seg_classes,
                                   1,
                                   bias=True)

        self.reset_parameters()

    def forward(self, data_batch):
        x = data_batch['points']
        num_points = x.shape[2]
        end_points = {}

        # input transform
        if self.with_transform:
            trans_input = self.transform_input(x)
            x = torch.bmm(trans_input, x)
            end_points['trans_input'] = trans_input

        # EdgeConv
        features = []
        for edge_conv in self.edge_convs:
            x = edge_conv(x)
            features.append(x)

        inter_feature = torch.cat(
            features, dim=1)  # (batch_size, sum(inter_channels), num_points)
        x = self.mlp_local(inter_feature)
        global_feature, max_indices = torch.max(
            x, 2)  # (batch_size, local_channels[-1])
        # end_points['key_point_indices'] = max_indices
        global_feature_expand = global_feature.unsqueeze(2).expand(
            -1, -1, num_points)

        with torch.no_grad():
            cls_label = data_batch['cls_label']
            one_hot = cls_label.new_zeros(cls_label.size(0), self.num_classes)
            one_hot = one_hot.scatter(1, cls_label.unsqueeze(1),
                                      1).float()  # (batch_size, num_classes)
            one_hot_expand = one_hot.unsqueeze(2).expand(-1, -1, num_points)
        label_feature = self.mlp_label(one_hot_expand)

        # (batch_size, mlp_seg_in_channels, num_points)
        x = torch.cat((inter_feature, global_feature_expand, label_feature),
                      dim=1)
        x = self.mlp_seg(x)
        x = self.conv_seg(x)
        seg_logit = self.seg_logit(x)
        preds = {
            'seg_logit': seg_logit,
        }
        preds.update(end_points)

        return preds

    def reset_parameters(self):
        for edge_conv in self.edge_convs:
            edge_conv.reset_parameters(xavier_uniform)
        self.mlp_label.reset_parameters(xavier_uniform)
        self.mlp_local.reset_parameters(xavier_uniform)
        self.mlp_seg.reset_parameters(xavier_uniform)
        self.conv_seg.reset_parameters(xavier_uniform)
        xavier_uniform(self.seg_logit)
        set_bn(self, momentum=0.01)