Exemplo n.º 1
0
def ifft2d_unitary(n1: int,
                   n2: int,
                   br_first: bool = True,
                   with_br_perm: bool = True) -> nn.Module:
    """ Construct an nn.Module based on ButterflyUnitary that exactly performs the 2D iFFT.
    Corresponds to normalized=True.
    Does not support flatten for now.
    Parameters:
        n1: size of the iFFT on the last input dimension. Must be a power of 2.
        n2: size of the iFFT on the second to last input dimension. Must be a power of 2.
        br_first: which decomposition of iFFT. True corresponds to decimation-in-frequency.
                  False corresponds to decimation-in-time.
        with_br_perm: whether to return both the butterfly and the bit reversal permutation.
    """
    b_ifft1 = ifft_unitary(n1, br_first=br_first, with_br_perm=False)
    b_ifft2 = ifft_unitary(n2, br_first=br_first, with_br_perm=False)
    b = TensorProduct(b_ifft1, b_ifft2)
    if with_br_perm:
        br_perm1 = FixedPermutation(
            bitreversal_permutation(n1, pytorch_format=True))
        br_perm2 = FixedPermutation(
            bitreversal_permutation(n2, pytorch_format=True))
        br_perm = TensorProduct(br_perm1, br_perm2)
        return nn.Sequential(br_perm, b) if br_first else nn.Sequential(
            b, br_perm)
    else:
        return b
Exemplo n.º 2
0
def flip_increasing_stride(butterfly: Butterfly) -> nn.Module:
    """Convert a Butterfly with increasing_stride=True/False to a Butterfly with
    increasing_stride=False/True, along with 2 bit-reversal permutations.
    Follows the proof of Lemma G.4.
    """
    assert butterfly.bias is None
    assert butterfly.in_size == 1 << butterfly.log_n
    assert butterfly.out_size == 1 << butterfly.log_n
    n = butterfly.in_size
    new_butterfly = copy.deepcopy(butterfly)
    new_butterfly.increasing_stride = not butterfly.increasing_stride
    br = bitreversal_permutation(n, pytorch_format=True)
    br_half = bitreversal_permutation(n // 2, pytorch_format=True)
    with torch.no_grad():
        new_butterfly.twiddle.copy_(new_butterfly.twiddle[:, :, :, br_half])
    return nn.Sequential(FixedPermutation(br), new_butterfly,
                         FixedPermutation(br))
Exemplo n.º 3
0
def ifft2d(n1: int,
           n2: int,
           normalized: bool = False,
           br_first: bool = True,
           with_br_perm: bool = True,
           flatten=False) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs the 2D iFFT.
    Parameters:
        n1: size of the iFFT on the last input dimension. Must be a power of 2.
        n2: size of the iFFT on the second to last input dimension. Must be a power of 2.
        normalized: if True, corresponds to the unitary iFFT (i.e. multiplied by 1/sqrt(n))
        br_first: which decomposition of iFFT. True corresponds to decimation-in-frequency.
                  False corresponds to decimation-in-time.
        with_br_perm: whether to return both the butterfly and the bit reversal permutation.
        flatten: whether to combine the 2 butterflies into 1 with Kronecker product.
    """
    b_ifft1 = ifft(n1,
                   normalized=normalized,
                   br_first=br_first,
                   with_br_perm=False)
    b_ifft2 = ifft(n2,
                   normalized=normalized,
                   br_first=br_first,
                   with_br_perm=False)
    b = TensorProduct(b_ifft1,
                      b_ifft2) if not flatten else butterfly_kronecker(
                          b_ifft1, b_ifft2)
    if with_br_perm:
        br_perm1 = FixedPermutation(
            bitreversal_permutation(n1, pytorch_format=True))
        br_perm2 = FixedPermutation(
            bitreversal_permutation(n2, pytorch_format=True))
        br_perm = (TensorProduct(br_perm1, br_perm2) if not flatten else
                   permutation_kronecker(br_perm1, br_perm2))
        if not flatten:
            return nn.Sequential(br_perm, b) if br_first else nn.Sequential(
                b, br_perm)
        else:
            return (nn.Sequential(nn.Flatten(
                start_dim=-2), br_perm, b, nn.Unflatten(-1, (n2, n1)))
                    if br_first else nn.Sequential(nn.Flatten(
                        start_dim=-2), b, br_perm, nn.Unflatten(-1, (n2, n1))))
    else:
        return b if not flatten else nn.Sequential(nn.Flatten(
            start_dim=-2), b, nn.Unflatten(-1, (n2, n1)))
Exemplo n.º 4
0
def conv1d_circular_multichannel(n, weight) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs nn.Conv1d
    with multiple in/out channels, with circular padding.
    The output of nn.Conv1d must have the same size as the input (i.e. kernel size must be 2k + 1,
    and padding k for some integer k).
    Parameters:
        n: size of the input.
        weight: torch.Tensor of size (out_channels, in_channels, kernel_size). Kernel_size must be
                odd, and smaller than n. Padding is assumed to be (kernel_size - 1) // 2.
    """
    assert weight.dim() == 3, 'Weight must have dimension 3'
    kernel_size = weight.shape[-1]
    assert kernel_size < n
    assert kernel_size % 2 == 1, 'Kernel size must be odd'
    out_channels, in_channels = weight.shape[:2]
    padding = (kernel_size - 1) // 2
    col = F.pad(weight.flip([-1]), (0, n - kernel_size)).roll(-padding,
                                                              dims=-1)
    # From here we mimic the circulant construction, but the diagonal multiply is replaced with
    # multiply and then sum across the in-channels.
    complex = col.is_complex()
    log_n = int(math.ceil(math.log2(n)))
    # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n?
    # I've only figured out how to pad to size 1 << (log_n + 1).
    # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c]
    n_extended = n if n == 1 << log_n else 1 << (log_n + 1)
    b_fft = fft(n_extended,
                normalized=True,
                br_first=False,
                with_br_perm=False).to(col.device)
    b_fft.in_size = n
    b_ifft = ifft(n_extended,
                  normalized=True,
                  br_first=True,
                  with_br_perm=False).to(col.device)
    b_ifft.out_size = n
    if n < n_extended:
        col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n)))
        col = torch.cat((col_0, col), dim=-1)
    if not col.is_complex():
        col = real2complex(col)
    # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the
    # circulant matrix.
    col_f = view_as_complex(
        torch.fft(view_as_real(col), signal_ndim=1, normalized=False))
    br_perm = bitreversal_permutation(n_extended,
                                      pytorch_format=True).to(col.device)
    col_f = col_f[..., br_perm]
    # We just want (input_f.unsqueeze(1) * col_f).sum(dim=2).
    # This can be written as matrix multiply but Pytorch 1.6 doesn't yet support complex matrix
    # multiply.

    if not complex:
        return nn.Sequential(Real2Complex(), b_fft, DiagonalMultiplySum(col_f),
                             b_ifft, Complex2Real())
    else:
        return nn.Sequential(b_fft, DiagonalMultiplySum(col_f), b_ifft)
Exemplo n.º 5
0
def dct(n: int, type: int = 2, normalized: bool = False) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs the DCT.
    Parameters:
        n: size of the DCT. Must be a power of 2.
        type: either 2, 3, or  4. These are the only types supported. See scipy.fft.dct's notes.
        normalized: if True, corresponds to the orthogonal DCT (see scipy.fft.dct's notes)
    """
    assert type in [2, 3, 4]
    # Construct the permutation before the FFT: separate the even and odd and then reverse the odd
    # e.g., [0, 1, 2, 3] -> [0, 2, 3, 1].
    perm = torch.arange(n)
    perm = torch.cat((perm[::2], perm[1::2].flip([0])))
    br = bitreversal_permutation(n, pytorch_format=True)
    postprocess_diag = 2 * torch.exp(-1j * math.pi * torch.arange(0.0, n) /
                                     (2 * n))
    if type in [2, 4]:
        b = fft(n, normalized=normalized, br_first=True, with_br_perm=False)
        if type == 4:
            even_mul = torch.exp(-1j * math.pi / (2 * n) *
                                 (torch.arange(0.0, n, 2) + 0.5))
            odd_mul = torch.exp(1j * math.pi / (2 * n) *
                                (torch.arange(1.0, n, 2) + 0.5))
            preprocess_diag = torch.stack((even_mul, odd_mul),
                                          dim=-1).flatten()
            # This proprocess_diag is before the permutation.
            # To move it after the permutation, we have to permute the diagonal
            b = diagonal_butterfly(b,
                                   preprocess_diag[perm[br]],
                                   diag_first=True)
        if normalized:
            if type in [2, 3]:
                postprocess_diag[0] /= 2.0
                postprocess_diag[1:] /= math.sqrt(2)
            elif type == 4:
                postprocess_diag /= math.sqrt(2)
        b = diagonal_butterfly(b, postprocess_diag, diag_first=False)
        return nn.Sequential(FixedPermutation(perm[br]), Real2Complex(), b,
                             Complex2Real())
    else:
        assert type == 3
        b = ifft(n, normalized=normalized, br_first=False, with_br_perm=False)
        postprocess_diag[0] /= 2.0
        if normalized:
            postprocess_diag[1:] /= math.sqrt(2)
        else:
            # We want iFFT with the scaling of 1.0 instead of 1 / n
            with torch.no_grad():
                b.twiddle *= 2
        b = diagonal_butterfly(b, postprocess_diag.conj(), diag_first=True)
        perm_inverse = invert(perm)
        return nn.Sequential(Real2Complex(), b, Complex2Real(),
                             FixedPermutation(br[perm_inverse]))
Exemplo n.º 6
0
def ifft(n, normalized=False, br_first=True, with_br_perm=True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs the inverse FFT.
    Parameters:
        n: size of the iFFT. Must be a power of 2.
        normalized: if True, corresponds to unitary iFFT (i.e. multiplied by 1/sqrt(n), not 1/n)
        br_first: which decomposition of iFFT. True corresponds to decimation-in-frequency.
                  False corresponds to decimation-in-time.
        with_br_perm: whether to return both the butterfly and the bit reversal permutation.
    """
    log_n = int(math.ceil(math.log2(n)))
    assert n == 1 << log_n, 'n must be a power of 2'
    factors = []
    for log_size in range(1, log_n + 1):
        size = 1 << log_size
        exp = torch.exp(2j * math.pi * torch.arange(0.0, size // 2) / size)
        o = torch.ones_like(exp)
        twiddle_factor = torch.stack((torch.stack(
            (o, exp), dim=-1), torch.stack((o, -exp), dim=-1)),
                                     dim=-2)
        factors.append(twiddle_factor.repeat(n // size, 1, 1))
    twiddle = torch.stack(factors, dim=0).unsqueeze(0).unsqueeze(0)
    if not br_first:  # Take conjugate transpose of the BP decomposition of fft
        twiddle = twiddle.transpose(-1, -2).flip([2])
    # Divide the whole transform by sqrt(n) by dividing each factor by n^(1/2 log_n) = sqrt(2)
    if normalized:
        twiddle /= math.sqrt(2)
    else:
        twiddle /= 2
    b = Butterfly(n,
                  n,
                  bias=False,
                  complex=True,
                  increasing_stride=br_first,
                  init=twiddle)
    if with_br_perm:
        br_perm = FixedPermutation(
            bitreversal_permutation(n, pytorch_format=True))
        return nn.Sequential(br_perm, b) if br_first else nn.Sequential(
            b, br_perm)
    else:
        return b
Exemplo n.º 7
0
def fft_unitary(n, br_first=True, with_br_perm=True) -> nn.Module:
    """ Construct an nn.Module based on ButterflyUnitary that exactly performs the FFT.
    Since it's unitary, it corresponds to normalized=True.
    Parameters:
        n: size of the FFT. Must be a power of 2.
        br_first: which decomposition of FFT. br_first=True corresponds to decimation-in-time.
                  br_first=False corresponds to decimation-in-frequency.
        with_br_perm: whether to return both the butterfly and the bit reversal permutation.
    """
    log_n = int(math.ceil(math.log2(n)))
    assert n == 1 << log_n, 'n must be a power of 2'
    factors = []
    for log_size in range(1, log_n + 1):
        size = 1 << log_size
        angle = -2 * math.pi * torch.arange(0.0, size // 2) / size
        phi = torch.ones_like(angle) * math.pi / 4
        alpha = angle / 2 + math.pi / 2
        psi = -angle / 2 - math.pi / 2
        if br_first:
            chi = angle / 2 - math.pi / 2
        else:
            # Take conjugate transpose of the BP decomposition of ifft, which works out to this,
            # plus the flip later.
            chi = -angle / 2 - math.pi / 2
        twiddle_factor = torch.stack([phi, alpha, psi, chi], dim=-1)
        factors.append(twiddle_factor.repeat(n // size, 1))
    twiddle = torch.stack(factors, dim=0).unsqueeze(0).unsqueeze(0)
    if not br_first:
        twiddle = twiddle.flip([2])
    b = ButterflyUnitary(n, n, bias=False, increasing_stride=br_first)
    with torch.no_grad():
        b.twiddle.copy_(twiddle)
    if with_br_perm:
        br_perm = FixedPermutation(
            bitreversal_permutation(n, pytorch_format=True))
        return nn.Sequential(br_perm, b) if br_first else nn.Sequential(
            b, br_perm)
    else:
        return b
Exemplo n.º 8
0
    def __init__(self,
                 in_size,
                 in_ch,
                 out_ch,
                 kernel_size,
                 complex=True,
                 init='ortho',
                 nblocks=1,
                 base=2,
                 zero_pad=True):
        super().__init__()
        self.in_size = in_size
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.complex = complex
        assert init in ['ortho', 'fft']
        if init == 'fft':
            assert self.complex, 'fft init requires complex=True'
        self.init = init
        self.nblocks = nblocks
        assert base in [2, 4]
        self.base = base
        self.zero_pad = zero_pad
        if isinstance(self.in_size, int):
            self.in_size = (self.in_size, self.in_size)
        if isinstance(self.kernel_size, int):
            self.kernel_size = (self.kernel_size, self.kernel_size)
        self.padding = (self.kernel_size[0] - 1) // 2, (self.kernel_size[1] -
                                                        1) // 2
        # Just to use nn.Conv2d's initialization
        self.weight = nn.Parameter(
            nn.Conv2d(self.in_ch,
                      self.out_ch,
                      self.kernel_size,
                      padding=self.padding,
                      bias=False).weight.flip([-1, -2]))

        increasing_strides = [False, False, True]
        inits = ['ortho'] * 3 if self.init == 'ortho' else [
            'fft_no_br', 'fft_no_br', 'ifft_no_br'
        ]
        self.Kd, self.K1, self.K2 = [
            TensorProduct(
                Butterfly(self.in_size[-1],
                          self.in_size[-1],
                          bias=False,
                          complex=complex,
                          increasing_stride=incstride,
                          init=i,
                          nblocks=nblocks),
                Butterfly(self.in_size[-2],
                          self.in_size[-2],
                          bias=False,
                          complex=complex,
                          increasing_stride=incstride,
                          init=i,
                          nblocks=nblocks))
            for incstride, i in zip(increasing_strides, inits)
        ]
        with torch.no_grad():
            self.Kd.map1 *= math.sqrt(self.in_size[-1])
            self.Kd.map2 *= math.sqrt(self.in_size[-2])
        if self.zero_pad and self.complex:
            # Instead of zero-padding and calling weight.roll(-self.padding[-1], dims=-1) and
            # weight.roll(-self.padding[-2], dims=-2), we multiply self.Kd by complex exponential
            # instead, using the Shift theorem.
            # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Shift_theorem
            with torch.no_grad():
                n1, n2 = self.Kd.map1.n, self.Kd.map2.n
                device = self.Kd.map1.twiddle.device
                br1 = bitreversal_permutation(n1,
                                              pytorch_format=True).to(device)
                br2 = bitreversal_permutation(n2,
                                              pytorch_format=True).to(device)
                diagonal1 = torch.exp(1j * 2 * math.pi / n1 *
                                      self.padding[-1] *
                                      torch.arange(n1, device=device))[br1]
                diagonal2 = torch.exp(1j * 2 * math.pi / n2 *
                                      self.padding[-2] *
                                      torch.arange(n2, device=device))[br2]
                # We multiply the 1st block instead of the last block (only the first block is not
                # the identity if init=fft). This seems to perform a tiny bit better.
                # If init=ortho, this won't correspond exactly to rolling the weight.
                self.Kd.map1.twiddle[:, 0, -1, :,
                                     0, :] *= diagonal1[::2].unsqueeze(-1)
                self.Kd.map1.twiddle[:, 0, -1, :,
                                     1, :] *= diagonal1[1::2].unsqueeze(-1)
                self.Kd.map2.twiddle[:, 0, -1, :,
                                     0, :] *= diagonal2[::2].unsqueeze(-1)
                self.Kd.map2.twiddle[:, 0, -1, :,
                                     1, :] *= diagonal2[1::2].unsqueeze(-1)

        if base == 4:
            self.Kd.map1, self.Kd.map2 = self.Kd.map1.to_base4(
            ), self.Kd.map2.to_base4()
            self.K1.map1, self.K1.map2 = self.K1.map1.to_base4(
            ), self.K1.map2.to_base4()
            self.K2.map1, self.K2.map2 = self.K2.map1.to_base4(
            ), self.K2.map2.to_base4()

        if complex:
            self.Kd = nn.Sequential(Real2Complex(), self.Kd)
            self.K1 = nn.Sequential(Real2Complex(), self.K1)
            self.K2 = nn.Sequential(self.K2, Complex2Real())
Exemplo n.º 9
0
def acdc(diag1: torch.Tensor,
         diag2: torch.Tensor,
         dct_first: bool = True,
         separate_diagonal: bool = True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs either the multiplication:
        x -> diag2 @ iDCT @ diag1 @ DCT @ x
    or
        x -> diag2 @ DCT @ diag1 @ iDCT @ x.
    In the paper [1], the math describes the 2nd type while the implementation uses the 1st type.
    Note that the DCT and iDCT are normalized.
    [1] Marcin Moczulski, Misha Denil, Jeremy Appleyard, Nando de Freitas.
    ACDC: A Structured Efficient Linear Layer.
    http://arxiv.org/abs/1511.05946
    Parameters:
        diag1: (n,), where n is a power of 2.
        diag2: (n,), where n is a power of 2.
        dct_first: if True, uses the first type above; otherwise use the second type.
        separate_diagonal: if False, the diagonal is combined into the Butterfly part.
    """
    n, = diag1.shape
    assert diag2.shape == (n, )
    assert n == 1 << int(math.ceil(math.log2(n))), 'n must be a power of 2'
    # Construct the permutation before the FFT: separate the even and odd and then reverse the odd
    # e.g., [0, 1, 2, 3] -> [0, 2, 3, 1].
    # This permutation is actually in B (not just B^T B or B B^T). This can be checked with
    # perm2butterfly.
    perm = torch.arange(n)
    perm = torch.cat((perm[::2], perm[1::2].flip([0])))
    perm_inverse = invert(perm)
    br = bitreversal_permutation(n, pytorch_format=True)
    postprocess_diag = 2 * torch.exp(-1j * math.pi * torch.arange(0.0, n) /
                                     (2 * n))
    # Normalize
    postprocess_diag[0] /= 2.0
    postprocess_diag[1:] /= math.sqrt(2)
    if dct_first:
        b_fft = fft(n, normalized=True, br_first=False, with_br_perm=False)
        b_ifft = ifft(n, normalized=True, br_first=True, with_br_perm=False)
        b1 = diagonal_butterfly(b_fft, postprocess_diag[br], diag_first=False)
        b2 = diagonal_butterfly(b_ifft,
                                postprocess_diag.conj()[br],
                                diag_first=True)
        if not separate_diagonal:
            b1 = diagonal_butterfly(b_fft, diag1[br], diag_first=False)
            b2 = diagonal_butterfly(b2, diag2[perm], diag_first=False)
            return nn.Sequential(FixedPermutation(perm), Real2Complex(), b1,
                                 Complex2Real(), Real2Complex(), b2,
                                 Complex2Real(),
                                 FixedPermutation(perm_inverse))
        else:
            return nn.Sequential(FixedPermutation(perm), Real2Complex(), b1,
                                 Complex2Real(),
                                 Diagonal(diagonal_init=diag1[br]),
                                 Real2Complex(), b2, Complex2Real(),
                                 Diagonal(diagonal_init=diag2[perm]),
                                 FixedPermutation(perm_inverse))
    else:
        b_fft = fft(n, normalized=True, br_first=True, with_br_perm=False)
        b_ifft = ifft(n, normalized=True, br_first=False, with_br_perm=False)
        b1 = diagonal_butterfly(b_ifft,
                                postprocess_diag.conj(),
                                diag_first=True)
        b2 = diagonal_butterfly(b_fft, postprocess_diag, diag_first=False)
        if not separate_diagonal:
            b1 = diagonal_butterfly(b1, diag1[perm][br], diag_first=False)
            b2 = diagonal_butterfly(b_fft, diag2, diag_first=False)
            return nn.Sequential(Real2Complex(), b1, Complex2Real(),
                                 Real2Complex(), b2, Complex2Real())
        else:
            return nn.Sequential(Real2Complex(), b1, Complex2Real(),
                                 Diagonal(diagonal_init=diag1[perm][br]),
                                 Real2Complex(), b2, Complex2Real(),
                                 Diagonal(diagonal_init=diag2))
Exemplo n.º 10
0
def conv2d_circular_multichannel(n1: int,
                                 n2: int,
                                 weight: torch.Tensor,
                                 flatten: bool = False) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs nn.Conv2d
    with multiple in/out channels, with circular padding.
    The output of nn.Conv2d must have the same size as the input (i.e. kernel size must be 2k + 1,
    and padding k for some integer k).
    Parameters:
        n1: size of the last dimension of the input.
        n2: size of the second to last dimension of the input.
        weight: torch.Tensor of size (out_channels, in_channels, kernel_size2, kernel_size1).
            Kernel_size must be odd, and smaller than n1/n2. Padding is assumed to be
            (kernel_size - 1) // 2.
        flatten: whether to internally flatten the last 2 dimensions of the input. Only support n1
            and n2 being powers of 2.
    """
    assert weight.dim() == 4, 'Weight must have dimension 4'
    kernel_size2, kernel_size1 = weight.shape[-2], weight.shape[-1]
    assert kernel_size1 < n1, kernel_size2 < n2
    assert kernel_size1 % 2 == 1 and kernel_size2 % 2 == 1, 'Kernel size must be odd'
    out_channels, in_channels = weight.shape[:2]
    padding1 = (kernel_size1 - 1) // 2
    padding2 = (kernel_size2 - 1) // 2
    col = F.pad(weight.flip([-1]), (0, n1 - kernel_size1)).roll(-padding1,
                                                                dims=-1)
    col = F.pad(col.flip([-2]), (0, 0, 0, n2 - kernel_size2)).roll(-padding2,
                                                                   dims=-2)
    # From here we mimic the circulant construction, but the diagonal multiply is replaced with
    # multiply and then sum across the in-channels.
    complex = col.is_complex()
    log_n1 = int(math.ceil(math.log2(n1)))
    log_n2 = int(math.ceil(math.log2(n2)))
    if flatten:
        assert n1 == 1 << log_n1, n2 == 1 << log_n2
    # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n1?
    # I've only figured out how to pad to size 1 << (log_n1 + 1).
    # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c]
    n_extended1 = n1 if n1 == 1 << log_n1 else 1 << (log_n1 + 1)
    n_extended2 = n2 if n2 == 1 << log_n2 else 1 << (log_n2 + 1)
    b_fft = fft2d(n_extended1,
                  n_extended2,
                  normalized=True,
                  br_first=False,
                  with_br_perm=False,
                  flatten=flatten).to(col.device)
    if not flatten:
        b_fft.map1.in_size = n1
        b_fft.map2.in_size = n2
    else:
        b_fft = b_fft[1]  # Ignore the nn.Flatten and nn.Unflatten
    b_ifft = ifft2d(n_extended1,
                    n_extended2,
                    normalized=True,
                    br_first=True,
                    with_br_perm=False,
                    flatten=flatten).to(col.device)
    if not flatten:
        b_ifft.map1.out_size = n1
        b_ifft.map2.out_size = n2
    else:
        b_ifft = b_ifft[1]  # Ignore the nn.Flatten and nn.Unflatten
    if n1 < n_extended1:
        col_0 = F.pad(col, (0, 2 * ((1 << log_n1) - n1)))
        col = torch.cat((col_0, col), dim=-1)
    if n2 < n_extended2:
        col_0 = F.pad(col, (0, 0, 0, 2 * ((1 << log_n2) - n2)))
        col = torch.cat((col_0, col), dim=-2)
    if not col.is_complex():
        col = real2complex(col)
    # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the
    # circulant matrix.
    col_f = torch.fft.fftn(col, dim=(-1, -2), norm=None)
    br_perm1 = bitreversal_permutation(n_extended1,
                                       pytorch_format=True).to(col.device)
    br_perm2 = bitreversal_permutation(n_extended2,
                                       pytorch_format=True).to(col.device)
    # col_f[..., br_perm2, br_perm1] would error "shape mismatch: indexing tensors could not be
    # broadcast together"
    # col_f = col_f[..., br_perm2, :][..., br_perm1]
    col_f = torch.view_as_complex(
        torch.view_as_real(col_f)[..., br_perm2, :, :][..., br_perm1, :])
    if flatten:
        col_f = col_f.reshape(*col_f.shape[:-2],
                              col_f.shape[-2] * col_f.shape[-1])
    # We just want (input_f.unsqueeze(1) * col_f).sum(dim=2).
    # This can be written as a complex matrix multiply as well.
    if not complex:
        if not flatten:
            return nn.Sequential(Real2Complex(), b_fft,
                                 DiagonalMultiplySum(col_f), b_ifft,
                                 Complex2Real())
        else:
            return nn.Sequential(Real2Complex(), nn.Flatten(start_dim=-2),
                                 b_fft, DiagonalMultiplySum(col_f), b_ifft,
                                 nn.Unflatten(-1, (n2, n1)), Complex2Real())
    else:
        if not flatten:
            return nn.Sequential(b_fft, DiagonalMultiplySum(col_f), b_ifft)
        else:
            return nn.Sequential(nn.Flatten(start_dim=-2), b_fft,
                                 DiagonalMultiplySum(col_f), b_ifft,
                                 nn.Unflatten(-1, (n2, n1)))
Exemplo n.º 11
0
def circulant(col, transposed=False, separate_diagonal=True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs circulant matrix
    multiplication.
    Parameters:
        col: torch.Tensor of size (n, ). The first column of the circulant matrix.
        transposed: if True, then the circulant matrix is transposed, i.e. col is the first *row*
                    of the matrix.
        separate_diagonal: if True, the returned nn.Module is Butterfly, Diagonal, Butterfly.
                           if False, the diagonal is combined into the Butterfly part.
    """
    assert col.dim() == 1, 'Vector col must have dimension 1'
    complex = col.is_complex()
    n = col.shape[0]
    log_n = int(math.ceil(math.log2(n)))
    # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n?
    # I've only figured out how to pad to size 1 << (log_n + 1).
    # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c]
    n_extended = n if n == 1 << log_n else 1 << (log_n + 1)
    b_fft = fft(n_extended,
                normalized=True,
                br_first=False,
                with_br_perm=False).to(col.device)
    b_fft.in_size = n
    b_ifft = ifft(n_extended,
                  normalized=True,
                  br_first=True,
                  with_br_perm=False).to(col.device)
    b_ifft.out_size = n
    if n < n_extended:
        col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n)))
        col = torch.cat((col_0, col))
    if not col.is_complex():
        col = real2complex(col)
    # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the
    # circulant matrix.
    col_f = torch.fft.fft(col, norm=None)
    if transposed:
        # We could have just transposed the iFFT * Diag * FFT to get FFT * Diag * iFFT.
        # Instead we use the fact that row is the reverse of col, but the 0-th element stays put.
        # This corresponds to the same reversal in the frequency domain.
        # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal
        col_f = torch.cat((col_f[:1], col_f[1:].flip([0])))
    br_perm = bitreversal_permutation(n_extended,
                                      pytorch_format=True).to(col.device)
    # diag = col_f[..., br_perm]
    diag = index_last_dim(col_f, br_perm)
    if separate_diagonal:
        if not complex:
            return nn.Sequential(Real2Complex(), b_fft,
                                 Diagonal(diagonal_init=diag), b_ifft,
                                 Complex2Real())
        else:
            return nn.Sequential(b_fft, Diagonal(diagonal_init=diag), b_ifft)
    else:
        # Combine the diagonal with the last twiddle factor of b_fft
        with torch.no_grad():
            b_fft = diagonal_butterfly(b_fft,
                                       diag,
                                       diag_first=False,
                                       inplace=True)
        # Combine the b_fft and b_ifft into one Butterfly (with nblocks=2).
        b = butterfly_product(b_fft, b_ifft)
        b.in_size = n
        b.out_size = n
        return b if complex else nn.Sequential(Real2Complex(), b,
                                               Complex2Real())