def test_gs_search_cpx(self): L = 4 J = -1.0 hxs = [-1.3, -0.3] exEs = [-6.10160339, -4.09296160] for hx, exE in zip(hxs, exEs): # Set up variational wave function rbm = nets.CpxRBM.partial(numHidden=6, bias=False) _, params = rbm.init_by_shape(random.PRNGKey(1), [(1, L)]) rbmModel = nn.Model(rbm, params) psi = NQS(rbmModel) # Set up hamiltonian for ground state search hamiltonianGS = op.Operator() for l in range(L): hamiltonianGS.add( op.scal_opstr(J, (op.Sz(l), op.Sz((l + 1) % L)))) hamiltonianGS.add(op.scal_opstr(hx, (op.Sx(l), ))) # Set up exact sampler exactSampler = sampler.ExactSampler(L) delta = 2 tdvpEquation = jVMC.tdvp.TDVP(exactSampler, snrTol=1, svdTol=1e-8, rhsPrefactor=1., diagonalShift=delta, makeReal='real') # Perform ground state search to get initial state ground_state_search(psi, hamiltonianGS, tdvpEquation, exactSampler, numSteps=100, stepSize=2e-2) obs = measure({"energy": hamiltonianGS}, psi, exactSampler) self.assertTrue( jnp.max(jnp.abs((obs['energy']['mean'] - exE) / exE)) < 1e-3)
def test_nonzeros(self): L=4 lDim=2 key = random.PRNGKey(3) s = random.randint(key, (24,L), 0, 2, dtype=np.int32).reshape(get_shape((-1, L))) h=op.Operator() h.add(op.scal_opstr(2., (op.Sp(0),))) h.add(op.scal_opstr(2., (op.Sp(1),))) h.add(op.scal_opstr(2., (op.Sp(2),))) sp,matEl=h.get_s_primes(s) logPsi=jnp.ones(s.shape[:-1]) logPsiSP=jnp.ones(sp.shape[:-1]) tmp = h.get_O_loc(logPsi,logPsiSP) self.assertTrue( jnp.sum(jnp.abs( tmp - 2. * jnp.sum(-(s[...,:3]-1), axis=-1) )) < 1e-7 )
import jVMC import jVMC.operator as op import numpy as np import jax import jax.numpy as jnp import time L = 128 J = -1.0 hx = -0.3 numStates = 1000 timingReps = 10 hamiltonian = op.Operator() for l in range(L): hamiltonian.add(op.scal_opstr(J, (op.Sz(l), op.Sz((l + 1) % L)))) hamiltonian.add(op.scal_opstr(hx, (op.Sx(l), ))) states = jnp.array(jax.random.bernoulli(jax.random.PRNGKey(0), shape=(jax.device_count(), numStates, L)), dtype=np.int32) print("* Compute off-diagonal configurations") t0 = time.perf_counter() sp, me = hamiltonian.get_s_primes(states) sp.block_until_ready() t1 = time.perf_counter()
def test_time_evolution(self): L = 4 J = -1.0 hx = -0.3 weights = jnp.array([ 0.23898957, 0.12614753, 0.19479055, 0.17325271, 0.14619853, 0.21392751, 0.19648707, 0.17103704, -0.15457255, 0.10954413, 0.13228065, -0.14935214, -0.09963073, 0.17610707, 0.13386381, -0.14836467 ]) # Set up variational wave function rbm = nets.CpxRBM.partial(numHidden=2, bias=False) _, params = rbm.init_by_shape(random.PRNGKey(0), [(1, L)]) rbmModel = nn.Model(rbm, params) psi = NQS(rbmModel) psi.set_parameters(weights) # Set up hamiltonian for time evolution hamiltonian = op.Operator() for l in range(L): hamiltonian.add(op.scal_opstr(J, (op.Sz(l), op.Sz((l + 1) % L)))) hamiltonian.add(op.scal_opstr(hx, (op.Sx(l), ))) # Set up ZZ observable ZZ = op.Operator() for l in range(L): ZZ.add((op.Sz(l), op.Sz((l + 1) % L))) # Set up exact sampler exactSampler = sampler.ExactSampler(L) # Set up adaptive time stepper stepper = jVMCstepper.AdaptiveHeun(timeStep=1e-3, tol=1e-5) tdvpEquation = jVMC.tdvp.TDVP(exactSampler, snrTol=1, svdTol=1e-8, rhsPrefactor=1.j, diagonalShift=0., makeReal='imag') t = 0 obs = [] times = [] times.append(t) newMeas = measure({'E': hamiltonian, 'ZZ': ZZ}, psi, exactSampler) obs.append([newMeas['E']['mean'], newMeas['ZZ']['mean']]) while t < 0.5: dp, dt = stepper.step(0, tdvpEquation, psi.get_parameters(), hamiltonian=hamiltonian, psi=psi, numSamples=0) psi.set_parameters(dp) t += dt times.append(t) newMeas = measure({'E': hamiltonian, 'ZZ': ZZ}, psi, exactSampler) obs.append([newMeas['E']['mean'], newMeas['ZZ']['mean']]) obs = np.array(jnp.asarray(obs)) # Check energy conservation obs[:, 0] = np.abs((obs[:, 0] - obs[0, 0]) / obs[0, 0]) self.assertTrue(np.max(obs[:, 0]) < 1e-3) # Check observable dynamics zz = interp1d(np.array(times), obs[:, 1, 0]) refTimes = np.arange(0, 0.5, 0.05) netZZ = zz(refTimes) refZZ = np.array([ 0.882762129306284, 0.8936168721790617, 0.9257753299594491, 0.9779836185039352, 1.0482156449061142, 1.1337654450614298, 1.231369697427413, 1.337354107391303, 1.447796176316155, 1.558696104640795, 1.666147269524912, 1.7664978782554912, 1.8564960156892512, 1.9334113379450693, 1.9951280521882777, 2.0402054805651546, 2.067904337137255, 2.078178742959828, 2.071635856483114, 2.049466698269522, 2.049466698269522 ]) self.assertTrue(np.max(np.abs(netZZ - refZZ[:len(netZZ)])) < 1e-3)
flush=True) L = inp["system"]["L"] # Initialize output manager outp = OutputManager(wdir + inp["general"]["data_output"], append=inp["general"]["append_data"]) # Set up variational wave function psi = init_net(inp["network"], [(L, )]) outp.print("** Network properties") outp.print(" Number of parameters: %d" % (len(psi.get_parameters()))) # Set up hamiltonian for ground state search hamiltonianGS = op.Operator() hz0 = 0.0 if "hz0" in inp["system"].keys(): hz0 = inp["system"]["hz0"] for l in range(L): hamiltonianGS.add( op.scal_opstr(inp["system"]["J0"], (op.Sz(l), op.Sz((l + 1) % L)))) hamiltonianGS.add(op.scal_opstr(inp["system"]["hx0"], (op.Sx(l), ))) if np.abs(hz0) > 1e-10: hamiltonianGS.add(op.scal_opstr(hz0, (op.Sz(l), ))) # Set up hamiltonian hamiltonian = op.Operator() lbda = 0.0 if "lambda" in inp["system"].keys(): lbda = inp["system"]["lambda"]
print("Creation of the directory %s failed" % wdir) else: print("Successfully created the directory %s " % wdir) global_defs.set_pmap_devices(jax.devices()[mpi.rank % jax.device_count()]) print(" -> Rank %d working with device %s" % (mpi.rank, global_defs.devices()), flush=True) # L = inp["system"]["L"] L = 2 # Initialize output manager # outp = OutputManager(wdir+inp["general"]["data_output"], append=inp["general"]["append_data"]) # Set up hamiltonian for ground state search hamiltonianGS = op.Operator() hamiltonianGS.add((op.PzDiag(0), op.PzDiag(1))) sampler = jVMC.sampler.ExactSampler(L, lDim=4) print(sampler.get_basis().shape) print(sampler.get_basis()) print() configurations, matrix_element = hamiltonianGS.get_s_primes( sampler.get_basis()) print(configurations) print() print(matrix_element)