def forward(self, input): """ Parameters: input: (stack, ..., size) if real or (stack, ..., size, 2) if complex if not tied_weight: (stack, n_blocks, ..., size) if real or (stack, n_blocks, ..., size, 2) if complex Return: output: (stack, ..., size) if real or (stack, ..., size, 2) if complex if not tied_weight: (stack, n_blocks, ..., size) if real or (stack, n_blocks, ..., size, 2) if complex """ if self.tied_weight: if not self.complex: return (self.ABCD.unsqueeze(1) * input.view(self.stack, -1, 1, 2, self.size // 2)).sum( dim=-2).view(input.shape) else: return complex_mul( self.ABCD.unsqueeze(1), input.view(self.stack, -1, 1, 2, self.size // 2, 2)).sum(dim=-3).view(input.shape) else: if not self.complex: return (self.ABCD.unsqueeze(2) * input.view( self.stack, self.n_blocks, -1, 1, 2, self.size // 2)).sum( dim=-2).view(input.shape) else: return complex_mul( self.ABCD.unsqueeze(2), input.view(self.stack, self.n_blocks, -1, 1, 2, self.size // 2, 2)).sum(dim=-3).view(input.shape)
def forward(self, input): """ Parameters: input: (batch, *, in_size) Return: output: (batch, *, out_size) """ u = input.view(np.prod(input.size()[:-1]), input.size(-1)) batch = u.shape[0] # output = toeplitz_mult(self.G, self.H, input, self.corner) # return output.reshape(batch, self.nstack * self.size) n = self.in_size v = self.H # u_f = torch.rfft(torch.cat((u.flip(1), torch.zeros_like(u)), dim=-1), 1) u_f = torch.rfft( torch.cat((u[:, self.reverse_idx], torch.zeros_like(u)), dim=-1), 1) v_f = torch.rfft(torch.cat((v, torch.zeros_like(v)), dim=-1), 1) uv_f = complex_mul(u_f.unsqueeze(1).unsqueeze(1), v_f) # transpose_out = torch.irfft(uv_f, 1, signal_sizes=(2 * n, ))[..., :n].flip(3) transpose_out = torch.irfft(uv_f, 1, signal_sizes=(2 * n, ))[..., self.reverse_idx] v = self.G w = transpose_out w_f = torch.rfft(torch.cat((w, torch.zeros_like(w)), dim=-1), 1) v_f = torch.rfft(torch.cat((v, torch.zeros_like(v)), dim=-1), 1) wv_sum_f = complex_mul(w_f, v_f).sum(dim=2) output = torch.irfft(wv_sum_f, 1, signal_sizes=(2 * n, ))[..., :n] output = output.reshape(batch, self.nstack * self.in_size)[:, :self.out_size] if self.bias is not None: output = output + self.bias return output.view(*input.size()[:-1], self.out_size)
def test_butterfly_factor_complex_multiply(): from complex_utils import complex_mul n = 1024 m = int(math.log2(n)) x = torch.randn((n, 2), requires_grad=True) sizes = [n >> i for i in range(m)] for size in sizes: bf = Block2x2Diag(size, complex=True) x = x.view(-1, 2 * bf.ABCD.shape[-2], 2) result_slow = (complex_mul( bf.ABCD, x.view(x.shape[:-2] + (1, 2, size // 2, 2))).sum(dim=-3)).view(x.shape) result = butterfly_factor_mult(bf.ABCD, x.view(-1, 2, bf.ABCD.shape[-2], 2)).view(x.shape) assert torch.allclose(result, result_slow, atol=1e-6) grad = torch.randn_like(x) d_coef_slow, d_x_slow = torch.autograd.grad(result_slow, (bf.ABCD, x), grad, retain_graph=True) d_coef, d_x = torch.autograd.grad(result, (bf.ABCD, x), grad, retain_graph=True) assert torch.allclose(d_coef, d_coef_slow, atol=1e-6) assert torch.allclose(d_x, d_x_slow, atol=1e-6)
def test_butterfly_dct(): from scipy.fftpack import dct # DCT matrix for n = 4 size = 4 # Need to transpose as dct acts on rows of matrix np.eye, not columns DCT = torch.tensor(dct(np.eye(size)).T, dtype=torch.float) M0diag = torch.tensor([[1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [0.0, 1.0]], requires_grad=True) M0subdiag = torch.tensor([[1.0, 0.0], [1.0, 0.0]], requires_grad=True) M0superdiag = torch.tensor([[1.0, 0.0], [0.0, -1.0]], requires_grad=True) M0 = Butterfly(size, diagonal=2, complex=True, diag=M0diag, subdiag=M0subdiag, superdiag=M0superdiag) M1 = Butterfly(size, diagonal=1, complex=True, diag=torch.tensor( [[1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0]], requires_grad=True), subdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True), superdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True)) arange_ = np.arange(size) dct_perm = np.concatenate((arange_[::2], arange_[::-2])) br_perm = bitreversal_permutation(size) perm = torch.arange(size)[dct_perm][br_perm] arange_ = torch.arange(size, dtype=torch.float) diag_real = 2 * torch.cos(-math.pi * arange_ / (2 * size)) diag_imag = 2 * torch.sin(-math.pi * arange_ / (2 * size)) diag = torch.stack((torch.diag(diag_real), torch.diag(diag_imag)), dim=-1) assert torch.allclose( complex_matmul(diag, complex_matmul(M0.matrix(), M1.matrix()))[:, perm, 0], DCT) D = torch.stack((diag_real, diag_imag), dim=-1) DM0 = Butterfly(size, diagonal=2, complex=True, diag=complex_mul(D, M0diag), subdiag=complex_mul(D[2:], M0subdiag), superdiag=complex_mul(D[:2], M0superdiag)) assert torch.allclose( complex_matmul(DM0.matrix(), M1.matrix())[:, perm, 0], DCT)
def forward(self, input): """ Parameters: input: (batch, size) Return: output: (batch, nstack * size) """ batch = input.shape[0] input_f = torch.rfft(input, 1) prod = complex_mul(self.c_f, input_f.unsqueeze(1)) return torch.irfft(prod, 1, signal_sizes=(self.size, )).view( batch, self.nstack * self.size)
def toeplitz_krylov_multiply(v, w, f=0.0): """Multiply \sum_i Krylov(Z_f, v_i) @ w_i. Parameters: v: (nstack, rank, n) w: (batch_size, nstack, rank, n) f: real number Returns: product: (batch, nstack, n) """ _, nstack, rank, n = w.shape nstack_, rank_, n_ = v.shape assert n == n_, 'w and v must have the same last dimension' assert rank == rank_, 'w and v must have the same rank' assert nstack == nstack_, 'w and v must have the same nstack' if f != 0.0: # cycle version # Computing the roots of f mod = abs(f)**(torch.arange(n, dtype=w.dtype, device=w.device) / n) if f > 0: arg = torch.stack((torch.ones(n, dtype=w.dtype, device=w.device), torch.zeros(n, dtype=w.dtype, device=w.device)), dim=-1) else: # Find primitive roots of -1 angles = torch.arange(n, dtype=w.dtype, device=w.device) / n * np.pi arg = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) eta = mod[:, np.newaxis] * arg eta_inverse = (1.0 / mod)[:, np.newaxis] * conjugate(arg) w_f = torch.fft(eta * w[..., np.newaxis], 1) v_f = torch.fft(eta * v[..., np.newaxis], 1) wv_sum_f = complex_mul(w_f, v_f).sum(dim=2) wv_sum = torch.ifft(wv_sum_f, 1) # We only need the real part of complex_mul(eta_inverse, wv_sum) return eta_inverse[..., 0] * wv_sum[..., 0] - eta_inverse[ ..., 1] - wv_sum[..., 1] else: w_f = torch.rfft(torch.cat((w, torch.zeros_like(w)), dim=-1), 1) v_f = torch.rfft(torch.cat((v, torch.zeros_like(v)), dim=-1), 1) wv_sum_f = complex_mul(w_f, v_f).sum(dim=2) return torch.irfft(wv_sum_f, 1, signal_sizes=(2 * n, ))[..., :n]
def toeplitz_krylov_transpose_multiply(v, u, f=0.0): """Multiply Krylov(Z_f, v_i)^T @ u. Parameters: v: (nstack, rank, n) u: (batch_size, n) f: real number Returns: product: (batch, nstack, rank, n) """ _, n = u.shape _, _, n_ = v.shape assert n == n_, 'u and v must have the same last dimension' if f != 0.0: # cycle version # Computing the roots of f mod = abs(f)**(torch.arange(n, dtype=u.dtype, device=u.device) / n) if f > 0: arg = torch.stack((torch.ones(n, dtype=u.dtype, device=u.device), torch.zeros(n, dtype=u.dtype, device=u.device)), dim=-1) else: # Find primitive roots of -1 angles = torch.arange(n, dtype=u.dtype, device=u.device) / n * np.pi arg = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) eta = mod[:, np.newaxis] * arg eta_inverse = (1.0 / mod)[:, np.newaxis] * conjugate(arg) u_f = torch.ifft(eta_inverse * u[..., np.newaxis], 1) v_f = torch.fft(eta * v.unsqueeze(-1), 1) uv_f = complex_mul(u_f.unsqueeze(1).unsqueeze(1), v_f) uv = torch.fft(uv_f, 1) # We only need the real part of complex_mul(eta, uv) return eta[..., 0] * uv[..., 0] - eta[..., 1] * uv[..., 1] else: u_f = torch.rfft(torch.cat((u.flip(1), torch.zeros_like(u)), dim=-1), 1) v_f = torch.rfft(torch.cat((v, torch.zeros_like(v)), dim=-1), 1) uv_f = complex_mul(u_f.unsqueeze(1).unsqueeze(1), v_f) return torch.irfft(uv_f, 1, signal_sizes=(2 * n, ))[..., :n].flip(3)
def circulant_fft(c, x): n = x.shape[-1] x_f = torch.rfft(x, 1) c_f = torch.rfft(c, 1) prod = complex_mul(c_f, x_f) return torch.irfft(prod, 1, signal_sizes=(n, ))