def test_butterfly_complex_inplace_cpu(self): batch_size = 10 n = 4096 # TODO: in-place implementation doesn't support nstack for now nstack = 1 b = Butterfly(n, n, bias=False, complex=True, ortho_init=True) twiddle = b.twiddle input = torch.randn(batch_size, n, 2, requires_grad=True) output_inplace = butterfly_mult_inplace(twiddle.squeeze(0), input) output_torch = butterfly_mult_torch(twiddle, input).squeeze(1) self.assertTrue( torch.allclose(output_inplace, output_torch, rtol=self.rtol, atol=self.atol), (output_inplace - output_torch).abs().max().item())
def test_butterfly_inplace_cuda(self): batch_size = 10 n = 4096 # TODO: in-place implementation doesn't support nstack for now nstack = 1 b = Butterfly(n, n, bias=False, ortho_init=True).to('cuda') twiddle = b.twiddle input = torch.randn(batch_size, n, requires_grad=True, device=twiddle.device) output_inplace = butterfly_mult_inplace(twiddle.squeeze(0), input) output_torch = butterfly_mult_torch(twiddle, input).squeeze(1) self.assertTrue( torch.allclose(output_inplace, output_torch, rtol=self.rtol, atol=self.atol), (output_inplace - output_torch).abs().max().item()) grad = torch.randn_like(output_torch) d_twiddle_inplace, d_input_inplace = torch.autograd.grad( output_inplace, (twiddle, input), grad, retain_graph=True) d_twiddle_torch, d_input_torch = torch.autograd.grad(output_torch, (twiddle, input), grad, retain_graph=True) self.assertTrue( torch.allclose(d_input_inplace, d_input_torch, rtol=self.rtol, atol=self.atol), (d_input_inplace - d_input_torch).abs().max().item()) # print((d_twiddle_inplace - d_twiddle_torch) / d_twiddle_torch) self.assertTrue( torch.allclose(d_twiddle_inplace, d_twiddle_torch, rtol=self.rtol, atol=self.atol), ((d_twiddle_inplace - d_twiddle_torch) / d_twiddle_torch).abs().max().item())
torch.cuda.synchronize() end = time.perf_counter() print(f'Butterfly mult factors backward: {end - start}s') torch.cuda.synchronize() start = time.perf_counter() for _ in range(nsteps): output = butterfly_mult_factors(twiddle.squeeze(0), x) torch.autograd.grad(output, (twiddle, x), grad, retain_graph=True) torch.cuda.synchronize() end = time.perf_counter() print(f'Butterfly mult factors together: {end - start}s') torch.cuda.synchronize() start = time.perf_counter() for _ in range(nsteps): output = butterfly_mult_inplace(twiddle.squeeze(0), x) torch.cuda.synchronize() end = time.perf_counter() print(f'Butterfly mult in-place forward: {end - start}s') torch.cuda.synchronize() start = time.perf_counter() for _ in range(nsteps): torch.autograd.grad(output, (twiddle, x), grad, retain_graph=True) torch.cuda.synchronize() end = time.perf_counter() print(f'Butterfly mult in-place backward: {end - start}s') torch.cuda.synchronize() start = time.perf_counter() for _ in range(nsteps): output = butterfly_mult_inplace(twiddle.squeeze(0), x) torch.autograd.grad(output, (twiddle, x), grad)