def test_set_custom_extension():
    """Test the method set_custom_extension of BackpropExtension."""
    class _A(Module):
        pass

    class _ABatchGrad(FirstOrderModuleExtension):
        pass

    class _AVariance(FirstOrderModuleExtension):
        pass

    class _MyLinearBatchGrad(FirstOrderModuleExtension):
        pass

    grad_batch = BatchGrad()

    # Set module extension
    grad_batch.set_module_extension(_A, _ABatchGrad())

    # setting again should raise a ValueError
    with pytest.raises(ValueError):
        grad_batch.set_module_extension(_A, _ABatchGrad())

    # setting again with overwrite
    grad_batch.set_module_extension(_A, _ABatchGrad(), overwrite=True)

    # in a different extension, set another extension for the same module
    variance = Variance()
    variance.set_module_extension(_A, _AVariance())

    # set an extension for an already existing extension
    with pytest.raises(ValueError):
        grad_batch.set_module_extension(Linear, _MyLinearBatchGrad())

    grad_batch.set_module_extension(Linear,
                                    _MyLinearBatchGrad(),
                                    overwrite=True)
def total_variance_and_gradient_backpack(x, y, model, lossfunc):
    """Computes the total variance of the individual gradients and its gradient.

    Uses BackPACK's :py:meth:`Variance <backpack.extensions.Variance>`
    and PyTorch's :py:meth:`backward() <torch.Tensor.backward>`
    pass with the argument ``create_graph=True``.
    """
    model.zero_grad()
    loss = lossfunc(model(x), y)
    with backpack(Variance()):
        loss.backward(retain_graph=True, create_graph=True)

    total_var = 0
    for p in model.parameters():
        total_var += torch.sum(p.variance)

    grad_of_var = torch.autograd.grad(total_var, model.parameters())

    return total_var, grad_of_var
Esempio n. 3
0
# Batch gradients

loss = lossfunc(model(X), y)
with backpack(BatchGrad()):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".grad_batch.shape:       ", param.grad_batch.shape)

# %%
# Variance

loss = lossfunc(model(X), y)
with backpack(Variance()):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".variance.shape:         ", param.variance.shape)

# %%
# Second moment/sum of gradients squared

loss = lossfunc(model(X), y)
with backpack(SumGradSquared()):
    loss.backward()

for name, param in model.named_parameters():
Esempio n. 4
0
# forward pass
batch_loss, _ = tproblem.get_batch_loss_and_accuracy()

# individual loss
savefield = "_unreduced_loss"
individual_loss = getattr(batch_loss, savefield)

print("Individual loss shape:   ", individual_loss.shape)
print("Mini-batch loss:         ", batch_loss)
print("Averaged individual loss:", individual_loss.mean())

# It is still possible to use BackPACK in the backward pass
with backpack(
        BatchGrad(),
        Variance(),
        SumGradSquared(),
        BatchL2Grad(),
        DiagGGNExact(),
        DiagGGNMC(),
        KFAC(),
        KFLR(),
        KFRA(),
        DiagHessian(),
):
    batch_loss.backward()

# print info
for name, param in tproblem.net.named_parameters():
    print(name)
    print("\t.grad.shape:             ", param.grad.shape)