Esempio n. 1
0
def polish_dct_complex(trial):
    """Load model from checkpoint, then fix the order of the factor
    matrices (using the largest logits), and re-optimize using L-BFGS to find
    the nearest local optima.
    """
    trainable = eval(trial.trainable_name)(trial.config)
    trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value))
    model = trainable.model
    config = trial.config
    polished_model = ButterflyProduct(size=config['size'], complex=model.complex, fixed_order=True)
    if not model.fixed_order:
        prob = model.softmax_fn(model.logit)
        maxes, argmaxes = torch.max(prob, dim=-1)
        polished_model.factors = nn.ModuleList([model.factors[argmax] for argmax in argmaxes])
    else:
        polished_model.factors = model.factors
    optimizer = optim.LBFGS(polished_model.parameters())
    def closure():
        optimizer.zero_grad()
        loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm, 0], trainable.target_matrix)
        loss.backward()
        return loss
    for i in range(N_LBFGS_STEPS):
        optimizer.step(closure)
    torch.save(polished_model.state_dict(), str((Path(trial.logdir) / trial._checkpoint.value).parent / 'polished_model.pth'))
    loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm, 0], trainable.target_matrix)
    return loss.item()
Esempio n. 2
0
def polish_fft(trial):
    trainable = eval(trial.trainable_name)(trial.config)
    trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value))
    model = trainable.model
    config = trial.config
    polished_model = ButterflyProduct(size=config['size'],
                                      complex=model.complex,
                                      fixed_order=True)
    if not model.fixed_order:
        prob = model.softmax_fn(model.logit)
        maxes, argmaxes = torch.max(prob, dim=-1)
        # print(maxes)
        # if torch.all(maxes >= 0.99):
        polished_model.butterflies = nn.ModuleList(
            [model.butterflies[argmax] for argmax in argmaxes])
        # else:
        #     return -trial.last_result['negative_loss']
    else:
        polished_model.butterflies = model.butterflies
    optimizer = optim.LBFGS(polished_model.parameters())

    def closure():
        optimizer.zero_grad()
        loss = nn.functional.mse_loss(
            polished_model.matrix()[:, trainable.br_perm],
            trainable.target_matrix)
        loss.backward()
        return loss

    for i in range(N_LBFGS_STEPS):
        optimizer.step(closure)
    torch.save(
        polished_model.state_dict(),
        str((Path(trial.logdir) / trial._checkpoint.value).parent /
            'polished_model.pth'))
    loss = nn.functional.mse_loss(
        polished_model.matrix()[:, trainable.br_perm], trainable.target_matrix)
    return loss.item()