Exemplo n.º 1
0
def test_deprecations():
    g = nk.graph.Edgeless(3)

    with pytest.warns(FutureWarning):
        Spin(s=0.5, graph=g)

    with pytest.warns(FutureWarning):
        with pytest.raises(ValueError):
            Spin(s=0.5, graph=g, N=3)
Exemplo n.º 2
0
def test_state_iteration():
    hilbert = Spin(s=0.5, N=10)

    reference = [
        np.array(el) for el in itertools.product([-1.0, 1.0], repeat=10)
    ]

    for state, ref in zip(hilbert.states(), reference):
        assert np.allclose(state, ref)
Exemplo n.º 3
0
def test_composite_hilbert_spin():
    hi1 = Spin(s=1 / 2, N=8)
    hi2 = Spin(s=3 / 2, N=8)

    hi = hi1 * hi2

    assert hi.size == hi1.size + hi2.size

    for i in range(hi.size):
        assert hi.size_at_index(i) == 2 if i < 8 else 4
Exemplo n.º 4
0
def test_hilbert_index_discrete(hi: DiscreteHilbert):
    log_max_states = np.log(nk.hilbert._abstract_hilbert.max_states)

    if hi.is_indexable:
        local_sizes = [hi.size_at_index(i) for i in range(hi.size)]
        assert np.sum(np.log(local_sizes)) < log_max_states
        assert np.allclose(hi.states_to_numbers(hi.all_states()),
                           range(hi.n_states))

        # batched version of number to state
        n_few = min(hi.n_states, 100)
        few_states = np.zeros(shape=(n_few, hi.size))
        for k in range(n_few):
            few_states[k] = hi.numbers_to_states(k)

        assert np.allclose(hi.numbers_to_states(np.asarray(range(n_few))),
                           few_states)

    else:
        assert not hi.is_indexable

        with pytest.raises(RuntimeError):
            hi.n_states

    # Check that a large hilbert space raises error when constructing matrices
    g = nk.graph.Hypercube(length=100, n_dim=1)
    op = nk.operator.Heisenberg(hilbert=Spin(s=0.5, N=g.n_nodes), graph=g)

    with pytest.raises(RuntimeError):
        op.to_dense()
    with pytest.raises(RuntimeError):
        op.to_sparse()
Exemplo n.º 5
0
def merge_dicts(x, y):
    z = x.copy()  # start with x's keys and values
    z.update(y)  # modifies z with y's keys and values & returns None
    return z


machines = {}
dm_machines = {}

# TESTS FOR SPIN HILBERT
# Constructing a 1d lattice
g = nk.graph.Hypercube(length=4, n_dim=1)

# Hilbert space of spins from given graph
hi = Spin(s=0.5, N=g.n_nodes)

if test_jax:
    import jax
    import jax.experimental
    import jax.experimental.stax

    def initializer(rng, shape):
        return np.random.normal(scale=0.05, size=shape)

    # machines["Jax Real"] = nk.machine.Jax(
    #     hi,
    #     jax.experimental.stax.serial(
    #         jax.experimental.stax.Dense(4, initializer, initializer),
    #         jax.experimental.stax.Relu,
    #         jax.experimental.stax.Dense(2, initializer, initializer),
Exemplo n.º 6
0
    Qubit,
    Spin,
)
import netket.experimental as nkx

import jax
import jax.numpy as jnp

from .. import common

pytestmark = common.skipif_mpi

hilberts = {}

# Spin 1/2
hilberts["Spin 1/2"] = Spin(s=0.5, N=20)

# Spin 1/2 with total Sz
hilberts["Spin[0.5, N=20, total_sz=1"] = Spin(s=0.5, total_sz=1.0, N=20)
hilberts["Spin[0.5, N=5, total_sz=-1.5"] = Spin(s=0.5, total_sz=-1.5, N=5)

# Spin 1/2 with total Sz
hilberts["Spin 1 with total Sz, even sites"] = Spin(s=1.0, total_sz=5.0, N=6)

# Spin 1/2 with total Sz
hilberts["Spin 1 with total Sz, odd sites"] = Spin(s=1.0, total_sz=2.0, N=7)

# Spin 3
hilberts["Spin 3"] = Spin(s=3, N=25)

# Boson
Exemplo n.º 7
0
def merge_dicts(x, y):
    z = x.copy()  # start with x's keys and values
    z.update(y)  # modifies z with y's keys and values & returns None
    return z


machines = {}
dm_machines = {}

# TESTS FOR SPIN HILBERT
# Constructing a 1d lattice
g = nk.graph.Hypercube(length=4, n_dim=1)

# Hilbert space of spins from given graph
hi = Spin(s=0.5, graph=g)

if test_jax:
    import jax
    import jax.experimental
    import jax.experimental.stax

    def initializer(rng, shape):
        return np.random.normal(scale=0.05, size=shape)

    # machines["Jax Real"] = nk.machine.Jax(
    #     hi,
    #     jax.experimental.stax.serial(
    #         jax.experimental.stax.Dense(4, initializer, initializer),
    #         jax.experimental.stax.Relu,
    #         jax.experimental.stax.Dense(2, initializer, initializer),