Example #1
0
    def compute_diag_fisher(self):
        """
        Arguments: None. Just use global variables (self.model, self.criterion, ...)
        Return: Diagonal Fisher matrix. 
        
        This function will be used in the function 'update_fisher'
        """

        self.fisher_iterator

        param_names = []
        loglikelihood_grads = {}
        for data, label in self.fisher_iterator:
            data = Variable(data)
            label = Variable(label)
            loglikelihood = CrossEntropyLoss()(self.model(data)[self.t], label)
            loglikelihood.backward()
            for n, p in self.model.named_parameters():
                n = n.replace('.', '__')
                if p.grad == None:
                    continue
                loglikelihood_grads[n] = (loglikelihood_grads.get(n, 0) +
                                          (p.grad**2)).mean(0)
                param_names.append(n)

        fisher_diagonals = [g for g in loglikelihood_grads.values()]
        dict = {n: f.detach() for n, f in zip(param_names, fisher_diagonals)}
        return dict
def train_step(net, train_dataset_loader, epoch):
    # set net to train mode
    net.train()

    learning_rate = BASE_LEARNING_RATE * LEARNING_RATE_DECAY_PER_EPOCH**epoch
    optimizer = optim.SGD(net.parameters(),
                          lr=learning_rate,
                          momentum=LEARNING_MOMENTUM)

    for i, (inputs, labels) in enumerate(train_dataset_loader):
        inputs, labels = Variable(inputs), Variable(labels)
        optimizer.zero_grad()
        output = net.forward(inputs)
        loss = CrossEntropyLoss().forward(output, labels)
        loss.backward()
        optimizer.step()

        if i % LOG_INTERVAL == 0:
            print('Train Epoch (LR: {}): {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.
                  format(learning_rate, epoch, i * len(inputs),
                         len(train_dataset_loader.dataset),
                         100. * i / len(train_dataset_loader), loss.data[0]))