def test_gradients_nonhermitian(self): dlist = [jax.devices()[0], jax.devices()] for ds in dlist: global_defs.set_pmap_devices(ds) net = nets.CpxRNN.partial(L=3) _, params1 = net.init_by_shape(random.PRNGKey(0), [(3, )]) model = nn.Model(net, params1) s = jnp.zeros(get_shape((4, 3)), dtype=np.int32) s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1) s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1) psi = NQS(model) psi0 = psi(s) G = psi.gradients(s) delta = 1e-5 params = psi.get_parameters() for j in range(G.shape[-1]): u = jax.ops.index_update( jnp.zeros(G.shape[-1], dtype=jVMC.global_defs.tReal), jax.ops.index[j], 1) psi.update_parameters(delta * u) psi1 = psi(s) psi.set_parameters(params) # Finite difference gradients Gfd = (psi1 - psi0) / delta with self.subTest(i=j): self.assertTrue(jnp.max(jnp.abs(Gfd - G[..., j])) < 1e-2)
def test_gradients_cpx(self): dlist = [jax.devices()[0], jax.devices()] for ds in dlist: global_defs.set_pmap_devices(ds) rbm = nets.CpxRBM.partial(numHidden=2, bias=True) _, params = rbm.init_by_shape(random.PRNGKey(0), [(1, 3)]) rbmModel = nn.Model(rbm, params) s = jnp.zeros(get_shape((4, 3)), dtype=np.int32) s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1) s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1) psiC = NQS(rbmModel) psi0 = psiC(s) G = psiC.gradients(s) delta = 1e-5 params = psiC.get_parameters() for j in range(G.shape[-1]): u = jax.ops.index_update( jnp.zeros(G.shape[-1], dtype=global_defs.tReal), jax.ops.index[j], 1) psiC.update_parameters(delta * u) psi1 = psiC(s) psiC.set_parameters(params) # Finite difference gradients Gfd = (psi1 - psi0) / delta with self.subTest(i=j): self.assertTrue(jnp.max(jnp.abs(Gfd - G[..., j])) < 1e-2)
def test_MCMC_sampling(self): L = 4 weights = jnp.array([ 0.23898957, 0.12614753, 0.19479055, 0.17325271, 0.14619853, 0.21392751, 0.19648707, 0.17103704, -0.15457255, 0.10954413, 0.13228065, -0.14935214, -0.09963073, 0.17610707, 0.13386381, -0.14836467 ]) # Set up variational wave function rbm = nets.CpxRBM.partial(numHidden=2, bias=False) _, params = rbm.init_by_shape(random.PRNGKey(0), [(L, )]) rbmModel = nn.Model(rbm, params) psi = NQS(rbmModel) psi.set_parameters(weights) # Set up exact sampler exactSampler = sampler.ExactSampler(L) # Set up MCMC sampler mcSampler = sampler.MCMCSampler(random.PRNGKey(0), jVMC.sampler.propose_spin_flip, (L, ), numChains=777) # Compute exact probabilities _, _, pex = exactSampler.sample(psi) # Get samples from MCMC sampler numSamples = 500000 smc, _, _ = mcSampler.sample(psi, numSamples=numSamples) if global_defs.usePmap: smc = smc.reshape((smc.shape[0] * smc.shape[1], -1)) self.assertTrue(smc.shape[0] >= numSamples) # Compute histogram of sampled configurations smcInt = jax.vmap(state_to_int)(smc) pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17)) # Compare histogram to exact probabilities self.assertTrue( jnp.max( jnp.abs(pmc / mcSampler.get_last_number_of_samples() - pex.reshape((-1, ))[:16])) < 2e-3)
def test_autoregressive_sampling(self): L = 4 # Set up variational wave function rnn = nets.RNN.partial(L=4, hiddenSize=5, depth=2) _, params = rnn.init_by_shape(random.PRNGKey(0), [(L, )]) rnnModel = nn.Model(rnn, params) rbm = nets.RBM.partial(numHidden=2, bias=False) _, params = rbm.init_by_shape(random.PRNGKey(0), [(L, )]) rbmModel = nn.Model(rbm, params) psi = NQS((rnnModel, rbmModel)) ps = psi.get_parameters() psi.update_parameters(ps) # Set up exact sampler exactSampler = sampler.ExactSampler(L) # Set up MCMC sampler mcSampler = sampler.MCMCSampler(random.PRNGKey(0), jVMC.sampler.propose_spin_flip, (L, ), numChains=777) # Compute exact probabilities _, _, pex = exactSampler.sample(psi) numSamples = 500000 smc, p, _ = mcSampler.sample(psi, numSamples=numSamples) self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - p))) < 1e-12) if global_defs.usePmap: smc = smc.reshape((smc.shape[0] * smc.shape[1], -1)) self.assertTrue(smc.shape[0] >= numSamples) # Compute histogram of sampled configurations smcInt = jax.vmap(state_to_int)(smc) pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17)) self.assertTrue( jnp.max( jnp.abs(pmc / mcSampler.get_last_number_of_samples() - pex.reshape((-1, ))[:16])) < 1e-3)
def test_evaluation_cpx(self): dlist = [jax.devices()[0], jax.devices()] for ds in dlist: global_defs.set_pmap_devices(ds) rbm = nets.CpxRBM.partial(numHidden=2, bias=True) _, params = rbm.init_by_shape(random.PRNGKey(0), [(4, 3)]) rbmModel = nn.Model(rbm, params) s = jnp.zeros(get_shape((4, 3)), dtype=np.int32) s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1) s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1) psiC = NQS(rbmModel) cpxCoeffs = psiC(s) realCoeffs = psiC.real_coefficients(s) self.assertTrue( jnp.linalg.norm(jnp.real(cpxCoeffs) - realCoeffs) < 1e-6)
def time_net_eval(states, timingReps, get_net): model = get_net() psi = NQS(model) t0 = time.perf_counter() psi(states).block_until_ready() t1 = time.perf_counter() print(" Time elapsed (incl. jit): %f seconds" % (t1-t0)) t=0 for i in range(timingReps): t0 = time.perf_counter() psi(states).block_until_ready() t1 = time.perf_counter() t += t1-t0 print(" Avg. time elapsed (jit'd, %d repetitions): %f seconds" % (timingReps, t/timingReps))
def test_gs_search_cpx(self): L = 4 J = -1.0 hxs = [-1.3, -0.3] exEs = [-6.10160339, -4.09296160] for hx, exE in zip(hxs, exEs): # Set up variational wave function rbm = nets.CpxRBM.partial(numHidden=6, bias=False) _, params = rbm.init_by_shape(random.PRNGKey(1), [(1, L)]) rbmModel = nn.Model(rbm, params) psi = NQS(rbmModel) # Set up hamiltonian for ground state search hamiltonianGS = op.Operator() for l in range(L): hamiltonianGS.add( op.scal_opstr(J, (op.Sz(l), op.Sz((l + 1) % L)))) hamiltonianGS.add(op.scal_opstr(hx, (op.Sx(l), ))) # Set up exact sampler exactSampler = sampler.ExactSampler(L) delta = 2 tdvpEquation = jVMC.tdvp.TDVP(exactSampler, snrTol=1, svdTol=1e-8, rhsPrefactor=1., diagonalShift=delta, makeReal='real') # Perform ground state search to get initial state ground_state_search(psi, hamiltonianGS, tdvpEquation, exactSampler, numSteps=100, stepSize=2e-2) obs = measure({"energy": hamiltonianGS}, psi, exactSampler) self.assertTrue( jnp.max(jnp.abs((obs['energy']['mean'] - exE) / exE)) < 1e-3)
def time_net_gradients(states, timingReps, get_net): model = get_net() psi = NQS(model, batchSize=states.shape[1]) t0 = time.perf_counter() psi.gradients(states).block_until_ready() t1 = time.perf_counter() print(" Time elapsed (incl. jit): %f seconds" % (t1-t0), flush=True) t=0 for i in range(timingReps): t0 = time.perf_counter() psi.gradients(states).block_until_ready() t1 = time.perf_counter() t += t1-t0 print(" Avg. time elapsed (jit'd, %d repetitions): %f seconds" % (timingReps, t/timingReps), flush=True)
def test_time_evolution(self): L = 4 J = -1.0 hx = -0.3 weights = jnp.array([ 0.23898957, 0.12614753, 0.19479055, 0.17325271, 0.14619853, 0.21392751, 0.19648707, 0.17103704, -0.15457255, 0.10954413, 0.13228065, -0.14935214, -0.09963073, 0.17610707, 0.13386381, -0.14836467 ]) # Set up variational wave function rbm = nets.CpxRBM.partial(numHidden=2, bias=False) _, params = rbm.init_by_shape(random.PRNGKey(0), [(1, L)]) rbmModel = nn.Model(rbm, params) psi = NQS(rbmModel) psi.set_parameters(weights) # Set up hamiltonian for time evolution hamiltonian = op.Operator() for l in range(L): hamiltonian.add(op.scal_opstr(J, (op.Sz(l), op.Sz((l + 1) % L)))) hamiltonian.add(op.scal_opstr(hx, (op.Sx(l), ))) # Set up ZZ observable ZZ = op.Operator() for l in range(L): ZZ.add((op.Sz(l), op.Sz((l + 1) % L))) # Set up exact sampler exactSampler = sampler.ExactSampler(L) # Set up adaptive time stepper stepper = jVMCstepper.AdaptiveHeun(timeStep=1e-3, tol=1e-5) tdvpEquation = jVMC.tdvp.TDVP(exactSampler, snrTol=1, svdTol=1e-8, rhsPrefactor=1.j, diagonalShift=0., makeReal='imag') t = 0 obs = [] times = [] times.append(t) newMeas = measure({'E': hamiltonian, 'ZZ': ZZ}, psi, exactSampler) obs.append([newMeas['E']['mean'], newMeas['ZZ']['mean']]) while t < 0.5: dp, dt = stepper.step(0, tdvpEquation, psi.get_parameters(), hamiltonian=hamiltonian, psi=psi, numSamples=0) psi.set_parameters(dp) t += dt times.append(t) newMeas = measure({'E': hamiltonian, 'ZZ': ZZ}, psi, exactSampler) obs.append([newMeas['E']['mean'], newMeas['ZZ']['mean']]) obs = np.array(jnp.asarray(obs)) # Check energy conservation obs[:, 0] = np.abs((obs[:, 0] - obs[0, 0]) / obs[0, 0]) self.assertTrue(np.max(obs[:, 0]) < 1e-3) # Check observable dynamics zz = interp1d(np.array(times), obs[:, 1, 0]) refTimes = np.arange(0, 0.5, 0.05) netZZ = zz(refTimes) refZZ = np.array([ 0.882762129306284, 0.8936168721790617, 0.9257753299594491, 0.9779836185039352, 1.0482156449061142, 1.1337654450614298, 1.231369697427413, 1.337354107391303, 1.447796176316155, 1.558696104640795, 1.666147269524912, 1.7664978782554912, 1.8564960156892512, 1.9334113379450693, 1.9951280521882777, 2.0402054805651546, 2.067904337137255, 2.078178742959828, 2.071635856483114, 2.049466698269522, 2.049466698269522 ]) self.assertTrue(np.max(np.abs(netZZ - refZZ[:len(netZZ)])) < 1e-3)