示例#1
0
def test_cabs(problem, independent_runs, q_kwargs):
    """Compare BackPACK and ``torch.autograd`` implementation of CABS.

    Args:
        problem (tests.utils.Problem): Settings for train loop.
        independent_runs (bool): Whether to use to separate runs to compute the
            output of every quantity.
        q_kwargs (dict): Keyword arguments handed over to both quantities.
    """
    compare_fn = get_compare_fn(independent_runs)
    compare_fn(problem, (CABS, AutogradCABS), q_kwargs)
示例#2
0
def test_tic_trace(problem, independent_runs, q_kwargs):
    """Compare BackPACK and ``torch.autograd`` implementation of TICTrace.

    Both quantities run simultaneously in the same cockpit.

    Args:
        problem (tests.utils.Problem): Settings for train loop.
        independent_runs (bool): Whether to use to separate runs to compute the
            output of every quantity.
        q_kwargs (dict): Keyword arguments handed over to both quantities.
    """
    compare_fn = get_compare_fn(independent_runs)
    compare_fn(problem, (TICTrace, AutogradTICTrace), q_kwargs)
示例#3
0
def test_alpha(problem, independent_runs, q_kwargs):
    """Compare BackPACK and ``torch.autograd`` implementation of Alpha.

    Args:
        problem (tests.utils.Problem): Settings for train loop.
        independent_runs (bool): Whether to use to separate runs to compute the
            output of every quantity.
        q_kwargs (dict): Keyword arguments handed over to both quantities.
    """
    atol = 1e-5
    rtol = 1e-5

    compare_fn = get_compare_fn(independent_runs)
    compare_fn(problem, (Alpha, AutogradAlphaGeneral), q_kwargs, atol=atol, rtol=rtol)
示例#4
0
def test_mean_gsnr(problem, independent_runs, q_kwargs):
    """Compare BackPACK and ``torch.autograd`` implementation of MeanGSNR.

    Args:
        problem (tests.utils.Problem): Settings for train loop.
        independent_runs (bool): Whether to use to separate runs to compute the
            output of every quantity.
        q_kwargs (dict): Keyword arguments handed over to both quantities.
    """
    rtol, atol = 5e-3, 1e-5

    compare_fn = get_compare_fn(independent_runs)
    compare_fn(problem, (MeanGSNR, AutogradMeanGSNR),
               q_kwargs,
               rtol=rtol,
               atol=atol)
示例#5
0
def test_grad_hist2d_few_bins_cpu(problem, independent_runs, q_kwargs):
    """Compare BackPACK and ``torch.autograd`` implementation of GradHist2d.

    Use a small number of bins. This is because the histogram implementations
    have different floating precision inaccuracies. This leads to slightly
    deviating bin counts, which is hard to check. On GPUs, this problem becomes
    more pronounced.

    Args:
        problem (tests.utils.Problem): Settings for train loop.
        independent_runs (bool): Whether to use to separate runs to compute the
            output of every quantity.
        q_kwargs (dict): Keyword arguments handed over to both quantities.
    """
    q_extra_kwargs = {"bins": (4, 5)}
    q_extra_kwargs = {"bins": (8, 5), "range": ((-0.01, 0.01), (-0.02, 0.02))}
    combined_q_kwargs = {**q_kwargs, **q_extra_kwargs}

    compare_fn = get_compare_fn(independent_runs)
    compare_fn(problem, (GradHist2d, AutogradGradHist2d), combined_q_kwargs)