예제 #1
0
 def test_fft2d_init(self):
     batch_size = 10
     in_channels = 3
     out_channels = 4
     n1, n2 = 16, 32
     input = torch.randn(batch_size, in_channels, n2, n1)
     for kernel_size1 in [1, 3, 5, 7]:
         for kernel_size2 in [1, 3, 5, 7]:
             padding1 = (kernel_size1 - 1) // 2
             padding2 = (kernel_size2 - 1) // 2
             conv = nn.Conv2d(in_channels,
                              out_channels, (kernel_size2, kernel_size1),
                              padding=(padding2, padding1),
                              padding_mode='circular',
                              bias=False)
             out_torch = conv(input)
             weight = conv.weight
             w = F.pad(weight.flip(dims=(-1, )),
                       (0, n1 - kernel_size1)).roll(-padding1, dims=-1)
             w = F.pad(w.flip(dims=(-2, )),
                       (0, 0, 0, n2 - kernel_size2)).roll(-padding2,
                                                          dims=-2)
             increasing_strides = [False, False, True]
             inits = ['fft_no_br', 'fft_no_br', 'ifft_no_br']
             for nblocks in [1, 2, 3]:
                 Kd, K1, K2 = [
                     TensorProduct(
                         Butterfly(n1,
                                   n1,
                                   bias=False,
                                   complex=complex,
                                   increasing_stride=incstride,
                                   init=i,
                                   nblocks=nblocks),
                         Butterfly(n2,
                                   n2,
                                   bias=False,
                                   complex=complex,
                                   increasing_stride=incstride,
                                   init=i,
                                   nblocks=nblocks))
                     for incstride, i in zip(increasing_strides, inits)
                 ]
                 with torch.no_grad():
                     Kd.map1 *= math.sqrt(n1)
                     Kd.map2 *= math.sqrt(n2)
                 out = K2(
                     complex_matmul(
                         K1(real2complex(input)).permute(2, 3, 0, 1),
                         Kd(real2complex(w)).permute(2, 3, 1, 0)).permute(
                             2, 3, 0, 1)).real
                 self.assertTrue(
                     torch.allclose(out, out_torch, self.rtol, self.atol))
예제 #2
0
def perm2butterfly_slow(v: Union[np.ndarray, torch.Tensor],
                        complex: bool = False,
                        increasing_stride: bool = False) -> Butterfly:
    """
    Convert a permutation to a Butterfly that performs the same permutation.
    This implementation is slower but follows the proofs in Appendix G more closely.
    Parameter:
        v: a permutation, stored as a vector, in left-multiplication format.
            (i.e., applying v to a vector x is equivalent to x[p])
        complex: whether the Butterfly is complex or real.
        increasing_stride: whether the returned Butterfly should have increasing_stride=False or
            True. False corresponds to Lemma G.3 and True corresponds to Lemma G.6.
    Return:
        b: a Butterfly that performs the same permutation as v.
    """
    if isinstance(v, torch.Tensor):
        v = v.detach().cpu().numpy()
    n = len(v)
    log_n = int(math.ceil(math.log2(n)))
    if n < 1 << log_n:  # Pad permutation to the next power-of-2 size
        v = np.concatenate([v, np.arange(n, 1 << log_n)])
    if increasing_stride:  # Follow proof of Lemma G.6
        br = bitreversal_permutation(1 << log_n)
        b = perm2butterfly_slow(br[v[br]],
                                complex=complex,
                                increasing_stride=False)
        b.increasing_stride = True
        br_half = bitreversal_permutation((1 << log_n) // 2,
                                          pytorch_format=True)
        with torch.no_grad():
            b.twiddle.copy_(b.twiddle[:, :, :, br_half])
        b.in_size = b.out_size = n
        return b
    # modular_balance expects right-multiplication format so we convert the format of v.
    Rinv_perms, L_vec = modular_balance(invert(v))
    L_perms = list(reversed(modular_balanced_to_butterfly_factor(L_vec)))
    R_perms = [
        perm_vec_to_mat(invert(p), left=True) for p in reversed(Rinv_perms)
    ]
    # Stored in increasing_stride=True twiddle format.
    # Need to take transpose because the matrices are in right-multiplication format.
    L_twiddle = torch.stack([
        matrix_to_butterfly_factor(l.T, log_k=i + 1, pytorch_format=True)
        for i, l in enumerate(L_perms)
    ])
    # Stored in increasing_stride=False twiddle format so we need to flip the order
    R_twiddle = torch.stack([
        matrix_to_butterfly_factor(r, log_k=i + 1, pytorch_format=True)
        for i, r in enumerate(R_perms)
    ]).flip([0])
    twiddle = torch.stack([R_twiddle, L_twiddle]).unsqueeze(0)
    b = Butterfly(n,
                  n,
                  bias=False,
                  complex=complex,
                  increasing_stride=False,
                  init=twiddle if not complex else real2complex(twiddle),
                  nblocks=2)
    return b
예제 #3
0
파일: special.py 프로젝트: sfox14/butterfly
def conv1d_circular_multichannel(n, weight) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs nn.Conv1d
    with multiple in/out channels, with circular padding.
    The output of nn.Conv1d must have the same size as the input (i.e. kernel size must be 2k + 1,
    and padding k for some integer k).
    Parameters:
        n: size of the input.
        weight: torch.Tensor of size (out_channels, in_channels, kernel_size). Kernel_size must be
                odd, and smaller than n. Padding is assumed to be (kernel_size - 1) // 2.
    """
    assert weight.dim() == 3, 'Weight must have dimension 3'
    kernel_size = weight.shape[-1]
    assert kernel_size < n
    assert kernel_size % 2 == 1, 'Kernel size must be odd'
    out_channels, in_channels = weight.shape[:2]
    padding = (kernel_size - 1) // 2
    col = F.pad(weight.flip([-1]), (0, n - kernel_size)).roll(-padding,
                                                              dims=-1)
    # From here we mimic the circulant construction, but the diagonal multiply is replaced with
    # multiply and then sum across the in-channels.
    complex = col.is_complex()
    log_n = int(math.ceil(math.log2(n)))
    # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n?
    # I've only figured out how to pad to size 1 << (log_n + 1).
    # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c]
    n_extended = n if n == 1 << log_n else 1 << (log_n + 1)
    b_fft = fft(n_extended,
                normalized=True,
                br_first=False,
                with_br_perm=False).to(col.device)
    b_fft.in_size = n
    b_ifft = ifft(n_extended,
                  normalized=True,
                  br_first=True,
                  with_br_perm=False).to(col.device)
    b_ifft.out_size = n
    if n < n_extended:
        col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n)))
        col = torch.cat((col_0, col), dim=-1)
    if not col.is_complex():
        col = real2complex(col)
    # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the
    # circulant matrix.
    col_f = view_as_complex(
        torch.fft(view_as_real(col), signal_ndim=1, normalized=False))
    br_perm = bitreversal_permutation(n_extended,
                                      pytorch_format=True).to(col.device)
    col_f = col_f[..., br_perm]
    # We just want (input_f.unsqueeze(1) * col_f).sum(dim=2).
    # This can be written as matrix multiply but Pytorch 1.6 doesn't yet support complex matrix
    # multiply.

    if not complex:
        return nn.Sequential(Real2Complex(), b_fft, DiagonalMultiplySum(col_f),
                             b_ifft, Complex2Real())
    else:
        return nn.Sequential(b_fft, DiagonalMultiplySum(col_f), b_ifft)
예제 #4
0
def perm2butterfly(v: Union[np.ndarray, torch.Tensor],
                   complex: bool = False,
                   increasing_stride: bool = False) -> Butterfly:
    """
    Parameter:
        v: a permutation, stored as a vector, in left-multiplication format.
            (i.e., applying v to a vector x is equivalent to x[p])
        complex: whether the Butterfly is complex or real.
        increasing_stride: whether the returned Butterfly should have increasing_stride=False or
            True. False corresponds to Lemma G.3 and True corresponds to Lemma G.6.
    Return:
        b: a Butterfly that performs the same permutation as v.
    """
    if isinstance(v, torch.Tensor):
        v = v.detach().cpu().numpy()
    n = len(v)
    log_n = int(math.ceil(math.log2(n)))
    if n < 1 << log_n:  # Pad permutation to the next power-of-2 size
        v = np.concatenate([v, np.arange(n, 1 << log_n)])
    if increasing_stride:  # Follow proof of Lemma G.6
        br = bitreversal_permutation(1 << log_n)
        b = perm2butterfly(br[v[br]], complex=complex, increasing_stride=False)
        b.increasing_stride = True
        br_half = bitreversal_permutation((1 << log_n) // 2,
                                          pytorch_format=True)
        with torch.no_grad():
            b.twiddle.copy_(b.twiddle[:, :, :, br_half])
        b.in_size = b.out_size = n
        return b
    v = v[None]
    twiddle_right_factors, twiddle_left_factors = [], []
    for _ in range(log_n):
        right_factor, left_factor, v = outer_twiddle_factors(v)
        twiddle_right_factors.append(right_factor)
        twiddle_left_factors.append(left_factor)
    b = Butterfly(n,
                  n,
                  bias=False,
                  complex=complex,
                  increasing_stride=False,
                  nblocks=2)
    with torch.no_grad():
        b_twiddle = b.twiddle if not complex else view_as_complex(b.twiddle)
        twiddle = torch.stack([
            torch.stack(twiddle_right_factors),
            torch.stack(twiddle_left_factors).flip([0])
        ]).unsqueeze(0)
        b_twiddle.copy_(twiddle if not complex else real2complex(twiddle))
    return b
예제 #5
0
def conv2d_circular_multichannel(n1: int,
                                 n2: int,
                                 weight: torch.Tensor,
                                 flatten: bool = False) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs nn.Conv2d
    with multiple in/out channels, with circular padding.
    The output of nn.Conv2d must have the same size as the input (i.e. kernel size must be 2k + 1,
    and padding k for some integer k).
    Parameters:
        n1: size of the last dimension of the input.
        n2: size of the second to last dimension of the input.
        weight: torch.Tensor of size (out_channels, in_channels, kernel_size2, kernel_size1).
            Kernel_size must be odd, and smaller than n1/n2. Padding is assumed to be
            (kernel_size - 1) // 2.
        flatten: whether to internally flatten the last 2 dimensions of the input. Only support n1
            and n2 being powers of 2.
    """
    assert weight.dim() == 4, 'Weight must have dimension 4'
    kernel_size2, kernel_size1 = weight.shape[-2], weight.shape[-1]
    assert kernel_size1 < n1, kernel_size2 < n2
    assert kernel_size1 % 2 == 1 and kernel_size2 % 2 == 1, 'Kernel size must be odd'
    out_channels, in_channels = weight.shape[:2]
    padding1 = (kernel_size1 - 1) // 2
    padding2 = (kernel_size2 - 1) // 2
    col = F.pad(weight.flip([-1]), (0, n1 - kernel_size1)).roll(-padding1,
                                                                dims=-1)
    col = F.pad(col.flip([-2]), (0, 0, 0, n2 - kernel_size2)).roll(-padding2,
                                                                   dims=-2)
    # From here we mimic the circulant construction, but the diagonal multiply is replaced with
    # multiply and then sum across the in-channels.
    complex = col.is_complex()
    log_n1 = int(math.ceil(math.log2(n1)))
    log_n2 = int(math.ceil(math.log2(n2)))
    if flatten:
        assert n1 == 1 << log_n1, n2 == 1 << log_n2
    # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n1?
    # I've only figured out how to pad to size 1 << (log_n1 + 1).
    # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c]
    n_extended1 = n1 if n1 == 1 << log_n1 else 1 << (log_n1 + 1)
    n_extended2 = n2 if n2 == 1 << log_n2 else 1 << (log_n2 + 1)
    b_fft = fft2d(n_extended1,
                  n_extended2,
                  normalized=True,
                  br_first=False,
                  with_br_perm=False,
                  flatten=flatten).to(col.device)
    if not flatten:
        b_fft.map1.in_size = n1
        b_fft.map2.in_size = n2
    else:
        b_fft = b_fft[1]  # Ignore the nn.Flatten and nn.Unflatten
    b_ifft = ifft2d(n_extended1,
                    n_extended2,
                    normalized=True,
                    br_first=True,
                    with_br_perm=False,
                    flatten=flatten).to(col.device)
    if not flatten:
        b_ifft.map1.out_size = n1
        b_ifft.map2.out_size = n2
    else:
        b_ifft = b_ifft[1]  # Ignore the nn.Flatten and nn.Unflatten
    if n1 < n_extended1:
        col_0 = F.pad(col, (0, 2 * ((1 << log_n1) - n1)))
        col = torch.cat((col_0, col), dim=-1)
    if n2 < n_extended2:
        col_0 = F.pad(col, (0, 0, 0, 2 * ((1 << log_n2) - n2)))
        col = torch.cat((col_0, col), dim=-2)
    if not col.is_complex():
        col = real2complex(col)
    # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the
    # circulant matrix.
    col_f = torch.fft.fftn(col, dim=(-1, -2), norm=None)
    br_perm1 = bitreversal_permutation(n_extended1,
                                       pytorch_format=True).to(col.device)
    br_perm2 = bitreversal_permutation(n_extended2,
                                       pytorch_format=True).to(col.device)
    # col_f[..., br_perm2, br_perm1] would error "shape mismatch: indexing tensors could not be
    # broadcast together"
    # col_f = col_f[..., br_perm2, :][..., br_perm1]
    col_f = torch.view_as_complex(
        torch.view_as_real(col_f)[..., br_perm2, :, :][..., br_perm1, :])
    if flatten:
        col_f = col_f.reshape(*col_f.shape[:-2],
                              col_f.shape[-2] * col_f.shape[-1])
    # We just want (input_f.unsqueeze(1) * col_f).sum(dim=2).
    # This can be written as a complex matrix multiply as well.
    if not complex:
        if not flatten:
            return nn.Sequential(Real2Complex(), b_fft,
                                 DiagonalMultiplySum(col_f), b_ifft,
                                 Complex2Real())
        else:
            return nn.Sequential(Real2Complex(), nn.Flatten(start_dim=-2),
                                 b_fft, DiagonalMultiplySum(col_f), b_ifft,
                                 nn.Unflatten(-1, (n2, n1)), Complex2Real())
    else:
        if not flatten:
            return nn.Sequential(b_fft, DiagonalMultiplySum(col_f), b_ifft)
        else:
            return nn.Sequential(nn.Flatten(start_dim=-2), b_fft,
                                 DiagonalMultiplySum(col_f), b_ifft,
                                 nn.Unflatten(-1, (n2, n1)))
예제 #6
0
def circulant(col, transposed=False, separate_diagonal=True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs circulant matrix
    multiplication.
    Parameters:
        col: torch.Tensor of size (n, ). The first column of the circulant matrix.
        transposed: if True, then the circulant matrix is transposed, i.e. col is the first *row*
                    of the matrix.
        separate_diagonal: if True, the returned nn.Module is Butterfly, Diagonal, Butterfly.
                           if False, the diagonal is combined into the Butterfly part.
    """
    assert col.dim() == 1, 'Vector col must have dimension 1'
    complex = col.is_complex()
    n = col.shape[0]
    log_n = int(math.ceil(math.log2(n)))
    # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n?
    # I've only figured out how to pad to size 1 << (log_n + 1).
    # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c]
    n_extended = n if n == 1 << log_n else 1 << (log_n + 1)
    b_fft = fft(n_extended,
                normalized=True,
                br_first=False,
                with_br_perm=False).to(col.device)
    b_fft.in_size = n
    b_ifft = ifft(n_extended,
                  normalized=True,
                  br_first=True,
                  with_br_perm=False).to(col.device)
    b_ifft.out_size = n
    if n < n_extended:
        col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n)))
        col = torch.cat((col_0, col))
    if not col.is_complex():
        col = real2complex(col)
    # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the
    # circulant matrix.
    col_f = torch.fft.fft(col, norm=None)
    if transposed:
        # We could have just transposed the iFFT * Diag * FFT to get FFT * Diag * iFFT.
        # Instead we use the fact that row is the reverse of col, but the 0-th element stays put.
        # This corresponds to the same reversal in the frequency domain.
        # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal
        col_f = torch.cat((col_f[:1], col_f[1:].flip([0])))
    br_perm = bitreversal_permutation(n_extended,
                                      pytorch_format=True).to(col.device)
    # diag = col_f[..., br_perm]
    diag = index_last_dim(col_f, br_perm)
    if separate_diagonal:
        if not complex:
            return nn.Sequential(Real2Complex(), b_fft,
                                 Diagonal(diagonal_init=diag), b_ifft,
                                 Complex2Real())
        else:
            return nn.Sequential(b_fft, Diagonal(diagonal_init=diag), b_ifft)
    else:
        # Combine the diagonal with the last twiddle factor of b_fft
        with torch.no_grad():
            b_fft = diagonal_butterfly(b_fft,
                                       diag,
                                       diag_first=False,
                                       inplace=True)
        # Combine the b_fft and b_ifft into one Butterfly (with nblocks=2).
        b = butterfly_product(b_fft, b_ifft)
        b.in_size = n
        b.out_size = n
        return b if complex else nn.Sequential(Real2Complex(), b,
                                               Complex2Real())