Пример #1
0
def fastfood(diag1: torch.Tensor,
             diag2: torch.Tensor,
             diag3: torch.Tensor,
             permutation: torch.Tensor,
             normalized: bool = False,
             increasing_stride: bool = True,
             separate_diagonal: bool = True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that performs Fastfood multiplication:
    x -> Diag3 @ H @ Diag2 @ P @ H @ Diag1,
    where H is the Hadamard matrix and P is a permutation matrix.
    Parameters:
        diag1: (n,), where n is a power of 2.
        diag2: (n,)
        diag3: (n,)
        permutation: (n,)
        normalized: if True, corresponds to the orthogonal Hadamard transform
                    (i.e. multiplied by 1/sqrt(n))
        increasing_stride: whether the first Butterfly in the sequence has increasing stride.
        separate_diagonal: if False, the diagonal is combined into the Butterfly part.
    """
    n, = diag1.shape
    assert diag2.shape == diag3.shape == permutation.shape == (n, )
    h1 = hadamard(n, normalized, increasing_stride)
    h2 = hadamard(n, normalized, not increasing_stride)
    if not separate_diagonal:
        h1 = diagonal_butterfly(h1, diag1, diag_first=True)
        h2 = diagonal_butterfly(h2, diag2, diag_first=True)
        h2 = diagonal_butterfly(h2, diag3, diag_first=False)
        return nn.Sequential(h1, FixedPermutation(permutation), h2)
    else:
        return nn.Sequential(Diagonal(diagonal_init=diag1), h1,
                             FixedPermutation(permutation),
                             Diagonal(diagonal_init=diag2), h2,
                             Diagonal(diagonal_init=diag3))
Пример #2
0
def hadamard_diagonal(diagonals: torch.Tensor,
                      normalized: bool = False,
                      increasing_stride: bool = True,
                      separate_diagonal: bool = True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that performs multiplication by H D H D ... H D,
    where H is the Hadamard matrix and D is a diagonal matrix
    Parameters:
        diagonals: (k, n), where k is the number of diagonal matrices and n is the dimension of the
            Hadamard transform.
        normalized: if True, corresponds to the orthogonal Hadamard transform
                    (i.e. multiplied by 1/sqrt(n))
        increasing_stride: whether the returned Butterfly has increasing stride.
        separate_diagonal: if False, the diagonal is combined into the Butterfly part.
    """
    k, n = diagonals.shape
    if not separate_diagonal:
        butterflies = []
        for i, diagonal in enumerate(diagonals.unbind()):
            cur_increasing_stride = increasing_stride != (i % 2 == 1)
            h = hadamard(n, normalized, cur_increasing_stride)
            butterflies.append(diagonal_butterfly(h, diagonal,
                                                  diag_first=True))
        return reduce(butterfly_product, butterflies)
    else:
        modules = []
        for i, diagonal in enumerate(diagonals.unbind()):
            modules.append(Diagonal(diagonal_init=diagonal))
            cur_increasing_stride = increasing_stride != (i % 2 == 1)
            h = hadamard(n, normalized, cur_increasing_stride)
            modules.append(h)
        return nn.Sequential(*modules)
Пример #3
0
def acdc(diag1: torch.Tensor,
         diag2: torch.Tensor,
         dct_first: bool = True,
         separate_diagonal: bool = True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs either the multiplication:
        x -> diag2 @ iDCT @ diag1 @ DCT @ x
    or
        x -> diag2 @ DCT @ diag1 @ iDCT @ x.
    In the paper [1], the math describes the 2nd type while the implementation uses the 1st type.
    Note that the DCT and iDCT are normalized.
    [1] Marcin Moczulski, Misha Denil, Jeremy Appleyard, Nando de Freitas.
    ACDC: A Structured Efficient Linear Layer.
    http://arxiv.org/abs/1511.05946
    Parameters:
        diag1: (n,), where n is a power of 2.
        diag2: (n,), where n is a power of 2.
        dct_first: if True, uses the first type above; otherwise use the second type.
        separate_diagonal: if False, the diagonal is combined into the Butterfly part.
    """
    n, = diag1.shape
    assert diag2.shape == (n, )
    assert n == 1 << int(math.ceil(math.log2(n))), 'n must be a power of 2'
    # Construct the permutation before the FFT: separate the even and odd and then reverse the odd
    # e.g., [0, 1, 2, 3] -> [0, 2, 3, 1].
    # This permutation is actually in B (not just B^T B or B B^T). This can be checked with
    # perm2butterfly.
    perm = torch.arange(n)
    perm = torch.cat((perm[::2], perm[1::2].flip([0])))
    perm_inverse = invert(perm)
    br = bitreversal_permutation(n, pytorch_format=True)
    postprocess_diag = 2 * torch.exp(-1j * math.pi * torch.arange(0.0, n) /
                                     (2 * n))
    # Normalize
    postprocess_diag[0] /= 2.0
    postprocess_diag[1:] /= math.sqrt(2)
    if dct_first:
        b_fft = fft(n, normalized=True, br_first=False, with_br_perm=False)
        b_ifft = ifft(n, normalized=True, br_first=True, with_br_perm=False)
        b1 = diagonal_butterfly(b_fft, postprocess_diag[br], diag_first=False)
        b2 = diagonal_butterfly(b_ifft,
                                postprocess_diag.conj()[br],
                                diag_first=True)
        if not separate_diagonal:
            b1 = diagonal_butterfly(b_fft, diag1[br], diag_first=False)
            b2 = diagonal_butterfly(b2, diag2[perm], diag_first=False)
            return nn.Sequential(FixedPermutation(perm), Real2Complex(), b1,
                                 Complex2Real(), Real2Complex(), b2,
                                 Complex2Real(),
                                 FixedPermutation(perm_inverse))
        else:
            return nn.Sequential(FixedPermutation(perm), Real2Complex(), b1,
                                 Complex2Real(),
                                 Diagonal(diagonal_init=diag1[br]),
                                 Real2Complex(), b2, Complex2Real(),
                                 Diagonal(diagonal_init=diag2[perm]),
                                 FixedPermutation(perm_inverse))
    else:
        b_fft = fft(n, normalized=True, br_first=True, with_br_perm=False)
        b_ifft = ifft(n, normalized=True, br_first=False, with_br_perm=False)
        b1 = diagonal_butterfly(b_ifft,
                                postprocess_diag.conj(),
                                diag_first=True)
        b2 = diagonal_butterfly(b_fft, postprocess_diag, diag_first=False)
        if not separate_diagonal:
            b1 = diagonal_butterfly(b1, diag1[perm][br], diag_first=False)
            b2 = diagonal_butterfly(b_fft, diag2, diag_first=False)
            return nn.Sequential(Real2Complex(), b1, Complex2Real(),
                                 Real2Complex(), b2, Complex2Real())
        else:
            return nn.Sequential(Real2Complex(), b1, Complex2Real(),
                                 Diagonal(diagonal_init=diag1[perm][br]),
                                 Real2Complex(), b2, Complex2Real(),
                                 Diagonal(diagonal_init=diag2))
Пример #4
0
def circulant(col, transposed=False, separate_diagonal=True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs circulant matrix
    multiplication.
    Parameters:
        col: torch.Tensor of size (n, ). The first column of the circulant matrix.
        transposed: if True, then the circulant matrix is transposed, i.e. col is the first *row*
                    of the matrix.
        separate_diagonal: if True, the returned nn.Module is Butterfly, Diagonal, Butterfly.
                           if False, the diagonal is combined into the Butterfly part.
    """
    assert col.dim() == 1, 'Vector col must have dimension 1'
    complex = col.is_complex()
    n = col.shape[0]
    log_n = int(math.ceil(math.log2(n)))
    # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n?
    # I've only figured out how to pad to size 1 << (log_n + 1).
    # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c]
    n_extended = n if n == 1 << log_n else 1 << (log_n + 1)
    b_fft = fft(n_extended,
                normalized=True,
                br_first=False,
                with_br_perm=False).to(col.device)
    b_fft.in_size = n
    b_ifft = ifft(n_extended,
                  normalized=True,
                  br_first=True,
                  with_br_perm=False).to(col.device)
    b_ifft.out_size = n
    if n < n_extended:
        col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n)))
        col = torch.cat((col_0, col))
    if not col.is_complex():
        col = real2complex(col)
    # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the
    # circulant matrix.
    col_f = torch.fft.fft(col, norm=None)
    if transposed:
        # We could have just transposed the iFFT * Diag * FFT to get FFT * Diag * iFFT.
        # Instead we use the fact that row is the reverse of col, but the 0-th element stays put.
        # This corresponds to the same reversal in the frequency domain.
        # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal
        col_f = torch.cat((col_f[:1], col_f[1:].flip([0])))
    br_perm = bitreversal_permutation(n_extended,
                                      pytorch_format=True).to(col.device)
    # diag = col_f[..., br_perm]
    diag = index_last_dim(col_f, br_perm)
    if separate_diagonal:
        if not complex:
            return nn.Sequential(Real2Complex(), b_fft,
                                 Diagonal(diagonal_init=diag), b_ifft,
                                 Complex2Real())
        else:
            return nn.Sequential(b_fft, Diagonal(diagonal_init=diag), b_ifft)
    else:
        # Combine the diagonal with the last twiddle factor of b_fft
        with torch.no_grad():
            b_fft = diagonal_butterfly(b_fft,
                                       diag,
                                       diag_first=False,
                                       inplace=True)
        # Combine the b_fft and b_ifft into one Butterfly (with nblocks=2).
        b = butterfly_product(b_fft, b_ifft)
        b.in_size = n
        b.out_size = n
        return b if complex else nn.Sequential(Real2Complex(), b,
                                               Complex2Real())