def batch_grad(self): with backpack(new_ext.BatchGrad()): _, _, loss = self.problem.forward_pass() loss.backward() batch_grads = [ p.grad_batch for p in self.problem.model.parameters() ] return batch_grads
def batch_l2_grad_extension_hook(self): """Individual gradient squared ℓ₂ norms via extension hook.""" hook = ExtensionHookManager(BatchL2GradHook()) with backpack(new_ext.BatchGrad(), extension_hook=hook): _, _, loss = self.problem.forward_pass() loss.backward() batch_l2_grad = [ p.batch_l2_hook for p in self.problem.model.parameters() ] return batch_l2_grad
def sgs_extension_hook(self) -> List[Tensor]: """Individual gradient second moment via extension hook. Returns: Parameter-wise individual gradient second moment. """ hook = ExtensionHookManager(SumGradSquaredHook()) with backpack(new_ext.BatchGrad(), extension_hook=hook): _, _, loss = self.problem.forward_pass() loss.backward() return self.problem.collect_data("sum_grad_squared_hook")
def batch_l2_grad_extension_hook(self) -> List[Tensor]: """Individual gradient squared ℓ₂ norms via extension hook. Returns: Parameter-wise individual gradient norms. """ hook = ExtensionHookManager(BatchL2GradHook()) with backpack(new_ext.BatchGrad(), extension_hook=hook): _, _, loss = self.problem.forward_pass() loss.backward() return self.problem.collect_data("batch_l2_hook")
def sgs_extension_hook(self): """Individual gradient second moment via extension hook.""" hook = ExtensionHookManager(SumGradSquaredHook()) with backpack(new_ext.BatchGrad(), extension_hook=hook): _, _, loss = self.problem.forward_pass() loss.backward() sgs = [ p.sum_grad_squared_hook for p in self.problem.model.parameters() ] return sgs
def test_grad(): """Test computation of bias/weight gradients.""" for ex in EXAMPLES: input, b_grad, w_grad = ex["in"], ex["bias_grad"], ex["weight_grad"] loss = loss_function(g_lin(input)) with backpack(new_ext.BatchGrad()): loss.backward() assert allclose(g_lin.bias.grad, b_grad) assert allclose(g_lin.weight.grad, w_grad) del g_lin.bias.grad del g_lin.weight.grad
def extensions(self, global_step): """Return list of BackPACK extensions required for the computation. Args: global_step (int): The current iteration number. Returns: list: (Potentially empty) list with required BackPACK quantities. """ ext = [] if self.is_start(global_step) or self.is_end(global_step): ext.append(extensions.BatchGrad()) return ext
def compare_grads(conv2d, g_conv2d, input): """Feed input through nn and exts conv2d, compare bias/weight grad.""" loss = loss_function(conv2d(input)) loss.backward() loss_g = loss_function(g_conv2d(input)) with backpack(new_ext.BatchGrad()): loss_g.backward() assert allclose(g_conv2d.bias.grad, conv2d.bias.grad, atol=TEST_ATOL) assert allclose(g_conv2d.weight.grad, conv2d.weight.grad, atol=TEST_ATOL) assert allclose(g_conv2d.bias.grad_batch.sum(0), conv2d.bias.grad, atol=TEST_ATOL) assert allclose( g_conv2d.weight.grad_batch.sum(0), conv2d.weight.grad, atol=TEST_ATOL )
def test_grad_batch(): """Test computation of bias/weight batch gradients.""" for ex in EXAMPLES: input, b_grad_batch, w_grad_batch = ex["in"], ex[ "bias_grad_batch"], ex["weight_grad_batch"] loss = loss_function(g_lin(input)) with backpack(new_ext.BatchGrad()): loss.backward() assert allclose(g_lin.bias.grad_batch, b_grad_batch), "{} ≠ {}".format( g_lin.bias.grad_batch, b_grad_batch) assert allclose(g_lin.weight.grad_batch, w_grad_batch), "{} ≠ {}".format( g_lin.weight.grad_batch, w_grad_batch) del g_lin.bias.grad del g_lin.weight.grad
def extensions(self, global_step): """Return list of BackPACK extensions required for the computation. Args: global_step (int): The current iteration number. Returns: list: (Potentially empty) list with required BackPACK quantities. """ if self.is_active(global_step): ext = [BatchGradTransforms_BatchDotGrad()] if self._check: ext.append(extensions.BatchGrad()) else: ext = [] return ext
def test_same_batch_grad(tproblem_cls, batch_size=3, seed=0): """Test individual gradients from unreduced losses match with BackPACK. Args: tproblem (TestProblem): DeepOBS test problem class. """ # via backpack tproblem1 = set_up_problem( tproblem_cls, batch_size=batch_size, seed=seed, extend=True ) loss1, acc1 = tproblem1.get_batch_loss_and_accuracy() with backpack(extensions.BatchGrad()): loss1.backward() batch_grad1 = [p.grad_batch for p in tproblem1.net.parameters() if p.requires_grad] # via autograd tproblem2 = set_up_problem( tproblem_cls, batch_size=batch_size, seed=seed, extend=True, unreduced_loss=True ) loss2, acc2 = tproblem2.get_batch_loss_and_accuracy() loss2_unreduced = loss2._unreduced_loss factor = get_reduction_factor(loss2, loss2_unreduced) # backpack assumes N individual losses, but MSELoss does not reduce non-batch # axes, so we have to do it manually if tproblem_cls == quadratic_deep: loss2_unreduced = loss2_unreduced.flatten(start_dim=1).sum(1) trainable_params = [p for p in tproblem2.net.parameters() if p.requires_grad] batch_grad2 = [ torch.zeros(batch_size, *p.shape, device=p.device) for p in trainable_params ] for i in range(batch_size): retain_graph = True if i < batch_size - 1 else False l_i = loss2_unreduced[i] grad = torch.autograd.grad(l_i, trainable_params, retain_graph=retain_graph) for param_idx, g in enumerate(grad): batch_grad2[param_idx][i] = g * factor check_sizes_and_values(batch_grad1, batch_grad2)
def compute_individual_gradients(device, seed=0): """Compute individual gradients for the seeded problem specified in ``setup``. Args: device (torch.device): Device that the computation should be performed on. seed (int): Random seed to set before setting up the problem. Returns: Dictionary with parameter name and individual gradients as key value pairs. """ torch.manual_seed(seed) X, y, model, lossfunc = setup(device) loss = lossfunc(model(X), y) with backpack(extensions.BatchGrad()): loss.backward() return {name: param.grad_batch for name, param in model.named_parameters()}
def test_extension_hook_param_before_savefield_exists(device): """Extension hooks iterating over parameters may get called before BackPACK.""" _, loss = set_up(device) params_without_grad_batch = [] def check_grad_batch(module): """Raise ``AssertionError`` if one parameter misses ``'grad_batch'``.""" for p in module.parameters(): if not hasattr(p, "grad_batch"): params_without_grad_batch.append(id(p)) raise AssertionError(f"Param {id(p)} has no 'grad_batch' attribute") # AssertionError is caught inside BackPACK and will raise a RuntimeError with pytest.raises(RuntimeError): with backpack( extensions.BatchGrad(), extension_hook=check_grad_batch, debug=True ): loss.backward() assert len(params_without_grad_batch) > 0
def extensions(self, global_step): """Return list of BackPACK extensions required for the computation. Args: global_step (int): The current iteration number. Raises: KeyError: If curvature string has unknown associated extension. Returns: list: (Potentially empty) list with required BackPACK quantities. """ ext = [] if self.should_compute(global_step): ext.append(extensions.BatchGrad()) try: ext.append(self.extensions_from_str[self._curvature]()) except KeyError as e: available = list(self.extensions_from_str.keys()) raise KeyError(f"Available: {available}") from e return ext
def test_interface_batch_grad(): interface_test(new_ext.BatchGrad())
def test_interface_batch_grad_conv(): interface_test(new_ext.BatchGrad(), use_conv=True)
# -------------------------------- model = Sequential(Flatten(), Linear(784, 10),) lossfunc = CrossEntropyLoss() model = extend(model) lossfunc = extend(lossfunc) # %% # We can now evaluate the loss and do a backward pass with Backpack # ----------------------------------------------------------------- loss = lossfunc(model(X), y) with backpack( extensions.BatchGrad(), extensions.Variance(), extensions.SumGradSquared(), extensions.BatchL2Grad(), extensions.DiagGGNMC(mc_samples=1), extensions.DiagGGNExact(), extensions.DiagHessian(), extensions.KFAC(mc_samples=1), extensions.KFLR(), extensions.KFRA(), ): loss.backward() # %% # And here are the results # -----------------------------------------------------------------
batch_size, seed=0, extend=use_backpack) if add_individual_loss: tp = integrate_individual_loss(tp) loss, acc = tp.get_batch_loss_and_accuracy( add_regularization_if_available=False) if add_individual_loss: print("Individual loss shape: {}".format( loss._deepobs_unreduced_loss.shape)) if use_backpack: with backpack(extensions.BatchGrad()): loss.backward() losses.append(loss.item()) accuracies.append(acc) same_loss = losses[0] == losses[1] same_acc = accuracies[0] == accuracies[1] same = same_loss and same_acc same_symbol = "✓" if same else "❌" print("{} [{}, BackPACK: {}] losses: {}, accuracies: {}".format( same_symbol, tp_cls.__name__, use_backpack,
def batch_grad(self, subsampling) -> List[Tensor]: # noqa:D102 with backpack(new_ext.BatchGrad(subsampling=subsampling)): _, _, loss = self.problem.forward_pass() loss.backward() return self.problem.collect_data("grad_batch")
def batch_gradients(self): with backpack(new_ext.BatchGrad()): self.loss().backward() batch_grads = [p.grad_batch for p in self.model.parameters()] return batch_grads
def extensions_fn(): return [ extensions.BatchGrad(), ]