def __init__(self, intLevel): super(Decoder, self).__init__() intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 1] intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 0] if intLevel < 6: self.moduleUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1) if intLevel < 6: self.moduleUpfeat = torch.nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1) if intLevel < 6: self.dblBackward = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1] if intLevel < 6: self.moduleBackward = Backward() self.moduleCorrelation = correlation.ModuleCorrelation() self.moduleCorreleaky = torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) self.moduleOne = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.moduleTwo = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.moduleThr = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.moduleFou = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.moduleFiv = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.moduleSix = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1) )
def __init__(self): super(Complex, self).__init__() self.moduleOne = torch.nn.Sequential( torch.nn.ZeroPad2d([2, 4, 2, 4]), torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=0), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), ) self.moduleTwo = torch.nn.Sequential( torch.nn.ZeroPad2d([1, 3, 1, 3]), torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=0), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), ) self.moduleThr = torch.nn.Sequential( torch.nn.ZeroPad2d([1, 3, 1, 3]), torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=2, padding=0), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), ) self.moduleRedir = torch.nn.Sequential( torch.nn.Conv2d(in_channels=256, out_channels=32, kernel_size=1, stride=1, padding=0), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), ) self.moduleCorrelation = correlation.ModuleCorrelation() self.moduleCombined = torch.nn.Sequential( torch.nn.Conv2d(in_channels=473, out_channels=256, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), ) self.moduleFou = torch.nn.Sequential( torch.nn.ZeroPad2d([0, 2, 0, 2]), torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=0), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), ) self.moduleFiv = torch.nn.Sequential( torch.nn.ZeroPad2d([0, 2, 0, 2]), torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=0), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), ) self.moduleSix = torch.nn.Sequential( torch.nn.ZeroPad2d([0, 2, 0, 2]), torch.nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=2, padding=0), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), ) self.moduleUpconv = Upconv()