Esempio n. 1
0
    def __init__(self):
        super().__init__()
        # torch.cuda.set_device(1)
        # self.img_size = img_size

        self.preBlock = nn.Sequential(
            nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=2), nn.PReLU())

        # ResBlock 8
        self.blocks = nn.Sequential(
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
        )

        self.postBlock = nn.Sequential(
            nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64))

        self.final = nn.Sequential(
            nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2), )

        self.symmetry_amp = Lambda(partial(symmetry, mode="real"))
        self.symmetry_imag = Lambda(partial(symmetry, mode="imag"))
Esempio n. 2
0
    def __init__(self):
        super().__init__()

        self.preBlock = nn.Sequential(
            nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=2), nn.PReLU())

        # ResBlock 8
        self.blocks = nn.Sequential(
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
        )

        self.postBlock = nn.Sequential(
            nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64))

        self.final = nn.Sequential(
            nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2), )

        self.symmetry_amp = Lambda(partial(symmetry, mode="real"))
        self.symmetry_imag = Lambda(partial(symmetry, mode="imag"))
        self.hardtanh = nn.Hardtanh(-pi, pi)
Esempio n. 3
0
    def __init__(self):
        super().__init__()

        n_channel = 56

        self.preBlock = nn.Sequential(
            nn.Conv2d(1, n_channel, 9, stride=1, padding=4, groups=1),
            nn.PReLU())

        # ResBlock 8
        self.blocks = nn.Sequential(
            SRBlock(n_channel, n_channel),
            SRBlock(n_channel, n_channel),
            SRBlock(n_channel, n_channel),
            SRBlock(n_channel, n_channel),
            SRBlock(n_channel, n_channel),
            SRBlock(n_channel, n_channel),
            SRBlock(n_channel, n_channel),
            SRBlock(n_channel, n_channel),
        )

        self.postBlock = nn.Sequential(
            nn.Conv2d(n_channel, n_channel, 3, stride=1, padding=1,
                      bias=False),
            nn.BatchNorm2d(n_channel),
        )

        self.final = nn.Sequential(
            nn.Conv2d(n_channel, 2, 9, stride=1, padding=4, groups=1), )

        self.symmetry_amp = Lambda(partial(symmetry, mode="real"))
        self.symmetry_imag = Lambda(partial(symmetry, mode="imag"))

        self.elu = GeneralELU(add=+(1 + 1e-10))
Esempio n. 4
0
    def __init__(self):
        super().__init__()

        self.preBlock = nn.Sequential(
            nn.Conv2d(1, 64, 9, stride=1, padding=4, groups=1), nn.PReLU())

        # ResBlock 16
        self.blocks = nn.Sequential(
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            # SRBlock(64, 64),
            # SRBlock(64, 64),
            # SRBlock(64, 64),
            # SRBlock(64, 64),
            # SRBlock(64, 64),
            # SRBlock(64, 64),
            # SRBlock(64, 64),
            # SRBlock(64, 64),
        )

        self.postBlock = nn.Sequential(
            nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64))

        self.final = nn.Sequential(
            nn.Conv2d(64, 1, 9, stride=1, padding=4, groups=1), )

        self.symmetry_amp = Lambda(partial(symmetry, mode="real"))
Esempio n. 5
0
    def __init__(self):
        super().__init__()

        self.preBlock = nn.Sequential(
            nn.Conv2d(2, 32, 9, stride=1, padding=4, groups=2), nn.PReLU())

        # ResBlock 8
        self.blocks = nn.Sequential(
            SRBlock(32, 32),
            SRBlock(32, 32),
            SRBlock(32, 32),
            SRBlock(32, 32),
            SRBlock(32, 32),
            SRBlock(32, 32),
            SRBlock(32, 32),
            SRBlock(32, 32),
        )

        self.postBlock = nn.Sequential(
            nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.BatchNorm2d(32))

        self.final = nn.Sequential(
            nn.Conv2d(32, 2, 9, stride=1, padding=4, groups=2), )

        self.symmetry_amp = Lambda(partial(symmetry, mode="real"))
        self.symmetry_imag = Lambda(partial(symmetry, mode="imag"))

        self.conv1 = nn.Sequential(
            nn.Conv2d(2, 512, stride=1, kernel_size=3),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
        )

        self.flatten = Lambda(flatten)
        self.linear1 = nn.Linear(512, 256)
        self.linear2 = nn.Linear(256, 2 * 3)
        self.shape = Lambda(shape)
Esempio n. 6
0
    def __init__(self):
        super().__init__()
        # torch.cuda.set_device(1)
        # self.img_size = img_size

        self.preBlock = nn.Sequential(
            nn.Conv2d(1, 32, 9, stride=1, padding=4, groups=1), nn.PReLU())

        # ResBlock 12
        self.blocks = nn.Sequential(
            SRBlock(32, 32),
            SRBlock(32, 32),
            SRBlock(32, 32),
            SRBlock(32, 32),
            SRBlock(32, 32),
            SRBlock(32, 32),
            SRBlock(32, 32),
            SRBlock(32, 32),
            SRBlock(32, 32),
            SRBlock(32, 32),
            SRBlock(32, 32),
            SRBlock(32, 32),
        )

        self.postBlock = nn.Sequential(
            nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.BatchNorm2d(32))

        self.final = nn.Sequential(
            nn.Conv2d(32, 1, 9, stride=1, padding=4, groups=1), )

        self.symmetry_phase = Lambda(partial(symmetry, mode="imag"))
Esempio n. 7
0
    def __init__(self, img_size):
        super().__init__()
        # torch.cuda.set_device(1)
        self.img_size = img_size

        self.preBlock_amp = nn.Sequential(
            nn.Conv2d(1, 64, 9, stride=1, padding=4), nn.PReLU())
        self.preBlock_phase = nn.Sequential(
            nn.Conv2d(1, 64, 9, stride=1, padding=4), nn.PReLU())

        # ResBlock 16
        self.blocks_amp = nn.Sequential(
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
        )
        self.blocks_phase = nn.Sequential(
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
        )

        self.postBlock_amp = nn.Sequential(
            nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64))
        self.postBlock_phase = nn.Sequential(
            nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64))

        self.final_amp = nn.Sequential(
            nn.Conv2d(64, 1, 9, stride=1, padding=4), )
        self.final_phase = nn.Sequential(
            nn.Conv2d(64, 1, 9, stride=1, padding=4), )
Esempio n. 8
0
    def __init__(self):
        super().__init__()

        self.preBlock = nn.Sequential(
            nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=2), nn.PReLU())

        # ResBlock 16
        self.blocks = nn.Sequential(
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
            SRBlock(64, 64),
        )

        self.postBlock = nn.Sequential(
            nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64))

        self.final = nn.Sequential(
            nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2), )

        self.hardtanh = nn.Hardtanh(-pi, pi)