예제 #1
0
    def __init__(self,
                 inplanes,
                 intermediate_inplanes,
                 intermediate_outplanes,
                 outplanes,
                 intermediate_module,
                 BLOCK=None,
                 reps=1,
                 conv_type=ConvType.HYPERCUBE,
                 norm_type=NormType.BATCH_NORM,
                 bn_momentum=0.1,
                 D=3):
        super(UBlock, self).__init__()

        self.block = BLOCK(inplanes,
                           inplanes,
                           conv_type=conv_type,
                           bn_momentum=bn_momentum,
                           D=D)
        self.down = conv(inplanes,
                         intermediate_inplanes,
                         kernel_size=space_n_time_m(2, 1, D),
                         stride=space_n_time_m(2, 1, D),
                         conv_type=conv_type,
                         D=D)
        self.down_norm = get_norm(norm_type,
                                  intermediate_inplanes,
                                  D,
                                  bn_momentum=bn_momentum)
        self.intermediate = intermediate_module
        self.up = conv_tr(intermediate_outplanes,
                          outplanes,
                          kernel_size=space_n_time_m(2, 1, D),
                          upsample_stride=space_n_time_m(2, 1, D),
                          conv_type=conv_type,
                          D=D)
        self.up_norm = get_norm(norm_type,
                                outplanes,
                                D,
                                bn_momentum=bn_momentum)
        self.reps = reps
        for i in range(reps):
            if i == 0:
                downsample = nn.Sequential(
                    conv(inplanes + outplanes,
                         outplanes,
                         kernel_size=1,
                         bias=False,
                         D=D),
                    get_norm(norm_type, outplanes, D, bn_momentum=bn_momentum),
                )

            setattr(
                self, f'end_blocks{i}',
                BLOCK((inplanes + outplanes) if i == 0 else outplanes,
                      outplanes,
                      downsample=downsample if i == 0 else None,
                      conv_type=conv_type,
                      bn_momentum=bn_momentum,
                      D=D))
예제 #2
0
    def __init__(self, in_channels, out_channels, config, D=3, **kwargs):
        super(RecUNetBase, self).__init__(in_channels, out_channels, config, D)

        PLANES = self.PLANES[::-1]
        bn_momentum = config.bn_momentum

        if D == 4:
            self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1, D)

        # Output of the first conv concated to conv6
        self.conv1 = conv(in_channels,
                          PLANES[-1][0],
                          kernel_size=space_n_time_m(3, 1, D),
                          conv_type=self.CONV_TYPE,
                          D=D)
        self.norm1 = get_norm(self.NORM_TYPE, PLANES[-1][0], D, bn_momentum)
        interm = self.BLOCK(PLANES[0][0],
                            PLANES[0][0],
                            conv_type=self.CONV_TYPE,
                            D=self.D)

        for i, inoutplanes in enumerate(PLANES[1:]):
            interm = UBlock(inoutplanes[0],
                            PLANES[i][0],
                            PLANES[i][1],
                            inoutplanes[1],
                            intermediate_module=interm,
                            BLOCK=self.BLOCK,
                            reps=self.REPS[len(self.REPS) - i - 1],
                            conv_type=self.CONV_TYPE,
                            bn_momentum=bn_momentum,
                            D=D)
        self.unet = interm
        self.final = conv(PLANES[-1][1],
                          out_channels,
                          kernel_size=1,
                          stride=1,
                          dilation=1,
                          bias=True,
                          D=D)

        self.relu = MinkowskiReLU(inplace=True)
예제 #3
0
    def network_initialization(self, in_channels, out_channels, config, D):
        # Setup net_metadata
        dilations = self.DILATIONS
        bn_momentum = config['bn_momentum']

        def space_n_time_m(n, m):
            return n if D == 3 else [n, n, n, m]

        if D == 4:
            self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1)

        # Output of the first conv concated to conv6
        self.inplanes = self.INIT_DIM
        self.conv0p1s1 = conv(in_channels,
                              self.inplanes,
                              kernel_size=space_n_time_m(
                                  config['conv1_kernel_size'], 1),
                              stride=1,
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)

        self.bn0 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)

        self.conv1p1s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn1 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block1 = self._make_layer(self.BLOCK,
                                       self.PLANES[0],
                                       self.LAYERS[0],
                                       dilation=dilations[0],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv2p2s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn2 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block2 = self._make_layer(self.BLOCK,
                                       self.PLANES[1],
                                       self.LAYERS[1],
                                       dilation=dilations[1],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv3p4s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn3 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block3 = self._make_layer(self.BLOCK,
                                       self.PLANES[2],
                                       self.LAYERS[2],
                                       dilation=dilations[2],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv4p8s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn4 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block4 = self._make_layer(self.BLOCK,
                                       self.PLANES[3],
                                       self.LAYERS[3],
                                       dilation=dilations[3],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.convtr4p16s2 = conv_tr(self.inplanes,
                                    self.PLANES[4],
                                    kernel_size=space_n_time_m(2, 1),
                                    upsample_stride=space_n_time_m(2, 1),
                                    dilation=1,
                                    bias=False,
                                    conv_type=self.NON_BLOCK_CONV_TYPE,
                                    D=D)
        self.bntr4 = get_norm(self.NORM_TYPE,
                              self.PLANES[4],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion
        self.block5 = self._make_layer(self.BLOCK,
                                       self.PLANES[4],
                                       self.LAYERS[4],
                                       dilation=dilations[4],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.convtr5p8s2 = conv_tr(self.inplanes,
                                   self.PLANES[5],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr5 = get_norm(self.NORM_TYPE,
                              self.PLANES[5],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion
        self.block6 = self._make_layer(self.BLOCK,
                                       self.PLANES[5],
                                       self.LAYERS[5],
                                       dilation=dilations[5],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.convtr6p4s2 = conv_tr(self.inplanes,
                                   self.PLANES[6],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr6 = get_norm(self.NORM_TYPE,
                              self.PLANES[6],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion
        self.block7 = self._make_layer(self.BLOCK,
                                       self.PLANES[6],
                                       self.LAYERS[6],
                                       dilation=dilations[6],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.convtr7p2s2 = conv_tr(self.inplanes,
                                   self.PLANES[7],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr7 = get_norm(self.NORM_TYPE,
                              self.PLANES[7],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[7] + self.INIT_DIM
        self.block8 = self._make_layer(self.BLOCK,
                                       self.PLANES[7],
                                       self.LAYERS[7],
                                       dilation=dilations[7],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.final = conv(self.PLANES[7],
                          out_channels,
                          kernel_size=1,
                          stride=1,
                          bias=True,
                          D=D)
        self.relu = MinkowskiReLU(inplace=True)
예제 #4
0
    def network_initialization(self, in_channels, out_channels, config, D):
        # Setup net_metadata
        dilations = self.DILATIONS
        bn_momentum = config.bn_momentum

        def space_n_time_m(n, m):
            return n if D == 3 else [n, n, n, m]

        if D == 4:
            self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1)

        # Output of the first conv concated to conv6
        self.inplanes = self.INIT_DIM
        self.conv1p1s1 = conv(in_channels,
                              self.inplanes,
                              kernel_size=space_n_time_m(
                                  config.conv1_kernel_size, 1),
                              stride=1,
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)

        self.bn1 = get_norm(self.NORM_TYPE,
                            self.PLANES[0],
                            D,
                            bn_momentum=bn_momentum)
        self.block1 = self._make_layer(self.BLOCK,
                                       self.PLANES[0],
                                       self.LAYERS[0],
                                       dilation=dilations[0],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv2p1s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn2 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block2 = self._make_layer(self.BLOCK,
                                       self.PLANES[1],
                                       self.LAYERS[1],
                                       dilation=dilations[1],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv3p2s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn3 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block3 = self._make_layer(self.BLOCK,
                                       self.PLANES[2],
                                       self.LAYERS[2],
                                       dilation=dilations[2],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)

        self.conv4p4s2 = conv(self.inplanes,
                              self.inplanes,
                              kernel_size=space_n_time_m(2, 1),
                              stride=space_n_time_m(2, 1),
                              dilation=1,
                              conv_type=self.NON_BLOCK_CONV_TYPE,
                              D=D)
        self.bn4 = get_norm(self.NORM_TYPE,
                            self.inplanes,
                            D,
                            bn_momentum=bn_momentum)
        self.block4 = self._make_layer(self.BLOCK,
                                       self.PLANES[3],
                                       self.LAYERS[3],
                                       dilation=dilations[3],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.pool_tr4 = ME.MinkowskiPoolingTranspose(kernel_size=8,
                                                     stride=8,
                                                     dimension=D)
        out_pool4 = self.inplanes
        self.convtr4p8s2 = conv_tr(self.inplanes,
                                   self.PLANES[4],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr4 = get_norm(self.NORM_TYPE,
                              self.PLANES[4],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion
        self.block5 = self._make_layer(self.BLOCK,
                                       self.PLANES[4],
                                       self.LAYERS[4],
                                       dilation=dilations[4],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.pool_tr5 = ME.MinkowskiPoolingTranspose(kernel_size=4,
                                                     stride=4,
                                                     dimension=D)
        out_pool5 = self.inplanes
        self.convtr5p4s2 = conv_tr(self.inplanes,
                                   self.PLANES[5],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr5 = get_norm(self.NORM_TYPE,
                              self.PLANES[5],
                              D,
                              bn_momentum=bn_momentum)

        self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion
        self.block6 = self._make_layer(self.BLOCK,
                                       self.PLANES[5],
                                       self.LAYERS[5],
                                       dilation=dilations[5],
                                       norm_type=self.NORM_TYPE,
                                       bn_momentum=bn_momentum)
        self.pool_tr6 = ME.MinkowskiPoolingTranspose(kernel_size=2,
                                                     stride=2,
                                                     dimension=D)
        out_pool6 = self.inplanes
        self.convtr6p2s2 = conv_tr(self.inplanes,
                                   self.PLANES[6],
                                   kernel_size=space_n_time_m(2, 1),
                                   upsample_stride=space_n_time_m(2, 1),
                                   dilation=1,
                                   bias=False,
                                   conv_type=self.NON_BLOCK_CONV_TYPE,
                                   D=D)
        self.bntr6 = get_norm(self.NORM_TYPE,
                              self.PLANES[6],
                              D,
                              bn_momentum=bn_momentum)

        self.relu = MinkowskiReLU(inplace=True)

        self.final = nn.Sequential(
            conv(out_pool5 + out_pool6 + self.PLANES[6] +
                 self.PLANES[0] * self.BLOCK.expansion,
                 512,
                 kernel_size=1,
                 bias=False,
                 D=D), ME.MinkowskiBatchNorm(512), ME.MinkowskiReLU(),
            conv(512, out_channels, kernel_size=1, bias=True, D=D))
  def network_initialization(self, in_channels, out_channels, config, D):
    # Setup net_metadata
    dilations = self.DILATIONS
    bn_momentum = config['bn_momentum']

    def space_n_time_m(n, m):
      return n if D == 3 else [n, n, n, m]

    if D == 4:
      self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1)

    # Output of the first conv concated to conv6
    self.inplanes = self.INIT_DIM
    self.conv0p1s1 = conv(
        in_channels,
        self.inplanes,
        kernel_size=space_n_time_m(config['conv1_kernel_size'], 1),
        stride=1,
        dilation=1,
        conv_type=self.NON_BLOCK_CONV_TYPE,
        D=D)

    self.bn0 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)

    self.conv1p1s2 = conv(
        self.inplanes,
        self.inplanes,
        kernel_size=space_n_time_m(2, 1),
        stride=space_n_time_m(2, 1),
        dilation=1,
        conv_type=self.NON_BLOCK_CONV_TYPE,
        D=D)
    self.bn1 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
    self.block1 = self._make_layer(
        self.BLOCK,
        self.PLANES[0],
        self.LAYERS[0],
        dilation=dilations[0],
        norm_type=self.NORM_TYPE,
        bn_momentum=bn_momentum)

    self.conv2p2s2 = conv(
        self.inplanes,
        self.inplanes,
        kernel_size=space_n_time_m(2, 1),
        stride=space_n_time_m(2, 1),
        dilation=1,
        conv_type=self.NON_BLOCK_CONV_TYPE,
        D=D)
    self.bn2 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
    self.block2 = self._make_layer(
        self.BLOCK,
        self.PLANES[1],
        self.LAYERS[1],
        dilation=dilations[1],
        norm_type=self.NORM_TYPE,
        bn_momentum=bn_momentum)

    self.conv3p4s2 = conv(
        self.inplanes,
        self.inplanes,
        kernel_size=space_n_time_m(2, 1),
        stride=space_n_time_m(2, 1),
        dilation=1,
        conv_type=self.NON_BLOCK_CONV_TYPE,
        D=D)
    self.bn3 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
    self.block3 = self._make_layer(
        self.BLOCK,
        self.PLANES[2],
        self.LAYERS[2],
        dilation=dilations[2],
        norm_type=self.NORM_TYPE,
        bn_momentum=bn_momentum)

    self.conv4p8s2 = conv(
        self.inplanes,
        self.inplanes,
        kernel_size=space_n_time_m(2, 1),
        stride=space_n_time_m(2, 1),
        dilation=1,
        conv_type=self.NON_BLOCK_CONV_TYPE,
        D=D)
    self.bn4 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
    self.block4 = self._make_layer(
        self.BLOCK,
        self.PLANES[3],
        self.LAYERS[3],
        dilation=dilations[3],
        norm_type=self.NORM_TYPE,
        bn_momentum=bn_momentum)
        
    self.relu = MinkowskiReLU(inplace=True)
    # add a classification head here
    self.clf_glob_avg = ME.MinkowskiGlobalPooling(dimension=D)
    self.clf_glob_max=ME.MinkowskiGlobalMaxPooling(dimension=D)
    self.clf_conv0 = conv(
        256,
        512,
        kernel_size=3,
        stride=2,
        dilation=1,
        conv_type=self.NON_BLOCK_CONV_TYPE,
        D=D)
    self.clf_bn0 = get_norm(self.NORM_TYPE, 512, D, bn_momentum=bn_momentum)
    self.clf_conv1 = conv(
        512,
        512,
        kernel_size=3,
        stride=2,
        dilation=1,
        conv_type=self.NON_BLOCK_CONV_TYPE,
        D=D)
    self.clf_bn1 = get_norm(self.NORM_TYPE, 512, D, bn_momentum=bn_momentum)
    self.clf_conv2 = conv(
        512,
        config['clf_num_labels'],
        kernel_size=1,
        stride=1,
        dilation=1,
        conv_type=self.NON_BLOCK_CONV_TYPE,
        D=D)