def __init__(self, in_channels, kernel_size, r=2, d=8): super().__init__() self.conv = conv(in_channels * (r**2) * d, in_channels, kernel_size, 1) self.pack = partial(packing, r=r) self.conv3d = nn.Conv3d(1, d, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
def __init__(self, in_channels, out_channels, kernel_size, r=2, d=8): super().__init__() self.conv = conv(in_channels, out_channels * (r**2) // d, kernel_size, 1) self.unpack = nn.PixelShuffle(r) self.conv3d = nn.Conv3d(1, d, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
def __init__(self, in_channels, kernel_size, r=2): super().__init__() self.conv = conv(in_channels * (r**2), in_channels, kernel_size, 1) self.pack = partial(packing, r=r)
def __init__(self, in_channels, out_channels, kernel_size, r=2): super().__init__() self.conv = conv(in_channels, out_channels * (r**2), kernel_size, 1) self.unpack = nn.PixelShuffle(r)
def __init__(self, in_planes=3, out_planes=2, dropout=None, version=None, bn=False, store_features=None): super().__init__() assert not bn, 'Only GroupNorm is supported' self.super_resolution = int(version[0]) > 1 out_planes = out_planes * int(version[0]) self.version = version[1:] self.store_features = store_features self.features = {} # Hyper-parameters ni, no = 64, out_planes n1, n2, n3, n4, n5 = 64, 64, 128, 256, 512 num_blocks = [2, 2, 3, 3] pack_kernel = [5, 3, 3, 3, 3] unpack_kernel = [3, 3, 3, 3, 3] iconv_kernel = [3, 3, 3, 3, 3] self.pre_calc = conv(in_planes, ni, 5, 1) # Version A (Concatenated features): # Concatenate upconv features, skip features and up-sampled disparities if self.version == 'A': n1o, n1i = n1, n1 + ni + no n2o, n2i = n2, n2 + n1 + no n3o, n3i = n3, n3 + n2 + no n4o, n4i = n4, n4 + n3 n5o, n5i = n5, n5 + n4 # Version B (Additive features): # Add upconv features and skip features, and concatenate the result # with the upsampled disparities elif self.version == 'B': n1o, n1i = n1, n1 + no n2o, n2i = n2, n2 + no n3o, n3i = n3 // 2, n3 // 2 + no n4o, n4i = n4 // 2, n4 // 2 n5o, n5i = n5 // 2, n5 // 2 else: raise ValueError('Unknown PackNet version {}'.format(version)) # Encoder self.pack1 = PackLayerConv3d(n1, pack_kernel[0]) self.pack2 = PackLayerConv3d(n2, pack_kernel[1]) self.pack3 = PackLayerConv3d(n3, pack_kernel[2]) self.pack4 = PackLayerConv3d(n4, pack_kernel[3]) self.pack5 = PackLayerConv3d(n5, pack_kernel[4]) self.conv1 = conv(ni, n1, 7, 1) self.conv2 = resblock_basic(n1, n2, num_blocks[0], 1, dropout=dropout) self.conv3 = resblock_basic(n2, n3, num_blocks[1], 1, dropout=dropout) self.conv4 = resblock_basic(n3, n4, num_blocks[2], 1, dropout=dropout) self.conv5 = resblock_basic(n4, n5, num_blocks[3], 1, dropout=dropout) # Decoder self.unpack5 = UnpackLayerConv3d(n5, n5o, unpack_kernel[0]) self.unpack4 = UnpackLayerConv3d(n5, n4o, unpack_kernel[1]) self.unpack3 = UnpackLayerConv3d(n4, n3o, unpack_kernel[2]) self.unpack2 = UnpackLayerConv3d(n3, n2o, unpack_kernel[3]) self.unpack1 = UnpackLayerConv3d(n2, n1o, unpack_kernel[4]) self.iconv5 = conv(n5i, n5, iconv_kernel[0], 1) self.iconv4 = conv(n4i, n4, iconv_kernel[1], 1) self.iconv3 = conv(n3i, n3, iconv_kernel[2], 1) self.iconv2 = conv(n2i, n2, iconv_kernel[3], 1) self.iconv1 = conv(n1i, n1, iconv_kernel[4], 1) # Depth Layers self.unpack_disps = nn.PixelShuffle(2) self.unpack_disp4 = nn.Upsample(scale_factor=2, mode='nearest', align_corners=None) self.unpack_disp3 = nn.Upsample(scale_factor=2, mode='nearest', align_corners=None) self.unpack_disp2 = nn.Upsample(scale_factor=2, mode='nearest', align_corners=None) self.disp4_layer = get_invdepth(n4, out_planes=out_planes) self.disp3_layer = get_invdepth(n3, out_planes=out_planes) self.disp2_layer = get_invdepth(n2, out_planes=out_planes) self.disp1_layer = get_invdepth(n1, out_planes=out_planes) self.init_weights()