示例#1
0
    def __init__(self,
                 in_channels,
                 num_classes,
                 num_seg_classes,
                 stem_channels=(64, 128, 128),
                 local_channels=(512, 2048),
                 cls_channels=(256, 256),
                 seg_channels=(256, 256, 128),
                 dropout_prob_cls=0.3,
                 dropout_prob_seg=0.2,
                 with_transform=True):
        """

        Args:
           in_channels (int): the number of input channels
           out_channels (int): the number of output channels
           stem_channels (tuple of int): the numbers of channels in stem feature extractor
           local_channels (tuple of int): the numbers of channels in local mlp
           cls_channels (tuple of int): the numbers of channels in classification mlp
           seg_channels (tuple of int): the numbers of channels in segmentation mlp
           dropout_prob_cls (float): the probability to dropout in classification mlp
           dropout_prob_seg (float): the probability to dropout in segmentation mlp
           with_transform (bool): whether to use TNet to transform features.

        """
        super(PointNetPartSeg, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes
        self.num_seg_classes = num_seg_classes

        # stem
        self.stem = Stem(in_channels,
                         stem_channels,
                         with_transform=with_transform)
        self.mlp_local = SharedMLP(stem_channels[-1], local_channels)

        # classification
        # Notice that we apply dropout to each classification mlp.
        self.mlp_cls = MLP(local_channels[-1],
                           cls_channels,
                           dropout=dropout_prob_cls)
        self.cls_logit = nn.Linear(cls_channels[-1], num_classes, bias=True)

        # part segmentation
        # Notice that the original repo concatenates global feature, one hot class embedding,
        # stem features and local features. However, the paper does not use last local feature.
        # Here, we follow the released repo.
        in_channels_seg = local_channels[-1] + num_classes + sum(
            stem_channels) + sum(local_channels)
        self.mlp_seg = SharedMLP(in_channels_seg,
                                 seg_channels[:-1],
                                 dropout=dropout_prob_seg)
        self.conv_seg = Conv1d(seg_channels[-2], seg_channels[-1], 1)
        self.seg_logit = nn.Conv1d(seg_channels[-1],
                                   num_seg_classes,
                                   1,
                                   bias=True)

        self.init_weights()
示例#2
0
class EdgeConvBlock(nn.Module):
    """EdgeConv Block

    Structure: point features -> [get_edge_feature] -> edge features -> [MLP(2d)]
    -> [MaxPool] -> point features

    """

    def __init__(self, in_channels, out_channels, k):
        super(EdgeConvBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.k = k

        self.mlp = SharedMLP(2 * in_channels, out_channels, ndim=2)

    def forward(self, x):
        x = get_edge_feature(x, self.k)
        x = self.mlp(x)
        x, _ = torch.max(x, 3)
        return x

    def init_weights(self, init_fn=None):
        self.mlp.init_weights(init_fn)
示例#3
0
    def __init__(self, in_channels, out_channels, k):
        super(EdgeConvBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.k = k

        self.mlp = SharedMLP(2 * in_channels, out_channels, ndim=2)
示例#4
0
class TNet(nn.Module):
    """Transformation Network. The structure is similar with PointNet"""
    def __init__(self,
                 in_channels=3,
                 out_channels=3,
                 local_channels=(64, 128, 1024),
                 global_channels=(512, 256)):
        super(TNet, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        # local features
        self.mlp_local = SharedMLP(in_channels, local_channels)

        # global features
        self.mlp_global = MLP(local_channels[-1], global_channels)

        # linear output
        self.linear = nn.Linear(global_channels[-1],
                                in_channels * out_channels,
                                bias=True)

        self.init_weights()

    def forward(self, x):
        """TNet forward

        Args:
            x (torch.Tensor): (batch_size, in_channels, num_points)

        Returns:
            torch.Tensor: (batch_size, out_channels, in_channels)

        """
        x = self.mlp_local(x)  # (batch_size, local_channels[-1], num_points)
        x, _ = torch.max(x, 2)  # (batch_size, local_channels[-1])
        x = self.mlp_global(x)
        x = self.linear(x)
        x = x.view(-1, self.out_channels, self.in_channels)
        I = torch.eye(self.out_channels,
                      self.in_channels,
                      dtype=x.dtype,
                      device=x.device)
        x = x.add(I)  # broadcast add
        return x

    def init_weights(self):
        # Default initialization in original implementation
        self.mlp_local.init_weights(xavier_uniform)
        self.mlp_global.init_weights(xavier_uniform)
        # Initialize linear transform to 0
        nn.init.zeros_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)
示例#5
0
    def __init__(self, in_channels, mlp_channels, num_centroids, radius,
                 num_neighbours, use_xyz):
        super(PointNetSAModule, self).__init__()

        self.in_channels = in_channels
        self.out_channels = mlp_channels[-1]

        if use_xyz:
            in_channels += 3
        self.mlp = SharedMLP(in_channels, mlp_channels, ndim=2, bn=True)

        self.sampler = FarthestPointSampler(num_centroids)
        self.grouper = QueryGrouper(radius, num_neighbours, use_xyz=use_xyz)
示例#6
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 num_centroids=(512, 128),
                 radius_list=((0.1, 0.2, 0.4), (0.2, 0.4, 0.8)),
                 num_neighbours_list=((16, 32, 128), (32, 64, 128)),
                 sa_channels_list=(((32, 32, 64), (64, 64, 128), (64, 96,
                                                                  128)),
                                   ((64, 64, 128), (128, 128, 256), (128, 128,
                                                                     256))),
                 local_channels=(256, 512, 1024),
                 global_channels=(512, 256),
                 dropout_prob=0.5,
                 use_xyz=True):
        """Refer to PointNet2SSGCls. Major difference is that all the arguments will be a tuple of original types."""
        super(PointNet2MSGCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_xyz = use_xyz

        # sanity check
        num_layers = len(num_centroids)
        assert len(radius_list) == num_layers
        assert len(num_neighbours_list) == num_layers
        assert len(sa_channels_list) == num_layers

        feature_channels = in_channels - 3
        self.sa_modules = nn.ModuleList()
        for ind in range(num_layers):
            sa_module = PointNetSAModuleMSG(
                in_channels=feature_channels,
                mlp_channels_list=sa_channels_list[ind],
                num_centroids=num_centroids[ind],
                radius_list=radius_list[ind],
                num_neighbours_list=num_neighbours_list[ind],
                use_xyz=use_xyz)
            self.sa_modules.append(sa_module)
            feature_channels = sa_module.out_channels

        if use_xyz:
            feature_channels += 3
        self.mlp_local = SharedMLP(feature_channels, local_channels, bn=True)
        self.mlp_global = MLP(local_channels[-1],
                              global_channels,
                              dropout=dropout_prob)
        self.classifier = nn.Linear(global_channels[-1],
                                    out_channels,
                                    bias=True)

        self.init_weights()
示例#7
0
    def __init__(self, in_channels, mlp_channels_list, num_centroids,
                 radius_list, num_neighbours_list, use_xyz):
        super(PointNetSAModuleMSG, self).__init__()

        self.in_channels = in_channels
        self.out_channels = sum(mlp_channels[-1]
                                for mlp_channels in mlp_channels_list)

        num_scales = len(mlp_channels_list)
        assert len(radius_list) == num_scales
        assert len(num_neighbours_list) == num_scales

        if use_xyz:
            in_channels += 3
        self.mlp = nn.ModuleList()

        self.sampler = FarthestPointSampler(num_centroids)
        self.grouper = nn.ModuleList()

        for ind in range(num_scales):
            self.mlp.append(
                SharedMLP(in_channels, mlp_channels_list[ind], ndim=2,
                          bn=True))
            self.grouper.append(
                QueryGrouper(radius_list[ind],
                             num_neighbours_list[ind],
                             use_xyz=use_xyz))
示例#8
0
class PointnetFPModule(nn.Module):
    """PointNet feature propagation module"""
    def __init__(self, in_channels, mlp_channels, num_neighbors):
        super(PointnetFPModule, self).__init__()

        self.in_channels = in_channels
        self.mlp = SharedMLP(in_channels, mlp_channels, ndim=1, bn=True)
        self.interpolator = FeatureInterpolator(num_neighbors)

    def forward(self, query_xyz, key_xyz, query_feature, key_feature):
        new_feature = self.interpolator(query_xyz, key_xyz, query_feature,
                                        key_feature)
        new_feature = self.mlp(new_feature)

        return new_feature

    def init_weights(self, init_fn=None):
        self.mlp.init_weights(init_fn)
示例#9
0
class PointNetSAModule(nn.Module):
    """PointNet set abstraction module"""
    def __init__(self, in_channels, mlp_channels, num_centroids, radius,
                 num_neighbours, use_xyz):
        super(PointNetSAModule, self).__init__()

        self.in_channels = in_channels
        self.out_channels = mlp_channels[-1]

        if use_xyz:
            in_channels += 3
        self.mlp = SharedMLP(in_channels, mlp_channels, ndim=2, bn=True)

        self.sampler = FarthestPointSampler(num_centroids)
        self.grouper = QueryGrouper(radius, num_neighbours, use_xyz=use_xyz)

    def forward(self, xyz, feature=None):
        """

        Args:
            xyz (torch.Tensor): (batch_size, 3, num_points)
                xyz coordinates of feature
            feature (torch.Tensor, optional): (batch_size, in_channels, num_points)

        Returns:
            new_xyz (torch.Tensor): (batch_size, 3, num_centroids)
            new_feature (torch.Tensor): (batch_size, out_channels, num_centroids)

        """
        # sample new points
        index = self.sampler(xyz)
        # (batch_size, 3, num_centroids)
        new_xyz = _F.gather_points(xyz, index)

        # (batch_size, in_channels, num_centroids, num_neighbours)
        new_feature = self.grouper(new_xyz, xyz, feature)

        new_feature = self.mlp(new_feature)
        new_feature, _ = torch.max(new_feature, 3)

        return new_xyz, new_feature

    def init_weights(self, init_fn=None):
        self.mlp.init_weights(init_fn)
示例#10
0
    def __init__(self,
                 in_channels,
                 stem_channels=(64, 64),
                 with_transform=True,
                 bn=True):
        super(Stem, self).__init__()

        self.in_channels = in_channels
        self.out_channels = stem_channels[-1]
        self.with_transform = with_transform

        # feature stem
        self.mlp = SharedMLP(in_channels, stem_channels, bn=bn)
        self.mlp.init_weights(xavier_uniform)

        if self.with_transform:
            # input transform
            self.transform_input = TNet(in_channels, in_channels)
            # feature transform
            self.transform_feature = TNet(self.out_channels, self.out_channels)
示例#11
0
    def __init__(self,
                 in_channels=3,
                 out_channels=3,
                 conv_channels=(64, 128),
                 local_channels=(1024, ),
                 global_channels=(512, 256),
                 k=20):
        super(TNet, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.k = k

        self.edge_conv = SharedMLP(2 * in_channels, conv_channels, ndim=2)
        self.mlp_local = SharedMLP(conv_channels[-1], local_channels)
        self.mlp_global = MLP(local_channels[-1], global_channels)
        self.linear = nn.Linear(global_channels[-1],
                                self.in_channels * out_channels,
                                bias=True)

        self.init_weights()
示例#12
0
    def __init__(self,
                 in_channels=3,
                 out_channels=3,
                 local_channels=(64, 128, 1024),
                 global_channels=(512, 256)):
        super(TNet, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        # local features
        self.mlp_local = SharedMLP(in_channels, local_channels)

        # global features
        self.mlp_global = MLP(local_channels[-1], global_channels)

        # linear output
        self.linear = nn.Linear(global_channels[-1],
                                in_channels * out_channels,
                                bias=True)

        self.init_weights()
示例#13
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 stem_channels=(64, 64),
                 local_channels=(64, 128, 1024),
                 global_channels=(512, 256),
                 dropout_prob=0.3,
                 with_transform=True):
        """

        Args:
            in_channels (int): the number of input channels
            out_channels (int): the number of output channels
            stem_channels (tuple of int): the numbers of channels in stem feature extractor
            local_channels (tuple of int): the numbers of channels in local mlp
            global_channels (tuple of int): the numbers of channels in global mlp
            dropout_prob (float): the probability to dropout
            with_transform (bool): whether to use TNet to transform features.
        """
        super(PointNetCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.stem = Stem(in_channels,
                         stem_channels,
                         with_transform=with_transform)
        self.mlp_local = SharedMLP(stem_channels[-1], local_channels)
        self.mlp_global = MLP(local_channels[-1],
                              global_channels,
                              dropout=dropout_prob)
        # self.classifier = nn.Linear(global_channels[-1], out_channels, bias=True)
        self.classifier = nn.Sequential(
            FC(global_channels[-1], global_channels[-1]),
            nn.Linear(global_channels[-1], out_channels, bias=True))

        self.init_weights()
示例#14
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 num_class,
                 num_seg_class,
                 edge_conv_channels=((64, 64), (64, 64), (64, 64)),
                 inter_channels=1024,
                 global_channels=(256, 256, 128),
                 k=20,
                 dropout_prob=0.4,
                 with_transform=True):
        super(DGCNNPartSeg, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.k = k
        self.with_transform = with_transform
        self.num_gpu = torch.cuda.device_count()
        self.num_class = num_class

        # input transform
        if self.with_transform:
            self.transform_input = TNet(in_channels, in_channels, k=k)

        self.mlp_edge_conv = nn.ModuleList()
        for out in edge_conv_channels:
            self.mlp_edge_conv.append(EdgeConvBlock(in_channels, out, k))
            in_channels = out[-1]

        out_channel = edge_conv_channels[0][0]
        self.lable_conv = Conv2d(num_class, out_channel, [1, 1])

        mlplocal_input = sum([item[-1] for item in edge_conv_channels])
        self.mlp_local = Conv1d(mlplocal_input, inter_channels, 1)

        mlp_in_channels = inter_channels + edge_conv_channels[-1][-1] + sum(
            [item[-1] for item in edge_conv_channels])
        self.mlp_seg = SharedMLP(mlp_in_channels,
                                 global_channels[:-1],
                                 dropout=dropout_prob)
        self.conv_seg = Conv1d(global_channels[-2], global_channels[-1], 1)
        self.seg_logit = nn.Conv1d(global_channels[-1],
                                   num_seg_class,
                                   1,
                                   bias=True)

        self.init_weights()
        set_bn(self, momentum=0.01)
示例#15
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 num_centroids=(512, 128),
                 radius=(0.2, 0.4),
                 num_neighbours=(32, 64),
                 sa_channels=((64, 64, 128), (128, 128, 256)),
                 local_channels=(256, 512, 1024),
                 global_channels=(512, 256),
                 dropout_prob=0.5,
                 use_xyz=True):
        """

        Args:
            in_channels (int): the number of input channels
            out_channels (int): the number of semantics classes to predict over
            num_centroids (tuple of int): the numbers of centroids to sample in each set abstraction module
            radius (tuple of float): a tuple of radius to query neighbours in each set abstraction module
            num_neighbours (tuple of int): the numbers of neighbours to query for each centroid
            sa_channels (tuple of tuple of int): the numbers of channels to within each set abstraction module
            local_channels (tuple of int): the numbers of channels to extract local features after set abstraction
            global_channels (tuple of int): the numbers of channels to extract global features
            dropout_prob (float): the probability to dropout input features
            use_xyz (bool): whether or not to use the xyz position of a points as a feature

        """
        super(PointNet2SSGCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_xyz = use_xyz

        # sanity check
        num_layers = len(num_centroids)
        assert len(radius) == num_layers
        assert len(num_neighbours) == num_layers
        assert len(sa_channels) == num_layers

        feature_channels = in_channels - 3
        self.sa_modules = nn.ModuleList()
        for ind in range(num_layers):
            sa_module = PointNetSAModule(in_channels=feature_channels,
                                         mlp_channels=sa_channels[ind],
                                         num_centroids=num_centroids[ind],
                                         radius=radius[ind],
                                         num_neighbours=num_neighbours[ind],
                                         use_xyz=use_xyz)
            self.sa_modules.append(sa_module)
            feature_channels = sa_channels[ind][-1]

        if use_xyz:
            feature_channels += 3
        self.mlp_local = SharedMLP(feature_channels, local_channels, bn=True)
        self.mlp_global = MLP(local_channels[-1],
                              global_channels,
                              dropout=dropout_prob)
        self.classifier = nn.Linear(global_channels[-1],
                                    out_channels,
                                    bias=True)

        self.init_weights()
示例#16
0
    def __init__(self, in_channels, mlp_channels, num_neighbors):
        super(PointnetFPModule, self).__init__()

        self.in_channels = in_channels
        self.mlp = SharedMLP(in_channels, mlp_channels, ndim=1, bn=True)
        self.interpolator = FeatureInterpolator(num_neighbors)
示例#17
0
class PointNet2SSGCls(nn.Module):
    """PointNet2 with single-scale grouping for classification

    Structure: input -> [PointNetSA]s -> [MLP]s -> [MaxPooling] -> [MLP]s -> [Linear] -> logits
    Notice different with the original implementation, the last set abstraction is implemented as a local MLP.

    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 num_centroids=(512, 128),
                 radius=(0.2, 0.4),
                 num_neighbours=(32, 64),
                 sa_channels=((64, 64, 128), (128, 128, 256)),
                 local_channels=(256, 512, 1024),
                 global_channels=(512, 256),
                 dropout_prob=0.5,
                 use_xyz=True):
        """

        Args:
            in_channels (int): the number of input channels
            out_channels (int): the number of semantics classes to predict over
            num_centroids (tuple of int): the numbers of centroids to sample in each set abstraction module
            radius (tuple of float): a tuple of radius to query neighbours in each set abstraction module
            num_neighbours (tuple of int): the numbers of neighbours to query for each centroid
            sa_channels (tuple of tuple of int): the numbers of channels to within each set abstraction module
            local_channels (tuple of int): the numbers of channels to extract local features after set abstraction
            global_channels (tuple of int): the numbers of channels to extract global features
            dropout_prob (float): the probability to dropout input features
            use_xyz (bool): whether or not to use the xyz position of a points as a feature

        """
        super(PointNet2SSGCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_xyz = use_xyz

        # sanity check
        num_layers = len(num_centroids)
        assert len(radius) == num_layers
        assert len(num_neighbours) == num_layers
        assert len(sa_channels) == num_layers

        feature_channels = in_channels - 3
        self.sa_modules = nn.ModuleList()
        for ind in range(num_layers):
            sa_module = PointNetSAModule(in_channels=feature_channels,
                                         mlp_channels=sa_channels[ind],
                                         num_centroids=num_centroids[ind],
                                         radius=radius[ind],
                                         num_neighbours=num_neighbours[ind],
                                         use_xyz=use_xyz)
            self.sa_modules.append(sa_module)
            feature_channels = sa_channels[ind][-1]

        if use_xyz:
            feature_channels += 3
        self.mlp_local = SharedMLP(feature_channels, local_channels, bn=True)
        self.mlp_global = MLP(local_channels[-1],
                              global_channels,
                              dropout=dropout_prob)
        self.classifier = nn.Linear(global_channels[-1],
                                    out_channels,
                                    bias=True)

        self.init_weights()

    def forward(self, data_batch):
        points = data_batch["points"]
        end_points = {}

        # torch.Tensor.narrow; share same memory
        xyz = points.narrow(1, 0, 3)
        if points.size(1) > 3:
            feature = points.narrow(1, 3, points.size(1) - 3)
        else:
            feature = None

        for sa_module in self.sa_modules:
            xyz, feature = sa_module(xyz, feature)

        if self.use_xyz:
            feature = torch.cat([xyz, feature], dim=1)
        x = self.mlp_local(feature)
        x, max_indices = torch.max(x, 2)
        end_points['key_point_inds'] = max_indices
        x = self.mlp_global(x)

        cls_logit = self.classifier(x)

        preds = {'cls_logit': cls_logit}
        preds.update(end_points)

        return preds

    def init_weights(self):
        for sa_module in self.sa_modules:
            sa_module.init_weights(xavier_uniform)
        self.mlp_local.init_weights(xavier_uniform)
        self.mlp_global.init_weights(xavier_uniform)
        nn.init.xavier_uniform_(self.classifier.weight)
        nn.init.zeros_(self.classifier.bias)
        set_bn(self, momentum=0.01)
示例#18
0
class PointNet2MSGCls(nn.Module):
    """PointNet2 with multi-scale grouping for classification

    Structure: input -> [PointNetSA(MSG)]s -> [MLP]s -> MaxPooling -> [MLP]s -> [Linear] -> logits

    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 num_centroids=(512, 128),
                 radius_list=((0.1, 0.2, 0.4), (0.2, 0.4, 0.8)),
                 num_neighbours_list=((16, 32, 128), (32, 64, 128)),
                 sa_channels_list=(((32, 32, 64), (64, 64, 128), (64, 96,
                                                                  128)),
                                   ((64, 64, 128), (128, 128, 256), (128, 128,
                                                                     256))),
                 local_channels=(256, 512, 1024),
                 global_channels=(512, 256),
                 dropout_prob=0.5,
                 use_xyz=True):
        """Refer to PointNet2SSGCls. Major difference is that all the arguments will be a tuple of original types."""
        super(PointNet2MSGCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_xyz = use_xyz

        # sanity check
        num_layers = len(num_centroids)
        assert len(radius_list) == num_layers
        assert len(num_neighbours_list) == num_layers
        assert len(sa_channels_list) == num_layers

        feature_channels = in_channels - 3
        self.sa_modules = nn.ModuleList()
        for ind in range(num_layers):
            sa_module = PointNetSAModuleMSG(
                in_channels=feature_channels,
                mlp_channels_list=sa_channels_list[ind],
                num_centroids=num_centroids[ind],
                radius_list=radius_list[ind],
                num_neighbours_list=num_neighbours_list[ind],
                use_xyz=use_xyz)
            self.sa_modules.append(sa_module)
            feature_channels = sa_module.out_channels

        if use_xyz:
            feature_channels += 3
        self.mlp_local = SharedMLP(feature_channels, local_channels, bn=True)
        self.mlp_global = MLP(local_channels[-1],
                              global_channels,
                              dropout=dropout_prob)
        self.classifier = nn.Linear(global_channels[-1],
                                    out_channels,
                                    bias=True)

        self.init_weights()

    def forward(self, data_batch):
        point = data_batch["points"]
        end_points = {}

        # torch.Tensor.narrow; share same memory
        xyz = point.narrow(1, 0, 3)
        if point.size(1) > 3:
            feature = point.narrow(1, 3, point.size(1) - 3)
        else:
            feature = None

        for sa_module in self.sa_modules:
            xyz, feature = sa_module(xyz, feature)

        if self.use_xyz:
            x = torch.cat([xyz, feature], dim=1)
        else:
            x = feature
        x = self.mlp_local(x)
        x, max_indices = torch.max(x, 2)
        end_points['key_point_inds'] = max_indices
        x = self.mlp_global(x)

        cls_logit = self.classifier(x)

        preds = {'cls_logit': cls_logit}
        preds.update(end_points)

        return preds

    def init_weights(self):
        for sa_module in self.sa_modules:
            sa_module.init_weights(xavier_uniform)
        self.mlp_local.init_weights(xavier_uniform)
        self.mlp_global.init_weights(xavier_uniform)
        nn.init.xavier_uniform_(self.classifier.weight)
        nn.init.zeros_(self.classifier.bias)
        set_bn(self, momentum=0.01)
示例#19
0
class PointNetPartSeg(nn.Module):
    """PointNet for part segmentation

    References:
        https://github.com/charlesq34/pointnet/blob/master/part_seg/pointnet_part_seg.py

    """
    def __init__(self,
                 in_channels,
                 num_classes,
                 num_seg_classes,
                 stem_channels=(64, 128, 128),
                 local_channels=(512, 2048),
                 cls_channels=(256, 256),
                 seg_channels=(256, 256, 128),
                 dropout_prob_cls=0.3,
                 dropout_prob_seg=0.2,
                 with_transform=True):
        """

        Args:
           in_channels (int): the number of input channels
           out_channels (int): the number of output channels
           stem_channels (tuple of int): the numbers of channels in stem feature extractor
           local_channels (tuple of int): the numbers of channels in local mlp
           cls_channels (tuple of int): the numbers of channels in classification mlp
           seg_channels (tuple of int): the numbers of channels in segmentation mlp
           dropout_prob_cls (float): the probability to dropout in classification mlp
           dropout_prob_seg (float): the probability to dropout in segmentation mlp
           with_transform (bool): whether to use TNet to transform features.

        """
        super(PointNetPartSeg, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes
        self.num_seg_classes = num_seg_classes

        # stem
        self.stem = Stem(in_channels,
                         stem_channels,
                         with_transform=with_transform)
        self.mlp_local = SharedMLP(stem_channels[-1], local_channels)

        # classification
        # Notice that we apply dropout to each classification mlp.
        self.mlp_cls = MLP(local_channels[-1],
                           cls_channels,
                           dropout=dropout_prob_cls)
        self.cls_logit = nn.Linear(cls_channels[-1], num_classes, bias=True)

        # part segmentation
        # Notice that the original repo concatenates global feature, one hot class embedding,
        # stem features and local features. However, the paper does not use last local feature.
        # Here, we follow the released repo.
        in_channels_seg = local_channels[-1] + num_classes + sum(
            stem_channels) + sum(local_channels)
        self.mlp_seg = SharedMLP(in_channels_seg,
                                 seg_channels[:-1],
                                 dropout=dropout_prob_seg)
        self.conv_seg = Conv1d(seg_channels[-2], seg_channels[-1], 1)
        self.seg_logit = nn.Conv1d(seg_channels[-1],
                                   num_seg_classes,
                                   1,
                                   bias=True)

        self.init_weights()

    def forward(self, data_batch):
        x = data_batch["points"]
        cls_label = data_batch["cls_label"]
        num_points = x.shape[2]
        end_points = {}

        # stem
        stem_feature, end_points_stem = self.stem(x)
        end_points["trans_input"] = end_points_stem["trans_input"]
        end_points["trans_feature"] = end_points_stem["trans_feature"]
        stem_features = end_points_stem["stem_features"]

        # mlp for local features
        local_features = []
        x = stem_feature
        for ind, mlp in enumerate(self.mlp_local):
            x = mlp(x)
            local_features.append(x)

        # max pool over points
        global_feature, max_indices = torch.max(
            x, 2)  # (batch_size, local_channels[-1])
        end_points['key_point_inds'] = max_indices

        # classification
        x = global_feature
        x = self.mlp_cls(x)
        cls_logit = self.cls_logit(x)

        # segmentation
        global_feature_expand = global_feature.unsqueeze(2).expand(
            -1, -1, num_points)
        with torch.no_grad():
            I = torch.eye(self.num_classes,
                          dtype=global_feature.dtype,
                          device=global_feature.device)
            one_hot = I[cls_label]  # (batch_size, num_classes)
            one_hot_expand = one_hot.unsqueeze(2).expand(-1, -1, num_points)

        x = torch.cat(stem_features + local_features +
                      [global_feature_expand, one_hot_expand],
                      dim=1)
        x = self.mlp_seg(x)
        x = self.conv_seg(x)
        seg_logit = self.seg_logit(x)

        preds = {"cls_logit": cls_logit, "seg_logit": seg_logit}
        preds.update(end_points)

        return preds

    def init_weights(self):
        self.mlp_local.init_weights(xavier_uniform)
        self.mlp_cls.init_weights(xavier_uniform)
        self.mlp_seg.init_weights(xavier_uniform)
        self.conv_seg.init_weights(xavier_uniform)
        nn.init.xavier_uniform_(self.cls_logit.weight)
        nn.init.zeros_(self.cls_logit.bias)
        nn.init.xavier_uniform_(self.seg_logit.weight)
        nn.init.zeros_(self.seg_logit.bias)
        # Set batch normalization to 0.01 as default
        set_bn(self, momentum=0.01)
示例#20
0
class PointNet2MSGPartSeg(nn.Module):
    """ PointNet++ part segmentation with multi-scale grouping

    Refer to PointNet2SSGPartSeg

    """
    def __init__(self,
                 in_channels,
                 num_classes,
                 num_seg_classes,
                 num_centroids=(512, 128),
                 radius_list=((0.1, 0.2, 0.4), (0.4, 0.8)),
                 num_neighbours_list=((32, 64, 128), (64, 128)),
                 sa_channels_list=(((32, 32, 64), (64, 64, 128),
                                    (64, 96, 128)), ((128, 128, 256),
                                                     (128, 196, 256))),
                 local_channels=(256, 512, 1024),
                 fp_local_channels=(256, 256),
                 fp_channels=((256, 128), (128, 128)),
                 num_fp_neighbours=(3, 3),
                 seg_channels=(128, ),
                 dropout_prob=0.5,
                 use_xyz=True):
        super(PointNet2MSGPartSeg, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes
        self.num_seg_classes = num_seg_classes
        self.use_xyz = use_xyz

        # sanity check
        num_sa_layers = len(num_centroids)
        num_fp_layers = len(fp_channels)
        assert len(radius_list) == num_sa_layers
        assert len(num_neighbours_list) == num_sa_layers
        assert len(sa_channels_list) == num_sa_layers
        assert num_sa_layers == num_fp_layers
        assert len(num_fp_neighbours) == num_fp_layers

        # Set Abstraction Layers
        feature_channels = in_channels - 3
        self.sa_modules = nn.ModuleList()
        for ind in range(num_sa_layers):
            sa_module = PointNetSAModuleMSG(
                in_channels=feature_channels,
                mlp_channels_list=sa_channels_list[ind],
                num_centroids=num_centroids[ind],
                radius_list=radius_list[ind],
                num_neighbours_list=num_neighbours_list[ind],
                use_xyz=use_xyz)
            self.sa_modules.append(sa_module)
            feature_channels = sa_module.out_channels

        # Local Set Abstraction Layer
        if use_xyz:
            feature_channels += 3
        self.mlp_local = SharedMLP(feature_channels, local_channels, bn=True)

        inter_channels = [in_channels if use_xyz else in_channels - 3]
        inter_channels[0] += num_classes  # concat with one-hot
        inter_channels.extend(
            [sa_module.out_channels for sa_module in self.sa_modules])

        # Local Feature Propagation Layer
        self.mlp_local_fp = SharedMLP(local_channels[-1] + inter_channels[-1],
                                      fp_local_channels,
                                      bn=True)

        # Feature Propagation Layers
        self.fp_modules = nn.ModuleList()
        feature_channels = fp_local_channels[-1]
        for ind in range(num_fp_layers):
            fp_module = PointnetFPModule(in_channels=feature_channels +
                                         inter_channels[-2 - ind],
                                         mlp_channels=fp_channels[ind],
                                         num_neighbors=num_fp_neighbours[ind])
            self.fp_modules.append(fp_module)
            feature_channels = fp_channels[ind][-1]

        # MLP
        self.mlp_seg = SharedMLP(feature_channels,
                                 seg_channels,
                                 ndim=1,
                                 dropout=dropout_prob)
        self.seg_logit = nn.Conv1d(seg_channels[-1],
                                   num_seg_classes,
                                   1,
                                   bias=True)

        self.init_weights()

    def forward(self, data_batch):
        points = data_batch["points"]
        end_points = {}

        xyz = points.narrow(1, 0, 3)
        if points.size(1) > 3:
            feature = points.narrow(1, 3, points.size(1) - 3)
        else:
            feature = None

        # Save intermediate results
        inter_xyz = [xyz]
        inter_feature = [points if self.use_xyz else feature]

        # Create one hot class label
        num_points = points.size(2)
        with torch.no_grad():
            cls_label = data_batch["cls_label"]
            I = torch.eye(self.num_classes,
                          dtype=points.dtype,
                          device=points.device)
            one_hot = I[cls_label]
            one_hot_expand = one_hot.unsqueeze(2).expand(-1, -1, num_points)
            inter_feature[0] = torch.cat([inter_feature[0], one_hot_expand],
                                         dim=1)

        # Set Abstraction Layers
        for sa_module in self.sa_modules:
            xyz, feature = sa_module(xyz, feature)
            inter_xyz.append(xyz)
            inter_feature.append(feature)

        # Local Set Abstraction Layer
        if self.use_xyz:
            feature = torch.cat([xyz, feature], dim=1)
        feature = self.mlp_local(feature)
        global_feature, _ = torch.max(feature, 2)

        # Local Feature Propagation Layer
        global_feature_expand = global_feature.unsqueeze(2).expand(
            -1, -1, inter_xyz[-1].size(2))
        feature = torch.cat([global_feature_expand, inter_feature[-1]], dim=1)
        feature = self.mlp_local_fp(feature)

        # Feature Propagation Layers
        key_xyz = xyz
        key_feature = feature
        for fp_ind, fp_module in enumerate(self.fp_modules):
            query_xyz = inter_xyz[-2 - fp_ind]
            query_feature = inter_feature[-2 - fp_ind]
            fp_feature = fp_module(query_xyz, key_xyz, query_feature,
                                   key_feature)
            key_xyz = query_xyz
            key_feature = fp_feature

        # MLP
        x = self.mlp_seg(key_feature)
        seg_logit = self.seg_logit(x)

        preds = {"seg_logit": seg_logit}
        preds.update(end_points)

        return preds

    def init_weights(self):
        for sa_module in self.sa_modules:
            sa_module.init_weights(xavier_uniform)
        self.mlp_local.init_weights(xavier_uniform)
        self.mlp_local_fp.init_weights(xavier_uniform)
        for fp_module in self.fp_modules:
            fp_module.init_weights(xavier_uniform)
        self.mlp_seg.init_weights(xavier_uniform)
        nn.init.xavier_uniform_(self.seg_logit.weight)
        nn.init.zeros_(self.seg_logit.bias)
        set_bn(self, momentum=0.01)
示例#21
0
    def __init__(self,
                 in_channels,
                 num_classes,
                 num_seg_classes,
                 num_centroids=(512, 128),
                 radius_list=((0.1, 0.2, 0.4), (0.4, 0.8)),
                 num_neighbours_list=((32, 64, 128), (64, 128)),
                 sa_channels_list=(((32, 32, 64), (64, 64, 128),
                                    (64, 96, 128)), ((128, 128, 256),
                                                     (128, 196, 256))),
                 local_channels=(256, 512, 1024),
                 fp_local_channels=(256, 256),
                 fp_channels=((256, 128), (128, 128)),
                 num_fp_neighbours=(3, 3),
                 seg_channels=(128, ),
                 dropout_prob=0.5,
                 use_xyz=True):
        super(PointNet2MSGPartSeg, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes
        self.num_seg_classes = num_seg_classes
        self.use_xyz = use_xyz

        # sanity check
        num_sa_layers = len(num_centroids)
        num_fp_layers = len(fp_channels)
        assert len(radius_list) == num_sa_layers
        assert len(num_neighbours_list) == num_sa_layers
        assert len(sa_channels_list) == num_sa_layers
        assert num_sa_layers == num_fp_layers
        assert len(num_fp_neighbours) == num_fp_layers

        # Set Abstraction Layers
        feature_channels = in_channels - 3
        self.sa_modules = nn.ModuleList()
        for ind in range(num_sa_layers):
            sa_module = PointNetSAModuleMSG(
                in_channels=feature_channels,
                mlp_channels_list=sa_channels_list[ind],
                num_centroids=num_centroids[ind],
                radius_list=radius_list[ind],
                num_neighbours_list=num_neighbours_list[ind],
                use_xyz=use_xyz)
            self.sa_modules.append(sa_module)
            feature_channels = sa_module.out_channels

        # Local Set Abstraction Layer
        if use_xyz:
            feature_channels += 3
        self.mlp_local = SharedMLP(feature_channels, local_channels, bn=True)

        inter_channels = [in_channels if use_xyz else in_channels - 3]
        inter_channels[0] += num_classes  # concat with one-hot
        inter_channels.extend(
            [sa_module.out_channels for sa_module in self.sa_modules])

        # Local Feature Propagation Layer
        self.mlp_local_fp = SharedMLP(local_channels[-1] + inter_channels[-1],
                                      fp_local_channels,
                                      bn=True)

        # Feature Propagation Layers
        self.fp_modules = nn.ModuleList()
        feature_channels = fp_local_channels[-1]
        for ind in range(num_fp_layers):
            fp_module = PointnetFPModule(in_channels=feature_channels +
                                         inter_channels[-2 - ind],
                                         mlp_channels=fp_channels[ind],
                                         num_neighbors=num_fp_neighbours[ind])
            self.fp_modules.append(fp_module)
            feature_channels = fp_channels[ind][-1]

        # MLP
        self.mlp_seg = SharedMLP(feature_channels,
                                 seg_channels,
                                 ndim=1,
                                 dropout=dropout_prob)
        self.seg_logit = nn.Conv1d(seg_channels[-1],
                                   num_seg_classes,
                                   1,
                                   bias=True)

        self.init_weights()
示例#22
0
class TNet(nn.Module):
    """Transformation Network for DGCNN

    Structure: input -> [EdgeFeature] -> [EdgeConv]s -> [EdgePool] -> features -> [MLP] -> local features
    -> [MaxPool] -> global features -> [MLP] -> [Linear] -> logits

    Args:
        conv_channels (tuple of int): the numbers of channels of edge convolution layers
        k: the number of neareast neighbours for edge feature extractor

    """
    def __init__(self,
                 in_channels=3,
                 out_channels=3,
                 conv_channels=(64, 128),
                 local_channels=(1024, ),
                 global_channels=(512, 256),
                 k=20):
        super(TNet, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.k = k

        self.edge_conv = SharedMLP(2 * in_channels, conv_channels, ndim=2)
        self.mlp_local = SharedMLP(conv_channels[-1], local_channels)
        self.mlp_global = MLP(local_channels[-1], global_channels)
        self.linear = nn.Linear(global_channels[-1],
                                self.in_channels * out_channels,
                                bias=True)

        self.init_weights()

    def forward(self, x):
        """TNet forward

        Args:
            x (torch.Tensor): (batch_size, in_channels, num_points)

        Returns:
            torch.Tensor: (batch_size, out_channels, in_channels)

        """
        x = get_edge_feature(
            x, self.k)  # (batch_size, 2 * in_channels, num_points, k)
        x = self.edge_conv(x)
        x, _ = torch.max(x, 3)  # (batch_size, edge_channels[-1], num_points)
        x = self.mlp_local(x)
        x, _ = torch.max(x, 2)  # (batch_size, local_channels[-1], num_points)
        x = self.mlp_global(x)
        x = self.linear(x)
        x = x.view(-1, self.out_channels, self.in_channels)
        I = torch.eye(self.out_channels, self.in_channels, device=x.device)
        x = x.add(I)  # broadcast first dimension
        return x

    def init_weights(self):
        self.edge_conv.init_weights(xavier_uniform)
        self.mlp_local.init_weights(xavier_uniform)
        self.mlp_global.init_weights(xavier_uniform)
        # Set linear transform be 0
        nn.init.zeros_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)
示例#23
0
class Stem(nn.Module):
    """Stem (main body or stalk). Extract features from raw point clouds

    Structure: input (-> [TNet] -> transform_input) -> [MLP] -> features (-> [TNet] -> transform_feature)

    Attributes:
        with_transform: whether to use TNet

    """
    def __init__(self,
                 in_channels,
                 stem_channels=(64, 64),
                 with_transform=True,
                 bn=True):
        super(Stem, self).__init__()

        self.in_channels = in_channels
        self.out_channels = stem_channels[-1]
        self.with_transform = with_transform

        # feature stem
        self.mlp = SharedMLP(in_channels, stem_channels, bn=bn)
        self.mlp.init_weights(xavier_uniform)

        if self.with_transform:
            # input transform
            self.transform_input = TNet(in_channels, in_channels)
            # feature transform
            self.transform_feature = TNet(self.out_channels, self.out_channels)

    def forward(self, x):
        """PointNet Stem forward

        Args:
            x (torch.Tensor): (batch_size, in_channels, num_points)

        Returns:
            torch.Tensor: (batch_size, stem_channels[-1], num_points)
            dict (optional non-empty):
                trans_input: (batch_size, in_channels, in_channels)
                trans_feature: (batch_size, stem_channels[-1], stem_channels[-1])

        """
        end_points = {}

        # input transform
        if self.with_transform:
            trans_input = self.transform_input(x)
            x = torch.bmm(trans_input, x)
            end_points["trans_input"] = trans_input

        # feature
        x = self.mlp(x)

        # feature transform
        if self.with_transform:
            trans_feature = self.transform_feature(x)
            x = torch.bmm(trans_feature, x)
            end_points["trans_feature"] = trans_feature

        return x, end_points
示例#24
0
class PointNetCls(nn.Module):
    """PointNet for classification

    Structure: input -> [Stem] -> features -> [SharedMLP] -> local features
    -> [MaxPool] -> gloal features -> [MLP] -> [Linear] -> logits

    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 stem_channels=(64, 64),
                 local_channels=(64, 128, 1024),
                 global_channels=(512, 256),
                 dropout_prob=0.3,
                 with_transform=True):
        """

        Args:
            in_channels (int): the number of input channels
            out_channels (int): the number of output channels
            stem_channels (tuple of int): the numbers of channels in stem feature extractor
            local_channels (tuple of int): the numbers of channels in local mlp
            global_channels (tuple of int): the numbers of channels in global mlp
            dropout_prob (float): the probability to dropout
            with_transform (bool): whether to use TNet to transform features.
        """
        super(PointNetCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.stem = Stem(in_channels,
                         stem_channels,
                         with_transform=with_transform)
        self.mlp_local = SharedMLP(stem_channels[-1], local_channels)
        self.mlp_global = MLP(local_channels[-1],
                              global_channels,
                              dropout=dropout_prob)
        # self.classifier = nn.Linear(global_channels[-1], out_channels, bias=True)
        self.classifier = nn.Sequential(
            FC(global_channels[-1], global_channels[-1]),
            nn.Linear(global_channels[-1], out_channels, bias=True))

        self.init_weights()

    def forward(self, data_batch):
        x = data_batch["points"]

        # stem
        x, end_points = self.stem(x)
        # mlp for local features
        x = self.mlp_local(x)
        # max pool over points
        x, max_indices = torch.max(x, 2)
        end_points["key_point_inds"] = max_indices
        # mlp for global features
        x = self.mlp_global(x)
        x = self.classifier(x)

        preds = {"cls_logit": x}
        preds.update(end_points)

        return preds

    def init_weights(self):
        # Default initialization in original implementation
        self.mlp_local.init_weights(xavier_uniform)
        self.mlp_global.init_weights(xavier_uniform)
        nn.init.xavier_uniform_(self.classifier.weight)
        nn.init.zeros_(self.classifier.bias)
        # Set batch normalization to 0.01 as default
        set_bn(self, momentum=0.01)