def __init__(self, in_channels, stem_channels=(16, 32, 32), local_channels=(128, 128), seg_channels=(64, 64, 32), dropout_prob_seg=0.2): super(PointNet, self).__init__() self.in_channels = in_channels # stem self.stem = Stem(in_channels, stem_channels, with_transform=False) self.mlp_local = SharedMLP(stem_channels[-1], local_channels) # part segmentation # Notice that the original repo concatenates global feature, one hot class embedding, # stem features and local features. However, the paper does not use last local feature. # Here, we follow the released repo. in_channels_seg = sum(stem_channels) + sum( local_channels) + local_channels[-1] self.mlp_seg = SharedMLP(in_channels_seg, seg_channels[:-1], dropout_prob=dropout_prob_seg) self.conv_seg = Conv1d(seg_channels[-2], seg_channels[-1], 1) self.reset_parameters()
class ConcatHead(nn.Module): def __init__(self, in_channels, dropout_prob=0.5): super(ConcatHead, self).__init__() self.mlp_local = SharedMLP(in_channels, (in_channels, ), dropout_prob=dropout_prob) self.conv1d = Conv1d(in_channels, 2, 1, relu=False, bn=False) #self.classifier = nn.Linear(in_channels, 2, bias=True) #self.mlp_local = SharedMLP(in_channels, (in_channels,2), dropout_prob=dropout_prob) self.reset_parameters() def forward(self, concat_feats): if isinstance(concat_feats, (list, tuple)): # concat_feature, (batch_size, in_channel, num_points) concat_feats = torch.cat(concat_feats, dim=1) # ins_logit, (batch_size, 2, num_points) ins_logit = self.mlp_local(concat_feats) ins_logit = self.conv1d(ins_logit) return ins_logit def reset_parameters(self): #xavier_uniform(self.classifier) self.mlp_local.reset_parameters(xavier_uniform) self.conv1d.reset_parameters(xavier_uniform) set_bn(self, momentum=0.01)
def __init__(self, in_channels, mlp_channels, num_centroids, radius, num_neighbours, use_xyz): super(PointNetSAModule, self).__init__() self.in_channels = in_channels self.out_channels = mlp_channels[-1] self.num_centroids = num_centroids # self.num_neighbours = num_neighbours self.use_xyz = use_xyz if self.use_xyz: in_channels += 3 self.mlp = SharedMLP(in_channels, mlp_channels, ndim=2, bn=True) if num_centroids <= 0: self.sampler = None else: self.sampler = FarthestPointSampler(num_centroids) if num_neighbours < 0: assert radius < 0.0 self.grouper = None else: assert num_neighbours > 0 and radius > 0.0 self.grouper = QueryGrouper(radius, num_neighbours)
def __init__(self, in_channels, out_channels, stem_channels=(16, 32), local_channels=(32, 32), global_channels=(32, 32), dropout_prob=0.5, with_transform=False): super(PointNetCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.stem = Stem(in_channels, stem_channels, with_transform=with_transform) self.mlp_local = SharedMLP(stem_channels[-1], local_channels) self.mlp_global = MLP(local_channels[-1], global_channels, dropout_prob=dropout_prob) self.classifier = nn.Linear(global_channels[-1], out_channels, bias=True) self.classifier2 = nn.Linear(global_channels[-1], 2, bias=True) self.reset_parameters()
class PointnetFPModule(nn.Module): """PointNet feature propagation module""" def __init__(self, in_channels, mlp_channels, num_neighbors): super(PointnetFPModule, self).__init__() self.in_channels = in_channels self.out_channels = mlp_channels[-1] self.mlp = SharedMLP(in_channels, mlp_channels, ndim=1, bn=True) if num_neighbors == 0: self.interpolator = None elif num_neighbors == 3: self.interpolator = FeatureInterpolator(num_neighbors) else: raise ValueError('Expected value 1 or 3, but {} given.'.format(num_neighbors)) def forward(self, dense_xyz, sparse_xyz, dense_feature, sparse_feature): if self.interpolator is None: assert sparse_xyz.size(2) == 1 and sparse_feature.size(2) == 1 sparse_feature_expand = sparse_feature.expand(-1, -1, dense_xyz.size(2)) new_feature = torch.cat([sparse_feature_expand, dense_feature], dim=1) else: new_feature = self.interpolator(dense_xyz, sparse_xyz, dense_feature, sparse_feature) new_feature = self.mlp(new_feature) return new_feature def reset_parameters(self, init_fn=None): self.mlp.reset_parameters(init_fn)
def __init__(self, in_channels, num_centroids=(128, 32, 0), radius=(0.2, 0.4, -1.0), num_neighbours=(64, 64, -1), sa_channels=((16, 16, 32), (32, 32, 64), (128, 128, 256)), fp_channels=((64, 64), (64, 32), (32, 32, 32)), num_fp_neighbours=(0, 3, 3), seg_channels=(32, ), dropout_prob=0.5, use_xyz=True): super(PointNet2SSG, self).__init__() self.in_channels = in_channels self.use_xyz = use_xyz # Sanity check num_sa_layers = len(num_centroids) num_fp_layers = len(fp_channels) assert len(radius) == num_sa_layers assert len(num_neighbours) == num_sa_layers assert len(sa_channels) == num_sa_layers assert num_sa_layers == num_fp_layers assert len(num_fp_neighbours) == num_fp_layers # Set Abstraction Layers feature_channels = in_channels - 3 self.sa_modules = nn.ModuleList() for ind in range(num_sa_layers): sa_module = PointNetSAModule(in_channels=feature_channels, mlp_channels=sa_channels[ind], num_centroids=num_centroids[ind], radius=radius[ind], num_neighbours=num_neighbours[ind], use_xyz=use_xyz) self.sa_modules.append(sa_module) feature_channels = sa_channels[ind][-1] inter_channels = [in_channels if use_xyz else in_channels - 3] inter_channels.extend([x[-1] for x in sa_channels]) # Feature Propagation Layers self.fp_modules = nn.ModuleList() feature_channels = inter_channels[-1] for ind in range(num_fp_layers): fp_module = PointnetFPModule(in_channels=feature_channels + inter_channels[-2 - ind], mlp_channels=fp_channels[ind], num_neighbors=num_fp_neighbours[ind]) self.fp_modules.append(fp_module) feature_channels = fp_channels[ind][-1] # MLP self.mlp_seg = SharedMLP(feature_channels, seg_channels, ndim=1, dropout_prob=dropout_prob) self.reset_parameters()
def __init__(self, in_channels, dropout_prob=0.5): super(ConcatHead, self).__init__() self.mlp_local = SharedMLP(in_channels, (in_channels, ), dropout_prob=dropout_prob) self.conv1d = Conv1d(in_channels, 2, 1, relu=False, bn=False) #self.classifier = nn.Linear(in_channels, 2, bias=True) #self.mlp_local = SharedMLP(in_channels, (in_channels,2), dropout_prob=dropout_prob) self.reset_parameters()
class Stem(nn.Module): """Stem (main body or stalk). Extract features from raw point clouds""" def __init__(self, in_channels, stem_channels=(64, 128, 128), with_transform=True): super(Stem, self).__init__() self.in_channels = in_channels self.out_channels = stem_channels[-1] self.with_transform = with_transform # feature stem self.mlp = SharedMLP(in_channels, stem_channels) self.mlp.reset_parameters(xavier_uniform) if self.with_transform: # input transform self.transform_input = TNet(in_channels, in_channels) # feature transform self.transform_feature = TNet(self.out_channels, self.out_channels) def forward(self, x): """PointNet Stem forward Args: x (torch.Tensor): (batch_size, in_channels, num_points) Returns: torch.Tensor: (batch_size, stem_channels[-1], num_points) dict (optional): trans_input: (batch_size, in_channels, in_channels) trans_feature: (batch_size, stem_channels[-1], stem_channels[-1]) stem_features (list of torch.Tensor) """ end_points = {} # input transform if self.with_transform: trans_input = self.transform_input(x) x = torch.bmm(trans_input, x) end_points['trans_input'] = trans_input # feature features = [] for module in self.mlp: x = module(x) features.append(x) end_points['stem_features'] = features # feature transform if self.with_transform: trans_feature = self.transform_feature(x) x = torch.bmm(trans_feature, x) end_points['trans_feature'] = trans_feature return x, end_points
def __init__(self, in_channels, num_classes, num_seg_classes, stem_channels=(64, 128, 128), local_channels=(512, 2048), cls_channels=(256, 256), seg_channels=(256, 256, 128), dropout_prob_cls=0.3, dropout_prob_seg=0.2, with_transform=True, use_one_hot=True): super(PointNetPartSeg, self).__init__() self.in_channels = in_channels self.num_classes = num_classes self.num_seg_classes = num_seg_classes self.use_one_hot = use_one_hot # stem self.stem = Stem(in_channels, stem_channels, with_transform=with_transform) self.mlp_local = SharedMLP(stem_channels[-1], local_channels) # part segmentation # Notice that the original repo concatenates global feature, one hot class embedding, # stem features and local features. However, the paper does not use last local feature. # Here, we follow the released repo. in_channels_seg = sum(stem_channels) + sum( local_channels) + local_channels[-1] if self.use_one_hot: in_channels_seg += num_classes self.mlp_seg = SharedMLP(in_channels_seg, seg_channels[:-1], dropout_prob=dropout_prob_seg) self.conv_seg = Conv1d(seg_channels[-2], seg_channels[-1], 1) self.seg_logit = nn.Conv1d(seg_channels[-1], num_seg_classes, 1, bias=True) # classification (optional) if len(cls_channels) > 0: # Notice that we apply dropout to each classification mlp. self.mlp_cls = MLP(local_channels[-1], cls_channels, dropout_prob=dropout_prob_cls) self.cls_logit = nn.Linear(cls_channels[-1], num_classes, bias=True) else: self.mlp_cls = None self.cls_logit = None self.reset_parameters()
class TNet(nn.Module): """Transformation Network for DGCNN Args: in_channels (int): the number of input channels out_channels (int): the number of output channels conv_channels (tuple of int): the numbers of channels of edge convolution layers local_channels (tuple of int): the numbers of channels in local mlp global_channels (tuple of int): the numbers of channels in global mlp k: the number of neareast neighbours for edge feature extractor """ def __init__(self, in_channels=3, out_channels=3, conv_channels=(64, 128), local_channels=(1024, ), global_channels=(512, 256), k=20): super(TNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.edge_conv = EdgeConvBlock(in_channels, conv_channels, k) self.mlp_local = SharedMLP(conv_channels[-1], local_channels) self.mlp_global = MLP(local_channels[-1], global_channels) self.linear = nn.Linear(global_channels[-1], self.in_channels * self.out_channels, bias=True) self.reset_parameters() def forward(self, x): # input x: (batch_size, in_channels, num_points) x = self.edge_conv(x) # (batch_size, edge_channels[-1], num_points) x = self.mlp_local(x) # (batch_size, local_channels[-1], num_points) x, _ = torch.max(x, 2) x = self.mlp_global(x) x = self.linear(x) x = x.view(-1, self.out_channels, self.in_channels) I = torch.eye(self.out_channels, self.in_channels, device=x.device) x = x.add(I) # broadcast first dimension return x def reset_parameters(self, init_fn=xavier_uniform): self.edge_conv.reset_parameters(init_fn) self.mlp_local.reset_parameters(init_fn) self.mlp_global.reset_parameters(init_fn) # set linear transform be 0 nn.init.zeros_(self.linear.weight) nn.init.zeros_(self.linear.bias)
def __init__(self, in_channels, num_classes, num_seg_classes, edge_conv_channels=((64, 64), (64, 64), 64), local_channels=(1024, ), seg_channels=(256, 256, 128), k=20, dropout_prob=0.4, with_transform=True): super(DGCNNPartSeg, self).__init__() self.in_channels = in_channels self.num_classes = num_classes self.num_seg_classes = num_seg_classes self.with_transform = with_transform # input transform if self.with_transform: self.transform_input = TNet(in_channels, in_channels, k=k) self.edge_convs = nn.ModuleList() inter_channels = [] for conv_channels in edge_conv_channels: if isinstance(conv_channels, int): conv_channels = [conv_channels] else: assert isinstance(conv_channels, (tuple, list)) self.edge_convs.append(EdgeConvBlock(in_channels, conv_channels, k)) inter_channels.append(conv_channels[-1]) in_channels = conv_channels[-1] LABEL_CHANNELS = 64 self.mlp_label = Conv1d(self.num_classes, LABEL_CHANNELS, 1) self.mlp_local = SharedMLP(sum(inter_channels), local_channels) mlp_seg_in_channels = sum( inter_channels) + local_channels[-1] + LABEL_CHANNELS self.mlp_seg = SharedMLP(mlp_seg_in_channels, seg_channels[:-1], dropout_prob=dropout_prob) self.conv_seg = Conv1d(seg_channels[-2], seg_channels[-1], 1) self.seg_logit = nn.Conv1d(seg_channels[-1], num_seg_classes, 1, bias=True) self.reset_parameters()
def __init__(self, in_channels, mlp_channels_list, num_centroids, radius_list, num_neighbours_list, use_xyz): super(PointNetSAModuleMSG, self).__init__() self.in_channels = in_channels self.out_channels = sum(mlp_channels[-1] for mlp_channels in mlp_channels_list) self.num_centroids = num_centroids self.use_xyz = use_xyz num_scales = len(mlp_channels_list) assert len(radius_list) == num_scales assert len(num_neighbours_list) == num_scales self.mlp = nn.ModuleList() if num_centroids == -1: self.sampler = None else: assert num_centroids > 0 self.sampler = FarthestPointSampler(num_centroids) self.grouper = nn.ModuleList() if self.use_xyz: in_channels += 3 for ind in range(num_scales): self.mlp.append(SharedMLP(in_channels, mlp_channels_list[ind], ndim=2, bn=True)) self.grouper.append(QueryGrouper(radius_list[ind], num_neighbours_list[ind]))
def __init__(self, in_channels, mlp_channels, num_neighbors): super(PointnetFPModule, self).__init__() self.in_channels = in_channels self.out_channels = mlp_channels[-1] self.mlp = SharedMLP(in_channels, mlp_channels, ndim=1, bn=True) if num_neighbors == 0: self.interpolator = None elif num_neighbors == 3: self.interpolator = FeatureInterpolator(num_neighbors) else: raise ValueError('Expected value 1 or 3, but {} given.'.format(num_neighbors))
def __init__(self, in_channels, stem_channels=(64, 64), with_transform=True): super(Stem, self).__init__() self.in_channels = in_channels self.out_channels = stem_channels[-1] self.with_transform = with_transform # feature stem self.mlp = SharedMLP(in_channels, stem_channels) if self.with_transform: # input transform self.transform_input = TNet(in_channels, in_channels) # feature transform self.transform_feature = TNet(self.out_channels, self.out_channels)
class TNet(nn.Module): """Transformation Network. The structure is similar with PointNet""" def __init__(self, in_channels=3, out_channels=3, local_channels=(64, 128, 1024), global_channels=(512, 256)): super(TNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels # local features self.mlp_local = SharedMLP(in_channels, local_channels) # global features self.mlp_global = MLP(local_channels[-1], global_channels) # linear output self.linear = nn.Linear(global_channels[-1], in_channels * out_channels, bias=True) self.reset_parameters() def forward(self, x): # x: (batch_size, in_channels, num_points) x = self.mlp_local(x) # (batch_size, local_channels[-1], num_points) x, _ = torch.max(x, 2) # (batch_size, local_channels[-1]) x = self.mlp_global(x) x = self.linear(x) x = x.view(-1, self.out_channels, self.in_channels) I = torch.eye(self.out_channels, self.in_channels, dtype=x.dtype, device=x.device) x = x.add(I) # broadcast add, (batch_size, out_channels, in_channels) return x def reset_parameters(self, init_fn=xavier_uniform): # Default initialization in original implementation self.mlp_local.reset_parameters(init_fn) self.mlp_global.reset_parameters(init_fn) # Initialize linear transform to 0 nn.init.zeros_(self.linear.weight) nn.init.zeros_(self.linear.bias)
def __init__(self, in_channels, stem_channels=(64, 128, 128), with_transform=True): super(Stem, self).__init__() self.in_channels = in_channels self.out_channels = stem_channels[-1] self.with_transform = with_transform # feature stem self.mlp = SharedMLP(in_channels, stem_channels) self.mlp.reset_parameters(xavier_uniform) if self.with_transform: # input transform self.transform_input = TNet(in_channels, in_channels) # feature transform self.transform_feature = TNet(self.out_channels, self.out_channels)
def __init__(self, in_channels=3, out_channels=3, local_channels=(64, 128, 1024), global_channels=(512, 256)): super(TNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels # local features self.mlp_local = SharedMLP(in_channels, local_channels) # global features self.mlp_global = MLP(local_channels[-1], global_channels) # linear output self.linear = nn.Linear(global_channels[-1], in_channels * out_channels, bias=True) self.reset_parameters()
def __init__(self, in_channels=3, out_channels=3, conv_channels=(64, 128), local_channels=(1024, ), global_channels=(512, 256), k=20): super(TNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.edge_conv = EdgeConvBlock(in_channels, conv_channels, k) self.mlp_local = SharedMLP(conv_channels[-1], local_channels) self.mlp_global = MLP(local_channels[-1], global_channels) self.linear = nn.Linear(global_channels[-1], self.in_channels * self.out_channels, bias=True) self.reset_parameters()
def __init__(self, in_channels, out_channels, edge_conv_channels=(64, 64, 64, 128), local_channels=(1024, ), global_channels=(512, 256), k=20, dropout_prob=0.5, with_transform=True): super(DGCNNCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.with_transform = with_transform # input transform if self.with_transform: self.transform_input = TNet(in_channels, in_channels, k=k) self.edge_convs = nn.ModuleList() inter_channels = [] for conv_channels in edge_conv_channels: if isinstance(conv_channels, int): conv_channels = [conv_channels] else: assert isinstance(conv_channels, (tuple, list)) self.edge_convs.append(EdgeConvBlock(in_channels, conv_channels, k)) inter_channels.append(conv_channels[-1]) in_channels = conv_channels[-1] self.mlp_local = SharedMLP(sum(inter_channels), local_channels) self.mlp_global = MLP(local_channels[-1], global_channels, dropout_prob=dropout_prob) self.classifier = nn.Linear(global_channels[-1], self.out_channels, bias=True) self.reset_parameters()
class DGCNNCls(nn.Module): """DGCNN for classification Args: in_channels (int): the number of input channels out_channels (int): the number of output channels edge_conv_channels (tuple of int): the numbers of channels of edge convolution layers inter_channels (int): the number of channels of intermediate features global_channels (tuple of int): the numbers of channels in global mlp k (int): the number of neareast neighbours for edge feature extractor dropout_prob (float): the probability to dropout with_transform (bool): whether to use TNet to transform features. """ def __init__(self, in_channels, out_channels, edge_conv_channels=(64, 64, 64, 128), local_channels=(1024, ), global_channels=(512, 256), k=20, dropout_prob=0.5, with_transform=True): super(DGCNNCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.with_transform = with_transform # input transform if self.with_transform: self.transform_input = TNet(in_channels, in_channels, k=k) self.edge_convs = nn.ModuleList() inter_channels = [] for conv_channels in edge_conv_channels: if isinstance(conv_channels, int): conv_channels = [conv_channels] else: assert isinstance(conv_channels, (tuple, list)) self.edge_convs.append(EdgeConvBlock(in_channels, conv_channels, k)) inter_channels.append(conv_channels[-1]) in_channels = conv_channels[-1] self.mlp_local = SharedMLP(sum(inter_channels), local_channels) self.mlp_global = MLP(local_channels[-1], global_channels, dropout_prob=dropout_prob) self.classifier = nn.Linear(global_channels[-1], self.out_channels, bias=True) self.reset_parameters() def forward(self, data_batch): x = data_batch['points'] end_points = {} # input transform if self.with_transform: trans_input = self.transform_input(x) x = torch.bmm(trans_input, x) end_points['trans_input'] = trans_input # EdgeConv features = [] for edge_conv in self.edge_convs: x = edge_conv(x) features.append(x) x = torch.cat(features, dim=1) x = self.mlp_local(x) x, max_indices = torch.max(x, 2) end_points['key_point_indices'] = max_indices x = self.mlp_global(x) x = self.classifier(x) preds = { 'cls_logit': x, } preds.update(end_points) return preds def reset_parameters(self): for edge_conv in self.edge_convs: edge_conv.reset_parameters(xavier_uniform) self.mlp_local.reset_parameters(xavier_uniform) self.mlp_global.reset_parameters(xavier_uniform) xavier_uniform(self.classifier) set_bn(self, momentum=0.01)
class PointNet2SSGPartSeg(nn.Module): """PointNet++ part segmentation with single-scale grouping PointNetSA: PointNet Set Abstraction Layer PointNetFP: PointNet Feature Propagation Layer Args: in_channels (int): the number of input channels num_classes (int): the number of classification class num_seg_classes (int): the number of segmentation class num_centroids (tuple of int): the numbers of centroids to sample in each set abstraction module radius (tuple of float): a tuple of radius to query neighbours in each set abstraction module num_neighbours (tuple of int): the numbers of neighbours to query for each centroid sa_channels (tuple of tuple of int): the numbers of channels within each set abstraction module fp_channels (tuple of tuple of int): the numbers of channels for feature propagation (FP) module num_fp_neighbours (tuple of int): the numbers of nearest neighbor used in FP seg_channels (tuple of int): the numbers of channels in segmentation mlp dropout_prob (float): the probability to dropout input features use_xyz (bool): whether or not to use the xyz position of a points as a feature use_one_hot (bool): whehter to use one hot vector of class labels. References: https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_part_seg.py """ def __init__(self, in_channels, num_classes, num_seg_classes, num_centroids=(512, 128, 0), radius=(0.2, 0.4, -1.0), num_neighbours=(64, 64, -1), sa_channels=((64, 64, 128), (128, 128, 256), (256, 512, 1024)), fp_channels=((256, 256), (256, 128), (128, 128, 128)), num_fp_neighbours=(0, 3, 3), seg_channels=(128,), dropout_prob=0.5, use_xyz=True, use_one_hot=True): super(PointNet2SSGPartSeg, self).__init__() self.in_channels = in_channels self.num_classes = num_classes self.num_seg_classes = num_seg_classes self.use_xyz = use_xyz self.use_one_hot = use_one_hot # Sanity check num_sa_layers = len(num_centroids) num_fp_layers = len(fp_channels) assert len(radius) == num_sa_layers assert len(num_neighbours) == num_sa_layers assert len(sa_channels) == num_sa_layers assert num_sa_layers == num_fp_layers assert len(num_fp_neighbours) == num_fp_layers # Set Abstraction Layers feature_channels = in_channels - 3 self.sa_modules = nn.ModuleList() for ind in range(num_sa_layers): sa_module = PointNetSAModule(in_channels=feature_channels, mlp_channels=sa_channels[ind], num_centroids=num_centroids[ind], radius=radius[ind], num_neighbours=num_neighbours[ind], use_xyz=use_xyz) self.sa_modules.append(sa_module) feature_channels = sa_channels[ind][-1] inter_channels = [in_channels if use_xyz else in_channels - 3] if self.use_one_hot: inter_channels[0] += num_classes # concat with one-hot inter_channels.extend([x[-1] for x in sa_channels]) # Feature Propagation Layers self.fp_modules = nn.ModuleList() feature_channels = inter_channels[-1] for ind in range(num_fp_layers): fp_module = PointnetFPModule(in_channels=feature_channels + inter_channels[-2 - ind], mlp_channels=fp_channels[ind], num_neighbors=num_fp_neighbours[ind]) self.fp_modules.append(fp_module) feature_channels = fp_channels[ind][-1] # MLP self.mlp_seg = SharedMLP(feature_channels, seg_channels, ndim=1, dropout_prob=dropout_prob) self.seg_logit = nn.Conv1d(seg_channels[-1], num_seg_classes, 1, bias=True) self.reset_parameters() def forward(self, data_batch): points = data_batch['points'] end_points = {} xyz = points.narrow(1, 0, 3) if points.size(1) > 3: feature = points.narrow(1, 3, points.size(1) - 3) else: feature = None # save intermediate results inter_xyz = [xyz] inter_feature = [points if self.use_xyz else feature] if self.use_one_hot: # one hot class label num_points = points.size(2) with torch.no_grad(): cls_label = data_batch['cls_label'] one_hot = cls_label.new_zeros(cls_label.size(0), self.num_classes) one_hot = one_hot.scatter(1, cls_label.unsqueeze(1), 1) # (batch_size, num_classes) one_hot_expand = one_hot.unsqueeze(2).expand(-1, -1, num_points).float() inter_feature[0] = torch.cat((inter_feature[0], one_hot_expand), dim=1) # Set Abstraction Layers for sa_module in self.sa_modules: xyz, feature = sa_module(xyz, feature) inter_xyz.append(xyz) inter_feature.append(feature) # Feature Propagation Layers sparse_xyz = xyz sparse_feature = feature for fp_ind, fp_module in enumerate(self.fp_modules): dense_xyz = inter_xyz[-2 - fp_ind] dense_feature = inter_feature[-2 - fp_ind] fp_feature = fp_module(dense_xyz, sparse_xyz, dense_feature, sparse_feature) sparse_xyz = dense_xyz sparse_feature = fp_feature # MLP x = self.mlp_seg(sparse_feature) seg_logit = self.seg_logit(x) preds = { 'seg_logit': seg_logit, } preds.update(end_points) return preds def reset_parameters(self): for sa_module in self.sa_modules: sa_module.reset_parameters(xavier_uniform) for fp_module in self.fp_modules: fp_module.reset_parameters(xavier_uniform) self.mlp_seg.reset_parameters(xavier_uniform) xavier_uniform(self.seg_logit) set_bn(self, momentum=0.01)
class PointNetCls(nn.Module): """PointNet classification model Args: in_channels (int): the number of input channels out_channels (int): the number of output channels stem_channels (tuple of int): the numbers of channels in stem feature extractor local_channels (tuple of int): the numbers of channels in local mlp global_channels (tuple of int): the numbers of channels in global mlp dropout_prob (float): the probability to dropout with_transform (bool): whether to use TNet to transform features. """ def __init__(self, in_channels, out_channels, stem_channels=(16, 32), local_channels=(32, 32), global_channels=(32, 32), dropout_prob=0.5, with_transform=False): super(PointNetCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.stem = Stem(in_channels, stem_channels, with_transform=with_transform) self.mlp_local = SharedMLP(stem_channels[-1], local_channels) self.mlp_global = MLP(local_channels[-1], global_channels, dropout_prob=dropout_prob) self.classifier = nn.Linear(global_channels[-1], out_channels, bias=True) self.classifier2 = nn.Linear(global_channels[-1], 2, bias=True) self.reset_parameters() def forward(self, points): #x = data_batch['points'] #x = data_batch['box_points'] x = points # stem x, end_points = self.stem(x) # mlp for local features x = self.mlp_local(x) # max pool over points x, max_indices = torch.max(x, 2) end_points['key_point_indices'] = max_indices # mlp for global features x = self.mlp_global(x) x1 = self.classifier(x) x2 = self.classifier2(x) preds = { 'cls_logit': x1, 'node_logit': x2, } preds.update(end_points) return preds def reset_parameters(self): # default initialization in original implementation self.mlp_local.reset_parameters(xavier_uniform) self.mlp_global.reset_parameters(xavier_uniform) xavier_uniform(self.classifier) xavier_uniform(self.classifier2) # set batch normalization to 0.01 as default set_bn(self, momentum=0.01)
def __init__(self, in_channels, num_classes, num_seg_classes, num_centroids=(512, 128, 0), radius_list=((0.1, 0.2, 0.4), (0.4, 0.8), -1.0), num_neighbours_list=((32, 64, 128), (64, 128), -1), sa_channels_list=( ((32, 32, 64), (64, 64, 128), (64, 96, 128)), ((128, 128, 256), (128, 196, 256)), (256, 512, 1024), ), fp_channels=((256, 256), (256, 128), (128, 128)), num_fp_neighbours=(0, 3, 3), seg_channels=(128, ), dropout_prob=0.5, use_xyz=True, use_one_hot=True): super(PointNet2MSGPartSeg, self).__init__() self.in_channels = in_channels self.num_classes = num_classes self.num_seg_classes = num_seg_classes self.use_xyz = use_xyz self.use_one_hot = use_one_hot # sanity check num_sa_layers = len(num_centroids) num_fp_layers = len(fp_channels) assert len(radius_list) == num_sa_layers assert len(num_neighbours_list) == num_sa_layers assert len(sa_channels_list) == num_sa_layers assert num_sa_layers == num_fp_layers assert len(num_fp_neighbours) == num_fp_layers # Set Abstraction Layers feature_channels = in_channels - 3 self.sa_modules = nn.ModuleList() for ind in range(num_sa_layers - 1): sa_module = PointNetSAModuleMSG( in_channels=feature_channels, mlp_channels_list=sa_channels_list[ind], num_centroids=num_centroids[ind], radius_list=radius_list[ind], num_neighbours_list=num_neighbours_list[ind], use_xyz=use_xyz) self.sa_modules.append(sa_module) feature_channels = sa_module.out_channels sa_module = PointNetSAModule(in_channels=feature_channels, mlp_channels=sa_channels_list[-1], num_centroids=num_centroids[-1], radius=radius_list[-1], num_neighbours=num_neighbours_list[-1], use_xyz=use_xyz) self.sa_modules.append(sa_module) inter_channels = [in_channels if use_xyz else in_channels - 3] if self.use_one_hot: inter_channels[0] += num_classes # concat with one-hot inter_channels.extend( [sa_module.out_channels for sa_module in self.sa_modules]) # Feature Propagation Layers self.fp_modules = nn.ModuleList() feature_channels = inter_channels[-1] for ind in range(num_fp_layers): fp_module = PointnetFPModule(in_channels=feature_channels + inter_channels[-2 - ind], mlp_channels=fp_channels[ind], num_neighbors=num_fp_neighbours[ind]) self.fp_modules.append(fp_module) feature_channels = fp_channels[ind][-1] # MLP self.mlp_seg = SharedMLP(feature_channels, seg_channels, ndim=1, dropout_prob=dropout_prob) self.seg_logit = nn.Conv1d(seg_channels[-1], num_seg_classes, 1, bias=True) self.reset_parameters()
class PointNet2MSGPartSeg(nn.Module): """ PointNet++ part segmentation with multi-scale grouping""" def __init__(self, in_channels, num_classes, num_seg_classes, num_centroids=(512, 128, 0), radius_list=((0.1, 0.2, 0.4), (0.4, 0.8), -1.0), num_neighbours_list=((32, 64, 128), (64, 128), -1), sa_channels_list=( ((32, 32, 64), (64, 64, 128), (64, 96, 128)), ((128, 128, 256), (128, 196, 256)), (256, 512, 1024), ), fp_channels=((256, 256), (256, 128), (128, 128)), num_fp_neighbours=(0, 3, 3), seg_channels=(128, ), dropout_prob=0.5, use_xyz=True, use_one_hot=True): super(PointNet2MSGPartSeg, self).__init__() self.in_channels = in_channels self.num_classes = num_classes self.num_seg_classes = num_seg_classes self.use_xyz = use_xyz self.use_one_hot = use_one_hot # sanity check num_sa_layers = len(num_centroids) num_fp_layers = len(fp_channels) assert len(radius_list) == num_sa_layers assert len(num_neighbours_list) == num_sa_layers assert len(sa_channels_list) == num_sa_layers assert num_sa_layers == num_fp_layers assert len(num_fp_neighbours) == num_fp_layers # Set Abstraction Layers feature_channels = in_channels - 3 self.sa_modules = nn.ModuleList() for ind in range(num_sa_layers - 1): sa_module = PointNetSAModuleMSG( in_channels=feature_channels, mlp_channels_list=sa_channels_list[ind], num_centroids=num_centroids[ind], radius_list=radius_list[ind], num_neighbours_list=num_neighbours_list[ind], use_xyz=use_xyz) self.sa_modules.append(sa_module) feature_channels = sa_module.out_channels sa_module = PointNetSAModule(in_channels=feature_channels, mlp_channels=sa_channels_list[-1], num_centroids=num_centroids[-1], radius=radius_list[-1], num_neighbours=num_neighbours_list[-1], use_xyz=use_xyz) self.sa_modules.append(sa_module) inter_channels = [in_channels if use_xyz else in_channels - 3] if self.use_one_hot: inter_channels[0] += num_classes # concat with one-hot inter_channels.extend( [sa_module.out_channels for sa_module in self.sa_modules]) # Feature Propagation Layers self.fp_modules = nn.ModuleList() feature_channels = inter_channels[-1] for ind in range(num_fp_layers): fp_module = PointnetFPModule(in_channels=feature_channels + inter_channels[-2 - ind], mlp_channels=fp_channels[ind], num_neighbors=num_fp_neighbours[ind]) self.fp_modules.append(fp_module) feature_channels = fp_channels[ind][-1] # MLP self.mlp_seg = SharedMLP(feature_channels, seg_channels, ndim=1, dropout_prob=dropout_prob) self.seg_logit = nn.Conv1d(seg_channels[-1], num_seg_classes, 1, bias=True) self.reset_parameters() def forward(self, data_batch): points = data_batch['points'] end_points = {} xyz = points.narrow(1, 0, 3) if points.size(1) > 3: feature = points.narrow(1, 3, points.size(1) - 3) else: feature = None # save intermediate results inter_xyz = [xyz] inter_feature = [points if self.use_xyz else feature] if self.use_one_hot: # one hot class label num_points = points.size(2) with torch.no_grad(): cls_label = data_batch['cls_label'] one_hot = cls_label.new_zeros(cls_label.size(0), self.num_classes) one_hot = one_hot.scatter(1, cls_label.unsqueeze(1), 1) # (batch_size, num_classes) one_hot_expand = one_hot.unsqueeze(2).expand( -1, -1, num_points).float() inter_feature[0] = torch.cat( (inter_feature[0], one_hot_expand), dim=1) # Set Abstraction Layers for sa_module in self.sa_modules: xyz, feature = sa_module(xyz, feature) inter_xyz.append(xyz) inter_feature.append(feature) # Feature Propagation Layers sparse_xyz = xyz sparse_feature = feature for fp_ind, fp_module in enumerate(self.fp_modules): dense_xyz = inter_xyz[-2 - fp_ind] dense_feature = inter_feature[-2 - fp_ind] fp_feature = fp_module(dense_xyz, sparse_xyz, dense_feature, sparse_feature) sparse_xyz = dense_xyz sparse_feature = fp_feature # MLP x = self.mlp_seg(sparse_feature) seg_logit = self.seg_logit(x) preds = {'seg_logit': seg_logit} preds.update(end_points) return preds def reset_parameters(self): for sa_module in self.sa_modules: sa_module.reset_parameters(xavier_uniform) for fp_module in self.fp_modules: fp_module.reset_parameters(xavier_uniform) self.mlp_seg.reset_parameters(xavier_uniform) xavier_uniform(self.seg_logit) set_bn(self, momentum=0.01)
class PointNet2SSGPartSeg(nn.Module): """PointNet++ part segmentation with single-scale grouping PointNetSA: PointNet Set Abstraction Layer PointNetFP: PointNet Feature Propagation Layer Args: in_channels (int): the number of input channels num_centroids (tuple of int): the numbers of centroids to sample in each set abstraction module radius (tuple of float): a tuple of radius to query neighbours in each set abstraction module num_neighbours (tuple of int): the numbers of neighbours to query for each centroid sa_channels (tuple of tuple of int): the numbers of channels within each set abstraction module fp_channels (tuple of tuple of int): the numbers of channels for feature propagation (FP) module num_fp_neighbours (tuple of int): the numbers of nearest neighbor used in FP seg_channels (tuple of int): the numbers of channels in segmentation mlp dropout_prob (float): the probability to dropout input features use_xyz (bool): whether or not to use the xyz position of a points as a feature References: https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_part_seg.py """ def __init__(self, in_channels, num_centroids=(128, 32, 0), radius=(0.2, 0.4, -1.0), num_neighbours=(64, 64, -1), sa_channels=((16, 16, 32), (32, 32, 64), (128, 128, 256)), fp_channels=((64, 64), (64, 32), (32, 32, 32)), num_fp_neighbours=(0, 3, 3), seg_channels=(32,), dropout_prob=0.5, use_xyz=True): super(PointNet2SSGPartSeg, self).__init__() self.in_channels = in_channels self.use_xyz = use_xyz # Sanity check num_sa_layers = len(num_centroids) num_fp_layers = len(fp_channels) assert len(radius) == num_sa_layers assert len(num_neighbours) == num_sa_layers assert len(sa_channels) == num_sa_layers assert num_sa_layers == num_fp_layers assert len(num_fp_neighbours) == num_fp_layers # Set Abstraction Layers feature_channels = in_channels - 3 self.sa_modules = nn.ModuleList() for ind in range(num_sa_layers): sa_module = PointNetSAModule(in_channels=feature_channels, mlp_channels=sa_channels[ind], num_centroids=num_centroids[ind], radius=radius[ind], num_neighbours=num_neighbours[ind], use_xyz=use_xyz) self.sa_modules.append(sa_module) feature_channels = sa_channels[ind][-1] inter_channels = [in_channels if use_xyz else in_channels - 3] inter_channels.extend([x[-1] for x in sa_channels]) # Feature Propagation Layers self.fp_modules = nn.ModuleList() feature_channels = inter_channels[-1] for ind in range(num_fp_layers): fp_module = PointnetFPModule(in_channels=feature_channels + inter_channels[-2 - ind], mlp_channels=fp_channels[ind], num_neighbors=num_fp_neighbours[ind]) self.fp_modules.append(fp_module) feature_channels = fp_channels[ind][-1] # MLP self.mlp_seg = SharedMLP(feature_channels, seg_channels, ndim=1, dropout_prob=dropout_prob) self.reset_parameters() def extract_feats(self, points): xyz = points.narrow(1, 0, 3) if points.size(1) > 3: feature = points.narrow(1, 3, points.size(1) - 3) else: feature = None # save intermediate results inter_xyz = [xyz] inter_feature = [points if self.use_xyz else feature] # Set Abstraction Layers for sa_module in self.sa_modules: xyz, feature = sa_module(xyz, feature) inter_xyz.append(xyz) inter_feature.append(feature) # Feature Propagation Layers sparse_xyz = xyz sparse_feature = feature for fp_ind, fp_module in enumerate(self.fp_modules): dense_xyz = inter_xyz[-2 - fp_ind] dense_feature = inter_feature[-2 - fp_ind] fp_feature = fp_module(dense_xyz, sparse_xyz, dense_feature, sparse_feature) sparse_xyz = dense_xyz sparse_feature = fp_feature # MLP x = self.mlp_seg(sparse_feature) return x def forward(self, points): preds = { 'feature': self.extract_feats(points), } return preds def reset_parameters(self): for sa_module in self.sa_modules: sa_module.reset_parameters(xavier_uniform) for fp_module in self.fp_modules: fp_module.reset_parameters(xavier_uniform) self.mlp_seg.reset_parameters(xavier_uniform) set_bn(self, momentum=0.01)
class PointNetSAModule(nn.Module): """PointNet set abstraction module""" def __init__(self, in_channels, mlp_channels, num_centroids, radius, num_neighbours, use_xyz): super(PointNetSAModule, self).__init__() self.in_channels = in_channels self.out_channels = mlp_channels[-1] self.num_centroids = num_centroids # self.num_neighbours = num_neighbours self.use_xyz = use_xyz if self.use_xyz: in_channels += 3 self.mlp = SharedMLP(in_channels, mlp_channels, ndim=2, bn=True) if num_centroids <= 0: self.sampler = None else: self.sampler = FarthestPointSampler(num_centroids) if num_neighbours < 0: assert radius < 0.0 self.grouper = None else: assert num_neighbours > 0 and radius > 0.0 self.grouper = QueryGrouper(radius, num_neighbours) def forward(self, xyz, feature=None): """ Args: xyz (torch.Tensor): (batch_size, 3, num_points) xyz coordinates of feature feature (torch.Tensor, optional): (batch_size, in_channels, num_points) Returns: new_xyz (torch.Tensor): (batch_size, 3, num_centroids) new_feature (torch.Tensor): (batch_size, out_channels, num_centroids) """ if self.num_centroids == 0: # use the origin as the centroid new_xyz = xyz.new_zeros(xyz.size(0), 3, 1) # (batch_size, 3, 1) assert self.grouper is None group_feature = feature.unsqueeze(2) # (batch_size, in_channels, 1, num_points) group_xyz = xyz.unsqueeze(2) # (batch_size, 3, 1, num_points) if self.use_xyz: group_feature = torch.cat([group_xyz, group_feature], dim=1) else: if self.num_centroids == -1: # use all points new_xyz = xyz else: # sample new points index = self.sampler(xyz) new_xyz = _F.gather_points(xyz, index) # (batch_size, 3, num_centroids) # group_feature, (batch_size, in_channels, num_centroids, num_neighbours) group_feature, group_xyz = self.grouper(new_xyz, xyz, feature, use_xyz=self.use_xyz) new_feature = self.mlp(group_feature) new_feature, _ = torch.max(new_feature, 3) return new_xyz, new_feature def reset_parameters(self, init_fn=None): self.mlp.reset_parameters(init_fn) def extra_repr(self): return 'num_centroids={:d}, use_xyz={}'.format(self.num_centroids, self.use_xyz)
class PointNetPartSeg(nn.Module): """PointNet for part segmentation Args: in_channels (int): the number of input channels num_classes (int): the number of classification class num_seg_classes (int): the number of segmentation class stem_channels (tuple of int): the numbers of channels in stem feature extractor local_channels (tuple of int): the numbers of channels in local mlp cls_channels (tuple of int): the numbers of channels in classification mlp seg_channels (tuple of int): the numbers of channels in segmentation mlp dropout_prob_cls (float): the probability to dropout in classification mlp dropout_prob_seg (float): the probability to dropout in segmentation mlp with_transform (bool): whether to use TNet to transform features. use_one_hot (bool): whehter to use one hot vector of class labels. References: https://github.com/charlesq34/pointnet/blob/master/part_seg/pointnet_part_seg.py """ def __init__(self, in_channels, num_classes, num_seg_classes, stem_channels=(64, 128, 128), local_channels=(512, 2048), cls_channels=(256, 256), seg_channels=(256, 256, 128), dropout_prob_cls=0.3, dropout_prob_seg=0.2, with_transform=True, use_one_hot=True): super(PointNetPartSeg, self).__init__() self.in_channels = in_channels self.num_classes = num_classes self.num_seg_classes = num_seg_classes self.use_one_hot = use_one_hot # stem self.stem = Stem(in_channels, stem_channels, with_transform=with_transform) self.mlp_local = SharedMLP(stem_channels[-1], local_channels) # part segmentation # Notice that the original repo concatenates global feature, one hot class embedding, # stem features and local features. However, the paper does not use last local feature. # Here, we follow the released repo. in_channels_seg = sum(stem_channels) + sum( local_channels) + local_channels[-1] if self.use_one_hot: in_channels_seg += num_classes self.mlp_seg = SharedMLP(in_channels_seg, seg_channels[:-1], dropout_prob=dropout_prob_seg) self.conv_seg = Conv1d(seg_channels[-2], seg_channels[-1], 1) self.seg_logit = nn.Conv1d(seg_channels[-1], num_seg_classes, 1, bias=True) # classification (optional) if len(cls_channels) > 0: # Notice that we apply dropout to each classification mlp. self.mlp_cls = MLP(local_channels[-1], cls_channels, dropout_prob=dropout_prob_cls) self.cls_logit = nn.Linear(cls_channels[-1], num_classes, bias=True) else: self.mlp_cls = None self.cls_logit = None self.reset_parameters() def forward(self, data_batch): x = data_batch['points'] num_points = x.shape[2] end_points = {} # stem stem_feature, end_points_stem = self.stem(x) stem_features = end_points_stem.pop('stem_features') end_points.update(end_points_stem) # mlp for local features local_features = [] x = stem_feature for ind, mlp in enumerate(self.mlp_local): x = mlp(x) local_features.append(x) # max pool over points global_feature, max_indices = torch.max( x, 2) # (batch_size, local_channels[-1]) # end_points['key_point_indices'] = max_indices # segmentation global_feature_expand = global_feature.unsqueeze(2).expand( -1, -1, num_points) seg_features = stem_features + local_features + [global_feature_expand] if self.use_one_hot: with torch.no_grad(): cls_label = data_batch['cls_label'] one_hot = cls_label.new_zeros(cls_label.size(0), self.num_classes) one_hot = one_hot.scatter( 1, cls_label.unsqueeze(1), 1).float() # (batch_size, num_classes) one_hot_expand = one_hot.unsqueeze(2).expand( -1, -1, num_points) seg_features.append(one_hot_expand) x = torch.cat(seg_features, dim=1) x = self.mlp_seg(x) x = self.conv_seg(x) seg_logit = self.seg_logit(x) preds = {'seg_logit': seg_logit} preds.update(end_points) # classification (optional) if self.cls_logit is not None: x = self.mlp_cls(global_feature) cls_logit = self.cls_logit(x) preds['cls_logit'] = cls_logit return preds def reset_parameters(self): # default initialization self.mlp_local.reset_parameters(xavier_uniform) self.mlp_seg.reset_parameters(xavier_uniform) self.conv_seg.reset_parameters(xavier_uniform) if self.cls_logit is not None: self.mlp_cls.reset_parameters(xavier_uniform) xavier_uniform(self.cls_logit) xavier_uniform(self.seg_logit) # set batch normalization to 0.01 as default set_bn(self, momentum=0.01)
class PointNetCls(nn.Module): """PointNet classification model Args: in_channels (int): the number of input channels out_channels (int): the number of output channels stem_channels (tuple of int): the numbers of channels in stem feature extractor local_channels (tuple of int): the numbers of channels in local mlp global_channels (tuple of int): the numbers of channels in global mlp dropout_prob (float): the probability to dropout with_transform (bool): whether to backboneuse TNet to transform features. """ def __init__(self, in_channels, out_channels, stem_channels=(128, 128), local_channels=(128, 256, 256), global_channels=(128, 128), dropout_prob=0.3, with_transform=False): super(PointNetCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels p1 = 128 stem_channels1 = ((p1, p1, int(2 * p1)), (int(2 * p1), int(2 * p1), int(2 * p1))) p2 = int(p1 / 4) stem_channels2 = ((p2, p2, int(2 * p2)), (int(2 * p2), int(2 * p2), int(2 * p2))) p3 = int(p1 / 8) stem_channels3 = ((p3, p3, int(2 * p3)), (int(2 * p3), int(2 * p3), int(2 * p3))) global_channels1 = (p1 * 2, p1) global_channels2 = (p2 * 2, p2) global_channels3 = (p3 * 2, p3) self.p1 = PointNet2SSGCls(sa_channels=stem_channels1, global_channels=global_channels1) self.classifier1 = nn.Linear(global_channels[-1], out_channels, bias=True) self.p2 = PointNet2SSGCls(sa_channels=stem_channels1, global_channels=global_channels1) self.classifier2 = nn.Linear(global_channels[-1], out_channels, bias=True) self.p3 = PointNet2SSGCls(sa_channels=stem_channels2, global_channels=global_channels2) self.classifier3 = nn.Linear(int(global_channels[-1] / 4), int(global_channels[-1] / 4), bias=True) self.mlp_global22 = MLP(int(global_channels[-1] / 4) * 2, tuple(int(e / 4) for e in global_channels), dropout_prob=dropout_prob) self.classifier22 = nn.Linear(int(global_channels[-1] / 4), 1, bias=True) self.p4 = PointNet2SSGCls(sa_channels=stem_channels1, global_channels=global_channels1) self.classifier4 = nn.Linear(global_channels[-1], 1, bias=True) self.stem8 = Stem(int(out_channels * 1) + 3, stem_channels, with_transform=with_transform) self.mlp_local8 = SharedMLP(stem_channels[-1], local_channels) self.mlp_global8 = MLP(local_channels[-1], global_channels, dropout_prob=dropout_prob) self.classifier8 = nn.Linear(global_channels[-1], 2, bias=True) self.stem9 = Stem(int(out_channels * 2) + 3, stem_channels, with_transform=with_transform) self.mlp_local9 = SharedMLP(stem_channels[-1], local_channels) self.mlp_global9 = MLP(local_channels[-1], global_channels, dropout_prob=dropout_prob) self.classifier9 = nn.Linear(global_channels[-1], 2, bias=True) self.sigmoid = nn.Sigmoid() self.reset_parameters() def forward(self, x, infer_type, y=None): if infer_type == 'backbone': x = self.p1(x) x = self.classifier1(x) elif infer_type == 'backbone2': x = self.p2(x) x = self.classifier2(x) elif infer_type == 'policy': x = self.p3(x) x = self.classifier3(x) elif infer_type == 'policy_head': x = self.mlp_global22(x) x = self.classifier22(x) elif infer_type == 'purity': x = self.p4(x) x = self.classifier4(x) x = self.sigmoid(x) elif infer_type == 'head': x, end_points = self.stem8(x) x = self.mlp_local8(x) x, max_indices = torch.max(x, 2) end_points['key_point_indices'] = max_indices x = self.mlp_global8(x) x = self.classifier8(x) elif infer_type == 'head2': x, end_points = self.stem9(x) x = self.mlp_local9(x) x, max_indices = torch.max(x, 2) end_points['key_point_indices'] = max_indices x = self.mlp_global9(x) x = self.classifier9(x) else: raise NameError return x def reset_parameters(self): # default initialization in original implementation self.p1.reset_parameters() self.p2.reset_parameters() self.p3.reset_parameters() self.p4.reset_parameters() xavier_uniform(self.classifier1) xavier_uniform(self.classifier2) xavier_uniform(self.classifier3) xavier_uniform(self.classifier22) xavier_uniform(self.classifier4) xavier_uniform(self.classifier8) xavier_uniform(self.classifier9) self.mlp_local8.reset_parameters(xavier_uniform) self.mlp_global8.reset_parameters(xavier_uniform) self.mlp_local9.reset_parameters(xavier_uniform) self.mlp_global9.reset_parameters(xavier_uniform) self.mlp_global22.reset_parameters(xavier_uniform) set_bn(self, momentum=0.01)
def __init__(self, in_channels, out_channels, stem_channels=(128, 128), local_channels=(128, 256, 256), global_channels=(128, 128), dropout_prob=0.3, with_transform=False): super(PointNetCls, self).__init__() self.in_channels = in_channels self.out_channels = out_channels p1 = 128 stem_channels1 = ((p1, p1, int(2 * p1)), (int(2 * p1), int(2 * p1), int(2 * p1))) p2 = int(p1 / 4) stem_channels2 = ((p2, p2, int(2 * p2)), (int(2 * p2), int(2 * p2), int(2 * p2))) p3 = int(p1 / 8) stem_channels3 = ((p3, p3, int(2 * p3)), (int(2 * p3), int(2 * p3), int(2 * p3))) global_channels1 = (p1 * 2, p1) global_channels2 = (p2 * 2, p2) global_channels3 = (p3 * 2, p3) self.p1 = PointNet2SSGCls(sa_channels=stem_channels1, global_channels=global_channels1) self.classifier1 = nn.Linear(global_channels[-1], out_channels, bias=True) self.p2 = PointNet2SSGCls(sa_channels=stem_channels1, global_channels=global_channels1) self.classifier2 = nn.Linear(global_channels[-1], out_channels, bias=True) self.p3 = PointNet2SSGCls(sa_channels=stem_channels2, global_channels=global_channels2) self.classifier3 = nn.Linear(int(global_channels[-1] / 4), int(global_channels[-1] / 4), bias=True) self.mlp_global22 = MLP(int(global_channels[-1] / 4) * 2, tuple(int(e / 4) for e in global_channels), dropout_prob=dropout_prob) self.classifier22 = nn.Linear(int(global_channels[-1] / 4), 1, bias=True) self.p4 = PointNet2SSGCls(sa_channels=stem_channels1, global_channels=global_channels1) self.classifier4 = nn.Linear(global_channels[-1], 1, bias=True) self.stem8 = Stem(int(out_channels * 1) + 3, stem_channels, with_transform=with_transform) self.mlp_local8 = SharedMLP(stem_channels[-1], local_channels) self.mlp_global8 = MLP(local_channels[-1], global_channels, dropout_prob=dropout_prob) self.classifier8 = nn.Linear(global_channels[-1], 2, bias=True) self.stem9 = Stem(int(out_channels * 2) + 3, stem_channels, with_transform=with_transform) self.mlp_local9 = SharedMLP(stem_channels[-1], local_channels) self.mlp_global9 = MLP(local_channels[-1], global_channels, dropout_prob=dropout_prob) self.classifier9 = nn.Linear(global_channels[-1], 2, bias=True) self.sigmoid = nn.Sigmoid() self.reset_parameters()
class DGCNNPartSeg(nn.Module): """DGCNN for part segmentation Args: in_channels (int): the number of input channels num_classes (int): the number of classification class num_seg_classes (int): the number of segmentation class edge_conv_channels (tuple of int): the numbers of channels of edge convolution layers local_channels (tuple of int): the number of channels of intermediate features seg_channels (tuple of int): the numbers of channels in segmentation mlp k (int): the number of neareast neighbours for edge feature extractor dropout_prob (float): the probability to dropout with_transform (bool): whether to use TNet to transform features. """ def __init__(self, in_channels, num_classes, num_seg_classes, edge_conv_channels=((64, 64), (64, 64), 64), local_channels=(1024, ), seg_channels=(256, 256, 128), k=20, dropout_prob=0.4, with_transform=True): super(DGCNNPartSeg, self).__init__() self.in_channels = in_channels self.num_classes = num_classes self.num_seg_classes = num_seg_classes self.with_transform = with_transform # input transform if self.with_transform: self.transform_input = TNet(in_channels, in_channels, k=k) self.edge_convs = nn.ModuleList() inter_channels = [] for conv_channels in edge_conv_channels: if isinstance(conv_channels, int): conv_channels = [conv_channels] else: assert isinstance(conv_channels, (tuple, list)) self.edge_convs.append(EdgeConvBlock(in_channels, conv_channels, k)) inter_channels.append(conv_channels[-1]) in_channels = conv_channels[-1] LABEL_CHANNELS = 64 self.mlp_label = Conv1d(self.num_classes, LABEL_CHANNELS, 1) self.mlp_local = SharedMLP(sum(inter_channels), local_channels) mlp_seg_in_channels = sum( inter_channels) + local_channels[-1] + LABEL_CHANNELS self.mlp_seg = SharedMLP(mlp_seg_in_channels, seg_channels[:-1], dropout_prob=dropout_prob) self.conv_seg = Conv1d(seg_channels[-2], seg_channels[-1], 1) self.seg_logit = nn.Conv1d(seg_channels[-1], num_seg_classes, 1, bias=True) self.reset_parameters() def forward(self, data_batch): x = data_batch['points'] num_points = x.shape[2] end_points = {} # input transform if self.with_transform: trans_input = self.transform_input(x) x = torch.bmm(trans_input, x) end_points['trans_input'] = trans_input # EdgeConv features = [] for edge_conv in self.edge_convs: x = edge_conv(x) features.append(x) inter_feature = torch.cat( features, dim=1) # (batch_size, sum(inter_channels), num_points) x = self.mlp_local(inter_feature) global_feature, max_indices = torch.max( x, 2) # (batch_size, local_channels[-1]) # end_points['key_point_indices'] = max_indices global_feature_expand = global_feature.unsqueeze(2).expand( -1, -1, num_points) with torch.no_grad(): cls_label = data_batch['cls_label'] one_hot = cls_label.new_zeros(cls_label.size(0), self.num_classes) one_hot = one_hot.scatter(1, cls_label.unsqueeze(1), 1).float() # (batch_size, num_classes) one_hot_expand = one_hot.unsqueeze(2).expand(-1, -1, num_points) label_feature = self.mlp_label(one_hot_expand) # (batch_size, mlp_seg_in_channels, num_points) x = torch.cat((inter_feature, global_feature_expand, label_feature), dim=1) x = self.mlp_seg(x) x = self.conv_seg(x) seg_logit = self.seg_logit(x) preds = { 'seg_logit': seg_logit, } preds.update(end_points) return preds def reset_parameters(self): for edge_conv in self.edge_convs: edge_conv.reset_parameters(xavier_uniform) self.mlp_label.reset_parameters(xavier_uniform) self.mlp_local.reset_parameters(xavier_uniform) self.mlp_seg.reset_parameters(xavier_uniform) self.conv_seg.reset_parameters(xavier_uniform) xavier_uniform(self.seg_logit) set_bn(self, momentum=0.01)