Пример #1
0
Файл: unet.py Проект: jgwak/GSDN
    def __init__(self, in_channels, out_channels, config, D=3, **kwargs):
        super(UNet2, self).__init__(in_channels, out_channels, config, D)
        PLANES = self.PLANES

        # Output of the first conv concated to conv6
        self.conv1 = conv(in_channels,
                          PLANES[0],
                          kernel_size=3,
                          stride=1,
                          bias=False,
                          D=D)
        self.bn1 = ME.MinkowskiBatchNorm(PLANES[0])

        self.block1 = UNBlocks(PLANES[0], PLANES[0], D)
        self.down1 = conv(PLANES[0], PLANES[1], kernel_size=2, stride=2, D=D)
        self.down1bn = ME.MinkowskiBatchNorm(PLANES[1])

        self.up1 = conv_tr(PLANES[1],
                           PLANES[0],
                           kernel_size=2,
                           upsample_stride=2,
                           D=D)
        self.block1up = UNBlocks(PLANES[0] * 2, PLANES[0], D)

        self.block2 = UNBlocks(PLANES[1], PLANES[1], D)
        self.down2 = conv(PLANES[1], PLANES[2], kernel_size=2, stride=2, D=D)

        self.up2 = conv_tr(PLANES[2],
                           PLANES[1],
                           kernel_size=2,
                           upsample_stride=2,
                           D=D)
        self.block2up = UNBlocks(PLANES[1] * 2, PLANES[1], D)

        self.block3 = UNBlocks(PLANES[2], PLANES[2], D)
        self.down3 = conv(PLANES[2], PLANES[3], kernel_size=2, stride=2, D=D)

        self.up3 = conv_tr(PLANES[3],
                           PLANES[2],
                           kernel_size=2,
                           upsample_stride=2,
                           D=D)
        self.block3up = UNBlocks(PLANES[2] * 2, PLANES[2], D)

        self.block4 = UNBlocks(PLANES[3], PLANES[3], D)

        self.relu = ME.MinkowskiReLU(inplace=True)
        self.final = conv(PLANES[0],
                          out_channels,
                          kernel_size=1,
                          bias=True,
                          D=D)
Пример #2
0
    def __init__(self,
                 inplanes,
                 intermediate_inplanes,
                 intermediate_outplanes,
                 outplanes,
                 intermediate_module,
                 BLOCK=None,
                 reps=1,
                 conv_type=ConvType.HYPERCUBE,
                 norm_type=NormType.BATCH_NORM,
                 bn_momentum=0.1,
                 D=3):
        super(UBlock, self).__init__()

        self.block = BLOCK(inplanes,
                           inplanes,
                           conv_type=conv_type,
                           bn_momentum=bn_momentum,
                           D=D)
        self.down = conv(inplanes,
                         intermediate_inplanes,
                         kernel_size=space_n_time_m(2, 1, D),
                         stride=space_n_time_m(2, 1, D),
                         conv_type=conv_type,
                         D=D)
        self.down_norm = get_norm(norm_type,
                                  intermediate_inplanes,
                                  D,
                                  bn_momentum=bn_momentum)
        self.intermediate = intermediate_module
        self.up = conv_tr(intermediate_outplanes,
                          outplanes,
                          kernel_size=space_n_time_m(2, 1, D),
                          upsample_stride=space_n_time_m(2, 1, D),
                          conv_type=conv_type,
                          D=D)
        self.up_norm = get_norm(norm_type,
                                outplanes,
                                D,
                                bn_momentum=bn_momentum)
        self.reps = reps
        for i in range(reps):
            if i == 0:
                downsample = nn.Sequential(
                    conv(inplanes + outplanes,
                         outplanes,
                         kernel_size=1,
                         bias=False,
                         D=D),
                    get_norm(norm_type, outplanes, D, bn_momentum=bn_momentum),
                )

            setattr(
                self, f'end_blocks{i}',
                BLOCK((inplanes + outplanes) if i == 0 else outplanes,
                      outplanes,
                      downsample=downsample if i == 0 else None,
                      conv_type=conv_type,
                      bn_momentum=bn_momentum,
                      D=D))
Пример #3
0
 def network_initialization(self, in_channels, out_channels, config, D):
   self.conv_feat1 = conv(in_channels[0], config.upsample_feat_size, 3, D=D)
   self.conv_feat2 = conv(in_channels[1], config.upsample_feat_size, 3, D=D)
   self.conv_feat3 = conv(in_channels[2], config.upsample_feat_size, 3, D=D)
   self.conv_feat4 = conv(in_channels[3], config.upsample_feat_size, 3, D=D)
   self.bn_feat1 = get_norm(
       NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
   self.bn_feat2 = get_norm(
       NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
   self.bn_feat3 = get_norm(
       NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
   self.bn_feat4 = get_norm(
       NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
   self.conv_up2 = conv_tr(
       config.upsample_feat_size, config.upsample_feat_size, kernel_size=2, upsample_stride=2,
       dilation=1, bias=False, D=3)
   self.conv_up3 = conv_tr(
       config.upsample_feat_size, config.upsample_feat_size, kernel_size=2, upsample_stride=2,
       dilation=1, bias=False, D=3)
   self.conv_up4 = conv_tr(
       config.upsample_feat_size, config.upsample_feat_size, kernel_size=2, upsample_stride=2,
       dilation=1, bias=False, D=3)
   self.conv_up5 = conv_tr(
       config.upsample_feat_size, config.upsample_feat_size, kernel_size=2, upsample_stride=2,
       dilation=1, bias=False, D=3)
   self.conv_up6 = conv_tr(
       config.upsample_feat_size, config.upsample_feat_size, kernel_size=2, upsample_stride=2,
       dilation=1, bias=False, D=3)
   self.bn_up2 = get_norm(
       NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
   self.bn_up3 = get_norm(
       NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
   self.bn_up4 = get_norm(
       NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
   self.bn_up5 = get_norm(
       NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
   self.bn_up6 = get_norm(
       NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
   self.conv_final = conv(config.upsample_feat_size, out_channels, 1, D=D)
   self.relu = ME.MinkowskiReLU(inplace=False)
Пример #4
0
    def network_initialization(self, in_channels, out_channels, config, D):
        net_metadata = self.net_metadata

        def space_n_time_m(n, m):
            return n if D == 3 else [n, n, n, m]

        if D == 4:
            self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1)

        # Output of the first conv concated to conv6
        self.inplanes = self.PLANES[0]
        self.conv1 = conv(in_channels,
                          self.inplanes,
                          pixel_dist=1,
                          kernel_size=space_n_time_m(5, 1),
                          stride=1,
                          dilation=1,
                          bias=False,
                          D=D,
                          net_metadata=net_metadata)

        self.bn1 = nn.BatchNorm1d(self.inplanes)
        self.block1 = self._make_layer(self.BLOCK,
                                       self.PLANES[0],
                                       self.LAYERS[0],
                                       pixel_dist=1)

        self.conv2p1s2 = conv(
            self.inplanes,
            self.inplanes,
            pixel_dist=1,
            kernel_size=space_n_time_m(2, 1),
            stride=space_n_time_m(2, 1),
            dilation=1,
            conv_type=ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS,
            bias=False,
            D=D,
            net_metadata=net_metadata)
        self.bn2 = nn.BatchNorm1d(self.inplanes)
        self.block2 = self._make_layer(self.BLOCK,
                                       self.PLANES[1],
                                       self.LAYERS[1],
                                       pixel_dist=space_n_time_m(2, 1))
        self.convtr2p2s2 = conv_tr(
            self.inplanes,
            self.PLANES[1],
            pixel_dist=space_n_time_m(2, 1),
            kernel_size=space_n_time_m(2, 1),
            upsample_stride=space_n_time_m(2, 1),
            dilation=1,
            conv_type=ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS,
            bias=False,
            D=D,
            net_metadata=net_metadata)
        self.bntr2 = nn.BatchNorm1d(self.PLANES[1])

        self.conv3p2s2 = conv(
            self.inplanes,
            self.inplanes,
            pixel_dist=space_n_time_m(2, 1),
            kernel_size=space_n_time_m(2, 1),
            stride=space_n_time_m(2, 1),
            dilation=1,
            conv_type=ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS,
            bias=False,
            D=D,
            net_metadata=net_metadata)
        self.bn3 = nn.BatchNorm1d(self.inplanes)
        self.block3 = self._make_layer(self.BLOCK,
                                       self.PLANES[2],
                                       self.LAYERS[2],
                                       pixel_dist=space_n_time_m(4, 1))

        self.convtr3p4s4 = conv_tr(
            self.inplanes,
            self.PLANES[2],
            pixel_dist=space_n_time_m(4, 1),
            kernel_size=space_n_time_m(4, 1),
            upsample_stride=space_n_time_m(4, 1),
            dilation=1,
            conv_type=ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS,
            bias=False,
            D=D,
            net_metadata=net_metadata)
        self.bntr3 = nn.BatchNorm1d(self.PLANES[2])

        self.conv4p4s2 = conv(
            self.inplanes,
            self.inplanes,
            pixel_dist=space_n_time_m(4, 1),
            kernel_size=space_n_time_m(2, 1),
            stride=space_n_time_m(2, 1),
            dilation=1,
            conv_type=ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS,
            bias=False,
            D=D,
            net_metadata=net_metadata)
        self.bn4 = nn.BatchNorm1d(self.inplanes)
        self.block4 = self._make_layer(self.BLOCK,
                                       self.PLANES[3],
                                       self.LAYERS[3],
                                       pixel_dist=space_n_time_m(8, 1))
        self.convtr4p8s8 = conv_tr(
            self.inplanes,
            self.PLANES[3],
            pixel_dist=space_n_time_m(8, 1),
            kernel_size=space_n_time_m(8, 1),
            upsample_stride=space_n_time_m(8, 1),
            dilation=1,
            conv_type=ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS,
            bias=False,
            D=D,
            net_metadata=net_metadata)
        self.bntr4 = nn.BatchNorm1d(self.PLANES[3])

        self.relu = nn.ReLU(inplace=True)

        self.final = conv(sum(self.PLANES[1:4]) +
                          self.PLANES[0] * self.BLOCK.expansion,
                          out_channels,
                          pixel_dist=1,
                          kernel_size=1,
                          stride=1,
                          dilation=1,
                          bias=True,
                          D=D,
                          net_metadata=net_metadata)
Пример #5
0
    def network_initialization(self, in_channels, out_channels, config, D):
        # Setup net_metadata
        dilations = self.DILATIONS
        bn_momentum = config['bn_momentum']

        def space_n_time_m(n, m):
            return n if D == 3 else [n, n, n, m]

        if D == 4:
            self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1)

        # Output of the first conv concated to conv6
        self.inplanes = self.INIT_DIM
        self.conv0p1s1 = conv(in_channels,
                              self.inplanes,
                              kernel_size=space_n_time_m(
                                  config['conv1_kernel_size'], 1),
                              stride=1,
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)

        self.bn0 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)

        self.conv1p1s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn1 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block1 = self._make_layer(self.BLOCK,
                                       self.PLANES[0],
                                       self.LAYERS[0],
                                       dilation=dilations[0],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv2p2s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn2 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block2 = self._make_layer(self.BLOCK,
                                       self.PLANES[1],
                                       self.LAYERS[1],
                                       dilation=dilations[1],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv3p4s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn3 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block3 = self._make_layer(self.BLOCK,
                                       self.PLANES[2],
                                       self.LAYERS[2],
                                       dilation=dilations[2],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv4p8s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn4 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block4 = self._make_layer(self.BLOCK,
                                       self.PLANES[3],
                                       self.LAYERS[3],
                                       dilation=dilations[3],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.convtr4p16s2 = conv_tr(self.inplanes,
                                    self.PLANES[4],
                                    kernel_size=space_n_time_m(2, 1),
                                    upsample_stride=space_n_time_m(2, 1),
                                    dilation=1,
                                    bias=False,
                                    conv_type=self.NON_BLOCK_CONV_TYPE,
                                    D=D)
        self.bntr4 = get_norm(self.NORM_TYPE,
                              self.PLANES[4],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion
        self.block5 = self._make_layer(self.BLOCK,
                                       self.PLANES[4],
                                       self.LAYERS[4],
                                       dilation=dilations[4],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.convtr5p8s2 = conv_tr(self.inplanes,
                                   self.PLANES[5],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr5 = get_norm(self.NORM_TYPE,
                              self.PLANES[5],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion
        self.block6 = self._make_layer(self.BLOCK,
                                       self.PLANES[5],
                                       self.LAYERS[5],
                                       dilation=dilations[5],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.convtr6p4s2 = conv_tr(self.inplanes,
                                   self.PLANES[6],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr6 = get_norm(self.NORM_TYPE,
                              self.PLANES[6],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion
        self.block7 = self._make_layer(self.BLOCK,
                                       self.PLANES[6],
                                       self.LAYERS[6],
                                       dilation=dilations[6],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.convtr7p2s2 = conv_tr(self.inplanes,
                                   self.PLANES[7],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr7 = get_norm(self.NORM_TYPE,
                              self.PLANES[7],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[7] + self.INIT_DIM
        self.block8 = self._make_layer(self.BLOCK,
                                       self.PLANES[7],
                                       self.LAYERS[7],
                                       dilation=dilations[7],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.final = conv(self.PLANES[7],
                          out_channels,
                          kernel_size=1,
                          stride=1,
                          bias=True,
                          D=D)
        self.relu = MinkowskiReLU(inplace=True)
Пример #6
0
    def network_initialization(self, in_channels, out_channels, config, D):
        # Setup net_metadata
        dilations = self.DILATIONS
        bn_momentum = config.bn_momentum

        def space_n_time_m(n, m):
            return n if D == 3 else [n, n, n, m]

        if D == 4:
            self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1)

        # Output of the first conv concated to conv6
        self.inplanes = self.INIT_DIM
        self.conv1p1s1 = conv(in_channels,
                              self.inplanes,
                              kernel_size=space_n_time_m(
                                  config.conv1_kernel_size, 1),
                              stride=1,
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)

        self.bn1 = get_norm(self.NORM_TYPE,
                            self.PLANES[0],
                            D,
                            bn_momentum=bn_momentum)
        self.block1 = self._make_layer(self.BLOCK,
                                       self.PLANES[0],
                                       self.LAYERS[0],
                                       dilation=dilations[0],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv2p1s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn2 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block2 = self._make_layer(self.BLOCK,
                                       self.PLANES[1],
                                       self.LAYERS[1],
                                       dilation=dilations[1],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv3p2s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn3 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block3 = self._make_layer(self.BLOCK,
                                       self.PLANES[2],
                                       self.LAYERS[2],
                                       dilation=dilations[2],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv4p4s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn4 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block4 = self._make_layer(self.BLOCK,
                                       self.PLANES[3],
                                       self.LAYERS[3],
                                       dilation=dilations[3],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.pool_tr4 = ME.MinkowskiPoolingTranspose(kernel_size=8,
                                                     stride=8,
                                                     dimension=D)
        out_pool4 = self.inplanes
        self.convtr4p8s2 = conv_tr(self.inplanes,
                                   self.PLANES[4],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr4 = get_norm(self.NORM_TYPE,
                              self.PLANES[4],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion
        self.block5 = self._make_layer(self.BLOCK,
                                       self.PLANES[4],
                                       self.LAYERS[4],
                                       dilation=dilations[4],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.pool_tr5 = ME.MinkowskiPoolingTranspose(kernel_size=4,
                                                     stride=4,
                                                     dimension=D)
        out_pool5 = self.inplanes
        self.convtr5p4s2 = conv_tr(self.inplanes,
                                   self.PLANES[5],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr5 = get_norm(self.NORM_TYPE,
                              self.PLANES[5],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion
        self.block6 = self._make_layer(self.BLOCK,
                                       self.PLANES[5],
                                       self.LAYERS[5],
                                       dilation=dilations[5],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.pool_tr6 = ME.MinkowskiPoolingTranspose(kernel_size=2,
                                                     stride=2,
                                                     dimension=D)
        out_pool6 = self.inplanes
        self.convtr6p2s2 = conv_tr(self.inplanes,
                                   self.PLANES[6],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr6 = get_norm(self.NORM_TYPE,
                              self.PLANES[6],
                              D,
                              bn_momentum=bn_momentum)

        self.relu = MinkowskiReLU(inplace=True)

        self.final = nn.Sequential(
            conv(out_pool5 + out_pool6 + self.PLANES[6] +
                 self.PLANES[0] * self.BLOCK.expansion,
                 512,
                 kernel_size=1,
                 bias=False,
                 D=D), ME.MinkowskiBatchNorm(512), ME.MinkowskiReLU(),
            conv(512, out_channels, kernel_size=1, bias=True, D=D))
Пример #7
0
    def network_initialization(self, in_channels, out_channels, config, D):

        if not isinstance(self.BLOCK, list):  # if single type
            self.BLOCK = [self.BLOCK] * len(self.PLANES)
        # print('IN CHANNEL inside the model {}'.format(in_channels))

        # Setup net_metadata
        dilations = self.DILATIONS
        bn_momentum = config.bn_momentum

        block_noexpansion = True

        def space_n_time_m(n, m):
            return n if D == 3 else [n, n, n, m]

        if D == 4:
            self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1)

        if config.xyz_input is not None:
            if config.xyz_input:
                in_channels = in_channels + 3

        if config.dataset == 'SemanticKITTI':
            in_channels = 4
        if config.dataset == 'S3DIS':
            in_channels = 9

        # Output of the first conv concated to conv6
        self.inplanes = self.INIT_DIM
        self.conv0p1s1 = conv(in_channels,
                              self.inplanes,
                              kernel_size=space_n_time_m(
                                  config.conv1_kernel_size, 1),
                              stride=1,
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)

        self.bn0 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)

        self.conv1p1s2 = conv(
            self.inplanes,
            self.inplanes if not block_noexpansion else self.PLANES[0],
            kernel_size=space_n_time_m(2, 1),
            stride=space_n_time_m(2, 1),
            dilation=1,
            conv_type=self.NON_BLOCK_CONV_TYPE,
            D=D)
        self.inplanes = self.inplanes if not block_noexpansion else self.PLANES[
            0]
        self.bn1 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)

        self.block1 = self._make_layer(
            self.BLOCK[0],
            self.PLANES[0],
            self.LAYERS[0],
            dilation=dilations[0],
            norm_type=self.NORM_TYPE,
            nonlinearity_type=self.config.nonlinearity,
            bn_momentum=bn_momentum)

        self.conv2p2s2 = conv(
            self.inplanes,
            self.inplanes if not block_noexpansion else self.PLANES[1],
            kernel_size=space_n_time_m(2, 1),
            stride=space_n_time_m(2, 1),
            dilation=1,
            conv_type=self.NON_BLOCK_CONV_TYPE,
            D=D)
        self.inplanes = self.inplanes if not block_noexpansion else self.PLANES[
            1]
        self.bn2 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block2 = self._make_layer(
            self.BLOCK[1],
            self.PLANES[1],
            self.LAYERS[1],
            dilation=dilations[1],
            norm_type=self.NORM_TYPE,
            nonlinearity_type=self.config.nonlinearity,
            bn_momentum=bn_momentum)

        self.conv3p4s2 = conv(
            self.inplanes,
            self.inplanes if not block_noexpansion else self.PLANES[2],
            kernel_size=space_n_time_m(2, 1),
            stride=space_n_time_m(2, 1),
            dilation=1,
            conv_type=self.NON_BLOCK_CONV_TYPE,
            D=D)
        self.inplanes = self.inplanes if not block_noexpansion else self.PLANES[
            2]
        self.bn3 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block3 = self._make_layer(
            self.BLOCK[2],
            self.PLANES[2],
            self.LAYERS[2],
            dilation=dilations[2],
            norm_type=self.NORM_TYPE,
            nonlinearity_type=self.config.nonlinearity,
            bn_momentum=bn_momentum)

        self.conv4p8s2 = conv(
            self.inplanes,
            self.inplanes if not block_noexpansion else self.PLANES[3],
            kernel_size=space_n_time_m(2, 1),
            stride=space_n_time_m(2, 1),
            dilation=1,
            conv_type=self.NON_BLOCK_CONV_TYPE,
            D=D)
        self.inplanes = self.inplanes if not block_noexpansion else self.PLANES[
            3]
        self.bn4 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block4 = self._make_layer(
            self.BLOCK[3],
            self.PLANES[3],
            self.LAYERS[3],
            dilation=dilations[3],
            norm_type=self.NORM_TYPE,
            nonlinearity_type=self.config.nonlinearity,
            bn_momentum=bn_momentum)
        self.convtr4p16s2 = conv_tr(self.inplanes,
                                    self.PLANES[4],
                                    kernel_size=space_n_time_m(2, 1),
                                    upsample_stride=space_n_time_m(2, 1),
                                    dilation=1,
                                    bias=False,
                                    conv_type=self.NON_BLOCK_CONV_TYPE,
                                    D=D)
        self.bntr4 = get_norm(self.NORM_TYPE,
                              self.PLANES[4],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[
            4] + self.PLANES[2] * self.BLOCK[4].expansion
        self.block5 = self._make_layer(
            self.BLOCK[4],
            self.PLANES[4],
            self.LAYERS[4],
            dilation=dilations[4],
            norm_type=self.NORM_TYPE,
            nonlinearity_type=self.config.nonlinearity,
            bn_momentum=bn_momentum)
        self.convtr5p8s2 = conv_tr(self.inplanes,
                                   self.PLANES[5],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr5 = get_norm(self.NORM_TYPE,
                              self.PLANES[5],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[
            5] + self.PLANES[1] * self.BLOCK[5].expansion
        self.block6 = self._make_layer(
            self.BLOCK[5],
            self.PLANES[5],
            self.LAYERS[5],
            dilation=dilations[5],
            norm_type=self.NORM_TYPE,
            nonlinearity_type=self.config.nonlinearity,
            bn_momentum=bn_momentum)
        self.convtr6p4s2 = conv_tr(self.inplanes,
                                   self.PLANES[6],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr6 = get_norm(self.NORM_TYPE,
                              self.PLANES[6],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[
            6] + self.PLANES[0] * self.BLOCK[6].expansion
        self.block7 = self._make_layer(
            self.BLOCK[6],
            self.PLANES[6],
            self.LAYERS[6],
            dilation=dilations[6],
            norm_type=self.NORM_TYPE,
            nonlinearity_type=self.config.nonlinearity,
            bn_momentum=bn_momentum)
        self.convtr7p2s2 = conv_tr(self.inplanes,
                                   self.PLANES[7],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr7 = get_norm(self.NORM_TYPE,
                              self.PLANES[7],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[7] + self.INIT_DIM
        self.block8 = self._make_layer(
            self.BLOCK[7],
            self.PLANES[7],
            self.LAYERS[7],
            dilation=dilations[7],
            norm_type=self.NORM_TYPE,
            nonlinearity_type=self.config.nonlinearity,
            bn_momentum=bn_momentum)

        self.final = conv(self.PLANES[7],
                          out_channels,
                          kernel_size=1,
                          stride=1,
                          bias=True,
                          D=D)

        if config.enable_point_branch:
            self.point_transform_mlp = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(self.INIT_DIM, self.PLANES[3]),
                    # nn.BatchNorm1d(self.PLANES[3]),
                    nn.ReLU(True)),
                nn.Sequential(
                    nn.Linear(self.PLANES[3], self.PLANES[5]),
                    #nn.BatchNorm1d(self.PLANES[5]),
                    nn.ReLU(True)),
                nn.Sequential(
                    nn.Linear(self.PLANES[5], self.PLANES[7]),
                    # nn.BatchNorm1d(self.PLANES[7]),
                    nn.ReLU(True))
            ])
            self.downsample16x = nn.Sequential(
                ME.MinkowskiMaxPooling(kernel_size=2, stride=2, dimension=3),
                ME.MinkowskiMaxPooling(kernel_size=2, stride=2, dimension=3),
                ME.MinkowskiMaxPooling(kernel_size=2, stride=2, dimension=3),
                ME.MinkowskiMaxPooling(kernel_size=2, stride=2, dimension=3))
            self.downsample4x = nn.Sequential(
                ME.MinkowskiMaxPooling(kernel_size=2, stride=2, dimension=3),
                ME.MinkowskiMaxPooling(kernel_size=2, stride=2, dimension=3))
            self.dropout = nn.Dropout(0.1, True)
            # self.dropout = nn.Dropout(0.3, True)
            self.interpolate = ME.MinkowskiInterpolation(
                return_kernel_map=False, return_weights=False)
Пример #8
0
Файл: unet.py Проект: jgwak/GSDN
    def __init__(self,
                 in_channels,
                 out_channels,
                 config,
                 D=3,
                 return_feat=False,
                 **kwargs):
        super(UNet, self).__init__(in_channels, out_channels, config, D)
        self.in_channels = in_channels
        self.return_feat = return_feat
        PLANES = self.PLANES

        # Output of the first conv concated to conv6
        self.conv_down1 = conv(in_channels,
                               PLANES[0],
                               kernel_size=3,
                               stride=1,
                               bias=False,
                               D=D)
        self.bn_down1 = ME.MinkowskiBatchNorm(PLANES[0])
        self.down1 = conv(PLANES[0], PLANES[1], kernel_size=2, stride=2, D=D)
        self.conv_down2 = conv(PLANES[1],
                               PLANES[1],
                               kernel_size=3,
                               stride=1,
                               bias=False,
                               D=D)
        self.bn_down2 = ME.MinkowskiBatchNorm(PLANES[1])
        self.down2 = conv(PLANES[1], PLANES[2], kernel_size=2, stride=2, D=D)
        self.conv_down3 = conv(PLANES[2],
                               PLANES[2],
                               kernel_size=3,
                               stride=1,
                               bias=False,
                               D=D)
        self.bn_down3 = ME.MinkowskiBatchNorm(PLANES[2])
        self.down3 = conv(PLANES[2], PLANES[3], kernel_size=2, stride=2, D=D)
        self.conv_down4 = conv(PLANES[3],
                               PLANES[3],
                               kernel_size=3,
                               stride=1,
                               bias=False,
                               D=D)
        self.bn_down4 = ME.MinkowskiBatchNorm(PLANES[3])
        self.down4 = conv(PLANES[3], PLANES[4], kernel_size=2, stride=2, D=D)
        self.conv_down5 = conv(PLANES[4],
                               PLANES[4],
                               kernel_size=3,
                               stride=1,
                               bias=False,
                               D=D)
        self.bn_down5 = ME.MinkowskiBatchNorm(PLANES[4])
        self.down5 = conv(PLANES[4], PLANES[5], kernel_size=2, stride=2, D=D)
        self.conv_down6 = conv(PLANES[5],
                               PLANES[5],
                               kernel_size=3,
                               stride=1,
                               bias=False,
                               D=D)
        self.bn_down6 = ME.MinkowskiBatchNorm(PLANES[5])
        self.down6 = conv(PLANES[5], PLANES[6], kernel_size=2, stride=2, D=D)
        self.conv7 = conv(PLANES[6],
                          PLANES[6],
                          kernel_size=3,
                          stride=1,
                          bias=False,
                          D=D)
        self.bn7 = ME.MinkowskiBatchNorm(PLANES[6])
        self.up6 = conv_tr(PLANES[6],
                           PLANES[5],
                           kernel_size=2,
                           upsample_stride=2,
                           D=D)
        self.conv_up6 = conv(PLANES[5] * 2,
                             PLANES[5],
                             kernel_size=3,
                             stride=1,
                             bias=False,
                             D=D)
        self.bn_up6 = ME.MinkowskiBatchNorm(PLANES[5])
        self.up5 = conv_tr(PLANES[5],
                           PLANES[4],
                           kernel_size=2,
                           upsample_stride=2,
                           D=D)
        self.conv_up5 = conv(PLANES[4] * 2,
                             PLANES[4],
                             kernel_size=3,
                             stride=1,
                             bias=False,
                             D=D)
        self.bn_up5 = ME.MinkowskiBatchNorm(PLANES[4])
        self.up4 = conv_tr(PLANES[4],
                           PLANES[3],
                           kernel_size=2,
                           upsample_stride=2,
                           D=D)
        self.conv_up4 = conv(PLANES[3] * 2,
                             PLANES[3],
                             kernel_size=3,
                             stride=1,
                             bias=False,
                             D=D)
        self.bn_up4 = ME.MinkowskiBatchNorm(PLANES[3])
        self.up3 = conv_tr(PLANES[3],
                           PLANES[2],
                           kernel_size=2,
                           upsample_stride=2,
                           D=D)
        self.conv_up3 = conv(PLANES[2] * 2,
                             PLANES[2],
                             kernel_size=3,
                             stride=1,
                             bias=False,
                             D=D)
        self.bn_up3 = ME.MinkowskiBatchNorm(PLANES[2])
        self.up2 = conv_tr(PLANES[2],
                           PLANES[1],
                           kernel_size=2,
                           upsample_stride=2,
                           D=D)
        self.conv_up2 = conv(PLANES[1] * 2,
                             PLANES[1],
                             kernel_size=3,
                             stride=1,
                             bias=False,
                             D=D)
        self.bn_up2 = ME.MinkowskiBatchNorm(PLANES[1])
        self.up1 = conv_tr(PLANES[1],
                           PLANES[0],
                           kernel_size=2,
                           upsample_stride=2,
                           D=D)
        self.conv_up1 = conv(PLANES[0] * 2,
                             PLANES[0],
                             kernel_size=3,
                             stride=1,
                             bias=False,
                             D=D)
        self.bn_up1 = ME.MinkowskiBatchNorm(PLANES[0])
        self.mask_feat = conv(PLANES[0],
                              self.in_channels,
                              kernel_size=1,
                              bias=True,
                              D=D)
        self.final = conv(PLANES[0],
                          out_channels,
                          kernel_size=1,
                          bias=True,
                          D=D)
        self.relu = ME.MinkowskiReLU(inplace=True)