def __init__(self, depth, channel_in, nout, interp='linear'): super(SegNet, self).__init__() self.depth, self.channel_in = depth, channel_in channels = [2**max(10 - i, 2) for i in range(depth + 1)] channels.append(channel_in) channels[2] = channels[3] self.channels = channels self.convs = torch.nn.ModuleList([ ocnn.OctreeConvBnRelu(d, channels[d + 1], channels[d]) for d in range(depth, 2, -1) ]) self.pools = torch.nn.ModuleList([ ocnn.OctreeMaxPool(d, return_indices=True) for d in range(depth, 2, -1) ]) self.deconvs = torch.nn.ModuleList([ ocnn.OctreeConvBnRelu(d, channels[d], channels[d + 1]) for d in range(2, depth) ]) self.unpools = torch.nn.ModuleList( [ocnn.OctreeMaxUnpool(d) for d in range(2, depth)]) self.deconv = ocnn.OctreeConvBnRelu(depth, channels[depth], channels[depth]) self.octree_interp = ocnn.OctreeInterp(self.depth, interp, nempty=False) self.header = torch.nn.Sequential( ocnn.OctreeConv1x1BnRelu(channels[depth], 64), # fc1 ocnn.OctreeConv1x1(64, nout, use_bias=True)) # fc2
def conv_block(input_channel, depth, channel_in): conv_bn_re = [] conv_bn_re.append( ocnn.OctreeConvBnRelu(depth, input_channel, 2**(9 - depth))) pool = [] pool.append(ocnn.OctreeMaxPool(depth)) for i in range(depth - 1, 2, -1): conv_bn_re.append(ocnn.OctreeConvBnRelu(i, 2**(9 - i - 1), 2**(9 - i))) pool.append(ocnn.OctreeMaxPool(i)) return ModuleList(conv_bn_re), ModuleList(pool)
def __init__(self, depth, channel_in, nout, resblk_num): super(ResNet, self).__init__() self.depth, self.channel_in = depth, channel_in channels = [2**max(11 - i, 2) for i in range(depth + 1)] channels.append(channels[depth]) self.conv1 = ocnn.OctreeConvBnRelu(depth, channel_in, channels[depth]) self.resblocks = torch.nn.ModuleList([ ocnn.OctreeResBlocks(d, channels[d + 1], channels[d], resblk_num) for d in range(depth, 2, -1) ]) self.pools = torch.nn.ModuleList( [ocnn.OctreeMaxPool(d) for d in range(depth, 2, -1)]) self.header = torch.nn.Sequential( ocnn.FullOctreeGlobalPool(depth=2), # global pool # torch.nn.Dropout(p=0.5), # drop torch.nn.Linear(channels[3], nout)) # fc
def __init__(self, depth, channel_in, nout, nempty=False, interp='linear', use_checkpoint=False): super(UNet, self).__init__() self.depth = depth self.channel_in = channel_in self.nempty = nempty self.use_checkpoint = use_checkpoint self.config_network() self.stages = len(self.encoder_blocks) # encoder self.conv1 = ocnn.OctreeConvBnRelu(depth, channel_in, self.encoder_channel[0], nempty=nempty) self.downsample = torch.nn.ModuleList([ ocnn.OctreeConvBnRelu(depth - i, self.encoder_channel[i], self.encoder_channel[i + 1], kernel_size=[2], stride=2, nempty=nempty) for i in range(self.stages) ]) self.encoder = torch.nn.ModuleList([ ocnn.OctreeResBlocks(depth - i - 1, self.encoder_channel[i + 1], self.encoder_channel[i + 1], self.encoder_blocks[i], self.bottleneck, nempty, self.resblk, self.use_checkpoint) for i in range(self.stages) ]) # decoder depth = depth - self.stages channel = [ self.decoder_channel[i + 1] + self.encoder_channel[-i - 2] for i in range(self.stages) ] self.upsample = torch.nn.ModuleList([ ocnn.OctreeDeConvBnRelu(depth + i, self.decoder_channel[i], self.decoder_channel[i + 1], kernel_size=[2], stride=2, nempty=nempty) for i in range(self.stages) ]) self.decoder = torch.nn.ModuleList([ ocnn.OctreeResBlocks(depth + i + 1, channel[i], self.decoder_channel[i + 1], self.decoder_blocks[i], self.bottleneck, nempty, self.resblk, self.use_checkpoint) for i in range(self.stages) ]) # interpolation self.octree_interp = ocnn.OctreeInterp(self.depth, interp, nempty) # header self.header = self.make_predict_module(self.decoder_channel[-1], nout)