Ejemplo n.º 1
0
    def test_autoregressive_sampling_with_lstm(self):

        L = 4

        # Set up symmetry orbit
        orbit = jnp.array([
            jnp.roll(jnp.identity(L, dtype=np.int32), l, axis=1)
            for l in range(L)
        ])

        # Set up variational wave function
        rnn = nets.LSTM.partial(L=L, hiddenSize=5)
        _, params = rnn.init_by_shape(random.PRNGKey(0), [(L, )])
        rnnModel = nn.Model(rnn, params)
        rbm = nets.RBM.partial(numHidden=2, bias=False)
        _, params = rbm.init_by_shape(random.PRNGKey(0), [(L, )])
        rbmModel = nn.Model(rbm, params)

        psi = NQS((rnnModel, rbmModel))

        # Set up exact sampler
        exactSampler = sampler.ExactSampler(L)

        # Set up MCMC sampler
        mcSampler = sampler.MCMCSampler(random.PRNGKey(0),
                                        jVMC.sampler.propose_spin_flip, (L, ),
                                        numChains=777)

        # Compute exact probabilities
        _, logPsi, pex = exactSampler.sample(psi)

        numSamples = 1000000
        smc, p, _ = mcSampler.sample(psi, numSamples=numSamples)

        self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - p))) < 1e-12)

        if global_defs.usePmap:
            smc = smc.reshape((smc.shape[0] * smc.shape[1], -1))

        self.assertTrue(smc.shape[0] >= numSamples)

        # Compute histogram of sampled configurations
        smcInt = jax.vmap(state_to_int)(smc)
        pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17))

        self.assertTrue(
            jnp.max(
                jnp.abs(pmc / mcSampler.get_last_number_of_samples() -
                        pex.reshape((-1, ))[:16])) < 1e-3)
Ejemplo n.º 2
0
    def test_MCMC_sampling(self):
        L = 4

        weights = jnp.array([
            0.23898957, 0.12614753, 0.19479055, 0.17325271, 0.14619853,
            0.21392751, 0.19648707, 0.17103704, -0.15457255, 0.10954413,
            0.13228065, -0.14935214, -0.09963073, 0.17610707, 0.13386381,
            -0.14836467
        ])

        # Set up variational wave function
        rbm = nets.CpxRBM.partial(numHidden=2, bias=False)
        _, params = rbm.init_by_shape(random.PRNGKey(0), [(L, )])
        rbmModel = nn.Model(rbm, params)
        psi = NQS(rbmModel)
        psi.set_parameters(weights)

        # Set up exact sampler
        exactSampler = sampler.ExactSampler(L)

        # Set up MCMC sampler
        mcSampler = sampler.MCMCSampler(random.PRNGKey(0),
                                        jVMC.sampler.propose_spin_flip, (L, ),
                                        numChains=777)

        # Compute exact probabilities
        _, _, pex = exactSampler.sample(psi)

        # Get samples from MCMC sampler
        numSamples = 500000
        smc, _, _ = mcSampler.sample(psi, numSamples=numSamples)

        if global_defs.usePmap:
            smc = smc.reshape((smc.shape[0] * smc.shape[1], -1))

        self.assertTrue(smc.shape[0] >= numSamples)

        # Compute histogram of sampled configurations
        smcInt = jax.vmap(state_to_int)(smc)
        pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17))

        # Compare histogram to exact probabilities
        self.assertTrue(
            jnp.max(
                jnp.abs(pmc / mcSampler.get_last_number_of_samples() -
                        pex.reshape((-1, ))[:16])) < 2e-3)
Ejemplo n.º 3
0
    def test_gs_search_cpx(self):
        L = 4
        J = -1.0
        hxs = [-1.3, -0.3]

        exEs = [-6.10160339, -4.09296160]

        for hx, exE in zip(hxs, exEs):
            # Set up variational wave function
            rbm = nets.CpxRBM.partial(numHidden=6, bias=False)
            _, params = rbm.init_by_shape(random.PRNGKey(1), [(1, L)])
            rbmModel = nn.Model(rbm, params)
            psi = NQS(rbmModel)

            # Set up hamiltonian for ground state search
            hamiltonianGS = op.Operator()
            for l in range(L):
                hamiltonianGS.add(
                    op.scal_opstr(J, (op.Sz(l), op.Sz((l + 1) % L))))
                hamiltonianGS.add(op.scal_opstr(hx, (op.Sx(l), )))

            # Set up exact sampler
            exactSampler = sampler.ExactSampler(L)

            delta = 2
            tdvpEquation = jVMC.tdvp.TDVP(exactSampler,
                                          snrTol=1,
                                          svdTol=1e-8,
                                          rhsPrefactor=1.,
                                          diagonalShift=delta,
                                          makeReal='real')

            # Perform ground state search to get initial state
            ground_state_search(psi,
                                hamiltonianGS,
                                tdvpEquation,
                                exactSampler,
                                numSteps=100,
                                stepSize=2e-2)

            obs = measure({"energy": hamiltonianGS}, psi, exactSampler)

            self.assertTrue(
                jnp.max(jnp.abs((obs['energy']['mean'] - exE) / exE)) < 1e-3)
Ejemplo n.º 4
0
    def test_time_evolution(self):
        L = 4
        J = -1.0
        hx = -0.3

        weights = jnp.array([
            0.23898957, 0.12614753, 0.19479055, 0.17325271, 0.14619853,
            0.21392751, 0.19648707, 0.17103704, -0.15457255, 0.10954413,
            0.13228065, -0.14935214, -0.09963073, 0.17610707, 0.13386381,
            -0.14836467
        ])

        # Set up variational wave function
        rbm = nets.CpxRBM.partial(numHidden=2, bias=False)
        _, params = rbm.init_by_shape(random.PRNGKey(0), [(1, L)])
        rbmModel = nn.Model(rbm, params)
        psi = NQS(rbmModel)
        psi.set_parameters(weights)

        # Set up hamiltonian for time evolution
        hamiltonian = op.Operator()
        for l in range(L):
            hamiltonian.add(op.scal_opstr(J, (op.Sz(l), op.Sz((l + 1) % L))))
            hamiltonian.add(op.scal_opstr(hx, (op.Sx(l), )))

        # Set up ZZ observable
        ZZ = op.Operator()
        for l in range(L):
            ZZ.add((op.Sz(l), op.Sz((l + 1) % L)))

        # Set up exact sampler
        exactSampler = sampler.ExactSampler(L)

        # Set up adaptive time stepper
        stepper = jVMCstepper.AdaptiveHeun(timeStep=1e-3, tol=1e-5)

        tdvpEquation = jVMC.tdvp.TDVP(exactSampler,
                                      snrTol=1,
                                      svdTol=1e-8,
                                      rhsPrefactor=1.j,
                                      diagonalShift=0.,
                                      makeReal='imag')

        t = 0
        obs = []
        times = []
        times.append(t)
        newMeas = measure({'E': hamiltonian, 'ZZ': ZZ}, psi, exactSampler)
        obs.append([newMeas['E']['mean'], newMeas['ZZ']['mean']])
        while t < 0.5:
            dp, dt = stepper.step(0,
                                  tdvpEquation,
                                  psi.get_parameters(),
                                  hamiltonian=hamiltonian,
                                  psi=psi,
                                  numSamples=0)
            psi.set_parameters(dp)
            t += dt
            times.append(t)
            newMeas = measure({'E': hamiltonian, 'ZZ': ZZ}, psi, exactSampler)
            obs.append([newMeas['E']['mean'], newMeas['ZZ']['mean']])

        obs = np.array(jnp.asarray(obs))

        # Check energy conservation
        obs[:, 0] = np.abs((obs[:, 0] - obs[0, 0]) / obs[0, 0])
        self.assertTrue(np.max(obs[:, 0]) < 1e-3)

        # Check observable dynamics
        zz = interp1d(np.array(times), obs[:, 1, 0])
        refTimes = np.arange(0, 0.5, 0.05)
        netZZ = zz(refTimes)
        refZZ = np.array([
            0.882762129306284, 0.8936168721790617, 0.9257753299594491,
            0.9779836185039352, 1.0482156449061142, 1.1337654450614298,
            1.231369697427413, 1.337354107391303, 1.447796176316155,
            1.558696104640795, 1.666147269524912, 1.7664978782554912,
            1.8564960156892512, 1.9334113379450693, 1.9951280521882777,
            2.0402054805651546, 2.067904337137255, 2.078178742959828,
            2.071635856483114, 2.049466698269522, 2.049466698269522
        ])
        self.assertTrue(np.max(np.abs(netZZ - refZZ[:len(netZZ)])) < 1e-3)