def forward(ctx, twiddle, input, increasing_stride=True):
     """
     Parameters:
         twiddle: (nstack, n - 1, 2, 2) if real or (nstack, n - 1, 2, 2, 2) if complex
         input: (batch_size, n) if real or (batch_size, n, 2) if complex
         increasing_stride: whether to multiply with increasing stride (e.g. 2, 4, ..., n/2) or
             decreasing stride (e.g., n/2, n/4, ..., 2).
             Note that this only changes the order of multiplication, not how twiddle is stored.
             In other words, twiddle[@log_stride] always stores the twiddle for @stride.
     Returns:
         output: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex
     """
     output_and_intermediate = butterfly_multiply_intermediate(
         twiddle, input, increasing_stride)
     ctx.save_for_backward(twiddle, output_and_intermediate)
     ctx._increasing_stride = increasing_stride
     return output_and_intermediate[-1]
Example #2
0
 def backward(ctx, grad):
     """
     Parameters:
         grad: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex
         twiddle: (nstack, n - 1, 2, 2) if real or (nstack, n - 1, 2, 2, 2) if complex
         output + intermediate values for backward: (log n + 1, batch_size, nstack, n) if real or (log n + 1, batch_size, nstack, n, 2) if complex
     Return:
         d_twiddle: (nstack, n - 1, 2, 2) if real or (nstack, n - 1, 2, 2, 2) if complex
         d_input: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex
     """
     # twiddle, output_and_intermediate = ctx.saved_tensors
     twiddle, input = ctx.saved_tensors
     increasing_stride = ctx._increasing_stride
     output_and_intermediate = butterfly_multiply_intermediate(
         twiddle, input, increasing_stride, True)
     d_coefficients, d_input = butterfly_multiply_intermediate_backward(
         grad, twiddle, output_and_intermediate, increasing_stride)
     return d_coefficients, d_input, None  # Autograd requires 3 gradients
Example #3
0
 def test_butterfly_factor_intermediate_complex_cuda(self):
     batch_size = 10
     n = 4096
     B = Block2x2DiagProduct(n, complex=True).to('cuda')
     input_ = torch.randn(batch_size,
                          n,
                          2,
                          device='cuda',
                          requires_grad=True)
     twiddle = twiddle_list_concat(B).unsqueeze(0)
     output_intermediate = butterfly_multiply_intermediate(twiddle, input_)
     output = [input_]
     for factor in B.factors[::-1]:
         output.append(
             butterfly_factor_mult(
                 factor.ABCD, output[-1].view(-1, 2, factor.size // 2,
                                              2)).view(output[-1].shape))
     output = torch.stack(output)
     self.assertTrue(
         torch.allclose(output_intermediate.squeeze(2),
                        output,
                        rtol=self.rtol,
                        atol=self.atol),
         (output_intermediate.squeeze(2) - output).abs().max().item())
     grad = torch.randn_like(output[-1])
     d_twiddle_intermediate, d_input_intermediate = butterfly_multiply_intermediate_backward(
         grad.unsqueeze(1), twiddle, output_intermediate)
     output[-1].backward(grad, retain_graph=True)
     d_input = input_.grad
     d_twiddle = torch.cat([
         factor.ABCD.grad.permute(2, 0, 1, 3) for factor in B.factors[::-1]
     ])
     self.assertTrue(
         torch.allclose(d_input_intermediate,
                        d_input,
                        rtol=self.rtol,
                        atol=self.atol),
         (d_input_intermediate - d_input).abs().max().item())
     self.assertTrue(
         torch.allclose(d_twiddle_intermediate,
                        d_twiddle,
                        rtol=self.rtol,
                        atol=self.atol),
         (d_twiddle_intermediate - d_twiddle).abs().max().item())
Example #4
0
 def forward(ctx, twiddle, input):
     output = butterfly_multiply_intermediate(twiddle, input)
     ctx.save_for_backward(twiddle, output)
     return output[-1]