Example #1
0
import jax
from jax import numpy as jnp
from flax import linen as nn

from netket import nn as nknn
from netket.nn.initializers import lecun_normal, variance_scaling, zeros

from netket.hilbert import AbstractHilbert
from netket.graph import AbstractGraph

PRNGKey = Any
Shape = Tuple[int]
Dtype = Any  # this could be a real type?
Array = Any

default_kernel_init = lecun_normal()


class MPSPeriodic(nn.Module):
    r"""
    A periodic Matrix Product State (MPS) for a quantum state of discrete
    degrees of freedom, wrapped as Jax machine.

    The MPS is defined as

    .. math:: \Psi(s_1,\dots s_N) = \mathrm{Tr} \left[ A[s_1]\dots A[s_N] \right] ,

    for arbitrary local quantum numbers :math:`s_i`, where :math:`A[s_1]` is a matrix
    of dimension (bdim,bdim), depending on the value of the local quantum number :math:`s_i`.

    Attributes:
Example #2
0
def test_initializer(init, ndim, dtype):
    np.random.seed(seed)

    if init == "uniform":
        init_fun = lecun_uniform()
    elif init == "truncated_normal":
        init_fun = lecun_normal()

    key, rand_key = jax.random.split(nk.jax.PRNGKey(seed))
    # The lengths of the weight dimensions and the input dimension are random
    shape = tuple(np.random.randint(1, 10) for _ in range(ndim))
    shape_prod = np.prod(shape)
    # The length of the output dimension is a statistically large number, but not too large that OOM
    len_out = int(10 ** 6 / shape_prod)
    shape += (len_out,)
    param = init_fun(key, shape, dtype)

    variance = 1 / shape_prod
    stddev = sqrt(variance)

    assert param.mean() == pytest.approx(0, abs=1e-3)
    assert param.var() == pytest.approx(variance, abs=1e-2)

    if init == "uniform":
        if jnp.issubdtype(dtype, jnp.floating):
            max_norm = sqrt(3) * stddev
        else:
            max_norm = sqrt(2) * stddev
    elif init == "truncated_normal":
        if jnp.issubdtype(dtype, jnp.floating):
            max_norm = 2 / 0.87962566103423978 * stddev
        else:
            max_norm = 2 / 0.95311164380491208 * stddev

    assert jnp.abs(param).max() == pytest.approx(max_norm, abs=1e-3)

    # Draw random samples using rejection sampling, and test if `param` and
    # `samples` are from the same distribution
    rand_shape = (10 ** 4,)
    rand_dtype = dtype_real(dtype)
    if init == "uniform":
        if jnp.issubdtype(dtype, jnp.floating):
            samples = jax.random.uniform(
                rand_key, rand_shape, rand_dtype, -max_norm, max_norm
            )
        else:
            key_real, key_imag = jax.random.split(rand_key)
            samples = (
                jax.random.uniform(
                    key_real, rand_shape, rand_dtype, -max_norm, max_norm
                )
                + jax.random.uniform(
                    key_imag, rand_shape, rand_dtype, -max_norm, max_norm
                )
                * 1j
            )
    elif init == "truncated_normal":
        if jnp.issubdtype(dtype, jnp.floating):
            rand_stddev = max_norm / 2
        else:
            rand_stddev = max_norm / (2 * sqrt(2))
        samples = jax.random.normal(rand_key, rand_shape, rand_dtype) * rand_stddev
    samples = samples[jnp.abs(samples) < max_norm]

    _, pvalue = kstest(param.flatten(), samples)
    assert pvalue > 0.01