Exemple #1
0
 def batch_grad(self):
     with backpack(new_ext.BatchGrad()):
         _, _, loss = self.problem.forward_pass()
         loss.backward()
         batch_grads = [
             p.grad_batch for p in self.problem.model.parameters()
         ]
     return batch_grads
 def batch_l2_grad_extension_hook(self):
     """Individual gradient squared ℓ₂ norms via extension hook."""
     hook = ExtensionHookManager(BatchL2GradHook())
     with backpack(new_ext.BatchGrad(), extension_hook=hook):
         _, _, loss = self.problem.forward_pass()
         loss.backward()
         batch_l2_grad = [
             p.batch_l2_hook for p in self.problem.model.parameters()
         ]
     return batch_l2_grad
Exemple #3
0
    def sgs_extension_hook(self) -> List[Tensor]:
        """Individual gradient second moment via extension hook.

        Returns:
            Parameter-wise individual gradient second moment.
        """
        hook = ExtensionHookManager(SumGradSquaredHook())
        with backpack(new_ext.BatchGrad(), extension_hook=hook):
            _, _, loss = self.problem.forward_pass()
            loss.backward()
        return self.problem.collect_data("sum_grad_squared_hook")
Exemple #4
0
    def batch_l2_grad_extension_hook(self) -> List[Tensor]:
        """Individual gradient squared ℓ₂ norms via extension hook.

        Returns:
            Parameter-wise individual gradient norms.
        """
        hook = ExtensionHookManager(BatchL2GradHook())
        with backpack(new_ext.BatchGrad(), extension_hook=hook):
            _, _, loss = self.problem.forward_pass()
            loss.backward()
        return self.problem.collect_data("batch_l2_hook")
 def sgs_extension_hook(self):
     """Individual gradient second moment via extension hook."""
     hook = ExtensionHookManager(SumGradSquaredHook())
     with backpack(new_ext.BatchGrad(), extension_hook=hook):
         _, _, loss = self.problem.forward_pass()
         loss.backward()
         sgs = [
             p.sum_grad_squared_hook
             for p in self.problem.model.parameters()
         ]
     return sgs
Exemple #6
0
def test_grad():
    """Test computation of bias/weight gradients."""
    for ex in EXAMPLES:
        input, b_grad, w_grad = ex["in"], ex["bias_grad"], ex["weight_grad"]

        loss = loss_function(g_lin(input))
        with backpack(new_ext.BatchGrad()):
            loss.backward()

        assert allclose(g_lin.bias.grad, b_grad)
        assert allclose(g_lin.weight.grad, w_grad)

        del g_lin.bias.grad
        del g_lin.weight.grad
Exemple #7
0
    def extensions(self, global_step):
        """Return list of BackPACK extensions required for the computation.

        Args:
            global_step (int): The current iteration number.

        Returns:
            list: (Potentially empty) list with required BackPACK quantities.
        """
        ext = []

        if self.is_start(global_step) or self.is_end(global_step):
            ext.append(extensions.BatchGrad())

        return ext
Exemple #8
0
def compare_grads(conv2d, g_conv2d, input):
    """Feed input through nn and exts conv2d, compare bias/weight grad."""
    loss = loss_function(conv2d(input))
    loss.backward()

    loss_g = loss_function(g_conv2d(input))
    with backpack(new_ext.BatchGrad()):
        loss_g.backward()

    assert allclose(g_conv2d.bias.grad, conv2d.bias.grad, atol=TEST_ATOL)
    assert allclose(g_conv2d.weight.grad, conv2d.weight.grad, atol=TEST_ATOL)
    assert allclose(g_conv2d.bias.grad_batch.sum(0), conv2d.bias.grad, atol=TEST_ATOL)
    assert allclose(
        g_conv2d.weight.grad_batch.sum(0), conv2d.weight.grad, atol=TEST_ATOL
    )
def test_grad_batch():
    """Test computation of bias/weight batch gradients."""
    for ex in EXAMPLES:
        input, b_grad_batch, w_grad_batch = ex["in"], ex[
            "bias_grad_batch"], ex["weight_grad_batch"]

        loss = loss_function(g_lin(input))
        with backpack(new_ext.BatchGrad()):
            loss.backward()

        assert allclose(g_lin.bias.grad_batch, b_grad_batch), "{} ≠ {}".format(
            g_lin.bias.grad_batch, b_grad_batch)
        assert allclose(g_lin.weight.grad_batch,
                        w_grad_batch), "{} ≠ {}".format(
                            g_lin.weight.grad_batch, w_grad_batch)

        del g_lin.bias.grad
        del g_lin.weight.grad
Exemple #10
0
    def extensions(self, global_step):
        """Return list of BackPACK extensions required for the computation.

        Args:
            global_step (int): The current iteration number.

        Returns:
            list: (Potentially empty) list with required BackPACK quantities.
        """
        if self.is_active(global_step):
            ext = [BatchGradTransforms_BatchDotGrad()]

            if self._check:
                ext.append(extensions.BatchGrad())

        else:
            ext = []
        return ext
Exemple #11
0
def test_same_batch_grad(tproblem_cls, batch_size=3, seed=0):
    """Test individual gradients from unreduced losses match with BackPACK.

    Args:
        tproblem (TestProblem): DeepOBS test problem class.
    """
    # via backpack
    tproblem1 = set_up_problem(
        tproblem_cls, batch_size=batch_size, seed=seed, extend=True
    )
    loss1, acc1 = tproblem1.get_batch_loss_and_accuracy()
    with backpack(extensions.BatchGrad()):
        loss1.backward()

    batch_grad1 = [p.grad_batch for p in tproblem1.net.parameters() if p.requires_grad]

    # via autograd
    tproblem2 = set_up_problem(
        tproblem_cls, batch_size=batch_size, seed=seed, extend=True, unreduced_loss=True
    )
    loss2, acc2 = tproblem2.get_batch_loss_and_accuracy()
    loss2_unreduced = loss2._unreduced_loss
    factor = get_reduction_factor(loss2, loss2_unreduced)

    # backpack assumes N individual losses, but MSELoss does not reduce non-batch
    # axes, so we have to do it manually
    if tproblem_cls == quadratic_deep:
        loss2_unreduced = loss2_unreduced.flatten(start_dim=1).sum(1)

    trainable_params = [p for p in tproblem2.net.parameters() if p.requires_grad]
    batch_grad2 = [
        torch.zeros(batch_size, *p.shape, device=p.device) for p in trainable_params
    ]

    for i in range(batch_size):
        retain_graph = True if i < batch_size - 1 else False
        l_i = loss2_unreduced[i]
        grad = torch.autograd.grad(l_i, trainable_params, retain_graph=retain_graph)
        for param_idx, g in enumerate(grad):
            batch_grad2[param_idx][i] = g * factor

    check_sizes_and_values(batch_grad1, batch_grad2)
Exemple #12
0
def compute_individual_gradients(device, seed=0):
    """Compute individual gradients for the seeded problem specified in ``setup``.

    Args:
        device (torch.device): Device that the computation should be performed on.
        seed (int): Random seed to set before setting up the problem.

    Returns:
        Dictionary with parameter name and individual gradients as key value pairs.
    """
    torch.manual_seed(seed)

    X, y, model, lossfunc = setup(device)

    loss = lossfunc(model(X), y)

    with backpack(extensions.BatchGrad()):
        loss.backward()

    return {name: param.grad_batch for name, param in model.named_parameters()}
Exemple #13
0
def test_extension_hook_param_before_savefield_exists(device):
    """Extension hooks iterating over parameters may get called before BackPACK."""
    _, loss = set_up(device)

    params_without_grad_batch = []

    def check_grad_batch(module):
        """Raise ``AssertionError`` if one parameter misses ``'grad_batch'``."""
        for p in module.parameters():
            if not hasattr(p, "grad_batch"):
                params_without_grad_batch.append(id(p))
                raise AssertionError(f"Param {id(p)} has no 'grad_batch' attribute")

    # AssertionError is caught inside BackPACK and will raise a RuntimeError
    with pytest.raises(RuntimeError):
        with backpack(
            extensions.BatchGrad(), extension_hook=check_grad_batch, debug=True
        ):
            loss.backward()

    assert len(params_without_grad_batch) > 0
Exemple #14
0
    def extensions(self, global_step):
        """Return list of BackPACK extensions required for the computation.

        Args:
            global_step (int): The current iteration number.

        Raises:
            KeyError: If curvature string has unknown associated extension.

        Returns:
            list: (Potentially empty) list with required BackPACK quantities.
        """
        ext = []

        if self.should_compute(global_step):
            ext.append(extensions.BatchGrad())
            try:
                ext.append(self.extensions_from_str[self._curvature]())
            except KeyError as e:
                available = list(self.extensions_from_str.keys())
                raise KeyError(f"Available: {available}") from e

        return ext
Exemple #15
0
def test_interface_batch_grad():
    interface_test(new_ext.BatchGrad())
Exemple #16
0
def test_interface_batch_grad_conv():
    interface_test(new_ext.BatchGrad(), use_conv=True)
Exemple #17
0
# --------------------------------

model = Sequential(Flatten(), Linear(784, 10),)
lossfunc = CrossEntropyLoss()

model = extend(model)
lossfunc = extend(lossfunc)

# %%
# We can now evaluate the loss and do a backward pass with Backpack
# -----------------------------------------------------------------

loss = lossfunc(model(X), y)

with backpack(
    extensions.BatchGrad(),
    extensions.Variance(),
    extensions.SumGradSquared(),
    extensions.BatchL2Grad(),
    extensions.DiagGGNMC(mc_samples=1),
    extensions.DiagGGNExact(),
    extensions.DiagHessian(),
    extensions.KFAC(mc_samples=1),
    extensions.KFLR(),
    extensions.KFRA(),
):
    loss.backward()

# %%
# And here are the results
# -----------------------------------------------------------------
Exemple #18
0
                                    batch_size,
                                    seed=0,
                                    extend=use_backpack)

                if add_individual_loss:
                    tp = integrate_individual_loss(tp)

                loss, acc = tp.get_batch_loss_and_accuracy(
                    add_regularization_if_available=False)

                if add_individual_loss:
                    print("Individual loss shape: {}".format(
                        loss._deepobs_unreduced_loss.shape))

                if use_backpack:
                    with backpack(extensions.BatchGrad()):
                        loss.backward()

                losses.append(loss.item())
                accuracies.append(acc)

            same_loss = losses[0] == losses[1]
            same_acc = accuracies[0] == accuracies[1]

            same = same_loss and same_acc
            same_symbol = "✓" if same else "❌"

            print("{} [{}, BackPACK: {}] losses: {}, accuracies: {}".format(
                same_symbol,
                tp_cls.__name__,
                use_backpack,
Exemple #19
0
 def batch_grad(self, subsampling) -> List[Tensor]:  # noqa:D102
     with backpack(new_ext.BatchGrad(subsampling=subsampling)):
         _, _, loss = self.problem.forward_pass()
         loss.backward()
     return self.problem.collect_data("grad_batch")
 def batch_gradients(self):
     with backpack(new_ext.BatchGrad()):
         self.loss().backward()
         batch_grads = [p.grad_batch for p in self.model.parameters()]
     return batch_grads
Exemple #21
0
def extensions_fn():
    return [
        extensions.BatchGrad(),
    ]