Ejemplo n.º 1
0
    def test_complex_matmul(self):
        """Check that our index_last_dim backward is also correct for real input
        """
        bs = (3, 5)
        for device in ['cpu', 'cuda']:
            X = torch.randn(*bs, 128, 16, dtype=torch.complex64, device=device, requires_grad=True)
            Y = torch.randn(*bs, 16, 32, dtype=torch.complex64, device=device, requires_grad=True)
            prod = complex_matmul(X, Y)
            prod_sum = complex_mul(X.unsqueeze(-1), Y.unsqueeze(-3)).sum(dim=-2)
            self.assertTrue(torch.allclose(prod, prod_sum, self.rtol, self.atol))
            g = torch.randn_like(prod)
            grad_X, grad_Y = torch.autograd.grad(prod, (X, Y), g)
            grad_X_sum, grad_Y_sum = torch.autograd.grad(prod_sum, (X, Y), g)
            self.assertTrue(torch.allclose(grad_X, grad_X_sum, self.rtol, self.atol))
            self.assertTrue(torch.allclose(grad_Y, grad_Y_sum, self.rtol, self.atol))

            X = torch.randn(5, 3, 32, 32, dtype=torch.complex64, device=device, requires_grad=True)
            Y = torch.randn(6, 3, 32, 32, dtype=torch.complex64, device=device, requires_grad=True)
            prod = complex_matmul(X.permute(2, 3, 0, 1), Y.permute(2, 3, 1, 0)).permute(2, 3, 0, 1)
            prod_sum = complex_mul(X.unsqueeze(1), Y).sum(dim=2)
            self.assertTrue(torch.allclose(prod, prod_sum, self.rtol, self.atol))
            g = torch.randn_like(prod)
            grad_X, grad_Y = torch.autograd.grad(prod, (X, Y), g)
            grad_X_sum, grad_Y_sum = torch.autograd.grad(prod_sum, (X, Y), g)
            self.assertTrue(torch.allclose(grad_X, grad_X_sum, self.rtol, self.atol))
            self.assertTrue(torch.allclose(grad_Y, grad_Y_sum, self.rtol, self.atol))
Ejemplo n.º 2
0
 def test_conv1d_circular_multichannel(self):
     batch_size = 10
     in_channels = 3
     out_channels = 4
     for n in [13, 16]:
         for kernel_size in [1, 3, 5, 7]:
             padding = (kernel_size - 1) // 2
             conv = nn.Conv1d(in_channels,
                              out_channels,
                              kernel_size,
                              padding=padding,
                              padding_mode='circular',
                              bias=False)
             weight = conv.weight
             input = torch.randn(batch_size, in_channels, n)
             out_torch = conv(input)
             # Just to show how to implement conv1d with FFT
             input_f = view_as_complex(torch.rfft(input, signal_ndim=1))
             col = F.pad(weight.flip(dims=(-1, )),
                         (0, n - kernel_size)).roll(-padding, dims=-1)
             col_f = view_as_complex(torch.rfft(col, signal_ndim=1))
             prod_f = complex_mul(input_f.unsqueeze(1), col_f).sum(dim=2)
             out_fft = torch.irfft(view_as_real(prod_f),
                                   signal_ndim=1,
                                   signal_sizes=(n, ))
             self.assertTrue(
                 torch.allclose(out_torch, out_fft, self.rtol, self.atol))
             b = torch_butterfly.special.conv1d_circular_multichannel(
                 n, weight)
             out = b(input)
             self.assertTrue(
                 torch.allclose(out, out_torch, self.rtol, self.atol))
Ejemplo n.º 3
0
 def forward(self, input):
     """
     Parameters:
         input: (batch, *, size)
     Return:
         output: (batch, *, size)
     """
     diagonal = self.diagonal if not self.complex else view_as_complex(
         self.diagonal)
     return complex_mul(input, diagonal)
Ejemplo n.º 4
0
 def forward(self, input):
     """
     Parameters:
         input: (batch, in_channels, size)
     Return:
         output: (batch, out_channels, size)
     """
     diagonal = self.diagonal if not self.complex else view_as_complex(
         self.diagonal)
     return complex_mul(input.unsqueeze(1), diagonal).sum(dim=2)
Ejemplo n.º 5
0
    def test_circulant(self):
        batch_size = 10
        n = 13
        for complex in [False, True]:
            dtype = torch.float32 if not complex else torch.complex64
            col = torch.randn(n, dtype=dtype)
            C = la.circulant(col.numpy())
            input = torch.randn(batch_size, n, dtype=dtype)
            out_torch = torch.tensor(input.detach().numpy() @ C.T)
            out_np = torch.tensor(np.fft.ifft(
                np.fft.fft(input.numpy()) * np.fft.fft(col.numpy())),
                                  dtype=dtype)
            self.assertTrue(
                torch.allclose(out_torch, out_np, self.rtol, self.atol))
            # Just to show how to implement circulant multiply with FFT
            if complex:
                input_f = view_as_complex(
                    torch.fft(view_as_real(input), signal_ndim=1))
                col_f = view_as_complex(
                    torch.fft(view_as_real(col), signal_ndim=1))
                prod_f = complex_mul(input_f, col_f)
                out_fft = view_as_complex(
                    torch.ifft(view_as_real(prod_f), signal_ndim=1))
                self.assertTrue(
                    torch.allclose(out_torch, out_fft, self.rtol, self.atol))
            for separate_diagonal in [True, False]:
                b = torch_butterfly.special.circulant(
                    col, transposed=False, separate_diagonal=separate_diagonal)
                out = b(input)
                self.assertTrue(
                    torch.allclose(out, out_torch, self.rtol, self.atol))

            row = torch.randn(n, dtype=dtype)
            C = la.circulant(row.numpy()).T
            input = torch.randn(batch_size, n, dtype=dtype)
            out_torch = torch.tensor(input.detach().numpy() @ C.T)
            # row is the reverse of col, except the 0-th element stays put
            # This corresponds to the same reversal in the frequency domain.
            # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal
            row_f = np.fft.fft(row.numpy())
            row_f_reversed = np.hstack((row_f[:1], row_f[1:][::-1]))
            out_np = torch.tensor(np.fft.ifft(
                np.fft.fft(input.numpy()) * row_f_reversed),
                                  dtype=dtype)
            self.assertTrue(
                torch.allclose(out_torch, out_np, self.rtol, self.atol))
            for separate_diagonal in [True, False]:
                b = torch_butterfly.special.circulant(
                    row, transposed=True, separate_diagonal=separate_diagonal)
                out = b(input)
                self.assertTrue(
                    torch.allclose(out, out_torch, self.rtol, self.atol))
Ejemplo n.º 6
0
 def test_conv2d_circular_multichannel(self):
     batch_size = 10
     in_channels = 3
     out_channels = 4
     for n1 in [13, 16]:
         for n2 in [27, 32]:
             # flatten is only supported for powers of 2 for now
             if n1 == 1 << int(math.log2(n1)) and n2 == 1 << int(
                     math.log2(n2)):
                 flatten_cases = [False, True]
             else:
                 flatten_cases = [False]
             for kernel_size1 in [1, 3, 5, 7]:
                 for kernel_size2 in [1, 3, 5, 7]:
                     padding1 = (kernel_size1 - 1) // 2
                     padding2 = (kernel_size2 - 1) // 2
                     conv = nn.Conv2d(in_channels,
                                      out_channels,
                                      (kernel_size2, kernel_size1),
                                      padding=(padding2, padding1),
                                      padding_mode='circular',
                                      bias=False)
                     weight = conv.weight
                     input = torch.randn(batch_size, in_channels, n2, n1)
                     out_torch = conv(input)
                     # Just to show how to implement conv2d with FFT
                     input_f = view_as_complex(
                         torch.rfft(input, signal_ndim=2))
                     col = F.pad(weight.flip(dims=(-1, )),
                                 (0, n1 - kernel_size1)).roll(-padding1,
                                                              dims=-1)
                     col = F.pad(col.flip(dims=(-2, )),
                                 (0, 0, 0, n2 - kernel_size2)).roll(
                                     -padding2, dims=-2)
                     col_f = view_as_complex(torch.rfft(col, signal_ndim=2))
                     prod_f = complex_mul(input_f.unsqueeze(1),
                                          col_f).sum(dim=2)
                     out_fft = torch.irfft(view_as_real(prod_f),
                                           signal_ndim=2,
                                           signal_sizes=(n2, n1))
                     self.assertTrue(
                         torch.allclose(out_torch, out_fft, self.rtol,
                                        self.atol))
                     for flatten in flatten_cases:
                         b = torch_butterfly.special.conv2d_circular_multichannel(
                             n1, n2, weight, flatten=flatten)
                         out = b(input)
                         self.assertTrue(
                             torch.allclose(out, out_torch, self.rtol,
                                            self.atol))
Ejemplo n.º 7
0
def butterfly_multiply_torch(twiddle, input, increasing_stride=True):
    batch_size, nstacks, n = input.shape
    nblocks = twiddle.shape[1]
    log_n = int(math.log2(n))
    assert n == 1 << log_n, "size must be a power of 2"
    assert twiddle.shape == (nstacks, nblocks, log_n, n // 2, 2, 2)
    output = input.contiguous()
    cur_increasing_stride = increasing_stride
    for block in range(nblocks):
        for idx in range(log_n):
            log_stride = idx if cur_increasing_stride else log_n - 1 - idx
            stride = 1 << log_stride
            # shape (nstacks, n // (2 * stride), 2, 2, stride)
            t = twiddle[:, block, idx].view(nstacks, n // (2 * stride), stride,
                                            2, 2).permute(0, 1, 3, 4, 2)
            output_reshape = output.view(batch_size, nstacks,
                                         n // (2 * stride), 1, 2, stride)
            output = complex_mul(t, output_reshape).sum(dim=4)
        cur_increasing_stride = not cur_increasing_stride
    return output.view(batch_size, nstacks, n)