def fastfood(diag1: torch.Tensor, diag2: torch.Tensor, diag3: torch.Tensor, permutation: torch.Tensor, normalized: bool = False, increasing_stride: bool = True, separate_diagonal: bool = True) -> nn.Module: """ Construct an nn.Module based on Butterfly that performs Fastfood multiplication: x -> Diag3 @ H @ Diag2 @ P @ H @ Diag1, where H is the Hadamard matrix and P is a permutation matrix. Parameters: diag1: (n,), where n is a power of 2. diag2: (n,) diag3: (n,) permutation: (n,) normalized: if True, corresponds to the orthogonal Hadamard transform (i.e. multiplied by 1/sqrt(n)) increasing_stride: whether the first Butterfly in the sequence has increasing stride. separate_diagonal: if False, the diagonal is combined into the Butterfly part. """ n, = diag1.shape assert diag2.shape == diag3.shape == permutation.shape == (n, ) h1 = hadamard(n, normalized, increasing_stride) h2 = hadamard(n, normalized, not increasing_stride) if not separate_diagonal: h1 = diagonal_butterfly(h1, diag1, diag_first=True) h2 = diagonal_butterfly(h2, diag2, diag_first=True) h2 = diagonal_butterfly(h2, diag3, diag_first=False) return nn.Sequential(h1, FixedPermutation(permutation), h2) else: return nn.Sequential(Diagonal(diagonal_init=diag1), h1, FixedPermutation(permutation), Diagonal(diagonal_init=diag2), h2, Diagonal(diagonal_init=diag3))
def ifft2d_unitary(n1: int, n2: int, br_first: bool = True, with_br_perm: bool = True) -> nn.Module: """ Construct an nn.Module based on ButterflyUnitary that exactly performs the 2D iFFT. Corresponds to normalized=True. Does not support flatten for now. Parameters: n1: size of the iFFT on the last input dimension. Must be a power of 2. n2: size of the iFFT on the second to last input dimension. Must be a power of 2. br_first: which decomposition of iFFT. True corresponds to decimation-in-frequency. False corresponds to decimation-in-time. with_br_perm: whether to return both the butterfly and the bit reversal permutation. """ b_ifft1 = ifft_unitary(n1, br_first=br_first, with_br_perm=False) b_ifft2 = ifft_unitary(n2, br_first=br_first, with_br_perm=False) b = TensorProduct(b_ifft1, b_ifft2) if with_br_perm: br_perm1 = FixedPermutation( bitreversal_permutation(n1, pytorch_format=True)) br_perm2 = FixedPermutation( bitreversal_permutation(n2, pytorch_format=True)) br_perm = TensorProduct(br_perm1, br_perm2) return nn.Sequential(br_perm, b) if br_first else nn.Sequential( b, br_perm) else: return b
def dct(n: int, type: int = 2, normalized: bool = False) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs the DCT. Parameters: n: size of the DCT. Must be a power of 2. type: either 2, 3, or 4. These are the only types supported. See scipy.fft.dct's notes. normalized: if True, corresponds to the orthogonal DCT (see scipy.fft.dct's notes) """ assert type in [2, 3, 4] # Construct the permutation before the FFT: separate the even and odd and then reverse the odd # e.g., [0, 1, 2, 3] -> [0, 2, 3, 1]. perm = torch.arange(n) perm = torch.cat((perm[::2], perm[1::2].flip([0]))) br = bitreversal_permutation(n, pytorch_format=True) postprocess_diag = 2 * torch.exp(-1j * math.pi * torch.arange(0.0, n) / (2 * n)) if type in [2, 4]: b = fft(n, normalized=normalized, br_first=True, with_br_perm=False) if type == 4: even_mul = torch.exp(-1j * math.pi / (2 * n) * (torch.arange(0.0, n, 2) + 0.5)) odd_mul = torch.exp(1j * math.pi / (2 * n) * (torch.arange(1.0, n, 2) + 0.5)) preprocess_diag = torch.stack((even_mul, odd_mul), dim=-1).flatten() # This proprocess_diag is before the permutation. # To move it after the permutation, we have to permute the diagonal b = diagonal_butterfly(b, preprocess_diag[perm[br]], diag_first=True) if normalized: if type in [2, 3]: postprocess_diag[0] /= 2.0 postprocess_diag[1:] /= math.sqrt(2) elif type == 4: postprocess_diag /= math.sqrt(2) b = diagonal_butterfly(b, postprocess_diag, diag_first=False) return nn.Sequential(FixedPermutation(perm[br]), Real2Complex(), b, Complex2Real()) else: assert type == 3 b = ifft(n, normalized=normalized, br_first=False, with_br_perm=False) postprocess_diag[0] /= 2.0 if normalized: postprocess_diag[1:] /= math.sqrt(2) else: # We want iFFT with the scaling of 1.0 instead of 1 / n with torch.no_grad(): b.twiddle *= 2 b = diagonal_butterfly(b, postprocess_diag.conj(), diag_first=True) perm_inverse = invert(perm) return nn.Sequential(Real2Complex(), b, Complex2Real(), FixedPermutation(br[perm_inverse]))
def wavelet_haar(n, with_perm=True) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs the multilevel discrete wavelet transform with the Haar wavelet. Parameters: n: size of the discrete wavelet transform. Must be a power of 2. with_perm: whether to return both the butterfly and the wavelet rearrangement permutation. """ log_n = int(math.ceil(math.log2(n))) assert n == 1 << log_n, 'n must be a power of 2' factors = [] for log_size in range(1, log_n + 1): size = 1 << log_size factor = torch.tensor([[1, 1], [1, -1]], dtype=torch.float).reshape( 1, 2, 2) / math.sqrt(2) identity = torch.eye(2).reshape(1, 2, 2) num_identity = size // 2 - 1 twiddle_factor = torch.cat( (factor, identity.expand(num_identity, 2, 2))) factors.append(twiddle_factor.repeat(n // size, 1, 1)) twiddle = torch.stack(factors, dim=0).unsqueeze(0).unsqueeze(0) b = Butterfly(n, n, bias=False, increasing_stride=True, init=twiddle) if with_perm: perm = FixedPermutation(wavelet_permutation(n, pytorch_format=True)) return nn.Sequential(b, perm) else: return b
def flip_increasing_stride(butterfly: Butterfly) -> nn.Module: """Convert a Butterfly with increasing_stride=True/False to a Butterfly with increasing_stride=False/True, along with 2 bit-reversal permutations. Follows the proof of Lemma G.4. """ assert butterfly.bias is None assert butterfly.in_size == 1 << butterfly.log_n assert butterfly.out_size == 1 << butterfly.log_n n = butterfly.in_size new_butterfly = copy.deepcopy(butterfly) new_butterfly.increasing_stride = not butterfly.increasing_stride br = bitreversal_permutation(n, pytorch_format=True) br_half = bitreversal_permutation(n // 2, pytorch_format=True) with torch.no_grad(): new_butterfly.twiddle.copy_(new_butterfly.twiddle[:, :, :, br_half]) return nn.Sequential(FixedPermutation(br), new_butterfly, FixedPermutation(br))
def permutation_kronecker(perm1: FixedPermutation, perm2: FixedPermutation) -> FixedPermutation: """Combine two permutations of size n1 and n2 into their Kronecker product of size n1 * n2. """ n1, n2 = perm1.permutation.shape[-1], perm2.permutation.shape[-1] x = torch.arange(n2 * n1, device=perm1.permutation.device).reshape(n2, n1) perm = perm2(perm1(x).t()).t().reshape(-1) return FixedPermutation(perm)
def ifft2d(n1: int, n2: int, normalized: bool = False, br_first: bool = True, with_br_perm: bool = True, flatten=False) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs the 2D iFFT. Parameters: n1: size of the iFFT on the last input dimension. Must be a power of 2. n2: size of the iFFT on the second to last input dimension. Must be a power of 2. normalized: if True, corresponds to the unitary iFFT (i.e. multiplied by 1/sqrt(n)) br_first: which decomposition of iFFT. True corresponds to decimation-in-frequency. False corresponds to decimation-in-time. with_br_perm: whether to return both the butterfly and the bit reversal permutation. flatten: whether to combine the 2 butterflies into 1 with Kronecker product. """ b_ifft1 = ifft(n1, normalized=normalized, br_first=br_first, with_br_perm=False) b_ifft2 = ifft(n2, normalized=normalized, br_first=br_first, with_br_perm=False) b = TensorProduct(b_ifft1, b_ifft2) if not flatten else butterfly_kronecker( b_ifft1, b_ifft2) if with_br_perm: br_perm1 = FixedPermutation( bitreversal_permutation(n1, pytorch_format=True)) br_perm2 = FixedPermutation( bitreversal_permutation(n2, pytorch_format=True)) br_perm = (TensorProduct(br_perm1, br_perm2) if not flatten else permutation_kronecker(br_perm1, br_perm2)) if not flatten: return nn.Sequential(br_perm, b) if br_first else nn.Sequential( b, br_perm) else: return (nn.Sequential(nn.Flatten( start_dim=-2), br_perm, b, nn.Unflatten(-1, (n2, n1))) if br_first else nn.Sequential(nn.Flatten( start_dim=-2), b, br_perm, nn.Unflatten(-1, (n2, n1)))) else: return b if not flatten else nn.Sequential(nn.Flatten( start_dim=-2), b, nn.Unflatten(-1, (n2, n1)))
def ifft(n, normalized=False, br_first=True, with_br_perm=True) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs the inverse FFT. Parameters: n: size of the iFFT. Must be a power of 2. normalized: if True, corresponds to unitary iFFT (i.e. multiplied by 1/sqrt(n), not 1/n) br_first: which decomposition of iFFT. True corresponds to decimation-in-frequency. False corresponds to decimation-in-time. with_br_perm: whether to return both the butterfly and the bit reversal permutation. """ log_n = int(math.ceil(math.log2(n))) assert n == 1 << log_n, 'n must be a power of 2' factors = [] for log_size in range(1, log_n + 1): size = 1 << log_size exp = torch.exp(2j * math.pi * torch.arange(0.0, size // 2) / size) o = torch.ones_like(exp) twiddle_factor = torch.stack((torch.stack( (o, exp), dim=-1), torch.stack((o, -exp), dim=-1)), dim=-2) factors.append(twiddle_factor.repeat(n // size, 1, 1)) twiddle = torch.stack(factors, dim=0).unsqueeze(0).unsqueeze(0) if not br_first: # Take conjugate transpose of the BP decomposition of fft twiddle = twiddle.transpose(-1, -2).flip([2]) # Divide the whole transform by sqrt(n) by dividing each factor by n^(1/2 log_n) = sqrt(2) if normalized: twiddle /= math.sqrt(2) else: twiddle /= 2 b = Butterfly(n, n, bias=False, complex=True, increasing_stride=br_first, init=twiddle) if with_br_perm: br_perm = FixedPermutation( bitreversal_permutation(n, pytorch_format=True)) return nn.Sequential(br_perm, b) if br_first else nn.Sequential( b, br_perm) else: return b
def fft_unitary(n, br_first=True, with_br_perm=True) -> nn.Module: """ Construct an nn.Module based on ButterflyUnitary that exactly performs the FFT. Since it's unitary, it corresponds to normalized=True. Parameters: n: size of the FFT. Must be a power of 2. br_first: which decomposition of FFT. br_first=True corresponds to decimation-in-time. br_first=False corresponds to decimation-in-frequency. with_br_perm: whether to return both the butterfly and the bit reversal permutation. """ log_n = int(math.ceil(math.log2(n))) assert n == 1 << log_n, 'n must be a power of 2' factors = [] for log_size in range(1, log_n + 1): size = 1 << log_size angle = -2 * math.pi * torch.arange(0.0, size // 2) / size phi = torch.ones_like(angle) * math.pi / 4 alpha = angle / 2 + math.pi / 2 psi = -angle / 2 - math.pi / 2 if br_first: chi = angle / 2 - math.pi / 2 else: # Take conjugate transpose of the BP decomposition of ifft, which works out to this, # plus the flip later. chi = -angle / 2 - math.pi / 2 twiddle_factor = torch.stack([phi, alpha, psi, chi], dim=-1) factors.append(twiddle_factor.repeat(n // size, 1)) twiddle = torch.stack(factors, dim=0).unsqueeze(0).unsqueeze(0) if not br_first: twiddle = twiddle.flip([2]) b = ButterflyUnitary(n, n, bias=False, increasing_stride=br_first) with torch.no_grad(): b.twiddle.copy_(twiddle) if with_br_perm: br_perm = FixedPermutation( bitreversal_permutation(n, pytorch_format=True)) return nn.Sequential(br_perm, b) if br_first else nn.Sequential( b, br_perm) else: return b
def acdc(diag1: torch.Tensor, diag2: torch.Tensor, dct_first: bool = True, separate_diagonal: bool = True) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs either the multiplication: x -> diag2 @ iDCT @ diag1 @ DCT @ x or x -> diag2 @ DCT @ diag1 @ iDCT @ x. In the paper [1], the math describes the 2nd type while the implementation uses the 1st type. Note that the DCT and iDCT are normalized. [1] Marcin Moczulski, Misha Denil, Jeremy Appleyard, Nando de Freitas. ACDC: A Structured Efficient Linear Layer. http://arxiv.org/abs/1511.05946 Parameters: diag1: (n,), where n is a power of 2. diag2: (n,), where n is a power of 2. dct_first: if True, uses the first type above; otherwise use the second type. separate_diagonal: if False, the diagonal is combined into the Butterfly part. """ n, = diag1.shape assert diag2.shape == (n, ) assert n == 1 << int(math.ceil(math.log2(n))), 'n must be a power of 2' # Construct the permutation before the FFT: separate the even and odd and then reverse the odd # e.g., [0, 1, 2, 3] -> [0, 2, 3, 1]. # This permutation is actually in B (not just B^T B or B B^T). This can be checked with # perm2butterfly. perm = torch.arange(n) perm = torch.cat((perm[::2], perm[1::2].flip([0]))) perm_inverse = invert(perm) br = bitreversal_permutation(n, pytorch_format=True) postprocess_diag = 2 * torch.exp(-1j * math.pi * torch.arange(0.0, n) / (2 * n)) # Normalize postprocess_diag[0] /= 2.0 postprocess_diag[1:] /= math.sqrt(2) if dct_first: b_fft = fft(n, normalized=True, br_first=False, with_br_perm=False) b_ifft = ifft(n, normalized=True, br_first=True, with_br_perm=False) b1 = diagonal_butterfly(b_fft, postprocess_diag[br], diag_first=False) b2 = diagonal_butterfly(b_ifft, postprocess_diag.conj()[br], diag_first=True) if not separate_diagonal: b1 = diagonal_butterfly(b_fft, diag1[br], diag_first=False) b2 = diagonal_butterfly(b2, diag2[perm], diag_first=False) return nn.Sequential(FixedPermutation(perm), Real2Complex(), b1, Complex2Real(), Real2Complex(), b2, Complex2Real(), FixedPermutation(perm_inverse)) else: return nn.Sequential(FixedPermutation(perm), Real2Complex(), b1, Complex2Real(), Diagonal(diagonal_init=diag1[br]), Real2Complex(), b2, Complex2Real(), Diagonal(diagonal_init=diag2[perm]), FixedPermutation(perm_inverse)) else: b_fft = fft(n, normalized=True, br_first=True, with_br_perm=False) b_ifft = ifft(n, normalized=True, br_first=False, with_br_perm=False) b1 = diagonal_butterfly(b_ifft, postprocess_diag.conj(), diag_first=True) b2 = diagonal_butterfly(b_fft, postprocess_diag, diag_first=False) if not separate_diagonal: b1 = diagonal_butterfly(b1, diag1[perm][br], diag_first=False) b2 = diagonal_butterfly(b_fft, diag2, diag_first=False) return nn.Sequential(Real2Complex(), b1, Complex2Real(), Real2Complex(), b2, Complex2Real()) else: return nn.Sequential(Real2Complex(), b1, Complex2Real(), Diagonal(diagonal_init=diag1[perm][br]), Real2Complex(), b2, Complex2Real(), Diagonal(diagonal_init=diag2))