def test_jacobian_pimplicit_vs_pdense(): for get_task in linear_tasks + nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian(layer_collection=lc, model=model, function=function, n_output=n_output) PMat_implicit = PMatImplicit(generator=generator, examples=loader) PMat_dense = PMatDense(generator=generator, examples=loader) dw = random_pvector(lc, device=device) # Test trace check_ratio(PMat_dense.trace(), PMat_implicit.trace()) # Test mv if ('BatchNorm1dLayer' in [l.__class__.__name__ for l in lc.layers.values()] or 'BatchNorm2dLayer' in [l.__class__.__name__ for l in lc.layers.values()]): with pytest.raises(NotImplementedError): PMat_implicit.mv(dw) else: check_tensors( PMat_dense.mv(dw).get_flat_representation(), PMat_implicit.mv(dw).get_flat_representation()) # Test vTMv if ('BatchNorm1dLayer' in [l.__class__.__name__ for l in lc.layers.values()] or 'BatchNorm2dLayer' in [l.__class__.__name__ for l in lc.layers.values()]): with pytest.raises(NotImplementedError): PMat_implicit.vTMv(dw) else: check_ratio(PMat_dense.vTMv(dw), PMat_implicit.vTMv(dw))
def test_jacobian_pdense_vs_pushforward(): # NB: sometimes the test with centering=True do not pass, # which is probably due to the way we compute centering # for PMatDense: E[x^2] - (Ex)^2 is notoriously not numerically stable for get_task in linear_tasks + nonlinear_tasks: for centering in [True, False]: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output, centering=centering) push_forward = PushForwardDense(generator) pull_back = PullBackDense(generator, data=push_forward.data) PMat_dense = PMatDense(generator) dw = random_pvector(lc, device=device) n = len(loader.sampler) # Test get_dense_tensor jacobian = push_forward.get_dense_tensor() sj = jacobian.size() PMat_computed = torch.mm(jacobian.view(-1, sj[2]).t(), jacobian.view(-1, sj[2])) / n check_tensors(PMat_computed, PMat_dense.get_dense_tensor()) # Test vTMv vTMv_PMat = PMat_dense.vTMv(dw) Jv_pushforward = push_forward.mv(dw) Jv_pushforward_flat = Jv_pushforward.get_flat_representation() vTMv_pushforward = torch.dot(Jv_pushforward_flat.view(-1), Jv_pushforward_flat.view(-1)) / n check_ratio(vTMv_pushforward, vTMv_PMat) # Test Mv Mv_PMat = PMat_dense.mv(dw) Mv_pf_pb = pull_back.mv(Jv_pushforward) check_tensors(Mv_pf_pb.get_flat_representation() / n, Mv_PMat.get_flat_representation())