def polish_dct_complex(trial): """Load model from checkpoint, then fix the order of the factor matrices (using the largest logits), and re-optimize using L-BFGS to find the nearest local optima. """ trainable = eval(trial.trainable_name)(trial.config) trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value)) model = trainable.model config = trial.config polished_model = ButterflyProduct(size=config['size'], complex=model.complex, fixed_order=True) if not model.fixed_order: prob = model.softmax_fn(model.logit) maxes, argmaxes = torch.max(prob, dim=-1) polished_model.factors = nn.ModuleList([model.factors[argmax] for argmax in argmaxes]) else: polished_model.factors = model.factors optimizer = optim.LBFGS(polished_model.parameters()) def closure(): optimizer.zero_grad() loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm, 0], trainable.target_matrix) loss.backward() return loss for i in range(N_LBFGS_STEPS): optimizer.step(closure) torch.save(polished_model.state_dict(), str((Path(trial.logdir) / trial._checkpoint.value).parent / 'polished_model.pth')) loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm, 0], trainable.target_matrix) return loss.item()
def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], complex=True, fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn']) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] n = size np.random.seed(0) x = np.random.randn(n) V = np.vander(x, increasing=True) self.target_matrix = torch.tensor(V, dtype=torch.float) arange_ = np.arange(size) dct_perm = np.concatenate((arange_[::2], arange_[::-2])) br_perm = bitreversal_permutation(size) assert config['perm'] in ['id', 'br', 'dct'] if config['perm'] == 'id': self.perm = torch.arange(size) elif config['perm'] == 'br': self.perm = br_perm elif config['perm'] == 'dct': self.perm = torch.arange(size)[dct_perm][br_perm] else: assert False, 'Wrong perm in config'
def polished_loss_fft_learn_perm(trainable): model = trainable.model polished_model = ButterflyProduct(size=model.size, complex=model.complex, fixed_order=True) temperature = 1.0 / (0.3 * trainable._iteration + 1) trainable.perm = torch.argmax(sinkhorn(model.perm_logit / temperature), dim=1) if not model.fixed_order: prob = model.softmax_fn(model.logit) maxes, argmaxes = torch.max(prob, dim=-1) polished_model.factors = nn.ModuleList( [model.factors[argmax] for argmax in argmaxes]) else: polished_model.factors = model.factors preopt_loss = nn.functional.mse_loss( polished_model.matrix()[:, trainable.perm], trainable.target_matrix) optimizer = optim.LBFGS(polished_model.parameters()) def closure(): optimizer.zero_grad() loss = nn.functional.mse_loss( polished_model.matrix()[:, trainable.perm], trainable.target_matrix) loss.backward() return loss for i in range(N_LBFGS_STEPS_VALIDATION): optimizer.step(closure) loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm], trainable.target_matrix) # return loss.item() if not torch.isnan(loss) else preopt_loss.item() if not torch.isnan(preopt_loss) else float('inf') return loss.item() if not torch.isnan(loss) else preopt_loss.item( ) if not torch.isnan(preopt_loss) else 9999.0
def _setup(self, config): size = config['size'] torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=size, complex=True, fixed_order=True) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.br_perm = torch.tensor(bitreversal_permutation(size))
def _setup(self, config): size = config['size'] torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=size, complex=False, fixed_order=False, softmax_fn='softmax') self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.rand(size, size, requires_grad=False)
def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn']) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.tensor(hadamard(config['size']), dtype=torch.float)
def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], complex=True, fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn'], learn_perm=True) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
def hadamard_test(): # Hadamard matrix for n = 4 size = 4 M0 = Butterfly(size, diagonal=2, diag=torch.tensor([1.0, 1.0, -1.0, -1.0], requires_grad=True), subdiag=torch.ones(2, requires_grad=True), superdiag=torch.ones(2, requires_grad=True)) M1 = Butterfly(size, diagonal=1, diag=torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True), subdiag=torch.tensor([1.0, 0.0, 1.0], requires_grad=True), superdiag=torch.tensor([1.0, 0.0, 1.0], requires_grad=True)) H = M0.matrix() @ M1.matrix() assert torch.allclose(H, torch.tensor(hadamard(4), dtype=torch.float)) M = ButterflyProduct(size, fixed_order=True) M.factors[0] = M0 M.factors[1] = M1 assert torch.allclose(M.matrix(), H)
def polish_fft(trial): trainable = eval(trial.trainable_name)(trial.config) trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value)) model = trainable.model config = trial.config polished_model = ButterflyProduct(size=config['size'], complex=model.complex, fixed_order=True) if not model.fixed_order: prob = model.softmax_fn(model.logit) maxes, argmaxes = torch.max(prob, dim=-1) # print(maxes) # if torch.all(maxes >= 0.99): polished_model.butterflies = nn.ModuleList( [model.butterflies[argmax] for argmax in argmaxes]) # else: # return -trial.last_result['negative_loss'] else: polished_model.butterflies = model.butterflies optimizer = optim.LBFGS(polished_model.parameters()) def closure(): optimizer.zero_grad() loss = nn.functional.mse_loss( polished_model.matrix()[:, trainable.br_perm], trainable.target_matrix) loss.backward() return loss for i in range(N_LBFGS_STEPS): optimizer.step(closure) torch.save( polished_model.state_dict(), str((Path(trial.logdir) / trial._checkpoint.value).parent / 'polished_model.pth')) loss = nn.functional.mse_loss( polished_model.matrix()[:, trainable.br_perm], trainable.target_matrix) return loss.item()