Exemple #1
0
 def forward(self, input):
     """
     Parameters:
         input: (stack, ..., size) if real or (stack, ..., size, 2) if complex
         if not tied_weight: (stack, n_blocks, ..., size) if real or (stack, n_blocks, ..., size, 2) if complex
     Return:
         output: (stack, ..., size) if real or (stack, ..., size, 2) if complex
         if not tied_weight: (stack, n_blocks, ..., size) if real or (stack, n_blocks, ..., size, 2) if complex
     """
     if self.tied_weight:
         if not self.complex:
             return (self.ABCD.unsqueeze(1) *
                     input.view(self.stack, -1, 1, 2, self.size // 2)).sum(
                         dim=-2).view(input.shape)
         else:
             return complex_mul(
                 self.ABCD.unsqueeze(1),
                 input.view(self.stack, -1, 1, 2, self.size // 2,
                            2)).sum(dim=-3).view(input.shape)
     else:
         if not self.complex:
             return (self.ABCD.unsqueeze(2) * input.view(
                 self.stack, self.n_blocks, -1, 1, 2, self.size // 2)).sum(
                     dim=-2).view(input.shape)
         else:
             return complex_mul(
                 self.ABCD.unsqueeze(2),
                 input.view(self.stack, self.n_blocks, -1, 1, 2,
                            self.size // 2,
                            2)).sum(dim=-3).view(input.shape)
 def forward(self, input):
     """
     Parameters:
         input: (batch, *, in_size)
     Return:
         output: (batch, *, out_size)
     """
     u = input.view(np.prod(input.size()[:-1]), input.size(-1))
     batch = u.shape[0]
     # output = toeplitz_mult(self.G, self.H, input, self.corner)
     # return output.reshape(batch, self.nstack * self.size)
     n = self.in_size
     v = self.H
     # u_f = torch.rfft(torch.cat((u.flip(1), torch.zeros_like(u)), dim=-1), 1)
     u_f = torch.rfft(
         torch.cat((u[:, self.reverse_idx], torch.zeros_like(u)), dim=-1),
         1)
     v_f = torch.rfft(torch.cat((v, torch.zeros_like(v)), dim=-1), 1)
     uv_f = complex_mul(u_f.unsqueeze(1).unsqueeze(1), v_f)
     # transpose_out =  torch.irfft(uv_f, 1, signal_sizes=(2 * n, ))[..., :n].flip(3)
     transpose_out = torch.irfft(uv_f, 1,
                                 signal_sizes=(2 * n, ))[...,
                                                         self.reverse_idx]
     v = self.G
     w = transpose_out
     w_f = torch.rfft(torch.cat((w, torch.zeros_like(w)), dim=-1), 1)
     v_f = torch.rfft(torch.cat((v, torch.zeros_like(v)), dim=-1), 1)
     wv_sum_f = complex_mul(w_f, v_f).sum(dim=2)
     output = torch.irfft(wv_sum_f, 1, signal_sizes=(2 * n, ))[..., :n]
     output = output.reshape(batch,
                             self.nstack * self.in_size)[:, :self.out_size]
     if self.bias is not None:
         output = output + self.bias
     return output.view(*input.size()[:-1], self.out_size)
Exemple #3
0
def test_butterfly_factor_complex_multiply():
    from complex_utils import complex_mul
    n = 1024
    m = int(math.log2(n))
    x = torch.randn((n, 2), requires_grad=True)
    sizes = [n >> i for i in range(m)]
    for size in sizes:
        bf = Block2x2Diag(size, complex=True)
        x = x.view(-1, 2 * bf.ABCD.shape[-2], 2)
        result_slow = (complex_mul(
            bf.ABCD, x.view(x.shape[:-2] +
                            (1, 2, size // 2, 2))).sum(dim=-3)).view(x.shape)
        result = butterfly_factor_mult(bf.ABCD,
                                       x.view(-1, 2, bf.ABCD.shape[-2],
                                              2)).view(x.shape)
        assert torch.allclose(result, result_slow, atol=1e-6)
        grad = torch.randn_like(x)
        d_coef_slow, d_x_slow = torch.autograd.grad(result_slow, (bf.ABCD, x),
                                                    grad,
                                                    retain_graph=True)
        d_coef, d_x = torch.autograd.grad(result, (bf.ABCD, x),
                                          grad,
                                          retain_graph=True)
        assert torch.allclose(d_coef, d_coef_slow, atol=1e-6)
        assert torch.allclose(d_x, d_x_slow, atol=1e-6)
Exemple #4
0
def test_butterfly_dct():
    from scipy.fftpack import dct
    # DCT matrix for n = 4
    size = 4
    # Need to transpose as dct acts on rows of matrix np.eye, not columns
    DCT = torch.tensor(dct(np.eye(size)).T, dtype=torch.float)
    M0diag = torch.tensor([[1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [0.0, 1.0]],
                          requires_grad=True)
    M0subdiag = torch.tensor([[1.0, 0.0], [1.0, 0.0]], requires_grad=True)
    M0superdiag = torch.tensor([[1.0, 0.0], [0.0, -1.0]], requires_grad=True)
    M0 = Butterfly(size,
                   diagonal=2,
                   complex=True,
                   diag=M0diag,
                   subdiag=M0subdiag,
                   superdiag=M0superdiag)
    M1 = Butterfly(size,
                   diagonal=1,
                   complex=True,
                   diag=torch.tensor(
                       [[1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0]],
                       requires_grad=True),
                   subdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]],
                                        requires_grad=True),
                   superdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]],
                                          requires_grad=True))
    arange_ = np.arange(size)
    dct_perm = np.concatenate((arange_[::2], arange_[::-2]))
    br_perm = bitreversal_permutation(size)
    perm = torch.arange(size)[dct_perm][br_perm]
    arange_ = torch.arange(size, dtype=torch.float)
    diag_real = 2 * torch.cos(-math.pi * arange_ / (2 * size))
    diag_imag = 2 * torch.sin(-math.pi * arange_ / (2 * size))
    diag = torch.stack((torch.diag(diag_real), torch.diag(diag_imag)), dim=-1)
    assert torch.allclose(
        complex_matmul(diag, complex_matmul(M0.matrix(), M1.matrix()))[:, perm,
                                                                       0], DCT)
    D = torch.stack((diag_real, diag_imag), dim=-1)
    DM0 = Butterfly(size,
                    diagonal=2,
                    complex=True,
                    diag=complex_mul(D, M0diag),
                    subdiag=complex_mul(D[2:], M0subdiag),
                    superdiag=complex_mul(D[:2], M0superdiag))
    assert torch.allclose(
        complex_matmul(DM0.matrix(), M1.matrix())[:, perm, 0], DCT)
Exemple #5
0
 def forward(self, input):
     """
     Parameters:
         input: (batch, size)
     Return:
         output: (batch, nstack * size)
     """
     batch = input.shape[0]
     input_f = torch.rfft(input, 1)
     prod = complex_mul(self.c_f, input_f.unsqueeze(1))
     return torch.irfft(prod, 1, signal_sizes=(self.size, )).view(
         batch, self.nstack * self.size)
def toeplitz_krylov_multiply(v, w, f=0.0):
    """Multiply \sum_i Krylov(Z_f, v_i) @ w_i.
    Parameters:
        v: (nstack, rank, n)
        w: (batch_size, nstack, rank, n)
        f: real number
    Returns:
        product: (batch, nstack, n)
    """
    _, nstack, rank, n = w.shape
    nstack_, rank_, n_ = v.shape
    assert n == n_, 'w and v must have the same last dimension'
    assert rank == rank_, 'w and v must have the same rank'
    assert nstack == nstack_, 'w and v must have the same nstack'
    if f != 0.0:  # cycle version
        # Computing the roots of f
        mod = abs(f)**(torch.arange(n, dtype=w.dtype, device=w.device) / n)
        if f > 0:
            arg = torch.stack((torch.ones(n, dtype=w.dtype, device=w.device),
                               torch.zeros(n, dtype=w.dtype, device=w.device)),
                              dim=-1)
        else:  # Find primitive roots of -1
            angles = torch.arange(n, dtype=w.dtype,
                                  device=w.device) / n * np.pi
            arg = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
        eta = mod[:, np.newaxis] * arg
        eta_inverse = (1.0 / mod)[:, np.newaxis] * conjugate(arg)
        w_f = torch.fft(eta * w[..., np.newaxis], 1)
        v_f = torch.fft(eta * v[..., np.newaxis], 1)
        wv_sum_f = complex_mul(w_f, v_f).sum(dim=2)
        wv_sum = torch.ifft(wv_sum_f, 1)
        # We only need the real part of complex_mul(eta_inverse, wv_sum)
        return eta_inverse[..., 0] * wv_sum[..., 0] - eta_inverse[
            ..., 1] - wv_sum[..., 1]
    else:
        w_f = torch.rfft(torch.cat((w, torch.zeros_like(w)), dim=-1), 1)
        v_f = torch.rfft(torch.cat((v, torch.zeros_like(v)), dim=-1), 1)
        wv_sum_f = complex_mul(w_f, v_f).sum(dim=2)
        return torch.irfft(wv_sum_f, 1, signal_sizes=(2 * n, ))[..., :n]
def toeplitz_krylov_transpose_multiply(v, u, f=0.0):
    """Multiply Krylov(Z_f, v_i)^T @ u.
    Parameters:
        v: (nstack, rank, n)
        u: (batch_size, n)
        f: real number
    Returns:
        product: (batch, nstack, rank, n)
    """
    _, n = u.shape
    _, _, n_ = v.shape
    assert n == n_, 'u and v must have the same last dimension'
    if f != 0.0:  # cycle version
        # Computing the roots of f
        mod = abs(f)**(torch.arange(n, dtype=u.dtype, device=u.device) / n)
        if f > 0:
            arg = torch.stack((torch.ones(n, dtype=u.dtype, device=u.device),
                               torch.zeros(n, dtype=u.dtype, device=u.device)),
                              dim=-1)
        else:  # Find primitive roots of -1
            angles = torch.arange(n, dtype=u.dtype,
                                  device=u.device) / n * np.pi
            arg = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
        eta = mod[:, np.newaxis] * arg
        eta_inverse = (1.0 / mod)[:, np.newaxis] * conjugate(arg)
        u_f = torch.ifft(eta_inverse * u[..., np.newaxis], 1)
        v_f = torch.fft(eta * v.unsqueeze(-1), 1)
        uv_f = complex_mul(u_f.unsqueeze(1).unsqueeze(1), v_f)
        uv = torch.fft(uv_f, 1)
        # We only need the real part of complex_mul(eta, uv)
        return eta[..., 0] * uv[..., 0] - eta[..., 1] * uv[..., 1]
    else:
        u_f = torch.rfft(torch.cat((u.flip(1), torch.zeros_like(u)), dim=-1),
                         1)
        v_f = torch.rfft(torch.cat((v, torch.zeros_like(v)), dim=-1), 1)
        uv_f = complex_mul(u_f.unsqueeze(1).unsqueeze(1), v_f)
        return torch.irfft(uv_f, 1, signal_sizes=(2 * n, ))[..., :n].flip(3)
Exemple #8
0
def circulant_fft(c, x):
    n = x.shape[-1]
    x_f = torch.rfft(x, 1)
    c_f = torch.rfft(c, 1)
    prod = complex_mul(c_f, x_f)
    return torch.irfft(prod, 1, signal_sizes=(n, ))