def __init__(self, in_channels, num_classes, num_seg_classes, stem_channels=(64, 128, 128), local_channels=(512, 2048), cls_channels=(256, 256), seg_channels=(256, 256, 128), dropout_prob_cls=0.3, dropout_prob_seg=0.2, with_transform=True): """ Args: in_channels (int): the number of input channels out_channels (int): the number of output channels stem_channels (tuple of int): the numbers of channels in stem feature extractor local_channels (tuple of int): the numbers of channels in local mlp cls_channels (tuple of int): the numbers of channels in classification mlp seg_channels (tuple of int): the numbers of channels in segmentation mlp dropout_prob_cls (float): the probability to dropout in classification mlp dropout_prob_seg (float): the probability to dropout in segmentation mlp with_transform (bool): whether to use TNet to transform features. """ super(PointNetPartSeg, self).__init__() self.in_channels = in_channels self.num_classes = num_classes self.num_seg_classes = num_seg_classes # stem self.stem = Stem(in_channels, stem_channels, with_transform=with_transform) self.mlp_local = SharedMLP(stem_channels[-1], local_channels) # classification # Notice that we apply dropout to each classification mlp. self.mlp_cls = MLP(local_channels[-1], cls_channels, dropout=dropout_prob_cls) self.cls_logit = nn.Linear(cls_channels[-1], num_classes, bias=True) # part segmentation # Notice that the original repo concatenates global feature, one hot class embedding, # stem features and local features. However, the paper does not use last local feature. # Here, we follow the released repo. in_channels_seg = local_channels[-1] + num_classes + sum( stem_channels) + sum(local_channels) self.mlp_seg = SharedMLP(in_channels_seg, seg_channels[:-1], dropout=dropout_prob_seg) self.conv_seg = Conv1d(seg_channels[-2], seg_channels[-1], 1) self.seg_logit = nn.Conv1d(seg_channels[-1], num_seg_classes, 1, bias=True) self.init_weights()
class EdgeConvBlock(nn.Module): """EdgeConv Block Structure: point features -> [get_edge_feature] -> edge features -> [MLP(2d)] -> [MaxPool] -> point features """ def __init__(self, in_channels, out_channels, k): super(EdgeConvBlock, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.k = k self.mlp = SharedMLP(2 * in_channels, out_channels, ndim=2) def forward(self, x): x = get_edge_feature(x, self.k) x = self.mlp(x) x, _ = torch.max(x, 3) return x def init_weights(self, init_fn=None): self.mlp.init_weights(init_fn)
def __init__(self, in_channels, out_channels, k): super(EdgeConvBlock, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.k = k self.mlp = SharedMLP(2 * in_channels, out_channels, ndim=2)
class TNet(nn.Module): """Transformation Network. The structure is similar with PointNet""" def __init__(self, in_channels=3, out_channels=3, local_channels=(64, 128, 1024), global_channels=(512, 256)): super(TNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels # local features self.mlp_local = SharedMLP(in_channels, local_channels) # global features self.mlp_global = MLP(local_channels[-1], global_channels) # linear output self.linear = nn.Linear(global_channels[-1], in_channels * out_channels, bias=True) self.init_weights() def forward(self, x): """TNet forward Args: x (torch.Tensor): (batch_size, in_channels, num_points) Returns: torch.Tensor: (batch_size, out_channels, in_channels) """ x = self.mlp_local(x) # (batch_size, local_channels[-1], num_points) x, _ = torch.max(x, 2) # (batch_size, local_channels[-1]) x = self.mlp_global(x) x = self.linear(x) x = x.view(-1, self.out_channels, self.in_channels) I = torch.eye(self.out_channels, self.in_channels, dtype=x.dtype, device=x.device) x = x.add(I) # broadcast add return x def init_weights(self): # Default initialization in original implementation self.mlp_local.init_weights(xavier_uniform) self.mlp_global.init_weights(xavier_uniform) # Initialize linear transform to 0 nn.init.zeros_(self.linear.weight) nn.init.zeros_(self.linear.bias)
def __init__(self, in_channels, mlp_channels, num_centroids, radius, num_neighbours, use_xyz): super(PointNetSAModule, self).__init__() self.in_channels = in_channels self.out_channels = mlp_channels[-1] if use_xyz: in_channels += 3 self.mlp = SharedMLP(in_channels, mlp_channels, ndim=2, bn=True) self.sampler = FarthestPointSampler(num_centroids) self.grouper = QueryGrouper(radius, num_neighbours, use_xyz=use_xyz)
def __init__(self, in_channels, out_channels, num_centroids=(512, 128), radius_list=((0.1, 0.2, 0.4), (0.2, 0.4, 0.8)), num_neighbours_list=((16, 32, 128), (32, 64, 128)), sa_channels_list=(((32, 32, 64), (64, 64, 128), (64, 96, 128)), ((64, 64, 128), (128, 128, 256), (128, 128, 256))), local_channels=(256, 512, 1024), global_channels=(512, 256), dropout_prob=0.5, use_xyz=True): """Refer to PointNet2SSGCls. Major difference is that all the arguments will be a tuple of original types.""" super(PointNet2MSGCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.use_xyz = use_xyz # sanity check num_layers = len(num_centroids) assert len(radius_list) == num_layers assert len(num_neighbours_list) == num_layers assert len(sa_channels_list) == num_layers feature_channels = in_channels - 3 self.sa_modules = nn.ModuleList() for ind in range(num_layers): sa_module = PointNetSAModuleMSG( in_channels=feature_channels, mlp_channels_list=sa_channels_list[ind], num_centroids=num_centroids[ind], radius_list=radius_list[ind], num_neighbours_list=num_neighbours_list[ind], use_xyz=use_xyz) self.sa_modules.append(sa_module) feature_channels = sa_module.out_channels if use_xyz: feature_channels += 3 self.mlp_local = SharedMLP(feature_channels, local_channels, bn=True) self.mlp_global = MLP(local_channels[-1], global_channels, dropout=dropout_prob) self.classifier = nn.Linear(global_channels[-1], out_channels, bias=True) self.init_weights()
def __init__(self, in_channels, mlp_channels_list, num_centroids, radius_list, num_neighbours_list, use_xyz): super(PointNetSAModuleMSG, self).__init__() self.in_channels = in_channels self.out_channels = sum(mlp_channels[-1] for mlp_channels in mlp_channels_list) num_scales = len(mlp_channels_list) assert len(radius_list) == num_scales assert len(num_neighbours_list) == num_scales if use_xyz: in_channels += 3 self.mlp = nn.ModuleList() self.sampler = FarthestPointSampler(num_centroids) self.grouper = nn.ModuleList() for ind in range(num_scales): self.mlp.append( SharedMLP(in_channels, mlp_channels_list[ind], ndim=2, bn=True)) self.grouper.append( QueryGrouper(radius_list[ind], num_neighbours_list[ind], use_xyz=use_xyz))
class PointnetFPModule(nn.Module): """PointNet feature propagation module""" def __init__(self, in_channels, mlp_channels, num_neighbors): super(PointnetFPModule, self).__init__() self.in_channels = in_channels self.mlp = SharedMLP(in_channels, mlp_channels, ndim=1, bn=True) self.interpolator = FeatureInterpolator(num_neighbors) def forward(self, query_xyz, key_xyz, query_feature, key_feature): new_feature = self.interpolator(query_xyz, key_xyz, query_feature, key_feature) new_feature = self.mlp(new_feature) return new_feature def init_weights(self, init_fn=None): self.mlp.init_weights(init_fn)
class PointNetSAModule(nn.Module): """PointNet set abstraction module""" def __init__(self, in_channels, mlp_channels, num_centroids, radius, num_neighbours, use_xyz): super(PointNetSAModule, self).__init__() self.in_channels = in_channels self.out_channels = mlp_channels[-1] if use_xyz: in_channels += 3 self.mlp = SharedMLP(in_channels, mlp_channels, ndim=2, bn=True) self.sampler = FarthestPointSampler(num_centroids) self.grouper = QueryGrouper(radius, num_neighbours, use_xyz=use_xyz) def forward(self, xyz, feature=None): """ Args: xyz (torch.Tensor): (batch_size, 3, num_points) xyz coordinates of feature feature (torch.Tensor, optional): (batch_size, in_channels, num_points) Returns: new_xyz (torch.Tensor): (batch_size, 3, num_centroids) new_feature (torch.Tensor): (batch_size, out_channels, num_centroids) """ # sample new points index = self.sampler(xyz) # (batch_size, 3, num_centroids) new_xyz = _F.gather_points(xyz, index) # (batch_size, in_channels, num_centroids, num_neighbours) new_feature = self.grouper(new_xyz, xyz, feature) new_feature = self.mlp(new_feature) new_feature, _ = torch.max(new_feature, 3) return new_xyz, new_feature def init_weights(self, init_fn=None): self.mlp.init_weights(init_fn)
def __init__(self, in_channels, stem_channels=(64, 64), with_transform=True, bn=True): super(Stem, self).__init__() self.in_channels = in_channels self.out_channels = stem_channels[-1] self.with_transform = with_transform # feature stem self.mlp = SharedMLP(in_channels, stem_channels, bn=bn) self.mlp.init_weights(xavier_uniform) if self.with_transform: # input transform self.transform_input = TNet(in_channels, in_channels) # feature transform self.transform_feature = TNet(self.out_channels, self.out_channels)
def __init__(self, in_channels=3, out_channels=3, conv_channels=(64, 128), local_channels=(1024, ), global_channels=(512, 256), k=20): super(TNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.k = k self.edge_conv = SharedMLP(2 * in_channels, conv_channels, ndim=2) self.mlp_local = SharedMLP(conv_channels[-1], local_channels) self.mlp_global = MLP(local_channels[-1], global_channels) self.linear = nn.Linear(global_channels[-1], self.in_channels * out_channels, bias=True) self.init_weights()
def __init__(self, in_channels=3, out_channels=3, local_channels=(64, 128, 1024), global_channels=(512, 256)): super(TNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels # local features self.mlp_local = SharedMLP(in_channels, local_channels) # global features self.mlp_global = MLP(local_channels[-1], global_channels) # linear output self.linear = nn.Linear(global_channels[-1], in_channels * out_channels, bias=True) self.init_weights()
def __init__(self, in_channels, out_channels, stem_channels=(64, 64), local_channels=(64, 128, 1024), global_channels=(512, 256), dropout_prob=0.3, with_transform=True): """ Args: in_channels (int): the number of input channels out_channels (int): the number of output channels stem_channels (tuple of int): the numbers of channels in stem feature extractor local_channels (tuple of int): the numbers of channels in local mlp global_channels (tuple of int): the numbers of channels in global mlp dropout_prob (float): the probability to dropout with_transform (bool): whether to use TNet to transform features. """ super(PointNetCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.stem = Stem(in_channels, stem_channels, with_transform=with_transform) self.mlp_local = SharedMLP(stem_channels[-1], local_channels) self.mlp_global = MLP(local_channels[-1], global_channels, dropout=dropout_prob) # self.classifier = nn.Linear(global_channels[-1], out_channels, bias=True) self.classifier = nn.Sequential( FC(global_channels[-1], global_channels[-1]), nn.Linear(global_channels[-1], out_channels, bias=True)) self.init_weights()
def __init__(self, in_channels, out_channels, num_class, num_seg_class, edge_conv_channels=((64, 64), (64, 64), (64, 64)), inter_channels=1024, global_channels=(256, 256, 128), k=20, dropout_prob=0.4, with_transform=True): super(DGCNNPartSeg, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.k = k self.with_transform = with_transform self.num_gpu = torch.cuda.device_count() self.num_class = num_class # input transform if self.with_transform: self.transform_input = TNet(in_channels, in_channels, k=k) self.mlp_edge_conv = nn.ModuleList() for out in edge_conv_channels: self.mlp_edge_conv.append(EdgeConvBlock(in_channels, out, k)) in_channels = out[-1] out_channel = edge_conv_channels[0][0] self.lable_conv = Conv2d(num_class, out_channel, [1, 1]) mlplocal_input = sum([item[-1] for item in edge_conv_channels]) self.mlp_local = Conv1d(mlplocal_input, inter_channels, 1) mlp_in_channels = inter_channels + edge_conv_channels[-1][-1] + sum( [item[-1] for item in edge_conv_channels]) self.mlp_seg = SharedMLP(mlp_in_channels, global_channels[:-1], dropout=dropout_prob) self.conv_seg = Conv1d(global_channels[-2], global_channels[-1], 1) self.seg_logit = nn.Conv1d(global_channels[-1], num_seg_class, 1, bias=True) self.init_weights() set_bn(self, momentum=0.01)
def __init__(self, in_channels, out_channels, num_centroids=(512, 128), radius=(0.2, 0.4), num_neighbours=(32, 64), sa_channels=((64, 64, 128), (128, 128, 256)), local_channels=(256, 512, 1024), global_channels=(512, 256), dropout_prob=0.5, use_xyz=True): """ Args: in_channels (int): the number of input channels out_channels (int): the number of semantics classes to predict over num_centroids (tuple of int): the numbers of centroids to sample in each set abstraction module radius (tuple of float): a tuple of radius to query neighbours in each set abstraction module num_neighbours (tuple of int): the numbers of neighbours to query for each centroid sa_channels (tuple of tuple of int): the numbers of channels to within each set abstraction module local_channels (tuple of int): the numbers of channels to extract local features after set abstraction global_channels (tuple of int): the numbers of channels to extract global features dropout_prob (float): the probability to dropout input features use_xyz (bool): whether or not to use the xyz position of a points as a feature """ super(PointNet2SSGCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.use_xyz = use_xyz # sanity check num_layers = len(num_centroids) assert len(radius) == num_layers assert len(num_neighbours) == num_layers assert len(sa_channels) == num_layers feature_channels = in_channels - 3 self.sa_modules = nn.ModuleList() for ind in range(num_layers): sa_module = PointNetSAModule(in_channels=feature_channels, mlp_channels=sa_channels[ind], num_centroids=num_centroids[ind], radius=radius[ind], num_neighbours=num_neighbours[ind], use_xyz=use_xyz) self.sa_modules.append(sa_module) feature_channels = sa_channels[ind][-1] if use_xyz: feature_channels += 3 self.mlp_local = SharedMLP(feature_channels, local_channels, bn=True) self.mlp_global = MLP(local_channels[-1], global_channels, dropout=dropout_prob) self.classifier = nn.Linear(global_channels[-1], out_channels, bias=True) self.init_weights()
def __init__(self, in_channels, mlp_channels, num_neighbors): super(PointnetFPModule, self).__init__() self.in_channels = in_channels self.mlp = SharedMLP(in_channels, mlp_channels, ndim=1, bn=True) self.interpolator = FeatureInterpolator(num_neighbors)
class PointNet2SSGCls(nn.Module): """PointNet2 with single-scale grouping for classification Structure: input -> [PointNetSA]s -> [MLP]s -> [MaxPooling] -> [MLP]s -> [Linear] -> logits Notice different with the original implementation, the last set abstraction is implemented as a local MLP. """ def __init__(self, in_channels, out_channels, num_centroids=(512, 128), radius=(0.2, 0.4), num_neighbours=(32, 64), sa_channels=((64, 64, 128), (128, 128, 256)), local_channels=(256, 512, 1024), global_channels=(512, 256), dropout_prob=0.5, use_xyz=True): """ Args: in_channels (int): the number of input channels out_channels (int): the number of semantics classes to predict over num_centroids (tuple of int): the numbers of centroids to sample in each set abstraction module radius (tuple of float): a tuple of radius to query neighbours in each set abstraction module num_neighbours (tuple of int): the numbers of neighbours to query for each centroid sa_channels (tuple of tuple of int): the numbers of channels to within each set abstraction module local_channels (tuple of int): the numbers of channels to extract local features after set abstraction global_channels (tuple of int): the numbers of channels to extract global features dropout_prob (float): the probability to dropout input features use_xyz (bool): whether or not to use the xyz position of a points as a feature """ super(PointNet2SSGCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.use_xyz = use_xyz # sanity check num_layers = len(num_centroids) assert len(radius) == num_layers assert len(num_neighbours) == num_layers assert len(sa_channels) == num_layers feature_channels = in_channels - 3 self.sa_modules = nn.ModuleList() for ind in range(num_layers): sa_module = PointNetSAModule(in_channels=feature_channels, mlp_channels=sa_channels[ind], num_centroids=num_centroids[ind], radius=radius[ind], num_neighbours=num_neighbours[ind], use_xyz=use_xyz) self.sa_modules.append(sa_module) feature_channels = sa_channels[ind][-1] if use_xyz: feature_channels += 3 self.mlp_local = SharedMLP(feature_channels, local_channels, bn=True) self.mlp_global = MLP(local_channels[-1], global_channels, dropout=dropout_prob) self.classifier = nn.Linear(global_channels[-1], out_channels, bias=True) self.init_weights() def forward(self, data_batch): points = data_batch["points"] end_points = {} # torch.Tensor.narrow; share same memory xyz = points.narrow(1, 0, 3) if points.size(1) > 3: feature = points.narrow(1, 3, points.size(1) - 3) else: feature = None for sa_module in self.sa_modules: xyz, feature = sa_module(xyz, feature) if self.use_xyz: feature = torch.cat([xyz, feature], dim=1) x = self.mlp_local(feature) x, max_indices = torch.max(x, 2) end_points['key_point_inds'] = max_indices x = self.mlp_global(x) cls_logit = self.classifier(x) preds = {'cls_logit': cls_logit} preds.update(end_points) return preds def init_weights(self): for sa_module in self.sa_modules: sa_module.init_weights(xavier_uniform) self.mlp_local.init_weights(xavier_uniform) self.mlp_global.init_weights(xavier_uniform) nn.init.xavier_uniform_(self.classifier.weight) nn.init.zeros_(self.classifier.bias) set_bn(self, momentum=0.01)
class PointNet2MSGCls(nn.Module): """PointNet2 with multi-scale grouping for classification Structure: input -> [PointNetSA(MSG)]s -> [MLP]s -> MaxPooling -> [MLP]s -> [Linear] -> logits """ def __init__(self, in_channels, out_channels, num_centroids=(512, 128), radius_list=((0.1, 0.2, 0.4), (0.2, 0.4, 0.8)), num_neighbours_list=((16, 32, 128), (32, 64, 128)), sa_channels_list=(((32, 32, 64), (64, 64, 128), (64, 96, 128)), ((64, 64, 128), (128, 128, 256), (128, 128, 256))), local_channels=(256, 512, 1024), global_channels=(512, 256), dropout_prob=0.5, use_xyz=True): """Refer to PointNet2SSGCls. Major difference is that all the arguments will be a tuple of original types.""" super(PointNet2MSGCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.use_xyz = use_xyz # sanity check num_layers = len(num_centroids) assert len(radius_list) == num_layers assert len(num_neighbours_list) == num_layers assert len(sa_channels_list) == num_layers feature_channels = in_channels - 3 self.sa_modules = nn.ModuleList() for ind in range(num_layers): sa_module = PointNetSAModuleMSG( in_channels=feature_channels, mlp_channels_list=sa_channels_list[ind], num_centroids=num_centroids[ind], radius_list=radius_list[ind], num_neighbours_list=num_neighbours_list[ind], use_xyz=use_xyz) self.sa_modules.append(sa_module) feature_channels = sa_module.out_channels if use_xyz: feature_channels += 3 self.mlp_local = SharedMLP(feature_channels, local_channels, bn=True) self.mlp_global = MLP(local_channels[-1], global_channels, dropout=dropout_prob) self.classifier = nn.Linear(global_channels[-1], out_channels, bias=True) self.init_weights() def forward(self, data_batch): point = data_batch["points"] end_points = {} # torch.Tensor.narrow; share same memory xyz = point.narrow(1, 0, 3) if point.size(1) > 3: feature = point.narrow(1, 3, point.size(1) - 3) else: feature = None for sa_module in self.sa_modules: xyz, feature = sa_module(xyz, feature) if self.use_xyz: x = torch.cat([xyz, feature], dim=1) else: x = feature x = self.mlp_local(x) x, max_indices = torch.max(x, 2) end_points['key_point_inds'] = max_indices x = self.mlp_global(x) cls_logit = self.classifier(x) preds = {'cls_logit': cls_logit} preds.update(end_points) return preds def init_weights(self): for sa_module in self.sa_modules: sa_module.init_weights(xavier_uniform) self.mlp_local.init_weights(xavier_uniform) self.mlp_global.init_weights(xavier_uniform) nn.init.xavier_uniform_(self.classifier.weight) nn.init.zeros_(self.classifier.bias) set_bn(self, momentum=0.01)
class PointNetPartSeg(nn.Module): """PointNet for part segmentation References: https://github.com/charlesq34/pointnet/blob/master/part_seg/pointnet_part_seg.py """ def __init__(self, in_channels, num_classes, num_seg_classes, stem_channels=(64, 128, 128), local_channels=(512, 2048), cls_channels=(256, 256), seg_channels=(256, 256, 128), dropout_prob_cls=0.3, dropout_prob_seg=0.2, with_transform=True): """ Args: in_channels (int): the number of input channels out_channels (int): the number of output channels stem_channels (tuple of int): the numbers of channels in stem feature extractor local_channels (tuple of int): the numbers of channels in local mlp cls_channels (tuple of int): the numbers of channels in classification mlp seg_channels (tuple of int): the numbers of channels in segmentation mlp dropout_prob_cls (float): the probability to dropout in classification mlp dropout_prob_seg (float): the probability to dropout in segmentation mlp with_transform (bool): whether to use TNet to transform features. """ super(PointNetPartSeg, self).__init__() self.in_channels = in_channels self.num_classes = num_classes self.num_seg_classes = num_seg_classes # stem self.stem = Stem(in_channels, stem_channels, with_transform=with_transform) self.mlp_local = SharedMLP(stem_channels[-1], local_channels) # classification # Notice that we apply dropout to each classification mlp. self.mlp_cls = MLP(local_channels[-1], cls_channels, dropout=dropout_prob_cls) self.cls_logit = nn.Linear(cls_channels[-1], num_classes, bias=True) # part segmentation # Notice that the original repo concatenates global feature, one hot class embedding, # stem features and local features. However, the paper does not use last local feature. # Here, we follow the released repo. in_channels_seg = local_channels[-1] + num_classes + sum( stem_channels) + sum(local_channels) self.mlp_seg = SharedMLP(in_channels_seg, seg_channels[:-1], dropout=dropout_prob_seg) self.conv_seg = Conv1d(seg_channels[-2], seg_channels[-1], 1) self.seg_logit = nn.Conv1d(seg_channels[-1], num_seg_classes, 1, bias=True) self.init_weights() def forward(self, data_batch): x = data_batch["points"] cls_label = data_batch["cls_label"] num_points = x.shape[2] end_points = {} # stem stem_feature, end_points_stem = self.stem(x) end_points["trans_input"] = end_points_stem["trans_input"] end_points["trans_feature"] = end_points_stem["trans_feature"] stem_features = end_points_stem["stem_features"] # mlp for local features local_features = [] x = stem_feature for ind, mlp in enumerate(self.mlp_local): x = mlp(x) local_features.append(x) # max pool over points global_feature, max_indices = torch.max( x, 2) # (batch_size, local_channels[-1]) end_points['key_point_inds'] = max_indices # classification x = global_feature x = self.mlp_cls(x) cls_logit = self.cls_logit(x) # segmentation global_feature_expand = global_feature.unsqueeze(2).expand( -1, -1, num_points) with torch.no_grad(): I = torch.eye(self.num_classes, dtype=global_feature.dtype, device=global_feature.device) one_hot = I[cls_label] # (batch_size, num_classes) one_hot_expand = one_hot.unsqueeze(2).expand(-1, -1, num_points) x = torch.cat(stem_features + local_features + [global_feature_expand, one_hot_expand], dim=1) x = self.mlp_seg(x) x = self.conv_seg(x) seg_logit = self.seg_logit(x) preds = {"cls_logit": cls_logit, "seg_logit": seg_logit} preds.update(end_points) return preds def init_weights(self): self.mlp_local.init_weights(xavier_uniform) self.mlp_cls.init_weights(xavier_uniform) self.mlp_seg.init_weights(xavier_uniform) self.conv_seg.init_weights(xavier_uniform) nn.init.xavier_uniform_(self.cls_logit.weight) nn.init.zeros_(self.cls_logit.bias) nn.init.xavier_uniform_(self.seg_logit.weight) nn.init.zeros_(self.seg_logit.bias) # Set batch normalization to 0.01 as default set_bn(self, momentum=0.01)
class PointNet2MSGPartSeg(nn.Module): """ PointNet++ part segmentation with multi-scale grouping Refer to PointNet2SSGPartSeg """ def __init__(self, in_channels, num_classes, num_seg_classes, num_centroids=(512, 128), radius_list=((0.1, 0.2, 0.4), (0.4, 0.8)), num_neighbours_list=((32, 64, 128), (64, 128)), sa_channels_list=(((32, 32, 64), (64, 64, 128), (64, 96, 128)), ((128, 128, 256), (128, 196, 256))), local_channels=(256, 512, 1024), fp_local_channels=(256, 256), fp_channels=((256, 128), (128, 128)), num_fp_neighbours=(3, 3), seg_channels=(128, ), dropout_prob=0.5, use_xyz=True): super(PointNet2MSGPartSeg, self).__init__() self.in_channels = in_channels self.num_classes = num_classes self.num_seg_classes = num_seg_classes self.use_xyz = use_xyz # sanity check num_sa_layers = len(num_centroids) num_fp_layers = len(fp_channels) assert len(radius_list) == num_sa_layers assert len(num_neighbours_list) == num_sa_layers assert len(sa_channels_list) == num_sa_layers assert num_sa_layers == num_fp_layers assert len(num_fp_neighbours) == num_fp_layers # Set Abstraction Layers feature_channels = in_channels - 3 self.sa_modules = nn.ModuleList() for ind in range(num_sa_layers): sa_module = PointNetSAModuleMSG( in_channels=feature_channels, mlp_channels_list=sa_channels_list[ind], num_centroids=num_centroids[ind], radius_list=radius_list[ind], num_neighbours_list=num_neighbours_list[ind], use_xyz=use_xyz) self.sa_modules.append(sa_module) feature_channels = sa_module.out_channels # Local Set Abstraction Layer if use_xyz: feature_channels += 3 self.mlp_local = SharedMLP(feature_channels, local_channels, bn=True) inter_channels = [in_channels if use_xyz else in_channels - 3] inter_channels[0] += num_classes # concat with one-hot inter_channels.extend( [sa_module.out_channels for sa_module in self.sa_modules]) # Local Feature Propagation Layer self.mlp_local_fp = SharedMLP(local_channels[-1] + inter_channels[-1], fp_local_channels, bn=True) # Feature Propagation Layers self.fp_modules = nn.ModuleList() feature_channels = fp_local_channels[-1] for ind in range(num_fp_layers): fp_module = PointnetFPModule(in_channels=feature_channels + inter_channels[-2 - ind], mlp_channels=fp_channels[ind], num_neighbors=num_fp_neighbours[ind]) self.fp_modules.append(fp_module) feature_channels = fp_channels[ind][-1] # MLP self.mlp_seg = SharedMLP(feature_channels, seg_channels, ndim=1, dropout=dropout_prob) self.seg_logit = nn.Conv1d(seg_channels[-1], num_seg_classes, 1, bias=True) self.init_weights() def forward(self, data_batch): points = data_batch["points"] end_points = {} xyz = points.narrow(1, 0, 3) if points.size(1) > 3: feature = points.narrow(1, 3, points.size(1) - 3) else: feature = None # Save intermediate results inter_xyz = [xyz] inter_feature = [points if self.use_xyz else feature] # Create one hot class label num_points = points.size(2) with torch.no_grad(): cls_label = data_batch["cls_label"] I = torch.eye(self.num_classes, dtype=points.dtype, device=points.device) one_hot = I[cls_label] one_hot_expand = one_hot.unsqueeze(2).expand(-1, -1, num_points) inter_feature[0] = torch.cat([inter_feature[0], one_hot_expand], dim=1) # Set Abstraction Layers for sa_module in self.sa_modules: xyz, feature = sa_module(xyz, feature) inter_xyz.append(xyz) inter_feature.append(feature) # Local Set Abstraction Layer if self.use_xyz: feature = torch.cat([xyz, feature], dim=1) feature = self.mlp_local(feature) global_feature, _ = torch.max(feature, 2) # Local Feature Propagation Layer global_feature_expand = global_feature.unsqueeze(2).expand( -1, -1, inter_xyz[-1].size(2)) feature = torch.cat([global_feature_expand, inter_feature[-1]], dim=1) feature = self.mlp_local_fp(feature) # Feature Propagation Layers key_xyz = xyz key_feature = feature for fp_ind, fp_module in enumerate(self.fp_modules): query_xyz = inter_xyz[-2 - fp_ind] query_feature = inter_feature[-2 - fp_ind] fp_feature = fp_module(query_xyz, key_xyz, query_feature, key_feature) key_xyz = query_xyz key_feature = fp_feature # MLP x = self.mlp_seg(key_feature) seg_logit = self.seg_logit(x) preds = {"seg_logit": seg_logit} preds.update(end_points) return preds def init_weights(self): for sa_module in self.sa_modules: sa_module.init_weights(xavier_uniform) self.mlp_local.init_weights(xavier_uniform) self.mlp_local_fp.init_weights(xavier_uniform) for fp_module in self.fp_modules: fp_module.init_weights(xavier_uniform) self.mlp_seg.init_weights(xavier_uniform) nn.init.xavier_uniform_(self.seg_logit.weight) nn.init.zeros_(self.seg_logit.bias) set_bn(self, momentum=0.01)
def __init__(self, in_channels, num_classes, num_seg_classes, num_centroids=(512, 128), radius_list=((0.1, 0.2, 0.4), (0.4, 0.8)), num_neighbours_list=((32, 64, 128), (64, 128)), sa_channels_list=(((32, 32, 64), (64, 64, 128), (64, 96, 128)), ((128, 128, 256), (128, 196, 256))), local_channels=(256, 512, 1024), fp_local_channels=(256, 256), fp_channels=((256, 128), (128, 128)), num_fp_neighbours=(3, 3), seg_channels=(128, ), dropout_prob=0.5, use_xyz=True): super(PointNet2MSGPartSeg, self).__init__() self.in_channels = in_channels self.num_classes = num_classes self.num_seg_classes = num_seg_classes self.use_xyz = use_xyz # sanity check num_sa_layers = len(num_centroids) num_fp_layers = len(fp_channels) assert len(radius_list) == num_sa_layers assert len(num_neighbours_list) == num_sa_layers assert len(sa_channels_list) == num_sa_layers assert num_sa_layers == num_fp_layers assert len(num_fp_neighbours) == num_fp_layers # Set Abstraction Layers feature_channels = in_channels - 3 self.sa_modules = nn.ModuleList() for ind in range(num_sa_layers): sa_module = PointNetSAModuleMSG( in_channels=feature_channels, mlp_channels_list=sa_channels_list[ind], num_centroids=num_centroids[ind], radius_list=radius_list[ind], num_neighbours_list=num_neighbours_list[ind], use_xyz=use_xyz) self.sa_modules.append(sa_module) feature_channels = sa_module.out_channels # Local Set Abstraction Layer if use_xyz: feature_channels += 3 self.mlp_local = SharedMLP(feature_channels, local_channels, bn=True) inter_channels = [in_channels if use_xyz else in_channels - 3] inter_channels[0] += num_classes # concat with one-hot inter_channels.extend( [sa_module.out_channels for sa_module in self.sa_modules]) # Local Feature Propagation Layer self.mlp_local_fp = SharedMLP(local_channels[-1] + inter_channels[-1], fp_local_channels, bn=True) # Feature Propagation Layers self.fp_modules = nn.ModuleList() feature_channels = fp_local_channels[-1] for ind in range(num_fp_layers): fp_module = PointnetFPModule(in_channels=feature_channels + inter_channels[-2 - ind], mlp_channels=fp_channels[ind], num_neighbors=num_fp_neighbours[ind]) self.fp_modules.append(fp_module) feature_channels = fp_channels[ind][-1] # MLP self.mlp_seg = SharedMLP(feature_channels, seg_channels, ndim=1, dropout=dropout_prob) self.seg_logit = nn.Conv1d(seg_channels[-1], num_seg_classes, 1, bias=True) self.init_weights()
class TNet(nn.Module): """Transformation Network for DGCNN Structure: input -> [EdgeFeature] -> [EdgeConv]s -> [EdgePool] -> features -> [MLP] -> local features -> [MaxPool] -> global features -> [MLP] -> [Linear] -> logits Args: conv_channels (tuple of int): the numbers of channels of edge convolution layers k: the number of neareast neighbours for edge feature extractor """ def __init__(self, in_channels=3, out_channels=3, conv_channels=(64, 128), local_channels=(1024, ), global_channels=(512, 256), k=20): super(TNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.k = k self.edge_conv = SharedMLP(2 * in_channels, conv_channels, ndim=2) self.mlp_local = SharedMLP(conv_channels[-1], local_channels) self.mlp_global = MLP(local_channels[-1], global_channels) self.linear = nn.Linear(global_channels[-1], self.in_channels * out_channels, bias=True) self.init_weights() def forward(self, x): """TNet forward Args: x (torch.Tensor): (batch_size, in_channels, num_points) Returns: torch.Tensor: (batch_size, out_channels, in_channels) """ x = get_edge_feature( x, self.k) # (batch_size, 2 * in_channels, num_points, k) x = self.edge_conv(x) x, _ = torch.max(x, 3) # (batch_size, edge_channels[-1], num_points) x = self.mlp_local(x) x, _ = torch.max(x, 2) # (batch_size, local_channels[-1], num_points) x = self.mlp_global(x) x = self.linear(x) x = x.view(-1, self.out_channels, self.in_channels) I = torch.eye(self.out_channels, self.in_channels, device=x.device) x = x.add(I) # broadcast first dimension return x def init_weights(self): self.edge_conv.init_weights(xavier_uniform) self.mlp_local.init_weights(xavier_uniform) self.mlp_global.init_weights(xavier_uniform) # Set linear transform be 0 nn.init.zeros_(self.linear.weight) nn.init.zeros_(self.linear.bias)
class Stem(nn.Module): """Stem (main body or stalk). Extract features from raw point clouds Structure: input (-> [TNet] -> transform_input) -> [MLP] -> features (-> [TNet] -> transform_feature) Attributes: with_transform: whether to use TNet """ def __init__(self, in_channels, stem_channels=(64, 64), with_transform=True, bn=True): super(Stem, self).__init__() self.in_channels = in_channels self.out_channels = stem_channels[-1] self.with_transform = with_transform # feature stem self.mlp = SharedMLP(in_channels, stem_channels, bn=bn) self.mlp.init_weights(xavier_uniform) if self.with_transform: # input transform self.transform_input = TNet(in_channels, in_channels) # feature transform self.transform_feature = TNet(self.out_channels, self.out_channels) def forward(self, x): """PointNet Stem forward Args: x (torch.Tensor): (batch_size, in_channels, num_points) Returns: torch.Tensor: (batch_size, stem_channels[-1], num_points) dict (optional non-empty): trans_input: (batch_size, in_channels, in_channels) trans_feature: (batch_size, stem_channels[-1], stem_channels[-1]) """ end_points = {} # input transform if self.with_transform: trans_input = self.transform_input(x) x = torch.bmm(trans_input, x) end_points["trans_input"] = trans_input # feature x = self.mlp(x) # feature transform if self.with_transform: trans_feature = self.transform_feature(x) x = torch.bmm(trans_feature, x) end_points["trans_feature"] = trans_feature return x, end_points
class PointNetCls(nn.Module): """PointNet for classification Structure: input -> [Stem] -> features -> [SharedMLP] -> local features -> [MaxPool] -> gloal features -> [MLP] -> [Linear] -> logits """ def __init__(self, in_channels, out_channels, stem_channels=(64, 64), local_channels=(64, 128, 1024), global_channels=(512, 256), dropout_prob=0.3, with_transform=True): """ Args: in_channels (int): the number of input channels out_channels (int): the number of output channels stem_channels (tuple of int): the numbers of channels in stem feature extractor local_channels (tuple of int): the numbers of channels in local mlp global_channels (tuple of int): the numbers of channels in global mlp dropout_prob (float): the probability to dropout with_transform (bool): whether to use TNet to transform features. """ super(PointNetCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.stem = Stem(in_channels, stem_channels, with_transform=with_transform) self.mlp_local = SharedMLP(stem_channels[-1], local_channels) self.mlp_global = MLP(local_channels[-1], global_channels, dropout=dropout_prob) # self.classifier = nn.Linear(global_channels[-1], out_channels, bias=True) self.classifier = nn.Sequential( FC(global_channels[-1], global_channels[-1]), nn.Linear(global_channels[-1], out_channels, bias=True)) self.init_weights() def forward(self, data_batch): x = data_batch["points"] # stem x, end_points = self.stem(x) # mlp for local features x = self.mlp_local(x) # max pool over points x, max_indices = torch.max(x, 2) end_points["key_point_inds"] = max_indices # mlp for global features x = self.mlp_global(x) x = self.classifier(x) preds = {"cls_logit": x} preds.update(end_points) return preds def init_weights(self): # Default initialization in original implementation self.mlp_local.init_weights(xavier_uniform) self.mlp_global.init_weights(xavier_uniform) nn.init.xavier_uniform_(self.classifier.weight) nn.init.zeros_(self.classifier.bias) # Set batch normalization to 0.01 as default set_bn(self, momentum=0.01)