Exemple #1
0
def fastfood(diag1: torch.Tensor,
             diag2: torch.Tensor,
             diag3: torch.Tensor,
             permutation: torch.Tensor,
             normalized: bool = False,
             increasing_stride: bool = True,
             separate_diagonal: bool = True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that performs Fastfood multiplication:
    x -> Diag3 @ H @ Diag2 @ P @ H @ Diag1,
    where H is the Hadamard matrix and P is a permutation matrix.
    Parameters:
        diag1: (n,), where n is a power of 2.
        diag2: (n,)
        diag3: (n,)
        permutation: (n,)
        normalized: if True, corresponds to the orthogonal Hadamard transform
                    (i.e. multiplied by 1/sqrt(n))
        increasing_stride: whether the first Butterfly in the sequence has increasing stride.
        separate_diagonal: if False, the diagonal is combined into the Butterfly part.
    """
    n, = diag1.shape
    assert diag2.shape == diag3.shape == permutation.shape == (n, )
    h1 = hadamard(n, normalized, increasing_stride)
    h2 = hadamard(n, normalized, not increasing_stride)
    if not separate_diagonal:
        h1 = diagonal_butterfly(h1, diag1, diag_first=True)
        h2 = diagonal_butterfly(h2, diag2, diag_first=True)
        h2 = diagonal_butterfly(h2, diag3, diag_first=False)
        return nn.Sequential(h1, FixedPermutation(permutation), h2)
    else:
        return nn.Sequential(Diagonal(diagonal_init=diag1), h1,
                             FixedPermutation(permutation),
                             Diagonal(diagonal_init=diag2), h2,
                             Diagonal(diagonal_init=diag3))
Exemple #2
0
def ifft2d_unitary(n1: int,
                   n2: int,
                   br_first: bool = True,
                   with_br_perm: bool = True) -> nn.Module:
    """ Construct an nn.Module based on ButterflyUnitary that exactly performs the 2D iFFT.
    Corresponds to normalized=True.
    Does not support flatten for now.
    Parameters:
        n1: size of the iFFT on the last input dimension. Must be a power of 2.
        n2: size of the iFFT on the second to last input dimension. Must be a power of 2.
        br_first: which decomposition of iFFT. True corresponds to decimation-in-frequency.
                  False corresponds to decimation-in-time.
        with_br_perm: whether to return both the butterfly and the bit reversal permutation.
    """
    b_ifft1 = ifft_unitary(n1, br_first=br_first, with_br_perm=False)
    b_ifft2 = ifft_unitary(n2, br_first=br_first, with_br_perm=False)
    b = TensorProduct(b_ifft1, b_ifft2)
    if with_br_perm:
        br_perm1 = FixedPermutation(
            bitreversal_permutation(n1, pytorch_format=True))
        br_perm2 = FixedPermutation(
            bitreversal_permutation(n2, pytorch_format=True))
        br_perm = TensorProduct(br_perm1, br_perm2)
        return nn.Sequential(br_perm, b) if br_first else nn.Sequential(
            b, br_perm)
    else:
        return b
Exemple #3
0
def dct(n: int, type: int = 2, normalized: bool = False) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs the DCT.
    Parameters:
        n: size of the DCT. Must be a power of 2.
        type: either 2, 3, or  4. These are the only types supported. See scipy.fft.dct's notes.
        normalized: if True, corresponds to the orthogonal DCT (see scipy.fft.dct's notes)
    """
    assert type in [2, 3, 4]
    # Construct the permutation before the FFT: separate the even and odd and then reverse the odd
    # e.g., [0, 1, 2, 3] -> [0, 2, 3, 1].
    perm = torch.arange(n)
    perm = torch.cat((perm[::2], perm[1::2].flip([0])))
    br = bitreversal_permutation(n, pytorch_format=True)
    postprocess_diag = 2 * torch.exp(-1j * math.pi * torch.arange(0.0, n) /
                                     (2 * n))
    if type in [2, 4]:
        b = fft(n, normalized=normalized, br_first=True, with_br_perm=False)
        if type == 4:
            even_mul = torch.exp(-1j * math.pi / (2 * n) *
                                 (torch.arange(0.0, n, 2) + 0.5))
            odd_mul = torch.exp(1j * math.pi / (2 * n) *
                                (torch.arange(1.0, n, 2) + 0.5))
            preprocess_diag = torch.stack((even_mul, odd_mul),
                                          dim=-1).flatten()
            # This proprocess_diag is before the permutation.
            # To move it after the permutation, we have to permute the diagonal
            b = diagonal_butterfly(b,
                                   preprocess_diag[perm[br]],
                                   diag_first=True)
        if normalized:
            if type in [2, 3]:
                postprocess_diag[0] /= 2.0
                postprocess_diag[1:] /= math.sqrt(2)
            elif type == 4:
                postprocess_diag /= math.sqrt(2)
        b = diagonal_butterfly(b, postprocess_diag, diag_first=False)
        return nn.Sequential(FixedPermutation(perm[br]), Real2Complex(), b,
                             Complex2Real())
    else:
        assert type == 3
        b = ifft(n, normalized=normalized, br_first=False, with_br_perm=False)
        postprocess_diag[0] /= 2.0
        if normalized:
            postprocess_diag[1:] /= math.sqrt(2)
        else:
            # We want iFFT with the scaling of 1.0 instead of 1 / n
            with torch.no_grad():
                b.twiddle *= 2
        b = diagonal_butterfly(b, postprocess_diag.conj(), diag_first=True)
        perm_inverse = invert(perm)
        return nn.Sequential(Real2Complex(), b, Complex2Real(),
                             FixedPermutation(br[perm_inverse]))
Exemple #4
0
def wavelet_haar(n, with_perm=True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs the multilevel discrete
    wavelet transform with the Haar wavelet.
    Parameters:
        n: size of the discrete wavelet transform. Must be a power of 2.
        with_perm: whether to return both the butterfly and the wavelet rearrangement permutation.
    """
    log_n = int(math.ceil(math.log2(n)))
    assert n == 1 << log_n, 'n must be a power of 2'
    factors = []
    for log_size in range(1, log_n + 1):
        size = 1 << log_size
        factor = torch.tensor([[1, 1], [1, -1]], dtype=torch.float).reshape(
            1, 2, 2) / math.sqrt(2)
        identity = torch.eye(2).reshape(1, 2, 2)
        num_identity = size // 2 - 1
        twiddle_factor = torch.cat(
            (factor, identity.expand(num_identity, 2, 2)))
        factors.append(twiddle_factor.repeat(n // size, 1, 1))
    twiddle = torch.stack(factors, dim=0).unsqueeze(0).unsqueeze(0)
    b = Butterfly(n, n, bias=False, increasing_stride=True, init=twiddle)
    if with_perm:
        perm = FixedPermutation(wavelet_permutation(n, pytorch_format=True))
        return nn.Sequential(b, perm)
    else:
        return b
Exemple #5
0
def flip_increasing_stride(butterfly: Butterfly) -> nn.Module:
    """Convert a Butterfly with increasing_stride=True/False to a Butterfly with
    increasing_stride=False/True, along with 2 bit-reversal permutations.
    Follows the proof of Lemma G.4.
    """
    assert butterfly.bias is None
    assert butterfly.in_size == 1 << butterfly.log_n
    assert butterfly.out_size == 1 << butterfly.log_n
    n = butterfly.in_size
    new_butterfly = copy.deepcopy(butterfly)
    new_butterfly.increasing_stride = not butterfly.increasing_stride
    br = bitreversal_permutation(n, pytorch_format=True)
    br_half = bitreversal_permutation(n // 2, pytorch_format=True)
    with torch.no_grad():
        new_butterfly.twiddle.copy_(new_butterfly.twiddle[:, :, :, br_half])
    return nn.Sequential(FixedPermutation(br), new_butterfly,
                         FixedPermutation(br))
Exemple #6
0
def permutation_kronecker(perm1: FixedPermutation,
                          perm2: FixedPermutation) -> FixedPermutation:
    """Combine two permutations of size n1 and n2 into their Kronecker product of size n1 * n2.
    """
    n1, n2 = perm1.permutation.shape[-1], perm2.permutation.shape[-1]
    x = torch.arange(n2 * n1, device=perm1.permutation.device).reshape(n2, n1)
    perm = perm2(perm1(x).t()).t().reshape(-1)
    return FixedPermutation(perm)
Exemple #7
0
def ifft2d(n1: int,
           n2: int,
           normalized: bool = False,
           br_first: bool = True,
           with_br_perm: bool = True,
           flatten=False) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs the 2D iFFT.
    Parameters:
        n1: size of the iFFT on the last input dimension. Must be a power of 2.
        n2: size of the iFFT on the second to last input dimension. Must be a power of 2.
        normalized: if True, corresponds to the unitary iFFT (i.e. multiplied by 1/sqrt(n))
        br_first: which decomposition of iFFT. True corresponds to decimation-in-frequency.
                  False corresponds to decimation-in-time.
        with_br_perm: whether to return both the butterfly and the bit reversal permutation.
        flatten: whether to combine the 2 butterflies into 1 with Kronecker product.
    """
    b_ifft1 = ifft(n1,
                   normalized=normalized,
                   br_first=br_first,
                   with_br_perm=False)
    b_ifft2 = ifft(n2,
                   normalized=normalized,
                   br_first=br_first,
                   with_br_perm=False)
    b = TensorProduct(b_ifft1,
                      b_ifft2) if not flatten else butterfly_kronecker(
                          b_ifft1, b_ifft2)
    if with_br_perm:
        br_perm1 = FixedPermutation(
            bitreversal_permutation(n1, pytorch_format=True))
        br_perm2 = FixedPermutation(
            bitreversal_permutation(n2, pytorch_format=True))
        br_perm = (TensorProduct(br_perm1, br_perm2) if not flatten else
                   permutation_kronecker(br_perm1, br_perm2))
        if not flatten:
            return nn.Sequential(br_perm, b) if br_first else nn.Sequential(
                b, br_perm)
        else:
            return (nn.Sequential(nn.Flatten(
                start_dim=-2), br_perm, b, nn.Unflatten(-1, (n2, n1)))
                    if br_first else nn.Sequential(nn.Flatten(
                        start_dim=-2), b, br_perm, nn.Unflatten(-1, (n2, n1))))
    else:
        return b if not flatten else nn.Sequential(nn.Flatten(
            start_dim=-2), b, nn.Unflatten(-1, (n2, n1)))
Exemple #8
0
def ifft(n, normalized=False, br_first=True, with_br_perm=True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs the inverse FFT.
    Parameters:
        n: size of the iFFT. Must be a power of 2.
        normalized: if True, corresponds to unitary iFFT (i.e. multiplied by 1/sqrt(n), not 1/n)
        br_first: which decomposition of iFFT. True corresponds to decimation-in-frequency.
                  False corresponds to decimation-in-time.
        with_br_perm: whether to return both the butterfly and the bit reversal permutation.
    """
    log_n = int(math.ceil(math.log2(n)))
    assert n == 1 << log_n, 'n must be a power of 2'
    factors = []
    for log_size in range(1, log_n + 1):
        size = 1 << log_size
        exp = torch.exp(2j * math.pi * torch.arange(0.0, size // 2) / size)
        o = torch.ones_like(exp)
        twiddle_factor = torch.stack((torch.stack(
            (o, exp), dim=-1), torch.stack((o, -exp), dim=-1)),
                                     dim=-2)
        factors.append(twiddle_factor.repeat(n // size, 1, 1))
    twiddle = torch.stack(factors, dim=0).unsqueeze(0).unsqueeze(0)
    if not br_first:  # Take conjugate transpose of the BP decomposition of fft
        twiddle = twiddle.transpose(-1, -2).flip([2])
    # Divide the whole transform by sqrt(n) by dividing each factor by n^(1/2 log_n) = sqrt(2)
    if normalized:
        twiddle /= math.sqrt(2)
    else:
        twiddle /= 2
    b = Butterfly(n,
                  n,
                  bias=False,
                  complex=True,
                  increasing_stride=br_first,
                  init=twiddle)
    if with_br_perm:
        br_perm = FixedPermutation(
            bitreversal_permutation(n, pytorch_format=True))
        return nn.Sequential(br_perm, b) if br_first else nn.Sequential(
            b, br_perm)
    else:
        return b
Exemple #9
0
def fft_unitary(n, br_first=True, with_br_perm=True) -> nn.Module:
    """ Construct an nn.Module based on ButterflyUnitary that exactly performs the FFT.
    Since it's unitary, it corresponds to normalized=True.
    Parameters:
        n: size of the FFT. Must be a power of 2.
        br_first: which decomposition of FFT. br_first=True corresponds to decimation-in-time.
                  br_first=False corresponds to decimation-in-frequency.
        with_br_perm: whether to return both the butterfly and the bit reversal permutation.
    """
    log_n = int(math.ceil(math.log2(n)))
    assert n == 1 << log_n, 'n must be a power of 2'
    factors = []
    for log_size in range(1, log_n + 1):
        size = 1 << log_size
        angle = -2 * math.pi * torch.arange(0.0, size // 2) / size
        phi = torch.ones_like(angle) * math.pi / 4
        alpha = angle / 2 + math.pi / 2
        psi = -angle / 2 - math.pi / 2
        if br_first:
            chi = angle / 2 - math.pi / 2
        else:
            # Take conjugate transpose of the BP decomposition of ifft, which works out to this,
            # plus the flip later.
            chi = -angle / 2 - math.pi / 2
        twiddle_factor = torch.stack([phi, alpha, psi, chi], dim=-1)
        factors.append(twiddle_factor.repeat(n // size, 1))
    twiddle = torch.stack(factors, dim=0).unsqueeze(0).unsqueeze(0)
    if not br_first:
        twiddle = twiddle.flip([2])
    b = ButterflyUnitary(n, n, bias=False, increasing_stride=br_first)
    with torch.no_grad():
        b.twiddle.copy_(twiddle)
    if with_br_perm:
        br_perm = FixedPermutation(
            bitreversal_permutation(n, pytorch_format=True))
        return nn.Sequential(br_perm, b) if br_first else nn.Sequential(
            b, br_perm)
    else:
        return b
Exemple #10
0
def acdc(diag1: torch.Tensor,
         diag2: torch.Tensor,
         dct_first: bool = True,
         separate_diagonal: bool = True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs either the multiplication:
        x -> diag2 @ iDCT @ diag1 @ DCT @ x
    or
        x -> diag2 @ DCT @ diag1 @ iDCT @ x.
    In the paper [1], the math describes the 2nd type while the implementation uses the 1st type.
    Note that the DCT and iDCT are normalized.
    [1] Marcin Moczulski, Misha Denil, Jeremy Appleyard, Nando de Freitas.
    ACDC: A Structured Efficient Linear Layer.
    http://arxiv.org/abs/1511.05946
    Parameters:
        diag1: (n,), where n is a power of 2.
        diag2: (n,), where n is a power of 2.
        dct_first: if True, uses the first type above; otherwise use the second type.
        separate_diagonal: if False, the diagonal is combined into the Butterfly part.
    """
    n, = diag1.shape
    assert diag2.shape == (n, )
    assert n == 1 << int(math.ceil(math.log2(n))), 'n must be a power of 2'
    # Construct the permutation before the FFT: separate the even and odd and then reverse the odd
    # e.g., [0, 1, 2, 3] -> [0, 2, 3, 1].
    # This permutation is actually in B (not just B^T B or B B^T). This can be checked with
    # perm2butterfly.
    perm = torch.arange(n)
    perm = torch.cat((perm[::2], perm[1::2].flip([0])))
    perm_inverse = invert(perm)
    br = bitreversal_permutation(n, pytorch_format=True)
    postprocess_diag = 2 * torch.exp(-1j * math.pi * torch.arange(0.0, n) /
                                     (2 * n))
    # Normalize
    postprocess_diag[0] /= 2.0
    postprocess_diag[1:] /= math.sqrt(2)
    if dct_first:
        b_fft = fft(n, normalized=True, br_first=False, with_br_perm=False)
        b_ifft = ifft(n, normalized=True, br_first=True, with_br_perm=False)
        b1 = diagonal_butterfly(b_fft, postprocess_diag[br], diag_first=False)
        b2 = diagonal_butterfly(b_ifft,
                                postprocess_diag.conj()[br],
                                diag_first=True)
        if not separate_diagonal:
            b1 = diagonal_butterfly(b_fft, diag1[br], diag_first=False)
            b2 = diagonal_butterfly(b2, diag2[perm], diag_first=False)
            return nn.Sequential(FixedPermutation(perm), Real2Complex(), b1,
                                 Complex2Real(), Real2Complex(), b2,
                                 Complex2Real(),
                                 FixedPermutation(perm_inverse))
        else:
            return nn.Sequential(FixedPermutation(perm), Real2Complex(), b1,
                                 Complex2Real(),
                                 Diagonal(diagonal_init=diag1[br]),
                                 Real2Complex(), b2, Complex2Real(),
                                 Diagonal(diagonal_init=diag2[perm]),
                                 FixedPermutation(perm_inverse))
    else:
        b_fft = fft(n, normalized=True, br_first=True, with_br_perm=False)
        b_ifft = ifft(n, normalized=True, br_first=False, with_br_perm=False)
        b1 = diagonal_butterfly(b_ifft,
                                postprocess_diag.conj(),
                                diag_first=True)
        b2 = diagonal_butterfly(b_fft, postprocess_diag, diag_first=False)
        if not separate_diagonal:
            b1 = diagonal_butterfly(b1, diag1[perm][br], diag_first=False)
            b2 = diagonal_butterfly(b_fft, diag2, diag_first=False)
            return nn.Sequential(Real2Complex(), b1, Complex2Real(),
                                 Real2Complex(), b2, Complex2Real())
        else:
            return nn.Sequential(Real2Complex(), b1, Complex2Real(),
                                 Diagonal(diagonal_init=diag1[perm][br]),
                                 Real2Complex(), b2, Complex2Real(),
                                 Diagonal(diagonal_init=diag2))