def test_fft2d_init(self): batch_size = 10 in_channels = 3 out_channels = 4 n1, n2 = 16, 32 input = torch.randn(batch_size, in_channels, n2, n1) for kernel_size1 in [1, 3, 5, 7]: for kernel_size2 in [1, 3, 5, 7]: padding1 = (kernel_size1 - 1) // 2 padding2 = (kernel_size2 - 1) // 2 conv = nn.Conv2d(in_channels, out_channels, (kernel_size2, kernel_size1), padding=(padding2, padding1), padding_mode='circular', bias=False) out_torch = conv(input) weight = conv.weight w = F.pad(weight.flip(dims=(-1, )), (0, n1 - kernel_size1)).roll(-padding1, dims=-1) w = F.pad(w.flip(dims=(-2, )), (0, 0, 0, n2 - kernel_size2)).roll(-padding2, dims=-2) increasing_strides = [False, False, True] inits = ['fft_no_br', 'fft_no_br', 'ifft_no_br'] for nblocks in [1, 2, 3]: Kd, K1, K2 = [ TensorProduct( Butterfly(n1, n1, bias=False, complex=complex, increasing_stride=incstride, init=i, nblocks=nblocks), Butterfly(n2, n2, bias=False, complex=complex, increasing_stride=incstride, init=i, nblocks=nblocks)) for incstride, i in zip(increasing_strides, inits) ] with torch.no_grad(): Kd.map1 *= math.sqrt(n1) Kd.map2 *= math.sqrt(n2) out = K2( complex_matmul( K1(real2complex(input)).permute(2, 3, 0, 1), Kd(real2complex(w)).permute(2, 3, 1, 0)).permute( 2, 3, 0, 1)).real self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol))
def perm2butterfly_slow(v: Union[np.ndarray, torch.Tensor], complex: bool = False, increasing_stride: bool = False) -> Butterfly: """ Convert a permutation to a Butterfly that performs the same permutation. This implementation is slower but follows the proofs in Appendix G more closely. Parameter: v: a permutation, stored as a vector, in left-multiplication format. (i.e., applying v to a vector x is equivalent to x[p]) complex: whether the Butterfly is complex or real. increasing_stride: whether the returned Butterfly should have increasing_stride=False or True. False corresponds to Lemma G.3 and True corresponds to Lemma G.6. Return: b: a Butterfly that performs the same permutation as v. """ if isinstance(v, torch.Tensor): v = v.detach().cpu().numpy() n = len(v) log_n = int(math.ceil(math.log2(n))) if n < 1 << log_n: # Pad permutation to the next power-of-2 size v = np.concatenate([v, np.arange(n, 1 << log_n)]) if increasing_stride: # Follow proof of Lemma G.6 br = bitreversal_permutation(1 << log_n) b = perm2butterfly_slow(br[v[br]], complex=complex, increasing_stride=False) b.increasing_stride = True br_half = bitreversal_permutation((1 << log_n) // 2, pytorch_format=True) with torch.no_grad(): b.twiddle.copy_(b.twiddle[:, :, :, br_half]) b.in_size = b.out_size = n return b # modular_balance expects right-multiplication format so we convert the format of v. Rinv_perms, L_vec = modular_balance(invert(v)) L_perms = list(reversed(modular_balanced_to_butterfly_factor(L_vec))) R_perms = [ perm_vec_to_mat(invert(p), left=True) for p in reversed(Rinv_perms) ] # Stored in increasing_stride=True twiddle format. # Need to take transpose because the matrices are in right-multiplication format. L_twiddle = torch.stack([ matrix_to_butterfly_factor(l.T, log_k=i + 1, pytorch_format=True) for i, l in enumerate(L_perms) ]) # Stored in increasing_stride=False twiddle format so we need to flip the order R_twiddle = torch.stack([ matrix_to_butterfly_factor(r, log_k=i + 1, pytorch_format=True) for i, r in enumerate(R_perms) ]).flip([0]) twiddle = torch.stack([R_twiddle, L_twiddle]).unsqueeze(0) b = Butterfly(n, n, bias=False, complex=complex, increasing_stride=False, init=twiddle if not complex else real2complex(twiddle), nblocks=2) return b
def conv1d_circular_multichannel(n, weight) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs nn.Conv1d with multiple in/out channels, with circular padding. The output of nn.Conv1d must have the same size as the input (i.e. kernel size must be 2k + 1, and padding k for some integer k). Parameters: n: size of the input. weight: torch.Tensor of size (out_channels, in_channels, kernel_size). Kernel_size must be odd, and smaller than n. Padding is assumed to be (kernel_size - 1) // 2. """ assert weight.dim() == 3, 'Weight must have dimension 3' kernel_size = weight.shape[-1] assert kernel_size < n assert kernel_size % 2 == 1, 'Kernel size must be odd' out_channels, in_channels = weight.shape[:2] padding = (kernel_size - 1) // 2 col = F.pad(weight.flip([-1]), (0, n - kernel_size)).roll(-padding, dims=-1) # From here we mimic the circulant construction, but the diagonal multiply is replaced with # multiply and then sum across the in-channels. complex = col.is_complex() log_n = int(math.ceil(math.log2(n))) # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n? # I've only figured out how to pad to size 1 << (log_n + 1). # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c] n_extended = n if n == 1 << log_n else 1 << (log_n + 1) b_fft = fft(n_extended, normalized=True, br_first=False, with_br_perm=False).to(col.device) b_fft.in_size = n b_ifft = ifft(n_extended, normalized=True, br_first=True, with_br_perm=False).to(col.device) b_ifft.out_size = n if n < n_extended: col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n))) col = torch.cat((col_0, col), dim=-1) if not col.is_complex(): col = real2complex(col) # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the # circulant matrix. col_f = view_as_complex( torch.fft(view_as_real(col), signal_ndim=1, normalized=False)) br_perm = bitreversal_permutation(n_extended, pytorch_format=True).to(col.device) col_f = col_f[..., br_perm] # We just want (input_f.unsqueeze(1) * col_f).sum(dim=2). # This can be written as matrix multiply but Pytorch 1.6 doesn't yet support complex matrix # multiply. if not complex: return nn.Sequential(Real2Complex(), b_fft, DiagonalMultiplySum(col_f), b_ifft, Complex2Real()) else: return nn.Sequential(b_fft, DiagonalMultiplySum(col_f), b_ifft)
def perm2butterfly(v: Union[np.ndarray, torch.Tensor], complex: bool = False, increasing_stride: bool = False) -> Butterfly: """ Parameter: v: a permutation, stored as a vector, in left-multiplication format. (i.e., applying v to a vector x is equivalent to x[p]) complex: whether the Butterfly is complex or real. increasing_stride: whether the returned Butterfly should have increasing_stride=False or True. False corresponds to Lemma G.3 and True corresponds to Lemma G.6. Return: b: a Butterfly that performs the same permutation as v. """ if isinstance(v, torch.Tensor): v = v.detach().cpu().numpy() n = len(v) log_n = int(math.ceil(math.log2(n))) if n < 1 << log_n: # Pad permutation to the next power-of-2 size v = np.concatenate([v, np.arange(n, 1 << log_n)]) if increasing_stride: # Follow proof of Lemma G.6 br = bitreversal_permutation(1 << log_n) b = perm2butterfly(br[v[br]], complex=complex, increasing_stride=False) b.increasing_stride = True br_half = bitreversal_permutation((1 << log_n) // 2, pytorch_format=True) with torch.no_grad(): b.twiddle.copy_(b.twiddle[:, :, :, br_half]) b.in_size = b.out_size = n return b v = v[None] twiddle_right_factors, twiddle_left_factors = [], [] for _ in range(log_n): right_factor, left_factor, v = outer_twiddle_factors(v) twiddle_right_factors.append(right_factor) twiddle_left_factors.append(left_factor) b = Butterfly(n, n, bias=False, complex=complex, increasing_stride=False, nblocks=2) with torch.no_grad(): b_twiddle = b.twiddle if not complex else view_as_complex(b.twiddle) twiddle = torch.stack([ torch.stack(twiddle_right_factors), torch.stack(twiddle_left_factors).flip([0]) ]).unsqueeze(0) b_twiddle.copy_(twiddle if not complex else real2complex(twiddle)) return b
def conv2d_circular_multichannel(n1: int, n2: int, weight: torch.Tensor, flatten: bool = False) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs nn.Conv2d with multiple in/out channels, with circular padding. The output of nn.Conv2d must have the same size as the input (i.e. kernel size must be 2k + 1, and padding k for some integer k). Parameters: n1: size of the last dimension of the input. n2: size of the second to last dimension of the input. weight: torch.Tensor of size (out_channels, in_channels, kernel_size2, kernel_size1). Kernel_size must be odd, and smaller than n1/n2. Padding is assumed to be (kernel_size - 1) // 2. flatten: whether to internally flatten the last 2 dimensions of the input. Only support n1 and n2 being powers of 2. """ assert weight.dim() == 4, 'Weight must have dimension 4' kernel_size2, kernel_size1 = weight.shape[-2], weight.shape[-1] assert kernel_size1 < n1, kernel_size2 < n2 assert kernel_size1 % 2 == 1 and kernel_size2 % 2 == 1, 'Kernel size must be odd' out_channels, in_channels = weight.shape[:2] padding1 = (kernel_size1 - 1) // 2 padding2 = (kernel_size2 - 1) // 2 col = F.pad(weight.flip([-1]), (0, n1 - kernel_size1)).roll(-padding1, dims=-1) col = F.pad(col.flip([-2]), (0, 0, 0, n2 - kernel_size2)).roll(-padding2, dims=-2) # From here we mimic the circulant construction, but the diagonal multiply is replaced with # multiply and then sum across the in-channels. complex = col.is_complex() log_n1 = int(math.ceil(math.log2(n1))) log_n2 = int(math.ceil(math.log2(n2))) if flatten: assert n1 == 1 << log_n1, n2 == 1 << log_n2 # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n1? # I've only figured out how to pad to size 1 << (log_n1 + 1). # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c] n_extended1 = n1 if n1 == 1 << log_n1 else 1 << (log_n1 + 1) n_extended2 = n2 if n2 == 1 << log_n2 else 1 << (log_n2 + 1) b_fft = fft2d(n_extended1, n_extended2, normalized=True, br_first=False, with_br_perm=False, flatten=flatten).to(col.device) if not flatten: b_fft.map1.in_size = n1 b_fft.map2.in_size = n2 else: b_fft = b_fft[1] # Ignore the nn.Flatten and nn.Unflatten b_ifft = ifft2d(n_extended1, n_extended2, normalized=True, br_first=True, with_br_perm=False, flatten=flatten).to(col.device) if not flatten: b_ifft.map1.out_size = n1 b_ifft.map2.out_size = n2 else: b_ifft = b_ifft[1] # Ignore the nn.Flatten and nn.Unflatten if n1 < n_extended1: col_0 = F.pad(col, (0, 2 * ((1 << log_n1) - n1))) col = torch.cat((col_0, col), dim=-1) if n2 < n_extended2: col_0 = F.pad(col, (0, 0, 0, 2 * ((1 << log_n2) - n2))) col = torch.cat((col_0, col), dim=-2) if not col.is_complex(): col = real2complex(col) # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the # circulant matrix. col_f = torch.fft.fftn(col, dim=(-1, -2), norm=None) br_perm1 = bitreversal_permutation(n_extended1, pytorch_format=True).to(col.device) br_perm2 = bitreversal_permutation(n_extended2, pytorch_format=True).to(col.device) # col_f[..., br_perm2, br_perm1] would error "shape mismatch: indexing tensors could not be # broadcast together" # col_f = col_f[..., br_perm2, :][..., br_perm1] col_f = torch.view_as_complex( torch.view_as_real(col_f)[..., br_perm2, :, :][..., br_perm1, :]) if flatten: col_f = col_f.reshape(*col_f.shape[:-2], col_f.shape[-2] * col_f.shape[-1]) # We just want (input_f.unsqueeze(1) * col_f).sum(dim=2). # This can be written as a complex matrix multiply as well. if not complex: if not flatten: return nn.Sequential(Real2Complex(), b_fft, DiagonalMultiplySum(col_f), b_ifft, Complex2Real()) else: return nn.Sequential(Real2Complex(), nn.Flatten(start_dim=-2), b_fft, DiagonalMultiplySum(col_f), b_ifft, nn.Unflatten(-1, (n2, n1)), Complex2Real()) else: if not flatten: return nn.Sequential(b_fft, DiagonalMultiplySum(col_f), b_ifft) else: return nn.Sequential(nn.Flatten(start_dim=-2), b_fft, DiagonalMultiplySum(col_f), b_ifft, nn.Unflatten(-1, (n2, n1)))
def circulant(col, transposed=False, separate_diagonal=True) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs circulant matrix multiplication. Parameters: col: torch.Tensor of size (n, ). The first column of the circulant matrix. transposed: if True, then the circulant matrix is transposed, i.e. col is the first *row* of the matrix. separate_diagonal: if True, the returned nn.Module is Butterfly, Diagonal, Butterfly. if False, the diagonal is combined into the Butterfly part. """ assert col.dim() == 1, 'Vector col must have dimension 1' complex = col.is_complex() n = col.shape[0] log_n = int(math.ceil(math.log2(n))) # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n? # I've only figured out how to pad to size 1 << (log_n + 1). # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c] n_extended = n if n == 1 << log_n else 1 << (log_n + 1) b_fft = fft(n_extended, normalized=True, br_first=False, with_br_perm=False).to(col.device) b_fft.in_size = n b_ifft = ifft(n_extended, normalized=True, br_first=True, with_br_perm=False).to(col.device) b_ifft.out_size = n if n < n_extended: col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n))) col = torch.cat((col_0, col)) if not col.is_complex(): col = real2complex(col) # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the # circulant matrix. col_f = torch.fft.fft(col, norm=None) if transposed: # We could have just transposed the iFFT * Diag * FFT to get FFT * Diag * iFFT. # Instead we use the fact that row is the reverse of col, but the 0-th element stays put. # This corresponds to the same reversal in the frequency domain. # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal col_f = torch.cat((col_f[:1], col_f[1:].flip([0]))) br_perm = bitreversal_permutation(n_extended, pytorch_format=True).to(col.device) # diag = col_f[..., br_perm] diag = index_last_dim(col_f, br_perm) if separate_diagonal: if not complex: return nn.Sequential(Real2Complex(), b_fft, Diagonal(diagonal_init=diag), b_ifft, Complex2Real()) else: return nn.Sequential(b_fft, Diagonal(diagonal_init=diag), b_ifft) else: # Combine the diagonal with the last twiddle factor of b_fft with torch.no_grad(): b_fft = diagonal_butterfly(b_fft, diag, diag_first=False, inplace=True) # Combine the b_fft and b_ifft into one Butterfly (with nblocks=2). b = butterfly_product(b_fft, b_ifft) b.in_size = n b.out_size = n return b if complex else nn.Sequential(Real2Complex(), b, Complex2Real())