def init_parameters( self, init_fun: Optional[NNInitFunc] = None, *, seed: Optional[PRNGKeyT] = None ): r""" Re-initializes all the parameters with the provided initialization function, defaulting to the normal distribution of standard deviation 0.01. .. warning:: The init function will not change the dtype of the parameters, which is determined by the model. DO NOT SPECIFY IT INSIDE THE INIT FUNCTION Args: init_fun: a jax initializer such as :ref:`netket.nn.initializers.normal`. Must be a Callable taking 3 inputs, the jax PRNG key, the shape and the dtype, and outputting an array with the valid dtype and shape. If left unspecified, defaults to :code:`netket.nn.initializers.normal(stddev=0.01)` seed: Optional seed to be used. The seed is synced across all MPI processes. If unspecified, uses a random seed. """ if init_fun is None: init_fun = nknn.initializers.normal(stddev=0.01) rng = nkjax.PRNGSeq(nkjax.PRNGKey(seed)) def new_pars(par): return jnp.asarray( init_fun(rng.take(1)[0], shape=par.shape, dtype=par.dtype), dtype=par.dtype, ) self.parameters = jax.tree_map(new_pars, self.parameters)
def test_qgt_matmul(qgt, vstate, _mpi_size, _mpi_rank): rtol, atol = matmul_tol[nk.jax.dtype_real(vstate.model.dtype)] S = qgt(vstate) rng = nkjax.PRNGSeq(0) y = jax.tree_map( lambda x: 0.001 * jax.random.normal(rng.next(), x.shape, dtype=x.dtype ), vstate.parameters, ) x = S @ y def check_same_dtype(x, y): assert x.dtype == y.dtype jax.tree_map(check_same_dtype, x, y) # test multiplication by dense gives same result... y_dense, unravel = nk.jax.tree_ravel(y) x_dense = S @ y_dense x_dense_unravelled = unravel(x_dense) jax.tree_map( lambda a, b: np.testing.assert_allclose(a, b, rtol=rtol, atol=atol), x, x_dense_unravelled, ) if _mpi_size > 1: # other check with common.netket_disable_mpi(): import mpi4jax samples, _ = mpi4jax.allgather(vstate.samples, comm=nk.utils.mpi.MPI_jax_comm) assert samples.shape == (_mpi_size, *vstate.samples.shape) vstate._samples = samples.reshape((-1, *vstate.samples.shape[1:])) S = qgt(vstate) x_all = S @ y jax.tree_map( lambda a, b: np.testing.assert_allclose( a, b, rtol=rtol, atol=atol), x, x_all, )