def test_RBMSymm(use_hidden_bias, use_visible_bias, symmetries): g, hi, perms = _setup_symm(symmetries, N=8) ma = nk.models.RBMSymm( symmetries=perms, alpha=4, use_visible_bias=use_visible_bias, use_hidden_bias=use_hidden_bias, hidden_bias_init=uniform(), visible_bias_init=uniform(), ) pars = ma.init(nk.jax.PRNGKey(), hi.random_state(nk.jax.PRNGKey(), 1)) print(pars) v = hi.random_state(jax.random.PRNGKey(1), 3) vals = [ma.apply(pars, v[..., p]) for p in np.asarray(perms)] for val in vals: assert jnp.allclose(val, vals[0]) vmc = nk.VMC( nk.operator.Ising(hi, g, h=1.0), nk.optimizer.Sgd(0.1), nk.sampler.MetropolisLocal(hi), ma, ) vmc.advance(1)
def test_DenseSymm_infeatures(symmetries, use_bias, mode): rng = nk.jax.PRNGSeq(0) g, hi, perms = _setup_symm(symmetries, N=8) if mode == "matrix": ma = nk.nn.DenseSymm( symmetries=perms, mode=mode, features=8, use_bias=use_bias, bias_init=uniform(), ) else: ma = nk.nn.DenseSymm( symmetries=perms, shape=tuple(g.extent), mode=mode, features=8, use_bias=use_bias, bias_init=uniform(), ) pars = ma.init(rng.next(), hi.random_state(rng.next(), 2).reshape(1, 2, -1)) v = hi.random_state(rng.next(), 6).reshape(3, 2, -1) vals = [ma.apply(pars, v[..., p]) for p in np.asarray(perms)] for val in vals: assert jnp.allclose(jnp.sort(val, -1), jnp.sort(vals[0], -1))
def test_DenseEquivariant(symmetries, use_bias, lattice, mode, mask): rng = nk.jax.PRNGSeq(0) g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice) pt = perms.product_table n_symm = np.asarray(perms).shape[0] if mask: mask = np.zeros(n_symm) mask[np.random.choice(n_symm, n_symm // 2, replace=False)] = 1 else: mask = np.ones([n_symm]) if mode == "irreps": ma = nk.nn.DenseEquivariant( symmetries=perms, mode=mode, features=1, mask=mask, use_bias=use_bias, bias_init=uniform(), ) else: ma = nk.nn.DenseEquivariant( symmetries=pt, shape=tuple(g.extent), mode=mode, features=1, mask=mask, use_bias=use_bias, bias_init=uniform(), ) dum_input = jax.random.normal(rng.next(), (1, 1, n_symm)) pars = ma.init(rng.next(), dum_input) # inv_pt computes chosen_op = gh^-1 instead of g^-1h chosen_op = np.random.randint(n_symm) inverse = PermutationGroup( [perms.elems[i] for i in perms.inverse], degree=g.n_nodes ) inv_pt = inverse.product_table sym_op = np.where(inv_pt == chosen_op, 1.0, 0.0) v = random.normal(rng.next(), [3, 1, n_symm]) v_trans = jnp.matmul(v, sym_op) out = ma.apply(pars, v) out_trans = ma.apply(pars, v_trans) # output should be involution assert jnp.allclose(jnp.matmul(out, sym_op), out_trans)
def test_gcnn(mode, complex_output): lattice = nk.graph.Chain symmetries = "trans" parity = True g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice) ma = nk.models.GCNN( symmetries=perms, mode=mode, shape=tuple(g.extent), layers=2, features=2, parity=parity, bias_init=uniform(), complex_output=complex_output, ) vmc = nk.VMC( nk.operator.Ising(hi, g, h=1.0), nk.optimizer.Sgd(0.1), nk.sampler.MetropolisLocal(hi, n_chains=2, n_sweeps=2), ma, n_samples=8, ) vmc.advance(1)
def init_fun(rng, input_shape): assert input_dim == input_shape[-1] *k1, k2, k3 = random.split(rng, num_blocks + 2) # Initialize each column block using W_init W = jnp.zeros((input_dim, out_dim)) for i in range(num_blocks): W = W.at[:(i + 1) * in_factor, i * out_factor:(i + 1) * out_factor].set( W_init(k1[i], ((i + 1) * in_factor, out_factor))) # initialize weight scale ws = jnp.log(uniform(1.0)(k2, (out_dim, ))) if bias: b = (uniform(1.0)(k3, (out_dim, )) - 0.5) * (2 / jnp.sqrt(out_dim)) params = (W, ws, b) else: params = (W, ws) return input_shape[:-1] + (out_dim, ), params
def test_RBMMultiVal(use_hidden_bias, use_visible_bias): N = 8 M = 3 hi = nk.hilbert.Fock(M, N) g = nk.graph.Chain(N) ma = nk.models.RBMMultiVal( alpha=2, n_classes=M + 1, use_visible_bias=use_visible_bias, use_hidden_bias=use_hidden_bias, hidden_bias_init=uniform(), visible_bias_init=uniform(), ) _ = ma.init(nk.jax.PRNGKey(), hi.random_state(nk.jax.PRNGKey(), 1)) vmc = nk.VMC( nk.operator.BoseHubbard(hi, g, U=1.0), nk.optimizer.Sgd(0.1), nk.sampler.MetropolisLocal(hi), ma, ) vmc.advance(1)
def test_modes_DenseEquivariant(lattice, symmetries): rng = nk.jax.PRNGSeq(0) g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice) ma_fft = nk.nn.DenseEquivariant( symmetries=perms, mode="fft", features=1, shape=tuple(g.extent), bias_init=uniform(), ) ma_irreps = nk.nn.DenseEquivariant( symmetries=perms, mode="irreps", features=1, bias_init=uniform(), ) ma_matrix = nk.nn.DenseEquivariant( symmetries=perms, mode="matrix", features=1, bias_init=uniform(), ) dum_input = jax.random.normal(rng.next(), (1, 1, len(perms))) pars = ma_fft.init(rng.next(), dum_input) _ = ma_irreps.init(rng.next(), dum_input) _ = ma_matrix.init(rng.next(), dum_input) fft_out = ma_fft.apply(pars, dum_input) irreps_out = ma_irreps.apply(pars, dum_input) matrix_out = ma_matrix.apply(pars, dum_input) assert jnp.allclose(fft_out, irreps_out) assert jnp.allclose(fft_out, matrix_out)
def test_modes_DenseSymm_infeatures(lattice, symmetries): rng = nk.jax.PRNGSeq(0) g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice) ma_fft = nk.nn.DenseSymm( symmetries=perms, mode="fft", features=4, shape=tuple(g.extent), bias_init=uniform(), ) ma_matrix = nk.nn.DenseSymm( symmetries=perms, mode="matrix", features=4, bias_init=uniform(), ) dum_input = jax.random.normal(rng.next(), (1, 3, g.n_nodes)) pars = ma_fft.init(rng.next(), dum_input) _ = ma_matrix.init(rng.next(), dum_input) assert jnp.allclose(ma_fft.apply(pars, dum_input), ma_matrix.apply(pars, dum_input))
def test_modes_DenseSymm(lattice, symmetries): rng = nk.jax.PRNGSeq(0) g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice) ma_fft = nk.nn.DenseSymm( symmetries=perms, mode="fft", features=4, shape=tuple(g.extent), bias_init=uniform(), ) ma_matrix = nk.nn.DenseSymm( symmetries=perms, mode="matrix", features=4, bias_init=uniform(), ) dum_input = jax.random.normal(rng.next(), (3, 1, g.n_nodes)) pars = ma_fft.init(rng.next(), dum_input) _ = ma_matrix.init(rng.next(), dum_input) assert jnp.allclose(ma_fft.apply(pars, dum_input), ma_matrix.apply(pars, dum_input)) # Test Deprecation warning dum_input_nofeatures = dum_input.reshape((dum_input.shape[0], dum_input.shape[2])) with pytest.warns(FutureWarning): assert jnp.allclose( ma_fft.apply(pars, dum_input), ma_fft.apply(pars, dum_input_nofeatures) ) assert jnp.allclose( ma_matrix.apply(pars, dum_input), ma_matrix.apply(pars, dum_input_nofeatures), )
def Embedding(vocab_size, embedding_size, padding_idx=None, embedding_init=uniform()): """Layer construction function for an embedding layer.""" def init_fun(rng, input_shape): embedding_shape = (vocab_size, embedding_size) embedding_table = embedding_init(rng, embedding_shape) if padding_idx is not None: embedding_table = index_update(embedding_table, padding_idx, 0.) output_shape = input_shape + (embedding_size, ) return output_shape, (embedding_table, ) def apply_fun(params, inputs, **kwargs): embedding_table = params[0] return embedding_table[inputs] return init_fun, apply_fun
def test_gcnn_equivariance(parity, symmetries, lattice, mode): g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice) ma = nk.models.GCNN( symmetries=perms, mode=mode, shape=tuple(g.extent), layers=2, features=2, parity=parity, bias_init=uniform(), ) pars = ma.init(nk.jax.PRNGKey(), hi.random_state(nk.jax.PRNGKey(), 1)) v = hi.random_state(jax.random.PRNGKey(0), 3) vals = [ma.apply(pars, v[..., p]) for p in np.asarray(perms)] for val in vals: assert jnp.allclose(val, vals[0])
def test_gcnn_equivariance(parity, symmetries, lattice, mode): g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice) ma = nk.models.GCNN( symmetries=perms, mode=mode, shape=tuple(g.extent), layers=2, features=2, parity=parity, bias_init=uniform(), ) pars = ma.init(nk.jax.PRNGKey(), hi.random_state(nk.jax.PRNGKey(), 1)) v = hi.random_state(jax.random.PRNGKey(0), 3) # vals = [ma.apply(pars, v[..., p]) for p in np.asarray(perms)] # code below implements the commented line above, but is vectorised v = v[..., np.asarray(perms)].transpose(1, 0, 2) v = v.reshape(len(perms) * 3, g.n_nodes) vals = ma.apply(pars, v).reshape(len(perms), 3) for val in vals: assert jnp.allclose(val, vals[0])
def __call__(self, key, shape, dtype=None): if dtype is None: dtype = "float32" initializer_fn = jax_initializers.uniform() return initializer_fn(key, shape, dtype)