def __init__(self): super().__init__() self.dconv_down1 = nn.Sequential(*double_conv(2, 4, (3, 3), 1, 1), ) self.dconv_down2 = nn.Sequential(*double_conv(4, 8, (3, 3), 1, 1), ) self.dconv_down3 = nn.Sequential(*double_conv(8, 16, (3, 3), 1, 1), ) self.dconv_down4 = nn.Sequential(*double_conv(16, 32, (3, 3), 1, 1), ) self.dconv_down5 = nn.Sequential(*double_conv(32, 64, (3, 3), 1, 1), ) self.maxpool = nn.MaxPool2d(2) self.upsample1 = nn.Upsample(size=7, mode="bilinear", align_corners=True) self.upsample2 = nn.Upsample(size=15, mode="bilinear", align_corners=True) self.upsample3 = nn.Upsample(size=31, mode="bilinear", align_corners=True) self.upsample4 = nn.Upsample(size=63, mode="bilinear", align_corners=True) self.dconv_up4 = nn.Sequential(*double_conv(32 + 64, 32, (3, 3), 1, 1), ) self.dconv_up3 = nn.Sequential(*double_conv(16 + 32, 16, (3, 3), 1, 1), ) self.dconv_up2 = nn.Sequential(*double_conv(8 + 16, 8, (3, 3), 1, 1), ) self.dconv_up1 = nn.Sequential(*double_conv(4 + 8, 4, (3, 3), 1, 1), ) self.conv_last = nn.Conv2d(4, 2, 1) self.flatten = Lambda(flatten) self.linear1 = nn.Linear(8192, 4096) self.fft = Lambda(fft)
def __init__(self): super().__init__() self.preBlock = nn.Sequential( nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=2), nn.PReLU()) # ResBlock 8 self.blocks = nn.Sequential( SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), ) self.postBlock = nn.Sequential( nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False), nn.BatchNorm2d(64)) self.final = nn.Sequential( nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2), ) self.symmetry_amp = Lambda(partial(symmetry, mode="real")) self.symmetry_imag = Lambda(partial(symmetry, mode="imag")) self.hardtanh = nn.Hardtanh(-pi, pi)
def __init__(self): super().__init__() # torch.cuda.set_device(1) # self.img_size = img_size self.preBlock = nn.Sequential( nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=2), nn.PReLU()) # ResBlock 8 self.blocks = nn.Sequential( SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), ) self.postBlock = nn.Sequential( nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64)) self.final = nn.Sequential( nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2), ) self.symmetry_amp = Lambda(partial(symmetry, mode="real")) self.symmetry_imag = Lambda(partial(symmetry, mode="imag"))
def __init__(self): super().__init__() n_channel = 56 self.preBlock = nn.Sequential( nn.Conv2d(1, n_channel, 9, stride=1, padding=4, groups=1), nn.PReLU()) # ResBlock 8 self.blocks = nn.Sequential( SRBlock(n_channel, n_channel), SRBlock(n_channel, n_channel), SRBlock(n_channel, n_channel), SRBlock(n_channel, n_channel), SRBlock(n_channel, n_channel), SRBlock(n_channel, n_channel), SRBlock(n_channel, n_channel), SRBlock(n_channel, n_channel), ) self.postBlock = nn.Sequential( nn.Conv2d(n_channel, n_channel, 3, stride=1, padding=1, bias=False), nn.BatchNorm2d(n_channel), ) self.final = nn.Sequential( nn.Conv2d(n_channel, 2, 9, stride=1, padding=4, groups=1), ) self.symmetry_amp = Lambda(partial(symmetry, mode="real")) self.symmetry_imag = Lambda(partial(symmetry, mode="imag")) self.elu = GeneralELU(add=+(1 + 1e-10))
def __init__(self, img_size): super().__init__() self.conv1_amp = nn.Sequential(*conv_amp(1, 4, (23, 23), 1, 11, 1)) self.conv2_amp = nn.Sequential(*conv_amp(4, 8, (21, 21), 1, 10, 1)) self.conv3_amp = nn.Sequential(*conv_amp(8, 12, (17, 17), 1, 8, 1)) self.conv_con1_amp = nn.Sequential( LocallyConnected2d(12, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), nn.ReLU(), ) self.conv4_amp = nn.Sequential(*conv_amp(1, 4, (5, 5), 1, 4, 2)) self.conv5_amp = nn.Sequential(*conv_amp(4, 8, (5, 5), 1, 2, 1)) self.conv6_amp = nn.Sequential(*conv_amp(8, 12, (3, 3), 1, 2, 2)) self.conv7_amp = nn.Sequential(*conv_amp(12, 16, (3, 3), 1, 1, 1)) self.conv_con2_amp = nn.Sequential( LocallyConnected2d(16, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), nn.ReLU(), ) self.conv8_amp = nn.Sequential(*conv_amp(1, 4, (3, 3), 1, 1, 1)) self.conv9_amp = nn.Sequential(*conv_amp(4, 8, (3, 3), 1, 1, 1)) self.conv10_amp = nn.Sequential(*conv_amp(8, 12, (3, 3), 1, 2, 2)) self.conv_con3_amp = nn.Sequential( LocallyConnected2d(12, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), nn.ReLU(), ) self.symmetry_real = Lambda(symmetry)
def __init__(self): super().__init__() # torch.cuda.set_device(1) # self.img_size = img_size self.preBlock = nn.Sequential( nn.Conv2d(1, 32, 9, stride=1, padding=4, groups=1), nn.PReLU()) # ResBlock 12 self.blocks = nn.Sequential( SRBlock(32, 32), SRBlock(32, 32), SRBlock(32, 32), SRBlock(32, 32), SRBlock(32, 32), SRBlock(32, 32), SRBlock(32, 32), SRBlock(32, 32), SRBlock(32, 32), SRBlock(32, 32), SRBlock(32, 32), SRBlock(32, 32), ) self.postBlock = nn.Sequential( nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.BatchNorm2d(32)) self.final = nn.Sequential( nn.Conv2d(32, 1, 9, stride=1, padding=4, groups=1), ) self.symmetry_phase = Lambda(partial(symmetry, mode="imag"))
def __init__(self): super().__init__() self.preBlock = nn.Sequential( nn.Conv2d(1, 64, 9, stride=1, padding=4, groups=1), nn.PReLU()) # ResBlock 16 self.blocks = nn.Sequential( SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), # SRBlock(64, 64), # SRBlock(64, 64), # SRBlock(64, 64), # SRBlock(64, 64), # SRBlock(64, 64), # SRBlock(64, 64), # SRBlock(64, 64), # SRBlock(64, 64), ) self.postBlock = nn.Sequential( nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False), nn.BatchNorm2d(64)) self.final = nn.Sequential( nn.Conv2d(64, 1, 9, stride=1, padding=4, groups=1), ) self.symmetry_amp = Lambda(partial(symmetry, mode="real"))
def autoencoder(): arch = nn.Sequential( *conv(1, 4, (3, 3), 2, 1), *conv(4, 8, (3, 3), 2, 1), *conv(8, 16, (3, 3), 2, 1), nn.MaxPool2d((3, 3)), *conv(16, 32, (2, 2), 2, 1), *conv(32, 64, (2, 2), 2, 1), nn.MaxPool2d((2, 2)), nn.ConvTranspose2d(64, 32, (3, 3), 2, 1, 1), nn.BatchNorm2d(32), GeneralRelu(leak=0.1, sub=0.4), nn.ConvTranspose2d(32, 16, (3, 3), 2, 1, 1), nn.BatchNorm2d(16), GeneralRelu(leak=0.1, sub=0.4), nn.ConvTranspose2d(16, 16, (3, 3), 2, 1, 1), nn.BatchNorm2d(16), GeneralRelu(leak=0.1, sub=0.4), nn.ConvTranspose2d(16, 8, (3, 3), 2, 1, 1), nn.BatchNorm2d(8), GeneralRelu(leak=0.1, sub=0.4), nn.ConvTranspose2d(8, 4, (3, 3), 2, 1, 1), nn.BatchNorm2d(4), GeneralRelu(leak=0.1, sub=0.4), nn.ConvTranspose2d(4, 1, (3, 3), 2, 1, 1), Lambda(flatten), ) return arch
def cnn(): """ conv-layer: number of entry channels, number of exit channels, kerner size, stride, padding """ arch = nn.Sequential( *conv(1, 4, (3, 3), 2, 1), *conv(4, 8, (3, 3), 2, 1), *conv(8, 16, (3, 3), 2, 1), nn.MaxPool2d((3, 3)), *conv(16, 32, (2, 2), 2, 1), *conv(32, 64, (2, 2), 2, 1), nn.MaxPool2d((2, 2)), Lambda(flatten), nn.Linear(64, 8192), Lambda(fft), Lambda(flatten), # *conv(2, 1, 1, 1, 0), nn.Linear(8192, 4096), # Lambda(flatten), ) return arch
def __init__(self): super().__init__() self.preBlock = nn.Sequential( nn.Conv2d(2, 32, 9, stride=1, padding=4, groups=2), nn.PReLU()) # ResBlock 8 self.blocks = nn.Sequential( SRBlock(32, 32), SRBlock(32, 32), SRBlock(32, 32), SRBlock(32, 32), SRBlock(32, 32), SRBlock(32, 32), SRBlock(32, 32), SRBlock(32, 32), ) self.postBlock = nn.Sequential( nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.BatchNorm2d(32)) self.final = nn.Sequential( nn.Conv2d(32, 2, 9, stride=1, padding=4, groups=2), ) self.symmetry_amp = Lambda(partial(symmetry, mode="real")) self.symmetry_imag = Lambda(partial(symmetry, mode="imag")) self.conv1 = nn.Sequential( nn.Conv2d(2, 512, stride=1, kernel_size=3), nn.ReLU(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), ) self.flatten = Lambda(flatten) self.linear1 = nn.Linear(512, 256) self.linear2 = nn.Linear(256, 2 * 3) self.shape = Lambda(shape)
def __init__(self): super().__init__() self.dconv_down1 = nn.Sequential(*double_conv(1, 4)) self.dconv_down2 = nn.Sequential(*double_conv(4, 8)) self.dconv_down3 = nn.Sequential(*double_conv(8, 16)) self.dconv_down4 = nn.Sequential(*double_conv(16, 32)) self.dconv_down5 = nn.Sequential(*double_conv(32, 64)) self.maxpool = nn.MaxPool2d(2) self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) self.dconv_up4 = nn.Sequential(*double_conv(32 + 64, 32)) self.dconv_up3 = nn.Sequential(*double_conv(16 + 32, 16)) self.dconv_up2 = nn.Sequential(*double_conv(8 + 16, 8)) self.dconv_up1 = nn.Sequential(*double_conv(4 + 8, 4)) self.conv_last = nn.Conv2d(4, 1, 1) self.flatten = Lambda(flatten)
def autoencoder_two_channel(): arch = nn.Sequential( *conv(2, 4, (3, 3), 2, 1), *conv(4, 8, (3, 3), 2, 1), *conv(8, 16, (3, 3), 2, 1), nn.MaxPool2d((3, 3)), *conv(16, 32, (2, 2), 2, 1), *conv(32, 64, (2, 2), 2, 1), nn.MaxPool2d((2, 2)), *deconv(64, 32, (3, 3), 2, 1, 0), *deconv(32, 16, (3, 3), 2, 1, 0), *deconv(16, 16, (3, 3), 2, 1, 0), *deconv(16, 8, (3, 3), 2, 1, 0), *deconv(8, 4, (3, 3), 2, 1, 0), # nn.ConvTranspose2d(4, 2, (3, 3), 2, 1, 1), *deconv(4, 2, (3, 3), 2, 1, 0), Lambda(flatten), # nn.Linear(8192, 4096) nn.Linear(2, 4096), ) return arch
def __init__(self, img_size): super().__init__() self.conv1_phase = nn.Sequential( *conv_phase(1, 4, (23, 23), 1, 11, 1, add=-2.1415)) self.conv2_phase = nn.Sequential( *conv_phase(4, 8, (21, 21), 1, 10, 1, add=-2.1415)) self.conv3_phase = nn.Sequential( *conv_phase(8, 12, (17, 17), 1, 8, 1, add=-2.1415)) self.conv_con1_phase = nn.Sequential( LocallyConnected2d(12, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), GeneralELU(-2.1415), ) self.conv4_phase = nn.Sequential( *conv_phase(1, 4, (5, 5), 1, 3, 2, add=-2.1415)) self.conv5_phase = nn.Sequential( *conv_phase(4, 8, (5, 5), 1, 2, 1, add=-2.1415)) self.conv6_phase = nn.Sequential( *conv_phase(8, 12, (3, 3), 1, 3, 2, add=-2.1415)) self.conv7_phase = nn.Sequential( *conv_phase(12, 16, (3, 3), 1, 1, 1, add=-2.1415)) self.conv_con2_phase = nn.Sequential( LocallyConnected2d(16, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), GeneralELU(-2.1415), ) self.conv8_phase = nn.Sequential( *conv_phase(1, 4, (3, 3), 1, 1, 1, add=-2.1415)) self.conv9_phase = nn.Sequential( *conv_phase(4, 8, (3, 3), 1, 1, 1, add=-2.1415)) self.conv10_phase = nn.Sequential( *conv_phase(8, 12, (3, 3), 1, 2, 2, add=-2.1415)) self.conv_con3_phase = nn.Sequential( LocallyConnected2d(12, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), GeneralELU(-2.1415), ) self.symmetry_imag = Lambda(partial(symmetry, mode="imag"))
def __init__(self, img_size): super().__init__() # ########################## Phase 1 self.conv1_amp = nn.Sequential(*conv_amp( ni=1, nc=4, ks=(round_odd(0.365 * img_size), round_odd(0.365 * img_size)), stride=1, padding=make_padding(round_odd(0.365 * img_size), 1, 1), dilation=1, )) self.conv1_phase = nn.Sequential(*conv_phase( ni=1, nc=4, ks=(round_odd(0.365 * img_size), round_odd(0.365 * img_size)), stride=1, padding=make_padding(round_odd(0.365 * img_size), 1, 1), dilation=1, add=1 - pi, )) self.conv2_amp = nn.Sequential(*conv_amp( ni=4, nc=8, ks=(round_odd(0.333 * img_size), round_odd(0.333 * img_size)), stride=1, padding=make_padding(round_odd(0.333 * img_size), 1, 1), dilation=1, )) self.conv2_phase = nn.Sequential(*conv_phase( ni=4, nc=8, ks=(round_odd(0.333 * img_size), round_odd(0.333 * img_size)), stride=1, padding=make_padding(round_odd(0.333 * img_size), 1, 1), dilation=1, add=1 - pi, )) self.conv3_amp = nn.Sequential(*conv_amp( ni=8, nc=12, ks=(round_odd(0.269 * img_size), round_odd(0.269 * img_size)), stride=1, padding=make_padding(round_odd(0.269 * img_size), 1, 1), dilation=1, )) self.conv3_phase = nn.Sequential(*conv_phase( ni=8, nc=12, ks=(round_odd(0.269 * img_size), round_odd(0.269 * img_size)), stride=1, padding=make_padding(round_odd(0.269 * img_size), 1, 1), dilation=1, add=1 - pi, )) self.conv_con1_amp = nn.Sequential( LocallyConnected2d(12, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), nn.ReLU(), ) self.conv_con1_phase = nn.Sequential( LocallyConnected2d(12, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), GeneralELU(1 - pi), ) # #################################### Phase 2 self.conv4_amp = nn.Sequential(*conv_amp( ni=1, nc=4, ks=(round_odd(0.0793 * img_size), round_odd(0.0793 * img_size)), stride=1, padding=make_padding(round_odd(0.0793 * img_size), 1, 2), dilation=2, )) self.conv4_phase = nn.Sequential(*conv_phase( ni=1, nc=4, ks=(round_odd(0.0793 * img_size), round_odd(0.0793 * img_size)), stride=1, padding=make_padding(round_odd(0.0793 * img_size), 1, 2), dilation=2, add=1 - pi, )) self.conv5_amp = nn.Sequential(*conv_amp( ni=4, nc=8, ks=(round_odd(0.0793 * img_size), round_odd(0.0793 * img_size)), stride=1, padding=make_padding(round_odd(0.0793 * img_size), 1, 1), dilation=1, )) self.conv5_phase = nn.Sequential(*conv_phase( ni=4, nc=8, ks=(round_odd(0.0793 * img_size), round_odd(0.0793 * img_size)), stride=1, padding=make_padding(round_odd(0.0793 * img_size), 1, 1), dilation=1, add=1 - pi, )) self.conv6_amp = nn.Sequential(*conv_amp( ni=8, nc=12, ks=(round_odd(0.0476 * img_size), round_odd(0.0476 * img_size)), stride=1, padding=make_padding(round_odd(0.0476 * img_size), 1, 2), dilation=2, )) self.conv6_phase = nn.Sequential(*conv_phase( ni=8, nc=12, ks=(round_odd(0.0476 * img_size), round_odd(0.0476 * img_size)), stride=1, padding=make_padding(round_odd(0.0476 * img_size), 1, 2), dilation=2, add=1 - pi, )) self.conv7_amp = nn.Sequential(*conv_amp( ni=12, nc=16, ks=(round_odd(0.0476 * img_size), round_odd(0.0476 * img_size)), stride=1, padding=make_padding(round_odd(0.0476 * img_size), 1, 1), dilation=1, )) self.conv7_phase = nn.Sequential(*conv_phase( ni=12, nc=16, ks=(round_odd(0.0476 * img_size), round_odd(0.0476 * img_size)), stride=1, padding=make_padding(round_odd(0.0476 * img_size), 1, 1), dilation=1, add=1 - pi, )) self.conv_con2_amp = nn.Sequential( LocallyConnected2d(16, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), nn.ReLU(), ) self.conv_con2_phase = nn.Sequential( LocallyConnected2d(16, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), GeneralELU(1 - pi), ) # ################################## Phase 3 self.conv8_amp = nn.Sequential(*conv_amp( ni=1, nc=4, ks=(round_odd(0.0476 * img_size), round_odd(0.0476 * img_size)), stride=1, padding=make_padding(round_odd(0.0476 * img_size), 1, 1), dilation=1, )) self.conv8_phase = nn.Sequential(*conv_phase( ni=1, nc=4, ks=(round_odd(0.0476 * img_size), round_odd(0.0476 * img_size)), stride=1, padding=make_padding(round_odd(0.0476 * img_size), 1, 1), dilation=1, add=1 - pi, )) self.conv9_amp = nn.Sequential(*conv_amp( ni=4, nc=8, ks=(round_odd(0.0476 * img_size), round_odd(0.0476 * img_size)), stride=1, padding=make_padding(round_odd(0.0476 * img_size), 1, 1), dilation=1, )) self.conv9_phase = nn.Sequential(*conv_phase( ni=4, nc=8, ks=(round_odd(0.0476 * img_size), round_odd(0.0476 * img_size)), stride=1, padding=make_padding(round_odd(0.0476 * img_size), 1, 1), dilation=1, add=1 - pi, )) self.conv10_amp = nn.Sequential(*conv_amp( ni=8, nc=12, ks=(round_odd(0.0476 * img_size), round_odd(0.0476 * img_size)), stride=1, padding=make_padding(round_odd(0.0476 * img_size), 1, 2), dilation=2, )) self.conv10_phase = nn.Sequential(*conv_phase( ni=8, nc=12, ks=(round_odd(0.0476 * img_size), round_odd(0.0476 * img_size)), stride=1, padding=make_padding(round_odd(0.0476 * img_size), 1, 2), dilation=2, add=1 - pi, )) self.conv11_amp = nn.Sequential(*conv_amp( ni=12, nc=20, ks=(round_odd(0.0476 * img_size), round_odd(0.0476 * img_size)), stride=1, padding=make_padding(round_odd(0.0476 * img_size), 1, 1), dilation=1, )) self.conv11_phase = nn.Sequential(*conv_phase( ni=12, nc=20, ks=(round_odd(0.0476 * img_size), round_odd(0.0476 * img_size)), stride=1, padding=make_padding(round_odd(0.0476 * img_size), 1, 1), dilation=1, add=1 - pi, )) self.conv_con3_amp = nn.Sequential( LocallyConnected2d(20, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), nn.ReLU(), ) self.conv_con3_phase = nn.Sequential( LocallyConnected2d(20, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), GeneralELU(1 - pi), ) self.symmetry_real = Lambda(symmetry) self.symmetry_imag = Lambda(partial(symmetry, mode="imag"))
def __init__(self, img_size): super().__init__() self.conv1_amp = nn.Sequential(*conv_amp(1, 4, (23, 23), 1, 11, 1)) self.conv1_phase = nn.Sequential( *conv_phase(1, 4, (23, 23), 1, 11, 1, add=1 - pi)) self.conv2_amp = nn.Sequential(*conv_amp(4, 8, (21, 21), 1, 10, 1)) self.conv2_phase = nn.Sequential( *conv_phase(4, 8, (21, 21), 1, 10, 1, add=1 - pi)) self.conv3_amp = nn.Sequential(*conv_amp(8, 12, (17, 17), 1, 8, 1)) self.conv3_phase = nn.Sequential( *conv_phase(8, 12, (17, 17), 1, 8, 1, add=1 - pi)) self.conv_con1_amp = nn.Sequential( LocallyConnected2d(12, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), nn.ReLU(), ) self.conv_con1_phase = nn.Sequential( LocallyConnected2d(12, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), GeneralELU(1 - pi), ) self.conv4_amp = nn.Sequential(*conv_amp(1, 4, (5, 5), 1, 3, 2)) self.conv4_phase = nn.Sequential( *conv_phase(1, 4, (5, 5), 1, 3, 2, add=1 - pi)) self.conv5_amp = nn.Sequential(*conv_amp(4, 8, (5, 5), 1, 2, 1)) self.conv5_phase = nn.Sequential( *conv_phase(4, 8, (5, 5), 1, 2, 1, add=1 - pi)) self.conv6_amp = nn.Sequential(*conv_amp(8, 12, (3, 3), 1, 3, 2)) self.conv6_phase = nn.Sequential( *conv_phase(8, 12, (3, 3), 1, 3, 2, add=1 - pi)) self.conv7_amp = nn.Sequential(*conv_amp(12, 16, (3, 3), 1, 1, 1)) self.conv7_phase = nn.Sequential( *conv_phase(12, 16, (3, 3), 1, 1, 1, add=1 - pi)) self.conv_con2_amp = nn.Sequential( LocallyConnected2d(16, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), nn.ReLU(), ) self.conv_con2_phase = nn.Sequential( LocallyConnected2d(16, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), GeneralELU(1 - pi), ) self.conv8_amp = nn.Sequential(*conv_amp(1, 4, (3, 3), 1, 1, 1)) self.conv8_phase = nn.Sequential( *conv_phase(1, 4, (3, 3), 1, 1, 1, add=1 - pi)) self.conv9_amp = nn.Sequential(*conv_amp(4, 8, (3, 3), 1, 1, 1)) self.conv9_phase = nn.Sequential( *conv_phase(4, 8, (3, 3), 1, 1, 1, add=1 - pi)) self.conv10_amp = nn.Sequential(*conv_amp(8, 12, (3, 3), 1, 2, 2)) self.conv10_phase = nn.Sequential( *conv_phase(8, 12, (3, 3), 1, 2, 2, add=1 - pi)) self.conv11_amp = nn.Sequential(*conv_amp(12, 20, (3, 3), 1, 1, 1)) self.conv11_phase = nn.Sequential( *conv_phase(12, 20, (3, 3), 1, 1, 1, add=1 - pi)) self.conv_con3_amp = nn.Sequential( LocallyConnected2d(20, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), nn.ReLU(), ) self.conv_con3_phase = nn.Sequential( LocallyConnected2d(20, 1, img_size, 1, stride=1, bias=False), nn.BatchNorm2d(1), GeneralELU(1 - pi), ) self.symmetry_real = Lambda(symmetry) self.symmetry_imag = Lambda(partial(symmetry, mode="imag")) self.flatten = Lambda(flatten) # self.fully_connected = nn.Linear(3969, 54) # self.fully_connected = nn.Linear(7938, 5) # self.fully_connected = nn.Linear(3969, 1) self.fully_connected = nn.Linear(7938, 3) self.vaild_gauss_bs = Lambda(vaild_gauss_bs) self.Relu = nn.ReLU() self.fft = Lambda(fft) self.euler = Lambda(euler) self.shape = Lambda(shape)