Example #1
0
def test_qgt_solve(qgt, vstate, solver, _mpi_size, _mpi_rank):
    S = qgt(vstate)
    x, _ = S.solve(solver, vstate.parameters)

    rtol, atol = solvers_tol[solver, nk.jax.dtype_real(vstate.model.dtype)]
    jax.tree_map(
        partial(testing.assert_allclose, rtol=rtol, atol=atol),
        S @ x,
        vstate.parameters,
    )

    if _mpi_size > 1:
        # other check
        with common.netket_disable_mpi():
            import mpi4jax

            samples, _ = mpi4jax.allgather(vstate.samples,
                                           comm=nk.utils.mpi.MPI_jax_comm)
            assert samples.shape == (_mpi_size, *vstate.samples.shape)
            vstate._samples = samples.reshape((-1, *vstate.samples.shape[1:]))

            S = qgt(vstate)
            x_all, _ = S.solve(solver, vstate.parameters)

            jax.tree_map(
                lambda a, b: np.testing.assert_allclose(
                    a, b, rtol=rtol, atol=atol),
                x,
                x_all,
            )
Example #2
0
def test_qgt_dense(qgt, vstate, _mpi_size, _mpi_rank):
    S = qgt(vstate)

    Sd = S.to_dense()

    assert Sd.ndim == 2
    if hasattr(S, "mode"):
        if S.mode == "complex" and np.issubdtype(vstate.model.dtype,
                                                 np.complexfloating):
            assert Sd.shape == (2 * vstate.n_parameters,
                                2 * vstate.n_parameters)
        else:
            assert Sd.shape == (vstate.n_parameters, vstate.n_parameters)
    else:
        assert Sd.shape == (vstate.n_parameters, vstate.n_parameters)

    if _mpi_size > 1:
        # other check
        with common.netket_disable_mpi():
            import mpi4jax

            samples, _ = mpi4jax.allgather(vstate.samples,
                                           comm=nk.utils.mpi.MPI_jax_comm)
            assert samples.shape == (_mpi_size, *vstate.samples.shape)
            vstate._samples = samples.reshape((-1, *vstate.samples.shape[1:]))

            S = qgt(vstate)
            Sd_all = S.to_dense()

            np.testing.assert_allclose(Sd_all, Sd, rtol=1e-5, atol=1e-15)
Example #3
0
def test_qgt_matmul(qgt, vstate, _mpi_size, _mpi_rank):
    S = qgt(vstate)
    y = vstate.parameters
    x = S @ y

    # test multiplication by dense gives same result...
    y_dense, unravel = nk.jax.tree_ravel(y)
    x_dense = S @ y_dense
    x_dense_unravelled = unravel(x_dense)

    jax.tree_multimap(
        lambda a, b: np.testing.assert_allclose(a, b), x, x_dense_unravelled
    )

    if _mpi_size > 1:
        # other check
        with common.netket_disable_mpi():
            import mpi4jax

            samples, _ = mpi4jax.allgather(
                vstate.samples, comm=nk.utils.mpi.MPI_jax_comm
            )
            assert samples.shape == (_mpi_size, *vstate.samples.shape)
            vstate._samples = samples.reshape((-1, *vstate.samples.shape[1:]))

            S = qgt(vstate)
            x_all = S @ y

            jax.tree_multimap(lambda a, b: np.testing.assert_allclose(a, b), x, x_all)
Example #4
0
def test_qgt_matmul(qgt, vstate, _mpi_size, _mpi_rank):

    rtol, atol = matmul_tol[nk.jax.dtype_real(vstate.model.dtype)]

    S = qgt(vstate)
    rng = nkjax.PRNGSeq(0)
    y = jax.tree_map(
        lambda x: 0.001 * jax.random.normal(rng.next(), x.shape, dtype=x.dtype
                                            ),
        vstate.parameters,
    )
    x = S @ y

    def check_same_dtype(x, y):
        assert x.dtype == y.dtype

    jax.tree_map(check_same_dtype, x, y)

    # test multiplication by dense gives same result...
    y_dense, unravel = nk.jax.tree_ravel(y)
    x_dense = S @ y_dense
    x_dense_unravelled = unravel(x_dense)

    jax.tree_map(
        lambda a, b: np.testing.assert_allclose(a, b, rtol=rtol, atol=atol),
        x,
        x_dense_unravelled,
    )

    if _mpi_size > 1:
        # other check
        with common.netket_disable_mpi():
            import mpi4jax

            samples, _ = mpi4jax.allgather(vstate.samples,
                                           comm=nk.utils.mpi.MPI_jax_comm)
            assert samples.shape == (_mpi_size, *vstate.samples.shape)
            vstate._samples = samples.reshape((-1, *vstate.samples.shape[1:]))

            S = qgt(vstate)
            x_all = S @ y

            jax.tree_map(
                lambda a, b: np.testing.assert_allclose(
                    a, b, rtol=rtol, atol=atol),
                x,
                x_all,
            )
Example #5
0
def test_qgt_pytree_diag_shift(qgt, vstate):
    v = vstate.parameters
    S = qgt(vstate)
    expected = S @ v
    diag_shift = S.diag_shift
    if isinstance(S, (QGTJacobianPyTreeT, QGTJacobianDenseT)):
        # extract the necessary shape for the diag_shift
        t = jax.eval_shape(partial(jax.tree_map, lambda x: x[0], S.O))
    else:
        t = v
    diag_shift_tree = jax.tree_map(
        lambda x: diag_shift * jnp.ones(x.shape, dtype=x.dtype), t)
    S = S.replace(diag_shift=diag_shift_tree)
    res = S @ v
    jax.tree_map(lambda a, b: np.testing.assert_allclose(a, b), res, expected)
Example #6
0
def test_qgt_solve_with_x0(qgt, vstate):
    solver = jax.scipy.sparse.linalg.gmres
    x0 = jax.tree_map(jnp.zeros_like, vstate.parameters)

    S = qgt(vstate)
    x, _ = S.solve(solver, vstate.parameters, x0=x0)
Example #7
0
def test_qgt_solve(qgt, vstate, solver, _mpi_size, _mpi_rank):
    S = qgt(vstate)

    x, _ = S.solve(solver, vstate.parameters)