예제 #1
0
def setup(device):
    """Load MNIST batch, create extended CNN and loss function. Load to device.

    Args:
        device (torch.device): Device that all objects are transferred to.

    Returns:
        inputs, labels, model, loss function
    """
    X, y = load_one_batch_mnist(batch_size=64)
    X, y = X.to(device), y.to(device)

    model = extend(
        Sequential(
            Conv2d(1, 128, 3, padding=1),
            ReLU(),
            MaxPool2d(3, stride=2),
            Conv2d(128, 256, 3, padding=1),
            ReLU(),
            MaxPool2d(3, padding=1, stride=2),
            Conv2d(256, 64, 3, padding=1),
            ReLU(),
            MaxPool2d(3, stride=2),
            Conv2d(64, 32, 3, padding=1),
            ReLU(),
            MaxPool2d(3, stride=2),
            Flatten(),
            Linear(32, 10),
        ).to(device)
    )

    lossfunc = extend(CrossEntropyLoss().to(device))

    return X, y, model, lossfunc
    torch.nn.Flatten(),
    torch.nn.Linear(784, 20, bias=False),
    torch.nn.Sigmoid(),
    torch.nn.Linear(20, 10, bias=False),
).to(DEVICE)
model = extend(model)

loss_function = torch.nn.CrossEntropyLoss().to(DEVICE)
loss_function = extend(loss_function)

# %%
# In the following, we load a batch from MNIST, compute the loss and trigger the
# backward pass ``with(backpack(..))`` such that we have access to the extensions that
# we are going to use (``DiagHessian`` and ``HMP)``).

x, y = load_one_batch_mnist(BATCH_SIZE)
x, y = x.to(DEVICE), y.to(DEVICE)


def forward_backward_with_backpack():
    """Provide working access to BackPACK's `DiagHessian` and `HMP`."""
    loss = loss_function(model(x), y)

    with backpack(DiagHessian(), HMP()):
        # keep graph for autodiff HVPs
        loss.backward(retain_graph=True)

    return loss


loss = forward_backward_with_backpack()
예제 #3
0
# Let's start by loading some dummy data and extending the model

import torch
from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential

from backpack import backpack, extend
from backpack.extensions import DiagHessian
from backpack.utils.examples import load_one_batch_mnist

# make deterministic
torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# data
X, y = load_one_batch_mnist(batch_size=128)
X, y = X.to(device), y.to(device)

# model
model = Sequential(Flatten(), Linear(784, 10)).to(device)
lossfunc = CrossEntropyLoss().to(device)

model = extend(model)
lossfunc = extend(lossfunc)

# %%
# Standard computation of the trace
# ---------------------------------

loss = lossfunc(model(X), y)
예제 #4
0
    BatchDiagGGNExact,
    BatchDiagGGNMC,
    BatchDiagHessian,
    BatchGrad,
    BatchL2Grad,
    DiagGGNExact,
    DiagGGNMC,
    DiagHessian,
    SqrtGGNExact,
    SqrtGGNMC,
    SumGradSquared,
    Variance,
)
from backpack.utils.examples import load_one_batch_mnist

X, y = load_one_batch_mnist(batch_size=512)

model = Sequential(Flatten(), Linear(784, 10))
lossfunc = CrossEntropyLoss()

model = extend(model)
lossfunc = extend(lossfunc)

# %%
# First order extensions
# ----------------------

# %%
# Batch gradients

loss = lossfunc(model(X), y)
#

N, D = 3, 2
x = torch.randn(N, D).to(DEVICE)
y = torch.randn(N, 1).to(DEVICE)
model = extend(nn.Sequential(nn.Linear(D, 1, bias=False))).to(DEVICE)
lossfunc = torch.nn.MSELoss(reduction="sum")

check_same_results(x, y, model, lossfunc)

# %%
# We can also try a linear model on MNIST data
#


x, y = load_one_batch_mnist(batch_size=32)
x, y = x.to(DEVICE), y.to(DEVICE)

model = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(784, 10))
model = extend(model).to(DEVICE)

lossfunc = torch.nn.CrossEntropyLoss(reduction="sum")

check_same_results(x, y, model, lossfunc)

# %%
# And a small CNN for some architecture variety
#


x, y = load_one_batch_mnist(batch_size=32)
from backpack import backpack, extend
from backpack.extensions import BatchGrad
from backpack.utils.examples import load_one_batch_mnist

BATCH_SIZE = 3
torch.manual_seed(0)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def get_accuracy(output, targets):
    """Helper function to print the accuracy"""
    predictions = output.argmax(dim=1, keepdim=True).view_as(targets)
    return predictions.eq(targets).float().mean().item()


x, y = load_one_batch_mnist(batch_size=BATCH_SIZE)
x, y = x.to(DEVICE), y.to(DEVICE)


# %%
# We can build a ResNet by extending :py:class:`torch.nn.Module`.
# As long as the layers with parameters
# (:py:class:`torch.nn.Conv2d` and :py:class:`torch.nn.Linear`) are
# ``nn`` modules, BackPACK can extend them,
# and this is all that is needed for first order extensions.
# We can rewrite the forward to implement the residual connection,
# and :py:func:`extend() <backpack.extend>` the resulting model.


class MyFirstResNet(torch.nn.Module):
    def __init__(self, C_in=1, C_hid=5, input_dim=(28, 28), output_dim=10):