def forward(self, input): """ Parameters: input: (..., size) if real or (..., size, 2) if complex Return: output: (..., size) if real or (..., size, 2) if complex """ if not self.complex: # return ((self.ABCD * input.view(input.shape[:-1] + (1, 2, self.size // 2))).sum(dim=-2)).view(input.shape) return butterfly_factor_mult(self.ABCD, input.view(-1, 2, self.size // 2)).view(input.shape) else: # return (self.mul_op(self.ABCD, input.view(input.shape[:-2] + (1, 2, self.size // 2, 2))).sum(dim=-3)).view(input.shape) return butterfly_factor_mult(self.ABCD, input.view(-1, 2, self.size // 2, 2)).view(input.shape)
def test_butterfly_factor_intermediate_complex_cuda(self): batch_size = 10 n = 4096 B = Block2x2DiagProduct(n, complex=True).to('cuda') input_ = torch.randn(batch_size, n, 2, device='cuda', requires_grad=True) twiddle = twiddle_list_concat(B).unsqueeze(0) output_intermediate = butterfly_multiply_intermediate(twiddle, input_) output = [input_] for factor in B.factors[::-1]: output.append( butterfly_factor_mult( factor.ABCD, output[-1].view(-1, 2, factor.size // 2, 2)).view(output[-1].shape)) output = torch.stack(output) self.assertTrue( torch.allclose(output_intermediate.squeeze(2), output, rtol=self.rtol, atol=self.atol), (output_intermediate.squeeze(2) - output).abs().max().item()) grad = torch.randn_like(output[-1]) d_twiddle_intermediate, d_input_intermediate = butterfly_multiply_intermediate_backward( grad.unsqueeze(1), twiddle, output_intermediate) output[-1].backward(grad, retain_graph=True) d_input = input_.grad d_twiddle = torch.cat([ factor.ABCD.grad.permute(2, 0, 1, 3) for factor in B.factors[::-1] ]) self.assertTrue( torch.allclose(d_input_intermediate, d_input, rtol=self.rtol, atol=self.atol), (d_input_intermediate - d_input).abs().max().item()) self.assertTrue( torch.allclose(d_twiddle_intermediate, d_twiddle, rtol=self.rtol, atol=self.atol), (d_twiddle_intermediate - d_twiddle).abs().max().item())
def profile_butterfly_mult(): nsteps = 10 batch_size = 100 n = 1024 B = Block2x2DiagProduct(n) x = torch.randn(batch_size, n) # B(x) optimizer = optim.Adam(B.parameters(), lr=0.01) for _ in range(nsteps): optimizer.zero_grad() # output = B(x) # loss = nn.functional.mse_loss(output, x) output = x for factor in B.factors[::-1]: output = butterfly_factor_mult(factor.ABCD, output.view(-1, 2, factor.size // 2)).view(x.shape) # output = output.reshape(x.shape) loss = output.sum() loss.backward() optimizer.step()
def test_butterfly_factor_complex_cpu(self): batch_size = 10 n = 4096 B = Block2x2DiagProduct(n, complex=True) input_ = torch.randn(batch_size, n, 2, requires_grad=True) output = input_ for factor in B.factors[::-1]: prev = output output = butterfly_factor_mult( factor.ABCD, output.view(-1, 2, factor.size // 2, 2)).view(prev.shape) output_slow = (complex_mul( factor.ABCD, prev.view(-1, 1, 2, factor.size // 2, 2)).sum(dim=-3)).view(prev.shape) self.assertTrue( torch.allclose(output, output_slow, rtol=self.rtol, atol=self.atol), (output - output_slow).abs().max().item()) grad = torch.randn_like(output) d_twiddle, d_input = torch.autograd.grad(output, (factor.ABCD, prev), grad, retain_graph=True) d_twiddle_slow, d_input_slow = torch.autograd.grad( output_slow, (factor.ABCD, prev), grad, retain_graph=True) self.assertTrue( torch.allclose(d_twiddle, d_twiddle_slow, rtol=self.rtol, atol=self.atol), (d_twiddle - d_twiddle_slow).abs().max().item()) self.assertTrue( torch.allclose(d_input, d_input_slow, rtol=self.rtol, atol=self.atol), (d_input - d_input_slow).abs().max().item())
def test_butterfly_factor_cuda(self): batch_size = 100 n = 4096 # To test n > MAX_BLOCK_SIZE B = Block2x2DiagProduct(n).to('cuda') input_ = torch.randn(batch_size, n, device='cuda', requires_grad=True) output = input_ for factor in B.factors[::-1]: prev = output output = butterfly_factor_mult( factor.ABCD, output.view(-1, 2, factor.size // 2)).view(prev.shape) output_slow = ((factor.ABCD * prev.view(-1, 1, 2, factor.size // 2)).sum( dim=-2)).view(prev.shape) self.assertTrue( torch.allclose(output, output_slow, rtol=self.rtol, atol=self.atol), (output - output_slow).abs().max().item()) grad = torch.randn_like(output) d_twiddle, d_input = torch.autograd.grad(output, (factor.ABCD, prev), grad, retain_graph=True) d_twiddle_slow, d_input_slow = torch.autograd.grad( output_slow, (factor.ABCD, prev), grad, retain_graph=True) self.assertTrue( torch.allclose(d_twiddle, d_twiddle_slow, rtol=self.rtol, atol=self.atol), (factor.size, (d_twiddle - d_twiddle_slow).abs().max().item())) self.assertTrue( torch.allclose(d_input, d_input_slow, rtol=self.rtol, atol=self.atol), (d_input - d_input_slow).abs().max().item())