Ejemplo n.º 1
0
 def test_conv1d_circular_multichannel(self):
     batch_size = 10
     in_channels = 3
     out_channels = 4
     for n in [13, 16]:
         for kernel_size in [1, 3, 5, 7]:
             padding = (kernel_size - 1) // 2
             conv = nn.Conv1d(in_channels,
                              out_channels,
                              kernel_size,
                              padding=padding,
                              padding_mode='circular',
                              bias=False)
             weight = conv.weight
             input = torch.randn(batch_size, in_channels, n)
             out_torch = conv(input)
             # Just to show how to implement conv1d with FFT
             input_f = view_as_complex(torch.rfft(input, signal_ndim=1))
             col = F.pad(weight.flip(dims=(-1, )),
                         (0, n - kernel_size)).roll(-padding, dims=-1)
             col_f = view_as_complex(torch.rfft(col, signal_ndim=1))
             prod_f = complex_mul(input_f.unsqueeze(1), col_f).sum(dim=2)
             out_fft = torch.irfft(view_as_real(prod_f),
                                   signal_ndim=1,
                                   signal_sizes=(n, ))
             self.assertTrue(
                 torch.allclose(out_torch, out_fft, self.rtol, self.atol))
             b = torch_butterfly.special.conv1d_circular_multichannel(
                 n, weight)
             out = b(input)
             self.assertTrue(
                 torch.allclose(out, out_torch, self.rtol, self.atol))
Ejemplo n.º 2
0
 def test_ifft2d(self):
     batch_size = 10
     n1 = 32
     n2 = 16
     input = torch.randn(batch_size, n2, n1, dtype=torch.complex64)
     for normalized in [False, True]:
         out_torch = view_as_complex(
             torch.ifft(view_as_real(input),
                        signal_ndim=2,
                        normalized=normalized))
         # Just to show how ifft2d is exactly 2 iffts on each dimension
         input_f = view_as_complex(
             torch.ifft(view_as_real(input),
                        signal_ndim=1,
                        normalized=normalized))
         out_fft = view_as_complex(
             torch.ifft(view_as_real(input_f.transpose(-1, -2)),
                        signal_ndim=1,
                        normalized=normalized)).transpose(-1, -2)
         self.assertTrue(
             torch.allclose(out_torch, out_fft, self.rtol, self.atol))
         for br_first in [True, False]:
             for flatten in [False, True]:
                 b = torch_butterfly.special.ifft2d(n1,
                                                    n2,
                                                    normalized=normalized,
                                                    br_first=br_first,
                                                    flatten=flatten)
                 out = b(input)
                 self.assertTrue(
                     torch.allclose(out, out_torch, self.rtol, self.atol))
Ejemplo n.º 3
0
    def test_circulant(self):
        batch_size = 10
        n = 13
        for complex in [False, True]:
            dtype = torch.float32 if not complex else torch.complex64
            col = torch.randn(n, dtype=dtype)
            C = la.circulant(col.numpy())
            input = torch.randn(batch_size, n, dtype=dtype)
            out_torch = torch.tensor(input.detach().numpy() @ C.T)
            out_np = torch.tensor(np.fft.ifft(
                np.fft.fft(input.numpy()) * np.fft.fft(col.numpy())),
                                  dtype=dtype)
            self.assertTrue(
                torch.allclose(out_torch, out_np, self.rtol, self.atol))
            # Just to show how to implement circulant multiply with FFT
            if complex:
                input_f = view_as_complex(
                    torch.fft(view_as_real(input), signal_ndim=1))
                col_f = view_as_complex(
                    torch.fft(view_as_real(col), signal_ndim=1))
                prod_f = complex_mul(input_f, col_f)
                out_fft = view_as_complex(
                    torch.ifft(view_as_real(prod_f), signal_ndim=1))
                self.assertTrue(
                    torch.allclose(out_torch, out_fft, self.rtol, self.atol))
            for separate_diagonal in [True, False]:
                b = torch_butterfly.special.circulant(
                    col, transposed=False, separate_diagonal=separate_diagonal)
                out = b(input)
                self.assertTrue(
                    torch.allclose(out, out_torch, self.rtol, self.atol))

            row = torch.randn(n, dtype=dtype)
            C = la.circulant(row.numpy()).T
            input = torch.randn(batch_size, n, dtype=dtype)
            out_torch = torch.tensor(input.detach().numpy() @ C.T)
            # row is the reverse of col, except the 0-th element stays put
            # This corresponds to the same reversal in the frequency domain.
            # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal
            row_f = np.fft.fft(row.numpy())
            row_f_reversed = np.hstack((row_f[:1], row_f[1:][::-1]))
            out_np = torch.tensor(np.fft.ifft(
                np.fft.fft(input.numpy()) * row_f_reversed),
                                  dtype=dtype)
            self.assertTrue(
                torch.allclose(out_torch, out_np, self.rtol, self.atol))
            for separate_diagonal in [True, False]:
                b = torch_butterfly.special.circulant(
                    row, transposed=True, separate_diagonal=separate_diagonal)
                out = b(input)
                self.assertTrue(
                    torch.allclose(out, out_torch, self.rtol, self.atol))
Ejemplo n.º 4
0
 def test_conv2d_circular_multichannel(self):
     batch_size = 10
     in_channels = 3
     out_channels = 4
     for n1 in [13, 16]:
         for n2 in [27, 32]:
             # flatten is only supported for powers of 2 for now
             if n1 == 1 << int(math.log2(n1)) and n2 == 1 << int(
                     math.log2(n2)):
                 flatten_cases = [False, True]
             else:
                 flatten_cases = [False]
             for kernel_size1 in [1, 3, 5, 7]:
                 for kernel_size2 in [1, 3, 5, 7]:
                     padding1 = (kernel_size1 - 1) // 2
                     padding2 = (kernel_size2 - 1) // 2
                     conv = nn.Conv2d(in_channels,
                                      out_channels,
                                      (kernel_size2, kernel_size1),
                                      padding=(padding2, padding1),
                                      padding_mode='circular',
                                      bias=False)
                     weight = conv.weight
                     input = torch.randn(batch_size, in_channels, n2, n1)
                     out_torch = conv(input)
                     # Just to show how to implement conv2d with FFT
                     input_f = view_as_complex(
                         torch.rfft(input, signal_ndim=2))
                     col = F.pad(weight.flip(dims=(-1, )),
                                 (0, n1 - kernel_size1)).roll(-padding1,
                                                              dims=-1)
                     col = F.pad(col.flip(dims=(-2, )),
                                 (0, 0, 0, n2 - kernel_size2)).roll(
                                     -padding2, dims=-2)
                     col_f = view_as_complex(torch.rfft(col, signal_ndim=2))
                     prod_f = complex_mul(input_f.unsqueeze(1),
                                          col_f).sum(dim=2)
                     out_fft = torch.irfft(view_as_real(prod_f),
                                           signal_ndim=2,
                                           signal_sizes=(n2, n1))
                     self.assertTrue(
                         torch.allclose(out_torch, out_fft, self.rtol,
                                        self.atol))
                     for flatten in flatten_cases:
                         b = torch_butterfly.special.conv2d_circular_multichannel(
                             n1, n2, weight, flatten=flatten)
                         out = b(input)
                         self.assertTrue(
                             torch.allclose(out, out_torch, self.rtol,
                                            self.atol))
Ejemplo n.º 5
0
 def forward(self, input, transpose=False, conjugate=False, subtwiddle=False):
     """
     Parameters:
         input: (batch, *, in_size)
         transpose: whether the butterfly matrix should be transposed.
         conjugate: whether the butterfly matrix should be conjugated.
         subtwiddle: allow using only part of the parameters for smaller input.
             Could be useful for weight sharing.
             out_size is set to self.nstacks * self.in_size_extended in this case
     Return:
         output: (batch, *, out_size)
     """
     twiddle = self.twiddle if not self.complex else view_as_complex(self.twiddle)
     if not subtwiddle:
         output = self.pre_process(input)
     else:
         log_n = int(math.ceil(math.log2(input.size(-1))))
         n = 1 << log_n
         output = self.pre_process(input, padded_size=n)
         twiddle = (twiddle[:, :, :log_n, :n // 2] if self.increasing_stride
                    else twiddle[:, :, -log_n:, :n // 2])
     if conjugate and self.complex:
         twiddle = twiddle.conj()
     if not transpose:
         output = butterfly_multiply(twiddle, output, self.increasing_stride)
     else:
         twiddle = twiddle.transpose(-1, -2).flip([1, 2])
         last_increasing_stride = self.increasing_stride != ((self.nblocks - 1) % 2 == 1)
         output = butterfly_multiply(twiddle, output, not last_increasing_stride)
     if not subtwiddle:
         return self.post_process(input, output)
     else:
         return self.post_process(input, output, out_size=output.size(-1))
Ejemplo n.º 6
0
def conv1d_circular_multichannel(n, weight) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs nn.Conv1d
    with multiple in/out channels, with circular padding.
    The output of nn.Conv1d must have the same size as the input (i.e. kernel size must be 2k + 1,
    and padding k for some integer k).
    Parameters:
        n: size of the input.
        weight: torch.Tensor of size (out_channels, in_channels, kernel_size). Kernel_size must be
                odd, and smaller than n. Padding is assumed to be (kernel_size - 1) // 2.
    """
    assert weight.dim() == 3, 'Weight must have dimension 3'
    kernel_size = weight.shape[-1]
    assert kernel_size < n
    assert kernel_size % 2 == 1, 'Kernel size must be odd'
    out_channels, in_channels = weight.shape[:2]
    padding = (kernel_size - 1) // 2
    col = F.pad(weight.flip([-1]), (0, n - kernel_size)).roll(-padding,
                                                              dims=-1)
    # From here we mimic the circulant construction, but the diagonal multiply is replaced with
    # multiply and then sum across the in-channels.
    complex = col.is_complex()
    log_n = int(math.ceil(math.log2(n)))
    # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n?
    # I've only figured out how to pad to size 1 << (log_n + 1).
    # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c]
    n_extended = n if n == 1 << log_n else 1 << (log_n + 1)
    b_fft = fft(n_extended,
                normalized=True,
                br_first=False,
                with_br_perm=False).to(col.device)
    b_fft.in_size = n
    b_ifft = ifft(n_extended,
                  normalized=True,
                  br_first=True,
                  with_br_perm=False).to(col.device)
    b_ifft.out_size = n
    if n < n_extended:
        col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n)))
        col = torch.cat((col_0, col), dim=-1)
    if not col.is_complex():
        col = real2complex(col)
    # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the
    # circulant matrix.
    col_f = view_as_complex(
        torch.fft(view_as_real(col), signal_ndim=1, normalized=False))
    br_perm = bitreversal_permutation(n_extended,
                                      pytorch_format=True).to(col.device)
    col_f = col_f[..., br_perm]
    # We just want (input_f.unsqueeze(1) * col_f).sum(dim=2).
    # This can be written as matrix multiply but Pytorch 1.6 doesn't yet support complex matrix
    # multiply.

    if not complex:
        return nn.Sequential(Real2Complex(), b_fft, DiagonalMultiplySum(col_f),
                             b_ifft, Complex2Real())
    else:
        return nn.Sequential(b_fft, DiagonalMultiplySum(col_f), b_ifft)
Ejemplo n.º 7
0
 def forward(self, input):
     """
     Parameters:
         input: (batch, *, size)
     Return:
         output: (batch, *, size)
     """
     diagonal = self.diagonal if not self.complex else view_as_complex(
         self.diagonal)
     return complex_mul(input, diagonal)
Ejemplo n.º 8
0
 def forward(self, input):
     """
     Parameters:
         input: (batch, in_channels, size)
     Return:
         output: (batch, out_channels, size)
     """
     diagonal = self.diagonal if not self.complex else view_as_complex(
         self.diagonal)
     return complex_mul(input.unsqueeze(1), diagonal).sum(dim=2)
Ejemplo n.º 9
0
 def forward(self,
             input,
             transpose=False,
             conjugate=False,
             subtwiddle=False):
     """
     Parameters:
         input: (batch, *, in_size)
         transpose: whether the butterfly matrix should be transposed.
         conjugate: whether the butterfly matrix should be conjugated.
         subtwiddle: allow using only part of the parameters for smaller input.
             Could be useful for weight sharing.
             out_size is set to self.nstacks * self.in_size_extended in this case
     Return:
         output: (batch, *, out_size)
     """
     phi, alpha, psi, chi = torch.unbind(self.twiddle, -1)
     c, s = torch.cos(phi), torch.sin(phi)
     # Pytorch 1.6.0 doesn't support complex exp on GPU so we have to use cos/sin
     A = torch.stack(
         (c * torch.cos(alpha + psi), c * torch.sin(alpha + psi)), dim=-1)
     B = torch.stack(
         (s * torch.cos(alpha + chi), s * torch.sin(alpha + chi)), dim=-1)
     C = torch.stack(
         (-s * torch.cos(alpha - chi), -s * torch.sin(alpha - chi)), dim=-1)
     D = torch.stack(
         (c * torch.cos(alpha - psi), c * torch.sin(alpha - psi)), dim=-1)
     twiddle = torch.stack(
         [torch.stack([A, B], dim=-2),
          torch.stack([C, D], dim=-2)], dim=-3)
     twiddle = view_as_complex(twiddle)
     if not subtwiddle:
         output = self.pre_process(input)
     else:
         log_n = int(math.ceil(math.log2(input.size(-1))))
         n = 1 << log_n
         output = self.pre_process(input, padded_size=n)
         twiddle = (twiddle[:, :, :log_n, :n // 2] if self.increasing_stride
                    else twiddle[:, :, -log_n:, :n // 2])
     if conjugate and self.complex:
         twiddle = twiddle.conj()
     if not transpose:
         output = butterfly_multiply(twiddle, output,
                                     self.increasing_stride)
     else:
         twiddle = twiddle.transpose(-1, -2).flip([1, 2])
         last_increasing_stride = self.increasing_stride != (
             (self.nblocks - 1) % 2 == 1)
         output = butterfly_multiply(twiddle, output,
                                     not last_increasing_stride)
     if not subtwiddle:
         return self.post_process(input, output)
     else:
         return self.post_process(input, output, out_size=output.size(-1))
Ejemplo n.º 10
0
 def post_process(self, input, output, out_size=None):
     if out_size is None:
         out_size = self.out_size
     batch = output.shape[0]
     output = output.view(batch, self.matrix_batch, self.nstacks * output.size(-1))
     out_size_extended = 1 << (int(math.ceil(math.log2(out_size))))
     if out_size != out_size_extended:  # Take top rows
         output = output[:, :, :out_size]
     if self.bias is not None:
         bias = self.bias if not self.complex else view_as_complex(self.bias)
         output = output + bias[:, :out_size]
     return output.view(*input.size()[:-2], self.matrix_batch, self.out_size)
Ejemplo n.º 11
0
def butterfly_kronecker(butterfly1: Butterfly,
                        butterfly2: Butterfly) -> Butterfly:
    """Combine two butterflies of size n1 and n2 into their Kronecker product of size n1 * n2.
    They must both have increasing_stride=True or increasing_stride=False.
    If butterfly1 or butterfly2 has padding, then the kronecker product (after flattening input)
    will not produce the same result unless the input is padding in the same way before flattening.

    Only support nstacks==1, nblocks==1 for now.
    """
    assert butterfly1.increasing_stride == butterfly2.increasing_stride
    assert butterfly1.complex == butterfly2.complex
    assert not butterfly1.bias and not butterfly2.bias
    assert butterfly1.nstacks == 1 and butterfly2.nstacks == 1
    assert butterfly1.nblocks == 1 and butterfly2.nblocks == 1
    increasing_stride = butterfly1.increasing_stride
    complex = butterfly1.complex
    log_n1 = butterfly1.twiddle.shape[2]
    log_n2 = butterfly2.twiddle.shape[2]
    log_n = log_n1 + log_n2
    n = 1 << log_n
    twiddle1 = butterfly1.twiddle if not complex else view_as_complex(
        butterfly1.twiddle)
    twiddle2 = butterfly2.twiddle if not complex else view_as_complex(
        butterfly2.twiddle)
    twiddle1 = twiddle1.repeat(1, 1, 1, 1 << log_n2, 1, 1)
    twiddle2 = twiddle2.repeat_interleave(1 << log_n1, dim=3)
    twiddle = (torch.cat(
        (twiddle1, twiddle2), dim=2) if increasing_stride else torch.cat(
            (twiddle2, twiddle1), dim=2))
    b = Butterfly(n,
                  n,
                  bias=False,
                  complex=complex,
                  increasing_stride=increasing_stride).to(twiddle.device)
    b.in_size = butterfly1.in_size * butterfly2.in_size
    b.out_size = butterfly1.out_size * butterfly2.out_size
    with torch.no_grad():
        b_twiddle = b.twiddle if not complex else view_as_complex(b.twiddle)
        b_twiddle.copy_(twiddle)
    return b
Ejemplo n.º 12
0
def ifft(n, normalized=False, br_first=True, with_br_perm=True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs the inverse FFT.
    Parameters:
        n: size of the iFFT. Must be a power of 2.
        normalized: if True, corresponds to unitary iFFT (i.e. multiplied by 1/sqrt(n), not 1/n)
        br_first: which decomposition of iFFT. True corresponds to decimation-in-frequency.
                  False corresponds to decimation-in-time.
        with_br_perm: whether to return both the butterfly and the bit reversal permutation.
    """
    log_n = int(math.ceil(math.log2(n)))
    assert n == 1 << log_n, 'n must be a power of 2'
    factors = []
    for log_size in range(1, log_n + 1):
        size = 1 << log_size
        exp = torch.exp(2j * math.pi * torch.arange(0.0, size // 2) / size)
        o = torch.ones_like(exp)
        twiddle_factor = torch.stack((torch.stack(
            (o, exp), dim=-1), torch.stack((o, -exp), dim=-1)),
                                     dim=-2)
        factors.append(twiddle_factor.repeat(n // size, 1, 1))
    twiddle = torch.stack(factors, dim=0).unsqueeze(0).unsqueeze(0)
    if not br_first:  # Take conjugate transpose of the BP decomposition of fft
        twiddle = twiddle.transpose(-1, -2).flip([2])
    # Divide the whole transform by sqrt(n) by dividing each factor by n^(1/2 log_n) = sqrt(2)
    if normalized:
        twiddle /= math.sqrt(2)
    else:
        twiddle /= 2
    b = Butterfly(n, n, bias=False, complex=True, increasing_stride=br_first)
    with torch.no_grad():
        view_as_complex(b.twiddle).copy_(twiddle)
    if with_br_perm:
        br_perm = FixedPermutation(
            bitreversal_permutation(n, pytorch_format=True))
        return nn.Sequential(br_perm, b) if br_first else nn.Sequential(
            b, br_perm)
    else:
        return b
Ejemplo n.º 13
0
 def test_ifft_unitary(self):
     batch_size = 10
     n = 16
     input = torch.randn(batch_size, n, dtype=torch.complex64)
     normalized = True
     out_torch = view_as_complex(
         torch.ifft(view_as_real(input),
                    signal_ndim=1,
                    normalized=normalized))
     for br_first in [True, False]:
         b = torch_butterfly.special.ifft_unitary(n, br_first=br_first)
         out = b(input)
         self.assertTrue(
             torch.allclose(out, out_torch, self.rtol, self.atol))
Ejemplo n.º 14
0
def perm2butterfly(v: Union[np.ndarray, torch.Tensor],
                   complex: bool = False,
                   increasing_stride: bool = False) -> Butterfly:
    """
    Parameter:
        v: a permutation, stored as a vector, in left-multiplication format.
            (i.e., applying v to a vector x is equivalent to x[p])
        complex: whether the Butterfly is complex or real.
        increasing_stride: whether the returned Butterfly should have increasing_stride=False or
            True. False corresponds to Lemma G.3 and True corresponds to Lemma G.6.
    Return:
        b: a Butterfly that performs the same permutation as v.
    """
    if isinstance(v, torch.Tensor):
        v = v.detach().cpu().numpy()
    n = len(v)
    log_n = int(math.ceil(math.log2(n)))
    if n < 1 << log_n:  # Pad permutation to the next power-of-2 size
        v = np.concatenate([v, np.arange(n, 1 << log_n)])
    if increasing_stride:  # Follow proof of Lemma G.6
        br = bitreversal_permutation(1 << log_n)
        b = perm2butterfly(br[v[br]], complex=complex, increasing_stride=False)
        b.increasing_stride = True
        br_half = bitreversal_permutation((1 << log_n) // 2,
                                          pytorch_format=True)
        with torch.no_grad():
            b.twiddle.copy_(b.twiddle[:, :, :, br_half])
        b.in_size = b.out_size = n
        return b
    v = v[None]
    twiddle_right_factors, twiddle_left_factors = [], []
    for _ in range(log_n):
        right_factor, left_factor, v = outer_twiddle_factors(v)
        twiddle_right_factors.append(right_factor)
        twiddle_left_factors.append(left_factor)
    b = Butterfly(n,
                  n,
                  bias=False,
                  complex=complex,
                  increasing_stride=False,
                  nblocks=2)
    with torch.no_grad():
        b_twiddle = b.twiddle if not complex else view_as_complex(b.twiddle)
        twiddle = torch.stack([
            torch.stack(twiddle_right_factors),
            torch.stack(twiddle_left_factors).flip([0])
        ]).unsqueeze(0)
        b_twiddle.copy_(twiddle if not complex else real2complex(twiddle))
    return b
Ejemplo n.º 15
0
 def reset_parameters(self):
     """Initialize bias the same way as torch.nn.Linear."""
     twiddle = self.twiddle if not self.complex else view_as_complex(
         self.twiddle)
     if self.init == 'randn':
         # complex randn already has the correct scaling of stddev=1.0
         scaling = 1.0 / math.sqrt(2)
         with torch.no_grad():
             twiddle.copy_(
                 torch.randn(twiddle.shape, dtype=twiddle.dtype) * scaling)
     elif self.init == 'ortho':
         twiddle_core_shape = twiddle.shape[:-2]
         if not self.complex:
             theta = torch.rand(twiddle_core_shape) * math.pi * 2
             c, s = torch.cos(theta), torch.sin(theta)
             det = torch.randint(
                 0, 2, twiddle_core_shape,
                 dtype=c.dtype) * 2 - 1  # Rotation (+1) or reflection (-1)
             with torch.no_grad():
                 twiddle.copy_(
                     torch.stack(
                         (torch.stack((det * c, -det * s),
                                      dim=-1), torch.stack((s, c), dim=-1)),
                         dim=-2))
         else:
             # Sampling from the Haar measure on U(2) is a bit subtle.
             # Using the parameterization here: http://home.lu.lv/~sd20008/papers/essays/Random%20unitary%20[paper].pdf
             phi = torch.asin(torch.sqrt(torch.rand(twiddle_core_shape)))
             c, s = torch.cos(phi), torch.sin(phi)
             alpha, psi, chi = torch.rand((3, ) +
                                          twiddle_core_shape) * math.pi * 2
             A = torch.exp(1j * (alpha + psi)) * c
             B = torch.exp(1j * (alpha + chi)) * s
             C = -torch.exp(1j * (alpha - chi)) * s
             D = torch.exp(1j * (alpha - psi)) * c
             with torch.no_grad():
                 twiddle.copy_(
                     torch.stack((torch.stack(
                         (A, B), dim=-1), torch.stack((C, D), dim=-1)),
                                 dim=-2))
     elif self.init == 'identity':
         twiddle_new = torch.eye(2, dtype=twiddle.dtype).reshape(
             1, 1, 1, 1, 2, 2)
         twiddle_new = twiddle_new.expand(*twiddle.shape).contiguous()
         with torch.no_grad():
             twiddle.copy_(twiddle_new)
     if self.bias is not None:
         bound = 1 / math.sqrt(self.in_size)
         nn.init.uniform_(self.bias, -bound, bound)
Ejemplo n.º 16
0
def diagonal_butterfly(butterfly: Butterfly,
                       diagonal: torch.Tensor,
                       diag_first: bool,
                       inplace: bool = True) -> Butterfly:
    """
    Combine a Butterfly and a diagonal into another Butterfly.
    Only support nstacks==1 for now.
    Parameters:
        butterfly: Butterfly(in_size, out_size)
        diagonal: size (in_size,) if diag_first, else (out_size,). Should be of type complex
            if butterfly.complex == True.
        diag_first: If True, the map is input -> diagonal -> butterfly.
            If False, the map is input -> butterfly -> diagonal.
        inplace: whether to modify the input Butterfly
    """
    assert butterfly.nstacks == 1
    assert butterfly.bias is None
    twiddle = (butterfly.twiddle.clone() if not butterfly.complex else
               view_as_complex(butterfly.twiddle).clone())
    n = 1 << twiddle.shape[2]
    if diagonal.shape[-1] < n:
        diagonal = F.pad(diagonal, (0, n - diagonal.shape[-1]), value=1)
    if diag_first:
        if butterfly.increasing_stride:
            twiddle[:, 0, 0, :, :, 0] *= diagonal[::2].unsqueeze(-1)
            twiddle[:, 0, 0, :, :, 1] *= diagonal[1::2].unsqueeze(-1)
        else:
            n = diagonal.shape[-1]
            twiddle[:, 0, 0, :, :, 0] *= diagonal[:n // 2].unsqueeze(-1)
            twiddle[:, 0, 0, :, :, 1] *= diagonal[n // 2:].unsqueeze(-1)
    else:
        # Whether the last block is increasing or decreasing stride
        increasing_stride = butterfly.increasing_stride != (
            (butterfly.nblocks - 1) % 2 == 1)
        if increasing_stride:
            n = diagonal.shape[-1]
            twiddle[:, -1, -1, :, 0, :] *= diagonal[:n // 2].unsqueeze(-1)
            twiddle[:, -1, -1, :, 1, :] *= diagonal[n // 2:].unsqueeze(-1)
        else:
            twiddle[:, -1, -1, :, 0, :] *= diagonal[::2].unsqueeze(-1)
            twiddle[:, -1, -1, :, 1, :] *= diagonal[1::2].unsqueeze(-1)
    out_butterfly = butterfly if inplace else copy.deepcopy(butterfly)
    with torch.no_grad():
        out_butterfly.twiddle.copy_(
            twiddle if not butterfly.complex else view_as_real(twiddle))
    return out_butterfly
Ejemplo n.º 17
0
 def test_butterfly(self):
     batch_size = 10
     for device in ['cpu'
                    ] + ([] if not torch.cuda.is_available() else ['cuda']):
         for in_size, out_size in [(7, 15), (15, 7)]:
             for complex in [False, True]:
                 for increasing_stride in [True, False]:
                     for init in ['randn', 'ortho', 'identity']:
                         for nblocks in [1, 2, 3]:
                             b = torch_butterfly.Butterfly(
                                 in_size,
                                 out_size,
                                 True,
                                 complex,
                                 increasing_stride,
                                 init,
                                 nblocks=nblocks).to(device)
                             dtype = torch.float32 if not complex else torch.complex64
                             input = torch.randn(batch_size,
                                                 in_size,
                                                 dtype=dtype,
                                                 device=device)
                             output = b(input)
                             self.assertTrue(
                                 output.shape == (batch_size, out_size),
                                 (output.shape, device, (in_size, out_size),
                                  complex, init, nblocks))
                             if init == 'ortho':
                                 twiddle = b.twiddle if not b.complex else view_as_complex(
                                     b.twiddle)
                                 twiddle_np = twiddle.detach().to(
                                     'cpu').numpy()
                                 twiddle_np = twiddle_np.reshape(-1, 2, 2)
                                 twiddle_norm = np.linalg.norm(twiddle_np,
                                                               ord=2,
                                                               axis=(1, 2))
                                 self.assertTrue(
                                     np.allclose(twiddle_norm, 1),
                                     (twiddle_norm, device,
                                      (in_size, out_size), complex, init))
Ejemplo n.º 18
0
def circulant(col, transposed=False, separate_diagonal=True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs circulant matrix
    multiplication.
    Parameters:
        col: torch.Tensor of size (n, ). The first column of the circulant matrix.
        transposed: if True, then the circulant matrix is transposed, i.e. col is the first *row*
                    of the matrix.
        separate_diagonal: if True, the returned nn.Module is Butterfly, Diagonal, Butterfly.
                           if False, the diagonal is combined into the Butterfly part.
    """
    assert col.dim() == 1, 'Vector col must have dimension 1'
    complex = col.is_complex()
    n = col.shape[0]
    log_n = int(math.ceil(math.log2(n)))
    # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n?
    # I've only figured out how to pad to size 1 << (log_n + 1).
    # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c]
    n_extended = n if n == 1 << log_n else 1 << (log_n + 1)
    b_fft = fft(n_extended,
                normalized=True,
                br_first=False,
                with_br_perm=False).to(col.device)
    b_fft.in_size = n
    b_ifft = ifft(n_extended,
                  normalized=True,
                  br_first=True,
                  with_br_perm=False).to(col.device)
    b_ifft.out_size = n
    if n < n_extended:
        col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n)))
        col = torch.cat((col_0, col))
    if not col.is_complex():
        col = real2complex(col)
    # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the
    # circulant matrix.
    col_f = view_as_complex(
        torch.fft(view_as_real(col), signal_ndim=1, normalized=False))
    if transposed:
        # We could have just transposed the iFFT * Diag * FFT to get FFT * Diag * iFFT.
        # Instead we use the fact that row is the reverse of col, but the 0-th element stays put.
        # This corresponds to the same reversal in the frequency domain.
        # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal
        col_f = torch.cat((col_f[:1], col_f[1:].flip([0])))
    br_perm = bitreversal_permutation(n_extended,
                                      pytorch_format=True).to(col.device)
    diag = col_f[..., br_perm]
    if separate_diagonal:
        if not complex:
            return nn.Sequential(Real2Complex(), b_fft,
                                 Diagonal(diagonal_init=diag), b_ifft,
                                 Complex2Real())
        else:
            return nn.Sequential(b_fft, Diagonal(diagonal_init=diag), b_ifft)
    else:
        # Combine the diagonal with the last twiddle factor of b_fft
        with torch.no_grad():
            b_fft = diagonal_butterfly(b_fft,
                                       diag,
                                       diag_first=False,
                                       inplace=True)
        # Combine the b_fft and b_ifft into one Butterfly (with nblocks=2).
        b = butterfly_product(b_fft, b_ifft)
        b.in_size = n
        b.out_size = n
        return b if complex else nn.Sequential(Real2Complex(), b,
                                               Complex2Real())
Ejemplo n.º 19
0
def perm2butterfly_slow(v: Union[np.ndarray, torch.Tensor],
                        complex: bool = False,
                        increasing_stride: bool = False) -> Butterfly:
    """
    Convert a permutation to a Butterfly that performs the same permutation.
    This implementation is slower but follows the proofs in Appendix G more closely.
    Parameter:
        v: a permutation, stored as a vector, in left-multiplication format.
            (i.e., applying v to a vector x is equivalent to x[p])
        complex: whether the Butterfly is complex or real.
        increasing_stride: whether the returned Butterfly should have increasing_stride=False or
            True. False corresponds to Lemma G.3 and True corresponds to Lemma G.6.
    Return:
        b: a Butterfly that performs the same permutation as v.
    """
    if isinstance(v, torch.Tensor):
        v = v.detach().cpu().numpy()
    n = len(v)
    log_n = int(math.ceil(math.log2(n)))
    if n < 1 << log_n:  # Pad permutation to the next power-of-2 size
        v = np.concatenate([v, np.arange(n, 1 << log_n)])
    if increasing_stride:  # Follow proof of Lemma G.6
        br = bitreversal_permutation(1 << log_n)
        b = perm2butterfly_slow(br[v[br]],
                                complex=complex,
                                increasing_stride=False)
        b.increasing_stride = True
        br_half = bitreversal_permutation((1 << log_n) // 2,
                                          pytorch_format=True)
        with torch.no_grad():
            b.twiddle.copy_(b.twiddle[:, :, :, br_half])
        b.in_size = b.out_size = n
        return b
    # modular_balance expects right-multiplication format so we convert the format of v.
    Rinv_perms, L_vec = modular_balance(invert(v))
    L_perms = list(reversed(modular_balanced_to_butterfly_factor(L_vec)))
    R_perms = [
        perm_vec_to_mat(invert(p), left=True) for p in reversed(Rinv_perms)
    ]
    # Stored in increasing_stride=True twiddle format.
    # Need to take transpose because the matrices are in right-multiplication format.
    L_twiddle = torch.stack([
        matrix_to_butterfly_factor(l.T, log_k=i + 1, pytorch_format=True)
        for i, l in enumerate(L_perms)
    ])
    # Stored in increasing_stride=False twiddle format so we need to flip the order
    R_twiddle = torch.stack([
        matrix_to_butterfly_factor(r, log_k=i + 1, pytorch_format=True)
        for i, r in enumerate(R_perms)
    ]).flip([0])
    b = Butterfly(n,
                  n,
                  bias=False,
                  complex=complex,
                  increasing_stride=False,
                  nblocks=2)
    with torch.no_grad():
        b_twiddle = b.twiddle if not complex else view_as_complex(b.twiddle)
        twiddle = torch.stack([R_twiddle, L_twiddle]).unsqueeze(0)
        b_twiddle.copy_(twiddle if not complex else real2complex(twiddle))
    return b
Ejemplo n.º 20
0
def conv2d_circular_multichannel(n1: int,
                                 n2: int,
                                 weight: torch.Tensor,
                                 flatten: bool = False) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs nn.Conv2d
    with multiple in/out channels, with circular padding.
    The output of nn.Conv2d must have the same size as the input (i.e. kernel size must be 2k + 1,
    and padding k for some integer k).
    Parameters:
        n1: size of the last dimension of the input.
        n2: size of the second to last dimension of the input.
        weight: torch.Tensor of size (out_channels, in_channels, kernel_size2, kernel_size1).
            Kernel_size must be odd, and smaller than n1/n2. Padding is assumed to be
            (kernel_size - 1) // 2.
        flatten: whether to internally flatten the last 2 dimensions of the input. Only support n1
            and n2 being powers of 2.
    """
    assert weight.dim() == 4, 'Weight must have dimension 4'
    kernel_size2, kernel_size1 = weight.shape[-2], weight.shape[-1]
    assert kernel_size1 < n1, kernel_size2 < n2
    assert kernel_size1 % 2 == 1 and kernel_size2 % 2 == 1, 'Kernel size must be odd'
    out_channels, in_channels = weight.shape[:2]
    padding1 = (kernel_size1 - 1) // 2
    padding2 = (kernel_size2 - 1) // 2
    col = F.pad(weight.flip([-1]), (0, n1 - kernel_size1)).roll(-padding1,
                                                                dims=-1)
    col = F.pad(col.flip([-2]), (0, 0, 0, n2 - kernel_size2)).roll(-padding2,
                                                                   dims=-2)
    # From here we mimic the circulant construction, but the diagonal multiply is replaced with
    # multiply and then sum across the in-channels.
    complex = col.is_complex()
    log_n1 = int(math.ceil(math.log2(n1)))
    log_n2 = int(math.ceil(math.log2(n2)))
    if flatten:
        assert n1 == 1 << log_n1, n2 == 1 << log_n2
    # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n1?
    # I've only figured out how to pad to size 1 << (log_n1 + 1).
    # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c]
    n_extended1 = n1 if n1 == 1 << log_n1 else 1 << (log_n1 + 1)
    n_extended2 = n2 if n2 == 1 << log_n2 else 1 << (log_n2 + 1)
    b_fft = fft2d(n_extended1,
                  n_extended2,
                  normalized=True,
                  br_first=False,
                  with_br_perm=False,
                  flatten=flatten).to(col.device)
    if not flatten:
        b_fft.map1.in_size = n1
        b_fft.map2.in_size = n2
    else:
        b_fft = b_fft[1]  # Ignore the nn.Flatten and Unflatten2D
    b_ifft = ifft2d(n_extended1,
                    n_extended2,
                    normalized=True,
                    br_first=True,
                    with_br_perm=False,
                    flatten=flatten).to(col.device)
    if not flatten:
        b_ifft.map1.out_size = n1
        b_ifft.map2.out_size = n2
    else:
        b_ifft = b_ifft[1]  # Ignore the nn.Flatten and Unflatten2D
    if n1 < n_extended1:
        col_0 = F.pad(col, (0, 2 * ((1 << log_n1) - n1)))
        col = torch.cat((col_0, col), dim=-1)
    if n2 < n_extended2:
        col_0 = F.pad(col, (0, 0, 0, 2 * ((1 << log_n2) - n2)))
        col = torch.cat((col_0, col), dim=-2)
    if not col.is_complex():
        col = real2complex(col)
    # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the
    # circulant matrix.
    col_f = view_as_complex(
        torch.fft(view_as_real(col), signal_ndim=2, normalized=False))
    br_perm1 = bitreversal_permutation(n_extended1,
                                       pytorch_format=True).to(col.device)
    br_perm2 = bitreversal_permutation(n_extended2,
                                       pytorch_format=True).to(col.device)
    # col_f[..., br_perm2, br_perm1] would error "shape mismatch: indexing tensors could not be
    # broadcast together"
    col_f = col_f[..., br_perm2, :][..., br_perm1]
    if flatten:
        col_f = col_f.reshape(*col_f.shape[:-2],
                              col_f.shape[-2] * col_f.shape[-1])
    # We just want (input_f.unsqueeze(1) * col_f).sum(dim=2).
    # This can be written as matrix multiply but Pytorch 1.6 doesn't yet support complex matrix
    # multiply.
    if not complex:
        if not flatten:
            return nn.Sequential(Real2Complex(), b_fft,
                                 DiagonalMultiplySum(col_f), b_ifft,
                                 Complex2Real())
        else:
            return nn.Sequential(Real2Complex(), nn.Flatten(start_dim=-2),
                                 b_fft, DiagonalMultiplySum(col_f), b_ifft,
                                 Unflatten2D(n1), Complex2Real())
    else:
        if not flatten:
            return nn.Sequential(b_fft, DiagonalMultiplySum(col_f), b_ifft)
        else:
            return nn.Sequential(nn.Flatten(start_dim=-2), b_fft,
                                 DiagonalMultiplySum(col_f), b_ifft,
                                 Unflatten2D(n1))