Ejemplo n.º 1
0
def test_dot():
    model = ConvNet()
    layer_collection = LayerCollection.from_model(model)
    r1 = random_pvector(layer_collection)
    r2 = random_pvector(layer_collection)
    dotr1r2 = r1.dot(r2)
    check_ratio(
        torch.dot(r1.get_flat_representation(), r2.get_flat_representation()),
        dotr1r2)

    r1 = random_pvector_dict(layer_collection)
    r2 = random_pvector_dict(layer_collection)
    dotr1r2 = r1.dot(r2)
    check_ratio(
        torch.dot(r1.get_flat_representation(), r2.get_flat_representation()),
        dotr1r2)

    r1 = random_pvector(layer_collection)
    r2 = random_pvector_dict(layer_collection)
    dotr1r2 = r1.dot(r2)
    dotr2r1 = r2.dot(r1)
    check_ratio(
        torch.dot(r1.get_flat_representation(), r2.get_flat_representation()),
        dotr1r2)
    check_ratio(
        torch.dot(r1.get_flat_representation(), r2.get_flat_representation()),
        dotr2r1)
Ejemplo n.º 2
0
def test_jacobian_pquasidiag():
    for get_task in [get_conv_task, get_fullyconnect_task]:
        loader, lc, parameters, model, function, n_output = get_task()
        model.train()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             loader=loader,
                             function=function,
                             n_output=n_output)
        PMat_qd = PMatQuasiDiag(generator)
        dense_tensor = PMat_qd.get_dense_tensor()

        v = random_pvector(lc, device=device)
        v_flat = v.get_flat_representation()

        check_tensors(torch.diag(dense_tensor), PMat_qd.get_diag())

        check_ratio(torch.norm(dense_tensor), PMat_qd.frobenius_norm())

        check_ratio(torch.trace(dense_tensor), PMat_qd.trace())

        mv = PMat_qd.mv(v)
        check_tensors(torch.mv(dense_tensor, v_flat),
                      mv.get_flat_representation())

        check_ratio(torch.dot(torch.mv(dense_tensor, v_flat), v_flat),
                    PMat_qd.vTMv(v))

        # Test solve
        regul = 1e-8
        v_back = PMat_qd.solve(mv + regul * v, regul=regul)
        check_tensors(v.get_flat_representation(),
                      v_back.get_flat_representation())
Ejemplo n.º 3
0
def test_jacobian_pimplicit_vs_pdense():
    for get_task in linear_tasks + nonlinear_tasks:
        loader, lc, parameters, model, function, n_output = get_task()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             function=function,
                             n_output=n_output)
        PMat_implicit = PMatImplicit(generator=generator, examples=loader)
        PMat_dense = PMatDense(generator=generator, examples=loader)
        dw = random_pvector(lc, device=device)

        # Test trace
        check_ratio(PMat_dense.trace(), PMat_implicit.trace())

        # Test mv
        if ('BatchNorm1dLayer'
                in [l.__class__.__name__
                    for l in lc.layers.values()] or 'BatchNorm2dLayer'
                in [l.__class__.__name__ for l in lc.layers.values()]):
            with pytest.raises(NotImplementedError):
                PMat_implicit.mv(dw)
        else:
            check_tensors(
                PMat_dense.mv(dw).get_flat_representation(),
                PMat_implicit.mv(dw).get_flat_representation())

        # Test vTMv
        if ('BatchNorm1dLayer'
                in [l.__class__.__name__
                    for l in lc.layers.values()] or 'BatchNorm2dLayer'
                in [l.__class__.__name__ for l in lc.layers.values()]):
            with pytest.raises(NotImplementedError):
                PMat_implicit.vTMv(dw)
        else:
            check_ratio(PMat_dense.vTMv(dw), PMat_implicit.vTMv(dw))
Ejemplo n.º 4
0
def test_FIM_vs_linearization_classif_logits():
    step = 1e-2

    for get_task in nonlinear_tasks:
        quots = []
        for i in range(10):  # repeat to kill statistical fluctuations
            loader, lc, parameters, model, function, n_output = get_task()
            model.train()
            F = FIM(layer_collection=lc,
                    model=model,
                    loader=loader,
                    variant='classif_logits',
                    representation=PMatDense,
                    n_output=n_output,
                    function=lambda *d: model(to_device(d[0])))

            dw = random_pvector(lc, device=device)
            dw = step / dw.norm() * dw

            output_before = get_output_vector(loader, function)
            update_model(parameters, dw.get_flat_representation())
            output_after = get_output_vector(loader, function)
            update_model(parameters, -dw.get_flat_representation())

            KL = tF.kl_div(tF.log_softmax(output_before, dim=1),
                           tF.log_softmax(output_after, dim=1),
                           log_target=True,
                           reduction='batchmean')

            quot = (KL / F.vTMv(dw) * 2)**.5

            quots.append(quot.item())

        mean_quotient = sum(quots) / len(quots)
        assert mean_quotient > 1 - 5e-2 and mean_quotient < 1 + 5e-2
Ejemplo n.º 5
0
def test_FIM_vs_linearization_regression():
    step = 1e-2

    for get_task in nonlinear_tasks:
        quots = []
        for i in range(10):  # repeat to kill statistical fluctuations
            loader, lc, parameters, model, function, n_output = get_task()
            model.train()
            F = FIM(layer_collection=lc,
                    model=model,
                    loader=loader,
                    variant='regression',
                    representation=PMatDense,
                    n_output=n_output,
                    function=lambda *d: model(to_device(d[0])))

            dw = random_pvector(lc, device=device)
            dw = step / dw.norm() * dw

            output_before = get_output_vector(loader, function)
            update_model(parameters, dw.get_flat_representation())
            output_after = get_output_vector(loader, function)
            update_model(parameters, -dw.get_flat_representation())

            diff = (((output_before - output_after)**2).sum() /
                    output_before.size(0))

            quot = (diff / F.vTMv(dw))**.5

            quots.append(quot.item())

        mean_quotient = sum(quots) / len(quots)
        assert mean_quotient > 1 - 5e-2 and mean_quotient < 1 + 5e-2
Ejemplo n.º 6
0
def test_grad_flat_repr():
    loader, lc, parameters, model, function, n_output = get_conv_gn_task()

    vec = random_pvector(lc)
    scalar_output = vec.norm()

    with pytest.raises(RuntimeError):
        grad(scalar_output, vec)
Ejemplo n.º 7
0
def test_jacobian_kfac():
    for get_task in [get_fullyconnect_task, get_conv_task]:
        loader, lc, parameters, model, function, n_output = get_task()

        generator = Jacobian(layer_collection=lc,
                             model=model,
                             loader=loader,
                             function=function,
                             n_output=n_output)
        M_kfac = PMatKFAC(generator)
        G_kfac_split = M_kfac.get_dense_tensor(split_weight_bias=True)
        G_kfac = M_kfac.get_dense_tensor(split_weight_bias=False)

        # Test trace
        trace_direct = torch.trace(G_kfac_split)
        trace_kfac = M_kfac.trace()
        check_ratio(trace_direct, trace_kfac)

        # Test frobenius norm
        frob_direct = torch.norm(G_kfac)
        frob_kfac = M_kfac.frobenius_norm()
        check_ratio(frob_direct, frob_kfac)

        # Test get_diag
        check_tensors(torch.diag(G_kfac_split),
                      M_kfac.get_diag(split_weight_bias=True))

        # sample random vector
        random_v = random_pvector(lc, device)

        # Test mv
        mv_direct = torch.mv(G_kfac_split, random_v.get_flat_representation())
        mv_kfac = M_kfac.mv(random_v)
        check_tensors(mv_direct,
                      mv_kfac.get_flat_representation())

        # Test vTMv
        mnorm_kfac = M_kfac.vTMv(random_v)
        mnorm_direct = torch.dot(mv_direct, random_v.get_flat_representation())
        check_ratio(mnorm_direct, mnorm_kfac)

        # Test inverse
        # We start from a mv vector since it kills its components projected to
        # the small eigenvalues of KFAC
        regul = 1e-7

        mv2 = M_kfac.mv(mv_kfac)
        kfac_inverse = M_kfac.inverse(regul)
        mv_back = kfac_inverse.mv(mv2 + regul * mv_kfac)
        check_tensors(mv_kfac.get_flat_representation(),
                      mv_back.get_flat_representation(),
                      eps=1e-2)

        # Test solve
        mv_back = M_kfac.solve(mv2 + regul * mv_kfac, regul=regul)
        check_tensors(mv_kfac.get_flat_representation(),
                      mv_back.get_flat_representation(),
                      eps=1e-2)
Ejemplo n.º 8
0
def test_norm():
    model = ConvNet()
    layer_collection = LayerCollection.from_model(model)

    v = random_pvector(layer_collection)
    check_ratio(torch.norm(v.get_flat_representation()), v.norm())

    v = random_pvector_dict(layer_collection)
    check_ratio(torch.norm(v.get_flat_representation()), v.norm())
Ejemplo n.º 9
0
def test_size():
    model = ConvNet()
    layer_collection = LayerCollection.from_model(model)

    v = random_pvector(layer_collection)
    assert v.size() == v.get_flat_representation().size()

    v = random_pvector_dict(layer_collection)
    assert v.size() == v.get_flat_representation().size()
Ejemplo n.º 10
0
def test_pspace_ekfac_vs_direct():
    """
    Check EKFAC basis operations against direct computation using
    get_dense_tensor
    """
    for get_task in [get_fullyconnect_task, get_conv_task]:
        loader, lc, parameters, model, function, n_output = get_task()
        model.train()

        generator = Jacobian(layer_collection=lc,
                             model=model,
                             loader=loader,
                             function=function,
                             n_output=n_output)

        M_ekfac = PMatEKFAC(generator)
        v = random_pvector(lc, device=device)

        # the second time we will have called update_diag
        for i in range(2):
            vTMv_direct = torch.dot(torch.mv(M_ekfac.get_dense_tensor(),
                                             v.get_flat_representation()),
                                    v.get_flat_representation())
            vTMv_ekfac = M_ekfac.vTMv(v)
            check_ratio(vTMv_direct, vTMv_ekfac)

            trace_ekfac = M_ekfac.trace()
            trace_direct = torch.trace(M_ekfac.get_dense_tensor())
            check_ratio(trace_direct, trace_ekfac)

            frob_ekfac = M_ekfac.frobenius_norm()
            frob_direct = torch.norm(M_ekfac.get_dense_tensor())
            check_ratio(frob_direct, frob_ekfac)

            mv_direct = torch.mv(M_ekfac.get_dense_tensor(),
                                 v.get_flat_representation())
            mv_ekfac = M_ekfac.mv(v)
            check_tensors(mv_direct, mv_ekfac.get_flat_representation())

            # Test inverse
            regul = 1e-5    
            M_inv = M_ekfac.inverse(regul=regul)
            v_back = M_inv.mv(mv_ekfac + regul * v)
            check_tensors(v.get_flat_representation(),
                          v_back.get_flat_representation())

            # Test solve
            v_back = M_ekfac.solve(mv_ekfac + regul * v, regul=regul)
            check_tensors(v.get_flat_representation(),
                          v_back.get_flat_representation())

            # Test rmul
            M_mul = 1.23 * M_ekfac
            check_tensors(1.23 * M_ekfac.get_dense_tensor(),
                          M_mul.get_dense_tensor())

            M_ekfac.update_diag()
Ejemplo n.º 11
0
def test_jacobian_plowrank():
    for get_task in nonlinear_tasks:
        loader, lc, parameters, model, function, n_output = get_task()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             function=function,
                             n_output=n_output)
        PMat_lowrank = PMatLowRank(generator=generator, examples=loader)
        dw = random_pvector(lc, device=device)
        dw = dw / dw.norm()
        dense_tensor = PMat_lowrank.get_dense_tensor()

        # Test get_diag
        check_tensors(torch.diag(dense_tensor),
                      PMat_lowrank.get_diag(),
                      eps=1e-4)

        # Test frobenius
        frob_PMat = PMat_lowrank.frobenius_norm()
        frob_direct = (dense_tensor**2).sum()**.5
        check_ratio(frob_direct, frob_PMat)

        # Test trace
        trace_PMat = PMat_lowrank.trace()
        trace_direct = torch.trace(dense_tensor)
        check_ratio(trace_PMat, trace_direct)

        # Test mv
        mv_direct = torch.mv(dense_tensor, dw.get_flat_representation())
        mv = PMat_lowrank.mv(dw)
        check_tensors(mv_direct, mv.get_flat_representation())

        # Test vTMV
        check_ratio(torch.dot(mv_direct, dw.get_flat_representation()),
                    PMat_lowrank.vTMv(dw))

        # Test solve
        # We will try to recover mv, which is in the span of the
        # low rank matrix
        regul = 1e-3
        mmv = PMat_lowrank.mv(mv)
        mv_using_inv = PMat_lowrank.solve(mmv, regul=regul)
        check_tensors(mv.get_flat_representation(),
                      mv_using_inv.get_flat_representation(),
                      eps=1e-2)
        # Test inv TODO

        # Test add, sub, rmul

        check_tensors(1.23 * PMat_lowrank.get_dense_tensor(),
                      (1.23 * PMat_lowrank).get_dense_tensor())
Ejemplo n.º 12
0
def test_sub():
    model = ConvNet()
    layer_collection = LayerCollection.from_model(model)
    r1 = random_pvector(layer_collection)
    r2 = random_pvector(layer_collection)
    sumr1r2 = r1 - r2
    assert torch.norm(sumr1r2.get_flat_representation() -
                      (r1.get_flat_representation() -
                       r2.get_flat_representation())) < 1e-5

    r1 = random_pvector_dict(layer_collection)
    r2 = random_pvector_dict(layer_collection)
    sumr1r2 = r1 - r2
    assert torch.norm(sumr1r2.get_flat_representation() -
                      (r1.get_flat_representation() -
                       r2.get_flat_representation())) < 1e-5

    r1 = random_pvector(layer_collection)
    r2 = random_pvector_dict(layer_collection)
    sumr1r2 = r1 - r2
    assert torch.norm(sumr1r2.get_flat_representation() -
                      (r1.get_flat_representation() -
                       r2.get_flat_representation())) < 1e-5
Ejemplo n.º 13
0
def test_jacobian_pullback_dense():
    for get_task in linear_tasks:
        loader, lc, parameters, model, function, n_output = get_task()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             loader=loader,
                             function=function,
                             n_output=n_output)
        pull_back = PullBackDense(generator)
        push_forward = PushForwardDense(generator)
        dw = random_pvector(lc, device=device)

        doutput_lin = push_forward.mv(dw)
        dinput_lin = pull_back.mv(doutput_lin)
        check_ratio(torch.dot(dw.get_flat_representation(),
                              dinput_lin.get_flat_representation()),
                    torch.norm(doutput_lin.get_flat_representation())**2)
Ejemplo n.º 14
0
def test_jacobian_pushforward_implicit():
    for get_task in linear_tasks:
        loader, lc, parameters, model, function, n_output = get_task()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             loader=loader,
                             function=function,
                             n_output=n_output)
        dense_push_forward = PushForwardDense(generator)
        implicit_push_forward = PushForwardImplicit(generator)
        dw = random_pvector(lc, device=device)

        doutput_lin_dense = dense_push_forward.mv(dw)
        doutput_lin_implicit = implicit_push_forward.mv(dw)

        check_tensors(doutput_lin_dense.get_flat_representation(),
                      doutput_lin_implicit.get_flat_representation())
Ejemplo n.º 15
0
def test_jacobian_pushforward_dense_linear():
    for get_task in linear_tasks:
        loader, lc, parameters, model, function, n_output = get_task()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             function=function,
                             n_output=n_output)
        push_forward = PushForwardDense(generator=generator, examples=loader)
        dw = random_pvector(lc, device=device)

        doutput_lin = push_forward.mv(dw)

        output_before = get_output_vector(loader, function)
        update_model(parameters, dw.get_flat_representation())
        output_after = get_output_vector(loader, function)

        check_tensors(output_after - output_before,
                      doutput_lin.get_flat_representation().t())
Ejemplo n.º 16
0
def test_jacobian_plowrank():
    for get_task in nonlinear_tasks:
        loader, lc, parameters, model, function, n_output = get_task()
        model.train()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             loader=loader,
                             function=function,
                             n_output=n_output)
        PMat_lowrank = PMatLowRank(generator)
        dw = random_pvector(lc, device=device)
        dense_tensor = PMat_lowrank.get_dense_tensor()

        # Test get_diag
        check_tensors(torch.diag(dense_tensor),
                      PMat_lowrank.get_diag(),
                      eps=1e-4)

        # Test frobenius
        frob_PMat = PMat_lowrank.frobenius_norm()
        frob_direct = (dense_tensor**2).sum()**.5
        check_ratio(frob_direct, frob_PMat)

        # Test trace
        trace_PMat = PMat_lowrank.trace()
        trace_direct = torch.trace(dense_tensor)
        check_ratio(trace_PMat, trace_direct)

        # Test mv
        mv_direct = torch.mv(dense_tensor, dw.get_flat_representation())
        check_tensors(mv_direct,
                      PMat_lowrank.mv(dw).get_flat_representation())

        # Test vTMV
        check_ratio(torch.dot(mv_direct, dw.get_flat_representation()),
                    PMat_lowrank.vTMv(dw))

        # Test solve TODO
        # Test inv TODO

        # Test add, sub, rmul

        check_tensors(1.23 * PMat_lowrank.get_dense_tensor(),
                      (1.23 * PMat_lowrank).get_dense_tensor())
Ejemplo n.º 17
0
def test_jacobian_pdense_vs_pushforward():
    # NB: sometimes the test with centering=True do not pass,
    # which is probably due to the way we compute centering
    # for PMatDense: E[x^2] - (Ex)^2 is notoriously not numerically stable
    for get_task in linear_tasks + nonlinear_tasks:
        for centering in [True, False]:
            loader, lc, parameters, model, function, n_output = get_task()
            model.train()
            generator = Jacobian(layer_collection=lc,
                                 model=model,
                                 loader=loader,
                                 function=function,
                                 n_output=n_output,
                                 centering=centering)
            push_forward = PushForwardDense(generator)
            pull_back = PullBackDense(generator, data=push_forward.data)
            PMat_dense = PMatDense(generator)
            dw = random_pvector(lc, device=device)
            n = len(loader.sampler)

            # Test get_dense_tensor
            jacobian = push_forward.get_dense_tensor()
            sj = jacobian.size()
            PMat_computed = torch.mm(jacobian.view(-1, sj[2]).t(),
                                     jacobian.view(-1, sj[2])) / n
            check_tensors(PMat_computed,
                          PMat_dense.get_dense_tensor())

            # Test vTMv
            vTMv_PMat = PMat_dense.vTMv(dw)
            Jv_pushforward = push_forward.mv(dw)
            Jv_pushforward_flat = Jv_pushforward.get_flat_representation()
            vTMv_pushforward = torch.dot(Jv_pushforward_flat.view(-1),
                                         Jv_pushforward_flat.view(-1)) / n
            check_ratio(vTMv_pushforward, vTMv_PMat)

            # Test Mv
            Mv_PMat = PMat_dense.mv(dw)
            Mv_pf_pb = pull_back.mv(Jv_pushforward)
            check_tensors(Mv_pf_pb.get_flat_representation() / n,
                          Mv_PMat.get_flat_representation())
Ejemplo n.º 18
0
def test_jacobian_pushforward_dense_nonlinear():
    for get_task in nonlinear_tasks:
        loader, lc, parameters, model, function, n_output = get_task()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             function=function,
                             n_output=n_output)
        push_forward = PushForwardDense(generator=generator, examples=loader)
        dw = random_pvector(lc, device=device)
        dw = 1e-4 / dw.norm() * dw

        doutput_lin = push_forward.mv(dw)

        output_before = get_output_vector(loader, function)
        update_model(parameters, dw.get_flat_representation())
        output_after = get_output_vector(loader, function)

        # This is non linear, so we don't expect the finite difference
        # estimate to be very accurate. We use a larger eps value
        check_tensors(output_after - output_before,
                      doutput_lin.get_flat_representation().t(),
                      eps=5e-2)
Ejemplo n.º 19
0
        results[model].append([])

        F = FIM_MonteCarlo(m,
                           smalltestloader,
                           PMatImplicit,
                           trials=5,
                           device='cuda')
        # G = FMatDense(F.generator)
        # frob = torch.norm(G.sum(dim=(0, 2))) / n_samples
        tr = F.trace()

        for j in range(4):

            results[model][-1].append(dict())
            v = random_pvector(F.generator.layer_collection, device='cuda')
            v_flat = v.get_flat_representation()
            Fv = F.mv(v)
            Fv_flat = Fv.get_flat_representation()
            vTMv = F.vTMv(v)

            for repr in [PMatDiag, PMatQuasiDiag, PMatKFAC, PMatEKFAC]:
                results[model][-1][-1][repr] = dict()
                F2 = repr(F.generator)
                if repr == PMatEKFAC:
                    F2.update_diag()

                tr_repr = F2.trace()
                results[model][-1][-1][repr]['trace'] = torch.abs(
                    (tr_repr - tr) / tr).item()
Ejemplo n.º 20
0
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

trainset = datasets.CIFAR10(root='/tmp/data',
                            train=True,
                            download=True,
                            transform=transform)
trainset = Subset(trainset, range(100))
trainloader = DataLoader(trainset, batch_size=50, shuffle=False, num_workers=1)

# %%
from resnet import ResNet50
resnet = ResNet50().cuda()

layer_collection = LayerCollection.from_model(resnet)
v = random_pvector(LayerCollection.from_model(resnet), device='cuda')

print(f'{layer_collection.numel()} parameters')

# %%
# compute timings and display FIMs


def perform_timing():
    timings = dict()

    for repr in [PMatImplicit, PMatDiag, PMatEKFAC, PMatKFAC, PMatQuasiDiag]:

        print('Timing representation:')
        pprint.pprint(repr)
Ejemplo n.º 21
0
def test_jacobian_pblockdiag():
    for get_task in nonlinear_tasks:
        loader, lc, parameters, model, function, n_output = get_task()
        model.train()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             loader=loader,
                             function=function,
                             n_output=n_output)
        PMat_blockdiag = PMatBlockDiag(generator)
        dw = random_pvector(lc, device=device)
        dense_tensor = PMat_blockdiag.get_dense_tensor()

        # Test get_diag
        check_tensors(torch.diag(dense_tensor),
                      PMat_blockdiag.get_diag())

        # Test frobenius
        frob_PMat = PMat_blockdiag.frobenius_norm()
        frob_direct = (dense_tensor**2).sum()**.5
        check_ratio(frob_direct, frob_PMat)

        # Test trace
        trace_PMat = PMat_blockdiag.trace()
        trace_direct = torch.trace(dense_tensor)
        check_ratio(trace_PMat, trace_direct)

        # Test mv
        mv_direct = torch.mv(dense_tensor, dw.get_flat_representation())
        check_tensors(mv_direct,
                      PMat_blockdiag.mv(dw).get_flat_representation())

        # Test vTMV
        check_ratio(torch.dot(mv_direct, dw.get_flat_representation()),
                    PMat_blockdiag.vTMv(dw))

        # Test solve
        regul = 1e-3
        Mv_regul = torch.mv(dense_tensor +
                            regul * torch.eye(PMat_blockdiag.size(0),
                                              device=device),
                            dw.get_flat_representation())
        Mv_regul = PVector(layer_collection=lc,
                           vector_repr=Mv_regul)
        dw_using_inv = PMat_blockdiag.solve(Mv_regul, regul=regul)
        check_tensors(dw.get_flat_representation(),
                      dw_using_inv.get_flat_representation(), eps=5e-3)

        # Test inv
        PMat_inv = PMat_blockdiag.inverse(regul=regul)
        check_tensors(dw.get_flat_representation(),
                      PMat_inv.mv(PMat_blockdiag.mv(dw) + regul * dw)
                      .get_flat_representation(), eps=5e-3)

        # Test add, sub, rmul
        loader, lc, parameters, model, function, n_output = get_task()
        model.train()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             loader=loader,
                             function=function,
                             n_output=n_output)
        PMat_blockdiag2 = PMatBlockDiag(generator)

        check_tensors(PMat_blockdiag.get_dense_tensor() +
                      PMat_blockdiag2.get_dense_tensor(),
                      (PMat_blockdiag + PMat_blockdiag2)
                      .get_dense_tensor())
        check_tensors(PMat_blockdiag.get_dense_tensor() -
                      PMat_blockdiag2.get_dense_tensor(),
                      (PMat_blockdiag - PMat_blockdiag2)
                      .get_dense_tensor())
        check_tensors(1.23 * PMat_blockdiag.get_dense_tensor(),
                      (1.23 * PMat_blockdiag).get_dense_tensor())
Ejemplo n.º 22
0
def test_jacobian_pdiag_vs_pdense():
    for get_task in nonlinear_tasks:
        loader, lc, parameters, model, function, n_output = get_task()
        model.train()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             loader=loader,
                             function=function,
                             n_output=n_output)
        PMat_diag = PMatDiag(generator)
        PMat_dense = PMatDense(generator)
        dw = random_pvector(lc, device=device)

        # Test get_dense_tensor
        matrix_diag = PMat_diag.get_dense_tensor()
        matrix_dense = PMat_dense.get_dense_tensor()
        check_tensors(torch.diag(matrix_diag),
                      torch.diag(matrix_dense))
        assert torch.norm(matrix_diag -
                          torch.diag(torch.diag(matrix_diag))) < 1e-5

        # Test trace
        check_ratio(torch.trace(matrix_diag),
                    PMat_diag.trace())

        # Test frobenius
        check_ratio(torch.norm(matrix_diag),
                    PMat_diag.frobenius_norm())

        # Test mv
        mv_direct = torch.mv(matrix_diag, dw.get_flat_representation())
        mv_PMat_diag = PMat_diag.mv(dw)
        check_tensors(mv_direct,
                      mv_PMat_diag.get_flat_representation())

        # Test vTMv
        vTMv_direct = torch.dot(mv_direct, dw.get_flat_representation())
        vTMv_PMat_diag = PMat_diag.vTMv(dw)
        check_ratio(vTMv_direct, vTMv_PMat_diag)

        # Test inverse
        regul = 1e-3
        PMat_diag_inverse = PMat_diag.inverse(regul)
        prod = torch.mm(matrix_diag + regul * torch.eye(lc.numel(),
                                                        device=device),
                        PMat_diag_inverse.get_dense_tensor())
        check_tensors(torch.eye(lc.numel(), device=device),
                      prod)

        # Test solve
        regul = 1e-3
        Mv_regul = torch.mv(matrix_diag +
                            regul * torch.eye(PMat_diag.size(0),
                                              device=device),
                            dw.get_flat_representation())
        Mv_regul = PVector(layer_collection=lc,
                           vector_repr=Mv_regul)
        dw_using_inv = PMat_diag.solve(Mv_regul, regul=regul)
        check_tensors(dw.get_flat_representation(),
                      dw_using_inv.get_flat_representation(), eps=5e-3)

        # Test get_diag
        diag_direct = torch.diag(matrix_diag)
        diag_PMat_diag = PMat_diag.get_diag()
        check_tensors(diag_direct, diag_PMat_diag)

        # Test add, sub, rmul
        loader, lc, parameters, model, function, n_output = get_task()
        model.train()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             loader=loader,
                             function=function,
                             n_output=n_output)
        PMat_diag2 = PMatDiag(generator)

        check_tensors(PMat_diag.get_dense_tensor() +
                      PMat_diag2.get_dense_tensor(),
                      (PMat_diag + PMat_diag2).get_dense_tensor())
        check_tensors(PMat_diag.get_dense_tensor() -
                      PMat_diag2.get_dense_tensor(),
                      (PMat_diag - PMat_diag2).get_dense_tensor())
        check_tensors(1.23 * PMat_diag.get_dense_tensor(),
                      (1.23 * PMat_diag).get_dense_tensor())
Ejemplo n.º 23
0
def test_jacobian_pdense():
    for get_task in nonlinear_tasks:
        for centering in [True, False]:
            loader, lc, parameters, model, function, n_output = get_task()
            model.train()
            generator = Jacobian(layer_collection=lc,
                                 model=model,
                                 loader=loader,
                                 function=function,
                                 n_output=n_output,
                                 centering=centering)
            PMat_dense = PMatDense(generator)
            dw = random_pvector(lc, device=device)

            # Test get_diag
            check_tensors(torch.diag(PMat_dense.get_dense_tensor()),
                          PMat_dense.get_diag())

            # Test frobenius
            frob_PMat = PMat_dense.frobenius_norm()
            frob_direct = (PMat_dense.get_dense_tensor()**2).sum()**.5
            check_ratio(frob_direct, frob_PMat)

            # Test trace
            trace_PMat = PMat_dense.trace()
            trace_direct = torch.trace(PMat_dense.get_dense_tensor())
            check_ratio(trace_PMat, trace_direct)

            # Test solve
            # NB: regul is high since the matrix is not full rank
            regul = 1e-3
            Mv_regul = torch.mv(PMat_dense.get_dense_tensor() +
                                regul * torch.eye(PMat_dense.size(0),
                                                  device=device),
                                dw.get_flat_representation())
            Mv_regul = PVector(layer_collection=lc,
                               vector_repr=Mv_regul)
            dw_using_inv = PMat_dense.solve(Mv_regul, regul=regul)
            check_tensors(dw.get_flat_representation(),
                          dw_using_inv.get_flat_representation(), eps=5e-3)

            # Test inv
            PMat_inv = PMat_dense.inverse(regul=regul)
            check_tensors(dw.get_flat_representation(),
                          PMat_inv.mv(PMat_dense.mv(dw) + regul * dw)
                          .get_flat_representation(), eps=5e-3)

            # Test add, sub, rmul
            loader, lc, parameters, model, function, n_output = get_task()
            model.train()
            generator = Jacobian(layer_collection=lc,
                                 model=model,
                                 loader=loader,
                                 function=function,
                                 n_output=n_output,
                                 centering=centering)
            PMat_dense2 = PMatDense(generator)

            check_tensors(PMat_dense.get_dense_tensor() +
                          PMat_dense2.get_dense_tensor(),
                          (PMat_dense + PMat_dense2).get_dense_tensor())
            check_tensors(PMat_dense.get_dense_tensor() -
                          PMat_dense2.get_dense_tensor(),
                          (PMat_dense - PMat_dense2).get_dense_tensor())
            check_tensors(1.23 * PMat_dense.get_dense_tensor(),
                          (1.23 * PMat_dense).get_dense_tensor())