def __init__(self,
               inplanes,
               planes,
               stride=1,
               dilation=1,
               downsample=None,
               conv_type=ConvType.HYPERCUBE,
               bn_momentum=0.1,
               D=3):
    super(BasicBlockBase, self).__init__()

    self.conv1 = conv(
        inplanes, planes, kernel_size=3, stride=stride, dilation=dilation, conv_type=conv_type, D=D)
    self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum)
    self.conv2 = conv(
        planes,
        planes,
        kernel_size=3,
        stride=1,
        dilation=dilation,
        bias=False,
        conv_type=conv_type,
        D=D)
    self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum)
    self.relu = MinkowskiReLU(inplace=True)
    self.downsample = downsample
Exemple #2
0
    def test_network(self):
        dense_tensor = torch.rand(3, 4, 11, 11, 11, 11)  # BxCxD1xD2x....xDN
        dense_tensor.requires_grad = True

        # Since the shape is fixed, cache the coordinates for faster inference
        coordinates = dense_coordinates(dense_tensor.shape)

        network = nn.Sequential(
            # Add layers that can be applied on a regular pytorch tensor
            nn.ReLU(),
            MinkowskiToSparseTensor(remove_zeros=False,
                                    coordinates=coordinates),
            MinkowskiConvolution(4, 5, stride=2, kernel_size=3, dimension=4),
            MinkowskiBatchNorm(5),
            MinkowskiReLU(),
            MinkowskiConvolutionTranspose(5,
                                          6,
                                          stride=2,
                                          kernel_size=3,
                                          dimension=4),
            MinkowskiToDenseTensor(
                dense_tensor.shape),  # must have the same tensor stride.
        )

        for i in range(5):
            print(f"Iteration: {i}")
            output = network(dense_tensor)
            output.sum().backward()

        assert dense_tensor.grad is not None
Exemple #3
0
    def __init__(self, in_channels, out_channels, config, D=3, **kwargs):
        super(RecUNetBase, self).__init__(in_channels, out_channels, config, D)

        PLANES = self.PLANES[::-1]
        bn_momentum = config.bn_momentum

        if D == 4:
            self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1, D)

        # Output of the first conv concated to conv6
        self.conv1 = conv(in_channels,
                          PLANES[-1][0],
                          kernel_size=space_n_time_m(3, 1, D),
                          conv_type=self.CONV_TYPE,
                          D=D)
        self.norm1 = get_norm(self.NORM_TYPE, PLANES[-1][0], D, bn_momentum)
        interm = self.BLOCK(PLANES[0][0],
                            PLANES[0][0],
                            conv_type=self.CONV_TYPE,
                            D=self.D)

        for i, inoutplanes in enumerate(PLANES[1:]):
            interm = UBlock(inoutplanes[0],
                            PLANES[i][0],
                            PLANES[i][1],
                            inoutplanes[1],
                            intermediate_module=interm,
                            BLOCK=self.BLOCK,
                            reps=self.REPS[len(self.REPS) - i - 1],
                            conv_type=self.CONV_TYPE,
                            bn_momentum=bn_momentum,
                            D=D)
        self.unet = interm
        self.final = conv(PLANES[-1][1],
                          out_channels,
                          kernel_size=1,
                          stride=1,
                          dilation=1,
                          bias=True,
                          D=D)

        self.relu = MinkowskiReLU(inplace=True)
    def network_initialization(self, in_channels, out_channels, D):
        # Setup net_metadata
        dilations = self.DILATIONS
        bn_momentum = 0.02

        def space_n_time_m(n, m):
            return n if D == 3 else [n, n, n, m]

        if D == 4:
            self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1)

        # Output of the first conv concated to conv6
        self.inplanes = self.INIT_DIM
        self.conv0p1s1 = conv(
            in_channels,
            self.inplanes,
            kernel_size=space_n_time_m(3, 1),
            stride=1,
            dilation=1,
            conv_type=self.NON_BLOCK_CONV_TYPE,
            D=D,
        )

        self.bn0 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)

        self.conv1p1s2 = conv(
            self.inplanes,
            self.inplanes,
            kernel_size=space_n_time_m(2, 1),
            stride=space_n_time_m(2, 1),
            dilation=1,
            conv_type=self.NON_BLOCK_CONV_TYPE,
            D=D,
        )
        self.bn1 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
        self.block1 = self._make_layer(
            self.BLOCK,
            self.PLANES[0],
            self.LAYERS[0],
            dilation=dilations[0],
            norm_type=self.NORM_TYPE,
            bn_momentum=bn_momentum,
        )

        self.conv2p2s2 = conv(
            self.inplanes,
            self.inplanes,
            kernel_size=space_n_time_m(2, 1),
            stride=space_n_time_m(2, 1),
            dilation=1,
            conv_type=self.NON_BLOCK_CONV_TYPE,
            D=D,
        )
        self.bn2 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
        self.block2 = self._make_layer(
            self.BLOCK,
            self.PLANES[1],
            self.LAYERS[1],
            dilation=dilations[1],
            norm_type=self.NORM_TYPE,
            bn_momentum=bn_momentum,
        )

        self.conv3p4s2 = conv(
            self.inplanes,
            self.inplanes,
            kernel_size=space_n_time_m(2, 1),
            stride=space_n_time_m(2, 1),
            dilation=1,
            conv_type=self.NON_BLOCK_CONV_TYPE,
            D=D,
        )
        self.bn3 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
        self.block3 = self._make_layer(
            self.BLOCK,
            self.PLANES[2],
            self.LAYERS[2],
            dilation=dilations[2],
            norm_type=self.NORM_TYPE,
            bn_momentum=bn_momentum,
        )

        self.conv4p8s2 = conv(
            self.inplanes,
            self.inplanes,
            kernel_size=space_n_time_m(2, 1),
            stride=space_n_time_m(2, 1),
            dilation=1,
            conv_type=self.NON_BLOCK_CONV_TYPE,
            D=D,
        )
        self.bn4 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
        self.block4 = self._make_layer(
            self.BLOCK,
            self.PLANES[3],
            self.LAYERS[3],
            dilation=dilations[3],
            norm_type=self.NORM_TYPE,
            bn_momentum=bn_momentum,
        )
        self.convtr4p16s2 = conv_tr(
            self.inplanes,
            self.PLANES[4],
            kernel_size=space_n_time_m(2, 1),
            upsample_stride=space_n_time_m(2, 1),
            dilation=1,
            bias=False,
            conv_type=self.NON_BLOCK_CONV_TYPE,
            D=D,
        )
        self.bntr4 = get_norm(self.NORM_TYPE, self.PLANES[4], D, bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion
        self.block5 = self._make_layer(
            self.BLOCK,
            self.PLANES[4],
            self.LAYERS[4],
            dilation=dilations[4],
            norm_type=self.NORM_TYPE,
            bn_momentum=bn_momentum,
        )
        self.convtr5p8s2 = conv_tr(
            self.inplanes,
            self.PLANES[5],
            kernel_size=space_n_time_m(2, 1),
            upsample_stride=space_n_time_m(2, 1),
            dilation=1,
            bias=False,
            conv_type=self.NON_BLOCK_CONV_TYPE,
            D=D,
        )
        self.bntr5 = get_norm(self.NORM_TYPE, self.PLANES[5], D, bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion
        self.block6 = self._make_layer(
            self.BLOCK,
            self.PLANES[5],
            self.LAYERS[5],
            dilation=dilations[5],
            norm_type=self.NORM_TYPE,
            bn_momentum=bn_momentum,
        )
        self.convtr6p4s2 = conv_tr(
            self.inplanes,
            self.PLANES[6],
            kernel_size=space_n_time_m(2, 1),
            upsample_stride=space_n_time_m(2, 1),
            dilation=1,
            bias=False,
            conv_type=self.NON_BLOCK_CONV_TYPE,
            D=D,
        )
        self.bntr6 = get_norm(self.NORM_TYPE, self.PLANES[6], D, bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion
        self.block7 = self._make_layer(
            self.BLOCK,
            self.PLANES[6],
            self.LAYERS[6],
            dilation=dilations[6],
            norm_type=self.NORM_TYPE,
            bn_momentum=bn_momentum,
        )
        self.convtr7p2s2 = conv_tr(
            self.inplanes,
            self.PLANES[7],
            kernel_size=space_n_time_m(2, 1),
            upsample_stride=space_n_time_m(2, 1),
            dilation=1,
            bias=False,
            conv_type=self.NON_BLOCK_CONV_TYPE,
            D=D,
        )
        self.bntr7 = get_norm(self.NORM_TYPE, self.PLANES[7], D, bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[7] + self.INIT_DIM
        self.block8 = self._make_layer(
            self.BLOCK,
            self.PLANES[7],
            self.LAYERS[7],
            dilation=dilations[7],
            norm_type=self.NORM_TYPE,
            bn_momentum=bn_momentum,
        )

        self.final = conv(self.PLANES[7], out_channels, kernel_size=1, stride=1, bias=True, D=D)
        self.relu = MinkowskiReLU(inplace=True)
Exemple #5
0
    def network_initialization(self, in_channels, out_channels, config, D):
        # Setup net_metadata
        dilations = self.DILATIONS
        bn_momentum = config.bn_momentum

        def space_n_time_m(n, m):
            return n if D == 3 else [n, n, n, m]

        if D == 4:
            self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1)

        # Output of the first conv concated to conv6
        self.inplanes = self.INIT_DIM
        self.conv1p1s1 = conv(in_channels,
                              self.inplanes,
                              kernel_size=space_n_time_m(
                                  config.conv1_kernel_size, 1),
                              stride=1,
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)

        self.bn1 = get_norm(self.NORM_TYPE,
                            self.PLANES[0],
                            D,
                            bn_momentum=bn_momentum)
        self.block1 = self._make_layer(self.BLOCK,
                                       self.PLANES[0],
                                       self.LAYERS[0],
                                       dilation=dilations[0],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv2p1s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn2 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block2 = self._make_layer(self.BLOCK,
                                       self.PLANES[1],
                                       self.LAYERS[1],
                                       dilation=dilations[1],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv3p2s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn3 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block3 = self._make_layer(self.BLOCK,
                                       self.PLANES[2],
                                       self.LAYERS[2],
                                       dilation=dilations[2],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv4p4s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn4 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block4 = self._make_layer(self.BLOCK,
                                       self.PLANES[3],
                                       self.LAYERS[3],
                                       dilation=dilations[3],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.pool_tr4 = ME.MinkowskiPoolingTranspose(kernel_size=8,
                                                     stride=8,
                                                     dimension=D)
        out_pool4 = self.inplanes
        self.convtr4p8s2 = conv_tr(self.inplanes,
                                   self.PLANES[4],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr4 = get_norm(self.NORM_TYPE,
                              self.PLANES[4],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion
        self.block5 = self._make_layer(self.BLOCK,
                                       self.PLANES[4],
                                       self.LAYERS[4],
                                       dilation=dilations[4],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.pool_tr5 = ME.MinkowskiPoolingTranspose(kernel_size=4,
                                                     stride=4,
                                                     dimension=D)
        out_pool5 = self.inplanes
        self.convtr5p4s2 = conv_tr(self.inplanes,
                                   self.PLANES[5],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr5 = get_norm(self.NORM_TYPE,
                              self.PLANES[5],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion
        self.block6 = self._make_layer(self.BLOCK,
                                       self.PLANES[5],
                                       self.LAYERS[5],
                                       dilation=dilations[5],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.pool_tr6 = ME.MinkowskiPoolingTranspose(kernel_size=2,
                                                     stride=2,
                                                     dimension=D)
        out_pool6 = self.inplanes
        self.convtr6p2s2 = conv_tr(self.inplanes,
                                   self.PLANES[6],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr6 = get_norm(self.NORM_TYPE,
                              self.PLANES[6],
                              D,
                              bn_momentum=bn_momentum)

        self.relu = MinkowskiReLU(inplace=True)

        self.final = nn.Sequential(
            conv(out_pool5 + out_pool6 + self.PLANES[6] +
                 self.PLANES[0] * self.BLOCK.expansion,
                 512,
                 kernel_size=1,
                 bias=False,
                 D=D), ME.MinkowskiBatchNorm(512), ME.MinkowskiReLU(),
            conv(512, out_channels, kernel_size=1, bias=True, D=D))
  def network_initialization(self, in_channels, out_channels, config, D):
    # Setup net_metadata
    dilations = self.DILATIONS
    bn_momentum = config['bn_momentum']

    def space_n_time_m(n, m):
      return n if D == 3 else [n, n, n, m]

    if D == 4:
      self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1)

    # Output of the first conv concated to conv6
    self.inplanes = self.INIT_DIM
    self.conv0p1s1 = conv(
        in_channels,
        self.inplanes,
        kernel_size=space_n_time_m(config['conv1_kernel_size'], 1),
        stride=1,
        dilation=1,
        conv_type=self.NON_BLOCK_CONV_TYPE,
        D=D)

    self.bn0 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)

    self.conv1p1s2 = conv(
        self.inplanes,
        self.inplanes,
        kernel_size=space_n_time_m(2, 1),
        stride=space_n_time_m(2, 1),
        dilation=1,
        conv_type=self.NON_BLOCK_CONV_TYPE,
        D=D)
    self.bn1 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
    self.block1 = self._make_layer(
        self.BLOCK,
        self.PLANES[0],
        self.LAYERS[0],
        dilation=dilations[0],
        norm_type=self.NORM_TYPE,
        bn_momentum=bn_momentum)

    self.conv2p2s2 = conv(
        self.inplanes,
        self.inplanes,
        kernel_size=space_n_time_m(2, 1),
        stride=space_n_time_m(2, 1),
        dilation=1,
        conv_type=self.NON_BLOCK_CONV_TYPE,
        D=D)
    self.bn2 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
    self.block2 = self._make_layer(
        self.BLOCK,
        self.PLANES[1],
        self.LAYERS[1],
        dilation=dilations[1],
        norm_type=self.NORM_TYPE,
        bn_momentum=bn_momentum)

    self.conv3p4s2 = conv(
        self.inplanes,
        self.inplanes,
        kernel_size=space_n_time_m(2, 1),
        stride=space_n_time_m(2, 1),
        dilation=1,
        conv_type=self.NON_BLOCK_CONV_TYPE,
        D=D)
    self.bn3 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
    self.block3 = self._make_layer(
        self.BLOCK,
        self.PLANES[2],
        self.LAYERS[2],
        dilation=dilations[2],
        norm_type=self.NORM_TYPE,
        bn_momentum=bn_momentum)

    self.conv4p8s2 = conv(
        self.inplanes,
        self.inplanes,
        kernel_size=space_n_time_m(2, 1),
        stride=space_n_time_m(2, 1),
        dilation=1,
        conv_type=self.NON_BLOCK_CONV_TYPE,
        D=D)
    self.bn4 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
    self.block4 = self._make_layer(
        self.BLOCK,
        self.PLANES[3],
        self.LAYERS[3],
        dilation=dilations[3],
        norm_type=self.NORM_TYPE,
        bn_momentum=bn_momentum)
        
    self.relu = MinkowskiReLU(inplace=True)
    # add a classification head here
    self.clf_glob_avg = ME.MinkowskiGlobalPooling(dimension=D)
    self.clf_glob_max=ME.MinkowskiGlobalMaxPooling(dimension=D)
    self.clf_conv0 = conv(
        256,
        512,
        kernel_size=3,
        stride=2,
        dilation=1,
        conv_type=self.NON_BLOCK_CONV_TYPE,
        D=D)
    self.clf_bn0 = get_norm(self.NORM_TYPE, 512, D, bn_momentum=bn_momentum)
    self.clf_conv1 = conv(
        512,
        512,
        kernel_size=3,
        stride=2,
        dilation=1,
        conv_type=self.NON_BLOCK_CONV_TYPE,
        D=D)
    self.clf_bn1 = get_norm(self.NORM_TYPE, 512, D, bn_momentum=bn_momentum)
    self.clf_conv2 = conv(
        512,
        config['clf_num_labels'],
        kernel_size=1,
        stride=1,
        dilation=1,
        conv_type=self.NON_BLOCK_CONV_TYPE,
        D=D)
Exemple #7
0
    def network_initialization(self, in_channels, out_channels, D):
        # Setup net_metadata
        dilations = self.DILATIONS
        bn_momentum = 0.02

        def space_n_time_m(n, m):
            return n if D == 3 else [n, n, n, m]

        if D == 4:
            self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1)

        # Output of the first conv concated to conv6
        conv1_kernel_size = 3
        self.inplanes = self.INIT_DIM
        self.conv0p1s1 = conv(in_channels,
                              self.inplanes,
                              kernel_size=space_n_time_m(conv1_kernel_size, 1),
                              stride=1,
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)

        self.bn0 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)

        self.conv1p1s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn1 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block1 = self._make_layer(self.BLOCK,
                                       self.PLANES[0],
                                       self.LAYERS[0],
                                       dilation=dilations[0],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv2p2s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn2 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block2 = self._make_layer(self.BLOCK,
                                       self.PLANES[1],
                                       self.LAYERS[1],
                                       dilation=dilations[1],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv3p4s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn3 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block3 = self._make_layer(self.BLOCK,
                                       self.PLANES[2],
                                       self.LAYERS[2],
                                       dilation=dilations[2],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv4p8s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn4 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block4 = self._make_layer(self.BLOCK,
                                       self.PLANES[3],
                                       self.LAYERS[3],
                                       dilation=dilations[3],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.convtr4p16s2 = conv_tr(self.inplanes,
                                    self.PLANES[4],
                                    kernel_size=space_n_time_m(2, 1),
                                    upsample_stride=space_n_time_m(2, 1),
                                    dilation=1,
                                    bias=False,
                                    conv_type=self.NON_BLOCK_CONV_TYPE,
                                    D=D)
        self.bntr4 = get_norm(self.NORM_TYPE,
                              self.PLANES[4],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion
        self.block5 = self._make_layer(self.BLOCK,
                                       self.PLANES[4],
                                       self.LAYERS[4],
                                       dilation=dilations[4],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.convtr5p8s2 = conv_tr(self.inplanes,
                                   self.PLANES[5],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr5 = get_norm(self.NORM_TYPE,
                              self.PLANES[5],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion
        self.block6 = self._make_layer(self.BLOCK,
                                       self.PLANES[5],
                                       self.LAYERS[5],
                                       dilation=dilations[5],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.convtr6p4s2 = conv_tr(self.inplanes,
                                   self.PLANES[6],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr6 = get_norm(self.NORM_TYPE,
                              self.PLANES[6],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion
        self.block7 = self._make_layer(self.BLOCK,
                                       self.PLANES[6],
                                       self.LAYERS[6],
                                       dilation=dilations[6],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.convtr7p2s2 = conv_tr(self.inplanes,
                                   self.PLANES[7],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr7 = get_norm(self.NORM_TYPE,
                              self.PLANES[7],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[7] + self.INIT_DIM
        self.block8 = self._make_layer(self.BLOCK,
                                       self.PLANES[7],
                                       self.LAYERS[7],
                                       dilation=dilations[7],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.all_feat_names = [
            "en0",
            "en1",
            "en2",
            "en3",
            "en4",
            "plane4",
            "plane5",
            "plane6",
            "plane7",
        ]

        self.relu = MinkowskiReLU(inplace=True)
        self.maxpool = ME.MinkowskiGlobalMaxPooling()
        self.avgpool = ME.MinkowskiGlobalAvgPooling()
        if self.use_mlp:
            self.head = SMLP(self.mlp_dim)