Example #1
0
def test_RBMSymm(use_hidden_bias, use_visible_bias, symmetries):
    g, hi, perms = _setup_symm(symmetries, N=8)

    ma = nk.models.RBMSymm(
        symmetries=perms,
        alpha=4,
        use_visible_bias=use_visible_bias,
        use_hidden_bias=use_hidden_bias,
        hidden_bias_init=uniform(),
        visible_bias_init=uniform(),
    )
    pars = ma.init(nk.jax.PRNGKey(), hi.random_state(nk.jax.PRNGKey(), 1))

    print(pars)

    v = hi.random_state(jax.random.PRNGKey(1), 3)
    vals = [ma.apply(pars, v[..., p]) for p in np.asarray(perms)]

    for val in vals:
        assert jnp.allclose(val, vals[0])

    vmc = nk.VMC(
        nk.operator.Ising(hi, g, h=1.0),
        nk.optimizer.Sgd(0.1),
        nk.sampler.MetropolisLocal(hi),
        ma,
    )
    vmc.advance(1)
Example #2
0
def test_DenseSymm_infeatures(symmetries, use_bias, mode):
    rng = nk.jax.PRNGSeq(0)

    g, hi, perms = _setup_symm(symmetries, N=8)

    if mode == "matrix":
        ma = nk.nn.DenseSymm(
            symmetries=perms,
            mode=mode,
            features=8,
            use_bias=use_bias,
            bias_init=uniform(),
        )
    else:
        ma = nk.nn.DenseSymm(
            symmetries=perms,
            shape=tuple(g.extent),
            mode=mode,
            features=8,
            use_bias=use_bias,
            bias_init=uniform(),
        )

    pars = ma.init(rng.next(), hi.random_state(rng.next(), 2).reshape(1, 2, -1))

    v = hi.random_state(rng.next(), 6).reshape(3, 2, -1)
    vals = [ma.apply(pars, v[..., p]) for p in np.asarray(perms)]
    for val in vals:
        assert jnp.allclose(jnp.sort(val, -1), jnp.sort(vals[0], -1))
Example #3
0
def test_DenseEquivariant(symmetries, use_bias, lattice, mode, mask):
    rng = nk.jax.PRNGSeq(0)

    g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice)

    pt = perms.product_table
    n_symm = np.asarray(perms).shape[0]

    if mask:
        mask = np.zeros(n_symm)
        mask[np.random.choice(n_symm, n_symm // 2, replace=False)] = 1
    else:
        mask = np.ones([n_symm])

    if mode == "irreps":
        ma = nk.nn.DenseEquivariant(
            symmetries=perms,
            mode=mode,
            features=1,
            mask=mask,
            use_bias=use_bias,
            bias_init=uniform(),
        )
    else:
        ma = nk.nn.DenseEquivariant(
            symmetries=pt,
            shape=tuple(g.extent),
            mode=mode,
            features=1,
            mask=mask,
            use_bias=use_bias,
            bias_init=uniform(),
        )

    dum_input = jax.random.normal(rng.next(), (1, 1, n_symm))
    pars = ma.init(rng.next(), dum_input)

    # inv_pt computes chosen_op = gh^-1 instead of g^-1h
    chosen_op = np.random.randint(n_symm)
    inverse = PermutationGroup(
        [perms.elems[i] for i in perms.inverse], degree=g.n_nodes
    )
    inv_pt = inverse.product_table
    sym_op = np.where(inv_pt == chosen_op, 1.0, 0.0)

    v = random.normal(rng.next(), [3, 1, n_symm])
    v_trans = jnp.matmul(v, sym_op)

    out = ma.apply(pars, v)
    out_trans = ma.apply(pars, v_trans)

    # output should be involution
    assert jnp.allclose(jnp.matmul(out, sym_op), out_trans)
Example #4
0
def test_gcnn(mode, complex_output):
    lattice = nk.graph.Chain
    symmetries = "trans"
    parity = True
    g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice)

    ma = nk.models.GCNN(
        symmetries=perms,
        mode=mode,
        shape=tuple(g.extent),
        layers=2,
        features=2,
        parity=parity,
        bias_init=uniform(),
        complex_output=complex_output,
    )

    vmc = nk.VMC(
        nk.operator.Ising(hi, g, h=1.0),
        nk.optimizer.Sgd(0.1),
        nk.sampler.MetropolisLocal(hi, n_chains=2, n_sweeps=2),
        ma,
        n_samples=8,
    )
    vmc.advance(1)
Example #5
0
    def init_fun(rng, input_shape):
        assert input_dim == input_shape[-1]
        *k1, k2, k3 = random.split(rng, num_blocks + 2)

        # Initialize each column block using W_init
        W = jnp.zeros((input_dim, out_dim))
        for i in range(num_blocks):
            W = W.at[:(i + 1) * in_factor,
                     i * out_factor:(i + 1) * out_factor].set(
                         W_init(k1[i], ((i + 1) * in_factor, out_factor)))

        # initialize weight scale
        ws = jnp.log(uniform(1.0)(k2, (out_dim, )))

        if bias:
            b = (uniform(1.0)(k3, (out_dim, )) - 0.5) * (2 / jnp.sqrt(out_dim))
            params = (W, ws, b)
        else:
            params = (W, ws)
        return input_shape[:-1] + (out_dim, ), params
Example #6
0
def test_RBMMultiVal(use_hidden_bias, use_visible_bias):
    N = 8
    M = 3
    hi = nk.hilbert.Fock(M, N)
    g = nk.graph.Chain(N)

    ma = nk.models.RBMMultiVal(
        alpha=2,
        n_classes=M + 1,
        use_visible_bias=use_visible_bias,
        use_hidden_bias=use_hidden_bias,
        hidden_bias_init=uniform(),
        visible_bias_init=uniform(),
    )
    _ = ma.init(nk.jax.PRNGKey(), hi.random_state(nk.jax.PRNGKey(), 1))

    vmc = nk.VMC(
        nk.operator.BoseHubbard(hi, g, U=1.0),
        nk.optimizer.Sgd(0.1),
        nk.sampler.MetropolisLocal(hi),
        ma,
    )
    vmc.advance(1)
Example #7
0
def test_modes_DenseEquivariant(lattice, symmetries):

    rng = nk.jax.PRNGSeq(0)
    g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice)

    ma_fft = nk.nn.DenseEquivariant(
        symmetries=perms,
        mode="fft",
        features=1,
        shape=tuple(g.extent),
        bias_init=uniform(),
    )
    ma_irreps = nk.nn.DenseEquivariant(
        symmetries=perms,
        mode="irreps",
        features=1,
        bias_init=uniform(),
    )
    ma_matrix = nk.nn.DenseEquivariant(
        symmetries=perms,
        mode="matrix",
        features=1,
        bias_init=uniform(),
    )

    dum_input = jax.random.normal(rng.next(), (1, 1, len(perms)))
    pars = ma_fft.init(rng.next(), dum_input)
    _ = ma_irreps.init(rng.next(), dum_input)
    _ = ma_matrix.init(rng.next(), dum_input)

    fft_out = ma_fft.apply(pars, dum_input)
    irreps_out = ma_irreps.apply(pars, dum_input)
    matrix_out = ma_matrix.apply(pars, dum_input)

    assert jnp.allclose(fft_out, irreps_out)
    assert jnp.allclose(fft_out, matrix_out)
Example #8
0
def test_modes_DenseSymm_infeatures(lattice, symmetries):

    rng = nk.jax.PRNGSeq(0)
    g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice)

    ma_fft = nk.nn.DenseSymm(
        symmetries=perms,
        mode="fft",
        features=4,
        shape=tuple(g.extent),
        bias_init=uniform(),
    )
    ma_matrix = nk.nn.DenseSymm(
        symmetries=perms,
        mode="matrix",
        features=4,
        bias_init=uniform(),
    )

    dum_input = jax.random.normal(rng.next(), (1, 3, g.n_nodes))
    pars = ma_fft.init(rng.next(), dum_input)
    _ = ma_matrix.init(rng.next(), dum_input)

    assert jnp.allclose(ma_fft.apply(pars, dum_input), ma_matrix.apply(pars, dum_input))
Example #9
0
def test_modes_DenseSymm(lattice, symmetries):

    rng = nk.jax.PRNGSeq(0)
    g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice)

    ma_fft = nk.nn.DenseSymm(
        symmetries=perms,
        mode="fft",
        features=4,
        shape=tuple(g.extent),
        bias_init=uniform(),
    )
    ma_matrix = nk.nn.DenseSymm(
        symmetries=perms,
        mode="matrix",
        features=4,
        bias_init=uniform(),
    )

    dum_input = jax.random.normal(rng.next(), (3, 1, g.n_nodes))

    pars = ma_fft.init(rng.next(), dum_input)
    _ = ma_matrix.init(rng.next(), dum_input)

    assert jnp.allclose(ma_fft.apply(pars, dum_input), ma_matrix.apply(pars, dum_input))

    # Test Deprecation warning
    dum_input_nofeatures = dum_input.reshape((dum_input.shape[0], dum_input.shape[2]))
    with pytest.warns(FutureWarning):
        assert jnp.allclose(
            ma_fft.apply(pars, dum_input), ma_fft.apply(pars, dum_input_nofeatures)
        )
        assert jnp.allclose(
            ma_matrix.apply(pars, dum_input),
            ma_matrix.apply(pars, dum_input_nofeatures),
        )
Example #10
0
def Embedding(vocab_size,
              embedding_size,
              padding_idx=None,
              embedding_init=uniform()):
    """Layer construction function for an embedding layer."""
    def init_fun(rng, input_shape):
        embedding_shape = (vocab_size, embedding_size)
        embedding_table = embedding_init(rng, embedding_shape)
        if padding_idx is not None:
            embedding_table = index_update(embedding_table, padding_idx, 0.)
        output_shape = input_shape + (embedding_size, )
        return output_shape, (embedding_table, )

    def apply_fun(params, inputs, **kwargs):
        embedding_table = params[0]
        return embedding_table[inputs]

    return init_fun, apply_fun
Example #11
0
def test_gcnn_equivariance(parity, symmetries, lattice, mode):
    g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice)

    ma = nk.models.GCNN(
        symmetries=perms,
        mode=mode,
        shape=tuple(g.extent),
        layers=2,
        features=2,
        parity=parity,
        bias_init=uniform(),
    )

    pars = ma.init(nk.jax.PRNGKey(), hi.random_state(nk.jax.PRNGKey(), 1))

    v = hi.random_state(jax.random.PRNGKey(0), 3)
    vals = [ma.apply(pars, v[..., p]) for p in np.asarray(perms)]

    for val in vals:
        assert jnp.allclose(val, vals[0])
Example #12
0
def test_gcnn_equivariance(parity, symmetries, lattice, mode):
    g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice)

    ma = nk.models.GCNN(
        symmetries=perms,
        mode=mode,
        shape=tuple(g.extent),
        layers=2,
        features=2,
        parity=parity,
        bias_init=uniform(),
    )

    pars = ma.init(nk.jax.PRNGKey(), hi.random_state(nk.jax.PRNGKey(), 1))

    v = hi.random_state(jax.random.PRNGKey(0), 3)
    # vals = [ma.apply(pars, v[..., p]) for p in np.asarray(perms)]
    # code below implements the commented line above, but is vectorised
    v = v[..., np.asarray(perms)].transpose(1, 0, 2)
    v = v.reshape(len(perms) * 3, g.n_nodes)
    vals = ma.apply(pars, v).reshape(len(perms), 3)

    for val in vals:
        assert jnp.allclose(val, vals[0])
Example #13
0
 def __call__(self, key, shape, dtype=None):
     if dtype is None:
         dtype = "float32"
     initializer_fn = jax_initializers.uniform()
     return initializer_fn(key, shape, dtype)