def test_example_passing(): # test when passing a minibatch of examples instead of the full dataloader for get_task in [get_fullyconnect_task]: loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian(layer_collection=lc, model=model, function=function, n_output=n_output) sum_mats = None tot_examples = 0 for d in iter(loader): this_mat = PMatDense(generator=generator, examples=d) n_examples = len(d[0]) if sum_mats is None: sum_mats = n_examples * this_mat else: sum_mats = n_examples * this_mat + sum_mats tot_examples += n_examples PMat_dense = PMatDense(generator=generator, examples=loader) check_tensors(PMat_dense.get_dense_tensor(), (1. / tot_examples * sum_mats).get_dense_tensor())
def test_jacobian_pblockdiag_vs_pdense(): for get_task in linear_tasks + nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian(layer_collection=lc, model=model, function=function, n_output=n_output) PMat_blockdiag = PMatBlockDiag(generator=generator, examples=loader) PMat_dense = PMatDense(generator=generator, examples=loader) # Test get_dense_tensor matrix_blockdiag = PMat_blockdiag.get_dense_tensor() matrix_dense = PMat_dense.get_dense_tensor() for layer_id, layer in lc.layers.items(): start = lc.p_pos[layer_id] # compare blocks check_tensors( matrix_dense[start:start + layer.numel(), start:start + layer.numel()], matrix_blockdiag[start:start + layer.numel(), start:start + layer.numel()]) # verify that the rest is 0 assert torch.norm(matrix_blockdiag[start:start + layer.numel(), start + layer.numel():]) < 1e-5 assert torch.norm(matrix_blockdiag[start+layer.numel():, start:start+layer.numel()]) \ < 1e-5
def test_jacobian_pquasidiag_vs_pdense(): for get_task in [get_conv_task, get_fullyconnect_task]: loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian(layer_collection=lc, model=model, function=function, n_output=n_output) PMat_qd = PMatQuasiDiag(generator=generator, examples=loader) PMat_dense = PMatDense(generator=generator, examples=loader) # Test get_dense_tensor matrix_qd = PMat_qd.get_dense_tensor() matrix_dense = PMat_dense.get_dense_tensor() for layer_id, layer in lc.layers.items(): start = lc.p_pos[layer_id] # compare diags sw = layer.weight.numel() check_tensors( torch.diag( torch.diag(matrix_dense[start:start + sw, start:start + sw])), matrix_qd[start:start + sw, start:start + sw]) if layer.bias is not None: sb = layer.bias.numel() check_tensors( torch.diag( torch.diag(matrix_dense[start + sw:start + sw + sb, start + sw:start + sw + sb])), matrix_qd[start + sw:start + sw + sb, start + sw:start + sw + sb]) s_in = sw // sb for i in range(sb): # check the strips bias/weight check_tensors( matrix_dense[start + i * s_in:start + (i + 1) * s_in, start + sw + i], matrix_qd[start + i * s_in:start + (i + 1) * s_in, start + sw + i]) # verify that the rest is 0 assert torch.norm(matrix_qd[start + i * s_in:start + (i + 1) * s_in, start + sw:start + sw + i]) < 1e-10 assert torch.norm( matrix_qd[start + i * s_in:start + (i + 1) * s_in, start + sw + i + 1:]) < 1e-10 # compare upper triangular block with lower triangular one check_tensors( matrix_qd[start:start + sw + sb, start + sw:], matrix_qd[start + sw:, start:start + sw + sb].t())
def test_jacobian_plowrank_vs_pdense(): for get_task in nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian(layer_collection=lc, model=model, function=function, n_output=n_output) PMat_lowrank = PMatLowRank(generator=generator, examples=loader) PMat_dense = PMatDense(generator=generator, examples=loader) # Test get_dense_tensor matrix_lowrank = PMat_lowrank.get_dense_tensor() matrix_dense = PMat_dense.get_dense_tensor() check_tensors(matrix_dense, matrix_lowrank)
def test_dense(): for get_task in nonlinear_tasks: loader, lc, parameters, model1, function1, n_output = get_task() _, _, _, model2, function2, _ = get_task() generator1 = Jacobian(layer_collection=lc, model=model1, function=function1, n_output=n_output) generator2 = Jacobian(layer_collection=lc, model=model2, function=function1, n_output=n_output) M_dense1 = PMatDense(generator=generator1, examples=loader) M_dense2 = PMatDense(generator=generator2, examples=loader) prod = M_dense1.mm(M_dense2) M_dense1_tensor = M_dense1.get_dense_tensor() M_dense2_tensor = M_dense2.get_dense_tensor() prod_tensor = prod.get_dense_tensor() check_tensors(torch.mm(M_dense1_tensor, M_dense2_tensor), prod_tensor)
def test_jacobian_pdense_vs_pushforward(): # NB: sometimes the test with centering=True do not pass, # which is probably due to the way we compute centering # for PMatDense: E[x^2] - (Ex)^2 is notoriously not numerically stable for get_task in linear_tasks + nonlinear_tasks: for centering in [True, False]: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output, centering=centering) push_forward = PushForwardDense(generator) pull_back = PullBackDense(generator, data=push_forward.data) PMat_dense = PMatDense(generator) dw = random_pvector(lc, device=device) n = len(loader.sampler) # Test get_dense_tensor jacobian = push_forward.get_dense_tensor() sj = jacobian.size() PMat_computed = torch.mm(jacobian.view(-1, sj[2]).t(), jacobian.view(-1, sj[2])) / n check_tensors(PMat_computed, PMat_dense.get_dense_tensor()) # Test vTMv vTMv_PMat = PMat_dense.vTMv(dw) Jv_pushforward = push_forward.mv(dw) Jv_pushforward_flat = Jv_pushforward.get_flat_representation() vTMv_pushforward = torch.dot(Jv_pushforward_flat.view(-1), Jv_pushforward_flat.view(-1)) / n check_ratio(vTMv_pushforward, vTMv_PMat) # Test Mv Mv_PMat = PMat_dense.mv(dw) Mv_pf_pb = pull_back.mv(Jv_pushforward) check_tensors(Mv_pf_pb.get_flat_representation() / n, Mv_PMat.get_flat_representation())
def test_jacobian_pdiag_vs_pdense(): for get_task in nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) PMat_diag = PMatDiag(generator) PMat_dense = PMatDense(generator) dw = random_pvector(lc, device=device) # Test get_dense_tensor matrix_diag = PMat_diag.get_dense_tensor() matrix_dense = PMat_dense.get_dense_tensor() check_tensors(torch.diag(matrix_diag), torch.diag(matrix_dense)) assert torch.norm(matrix_diag - torch.diag(torch.diag(matrix_diag))) < 1e-5 # Test trace check_ratio(torch.trace(matrix_diag), PMat_diag.trace()) # Test frobenius check_ratio(torch.norm(matrix_diag), PMat_diag.frobenius_norm()) # Test mv mv_direct = torch.mv(matrix_diag, dw.get_flat_representation()) mv_PMat_diag = PMat_diag.mv(dw) check_tensors(mv_direct, mv_PMat_diag.get_flat_representation()) # Test vTMv vTMv_direct = torch.dot(mv_direct, dw.get_flat_representation()) vTMv_PMat_diag = PMat_diag.vTMv(dw) check_ratio(vTMv_direct, vTMv_PMat_diag) # Test inverse regul = 1e-3 PMat_diag_inverse = PMat_diag.inverse(regul) prod = torch.mm(matrix_diag + regul * torch.eye(lc.numel(), device=device), PMat_diag_inverse.get_dense_tensor()) check_tensors(torch.eye(lc.numel(), device=device), prod) # Test solve regul = 1e-3 Mv_regul = torch.mv(matrix_diag + regul * torch.eye(PMat_diag.size(0), device=device), dw.get_flat_representation()) Mv_regul = PVector(layer_collection=lc, vector_repr=Mv_regul) dw_using_inv = PMat_diag.solve(Mv_regul, regul=regul) check_tensors(dw.get_flat_representation(), dw_using_inv.get_flat_representation(), eps=5e-3) # Test get_diag diag_direct = torch.diag(matrix_diag) diag_PMat_diag = PMat_diag.get_diag() check_tensors(diag_direct, diag_PMat_diag) # Test add, sub, rmul loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) PMat_diag2 = PMatDiag(generator) check_tensors(PMat_diag.get_dense_tensor() + PMat_diag2.get_dense_tensor(), (PMat_diag + PMat_diag2).get_dense_tensor()) check_tensors(PMat_diag.get_dense_tensor() - PMat_diag2.get_dense_tensor(), (PMat_diag - PMat_diag2).get_dense_tensor()) check_tensors(1.23 * PMat_diag.get_dense_tensor(), (1.23 * PMat_diag).get_dense_tensor())
def test_jacobian_pdense(): for get_task in nonlinear_tasks: for centering in [True, False]: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output, centering=centering) PMat_dense = PMatDense(generator) dw = random_pvector(lc, device=device) # Test get_diag check_tensors(torch.diag(PMat_dense.get_dense_tensor()), PMat_dense.get_diag()) # Test frobenius frob_PMat = PMat_dense.frobenius_norm() frob_direct = (PMat_dense.get_dense_tensor()**2).sum()**.5 check_ratio(frob_direct, frob_PMat) # Test trace trace_PMat = PMat_dense.trace() trace_direct = torch.trace(PMat_dense.get_dense_tensor()) check_ratio(trace_PMat, trace_direct) # Test solve # NB: regul is high since the matrix is not full rank regul = 1e-3 Mv_regul = torch.mv(PMat_dense.get_dense_tensor() + regul * torch.eye(PMat_dense.size(0), device=device), dw.get_flat_representation()) Mv_regul = PVector(layer_collection=lc, vector_repr=Mv_regul) dw_using_inv = PMat_dense.solve(Mv_regul, regul=regul) check_tensors(dw.get_flat_representation(), dw_using_inv.get_flat_representation(), eps=5e-3) # Test inv PMat_inv = PMat_dense.inverse(regul=regul) check_tensors(dw.get_flat_representation(), PMat_inv.mv(PMat_dense.mv(dw) + regul * dw) .get_flat_representation(), eps=5e-3) # Test add, sub, rmul loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output, centering=centering) PMat_dense2 = PMatDense(generator) check_tensors(PMat_dense.get_dense_tensor() + PMat_dense2.get_dense_tensor(), (PMat_dense + PMat_dense2).get_dense_tensor()) check_tensors(PMat_dense.get_dense_tensor() - PMat_dense2.get_dense_tensor(), (PMat_dense - PMat_dense2).get_dense_tensor()) check_tensors(1.23 * PMat_dense.get_dense_tensor(), (1.23 * PMat_dense).get_dense_tensor())