コード例 #1
0
 def test_butterfly_factor_inplace_cuda(self):
     batch_size = 10
     n = 4096
     B = Block2x2DiagProduct(n, ortho_init=True).to('cuda')
     input_ = torch.randn(batch_size, n, device='cuda', requires_grad=True)
     twiddle = twiddle_list_concat(B)
     output_inplace = butterfly_factor_mult_inplace(twiddle, input_)
     output = B(input_)
     self.assertTrue(
         torch.allclose(output_inplace,
                        output,
                        rtol=self.rtol,
                        atol=self.atol),
         (output_inplace - output).abs().max().item())
     grad = torch.randn_like(output)
     d_twiddle_inplace, d_input_inplace = torch.autograd.grad(
         output_inplace, (twiddle, input_), grad, retain_graph=True)
     output.backward(grad, retain_graph=True)
     d_input = input_.grad
     d_twiddle = torch.cat(
         [factor.ABCD.grad.permute(2, 0, 1) for factor in B.factors[::-1]])
     self.assertTrue(
         torch.allclose(d_input_inplace,
                        d_input,
                        rtol=self.rtol,
                        atol=self.atol),
         (d_input_inplace - d_input).abs().max().item())
     self.assertTrue(
         torch.allclose(d_twiddle_inplace,
                        d_twiddle,
                        rtol=self.rtol,
                        atol=self.atol),
         (d_twiddle_inplace - d_twiddle).abs().max().item())
コード例 #2
0
 def _setup(self, config):
     self.target_matrix = torch.tensor(config['target_matrix'],
                                       dtype=torch.float)
     assert self.target_matrix.shape[0] == self.target_matrix.shape[
         1], 'Only square matrices are supported'
     assert self.target_matrix.dim() in [
         2, 3
     ], 'target matrix must be 2D if real of 3D if complex'
     size = self.target_matrix.shape[0]
     torch.manual_seed(config['seed'])
     # Transposing the permutation product won't capture the FFT, since we'll
     # permutations that interleave the first half and second half (inverse
     # of the permutation that separates the even and the odd).
     # However, using the permutation product with increasing size will work
     # since it can represent bit reversal, which is its own inverse.
     self.model = nn.Sequential(
         Block2x2DiagProduct(size=size, complex=True,
                             decreasing_size=False),
         BlockPermProduct(size=size,
                          complex=True,
                          share_logit=False,
                          increasing_size=True),
     )
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     self.n_epochs_per_validation = config['n_epochs_per_validation']
     self.input = real_to_complex(torch.eye(size))
コード例 #3
0
 def test_butterfly_factor_complex_inplace_cpu(self):
     batch_size = 10
     n = 4096
     B = Block2x2DiagProduct(n, complex=True)
     input_ = torch.randn(batch_size, n, 2, requires_grad=True)
     twiddle = twiddle_list_concat(B)
     output_inplace = butterfly_factor_mult_inplace(twiddle, input_)
     output = B(input_)
     self.assertTrue(
         torch.allclose(output_inplace,
                        output,
                        rtol=self.rtol,
                        atol=self.atol),
         (output_inplace - output).abs().max().item())
コード例 #4
0
 def _setup(self, config):
     self.target_matrix = torch.tensor(config['target_matrix'],
                                       dtype=torch.float)
     assert self.target_matrix.shape[0] == self.target_matrix.shape[
         1], 'Only square matrices are supported'
     assert self.target_matrix.dim() in [
         2, 3
     ], 'target matrix must be 2D if real of 3D if complex'
     size = self.target_matrix.shape[0]
     torch.manual_seed(config['seed'])
     self.model = Block2x2DiagProduct(size=size, complex=True)
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     self.n_epochs_per_validation = config['n_epochs_per_validation']
     self.input = real_to_complex(
         torch.eye(size)[:, torch.tensor(bitreversal_permutation(size))])
コード例 #5
0
 def test_butterfly_factor_intermediate_complex_cuda(self):
     batch_size = 10
     n = 4096
     B = Block2x2DiagProduct(n, complex=True).to('cuda')
     input_ = torch.randn(batch_size,
                          n,
                          2,
                          device='cuda',
                          requires_grad=True)
     twiddle = twiddle_list_concat(B).unsqueeze(0)
     output_intermediate = butterfly_multiply_intermediate(twiddle, input_)
     output = [input_]
     for factor in B.factors[::-1]:
         output.append(
             butterfly_factor_mult(
                 factor.ABCD, output[-1].view(-1, 2, factor.size // 2,
                                              2)).view(output[-1].shape))
     output = torch.stack(output)
     self.assertTrue(
         torch.allclose(output_intermediate.squeeze(2),
                        output,
                        rtol=self.rtol,
                        atol=self.atol),
         (output_intermediate.squeeze(2) - output).abs().max().item())
     grad = torch.randn_like(output[-1])
     d_twiddle_intermediate, d_input_intermediate = butterfly_multiply_intermediate_backward(
         grad.unsqueeze(1), twiddle, output_intermediate)
     output[-1].backward(grad, retain_graph=True)
     d_input = input_.grad
     d_twiddle = torch.cat([
         factor.ABCD.grad.permute(2, 0, 1, 3) for factor in B.factors[::-1]
     ])
     self.assertTrue(
         torch.allclose(d_input_intermediate,
                        d_input,
                        rtol=self.rtol,
                        atol=self.atol),
         (d_input_intermediate - d_input).abs().max().item())
     self.assertTrue(
         torch.allclose(d_twiddle_intermediate,
                        d_twiddle,
                        rtol=self.rtol,
                        atol=self.atol),
         (d_twiddle_intermediate - d_twiddle).abs().max().item())
コード例 #6
0
ファイル: profile.py プロジェクト: sfox14/butterfly
def profile_butterfly_mult():
    nsteps = 10
    batch_size = 100
    n = 1024
    B = Block2x2DiagProduct(n)
    x = torch.randn(batch_size, n)
    # B(x)
    optimizer = optim.Adam(B.parameters(), lr=0.01)
    for _ in range(nsteps):
        optimizer.zero_grad()
        # output = B(x)
        # loss = nn.functional.mse_loss(output, x)
        output = x
        for factor in B.factors[::-1]:
            output = butterfly_factor_mult(factor.ABCD, output.view(-1, 2, factor.size // 2)).view(x.shape)
        # output = output.reshape(x.shape)
        loss = output.sum()
        loss.backward()
        optimizer.step()
コード例 #7
0
 def test_butterfly_factor_complex_cpu(self):
     batch_size = 10
     n = 4096
     B = Block2x2DiagProduct(n, complex=True)
     input_ = torch.randn(batch_size, n, 2, requires_grad=True)
     output = input_
     for factor in B.factors[::-1]:
         prev = output
         output = butterfly_factor_mult(
             factor.ABCD, output.view(-1, 2, factor.size // 2,
                                      2)).view(prev.shape)
         output_slow = (complex_mul(
             factor.ABCD, prev.view(-1, 1, 2, factor.size // 2,
                                    2)).sum(dim=-3)).view(prev.shape)
         self.assertTrue(
             torch.allclose(output,
                            output_slow,
                            rtol=self.rtol,
                            atol=self.atol),
             (output - output_slow).abs().max().item())
         grad = torch.randn_like(output)
         d_twiddle, d_input = torch.autograd.grad(output,
                                                  (factor.ABCD, prev),
                                                  grad,
                                                  retain_graph=True)
         d_twiddle_slow, d_input_slow = torch.autograd.grad(
             output_slow, (factor.ABCD, prev), grad, retain_graph=True)
         self.assertTrue(
             torch.allclose(d_twiddle,
                            d_twiddle_slow,
                            rtol=self.rtol,
                            atol=self.atol),
             (d_twiddle - d_twiddle_slow).abs().max().item())
         self.assertTrue(
             torch.allclose(d_input,
                            d_input_slow,
                            rtol=self.rtol,
                            atol=self.atol),
             (d_input - d_input_slow).abs().max().item())
コード例 #8
0
 def test_butterfly_factor_cuda(self):
     batch_size = 100
     n = 4096  # To test n > MAX_BLOCK_SIZE
     B = Block2x2DiagProduct(n).to('cuda')
     input_ = torch.randn(batch_size, n, device='cuda', requires_grad=True)
     output = input_
     for factor in B.factors[::-1]:
         prev = output
         output = butterfly_factor_mult(
             factor.ABCD, output.view(-1, 2,
                                      factor.size // 2)).view(prev.shape)
         output_slow = ((factor.ABCD *
                         prev.view(-1, 1, 2, factor.size // 2)).sum(
                             dim=-2)).view(prev.shape)
         self.assertTrue(
             torch.allclose(output,
                            output_slow,
                            rtol=self.rtol,
                            atol=self.atol),
             (output - output_slow).abs().max().item())
         grad = torch.randn_like(output)
         d_twiddle, d_input = torch.autograd.grad(output,
                                                  (factor.ABCD, prev),
                                                  grad,
                                                  retain_graph=True)
         d_twiddle_slow, d_input_slow = torch.autograd.grad(
             output_slow, (factor.ABCD, prev), grad, retain_graph=True)
         self.assertTrue(
             torch.allclose(d_twiddle,
                            d_twiddle_slow,
                            rtol=self.rtol,
                            atol=self.atol),
             (factor.size, (d_twiddle - d_twiddle_slow).abs().max().item()))
         self.assertTrue(
             torch.allclose(d_input,
                            d_input_slow,
                            rtol=self.rtol,
                            atol=self.atol),
             (d_input - d_input_slow).abs().max().item())
コード例 #9
0
def named_target_matrix(name, size):
    """
    Parameter:
        name: name of the target matrix
    Return:
        target_matrix: (n, n) numpy array for real matrices or (n, n, 2) for complex matrices.
    """
    if name == 'dft':
        return LA.dft(size, scale='sqrtn')[:, :, None].view('float64')
    elif name == 'idft':
        return np.ascontiguousarray(LA.dft(size, scale='sqrtn').conj().T)[:, :, None].view('float64')
    elif name == 'dft2':
        size_sr = int(math.sqrt(size))
        matrix = np.fft.fft2(np.eye(size_sr**2).reshape(-1, size_sr, size_sr), norm='ortho').reshape(-1, size_sr**2)
        # matrix1d = LA.dft(size_sr, scale='sqrtn')
        # assert np.allclose(np.kron(m1d, m1d), matrix)
        # return matrix[:, :, None].view('float64')
        from butterfly.utils import bitreversal_permutation
        br_perm = bitreversal_permutation(size_sr)
        br_perm2 = np.arange(size_sr**2).reshape(size_sr, size_sr)[br_perm][:, br_perm].reshape(-1)
        matrix = np.ascontiguousarray(matrix[:, br_perm2])
        return matrix[:, :, None].view('float64')
    elif name == 'dct':
        # Need to transpose as dct acts on rows of matrix np.eye, not columns
        # return dct(np.eye(size), norm='ortho').T
        return dct(np.eye(size)).T / math.sqrt(size)
    elif name == 'dst':
        return dst(np.eye(size)).T / math.sqrt(size)
    elif name == 'hadamard':
        return LA.hadamard(size) / math.sqrt(size)
    elif name == 'hadamard2':
        size_sr = int(math.sqrt(size))
        matrix1d = LA.hadamard(size_sr) / math.sqrt(size_sr)
        return np.kron(matrix1d, matrix1d)
    elif name == 'b2':
        size_sr = int(math.sqrt(size))
        from butterfly import Block2x2DiagProduct
        b = Block2x2DiagProduct(size_sr)
        matrix1d = b(torch.eye(size_sr)).t().detach().numpy()
        return np.kron(matrix1d, matrix1d)
    elif name == 'convolution':
        np.random.seed(0)
        x = np.random.randn(size)
        return LA.circulant(x) / math.sqrt(size)
    elif name == 'hartley':
        return hartley_matrix(size) / math.sqrt(size)
    elif name == 'haar':
        return haar_matrix(size, normalized=True) / math.sqrt(size)
    elif name == 'legendre':
        grid = np.linspace(-1, 1, size + 2)[1:-1]
        return legendre.legvander(grid, size - 1).T / math.sqrt(size)
    elif name == 'hilbert':
        H = hilbert_matrix(size)
        return H / np.linalg.norm(H, 2)
    elif name == 'randn':
        np.random.seed(0)
        return np.random.randn(size, size) / math.sqrt(size)
    elif name == 'permutation':
        np.random.seed(0)
        perm = np.random.permutation(size)
        P = np.eye(size)[perm]
        return P
    elif name.startswith('rank-unnorm'):
        r = int(name[11:])
        np.random.seed(0)
        G = np.random.randn(size, r)
        H = np.random.randn(size, r)
        M = G @ H.T
        # M /= math.sqrt(size*r)
        return M
    elif name.startswith('rank'):
        r = int(name[4:])
        np.random.seed(0)
        G = np.random.randn(size, r)
        H = np.random.randn(size, r)
        M = G @ H.T
        M /= math.sqrt(size*r)
        return M
    elif name.startswith('sparse'):
        s = int(name[6:])
        # 2rn parameters
        np.random.seed(0)
        mask = sparse.random(size, size, density=s/size, data_rvs=np.ones)
        M = np.random.randn(size, size) * (mask.toarray())
        M /= math.sqrt(s)
        return M
    elif name.startswith('toeplitz'):
        r = int(name[8:])
        G = np.random.randn(size, r) / math.sqrt(size*r)
        H = np.random.randn(size, r) / math.sqrt(size*r)
        M = toeplitz_like(G, H)
        return M
    elif name == 'fastfood':
        n = size
        S = np.random.randn(n)
        G = np.random.randn(n)
        B = np.random.randn(n)
        # P = np.arange(n)
        P = np.random.permutation(n)
        H = hadamard(n)
        # SHGPHB
        # print(H)
        # print((H*B)[P,:])
        # print((H @ (G[:,np.newaxis] * (H * B)[P,:])))
        F = S[:,np.newaxis] * (H @ (G[:,np.newaxis] * (H * B)[P,:])) / n
        return F
        # x = np.random.randn(batch_size,n)
        # HB = hadamard_transform(B)
        # PHBx = HBx[:, P]
        # HGPHBx = hadamard_transform(G*PHBx)
        # return S*HGPHBx
    elif name == 'butterfly':
        # n (log n+1) params in the hierarchy
        b = Butterfly(in_size=size, out_size=size, bias=False, tied_weight=False, param='odo', nblocks=0)
        M = b(torch.eye(size))
        return M.cpu().detach().numpy()

    else:
        assert False, 'Target matrix name not recognized or implemented'
コード例 #10
0
from test_factor_multiply import twiddle_list_concat

exps = np.arange(6, 14)
sizes = 1 << exps

batch_size = 256

ntrials = [100000, 100000, 10000, 10000, 10000, 10000, 10000, 10000]

dense_times = np.zeros(exps.size)
fft_times = np.zeros(exps.size)
butterfly_times = np.zeros(exps.size)
for idx_n, (n, ntrial) in enumerate(zip(sizes, ntrials)):
    print(n)
    B = Block2x2DiagProduct(n).to('cuda')
    L = torch.nn.Linear(n, n, bias=False).to('cuda')
    x = torch.randn(batch_size, n, requires_grad=True).to('cuda')
    grad = torch.randn_like(x)
    twiddle = twiddle_list_concat(B)

    # Dense multiply
    output = L(x)  # Do it once to initialize cuBlas handle and such
    torch.autograd.grad(output, (L.weight, x), grad)
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(ntrial):
        output = L(x)
        torch.autograd.grad(output, (L.weight, x), grad)
    torch.cuda.synchronize()
    end = time.perf_counter()
コード例 #11
0
exps = np.arange(6, 14)
sizes = 1 << exps

ntrials = [100000, 100000, 1000, 100, 100, 10, 10, 10]

dense_times = np.zeros(exps.size)
fft_times = np.zeros(exps.size)
scipyfft_times = np.zeros(exps.size)
dct_times = np.zeros(exps.size)
dst_times = np.zeros(exps.size)
bp_times = np.zeros(exps.size)
for idx_n, (n, ntrial) in enumerate(zip(sizes, ntrials)):
    print(n)
    x = np.random.random(n).astype(np.float32)
    B = Block2x2DiagProduct(n)
    P = BlockPermProduct(n)
    B_matrix = B(torch.eye(int(n))).t().contiguous()
    B_matrix_np = B_matrix.detach().numpy()

    ABCDs = Block2x2DiagProduct_to_ABCDs(B)
    perm = P.argmax().detach().numpy().astype(int)

    # Dense multiply
    start = timer()
    [B_matrix_np @ x for _ in range(ntrial)]
    end = timer()
    dense_times[idx_n] = (end - start) / ntrial

    # FFT
    start = timer()
コード例 #12
0
def named_target_matrix(name, size):
    """
    Parameter:
        name: name of the target matrix
    Return:
        target_matrix: (n, n) numpy array for real matrices or (n, n, 2) for complex matrices.
    """
    if name == 'dft':
        return LA.dft(size, scale='sqrtn')[:, :, None].view('float64')
    elif name == 'idft':
        return np.ascontiguousarray(LA.dft(
            size, scale='sqrtn').conj().T)[:, :, None].view('float64')
    elif name == 'dft2':
        size_sr = int(math.sqrt(size))
        matrix = np.fft.fft2(np.eye(size_sr**2).reshape(-1, size_sr, size_sr),
                             norm='ortho').reshape(-1, size_sr**2)
        # matrix1d = LA.dft(size_sr, scale='sqrtn')
        # assert np.allclose(np.kron(m1d, m1d), matrix)
        # return matrix[:, :, None].view('float64')
        from butterfly.utils import bitreversal_permutation
        br_perm = bitreversal_permutation(size_sr)
        br_perm2 = np.arange(size_sr**2).reshape(
            size_sr, size_sr)[br_perm][:, br_perm].reshape(-1)
        matrix = np.ascontiguousarray(matrix[:, br_perm2])
        return matrix[:, :, None].view('float64')
    elif name == 'dct':
        # Need to transpose as dct acts on rows of matrix np.eye, not columns
        # return dct(np.eye(size), norm='ortho').T
        return dct(np.eye(size)).T / math.sqrt(size)
    elif name == 'dst':
        return dst(np.eye(size)).T / math.sqrt(size)
    elif name == 'hadamard':
        return LA.hadamard(size) / math.sqrt(size)
    elif name == 'hadamard2':
        size_sr = int(math.sqrt(size))
        matrix1d = LA.hadamard(size_sr) / math.sqrt(size_sr)
        return np.kron(matrix1d, matrix1d)
    elif name == 'b2':
        size_sr = int(math.sqrt(size))
        import torch
        from butterfly import Block2x2DiagProduct
        b = Block2x2DiagProduct(size_sr)
        matrix1d = b(torch.eye(size_sr)).t().detach().numpy()
        return np.kron(matrix1d, matrix1d)
    elif name == 'convolution':
        np.random.seed(0)
        x = np.random.randn(size)
        return LA.circulant(x) / math.sqrt(size)
    elif name == 'hartley':
        return hartley_matrix(size) / math.sqrt(size)
    elif name == 'haar':
        return haar_matrix(size, normalized=True) / math.sqrt(size)
    elif name == 'legendre':
        grid = np.linspace(-1, 1, size + 2)[1:-1]
        return legendre.legvander(grid, size - 1).T / math.sqrt(size)
    elif name == 'hilbert':
        H = hilbert_matrix(size)
        return H / np.linalg.norm(H, 2)
    elif name == 'randn':
        np.random.seed(0)
        return np.random.randn(size, size) / math.sqrt(size)
    else:
        assert False, 'Target matrix name not recognized or implemented'