def fastfood(diag1: torch.Tensor, diag2: torch.Tensor, diag3: torch.Tensor, permutation: torch.Tensor, normalized: bool = False, increasing_stride: bool = True, separate_diagonal: bool = True) -> nn.Module: """ Construct an nn.Module based on Butterfly that performs Fastfood multiplication: x -> Diag3 @ H @ Diag2 @ P @ H @ Diag1, where H is the Hadamard matrix and P is a permutation matrix. Parameters: diag1: (n,), where n is a power of 2. diag2: (n,) diag3: (n,) permutation: (n,) normalized: if True, corresponds to the orthogonal Hadamard transform (i.e. multiplied by 1/sqrt(n)) increasing_stride: whether the first Butterfly in the sequence has increasing stride. separate_diagonal: if False, the diagonal is combined into the Butterfly part. """ n, = diag1.shape assert diag2.shape == diag3.shape == permutation.shape == (n, ) h1 = hadamard(n, normalized, increasing_stride) h2 = hadamard(n, normalized, not increasing_stride) if not separate_diagonal: h1 = diagonal_butterfly(h1, diag1, diag_first=True) h2 = diagonal_butterfly(h2, diag2, diag_first=True) h2 = diagonal_butterfly(h2, diag3, diag_first=False) return nn.Sequential(h1, FixedPermutation(permutation), h2) else: return nn.Sequential(Diagonal(diagonal_init=diag1), h1, FixedPermutation(permutation), Diagonal(diagonal_init=diag2), h2, Diagonal(diagonal_init=diag3))
def hadamard_diagonal(diagonals: torch.Tensor, normalized: bool = False, increasing_stride: bool = True, separate_diagonal: bool = True) -> nn.Module: """ Construct an nn.Module based on Butterfly that performs multiplication by H D H D ... H D, where H is the Hadamard matrix and D is a diagonal matrix Parameters: diagonals: (k, n), where k is the number of diagonal matrices and n is the dimension of the Hadamard transform. normalized: if True, corresponds to the orthogonal Hadamard transform (i.e. multiplied by 1/sqrt(n)) increasing_stride: whether the returned Butterfly has increasing stride. separate_diagonal: if False, the diagonal is combined into the Butterfly part. """ k, n = diagonals.shape if not separate_diagonal: butterflies = [] for i, diagonal in enumerate(diagonals.unbind()): cur_increasing_stride = increasing_stride != (i % 2 == 1) h = hadamard(n, normalized, cur_increasing_stride) butterflies.append(diagonal_butterfly(h, diagonal, diag_first=True)) return reduce(butterfly_product, butterflies) else: modules = [] for i, diagonal in enumerate(diagonals.unbind()): modules.append(Diagonal(diagonal_init=diagonal)) cur_increasing_stride = increasing_stride != (i % 2 == 1) h = hadamard(n, normalized, cur_increasing_stride) modules.append(h) return nn.Sequential(*modules)
def acdc(diag1: torch.Tensor, diag2: torch.Tensor, dct_first: bool = True, separate_diagonal: bool = True) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs either the multiplication: x -> diag2 @ iDCT @ diag1 @ DCT @ x or x -> diag2 @ DCT @ diag1 @ iDCT @ x. In the paper [1], the math describes the 2nd type while the implementation uses the 1st type. Note that the DCT and iDCT are normalized. [1] Marcin Moczulski, Misha Denil, Jeremy Appleyard, Nando de Freitas. ACDC: A Structured Efficient Linear Layer. http://arxiv.org/abs/1511.05946 Parameters: diag1: (n,), where n is a power of 2. diag2: (n,), where n is a power of 2. dct_first: if True, uses the first type above; otherwise use the second type. separate_diagonal: if False, the diagonal is combined into the Butterfly part. """ n, = diag1.shape assert diag2.shape == (n, ) assert n == 1 << int(math.ceil(math.log2(n))), 'n must be a power of 2' # Construct the permutation before the FFT: separate the even and odd and then reverse the odd # e.g., [0, 1, 2, 3] -> [0, 2, 3, 1]. # This permutation is actually in B (not just B^T B or B B^T). This can be checked with # perm2butterfly. perm = torch.arange(n) perm = torch.cat((perm[::2], perm[1::2].flip([0]))) perm_inverse = invert(perm) br = bitreversal_permutation(n, pytorch_format=True) postprocess_diag = 2 * torch.exp(-1j * math.pi * torch.arange(0.0, n) / (2 * n)) # Normalize postprocess_diag[0] /= 2.0 postprocess_diag[1:] /= math.sqrt(2) if dct_first: b_fft = fft(n, normalized=True, br_first=False, with_br_perm=False) b_ifft = ifft(n, normalized=True, br_first=True, with_br_perm=False) b1 = diagonal_butterfly(b_fft, postprocess_diag[br], diag_first=False) b2 = diagonal_butterfly(b_ifft, postprocess_diag.conj()[br], diag_first=True) if not separate_diagonal: b1 = diagonal_butterfly(b_fft, diag1[br], diag_first=False) b2 = diagonal_butterfly(b2, diag2[perm], diag_first=False) return nn.Sequential(FixedPermutation(perm), Real2Complex(), b1, Complex2Real(), Real2Complex(), b2, Complex2Real(), FixedPermutation(perm_inverse)) else: return nn.Sequential(FixedPermutation(perm), Real2Complex(), b1, Complex2Real(), Diagonal(diagonal_init=diag1[br]), Real2Complex(), b2, Complex2Real(), Diagonal(diagonal_init=diag2[perm]), FixedPermutation(perm_inverse)) else: b_fft = fft(n, normalized=True, br_first=True, with_br_perm=False) b_ifft = ifft(n, normalized=True, br_first=False, with_br_perm=False) b1 = diagonal_butterfly(b_ifft, postprocess_diag.conj(), diag_first=True) b2 = diagonal_butterfly(b_fft, postprocess_diag, diag_first=False) if not separate_diagonal: b1 = diagonal_butterfly(b1, diag1[perm][br], diag_first=False) b2 = diagonal_butterfly(b_fft, diag2, diag_first=False) return nn.Sequential(Real2Complex(), b1, Complex2Real(), Real2Complex(), b2, Complex2Real()) else: return nn.Sequential(Real2Complex(), b1, Complex2Real(), Diagonal(diagonal_init=diag1[perm][br]), Real2Complex(), b2, Complex2Real(), Diagonal(diagonal_init=diag2))
def circulant(col, transposed=False, separate_diagonal=True) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs circulant matrix multiplication. Parameters: col: torch.Tensor of size (n, ). The first column of the circulant matrix. transposed: if True, then the circulant matrix is transposed, i.e. col is the first *row* of the matrix. separate_diagonal: if True, the returned nn.Module is Butterfly, Diagonal, Butterfly. if False, the diagonal is combined into the Butterfly part. """ assert col.dim() == 1, 'Vector col must have dimension 1' complex = col.is_complex() n = col.shape[0] log_n = int(math.ceil(math.log2(n))) # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n? # I've only figured out how to pad to size 1 << (log_n + 1). # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c] n_extended = n if n == 1 << log_n else 1 << (log_n + 1) b_fft = fft(n_extended, normalized=True, br_first=False, with_br_perm=False).to(col.device) b_fft.in_size = n b_ifft = ifft(n_extended, normalized=True, br_first=True, with_br_perm=False).to(col.device) b_ifft.out_size = n if n < n_extended: col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n))) col = torch.cat((col_0, col)) if not col.is_complex(): col = real2complex(col) # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the # circulant matrix. col_f = torch.fft.fft(col, norm=None) if transposed: # We could have just transposed the iFFT * Diag * FFT to get FFT * Diag * iFFT. # Instead we use the fact that row is the reverse of col, but the 0-th element stays put. # This corresponds to the same reversal in the frequency domain. # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal col_f = torch.cat((col_f[:1], col_f[1:].flip([0]))) br_perm = bitreversal_permutation(n_extended, pytorch_format=True).to(col.device) # diag = col_f[..., br_perm] diag = index_last_dim(col_f, br_perm) if separate_diagonal: if not complex: return nn.Sequential(Real2Complex(), b_fft, Diagonal(diagonal_init=diag), b_ifft, Complex2Real()) else: return nn.Sequential(b_fft, Diagonal(diagonal_init=diag), b_ifft) else: # Combine the diagonal with the last twiddle factor of b_fft with torch.no_grad(): b_fft = diagonal_butterfly(b_fft, diag, diag_first=False, inplace=True) # Combine the b_fft and b_ifft into one Butterfly (with nblocks=2). b = butterfly_product(b_fft, b_ifft) b.in_size = n b.out_size = n return b if complex else nn.Sequential(Real2Complex(), b, Complex2Real())