def test_compute_hessians(network, dataset, data_type, network_type, x, y): network.train() batch_size, num_iterations = get_batch_type(data_type) batch_sampler = BatchSampler(dataset, num_iterations, batch_size) # train by iteration, not epoch data_loader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=4) network_seq = get_seq_network(network_type) network_seq = copy_network(network, network_seq) criterion = nn.CrossEntropyLoss() criterion = extend(criterion) network_seq = extend(network_seq).cuda() hessians = None x = x.cuda() y = y.cuda() if data_type == 'mnist': x = x.view(len(x), -1) out = network_seq(x) loss = criterion(out, y) with backpack(DiagHessian()): loss.backward() hessians = get_hessians(network_seq, hessians) return hessians, x, y
def Diag_second_order(model, train_loader, prec0=10, device='cpu'): W = list(model.parameters())[-2] b = list(model.parameters())[-1] m, n = W.shape print("n: {} inputs to linear layer with m: {} classes".format(n, m)) lossfunc = torch.nn.CrossEntropyLoss() var0 = 1 / prec0 extend(lossfunc, debug=False) extend(model.linear, debug=False) with backpack(DiagHessian()): max_len = len(train_loader) weights_cov = torch.zeros(max_len, m, n, device=device) biases_cov = torch.zeros(max_len, m, device=device) for batch_idx, (x, y) in enumerate(train_loader): if device == 'cuda': x, y = x.cuda(), y.cuda() model.zero_grad() lossfunc(model(x), y).backward() with torch.no_grad(): # Hessian of weight W_ = W.diag_h b_ = b.diag_h #add_prior: since it will be flattened later we can just add the prior like that W_ += var0 * torch.ones(W_.size(), device=device) b_ += var0 * torch.ones(b_.size(), device=device) weights_cov[batch_idx] = W_ biases_cov[batch_idx] = b_ print("Batch: {}/{}".format(batch_idx, max_len)) print(len(weights_cov)) C_W = torch.mean(weights_cov, dim=0) C_b = torch.mean(biases_cov, dim=0) # Predictive distribution with torch.no_grad(): M_W_post = W.t() M_b_post = b C_W_post = C_W C_b_post = C_b print("M_W_post size: ", M_W_post.size()) print("M_b_post size: ", M_b_post.size()) print("C_W_post size: ", C_W_post.size()) print("C_b_post size: ", C_b_post.size()) return (M_W_post, M_b_post, C_W_post, C_b_post)
def test_backpack_extensions(problem): """Check if backpack quantities can be computed inside cockpit.""" quantity = quantities.TICDiag(track_schedule=linear(1)) with instantiate(problem): testing_harness = CustomTestHarness(problem) cockpit_kwargs = {"quantities": [quantity]} testing_harness.test(cockpit_kwargs, DiagHessian())
def forward_backward_with_backpack(): """Provide working access to BackPACK's `DiagHessian` and `HMP`.""" loss = loss_function(model(x), y) with backpack(DiagHessian(), HMP()): # keep graph for autodiff HVPs loss.backward(retain_graph=True) return loss
X, y = X.to(device), y.to(device) # model model = Sequential(Flatten(), Linear(784, 10)).to(device) lossfunc = CrossEntropyLoss().to(device) model = extend(model) lossfunc = extend(lossfunc) # %% # Standard computation of the trace # --------------------------------- loss = lossfunc(model(X), y) with backpack(DiagHessian()): loss.backward() tr_after_backward = sum(param.diag_h.sum() for param in model.parameters()) print(f"Tr(H) after backward: {tr_after_backward:.3f} ") # %% # Let's clean up the computation graph and existing BackPACK buffers del loss for param in model.parameters(): del param.diag_h # %% # Extension hook
loss = lossfunc(model(X), y) with backpack(KFAC(mc_samples=1), KFLR(), KFRA()): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".kfac (shapes): ", [kfac.shape for kfac in param.kfac]) print(".kflr (shapes): ", [kflr.shape for kflr in param.kflr]) print(".kfra (shapes): ", [kfra.shape for kfra in param.kfra]) # %% # Diagonal Hessian and per-sample diagonal Hessian loss = lossfunc(model(X), y) with backpack(DiagHessian(), BatchDiagHessian()): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".diag_h.shape: ", param.diag_h.shape) print(".diag_h_batch.shape: ", param.diag_h_batch.shape) # %% # Matrix square root of the generalized Gauss-Newton or its Monte-Carlo approximation loss = lossfunc(model(X), y) with backpack(SqrtGGNExact(), SqrtGGNMC(mc_samples=1)): loss.backward()
print("Individual loss shape: ", individual_loss.shape) print("Mini-batch loss: ", batch_loss) print("Averaged individual loss:", individual_loss.mean()) # It is still possible to use BackPACK in the backward pass with backpack( BatchGrad(), Variance(), SumGradSquared(), BatchL2Grad(), DiagGGNExact(), DiagGGNMC(), KFAC(), KFLR(), KFRA(), DiagHessian(), ): batch_loss.backward() # print info for name, param in tproblem.net.named_parameters(): print(name) print("\t.grad.shape: ", param.grad.shape) print("\t.grad_batch.shape: ", param.grad_batch.shape) print("\t.variance.shape: ", param.variance.shape) print("\t.sum_grad_squared.shape: ", param.sum_grad_squared.shape) print("\t.batch_l2.shape: ", param.batch_l2.shape) print("\t.diag_ggn_mc.shape: ", param.diag_ggn_mc.shape) print("\t.diag_ggn_exact.shape: ", param.diag_ggn_exact.shape) print("\t.diag_h.shape: ", param.diag_h.shape) print("\t.kfac (shapes): ", [kfac.shape for kfac in param.kfac])
print(name) print(".grad.shape: ", param.grad.shape) print(".diag_ggn_mc.shape: ", param.diag_ggn_mc.shape) print(".diag_ggn_exact.shape: ", param.diag_ggn_exact.shape) # %% # KFAC, KFRA and KFLR loss = lossfunc(model(X), y) with backpack(KFAC(mc_samples=1), KFLR(), KFRA()): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".kfac (shapes): ", [kfac.shape for kfac in param.kfac]) print(".kflr (shapes): ", [kflr.shape for kflr in param.kflr]) print(".kfra (shapes): ", [kfra.shape for kfra in param.kfra]) # %% # Diagonal Hessian loss = lossfunc(model(X), y) with backpack(DiagHessian()): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".diag_h.shape: ", param.diag_h.shape)
def step(self, closure, clip_grad=False): # increment step self.state['step'] += 1 # deterministic closure seed = time.time() def closure_deterministic(for_backtracking=False): with ut.random_seed_torch(int(seed)): return closure(for_backtracking) # get loss and compute gradients/second-order extensions loss = closure_deterministic() if self.base_opt == 'diag_hessian': with backpack(DiagHessian()): loss.backward() elif self.base_opt == 'diag_ggn_ex': with backpack(DiagGGNExact()): loss.backward() elif self.base_opt == 'diag_ggn_mc': with backpack(DiagGGNMC()): loss.backward() else: loss.backward() if clip_grad: torch.nn.utils.clip_grad_norm_(self.params, 0.25) # increment # forward-backward calls self.state['n_forwards'] += 1 self.state['n_backwards'] += 1 # save the current parameters: params_current = copy.deepcopy(self.params) grad_current = ut.get_grad_list(self.params) grad_norm = ut.compute_grad_norm(grad_current) # keep track of step if self.state['step'] % int(self.n_batches_per_epoch) == 1: self.state['step_size_avg'] = 0. # if grad_norm < 1e-6: # return 0. # Gv options # ============= # update gv if self.base_opt == 'diag_hessian': # get diagonal hessian here and store it in state['gv'] gv = [p.diag_h for p in self.params] elif self.base_opt == 'diag_ggn_ex': gv = [p.diag_ggn_exact for p in self.params] elif self.base_opt == 'diag_ggn_mc': gv = [p.diag_ggn_mc for p in self.params] else: raise ValueError('%s does not exist' % self.gv_update) for gv_i in gv: if torch.any(gv_i < 0): warnings.warn("%s contains negative values." % (self.gv_update)) print(gv) if self.state['gv'] is None or self.accum_gv is None: self.state['gv'] = gv if self.accum_gv == 'avg': # first iteration self.state['gv_lag'] = [[gv_i] for gv_i in gv] self.state['gv_sum'] = gv elif self.accum_gv == 'max': for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)): self.state['gv'][i] = torch.max(gv_old, gv_new) elif self.accum_gv == 'sum': for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)): self.state['gv'][i] = gv_old + gv_new elif self.accum_gv == 'ams': for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)): gv_accum = self.beta*gv_old + (1-self.beta)*gv_new self.state['gv'][i] = torch.max(gv_old, gv_accum) elif self.accum_gv == 'ams_no_max': for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)): # same as above without the max gv_accum = self.beta*gv_old + (1-self.beta)*gv_new self.state['gv'][i] = gv_accum elif self.accum_gv == 'avg': t = self.state['step'] for i, (gv_new, gv_lag) in enumerate(zip(gv, self.state['gv_lag'])): if t < self.avg_window: # keep track of the sum only, no need to kick anyone out # for the first window number of iterations gv_sum = self.state['gv_sum'][i] + gv_new # take the average of all seen so far gv_accum = gv_sum / t else: # kick out the gv stored for iteration (t-window) gv_kick_out = gv_lag.pop(0) # update the sum for the past window iterations gv_sum = self.state['gv_sum'][i] + gv_new - gv_kick_out # take the average gv within the window gv_accum = gv_sum / self.avg_window # add the new gv to the lags and update sum self.state['gv_lag'][i].append(gv_new) self.state['sum'][i] = gv_sum # this will be used as the diagonal preconditioner self.state['gv'][i] = gv_accum else: raise ValueError('accum_gv %s does not exist' % self.accum_gv) if self.lm > 0: for i in range(len(self.state['gv'])): self.state['gv'][i] += self.lm # compute pp norm method pp_norm = self.get_pp_norm(grad_current=grad_current) # compute step size - same as the SLS code but with a different norm pp_norm # ================= step_size = self.get_step_size(closure_deterministic, loss, params_current, grad_current, grad_norm, pp_norm, for_backtracking=True) self.try_sgd_precond_update(self.params, step_size, params_current, grad_current) # save the new step-size self.state['step_size'] = step_size self.state['step_size_avg'] += (step_size / int(self.n_batches_per_epoch)) self.state['grad_norm'] = grad_norm.item() if torch.isnan(self.params[0]).sum() > 0: print('nan') return loss