示例#1
0
 def test_butterfly_complex_inplace_cpu(self):
     batch_size = 10
     n = 4096
     # TODO: in-place implementation doesn't support nstack for now
     nstack = 1
     b = Butterfly(n, n, bias=False, complex=True, ortho_init=True)
     twiddle = b.twiddle
     input = torch.randn(batch_size, n, 2, requires_grad=True)
     output_inplace = butterfly_mult_inplace(twiddle.squeeze(0), input)
     output_torch = butterfly_mult_torch(twiddle, input).squeeze(1)
     self.assertTrue(
         torch.allclose(output_inplace,
                        output_torch,
                        rtol=self.rtol,
                        atol=self.atol),
         (output_inplace - output_torch).abs().max().item())
示例#2
0
 def test_butterfly_inplace_cuda(self):
     batch_size = 10
     n = 4096
     # TODO: in-place implementation doesn't support nstack for now
     nstack = 1
     b = Butterfly(n, n, bias=False, ortho_init=True).to('cuda')
     twiddle = b.twiddle
     input = torch.randn(batch_size,
                         n,
                         requires_grad=True,
                         device=twiddle.device)
     output_inplace = butterfly_mult_inplace(twiddle.squeeze(0), input)
     output_torch = butterfly_mult_torch(twiddle, input).squeeze(1)
     self.assertTrue(
         torch.allclose(output_inplace,
                        output_torch,
                        rtol=self.rtol,
                        atol=self.atol),
         (output_inplace - output_torch).abs().max().item())
     grad = torch.randn_like(output_torch)
     d_twiddle_inplace, d_input_inplace = torch.autograd.grad(
         output_inplace, (twiddle, input), grad, retain_graph=True)
     d_twiddle_torch, d_input_torch = torch.autograd.grad(output_torch,
                                                          (twiddle, input),
                                                          grad,
                                                          retain_graph=True)
     self.assertTrue(
         torch.allclose(d_input_inplace,
                        d_input_torch,
                        rtol=self.rtol,
                        atol=self.atol),
         (d_input_inplace - d_input_torch).abs().max().item())
     # print((d_twiddle_inplace - d_twiddle_torch) / d_twiddle_torch)
     self.assertTrue(
         torch.allclose(d_twiddle_inplace,
                        d_twiddle_torch,
                        rtol=self.rtol,
                        atol=self.atol),
         ((d_twiddle_inplace - d_twiddle_torch) /
          d_twiddle_torch).abs().max().item())
示例#3
0
torch.cuda.synchronize()
end = time.perf_counter()
print(f'Butterfly mult factors backward: {end - start}s')
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(nsteps):
    output = butterfly_mult_factors(twiddle.squeeze(0), x)
    torch.autograd.grad(output, (twiddle, x), grad, retain_graph=True)
torch.cuda.synchronize()
end = time.perf_counter()
print(f'Butterfly mult factors together: {end - start}s')

torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(nsteps):
    output = butterfly_mult_inplace(twiddle.squeeze(0), x)
torch.cuda.synchronize()
end = time.perf_counter()
print(f'Butterfly mult in-place forward: {end - start}s')
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(nsteps):
    torch.autograd.grad(output, (twiddle, x), grad, retain_graph=True)
torch.cuda.synchronize()
end = time.perf_counter()
print(f'Butterfly mult in-place backward: {end - start}s')
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(nsteps):
    output = butterfly_mult_inplace(twiddle.squeeze(0), x)
    torch.autograd.grad(output, (twiddle, x), grad)