def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, conv_type=ConvType.HYPERCUBE, bn_momentum=0.1, D=3): super(BasicBlockBase, self).__init__() self.conv1 = conv( inplanes, planes, kernel_size=3, stride=stride, dilation=dilation, conv_type=conv_type, D=D) self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) self.conv2 = conv( planes, planes, kernel_size=3, stride=1, dilation=dilation, bias=False, conv_type=conv_type, D=D) self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) self.relu = MinkowskiReLU(inplace=True) self.downsample = downsample
def test_network(self): dense_tensor = torch.rand(3, 4, 11, 11, 11, 11) # BxCxD1xD2x....xDN dense_tensor.requires_grad = True # Since the shape is fixed, cache the coordinates for faster inference coordinates = dense_coordinates(dense_tensor.shape) network = nn.Sequential( # Add layers that can be applied on a regular pytorch tensor nn.ReLU(), MinkowskiToSparseTensor(remove_zeros=False, coordinates=coordinates), MinkowskiConvolution(4, 5, stride=2, kernel_size=3, dimension=4), MinkowskiBatchNorm(5), MinkowskiReLU(), MinkowskiConvolutionTranspose(5, 6, stride=2, kernel_size=3, dimension=4), MinkowskiToDenseTensor( dense_tensor.shape), # must have the same tensor stride. ) for i in range(5): print(f"Iteration: {i}") output = network(dense_tensor) output.sum().backward() assert dense_tensor.grad is not None
def __init__(self, in_channels, out_channels, config, D=3, **kwargs): super(RecUNetBase, self).__init__(in_channels, out_channels, config, D) PLANES = self.PLANES[::-1] bn_momentum = config.bn_momentum if D == 4: self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1, D) # Output of the first conv concated to conv6 self.conv1 = conv(in_channels, PLANES[-1][0], kernel_size=space_n_time_m(3, 1, D), conv_type=self.CONV_TYPE, D=D) self.norm1 = get_norm(self.NORM_TYPE, PLANES[-1][0], D, bn_momentum) interm = self.BLOCK(PLANES[0][0], PLANES[0][0], conv_type=self.CONV_TYPE, D=self.D) for i, inoutplanes in enumerate(PLANES[1:]): interm = UBlock(inoutplanes[0], PLANES[i][0], PLANES[i][1], inoutplanes[1], intermediate_module=interm, BLOCK=self.BLOCK, reps=self.REPS[len(self.REPS) - i - 1], conv_type=self.CONV_TYPE, bn_momentum=bn_momentum, D=D) self.unet = interm self.final = conv(PLANES[-1][1], out_channels, kernel_size=1, stride=1, dilation=1, bias=True, D=D) self.relu = MinkowskiReLU(inplace=True)
def network_initialization(self, in_channels, out_channels, D): # Setup net_metadata dilations = self.DILATIONS bn_momentum = 0.02 def space_n_time_m(n, m): return n if D == 3 else [n, n, n, m] if D == 4: self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) # Output of the first conv concated to conv6 self.inplanes = self.INIT_DIM self.conv0p1s1 = conv( in_channels, self.inplanes, kernel_size=space_n_time_m(3, 1), stride=1, dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D, ) self.bn0 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.conv1p1s2 = conv( self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D, ) self.bn1 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block1 = self._make_layer( self.BLOCK, self.PLANES[0], self.LAYERS[0], dilation=dilations[0], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum, ) self.conv2p2s2 = conv( self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D, ) self.bn2 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block2 = self._make_layer( self.BLOCK, self.PLANES[1], self.LAYERS[1], dilation=dilations[1], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum, ) self.conv3p4s2 = conv( self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D, ) self.bn3 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block3 = self._make_layer( self.BLOCK, self.PLANES[2], self.LAYERS[2], dilation=dilations[2], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum, ) self.conv4p8s2 = conv( self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D, ) self.bn4 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block4 = self._make_layer( self.BLOCK, self.PLANES[3], self.LAYERS[3], dilation=dilations[3], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum, ) self.convtr4p16s2 = conv_tr( self.inplanes, self.PLANES[4], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D, ) self.bntr4 = get_norm(self.NORM_TYPE, self.PLANES[4], D, bn_momentum=bn_momentum) self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion self.block5 = self._make_layer( self.BLOCK, self.PLANES[4], self.LAYERS[4], dilation=dilations[4], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum, ) self.convtr5p8s2 = conv_tr( self.inplanes, self.PLANES[5], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D, ) self.bntr5 = get_norm(self.NORM_TYPE, self.PLANES[5], D, bn_momentum=bn_momentum) self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion self.block6 = self._make_layer( self.BLOCK, self.PLANES[5], self.LAYERS[5], dilation=dilations[5], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum, ) self.convtr6p4s2 = conv_tr( self.inplanes, self.PLANES[6], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D, ) self.bntr6 = get_norm(self.NORM_TYPE, self.PLANES[6], D, bn_momentum=bn_momentum) self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion self.block7 = self._make_layer( self.BLOCK, self.PLANES[6], self.LAYERS[6], dilation=dilations[6], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum, ) self.convtr7p2s2 = conv_tr( self.inplanes, self.PLANES[7], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D, ) self.bntr7 = get_norm(self.NORM_TYPE, self.PLANES[7], D, bn_momentum=bn_momentum) self.inplanes = self.PLANES[7] + self.INIT_DIM self.block8 = self._make_layer( self.BLOCK, self.PLANES[7], self.LAYERS[7], dilation=dilations[7], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum, ) self.final = conv(self.PLANES[7], out_channels, kernel_size=1, stride=1, bias=True, D=D) self.relu = MinkowskiReLU(inplace=True)
def network_initialization(self, in_channels, out_channels, config, D): # Setup net_metadata dilations = self.DILATIONS bn_momentum = config.bn_momentum def space_n_time_m(n, m): return n if D == 3 else [n, n, n, m] if D == 4: self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) # Output of the first conv concated to conv6 self.inplanes = self.INIT_DIM self.conv1p1s1 = conv(in_channels, self.inplanes, kernel_size=space_n_time_m( config.conv1_kernel_size, 1), stride=1, dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn1 = get_norm(self.NORM_TYPE, self.PLANES[0], D, bn_momentum=bn_momentum) self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0], dilation=dilations[0], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv2p1s2 = conv(self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn2 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1], dilation=dilations[1], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv3p2s2 = conv(self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn3 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2], dilation=dilations[2], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv4p4s2 = conv(self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn4 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3], dilation=dilations[3], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.pool_tr4 = ME.MinkowskiPoolingTranspose(kernel_size=8, stride=8, dimension=D) out_pool4 = self.inplanes self.convtr4p8s2 = conv_tr(self.inplanes, self.PLANES[4], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bntr4 = get_norm(self.NORM_TYPE, self.PLANES[4], D, bn_momentum=bn_momentum) self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], self.LAYERS[4], dilation=dilations[4], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.pool_tr5 = ME.MinkowskiPoolingTranspose(kernel_size=4, stride=4, dimension=D) out_pool5 = self.inplanes self.convtr5p4s2 = conv_tr(self.inplanes, self.PLANES[5], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bntr5 = get_norm(self.NORM_TYPE, self.PLANES[5], D, bn_momentum=bn_momentum) self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], self.LAYERS[5], dilation=dilations[5], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.pool_tr6 = ME.MinkowskiPoolingTranspose(kernel_size=2, stride=2, dimension=D) out_pool6 = self.inplanes self.convtr6p2s2 = conv_tr(self.inplanes, self.PLANES[6], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bntr6 = get_norm(self.NORM_TYPE, self.PLANES[6], D, bn_momentum=bn_momentum) self.relu = MinkowskiReLU(inplace=True) self.final = nn.Sequential( conv(out_pool5 + out_pool6 + self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion, 512, kernel_size=1, bias=False, D=D), ME.MinkowskiBatchNorm(512), ME.MinkowskiReLU(), conv(512, out_channels, kernel_size=1, bias=True, D=D))
def network_initialization(self, in_channels, out_channels, config, D): # Setup net_metadata dilations = self.DILATIONS bn_momentum = config['bn_momentum'] def space_n_time_m(n, m): return n if D == 3 else [n, n, n, m] if D == 4: self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) # Output of the first conv concated to conv6 self.inplanes = self.INIT_DIM self.conv0p1s1 = conv( in_channels, self.inplanes, kernel_size=space_n_time_m(config['conv1_kernel_size'], 1), stride=1, dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn0 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.conv1p1s2 = conv( self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn1 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block1 = self._make_layer( self.BLOCK, self.PLANES[0], self.LAYERS[0], dilation=dilations[0], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv2p2s2 = conv( self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn2 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block2 = self._make_layer( self.BLOCK, self.PLANES[1], self.LAYERS[1], dilation=dilations[1], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv3p4s2 = conv( self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn3 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block3 = self._make_layer( self.BLOCK, self.PLANES[2], self.LAYERS[2], dilation=dilations[2], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv4p8s2 = conv( self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn4 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block4 = self._make_layer( self.BLOCK, self.PLANES[3], self.LAYERS[3], dilation=dilations[3], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.relu = MinkowskiReLU(inplace=True) # add a classification head here self.clf_glob_avg = ME.MinkowskiGlobalPooling(dimension=D) self.clf_glob_max=ME.MinkowskiGlobalMaxPooling(dimension=D) self.clf_conv0 = conv( 256, 512, kernel_size=3, stride=2, dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.clf_bn0 = get_norm(self.NORM_TYPE, 512, D, bn_momentum=bn_momentum) self.clf_conv1 = conv( 512, 512, kernel_size=3, stride=2, dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.clf_bn1 = get_norm(self.NORM_TYPE, 512, D, bn_momentum=bn_momentum) self.clf_conv2 = conv( 512, config['clf_num_labels'], kernel_size=1, stride=1, dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D)
def network_initialization(self, in_channels, out_channels, D): # Setup net_metadata dilations = self.DILATIONS bn_momentum = 0.02 def space_n_time_m(n, m): return n if D == 3 else [n, n, n, m] if D == 4: self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) # Output of the first conv concated to conv6 conv1_kernel_size = 3 self.inplanes = self.INIT_DIM self.conv0p1s1 = conv(in_channels, self.inplanes, kernel_size=space_n_time_m(conv1_kernel_size, 1), stride=1, dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn0 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.conv1p1s2 = conv(self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn1 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0], dilation=dilations[0], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv2p2s2 = conv(self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn2 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1], dilation=dilations[1], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv3p4s2 = conv(self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn3 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2], dilation=dilations[2], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.conv4p8s2 = conv(self.inplanes, self.inplanes, kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), dilation=1, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bn4 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3], dilation=dilations[3], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.convtr4p16s2 = conv_tr(self.inplanes, self.PLANES[4], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bntr4 = get_norm(self.NORM_TYPE, self.PLANES[4], D, bn_momentum=bn_momentum) self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], self.LAYERS[4], dilation=dilations[4], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.convtr5p8s2 = conv_tr(self.inplanes, self.PLANES[5], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bntr5 = get_norm(self.NORM_TYPE, self.PLANES[5], D, bn_momentum=bn_momentum) self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], self.LAYERS[5], dilation=dilations[5], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.convtr6p4s2 = conv_tr(self.inplanes, self.PLANES[6], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bntr6 = get_norm(self.NORM_TYPE, self.PLANES[6], D, bn_momentum=bn_momentum) self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion self.block7 = self._make_layer(self.BLOCK, self.PLANES[6], self.LAYERS[6], dilation=dilations[6], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.convtr7p2s2 = conv_tr(self.inplanes, self.PLANES[7], kernel_size=space_n_time_m(2, 1), upsample_stride=space_n_time_m(2, 1), dilation=1, bias=False, conv_type=self.NON_BLOCK_CONV_TYPE, D=D) self.bntr7 = get_norm(self.NORM_TYPE, self.PLANES[7], D, bn_momentum=bn_momentum) self.inplanes = self.PLANES[7] + self.INIT_DIM self.block8 = self._make_layer(self.BLOCK, self.PLANES[7], self.LAYERS[7], dilation=dilations[7], norm_type=self.NORM_TYPE, bn_momentum=bn_momentum) self.all_feat_names = [ "en0", "en1", "en2", "en3", "en4", "plane4", "plane5", "plane6", "plane7", ] self.relu = MinkowskiReLU(inplace=True) self.maxpool = ME.MinkowskiGlobalMaxPooling() self.avgpool = ME.MinkowskiGlobalAvgPooling() if self.use_mlp: self.head = SMLP(self.mlp_dim)