Ejemplo n.º 1
0
def test_flip_state_fock_infinite():
    hi = Fock(N=2)
    rng = nk.jax.PRNGSeq(1)
    N_batches = 20

    states = hi.random_state(rng.next(), N_batches, dtype=jnp.int64)

    ids = jnp.asarray(
        jnp.floor(hi.size *
                  jax.random.uniform(rng.next(), shape=(N_batches, ))),
        dtype=int,
    )

    new_states, old_vals = nk.hilbert.random.flip_state(
        hi, rng.next(), states, ids)

    assert new_states.shape == states.shape

    assert np.all(states >= 0)

    states_np = np.asarray(states)
    states_new_np = np.array(new_states)

    for (row, col) in enumerate(ids):
        states_new_np[row, col] = states_np[row, col]

    np.testing.assert_allclose(states_np, states_new_np)
Ejemplo n.º 2
0
def test_random_states_fock_infinite():
    hi = Fock(N=2)
    rstate = hi.random_state(jax.random.PRNGKey(14), 20)
    assert np.all(rstate >= 0)
    assert rstate.shape == (20, 2)