def ifft2d_unitary(n1: int, n2: int, br_first: bool = True, with_br_perm: bool = True) -> nn.Module: """ Construct an nn.Module based on ButterflyUnitary that exactly performs the 2D iFFT. Corresponds to normalized=True. Does not support flatten for now. Parameters: n1: size of the iFFT on the last input dimension. Must be a power of 2. n2: size of the iFFT on the second to last input dimension. Must be a power of 2. br_first: which decomposition of iFFT. True corresponds to decimation-in-frequency. False corresponds to decimation-in-time. with_br_perm: whether to return both the butterfly and the bit reversal permutation. """ b_ifft1 = ifft_unitary(n1, br_first=br_first, with_br_perm=False) b_ifft2 = ifft_unitary(n2, br_first=br_first, with_br_perm=False) b = TensorProduct(b_ifft1, b_ifft2) if with_br_perm: br_perm1 = FixedPermutation( bitreversal_permutation(n1, pytorch_format=True)) br_perm2 = FixedPermutation( bitreversal_permutation(n2, pytorch_format=True)) br_perm = TensorProduct(br_perm1, br_perm2) return nn.Sequential(br_perm, b) if br_first else nn.Sequential( b, br_perm) else: return b
def ifft2d(n1: int, n2: int, normalized: bool = False, br_first: bool = True, with_br_perm: bool = True, flatten=False) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs the 2D iFFT. Parameters: n1: size of the iFFT on the last input dimension. Must be a power of 2. n2: size of the iFFT on the second to last input dimension. Must be a power of 2. normalized: if True, corresponds to the unitary iFFT (i.e. multiplied by 1/sqrt(n)) br_first: which decomposition of iFFT. True corresponds to decimation-in-frequency. False corresponds to decimation-in-time. with_br_perm: whether to return both the butterfly and the bit reversal permutation. flatten: whether to combine the 2 butterflies into 1 with Kronecker product. """ b_ifft1 = ifft(n1, normalized=normalized, br_first=br_first, with_br_perm=False) b_ifft2 = ifft(n2, normalized=normalized, br_first=br_first, with_br_perm=False) b = TensorProduct(b_ifft1, b_ifft2) if not flatten else butterfly_kronecker( b_ifft1, b_ifft2) if with_br_perm: br_perm1 = FixedPermutation( bitreversal_permutation(n1, pytorch_format=True)) br_perm2 = FixedPermutation( bitreversal_permutation(n2, pytorch_format=True)) br_perm = (TensorProduct(br_perm1, br_perm2) if not flatten else permutation_kronecker(br_perm1, br_perm2)) if not flatten: return nn.Sequential(br_perm, b) if br_first else nn.Sequential( b, br_perm) else: return (nn.Sequential(nn.Flatten( start_dim=-2), br_perm, b, nn.Unflatten(-1, (n2, n1))) if br_first else nn.Sequential(nn.Flatten( start_dim=-2), b, br_perm, nn.Unflatten(-1, (n2, n1)))) else: return b if not flatten else nn.Sequential(nn.Flatten( start_dim=-2), b, nn.Unflatten(-1, (n2, n1)))
def test_fft2d_init(self): batch_size = 10 in_channels = 3 out_channels = 4 n1, n2 = 16, 32 input = torch.randn(batch_size, in_channels, n2, n1) for kernel_size1 in [1, 3, 5, 7]: for kernel_size2 in [1, 3, 5, 7]: padding1 = (kernel_size1 - 1) // 2 padding2 = (kernel_size2 - 1) // 2 conv = nn.Conv2d(in_channels, out_channels, (kernel_size2, kernel_size1), padding=(padding2, padding1), padding_mode='circular', bias=False) out_torch = conv(input) weight = conv.weight w = F.pad(weight.flip(dims=(-1, )), (0, n1 - kernel_size1)).roll(-padding1, dims=-1) w = F.pad(w.flip(dims=(-2, )), (0, 0, 0, n2 - kernel_size2)).roll(-padding2, dims=-2) increasing_strides = [False, False, True] inits = ['fft_no_br', 'fft_no_br', 'ifft_no_br'] for nblocks in [1, 2, 3]: Kd, K1, K2 = [ TensorProduct( Butterfly(n1, n1, bias=False, complex=complex, increasing_stride=incstride, init=i, nblocks=nblocks), Butterfly(n2, n2, bias=False, complex=complex, increasing_stride=incstride, init=i, nblocks=nblocks)) for incstride, i in zip(increasing_strides, inits) ] with torch.no_grad(): Kd.map1 *= math.sqrt(n1) Kd.map2 *= math.sqrt(n2) out = K2( complex_matmul( K1(real2complex(input)).permute(2, 3, 0, 1), Kd(real2complex(w)).permute(2, 3, 1, 0)).permute( 2, 3, 0, 1)).real self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol))
def __init__(self, in_size, in_ch, out_ch, kernel_size, complex=True, init='ortho', nblocks=1, base=2, zero_pad=True): super().__init__() self.in_size = in_size self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size self.complex = complex assert init in ['ortho', 'fft'] if init == 'fft': assert self.complex, 'fft init requires complex=True' self.init = init self.nblocks = nblocks assert base in [2, 4] self.base = base self.zero_pad = zero_pad if isinstance(self.in_size, int): self.in_size = (self.in_size, self.in_size) if isinstance(self.kernel_size, int): self.kernel_size = (self.kernel_size, self.kernel_size) self.padding = (self.kernel_size[0] - 1) // 2, (self.kernel_size[1] - 1) // 2 # Just to use nn.Conv2d's initialization self.weight = nn.Parameter( nn.Conv2d(self.in_ch, self.out_ch, self.kernel_size, padding=self.padding, bias=False).weight.flip([-1, -2])) increasing_strides = [False, False, True] inits = ['ortho'] * 3 if self.init == 'ortho' else [ 'fft_no_br', 'fft_no_br', 'ifft_no_br' ] self.Kd, self.K1, self.K2 = [ TensorProduct( Butterfly(self.in_size[-1], self.in_size[-1], bias=False, complex=complex, increasing_stride=incstride, init=i, nblocks=nblocks), Butterfly(self.in_size[-2], self.in_size[-2], bias=False, complex=complex, increasing_stride=incstride, init=i, nblocks=nblocks)) for incstride, i in zip(increasing_strides, inits) ] with torch.no_grad(): self.Kd.map1 *= math.sqrt(self.in_size[-1]) self.Kd.map2 *= math.sqrt(self.in_size[-2]) if self.zero_pad and self.complex: # Instead of zero-padding and calling weight.roll(-self.padding[-1], dims=-1) and # weight.roll(-self.padding[-2], dims=-2), we multiply self.Kd by complex exponential # instead, using the Shift theorem. # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Shift_theorem with torch.no_grad(): n1, n2 = self.Kd.map1.n, self.Kd.map2.n device = self.Kd.map1.twiddle.device br1 = bitreversal_permutation(n1, pytorch_format=True).to(device) br2 = bitreversal_permutation(n2, pytorch_format=True).to(device) diagonal1 = torch.exp(1j * 2 * math.pi / n1 * self.padding[-1] * torch.arange(n1, device=device))[br1] diagonal2 = torch.exp(1j * 2 * math.pi / n2 * self.padding[-2] * torch.arange(n2, device=device))[br2] # We multiply the 1st block instead of the last block (only the first block is not # the identity if init=fft). This seems to perform a tiny bit better. # If init=ortho, this won't correspond exactly to rolling the weight. self.Kd.map1.twiddle[:, 0, -1, :, 0, :] *= diagonal1[::2].unsqueeze(-1) self.Kd.map1.twiddle[:, 0, -1, :, 1, :] *= diagonal1[1::2].unsqueeze(-1) self.Kd.map2.twiddle[:, 0, -1, :, 0, :] *= diagonal2[::2].unsqueeze(-1) self.Kd.map2.twiddle[:, 0, -1, :, 1, :] *= diagonal2[1::2].unsqueeze(-1) if base == 4: self.Kd.map1, self.Kd.map2 = self.Kd.map1.to_base4( ), self.Kd.map2.to_base4() self.K1.map1, self.K1.map2 = self.K1.map1.to_base4( ), self.K1.map2.to_base4() self.K2.map1, self.K2.map2 = self.K2.map1.to_base4( ), self.K2.map2.to_base4() if complex: self.Kd = nn.Sequential(Real2Complex(), self.Kd) self.K1 = nn.Sequential(Real2Complex(), self.K1) self.K2 = nn.Sequential(self.K2, Complex2Real())
def __init__(self, in_size, in_ch, out_ch, kernel_size, complex=True, init='random'): super().__init__() self.in_size = in_size self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size self.complex = complex assert init in ['random', 'fft'] if init == 'fft': assert self.complex, 'fft init requires complex=True' self.init = init if isinstance(self.in_size, int): self.in_size = (self.in_size, self.in_size) if isinstance(self.kernel_size, int): self.kernel_size = (self.kernel_size, self.kernel_size) self.padding = (self.kernel_size[0] - 1) // 2, (self.kernel_size[1] - 1) // 2 # Just to use nn.Conv2d's initialization self.weight = nn.Parameter( nn.Conv2d(self.in_ch, self.out_ch, self.kernel_size, padding=self.padding, bias=False).weight.flip([-1, -2])) linear_cls = nn.Linear if not complex else ComplexLinear self.Kd, self.K1, self.K2 = [ TensorProduct( linear_cls(self.in_size[-1], self.in_size[-1], bias=False), linear_cls(self.in_size[-2], self.in_size[-2], bias=False)) for _ in range(3) ] if init == 'fft': eye1 = torch.eye(self.in_size[-1], dtype=torch.complex64) eye2 = torch.eye(self.in_size[-2], dtype=torch.complex64) # These are symmetric so we don't have to take transpose fft_mat1 = torch.fft.fft(eye1, norm='ortho') fft_mat2 = torch.fft.fft(eye2, norm='ortho') ifft_mat1 = torch.fft.ifft(eye1, norm='ortho') ifft_mat2 = torch.fft.ifft(eye2, norm='ortho') with torch.no_grad(): self.Kd.map1.weight.copy_(fft_mat1) self.Kd.map2.weight.copy_(fft_mat2) self.K1.map1.weight.copy_(fft_mat1) self.K1.map2.weight.copy_(fft_mat2) self.K2.map1.weight.copy_(ifft_mat1) self.K2.map2.weight.copy_(ifft_mat2) with torch.no_grad(): self.Kd.map1.weight *= math.sqrt(self.in_size[-1]) self.Kd.map2.weight *= math.sqrt(self.in_size[-2]) self.Kd.map1.weight._is_structured = True self.Kd.map2.weight._is_structured = True self.K1.map1.weight._is_structured = True self.K1.map2.weight._is_structured = True self.K2.map1.weight._is_structured = True self.K2.map2.weight._is_structured = True if complex: self.Kd = nn.Sequential(Real2Complex(), self.Kd) self.K1 = nn.Sequential(Real2Complex(), self.K1) self.K2 = nn.Sequential(self.K2, Complex2Real())