def test_matvec_linear_transpose(e, jit): def mvt(v, f, params, samples, w): (res,) = jax.linear_transpose( lambda v_: qgt_onthefly_logic.mat_vec( v_, f, params, samples, 0.0, ), v, )(w) return res if jit: mvt = jax.jit(mvt, static_argnums=1) w = e.v actual = mvt(e.v, e.f, e.params, e.samples, w) # use that S is hermitian: # S^T = (O^H O)^T = O^T O* = (O^H O)* = S* # S^T w = S* w = (S w*)* expected = qgt_onthefly_logic.tree_conj( qgt_onthefly_logic.mat_vec( qgt_onthefly_logic.tree_conj(w), e.f, e.params, e.samples, 0.0, ) ) # (expected,) = jax.linear_transpose(lambda v_: reassemble_complex(S_real @ tree_toreal_flat(v_), target=e.target), v)(v) assert tree_allclose(actual, expected)
def test_vjp(e): actual = qgt_onthefly_logic.O_vjp(e.f, e.params, e.samples, e.w) expected = qgt_onthefly_logic.tree_conj( reassemble_complex( (e.w @ e.ok_real).real.astype(e.params_real_flat.dtype), target=e.target)) assert tree_allclose(actual, expected)
def reassemble_complex(x, target, fun=tree_toreal_flat): # target: a tree with the expected shape and types of the result (res, ) = jax.linear_transpose(fun, target)(x) res = qgt_onthefly_logic.tree_conj(res) # fix the dtypes: return qgt_onthefly_logic.tree_cast(res, target)
def test_mean(e): actual = qgt_onthefly_logic.O_mean(e.f, e.params, e.samples) expected = qgt_onthefly_logic.tree_conj( reassemble_complex(e.okmean_real.real, target=e.target)) assert tree_allclose(actual, expected)