def network_initialization(
        self, in_channel, out_channel, channels, embedding_channel, kernel_size, D=2,
    ):
        self.conv1_1 = self.get_conv_block(in_channel, channels[0], kernel_size=kernel_size, stride=1)
        self.conv1_2 = self.get_conv_block(channels[0], channels[1], kernel_size=kernel_size, stride=1)
        self.pool1 = ME.MinkowskiMaxPooling(kernel_size=3, stride=2, dimension=D)

        self.conv2_1 = self.get_conv_block(channels[1], channels[2], kernel_size=kernel_size, stride=1)
        self.conv2_2 = self.get_conv_block(channels[2], channels[3], kernel_size=kernel_size, stride=1)
        self.pool2 = ME.MinkowskiMaxPooling(kernel_size=3, stride=2, dimension=D)

        self.conv3_1 = self.get_conv_block(channels[3], channels[4], kernel_size=kernel_size, stride=1)
        self.conv3_2 = self.get_conv_block(channels[4], channels[5], kernel_size=kernel_size, stride=1)
        self.conv3_3 = self.get_conv_block(channels[5], channels[6], kernel_size=kernel_size, stride=1)
        self.pool3 = ME.MinkowskiMaxPooling(kernel_size=3, stride=2, dimension=D)

        self.conv4_1 = self.get_conv_block(channels[6], channels[7], kernel_size=kernel_size, stride=1)
        self.conv4_2 = self.get_conv_block(channels[7], channels[8], kernel_size=kernel_size, stride=1)
        self.conv4_3 = self.get_conv_block(channels[8], channels[9], kernel_size=kernel_size, stride=1)
        self.pool4 = ME.MinkowskiMaxPooling(kernel_size=3, stride=2, dimension=D)

        self.conv5_1 = self.get_conv_block(channels[9], channels[10], kernel_size=kernel_size, stride=1)
        self.conv5_2 = self.get_conv_block(channels[10], channels[11], kernel_size=kernel_size, stride=1)
        self.conv5_3 = self.get_conv_block(channels[11], channels[12], kernel_size=kernel_size, stride=1)
        self.pool5 = ME.MinkowskiMaxPooling(kernel_size=3, stride=2, dimension=D)

        self.global_pool = ME.MinkowskiGlobalPooling()

        self.final = nn.Sequential(
            self.get_mlp_block(512, 512),
            ME.MinkowskiDropout(),
            self.get_mlp_block(512, 512),
            self.get_mlp_block(512, 2),
            #ME.MinkowskiFunctional.softmax(),
        )
Ejemplo n.º 2
0
    def network_initialization(self, in_channels, out_channels, D):

        self.inplanes = self.INIT_DIM
        self.conv1 = nn.Sequential(
            ME.MinkowskiConvolution(in_channels,
                                    self.inplanes,
                                    kernel_size=3,
                                    stride=2,
                                    dimension=D),
            ME.MinkowskiBatchNorm(self.inplanes),
            ME.MinkowskiReLU(inplace=True),
            ME.MinkowskiMaxPooling(kernel_size=2, stride=2, dimension=D),
        )

        self.layer1 = self._make_layer(self.BLOCK,
                                       self.PLANES[0],
                                       self.LAYERS[0],
                                       stride=2)
        self.layer2 = self._make_layer(self.BLOCK,
                                       self.PLANES[1],
                                       self.LAYERS[1],
                                       stride=2)
        self.layer3 = self._make_layer(self.BLOCK,
                                       self.PLANES[2],
                                       self.LAYERS[2],
                                       stride=2)
        self.layer4 = self._make_layer(self.BLOCK,
                                       self.PLANES[3],
                                       self.LAYERS[3],
                                       stride=2)

        self.conv5 = nn.Sequential(
            ME.MinkowskiDropout(),
            ME.MinkowskiConvolution(self.inplanes,
                                    self.inplanes,
                                    kernel_size=3,
                                    stride=3,
                                    dimension=D),
            ME.MinkowskiBatchNorm(self.inplanes),
            ME.MinkowskiGELU(),
        )

        self.glob_pool = ME.MinkowskiGlobalMaxPooling()

        self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True)
Ejemplo n.º 3
0
    def network_initialization(
        self, in_channel, out_channel, channels, embedding_channel, kernel_size, D=3,
    ):
        self.mlp1 = self.get_mlp_block(in_channel, channels[0])
        self.conv1 = self.get_conv_block(
            channels[0], channels[1], kernel_size=kernel_size, stride=1,
        )
        self.conv2 = self.get_conv_block(
            channels[1], channels[2], kernel_size=kernel_size, stride=2,
        )

        self.conv3 = self.get_conv_block(
            channels[2], channels[3], kernel_size=kernel_size, stride=2,
        )

        self.conv4 = self.get_conv_block(
            channels[3], channels[4], kernel_size=kernel_size, stride=2,
        )
        self.conv5 = nn.Sequential(
            self.get_conv_block(
                channels[1] + channels[2] + channels[3] + channels[4],
                embedding_channel // 4,
                kernel_size=3,
                stride=2,
            ),
            self.get_conv_block(
                embedding_channel // 4, embedding_channel // 2, kernel_size=3, stride=2,
            ),
            self.get_conv_block(
                embedding_channel // 2, embedding_channel, kernel_size=3, stride=2,
            ),
        )

        self.pool = ME.MinkowskiMaxPooling(kernel_size=3, stride=2, dimension=D)

        self.global_max_pool = ME.MinkowskiGlobalMaxPooling()
        self.global_avg_pool = ME.MinkowskiGlobalAvgPooling()

        self.final = nn.Sequential(
            self.get_mlp_block(embedding_channel * 2, 512),
            ME.MinkowskiDropout(),
            self.get_mlp_block(512, 512),
            ME.MinkowskiLinear(512, out_channel, bias=True),
        )
Ejemplo n.º 4
0
    def __init__(self,
                 in_channel,
                 out_channel,
                 embedding_channel=1024,
                 dimension=3):
        ME.MinkowskiNetwork.__init__(self, dimension)
        self.conv1 = nn.Sequential(
            ME.MinkowskiLinear(3, 64, bias=False),
            ME.MinkowskiBatchNorm(64),
            ME.MinkowskiReLU(),
        )
        self.conv2 = nn.Sequential(
            ME.MinkowskiLinear(64, 64, bias=False),
            ME.MinkowskiBatchNorm(64),
            ME.MinkowskiReLU(),
        )
        self.conv3 = nn.Sequential(
            ME.MinkowskiLinear(64, 64, bias=False),
            ME.MinkowskiBatchNorm(64),
            ME.MinkowskiReLU(),
        )
        self.conv4 = nn.Sequential(
            ME.MinkowskiLinear(64, 128, bias=False),
            ME.MinkowskiBatchNorm(128),
            ME.MinkowskiReLU(),
        )
        self.conv5 = nn.Sequential(
            ME.MinkowskiLinear(128, embedding_channel, bias=False),
            ME.MinkowskiBatchNorm(embedding_channel),
            ME.MinkowskiReLU(),
        )
        self.max_pool = ME.MinkowskiGlobalMaxPooling()

        self.linear1 = nn.Sequential(
            ME.MinkowskiLinear(embedding_channel, 512, bias=False),
            ME.MinkowskiBatchNorm(512),
            ME.MinkowskiReLU(),
        )
        self.dp1 = ME.MinkowskiDropout()
        self.linear2 = ME.MinkowskiLinear(512, out_channel, bias=True)
Ejemplo n.º 5
0
    def network_initialization(self, in_channels, out_channels, D):
        # Output of the first conv concated to conv6
        self.inplanes = self.INIT_DIM
        self.conv0p1s1 = ME.MinkowskiConvolution(
            in_channels, self.inplanes, kernel_size=5, dimension=D)

        self.bn0 = ME.MinkowskiBatchNorm(self.inplanes)

        self.conv1p1s2 = ME.MinkowskiConvolution(
            self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
        self.bn1 = ME.MinkowskiBatchNorm(self.inplanes)

        self.block1 = self._make_layer(self.BLOCK, self.PLANES[0],
                                       self.LAYERS[0], leakiness=self.leakiness)

        self.conv2p2s2 = ME.MinkowskiConvolution(
            self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
        self.bn2 = ME.MinkowskiBatchNorm(self.inplanes)

        self.block2 = self._make_layer(self.BLOCK, self.PLANES[1],
                                       self.LAYERS[1], leakiness=self.leakiness)

        self.conv3p4s2 = ME.MinkowskiConvolution(
            self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)

        self.bn3 = ME.MinkowskiBatchNorm(self.inplanes)
        self.block3 = self._make_layer(self.BLOCK, self.PLANES[2],
                                       self.LAYERS[2], leakiness=self.leakiness)

        self.conv4p8s2 = ME.MinkowskiConvolution(
            self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
        self.bn4 = ME.MinkowskiBatchNorm(self.inplanes)
        self.block4 = self._make_layer(self.BLOCK, self.PLANES[3],
                                       self.LAYERS[3], leakiness=self.leakiness)

        self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose(
            self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=D)
        self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4])

        self.inplanes = self.PLANES[4] 
        if self.skipconnections:
            self.inplanes += self.PLANES[2] * self.BLOCK.expansion
        self.block5 = self._make_layer(self.BLOCK, self.PLANES[4],
                                       self.LAYERS[4], leakiness=self.leakiness)
        self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose(
            self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=D)
        self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5])

        self.inplanes = self.PLANES[5] #+ self.PLANES[1] * self.BLOCK.expansion
        if self.skipconnections:
            self.inplanes += self.PLANES[1] * self.BLOCK.expansion
        self.block6 = self._make_layer(self.BLOCK, self.PLANES[5],
                                       self.LAYERS[5], leakiness=self.leakiness)
        self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose(
            self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=D)
        self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6])

        self.inplanes = self.PLANES[6] #+ self.PLANES[0] * self.BLOCK.expansion
        if self.skipconnections:
            self.inplanes += self.PLANES[0] * self.BLOCK.expansion
        self.block7 = self._make_layer(self.BLOCK, self.PLANES[6],
                                       self.LAYERS[6], leakiness=self.leakiness)
        self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose(
            self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=D)
        self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7])

        self.inplanes = self.PLANES[7] #+ self.INIT_DIM
        if self.skipconnections:
            self.inplanes += self.INIT_DIM
        self.block8 = self._make_layer(self.BLOCK, self.PLANES[7],
                                       self.LAYERS[7], leakiness=self.leakiness)

        self.final = ME.MinkowskiConvolution(
            self.PLANES[7],
            out_channels,
            kernel_size=1,
            has_bias=True,
            dimension=D)
        self.relu = ME.MinkowskiLeakyReLU(negative_slope=self.leakiness)
        self.dropout = ME.MinkowskiDropout(self.dropout_p)
Ejemplo n.º 6
0
    def __init__(self,
                 config,
                 in_channel,
                 out_channel,
                 final_dim=96,
                 dimension=3):

        ME.MinkowskiNetwork.__init__(
            self, dimension
        )  # The normal channel for Modelnet is 3, for scannet is 6, for scanobjnn is 0
        normal_channel = 3  # the RGB

        self.CONV_TYPE = ConvType.SPATIAL_HYPERCUBE

        self.dims = np.array([32, 64, 128, 256, 512])
        self.neighbor_ks = np.array([32, 32, 32, 32, 32]) // 2

        self.final_dim = final_dim

        stem_dim = self.dims[0]
        in_channel = normal_channel + 3  # normal ch + xyz
        self.normal_channel = normal_channel

        # pixel size 1
        self.stem1 = nn.Sequential(
            ME.MinkowskiConvolution(in_channel,
                                    stem_dim,
                                    kernel_size=1,
                                    dimension=3),
            ME.MinkowskiBatchNorm(stem_dim),
            ME.MinkowskiReLU(),
        )

        # when using split-scnee, no stride here
        split_scene = True
        if split_scene:
            self.stem2 = nn.Sequential(
                ME.MinkowskiConvolution(stem_dim,
                                        stem_dim,
                                        kernel_size=1,
                                        dimension=3,
                                        stride=1),
                ME.MinkowskiBatchNorm(stem_dim),
                ME.MinkowskiReLU(),
            )
        # does the spatial downsampling
        # pixel size 2
        else:
            self.stem2 = nn.Sequential(
                ME.MinkowskiConvolution(stem_dim,
                                        stem_dim,
                                        kernel_size=2,
                                        dimension=3,
                                        stride=2),
                ME.MinkowskiBatchNorm(stem_dim),
                ME.MinkowskiReLU(),
            )

        base_r = 4

        self.PTBlock0 = PTBlock(in_dim=self.dims[0],
                                hidden_dim=self.dims[0],
                                n_sample=self.neighbor_ks[0],
                                skip_knn=False,
                                r=base_r)
        self.PTBlock1 = PTBlock(in_dim=self.dims[1],
                                hidden_dim=self.dims[1],
                                n_sample=self.neighbor_ks[1],
                                skip_knn=False,
                                r=base_r)
        self.PTBlock2 = PTBlock(in_dim=self.dims[2],
                                hidden_dim=self.dims[2],
                                n_sample=self.neighbor_ks[2],
                                skip_knn=False,
                                r=2 * base_r)
        self.PTBlock3 = PTBlock(in_dim=self.dims[3],
                                hidden_dim=self.dims[3],
                                n_sample=self.neighbor_ks[3],
                                skip_knn=False,
                                r=int(2 * base_r))

        self.PTBlock4 = PTBlock(in_dim=self.dims[4],
                                hidden_dim=self.dims[4],
                                n_sample=self.neighbor_ks[3],
                                skip_knn=False,
                                r=int(4 * base_r))
        self.PTBlock_middle = PTBlock(in_dim=self.dims[4],
                                      hidden_dim=self.dims[4],
                                      n_sample=self.neighbor_ks[3],
                                      skip_knn=False,
                                      r=int(4 * base_r))

        self.PTBlock5 = PTBlock(in_dim=self.dims[3],
                                hidden_dim=self.dims[3],
                                n_sample=self.neighbor_ks[3],
                                skip_knn=False,
                                r=2 * base_r)  # out: 256
        self.PTBlock6 = PTBlock(in_dim=self.dims[2],
                                hidden_dim=self.dims[2],
                                n_sample=self.neighbor_ks[2],
                                skip_knn=False,
                                r=2 * base_r)  # out: 128
        self.PTBlock7 = PTBlock(in_dim=self.dims[1],
                                hidden_dim=self.dims[1],
                                n_sample=self.neighbor_ks[1],
                                skip_knn=False,
                                r=base_r)  # out: 64
        self.PTBlock8 = PTBlock(in_dim=self.dims[0],
                                hidden_dim=self.dims[0],
                                n_sample=self.neighbor_ks[1],
                                skip_knn=False,
                                r=base_r)  # out: 64

        # self.PTBlock0 = PTBlock(in_dim=self.dims[0], hidden_dim = self.dims[0], n_sample=self.neighbor_ks[0], skip_knn=True, r=base_r)
        # self.PTBlock1 = PTBlock(in_dim=self.dims[1], hidden_dim = self.dims[1], n_sample=self.neighbor_ks[1], skip_knn=True, r=2*base_r)
        # self.PTBlock2 = PTBlock(in_dim=self.dims[2],hidden_dim = self.dims[2], n_sample=self.neighbor_ks[2], skip_knn=True, r=2*base_r)
        # self.PTBlock3 = PTBlock(in_dim=self.dims[3], hidden_dim = self.dims[3], n_sample=self.neighbor_ks[3], skip_knn=True, r=int(4*base_r))
        # self.PTBlock4 = PTBlock(in_dim=self.dims[4], hidden_dim = self.dims[4], n_sample=self.neighbor_ks[3], skip_knn=True, r=int(16*base_r))
        # self.PTBlock_middle = PTBlock(in_dim=self.dims[4], hidden_dim = self.dims[4], n_sample=self.neighbor_ks[3], skip_knn=True, r=int(16*base_r))
        # self.PTBlock5 = PTBlock(in_dim=self.dims[3], hidden_dim = self.dims[3], n_sample=self.neighbor_ks[3], skip_knn=True, r=4*base_r) # out: 256
        # self.PTBlock6 = PTBlock(in_dim=self.dims[2], hidden_dim=self.dims[2], n_sample=self.neighbor_ks[2], skip_knn=True, r=2*base_r) # out: 128
        # self.PTBlock7 = PTBlock(in_dim=self.dims[1], hidden_dim=self.dims[1], n_sample=self.neighbor_ks[1], skip_knn=True, r=2*base_r) # out: 64
        # self.PTBlock8 = PTBlock(in_dim=self.dims[0], hidden_dim=self.dims[0], n_sample=self.neighbor_ks[1], skip_knn=True, r=base_r) # out: 64

        # self.PTBlock1 = self._make_layer(block=BasicBlock, inplanes=self.dims[0], planes=self.dims[0], num_blocks=2)
        # self.PTBlock2 = self._make_layer(block=BasicBlock, inplanes=self.dims[1], planes=self.dims[1], num_blocks=2)
        # self.PTBlock3 = self._make_layer(block=BasicBlock, inplanes=self.dims[2], planes=self.dims[2], num_blocks=2)
        # self.PTBlock4 = self._make_layer(block=BasicBlock, inplanes=self.dims[3], planes=self.dims[3], num_blocks=2)
        # self.PTBlock_middle = self._make_layer(block=BasicBlock, inplanes=self.dims[4], planes=self.dims[4], num_blocks=2)
        # self.PTBlock5 = self._make_layer(block=BasicBlock, inplanes=self.dims[3], planes=self.dims[3], num_blocks=2)
        # self.PTBlock6 = self._make_layer(block=BasicBlock, inplanes=self.dims[2], planes=self.dims[2], num_blocks=2)
        # self.PTBlock7 = self._make_layer(block=BasicBlock, inplanes=self.dims[1], planes=self.dims[1], num_blocks=2)
        # self.PTBlock8 = self._make_layer(block=BasicBlock, inplanes=self.dims[0], planes=self.dims[0], num_blocks=2)

        TD_kernel_size = [4, 8, 12,
                          16]  # only applied when using non-PointTRlike
        # pixel size 2
        self.TDLayer1 = TDLayer(input_dim=self.dims[0],
                                out_dim=self.dims[1],
                                kernel_size=TD_kernel_size[0])  # strided conv
        # pixel size 4
        self.TDLayer2 = TDLayer(input_dim=self.dims[1],
                                out_dim=self.dims[2],
                                kernel_size=TD_kernel_size[1])
        # pixel size 8
        self.TDLayer3 = TDLayer(input_dim=self.dims[2],
                                out_dim=self.dims[3],
                                kernel_size=TD_kernel_size[2])
        # pixel size 16: PTBlock4
        self.TDLayer4 = TDLayer(input_dim=self.dims[3],
                                out_dim=self.dims[4],
                                kernel_size=TD_kernel_size[3])

        self.middle_linear = ME.MinkowskiConvolution(self.dims[4],
                                                     self.dims[4],
                                                     kernel_size=1,
                                                     dimension=3)

        # pixel size 8
        self.TULayer5 = TULayer(
            input_a_dim=self.dims[4],
            input_b_dim=self.dims[3],
            out_dim=self.dims[3])  # out: 256//2 + 128 = 256

        # pixel size 4
        self.TULayer6 = TULayer(input_a_dim=self.dims[3],
                                input_b_dim=self.dims[2],
                                out_dim=self.dims[2])  # out: 256//2 + 64 = 192

        # pixel size 2
        self.TULayer7 = TULayer(input_a_dim=self.dims[2],
                                input_b_dim=self.dims[1],
                                out_dim=self.dims[1])  # 128 // 2 + 32 = 96

        self.TULayer8 = TULayer(input_a_dim=self.dims[1],
                                input_b_dim=self.dims[0],
                                out_dim=self.dims[0])  # 128 // 2 + 32 = 96

        self.final_dim = 32
        if split_scene:
            self.final_conv = nn.Sequential(
                ME.MinkowskiConvolution(self.dims[0],
                                        self.final_dim,
                                        kernel_size=1,
                                        stride=1,
                                        dimension=3),
                ME.MinkowskiDropout(0.4),
            )
        else:
            self.final_conv = nn.Sequential(
                ME.MinkowskiConvolutionTranspose(self.dims[0],
                                                 self.final_dim,
                                                 kernel_size=2,
                                                 stride=2,
                                                 dimension=3),
                ME.MinkowskiDropout(0.4),
            )
        self.fc = ME.MinkowskiLinear(self.final_dim, out_channel)
    def __init__(self,
                 in_channel,
                 out_channel,
                 num_class,
                 embedding_channel=1024,
                 dimension=3):
        ME.MinkowskiNetwork.__init__(self, dimension)
        # The normal channel for Modelnet is 3, for scannet is 6, for scanobjnn is 0
        normal_channel = 3
        # in_channel = normal_channel+3 # normal ch + xyz
        self.normal_channel = normal_channel
        self.input_mlp = nn.Sequential(
            ME.MinkowskiConvolution(in_channel, 32, kernel_size=1,
                                    dimension=3), ME.MinkowskiBatchNorm(32),
            ME.MinkowskiReLU(),
            ME.MinkowskiConvolution(32, 32, kernel_size=1, dimension=3),
            ME.MinkowskiBatchNorm(32))

        self.in_dims = [32, 64, 128, 256]
        self.out_dims = [64, 128, 256, 512]
        self.neighbor_ks = [16, 32, 64, 16, 16]
        # self.neighbor_ks = [8, 8, 8, 8, 8]

        self.PTBlock0 = PTBlock(in_dim=self.in_dims[0],
                                n_sample=self.neighbor_ks[0])

        self.TDLayer1 = TDLayer(input_dim=self.in_dims[0],
                                out_dim=self.out_dims[0])
        self.PTBlock1 = PTBlock(in_dim=self.out_dims[0],
                                n_sample=self.neighbor_ks[1])

        self.TDLayer2 = TDLayer(input_dim=self.in_dims[1],
                                out_dim=self.out_dims[1])
        self.PTBlock2 = PTBlock(in_dim=self.out_dims[1],
                                n_sample=self.neighbor_ks[1])

        self.TDLayer3 = TDLayer(input_dim=self.in_dims[2],
                                out_dim=self.out_dims[2])
        self.PTBlock3 = PTBlock(in_dim=self.out_dims[2],
                                n_sample=self.neighbor_ks[2])

        self.TDLayer4 = TDLayer(input_dim=self.in_dims[3],
                                out_dim=self.out_dims[3])
        self.PTBlock4 = PTBlock(in_dim=self.out_dims[3],
                                n_sample=self.neighbor_ks[4])

        self.middle_linear = ME.MinkowskiConvolution(self.out_dims[3],
                                                     self.out_dims[3],
                                                     kernel_size=1,
                                                     dimension=3)
        self.PTBlock_middle = PTBlock(in_dim=self.out_dims[3],
                                      n_sample=self.neighbor_ks[4])

        self.TULayer5 = TULayer(input_dim=self.out_dims[3],
                                out_dim=self.in_dims[3])
        self.PTBlock5 = PTBlock(in_dim=self.in_dims[3],
                                n_sample=self.neighbor_ks[4])

        self.TULayer6 = TULayer(input_dim=self.out_dims[2],
                                out_dim=self.in_dims[2])
        self.PTBlock6 = PTBlock(in_dim=self.in_dims[2],
                                n_sample=self.neighbor_ks[3])

        self.TULayer7 = TULayer(input_dim=self.out_dims[1],
                                out_dim=self.in_dims[1])
        self.PTBlock7 = PTBlock(in_dim=self.in_dims[1],
                                n_sample=self.neighbor_ks[2])

        self.TULayer8 = TULayer(input_dim=self.out_dims[0],
                                out_dim=self.in_dims[0])
        self.PTBlock8 = PTBlock(in_dim=self.in_dims[0],
                                n_sample=self.neighbor_ks[1])

        self.fc = nn.Sequential(
            # ME.MinkowskiLinear(32, 32),
            ME.MinkowskiLinear(self.out_dims[3], 32),
            ME.MinkowskiDropout(0.4),
            ME.MinkowskiLinear(32, num_class))

        self.global_avg_pool = ME.MinkowskiGlobalAvgPooling()