def test_qgt_solve(qgt, vstate, solver, _mpi_size, _mpi_rank): S = qgt(vstate) x, _ = S.solve(solver, vstate.parameters) rtol, atol = solvers_tol[solver, nk.jax.dtype_real(vstate.model.dtype)] jax.tree_map( partial(testing.assert_allclose, rtol=rtol, atol=atol), S @ x, vstate.parameters, ) if _mpi_size > 1: # other check with common.netket_disable_mpi(): import mpi4jax samples, _ = mpi4jax.allgather(vstate.samples, comm=nk.utils.mpi.MPI_jax_comm) assert samples.shape == (_mpi_size, *vstate.samples.shape) vstate._samples = samples.reshape((-1, *vstate.samples.shape[1:])) S = qgt(vstate) x_all, _ = S.solve(solver, vstate.parameters) jax.tree_map( lambda a, b: np.testing.assert_allclose( a, b, rtol=rtol, atol=atol), x, x_all, )
def mpi_allgather_jax(x, *, token=None, comm=MPI_jax_comm): if n_nodes == 1: return x, token else: import mpi4jax return mpi4jax.allgather(x, token=token, comm=comm)
def test_qgt_dense(qgt, vstate, _mpi_size, _mpi_rank): S = qgt(vstate) Sd = S.to_dense() assert Sd.ndim == 2 if hasattr(S, "mode"): if S.mode == "complex" and np.issubdtype(vstate.model.dtype, np.complexfloating): assert Sd.shape == (2 * vstate.n_parameters, 2 * vstate.n_parameters) else: assert Sd.shape == (vstate.n_parameters, vstate.n_parameters) else: assert Sd.shape == (vstate.n_parameters, vstate.n_parameters) if _mpi_size > 1: # other check with common.netket_disable_mpi(): import mpi4jax samples, _ = mpi4jax.allgather(vstate.samples, comm=nk.utils.mpi.MPI_jax_comm) assert samples.shape == (_mpi_size, *vstate.samples.shape) vstate._samples = samples.reshape((-1, *vstate.samples.shape[1:])) S = qgt(vstate) Sd_all = S.to_dense() np.testing.assert_allclose(Sd_all, Sd, rtol=1e-5, atol=1e-15)
def test_qgt_matmul(qgt, vstate, _mpi_size, _mpi_rank): S = qgt(vstate) y = vstate.parameters x = S @ y # test multiplication by dense gives same result... y_dense, unravel = nk.jax.tree_ravel(y) x_dense = S @ y_dense x_dense_unravelled = unravel(x_dense) jax.tree_multimap( lambda a, b: np.testing.assert_allclose(a, b), x, x_dense_unravelled ) if _mpi_size > 1: # other check with common.netket_disable_mpi(): import mpi4jax samples, _ = mpi4jax.allgather( vstate.samples, comm=nk.utils.mpi.MPI_jax_comm ) assert samples.shape == (_mpi_size, *vstate.samples.shape) vstate._samples = samples.reshape((-1, *vstate.samples.shape[1:])) S = qgt(vstate) x_all = S @ y jax.tree_multimap(lambda a, b: np.testing.assert_allclose(a, b), x, x_all)
def test_allgather_jit(): from mpi4jax import allgather arr = jnp.ones((3, 2)) * rank res = jax.jit(lambda x: allgather(x)[0])(arr) for p in range(size): assert jnp.array_equal(res[p], jnp.ones((3, 2)) * p)
def test_allgather(): from mpi4jax import allgather arr = jnp.ones((3, 2)) * rank res, _ = allgather(arr) for p in range(size): assert jnp.array_equal(res[p], jnp.ones((3, 2)) * p)
def test_qgt_matmul(qgt, vstate, _mpi_size, _mpi_rank): rtol, atol = matmul_tol[nk.jax.dtype_real(vstate.model.dtype)] S = qgt(vstate) rng = nkjax.PRNGSeq(0) y = jax.tree_map( lambda x: 0.001 * jax.random.normal(rng.next(), x.shape, dtype=x.dtype ), vstate.parameters, ) x = S @ y def check_same_dtype(x, y): assert x.dtype == y.dtype jax.tree_map(check_same_dtype, x, y) # test multiplication by dense gives same result... y_dense, unravel = nk.jax.tree_ravel(y) x_dense = S @ y_dense x_dense_unravelled = unravel(x_dense) jax.tree_map( lambda a, b: np.testing.assert_allclose(a, b, rtol=rtol, atol=atol), x, x_dense_unravelled, ) if _mpi_size > 1: # other check with common.netket_disable_mpi(): import mpi4jax samples, _ = mpi4jax.allgather(vstate.samples, comm=nk.utils.mpi.MPI_jax_comm) assert samples.shape == (_mpi_size, *vstate.samples.shape) vstate._samples = samples.reshape((-1, *vstate.samples.shape[1:])) S = qgt(vstate) x_all = S @ y jax.tree_map( lambda a, b: np.testing.assert_allclose( a, b, rtol=rtol, atol=atol), x, x_all, )
def test_allgather_scalar_jit(): from mpi4jax import allgather arr = rank res = jax.jit(lambda x: allgather(x)[0])(arr) assert jnp.array_equal(res, jnp.arange(size))
def test_allgather_scalar(): from mpi4jax import allgather arr = rank res, _ = allgather(arr) assert jnp.array_equal(res, jnp.arange(size))