Exemplo n.º 1
0
    def test_gs_search_cpx(self):
        L = 4
        J = -1.0
        hxs = [-1.3, -0.3]

        exEs = [-6.10160339, -4.09296160]

        for hx, exE in zip(hxs, exEs):
            # Set up variational wave function
            rbm = nets.CpxRBM.partial(numHidden=6, bias=False)
            _, params = rbm.init_by_shape(random.PRNGKey(1), [(1, L)])
            rbmModel = nn.Model(rbm, params)
            psi = NQS(rbmModel)

            # Set up hamiltonian for ground state search
            hamiltonianGS = op.Operator()
            for l in range(L):
                hamiltonianGS.add(
                    op.scal_opstr(J, (op.Sz(l), op.Sz((l + 1) % L))))
                hamiltonianGS.add(op.scal_opstr(hx, (op.Sx(l), )))

            # Set up exact sampler
            exactSampler = sampler.ExactSampler(L)

            delta = 2
            tdvpEquation = jVMC.tdvp.TDVP(exactSampler,
                                          snrTol=1,
                                          svdTol=1e-8,
                                          rhsPrefactor=1.,
                                          diagonalShift=delta,
                                          makeReal='real')

            # Perform ground state search to get initial state
            ground_state_search(psi,
                                hamiltonianGS,
                                tdvpEquation,
                                exactSampler,
                                numSteps=100,
                                stepSize=2e-2)

            obs = measure({"energy": hamiltonianGS}, psi, exactSampler)

            self.assertTrue(
                jnp.max(jnp.abs((obs['energy']['mean'] - exE) / exE)) < 1e-3)
    def test_nonzeros(self):
        
        L=4
        lDim=2
        key = random.PRNGKey(3)
        s = random.randint(key, (24,L), 0, 2, dtype=np.int32).reshape(get_shape((-1, L)))

        h=op.Operator()

        h.add(op.scal_opstr(2., (op.Sp(0),)))
        h.add(op.scal_opstr(2., (op.Sp(1),)))
        h.add(op.scal_opstr(2., (op.Sp(2),)))

        sp,matEl=h.get_s_primes(s)

        logPsi=jnp.ones(s.shape[:-1])
        logPsiSP=jnp.ones(sp.shape[:-1])

        tmp = h.get_O_loc(logPsi,logPsiSP)

        self.assertTrue( jnp.sum(jnp.abs( tmp - 2. * jnp.sum(-(s[...,:3]-1), axis=-1) )) < 1e-7 )
Exemplo n.º 3
0
import jVMC
import jVMC.operator as op
import numpy as np
import jax
import jax.numpy as jnp
import time

L = 128
J = -1.0
hx = -0.3

numStates = 1000

timingReps = 10

hamiltonian = op.Operator()
for l in range(L):
    hamiltonian.add(op.scal_opstr(J, (op.Sz(l), op.Sz((l + 1) % L))))
    hamiltonian.add(op.scal_opstr(hx, (op.Sx(l), )))

states = jnp.array(jax.random.bernoulli(jax.random.PRNGKey(0),
                                        shape=(jax.device_count(), numStates,
                                               L)),
                   dtype=np.int32)

print("* Compute off-diagonal configurations")

t0 = time.perf_counter()
sp, me = hamiltonian.get_s_primes(states)
sp.block_until_ready()
t1 = time.perf_counter()
Exemplo n.º 4
0
    def test_time_evolution(self):
        L = 4
        J = -1.0
        hx = -0.3

        weights = jnp.array([
            0.23898957, 0.12614753, 0.19479055, 0.17325271, 0.14619853,
            0.21392751, 0.19648707, 0.17103704, -0.15457255, 0.10954413,
            0.13228065, -0.14935214, -0.09963073, 0.17610707, 0.13386381,
            -0.14836467
        ])

        # Set up variational wave function
        rbm = nets.CpxRBM.partial(numHidden=2, bias=False)
        _, params = rbm.init_by_shape(random.PRNGKey(0), [(1, L)])
        rbmModel = nn.Model(rbm, params)
        psi = NQS(rbmModel)
        psi.set_parameters(weights)

        # Set up hamiltonian for time evolution
        hamiltonian = op.Operator()
        for l in range(L):
            hamiltonian.add(op.scal_opstr(J, (op.Sz(l), op.Sz((l + 1) % L))))
            hamiltonian.add(op.scal_opstr(hx, (op.Sx(l), )))

        # Set up ZZ observable
        ZZ = op.Operator()
        for l in range(L):
            ZZ.add((op.Sz(l), op.Sz((l + 1) % L)))

        # Set up exact sampler
        exactSampler = sampler.ExactSampler(L)

        # Set up adaptive time stepper
        stepper = jVMCstepper.AdaptiveHeun(timeStep=1e-3, tol=1e-5)

        tdvpEquation = jVMC.tdvp.TDVP(exactSampler,
                                      snrTol=1,
                                      svdTol=1e-8,
                                      rhsPrefactor=1.j,
                                      diagonalShift=0.,
                                      makeReal='imag')

        t = 0
        obs = []
        times = []
        times.append(t)
        newMeas = measure({'E': hamiltonian, 'ZZ': ZZ}, psi, exactSampler)
        obs.append([newMeas['E']['mean'], newMeas['ZZ']['mean']])
        while t < 0.5:
            dp, dt = stepper.step(0,
                                  tdvpEquation,
                                  psi.get_parameters(),
                                  hamiltonian=hamiltonian,
                                  psi=psi,
                                  numSamples=0)
            psi.set_parameters(dp)
            t += dt
            times.append(t)
            newMeas = measure({'E': hamiltonian, 'ZZ': ZZ}, psi, exactSampler)
            obs.append([newMeas['E']['mean'], newMeas['ZZ']['mean']])

        obs = np.array(jnp.asarray(obs))

        # Check energy conservation
        obs[:, 0] = np.abs((obs[:, 0] - obs[0, 0]) / obs[0, 0])
        self.assertTrue(np.max(obs[:, 0]) < 1e-3)

        # Check observable dynamics
        zz = interp1d(np.array(times), obs[:, 1, 0])
        refTimes = np.arange(0, 0.5, 0.05)
        netZZ = zz(refTimes)
        refZZ = np.array([
            0.882762129306284, 0.8936168721790617, 0.9257753299594491,
            0.9779836185039352, 1.0482156449061142, 1.1337654450614298,
            1.231369697427413, 1.337354107391303, 1.447796176316155,
            1.558696104640795, 1.666147269524912, 1.7664978782554912,
            1.8564960156892512, 1.9334113379450693, 1.9951280521882777,
            2.0402054805651546, 2.067904337137255, 2.078178742959828,
            2.071635856483114, 2.049466698269522, 2.049466698269522
        ])
        self.assertTrue(np.max(np.abs(netZZ - refZZ[:len(netZZ)])) < 1e-3)
Exemplo n.º 5
0
      flush=True)

L = inp["system"]["L"]

# Initialize output manager
outp = OutputManager(wdir + inp["general"]["data_output"],
                     append=inp["general"]["append_data"])

# Set up variational wave function
psi = init_net(inp["network"], [(L, )])

outp.print("** Network properties")
outp.print("    Number of parameters: %d" % (len(psi.get_parameters())))

# Set up hamiltonian for ground state search
hamiltonianGS = op.Operator()
hz0 = 0.0
if "hz0" in inp["system"].keys():
    hz0 = inp["system"]["hz0"]
for l in range(L):
    hamiltonianGS.add(
        op.scal_opstr(inp["system"]["J0"], (op.Sz(l), op.Sz((l + 1) % L))))
    hamiltonianGS.add(op.scal_opstr(inp["system"]["hx0"], (op.Sx(l), )))
    if np.abs(hz0) > 1e-10:
        hamiltonianGS.add(op.scal_opstr(hz0, (op.Sz(l), )))

# Set up hamiltonian
hamiltonian = op.Operator()
lbda = 0.0
if "lambda" in inp["system"].keys():
    lbda = inp["system"]["lambda"]
Exemplo n.º 6
0
        print("Creation of the directory %s failed" % wdir)
    else:
        print("Successfully created the directory %s " % wdir)

global_defs.set_pmap_devices(jax.devices()[mpi.rank % jax.device_count()])
print(" -> Rank %d working with device %s" % (mpi.rank, global_defs.devices()),
      flush=True)

# L = inp["system"]["L"]
L = 2

# Initialize output manager
# outp = OutputManager(wdir+inp["general"]["data_output"], append=inp["general"]["append_data"])

# Set up hamiltonian for ground state search
hamiltonianGS = op.Operator()
hamiltonianGS.add((op.PzDiag(0), op.PzDiag(1)))

sampler = jVMC.sampler.ExactSampler(L, lDim=4)

print(sampler.get_basis().shape)
print(sampler.get_basis())
print()

configurations, matrix_element = hamiltonianGS.get_s_primes(
    sampler.get_basis())

print(configurations)
print()
print(matrix_element)