def test_FIM_vs_linearization_classif_logits(): step = 1e-2 for get_task in nonlinear_tasks: quots = [] for i in range(10): # repeat to kill statistical fluctuations loader, lc, parameters, model, function, n_output = get_task() model.train() F = FIM(layer_collection=lc, model=model, loader=loader, variant='classif_logits', representation=PMatDense, n_output=n_output, function=lambda *d: model(to_device(d[0]))) dw = random_pvector(lc, device=device) dw = step / dw.norm() * dw output_before = get_output_vector(loader, function) update_model(parameters, dw.get_flat_representation()) output_after = get_output_vector(loader, function) update_model(parameters, -dw.get_flat_representation()) KL = tF.kl_div(tF.log_softmax(output_before, dim=1), tF.log_softmax(output_after, dim=1), log_target=True, reduction='batchmean') quot = (KL / F.vTMv(dw) * 2)**.5 quots.append(quot.item()) mean_quotient = sum(quots) / len(quots) assert mean_quotient > 1 - 5e-2 and mean_quotient < 1 + 5e-2
def test_FIM_vs_linearization_regression(): step = 1e-2 for get_task in nonlinear_tasks: quots = [] for i in range(10): # repeat to kill statistical fluctuations loader, lc, parameters, model, function, n_output = get_task() model.train() F = FIM(layer_collection=lc, model=model, loader=loader, variant='regression', representation=PMatDense, n_output=n_output, function=lambda *d: model(to_device(d[0]))) dw = random_pvector(lc, device=device) dw = step / dw.norm() * dw output_before = get_output_vector(loader, function) update_model(parameters, dw.get_flat_representation()) output_after = get_output_vector(loader, function) update_model(parameters, -dw.get_flat_representation()) diff = (((output_before - output_after)**2).sum() / output_before.size(0)) quot = (diff / F.vTMv(dw))**.5 quots.append(quot.item()) mean_quotient = sum(quots) / len(quots) assert mean_quotient > 1 - 5e-2 and mean_quotient < 1 + 5e-2
def test_grad_dict_repr(): loader, lc, parameters, model, function, n_output = get_conv_gn_task() d, _ = next(iter(loader)) scalar_output = model(to_device(d)).sum() vec = PVector.from_model(model) grad_nng = grad(scalar_output, vec, retain_graph=True) scalar_output.backward() grad_direct = PVector.from_model_grad(model) check_tensors(grad_direct.get_flat_representation(), grad_nng.get_flat_representation())