Example #1
0
def test_compute_hessians(network, dataset, data_type, network_type, x, y):
    network.train()

    batch_size, num_iterations = get_batch_type(data_type)
    batch_sampler = BatchSampler(dataset, num_iterations,
                                 batch_size)  # train by iteration, not epoch
    data_loader = DataLoader(dataset,
                             batch_sampler=batch_sampler,
                             num_workers=4)

    network_seq = get_seq_network(network_type)
    network_seq = copy_network(network, network_seq)

    criterion = nn.CrossEntropyLoss()
    criterion = extend(criterion)
    network_seq = extend(network_seq).cuda()

    hessians = None
    x = x.cuda()
    y = y.cuda()
    if data_type == 'mnist':
        x = x.view(len(x), -1)

    out = network_seq(x)
    loss = criterion(out, y)

    with backpack(DiagHessian()):
        loss.backward()

    hessians = get_hessians(network_seq, hessians)

    return hessians, x, y
Example #2
0
def Diag_second_order(model, train_loader, prec0=10, device='cpu'):

    W = list(model.parameters())[-2]
    b = list(model.parameters())[-1]
    m, n = W.shape
    print("n: {} inputs to linear layer with m: {} classes".format(n, m))
    lossfunc = torch.nn.CrossEntropyLoss()

    var0 = 1 / prec0

    extend(lossfunc, debug=False)
    extend(model.linear, debug=False)

    with backpack(DiagHessian()):

        max_len = len(train_loader)
        weights_cov = torch.zeros(max_len, m, n, device=device)
        biases_cov = torch.zeros(max_len, m, device=device)

        for batch_idx, (x, y) in enumerate(train_loader):

            if device == 'cuda':
                x, y = x.cuda(), y.cuda()

            model.zero_grad()
            lossfunc(model(x), y).backward()

            with torch.no_grad():
                # Hessian of weight
                W_ = W.diag_h
                b_ = b.diag_h

                #add_prior: since it will be flattened later we can just add the prior like that
                W_ += var0 * torch.ones(W_.size(), device=device)
                b_ += var0 * torch.ones(b_.size(), device=device)

            weights_cov[batch_idx] = W_
            biases_cov[batch_idx] = b_

            print("Batch: {}/{}".format(batch_idx, max_len))

        print(len(weights_cov))
        C_W = torch.mean(weights_cov, dim=0)
        C_b = torch.mean(biases_cov, dim=0)

    # Predictive distribution
    with torch.no_grad():
        M_W_post = W.t()
        M_b_post = b

        C_W_post = C_W
        C_b_post = C_b

    print("M_W_post size: ", M_W_post.size())
    print("M_b_post size: ", M_b_post.size())
    print("C_W_post size: ", C_W_post.size())
    print("C_b_post size: ", C_b_post.size())

    return (M_W_post, M_b_post, C_W_post, C_b_post)
def test_backpack_extensions(problem):
    """Check if backpack quantities can be computed inside cockpit."""
    quantity = quantities.TICDiag(track_schedule=linear(1))

    with instantiate(problem):
        testing_harness = CustomTestHarness(problem)
        cockpit_kwargs = {"quantities": [quantity]}
        testing_harness.test(cockpit_kwargs, DiagHessian())
def forward_backward_with_backpack():
    """Provide working access to BackPACK's `DiagHessian` and `HMP`."""
    loss = loss_function(model(x), y)

    with backpack(DiagHessian(), HMP()):
        # keep graph for autodiff HVPs
        loss.backward(retain_graph=True)

    return loss
Example #5
0
X, y = X.to(device), y.to(device)

# model
model = Sequential(Flatten(), Linear(784, 10)).to(device)
lossfunc = CrossEntropyLoss().to(device)

model = extend(model)
lossfunc = extend(lossfunc)

# %%
# Standard computation of the trace
# ---------------------------------

loss = lossfunc(model(X), y)

with backpack(DiagHessian()):
    loss.backward()

tr_after_backward = sum(param.diag_h.sum() for param in model.parameters())

print(f"Tr(H) after backward: {tr_after_backward:.3f} ")

# %%
# Let's clean up the computation graph and existing BackPACK buffers
del loss

for param in model.parameters():
    del param.diag_h

# %%
# Extension hook
Example #6
0
loss = lossfunc(model(X), y)
with backpack(KFAC(mc_samples=1), KFLR(), KFRA()):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".kfac (shapes):          ", [kfac.shape for kfac in param.kfac])
    print(".kflr (shapes):          ", [kflr.shape for kflr in param.kflr])
    print(".kfra (shapes):          ", [kfra.shape for kfra in param.kfra])

# %%
# Diagonal Hessian and per-sample diagonal Hessian

loss = lossfunc(model(X), y)
with backpack(DiagHessian(), BatchDiagHessian()):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".diag_h.shape:           ", param.diag_h.shape)
    print(".diag_h_batch.shape:     ", param.diag_h_batch.shape)

# %%
# Matrix square root of the generalized Gauss-Newton or its Monte-Carlo approximation

loss = lossfunc(model(X), y)
with backpack(SqrtGGNExact(), SqrtGGNMC(mc_samples=1)):
    loss.backward()
Example #7
0
print("Individual loss shape:   ", individual_loss.shape)
print("Mini-batch loss:         ", batch_loss)
print("Averaged individual loss:", individual_loss.mean())

# It is still possible to use BackPACK in the backward pass
with backpack(
        BatchGrad(),
        Variance(),
        SumGradSquared(),
        BatchL2Grad(),
        DiagGGNExact(),
        DiagGGNMC(),
        KFAC(),
        KFLR(),
        KFRA(),
        DiagHessian(),
):
    batch_loss.backward()

# print info
for name, param in tproblem.net.named_parameters():
    print(name)
    print("\t.grad.shape:             ", param.grad.shape)
    print("\t.grad_batch.shape:       ", param.grad_batch.shape)
    print("\t.variance.shape:         ", param.variance.shape)
    print("\t.sum_grad_squared.shape: ", param.sum_grad_squared.shape)
    print("\t.batch_l2.shape:         ", param.batch_l2.shape)
    print("\t.diag_ggn_mc.shape:      ", param.diag_ggn_mc.shape)
    print("\t.diag_ggn_exact.shape:   ", param.diag_ggn_exact.shape)
    print("\t.diag_h.shape:           ", param.diag_h.shape)
    print("\t.kfac (shapes):          ", [kfac.shape for kfac in param.kfac])
Example #8
0
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".diag_ggn_mc.shape:      ", param.diag_ggn_mc.shape)
    print(".diag_ggn_exact.shape:   ", param.diag_ggn_exact.shape)

# %%
# KFAC, KFRA and KFLR

loss = lossfunc(model(X), y)
with backpack(KFAC(mc_samples=1), KFLR(), KFRA()):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".kfac (shapes):          ", [kfac.shape for kfac in param.kfac])
    print(".kflr (shapes):          ", [kflr.shape for kflr in param.kflr])
    print(".kfra (shapes):          ", [kfra.shape for kfra in param.kfra])

# %%
# Diagonal Hessian

loss = lossfunc(model(X), y)
with backpack(DiagHessian()):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".diag_h.shape:           ", param.diag_h.shape)
Example #9
0
    def step(self, closure, clip_grad=False):
        # increment step
        self.state['step'] += 1

        # deterministic closure
        seed = time.time()
        def closure_deterministic(for_backtracking=False):
            with ut.random_seed_torch(int(seed)):
                return closure(for_backtracking)
        
        # get loss and compute gradients/second-order extensions
        loss = closure_deterministic()
        if self.base_opt == 'diag_hessian':
            with backpack(DiagHessian()):            
                loss.backward()

        elif self.base_opt == 'diag_ggn_ex':
            with backpack(DiagGGNExact()):
                loss.backward()

        elif self.base_opt == 'diag_ggn_mc':
            with backpack(DiagGGNMC()):
                loss.backward()
        else:
            loss.backward()

        if clip_grad:
            torch.nn.utils.clip_grad_norm_(self.params, 0.25)
        # increment # forward-backward calls
        self.state['n_forwards'] += 1
        self.state['n_backwards'] += 1        
        # save the current parameters:
        params_current = copy.deepcopy(self.params)
        grad_current = ut.get_grad_list(self.params)
        grad_norm = ut.compute_grad_norm(grad_current)

        # keep track of step
        if self.state['step'] % int(self.n_batches_per_epoch) == 1:
            self.state['step_size_avg'] = 0.

        # if grad_norm < 1e-6:
        #     return 0.

        #  Gv options
        # =============
        # update gv
        if self.base_opt == 'diag_hessian':
            # get diagonal hessian here and store it in state['gv']
            gv = [p.diag_h for p in self.params]

        elif self.base_opt == 'diag_ggn_ex':
            gv = [p.diag_ggn_exact for p in self.params]

        elif self.base_opt == 'diag_ggn_mc':
            gv = [p.diag_ggn_mc for p in self.params]

        else:
            raise ValueError('%s does not exist' % self.gv_update)

        for gv_i in gv:
            if torch.any(gv_i < 0):
                warnings.warn("%s contains negative values." % (self.gv_update))
                print(gv)

        if self.state['gv'] is None or self.accum_gv is None:
            self.state['gv'] = gv

            if self.accum_gv == 'avg':
                # first iteration
                self.state['gv_lag'] = [[gv_i] for gv_i in gv] 
                self.state['gv_sum'] = gv

        elif self.accum_gv == 'max':
            for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)):
                self.state['gv'][i] = torch.max(gv_old, gv_new)

        elif self.accum_gv == 'sum':
            for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)):
                self.state['gv'][i] = gv_old + gv_new   

        elif self.accum_gv == 'ams':
            for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)):
                gv_accum = self.beta*gv_old + (1-self.beta)*gv_new
                self.state['gv'][i] = torch.max(gv_old, gv_accum)

        elif self.accum_gv == 'ams_no_max':
            for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)):
                # same as above without the max
                gv_accum = self.beta*gv_old + (1-self.beta)*gv_new
                self.state['gv'][i] = gv_accum

        elif self.accum_gv == 'avg':
            t = self.state['step']
            
            for i, (gv_new, gv_lag) in enumerate(zip(gv, self.state['gv_lag'])):

                if t < self.avg_window:
                    # keep track of the sum only, no need to kick anyone out 
                    # for the first window number of iterations
                    gv_sum = self.state['gv_sum'][i] + gv_new
                    # take the average of all seen so far
                    gv_accum = gv_sum / t 

                else:
                    # kick out the gv stored for iteration (t-window)
                    gv_kick_out = gv_lag.pop(0)
                    # update the sum for the past window iterations
                    gv_sum = self.state['gv_sum'][i] + gv_new - gv_kick_out
                    # take the average gv within the window
                    gv_accum = gv_sum / self.avg_window

                # add the new gv to the lags and update sum
                self.state['gv_lag'][i].append(gv_new)
                self.state['sum'][i] = gv_sum
                # this will be used as the diagonal preconditioner
                self.state['gv'][i] = gv_accum

        else:
            raise ValueError('accum_gv %s does not exist' % self.accum_gv)

        if self.lm > 0:
            for i in range(len(self.state['gv'])):
                self.state['gv'][i] += self.lm 

        # compute pp norm method
        pp_norm = self.get_pp_norm(grad_current=grad_current)
        
        # compute step size - same as the SLS code but with a different norm pp_norm
        # =================
        step_size = self.get_step_size(closure_deterministic, loss, params_current, grad_current, grad_norm, pp_norm, 
                        for_backtracking=True)
            

        self.try_sgd_precond_update(self.params, step_size, params_current,
                            grad_current)

        # save the new step-size
        self.state['step_size'] = step_size

        self.state['step_size_avg'] += (step_size / int(self.n_batches_per_epoch))
        self.state['grad_norm'] = grad_norm.item()
  
        if torch.isnan(self.params[0]).sum() > 0:
            print('nan')

        return loss