def test_scale_invariant_regularization(e, outdtype, pardtype): if not nkjax.is_complex_dtype(pardtype) and nkjax.is_complex_dtype(outdtype): centered_jacobian_fun = qgt_jacobian_pytree_logic.centered_jacobian_cplx else: centered_jacobian_fun = qgt_jacobian_pytree_logic.centered_jacobian_real_holo mv = qgt_jacobian_pytree_logic._mat_vec centered_oks = centered_jacobian_fun(e.f, e.params, e.samples) centered_oks = qgt_jacobian_pytree_logic._divide_by_sqrt_n_samp( centered_oks, e.samples ) centered_oks_scaled, scale = qgt_jacobian_pytree_logic._rescale(centered_oks) actual = mv(e.v, centered_oks_scaled) expected = reassemble_complex(e.S_real_scaled @ e.v_real_flat, target=e.target) assert tree_allclose(actual, expected)
def test_matvec_treemv(e, jit, holomorphic, pardtype, outdtype, chunk_size): mv = qgt_jacobian_pytree_logic._mat_vec if not nkjax.is_complex_dtype(pardtype) and nkjax.is_complex_dtype(outdtype): centered_jacobian_fun = qgt_jacobian_pytree_logic.centered_jacobian_cplx else: centered_jacobian_fun = qgt_jacobian_pytree_logic.centered_jacobian_real_holo centered_jacobian_fun = partial(centered_jacobian_fun, chunk_size=chunk_size) if jit: mv = jax.jit(mv) centered_jacobian_fun = jax.jit(centered_jacobian_fun, static_argnums=0) centered_oks = centered_jacobian_fun(e.f, e.params, e.samples) centered_oks = qgt_jacobian_pytree_logic._divide_by_sqrt_n_samp( centered_oks, e.samples ) actual = mv(e.v, centered_oks) expected = reassemble_complex(e.S_real @ e.v_real_flat, target=e.target) assert tree_allclose(actual, expected)