def test_set_custom_extension(): """Test the method set_custom_extension of BackpropExtension.""" class _A(Module): pass class _ABatchGrad(FirstOrderModuleExtension): pass class _AVariance(FirstOrderModuleExtension): pass class _MyLinearBatchGrad(FirstOrderModuleExtension): pass grad_batch = BatchGrad() # Set module extension grad_batch.set_module_extension(_A, _ABatchGrad()) # setting again should raise a ValueError with pytest.raises(ValueError): grad_batch.set_module_extension(_A, _ABatchGrad()) # setting again with overwrite grad_batch.set_module_extension(_A, _ABatchGrad(), overwrite=True) # in a different extension, set another extension for the same module variance = Variance() variance.set_module_extension(_A, _AVariance()) # set an extension for an already existing extension with pytest.raises(ValueError): grad_batch.set_module_extension(Linear, _MyLinearBatchGrad()) grad_batch.set_module_extension(Linear, _MyLinearBatchGrad(), overwrite=True)
def total_variance_and_gradient_backpack(x, y, model, lossfunc): """Computes the total variance of the individual gradients and its gradient. Uses BackPACK's :py:meth:`Variance <backpack.extensions.Variance>` and PyTorch's :py:meth:`backward() <torch.Tensor.backward>` pass with the argument ``create_graph=True``. """ model.zero_grad() loss = lossfunc(model(x), y) with backpack(Variance()): loss.backward(retain_graph=True, create_graph=True) total_var = 0 for p in model.parameters(): total_var += torch.sum(p.variance) grad_of_var = torch.autograd.grad(total_var, model.parameters()) return total_var, grad_of_var
# Batch gradients loss = lossfunc(model(X), y) with backpack(BatchGrad()): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".grad_batch.shape: ", param.grad_batch.shape) # %% # Variance loss = lossfunc(model(X), y) with backpack(Variance()): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".variance.shape: ", param.variance.shape) # %% # Second moment/sum of gradients squared loss = lossfunc(model(X), y) with backpack(SumGradSquared()): loss.backward() for name, param in model.named_parameters():
# forward pass batch_loss, _ = tproblem.get_batch_loss_and_accuracy() # individual loss savefield = "_unreduced_loss" individual_loss = getattr(batch_loss, savefield) print("Individual loss shape: ", individual_loss.shape) print("Mini-batch loss: ", batch_loss) print("Averaged individual loss:", individual_loss.mean()) # It is still possible to use BackPACK in the backward pass with backpack( BatchGrad(), Variance(), SumGradSquared(), BatchL2Grad(), DiagGGNExact(), DiagGGNMC(), KFAC(), KFLR(), KFRA(), DiagHessian(), ): batch_loss.backward() # print info for name, param in tproblem.net.named_parameters(): print(name) print("\t.grad.shape: ", param.grad.shape)