示例#1
0
    def test_gradients_nonhermitian(self):

        dlist = [jax.devices()[0], jax.devices()]

        for ds in dlist:

            global_defs.set_pmap_devices(ds)

            net = nets.CpxRNN.partial(L=3)
            _, params1 = net.init_by_shape(random.PRNGKey(0), [(3, )])
            model = nn.Model(net, params1)

            s = jnp.zeros(get_shape((4, 3)), dtype=np.int32)
            s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1)
            s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1)

            psi = NQS(model)
            psi0 = psi(s)
            G = psi.gradients(s)
            delta = 1e-5
            params = psi.get_parameters()
            for j in range(G.shape[-1]):
                u = jax.ops.index_update(
                    jnp.zeros(G.shape[-1], dtype=jVMC.global_defs.tReal),
                    jax.ops.index[j], 1)
                psi.update_parameters(delta * u)
                psi1 = psi(s)
                psi.set_parameters(params)

                # Finite difference gradients
                Gfd = (psi1 - psi0) / delta

                with self.subTest(i=j):
                    self.assertTrue(jnp.max(jnp.abs(Gfd - G[..., j])) < 1e-2)
示例#2
0
    def test_gradients_cpx(self):

        dlist = [jax.devices()[0], jax.devices()]

        for ds in dlist:

            global_defs.set_pmap_devices(ds)

            rbm = nets.CpxRBM.partial(numHidden=2, bias=True)
            _, params = rbm.init_by_shape(random.PRNGKey(0), [(1, 3)])
            rbmModel = nn.Model(rbm, params)
            s = jnp.zeros(get_shape((4, 3)), dtype=np.int32)
            s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1)
            s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1)

            psiC = NQS(rbmModel)
            psi0 = psiC(s)
            G = psiC.gradients(s)
            delta = 1e-5
            params = psiC.get_parameters()
            for j in range(G.shape[-1]):
                u = jax.ops.index_update(
                    jnp.zeros(G.shape[-1], dtype=global_defs.tReal),
                    jax.ops.index[j], 1)
                psiC.update_parameters(delta * u)
                psi1 = psiC(s)
                psiC.set_parameters(params)

                # Finite difference gradients
                Gfd = (psi1 - psi0) / delta

                with self.subTest(i=j):
                    self.assertTrue(jnp.max(jnp.abs(Gfd - G[..., j])) < 1e-2)
    def test_autoregressive_sampling(self):

        L = 4

        # Set up variational wave function
        rnn = nets.RNN.partial(L=4, hiddenSize=5, depth=2)
        _, params = rnn.init_by_shape(random.PRNGKey(0), [(L, )])
        rnnModel = nn.Model(rnn, params)
        rbm = nets.RBM.partial(numHidden=2, bias=False)
        _, params = rbm.init_by_shape(random.PRNGKey(0), [(L, )])
        rbmModel = nn.Model(rbm, params)

        psi = NQS((rnnModel, rbmModel))

        ps = psi.get_parameters()
        psi.update_parameters(ps)
        # Set up exact sampler
        exactSampler = sampler.ExactSampler(L)

        # Set up MCMC sampler
        mcSampler = sampler.MCMCSampler(random.PRNGKey(0),
                                        jVMC.sampler.propose_spin_flip, (L, ),
                                        numChains=777)

        # Compute exact probabilities
        _, _, pex = exactSampler.sample(psi)

        numSamples = 500000
        smc, p, _ = mcSampler.sample(psi, numSamples=numSamples)

        self.assertTrue(jnp.max(jnp.abs(jnp.real(psi(smc) - p))) < 1e-12)

        if global_defs.usePmap:
            smc = smc.reshape((smc.shape[0] * smc.shape[1], -1))

        self.assertTrue(smc.shape[0] >= numSamples)

        # Compute histogram of sampled configurations
        smcInt = jax.vmap(state_to_int)(smc)
        pmc, _ = np.histogram(smcInt, bins=np.arange(0, 17))

        self.assertTrue(
            jnp.max(
                jnp.abs(pmc / mcSampler.get_last_number_of_samples() -
                        pex.reshape((-1, ))[:16])) < 1e-3)