Example #1
0
 def test_diagonal_butterfly(self):
     batch_size = 10
     for in_size, out_size in [(9, 15), (15, 9)]:
         for nblocks in [1, 2, 3]:
             for complex in [False, True]:
                 for increasing_stride in [True, False]:
                     for diag_first in [True, False]:
                         dtype = torch.float32 if not complex else torch.complex64
                         input = torch.randn(batch_size, in_size, dtype=dtype)
                         b = torch_butterfly.Butterfly(in_size, out_size, bias=False,
                                                       complex=complex,
                                                       increasing_stride=increasing_stride,
                                                       nblocks=nblocks)
                         twiddle_copy = b.twiddle.clone()
                         diagonal = torch.randn(in_size if diag_first else out_size,
                                                dtype=dtype)
                         out = b(input * diagonal) if diag_first else b(input) * diagonal
                         for inplace in [True, False]:
                             b_copy = copy.deepcopy(b)  # otherwise inplace would modify b
                             bd = torch_butterfly.combine.diagonal_butterfly(
                                 b_copy, diagonal, diag_first, inplace)
                             out_bd = bd(input)
                             self.assertTrue(torch.allclose(out_bd, out, self.rtol, self.atol))
                             if not inplace:
                                 self.assertTrue(torch.allclose(b.twiddle, twiddle_copy,
                                                                self.rtol, self.atol))
Example #2
0
 def test_transpose_conjugate_multiply(self):
     n = 16
     for complex in [False, True]:
         for increasing_stride in [True, False]:
             for nblocks in [1, 2, 3]:
                 b = torch_butterfly.Butterfly(n,
                                               n,
                                               False,
                                               complex,
                                               increasing_stride,
                                               nblocks=nblocks)
                 dtype = torch.float32 if not complex else torch.complex64
                 input = torch.eye(n, dtype=dtype)
                 matrix = b(input).t()
                 matrix_t = b.forward(input, transpose=True).t()
                 matrix_conj = b.forward(input, conjugate=True).t()
                 matrix_t_conj = b.forward(input,
                                           transpose=True,
                                           conjugate=True).t()
                 self.assertTrue(
                     torch.allclose(matrix.t(), matrix_t, self.rtol,
                                    self.atol),
                     (complex, increasing_stride, nblocks))
                 self.assertTrue(
                     torch.allclose(matrix.conj(), matrix_conj, self.rtol,
                                    self.atol),
                     (complex, increasing_stride, nblocks))
                 self.assertTrue(
                     torch.allclose(matrix.t().conj(), matrix_t_conj,
                                    self.rtol, self.atol),
                     (complex, increasing_stride, nblocks))
Example #3
0
 def test_butterfly_kronecker(self):
     batch_size = 10
     n1 = 16
     n2 = 32
     for complex in [False, True]:
         for increasing_stride in [True, False]:
             dtype = torch.float32 if not complex else torch.complex64
             input = torch.randn(batch_size, n2, n1, dtype=dtype)
             b1 = torch_butterfly.Butterfly(n1, n1, bias=False, complex=complex,
                                            increasing_stride=increasing_stride)
             b2 = torch_butterfly.Butterfly(n2, n2, bias=False, complex=complex,
                                            increasing_stride=increasing_stride)
             b_tp = torch_butterfly.combine.TensorProduct(b1, b2)
             out_tp = b_tp(input)
             b = torch_butterfly.combine.butterfly_kronecker(b1, b2)
             out = b(input.reshape(batch_size, n2 * n1)).reshape(batch_size, n2, n1)
             self.assertTrue(torch.allclose(out, out_tp, self.rtol, self.atol))
Example #4
0
 def test_butterfly_product(self):
     batch_size = 10
     n = 16
     in_size = 13
     out_size = 15
     for complex in [False, True]:
         for inc_stride1, inc_stride2 in itertools.product([True, False], [True, False]):
             dtype = torch.float32 if not complex else torch.complex64
             input = torch.randn(batch_size, in_size, dtype=dtype)
             b1 = torch_butterfly.Butterfly(in_size, n, bias=False, complex=complex,
                                            increasing_stride=inc_stride1)
             b2 = torch_butterfly.Butterfly(n, out_size, bias=False, complex=complex,
                                            increasing_stride=inc_stride2)
             out = b2(b1(input))
             b = torch_butterfly.combine.butterfly_product(b1, b2)
             out_prod = b(input)
             self.assertTrue(torch.allclose(out_prod, out, self.rtol, self.atol))
Example #5
0
 def test_butterfly_bmm(self):
     batch_size = 10
     matrix_batch = 3
     for device in ['cpu'
                    ] + ([] if not torch.cuda.is_available() else ['cuda']):
         for in_size, out_size in [(7, 15), (15, 7)]:
             for complex in [False, True]:
                 for increasing_stride in [True, False]:
                     for nblocks in [1, 2, 3]:
                         # Test shape
                         b_bmm = torch_butterfly.ButterflyBmm(
                             in_size,
                             out_size,
                             matrix_batch,
                             True,
                             complex,
                             increasing_stride,
                             nblocks=nblocks).to(device)
                         dtype = torch.float32 if not complex else torch.complex64
                         input = torch.randn(batch_size,
                                             matrix_batch,
                                             in_size,
                                             dtype=dtype,
                                             device=device)
                         output = b_bmm(input)
                         self.assertTrue(
                             output.shape == (batch_size, matrix_batch,
                                              out_size),
                             (output.shape, device,
                              (in_size, out_size), nblocks))
                         # Check that the result is the same as looping over butterflies
                         output_loop = []
                         for i in range(matrix_batch):
                             b = torch_butterfly.Butterfly(
                                 in_size,
                                 out_size,
                                 True,
                                 complex,
                                 increasing_stride,
                                 nblocks=nblocks).to(device)
                             with torch.no_grad():
                                 b.twiddle.copy_(
                                     b_bmm.twiddle[i *
                                                   b_bmm.nstacks:(i + 1) *
                                                   b_bmm.nstacks])
                                 b.bias.copy_(b_bmm.bias[i])
                             output_loop.append(b(input[:, i]))
                         output_loop = torch.stack(output_loop, dim=1)
                         self.assertTrue(
                             torch.allclose(output, output_loop),
                             ((output - output_loop).abs().max().item(),
                              output.shape, device,
                              (in_size, out_size), complex))
Example #6
0
 def test_flip_increasing_stride(self):
     batch_size = 10
     for n in [16, 64]:
         for nblocks in [1, 2, 3]:
             for complex in [False, True]:
                 for increasing_stride in [True, False]:
                     dtype = torch.float32 if not complex else torch.complex64
                     input = torch.randn(batch_size, n, dtype=dtype)
                     b = torch_butterfly.Butterfly(n, n, bias=False, complex=complex,
                                                   increasing_stride=increasing_stride,
                                                   nblocks=nblocks)
                     b_new = torch_butterfly.combine.flip_increasing_stride(b)
                     self.assertTrue(b_new[1].increasing_stride == (not b.increasing_stride))
                     self.assertTrue(torch.allclose(b_new(input), b(input),
                                                    self.rtol, self.atol))
Example #7
0
 def test_autograd(self):
     """Check if autograd works (especially for complex), by trying to match a 4x4 matrix.
     """
     size = 4
     niters = 10000
     true_model = nn.Linear(size, size, bias=False)
     x = torch.eye(size)
     with torch.no_grad():
         y = true_model(x)
     for complex in [False, True]:
         if complex:
             model = nn.Sequential(
                 torch_butterfly.complex_utils.Real2Complex(),
                 torch_butterfly.Butterfly(size,
                                           size,
                                           bias=False,
                                           complex=complex),
                 torch_butterfly.complex_utils.Complex2Real(),
             )
         else:
             model = torch_butterfly.Butterfly(size,
                                               size,
                                               bias=False,
                                               complex=complex)
         with torch.no_grad():
             inital_loss = F.mse_loss(model(x), y)
         optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
         for i in range(niters):
             out = model(x)
             loss = F.mse_loss(out, y)
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
         # At least loss should decrease
         # print(inital_loss, loss)
         self.assertTrue(loss.item() < inital_loss.item())
Example #8
0
 def test_matrix_to_butterfly_factor(self):
     num_repeats = 10
     for n in [2, 16, 64]:
         for _ in range(num_repeats):
             log_n = int(math.ceil(math.log2(n)))
             for log_k in range(1, log_n + 1):
                 b = torch_butterfly.Butterfly(n,
                                               n,
                                               bias=False,
                                               init='identity')
                 factor = torch.randn(n // 2, 2, 2)
                 b.twiddle[0, 0, log_k - 1].copy_(factor)
                 matrix = b(torch.eye(n)).t()
                 factor_out = matrix_to_butterfly_factor(
                     matrix.detach().numpy(),
                     log_k,
                     pytorch_format=True,
                     check_input=True)
                 self.assertTrue(torch.allclose(factor_out, factor))
Example #9
0
 def test_subtwiddle(self):
     batch_size = 10
     n = 16
     input_size = 8
     for complex in [False, True]:
         for increasing_stride in [True, False]:
             for nblocks in [1, 2, 3]:
                 b = torch_butterfly.Butterfly(n,
                                               n,
                                               True,
                                               complex,
                                               increasing_stride,
                                               nblocks=nblocks)
                 dtype = torch.float32 if not complex else torch.complex64
                 input = torch.randn(batch_size, input_size, dtype=dtype)
                 output = b(input, subtwiddle=True)
                 self.assertTrue(
                     output.shape == (batch_size, input_size),
                     (output.shape, n, input_size, complex, nblocks))
Example #10
0
 def test_butterfly(self):
     batch_size = 10
     for device in ['cpu'
                    ] + ([] if not torch.cuda.is_available() else ['cuda']):
         for in_size, out_size in [(7, 15), (15, 7)]:
             for complex in [False, True]:
                 for increasing_stride in [True, False]:
                     for init in ['randn', 'ortho', 'identity']:
                         for nblocks in [1, 2, 3]:
                             b = torch_butterfly.Butterfly(
                                 in_size,
                                 out_size,
                                 True,
                                 complex,
                                 increasing_stride,
                                 init,
                                 nblocks=nblocks).to(device)
                             dtype = torch.float32 if not complex else torch.complex64
                             input = torch.randn(batch_size,
                                                 in_size,
                                                 dtype=dtype,
                                                 device=device)
                             output = b(input)
                             self.assertTrue(
                                 output.shape == (batch_size, out_size),
                                 (output.shape, device, (in_size, out_size),
                                  complex, init, nblocks))
                             if init == 'ortho':
                                 twiddle = b.twiddle if not b.complex else view_as_complex(
                                     b.twiddle)
                                 twiddle_np = twiddle.detach().to(
                                     'cpu').numpy()
                                 twiddle_np = twiddle_np.reshape(-1, 2, 2)
                                 twiddle_norm = np.linalg.norm(twiddle_np,
                                                               ord=2,
                                                               axis=(1, 2))
                                 self.assertTrue(
                                     np.allclose(twiddle_norm, 1),
                                     (twiddle_norm, device,
                                      (in_size, out_size), complex, init))
Example #11
0
 def test_butterfly_bmm_tensorproduct(self):
     # Just to show how to do TensorProduct (e.g., Conv2d) with ButterflyBmm
     batch_size = 10
     in_channels = 3
     out_channels = 6
     n1, n2 = 32, 16
     dtype = torch.complex64
     input = torch.randn(batch_size, in_channels, n2, n1, dtype=dtype)
     # Generate out_channels x in_channels butterfly matrices and loop over them
     b1s = [
         torch_butterfly.Butterfly(n1, n1, bias=False, complex=True)
         for _ in range(out_channels * in_channels)
     ]
     b2s = [
         torch_butterfly.Butterfly(n2, n2, bias=False, complex=True)
         for _ in range(out_channels * in_channels)
     ]
     b_tp = [
         torch_butterfly.combine.TensorProduct(b1, b2)
         for b1, b2 in zip(b1s, b2s)
     ]
     outputs = []
     for o in range(out_channels):
         output = []
         for i in range(in_channels):
             index = o * in_channels + i
             output.append(b_tp[index](input[:, i]))
         outputs.append(torch.stack(output, dim=1))
     out = torch.stack(outputs, dim=1)
     assert out.shape == (batch_size, out_channels, in_channels, n2, n1)
     # Use ButterflyBmm instead
     b1_bmm = torch_butterfly.ButterflyBmm(n1,
                                           n1,
                                           matrix_batch=out_channels *
                                           in_channels,
                                           bias=False,
                                           complex=True)
     with torch.no_grad():
         b1_bmm.twiddle.copy_(torch.cat([b1.twiddle for b1 in b1s]))
     b2_bmm = torch_butterfly.ButterflyBmm(n2,
                                           n2,
                                           matrix_batch=out_channels *
                                           in_channels,
                                           bias=False,
                                           complex=True)
     with torch.no_grad():
         b2_bmm.twiddle.copy_(torch.cat([b2.twiddle for b2 in b2s]))
     input_reshaped = input.transpose(1, 2).reshape(batch_size, n2, 1,
                                                    in_channels, n1)
     input_expanded = input_reshaped.expand(batch_size, n2, out_channels,
                                            in_channels, n1)
     out_bmm = b1_bmm(
         input_expanded.reshape(batch_size, n2, out_channels * in_channels,
                                n1))
     out_bmm = out_bmm.transpose(
         1, 3)  # (batch_size, n1, out_channels * in_channels, n2)
     out_bmm = b2_bmm(
         out_bmm)  # (batch_size, n1, out_channels * in_channels, n2)
     out_bmm = out_bmm.permute(
         0, 2, 3, 1)  # (batch_size, out_channels * in_channels, n2, n1)
     out_bmm = out_bmm.reshape(batch_size, out_channels, in_channels, n2,
                               n1)
     self.assertTrue(torch.allclose(out_bmm, out))