コード例 #1
0
class TNet(nn.Module):
    """Transformation Network for DGCNN

    Args:
        in_channels (int): the number of input channels
        out_channels (int): the number of output channels
        conv_channels (tuple of int): the numbers of channels of edge convolution layers
        local_channels (tuple of int): the numbers of channels in local mlp
        global_channels (tuple of int): the numbers of channels in global mlp
        k: the number of neareast neighbours for edge feature extractor

    """
    def __init__(self,
                 in_channels=3,
                 out_channels=3,
                 conv_channels=(64, 128),
                 local_channels=(1024, ),
                 global_channels=(512, 256),
                 k=20):
        super(TNet, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.edge_conv = EdgeConvBlock(in_channels, conv_channels, k)
        self.mlp_local = SharedMLP(conv_channels[-1], local_channels)
        self.mlp_global = MLP(local_channels[-1], global_channels)
        self.linear = nn.Linear(global_channels[-1],
                                self.in_channels * self.out_channels,
                                bias=True)

        self.reset_parameters()

    def forward(self, x):
        # input x: (batch_size, in_channels, num_points)
        x = self.edge_conv(x)  # (batch_size, edge_channels[-1], num_points)
        x = self.mlp_local(x)  # (batch_size, local_channels[-1], num_points)
        x, _ = torch.max(x, 2)
        x = self.mlp_global(x)
        x = self.linear(x)
        x = x.view(-1, self.out_channels, self.in_channels)
        I = torch.eye(self.out_channels, self.in_channels, device=x.device)
        x = x.add(I)  # broadcast first dimension
        return x

    def reset_parameters(self, init_fn=xavier_uniform):
        self.edge_conv.reset_parameters(init_fn)
        self.mlp_local.reset_parameters(init_fn)
        self.mlp_global.reset_parameters(init_fn)
        # set linear transform be 0
        nn.init.zeros_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)
コード例 #2
0
class TNet(nn.Module):
    """Transformation Network. The structure is similar with PointNet"""
    def __init__(self,
                 in_channels=3,
                 out_channels=3,
                 local_channels=(64, 128, 1024),
                 global_channels=(512, 256)):
        super(TNet, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        # local features
        self.mlp_local = SharedMLP(in_channels, local_channels)
        # global features
        self.mlp_global = MLP(local_channels[-1], global_channels)
        # linear output
        self.linear = nn.Linear(global_channels[-1],
                                in_channels * out_channels,
                                bias=True)

        self.reset_parameters()

    def forward(self, x):
        # x: (batch_size, in_channels, num_points)
        x = self.mlp_local(x)  # (batch_size, local_channels[-1], num_points)
        x, _ = torch.max(x, 2)  # (batch_size, local_channels[-1])
        x = self.mlp_global(x)
        x = self.linear(x)
        x = x.view(-1, self.out_channels, self.in_channels)
        I = torch.eye(self.out_channels,
                      self.in_channels,
                      dtype=x.dtype,
                      device=x.device)
        x = x.add(I)  # broadcast add, (batch_size, out_channels, in_channels)
        return x

    def reset_parameters(self, init_fn=xavier_uniform):
        # Default initialization in original implementation
        self.mlp_local.reset_parameters(init_fn)
        self.mlp_global.reset_parameters(init_fn)
        # Initialize linear transform to 0
        nn.init.zeros_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)
コード例 #3
0
class PointNet2SSGCls(nn.Module):
    """PointNet2 with single-scale grouping for classification'

    Args:
        in_channels (int): the number of input channels
        out_channels (int): the number of semantics classes to predict over
        num_centroids (tuple of int): the numbers of centroids to sample in each set abstraction module
        radius (tuple of float): a tuple of radius to query neighbours in each set abstraction module
        num_neighbours (tuple of int): the numbers of neighbours to query for each centroid
        sa_channels (tuple of tuple of int): the numbers of channels to within each set abstraction module
        global_channels (tuple of int): the numbers of channels to extract global features
        dropout_prob (float): the probability to dropout input features
        use_xyz (bool): whether or not to use the xyz position of a points as a feature

    Notes:
        1. num_centroids == -1: use all points; num_centroids == 0: use the origin.
        2. radius * num_neighbours > 0.

    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 num_centroids=(512, 128, 0),
                 radius=(0.2, 0.4, -1.0),
                 num_neighbours=(32, 64, -1),
                 sa_channels=((64, 64, 128), (128, 128, 256), (256, 512,
                                                               1024)),
                 global_channels=(512, 256),
                 dropout_prob=0.5,
                 use_xyz=True):
        super(PointNet2SSGCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_xyz = use_xyz

        # sanity check
        num_sa_layers = len(num_centroids)
        assert len(radius) == num_sa_layers
        assert len(num_neighbours) == num_sa_layers
        assert len(sa_channels) == num_sa_layers

        feature_channels = in_channels - 3
        self.sa_modules = nn.ModuleList()
        for ind in range(num_sa_layers):
            sa_module = PointNetSAModule(in_channels=feature_channels,
                                         mlp_channels=sa_channels[ind],
                                         num_centroids=num_centroids[ind],
                                         radius=radius[ind],
                                         num_neighbours=num_neighbours[ind],
                                         use_xyz=use_xyz)
            self.sa_modules.append(sa_module)
            feature_channels = sa_channels[ind][-1]

        self.mlp_global = MLP(feature_channels,
                              global_channels,
                              dropout_prob=dropout_prob)
        self.classifier = nn.Linear(global_channels[-1],
                                    out_channels,
                                    bias=True)

        self.reset_parameters()

    def forward(self, data_batch):
        points = data_batch['points']
        end_points = {}

        # torch.Tensor.narrow; share same memory
        xyz = points.narrow(1, 0, 3)  # equivalent to points[:, 0:3, :]
        if points.size(1) > 3:
            feature = points.narrow(1, 3, points.size(1) - 3)
        else:
            feature = None

        for sa_module in self.sa_modules:
            xyz, feature = sa_module(xyz, feature)

        x, max_indices = torch.max(feature, 2)
        end_points['key_point_indices'] = max_indices
        x = self.mlp_global(x)

        cls_logit = self.classifier(x)

        preds = {'cls_logit': cls_logit}
        preds.update(end_points)

        return preds

    def reset_parameters(self):
        for sa_module in self.sa_modules:
            sa_module.reset_parameters(xavier_uniform)
        self.mlp_global.reset_parameters(xavier_uniform)
        xavier_uniform(self.classifier)
        set_bn(self, momentum=0.01)
コード例 #4
0
class PointNetCls(nn.Module):
    """PointNet classification model

    Args:
        in_channels (int): the number of input channels
        out_channels (int): the number of output channels
        stem_channels (tuple of int): the numbers of channels in stem feature extractor
        local_channels (tuple of int): the numbers of channels in local mlp
        global_channels (tuple of int): the numbers of channels in global mlp
        dropout_prob (float): the probability to dropout
        with_transform (bool): whether to use TNet to transform features.

    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 stem_channels=(16, 32),
                 local_channels=(32, 32),
                 global_channels=(32, 32),
                 dropout_prob=0.5,
                 with_transform=False):
        super(PointNetCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.stem = Stem(in_channels,
                         stem_channels,
                         with_transform=with_transform)
        self.mlp_local = SharedMLP(stem_channels[-1], local_channels)
        self.mlp_global = MLP(local_channels[-1],
                              global_channels,
                              dropout_prob=dropout_prob)
        self.classifier = nn.Linear(global_channels[-1],
                                    out_channels,
                                    bias=True)
        self.classifier2 = nn.Linear(global_channels[-1], 2, bias=True)

        self.reset_parameters()

    def forward(self, points):
        #x = data_batch['points']
        #x = data_batch['box_points']
        x = points

        # stem
        x, end_points = self.stem(x)
        # mlp for local features
        x = self.mlp_local(x)
        # max pool over points
        x, max_indices = torch.max(x, 2)
        end_points['key_point_indices'] = max_indices
        # mlp for global features
        x = self.mlp_global(x)

        x1 = self.classifier(x)
        x2 = self.classifier2(x)
        preds = {
            'cls_logit': x1,
            'node_logit': x2,
        }
        preds.update(end_points)

        return preds

    def reset_parameters(self):
        # default initialization in original implementation
        self.mlp_local.reset_parameters(xavier_uniform)
        self.mlp_global.reset_parameters(xavier_uniform)
        xavier_uniform(self.classifier)
        xavier_uniform(self.classifier2)
        # set batch normalization to 0.01 as default
        set_bn(self, momentum=0.01)
コード例 #5
0
class DGCNNCls(nn.Module):
    """DGCNN for classification

    Args:
        in_channels (int): the number of input channels
        out_channels (int): the number of output channels
        edge_conv_channels (tuple of int): the numbers of channels of edge convolution layers
        inter_channels (int): the number of channels of intermediate features
        global_channels (tuple of int): the numbers of channels in global mlp
        k (int): the number of neareast neighbours for edge feature extractor
        dropout_prob (float): the probability to dropout
        with_transform (bool): whether to use TNet to transform features.

    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 edge_conv_channels=(64, 64, 64, 128),
                 local_channels=(1024, ),
                 global_channels=(512, 256),
                 k=20,
                 dropout_prob=0.5,
                 with_transform=True):
        super(DGCNNCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.with_transform = with_transform

        # input transform
        if self.with_transform:
            self.transform_input = TNet(in_channels, in_channels, k=k)

        self.edge_convs = nn.ModuleList()
        inter_channels = []
        for conv_channels in edge_conv_channels:
            if isinstance(conv_channels, int):
                conv_channels = [conv_channels]
            else:
                assert isinstance(conv_channels, (tuple, list))
            self.edge_convs.append(EdgeConvBlock(in_channels, conv_channels,
                                                 k))
            inter_channels.append(conv_channels[-1])
            in_channels = conv_channels[-1]
        self.mlp_local = SharedMLP(sum(inter_channels), local_channels)
        self.mlp_global = MLP(local_channels[-1],
                              global_channels,
                              dropout_prob=dropout_prob)
        self.classifier = nn.Linear(global_channels[-1],
                                    self.out_channels,
                                    bias=True)

        self.reset_parameters()

    def forward(self, data_batch):
        x = data_batch['points']
        end_points = {}

        # input transform
        if self.with_transform:
            trans_input = self.transform_input(x)
            x = torch.bmm(trans_input, x)
            end_points['trans_input'] = trans_input

        # EdgeConv
        features = []
        for edge_conv in self.edge_convs:
            x = edge_conv(x)
            features.append(x)

        x = torch.cat(features, dim=1)

        x = self.mlp_local(x)
        x, max_indices = torch.max(x, 2)
        end_points['key_point_indices'] = max_indices
        x = self.mlp_global(x)
        x = self.classifier(x)
        preds = {
            'cls_logit': x,
        }
        preds.update(end_points)

        return preds

    def reset_parameters(self):
        for edge_conv in self.edge_convs:
            edge_conv.reset_parameters(xavier_uniform)
        self.mlp_local.reset_parameters(xavier_uniform)
        self.mlp_global.reset_parameters(xavier_uniform)
        xavier_uniform(self.classifier)
        set_bn(self, momentum=0.01)
コード例 #6
0
ファイル: pn2.py プロジェクト: hyzcn/Learning-to-Group-1
class PointNetCls(nn.Module):
    """PointNet classification model

    Args:
        in_channels (int): the number of input channels
        out_channels (int): the number of output channels
        stem_channels (tuple of int): the numbers of channels in stem feature extractor
        local_channels (tuple of int): the numbers of channels in local mlp
        global_channels (tuple of int): the numbers of channels in global mlp
        dropout_prob (float): the probability to dropout
        with_transform (bool): whether to backboneuse TNet to transform features.

    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 stem_channels=(128, 128),
                 local_channels=(128, 256, 256),
                 global_channels=(128, 128),
                 dropout_prob=0.3,
                 with_transform=False):
        super(PointNetCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        p1 = 128
        stem_channels1 = ((p1, p1, int(2 * p1)), (int(2 * p1), int(2 * p1),
                                                  int(2 * p1)))
        p2 = int(p1 / 4)
        stem_channels2 = ((p2, p2, int(2 * p2)), (int(2 * p2), int(2 * p2),
                                                  int(2 * p2)))
        p3 = int(p1 / 8)
        stem_channels3 = ((p3, p3, int(2 * p3)), (int(2 * p3), int(2 * p3),
                                                  int(2 * p3)))
        global_channels1 = (p1 * 2, p1)
        global_channels2 = (p2 * 2, p2)
        global_channels3 = (p3 * 2, p3)

        self.p1 = PointNet2SSGCls(sa_channels=stem_channels1,
                                  global_channels=global_channels1)
        self.classifier1 = nn.Linear(global_channels[-1],
                                     out_channels,
                                     bias=True)

        self.p2 = PointNet2SSGCls(sa_channels=stem_channels1,
                                  global_channels=global_channels1)
        self.classifier2 = nn.Linear(global_channels[-1],
                                     out_channels,
                                     bias=True)

        self.p3 = PointNet2SSGCls(sa_channels=stem_channels2,
                                  global_channels=global_channels2)
        self.classifier3 = nn.Linear(int(global_channels[-1] / 4),
                                     int(global_channels[-1] / 4),
                                     bias=True)

        self.mlp_global22 = MLP(int(global_channels[-1] / 4) * 2,
                                tuple(int(e / 4) for e in global_channels),
                                dropout_prob=dropout_prob)
        self.classifier22 = nn.Linear(int(global_channels[-1] / 4),
                                      1,
                                      bias=True)

        self.p4 = PointNet2SSGCls(sa_channels=stem_channels1,
                                  global_channels=global_channels1)
        self.classifier4 = nn.Linear(global_channels[-1], 1, bias=True)

        self.stem8 = Stem(int(out_channels * 1) + 3,
                          stem_channels,
                          with_transform=with_transform)
        self.mlp_local8 = SharedMLP(stem_channels[-1], local_channels)
        self.mlp_global8 = MLP(local_channels[-1],
                               global_channels,
                               dropout_prob=dropout_prob)
        self.classifier8 = nn.Linear(global_channels[-1], 2, bias=True)

        self.stem9 = Stem(int(out_channels * 2) + 3,
                          stem_channels,
                          with_transform=with_transform)
        self.mlp_local9 = SharedMLP(stem_channels[-1], local_channels)
        self.mlp_global9 = MLP(local_channels[-1],
                               global_channels,
                               dropout_prob=dropout_prob)
        self.classifier9 = nn.Linear(global_channels[-1], 2, bias=True)

        self.sigmoid = nn.Sigmoid()
        self.reset_parameters()

    def forward(self, x, infer_type, y=None):

        if infer_type == 'backbone':
            x = self.p1(x)
            x = self.classifier1(x)
        elif infer_type == 'backbone2':
            x = self.p2(x)
            x = self.classifier2(x)
        elif infer_type == 'policy':
            x = self.p3(x)
            x = self.classifier3(x)
        elif infer_type == 'policy_head':
            x = self.mlp_global22(x)
            x = self.classifier22(x)
        elif infer_type == 'purity':
            x = self.p4(x)
            x = self.classifier4(x)
            x = self.sigmoid(x)
        elif infer_type == 'head':
            x, end_points = self.stem8(x)
            x = self.mlp_local8(x)
            x, max_indices = torch.max(x, 2)
            end_points['key_point_indices'] = max_indices
            x = self.mlp_global8(x)
            x = self.classifier8(x)
        elif infer_type == 'head2':
            x, end_points = self.stem9(x)
            x = self.mlp_local9(x)
            x, max_indices = torch.max(x, 2)
            end_points['key_point_indices'] = max_indices
            x = self.mlp_global9(x)
            x = self.classifier9(x)
        else:
            raise NameError

        return x

    def reset_parameters(self):
        # default initialization in original implementation
        self.p1.reset_parameters()
        self.p2.reset_parameters()
        self.p3.reset_parameters()
        self.p4.reset_parameters()
        xavier_uniform(self.classifier1)
        xavier_uniform(self.classifier2)
        xavier_uniform(self.classifier3)
        xavier_uniform(self.classifier22)
        xavier_uniform(self.classifier4)
        xavier_uniform(self.classifier8)
        xavier_uniform(self.classifier9)
        self.mlp_local8.reset_parameters(xavier_uniform)
        self.mlp_global8.reset_parameters(xavier_uniform)
        self.mlp_local9.reset_parameters(xavier_uniform)
        self.mlp_global9.reset_parameters(xavier_uniform)
        self.mlp_global22.reset_parameters(xavier_uniform)
        set_bn(self, momentum=0.01)
コード例 #7
0
class PointNetPartSeg(nn.Module):
    """PointNet for part segmentation

     Args:
        in_channels (int): the number of input channels
        num_classes (int): the number of classification class
        num_seg_classes (int): the number of segmentation class
        stem_channels (tuple of int): the numbers of channels in stem feature extractor
        local_channels (tuple of int): the numbers of channels in local mlp
        cls_channels (tuple of int): the numbers of channels in classification mlp
        seg_channels (tuple of int): the numbers of channels in segmentation mlp
        dropout_prob_cls (float): the probability to dropout in classification mlp
        dropout_prob_seg (float): the probability to dropout in segmentation mlp
        with_transform (bool): whether to use TNet to transform features.
        use_one_hot (bool): whehter to use one hot vector of class labels.

    References:
        https://github.com/charlesq34/pointnet/blob/master/part_seg/pointnet_part_seg.py

    """
    def __init__(self,
                 in_channels,
                 num_classes,
                 num_seg_classes,
                 stem_channels=(64, 128, 128),
                 local_channels=(512, 2048),
                 cls_channels=(256, 256),
                 seg_channels=(256, 256, 128),
                 dropout_prob_cls=0.3,
                 dropout_prob_seg=0.2,
                 with_transform=True,
                 use_one_hot=True):
        super(PointNetPartSeg, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes
        self.num_seg_classes = num_seg_classes
        self.use_one_hot = use_one_hot

        # stem
        self.stem = Stem(in_channels,
                         stem_channels,
                         with_transform=with_transform)
        self.mlp_local = SharedMLP(stem_channels[-1], local_channels)

        # part segmentation
        # Notice that the original repo concatenates global feature, one hot class embedding,
        # stem features and local features. However, the paper does not use last local feature.
        # Here, we follow the released repo.
        in_channels_seg = sum(stem_channels) + sum(
            local_channels) + local_channels[-1]
        if self.use_one_hot:
            in_channels_seg += num_classes
        self.mlp_seg = SharedMLP(in_channels_seg,
                                 seg_channels[:-1],
                                 dropout_prob=dropout_prob_seg)
        self.conv_seg = Conv1d(seg_channels[-2], seg_channels[-1], 1)
        self.seg_logit = nn.Conv1d(seg_channels[-1],
                                   num_seg_classes,
                                   1,
                                   bias=True)

        # classification (optional)
        if len(cls_channels) > 0:
            # Notice that we apply dropout to each classification mlp.
            self.mlp_cls = MLP(local_channels[-1],
                               cls_channels,
                               dropout_prob=dropout_prob_cls)
            self.cls_logit = nn.Linear(cls_channels[-1],
                                       num_classes,
                                       bias=True)
        else:
            self.mlp_cls = None
            self.cls_logit = None

        self.reset_parameters()

    def forward(self, data_batch):
        x = data_batch['points']
        num_points = x.shape[2]
        end_points = {}

        # stem
        stem_feature, end_points_stem = self.stem(x)
        stem_features = end_points_stem.pop('stem_features')
        end_points.update(end_points_stem)

        # mlp for local features
        local_features = []
        x = stem_feature
        for ind, mlp in enumerate(self.mlp_local):
            x = mlp(x)
            local_features.append(x)

        # max pool over points
        global_feature, max_indices = torch.max(
            x, 2)  # (batch_size, local_channels[-1])
        # end_points['key_point_indices'] = max_indices

        # segmentation
        global_feature_expand = global_feature.unsqueeze(2).expand(
            -1, -1, num_points)
        seg_features = stem_features + local_features + [global_feature_expand]
        if self.use_one_hot:
            with torch.no_grad():
                cls_label = data_batch['cls_label']
                one_hot = cls_label.new_zeros(cls_label.size(0),
                                              self.num_classes)
                one_hot = one_hot.scatter(
                    1, cls_label.unsqueeze(1),
                    1).float()  # (batch_size, num_classes)
                one_hot_expand = one_hot.unsqueeze(2).expand(
                    -1, -1, num_points)
            seg_features.append(one_hot_expand)

        x = torch.cat(seg_features, dim=1)
        x = self.mlp_seg(x)
        x = self.conv_seg(x)
        seg_logit = self.seg_logit(x)

        preds = {'seg_logit': seg_logit}
        preds.update(end_points)

        # classification (optional)
        if self.cls_logit is not None:
            x = self.mlp_cls(global_feature)
            cls_logit = self.cls_logit(x)
            preds['cls_logit'] = cls_logit

        return preds

    def reset_parameters(self):
        # default initialization
        self.mlp_local.reset_parameters(xavier_uniform)
        self.mlp_seg.reset_parameters(xavier_uniform)
        self.conv_seg.reset_parameters(xavier_uniform)
        if self.cls_logit is not None:
            self.mlp_cls.reset_parameters(xavier_uniform)
            xavier_uniform(self.cls_logit)
        xavier_uniform(self.seg_logit)
        # set batch normalization to 0.01 as default
        set_bn(self, momentum=0.01)
コード例 #8
0
class PointNet2MSGCls(nn.Module):
    """PointNet2 with multi-scale grouping for classification"""
    def __init__(self,
                 in_channels,
                 out_channels,
                 num_centroids=(512, 128, 0),
                 radius_list=((0.1, 0.2, 0.4), (0.2, 0.4, 0.8), -1.0),
                 num_neighbours_list=((16, 32, 128), (32, 64, 128), -1),
                 sa_channels_list=(
                     ((32, 32, 64), (64, 64, 128), (64, 96, 128)),
                     ((64, 64, 128), (128, 128, 256), (128, 128, 256)),
                     (256, 512, 1024),
                 ),
                 global_channels=(512, 256),
                 dropout_prob=0.5,
                 use_xyz=True):
        super(PointNet2MSGCls, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_xyz = use_xyz

        # sanity check
        num_sa_layers = len(num_centroids)
        assert len(radius_list) == num_sa_layers
        assert len(num_neighbours_list) == num_sa_layers
        assert len(sa_channels_list) == num_sa_layers

        feature_channels = in_channels - 3
        self.sa_modules = nn.ModuleList()
        for ind in range(num_sa_layers - 1):
            sa_module = PointNetSAModuleMSG(
                in_channels=feature_channels,
                mlp_channels_list=sa_channels_list[ind],
                num_centroids=num_centroids[ind],
                radius_list=radius_list[ind],
                num_neighbours_list=num_neighbours_list[ind],
                use_xyz=use_xyz)
            self.sa_modules.append(sa_module)
            feature_channels = sa_module.out_channels

        sa_module = PointNetSAModule(in_channels=feature_channels,
                                     mlp_channels=sa_channels_list[-1],
                                     num_centroids=num_centroids[-1],
                                     radius=radius_list[-1],
                                     num_neighbours=num_neighbours_list[-1],
                                     use_xyz=use_xyz)
        self.sa_modules.append(sa_module)

        self.mlp_global = MLP(sa_channels_list[-1][-1],
                              global_channels,
                              dropout_prob=dropout_prob)
        self.classifier = nn.Linear(global_channels[-1],
                                    out_channels,
                                    bias=True)

        self.reset_parameters()

    def forward(self, data_batch):
        point = data_batch['points']
        end_points = {}

        # torch.Tensor.narrow; share same memory
        xyz = point.narrow(1, 0, 3)
        if point.size(1) > 3:
            feature = point.narrow(1, 3, point.size(1) - 3)
        else:
            feature = None

        for sa_module in self.sa_modules:
            xyz, feature = sa_module(xyz, feature)

        x, max_indices = torch.max(feature, 2)
        end_points['key_point_indices'] = max_indices
        x = self.mlp_global(x)

        cls_logit = self.classifier(x)

        preds = {'cls_logit': cls_logit}
        preds.update(end_points)

        return preds

    def reset_parameters(self):
        for sa_module in self.sa_modules:
            sa_module.reset_parameters(xavier_uniform)
        self.mlp_global.reset_parameters(xavier_uniform)
        xavier_uniform(self.classifier)
        set_bn(self, momentum=0.01)