def test_network_diag_ggn(model_and_input): """Test whether the given module can compute diag_ggn. This test is placed here, because some models are too big to run with PyTorch. Thus, a full diag_ggn comparison with PyTorch is impossible. This test just checks whether it runs on BackPACK without errors. Additionally, it checks whether the forward pass is identical to the original model. Finally, a small number of elements of DiagGGN are compared. Args: model_and_input: module to test Raises: NotImplementedError: if loss_fn is not MSELoss or CrossEntropyLoss """ model_original, x, loss_fn = model_and_input model_original = model_original.eval() output_compare = model_original(x) if isinstance(loss_fn, MSELoss): y = regression_targets(output_compare.shape) elif isinstance(loss_fn, CrossEntropyLoss): y = classification_targets( (output_compare.shape[0], *output_compare.shape[2:]), output_compare.shape[1], ) else: raise NotImplementedError( f"test cannot handle loss_fn = {type(loss_fn)}") num_params = sum(p.numel() for p in model_original.parameters() if p.requires_grad) num_to_compare = 10 idx_to_compare = linspace(0, num_params - 1, num_to_compare, dtype=int32) diag_ggn_exact_to_compare = autograd_diag_ggn_exact(x, y, model_original, loss_fn, idx=idx_to_compare) model_extended = extend(model_original, use_converter=True, debug=True) output = model_extended(x) assert allclose(output, output_compare) loss = extend(loss_fn)(output, y) with backpack(DiagGGNExact()): loss.backward() diag_ggn_exact_vector = cat([ p.diag_ggn_exact.flatten() for p in model_extended.parameters() if p.requires_grad ]) for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare): assert allclose(element, diag_ggn_exact_vector[idx], atol=1e-5)
def test_graph_clear(problem) -> None: """Test that the graph is clear after a backward pass. More specifically, test that there are no saved quantities left over. Args: problem: problem consisting of inputs, and model """ inputs, model = problem extension = DiagGGNExact() outputs = extend(model)(inputs) loss = extend(MSELoss())(outputs, rand_like(outputs)) with backpack(extension): loss.backward() # test that the dictionary is empty saved_quantities: dict = extension.saved_quantities._saved_quantities assert type(saved_quantities) is dict assert not saved_quantities
extend(net_DiagGGNE) extend(loss_function) opt_DiagGGNE= DiagGGNEOptimizer(model.parameters(), step_size=STEP_SIZE, damping=DAMPING) GGNE_l_his=[] GGNE_a_his=[] start=time.time() for epoch in range(EPOCH): print('Epoch:', epoch) for step, (b_x, b_y) in enumerate(trainloader): b_x, b_y = b_x.to(DEVICE), b_y.to(DEVICE) output = net_DiagGGNE(b_x) loss = loss_function(output, b_y) with backpack(DiagGGNExact()): loss.backward() opt_DiagGGNE.step() GGNE_a_his.append(get_accuracy(output, b_y)) GGNE_l_his.append(loss.data.numpy()) end = time.time() GGNE_t=end-start fig = plt.figure() axes = [fig.add_subplot(1, 2, 1), fig.add_subplot(1, 2, 2)] axes[0].plot(GGNE_l_his) axes[0].set_title("Loss")
print(name) print(".grad.shape: ", param.grad.shape) print(".grad_batch.shape: ", param.grad_batch.shape) print(".variance.shape: ", param.variance.shape) print(".sum_grad_squared.shape: ", param.sum_grad_squared.shape) print(".batch_l2.shape: ", param.batch_l2.shape) # %% # Second order extensions # -------------------------- # %% # Diagonal of the generalized Gauss-Newton and its Monte-Carlo approximation loss = lossfunc(model(X), y) with backpack(DiagGGNExact(), DiagGGNMC(mc_samples=1)): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".diag_ggn_mc.shape: ", param.diag_ggn_mc.shape) print(".diag_ggn_exact.shape: ", param.diag_ggn_exact.shape) # %% # Per-sample diagonal of the generalized Gauss-Newton and its Monte-Carlo approximation loss = lossfunc(model(X), y) with backpack(BatchDiagGGNExact(), BatchDiagGGNMC(mc_samples=1)): loss.backward()
# individual loss savefield = "_unreduced_loss" individual_loss = getattr(batch_loss, savefield) print("Individual loss shape: ", individual_loss.shape) print("Mini-batch loss: ", batch_loss) print("Averaged individual loss:", individual_loss.mean()) # It is still possible to use BackPACK in the backward pass with backpack( BatchGrad(), Variance(), SumGradSquared(), BatchL2Grad(), DiagGGNExact(), DiagGGNMC(), KFAC(), KFLR(), KFRA(), DiagHessian(), ): batch_loss.backward() # print info for name, param in tproblem.net.named_parameters(): print(name) print("\t.grad.shape: ", param.grad.shape) print("\t.grad_batch.shape: ", param.grad_batch.shape) print("\t.variance.shape: ", param.variance.shape) print("\t.sum_grad_squared.shape: ", param.sum_grad_squared.shape)
x = self.linear1(x) x = self.linear2(x) return x model = _MyCustomModule() else: raise NotImplementedError( f"problem={problem_string} but no test setting for this.") model = extend(model.to(device)) lossfunc = extend(CrossEntropyLoss(reduction="mean").to(device)) loss = lossfunc(model(X), y) yield model, loss, problem_string @mark.parametrize("extension", [BatchGrad(), DiagGGNExact()], ids=["BatchGrad", "DiagGGNExact"]) def test_extension_hook_multiple_parameter_visits( problem, extension: BackpropExtension): """Tests whether each parameter is visited exactly once. For those cases where parameters are visited more than once (e.g. Custom containers), it tests that an error is raised. Furthermore, it is tested whether first order extensions run fine in either case, and second order extensions raise an error in the case of custom containers. Args: problem: test problem, consisting of model, loss, and problem_string extension: first or second order extension to test
p.requires_grad = False match = allclose(tolstoi_char_rnn_custom(x), tolstoi_char_rnn(x)) print(f"Forward pass of custom model matches TolstoiCharRNN? {match}") if not match: raise AssertionError("Forward passes don't match.") # %% # We can :py:func:`extend <backpack.extend>` our model and the loss function to # compute BackPACK extensions. tolstoi_char_rnn_custom = extend(tolstoi_char_rnn_custom) loss = loss_function(tolstoi_char_rnn_custom(x), y) with backpack(BatchGrad(), DiagGGNExact()): loss.backward() for name, param in tolstoi_char_rnn_custom.named_parameters(): if param.requires_grad: print( name, param.shape, param.grad_batch.shape, param.diag_ggn_exact.shape, ) # %% # Comparison of the GGN diagonal extension with :py:mod:`torch.autograd`: # # .. note::
def step(self, closure, clip_grad=False): # increment step self.state['step'] += 1 # deterministic closure seed = time.time() def closure_deterministic(for_backtracking=False): with ut.random_seed_torch(int(seed)): return closure(for_backtracking) # get loss and compute gradients/second-order extensions loss = closure_deterministic() if self.base_opt == 'diag_hessian': with backpack(DiagHessian()): loss.backward() elif self.base_opt == 'diag_ggn_ex': with backpack(DiagGGNExact()): loss.backward() elif self.base_opt == 'diag_ggn_mc': with backpack(DiagGGNMC()): loss.backward() else: loss.backward() if clip_grad: torch.nn.utils.clip_grad_norm_(self.params, 0.25) # increment # forward-backward calls self.state['n_forwards'] += 1 self.state['n_backwards'] += 1 # save the current parameters: params_current = copy.deepcopy(self.params) grad_current = ut.get_grad_list(self.params) grad_norm = ut.compute_grad_norm(grad_current) # keep track of step if self.state['step'] % int(self.n_batches_per_epoch) == 1: self.state['step_size_avg'] = 0. # if grad_norm < 1e-6: # return 0. # Gv options # ============= # update gv if self.base_opt == 'diag_hessian': # get diagonal hessian here and store it in state['gv'] gv = [p.diag_h for p in self.params] elif self.base_opt == 'diag_ggn_ex': gv = [p.diag_ggn_exact for p in self.params] elif self.base_opt == 'diag_ggn_mc': gv = [p.diag_ggn_mc for p in self.params] else: raise ValueError('%s does not exist' % self.gv_update) for gv_i in gv: if torch.any(gv_i < 0): warnings.warn("%s contains negative values." % (self.gv_update)) print(gv) if self.state['gv'] is None or self.accum_gv is None: self.state['gv'] = gv if self.accum_gv == 'avg': # first iteration self.state['gv_lag'] = [[gv_i] for gv_i in gv] self.state['gv_sum'] = gv elif self.accum_gv == 'max': for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)): self.state['gv'][i] = torch.max(gv_old, gv_new) elif self.accum_gv == 'sum': for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)): self.state['gv'][i] = gv_old + gv_new elif self.accum_gv == 'ams': for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)): gv_accum = self.beta*gv_old + (1-self.beta)*gv_new self.state['gv'][i] = torch.max(gv_old, gv_accum) elif self.accum_gv == 'ams_no_max': for i, (gv_old, gv_new) in enumerate(zip(self.state['gv'], gv)): # same as above without the max gv_accum = self.beta*gv_old + (1-self.beta)*gv_new self.state['gv'][i] = gv_accum elif self.accum_gv == 'avg': t = self.state['step'] for i, (gv_new, gv_lag) in enumerate(zip(gv, self.state['gv_lag'])): if t < self.avg_window: # keep track of the sum only, no need to kick anyone out # for the first window number of iterations gv_sum = self.state['gv_sum'][i] + gv_new # take the average of all seen so far gv_accum = gv_sum / t else: # kick out the gv stored for iteration (t-window) gv_kick_out = gv_lag.pop(0) # update the sum for the past window iterations gv_sum = self.state['gv_sum'][i] + gv_new - gv_kick_out # take the average gv within the window gv_accum = gv_sum / self.avg_window # add the new gv to the lags and update sum self.state['gv_lag'][i].append(gv_new) self.state['sum'][i] = gv_sum # this will be used as the diagonal preconditioner self.state['gv'][i] = gv_accum else: raise ValueError('accum_gv %s does not exist' % self.accum_gv) if self.lm > 0: for i in range(len(self.state['gv'])): self.state['gv'][i] += self.lm # compute pp norm method pp_norm = self.get_pp_norm(grad_current=grad_current) # compute step size - same as the SLS code but with a different norm pp_norm # ================= step_size = self.get_step_size(closure_deterministic, loss, params_current, grad_current, grad_norm, pp_norm, for_backtracking=True) self.try_sgd_precond_update(self.params, step_size, params_current, grad_current) # save the new step-size self.state['step_size'] = step_size self.state['step_size_avg'] += (step_size / int(self.n_batches_per_epoch)) self.state['grad_norm'] = grad_norm.item() if torch.isnan(self.params[0]).sum() > 0: print('nan') return loss