Exemplo n.º 1
0
def test_jacobian_pimplicit_vs_pdense():
    for get_task in linear_tasks + nonlinear_tasks:
        loader, lc, parameters, model, function, n_output = get_task()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             function=function,
                             n_output=n_output)
        PMat_implicit = PMatImplicit(generator=generator, examples=loader)
        PMat_dense = PMatDense(generator=generator, examples=loader)
        dw = random_pvector(lc, device=device)

        # Test trace
        check_ratio(PMat_dense.trace(), PMat_implicit.trace())

        # Test mv
        if ('BatchNorm1dLayer'
                in [l.__class__.__name__
                    for l in lc.layers.values()] or 'BatchNorm2dLayer'
                in [l.__class__.__name__ for l in lc.layers.values()]):
            with pytest.raises(NotImplementedError):
                PMat_implicit.mv(dw)
        else:
            check_tensors(
                PMat_dense.mv(dw).get_flat_representation(),
                PMat_implicit.mv(dw).get_flat_representation())

        # Test vTMv
        if ('BatchNorm1dLayer'
                in [l.__class__.__name__
                    for l in lc.layers.values()] or 'BatchNorm2dLayer'
                in [l.__class__.__name__ for l in lc.layers.values()]):
            with pytest.raises(NotImplementedError):
                PMat_implicit.vTMv(dw)
        else:
            check_ratio(PMat_dense.vTMv(dw), PMat_implicit.vTMv(dw))
Exemplo n.º 2
0
def test_jacobian_pdense():
    for get_task in nonlinear_tasks:
        for centering in [True, False]:
            loader, lc, parameters, model, function, n_output = get_task()
            model.train()
            generator = Jacobian(layer_collection=lc,
                                 model=model,
                                 loader=loader,
                                 function=function,
                                 n_output=n_output,
                                 centering=centering)
            PMat_dense = PMatDense(generator)
            dw = random_pvector(lc, device=device)

            # Test get_diag
            check_tensors(torch.diag(PMat_dense.get_dense_tensor()),
                          PMat_dense.get_diag())

            # Test frobenius
            frob_PMat = PMat_dense.frobenius_norm()
            frob_direct = (PMat_dense.get_dense_tensor()**2).sum()**.5
            check_ratio(frob_direct, frob_PMat)

            # Test trace
            trace_PMat = PMat_dense.trace()
            trace_direct = torch.trace(PMat_dense.get_dense_tensor())
            check_ratio(trace_PMat, trace_direct)

            # Test solve
            # NB: regul is high since the matrix is not full rank
            regul = 1e-3
            Mv_regul = torch.mv(PMat_dense.get_dense_tensor() +
                                regul * torch.eye(PMat_dense.size(0),
                                                  device=device),
                                dw.get_flat_representation())
            Mv_regul = PVector(layer_collection=lc,
                               vector_repr=Mv_regul)
            dw_using_inv = PMat_dense.solve(Mv_regul, regul=regul)
            check_tensors(dw.get_flat_representation(),
                          dw_using_inv.get_flat_representation(), eps=5e-3)

            # Test inv
            PMat_inv = PMat_dense.inverse(regul=regul)
            check_tensors(dw.get_flat_representation(),
                          PMat_inv.mv(PMat_dense.mv(dw) + regul * dw)
                          .get_flat_representation(), eps=5e-3)

            # Test add, sub, rmul
            loader, lc, parameters, model, function, n_output = get_task()
            model.train()
            generator = Jacobian(layer_collection=lc,
                                 model=model,
                                 loader=loader,
                                 function=function,
                                 n_output=n_output,
                                 centering=centering)
            PMat_dense2 = PMatDense(generator)

            check_tensors(PMat_dense.get_dense_tensor() +
                          PMat_dense2.get_dense_tensor(),
                          (PMat_dense + PMat_dense2).get_dense_tensor())
            check_tensors(PMat_dense.get_dense_tensor() -
                          PMat_dense2.get_dense_tensor(),
                          (PMat_dense - PMat_dense2).get_dense_tensor())
            check_tensors(1.23 * PMat_dense.get_dense_tensor(),
                          (1.23 * PMat_dense).get_dense_tensor())