Ejemplo n.º 1
0
 def forward(self, input):
     """
     Parameters:
         input: (..., size) if real or (..., size, 2) if complex
     Return:
         output: (..., size) if real or (..., size, 2) if complex
     """
     if not self.complex:
         # return ((self.ABCD * input.view(input.shape[:-1] + (1, 2, self.size // 2))).sum(dim=-2)).view(input.shape)
         return butterfly_factor_mult(self.ABCD,
                                      input.view(-1, 2, self.size //
                                                 2)).view(input.shape)
     else:
         # return (self.mul_op(self.ABCD, input.view(input.shape[:-2] + (1, 2, self.size // 2, 2))).sum(dim=-3)).view(input.shape)
         return butterfly_factor_mult(self.ABCD,
                                      input.view(-1, 2, self.size // 2,
                                                 2)).view(input.shape)
Ejemplo n.º 2
0
 def test_butterfly_factor_intermediate_complex_cuda(self):
     batch_size = 10
     n = 4096
     B = Block2x2DiagProduct(n, complex=True).to('cuda')
     input_ = torch.randn(batch_size,
                          n,
                          2,
                          device='cuda',
                          requires_grad=True)
     twiddle = twiddle_list_concat(B).unsqueeze(0)
     output_intermediate = butterfly_multiply_intermediate(twiddle, input_)
     output = [input_]
     for factor in B.factors[::-1]:
         output.append(
             butterfly_factor_mult(
                 factor.ABCD, output[-1].view(-1, 2, factor.size // 2,
                                              2)).view(output[-1].shape))
     output = torch.stack(output)
     self.assertTrue(
         torch.allclose(output_intermediate.squeeze(2),
                        output,
                        rtol=self.rtol,
                        atol=self.atol),
         (output_intermediate.squeeze(2) - output).abs().max().item())
     grad = torch.randn_like(output[-1])
     d_twiddle_intermediate, d_input_intermediate = butterfly_multiply_intermediate_backward(
         grad.unsqueeze(1), twiddle, output_intermediate)
     output[-1].backward(grad, retain_graph=True)
     d_input = input_.grad
     d_twiddle = torch.cat([
         factor.ABCD.grad.permute(2, 0, 1, 3) for factor in B.factors[::-1]
     ])
     self.assertTrue(
         torch.allclose(d_input_intermediate,
                        d_input,
                        rtol=self.rtol,
                        atol=self.atol),
         (d_input_intermediate - d_input).abs().max().item())
     self.assertTrue(
         torch.allclose(d_twiddle_intermediate,
                        d_twiddle,
                        rtol=self.rtol,
                        atol=self.atol),
         (d_twiddle_intermediate - d_twiddle).abs().max().item())
Ejemplo n.º 3
0
def profile_butterfly_mult():
    nsteps = 10
    batch_size = 100
    n = 1024
    B = Block2x2DiagProduct(n)
    x = torch.randn(batch_size, n)
    # B(x)
    optimizer = optim.Adam(B.parameters(), lr=0.01)
    for _ in range(nsteps):
        optimizer.zero_grad()
        # output = B(x)
        # loss = nn.functional.mse_loss(output, x)
        output = x
        for factor in B.factors[::-1]:
            output = butterfly_factor_mult(factor.ABCD, output.view(-1, 2, factor.size // 2)).view(x.shape)
        # output = output.reshape(x.shape)
        loss = output.sum()
        loss.backward()
        optimizer.step()
Ejemplo n.º 4
0
 def test_butterfly_factor_complex_cpu(self):
     batch_size = 10
     n = 4096
     B = Block2x2DiagProduct(n, complex=True)
     input_ = torch.randn(batch_size, n, 2, requires_grad=True)
     output = input_
     for factor in B.factors[::-1]:
         prev = output
         output = butterfly_factor_mult(
             factor.ABCD, output.view(-1, 2, factor.size // 2,
                                      2)).view(prev.shape)
         output_slow = (complex_mul(
             factor.ABCD, prev.view(-1, 1, 2, factor.size // 2,
                                    2)).sum(dim=-3)).view(prev.shape)
         self.assertTrue(
             torch.allclose(output,
                            output_slow,
                            rtol=self.rtol,
                            atol=self.atol),
             (output - output_slow).abs().max().item())
         grad = torch.randn_like(output)
         d_twiddle, d_input = torch.autograd.grad(output,
                                                  (factor.ABCD, prev),
                                                  grad,
                                                  retain_graph=True)
         d_twiddle_slow, d_input_slow = torch.autograd.grad(
             output_slow, (factor.ABCD, prev), grad, retain_graph=True)
         self.assertTrue(
             torch.allclose(d_twiddle,
                            d_twiddle_slow,
                            rtol=self.rtol,
                            atol=self.atol),
             (d_twiddle - d_twiddle_slow).abs().max().item())
         self.assertTrue(
             torch.allclose(d_input,
                            d_input_slow,
                            rtol=self.rtol,
                            atol=self.atol),
             (d_input - d_input_slow).abs().max().item())
Ejemplo n.º 5
0
 def test_butterfly_factor_cuda(self):
     batch_size = 100
     n = 4096  # To test n > MAX_BLOCK_SIZE
     B = Block2x2DiagProduct(n).to('cuda')
     input_ = torch.randn(batch_size, n, device='cuda', requires_grad=True)
     output = input_
     for factor in B.factors[::-1]:
         prev = output
         output = butterfly_factor_mult(
             factor.ABCD, output.view(-1, 2,
                                      factor.size // 2)).view(prev.shape)
         output_slow = ((factor.ABCD *
                         prev.view(-1, 1, 2, factor.size // 2)).sum(
                             dim=-2)).view(prev.shape)
         self.assertTrue(
             torch.allclose(output,
                            output_slow,
                            rtol=self.rtol,
                            atol=self.atol),
             (output - output_slow).abs().max().item())
         grad = torch.randn_like(output)
         d_twiddle, d_input = torch.autograd.grad(output,
                                                  (factor.ABCD, prev),
                                                  grad,
                                                  retain_graph=True)
         d_twiddle_slow, d_input_slow = torch.autograd.grad(
             output_slow, (factor.ABCD, prev), grad, retain_graph=True)
         self.assertTrue(
             torch.allclose(d_twiddle,
                            d_twiddle_slow,
                            rtol=self.rtol,
                            atol=self.atol),
             (factor.size, (d_twiddle - d_twiddle_slow).abs().max().item()))
         self.assertTrue(
             torch.allclose(d_input,
                            d_input_slow,
                            rtol=self.rtol,
                            atol=self.atol),
             (d_input - d_input_slow).abs().max().item())