示例#1
0
def test_matvec_linear_transpose(e, jit):
    def mvt(v, f, params, samples, w):
        (res,) = jax.linear_transpose(
            lambda v_: qgt_onthefly_logic.mat_vec(
                v_,
                f,
                params,
                samples,
                0.0,
            ),
            v,
        )(w)
        return res

    if jit:
        mvt = jax.jit(mvt, static_argnums=1)

    w = e.v
    actual = mvt(e.v, e.f, e.params, e.samples, w)

    # use that S is hermitian:
    # S^T = (O^H O)^T = O^T O* = (O^H O)* = S*
    # S^T w = S* w = (S w*)*
    expected = qgt_onthefly_logic.tree_conj(
        qgt_onthefly_logic.mat_vec(
            qgt_onthefly_logic.tree_conj(w),
            e.f,
            e.params,
            e.samples,
            0.0,
        )
    )
    # (expected,) = jax.linear_transpose(lambda v_: reassemble_complex(S_real @ tree_toreal_flat(v_), target=e.target), v)(v)
    assert tree_allclose(actual, expected)
示例#2
0
def test_vjp(e):
    actual = qgt_onthefly_logic.O_vjp(e.f, e.params, e.samples, e.w)
    expected = qgt_onthefly_logic.tree_conj(
        reassemble_complex(
            (e.w @ e.ok_real).real.astype(e.params_real_flat.dtype),
            target=e.target))
    assert tree_allclose(actual, expected)
示例#3
0
def reassemble_complex(x, target, fun=tree_toreal_flat):
    # target: a tree with the expected shape and types of the result
    (res, ) = jax.linear_transpose(fun, target)(x)
    res = qgt_onthefly_logic.tree_conj(res)
    # fix the dtypes:
    return qgt_onthefly_logic.tree_cast(res, target)
示例#4
0
def test_mean(e):
    actual = qgt_onthefly_logic.O_mean(e.f, e.params, e.samples)
    expected = qgt_onthefly_logic.tree_conj(
        reassemble_complex(e.okmean_real.real, target=e.target))
    assert tree_allclose(actual, expected)