def test_gradients_nonhermitian(self): dlist = [jax.devices()[0], jax.devices()] for ds in dlist: global_defs.set_pmap_devices(ds) net = nets.CpxRNN.partial(L=3) _, params1 = net.init_by_shape(random.PRNGKey(0), [(3, )]) model = nn.Model(net, params1) s = jnp.zeros(get_shape((4, 3)), dtype=np.int32) s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1) s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1) psi = NQS(model) psi0 = psi(s) G = psi.gradients(s) delta = 1e-5 params = psi.get_parameters() for j in range(G.shape[-1]): u = jax.ops.index_update( jnp.zeros(G.shape[-1], dtype=jVMC.global_defs.tReal), jax.ops.index[j], 1) psi.update_parameters(delta * u) psi1 = psi(s) psi.set_parameters(params) # Finite difference gradients Gfd = (psi1 - psi0) / delta with self.subTest(i=j): self.assertTrue(jnp.max(jnp.abs(Gfd - G[..., j])) < 1e-2)
def test_gradients_cpx(self): dlist = [jax.devices()[0], jax.devices()] for ds in dlist: global_defs.set_pmap_devices(ds) rbm = nets.CpxRBM.partial(numHidden=2, bias=True) _, params = rbm.init_by_shape(random.PRNGKey(0), [(1, 3)]) rbmModel = nn.Model(rbm, params) s = jnp.zeros(get_shape((4, 3)), dtype=np.int32) s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1) s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1) psiC = NQS(rbmModel) psi0 = psiC(s) G = psiC.gradients(s) delta = 1e-5 params = psiC.get_parameters() for j in range(G.shape[-1]): u = jax.ops.index_update( jnp.zeros(G.shape[-1], dtype=global_defs.tReal), jax.ops.index[j], 1) psiC.update_parameters(delta * u) psi1 = psiC(s) psiC.set_parameters(params) # Finite difference gradients Gfd = (psi1 - psi0) / delta with self.subTest(i=j): self.assertTrue(jnp.max(jnp.abs(Gfd - G[..., j])) < 1e-2)
def test_autoregressive_sampling(self): L = 4 # Set up variational wave function rnn = nets.RNN.partial(L=4, hiddenSize=5, depth=2) _, params = rnn.init_by_shape(random.PRNGKey(0), [(L, )]) rnnModel = nn.Model(rnn, params) rbm = nets.RBM.partial(numHidden=2, bias=False) _, params = rbm.init_by_shape(random.PRNGKey(0), [(L, )]) rbmModel = nn.Model(rbm, params) psi = NQS((rnnModel, rbmModel)) ps = psi.get_parameters() psi.update_parameters(ps) # Set up exact sampler exactSampler = sampler.ExactSampler(L) # Set up MCMC sampler mcSampler = sampler.MCMCSampler(random.PRNGKey(0), jVMC.sampler.propose_spin_flip, (L, ), numChains=777) # Compute exact probabilities _, _, pex = exactSampler.sample(psi) numSamples = 500000 smc, p, _ = mcSampler.sample(psi, numSamples=numSamples) self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - p))) < 1e-12) if global_defs.usePmap: smc = smc.reshape((smc.shape[0] * smc.shape[1], -1)) self.assertTrue(smc.shape[0] >= numSamples) # Compute histogram of sampled configurations smcInt = jax.vmap(state_to_int)(smc) pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17)) self.assertTrue( jnp.max( jnp.abs(pmc / mcSampler.get_last_number_of_samples() - pex.reshape((-1, ))[:16])) < 1e-3)