Esempio n. 1
0
def polish_dct_complex(trial):
    """Load model from checkpoint, then fix the order of the factor
    matrices (using the largest logits), and re-optimize using L-BFGS to find
    the nearest local optima.
    """
    trainable = eval(trial.trainable_name)(trial.config)
    trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value))
    model = trainable.model
    config = trial.config
    polished_model = ButterflyProduct(size=config['size'], complex=model.complex, fixed_order=True)
    if not model.fixed_order:
        prob = model.softmax_fn(model.logit)
        maxes, argmaxes = torch.max(prob, dim=-1)
        polished_model.factors = nn.ModuleList([model.factors[argmax] for argmax in argmaxes])
    else:
        polished_model.factors = model.factors
    optimizer = optim.LBFGS(polished_model.parameters())
    def closure():
        optimizer.zero_grad()
        loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm, 0], trainable.target_matrix)
        loss.backward()
        return loss
    for i in range(N_LBFGS_STEPS):
        optimizer.step(closure)
    torch.save(polished_model.state_dict(), str((Path(trial.logdir) / trial._checkpoint.value).parent / 'polished_model.pth'))
    loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm, 0], trainable.target_matrix)
    return loss.item()
Esempio n. 2
0
def polished_loss_fft_learn_perm(trainable):
    model = trainable.model
    polished_model = ButterflyProduct(size=model.size,
                                      complex=model.complex,
                                      fixed_order=True)
    temperature = 1.0 / (0.3 * trainable._iteration + 1)
    trainable.perm = torch.argmax(sinkhorn(model.perm_logit / temperature),
                                  dim=1)
    if not model.fixed_order:
        prob = model.softmax_fn(model.logit)
        maxes, argmaxes = torch.max(prob, dim=-1)
        polished_model.factors = nn.ModuleList(
            [model.factors[argmax] for argmax in argmaxes])
    else:
        polished_model.factors = model.factors
    preopt_loss = nn.functional.mse_loss(
        polished_model.matrix()[:, trainable.perm], trainable.target_matrix)
    optimizer = optim.LBFGS(polished_model.parameters())

    def closure():
        optimizer.zero_grad()
        loss = nn.functional.mse_loss(
            polished_model.matrix()[:, trainable.perm],
            trainable.target_matrix)
        loss.backward()
        return loss

    for i in range(N_LBFGS_STEPS_VALIDATION):
        optimizer.step(closure)
    loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm],
                                  trainable.target_matrix)
    # return loss.item() if not torch.isnan(loss) else preopt_loss.item() if not torch.isnan(preopt_loss) else float('inf')
    return loss.item() if not torch.isnan(loss) else preopt_loss.item(
    ) if not torch.isnan(preopt_loss) else 9999.0