예제 #1
0
    def test_gradients_nonhermitian(self):

        dlist = [jax.devices()[0], jax.devices()]

        for ds in dlist:

            global_defs.set_pmap_devices(ds)

            net = nets.CpxRNN.partial(L=3)
            _, params1 = net.init_by_shape(random.PRNGKey(0), [(3, )])
            model = nn.Model(net, params1)

            s = jnp.zeros(get_shape((4, 3)), dtype=np.int32)
            s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1)
            s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1)

            psi = NQS(model)
            psi0 = psi(s)
            G = psi.gradients(s)
            delta = 1e-5
            params = psi.get_parameters()
            for j in range(G.shape[-1]):
                u = jax.ops.index_update(
                    jnp.zeros(G.shape[-1], dtype=jVMC.global_defs.tReal),
                    jax.ops.index[j], 1)
                psi.update_parameters(delta * u)
                psi1 = psi(s)
                psi.set_parameters(params)

                # Finite difference gradients
                Gfd = (psi1 - psi0) / delta

                with self.subTest(i=j):
                    self.assertTrue(jnp.max(jnp.abs(Gfd - G[..., j])) < 1e-2)
예제 #2
0
    def test_gradients_cpx(self):

        dlist = [jax.devices()[0], jax.devices()]

        for ds in dlist:

            global_defs.set_pmap_devices(ds)

            rbm = nets.CpxRBM.partial(numHidden=2, bias=True)
            _, params = rbm.init_by_shape(random.PRNGKey(0), [(1, 3)])
            rbmModel = nn.Model(rbm, params)
            s = jnp.zeros(get_shape((4, 3)), dtype=np.int32)
            s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1)
            s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1)

            psiC = NQS(rbmModel)
            psi0 = psiC(s)
            G = psiC.gradients(s)
            delta = 1e-5
            params = psiC.get_parameters()
            for j in range(G.shape[-1]):
                u = jax.ops.index_update(
                    jnp.zeros(G.shape[-1], dtype=global_defs.tReal),
                    jax.ops.index[j], 1)
                psiC.update_parameters(delta * u)
                psi1 = psiC(s)
                psiC.set_parameters(params)

                # Finite difference gradients
                Gfd = (psi1 - psi0) / delta

                with self.subTest(i=j):
                    self.assertTrue(jnp.max(jnp.abs(Gfd - G[..., j])) < 1e-2)
def time_net_gradients(states, timingReps, get_net):
    
    model = get_net()
    psi = NQS(model, batchSize=states.shape[1])

    t0 = time.perf_counter()
    psi.gradients(states).block_until_ready()
    t1 = time.perf_counter()
    print("      Time elapsed (incl. jit): %f seconds" % (t1-t0), flush=True)

    t=0
    for i in range(timingReps):
        t0 = time.perf_counter()
        psi.gradients(states).block_until_ready()
        t1 = time.perf_counter()
        t += t1-t0
    print("      Avg. time elapsed (jit'd, %d repetitions): %f seconds" % (timingReps, t/timingReps), flush=True)