class TNet(nn.Module): """Transformation Network for DGCNN Args: in_channels (int): the number of input channels out_channels (int): the number of output channels conv_channels (tuple of int): the numbers of channels of edge convolution layers local_channels (tuple of int): the numbers of channels in local mlp global_channels (tuple of int): the numbers of channels in global mlp k: the number of neareast neighbours for edge feature extractor """ def __init__(self, in_channels=3, out_channels=3, conv_channels=(64, 128), local_channels=(1024, ), global_channels=(512, 256), k=20): super(TNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.edge_conv = EdgeConvBlock(in_channels, conv_channels, k) self.mlp_local = SharedMLP(conv_channels[-1], local_channels) self.mlp_global = MLP(local_channels[-1], global_channels) self.linear = nn.Linear(global_channels[-1], self.in_channels * self.out_channels, bias=True) self.reset_parameters() def forward(self, x): # input x: (batch_size, in_channels, num_points) x = self.edge_conv(x) # (batch_size, edge_channels[-1], num_points) x = self.mlp_local(x) # (batch_size, local_channels[-1], num_points) x, _ = torch.max(x, 2) x = self.mlp_global(x) x = self.linear(x) x = x.view(-1, self.out_channels, self.in_channels) I = torch.eye(self.out_channels, self.in_channels, device=x.device) x = x.add(I) # broadcast first dimension return x def reset_parameters(self, init_fn=xavier_uniform): self.edge_conv.reset_parameters(init_fn) self.mlp_local.reset_parameters(init_fn) self.mlp_global.reset_parameters(init_fn) # set linear transform be 0 nn.init.zeros_(self.linear.weight) nn.init.zeros_(self.linear.bias)
class TNet(nn.Module): """Transformation Network. The structure is similar with PointNet""" def __init__(self, in_channels=3, out_channels=3, local_channels=(64, 128, 1024), global_channels=(512, 256)): super(TNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels # local features self.mlp_local = SharedMLP(in_channels, local_channels) # global features self.mlp_global = MLP(local_channels[-1], global_channels) # linear output self.linear = nn.Linear(global_channels[-1], in_channels * out_channels, bias=True) self.reset_parameters() def forward(self, x): # x: (batch_size, in_channels, num_points) x = self.mlp_local(x) # (batch_size, local_channels[-1], num_points) x, _ = torch.max(x, 2) # (batch_size, local_channels[-1]) x = self.mlp_global(x) x = self.linear(x) x = x.view(-1, self.out_channels, self.in_channels) I = torch.eye(self.out_channels, self.in_channels, dtype=x.dtype, device=x.device) x = x.add(I) # broadcast add, (batch_size, out_channels, in_channels) return x def reset_parameters(self, init_fn=xavier_uniform): # Default initialization in original implementation self.mlp_local.reset_parameters(init_fn) self.mlp_global.reset_parameters(init_fn) # Initialize linear transform to 0 nn.init.zeros_(self.linear.weight) nn.init.zeros_(self.linear.bias)
class PointNet2SSGCls(nn.Module): """PointNet2 with single-scale grouping for classification' Args: in_channels (int): the number of input channels out_channels (int): the number of semantics classes to predict over num_centroids (tuple of int): the numbers of centroids to sample in each set abstraction module radius (tuple of float): a tuple of radius to query neighbours in each set abstraction module num_neighbours (tuple of int): the numbers of neighbours to query for each centroid sa_channels (tuple of tuple of int): the numbers of channels to within each set abstraction module global_channels (tuple of int): the numbers of channels to extract global features dropout_prob (float): the probability to dropout input features use_xyz (bool): whether or not to use the xyz position of a points as a feature Notes: 1. num_centroids == -1: use all points; num_centroids == 0: use the origin. 2. radius * num_neighbours > 0. """ def __init__(self, in_channels, out_channels, num_centroids=(512, 128, 0), radius=(0.2, 0.4, -1.0), num_neighbours=(32, 64, -1), sa_channels=((64, 64, 128), (128, 128, 256), (256, 512, 1024)), global_channels=(512, 256), dropout_prob=0.5, use_xyz=True): super(PointNet2SSGCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.use_xyz = use_xyz # sanity check num_sa_layers = len(num_centroids) assert len(radius) == num_sa_layers assert len(num_neighbours) == num_sa_layers assert len(sa_channels) == num_sa_layers feature_channels = in_channels - 3 self.sa_modules = nn.ModuleList() for ind in range(num_sa_layers): sa_module = PointNetSAModule(in_channels=feature_channels, mlp_channels=sa_channels[ind], num_centroids=num_centroids[ind], radius=radius[ind], num_neighbours=num_neighbours[ind], use_xyz=use_xyz) self.sa_modules.append(sa_module) feature_channels = sa_channels[ind][-1] self.mlp_global = MLP(feature_channels, global_channels, dropout_prob=dropout_prob) self.classifier = nn.Linear(global_channels[-1], out_channels, bias=True) self.reset_parameters() def forward(self, data_batch): points = data_batch['points'] end_points = {} # torch.Tensor.narrow; share same memory xyz = points.narrow(1, 0, 3) # equivalent to points[:, 0:3, :] if points.size(1) > 3: feature = points.narrow(1, 3, points.size(1) - 3) else: feature = None for sa_module in self.sa_modules: xyz, feature = sa_module(xyz, feature) x, max_indices = torch.max(feature, 2) end_points['key_point_indices'] = max_indices x = self.mlp_global(x) cls_logit = self.classifier(x) preds = {'cls_logit': cls_logit} preds.update(end_points) return preds def reset_parameters(self): for sa_module in self.sa_modules: sa_module.reset_parameters(xavier_uniform) self.mlp_global.reset_parameters(xavier_uniform) xavier_uniform(self.classifier) set_bn(self, momentum=0.01)
class PointNetCls(nn.Module): """PointNet classification model Args: in_channels (int): the number of input channels out_channels (int): the number of output channels stem_channels (tuple of int): the numbers of channels in stem feature extractor local_channels (tuple of int): the numbers of channels in local mlp global_channels (tuple of int): the numbers of channels in global mlp dropout_prob (float): the probability to dropout with_transform (bool): whether to use TNet to transform features. """ def __init__(self, in_channels, out_channels, stem_channels=(16, 32), local_channels=(32, 32), global_channels=(32, 32), dropout_prob=0.5, with_transform=False): super(PointNetCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.stem = Stem(in_channels, stem_channels, with_transform=with_transform) self.mlp_local = SharedMLP(stem_channels[-1], local_channels) self.mlp_global = MLP(local_channels[-1], global_channels, dropout_prob=dropout_prob) self.classifier = nn.Linear(global_channels[-1], out_channels, bias=True) self.classifier2 = nn.Linear(global_channels[-1], 2, bias=True) self.reset_parameters() def forward(self, points): #x = data_batch['points'] #x = data_batch['box_points'] x = points # stem x, end_points = self.stem(x) # mlp for local features x = self.mlp_local(x) # max pool over points x, max_indices = torch.max(x, 2) end_points['key_point_indices'] = max_indices # mlp for global features x = self.mlp_global(x) x1 = self.classifier(x) x2 = self.classifier2(x) preds = { 'cls_logit': x1, 'node_logit': x2, } preds.update(end_points) return preds def reset_parameters(self): # default initialization in original implementation self.mlp_local.reset_parameters(xavier_uniform) self.mlp_global.reset_parameters(xavier_uniform) xavier_uniform(self.classifier) xavier_uniform(self.classifier2) # set batch normalization to 0.01 as default set_bn(self, momentum=0.01)
class DGCNNCls(nn.Module): """DGCNN for classification Args: in_channels (int): the number of input channels out_channels (int): the number of output channels edge_conv_channels (tuple of int): the numbers of channels of edge convolution layers inter_channels (int): the number of channels of intermediate features global_channels (tuple of int): the numbers of channels in global mlp k (int): the number of neareast neighbours for edge feature extractor dropout_prob (float): the probability to dropout with_transform (bool): whether to use TNet to transform features. """ def __init__(self, in_channels, out_channels, edge_conv_channels=(64, 64, 64, 128), local_channels=(1024, ), global_channels=(512, 256), k=20, dropout_prob=0.5, with_transform=True): super(DGCNNCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.with_transform = with_transform # input transform if self.with_transform: self.transform_input = TNet(in_channels, in_channels, k=k) self.edge_convs = nn.ModuleList() inter_channels = [] for conv_channels in edge_conv_channels: if isinstance(conv_channels, int): conv_channels = [conv_channels] else: assert isinstance(conv_channels, (tuple, list)) self.edge_convs.append(EdgeConvBlock(in_channels, conv_channels, k)) inter_channels.append(conv_channels[-1]) in_channels = conv_channels[-1] self.mlp_local = SharedMLP(sum(inter_channels), local_channels) self.mlp_global = MLP(local_channels[-1], global_channels, dropout_prob=dropout_prob) self.classifier = nn.Linear(global_channels[-1], self.out_channels, bias=True) self.reset_parameters() def forward(self, data_batch): x = data_batch['points'] end_points = {} # input transform if self.with_transform: trans_input = self.transform_input(x) x = torch.bmm(trans_input, x) end_points['trans_input'] = trans_input # EdgeConv features = [] for edge_conv in self.edge_convs: x = edge_conv(x) features.append(x) x = torch.cat(features, dim=1) x = self.mlp_local(x) x, max_indices = torch.max(x, 2) end_points['key_point_indices'] = max_indices x = self.mlp_global(x) x = self.classifier(x) preds = { 'cls_logit': x, } preds.update(end_points) return preds def reset_parameters(self): for edge_conv in self.edge_convs: edge_conv.reset_parameters(xavier_uniform) self.mlp_local.reset_parameters(xavier_uniform) self.mlp_global.reset_parameters(xavier_uniform) xavier_uniform(self.classifier) set_bn(self, momentum=0.01)
class PointNetCls(nn.Module): """PointNet classification model Args: in_channels (int): the number of input channels out_channels (int): the number of output channels stem_channels (tuple of int): the numbers of channels in stem feature extractor local_channels (tuple of int): the numbers of channels in local mlp global_channels (tuple of int): the numbers of channels in global mlp dropout_prob (float): the probability to dropout with_transform (bool): whether to backboneuse TNet to transform features. """ def __init__(self, in_channels, out_channels, stem_channels=(128, 128), local_channels=(128, 256, 256), global_channels=(128, 128), dropout_prob=0.3, with_transform=False): super(PointNetCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels p1 = 128 stem_channels1 = ((p1, p1, int(2 * p1)), (int(2 * p1), int(2 * p1), int(2 * p1))) p2 = int(p1 / 4) stem_channels2 = ((p2, p2, int(2 * p2)), (int(2 * p2), int(2 * p2), int(2 * p2))) p3 = int(p1 / 8) stem_channels3 = ((p3, p3, int(2 * p3)), (int(2 * p3), int(2 * p3), int(2 * p3))) global_channels1 = (p1 * 2, p1) global_channels2 = (p2 * 2, p2) global_channels3 = (p3 * 2, p3) self.p1 = PointNet2SSGCls(sa_channels=stem_channels1, global_channels=global_channels1) self.classifier1 = nn.Linear(global_channels[-1], out_channels, bias=True) self.p2 = PointNet2SSGCls(sa_channels=stem_channels1, global_channels=global_channels1) self.classifier2 = nn.Linear(global_channels[-1], out_channels, bias=True) self.p3 = PointNet2SSGCls(sa_channels=stem_channels2, global_channels=global_channels2) self.classifier3 = nn.Linear(int(global_channels[-1] / 4), int(global_channels[-1] / 4), bias=True) self.mlp_global22 = MLP(int(global_channels[-1] / 4) * 2, tuple(int(e / 4) for e in global_channels), dropout_prob=dropout_prob) self.classifier22 = nn.Linear(int(global_channels[-1] / 4), 1, bias=True) self.p4 = PointNet2SSGCls(sa_channels=stem_channels1, global_channels=global_channels1) self.classifier4 = nn.Linear(global_channels[-1], 1, bias=True) self.stem8 = Stem(int(out_channels * 1) + 3, stem_channels, with_transform=with_transform) self.mlp_local8 = SharedMLP(stem_channels[-1], local_channels) self.mlp_global8 = MLP(local_channels[-1], global_channels, dropout_prob=dropout_prob) self.classifier8 = nn.Linear(global_channels[-1], 2, bias=True) self.stem9 = Stem(int(out_channels * 2) + 3, stem_channels, with_transform=with_transform) self.mlp_local9 = SharedMLP(stem_channels[-1], local_channels) self.mlp_global9 = MLP(local_channels[-1], global_channels, dropout_prob=dropout_prob) self.classifier9 = nn.Linear(global_channels[-1], 2, bias=True) self.sigmoid = nn.Sigmoid() self.reset_parameters() def forward(self, x, infer_type, y=None): if infer_type == 'backbone': x = self.p1(x) x = self.classifier1(x) elif infer_type == 'backbone2': x = self.p2(x) x = self.classifier2(x) elif infer_type == 'policy': x = self.p3(x) x = self.classifier3(x) elif infer_type == 'policy_head': x = self.mlp_global22(x) x = self.classifier22(x) elif infer_type == 'purity': x = self.p4(x) x = self.classifier4(x) x = self.sigmoid(x) elif infer_type == 'head': x, end_points = self.stem8(x) x = self.mlp_local8(x) x, max_indices = torch.max(x, 2) end_points['key_point_indices'] = max_indices x = self.mlp_global8(x) x = self.classifier8(x) elif infer_type == 'head2': x, end_points = self.stem9(x) x = self.mlp_local9(x) x, max_indices = torch.max(x, 2) end_points['key_point_indices'] = max_indices x = self.mlp_global9(x) x = self.classifier9(x) else: raise NameError return x def reset_parameters(self): # default initialization in original implementation self.p1.reset_parameters() self.p2.reset_parameters() self.p3.reset_parameters() self.p4.reset_parameters() xavier_uniform(self.classifier1) xavier_uniform(self.classifier2) xavier_uniform(self.classifier3) xavier_uniform(self.classifier22) xavier_uniform(self.classifier4) xavier_uniform(self.classifier8) xavier_uniform(self.classifier9) self.mlp_local8.reset_parameters(xavier_uniform) self.mlp_global8.reset_parameters(xavier_uniform) self.mlp_local9.reset_parameters(xavier_uniform) self.mlp_global9.reset_parameters(xavier_uniform) self.mlp_global22.reset_parameters(xavier_uniform) set_bn(self, momentum=0.01)
class PointNetPartSeg(nn.Module): """PointNet for part segmentation Args: in_channels (int): the number of input channels num_classes (int): the number of classification class num_seg_classes (int): the number of segmentation class stem_channels (tuple of int): the numbers of channels in stem feature extractor local_channels (tuple of int): the numbers of channels in local mlp cls_channels (tuple of int): the numbers of channels in classification mlp seg_channels (tuple of int): the numbers of channels in segmentation mlp dropout_prob_cls (float): the probability to dropout in classification mlp dropout_prob_seg (float): the probability to dropout in segmentation mlp with_transform (bool): whether to use TNet to transform features. use_one_hot (bool): whehter to use one hot vector of class labels. References: https://github.com/charlesq34/pointnet/blob/master/part_seg/pointnet_part_seg.py """ def __init__(self, in_channels, num_classes, num_seg_classes, stem_channels=(64, 128, 128), local_channels=(512, 2048), cls_channels=(256, 256), seg_channels=(256, 256, 128), dropout_prob_cls=0.3, dropout_prob_seg=0.2, with_transform=True, use_one_hot=True): super(PointNetPartSeg, self).__init__() self.in_channels = in_channels self.num_classes = num_classes self.num_seg_classes = num_seg_classes self.use_one_hot = use_one_hot # stem self.stem = Stem(in_channels, stem_channels, with_transform=with_transform) self.mlp_local = SharedMLP(stem_channels[-1], local_channels) # part segmentation # Notice that the original repo concatenates global feature, one hot class embedding, # stem features and local features. However, the paper does not use last local feature. # Here, we follow the released repo. in_channels_seg = sum(stem_channels) + sum( local_channels) + local_channels[-1] if self.use_one_hot: in_channels_seg += num_classes self.mlp_seg = SharedMLP(in_channels_seg, seg_channels[:-1], dropout_prob=dropout_prob_seg) self.conv_seg = Conv1d(seg_channels[-2], seg_channels[-1], 1) self.seg_logit = nn.Conv1d(seg_channels[-1], num_seg_classes, 1, bias=True) # classification (optional) if len(cls_channels) > 0: # Notice that we apply dropout to each classification mlp. self.mlp_cls = MLP(local_channels[-1], cls_channels, dropout_prob=dropout_prob_cls) self.cls_logit = nn.Linear(cls_channels[-1], num_classes, bias=True) else: self.mlp_cls = None self.cls_logit = None self.reset_parameters() def forward(self, data_batch): x = data_batch['points'] num_points = x.shape[2] end_points = {} # stem stem_feature, end_points_stem = self.stem(x) stem_features = end_points_stem.pop('stem_features') end_points.update(end_points_stem) # mlp for local features local_features = [] x = stem_feature for ind, mlp in enumerate(self.mlp_local): x = mlp(x) local_features.append(x) # max pool over points global_feature, max_indices = torch.max( x, 2) # (batch_size, local_channels[-1]) # end_points['key_point_indices'] = max_indices # segmentation global_feature_expand = global_feature.unsqueeze(2).expand( -1, -1, num_points) seg_features = stem_features + local_features + [global_feature_expand] if self.use_one_hot: with torch.no_grad(): cls_label = data_batch['cls_label'] one_hot = cls_label.new_zeros(cls_label.size(0), self.num_classes) one_hot = one_hot.scatter( 1, cls_label.unsqueeze(1), 1).float() # (batch_size, num_classes) one_hot_expand = one_hot.unsqueeze(2).expand( -1, -1, num_points) seg_features.append(one_hot_expand) x = torch.cat(seg_features, dim=1) x = self.mlp_seg(x) x = self.conv_seg(x) seg_logit = self.seg_logit(x) preds = {'seg_logit': seg_logit} preds.update(end_points) # classification (optional) if self.cls_logit is not None: x = self.mlp_cls(global_feature) cls_logit = self.cls_logit(x) preds['cls_logit'] = cls_logit return preds def reset_parameters(self): # default initialization self.mlp_local.reset_parameters(xavier_uniform) self.mlp_seg.reset_parameters(xavier_uniform) self.conv_seg.reset_parameters(xavier_uniform) if self.cls_logit is not None: self.mlp_cls.reset_parameters(xavier_uniform) xavier_uniform(self.cls_logit) xavier_uniform(self.seg_logit) # set batch normalization to 0.01 as default set_bn(self, momentum=0.01)
class PointNet2MSGCls(nn.Module): """PointNet2 with multi-scale grouping for classification""" def __init__(self, in_channels, out_channels, num_centroids=(512, 128, 0), radius_list=((0.1, 0.2, 0.4), (0.2, 0.4, 0.8), -1.0), num_neighbours_list=((16, 32, 128), (32, 64, 128), -1), sa_channels_list=( ((32, 32, 64), (64, 64, 128), (64, 96, 128)), ((64, 64, 128), (128, 128, 256), (128, 128, 256)), (256, 512, 1024), ), global_channels=(512, 256), dropout_prob=0.5, use_xyz=True): super(PointNet2MSGCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.use_xyz = use_xyz # sanity check num_sa_layers = len(num_centroids) assert len(radius_list) == num_sa_layers assert len(num_neighbours_list) == num_sa_layers assert len(sa_channels_list) == num_sa_layers feature_channels = in_channels - 3 self.sa_modules = nn.ModuleList() for ind in range(num_sa_layers - 1): sa_module = PointNetSAModuleMSG( in_channels=feature_channels, mlp_channels_list=sa_channels_list[ind], num_centroids=num_centroids[ind], radius_list=radius_list[ind], num_neighbours_list=num_neighbours_list[ind], use_xyz=use_xyz) self.sa_modules.append(sa_module) feature_channels = sa_module.out_channels sa_module = PointNetSAModule(in_channels=feature_channels, mlp_channels=sa_channels_list[-1], num_centroids=num_centroids[-1], radius=radius_list[-1], num_neighbours=num_neighbours_list[-1], use_xyz=use_xyz) self.sa_modules.append(sa_module) self.mlp_global = MLP(sa_channels_list[-1][-1], global_channels, dropout_prob=dropout_prob) self.classifier = nn.Linear(global_channels[-1], out_channels, bias=True) self.reset_parameters() def forward(self, data_batch): point = data_batch['points'] end_points = {} # torch.Tensor.narrow; share same memory xyz = point.narrow(1, 0, 3) if point.size(1) > 3: feature = point.narrow(1, 3, point.size(1) - 3) else: feature = None for sa_module in self.sa_modules: xyz, feature = sa_module(xyz, feature) x, max_indices = torch.max(feature, 2) end_points['key_point_indices'] = max_indices x = self.mlp_global(x) cls_logit = self.classifier(x) preds = {'cls_logit': cls_logit} preds.update(end_points) return preds def reset_parameters(self): for sa_module in self.sa_modules: sa_module.reset_parameters(xavier_uniform) self.mlp_global.reset_parameters(xavier_uniform) xavier_uniform(self.classifier) set_bn(self, momentum=0.01)