def test_dot(): model = ConvNet() layer_collection = LayerCollection.from_model(model) r1 = random_pvector(layer_collection) r2 = random_pvector(layer_collection) dotr1r2 = r1.dot(r2) check_ratio( torch.dot(r1.get_flat_representation(), r2.get_flat_representation()), dotr1r2) r1 = random_pvector_dict(layer_collection) r2 = random_pvector_dict(layer_collection) dotr1r2 = r1.dot(r2) check_ratio( torch.dot(r1.get_flat_representation(), r2.get_flat_representation()), dotr1r2) r1 = random_pvector(layer_collection) r2 = random_pvector_dict(layer_collection) dotr1r2 = r1.dot(r2) dotr2r1 = r2.dot(r1) check_ratio( torch.dot(r1.get_flat_representation(), r2.get_flat_representation()), dotr1r2) check_ratio( torch.dot(r1.get_flat_representation(), r2.get_flat_representation()), dotr2r1)
def test_jacobian_pquasidiag(): for get_task in [get_conv_task, get_fullyconnect_task]: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) PMat_qd = PMatQuasiDiag(generator) dense_tensor = PMat_qd.get_dense_tensor() v = random_pvector(lc, device=device) v_flat = v.get_flat_representation() check_tensors(torch.diag(dense_tensor), PMat_qd.get_diag()) check_ratio(torch.norm(dense_tensor), PMat_qd.frobenius_norm()) check_ratio(torch.trace(dense_tensor), PMat_qd.trace()) mv = PMat_qd.mv(v) check_tensors(torch.mv(dense_tensor, v_flat), mv.get_flat_representation()) check_ratio(torch.dot(torch.mv(dense_tensor, v_flat), v_flat), PMat_qd.vTMv(v)) # Test solve regul = 1e-8 v_back = PMat_qd.solve(mv + regul * v, regul=regul) check_tensors(v.get_flat_representation(), v_back.get_flat_representation())
def test_jacobian_pimplicit_vs_pdense(): for get_task in linear_tasks + nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian(layer_collection=lc, model=model, function=function, n_output=n_output) PMat_implicit = PMatImplicit(generator=generator, examples=loader) PMat_dense = PMatDense(generator=generator, examples=loader) dw = random_pvector(lc, device=device) # Test trace check_ratio(PMat_dense.trace(), PMat_implicit.trace()) # Test mv if ('BatchNorm1dLayer' in [l.__class__.__name__ for l in lc.layers.values()] or 'BatchNorm2dLayer' in [l.__class__.__name__ for l in lc.layers.values()]): with pytest.raises(NotImplementedError): PMat_implicit.mv(dw) else: check_tensors( PMat_dense.mv(dw).get_flat_representation(), PMat_implicit.mv(dw).get_flat_representation()) # Test vTMv if ('BatchNorm1dLayer' in [l.__class__.__name__ for l in lc.layers.values()] or 'BatchNorm2dLayer' in [l.__class__.__name__ for l in lc.layers.values()]): with pytest.raises(NotImplementedError): PMat_implicit.vTMv(dw) else: check_ratio(PMat_dense.vTMv(dw), PMat_implicit.vTMv(dw))
def test_FIM_vs_linearization_classif_logits(): step = 1e-2 for get_task in nonlinear_tasks: quots = [] for i in range(10): # repeat to kill statistical fluctuations loader, lc, parameters, model, function, n_output = get_task() model.train() F = FIM(layer_collection=lc, model=model, loader=loader, variant='classif_logits', representation=PMatDense, n_output=n_output, function=lambda *d: model(to_device(d[0]))) dw = random_pvector(lc, device=device) dw = step / dw.norm() * dw output_before = get_output_vector(loader, function) update_model(parameters, dw.get_flat_representation()) output_after = get_output_vector(loader, function) update_model(parameters, -dw.get_flat_representation()) KL = tF.kl_div(tF.log_softmax(output_before, dim=1), tF.log_softmax(output_after, dim=1), log_target=True, reduction='batchmean') quot = (KL / F.vTMv(dw) * 2)**.5 quots.append(quot.item()) mean_quotient = sum(quots) / len(quots) assert mean_quotient > 1 - 5e-2 and mean_quotient < 1 + 5e-2
def test_FIM_vs_linearization_regression(): step = 1e-2 for get_task in nonlinear_tasks: quots = [] for i in range(10): # repeat to kill statistical fluctuations loader, lc, parameters, model, function, n_output = get_task() model.train() F = FIM(layer_collection=lc, model=model, loader=loader, variant='regression', representation=PMatDense, n_output=n_output, function=lambda *d: model(to_device(d[0]))) dw = random_pvector(lc, device=device) dw = step / dw.norm() * dw output_before = get_output_vector(loader, function) update_model(parameters, dw.get_flat_representation()) output_after = get_output_vector(loader, function) update_model(parameters, -dw.get_flat_representation()) diff = (((output_before - output_after)**2).sum() / output_before.size(0)) quot = (diff / F.vTMv(dw))**.5 quots.append(quot.item()) mean_quotient = sum(quots) / len(quots) assert mean_quotient > 1 - 5e-2 and mean_quotient < 1 + 5e-2
def test_grad_flat_repr(): loader, lc, parameters, model, function, n_output = get_conv_gn_task() vec = random_pvector(lc) scalar_output = vec.norm() with pytest.raises(RuntimeError): grad(scalar_output, vec)
def test_jacobian_kfac(): for get_task in [get_fullyconnect_task, get_conv_task]: loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) M_kfac = PMatKFAC(generator) G_kfac_split = M_kfac.get_dense_tensor(split_weight_bias=True) G_kfac = M_kfac.get_dense_tensor(split_weight_bias=False) # Test trace trace_direct = torch.trace(G_kfac_split) trace_kfac = M_kfac.trace() check_ratio(trace_direct, trace_kfac) # Test frobenius norm frob_direct = torch.norm(G_kfac) frob_kfac = M_kfac.frobenius_norm() check_ratio(frob_direct, frob_kfac) # Test get_diag check_tensors(torch.diag(G_kfac_split), M_kfac.get_diag(split_weight_bias=True)) # sample random vector random_v = random_pvector(lc, device) # Test mv mv_direct = torch.mv(G_kfac_split, random_v.get_flat_representation()) mv_kfac = M_kfac.mv(random_v) check_tensors(mv_direct, mv_kfac.get_flat_representation()) # Test vTMv mnorm_kfac = M_kfac.vTMv(random_v) mnorm_direct = torch.dot(mv_direct, random_v.get_flat_representation()) check_ratio(mnorm_direct, mnorm_kfac) # Test inverse # We start from a mv vector since it kills its components projected to # the small eigenvalues of KFAC regul = 1e-7 mv2 = M_kfac.mv(mv_kfac) kfac_inverse = M_kfac.inverse(regul) mv_back = kfac_inverse.mv(mv2 + regul * mv_kfac) check_tensors(mv_kfac.get_flat_representation(), mv_back.get_flat_representation(), eps=1e-2) # Test solve mv_back = M_kfac.solve(mv2 + regul * mv_kfac, regul=regul) check_tensors(mv_kfac.get_flat_representation(), mv_back.get_flat_representation(), eps=1e-2)
def test_norm(): model = ConvNet() layer_collection = LayerCollection.from_model(model) v = random_pvector(layer_collection) check_ratio(torch.norm(v.get_flat_representation()), v.norm()) v = random_pvector_dict(layer_collection) check_ratio(torch.norm(v.get_flat_representation()), v.norm())
def test_size(): model = ConvNet() layer_collection = LayerCollection.from_model(model) v = random_pvector(layer_collection) assert v.size() == v.get_flat_representation().size() v = random_pvector_dict(layer_collection) assert v.size() == v.get_flat_representation().size()
def test_pspace_ekfac_vs_direct(): """ Check EKFAC basis operations against direct computation using get_dense_tensor """ for get_task in [get_fullyconnect_task, get_conv_task]: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) M_ekfac = PMatEKFAC(generator) v = random_pvector(lc, device=device) # the second time we will have called update_diag for i in range(2): vTMv_direct = torch.dot(torch.mv(M_ekfac.get_dense_tensor(), v.get_flat_representation()), v.get_flat_representation()) vTMv_ekfac = M_ekfac.vTMv(v) check_ratio(vTMv_direct, vTMv_ekfac) trace_ekfac = M_ekfac.trace() trace_direct = torch.trace(M_ekfac.get_dense_tensor()) check_ratio(trace_direct, trace_ekfac) frob_ekfac = M_ekfac.frobenius_norm() frob_direct = torch.norm(M_ekfac.get_dense_tensor()) check_ratio(frob_direct, frob_ekfac) mv_direct = torch.mv(M_ekfac.get_dense_tensor(), v.get_flat_representation()) mv_ekfac = M_ekfac.mv(v) check_tensors(mv_direct, mv_ekfac.get_flat_representation()) # Test inverse regul = 1e-5 M_inv = M_ekfac.inverse(regul=regul) v_back = M_inv.mv(mv_ekfac + regul * v) check_tensors(v.get_flat_representation(), v_back.get_flat_representation()) # Test solve v_back = M_ekfac.solve(mv_ekfac + regul * v, regul=regul) check_tensors(v.get_flat_representation(), v_back.get_flat_representation()) # Test rmul M_mul = 1.23 * M_ekfac check_tensors(1.23 * M_ekfac.get_dense_tensor(), M_mul.get_dense_tensor()) M_ekfac.update_diag()
def test_jacobian_plowrank(): for get_task in nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian(layer_collection=lc, model=model, function=function, n_output=n_output) PMat_lowrank = PMatLowRank(generator=generator, examples=loader) dw = random_pvector(lc, device=device) dw = dw / dw.norm() dense_tensor = PMat_lowrank.get_dense_tensor() # Test get_diag check_tensors(torch.diag(dense_tensor), PMat_lowrank.get_diag(), eps=1e-4) # Test frobenius frob_PMat = PMat_lowrank.frobenius_norm() frob_direct = (dense_tensor**2).sum()**.5 check_ratio(frob_direct, frob_PMat) # Test trace trace_PMat = PMat_lowrank.trace() trace_direct = torch.trace(dense_tensor) check_ratio(trace_PMat, trace_direct) # Test mv mv_direct = torch.mv(dense_tensor, dw.get_flat_representation()) mv = PMat_lowrank.mv(dw) check_tensors(mv_direct, mv.get_flat_representation()) # Test vTMV check_ratio(torch.dot(mv_direct, dw.get_flat_representation()), PMat_lowrank.vTMv(dw)) # Test solve # We will try to recover mv, which is in the span of the # low rank matrix regul = 1e-3 mmv = PMat_lowrank.mv(mv) mv_using_inv = PMat_lowrank.solve(mmv, regul=regul) check_tensors(mv.get_flat_representation(), mv_using_inv.get_flat_representation(), eps=1e-2) # Test inv TODO # Test add, sub, rmul check_tensors(1.23 * PMat_lowrank.get_dense_tensor(), (1.23 * PMat_lowrank).get_dense_tensor())
def test_sub(): model = ConvNet() layer_collection = LayerCollection.from_model(model) r1 = random_pvector(layer_collection) r2 = random_pvector(layer_collection) sumr1r2 = r1 - r2 assert torch.norm(sumr1r2.get_flat_representation() - (r1.get_flat_representation() - r2.get_flat_representation())) < 1e-5 r1 = random_pvector_dict(layer_collection) r2 = random_pvector_dict(layer_collection) sumr1r2 = r1 - r2 assert torch.norm(sumr1r2.get_flat_representation() - (r1.get_flat_representation() - r2.get_flat_representation())) < 1e-5 r1 = random_pvector(layer_collection) r2 = random_pvector_dict(layer_collection) sumr1r2 = r1 - r2 assert torch.norm(sumr1r2.get_flat_representation() - (r1.get_flat_representation() - r2.get_flat_representation())) < 1e-5
def test_jacobian_pullback_dense(): for get_task in linear_tasks: loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) pull_back = PullBackDense(generator) push_forward = PushForwardDense(generator) dw = random_pvector(lc, device=device) doutput_lin = push_forward.mv(dw) dinput_lin = pull_back.mv(doutput_lin) check_ratio(torch.dot(dw.get_flat_representation(), dinput_lin.get_flat_representation()), torch.norm(doutput_lin.get_flat_representation())**2)
def test_jacobian_pushforward_implicit(): for get_task in linear_tasks: loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) dense_push_forward = PushForwardDense(generator) implicit_push_forward = PushForwardImplicit(generator) dw = random_pvector(lc, device=device) doutput_lin_dense = dense_push_forward.mv(dw) doutput_lin_implicit = implicit_push_forward.mv(dw) check_tensors(doutput_lin_dense.get_flat_representation(), doutput_lin_implicit.get_flat_representation())
def test_jacobian_pushforward_dense_linear(): for get_task in linear_tasks: loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian(layer_collection=lc, model=model, function=function, n_output=n_output) push_forward = PushForwardDense(generator=generator, examples=loader) dw = random_pvector(lc, device=device) doutput_lin = push_forward.mv(dw) output_before = get_output_vector(loader, function) update_model(parameters, dw.get_flat_representation()) output_after = get_output_vector(loader, function) check_tensors(output_after - output_before, doutput_lin.get_flat_representation().t())
def test_jacobian_plowrank(): for get_task in nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) PMat_lowrank = PMatLowRank(generator) dw = random_pvector(lc, device=device) dense_tensor = PMat_lowrank.get_dense_tensor() # Test get_diag check_tensors(torch.diag(dense_tensor), PMat_lowrank.get_diag(), eps=1e-4) # Test frobenius frob_PMat = PMat_lowrank.frobenius_norm() frob_direct = (dense_tensor**2).sum()**.5 check_ratio(frob_direct, frob_PMat) # Test trace trace_PMat = PMat_lowrank.trace() trace_direct = torch.trace(dense_tensor) check_ratio(trace_PMat, trace_direct) # Test mv mv_direct = torch.mv(dense_tensor, dw.get_flat_representation()) check_tensors(mv_direct, PMat_lowrank.mv(dw).get_flat_representation()) # Test vTMV check_ratio(torch.dot(mv_direct, dw.get_flat_representation()), PMat_lowrank.vTMv(dw)) # Test solve TODO # Test inv TODO # Test add, sub, rmul check_tensors(1.23 * PMat_lowrank.get_dense_tensor(), (1.23 * PMat_lowrank).get_dense_tensor())
def test_jacobian_pdense_vs_pushforward(): # NB: sometimes the test with centering=True do not pass, # which is probably due to the way we compute centering # for PMatDense: E[x^2] - (Ex)^2 is notoriously not numerically stable for get_task in linear_tasks + nonlinear_tasks: for centering in [True, False]: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output, centering=centering) push_forward = PushForwardDense(generator) pull_back = PullBackDense(generator, data=push_forward.data) PMat_dense = PMatDense(generator) dw = random_pvector(lc, device=device) n = len(loader.sampler) # Test get_dense_tensor jacobian = push_forward.get_dense_tensor() sj = jacobian.size() PMat_computed = torch.mm(jacobian.view(-1, sj[2]).t(), jacobian.view(-1, sj[2])) / n check_tensors(PMat_computed, PMat_dense.get_dense_tensor()) # Test vTMv vTMv_PMat = PMat_dense.vTMv(dw) Jv_pushforward = push_forward.mv(dw) Jv_pushforward_flat = Jv_pushforward.get_flat_representation() vTMv_pushforward = torch.dot(Jv_pushforward_flat.view(-1), Jv_pushforward_flat.view(-1)) / n check_ratio(vTMv_pushforward, vTMv_PMat) # Test Mv Mv_PMat = PMat_dense.mv(dw) Mv_pf_pb = pull_back.mv(Jv_pushforward) check_tensors(Mv_pf_pb.get_flat_representation() / n, Mv_PMat.get_flat_representation())
def test_jacobian_pushforward_dense_nonlinear(): for get_task in nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian(layer_collection=lc, model=model, function=function, n_output=n_output) push_forward = PushForwardDense(generator=generator, examples=loader) dw = random_pvector(lc, device=device) dw = 1e-4 / dw.norm() * dw doutput_lin = push_forward.mv(dw) output_before = get_output_vector(loader, function) update_model(parameters, dw.get_flat_representation()) output_after = get_output_vector(loader, function) # This is non linear, so we don't expect the finite difference # estimate to be very accurate. We use a larger eps value check_tensors(output_after - output_before, doutput_lin.get_flat_representation().t(), eps=5e-2)
results[model].append([]) F = FIM_MonteCarlo(m, smalltestloader, PMatImplicit, trials=5, device='cuda') # G = FMatDense(F.generator) # frob = torch.norm(G.sum(dim=(0, 2))) / n_samples tr = F.trace() for j in range(4): results[model][-1].append(dict()) v = random_pvector(F.generator.layer_collection, device='cuda') v_flat = v.get_flat_representation() Fv = F.mv(v) Fv_flat = Fv.get_flat_representation() vTMv = F.vTMv(v) for repr in [PMatDiag, PMatQuasiDiag, PMatKFAC, PMatEKFAC]: results[model][-1][-1][repr] = dict() F2 = repr(F.generator) if repr == PMatEKFAC: F2.update_diag() tr_repr = F2.trace() results[model][-1][-1][repr]['trace'] = torch.abs( (tr_repr - tr) / tr).item()
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) trainset = datasets.CIFAR10(root='/tmp/data', train=True, download=True, transform=transform) trainset = Subset(trainset, range(100)) trainloader = DataLoader(trainset, batch_size=50, shuffle=False, num_workers=1) # %% from resnet import ResNet50 resnet = ResNet50().cuda() layer_collection = LayerCollection.from_model(resnet) v = random_pvector(LayerCollection.from_model(resnet), device='cuda') print(f'{layer_collection.numel()} parameters') # %% # compute timings and display FIMs def perform_timing(): timings = dict() for repr in [PMatImplicit, PMatDiag, PMatEKFAC, PMatKFAC, PMatQuasiDiag]: print('Timing representation:') pprint.pprint(repr)
def test_jacobian_pblockdiag(): for get_task in nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) PMat_blockdiag = PMatBlockDiag(generator) dw = random_pvector(lc, device=device) dense_tensor = PMat_blockdiag.get_dense_tensor() # Test get_diag check_tensors(torch.diag(dense_tensor), PMat_blockdiag.get_diag()) # Test frobenius frob_PMat = PMat_blockdiag.frobenius_norm() frob_direct = (dense_tensor**2).sum()**.5 check_ratio(frob_direct, frob_PMat) # Test trace trace_PMat = PMat_blockdiag.trace() trace_direct = torch.trace(dense_tensor) check_ratio(trace_PMat, trace_direct) # Test mv mv_direct = torch.mv(dense_tensor, dw.get_flat_representation()) check_tensors(mv_direct, PMat_blockdiag.mv(dw).get_flat_representation()) # Test vTMV check_ratio(torch.dot(mv_direct, dw.get_flat_representation()), PMat_blockdiag.vTMv(dw)) # Test solve regul = 1e-3 Mv_regul = torch.mv(dense_tensor + regul * torch.eye(PMat_blockdiag.size(0), device=device), dw.get_flat_representation()) Mv_regul = PVector(layer_collection=lc, vector_repr=Mv_regul) dw_using_inv = PMat_blockdiag.solve(Mv_regul, regul=regul) check_tensors(dw.get_flat_representation(), dw_using_inv.get_flat_representation(), eps=5e-3) # Test inv PMat_inv = PMat_blockdiag.inverse(regul=regul) check_tensors(dw.get_flat_representation(), PMat_inv.mv(PMat_blockdiag.mv(dw) + regul * dw) .get_flat_representation(), eps=5e-3) # Test add, sub, rmul loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) PMat_blockdiag2 = PMatBlockDiag(generator) check_tensors(PMat_blockdiag.get_dense_tensor() + PMat_blockdiag2.get_dense_tensor(), (PMat_blockdiag + PMat_blockdiag2) .get_dense_tensor()) check_tensors(PMat_blockdiag.get_dense_tensor() - PMat_blockdiag2.get_dense_tensor(), (PMat_blockdiag - PMat_blockdiag2) .get_dense_tensor()) check_tensors(1.23 * PMat_blockdiag.get_dense_tensor(), (1.23 * PMat_blockdiag).get_dense_tensor())
def test_jacobian_pdiag_vs_pdense(): for get_task in nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) PMat_diag = PMatDiag(generator) PMat_dense = PMatDense(generator) dw = random_pvector(lc, device=device) # Test get_dense_tensor matrix_diag = PMat_diag.get_dense_tensor() matrix_dense = PMat_dense.get_dense_tensor() check_tensors(torch.diag(matrix_diag), torch.diag(matrix_dense)) assert torch.norm(matrix_diag - torch.diag(torch.diag(matrix_diag))) < 1e-5 # Test trace check_ratio(torch.trace(matrix_diag), PMat_diag.trace()) # Test frobenius check_ratio(torch.norm(matrix_diag), PMat_diag.frobenius_norm()) # Test mv mv_direct = torch.mv(matrix_diag, dw.get_flat_representation()) mv_PMat_diag = PMat_diag.mv(dw) check_tensors(mv_direct, mv_PMat_diag.get_flat_representation()) # Test vTMv vTMv_direct = torch.dot(mv_direct, dw.get_flat_representation()) vTMv_PMat_diag = PMat_diag.vTMv(dw) check_ratio(vTMv_direct, vTMv_PMat_diag) # Test inverse regul = 1e-3 PMat_diag_inverse = PMat_diag.inverse(regul) prod = torch.mm(matrix_diag + regul * torch.eye(lc.numel(), device=device), PMat_diag_inverse.get_dense_tensor()) check_tensors(torch.eye(lc.numel(), device=device), prod) # Test solve regul = 1e-3 Mv_regul = torch.mv(matrix_diag + regul * torch.eye(PMat_diag.size(0), device=device), dw.get_flat_representation()) Mv_regul = PVector(layer_collection=lc, vector_repr=Mv_regul) dw_using_inv = PMat_diag.solve(Mv_regul, regul=regul) check_tensors(dw.get_flat_representation(), dw_using_inv.get_flat_representation(), eps=5e-3) # Test get_diag diag_direct = torch.diag(matrix_diag) diag_PMat_diag = PMat_diag.get_diag() check_tensors(diag_direct, diag_PMat_diag) # Test add, sub, rmul loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) PMat_diag2 = PMatDiag(generator) check_tensors(PMat_diag.get_dense_tensor() + PMat_diag2.get_dense_tensor(), (PMat_diag + PMat_diag2).get_dense_tensor()) check_tensors(PMat_diag.get_dense_tensor() - PMat_diag2.get_dense_tensor(), (PMat_diag - PMat_diag2).get_dense_tensor()) check_tensors(1.23 * PMat_diag.get_dense_tensor(), (1.23 * PMat_diag).get_dense_tensor())
def test_jacobian_pdense(): for get_task in nonlinear_tasks: for centering in [True, False]: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output, centering=centering) PMat_dense = PMatDense(generator) dw = random_pvector(lc, device=device) # Test get_diag check_tensors(torch.diag(PMat_dense.get_dense_tensor()), PMat_dense.get_diag()) # Test frobenius frob_PMat = PMat_dense.frobenius_norm() frob_direct = (PMat_dense.get_dense_tensor()**2).sum()**.5 check_ratio(frob_direct, frob_PMat) # Test trace trace_PMat = PMat_dense.trace() trace_direct = torch.trace(PMat_dense.get_dense_tensor()) check_ratio(trace_PMat, trace_direct) # Test solve # NB: regul is high since the matrix is not full rank regul = 1e-3 Mv_regul = torch.mv(PMat_dense.get_dense_tensor() + regul * torch.eye(PMat_dense.size(0), device=device), dw.get_flat_representation()) Mv_regul = PVector(layer_collection=lc, vector_repr=Mv_regul) dw_using_inv = PMat_dense.solve(Mv_regul, regul=regul) check_tensors(dw.get_flat_representation(), dw_using_inv.get_flat_representation(), eps=5e-3) # Test inv PMat_inv = PMat_dense.inverse(regul=regul) check_tensors(dw.get_flat_representation(), PMat_inv.mv(PMat_dense.mv(dw) + regul * dw) .get_flat_representation(), eps=5e-3) # Test add, sub, rmul loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output, centering=centering) PMat_dense2 = PMatDense(generator) check_tensors(PMat_dense.get_dense_tensor() + PMat_dense2.get_dense_tensor(), (PMat_dense + PMat_dense2).get_dense_tensor()) check_tensors(PMat_dense.get_dense_tensor() - PMat_dense2.get_dense_tensor(), (PMat_dense - PMat_dense2).get_dense_tensor()) check_tensors(1.23 * PMat_dense.get_dense_tensor(), (1.23 * PMat_dense).get_dense_tensor())