def test_butterfly_factor_inplace_cuda(self): batch_size = 10 n = 4096 B = Block2x2DiagProduct(n, ortho_init=True).to('cuda') input_ = torch.randn(batch_size, n, device='cuda', requires_grad=True) twiddle = twiddle_list_concat(B) output_inplace = butterfly_factor_mult_inplace(twiddle, input_) output = B(input_) self.assertTrue( torch.allclose(output_inplace, output, rtol=self.rtol, atol=self.atol), (output_inplace - output).abs().max().item()) grad = torch.randn_like(output) d_twiddle_inplace, d_input_inplace = torch.autograd.grad( output_inplace, (twiddle, input_), grad, retain_graph=True) output.backward(grad, retain_graph=True) d_input = input_.grad d_twiddle = torch.cat( [factor.ABCD.grad.permute(2, 0, 1) for factor in B.factors[::-1]]) self.assertTrue( torch.allclose(d_input_inplace, d_input, rtol=self.rtol, atol=self.atol), (d_input_inplace - d_input).abs().max().item()) self.assertTrue( torch.allclose(d_twiddle_inplace, d_twiddle, rtol=self.rtol, atol=self.atol), (d_twiddle_inplace - d_twiddle).abs().max().item())
def _setup(self, config): self.target_matrix = torch.tensor(config['target_matrix'], dtype=torch.float) assert self.target_matrix.shape[0] == self.target_matrix.shape[ 1], 'Only square matrices are supported' assert self.target_matrix.dim() in [ 2, 3 ], 'target matrix must be 2D if real of 3D if complex' size = self.target_matrix.shape[0] torch.manual_seed(config['seed']) # Transposing the permutation product won't capture the FFT, since we'll # permutations that interleave the first half and second half (inverse # of the permutation that separates the even and the odd). # However, using the permutation product with increasing size will work # since it can represent bit reversal, which is its own inverse. self.model = nn.Sequential( Block2x2DiagProduct(size=size, complex=True, decreasing_size=False), BlockPermProduct(size=size, complex=True, share_logit=False, increasing_size=True), ) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.n_epochs_per_validation = config['n_epochs_per_validation'] self.input = real_to_complex(torch.eye(size))
def test_butterfly_factor_complex_inplace_cpu(self): batch_size = 10 n = 4096 B = Block2x2DiagProduct(n, complex=True) input_ = torch.randn(batch_size, n, 2, requires_grad=True) twiddle = twiddle_list_concat(B) output_inplace = butterfly_factor_mult_inplace(twiddle, input_) output = B(input_) self.assertTrue( torch.allclose(output_inplace, output, rtol=self.rtol, atol=self.atol), (output_inplace - output).abs().max().item())
def _setup(self, config): self.target_matrix = torch.tensor(config['target_matrix'], dtype=torch.float) assert self.target_matrix.shape[0] == self.target_matrix.shape[ 1], 'Only square matrices are supported' assert self.target_matrix.dim() in [ 2, 3 ], 'target matrix must be 2D if real of 3D if complex' size = self.target_matrix.shape[0] torch.manual_seed(config['seed']) self.model = Block2x2DiagProduct(size=size, complex=True) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.n_epochs_per_validation = config['n_epochs_per_validation'] self.input = real_to_complex( torch.eye(size)[:, torch.tensor(bitreversal_permutation(size))])
def test_butterfly_factor_intermediate_complex_cuda(self): batch_size = 10 n = 4096 B = Block2x2DiagProduct(n, complex=True).to('cuda') input_ = torch.randn(batch_size, n, 2, device='cuda', requires_grad=True) twiddle = twiddle_list_concat(B).unsqueeze(0) output_intermediate = butterfly_multiply_intermediate(twiddle, input_) output = [input_] for factor in B.factors[::-1]: output.append( butterfly_factor_mult( factor.ABCD, output[-1].view(-1, 2, factor.size // 2, 2)).view(output[-1].shape)) output = torch.stack(output) self.assertTrue( torch.allclose(output_intermediate.squeeze(2), output, rtol=self.rtol, atol=self.atol), (output_intermediate.squeeze(2) - output).abs().max().item()) grad = torch.randn_like(output[-1]) d_twiddle_intermediate, d_input_intermediate = butterfly_multiply_intermediate_backward( grad.unsqueeze(1), twiddle, output_intermediate) output[-1].backward(grad, retain_graph=True) d_input = input_.grad d_twiddle = torch.cat([ factor.ABCD.grad.permute(2, 0, 1, 3) for factor in B.factors[::-1] ]) self.assertTrue( torch.allclose(d_input_intermediate, d_input, rtol=self.rtol, atol=self.atol), (d_input_intermediate - d_input).abs().max().item()) self.assertTrue( torch.allclose(d_twiddle_intermediate, d_twiddle, rtol=self.rtol, atol=self.atol), (d_twiddle_intermediate - d_twiddle).abs().max().item())
def profile_butterfly_mult(): nsteps = 10 batch_size = 100 n = 1024 B = Block2x2DiagProduct(n) x = torch.randn(batch_size, n) # B(x) optimizer = optim.Adam(B.parameters(), lr=0.01) for _ in range(nsteps): optimizer.zero_grad() # output = B(x) # loss = nn.functional.mse_loss(output, x) output = x for factor in B.factors[::-1]: output = butterfly_factor_mult(factor.ABCD, output.view(-1, 2, factor.size // 2)).view(x.shape) # output = output.reshape(x.shape) loss = output.sum() loss.backward() optimizer.step()
def test_butterfly_factor_complex_cpu(self): batch_size = 10 n = 4096 B = Block2x2DiagProduct(n, complex=True) input_ = torch.randn(batch_size, n, 2, requires_grad=True) output = input_ for factor in B.factors[::-1]: prev = output output = butterfly_factor_mult( factor.ABCD, output.view(-1, 2, factor.size // 2, 2)).view(prev.shape) output_slow = (complex_mul( factor.ABCD, prev.view(-1, 1, 2, factor.size // 2, 2)).sum(dim=-3)).view(prev.shape) self.assertTrue( torch.allclose(output, output_slow, rtol=self.rtol, atol=self.atol), (output - output_slow).abs().max().item()) grad = torch.randn_like(output) d_twiddle, d_input = torch.autograd.grad(output, (factor.ABCD, prev), grad, retain_graph=True) d_twiddle_slow, d_input_slow = torch.autograd.grad( output_slow, (factor.ABCD, prev), grad, retain_graph=True) self.assertTrue( torch.allclose(d_twiddle, d_twiddle_slow, rtol=self.rtol, atol=self.atol), (d_twiddle - d_twiddle_slow).abs().max().item()) self.assertTrue( torch.allclose(d_input, d_input_slow, rtol=self.rtol, atol=self.atol), (d_input - d_input_slow).abs().max().item())
def test_butterfly_factor_cuda(self): batch_size = 100 n = 4096 # To test n > MAX_BLOCK_SIZE B = Block2x2DiagProduct(n).to('cuda') input_ = torch.randn(batch_size, n, device='cuda', requires_grad=True) output = input_ for factor in B.factors[::-1]: prev = output output = butterfly_factor_mult( factor.ABCD, output.view(-1, 2, factor.size // 2)).view(prev.shape) output_slow = ((factor.ABCD * prev.view(-1, 1, 2, factor.size // 2)).sum( dim=-2)).view(prev.shape) self.assertTrue( torch.allclose(output, output_slow, rtol=self.rtol, atol=self.atol), (output - output_slow).abs().max().item()) grad = torch.randn_like(output) d_twiddle, d_input = torch.autograd.grad(output, (factor.ABCD, prev), grad, retain_graph=True) d_twiddle_slow, d_input_slow = torch.autograd.grad( output_slow, (factor.ABCD, prev), grad, retain_graph=True) self.assertTrue( torch.allclose(d_twiddle, d_twiddle_slow, rtol=self.rtol, atol=self.atol), (factor.size, (d_twiddle - d_twiddle_slow).abs().max().item())) self.assertTrue( torch.allclose(d_input, d_input_slow, rtol=self.rtol, atol=self.atol), (d_input - d_input_slow).abs().max().item())
def named_target_matrix(name, size): """ Parameter: name: name of the target matrix Return: target_matrix: (n, n) numpy array for real matrices or (n, n, 2) for complex matrices. """ if name == 'dft': return LA.dft(size, scale='sqrtn')[:, :, None].view('float64') elif name == 'idft': return np.ascontiguousarray(LA.dft(size, scale='sqrtn').conj().T)[:, :, None].view('float64') elif name == 'dft2': size_sr = int(math.sqrt(size)) matrix = np.fft.fft2(np.eye(size_sr**2).reshape(-1, size_sr, size_sr), norm='ortho').reshape(-1, size_sr**2) # matrix1d = LA.dft(size_sr, scale='sqrtn') # assert np.allclose(np.kron(m1d, m1d), matrix) # return matrix[:, :, None].view('float64') from butterfly.utils import bitreversal_permutation br_perm = bitreversal_permutation(size_sr) br_perm2 = np.arange(size_sr**2).reshape(size_sr, size_sr)[br_perm][:, br_perm].reshape(-1) matrix = np.ascontiguousarray(matrix[:, br_perm2]) return matrix[:, :, None].view('float64') elif name == 'dct': # Need to transpose as dct acts on rows of matrix np.eye, not columns # return dct(np.eye(size), norm='ortho').T return dct(np.eye(size)).T / math.sqrt(size) elif name == 'dst': return dst(np.eye(size)).T / math.sqrt(size) elif name == 'hadamard': return LA.hadamard(size) / math.sqrt(size) elif name == 'hadamard2': size_sr = int(math.sqrt(size)) matrix1d = LA.hadamard(size_sr) / math.sqrt(size_sr) return np.kron(matrix1d, matrix1d) elif name == 'b2': size_sr = int(math.sqrt(size)) from butterfly import Block2x2DiagProduct b = Block2x2DiagProduct(size_sr) matrix1d = b(torch.eye(size_sr)).t().detach().numpy() return np.kron(matrix1d, matrix1d) elif name == 'convolution': np.random.seed(0) x = np.random.randn(size) return LA.circulant(x) / math.sqrt(size) elif name == 'hartley': return hartley_matrix(size) / math.sqrt(size) elif name == 'haar': return haar_matrix(size, normalized=True) / math.sqrt(size) elif name == 'legendre': grid = np.linspace(-1, 1, size + 2)[1:-1] return legendre.legvander(grid, size - 1).T / math.sqrt(size) elif name == 'hilbert': H = hilbert_matrix(size) return H / np.linalg.norm(H, 2) elif name == 'randn': np.random.seed(0) return np.random.randn(size, size) / math.sqrt(size) elif name == 'permutation': np.random.seed(0) perm = np.random.permutation(size) P = np.eye(size)[perm] return P elif name.startswith('rank-unnorm'): r = int(name[11:]) np.random.seed(0) G = np.random.randn(size, r) H = np.random.randn(size, r) M = G @ H.T # M /= math.sqrt(size*r) return M elif name.startswith('rank'): r = int(name[4:]) np.random.seed(0) G = np.random.randn(size, r) H = np.random.randn(size, r) M = G @ H.T M /= math.sqrt(size*r) return M elif name.startswith('sparse'): s = int(name[6:]) # 2rn parameters np.random.seed(0) mask = sparse.random(size, size, density=s/size, data_rvs=np.ones) M = np.random.randn(size, size) * (mask.toarray()) M /= math.sqrt(s) return M elif name.startswith('toeplitz'): r = int(name[8:]) G = np.random.randn(size, r) / math.sqrt(size*r) H = np.random.randn(size, r) / math.sqrt(size*r) M = toeplitz_like(G, H) return M elif name == 'fastfood': n = size S = np.random.randn(n) G = np.random.randn(n) B = np.random.randn(n) # P = np.arange(n) P = np.random.permutation(n) H = hadamard(n) # SHGPHB # print(H) # print((H*B)[P,:]) # print((H @ (G[:,np.newaxis] * (H * B)[P,:]))) F = S[:,np.newaxis] * (H @ (G[:,np.newaxis] * (H * B)[P,:])) / n return F # x = np.random.randn(batch_size,n) # HB = hadamard_transform(B) # PHBx = HBx[:, P] # HGPHBx = hadamard_transform(G*PHBx) # return S*HGPHBx elif name == 'butterfly': # n (log n+1) params in the hierarchy b = Butterfly(in_size=size, out_size=size, bias=False, tied_weight=False, param='odo', nblocks=0) M = b(torch.eye(size)) return M.cpu().detach().numpy() else: assert False, 'Target matrix name not recognized or implemented'
from test_factor_multiply import twiddle_list_concat exps = np.arange(6, 14) sizes = 1 << exps batch_size = 256 ntrials = [100000, 100000, 10000, 10000, 10000, 10000, 10000, 10000] dense_times = np.zeros(exps.size) fft_times = np.zeros(exps.size) butterfly_times = np.zeros(exps.size) for idx_n, (n, ntrial) in enumerate(zip(sizes, ntrials)): print(n) B = Block2x2DiagProduct(n).to('cuda') L = torch.nn.Linear(n, n, bias=False).to('cuda') x = torch.randn(batch_size, n, requires_grad=True).to('cuda') grad = torch.randn_like(x) twiddle = twiddle_list_concat(B) # Dense multiply output = L(x) # Do it once to initialize cuBlas handle and such torch.autograd.grad(output, (L.weight, x), grad) torch.cuda.synchronize() start = time.perf_counter() for _ in range(ntrial): output = L(x) torch.autograd.grad(output, (L.weight, x), grad) torch.cuda.synchronize() end = time.perf_counter()
exps = np.arange(6, 14) sizes = 1 << exps ntrials = [100000, 100000, 1000, 100, 100, 10, 10, 10] dense_times = np.zeros(exps.size) fft_times = np.zeros(exps.size) scipyfft_times = np.zeros(exps.size) dct_times = np.zeros(exps.size) dst_times = np.zeros(exps.size) bp_times = np.zeros(exps.size) for idx_n, (n, ntrial) in enumerate(zip(sizes, ntrials)): print(n) x = np.random.random(n).astype(np.float32) B = Block2x2DiagProduct(n) P = BlockPermProduct(n) B_matrix = B(torch.eye(int(n))).t().contiguous() B_matrix_np = B_matrix.detach().numpy() ABCDs = Block2x2DiagProduct_to_ABCDs(B) perm = P.argmax().detach().numpy().astype(int) # Dense multiply start = timer() [B_matrix_np @ x for _ in range(ntrial)] end = timer() dense_times[idx_n] = (end - start) / ntrial # FFT start = timer()
def named_target_matrix(name, size): """ Parameter: name: name of the target matrix Return: target_matrix: (n, n) numpy array for real matrices or (n, n, 2) for complex matrices. """ if name == 'dft': return LA.dft(size, scale='sqrtn')[:, :, None].view('float64') elif name == 'idft': return np.ascontiguousarray(LA.dft( size, scale='sqrtn').conj().T)[:, :, None].view('float64') elif name == 'dft2': size_sr = int(math.sqrt(size)) matrix = np.fft.fft2(np.eye(size_sr**2).reshape(-1, size_sr, size_sr), norm='ortho').reshape(-1, size_sr**2) # matrix1d = LA.dft(size_sr, scale='sqrtn') # assert np.allclose(np.kron(m1d, m1d), matrix) # return matrix[:, :, None].view('float64') from butterfly.utils import bitreversal_permutation br_perm = bitreversal_permutation(size_sr) br_perm2 = np.arange(size_sr**2).reshape( size_sr, size_sr)[br_perm][:, br_perm].reshape(-1) matrix = np.ascontiguousarray(matrix[:, br_perm2]) return matrix[:, :, None].view('float64') elif name == 'dct': # Need to transpose as dct acts on rows of matrix np.eye, not columns # return dct(np.eye(size), norm='ortho').T return dct(np.eye(size)).T / math.sqrt(size) elif name == 'dst': return dst(np.eye(size)).T / math.sqrt(size) elif name == 'hadamard': return LA.hadamard(size) / math.sqrt(size) elif name == 'hadamard2': size_sr = int(math.sqrt(size)) matrix1d = LA.hadamard(size_sr) / math.sqrt(size_sr) return np.kron(matrix1d, matrix1d) elif name == 'b2': size_sr = int(math.sqrt(size)) import torch from butterfly import Block2x2DiagProduct b = Block2x2DiagProduct(size_sr) matrix1d = b(torch.eye(size_sr)).t().detach().numpy() return np.kron(matrix1d, matrix1d) elif name == 'convolution': np.random.seed(0) x = np.random.randn(size) return LA.circulant(x) / math.sqrt(size) elif name == 'hartley': return hartley_matrix(size) / math.sqrt(size) elif name == 'haar': return haar_matrix(size, normalized=True) / math.sqrt(size) elif name == 'legendre': grid = np.linspace(-1, 1, size + 2)[1:-1] return legendre.legvander(grid, size - 1).T / math.sqrt(size) elif name == 'hilbert': H = hilbert_matrix(size) return H / np.linalg.norm(H, 2) elif name == 'randn': np.random.seed(0) return np.random.randn(size, size) / math.sqrt(size) else: assert False, 'Target matrix name not recognized or implemented'