Exemplo n.º 1
0
 def backward(ctx, grad):
     """
     Parameters:
         grad: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex
         twiddle: (nstack, n - 1, 2, 2) if real or (nstack, n - 1, 2, 2, 2) if complex
         output + intermediate values for backward: (log n + 1, batch_size, nstack, n) if real or (log n + 1, batch_size, nstack, n, 2) if complex
     Return:
         d_twiddle: (nstack, n - 1, 2, 2) if real or (nstack, n - 1, 2, 2, 2) if complex
         d_input: (batch_size, n) if real or (batch_size, n, 2) if complex
     """
     twiddle, output_and_intermediate = ctx.saved_tensors
     increasing_stride = ctx._increasing_stride
     d_coefficients, d_input = butterfly_multiply_intermediate_backward(
         grad, twiddle, output_and_intermediate, increasing_stride)
     return d_coefficients, d_input, None  # Autograd requires 3 gradients
Exemplo n.º 2
0
 def test_butterfly_factor_intermediate_complex_cuda(self):
     batch_size = 10
     n = 4096
     B = Block2x2DiagProduct(n, complex=True).to('cuda')
     input_ = torch.randn(batch_size,
                          n,
                          2,
                          device='cuda',
                          requires_grad=True)
     twiddle = twiddle_list_concat(B).unsqueeze(0)
     output_intermediate = butterfly_multiply_intermediate(twiddle, input_)
     output = [input_]
     for factor in B.factors[::-1]:
         output.append(
             butterfly_factor_mult(
                 factor.ABCD, output[-1].view(-1, 2, factor.size // 2,
                                              2)).view(output[-1].shape))
     output = torch.stack(output)
     self.assertTrue(
         torch.allclose(output_intermediate.squeeze(2),
                        output,
                        rtol=self.rtol,
                        atol=self.atol),
         (output_intermediate.squeeze(2) - output).abs().max().item())
     grad = torch.randn_like(output[-1])
     d_twiddle_intermediate, d_input_intermediate = butterfly_multiply_intermediate_backward(
         grad.unsqueeze(1), twiddle, output_intermediate)
     output[-1].backward(grad, retain_graph=True)
     d_input = input_.grad
     d_twiddle = torch.cat([
         factor.ABCD.grad.permute(2, 0, 1, 3) for factor in B.factors[::-1]
     ])
     self.assertTrue(
         torch.allclose(d_input_intermediate,
                        d_input,
                        rtol=self.rtol,
                        atol=self.atol),
         (d_input_intermediate - d_input).abs().max().item())
     self.assertTrue(
         torch.allclose(d_twiddle_intermediate,
                        d_twiddle,
                        rtol=self.rtol,
                        atol=self.atol),
         (d_twiddle_intermediate - d_twiddle).abs().max().item())
Exemplo n.º 3
0
 def backward(ctx, grad):
     twiddle, output = ctx.saved_tensors
     d_coefficients, d_input = butterfly_multiply_intermediate_backward(
         grad, twiddle, output)
     return d_coefficients, d_input