Ejemplo n.º 1
0
"""Test class for module Batch_l2_grad (squared ℓ₂ norm of batch gradients)"""

from test.automated_test import check_sizes_and_values
from test.extensions.firstorder.batch_l2_grad.batchl2grad_settings import (
    BATCHl2GRAD_SETTINGS, )
from test.extensions.implementation.autograd import AutogradExtensions
from test.extensions.implementation.backpack import BackpackExtensions
from test.extensions.problem import make_test_problems

import pytest

PROBLEMS = make_test_problems(BATCHl2GRAD_SETTINGS)
IDS = [problem.make_id() for problem in PROBLEMS]


@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS)
def test_batch_l2_grad(problem):
    """Test squared ℓ₂ norm of individual gradients.

    Args:
        problem (ExtensionsTestProblem): Problem for extension test.
    """
    problem.set_up()

    backpack_res = BackpackExtensions(problem).batch_l2_grad()
    autograd_res = AutogradExtensions(problem).batch_l2_grad()

    check_sizes_and_values(autograd_res, backpack_res)
    problem.tear_down()

Ejemplo n.º 2
0
Test individual gradients for the following layers:
- sum of the square of batch gradients of linear layers
- sum of the square of batch gradients of convolutional layers

"""
from test.automated_test import check_sizes_and_values
from test.extensions.firstorder.sum_grad_squared.sumgradsquared_settings import (
    SUMGRADSQUARED_SETTINGS, )
from test.extensions.implementation.autograd import AutogradExtensions
from test.extensions.implementation.backpack import BackpackExtensions
from test.extensions.problem import make_test_problems

import pytest

PROBLEMS = make_test_problems(SUMGRADSQUARED_SETTINGS)
IDS = [problem.make_id() for problem in PROBLEMS]


@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS)
def test_sum_grad_squared(problem):
    """Test sum of square of individual gradients

    Args:
        problem (ExtensionsTestProblem): Problem for extension test.
    """
    problem.set_up()

    backpack_res = BackpackExtensions(problem).sgs()
    autograd_res = AutogradExtensions(problem).sgs()
Ejemplo n.º 3
0
"""Test BackPACK's KFAC extension."""
from test.automated_test import check_sizes_and_values
from test.extensions.implementation.autograd import AutogradExtensions
from test.extensions.implementation.backpack import BackpackExtensions
from test.extensions.problem import ExtensionsTestProblem, make_test_problems
from test.extensions.secondorder.hbp.kfac_settings import (
    BATCH_SIZE_1_SETTINGS,
    NOT_SUPPORTED_SETTINGS,
)

import pytest

from backpack.utils.kroneckers import kfacs_to_mat

NOT_SUPPORTED_PROBLEMS = make_test_problems(NOT_SUPPORTED_SETTINGS)
NOT_SUPPORTED_IDS = [problem.make_id() for problem in NOT_SUPPORTED_PROBLEMS]
BATCH_SIZE_1_PROBLEMS = make_test_problems(BATCH_SIZE_1_SETTINGS)
BATCH_SIZE_1_IDS = [problem.make_id() for problem in BATCH_SIZE_1_PROBLEMS]


@pytest.mark.parametrize("problem",
                         NOT_SUPPORTED_PROBLEMS,
                         ids=NOT_SUPPORTED_IDS)
def test_kfac_not_supported(problem):
    """Check that the KFAC extension does not allow specific hyperparameters/modules.

    Args:
        problem (ExtensionsTestProblem): Test case.
    """
    problem.set_up()
Ejemplo n.º 4
0
"""Tests BackPACK's ``SqrtGGNExact`` and ``SqrtGGNMC`` extension."""

from math import isclose
from test.automated_test import check_sizes_and_values
from test.extensions.implementation.autograd import AutogradExtensions
from test.extensions.implementation.backpack import BackpackExtensions
from test.extensions.problem import ExtensionsTestProblem, make_test_problems
from test.extensions.secondorder.sqrt_ggn.sqrt_ggn_settings import SQRT_GGN_SETTINGS
from test.utils.skip_test import skip_large_parameters, skip_subsampling_conflict
from typing import List, Union

from pytest import fixture, mark

PROBLEMS = make_test_problems(SQRT_GGN_SETTINGS)

SUBSAMPLINGS = [None, [0, 0], [2, 0]]
SUBSAMPLING_IDS = [f"subsampling={s}".replace(" ", "") for s in SUBSAMPLINGS]


@fixture(params=PROBLEMS, ids=lambda p: p.make_id())
def problem(request) -> ExtensionsTestProblem:
    """Set seed, create tested model, loss, data. Finally clean up.

    Args:
        request (SubRequest): Request for the fixture from a test/fixture function.

    Yields:
        Test case with deterministically constructed attributes.
    """
    case = request.param
    case.set_up()
Ejemplo n.º 5
0
"""Test DiagGGN extension."""
from test.automated_test import check_sizes_and_values
from test.extensions.implementation.autograd import AutogradExtensions
from test.extensions.implementation.backpack import BackpackExtensions
from test.extensions.problem import make_test_problems
from test.extensions.secondorder.diag_ggn.diag_ggn_settings import DiagGGN_SETTINGS
from test.utils.skip_test import skip_adaptive_avg_pool3d_cuda

import pytest

PROBLEMS = make_test_problems(DiagGGN_SETTINGS)
IDS = [problem.make_id() for problem in PROBLEMS]


@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS)
def test_diag_ggn(problem, request):
    """Test the diagonal of generalized Gauss-Newton.

    Args:
        problem (ExtensionsTestProblem): Problem for extension test.
        request: problem request
    """
    skip_adaptive_avg_pool3d_cuda(request)
    problem.set_up()

    backpack_res = BackpackExtensions(problem).diag_ggn()
    autograd_res = AutogradExtensions(problem).diag_ggn()

    check_sizes_and_values(autograd_res, backpack_res)
    problem.tear_down()
Ejemplo n.º 6
0
"""Test BackPACK's KFAC extension."""

from test.extensions.implementation.backpack import BackpackExtensions
from test.extensions.problem import make_test_problems
from test.extensions.secondorder.hbp.kfac_settings import NOT_SUPPORTED_SETTINGS

import pytest

NOT_SUPPORTED_PROBLEMS = make_test_problems(NOT_SUPPORTED_SETTINGS)
NOT_SUPPORTED_IDS = [problem.make_id() for problem in NOT_SUPPORTED_PROBLEMS]


@pytest.mark.parametrize("problem",
                         NOT_SUPPORTED_PROBLEMS,
                         ids=NOT_SUPPORTED_IDS)
def test_kfac_not_supported(problem):
    """Check that the KFAC extension does not allow specific hyperparameters/modules.

    Args:
        problem (ExtensionsTestProblem): Test case.
    """
    problem.set_up()

    with pytest.raises(NotImplementedError):
        BackpackExtensions(problem).kfac()

    problem.tear_down()
Ejemplo n.º 7
0
"""Test BackPACK's ``Variance`` extension."""
from test.automated_test import check_sizes_and_values
from test.extensions.firstorder.variance.variance_settings import VARIANCE_SETTINGS
from test.extensions.implementation.autograd import AutogradExtensions
from test.extensions.implementation.backpack import BackpackExtensions
from test.extensions.problem import ExtensionsTestProblem, make_test_problems

import pytest

PROBLEMS = make_test_problems(VARIANCE_SETTINGS)
IDS = [problem.make_id() for problem in PROBLEMS]


@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS)
def test_variance(problem: ExtensionsTestProblem) -> None:
    """Test variance of individual gradients.

    Args:
        problem: Test case.
    """
    problem.set_up()

    backpack_res = BackpackExtensions(problem).variance()
    autograd_res = AutogradExtensions(problem).variance()

    rtol = 5e-5
    check_sizes_and_values(autograd_res, backpack_res, rtol=rtol)
    problem.tear_down()
Ejemplo n.º 8
0
from test.automated_test import check_sizes_and_values
from test.extensions.implementation.autograd import AutogradExtensions
from test.extensions.implementation.backpack import BackpackExtensions
from test.extensions.problem import make_test_problems
from test.extensions.secondorder.diag_hessian.diagh_settings import DiagHESSIAN_SETTINGS

import pytest

PROBLEMS = make_test_problems(DiagHESSIAN_SETTINGS)
IDS = [problem.make_id() for problem in PROBLEMS]


@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS)
def test_diag_h(problem):
    """Test Diagonal of Hessian

    Args:
        problem (ExtensionsTestProblem): Problem for extension test.
    """
    problem.set_up()

    backpack_res = BackpackExtensions(problem).diag_h()
    autograd_res = AutogradExtensions(problem).diag_h()

    check_sizes_and_values(autograd_res, backpack_res)
    problem.tear_down()


@pytest.mark.parametrize("problem", PROBLEMS, ids=IDS)
def test_diag_h_batch(problem):
    """Test Diagonal of Hessian