Пример #1
0
def test_scale_invariant_regularization(e, outdtype, pardtype):

    if not nkjax.is_complex_dtype(pardtype) and nkjax.is_complex_dtype(outdtype):
        centered_jacobian_fun = qgt_jacobian_pytree_logic.centered_jacobian_cplx
    else:
        centered_jacobian_fun = qgt_jacobian_pytree_logic.centered_jacobian_real_holo

    mv = qgt_jacobian_pytree_logic._mat_vec
    centered_oks = centered_jacobian_fun(e.f, e.params, e.samples)
    centered_oks = qgt_jacobian_pytree_logic._divide_by_sqrt_n_samp(
        centered_oks, e.samples
    )

    centered_oks_scaled, scale = qgt_jacobian_pytree_logic._rescale(centered_oks)
    actual = mv(e.v, centered_oks_scaled)
    expected = reassemble_complex(e.S_real_scaled @ e.v_real_flat, target=e.target)
    assert tree_allclose(actual, expected)
Пример #2
0
def test_matvec_treemv(e, jit, holomorphic, pardtype, outdtype, chunk_size):
    mv = qgt_jacobian_pytree_logic._mat_vec

    if not nkjax.is_complex_dtype(pardtype) and nkjax.is_complex_dtype(outdtype):
        centered_jacobian_fun = qgt_jacobian_pytree_logic.centered_jacobian_cplx
    else:
        centered_jacobian_fun = qgt_jacobian_pytree_logic.centered_jacobian_real_holo
    centered_jacobian_fun = partial(centered_jacobian_fun, chunk_size=chunk_size)
    if jit:
        mv = jax.jit(mv)
        centered_jacobian_fun = jax.jit(centered_jacobian_fun, static_argnums=0)

    centered_oks = centered_jacobian_fun(e.f, e.params, e.samples)
    centered_oks = qgt_jacobian_pytree_logic._divide_by_sqrt_n_samp(
        centered_oks, e.samples
    )
    actual = mv(e.v, centered_oks)
    expected = reassemble_complex(e.S_real @ e.v_real_flat, target=e.target)
    assert tree_allclose(actual, expected)