Example #1
0
 def test_ifft2d(self):
     batch_size = 10
     n1 = 32
     n2 = 16
     input = torch.randn(batch_size, n2, n1, dtype=torch.complex64)
     for normalized in [False, True]:
         out_torch = view_as_complex(
             torch.ifft(view_as_real(input),
                        signal_ndim=2,
                        normalized=normalized))
         # Just to show how ifft2d is exactly 2 iffts on each dimension
         input_f = view_as_complex(
             torch.ifft(view_as_real(input),
                        signal_ndim=1,
                        normalized=normalized))
         out_fft = view_as_complex(
             torch.ifft(view_as_real(input_f.transpose(-1, -2)),
                        signal_ndim=1,
                        normalized=normalized)).transpose(-1, -2)
         self.assertTrue(
             torch.allclose(out_torch, out_fft, self.rtol, self.atol))
         for br_first in [True, False]:
             for flatten in [False, True]:
                 b = torch_butterfly.special.ifft2d(n1,
                                                    n2,
                                                    normalized=normalized,
                                                    br_first=br_first,
                                                    flatten=flatten)
                 out = b(input)
                 self.assertTrue(
                     torch.allclose(out, out_torch, self.rtol, self.atol))
Example #2
0
 def __init__(self, in_size, out_size, matrix_batch=1, bias=True, complex=False,
              increasing_stride=True, init='randn', nblocks=1):
     nn.Module.__init__(self)
     self.in_size = in_size
     log_n = int(math.ceil(math.log2(in_size)))
     self.log_n = log_n
     size = self.in_size_extended = 1 << log_n  # Will zero-pad input if in_size is not a power of 2
     self.out_size = out_size
     self.matrix_batch = matrix_batch
     self.nstacks = int(math.ceil(out_size / self.in_size_extended))
     self.complex = complex
     self.increasing_stride = increasing_stride
     assert nblocks >= 1
     self.nblocks = nblocks
     dtype = torch.get_default_dtype() if not self.complex else real_dtype_to_complex[torch.get_default_dtype()]
     twiddle_shape = (self.matrix_batch * self.nstacks, nblocks, log_n, size // 2, 2, 2)
     assert init in ['randn', 'ortho', 'identity']
     self.init = init
     self.twiddle = nn.Parameter(torch.empty(twiddle_shape, dtype=dtype))
     if bias:
         self.bias = nn.Parameter(torch.empty(self.matrix_batch, out_size, dtype=dtype))
     else:
         self.register_parameter('bias', None)
     self.twiddle._is_structured = True  # Flag to avoid weight decay
     # Pytorch 1.6 doesn't support torch.Tensor.add_(other, alpha) yet.
     # This is used in optimizers such as SGD.
     # So we have to store the parameters as real.
     if self.complex:
         self.twiddle = nn.Parameter(view_as_real(self.twiddle))
         if self.bias is not None:
             self.bias = nn.Parameter(view_as_real(self.bias))
     self.reset_parameters()
Example #3
0
    def test_circulant(self):
        batch_size = 10
        n = 13
        for complex in [False, True]:
            dtype = torch.float32 if not complex else torch.complex64
            col = torch.randn(n, dtype=dtype)
            C = la.circulant(col.numpy())
            input = torch.randn(batch_size, n, dtype=dtype)
            out_torch = torch.tensor(input.detach().numpy() @ C.T)
            out_np = torch.tensor(np.fft.ifft(
                np.fft.fft(input.numpy()) * np.fft.fft(col.numpy())),
                                  dtype=dtype)
            self.assertTrue(
                torch.allclose(out_torch, out_np, self.rtol, self.atol))
            # Just to show how to implement circulant multiply with FFT
            if complex:
                input_f = view_as_complex(
                    torch.fft(view_as_real(input), signal_ndim=1))
                col_f = view_as_complex(
                    torch.fft(view_as_real(col), signal_ndim=1))
                prod_f = complex_mul(input_f, col_f)
                out_fft = view_as_complex(
                    torch.ifft(view_as_real(prod_f), signal_ndim=1))
                self.assertTrue(
                    torch.allclose(out_torch, out_fft, self.rtol, self.atol))
            for separate_diagonal in [True, False]:
                b = torch_butterfly.special.circulant(
                    col, transposed=False, separate_diagonal=separate_diagonal)
                out = b(input)
                self.assertTrue(
                    torch.allclose(out, out_torch, self.rtol, self.atol))

            row = torch.randn(n, dtype=dtype)
            C = la.circulant(row.numpy()).T
            input = torch.randn(batch_size, n, dtype=dtype)
            out_torch = torch.tensor(input.detach().numpy() @ C.T)
            # row is the reverse of col, except the 0-th element stays put
            # This corresponds to the same reversal in the frequency domain.
            # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal
            row_f = np.fft.fft(row.numpy())
            row_f_reversed = np.hstack((row_f[:1], row_f[1:][::-1]))
            out_np = torch.tensor(np.fft.ifft(
                np.fft.fft(input.numpy()) * row_f_reversed),
                                  dtype=dtype)
            self.assertTrue(
                torch.allclose(out_torch, out_np, self.rtol, self.atol))
            for separate_diagonal in [True, False]:
                b = torch_butterfly.special.circulant(
                    row, transposed=True, separate_diagonal=separate_diagonal)
                out = b(input)
                self.assertTrue(
                    torch.allclose(out, out_torch, self.rtol, self.atol))
Example #4
0
 def test_conv1d_circular_multichannel(self):
     batch_size = 10
     in_channels = 3
     out_channels = 4
     for n in [13, 16]:
         for kernel_size in [1, 3, 5, 7]:
             padding = (kernel_size - 1) // 2
             conv = nn.Conv1d(in_channels,
                              out_channels,
                              kernel_size,
                              padding=padding,
                              padding_mode='circular',
                              bias=False)
             weight = conv.weight
             input = torch.randn(batch_size, in_channels, n)
             out_torch = conv(input)
             # Just to show how to implement conv1d with FFT
             input_f = view_as_complex(torch.rfft(input, signal_ndim=1))
             col = F.pad(weight.flip(dims=(-1, )),
                         (0, n - kernel_size)).roll(-padding, dims=-1)
             col_f = view_as_complex(torch.rfft(col, signal_ndim=1))
             prod_f = complex_mul(input_f.unsqueeze(1), col_f).sum(dim=2)
             out_fft = torch.irfft(view_as_real(prod_f),
                                   signal_ndim=1,
                                   signal_sizes=(n, ))
             self.assertTrue(
                 torch.allclose(out_torch, out_fft, self.rtol, self.atol))
             b = torch_butterfly.special.conv1d_circular_multichannel(
                 n, weight)
             out = b(input)
             self.assertTrue(
                 torch.allclose(out, out_torch, self.rtol, self.atol))
Example #5
0
def conv1d_circular_multichannel(n, weight) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs nn.Conv1d
    with multiple in/out channels, with circular padding.
    The output of nn.Conv1d must have the same size as the input (i.e. kernel size must be 2k + 1,
    and padding k for some integer k).
    Parameters:
        n: size of the input.
        weight: torch.Tensor of size (out_channels, in_channels, kernel_size). Kernel_size must be
                odd, and smaller than n. Padding is assumed to be (kernel_size - 1) // 2.
    """
    assert weight.dim() == 3, 'Weight must have dimension 3'
    kernel_size = weight.shape[-1]
    assert kernel_size < n
    assert kernel_size % 2 == 1, 'Kernel size must be odd'
    out_channels, in_channels = weight.shape[:2]
    padding = (kernel_size - 1) // 2
    col = F.pad(weight.flip([-1]), (0, n - kernel_size)).roll(-padding,
                                                              dims=-1)
    # From here we mimic the circulant construction, but the diagonal multiply is replaced with
    # multiply and then sum across the in-channels.
    complex = col.is_complex()
    log_n = int(math.ceil(math.log2(n)))
    # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n?
    # I've only figured out how to pad to size 1 << (log_n + 1).
    # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c]
    n_extended = n if n == 1 << log_n else 1 << (log_n + 1)
    b_fft = fft(n_extended,
                normalized=True,
                br_first=False,
                with_br_perm=False).to(col.device)
    b_fft.in_size = n
    b_ifft = ifft(n_extended,
                  normalized=True,
                  br_first=True,
                  with_br_perm=False).to(col.device)
    b_ifft.out_size = n
    if n < n_extended:
        col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n)))
        col = torch.cat((col_0, col), dim=-1)
    if not col.is_complex():
        col = real2complex(col)
    # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the
    # circulant matrix.
    col_f = view_as_complex(
        torch.fft(view_as_real(col), signal_ndim=1, normalized=False))
    br_perm = bitreversal_permutation(n_extended,
                                      pytorch_format=True).to(col.device)
    col_f = col_f[..., br_perm]
    # We just want (input_f.unsqueeze(1) * col_f).sum(dim=2).
    # This can be written as matrix multiply but Pytorch 1.6 doesn't yet support complex matrix
    # multiply.

    if not complex:
        return nn.Sequential(Real2Complex(), b_fft, DiagonalMultiplySum(col_f),
                             b_ifft, Complex2Real())
    else:
        return nn.Sequential(b_fft, DiagonalMultiplySum(col_f), b_ifft)
Example #6
0
 def __init__(self, diagonal_init):
     """
     Parameters:
         diagonal_init: (out_channels, in_channels, size)
     """
     super().__init__()
     self.diagonal = nn.Parameter(diagonal_init.detach().clone())
     self.complex = self.diagonal.is_complex()
     if self.complex:
         self.diagonal = nn.Parameter(view_as_real(self.diagonal))
Example #7
0
 def test_ifft_unitary(self):
     batch_size = 10
     n = 16
     input = torch.randn(batch_size, n, dtype=torch.complex64)
     normalized = True
     out_torch = view_as_complex(
         torch.ifft(view_as_real(input),
                    signal_ndim=1,
                    normalized=normalized))
     for br_first in [True, False]:
         b = torch_butterfly.special.ifft_unitary(n, br_first=br_first)
         out = b(input)
         self.assertTrue(
             torch.allclose(out, out_torch, self.rtol, self.atol))
Example #8
0
 def test_conv2d_circular_multichannel(self):
     batch_size = 10
     in_channels = 3
     out_channels = 4
     for n1 in [13, 16]:
         for n2 in [27, 32]:
             # flatten is only supported for powers of 2 for now
             if n1 == 1 << int(math.log2(n1)) and n2 == 1 << int(
                     math.log2(n2)):
                 flatten_cases = [False, True]
             else:
                 flatten_cases = [False]
             for kernel_size1 in [1, 3, 5, 7]:
                 for kernel_size2 in [1, 3, 5, 7]:
                     padding1 = (kernel_size1 - 1) // 2
                     padding2 = (kernel_size2 - 1) // 2
                     conv = nn.Conv2d(in_channels,
                                      out_channels,
                                      (kernel_size2, kernel_size1),
                                      padding=(padding2, padding1),
                                      padding_mode='circular',
                                      bias=False)
                     weight = conv.weight
                     input = torch.randn(batch_size, in_channels, n2, n1)
                     out_torch = conv(input)
                     # Just to show how to implement conv2d with FFT
                     input_f = view_as_complex(
                         torch.rfft(input, signal_ndim=2))
                     col = F.pad(weight.flip(dims=(-1, )),
                                 (0, n1 - kernel_size1)).roll(-padding1,
                                                              dims=-1)
                     col = F.pad(col.flip(dims=(-2, )),
                                 (0, 0, 0, n2 - kernel_size2)).roll(
                                     -padding2, dims=-2)
                     col_f = view_as_complex(torch.rfft(col, signal_ndim=2))
                     prod_f = complex_mul(input_f.unsqueeze(1),
                                          col_f).sum(dim=2)
                     out_fft = torch.irfft(view_as_real(prod_f),
                                           signal_ndim=2,
                                           signal_sizes=(n2, n1))
                     self.assertTrue(
                         torch.allclose(out_torch, out_fft, self.rtol,
                                        self.atol))
                     for flatten in flatten_cases:
                         b = torch_butterfly.special.conv2d_circular_multichannel(
                             n1, n2, weight, flatten=flatten)
                         out = b(input)
                         self.assertTrue(
                             torch.allclose(out, out_torch, self.rtol,
                                            self.atol))
Example #9
0
def diagonal_butterfly(butterfly: Butterfly,
                       diagonal: torch.Tensor,
                       diag_first: bool,
                       inplace: bool = True) -> Butterfly:
    """
    Combine a Butterfly and a diagonal into another Butterfly.
    Only support nstacks==1 for now.
    Parameters:
        butterfly: Butterfly(in_size, out_size)
        diagonal: size (in_size,) if diag_first, else (out_size,). Should be of type complex
            if butterfly.complex == True.
        diag_first: If True, the map is input -> diagonal -> butterfly.
            If False, the map is input -> butterfly -> diagonal.
        inplace: whether to modify the input Butterfly
    """
    assert butterfly.nstacks == 1
    assert butterfly.bias is None
    twiddle = (butterfly.twiddle.clone() if not butterfly.complex else
               view_as_complex(butterfly.twiddle).clone())
    n = 1 << twiddle.shape[2]
    if diagonal.shape[-1] < n:
        diagonal = F.pad(diagonal, (0, n - diagonal.shape[-1]), value=1)
    if diag_first:
        if butterfly.increasing_stride:
            twiddle[:, 0, 0, :, :, 0] *= diagonal[::2].unsqueeze(-1)
            twiddle[:, 0, 0, :, :, 1] *= diagonal[1::2].unsqueeze(-1)
        else:
            n = diagonal.shape[-1]
            twiddle[:, 0, 0, :, :, 0] *= diagonal[:n // 2].unsqueeze(-1)
            twiddle[:, 0, 0, :, :, 1] *= diagonal[n // 2:].unsqueeze(-1)
    else:
        # Whether the last block is increasing or decreasing stride
        increasing_stride = butterfly.increasing_stride != (
            (butterfly.nblocks - 1) % 2 == 1)
        if increasing_stride:
            n = diagonal.shape[-1]
            twiddle[:, -1, -1, :, 0, :] *= diagonal[:n // 2].unsqueeze(-1)
            twiddle[:, -1, -1, :, 1, :] *= diagonal[n // 2:].unsqueeze(-1)
        else:
            twiddle[:, -1, -1, :, 0, :] *= diagonal[::2].unsqueeze(-1)
            twiddle[:, -1, -1, :, 1, :] *= diagonal[1::2].unsqueeze(-1)
    out_butterfly = butterfly if inplace else copy.deepcopy(butterfly)
    with torch.no_grad():
        out_butterfly.twiddle.copy_(
            twiddle if not butterfly.complex else view_as_real(twiddle))
    return out_butterfly
Example #10
0
 def __init__(self, size=None, complex=False, diagonal_init=None):
     """Multiply by diagonal matrix
     Parameter:
         size: int
         diagonal_init: (n, )
     """
     super().__init__()
     if diagonal_init is not None:
         self.size = diagonal_init.shape
         self.diagonal = nn.Parameter(diagonal_init.detach().clone())
         self.complex = self.diagonal.is_complex()
     else:
         assert size is not None
         self.size = size
         dtype = torch.get_default_dtype(
         ) if not complex else real_dtype_to_complex[
             torch.get_default_dtype()]
         self.diagonal = nn.Parameter(torch.randn(size, dtype=dtype))
         self.complex = complex
     if self.complex:
         self.diagonal = nn.Parameter(view_as_real(self.diagonal))
Example #11
0
def conv2d_circular_multichannel(n1: int,
                                 n2: int,
                                 weight: torch.Tensor,
                                 flatten: bool = False) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs nn.Conv2d
    with multiple in/out channels, with circular padding.
    The output of nn.Conv2d must have the same size as the input (i.e. kernel size must be 2k + 1,
    and padding k for some integer k).
    Parameters:
        n1: size of the last dimension of the input.
        n2: size of the second to last dimension of the input.
        weight: torch.Tensor of size (out_channels, in_channels, kernel_size2, kernel_size1).
            Kernel_size must be odd, and smaller than n1/n2. Padding is assumed to be
            (kernel_size - 1) // 2.
        flatten: whether to internally flatten the last 2 dimensions of the input. Only support n1
            and n2 being powers of 2.
    """
    assert weight.dim() == 4, 'Weight must have dimension 4'
    kernel_size2, kernel_size1 = weight.shape[-2], weight.shape[-1]
    assert kernel_size1 < n1, kernel_size2 < n2
    assert kernel_size1 % 2 == 1 and kernel_size2 % 2 == 1, 'Kernel size must be odd'
    out_channels, in_channels = weight.shape[:2]
    padding1 = (kernel_size1 - 1) // 2
    padding2 = (kernel_size2 - 1) // 2
    col = F.pad(weight.flip([-1]), (0, n1 - kernel_size1)).roll(-padding1,
                                                                dims=-1)
    col = F.pad(col.flip([-2]), (0, 0, 0, n2 - kernel_size2)).roll(-padding2,
                                                                   dims=-2)
    # From here we mimic the circulant construction, but the diagonal multiply is replaced with
    # multiply and then sum across the in-channels.
    complex = col.is_complex()
    log_n1 = int(math.ceil(math.log2(n1)))
    log_n2 = int(math.ceil(math.log2(n2)))
    if flatten:
        assert n1 == 1 << log_n1, n2 == 1 << log_n2
    # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n1?
    # I've only figured out how to pad to size 1 << (log_n1 + 1).
    # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c]
    n_extended1 = n1 if n1 == 1 << log_n1 else 1 << (log_n1 + 1)
    n_extended2 = n2 if n2 == 1 << log_n2 else 1 << (log_n2 + 1)
    b_fft = fft2d(n_extended1,
                  n_extended2,
                  normalized=True,
                  br_first=False,
                  with_br_perm=False,
                  flatten=flatten).to(col.device)
    if not flatten:
        b_fft.map1.in_size = n1
        b_fft.map2.in_size = n2
    else:
        b_fft = b_fft[1]  # Ignore the nn.Flatten and Unflatten2D
    b_ifft = ifft2d(n_extended1,
                    n_extended2,
                    normalized=True,
                    br_first=True,
                    with_br_perm=False,
                    flatten=flatten).to(col.device)
    if not flatten:
        b_ifft.map1.out_size = n1
        b_ifft.map2.out_size = n2
    else:
        b_ifft = b_ifft[1]  # Ignore the nn.Flatten and Unflatten2D
    if n1 < n_extended1:
        col_0 = F.pad(col, (0, 2 * ((1 << log_n1) - n1)))
        col = torch.cat((col_0, col), dim=-1)
    if n2 < n_extended2:
        col_0 = F.pad(col, (0, 0, 0, 2 * ((1 << log_n2) - n2)))
        col = torch.cat((col_0, col), dim=-2)
    if not col.is_complex():
        col = real2complex(col)
    # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the
    # circulant matrix.
    col_f = view_as_complex(
        torch.fft(view_as_real(col), signal_ndim=2, normalized=False))
    br_perm1 = bitreversal_permutation(n_extended1,
                                       pytorch_format=True).to(col.device)
    br_perm2 = bitreversal_permutation(n_extended2,
                                       pytorch_format=True).to(col.device)
    # col_f[..., br_perm2, br_perm1] would error "shape mismatch: indexing tensors could not be
    # broadcast together"
    col_f = col_f[..., br_perm2, :][..., br_perm1]
    if flatten:
        col_f = col_f.reshape(*col_f.shape[:-2],
                              col_f.shape[-2] * col_f.shape[-1])
    # We just want (input_f.unsqueeze(1) * col_f).sum(dim=2).
    # This can be written as matrix multiply but Pytorch 1.6 doesn't yet support complex matrix
    # multiply.
    if not complex:
        if not flatten:
            return nn.Sequential(Real2Complex(), b_fft,
                                 DiagonalMultiplySum(col_f), b_ifft,
                                 Complex2Real())
        else:
            return nn.Sequential(Real2Complex(), nn.Flatten(start_dim=-2),
                                 b_fft, DiagonalMultiplySum(col_f), b_ifft,
                                 Unflatten2D(n1), Complex2Real())
    else:
        if not flatten:
            return nn.Sequential(b_fft, DiagonalMultiplySum(col_f), b_ifft)
        else:
            return nn.Sequential(nn.Flatten(start_dim=-2), b_fft,
                                 DiagonalMultiplySum(col_f), b_ifft,
                                 Unflatten2D(n1))
Example #12
0
def circulant(col, transposed=False, separate_diagonal=True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs circulant matrix
    multiplication.
    Parameters:
        col: torch.Tensor of size (n, ). The first column of the circulant matrix.
        transposed: if True, then the circulant matrix is transposed, i.e. col is the first *row*
                    of the matrix.
        separate_diagonal: if True, the returned nn.Module is Butterfly, Diagonal, Butterfly.
                           if False, the diagonal is combined into the Butterfly part.
    """
    assert col.dim() == 1, 'Vector col must have dimension 1'
    complex = col.is_complex()
    n = col.shape[0]
    log_n = int(math.ceil(math.log2(n)))
    # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n?
    # I've only figured out how to pad to size 1 << (log_n + 1).
    # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c]
    n_extended = n if n == 1 << log_n else 1 << (log_n + 1)
    b_fft = fft(n_extended,
                normalized=True,
                br_first=False,
                with_br_perm=False).to(col.device)
    b_fft.in_size = n
    b_ifft = ifft(n_extended,
                  normalized=True,
                  br_first=True,
                  with_br_perm=False).to(col.device)
    b_ifft.out_size = n
    if n < n_extended:
        col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n)))
        col = torch.cat((col_0, col))
    if not col.is_complex():
        col = real2complex(col)
    # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the
    # circulant matrix.
    col_f = view_as_complex(
        torch.fft(view_as_real(col), signal_ndim=1, normalized=False))
    if transposed:
        # We could have just transposed the iFFT * Diag * FFT to get FFT * Diag * iFFT.
        # Instead we use the fact that row is the reverse of col, but the 0-th element stays put.
        # This corresponds to the same reversal in the frequency domain.
        # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal
        col_f = torch.cat((col_f[:1], col_f[1:].flip([0])))
    br_perm = bitreversal_permutation(n_extended,
                                      pytorch_format=True).to(col.device)
    diag = col_f[..., br_perm]
    if separate_diagonal:
        if not complex:
            return nn.Sequential(Real2Complex(), b_fft,
                                 Diagonal(diagonal_init=diag), b_ifft,
                                 Complex2Real())
        else:
            return nn.Sequential(b_fft, Diagonal(diagonal_init=diag), b_ifft)
    else:
        # Combine the diagonal with the last twiddle factor of b_fft
        with torch.no_grad():
            b_fft = diagonal_butterfly(b_fft,
                                       diag,
                                       diag_first=False,
                                       inplace=True)
        # Combine the b_fft and b_ifft into one Butterfly (with nblocks=2).
        b = butterfly_product(b_fft, b_ifft)
        b.in_size = n
        b.out_size = n
        return b if complex else nn.Sequential(Real2Complex(), b,
                                               Complex2Real())