def test_jacobian_pimplicit_vs_pdense(): for get_task in linear_tasks + nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian(layer_collection=lc, model=model, function=function, n_output=n_output) PMat_implicit = PMatImplicit(generator=generator, examples=loader) PMat_dense = PMatDense(generator=generator, examples=loader) dw = random_pvector(lc, device=device) # Test trace check_ratio(PMat_dense.trace(), PMat_implicit.trace()) # Test mv if ('BatchNorm1dLayer' in [l.__class__.__name__ for l in lc.layers.values()] or 'BatchNorm2dLayer' in [l.__class__.__name__ for l in lc.layers.values()]): with pytest.raises(NotImplementedError): PMat_implicit.mv(dw) else: check_tensors( PMat_dense.mv(dw).get_flat_representation(), PMat_implicit.mv(dw).get_flat_representation()) # Test vTMv if ('BatchNorm1dLayer' in [l.__class__.__name__ for l in lc.layers.values()] or 'BatchNorm2dLayer' in [l.__class__.__name__ for l in lc.layers.values()]): with pytest.raises(NotImplementedError): PMat_implicit.vTMv(dw) else: check_ratio(PMat_dense.vTMv(dw), PMat_implicit.vTMv(dw))
def test_jacobian_pdense_vs_pushforward(): # NB: sometimes the test with centering=True do not pass, # which is probably due to the way we compute centering # for PMatDense: E[x^2] - (Ex)^2 is notoriously not numerically stable for get_task in linear_tasks + nonlinear_tasks: for centering in [True, False]: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output, centering=centering) push_forward = PushForwardDense(generator) pull_back = PullBackDense(generator, data=push_forward.data) PMat_dense = PMatDense(generator) dw = random_pvector(lc, device=device) n = len(loader.sampler) # Test get_dense_tensor jacobian = push_forward.get_dense_tensor() sj = jacobian.size() PMat_computed = torch.mm(jacobian.view(-1, sj[2]).t(), jacobian.view(-1, sj[2])) / n check_tensors(PMat_computed, PMat_dense.get_dense_tensor()) # Test vTMv vTMv_PMat = PMat_dense.vTMv(dw) Jv_pushforward = push_forward.mv(dw) Jv_pushforward_flat = Jv_pushforward.get_flat_representation() vTMv_pushforward = torch.dot(Jv_pushforward_flat.view(-1), Jv_pushforward_flat.view(-1)) / n check_ratio(vTMv_pushforward, vTMv_PMat) # Test Mv Mv_PMat = PMat_dense.mv(dw) Mv_pf_pb = pull_back.mv(Jv_pushforward) check_tensors(Mv_pf_pb.get_flat_representation() / n, Mv_PMat.get_flat_representation())
def test_jacobian_pdense(): for get_task in nonlinear_tasks: for centering in [True, False]: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output, centering=centering) PMat_dense = PMatDense(generator) dw = random_pvector(lc, device=device) # Test get_diag check_tensors(torch.diag(PMat_dense.get_dense_tensor()), PMat_dense.get_diag()) # Test frobenius frob_PMat = PMat_dense.frobenius_norm() frob_direct = (PMat_dense.get_dense_tensor()**2).sum()**.5 check_ratio(frob_direct, frob_PMat) # Test trace trace_PMat = PMat_dense.trace() trace_direct = torch.trace(PMat_dense.get_dense_tensor()) check_ratio(trace_PMat, trace_direct) # Test solve # NB: regul is high since the matrix is not full rank regul = 1e-3 Mv_regul = torch.mv(PMat_dense.get_dense_tensor() + regul * torch.eye(PMat_dense.size(0), device=device), dw.get_flat_representation()) Mv_regul = PVector(layer_collection=lc, vector_repr=Mv_regul) dw_using_inv = PMat_dense.solve(Mv_regul, regul=regul) check_tensors(dw.get_flat_representation(), dw_using_inv.get_flat_representation(), eps=5e-3) # Test inv PMat_inv = PMat_dense.inverse(regul=regul) check_tensors(dw.get_flat_representation(), PMat_inv.mv(PMat_dense.mv(dw) + regul * dw) .get_flat_representation(), eps=5e-3) # Test add, sub, rmul loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output, centering=centering) PMat_dense2 = PMatDense(generator) check_tensors(PMat_dense.get_dense_tensor() + PMat_dense2.get_dense_tensor(), (PMat_dense + PMat_dense2).get_dense_tensor()) check_tensors(PMat_dense.get_dense_tensor() - PMat_dense2.get_dense_tensor(), (PMat_dense - PMat_dense2).get_dense_tensor()) check_tensors(1.23 * PMat_dense.get_dense_tensor(), (1.23 * PMat_dense).get_dense_tensor())