示例#1
0
    def test_gradients_nonhermitian(self):

        dlist = [jax.devices()[0], jax.devices()]

        for ds in dlist:

            global_defs.set_pmap_devices(ds)

            net = nets.CpxRNN.partial(L=3)
            _, params1 = net.init_by_shape(random.PRNGKey(0), [(3, )])
            model = nn.Model(net, params1)

            s = jnp.zeros(get_shape((4, 3)), dtype=np.int32)
            s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1)
            s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1)

            psi = NQS(model)
            psi0 = psi(s)
            G = psi.gradients(s)
            delta = 1e-5
            params = psi.get_parameters()
            for j in range(G.shape[-1]):
                u = jax.ops.index_update(
                    jnp.zeros(G.shape[-1], dtype=jVMC.global_defs.tReal),
                    jax.ops.index[j], 1)
                psi.update_parameters(delta * u)
                psi1 = psi(s)
                psi.set_parameters(params)

                # Finite difference gradients
                Gfd = (psi1 - psi0) / delta

                with self.subTest(i=j):
                    self.assertTrue(jnp.max(jnp.abs(Gfd - G[..., j])) < 1e-2)
示例#2
0
    def test_gradients_cpx(self):

        dlist = [jax.devices()[0], jax.devices()]

        for ds in dlist:

            global_defs.set_pmap_devices(ds)

            rbm = nets.CpxRBM.partial(numHidden=2, bias=True)
            _, params = rbm.init_by_shape(random.PRNGKey(0), [(1, 3)])
            rbmModel = nn.Model(rbm, params)
            s = jnp.zeros(get_shape((4, 3)), dtype=np.int32)
            s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1)
            s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1)

            psiC = NQS(rbmModel)
            psi0 = psiC(s)
            G = psiC.gradients(s)
            delta = 1e-5
            params = psiC.get_parameters()
            for j in range(G.shape[-1]):
                u = jax.ops.index_update(
                    jnp.zeros(G.shape[-1], dtype=global_defs.tReal),
                    jax.ops.index[j], 1)
                psiC.update_parameters(delta * u)
                psi1 = psiC(s)
                psiC.set_parameters(params)

                # Finite difference gradients
                Gfd = (psi1 - psi0) / delta

                with self.subTest(i=j):
                    self.assertTrue(jnp.max(jnp.abs(Gfd - G[..., j])) < 1e-2)
    def test_MCMC_sampling(self):
        L = 4

        weights = jnp.array([
            0.23898957, 0.12614753, 0.19479055, 0.17325271, 0.14619853,
            0.21392751, 0.19648707, 0.17103704, -0.15457255, 0.10954413,
            0.13228065, -0.14935214, -0.09963073, 0.17610707, 0.13386381,
            -0.14836467
        ])

        # Set up variational wave function
        rbm = nets.CpxRBM.partial(numHidden=2, bias=False)
        _, params = rbm.init_by_shape(random.PRNGKey(0), [(L, )])
        rbmModel = nn.Model(rbm, params)
        psi = NQS(rbmModel)
        psi.set_parameters(weights)

        # Set up exact sampler
        exactSampler = sampler.ExactSampler(L)

        # Set up MCMC sampler
        mcSampler = sampler.MCMCSampler(random.PRNGKey(0),
                                        jVMC.sampler.propose_spin_flip, (L, ),
                                        numChains=777)

        # Compute exact probabilities
        _, _, pex = exactSampler.sample(psi)

        # Get samples from MCMC sampler
        numSamples = 500000
        smc, _, _ = mcSampler.sample(psi, numSamples=numSamples)

        if global_defs.usePmap:
            smc = smc.reshape((smc.shape[0] * smc.shape[1], -1))

        self.assertTrue(smc.shape[0] >= numSamples)

        # Compute histogram of sampled configurations
        smcInt = jax.vmap(state_to_int)(smc)
        pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17))

        # Compare histogram to exact probabilities
        self.assertTrue(
            jnp.max(
                jnp.abs(pmc / mcSampler.get_last_number_of_samples() -
                        pex.reshape((-1, ))[:16])) < 2e-3)
    def test_autoregressive_sampling(self):

        L = 4

        # Set up variational wave function
        rnn = nets.RNN.partial(L=4, hiddenSize=5, depth=2)
        _, params = rnn.init_by_shape(random.PRNGKey(0), [(L, )])
        rnnModel = nn.Model(rnn, params)
        rbm = nets.RBM.partial(numHidden=2, bias=False)
        _, params = rbm.init_by_shape(random.PRNGKey(0), [(L, )])
        rbmModel = nn.Model(rbm, params)

        psi = NQS((rnnModel, rbmModel))

        ps = psi.get_parameters()
        psi.update_parameters(ps)
        # Set up exact sampler
        exactSampler = sampler.ExactSampler(L)

        # Set up MCMC sampler
        mcSampler = sampler.MCMCSampler(random.PRNGKey(0),
                                        jVMC.sampler.propose_spin_flip, (L, ),
                                        numChains=777)

        # Compute exact probabilities
        _, _, pex = exactSampler.sample(psi)

        numSamples = 500000
        smc, p, _ = mcSampler.sample(psi, numSamples=numSamples)

        self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - p))) < 1e-12)

        if global_defs.usePmap:
            smc = smc.reshape((smc.shape[0] * smc.shape[1], -1))

        self.assertTrue(smc.shape[0] >= numSamples)

        # Compute histogram of sampled configurations
        smcInt = jax.vmap(state_to_int)(smc)
        pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17))

        self.assertTrue(
            jnp.max(
                jnp.abs(pmc / mcSampler.get_last_number_of_samples() -
                        pex.reshape((-1, ))[:16])) < 1e-3)
示例#5
0
    def test_evaluation_cpx(self):

        dlist = [jax.devices()[0], jax.devices()]

        for ds in dlist:

            global_defs.set_pmap_devices(ds)

            rbm = nets.CpxRBM.partial(numHidden=2, bias=True)
            _, params = rbm.init_by_shape(random.PRNGKey(0), [(4, 3)])
            rbmModel = nn.Model(rbm, params)
            s = jnp.zeros(get_shape((4, 3)), dtype=np.int32)
            s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1)
            s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1)

            psiC = NQS(rbmModel)
            cpxCoeffs = psiC(s)
            realCoeffs = psiC.real_coefficients(s)

            self.assertTrue(
                jnp.linalg.norm(jnp.real(cpxCoeffs) - realCoeffs) < 1e-6)
def time_net_eval(states, timingReps, get_net):
    
    model = get_net()
    psi = NQS(model)

    t0 = time.perf_counter()
    psi(states).block_until_ready()
    t1 = time.perf_counter()
    print("      Time elapsed (incl. jit): %f seconds" % (t1-t0))

    t=0
    for i in range(timingReps):
        t0 = time.perf_counter()
        psi(states).block_until_ready()
        t1 = time.perf_counter()
        t += t1-t0
    print("      Avg. time elapsed (jit'd, %d repetitions): %f seconds" % (timingReps, t/timingReps))
示例#7
0
    def test_gs_search_cpx(self):
        L = 4
        J = -1.0
        hxs = [-1.3, -0.3]

        exEs = [-6.10160339, -4.09296160]

        for hx, exE in zip(hxs, exEs):
            # Set up variational wave function
            rbm = nets.CpxRBM.partial(numHidden=6, bias=False)
            _, params = rbm.init_by_shape(random.PRNGKey(1), [(1, L)])
            rbmModel = nn.Model(rbm, params)
            psi = NQS(rbmModel)

            # Set up hamiltonian for ground state search
            hamiltonianGS = op.Operator()
            for l in range(L):
                hamiltonianGS.add(
                    op.scal_opstr(J, (op.Sz(l), op.Sz((l + 1) % L))))
                hamiltonianGS.add(op.scal_opstr(hx, (op.Sx(l), )))

            # Set up exact sampler
            exactSampler = sampler.ExactSampler(L)

            delta = 2
            tdvpEquation = jVMC.tdvp.TDVP(exactSampler,
                                          snrTol=1,
                                          svdTol=1e-8,
                                          rhsPrefactor=1.,
                                          diagonalShift=delta,
                                          makeReal='real')

            # Perform ground state search to get initial state
            ground_state_search(psi,
                                hamiltonianGS,
                                tdvpEquation,
                                exactSampler,
                                numSteps=100,
                                stepSize=2e-2)

            obs = measure({"energy": hamiltonianGS}, psi, exactSampler)

            self.assertTrue(
                jnp.max(jnp.abs((obs['energy']['mean'] - exE) / exE)) < 1e-3)
def time_net_gradients(states, timingReps, get_net):
    
    model = get_net()
    psi = NQS(model, batchSize=states.shape[1])

    t0 = time.perf_counter()
    psi.gradients(states).block_until_ready()
    t1 = time.perf_counter()
    print("      Time elapsed (incl. jit): %f seconds" % (t1-t0), flush=True)

    t=0
    for i in range(timingReps):
        t0 = time.perf_counter()
        psi.gradients(states).block_until_ready()
        t1 = time.perf_counter()
        t += t1-t0
    print("      Avg. time elapsed (jit'd, %d repetitions): %f seconds" % (timingReps, t/timingReps), flush=True)
示例#9
0
    def test_time_evolution(self):
        L = 4
        J = -1.0
        hx = -0.3

        weights = jnp.array([
            0.23898957, 0.12614753, 0.19479055, 0.17325271, 0.14619853,
            0.21392751, 0.19648707, 0.17103704, -0.15457255, 0.10954413,
            0.13228065, -0.14935214, -0.09963073, 0.17610707, 0.13386381,
            -0.14836467
        ])

        # Set up variational wave function
        rbm = nets.CpxRBM.partial(numHidden=2, bias=False)
        _, params = rbm.init_by_shape(random.PRNGKey(0), [(1, L)])
        rbmModel = nn.Model(rbm, params)
        psi = NQS(rbmModel)
        psi.set_parameters(weights)

        # Set up hamiltonian for time evolution
        hamiltonian = op.Operator()
        for l in range(L):
            hamiltonian.add(op.scal_opstr(J, (op.Sz(l), op.Sz((l + 1) % L))))
            hamiltonian.add(op.scal_opstr(hx, (op.Sx(l), )))

        # Set up ZZ observable
        ZZ = op.Operator()
        for l in range(L):
            ZZ.add((op.Sz(l), op.Sz((l + 1) % L)))

        # Set up exact sampler
        exactSampler = sampler.ExactSampler(L)

        # Set up adaptive time stepper
        stepper = jVMCstepper.AdaptiveHeun(timeStep=1e-3, tol=1e-5)

        tdvpEquation = jVMC.tdvp.TDVP(exactSampler,
                                      snrTol=1,
                                      svdTol=1e-8,
                                      rhsPrefactor=1.j,
                                      diagonalShift=0.,
                                      makeReal='imag')

        t = 0
        obs = []
        times = []
        times.append(t)
        newMeas = measure({'E': hamiltonian, 'ZZ': ZZ}, psi, exactSampler)
        obs.append([newMeas['E']['mean'], newMeas['ZZ']['mean']])
        while t < 0.5:
            dp, dt = stepper.step(0,
                                  tdvpEquation,
                                  psi.get_parameters(),
                                  hamiltonian=hamiltonian,
                                  psi=psi,
                                  numSamples=0)
            psi.set_parameters(dp)
            t += dt
            times.append(t)
            newMeas = measure({'E': hamiltonian, 'ZZ': ZZ}, psi, exactSampler)
            obs.append([newMeas['E']['mean'], newMeas['ZZ']['mean']])

        obs = np.array(jnp.asarray(obs))

        # Check energy conservation
        obs[:, 0] = np.abs((obs[:, 0] - obs[0, 0]) / obs[0, 0])
        self.assertTrue(np.max(obs[:, 0]) < 1e-3)

        # Check observable dynamics
        zz = interp1d(np.array(times), obs[:, 1, 0])
        refTimes = np.arange(0, 0.5, 0.05)
        netZZ = zz(refTimes)
        refZZ = np.array([
            0.882762129306284, 0.8936168721790617, 0.9257753299594491,
            0.9779836185039352, 1.0482156449061142, 1.1337654450614298,
            1.231369697427413, 1.337354107391303, 1.447796176316155,
            1.558696104640795, 1.666147269524912, 1.7664978782554912,
            1.8564960156892512, 1.9334113379450693, 1.9951280521882777,
            2.0402054805651546, 2.067904337137255, 2.078178742959828,
            2.071635856483114, 2.049466698269522, 2.049466698269522
        ])
        self.assertTrue(np.max(np.abs(netZZ - refZZ[:len(netZZ)])) < 1e-3)