def test_autoregressive_sampling_with_lstm(self): L = 4 # Set up symmetry orbit orbit = jnp.array([ jnp.roll(jnp.identity(L, dtype=np.int32), l, axis=1) for l in range(L) ]) # Set up variational wave function rnn = nets.LSTM.partial(L=L, hiddenSize=5) _, params = rnn.init_by_shape(random.PRNGKey(0), [(L, )]) rnnModel = nn.Model(rnn, params) rbm = nets.RBM.partial(numHidden=2, bias=False) _, params = rbm.init_by_shape(random.PRNGKey(0), [(L, )]) rbmModel = nn.Model(rbm, params) psi = NQS((rnnModel, rbmModel)) # Set up exact sampler exactSampler = sampler.ExactSampler(L) # Set up MCMC sampler mcSampler = sampler.MCMCSampler(random.PRNGKey(0), jVMC.sampler.propose_spin_flip, (L, ), numChains=777) # Compute exact probabilities _, logPsi, pex = exactSampler.sample(psi) numSamples = 1000000 smc, p, _ = mcSampler.sample(psi, numSamples=numSamples) self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - p))) < 1e-12) if global_defs.usePmap: smc = smc.reshape((smc.shape[0] * smc.shape[1], -1)) self.assertTrue(smc.shape[0] >= numSamples) # Compute histogram of sampled configurations smcInt = jax.vmap(state_to_int)(smc) pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17)) self.assertTrue( jnp.max( jnp.abs(pmc / mcSampler.get_last_number_of_samples() - pex.reshape((-1, ))[:16])) < 1e-3)
def test_MCMC_sampling(self): L = 4 weights = jnp.array([ 0.23898957, 0.12614753, 0.19479055, 0.17325271, 0.14619853, 0.21392751, 0.19648707, 0.17103704, -0.15457255, 0.10954413, 0.13228065, -0.14935214, -0.09963073, 0.17610707, 0.13386381, -0.14836467 ]) # Set up variational wave function rbm = nets.CpxRBM.partial(numHidden=2, bias=False) _, params = rbm.init_by_shape(random.PRNGKey(0), [(L, )]) rbmModel = nn.Model(rbm, params) psi = NQS(rbmModel) psi.set_parameters(weights) # Set up exact sampler exactSampler = sampler.ExactSampler(L) # Set up MCMC sampler mcSampler = sampler.MCMCSampler(random.PRNGKey(0), jVMC.sampler.propose_spin_flip, (L, ), numChains=777) # Compute exact probabilities _, _, pex = exactSampler.sample(psi) # Get samples from MCMC sampler numSamples = 500000 smc, _, _ = mcSampler.sample(psi, numSamples=numSamples) if global_defs.usePmap: smc = smc.reshape((smc.shape[0] * smc.shape[1], -1)) self.assertTrue(smc.shape[0] >= numSamples) # Compute histogram of sampled configurations smcInt = jax.vmap(state_to_int)(smc) pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17)) # Compare histogram to exact probabilities self.assertTrue( jnp.max( jnp.abs(pmc / mcSampler.get_last_number_of_samples() - pex.reshape((-1, ))[:16])) < 2e-3)
def test_gs_search_cpx(self): L = 4 J = -1.0 hxs = [-1.3, -0.3] exEs = [-6.10160339, -4.09296160] for hx, exE in zip(hxs, exEs): # Set up variational wave function rbm = nets.CpxRBM.partial(numHidden=6, bias=False) _, params = rbm.init_by_shape(random.PRNGKey(1), [(1, L)]) rbmModel = nn.Model(rbm, params) psi = NQS(rbmModel) # Set up hamiltonian for ground state search hamiltonianGS = op.Operator() for l in range(L): hamiltonianGS.add( op.scal_opstr(J, (op.Sz(l), op.Sz((l + 1) % L)))) hamiltonianGS.add(op.scal_opstr(hx, (op.Sx(l), ))) # Set up exact sampler exactSampler = sampler.ExactSampler(L) delta = 2 tdvpEquation = jVMC.tdvp.TDVP(exactSampler, snrTol=1, svdTol=1e-8, rhsPrefactor=1., diagonalShift=delta, makeReal='real') # Perform ground state search to get initial state ground_state_search(psi, hamiltonianGS, tdvpEquation, exactSampler, numSteps=100, stepSize=2e-2) obs = measure({"energy": hamiltonianGS}, psi, exactSampler) self.assertTrue( jnp.max(jnp.abs((obs['energy']['mean'] - exE) / exE)) < 1e-3)
def test_time_evolution(self): L = 4 J = -1.0 hx = -0.3 weights = jnp.array([ 0.23898957, 0.12614753, 0.19479055, 0.17325271, 0.14619853, 0.21392751, 0.19648707, 0.17103704, -0.15457255, 0.10954413, 0.13228065, -0.14935214, -0.09963073, 0.17610707, 0.13386381, -0.14836467 ]) # Set up variational wave function rbm = nets.CpxRBM.partial(numHidden=2, bias=False) _, params = rbm.init_by_shape(random.PRNGKey(0), [(1, L)]) rbmModel = nn.Model(rbm, params) psi = NQS(rbmModel) psi.set_parameters(weights) # Set up hamiltonian for time evolution hamiltonian = op.Operator() for l in range(L): hamiltonian.add(op.scal_opstr(J, (op.Sz(l), op.Sz((l + 1) % L)))) hamiltonian.add(op.scal_opstr(hx, (op.Sx(l), ))) # Set up ZZ observable ZZ = op.Operator() for l in range(L): ZZ.add((op.Sz(l), op.Sz((l + 1) % L))) # Set up exact sampler exactSampler = sampler.ExactSampler(L) # Set up adaptive time stepper stepper = jVMCstepper.AdaptiveHeun(timeStep=1e-3, tol=1e-5) tdvpEquation = jVMC.tdvp.TDVP(exactSampler, snrTol=1, svdTol=1e-8, rhsPrefactor=1.j, diagonalShift=0., makeReal='imag') t = 0 obs = [] times = [] times.append(t) newMeas = measure({'E': hamiltonian, 'ZZ': ZZ}, psi, exactSampler) obs.append([newMeas['E']['mean'], newMeas['ZZ']['mean']]) while t < 0.5: dp, dt = stepper.step(0, tdvpEquation, psi.get_parameters(), hamiltonian=hamiltonian, psi=psi, numSamples=0) psi.set_parameters(dp) t += dt times.append(t) newMeas = measure({'E': hamiltonian, 'ZZ': ZZ}, psi, exactSampler) obs.append([newMeas['E']['mean'], newMeas['ZZ']['mean']]) obs = np.array(jnp.asarray(obs)) # Check energy conservation obs[:, 0] = np.abs((obs[:, 0] - obs[0, 0]) / obs[0, 0]) self.assertTrue(np.max(obs[:, 0]) < 1e-3) # Check observable dynamics zz = interp1d(np.array(times), obs[:, 1, 0]) refTimes = np.arange(0, 0.5, 0.05) netZZ = zz(refTimes) refZZ = np.array([ 0.882762129306284, 0.8936168721790617, 0.9257753299594491, 0.9779836185039352, 1.0482156449061142, 1.1337654450614298, 1.231369697427413, 1.337354107391303, 1.447796176316155, 1.558696104640795, 1.666147269524912, 1.7664978782554912, 1.8564960156892512, 1.9334113379450693, 1.9951280521882777, 2.0402054805651546, 2.067904337137255, 2.078178742959828, 2.071635856483114, 2.049466698269522, 2.049466698269522 ]) self.assertTrue(np.max(np.abs(netZZ - refZZ[:len(netZZ)])) < 1e-3)