Exemplo n.º 1
0
def ifft2d_unitary(n1: int,
                   n2: int,
                   br_first: bool = True,
                   with_br_perm: bool = True) -> nn.Module:
    """ Construct an nn.Module based on ButterflyUnitary that exactly performs the 2D iFFT.
    Corresponds to normalized=True.
    Does not support flatten for now.
    Parameters:
        n1: size of the iFFT on the last input dimension. Must be a power of 2.
        n2: size of the iFFT on the second to last input dimension. Must be a power of 2.
        br_first: which decomposition of iFFT. True corresponds to decimation-in-frequency.
                  False corresponds to decimation-in-time.
        with_br_perm: whether to return both the butterfly and the bit reversal permutation.
    """
    b_ifft1 = ifft_unitary(n1, br_first=br_first, with_br_perm=False)
    b_ifft2 = ifft_unitary(n2, br_first=br_first, with_br_perm=False)
    b = TensorProduct(b_ifft1, b_ifft2)
    if with_br_perm:
        br_perm1 = FixedPermutation(
            bitreversal_permutation(n1, pytorch_format=True))
        br_perm2 = FixedPermutation(
            bitreversal_permutation(n2, pytorch_format=True))
        br_perm = TensorProduct(br_perm1, br_perm2)
        return nn.Sequential(br_perm, b) if br_first else nn.Sequential(
            b, br_perm)
    else:
        return b
Exemplo n.º 2
0
def ifft2d(n1: int,
           n2: int,
           normalized: bool = False,
           br_first: bool = True,
           with_br_perm: bool = True,
           flatten=False) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs the 2D iFFT.
    Parameters:
        n1: size of the iFFT on the last input dimension. Must be a power of 2.
        n2: size of the iFFT on the second to last input dimension. Must be a power of 2.
        normalized: if True, corresponds to the unitary iFFT (i.e. multiplied by 1/sqrt(n))
        br_first: which decomposition of iFFT. True corresponds to decimation-in-frequency.
                  False corresponds to decimation-in-time.
        with_br_perm: whether to return both the butterfly and the bit reversal permutation.
        flatten: whether to combine the 2 butterflies into 1 with Kronecker product.
    """
    b_ifft1 = ifft(n1,
                   normalized=normalized,
                   br_first=br_first,
                   with_br_perm=False)
    b_ifft2 = ifft(n2,
                   normalized=normalized,
                   br_first=br_first,
                   with_br_perm=False)
    b = TensorProduct(b_ifft1,
                      b_ifft2) if not flatten else butterfly_kronecker(
                          b_ifft1, b_ifft2)
    if with_br_perm:
        br_perm1 = FixedPermutation(
            bitreversal_permutation(n1, pytorch_format=True))
        br_perm2 = FixedPermutation(
            bitreversal_permutation(n2, pytorch_format=True))
        br_perm = (TensorProduct(br_perm1, br_perm2) if not flatten else
                   permutation_kronecker(br_perm1, br_perm2))
        if not flatten:
            return nn.Sequential(br_perm, b) if br_first else nn.Sequential(
                b, br_perm)
        else:
            return (nn.Sequential(nn.Flatten(
                start_dim=-2), br_perm, b, nn.Unflatten(-1, (n2, n1)))
                    if br_first else nn.Sequential(nn.Flatten(
                        start_dim=-2), b, br_perm, nn.Unflatten(-1, (n2, n1))))
    else:
        return b if not flatten else nn.Sequential(nn.Flatten(
            start_dim=-2), b, nn.Unflatten(-1, (n2, n1)))
Exemplo n.º 3
0
 def test_fft2d_init(self):
     batch_size = 10
     in_channels = 3
     out_channels = 4
     n1, n2 = 16, 32
     input = torch.randn(batch_size, in_channels, n2, n1)
     for kernel_size1 in [1, 3, 5, 7]:
         for kernel_size2 in [1, 3, 5, 7]:
             padding1 = (kernel_size1 - 1) // 2
             padding2 = (kernel_size2 - 1) // 2
             conv = nn.Conv2d(in_channels,
                              out_channels, (kernel_size2, kernel_size1),
                              padding=(padding2, padding1),
                              padding_mode='circular',
                              bias=False)
             out_torch = conv(input)
             weight = conv.weight
             w = F.pad(weight.flip(dims=(-1, )),
                       (0, n1 - kernel_size1)).roll(-padding1, dims=-1)
             w = F.pad(w.flip(dims=(-2, )),
                       (0, 0, 0, n2 - kernel_size2)).roll(-padding2,
                                                          dims=-2)
             increasing_strides = [False, False, True]
             inits = ['fft_no_br', 'fft_no_br', 'ifft_no_br']
             for nblocks in [1, 2, 3]:
                 Kd, K1, K2 = [
                     TensorProduct(
                         Butterfly(n1,
                                   n1,
                                   bias=False,
                                   complex=complex,
                                   increasing_stride=incstride,
                                   init=i,
                                   nblocks=nblocks),
                         Butterfly(n2,
                                   n2,
                                   bias=False,
                                   complex=complex,
                                   increasing_stride=incstride,
                                   init=i,
                                   nblocks=nblocks))
                     for incstride, i in zip(increasing_strides, inits)
                 ]
                 with torch.no_grad():
                     Kd.map1 *= math.sqrt(n1)
                     Kd.map2 *= math.sqrt(n2)
                 out = K2(
                     complex_matmul(
                         K1(real2complex(input)).permute(2, 3, 0, 1),
                         Kd(real2complex(w)).permute(2, 3, 1, 0)).permute(
                             2, 3, 0, 1)).real
                 self.assertTrue(
                     torch.allclose(out, out_torch, self.rtol, self.atol))
Exemplo n.º 4
0
    def __init__(self,
                 in_size,
                 in_ch,
                 out_ch,
                 kernel_size,
                 complex=True,
                 init='ortho',
                 nblocks=1,
                 base=2,
                 zero_pad=True):
        super().__init__()
        self.in_size = in_size
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.complex = complex
        assert init in ['ortho', 'fft']
        if init == 'fft':
            assert self.complex, 'fft init requires complex=True'
        self.init = init
        self.nblocks = nblocks
        assert base in [2, 4]
        self.base = base
        self.zero_pad = zero_pad
        if isinstance(self.in_size, int):
            self.in_size = (self.in_size, self.in_size)
        if isinstance(self.kernel_size, int):
            self.kernel_size = (self.kernel_size, self.kernel_size)
        self.padding = (self.kernel_size[0] - 1) // 2, (self.kernel_size[1] -
                                                        1) // 2
        # Just to use nn.Conv2d's initialization
        self.weight = nn.Parameter(
            nn.Conv2d(self.in_ch,
                      self.out_ch,
                      self.kernel_size,
                      padding=self.padding,
                      bias=False).weight.flip([-1, -2]))

        increasing_strides = [False, False, True]
        inits = ['ortho'] * 3 if self.init == 'ortho' else [
            'fft_no_br', 'fft_no_br', 'ifft_no_br'
        ]
        self.Kd, self.K1, self.K2 = [
            TensorProduct(
                Butterfly(self.in_size[-1],
                          self.in_size[-1],
                          bias=False,
                          complex=complex,
                          increasing_stride=incstride,
                          init=i,
                          nblocks=nblocks),
                Butterfly(self.in_size[-2],
                          self.in_size[-2],
                          bias=False,
                          complex=complex,
                          increasing_stride=incstride,
                          init=i,
                          nblocks=nblocks))
            for incstride, i in zip(increasing_strides, inits)
        ]
        with torch.no_grad():
            self.Kd.map1 *= math.sqrt(self.in_size[-1])
            self.Kd.map2 *= math.sqrt(self.in_size[-2])
        if self.zero_pad and self.complex:
            # Instead of zero-padding and calling weight.roll(-self.padding[-1], dims=-1) and
            # weight.roll(-self.padding[-2], dims=-2), we multiply self.Kd by complex exponential
            # instead, using the Shift theorem.
            # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Shift_theorem
            with torch.no_grad():
                n1, n2 = self.Kd.map1.n, self.Kd.map2.n
                device = self.Kd.map1.twiddle.device
                br1 = bitreversal_permutation(n1,
                                              pytorch_format=True).to(device)
                br2 = bitreversal_permutation(n2,
                                              pytorch_format=True).to(device)
                diagonal1 = torch.exp(1j * 2 * math.pi / n1 *
                                      self.padding[-1] *
                                      torch.arange(n1, device=device))[br1]
                diagonal2 = torch.exp(1j * 2 * math.pi / n2 *
                                      self.padding[-2] *
                                      torch.arange(n2, device=device))[br2]
                # We multiply the 1st block instead of the last block (only the first block is not
                # the identity if init=fft). This seems to perform a tiny bit better.
                # If init=ortho, this won't correspond exactly to rolling the weight.
                self.Kd.map1.twiddle[:, 0, -1, :,
                                     0, :] *= diagonal1[::2].unsqueeze(-1)
                self.Kd.map1.twiddle[:, 0, -1, :,
                                     1, :] *= diagonal1[1::2].unsqueeze(-1)
                self.Kd.map2.twiddle[:, 0, -1, :,
                                     0, :] *= diagonal2[::2].unsqueeze(-1)
                self.Kd.map2.twiddle[:, 0, -1, :,
                                     1, :] *= diagonal2[1::2].unsqueeze(-1)

        if base == 4:
            self.Kd.map1, self.Kd.map2 = self.Kd.map1.to_base4(
            ), self.Kd.map2.to_base4()
            self.K1.map1, self.K1.map2 = self.K1.map1.to_base4(
            ), self.K1.map2.to_base4()
            self.K2.map1, self.K2.map2 = self.K2.map1.to_base4(
            ), self.K2.map2.to_base4()

        if complex:
            self.Kd = nn.Sequential(Real2Complex(), self.Kd)
            self.K1 = nn.Sequential(Real2Complex(), self.K1)
            self.K2 = nn.Sequential(self.K2, Complex2Real())
Exemplo n.º 5
0
    def __init__(self,
                 in_size,
                 in_ch,
                 out_ch,
                 kernel_size,
                 complex=True,
                 init='random'):
        super().__init__()
        self.in_size = in_size
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.complex = complex
        assert init in ['random', 'fft']
        if init == 'fft':
            assert self.complex, 'fft init requires complex=True'
        self.init = init
        if isinstance(self.in_size, int):
            self.in_size = (self.in_size, self.in_size)
        if isinstance(self.kernel_size, int):
            self.kernel_size = (self.kernel_size, self.kernel_size)
        self.padding = (self.kernel_size[0] - 1) // 2, (self.kernel_size[1] -
                                                        1) // 2
        # Just to use nn.Conv2d's initialization
        self.weight = nn.Parameter(
            nn.Conv2d(self.in_ch,
                      self.out_ch,
                      self.kernel_size,
                      padding=self.padding,
                      bias=False).weight.flip([-1, -2]))

        linear_cls = nn.Linear if not complex else ComplexLinear
        self.Kd, self.K1, self.K2 = [
            TensorProduct(
                linear_cls(self.in_size[-1], self.in_size[-1], bias=False),
                linear_cls(self.in_size[-2], self.in_size[-2], bias=False))
            for _ in range(3)
        ]
        if init == 'fft':
            eye1 = torch.eye(self.in_size[-1], dtype=torch.complex64)
            eye2 = torch.eye(self.in_size[-2], dtype=torch.complex64)
            # These are symmetric so we don't have to take transpose
            fft_mat1 = torch.fft.fft(eye1, norm='ortho')
            fft_mat2 = torch.fft.fft(eye2, norm='ortho')
            ifft_mat1 = torch.fft.ifft(eye1, norm='ortho')
            ifft_mat2 = torch.fft.ifft(eye2, norm='ortho')
            with torch.no_grad():
                self.Kd.map1.weight.copy_(fft_mat1)
                self.Kd.map2.weight.copy_(fft_mat2)
                self.K1.map1.weight.copy_(fft_mat1)
                self.K1.map2.weight.copy_(fft_mat2)
                self.K2.map1.weight.copy_(ifft_mat1)
                self.K2.map2.weight.copy_(ifft_mat2)
        with torch.no_grad():
            self.Kd.map1.weight *= math.sqrt(self.in_size[-1])
            self.Kd.map2.weight *= math.sqrt(self.in_size[-2])
        self.Kd.map1.weight._is_structured = True
        self.Kd.map2.weight._is_structured = True
        self.K1.map1.weight._is_structured = True
        self.K1.map2.weight._is_structured = True
        self.K2.map1.weight._is_structured = True
        self.K2.map2.weight._is_structured = True

        if complex:
            self.Kd = nn.Sequential(Real2Complex(), self.Kd)
            self.K1 = nn.Sequential(Real2Complex(), self.K1)
            self.K2 = nn.Sequential(self.K2, Complex2Real())