def test_ifft2d(self): batch_size = 10 n1 = 32 n2 = 16 input = torch.randn(batch_size, n2, n1, dtype=torch.complex64) for normalized in [False, True]: out_torch = view_as_complex( torch.ifft(view_as_real(input), signal_ndim=2, normalized=normalized)) # Just to show how ifft2d is exactly 2 iffts on each dimension input_f = view_as_complex( torch.ifft(view_as_real(input), signal_ndim=1, normalized=normalized)) out_fft = view_as_complex( torch.ifft(view_as_real(input_f.transpose(-1, -2)), signal_ndim=1, normalized=normalized)).transpose(-1, -2) self.assertTrue( torch.allclose(out_torch, out_fft, self.rtol, self.atol)) for br_first in [True, False]: for flatten in [False, True]: b = torch_butterfly.special.ifft2d(n1, n2, normalized=normalized, br_first=br_first, flatten=flatten) out = b(input) self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol))
def __init__(self, in_size, out_size, matrix_batch=1, bias=True, complex=False, increasing_stride=True, init='randn', nblocks=1): nn.Module.__init__(self) self.in_size = in_size log_n = int(math.ceil(math.log2(in_size))) self.log_n = log_n size = self.in_size_extended = 1 << log_n # Will zero-pad input if in_size is not a power of 2 self.out_size = out_size self.matrix_batch = matrix_batch self.nstacks = int(math.ceil(out_size / self.in_size_extended)) self.complex = complex self.increasing_stride = increasing_stride assert nblocks >= 1 self.nblocks = nblocks dtype = torch.get_default_dtype() if not self.complex else real_dtype_to_complex[torch.get_default_dtype()] twiddle_shape = (self.matrix_batch * self.nstacks, nblocks, log_n, size // 2, 2, 2) assert init in ['randn', 'ortho', 'identity'] self.init = init self.twiddle = nn.Parameter(torch.empty(twiddle_shape, dtype=dtype)) if bias: self.bias = nn.Parameter(torch.empty(self.matrix_batch, out_size, dtype=dtype)) else: self.register_parameter('bias', None) self.twiddle._is_structured = True # Flag to avoid weight decay # Pytorch 1.6 doesn't support torch.Tensor.add_(other, alpha) yet. # This is used in optimizers such as SGD. # So we have to store the parameters as real. if self.complex: self.twiddle = nn.Parameter(view_as_real(self.twiddle)) if self.bias is not None: self.bias = nn.Parameter(view_as_real(self.bias)) self.reset_parameters()
def test_circulant(self): batch_size = 10 n = 13 for complex in [False, True]: dtype = torch.float32 if not complex else torch.complex64 col = torch.randn(n, dtype=dtype) C = la.circulant(col.numpy()) input = torch.randn(batch_size, n, dtype=dtype) out_torch = torch.tensor(input.detach().numpy() @ C.T) out_np = torch.tensor(np.fft.ifft( np.fft.fft(input.numpy()) * np.fft.fft(col.numpy())), dtype=dtype) self.assertTrue( torch.allclose(out_torch, out_np, self.rtol, self.atol)) # Just to show how to implement circulant multiply with FFT if complex: input_f = view_as_complex( torch.fft(view_as_real(input), signal_ndim=1)) col_f = view_as_complex( torch.fft(view_as_real(col), signal_ndim=1)) prod_f = complex_mul(input_f, col_f) out_fft = view_as_complex( torch.ifft(view_as_real(prod_f), signal_ndim=1)) self.assertTrue( torch.allclose(out_torch, out_fft, self.rtol, self.atol)) for separate_diagonal in [True, False]: b = torch_butterfly.special.circulant( col, transposed=False, separate_diagonal=separate_diagonal) out = b(input) self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol)) row = torch.randn(n, dtype=dtype) C = la.circulant(row.numpy()).T input = torch.randn(batch_size, n, dtype=dtype) out_torch = torch.tensor(input.detach().numpy() @ C.T) # row is the reverse of col, except the 0-th element stays put # This corresponds to the same reversal in the frequency domain. # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal row_f = np.fft.fft(row.numpy()) row_f_reversed = np.hstack((row_f[:1], row_f[1:][::-1])) out_np = torch.tensor(np.fft.ifft( np.fft.fft(input.numpy()) * row_f_reversed), dtype=dtype) self.assertTrue( torch.allclose(out_torch, out_np, self.rtol, self.atol)) for separate_diagonal in [True, False]: b = torch_butterfly.special.circulant( row, transposed=True, separate_diagonal=separate_diagonal) out = b(input) self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol))
def test_conv1d_circular_multichannel(self): batch_size = 10 in_channels = 3 out_channels = 4 for n in [13, 16]: for kernel_size in [1, 3, 5, 7]: padding = (kernel_size - 1) // 2 conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, padding_mode='circular', bias=False) weight = conv.weight input = torch.randn(batch_size, in_channels, n) out_torch = conv(input) # Just to show how to implement conv1d with FFT input_f = view_as_complex(torch.rfft(input, signal_ndim=1)) col = F.pad(weight.flip(dims=(-1, )), (0, n - kernel_size)).roll(-padding, dims=-1) col_f = view_as_complex(torch.rfft(col, signal_ndim=1)) prod_f = complex_mul(input_f.unsqueeze(1), col_f).sum(dim=2) out_fft = torch.irfft(view_as_real(prod_f), signal_ndim=1, signal_sizes=(n, )) self.assertTrue( torch.allclose(out_torch, out_fft, self.rtol, self.atol)) b = torch_butterfly.special.conv1d_circular_multichannel( n, weight) out = b(input) self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol))
def conv1d_circular_multichannel(n, weight) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs nn.Conv1d with multiple in/out channels, with circular padding. The output of nn.Conv1d must have the same size as the input (i.e. kernel size must be 2k + 1, and padding k for some integer k). Parameters: n: size of the input. weight: torch.Tensor of size (out_channels, in_channels, kernel_size). Kernel_size must be odd, and smaller than n. Padding is assumed to be (kernel_size - 1) // 2. """ assert weight.dim() == 3, 'Weight must have dimension 3' kernel_size = weight.shape[-1] assert kernel_size < n assert kernel_size % 2 == 1, 'Kernel size must be odd' out_channels, in_channels = weight.shape[:2] padding = (kernel_size - 1) // 2 col = F.pad(weight.flip([-1]), (0, n - kernel_size)).roll(-padding, dims=-1) # From here we mimic the circulant construction, but the diagonal multiply is replaced with # multiply and then sum across the in-channels. complex = col.is_complex() log_n = int(math.ceil(math.log2(n))) # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n? # I've only figured out how to pad to size 1 << (log_n + 1). # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c] n_extended = n if n == 1 << log_n else 1 << (log_n + 1) b_fft = fft(n_extended, normalized=True, br_first=False, with_br_perm=False).to(col.device) b_fft.in_size = n b_ifft = ifft(n_extended, normalized=True, br_first=True, with_br_perm=False).to(col.device) b_ifft.out_size = n if n < n_extended: col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n))) col = torch.cat((col_0, col), dim=-1) if not col.is_complex(): col = real2complex(col) # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the # circulant matrix. col_f = view_as_complex( torch.fft(view_as_real(col), signal_ndim=1, normalized=False)) br_perm = bitreversal_permutation(n_extended, pytorch_format=True).to(col.device) col_f = col_f[..., br_perm] # We just want (input_f.unsqueeze(1) * col_f).sum(dim=2). # This can be written as matrix multiply but Pytorch 1.6 doesn't yet support complex matrix # multiply. if not complex: return nn.Sequential(Real2Complex(), b_fft, DiagonalMultiplySum(col_f), b_ifft, Complex2Real()) else: return nn.Sequential(b_fft, DiagonalMultiplySum(col_f), b_ifft)
def __init__(self, diagonal_init): """ Parameters: diagonal_init: (out_channels, in_channels, size) """ super().__init__() self.diagonal = nn.Parameter(diagonal_init.detach().clone()) self.complex = self.diagonal.is_complex() if self.complex: self.diagonal = nn.Parameter(view_as_real(self.diagonal))
def test_ifft_unitary(self): batch_size = 10 n = 16 input = torch.randn(batch_size, n, dtype=torch.complex64) normalized = True out_torch = view_as_complex( torch.ifft(view_as_real(input), signal_ndim=1, normalized=normalized)) for br_first in [True, False]: b = torch_butterfly.special.ifft_unitary(n, br_first=br_first) out = b(input) self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol))
def test_conv2d_circular_multichannel(self): batch_size = 10 in_channels = 3 out_channels = 4 for n1 in [13, 16]: for n2 in [27, 32]: # flatten is only supported for powers of 2 for now if n1 == 1 << int(math.log2(n1)) and n2 == 1 << int( math.log2(n2)): flatten_cases = [False, True] else: flatten_cases = [False] for kernel_size1 in [1, 3, 5, 7]: for kernel_size2 in [1, 3, 5, 7]: padding1 = (kernel_size1 - 1) // 2 padding2 = (kernel_size2 - 1) // 2 conv = nn.Conv2d(in_channels, out_channels, (kernel_size2, kernel_size1), padding=(padding2, padding1), padding_mode='circular', bias=False) weight = conv.weight input = torch.randn(batch_size, in_channels, n2, n1) out_torch = conv(input) # Just to show how to implement conv2d with FFT input_f = view_as_complex( torch.rfft(input, signal_ndim=2)) col = F.pad(weight.flip(dims=(-1, )), (0, n1 - kernel_size1)).roll(-padding1, dims=-1) col = F.pad(col.flip(dims=(-2, )), (0, 0, 0, n2 - kernel_size2)).roll( -padding2, dims=-2) col_f = view_as_complex(torch.rfft(col, signal_ndim=2)) prod_f = complex_mul(input_f.unsqueeze(1), col_f).sum(dim=2) out_fft = torch.irfft(view_as_real(prod_f), signal_ndim=2, signal_sizes=(n2, n1)) self.assertTrue( torch.allclose(out_torch, out_fft, self.rtol, self.atol)) for flatten in flatten_cases: b = torch_butterfly.special.conv2d_circular_multichannel( n1, n2, weight, flatten=flatten) out = b(input) self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol))
def diagonal_butterfly(butterfly: Butterfly, diagonal: torch.Tensor, diag_first: bool, inplace: bool = True) -> Butterfly: """ Combine a Butterfly and a diagonal into another Butterfly. Only support nstacks==1 for now. Parameters: butterfly: Butterfly(in_size, out_size) diagonal: size (in_size,) if diag_first, else (out_size,). Should be of type complex if butterfly.complex == True. diag_first: If True, the map is input -> diagonal -> butterfly. If False, the map is input -> butterfly -> diagonal. inplace: whether to modify the input Butterfly """ assert butterfly.nstacks == 1 assert butterfly.bias is None twiddle = (butterfly.twiddle.clone() if not butterfly.complex else view_as_complex(butterfly.twiddle).clone()) n = 1 << twiddle.shape[2] if diagonal.shape[-1] < n: diagonal = F.pad(diagonal, (0, n - diagonal.shape[-1]), value=1) if diag_first: if butterfly.increasing_stride: twiddle[:, 0, 0, :, :, 0] *= diagonal[::2].unsqueeze(-1) twiddle[:, 0, 0, :, :, 1] *= diagonal[1::2].unsqueeze(-1) else: n = diagonal.shape[-1] twiddle[:, 0, 0, :, :, 0] *= diagonal[:n // 2].unsqueeze(-1) twiddle[:, 0, 0, :, :, 1] *= diagonal[n // 2:].unsqueeze(-1) else: # Whether the last block is increasing or decreasing stride increasing_stride = butterfly.increasing_stride != ( (butterfly.nblocks - 1) % 2 == 1) if increasing_stride: n = diagonal.shape[-1] twiddle[:, -1, -1, :, 0, :] *= diagonal[:n // 2].unsqueeze(-1) twiddle[:, -1, -1, :, 1, :] *= diagonal[n // 2:].unsqueeze(-1) else: twiddle[:, -1, -1, :, 0, :] *= diagonal[::2].unsqueeze(-1) twiddle[:, -1, -1, :, 1, :] *= diagonal[1::2].unsqueeze(-1) out_butterfly = butterfly if inplace else copy.deepcopy(butterfly) with torch.no_grad(): out_butterfly.twiddle.copy_( twiddle if not butterfly.complex else view_as_real(twiddle)) return out_butterfly
def __init__(self, size=None, complex=False, diagonal_init=None): """Multiply by diagonal matrix Parameter: size: int diagonal_init: (n, ) """ super().__init__() if diagonal_init is not None: self.size = diagonal_init.shape self.diagonal = nn.Parameter(diagonal_init.detach().clone()) self.complex = self.diagonal.is_complex() else: assert size is not None self.size = size dtype = torch.get_default_dtype( ) if not complex else real_dtype_to_complex[ torch.get_default_dtype()] self.diagonal = nn.Parameter(torch.randn(size, dtype=dtype)) self.complex = complex if self.complex: self.diagonal = nn.Parameter(view_as_real(self.diagonal))
def conv2d_circular_multichannel(n1: int, n2: int, weight: torch.Tensor, flatten: bool = False) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs nn.Conv2d with multiple in/out channels, with circular padding. The output of nn.Conv2d must have the same size as the input (i.e. kernel size must be 2k + 1, and padding k for some integer k). Parameters: n1: size of the last dimension of the input. n2: size of the second to last dimension of the input. weight: torch.Tensor of size (out_channels, in_channels, kernel_size2, kernel_size1). Kernel_size must be odd, and smaller than n1/n2. Padding is assumed to be (kernel_size - 1) // 2. flatten: whether to internally flatten the last 2 dimensions of the input. Only support n1 and n2 being powers of 2. """ assert weight.dim() == 4, 'Weight must have dimension 4' kernel_size2, kernel_size1 = weight.shape[-2], weight.shape[-1] assert kernel_size1 < n1, kernel_size2 < n2 assert kernel_size1 % 2 == 1 and kernel_size2 % 2 == 1, 'Kernel size must be odd' out_channels, in_channels = weight.shape[:2] padding1 = (kernel_size1 - 1) // 2 padding2 = (kernel_size2 - 1) // 2 col = F.pad(weight.flip([-1]), (0, n1 - kernel_size1)).roll(-padding1, dims=-1) col = F.pad(col.flip([-2]), (0, 0, 0, n2 - kernel_size2)).roll(-padding2, dims=-2) # From here we mimic the circulant construction, but the diagonal multiply is replaced with # multiply and then sum across the in-channels. complex = col.is_complex() log_n1 = int(math.ceil(math.log2(n1))) log_n2 = int(math.ceil(math.log2(n2))) if flatten: assert n1 == 1 << log_n1, n2 == 1 << log_n2 # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n1? # I've only figured out how to pad to size 1 << (log_n1 + 1). # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c] n_extended1 = n1 if n1 == 1 << log_n1 else 1 << (log_n1 + 1) n_extended2 = n2 if n2 == 1 << log_n2 else 1 << (log_n2 + 1) b_fft = fft2d(n_extended1, n_extended2, normalized=True, br_first=False, with_br_perm=False, flatten=flatten).to(col.device) if not flatten: b_fft.map1.in_size = n1 b_fft.map2.in_size = n2 else: b_fft = b_fft[1] # Ignore the nn.Flatten and Unflatten2D b_ifft = ifft2d(n_extended1, n_extended2, normalized=True, br_first=True, with_br_perm=False, flatten=flatten).to(col.device) if not flatten: b_ifft.map1.out_size = n1 b_ifft.map2.out_size = n2 else: b_ifft = b_ifft[1] # Ignore the nn.Flatten and Unflatten2D if n1 < n_extended1: col_0 = F.pad(col, (0, 2 * ((1 << log_n1) - n1))) col = torch.cat((col_0, col), dim=-1) if n2 < n_extended2: col_0 = F.pad(col, (0, 0, 0, 2 * ((1 << log_n2) - n2))) col = torch.cat((col_0, col), dim=-2) if not col.is_complex(): col = real2complex(col) # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the # circulant matrix. col_f = view_as_complex( torch.fft(view_as_real(col), signal_ndim=2, normalized=False)) br_perm1 = bitreversal_permutation(n_extended1, pytorch_format=True).to(col.device) br_perm2 = bitreversal_permutation(n_extended2, pytorch_format=True).to(col.device) # col_f[..., br_perm2, br_perm1] would error "shape mismatch: indexing tensors could not be # broadcast together" col_f = col_f[..., br_perm2, :][..., br_perm1] if flatten: col_f = col_f.reshape(*col_f.shape[:-2], col_f.shape[-2] * col_f.shape[-1]) # We just want (input_f.unsqueeze(1) * col_f).sum(dim=2). # This can be written as matrix multiply but Pytorch 1.6 doesn't yet support complex matrix # multiply. if not complex: if not flatten: return nn.Sequential(Real2Complex(), b_fft, DiagonalMultiplySum(col_f), b_ifft, Complex2Real()) else: return nn.Sequential(Real2Complex(), nn.Flatten(start_dim=-2), b_fft, DiagonalMultiplySum(col_f), b_ifft, Unflatten2D(n1), Complex2Real()) else: if not flatten: return nn.Sequential(b_fft, DiagonalMultiplySum(col_f), b_ifft) else: return nn.Sequential(nn.Flatten(start_dim=-2), b_fft, DiagonalMultiplySum(col_f), b_ifft, Unflatten2D(n1))
def circulant(col, transposed=False, separate_diagonal=True) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs circulant matrix multiplication. Parameters: col: torch.Tensor of size (n, ). The first column of the circulant matrix. transposed: if True, then the circulant matrix is transposed, i.e. col is the first *row* of the matrix. separate_diagonal: if True, the returned nn.Module is Butterfly, Diagonal, Butterfly. if False, the diagonal is combined into the Butterfly part. """ assert col.dim() == 1, 'Vector col must have dimension 1' complex = col.is_complex() n = col.shape[0] log_n = int(math.ceil(math.log2(n))) # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n? # I've only figured out how to pad to size 1 << (log_n + 1). # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c] n_extended = n if n == 1 << log_n else 1 << (log_n + 1) b_fft = fft(n_extended, normalized=True, br_first=False, with_br_perm=False).to(col.device) b_fft.in_size = n b_ifft = ifft(n_extended, normalized=True, br_first=True, with_br_perm=False).to(col.device) b_ifft.out_size = n if n < n_extended: col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n))) col = torch.cat((col_0, col)) if not col.is_complex(): col = real2complex(col) # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the # circulant matrix. col_f = view_as_complex( torch.fft(view_as_real(col), signal_ndim=1, normalized=False)) if transposed: # We could have just transposed the iFFT * Diag * FFT to get FFT * Diag * iFFT. # Instead we use the fact that row is the reverse of col, but the 0-th element stays put. # This corresponds to the same reversal in the frequency domain. # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal col_f = torch.cat((col_f[:1], col_f[1:].flip([0]))) br_perm = bitreversal_permutation(n_extended, pytorch_format=True).to(col.device) diag = col_f[..., br_perm] if separate_diagonal: if not complex: return nn.Sequential(Real2Complex(), b_fft, Diagonal(diagonal_init=diag), b_ifft, Complex2Real()) else: return nn.Sequential(b_fft, Diagonal(diagonal_init=diag), b_ifft) else: # Combine the diagonal with the last twiddle factor of b_fft with torch.no_grad(): b_fft = diagonal_butterfly(b_fft, diag, diag_first=False, inplace=True) # Combine the b_fft and b_ifft into one Butterfly (with nblocks=2). b = butterfly_product(b_fft, b_ifft) b.in_size = n b.out_size = n return b if complex else nn.Sequential(Real2Complex(), b, Complex2Real())