Esempio n. 1
0
def polish_dct_complex(trial):
    """Load model from checkpoint, then fix the order of the factor
    matrices (using the largest logits), and re-optimize using L-BFGS to find
    the nearest local optima.
    """
    trainable = eval(trial.trainable_name)(trial.config)
    trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value))
    model = trainable.model
    config = trial.config
    polished_model = ButterflyProduct(size=config['size'], complex=model.complex, fixed_order=True)
    if not model.fixed_order:
        prob = model.softmax_fn(model.logit)
        maxes, argmaxes = torch.max(prob, dim=-1)
        polished_model.factors = nn.ModuleList([model.factors[argmax] for argmax in argmaxes])
    else:
        polished_model.factors = model.factors
    optimizer = optim.LBFGS(polished_model.parameters())
    def closure():
        optimizer.zero_grad()
        loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm, 0], trainable.target_matrix)
        loss.backward()
        return loss
    for i in range(N_LBFGS_STEPS):
        optimizer.step(closure)
    torch.save(polished_model.state_dict(), str((Path(trial.logdir) / trial._checkpoint.value).parent / 'polished_model.pth'))
    loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm, 0], trainable.target_matrix)
    return loss.item()
Esempio n. 2
0
 def _setup(self, config):
     torch.manual_seed(config['seed'])
     self.model = ButterflyProduct(size=config['size'],
                                   complex=True,
                                   fixed_order=config['fixed_order'],
                                   softmax_fn=config['softmax_fn'])
     if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
         self.semantic_loss_weight = config['semantic_loss_weight']
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     size = config['size']
     n = size
     np.random.seed(0)
     x = np.random.randn(n)
     V = np.vander(x, increasing=True)
     self.target_matrix = torch.tensor(V, dtype=torch.float)
     arange_ = np.arange(size)
     dct_perm = np.concatenate((arange_[::2], arange_[::-2]))
     br_perm = bitreversal_permutation(size)
     assert config['perm'] in ['id', 'br', 'dct']
     if config['perm'] == 'id':
         self.perm = torch.arange(size)
     elif config['perm'] == 'br':
         self.perm = br_perm
     elif config['perm'] == 'dct':
         self.perm = torch.arange(size)[dct_perm][br_perm]
     else:
         assert False, 'Wrong perm in config'
Esempio n. 3
0
def polished_loss_fft_learn_perm(trainable):
    model = trainable.model
    polished_model = ButterflyProduct(size=model.size,
                                      complex=model.complex,
                                      fixed_order=True)
    temperature = 1.0 / (0.3 * trainable._iteration + 1)
    trainable.perm = torch.argmax(sinkhorn(model.perm_logit / temperature),
                                  dim=1)
    if not model.fixed_order:
        prob = model.softmax_fn(model.logit)
        maxes, argmaxes = torch.max(prob, dim=-1)
        polished_model.factors = nn.ModuleList(
            [model.factors[argmax] for argmax in argmaxes])
    else:
        polished_model.factors = model.factors
    preopt_loss = nn.functional.mse_loss(
        polished_model.matrix()[:, trainable.perm], trainable.target_matrix)
    optimizer = optim.LBFGS(polished_model.parameters())

    def closure():
        optimizer.zero_grad()
        loss = nn.functional.mse_loss(
            polished_model.matrix()[:, trainable.perm],
            trainable.target_matrix)
        loss.backward()
        return loss

    for i in range(N_LBFGS_STEPS_VALIDATION):
        optimizer.step(closure)
    loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm],
                                  trainable.target_matrix)
    # return loss.item() if not torch.isnan(loss) else preopt_loss.item() if not torch.isnan(preopt_loss) else float('inf')
    return loss.item() if not torch.isnan(loss) else preopt_loss.item(
    ) if not torch.isnan(preopt_loss) else 9999.0
Esempio n. 4
0
 def _setup(self, config):
     size = config['size']
     torch.manual_seed(config['seed'])
     self.model = ButterflyProduct(size=size,
                                   complex=True,
                                   fixed_order=True)
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
     self.br_perm = torch.tensor(bitreversal_permutation(size))
Esempio n. 5
0
 def _setup(self, config):
     size = config['size']
     torch.manual_seed(config['seed'])
     self.model = ButterflyProduct(size=size,
                                   complex=False,
                                   fixed_order=False,
                                   softmax_fn='softmax')
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     self.target_matrix = torch.rand(size, size, requires_grad=False)
Esempio n. 6
0
 def _setup(self, config):
     torch.manual_seed(config['seed'])
     self.model = ButterflyProduct(size=config['size'],
                                   fixed_order=config['fixed_order'],
                                   softmax_fn=config['softmax_fn'])
     if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
         self.semantic_loss_weight = config['semantic_loss_weight']
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     self.target_matrix = torch.tensor(hadamard(config['size']),
                                       dtype=torch.float)
Esempio n. 7
0
 def _setup(self, config):
     torch.manual_seed(config['seed'])
     self.model = ButterflyProduct(size=config['size'],
                                   complex=True,
                                   fixed_order=config['fixed_order'],
                                   softmax_fn=config['softmax_fn'],
                                   learn_perm=True)
     if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
         self.semantic_loss_weight = config['semantic_loss_weight']
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     size = config['size']
     self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
Esempio n. 8
0
def hadamard_test():
    # Hadamard matrix for n = 4
    size = 4
    M0 = Butterfly(size,
                   diagonal=2,
                   diag=torch.tensor([1.0, 1.0, -1.0, -1.0],
                                     requires_grad=True),
                   subdiag=torch.ones(2, requires_grad=True),
                   superdiag=torch.ones(2, requires_grad=True))
    M1 = Butterfly(size,
                   diagonal=1,
                   diag=torch.tensor([1.0, -1.0, 1.0, -1.0],
                                     requires_grad=True),
                   subdiag=torch.tensor([1.0, 0.0, 1.0], requires_grad=True),
                   superdiag=torch.tensor([1.0, 0.0, 1.0], requires_grad=True))
    H = M0.matrix() @ M1.matrix()
    assert torch.allclose(H, torch.tensor(hadamard(4), dtype=torch.float))
    M = ButterflyProduct(size, fixed_order=True)
    M.factors[0] = M0
    M.factors[1] = M1
    assert torch.allclose(M.matrix(), H)
Esempio n. 9
0
def polish_fft(trial):
    trainable = eval(trial.trainable_name)(trial.config)
    trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value))
    model = trainable.model
    config = trial.config
    polished_model = ButterflyProduct(size=config['size'],
                                      complex=model.complex,
                                      fixed_order=True)
    if not model.fixed_order:
        prob = model.softmax_fn(model.logit)
        maxes, argmaxes = torch.max(prob, dim=-1)
        # print(maxes)
        # if torch.all(maxes >= 0.99):
        polished_model.butterflies = nn.ModuleList(
            [model.butterflies[argmax] for argmax in argmaxes])
        # else:
        #     return -trial.last_result['negative_loss']
    else:
        polished_model.butterflies = model.butterflies
    optimizer = optim.LBFGS(polished_model.parameters())

    def closure():
        optimizer.zero_grad()
        loss = nn.functional.mse_loss(
            polished_model.matrix()[:, trainable.br_perm],
            trainable.target_matrix)
        loss.backward()
        return loss

    for i in range(N_LBFGS_STEPS):
        optimizer.step(closure)
    torch.save(
        polished_model.state_dict(),
        str((Path(trial.logdir) / trial._checkpoint.value).parent /
            'polished_model.pth'))
    loss = nn.functional.mse_loss(
        polished_model.matrix()[:, trainable.br_perm], trainable.target_matrix)
    return loss.item()