def test_matvec_linear_transpose(e, jit, chunk_size): def f(params_model_state, x): return e.f(params_model_state["params"], x) if chunk_size is None: mat_vec_factory = qgt_onthefly_logic.mat_vec_factory samples = e.samples else: mat_vec_factory = qgt_onthefly_logic.mat_vec_chunked_factory samples = e.samples.reshape((-1, chunk_size) + e.samples.shape[1:]) mv = mat_vec_factory(f, e.params, {}, samples) def mvt(v, w): (res,) = jax.linear_transpose(lambda v_: mv(v_, 0.0), v)(w) return res if jit: mv = jax.jit(mv) mvt = jax.jit(mvt) w = e.v actual = mvt(e.v, w) # use that S is hermitian: # S^T = (O^H O)^T = O^T O* = (O^H O)* = S* # S^T w = S* w = (S w*)* expected = nkjax.tree_conj( mv( nkjax.tree_conj(w), 0.0, ) ) # (expected,) = jax.linear_transpose(lambda v_: reassemble_complex(S_real @ tree_toreal_flat(v_), target=e.target), v)(v) assert tree_allclose(actual, expected)
def OH_w(forward_fn, params, samples, w): r""" compute Oᴴw (where ᴴ denotes the hermitian transpose) """ # Oᴴw = (wᴴO)ᴴ # The transposition of the 1D arrays is omitted in the implementation: # (wᴴO)ᴴ -> (w* O)* return tree_conj(O_vjp(forward_fn, params, samples, w.conjugate()))
def mat_vec(jvp_fn, v, diag_shift): # Save linearisation work # TODO move to mat_vec_factory after jax v0.2.19 vjp_fn = jax.linear_transpose(jvp_fn, v) w = jvp_fn(v) w = w * (1.0 / (w.size * mpi.n_nodes)) w = subtract_mean(w) # w/ MPI # Oᴴw = (wᴴO)ᴴ = (w* O)* since 1D arrays are not transposed # vjp_fn packages output into a length-1 tuple (res, ) = tree_conj(vjp_fn(w.conjugate())) res = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], res) return tree_axpy(diag_shift, v, res) # res + diag_shift * v
def OH_w(forward_fn, params, samples, w): r""" compute O^H w (where ^H is the hermitian transpose) """ # O^H w = (w^H O)^H # The transposition of the 1D arrays is omitted in the implementation: # (w^H O)^H -> (w* O)* # TODO The allreduce in O_vjp could be deferred until after the tree_cast # where the amount of data to be transferred would potentially be smaller res = tree_conj(O_vjp(forward_fn, params, samples, w.conjugate())) return tree_cast(res, params)
def _tree_reassemble_complex(x, target, fun=_tree_to_reim): (res,) = jax.linear_transpose(fun, target)(x) return nkjax.tree_conj(res)
def _mat_vec(v: PyTree, oks: PyTree) -> PyTree: """ Compute ⟨O† O⟩v = ∑ₗ ⟨Oₖᴴ Oₗ⟩ vₗ """ res = tree_conj(_vjp(oks, _jvp(oks, v).conjugate())) return tree_cast(res, v)
def OH_w(forward_fn, params, samples, w): return tree_conj(O_vjp(forward_fn, params, samples, w.conjugate()))
def _mv_trans(extra_args, y): # the linear operator is hermitian params, samples, diag_shift = extra_args return tree_conj( mat_vec_chunked(forward_fn, params, samples, tree_conj(y), diag_shift))