def test_qgt_solve(qgt, vstate, solver, _mpi_size, _mpi_rank): S = qgt(vstate) x, _ = S.solve(solver, vstate.parameters) rtol, atol = solvers_tol[solver, nk.jax.dtype_real(vstate.model.dtype)] jax.tree_map( partial(testing.assert_allclose, rtol=rtol, atol=atol), S @ x, vstate.parameters, ) if _mpi_size > 1: # other check with common.netket_disable_mpi(): import mpi4jax samples, _ = mpi4jax.allgather(vstate.samples, comm=nk.utils.mpi.MPI_jax_comm) assert samples.shape == (_mpi_size, *vstate.samples.shape) vstate._samples = samples.reshape((-1, *vstate.samples.shape[1:])) S = qgt(vstate) x_all, _ = S.solve(solver, vstate.parameters) jax.tree_map( lambda a, b: np.testing.assert_allclose( a, b, rtol=rtol, atol=atol), x, x_all, )
def test_qgt_dense(qgt, vstate, _mpi_size, _mpi_rank): S = qgt(vstate) Sd = S.to_dense() assert Sd.ndim == 2 if hasattr(S, "mode"): if S.mode == "complex" and np.issubdtype(vstate.model.dtype, np.complexfloating): assert Sd.shape == (2 * vstate.n_parameters, 2 * vstate.n_parameters) else: assert Sd.shape == (vstate.n_parameters, vstate.n_parameters) else: assert Sd.shape == (vstate.n_parameters, vstate.n_parameters) if _mpi_size > 1: # other check with common.netket_disable_mpi(): import mpi4jax samples, _ = mpi4jax.allgather(vstate.samples, comm=nk.utils.mpi.MPI_jax_comm) assert samples.shape == (_mpi_size, *vstate.samples.shape) vstate._samples = samples.reshape((-1, *vstate.samples.shape[1:])) S = qgt(vstate) Sd_all = S.to_dense() np.testing.assert_allclose(Sd_all, Sd, rtol=1e-5, atol=1e-15)
def test_qgt_matmul(qgt, vstate, _mpi_size, _mpi_rank): S = qgt(vstate) y = vstate.parameters x = S @ y # test multiplication by dense gives same result... y_dense, unravel = nk.jax.tree_ravel(y) x_dense = S @ y_dense x_dense_unravelled = unravel(x_dense) jax.tree_multimap( lambda a, b: np.testing.assert_allclose(a, b), x, x_dense_unravelled ) if _mpi_size > 1: # other check with common.netket_disable_mpi(): import mpi4jax samples, _ = mpi4jax.allgather( vstate.samples, comm=nk.utils.mpi.MPI_jax_comm ) assert samples.shape == (_mpi_size, *vstate.samples.shape) vstate._samples = samples.reshape((-1, *vstate.samples.shape[1:])) S = qgt(vstate) x_all = S @ y jax.tree_multimap(lambda a, b: np.testing.assert_allclose(a, b), x, x_all)
def test_qgt_matmul(qgt, vstate, _mpi_size, _mpi_rank): rtol, atol = matmul_tol[nk.jax.dtype_real(vstate.model.dtype)] S = qgt(vstate) rng = nkjax.PRNGSeq(0) y = jax.tree_map( lambda x: 0.001 * jax.random.normal(rng.next(), x.shape, dtype=x.dtype ), vstate.parameters, ) x = S @ y def check_same_dtype(x, y): assert x.dtype == y.dtype jax.tree_map(check_same_dtype, x, y) # test multiplication by dense gives same result... y_dense, unravel = nk.jax.tree_ravel(y) x_dense = S @ y_dense x_dense_unravelled = unravel(x_dense) jax.tree_map( lambda a, b: np.testing.assert_allclose(a, b, rtol=rtol, atol=atol), x, x_dense_unravelled, ) if _mpi_size > 1: # other check with common.netket_disable_mpi(): import mpi4jax samples, _ = mpi4jax.allgather(vstate.samples, comm=nk.utils.mpi.MPI_jax_comm) assert samples.shape == (_mpi_size, *vstate.samples.shape) vstate._samples = samples.reshape((-1, *vstate.samples.shape[1:])) S = qgt(vstate) x_all = S @ y jax.tree_map( lambda a, b: np.testing.assert_allclose( a, b, rtol=rtol, atol=atol), x, x_all, )
def test_qgt_pytree_diag_shift(qgt, vstate): v = vstate.parameters S = qgt(vstate) expected = S @ v diag_shift = S.diag_shift if isinstance(S, (QGTJacobianPyTreeT, QGTJacobianDenseT)): # extract the necessary shape for the diag_shift t = jax.eval_shape(partial(jax.tree_map, lambda x: x[0], S.O)) else: t = v diag_shift_tree = jax.tree_map( lambda x: diag_shift * jnp.ones(x.shape, dtype=x.dtype), t) S = S.replace(diag_shift=diag_shift_tree) res = S @ v jax.tree_map(lambda a, b: np.testing.assert_allclose(a, b), res, expected)
def test_qgt_solve_with_x0(qgt, vstate): solver = jax.scipy.sparse.linalg.gmres x0 = jax.tree_map(jnp.zeros_like, vstate.parameters) S = qgt(vstate) x, _ = S.solve(solver, vstate.parameters, x0=x0)
def test_qgt_solve(qgt, vstate, solver, _mpi_size, _mpi_rank): S = qgt(vstate) x, _ = S.solve(solver, vstate.parameters)