예제 #1
0
 def _setup(self, config):
     torch.manual_seed(config['seed'])
     self.model = ButterflyProduct(size=config['size'],
                                   complex=True,
                                   fixed_order=config['fixed_order'],
                                   softmax_fn=config['softmax_fn'])
     if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
         self.semantic_loss_weight = config['semantic_loss_weight']
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     size = config['size']
     n = size
     np.random.seed(0)
     x = np.random.randn(n)
     V = np.vander(x, increasing=True)
     self.target_matrix = torch.tensor(V, dtype=torch.float)
     arange_ = np.arange(size)
     dct_perm = np.concatenate((arange_[::2], arange_[::-2]))
     br_perm = bitreversal_permutation(size)
     assert config['perm'] in ['id', 'br', 'dct']
     if config['perm'] == 'id':
         self.perm = torch.arange(size)
     elif config['perm'] == 'br':
         self.perm = br_perm
     elif config['perm'] == 'dct':
         self.perm = torch.arange(size)[dct_perm][br_perm]
     else:
         assert False, 'Wrong perm in config'
예제 #2
0
 def _setup(self, config):
     size = config['size']
     torch.manual_seed(config['seed'])
     self.model = ButterflyProduct(size=size,
                                   complex=True,
                                   fixed_order=True)
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
     self.br_perm = torch.tensor(bitreversal_permutation(size))
예제 #3
0
파일: ops.py 프로젝트: sfox14/butterfly
def ops_transpose_mult_br(a, b, c, p0, p1, v):
    """Fast algorithm to multiply P^T v where P is the matrix of coefficients of
    OPs, specified by the coefficients a, b, c, and the starting polynomials p0,
    p_1. Implementation with bit-reversal.
    In particular, the recurrence is
    P_{n+2}(x) = (a[n] x + b[n]) P_{n+1}(x) + c[n] P_n(x).
    Parameters:
        a: array of length n
        b: array of length n
        c: array of length n
        p0: real number representing P_0(x).
        p1: pair of real numbers representing P_1(x).
        v: (batch_size, n)
    Return:
        result: P^T v.
    """
    n = v.shape[-1]
    m = int(np.log2(n))
    assert n == 1 << m, "Length n must be a power of 2."

    # Preprocessing: compute T_{i:j}, the transition matrix from p_i to p_j.
    T_br = [None] * (m + 1)
    # Lowest level, filled with T_{i:i+1}
    # n matrices, each 2 x 2, with coefficients being polynomials of degree <= 1
    T_br[0] = torch.zeros(n, 2, 2, 2)
    T_br[0][:, 0, 0, 1] = a
    T_br[0][:, 0, 0, 0] = b
    T_br[0][:, 0, 1, 0] = c
    T_br[0][:, 1, 0, 0] = 1.0
    br_perm = bitreversal_permutation(n)
    T_br[0] = T_br[0][br_perm]
    for i in range(1, m + 1):
        T_br[i] = polymatmul(T_br[i - 1][n >> i:], T_br[i - 1][:n >> i])

    P_init = torch.tensor([p1, [p0, 0.0]], dtype=torch.float)  # [p_1, p_0]
    P_init = P_init.unsqueeze(0).unsqueeze(-2)
    # Check that T_br is computed correctly
    # These should be the polynomials P_{n+1} and P_n
    # Pnp1n = polymatmul(T_br[m], P_init).squeeze()

    v_br = v[:, br_perm]
    # Bottom-up multiplication algorithm to avoid recursion
    S_br = [None] * m
    Tidentity = torch.eye(2).unsqueeze(0).unsqueeze(3)
    S_br[0] = v_br[:, n // 2:, None, None, None] * T_br[0][:n // 2]
    S_br[0][:, :, :, :, :1] += v_br[:, :n // 2, None, None, None] * Tidentity
    for i in range(1, m):
        S_br[i] = polymatmul(S_br[i - 1][:, (n >> (i + 1)):],
                             T_br[i][:(n >> (i + 1))])
        S_br[i][:, :, :, :, :S_br[i -
                                  1].shape[-1]] += S_br[i -
                                                        1][:, :(n >> (i + 1))]
    result = polymatmul(S_br[m - 1][:, :, [1], :, :n - 1],
                        P_init).squeeze(1).squeeze(1).squeeze(1)
    return result
예제 #4
0
 def _setup(self, config):
     torch.manual_seed(config['seed'])
     self.model = ButterflyProduct(size=config['size'],
                                   complex=True,
                                   fixed_order=config['fixed_order'],
                                   softmax_fn=config['softmax_fn'])
     if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
         self.semantic_loss_weight = config['semantic_loss_weight']
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     size = config['size']
     self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
     self.br_perm = torch.tensor(bitreversal_permutation(size))
예제 #5
0
 def _setup(self, config):
     torch.manual_seed(config['seed'])
     self.model = HstackDiagProduct(size=config['size'])
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     size = config['size']
     # Target: Legendre polynomials
     P = np.zeros((size, size), dtype=np.float64)
     for i, coef in enumerate(np.eye(size)):
         P[i, :i + 1] = legendre.leg2poly(coef)
     self.target_matrix = torch.tensor(P)
     self.br_perm = bitreversal_permutation(size)
     self.input = (torch.eye(size)[:, :, None, None] *
                   torch.eye(2)).unsqueeze(-1)
     self.input_permuted = self.input[:, self.br_perm]
예제 #6
0
 def _setup(self, config):
     self.target_matrix = torch.tensor(config['target_matrix'],
                                       dtype=torch.float)
     assert self.target_matrix.shape[0] == self.target_matrix.shape[
         1], 'Only square matrices are supported'
     assert self.target_matrix.dim() in [
         2, 3
     ], 'target matrix must be 2D if real of 3D if complex'
     size = self.target_matrix.shape[0]
     torch.manual_seed(config['seed'])
     self.model = Block2x2DiagProduct(size=size, complex=True)
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     self.n_epochs_per_validation = config['n_epochs_per_validation']
     self.input = real_to_complex(
         torch.eye(size)[:, torch.tensor(bitreversal_permutation(size))])
예제 #7
0
def test_hstackdiag_product():
    size = 8
    model = HstackDiagProduct(size)

    # Legendre polynomials
    n = size
    m = int(np.log2(n))
    n_range = torch.arange(n, dtype=torch.float)
    a = (2 * n_range + 3) / (n_range + 2)
    b = torch.zeros(n)
    c = -(n_range + 1) / (n_range + 2)
    p0 = 1.0
    p1 = (0.0, 1.0)
    # Preprocessing: compute T_{i:j}, the transition matrix from p_i to p_j.
    T_br = [None] * m
    # Lowest level, filled with T_{i:i+1}
    # n matrices, each 2 x 2, with coefficients being polynomials of degree <= 1
    T_br[0] = torch.zeros(n, 2, 2, 2)
    T_br[0][:, 0, 0, 1] = a
    T_br[0][:, 0, 0, 0] = b
    T_br[0][:, 0, 1, 0] = c
    T_br[0][:, 1, 0, 0] = 1.0
    br_perm = bitreversal_permutation(n)
    T_br[0] = T_br[0][br_perm]
    for i in range(1, m):
        T_br[i] = polymatmul(T_br[i - 1][n >> i:], T_br[i - 1][:n >> i])

    P_init = torch.tensor([p1, [p0, 0.0]], dtype=torch.float)  # [p_1, p_0]
    P_init = P_init.unsqueeze(0).unsqueeze(-2)
    Tidentity = torch.eye(2).unsqueeze(0).unsqueeze(3)

    model.P_init = nn.Parameter(P_init)
    for i in range(m):
        factor = model.factors[m - i - 1]
        factor.diag1 = nn.Parameter(
            torch.cat((Tidentity.expand(factor.size, -1, -1, -1),
                       torch.zeros(factor.size, 2, 2, factor.deg)),
                      dim=-1))
        factor.diag2 = nn.Parameter(T_br[i][:factor.size])

    batch_size = 2
    x_original = torch.randn((batch_size, size))
    x = (x_original[:, :, None, None] * torch.eye(2)).unsqueeze(-1)
    output = model(x[:, br_perm])
    assert output.shape == (batch_size, size)
    assert torch.allclose(output,
                          ops_transpose_mult_br(a, b, c, p0, p1, x_original))