Esempio n. 1
0
def polished_loss_fft_learn_perm(trainable):
    model = trainable.model
    polished_model = ButterflyProduct(size=model.size,
                                      complex=model.complex,
                                      fixed_order=True)
    temperature = 1.0 / (0.3 * trainable._iteration + 1)
    trainable.perm = torch.argmax(sinkhorn(model.perm_logit / temperature),
                                  dim=1)
    if not model.fixed_order:
        prob = model.softmax_fn(model.logit)
        maxes, argmaxes = torch.max(prob, dim=-1)
        polished_model.factors = nn.ModuleList(
            [model.factors[argmax] for argmax in argmaxes])
    else:
        polished_model.factors = model.factors
    preopt_loss = nn.functional.mse_loss(
        polished_model.matrix()[:, trainable.perm], trainable.target_matrix)
    optimizer = optim.LBFGS(polished_model.parameters())

    def closure():
        optimizer.zero_grad()
        loss = nn.functional.mse_loss(
            polished_model.matrix()[:, trainable.perm],
            trainable.target_matrix)
        loss.backward()
        return loss

    for i in range(N_LBFGS_STEPS_VALIDATION):
        optimizer.step(closure)
    loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm],
                                  trainable.target_matrix)
    # return loss.item() if not torch.isnan(loss) else preopt_loss.item() if not torch.isnan(preopt_loss) else float('inf')
    return loss.item() if not torch.isnan(loss) else preopt_loss.item(
    ) if not torch.isnan(preopt_loss) else 9999.0
Esempio n. 2
0
def polish_dct_complex(trial):
    """Load model from checkpoint, then fix the order of the factor
    matrices (using the largest logits), and re-optimize using L-BFGS to find
    the nearest local optima.
    """
    trainable = eval(trial.trainable_name)(trial.config)
    trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value))
    model = trainable.model
    config = trial.config
    polished_model = ButterflyProduct(size=config['size'], complex=model.complex, fixed_order=True)
    if not model.fixed_order:
        prob = model.softmax_fn(model.logit)
        maxes, argmaxes = torch.max(prob, dim=-1)
        polished_model.factors = nn.ModuleList([model.factors[argmax] for argmax in argmaxes])
    else:
        polished_model.factors = model.factors
    optimizer = optim.LBFGS(polished_model.parameters())
    def closure():
        optimizer.zero_grad()
        loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm, 0], trainable.target_matrix)
        loss.backward()
        return loss
    for i in range(N_LBFGS_STEPS):
        optimizer.step(closure)
    torch.save(polished_model.state_dict(), str((Path(trial.logdir) / trial._checkpoint.value).parent / 'polished_model.pth'))
    loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.perm, 0], trainable.target_matrix)
    return loss.item()
Esempio n. 3
0
class TrainableFftLearnPerm(PytorchTrainable):
    def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=config['size'],
                                      complex=True,
                                      fixed_order=config['fixed_order'],
                                      softmax_fn=config['softmax_fn'],
                                      learn_perm=True)
        if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
            self.semantic_loss_weight = config['semantic_loss_weight']
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        size = config['size']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)

    def _train(self):
        temperature = 1.0 / (0.3 * self._iteration + 1)
        for _ in range(self.n_steps_per_epoch):
            self.optimizer.zero_grad()
            y = self.model.matrix(temperature)
            loss = nn.functional.mse_loss(y, self.target_matrix)
            if (not self.model.fixed_order) and hasattr(
                    self, 'semantic_loss_weight'):
                semantic_loss = semantic_loss_exactly_one(
                    nn.functional.log_softmax(self.model.logit, dim=-1))
                loss += self.semantic_loss_weight * semantic_loss.mean()
            loss.backward()
            self.optimizer.step()
        return {'negative_loss': -polished_loss_fft_learn_perm(self)}
Esempio n. 4
0
class TrainableFftFactorSoftmax(PytorchTrainable):
    def _setup(self, config):
        size = config['size']
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=size,
                                      complex=True,
                                      fixed_order=False)
        self.semantic_loss_weight = config['semantic_loss_weight']
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
        self.br_perm = torch.tensor(bitreversal_permutation(size))

    def _train(self):
        for _ in range(self.n_steps_per_epoch):
            self.optimizer.zero_grad()
            y = self.model.matrix()[:, self.br_perm]
            loss = nn.functional.mse_loss(y, self.target_matrix)
            semantic_loss = semantic_loss_exactly_one(
                nn.functional.log_softmax(self.model.logit, dim=-1))
            total_loss = loss + self.semantic_loss_weight * semantic_loss.mean(
            )
            total_loss.backward()
            self.optimizer.step()
        return {'negative_loss': -loss.item()}
Esempio n. 5
0
class TrainableHadamard(PytorchTrainable):
    def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=config['size'],
                                      fixed_order=config['fixed_order'],
                                      softmax_fn=config['softmax_fn'])
        if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
            self.semantic_loss_weight = config['semantic_loss_weight']
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        self.target_matrix = torch.tensor(hadamard(config['size']),
                                          dtype=torch.float)

    def _train(self):
        for _ in range(self.n_steps_per_epoch):
            self.optimizer.zero_grad()
            y = self.model.matrix()
            loss = nn.functional.mse_loss(y, self.target_matrix)
            if (not self.model.fixed_order) and hasattr(
                    self, 'semantic_loss_weight'):
                semantic_loss = semantic_loss_exactly_one(
                    nn.functional.log_softmax(self.model.logit, dim=-1))
                loss += self.semantic_loss_weight * semantic_loss.mean()
            loss.backward()
            self.optimizer.step()
        return {'negative_loss': -loss.item()}
Esempio n. 6
0
class TrainableFft(PytorchTrainable):
    def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=config['size'],
                                      complex=True,
                                      fixed_order=config['fixed_order'],
                                      softmax_fn=config['softmax_fn'])
        if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
            self.semantic_loss_weight = config['semantic_loss_weight']
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        size = config['size']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
        self.br_perm = torch.tensor(bitreversal_permutation(size))
        # br_perm = bitreversal_permutation(size)
        # br_reverse = torch.tensor(list(br_perm[::-1]))
        # br_reverse = torch.cat((torch.tensor(list(br_perm[:size//2][::-1])), torch.tensor(list(br_perm[size//2:][::-1]))))
        # Same as [6, 2, 4, 0, 7, 3, 5, 1], which is [0, 1]^4 * [0, 2, 1, 3]^2 * [6, 4, 2, 0, 7, 5, 3, 1]
        # br_reverse = torch.cat((torch.tensor(list(br_perm[:size//4][::-1])), torch.tensor(list(br_perm[size//4:size//2][::-1])), torch.tensor(list(br_perm[size//2:3*size//4][::-1])), torch.tensor(list(br_perm[3*size//4:][::-1]))))
        # self.br_perm = br_reverse
        # self.br_perm = torch.tensor([0, 7, 4, 3, 2, 5, 6, 1])  # Doesn't work
        # self.br_perm = torch.tensor([7, 3, 0, 4, 2, 6, 5, 1])  # Doesn't work
        # self.br_perm = torch.tensor([4, 0, 6, 2, 5, 1, 7, 3])  # This works, [0, 1]^4 * [2, 0, 3, 1]^2 * [0, 2, 4, 6, 1, 3, 5, 7] or [1, 0]^4 * [0, 2, 1, 3]^2 * [0, 2, 4, 6, 1, 3, 5, 7]
        # self.br_perm = torch.tensor([4, 0, 2, 6, 5, 1, 3, 7])  # Doesn't work, [0, 1]^4 * [2, 0, 1, 3]^2 * [0, 2, 4, 6, 1, 3, 5, 7]
        # self.br_perm = torch.tensor([1, 5, 3, 7, 0, 4, 2, 6])  # This works, [0, 1]^4 * [4, 6, 5, 7, 0, 4, 2, 6]
        # self.br_perm = torch.tensor([4, 0, 6, 2, 5, 1, 3, 7])  # Doesn't work
        # self.br_perm = torch.tensor([4, 0, 6, 2, 1, 5, 3, 7])  # Doesn't work
        # self.br_perm = torch.tensor([0, 4, 6, 2, 1, 5, 7, 3])  # Doesn't work
        # self.br_perm = torch.tensor([4, 1, 6, 2, 5, 0, 7, 3])  # This works, since it's just swapping 0 and 1
        # self.br_perm = torch.tensor([5, 1, 6, 2, 4, 0, 7, 3])  # This works, since it's swapping 4 and 5

    def _train(self):
        for _ in range(self.n_steps_per_epoch):
            self.optimizer.zero_grad()
            y = self.model.matrix()[:, self.br_perm]
            loss = nn.functional.mse_loss(y, self.target_matrix)
            if (not self.model.fixed_order) and hasattr(
                    self, 'semantic_loss_weight'):
                semantic_loss = semantic_loss_exactly_one(
                    nn.functional.log_softmax(self.model.logit, dim=-1))
                loss += self.semantic_loss_weight * semantic_loss.mean()
            loss.backward()
            self.optimizer.step()
        return {'negative_loss': -loss.item()}
Esempio n. 7
0
class TrainableVandermondeReal(PytorchTrainable):

    def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=config['size'],
                                      complex=False,
                                      fixed_order=config['fixed_order'],
                                      softmax_fn=config['softmax_fn'])
        if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
            self.semantic_loss_weight = config['semantic_loss_weight']
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        size = config['size']
        # Need to transpose as dct acts on rows of matrix np.eye, not columns
        n = size
        np.random.seed(0)
        x = np.random.randn(n)
        V = np.vander(x, increasing=True)
        self.target_matrix = torch.tensor(V, dtype=torch.float)
        arange_ = np.arange(size)
        dct_perm = np.concatenate((arange_[::2], arange_[::-2]))
        br_perm = bitreversal_permutation(size)
        assert config['perm'] in ['id', 'br', 'dct']
        if config['perm'] == 'id':
            self.perm = torch.arange(size)
        elif config['perm'] == 'br':
            self.perm = br_perm
        elif config['perm'] == 'dct':
            self.perm = torch.arange(size)[dct_perm][br_perm]
        else:
            assert False, 'Wrong perm in config'

    def _train(self):
        for _ in range(self.n_steps_per_epoch):
            self.optimizer.zero_grad()
            y = self.model.matrix()[:, self.perm]
            loss = nn.functional.mse_loss(y, self.target_matrix)
            if (not self.model.fixed_order) and hasattr(self, 'semantic_loss_weight'):
                semantic_loss = semantic_loss_exactly_one(nn.functional.log_softmax(self.model.logit, dim=-1))
                loss += self.semantic_loss_weight * semantic_loss.mean()
            loss.backward()
            self.optimizer.step()
        return {'negative_loss': -loss.item()}
Esempio n. 8
0
class TrainableFftFactorFixedOrder(PytorchTrainable):
    def _setup(self, config):
        size = config['size']
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=size,
                                      complex=True,
                                      fixed_order=True)
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
        self.br_perm = torch.tensor(bitreversal_permutation(size))

    def _train(self):
        for _ in range(self.n_steps_per_epoch):
            self.optimizer.zero_grad()
            y = self.model.matrix()[:, self.br_perm]
            loss = nn.functional.mse_loss(y, self.target_matrix)
            loss.backward()
            self.optimizer.step()
        return {'negative_loss': -loss.item()}
Esempio n. 9
0
class TrainableRandnFactorSoftmaxNoPerm(PytorchTrainable):
    def _setup(self, config):
        size = config['size']
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=size,
                                      complex=False,
                                      fixed_order=False,
                                      softmax_fn='softmax')
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        self.target_matrix = torch.rand(size, size, requires_grad=False)

    def _train(self):
        for _ in range(self.n_steps_per_epoch):
            self.optimizer.zero_grad()
            y = self.model.matrix()
            loss = nn.functional.mse_loss(y, self.target_matrix)
            loss.backward()
            self.optimizer.step()
        return {'negative_loss': -loss.item()}
Esempio n. 10
0
def hadamard_test():
    # Hadamard matrix for n = 4
    size = 4
    M0 = Butterfly(size,
                   diagonal=2,
                   diag=torch.tensor([1.0, 1.0, -1.0, -1.0],
                                     requires_grad=True),
                   subdiag=torch.ones(2, requires_grad=True),
                   superdiag=torch.ones(2, requires_grad=True))
    M1 = Butterfly(size,
                   diagonal=1,
                   diag=torch.tensor([1.0, -1.0, 1.0, -1.0],
                                     requires_grad=True),
                   subdiag=torch.tensor([1.0, 0.0, 1.0], requires_grad=True),
                   superdiag=torch.tensor([1.0, 0.0, 1.0], requires_grad=True))
    H = M0.matrix() @ M1.matrix()
    assert torch.allclose(H, torch.tensor(hadamard(4), dtype=torch.float))
    M = ButterflyProduct(size, fixed_order=True)
    M.factors[0] = M0
    M.factors[1] = M1
    assert torch.allclose(M.matrix(), H)
Esempio n. 11
0
def polish_fft(trial):
    trainable = eval(trial.trainable_name)(trial.config)
    trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value))
    model = trainable.model
    config = trial.config
    polished_model = ButterflyProduct(size=config['size'],
                                      complex=model.complex,
                                      fixed_order=True)
    if not model.fixed_order:
        prob = model.softmax_fn(model.logit)
        maxes, argmaxes = torch.max(prob, dim=-1)
        # print(maxes)
        # if torch.all(maxes >= 0.99):
        polished_model.butterflies = nn.ModuleList(
            [model.butterflies[argmax] for argmax in argmaxes])
        # else:
        #     return -trial.last_result['negative_loss']
    else:
        polished_model.butterflies = model.butterflies
    optimizer = optim.LBFGS(polished_model.parameters())

    def closure():
        optimizer.zero_grad()
        loss = nn.functional.mse_loss(
            polished_model.matrix()[:, trainable.br_perm],
            trainable.target_matrix)
        loss.backward()
        return loss

    for i in range(N_LBFGS_STEPS):
        optimizer.step(closure)
    torch.save(
        polished_model.state_dict(),
        str((Path(trial.logdir) / trial._checkpoint.value).parent /
            'polished_model.pth'))
    loss = nn.functional.mse_loss(
        polished_model.matrix()[:, trainable.br_perm], trainable.target_matrix)
    return loss.item()