Ejemplo n.º 1
0
 def batch_l2_grad(self):
     with backpack(new_ext.BatchL2Grad()):
         _, _, loss = self.problem.forward_pass()
         loss.backward()
         batch_l2_grad = [
             p.batch_l2 for p in self.problem.model.parameters()
         ]
     return batch_l2_grad
Ejemplo n.º 2
0
lossfunc = CrossEntropyLoss()

model = extend(model)
lossfunc = extend(lossfunc)

# %%
# We can now evaluate the loss and do a backward pass with Backpack
# -----------------------------------------------------------------

loss = lossfunc(model(X), y)

with backpack(
    extensions.BatchGrad(),
    extensions.Variance(),
    extensions.SumGradSquared(),
    extensions.BatchL2Grad(),
    extensions.DiagGGNMC(mc_samples=1),
    extensions.DiagGGNExact(),
    extensions.DiagHessian(),
    extensions.KFAC(mc_samples=1),
    extensions.KFLR(),
    extensions.KFRA(),
):
    loss.backward()

# %%
# And here are the results
# -----------------------------------------------------------------

for name, param in model.named_parameters():
    print(name)
Ejemplo n.º 3
0
from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential

from backpack import backpack, extend, extensions
from backpack.utils.examples import load_mnist_data

B = 4
X, y = load_mnist_data(B)

print(
    "# Gradient with PyTorch, individual gradients' L2 norms with BackPACK | B =",
    B)

model = Sequential(
    Flatten(),
    Linear(784, 10),
)
lossfunc = CrossEntropyLoss()

model = extend(model)
lossfunc = extend(lossfunc)

loss = lossfunc(model(X), y)

with backpack(extensions.BatchL2Grad()):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".batch_l2.shape:         ", param.batch_l2.shape)
Ejemplo n.º 4
0
 def batch_l2(self):
     with backpack(new_ext.BatchL2Grad()):
         self.loss().backward()
         batch_l2s = [p.batch_l2 for p in self.model.parameters()]
     return batch_l2s
Ejemplo n.º 5
0
 def batch_l2_grad(self) -> List[Tensor]:  # noqa:D102
     with backpack(new_ext.BatchL2Grad()):
         _, _, loss = self.problem.forward_pass()
         loss.backward()
     return self.problem.collect_data("batch_l2")