def forward(ctx, twiddle, input, increasing_stride=True): """ Parameters: twiddle: (nstack, n - 1, 2, 2) if real or (nstack, n - 1, 2, 2, 2) if complex input: (batch_size, n) if real or (batch_size, n, 2) if complex increasing_stride: whether to multiply with increasing stride (e.g. 2, 4, ..., n/2) or decreasing stride (e.g., n/2, n/4, ..., 2). Note that this only changes the order of multiplication, not how twiddle is stored. In other words, twiddle[@log_stride] always stores the twiddle for @stride. Returns: output: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex """ output_and_intermediate = butterfly_multiply_intermediate( twiddle, input, increasing_stride) ctx.save_for_backward(twiddle, output_and_intermediate) ctx._increasing_stride = increasing_stride return output_and_intermediate[-1]
def backward(ctx, grad): """ Parameters: grad: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex twiddle: (nstack, n - 1, 2, 2) if real or (nstack, n - 1, 2, 2, 2) if complex output + intermediate values for backward: (log n + 1, batch_size, nstack, n) if real or (log n + 1, batch_size, nstack, n, 2) if complex Return: d_twiddle: (nstack, n - 1, 2, 2) if real or (nstack, n - 1, 2, 2, 2) if complex d_input: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex """ # twiddle, output_and_intermediate = ctx.saved_tensors twiddle, input = ctx.saved_tensors increasing_stride = ctx._increasing_stride output_and_intermediate = butterfly_multiply_intermediate( twiddle, input, increasing_stride, True) d_coefficients, d_input = butterfly_multiply_intermediate_backward( grad, twiddle, output_and_intermediate, increasing_stride) return d_coefficients, d_input, None # Autograd requires 3 gradients
def test_butterfly_factor_intermediate_complex_cuda(self): batch_size = 10 n = 4096 B = Block2x2DiagProduct(n, complex=True).to('cuda') input_ = torch.randn(batch_size, n, 2, device='cuda', requires_grad=True) twiddle = twiddle_list_concat(B).unsqueeze(0) output_intermediate = butterfly_multiply_intermediate(twiddle, input_) output = [input_] for factor in B.factors[::-1]: output.append( butterfly_factor_mult( factor.ABCD, output[-1].view(-1, 2, factor.size // 2, 2)).view(output[-1].shape)) output = torch.stack(output) self.assertTrue( torch.allclose(output_intermediate.squeeze(2), output, rtol=self.rtol, atol=self.atol), (output_intermediate.squeeze(2) - output).abs().max().item()) grad = torch.randn_like(output[-1]) d_twiddle_intermediate, d_input_intermediate = butterfly_multiply_intermediate_backward( grad.unsqueeze(1), twiddle, output_intermediate) output[-1].backward(grad, retain_graph=True) d_input = input_.grad d_twiddle = torch.cat([ factor.ABCD.grad.permute(2, 0, 1, 3) for factor in B.factors[::-1] ]) self.assertTrue( torch.allclose(d_input_intermediate, d_input, rtol=self.rtol, atol=self.atol), (d_input_intermediate - d_input).abs().max().item()) self.assertTrue( torch.allclose(d_twiddle_intermediate, d_twiddle, rtol=self.rtol, atol=self.atol), (d_twiddle_intermediate - d_twiddle).abs().max().item())
def forward(ctx, twiddle, input): output = butterfly_multiply_intermediate(twiddle, input) ctx.save_for_backward(twiddle, output) return output[-1]