예제 #1
0
def test_FIM_vs_linearization_classif_logits():
    step = 1e-2

    for get_task in nonlinear_tasks:
        quots = []
        for i in range(10):  # repeat to kill statistical fluctuations
            loader, lc, parameters, model, function, n_output = get_task()
            model.train()
            F = FIM(layer_collection=lc,
                    model=model,
                    loader=loader,
                    variant='classif_logits',
                    representation=PMatDense,
                    n_output=n_output,
                    function=lambda *d: model(to_device(d[0])))

            dw = random_pvector(lc, device=device)
            dw = step / dw.norm() * dw

            output_before = get_output_vector(loader, function)
            update_model(parameters, dw.get_flat_representation())
            output_after = get_output_vector(loader, function)
            update_model(parameters, -dw.get_flat_representation())

            KL = tF.kl_div(tF.log_softmax(output_before, dim=1),
                           tF.log_softmax(output_after, dim=1),
                           log_target=True,
                           reduction='batchmean')

            quot = (KL / F.vTMv(dw) * 2)**.5

            quots.append(quot.item())

        mean_quotient = sum(quots) / len(quots)
        assert mean_quotient > 1 - 5e-2 and mean_quotient < 1 + 5e-2
예제 #2
0
def test_FIM_vs_linearization_regression():
    step = 1e-2

    for get_task in nonlinear_tasks:
        quots = []
        for i in range(10):  # repeat to kill statistical fluctuations
            loader, lc, parameters, model, function, n_output = get_task()
            model.train()
            F = FIM(layer_collection=lc,
                    model=model,
                    loader=loader,
                    variant='regression',
                    representation=PMatDense,
                    n_output=n_output,
                    function=lambda *d: model(to_device(d[0])))

            dw = random_pvector(lc, device=device)
            dw = step / dw.norm() * dw

            output_before = get_output_vector(loader, function)
            update_model(parameters, dw.get_flat_representation())
            output_after = get_output_vector(loader, function)
            update_model(parameters, -dw.get_flat_representation())

            diff = (((output_before - output_after)**2).sum() /
                    output_before.size(0))

            quot = (diff / F.vTMv(dw))**.5

            quots.append(quot.item())

        mean_quotient = sum(quots) / len(quots)
        assert mean_quotient > 1 - 5e-2 and mean_quotient < 1 + 5e-2