Ejemplo n.º 1
0
    def __init__(self, resolution, in_nchannel=512):
        nn.Module.__init__(self)

        self.resolution = resolution

        # Input sparse tensor must have tensor stride 128.
        enc_ch = self.ENC_CHANNELS
        dec_ch = self.DEC_CHANNELS

        # Encoder
        self.enc_block_s1 = nn.Sequential(
            ME.MinkowskiConvolution(1,
                                    enc_ch[0],
                                    kernel_size=3,
                                    stride=1,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[0]),
            ME.MinkowskiELU(),
        )

        self.enc_block_s1s2 = nn.Sequential(
            ME.MinkowskiConvolution(enc_ch[0],
                                    enc_ch[1],
                                    kernel_size=2,
                                    stride=2,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[1]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(enc_ch[1],
                                    enc_ch[1],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[1]),
            ME.MinkowskiELU(),
        )

        self.enc_block_s2s4 = nn.Sequential(
            ME.MinkowskiConvolution(enc_ch[1],
                                    enc_ch[2],
                                    kernel_size=2,
                                    stride=2,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[2]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(enc_ch[2],
                                    enc_ch[2],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[2]),
            ME.MinkowskiELU(),
        )

        self.enc_block_s4s8 = nn.Sequential(
            ME.MinkowskiConvolution(enc_ch[2],
                                    enc_ch[3],
                                    kernel_size=2,
                                    stride=2,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[3]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(enc_ch[3],
                                    enc_ch[3],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[3]),
            ME.MinkowskiELU(),
        )

        self.enc_block_s8s16 = nn.Sequential(
            ME.MinkowskiConvolution(enc_ch[3],
                                    enc_ch[4],
                                    kernel_size=2,
                                    stride=2,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[4]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(enc_ch[4],
                                    enc_ch[4],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[4]),
            ME.MinkowskiELU(),
        )

        self.enc_block_s16s32 = nn.Sequential(
            ME.MinkowskiConvolution(enc_ch[4],
                                    enc_ch[5],
                                    kernel_size=2,
                                    stride=2,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[5]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(enc_ch[5],
                                    enc_ch[5],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[5]),
            ME.MinkowskiELU(),
        )

        self.enc_block_s32s64 = nn.Sequential(
            ME.MinkowskiConvolution(enc_ch[5],
                                    enc_ch[6],
                                    kernel_size=2,
                                    stride=2,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[6]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(enc_ch[6],
                                    enc_ch[6],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[6]),
            ME.MinkowskiELU(),
        )

        # Decoder
        self.dec_block_s64s32 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                enc_ch[6],
                dec_ch[5],
                kernel_size=4,
                stride=2,
                dimension=3,
            ),
            ME.MinkowskiBatchNorm(dec_ch[5]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(dec_ch[5],
                                    dec_ch[5],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(dec_ch[5]),
            ME.MinkowskiELU(),
        )

        self.dec_s32_cls = ME.MinkowskiConvolution(dec_ch[5],
                                                   1,
                                                   kernel_size=1,
                                                   bias=True,
                                                   dimension=3)

        self.dec_block_s32s16 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                enc_ch[5],
                dec_ch[4],
                kernel_size=2,
                stride=2,
                dimension=3,
            ),
            ME.MinkowskiBatchNorm(dec_ch[4]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(dec_ch[4],
                                    dec_ch[4],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(dec_ch[4]),
            ME.MinkowskiELU(),
        )

        self.dec_s16_cls = ME.MinkowskiConvolution(dec_ch[4],
                                                   1,
                                                   kernel_size=1,
                                                   bias=True,
                                                   dimension=3)

        self.dec_block_s16s8 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                dec_ch[4],
                dec_ch[3],
                kernel_size=2,
                stride=2,
                dimension=3,
            ),
            ME.MinkowskiBatchNorm(dec_ch[3]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(dec_ch[3],
                                    dec_ch[3],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(dec_ch[3]),
            ME.MinkowskiELU(),
        )

        self.dec_s8_cls = ME.MinkowskiConvolution(dec_ch[3],
                                                  1,
                                                  kernel_size=1,
                                                  bias=True,
                                                  dimension=3)

        self.dec_block_s8s4 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                dec_ch[3],
                dec_ch[2],
                kernel_size=2,
                stride=2,
                dimension=3,
            ),
            ME.MinkowskiBatchNorm(dec_ch[2]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(dec_ch[2],
                                    dec_ch[2],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(dec_ch[2]),
            ME.MinkowskiELU(),
        )

        self.dec_s4_cls = ME.MinkowskiConvolution(dec_ch[2],
                                                  1,
                                                  kernel_size=1,
                                                  bias=True,
                                                  dimension=3)

        self.dec_block_s4s2 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                dec_ch[2],
                dec_ch[1],
                kernel_size=2,
                stride=2,
                dimension=3,
            ),
            ME.MinkowskiBatchNorm(dec_ch[1]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(dec_ch[1],
                                    dec_ch[1],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(dec_ch[1]),
            ME.MinkowskiELU(),
        )

        self.dec_s2_cls = ME.MinkowskiConvolution(dec_ch[1],
                                                  1,
                                                  kernel_size=1,
                                                  bias=True,
                                                  dimension=3)

        self.dec_block_s2s1 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                dec_ch[1],
                dec_ch[0],
                kernel_size=2,
                stride=2,
                dimension=3,
            ),
            ME.MinkowskiBatchNorm(dec_ch[0]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(dec_ch[0],
                                    dec_ch[0],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(dec_ch[0]),
            ME.MinkowskiELU(),
        )

        self.dec_s1_cls = ME.MinkowskiConvolution(dec_ch[0],
                                                  1,
                                                  kernel_size=1,
                                                  bias=True,
                                                  dimension=3)

        # pruning
        self.pruning = ME.MinkowskiPruning()
Ejemplo n.º 2
0
    def __init__(self):
        nn.Module.__init__(self)

        # Input sparse tensor must have tensor stride 128.
        ch = self.CHANNELS

        # Block 1
        self.block1 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[0], ch[0], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[0]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[0], ch[0], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[0]),
            ME.MinkowskiELU(),
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[0], ch[1], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[1]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[1], ch[1], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[1]),
            ME.MinkowskiELU(),
        )

        self.block1_cls = ME.MinkowskiConvolution(
            ch[1], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 2
        self.block2 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[1], ch[2], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[2]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[2], ch[2], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[2]),
            ME.MinkowskiELU(),
        )

        self.block2_cls = ME.MinkowskiConvolution(
            ch[2], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 3
        self.block3 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[2], ch[3], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[3]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[3], ch[3], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[3]),
            ME.MinkowskiELU(),
        )

        self.block3_cls = ME.MinkowskiConvolution(
            ch[3], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 4
        self.block4 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[3], ch[4], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[4]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[4], ch[4], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[4]),
            ME.MinkowskiELU(),
        )

        self.block4_cls = ME.MinkowskiConvolution(
            ch[4], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 5
        self.block5 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[4], ch[5], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[5]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[5], ch[5], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[5]),
            ME.MinkowskiELU(),
        )

        self.block5_cls = ME.MinkowskiConvolution(
            ch[5], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 6
        self.block6 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[5], ch[6], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[6]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[6], ch[6], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[6]),
            ME.MinkowskiELU(),
        )

        self.block6_cls = ME.MinkowskiConvolution(
            ch[6], 1, kernel_size=1, bias=True, dimension=3
        )

        # pruning
        self.pruning = ME.MinkowskiPruning()
Ejemplo n.º 3
0
  def network_initialization(self, in_channels, config, D):
    up_kernel_size = 3
    self.conv_up1 = nn.Sequential(
        ME.MinkowskiConvolutionTranspose(
            in_channels[0], in_channels[0], kernel_size=up_kernel_size, stride=2,
            generate_new_coords=True, dimension=3),
        ME.MinkowskiBatchNorm(in_channels[0]),
        ME.MinkowskiELU())

    self.conv_up2 = nn.Sequential(
        ME.MinkowskiConvolutionTranspose(
            in_channels[1], in_channels[0], kernel_size=up_kernel_size, stride=2,
            generate_new_coords=True, dimension=3),
        ME.MinkowskiBatchNorm(in_channels[0]),
        ME.MinkowskiELU())

    self.conv_up3 = nn.Sequential(
        ME.MinkowskiConvolutionTranspose(
            in_channels[2], in_channels[1], kernel_size=up_kernel_size, stride=2,
            generate_new_coords=True, dimension=3),
        ME.MinkowskiBatchNorm(in_channels[1]),
        ME.MinkowskiELU())

    self.conv_up4 = nn.Sequential(
        ME.MinkowskiConvolutionTranspose(
            in_channels[3], in_channels[2], kernel_size=up_kernel_size, stride=2,
            generate_new_coords=True, dimension=3),
        ME.MinkowskiBatchNorm(in_channels[2]),
        ME.MinkowskiELU())

    self.conv_feat1 = nn.Sequential(
        ME.MinkowskiConvolution(
            in_channels[0], config.upsample_feat_size, kernel_size=1, dimension=3),
        ME.MinkowskiBatchNorm(config.upsample_feat_size),
        ME.MinkowskiELU())

    self.conv_feat2 = nn.Sequential(
        ME.MinkowskiConvolution(
            in_channels[1], config.upsample_feat_size, kernel_size=1, dimension=3),
        ME.MinkowskiBatchNorm(config.upsample_feat_size),
        ME.MinkowskiELU())

    self.conv_feat3 = nn.Sequential(
        ME.MinkowskiConvolution(
            in_channels[2], config.upsample_feat_size, kernel_size=1, dimension=3),
        ME.MinkowskiBatchNorm(config.upsample_feat_size),
        ME.MinkowskiELU())

    self.conv_feat4 = nn.Sequential(
        ME.MinkowskiConvolution(
            in_channels[3], config.upsample_feat_size, kernel_size=1, dimension=3),
        ME.MinkowskiBatchNorm(config.upsample_feat_size),
        ME.MinkowskiELU())

    self.conv_cls1 = ME.MinkowskiConvolution(
        config.upsample_feat_size, 2, kernel_size=1, has_bias=True, dimension=3)
    self.conv_cls2 = ME.MinkowskiConvolution(
        config.upsample_feat_size, 2, kernel_size=1, has_bias=True, dimension=3)
    self.conv_cls3 = ME.MinkowskiConvolution(
        config.upsample_feat_size, 2, kernel_size=1, has_bias=True, dimension=3)
    self.conv_cls4 = ME.MinkowskiConvolution(
        config.upsample_feat_size, 2, kernel_size=1, has_bias=True, dimension=3)

    self.elu = ME.MinkowskiELU()
    self.pruning = ME.MinkowskiPruning()
Ejemplo n.º 4
0
    def __init__(self, channels, block_layers, block):
        nn.Module.__init__(self)
        out_nchannel = 1
        ch = [channels, 64, 32, 16]
        if block == 'ResNet':
            self.block = ResNet
        elif block == 'InceptionResNet':
            self.block = InceptionResNet

        self.up0 = ME.MinkowskiGenerativeConvolutionTranspose(
            in_channels=ch[0],
            out_channels=ch[1],
            kernel_size=2,
            stride=2,
            bias=True,
            dimension=3)
        self.conv0 = ME.MinkowskiConvolution(in_channels=ch[1],
                                             out_channels=ch[1],
                                             kernel_size=3,
                                             stride=1,
                                             bias=True,
                                             dimension=3)
        self.block0 = self.make_layer(self.block, block_layers, ch[1])

        self.conv0_cls = ME.MinkowskiConvolution(in_channels=ch[1],
                                                 out_channels=out_nchannel,
                                                 kernel_size=3,
                                                 stride=1,
                                                 bias=True,
                                                 dimension=3)

        self.up1 = ME.MinkowskiGenerativeConvolutionTranspose(
            in_channels=ch[1],
            out_channels=ch[2],
            kernel_size=2,
            stride=2,
            bias=True,
            dimension=3)
        self.conv1 = ME.MinkowskiConvolution(in_channels=ch[2],
                                             out_channels=ch[2],
                                             kernel_size=3,
                                             stride=1,
                                             bias=True,
                                             dimension=3)
        self.block1 = self.make_layer(self.block, block_layers, ch[2])

        self.conv1_cls = ME.MinkowskiConvolution(in_channels=ch[2],
                                                 out_channels=out_nchannel,
                                                 kernel_size=3,
                                                 stride=1,
                                                 bias=True,
                                                 dimension=3)

        self.up2 = ME.MinkowskiGenerativeConvolutionTranspose(
            in_channels=ch[2],
            out_channels=ch[3],
            kernel_size=2,
            stride=2,
            bias=True,
            dimension=3)
        self.conv2 = ME.MinkowskiConvolution(in_channels=ch[3],
                                             out_channels=ch[3],
                                             kernel_size=3,
                                             stride=1,
                                             bias=True,
                                             dimension=3)
        self.block2 = self.make_layer(self.block, block_layers, ch[3])

        self.conv2_cls = ME.MinkowskiConvolution(in_channels=ch[3],
                                                 out_channels=out_nchannel,
                                                 kernel_size=3,
                                                 stride=1,
                                                 bias=True,
                                                 dimension=3)

        self.relu = ME.MinkowskiReLU(inplace=True)
        # self.relu = ME.MinkowskiELU(inplace=True)

        # pruning
        self.pruning = ME.MinkowskiPruning()
Ejemplo n.º 5
0
    def __init__(self, config={}, **kwargs):
        MEDecoder.__init__(self, config=config, **kwargs)

        # need square and power of 2 image size input
        power = math.log(self.config.input_size[0], 2)
        assert (power % 1 == 0.0) and (
            power > 3
        ), "Dumoulin Encoder needs a power of 2 as image input size (>=16)"
        assert self.config.input_size[0] == self.config.input_size[
            1], "Dumoulin Encoder needs a square image input size"

        assert self.config.n_conv_layers == power, "The number of convolutional layers in DumoulinEncoder must be log2(input_size) "

        # network architecture
        kernels_size = [3, 2] * self.config.n_conv_layers
        strides = [1, 2] * self.config.n_conv_layers
        dils = [1, 1] * self.config.n_conv_layers

        if self.config.hidden_channels is None:
            self.config.hidden_channels = 8

        # encoder feature inverse
        hidden_channels = int(self.config.hidden_channels *
                              math.pow(2, self.config.n_conv_layers))
        self.efi = nn.Sequential(
            ME.MinkowskiConvolution(
                self.config.n_latents,
                hidden_channels,
                kernel_size=1,
                stride=1,  #dilation=1,
                dimension=self.spatial_dims,
                bias=False),
            ME.MinkowskiBatchNorm(hidden_channels),
            ME.MinkowskiELU(inplace=True),
            ME.MinkowskiConvolution(
                hidden_channels,
                hidden_channels,
                kernel_size=1,
                stride=1,  #dilation=1,
                dimension=self.spatial_dims,
                bias=False),
            ME.MinkowskiBatchNorm(hidden_channels),
            ME.MinkowskiELU(inplace=True),
        )
        self.efi.out_connection_type = ("conv", hidden_channels)
        self.efi_cls = ME.MinkowskiConvolution(hidden_channels,
                                               1,
                                               kernel_size=1,
                                               bias=True,
                                               dimension=self.spatial_dims)

        # global feature inverse
        self.gfi = nn.Sequential()
        ## convolutional layers
        for conv_layer_id in range(self.config.n_conv_layers - 1,
                                   self.config.feature_layer + 1 - 1, -1):
            self.gfi.add_module(
                "conv_{}_i".format(conv_layer_id),
                nn.Sequential(
                    ME.MinkowskiGenerativeConvolutionTranspose(
                        hidden_channels,
                        hidden_channels // 2,
                        kernel_size=kernels_size[2 * conv_layer_id + 1],
                        stride=strides[2 * conv_layer_id + 1],
                        #dilation=dils[2 * conv_layer_id + 1],
                        dimension=self.spatial_dims,
                        bias=False),
                    ME.MinkowskiBatchNorm(hidden_channels // 2),
                    ME.MinkowskiELU(inplace=True),
                    ME.MinkowskiConvolution(
                        hidden_channels // 2,
                        hidden_channels // 2,
                        kernel_size=kernels_size[2 * conv_layer_id],
                        stride=strides[2 * conv_layer_id],
                        #dilation=dils[2 * conv_layer_id],
                        dimension=self.spatial_dims,
                        bias=False),
                    ME.MinkowskiBatchNorm(hidden_channels // 2),
                    ME.MinkowskiELU(inplace=True),
                ))
            hidden_channels = hidden_channels // 2

            self.gfi.add_module(
                "cls_{}_i".format(conv_layer_id),
                ME.MinkowskiConvolution(hidden_channels,
                                        1,
                                        kernel_size=1,
                                        bias=True,
                                        dimension=self.spatial_dims))

        self.gfi.out_connection_type = ("conv", hidden_channels)
        #self.gfi_cls = ME.MinkowskiConvolution(hidden_channels, 1, kernel_size=1, bias=True, dimension=self.spatial_dims)

        # local feature inverse
        self.lfi = nn.Sequential()
        for conv_layer_id in range(self.config.feature_layer + 1 - 1, 0, -1):
            self.lfi.add_module(
                "conv_{}_i".format(conv_layer_id),
                nn.Sequential(
                    ME.MinkowskiGenerativeConvolutionTranspose(
                        hidden_channels,
                        hidden_channels // 2,
                        kernel_size=kernels_size[2 * conv_layer_id + 1],
                        stride=strides[2 * conv_layer_id + 1],
                        #dilation=dils[2 * conv_layer_id + 1],
                        dimension=self.spatial_dims,
                        bias=False),
                    ME.MinkowskiBatchNorm(hidden_channels // 2),
                    ME.MinkowskiELU(inplace=True),
                    ME.MinkowskiConvolution(
                        hidden_channels // 2,
                        hidden_channels // 2,
                        kernel_size=kernels_size[2 * conv_layer_id],
                        stride=strides[2 * conv_layer_id],
                        #dilation=dils[2 * conv_layer_id],
                        dimension=self.spatial_dims,
                        bias=False),
                    ME.MinkowskiBatchNorm(hidden_channels // 2),
                    ME.MinkowskiELU(inplace=True),
                ))
            hidden_channels = hidden_channels // 2

            self.lfi.add_module(
                "cls_{}_i".format(conv_layer_id),
                ME.MinkowskiConvolution(hidden_channels,
                                        1,
                                        kernel_size=1,
                                        bias=True,
                                        dimension=self.spatial_dims))

        self.lfi.add_module(
            "conv_0_i",
            nn.Sequential(
                ME.MinkowskiGenerativeConvolutionTranspose(
                    hidden_channels,
                    hidden_channels // 2,
                    kernel_size=kernels_size[1],
                    stride=strides[1],  #dilation=dils[1],
                    dimension=self.spatial_dims,
                    bias=False),
                ME.MinkowskiBatchNorm(hidden_channels // 2),
                ME.MinkowskiELU(inplace=True),
                ME.MinkowskiConvolution(
                    hidden_channels // 2,
                    self.config.n_channels,
                    kernel_size=kernels_size[0],
                    stride=strides[0],  #dilation=dils[0],
                    dimension=self.spatial_dims,
                    bias=False),
            ))
        self.lfi.add_module(
            "cls_0_i",
            ME.MinkowskiConvolution(self.config.n_channels,
                                    1,
                                    kernel_size=1,
                                    bias=False,
                                    dimension=self.spatial_dims))

        self.lfi.out_connection_type = ("conv", self.config.n_channels)
        #self.lfi_cls = ME.MinkowskiConvolution(self.config.n_channels, 1, kernel_size=1, bias=True, dimension=self.spatial_dims)

        # pruning
        self.pruning = ME.MinkowskiPruning()