def polish_dct_complex(trial): """Load model from checkpoint, then fix the order of the factor matrices (using the largest logits), and re-optimize using L-BFGS to find the nearest local optima. """ trainable = eval(trial.trainable_name)(trial.config) trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value)) model = trainable.model config = trial.config polished_model = ButterflyProduct(size=config['size'], complex=model.complex, fixed_order=True) if not model.fixed_order: prob = model.softmax_fn(model.logit) maxes, argmaxes = torch.max(prob, dim=-1) polished_model.factors = nn.ModuleList([model.factors[argmax] for argmax in argmaxes]) else: polished_model.factors = model.factors optimizer = optim.LBFGS(polished_model.parameters()) def closure(): optimizer.zero_grad() loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm, 0], trainable.target_matrix) loss.backward() return loss for i in range(N_LBFGS_STEPS): optimizer.step(closure) torch.save(polished_model.state_dict(), str((Path(trial.logdir) / trial._checkpoint.value).parent / 'polished_model.pth')) loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm, 0], trainable.target_matrix) return loss.item()
def polish_fft(trial): trainable = eval(trial.trainable_name)(trial.config) trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value)) model = trainable.model config = trial.config polished_model = ButterflyProduct(size=config['size'], complex=model.complex, fixed_order=True) if not model.fixed_order: prob = model.softmax_fn(model.logit) maxes, argmaxes = torch.max(prob, dim=-1) # print(maxes) # if torch.all(maxes >= 0.99): polished_model.butterflies = nn.ModuleList( [model.butterflies[argmax] for argmax in argmaxes]) # else: # return -trial.last_result['negative_loss'] else: polished_model.butterflies = model.butterflies optimizer = optim.LBFGS(polished_model.parameters()) def closure(): optimizer.zero_grad() loss = nn.functional.mse_loss( polished_model.matrix()[:, trainable.br_perm], trainable.target_matrix) loss.backward() return loss for i in range(N_LBFGS_STEPS): optimizer.step(closure) torch.save( polished_model.state_dict(), str((Path(trial.logdir) / trial._checkpoint.value).parent / 'polished_model.pth')) loss = nn.functional.mse_loss( polished_model.matrix()[:, trainable.br_perm], trainable.target_matrix) return loss.item()