Esempio n. 1
0
    def init_parameters(
        self, init_fun: Optional[NNInitFunc] = None, *, seed: Optional[PRNGKeyT] = None
    ):
        r"""
        Re-initializes all the parameters with the provided initialization function, defaulting to
        the normal distribution of standard deviation 0.01.

        .. warning::

            The init function will not change the dtype of the parameters, which is determined by the
            model. DO NOT SPECIFY IT INSIDE THE INIT FUNCTION

        Args:
            init_fun: a jax initializer such as :ref:`netket.nn.initializers.normal`. Must be a Callable
                taking 3 inputs, the jax PRNG key, the shape and the dtype, and outputting an array with
                the valid dtype and shape. If left unspecified, defaults to :code:`netket.nn.initializers.normal(stddev=0.01)`
            seed: Optional seed to be used. The seed is synced across all MPI processes. If unspecified, uses
                a random seed.
        """
        if init_fun is None:
            init_fun = nknn.initializers.normal(stddev=0.01)

        rng = nkjax.PRNGSeq(nkjax.PRNGKey(seed))

        def new_pars(par):
            return jnp.asarray(
                init_fun(rng.take(1)[0], shape=par.shape, dtype=par.dtype),
                dtype=par.dtype,
            )

        self.parameters = jax.tree_map(new_pars, self.parameters)
Esempio n. 2
0
def test_qgt_matmul(qgt, vstate, _mpi_size, _mpi_rank):

    rtol, atol = matmul_tol[nk.jax.dtype_real(vstate.model.dtype)]

    S = qgt(vstate)
    rng = nkjax.PRNGSeq(0)
    y = jax.tree_map(
        lambda x: 0.001 * jax.random.normal(rng.next(), x.shape, dtype=x.dtype
                                            ),
        vstate.parameters,
    )
    x = S @ y

    def check_same_dtype(x, y):
        assert x.dtype == y.dtype

    jax.tree_map(check_same_dtype, x, y)

    # test multiplication by dense gives same result...
    y_dense, unravel = nk.jax.tree_ravel(y)
    x_dense = S @ y_dense
    x_dense_unravelled = unravel(x_dense)

    jax.tree_map(
        lambda a, b: np.testing.assert_allclose(a, b, rtol=rtol, atol=atol),
        x,
        x_dense_unravelled,
    )

    if _mpi_size > 1:
        # other check
        with common.netket_disable_mpi():
            import mpi4jax

            samples, _ = mpi4jax.allgather(vstate.samples,
                                           comm=nk.utils.mpi.MPI_jax_comm)
            assert samples.shape == (_mpi_size, *vstate.samples.shape)
            vstate._samples = samples.reshape((-1, *vstate.samples.shape[1:]))

            S = qgt(vstate)
            x_all = S @ y

            jax.tree_map(
                lambda a, b: np.testing.assert_allclose(
                    a, b, rtol=rtol, atol=atol),
                x,
                x_all,
            )