def test_gradients_nonhermitian(self): dlist = [jax.devices()[0], jax.devices()] for ds in dlist: global_defs.set_pmap_devices(ds) net = nets.CpxRNN.partial(L=3) _, params1 = net.init_by_shape(random.PRNGKey(0), [(3, )]) model = nn.Model(net, params1) s = jnp.zeros(get_shape((4, 3)), dtype=np.int32) s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1) s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1) psi = NQS(model) psi0 = psi(s) G = psi.gradients(s) delta = 1e-5 params = psi.get_parameters() for j in range(G.shape[-1]): u = jax.ops.index_update( jnp.zeros(G.shape[-1], dtype=jVMC.global_defs.tReal), jax.ops.index[j], 1) psi.update_parameters(delta * u) psi1 = psi(s) psi.set_parameters(params) # Finite difference gradients Gfd = (psi1 - psi0) / delta with self.subTest(i=j): self.assertTrue(jnp.max(jnp.abs(Gfd - G[..., j])) < 1e-2)
def test_gradients_cpx(self): dlist = [jax.devices()[0], jax.devices()] for ds in dlist: global_defs.set_pmap_devices(ds) rbm = nets.CpxRBM.partial(numHidden=2, bias=True) _, params = rbm.init_by_shape(random.PRNGKey(0), [(1, 3)]) rbmModel = nn.Model(rbm, params) s = jnp.zeros(get_shape((4, 3)), dtype=np.int32) s = jax.ops.index_update(s, jax.ops.index[..., 0, 1], 1) s = jax.ops.index_update(s, jax.ops.index[..., 2, 2], 1) psiC = NQS(rbmModel) psi0 = psiC(s) G = psiC.gradients(s) delta = 1e-5 params = psiC.get_parameters() for j in range(G.shape[-1]): u = jax.ops.index_update( jnp.zeros(G.shape[-1], dtype=global_defs.tReal), jax.ops.index[j], 1) psiC.update_parameters(delta * u) psi1 = psiC(s) psiC.set_parameters(params) # Finite difference gradients Gfd = (psi1 - psi0) / delta with self.subTest(i=j): self.assertTrue(jnp.max(jnp.abs(Gfd - G[..., j])) < 1e-2)
def time_net_gradients(states, timingReps, get_net): model = get_net() psi = NQS(model, batchSize=states.shape[1]) t0 = time.perf_counter() psi.gradients(states).block_until_ready() t1 = time.perf_counter() print(" Time elapsed (incl. jit): %f seconds" % (t1-t0), flush=True) t=0 for i in range(timingReps): t0 = time.perf_counter() psi.gradients(states).block_until_ready() t1 = time.perf_counter() t += t1-t0 print(" Avg. time elapsed (jit'd, %d repetitions): %f seconds" % (timingReps, t/timingReps), flush=True)