import jax from jax import numpy as jnp from flax import linen as nn from netket import nn as nknn from netket.nn.initializers import lecun_normal, variance_scaling, zeros from netket.hilbert import AbstractHilbert from netket.graph import AbstractGraph PRNGKey = Any Shape = Tuple[int] Dtype = Any # this could be a real type? Array = Any default_kernel_init = lecun_normal() class MPSPeriodic(nn.Module): r""" A periodic Matrix Product State (MPS) for a quantum state of discrete degrees of freedom, wrapped as Jax machine. The MPS is defined as .. math:: \Psi(s_1,\dots s_N) = \mathrm{Tr} \left[ A[s_1]\dots A[s_N] \right] , for arbitrary local quantum numbers :math:`s_i`, where :math:`A[s_1]` is a matrix of dimension (bdim,bdim), depending on the value of the local quantum number :math:`s_i`. Attributes:
def test_initializer(init, ndim, dtype): np.random.seed(seed) if init == "uniform": init_fun = lecun_uniform() elif init == "truncated_normal": init_fun = lecun_normal() key, rand_key = jax.random.split(nk.jax.PRNGKey(seed)) # The lengths of the weight dimensions and the input dimension are random shape = tuple(np.random.randint(1, 10) for _ in range(ndim)) shape_prod = np.prod(shape) # The length of the output dimension is a statistically large number, but not too large that OOM len_out = int(10 ** 6 / shape_prod) shape += (len_out,) param = init_fun(key, shape, dtype) variance = 1 / shape_prod stddev = sqrt(variance) assert param.mean() == pytest.approx(0, abs=1e-3) assert param.var() == pytest.approx(variance, abs=1e-2) if init == "uniform": if jnp.issubdtype(dtype, jnp.floating): max_norm = sqrt(3) * stddev else: max_norm = sqrt(2) * stddev elif init == "truncated_normal": if jnp.issubdtype(dtype, jnp.floating): max_norm = 2 / 0.87962566103423978 * stddev else: max_norm = 2 / 0.95311164380491208 * stddev assert jnp.abs(param).max() == pytest.approx(max_norm, abs=1e-3) # Draw random samples using rejection sampling, and test if `param` and # `samples` are from the same distribution rand_shape = (10 ** 4,) rand_dtype = dtype_real(dtype) if init == "uniform": if jnp.issubdtype(dtype, jnp.floating): samples = jax.random.uniform( rand_key, rand_shape, rand_dtype, -max_norm, max_norm ) else: key_real, key_imag = jax.random.split(rand_key) samples = ( jax.random.uniform( key_real, rand_shape, rand_dtype, -max_norm, max_norm ) + jax.random.uniform( key_imag, rand_shape, rand_dtype, -max_norm, max_norm ) * 1j ) elif init == "truncated_normal": if jnp.issubdtype(dtype, jnp.floating): rand_stddev = max_norm / 2 else: rand_stddev = max_norm / (2 * sqrt(2)) samples = jax.random.normal(rand_key, rand_shape, rand_dtype) * rand_stddev samples = samples[jnp.abs(samples) < max_norm] _, pvalue = kstest(param.flatten(), samples) assert pvalue > 0.01