Esempio n. 1
0
def test_network_diag_ggn(model_and_input):
    """Test whether the given module can compute diag_ggn.

    This test is placed here, because some models are too big to run with PyTorch.
    Thus, a full diag_ggn comparison with PyTorch is impossible.
    This test just checks whether it runs on BackPACK without errors.
    Additionally, it checks whether the forward pass is identical to the original model.
    Finally, a small number of elements of DiagGGN are compared.

    Args:
        model_and_input: module to test

    Raises:
        NotImplementedError: if loss_fn is not MSELoss or CrossEntropyLoss
    """
    model_original, x, loss_fn = model_and_input
    model_original = model_original.eval()
    output_compare = model_original(x)
    if isinstance(loss_fn, MSELoss):
        y = regression_targets(output_compare.shape)
    elif isinstance(loss_fn, CrossEntropyLoss):
        y = classification_targets(
            (output_compare.shape[0], *output_compare.shape[2:]),
            output_compare.shape[1],
        )
    else:
        raise NotImplementedError(
            f"test cannot handle loss_fn = {type(loss_fn)}")

    num_params = sum(p.numel() for p in model_original.parameters()
                     if p.requires_grad)
    num_to_compare = 10
    idx_to_compare = linspace(0, num_params - 1, num_to_compare, dtype=int32)
    diag_ggn_exact_to_compare = autograd_diag_ggn_exact(x,
                                                        y,
                                                        model_original,
                                                        loss_fn,
                                                        idx=idx_to_compare)

    model_extended = extend(model_original, use_converter=True, debug=True)
    output = model_extended(x)

    assert allclose(output, output_compare)

    loss = extend(loss_fn)(output, y)

    with backpack(DiagGGNExact()):
        loss.backward()

    diag_ggn_exact_vector = cat([
        p.diag_ggn_exact.flatten() for p in model_extended.parameters()
        if p.requires_grad
    ])

    for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare):
        assert allclose(element, diag_ggn_exact_vector[idx], atol=1e-5)
Esempio n. 2
0
def test_graph_clear(problem) -> None:
    """Test that the graph is clear after a backward pass.

    More specifically, test that there are no saved quantities left over.

    Args:
        problem: problem consisting of inputs, and model
    """
    inputs, model = problem
    extension = DiagGGNExact()
    outputs = extend(model)(inputs)
    loss = extend(MSELoss())(outputs, rand_like(outputs))
    with backpack(extension):
        loss.backward()

    # test that the dictionary is empty
    saved_quantities: dict = extension.saved_quantities._saved_quantities
    assert type(saved_quantities) is dict
    assert not saved_quantities
extend(net_DiagGGNE)
extend(loss_function)
opt_DiagGGNE= DiagGGNEOptimizer(model.parameters(), step_size=STEP_SIZE, damping=DAMPING)
GGNE_l_his=[]
GGNE_a_his=[]

start=time.time()

for epoch in range(EPOCH):
  print('Epoch:', epoch)
  for step, (b_x, b_y) in enumerate(trainloader):
    b_x, b_y = b_x.to(DEVICE), b_y.to(DEVICE)
    output = net_DiagGGNE(b_x)              
    loss = loss_function(output, b_y) 

    with backpack(DiagGGNExact()):
      loss.backward()

    opt_DiagGGNE.step()   

    GGNE_a_his.append(get_accuracy(output, b_y))            
    GGNE_l_his.append(loss.data.numpy()) 

end = time.time()
GGNE_t=end-start

fig = plt.figure()
axes = [fig.add_subplot(1, 2, 1), fig.add_subplot(1, 2, 2)]

axes[0].plot(GGNE_l_his)
axes[0].set_title("Loss")
Esempio n. 4
0
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".grad_batch.shape:       ", param.grad_batch.shape)
    print(".variance.shape:         ", param.variance.shape)
    print(".sum_grad_squared.shape: ", param.sum_grad_squared.shape)
    print(".batch_l2.shape:         ", param.batch_l2.shape)

# %%
# Second order extensions
# --------------------------

# %%
# Diagonal of the generalized Gauss-Newton and its Monte-Carlo approximation

loss = lossfunc(model(X), y)
with backpack(DiagGGNExact(), DiagGGNMC(mc_samples=1)):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".diag_ggn_mc.shape:      ", param.diag_ggn_mc.shape)
    print(".diag_ggn_exact.shape:   ", param.diag_ggn_exact.shape)

# %%
# Per-sample diagonal of the generalized Gauss-Newton and its Monte-Carlo approximation

loss = lossfunc(model(X), y)
with backpack(BatchDiagGGNExact(), BatchDiagGGNMC(mc_samples=1)):
    loss.backward()
Esempio n. 5
0
# individual loss
savefield = "_unreduced_loss"
individual_loss = getattr(batch_loss, savefield)

print("Individual loss shape:   ", individual_loss.shape)
print("Mini-batch loss:         ", batch_loss)
print("Averaged individual loss:", individual_loss.mean())

# It is still possible to use BackPACK in the backward pass
with backpack(
        BatchGrad(),
        Variance(),
        SumGradSquared(),
        BatchL2Grad(),
        DiagGGNExact(),
        DiagGGNMC(),
        KFAC(),
        KFLR(),
        KFRA(),
        DiagHessian(),
):
    batch_loss.backward()

# print info
for name, param in tproblem.net.named_parameters():
    print(name)
    print("\t.grad.shape:             ", param.grad.shape)
    print("\t.grad_batch.shape:       ", param.grad_batch.shape)
    print("\t.variance.shape:         ", param.variance.shape)
    print("\t.sum_grad_squared.shape: ", param.sum_grad_squared.shape)
Esempio n. 6
0
                x = self.linear1(x)
                x = self.linear2(x)
                return x

        model = _MyCustomModule()
    else:
        raise NotImplementedError(
            f"problem={problem_string} but no test setting for this.")

    model = extend(model.to(device))
    lossfunc = extend(CrossEntropyLoss(reduction="mean").to(device))
    loss = lossfunc(model(X), y)
    yield model, loss, problem_string


@mark.parametrize("extension", [BatchGrad(), DiagGGNExact()],
                  ids=["BatchGrad", "DiagGGNExact"])
def test_extension_hook_multiple_parameter_visits(
        problem, extension: BackpropExtension):
    """Tests whether each parameter is visited exactly once.

    For those cases where parameters are visited more than once (e.g. Custom containers),
    it tests that an error is raised.

    Furthermore, it is tested whether first order extensions run fine in either case,
    and second order extensions raise an error in the case of custom containers.

    Args:
        problem: test problem, consisting of model, loss, and problem_string
        extension: first or second order extension to test
Esempio n. 7
0
        p.requires_grad = False

match = allclose(tolstoi_char_rnn_custom(x), tolstoi_char_rnn(x))
print(f"Forward pass of custom model matches TolstoiCharRNN? {match}")

if not match:
    raise AssertionError("Forward passes don't match.")

# %%
# We can :py:func:`extend <backpack.extend>` our model and the loss function to
# compute BackPACK extensions.

tolstoi_char_rnn_custom = extend(tolstoi_char_rnn_custom)
loss = loss_function(tolstoi_char_rnn_custom(x), y)

with backpack(BatchGrad(), DiagGGNExact()):
    loss.backward()

for name, param in tolstoi_char_rnn_custom.named_parameters():
    if param.requires_grad:
        print(
            name,
            param.shape,
            param.grad_batch.shape,
            param.diag_ggn_exact.shape,
        )

# %%
# Comparison of the GGN diagonal extension with :py:mod:`torch.autograd`:
#
# .. note::
Esempio n. 8
0
    def step(self, closure, clip_grad=False):
        # increment step
        self.state['step'] += 1

        # deterministic closure
        seed = time.time()
        def closure_deterministic(for_backtracking=False):
            with ut.random_seed_torch(int(seed)):
                return closure(for_backtracking)
        
        # get loss and compute gradients/second-order extensions
        loss = closure_deterministic()
        if self.base_opt == 'diag_hessian':
            with backpack(DiagHessian()):            
                loss.backward()

        elif self.base_opt == 'diag_ggn_ex':
            with backpack(DiagGGNExact()):
                loss.backward()

        elif self.base_opt == 'diag_ggn_mc':
            with backpack(DiagGGNMC()):
                loss.backward()
        else:
            loss.backward()

        if clip_grad:
            torch.nn.utils.clip_grad_norm_(self.params, 0.25)
        # increment # forward-backward calls
        self.state['n_forwards'] += 1
        self.state['n_backwards'] += 1        
        # save the current parameters:
        params_current = copy.deepcopy(self.params)
        grad_current = ut.get_grad_list(self.params)
        grad_norm = ut.compute_grad_norm(grad_current)

        # keep track of step
        if self.state['step'] % int(self.n_batches_per_epoch) == 1:
            self.state['step_size_avg'] = 0.

        # if grad_norm < 1e-6:
        #     return 0.

        #  Gv options
        # =============
        # update gv
        if self.base_opt == 'diag_hessian':
            # get diagonal hessian here and store it in state['gv']
            gv = [p.diag_h for p in self.params]

        elif self.base_opt == 'diag_ggn_ex':
            gv = [p.diag_ggn_exact for p in self.params]

        elif self.base_opt == 'diag_ggn_mc':
            gv = [p.diag_ggn_mc for p in self.params]

        else:
            raise ValueError('%s does not exist' % self.gv_update)

        for gv_i in gv:
            if torch.any(gv_i < 0):
                warnings.warn("%s contains negative values." % (self.gv_update))
                print(gv)

        if self.state['gv'] is None or self.accum_gv is None:
            self.state['gv'] = gv

            if self.accum_gv == 'avg':
                # first iteration
                self.state['gv_lag'] = [[gv_i] for gv_i in gv] 
                self.state['gv_sum'] = gv

        elif self.accum_gv == 'max':
            for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)):
                self.state['gv'][i] = torch.max(gv_old, gv_new)

        elif self.accum_gv == 'sum':
            for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)):
                self.state['gv'][i] = gv_old + gv_new   

        elif self.accum_gv == 'ams':
            for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)):
                gv_accum = self.beta*gv_old + (1-self.beta)*gv_new
                self.state['gv'][i] = torch.max(gv_old, gv_accum)

        elif self.accum_gv == 'ams_no_max':
            for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)):
                # same as above without the max
                gv_accum = self.beta*gv_old + (1-self.beta)*gv_new
                self.state['gv'][i] = gv_accum

        elif self.accum_gv == 'avg':
            t = self.state['step']
            
            for i, (gv_new, gv_lag) in enumerate(zip(gv, self.state['gv_lag'])):

                if t < self.avg_window:
                    # keep track of the sum only, no need to kick anyone out 
                    # for the first window number of iterations
                    gv_sum = self.state['gv_sum'][i] + gv_new
                    # take the average of all seen so far
                    gv_accum = gv_sum / t 

                else:
                    # kick out the gv stored for iteration (t-window)
                    gv_kick_out = gv_lag.pop(0)
                    # update the sum for the past window iterations
                    gv_sum = self.state['gv_sum'][i] + gv_new - gv_kick_out
                    # take the average gv within the window
                    gv_accum = gv_sum / self.avg_window

                # add the new gv to the lags and update sum
                self.state['gv_lag'][i].append(gv_new)
                self.state['sum'][i] = gv_sum
                # this will be used as the diagonal preconditioner
                self.state['gv'][i] = gv_accum

        else:
            raise ValueError('accum_gv %s does not exist' % self.accum_gv)

        if self.lm > 0:
            for i in range(len(self.state['gv'])):
                self.state['gv'][i] += self.lm 

        # compute pp norm method
        pp_norm = self.get_pp_norm(grad_current=grad_current)
        
        # compute step size - same as the SLS code but with a different norm pp_norm
        # =================
        step_size = self.get_step_size(closure_deterministic, loss, params_current, grad_current, grad_norm, pp_norm, 
                        for_backtracking=True)
            

        self.try_sgd_precond_update(self.params, step_size, params_current,
                            grad_current)

        # save the new step-size
        self.state['step_size'] = step_size

        self.state['step_size_avg'] += (step_size / int(self.n_batches_per_epoch))
        self.state['grad_norm'] = grad_norm.item()
  
        if torch.isnan(self.params[0]).sum() > 0:
            print('nan')

        return loss