def diag_ggn_mc(self, mc_samples): with backpack(new_ext.DiagGGNMC(mc_samples=mc_samples)): _, _, loss = self.problem.forward_pass() loss.backward() diag_ggn_mc = [ p.diag_ggn_mc for p in self.problem.model.parameters() ] return diag_ggn_mc
model = extend(model) lossfunc = extend(lossfunc) # %% # We can now evaluate the loss and do a backward pass with Backpack # ----------------------------------------------------------------- loss = lossfunc(model(X), y) with backpack( extensions.BatchGrad(), extensions.Variance(), extensions.SumGradSquared(), extensions.BatchL2Grad(), extensions.DiagGGNMC(mc_samples=1), extensions.DiagGGNExact(), extensions.DiagHessian(), extensions.KFAC(mc_samples=1), extensions.KFLR(), extensions.KFRA(), ): loss.backward() # %% # And here are the results # ----------------------------------------------------------------- for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape)
""" Compute the gradient with PyTorch and the MC-sampled GGN diagonal with BackPACK. """ from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential from backpack import backpack, extend, extensions from backpack.utils.examples import load_mnist_data B = 4 X, y = load_mnist_data(B) print("# Gradient with PyTorch, MC-sampled GGN diagonal with BackPACK | B =", B) model = Sequential(Flatten(), Linear(784, 10),) lossfunc = CrossEntropyLoss() model = extend(model) lossfunc = extend(lossfunc) loss = lossfunc(model(X), y) # number of MC samples is optional, defaults to 1 with backpack(extensions.DiagGGNMC(mc_samples=1)): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".diag_ggn_mc.shape: ", param.diag_ggn_mc.shape)
def diag_ggn_mc(self): with backpack(new_ext.DiagGGNMC()): self.loss().backward() diag_ggn = [p.diag_ggn_mc for p in self.model.parameters()] return diag_ggn
def diag_ggn_mc(self, mc_samples) -> List[Tensor]: # noqa:D102 with backpack(new_ext.DiagGGNMC(mc_samples=mc_samples)): _, _, loss = self.problem.forward_pass() loss.backward() return self.problem.collect_data("diag_ggn_mc")