def polish_dct_complex(trial): """Load model from checkpoint, then fix the order of the factor matrices (using the largest logits), and re-optimize using L-BFGS to find the nearest local optima. """ trainable = eval(trial.trainable_name)(trial.config) trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value)) model = trainable.model config = trial.config polished_model = ButterflyProduct(size=config['size'], complex=model.complex, fixed_order=True) if not model.fixed_order: prob = model.softmax_fn(model.logit) maxes, argmaxes = torch.max(prob, dim=-1) polished_model.factors = nn.ModuleList([model.factors[argmax] for argmax in argmaxes]) else: polished_model.factors = model.factors optimizer = optim.LBFGS(polished_model.parameters()) def closure(): optimizer.zero_grad() loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm, 0], trainable.target_matrix) loss.backward() return loss for i in range(N_LBFGS_STEPS): optimizer.step(closure) torch.save(polished_model.state_dict(), str((Path(trial.logdir) / trial._checkpoint.value).parent / 'polished_model.pth')) loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm, 0], trainable.target_matrix) return loss.item()
def polished_loss_fft_learn_perm(trainable): model = trainable.model polished_model = ButterflyProduct(size=model.size, complex=model.complex, fixed_order=True) temperature = 1.0 / (0.3 * trainable._iteration + 1) trainable.perm = torch.argmax(sinkhorn(model.perm_logit / temperature), dim=1) if not model.fixed_order: prob = model.softmax_fn(model.logit) maxes, argmaxes = torch.max(prob, dim=-1) polished_model.factors = nn.ModuleList( [model.factors[argmax] for argmax in argmaxes]) else: polished_model.factors = model.factors preopt_loss = nn.functional.mse_loss( polished_model.matrix()[:, trainable.perm], trainable.target_matrix) optimizer = optim.LBFGS(polished_model.parameters()) def closure(): optimizer.zero_grad() loss = nn.functional.mse_loss( polished_model.matrix()[:, trainable.perm], trainable.target_matrix) loss.backward() return loss for i in range(N_LBFGS_STEPS_VALIDATION): optimizer.step(closure) loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm], trainable.target_matrix) # return loss.item() if not torch.isnan(loss) else preopt_loss.item() if not torch.isnan(preopt_loss) else float('inf') return loss.item() if not torch.isnan(loss) else preopt_loss.item( ) if not torch.isnan(preopt_loss) else 9999.0