def __init__(self, inplanes, intermediate_inplanes, intermediate_outplanes, outplanes, intermediate_module, BLOCK=None, reps=1, conv_type=ConvType.HYPERCUBE, norm_type=NormType.BATCH_NORM, bn_momentum=0.1, D=3): super(UBlock, self).__init__() self.block = BLOCK(inplanes, inplanes, conv_type=conv_type, bn_momentum=bn_momentum, D=D) self.down = conv(inplanes, intermediate_inplanes, kernel_size=space_n_time_m(2, 1, D), stride=space_n_time_m(2, 1, D), conv_type=conv_type, D=D) self.down_norm = get_norm(norm_type, intermediate_inplanes, D, bn_momentum=bn_momentum) self.intermediate = intermediate_module self.up = conv_tr(intermediate_outplanes, outplanes, kernel_size=space_n_time_m(2, 1, D), upsample_stride=space_n_time_m(2, 1, D), conv_type=conv_type, D=D) self.up_norm = get_norm(norm_type, outplanes, D, bn_momentum=bn_momentum) self.reps = reps for i in range(reps): if i == 0: downsample = nn.Sequential( conv(inplanes + outplanes, outplanes, kernel_size=1, bias=False, D=D), get_norm(norm_type, outplanes, D, bn_momentum=bn_momentum), ) setattr( self, f'end_blocks{i}', BLOCK((inplanes + outplanes) if i == 0 else outplanes, outplanes, downsample=downsample if i == 0 else None, conv_type=conv_type, bn_momentum=bn_momentum, D=D))
def __init__(self, in_channels, out_channels, config, D=3, **kwargs): super(RecUNetBase, self).__init__(in_channels, out_channels, config, D) PLANES = self.PLANES[::-1] bn_momentum = config.bn_momentum if D == 4: self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1, D) # Output of the first conv concated to conv6 self.conv1 = conv(in_channels, PLANES[-1][0], kernel_size=space_n_time_m(3, 1, D), conv_type=self.CONV_TYPE, D=D) self.norm1 = get_norm(self.NORM_TYPE, PLANES[-1][0], D, bn_momentum) interm = self.BLOCK(PLANES[0][0], PLANES[0][0], conv_type=self.CONV_TYPE, D=self.D) for i, inoutplanes in enumerate(PLANES[1:]): interm = UBlock(inoutplanes[0], PLANES[i][0], PLANES[i][1], inoutplanes[1], intermediate_module=interm, BLOCK=self.BLOCK, reps=self.REPS[len(self.REPS) - i - 1], conv_type=self.CONV_TYPE, bn_momentum=bn_momentum, D=D) self.unet = interm self.final = conv(PLANES[-1][1], out_channels, kernel_size=1, stride=1, dilation=1, bias=True, D=D) self.relu = MinkowskiReLU(inplace=True)
def network_initialization(self, in_channels, out_channels, config, D): # Setup net_metadata dilations = self.DILATIONS bn_momentum = config['bn_momentum'] def space_n_time_m(n, m): return n if D == 3 else [n, n, n, m] if D == 4: self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) # Output of the first conv concated to conv6 self.inplanes = self.INIT_DIM self.conv0p1s1 = conv(in_channels, self.inplanes, kernel_size=space_n_time_m( config['conv1_kernel_size'], 1), stride=1, dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn0 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.conv1p1s2 = conv(self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn1 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0], dilation=dilations[0], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv2p2s2 = conv(self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn2 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1], dilation=dilations[1], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv3p4s2 = conv(self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn3 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2], dilation=dilations[2], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv4p8s2 = conv(self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn4 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3], dilation=dilations[3], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.convtr4p16s2 = conv_tr(self.inplanes, self.PLANES[4], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bntr4 = get_norm(self.NORM_TYPE, self.PLANES[4], D, bn_momentum=bn_momentum) self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], self.LAYERS[4], dilation=dilations[4], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.convtr5p8s2 = conv_tr(self.inplanes, self.PLANES[5], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bntr5 = get_norm(self.NORM_TYPE, self.PLANES[5], D, bn_momentum=bn_momentum) self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], self.LAYERS[5], dilation=dilations[5], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.convtr6p4s2 = conv_tr(self.inplanes, self.PLANES[6], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bntr6 = get_norm(self.NORM_TYPE, self.PLANES[6], D, bn_momentum=bn_momentum) self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion self.block7 = self._make_layer(self.BLOCK, self.PLANES[6], self.LAYERS[6], dilation=dilations[6], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.convtr7p2s2 = conv_tr(self.inplanes, self.PLANES[7], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bntr7 = get_norm(self.NORM_TYPE, self.PLANES[7], D, bn_momentum=bn_momentum) self.inplanes = self.PLANES[7] + self.INIT_DIM self.block8 = self._make_layer(self.BLOCK, self.PLANES[7], self.LAYERS[7], dilation=dilations[7], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.final = conv(self.PLANES[7], out_channels, kernel_size=1, stride=1, bias=True, D=D) self.relu = MinkowskiReLU(inplace=True)
def network_initialization(self, in_channels, out_channels, config, D): # Setup net_metadata dilations = self.DILATIONS bn_momentum = config.bn_momentum def space_n_time_m(n, m): return n if D == 3 else [n, n, n, m] if D == 4: self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) # Output of the first conv concated to conv6 self.inplanes = self.INIT_DIM self.conv1p1s1 = conv(in_channels, self.inplanes, kernel_size=space_n_time_m( config.conv1_kernel_size, 1), stride=1, dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn1 = get_norm(self.NORM_TYPE, self.PLANES[0], D, bn_momentum=bn_momentum) self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0], dilation=dilations[0], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv2p1s2 = conv(self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn2 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1], dilation=dilations[1], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv3p2s2 = conv(self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn3 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2], dilation=dilations[2], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv4p4s2 = conv(self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn4 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3], dilation=dilations[3], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.pool_tr4 = ME.MinkowskiPoolingTranspose(kernel_size=8, stride=8, dimension=D) out_pool4 = self.inplanes self.convtr4p8s2 = conv_tr(self.inplanes, self.PLANES[4], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bntr4 = get_norm(self.NORM_TYPE, self.PLANES[4], D, bn_momentum=bn_momentum) self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], self.LAYERS[4], dilation=dilations[4], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.pool_tr5 = ME.MinkowskiPoolingTranspose(kernel_size=4, stride=4, dimension=D) out_pool5 = self.inplanes self.convtr5p4s2 = conv_tr(self.inplanes, self.PLANES[5], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bntr5 = get_norm(self.NORM_TYPE, self.PLANES[5], D, bn_momentum=bn_momentum) self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], self.LAYERS[5], dilation=dilations[5], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.pool_tr6 = ME.MinkowskiPoolingTranspose(kernel_size=2, stride=2, dimension=D) out_pool6 = self.inplanes self.convtr6p2s2 = conv_tr(self.inplanes, self.PLANES[6], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bntr6 = get_norm(self.NORM_TYPE, self.PLANES[6], D, bn_momentum=bn_momentum) self.relu = MinkowskiReLU(inplace=True) self.final = nn.Sequential( conv(out_pool5 + out_pool6 + self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion, 512, kernel_size=1, bias=False, D=D), ME.MinkowskiBatchNorm(512), ME.MinkowskiReLU(), conv(512, out_channels, kernel_size=1, bias=True, D=D))
def network_initialization(self, in_channels, out_channels, config, D): # Setup net_metadata dilations = self.DILATIONS bn_momentum = config['bn_momentum'] def space_n_time_m(n, m): return n if D == 3 else [n, n, n, m] if D == 4: self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) # Output of the first conv concated to conv6 self.inplanes = self.INIT_DIM self.conv0p1s1 = conv( in_channels, self.inplanes, kernel_size=space_n_time_m(config['conv1_kernel_size'], 1), stride=1, dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn0 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.conv1p1s2 = conv( self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn1 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block1 = self._make_layer( self.BLOCK, self.PLANES[0], self.LAYERS[0], dilation=dilations[0], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv2p2s2 = conv( self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn2 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block2 = self._make_layer( self.BLOCK, self.PLANES[1], self.LAYERS[1], dilation=dilations[1], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv3p4s2 = conv( self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn3 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block3 = self._make_layer( self.BLOCK, self.PLANES[2], self.LAYERS[2], dilation=dilations[2], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv4p8s2 = conv( self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn4 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block4 = self._make_layer( self.BLOCK, self.PLANES[3], self.LAYERS[3], dilation=dilations[3], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.relu = MinkowskiReLU(inplace=True) # add a classification head here self.clf_glob_avg = ME.MinkowskiGlobalPooling(dimension=D) self.clf_glob_max=ME.MinkowskiGlobalMaxPooling(dimension=D) self.clf_conv0 = conv( 256, 512, kernel_size=3, stride=2, dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.clf_bn0 = get_norm(self.NORM_TYPE, 512, D, bn_momentum=bn_momentum) self.clf_conv1 = conv( 512, 512, kernel_size=3, stride=2, dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.clf_bn1 = get_norm(self.NORM_TYPE, 512, D, bn_momentum=bn_momentum) self.clf_conv2 = conv( 512, config['clf_num_labels'], kernel_size=1, stride=1, dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D)