step_size=STEP_SIZE, damping=DAMPING) losses = [] accuracies = [] for batch_idx, (x, y) in enumerate(mnist_loader): optimizer.zero_grad() x, y = x.to(DEVICE), y.to(DEVICE) model.zero_grad() outputs = model(x) loss = loss_function(outputs, y) with backpack(DiagGGNMC()): loss.backward() optimizer.step() # Logging losses.append(loss.detach().item()) accuracies.append(get_accuracy(outputs, y)) if (batch_idx % PRINT_EVERY) == 0: print("Iteration %3.d/%3.d " % (batch_idx, MAX_ITER) + "Minibatch Loss %.3f " % losses[-1] + "Accuracy %.3f" % accuracies[-1]) if MAX_ITER is not None and batch_idx > MAX_ITER: break
# individual loss savefield = "_unreduced_loss" individual_loss = getattr(batch_loss, savefield) print("Individual loss shape: ", individual_loss.shape) print("Mini-batch loss: ", batch_loss) print("Averaged individual loss:", individual_loss.mean()) # It is still possible to use BackPACK in the backward pass with backpack( BatchGrad(), Variance(), SumGradSquared(), BatchL2Grad(), DiagGGNExact(), DiagGGNMC(), KFAC(), KFLR(), KFRA(), DiagHessian(), ): batch_loss.backward() # print info for name, param in tproblem.net.named_parameters(): print(name) print("\t.grad.shape: ", param.grad.shape) print("\t.grad_batch.shape: ", param.grad_batch.shape) print("\t.variance.shape: ", param.variance.shape) print("\t.sum_grad_squared.shape: ", param.sum_grad_squared.shape) print("\t.batch_l2.shape: ", param.batch_l2.shape)
print(name) print(".grad.shape: ", param.grad.shape) print(".grad_batch.shape: ", param.grad_batch.shape) print(".variance.shape: ", param.variance.shape) print(".sum_grad_squared.shape: ", param.sum_grad_squared.shape) print(".batch_l2.shape: ", param.batch_l2.shape) # %% # Second order extensions # -------------------------- # %% # Diagonal of the generalized Gauss-Newton and its Monte-Carlo approximation loss = lossfunc(model(X), y) with backpack(DiagGGNExact(), DiagGGNMC(mc_samples=1)): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".diag_ggn_mc.shape: ", param.diag_ggn_mc.shape) print(".diag_ggn_exact.shape: ", param.diag_ggn_exact.shape) # %% # Per-sample diagonal of the generalized Gauss-Newton and its Monte-Carlo approximation loss = lossfunc(model(X), y) with backpack(BatchDiagGGNExact(), BatchDiagGGNMC(mc_samples=1)): loss.backward()
def step(self, closure, clip_grad=False): # increment step self.state['step'] += 1 # deterministic closure seed = time.time() def closure_deterministic(for_backtracking=False): with ut.random_seed_torch(int(seed)): return closure(for_backtracking) # get loss and compute gradients/second-order extensions loss = closure_deterministic() if self.base_opt == 'diag_hessian': with backpack(DiagHessian()): loss.backward() elif self.base_opt == 'diag_ggn_ex': with backpack(DiagGGNExact()): loss.backward() elif self.base_opt == 'diag_ggn_mc': with backpack(DiagGGNMC()): loss.backward() else: loss.backward() if clip_grad: torch.nn.utils.clip_grad_norm_(self.params, 0.25) # increment # forward-backward calls self.state['n_forwards'] += 1 self.state['n_backwards'] += 1 # save the current parameters: params_current = copy.deepcopy(self.params) grad_current = ut.get_grad_list(self.params) grad_norm = ut.compute_grad_norm(grad_current) # keep track of step if self.state['step'] % int(self.n_batches_per_epoch) == 1: self.state['step_size_avg'] = 0. # if grad_norm < 1e-6: # return 0. # Gv options # ============= # update gv if self.base_opt == 'diag_hessian': # get diagonal hessian here and store it in state['gv'] gv = [p.diag_h for p in self.params] elif self.base_opt == 'diag_ggn_ex': gv = [p.diag_ggn_exact for p in self.params] elif self.base_opt == 'diag_ggn_mc': gv = [p.diag_ggn_mc for p in self.params] else: raise ValueError('%s does not exist' % self.gv_update) for gv_i in gv: if torch.any(gv_i < 0): warnings.warn("%s contains negative values." % (self.gv_update)) print(gv) if self.state['gv'] is None or self.accum_gv is None: self.state['gv'] = gv if self.accum_gv == 'avg': # first iteration self.state['gv_lag'] = [[gv_i] for gv_i in gv] self.state['gv_sum'] = gv elif self.accum_gv == 'max': for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)): self.state['gv'][i] = torch.max(gv_old, gv_new) elif self.accum_gv == 'sum': for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)): self.state['gv'][i] = gv_old + gv_new elif self.accum_gv == 'ams': for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)): gv_accum = self.beta*gv_old + (1-self.beta)*gv_new self.state['gv'][i] = torch.max(gv_old, gv_accum) elif self.accum_gv == 'ams_no_max': for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)): # same as above without the max gv_accum = self.beta*gv_old + (1-self.beta)*gv_new self.state['gv'][i] = gv_accum elif self.accum_gv == 'avg': t = self.state['step'] for i, (gv_new, gv_lag) in enumerate(zip(gv, self.state['gv_lag'])): if t < self.avg_window: # keep track of the sum only, no need to kick anyone out # for the first window number of iterations gv_sum = self.state['gv_sum'][i] + gv_new # take the average of all seen so far gv_accum = gv_sum / t else: # kick out the gv stored for iteration (t-window) gv_kick_out = gv_lag.pop(0) # update the sum for the past window iterations gv_sum = self.state['gv_sum'][i] + gv_new - gv_kick_out # take the average gv within the window gv_accum = gv_sum / self.avg_window # add the new gv to the lags and update sum self.state['gv_lag'][i].append(gv_new) self.state['sum'][i] = gv_sum # this will be used as the diagonal preconditioner self.state['gv'][i] = gv_accum else: raise ValueError('accum_gv %s does not exist' % self.accum_gv) if self.lm > 0: for i in range(len(self.state['gv'])): self.state['gv'][i] += self.lm # compute pp norm method pp_norm = self.get_pp_norm(grad_current=grad_current) # compute step size - same as the SLS code but with a different norm pp_norm # ================= step_size = self.get_step_size(closure_deterministic, loss, params_current, grad_current, grad_norm, pp_norm, for_backtracking=True) self.try_sgd_precond_update(self.params, step_size, params_current, grad_current) # save the new step-size self.state['step_size'] = step_size self.state['step_size_avg'] += (step_size / int(self.n_batches_per_epoch)) self.state['grad_norm'] = grad_norm.item() if torch.isnan(self.params[0]).sum() > 0: print('nan') return loss