class TrainableFftFactorSoftmax(PytorchTrainable): def _setup(self, config): size = config['size'] torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=size, complex=True, fixed_order=False) self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.br_perm = torch.tensor(bitreversal_permutation(size)) def _train(self): for _ in range(self.n_steps_per_epoch): self.optimizer.zero_grad() y = self.model.matrix()[:, self.br_perm] loss = nn.functional.mse_loss(y, self.target_matrix) semantic_loss = semantic_loss_exactly_one( nn.functional.log_softmax(self.model.logit, dim=-1)) total_loss = loss + self.semantic_loss_weight * semantic_loss.mean( ) total_loss.backward() self.optimizer.step() return {'negative_loss': -loss.item()}
def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], complex=True, fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn']) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] n = size np.random.seed(0) x = np.random.randn(n) V = np.vander(x, increasing=True) self.target_matrix = torch.tensor(V, dtype=torch.float) arange_ = np.arange(size) dct_perm = np.concatenate((arange_[::2], arange_[::-2])) br_perm = bitreversal_permutation(size) assert config['perm'] in ['id', 'br', 'dct'] if config['perm'] == 'id': self.perm = torch.arange(size) elif config['perm'] == 'br': self.perm = br_perm elif config['perm'] == 'dct': self.perm = torch.arange(size)[dct_perm][br_perm] else: assert False, 'Wrong perm in config'
class TrainableHadamard(PytorchTrainable): def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn']) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.tensor(hadamard(config['size']), dtype=torch.float) def _train(self): for _ in range(self.n_steps_per_epoch): self.optimizer.zero_grad() y = self.model.matrix() loss = nn.functional.mse_loss(y, self.target_matrix) if (not self.model.fixed_order) and hasattr( self, 'semantic_loss_weight'): semantic_loss = semantic_loss_exactly_one( nn.functional.log_softmax(self.model.logit, dim=-1)) loss += self.semantic_loss_weight * semantic_loss.mean() loss.backward() self.optimizer.step() return {'negative_loss': -loss.item()}
class TrainableFftLearnPerm(PytorchTrainable): def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], complex=True, fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn'], learn_perm=True) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) def _train(self): temperature = 1.0 / (0.3 * self._iteration + 1) for _ in range(self.n_steps_per_epoch): self.optimizer.zero_grad() y = self.model.matrix(temperature) loss = nn.functional.mse_loss(y, self.target_matrix) if (not self.model.fixed_order) and hasattr( self, 'semantic_loss_weight'): semantic_loss = semantic_loss_exactly_one( nn.functional.log_softmax(self.model.logit, dim=-1)) loss += self.semantic_loss_weight * semantic_loss.mean() loss.backward() self.optimizer.step() return {'negative_loss': -polished_loss_fft_learn_perm(self)}
def _setup(self, config): size = config['size'] torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=size, complex=True, fixed_order=True) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.br_perm = torch.tensor(bitreversal_permutation(size))
def _setup(self, config): size = config['size'] torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=size, complex=False, fixed_order=False, softmax_fn='softmax') self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.rand(size, size, requires_grad=False)
def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn']) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.tensor(hadamard(config['size']), dtype=torch.float)
def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], complex=True, fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn'], learn_perm=True) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
def polish_dct_complex(trial): """Load model from checkpoint, then fix the order of the factor matrices (using the largest logits), and re-optimize using L-BFGS to find the nearest local optima. """ trainable = eval(trial.trainable_name)(trial.config) trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value)) model = trainable.model config = trial.config polished_model = ButterflyProduct(size=config['size'], complex=model.complex, fixed_order=True) if not model.fixed_order: prob = model.softmax_fn(model.logit) maxes, argmaxes = torch.max(prob, dim=-1) polished_model.factors = nn.ModuleList([model.factors[argmax] for argmax in argmaxes]) else: polished_model.factors = model.factors optimizer = optim.LBFGS(polished_model.parameters()) def closure(): optimizer.zero_grad() loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm, 0], trainable.target_matrix) loss.backward() return loss for i in range(N_LBFGS_STEPS): optimizer.step(closure) torch.save(polished_model.state_dict(), str((Path(trial.logdir) / trial._checkpoint.value).parent / 'polished_model.pth')) loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm, 0], trainable.target_matrix) return loss.item()
def polished_loss_fft_learn_perm(trainable): model = trainable.model polished_model = ButterflyProduct(size=model.size, complex=model.complex, fixed_order=True) temperature = 1.0 / (0.3 * trainable._iteration + 1) trainable.perm = torch.argmax(sinkhorn(model.perm_logit / temperature), dim=1) if not model.fixed_order: prob = model.softmax_fn(model.logit) maxes, argmaxes = torch.max(prob, dim=-1) polished_model.factors = nn.ModuleList( [model.factors[argmax] for argmax in argmaxes]) else: polished_model.factors = model.factors preopt_loss = nn.functional.mse_loss( polished_model.matrix()[:, trainable.perm], trainable.target_matrix) optimizer = optim.LBFGS(polished_model.parameters()) def closure(): optimizer.zero_grad() loss = nn.functional.mse_loss( polished_model.matrix()[:, trainable.perm], trainable.target_matrix) loss.backward() return loss for i in range(N_LBFGS_STEPS_VALIDATION): optimizer.step(closure) loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm], trainable.target_matrix) # return loss.item() if not torch.isnan(loss) else preopt_loss.item() if not torch.isnan(preopt_loss) else float('inf') return loss.item() if not torch.isnan(loss) else preopt_loss.item( ) if not torch.isnan(preopt_loss) else 9999.0
class TrainableFft(PytorchTrainable): def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], complex=True, fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn']) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.br_perm = torch.tensor(bitreversal_permutation(size)) # br_perm = bitreversal_permutation(size) # br_reverse = torch.tensor(list(br_perm[::-1])) # br_reverse = torch.cat((torch.tensor(list(br_perm[:size//2][::-1])), torch.tensor(list(br_perm[size//2:][::-1])))) # Same as [6, 2, 4, 0, 7, 3, 5, 1], which is [0, 1]^4 * [0, 2, 1, 3]^2 * [6, 4, 2, 0, 7, 5, 3, 1] # br_reverse = torch.cat((torch.tensor(list(br_perm[:size//4][::-1])), torch.tensor(list(br_perm[size//4:size//2][::-1])), torch.tensor(list(br_perm[size//2:3*size//4][::-1])), torch.tensor(list(br_perm[3*size//4:][::-1])))) # self.br_perm = br_reverse # self.br_perm = torch.tensor([0, 7, 4, 3, 2, 5, 6, 1]) # Doesn't work # self.br_perm = torch.tensor([7, 3, 0, 4, 2, 6, 5, 1]) # Doesn't work # self.br_perm = torch.tensor([4, 0, 6, 2, 5, 1, 7, 3]) # This works, [0, 1]^4 * [2, 0, 3, 1]^2 * [0, 2, 4, 6, 1, 3, 5, 7] or [1, 0]^4 * [0, 2, 1, 3]^2 * [0, 2, 4, 6, 1, 3, 5, 7] # self.br_perm = torch.tensor([4, 0, 2, 6, 5, 1, 3, 7]) # Doesn't work, [0, 1]^4 * [2, 0, 1, 3]^2 * [0, 2, 4, 6, 1, 3, 5, 7] # self.br_perm = torch.tensor([1, 5, 3, 7, 0, 4, 2, 6]) # This works, [0, 1]^4 * [4, 6, 5, 7, 0, 4, 2, 6] # self.br_perm = torch.tensor([4, 0, 6, 2, 5, 1, 3, 7]) # Doesn't work # self.br_perm = torch.tensor([4, 0, 6, 2, 1, 5, 3, 7]) # Doesn't work # self.br_perm = torch.tensor([0, 4, 6, 2, 1, 5, 7, 3]) # Doesn't work # self.br_perm = torch.tensor([4, 1, 6, 2, 5, 0, 7, 3]) # This works, since it's just swapping 0 and 1 # self.br_perm = torch.tensor([5, 1, 6, 2, 4, 0, 7, 3]) # This works, since it's swapping 4 and 5 def _train(self): for _ in range(self.n_steps_per_epoch): self.optimizer.zero_grad() y = self.model.matrix()[:, self.br_perm] loss = nn.functional.mse_loss(y, self.target_matrix) if (not self.model.fixed_order) and hasattr( self, 'semantic_loss_weight'): semantic_loss = semantic_loss_exactly_one( nn.functional.log_softmax(self.model.logit, dim=-1)) loss += self.semantic_loss_weight * semantic_loss.mean() loss.backward() self.optimizer.step() return {'negative_loss': -loss.item()}
class TrainableVandermondeReal(PytorchTrainable): def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], complex=False, fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn']) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] # Need to transpose as dct acts on rows of matrix np.eye, not columns n = size np.random.seed(0) x = np.random.randn(n) V = np.vander(x, increasing=True) self.target_matrix = torch.tensor(V, dtype=torch.float) arange_ = np.arange(size) dct_perm = np.concatenate((arange_[::2], arange_[::-2])) br_perm = bitreversal_permutation(size) assert config['perm'] in ['id', 'br', 'dct'] if config['perm'] == 'id': self.perm = torch.arange(size) elif config['perm'] == 'br': self.perm = br_perm elif config['perm'] == 'dct': self.perm = torch.arange(size)[dct_perm][br_perm] else: assert False, 'Wrong perm in config' def _train(self): for _ in range(self.n_steps_per_epoch): self.optimizer.zero_grad() y = self.model.matrix()[:, self.perm] loss = nn.functional.mse_loss(y, self.target_matrix) if (not self.model.fixed_order) and hasattr(self, 'semantic_loss_weight'): semantic_loss = semantic_loss_exactly_one(nn.functional.log_softmax(self.model.logit, dim=-1)) loss += self.semantic_loss_weight * semantic_loss.mean() loss.backward() self.optimizer.step() return {'negative_loss': -loss.item()}
class TrainableFftFactorFixedOrder(PytorchTrainable): def _setup(self, config): size = config['size'] torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=size, complex=True, fixed_order=True) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.br_perm = torch.tensor(bitreversal_permutation(size)) def _train(self): for _ in range(self.n_steps_per_epoch): self.optimizer.zero_grad() y = self.model.matrix()[:, self.br_perm] loss = nn.functional.mse_loss(y, self.target_matrix) loss.backward() self.optimizer.step() return {'negative_loss': -loss.item()}
class TrainableRandnFactorSoftmaxNoPerm(PytorchTrainable): def _setup(self, config): size = config['size'] torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=size, complex=False, fixed_order=False, softmax_fn='softmax') self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.rand(size, size, requires_grad=False) def _train(self): for _ in range(self.n_steps_per_epoch): self.optimizer.zero_grad() y = self.model.matrix() loss = nn.functional.mse_loss(y, self.target_matrix) loss.backward() self.optimizer.step() return {'negative_loss': -loss.item()}
def hadamard_test(): # Hadamard matrix for n = 4 size = 4 M0 = Butterfly(size, diagonal=2, diag=torch.tensor([1.0, 1.0, -1.0, -1.0], requires_grad=True), subdiag=torch.ones(2, requires_grad=True), superdiag=torch.ones(2, requires_grad=True)) M1 = Butterfly(size, diagonal=1, diag=torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True), subdiag=torch.tensor([1.0, 0.0, 1.0], requires_grad=True), superdiag=torch.tensor([1.0, 0.0, 1.0], requires_grad=True)) H = M0.matrix() @ M1.matrix() assert torch.allclose(H, torch.tensor(hadamard(4), dtype=torch.float)) M = ButterflyProduct(size, fixed_order=True) M.factors[0] = M0 M.factors[1] = M1 assert torch.allclose(M.matrix(), H)
class TrainableFftFactorSparsemax(TrainableFftFactorFixedOrder): def _setup(self, config): size = config['size'] torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=size, complex=True, fixed_order=False, softmax_fn='sparsemax') self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.br_perm = torch.tensor(bitreversal_permutation(size))
def polish_fft(trial): trainable = eval(trial.trainable_name)(trial.config) trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value)) model = trainable.model config = trial.config polished_model = ButterflyProduct(size=config['size'], complex=model.complex, fixed_order=True) if not model.fixed_order: prob = model.softmax_fn(model.logit) maxes, argmaxes = torch.max(prob, dim=-1) # print(maxes) # if torch.all(maxes >= 0.99): polished_model.butterflies = nn.ModuleList( [model.butterflies[argmax] for argmax in argmaxes]) # else: # return -trial.last_result['negative_loss'] else: polished_model.butterflies = model.butterflies optimizer = optim.LBFGS(polished_model.parameters()) def closure(): optimizer.zero_grad() loss = nn.functional.mse_loss( polished_model.matrix()[:, trainable.br_perm], trainable.target_matrix) loss.backward() return loss for i in range(N_LBFGS_STEPS): optimizer.step(closure) torch.save( polished_model.state_dict(), str((Path(trial.logdir) / trial._checkpoint.value).parent / 'polished_model.pth')) loss = nn.functional.mse_loss( polished_model.matrix()[:, trainable.br_perm], trainable.target_matrix) return loss.item()