def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], complex=True, fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn']) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] n = size np.random.seed(0) x = np.random.randn(n) V = np.vander(x, increasing=True) self.target_matrix = torch.tensor(V, dtype=torch.float) arange_ = np.arange(size) dct_perm = np.concatenate((arange_[::2], arange_[::-2])) br_perm = bitreversal_permutation(size) assert config['perm'] in ['id', 'br', 'dct'] if config['perm'] == 'id': self.perm = torch.arange(size) elif config['perm'] == 'br': self.perm = br_perm elif config['perm'] == 'dct': self.perm = torch.arange(size)[dct_perm][br_perm] else: assert False, 'Wrong perm in config'
def _setup(self, config): size = config['size'] torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=size, complex=True, fixed_order=True) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.br_perm = torch.tensor(bitreversal_permutation(size))
def ops_transpose_mult_br(a, b, c, p0, p1, v): """Fast algorithm to multiply P^T v where P is the matrix of coefficients of OPs, specified by the coefficients a, b, c, and the starting polynomials p0, p_1. Implementation with bit-reversal. In particular, the recurrence is P_{n+2}(x) = (a[n] x + b[n]) P_{n+1}(x) + c[n] P_n(x). Parameters: a: array of length n b: array of length n c: array of length n p0: real number representing P_0(x). p1: pair of real numbers representing P_1(x). v: (batch_size, n) Return: result: P^T v. """ n = v.shape[-1] m = int(np.log2(n)) assert n == 1 << m, "Length n must be a power of 2." # Preprocessing: compute T_{i:j}, the transition matrix from p_i to p_j. T_br = [None] * (m + 1) # Lowest level, filled with T_{i:i+1} # n matrices, each 2 x 2, with coefficients being polynomials of degree <= 1 T_br[0] = torch.zeros(n, 2, 2, 2) T_br[0][:, 0, 0, 1] = a T_br[0][:, 0, 0, 0] = b T_br[0][:, 0, 1, 0] = c T_br[0][:, 1, 0, 0] = 1.0 br_perm = bitreversal_permutation(n) T_br[0] = T_br[0][br_perm] for i in range(1, m + 1): T_br[i] = polymatmul(T_br[i - 1][n >> i:], T_br[i - 1][:n >> i]) P_init = torch.tensor([p1, [p0, 0.0]], dtype=torch.float) # [p_1, p_0] P_init = P_init.unsqueeze(0).unsqueeze(-2) # Check that T_br is computed correctly # These should be the polynomials P_{n+1} and P_n # Pnp1n = polymatmul(T_br[m], P_init).squeeze() v_br = v[:, br_perm] # Bottom-up multiplication algorithm to avoid recursion S_br = [None] * m Tidentity = torch.eye(2).unsqueeze(0).unsqueeze(3) S_br[0] = v_br[:, n // 2:, None, None, None] * T_br[0][:n // 2] S_br[0][:, :, :, :, :1] += v_br[:, :n // 2, None, None, None] * Tidentity for i in range(1, m): S_br[i] = polymatmul(S_br[i - 1][:, (n >> (i + 1)):], T_br[i][:(n >> (i + 1))]) S_br[i][:, :, :, :, :S_br[i - 1].shape[-1]] += S_br[i - 1][:, :(n >> (i + 1))] result = polymatmul(S_br[m - 1][:, :, [1], :, :n - 1], P_init).squeeze(1).squeeze(1).squeeze(1) return result
def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], complex=True, fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn']) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.br_perm = torch.tensor(bitreversal_permutation(size))
def _setup(self, config): torch.manual_seed(config['seed']) self.model = HstackDiagProduct(size=config['size']) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] # Target: Legendre polynomials P = np.zeros((size, size), dtype=np.float64) for i, coef in enumerate(np.eye(size)): P[i, :i + 1] = legendre.leg2poly(coef) self.target_matrix = torch.tensor(P) self.br_perm = bitreversal_permutation(size) self.input = (torch.eye(size)[:, :, None, None] * torch.eye(2)).unsqueeze(-1) self.input_permuted = self.input[:, self.br_perm]
def _setup(self, config): self.target_matrix = torch.tensor(config['target_matrix'], dtype=torch.float) assert self.target_matrix.shape[0] == self.target_matrix.shape[ 1], 'Only square matrices are supported' assert self.target_matrix.dim() in [ 2, 3 ], 'target matrix must be 2D if real of 3D if complex' size = self.target_matrix.shape[0] torch.manual_seed(config['seed']) self.model = Block2x2DiagProduct(size=size, complex=True) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.n_epochs_per_validation = config['n_epochs_per_validation'] self.input = real_to_complex( torch.eye(size)[:, torch.tensor(bitreversal_permutation(size))])
def test_hstackdiag_product(): size = 8 model = HstackDiagProduct(size) # Legendre polynomials n = size m = int(np.log2(n)) n_range = torch.arange(n, dtype=torch.float) a = (2 * n_range + 3) / (n_range + 2) b = torch.zeros(n) c = -(n_range + 1) / (n_range + 2) p0 = 1.0 p1 = (0.0, 1.0) # Preprocessing: compute T_{i:j}, the transition matrix from p_i to p_j. T_br = [None] * m # Lowest level, filled with T_{i:i+1} # n matrices, each 2 x 2, with coefficients being polynomials of degree <= 1 T_br[0] = torch.zeros(n, 2, 2, 2) T_br[0][:, 0, 0, 1] = a T_br[0][:, 0, 0, 0] = b T_br[0][:, 0, 1, 0] = c T_br[0][:, 1, 0, 0] = 1.0 br_perm = bitreversal_permutation(n) T_br[0] = T_br[0][br_perm] for i in range(1, m): T_br[i] = polymatmul(T_br[i - 1][n >> i:], T_br[i - 1][:n >> i]) P_init = torch.tensor([p1, [p0, 0.0]], dtype=torch.float) # [p_1, p_0] P_init = P_init.unsqueeze(0).unsqueeze(-2) Tidentity = torch.eye(2).unsqueeze(0).unsqueeze(3) model.P_init = nn.Parameter(P_init) for i in range(m): factor = model.factors[m - i - 1] factor.diag1 = nn.Parameter( torch.cat((Tidentity.expand(factor.size, -1, -1, -1), torch.zeros(factor.size, 2, 2, factor.deg)), dim=-1)) factor.diag2 = nn.Parameter(T_br[i][:factor.size]) batch_size = 2 x_original = torch.randn((batch_size, size)) x = (x_original[:, :, None, None] * torch.eye(2)).unsqueeze(-1) output = model(x[:, br_perm]) assert output.shape == (batch_size, size) assert torch.allclose(output, ops_transpose_mult_br(a, b, c, p0, p1, x_original))