Esempio n. 1
0
def test_evolve_trotter_euclidean(num_sites, phys_dim, graph):
    psi = tf.complex(
        tf.random_normal([phys_dim] * num_sites, dtype=tf.float64),
        tf.random_normal([phys_dim] * num_sites, dtype=tf.float64))
    h = tf.complex(
        tf.random_normal((phys_dim**2, phys_dim**2), dtype=tf.float64),
        tf.random_normal((phys_dim**2, phys_dim**2), dtype=tf.float64))
    h = 0.5 * (h + tf.linalg.adjoint(h))
    h = tf.reshape(h, (phys_dim, phys_dim, phys_dim, phys_dim))
    H = [h] * (num_sites - 1)

    norm1 = wavefunctions.inner(psi, psi)
    en1 = sum(wavefunctions.expval(psi, H[i], i) for i in range(num_sites - 1))

    if graph:
        psi, t = wavefunctions.evolve_trotter_defun(psi,
                                                    H,
                                                    0.1,
                                                    10,
                                                    euclidean=True)
    else:
        psi, t = wavefunctions.evolve_trotter(psi, H, 0.1, 10, euclidean=True)

    norm2 = wavefunctions.inner(psi, psi)
    en2 = sum(wavefunctions.expval(psi, H[i], i) for i in range(num_sites - 1))

    np.testing.assert_allclose(t, 1.0)
    np.testing.assert_almost_equal(norm2, 1.0)
    assert en2.numpy() / norm2.numpy() < en1.numpy() / norm1.numpy()
def test_evolve_trotter(num_sites, phys_dim, graph):
    psi = tf.complex(
        tf.random.normal([phys_dim] * num_sites, dtype=tf.float64),
        tf.random.normal([phys_dim] * num_sites, dtype=tf.float64))
    h = tf.complex(
        tf.random.normal((phys_dim**2, phys_dim**2), dtype=tf.float64),
        tf.random.normal((phys_dim**2, phys_dim**2), dtype=tf.float64))
    h = 0.5 * (h + tf.linalg.adjoint(h))
    h = tf.reshape(h, (phys_dim, phys_dim, phys_dim, phys_dim))
    H = [h] * (num_sites - 1)

    norm1 = wavefunctions.inner(psi, psi)
    en1 = sum(wavefunctions.expval(psi, H[i], i) for i in range(num_sites - 1))

    if graph:
        psi, t = wavefunctions.evolve_trotter_defun(psi, H, 0.001, 10)
    else:
        psi, t = wavefunctions.evolve_trotter(psi, H, 0.001, 10)

    norm2 = wavefunctions.inner(psi, psi)
    en2 = sum(wavefunctions.expval(psi, H[i], i) for i in range(num_sites - 1))

    np.testing.assert_allclose(t, 0.01)
    np.testing.assert_almost_equal(norm1 / norm2, 1.0)
    np.testing.assert_almost_equal(en1 / en2, 1.0, decimal=2)
Esempio n. 3
0
def test_expval(num_sites):
    op = tf.convert_to_tensor(np.array([[1.0, 0.0], [0.0, -1.0]]))
    for j in range(num_sites):
        psi = np.zeros([2] * num_sites)
        psi_vec = psi.reshape((2**num_sites, ))
        psi_vec[2**j] = 1.0
        psi = tf.convert_to_tensor(psi)
        for i in range(num_sites):
            res = wavefunctions.expval(psi, op, i)
            if i == num_sites - 1 - j:
                np.testing.assert_allclose(res, -1.0)
            else:
                np.testing.assert_allclose(res, 1.0)
Esempio n. 4
0
def test_opt(backend):
    if backend == "tensorflow":
        dtype = tf.float64
    else:
        dtype = np.float64

    backend_obj = ttn.get_backend()

    num_layers = 3
    max_bond_dim = 8
    build_graphs = True
    num_sweeps = 5

    Ds = [min(2**i, max_bond_dim) for i in range(1, num_layers + 1)]

    H = ttn.get_ham_ising(dtype)
    isos_012 = ttn.random_tree_tn_uniform(Ds, dtype, top_rank=1)
    energy_0 = backend_obj.trace(ttn.top_hamiltonian(H, isos_012))
    isos_012 = ttn.opt_tree_energy(isos_012,
                                   H,
                                   num_sweeps,
                                   1,
                                   verbose=0,
                                   graphed=build_graphs,
                                   ham_shift=0.2)
    energy_1 = backend_obj.trace(ttn.top_hamiltonian(H, isos_012))
    assert backend_obj.to_numpy(energy_1) < backend_obj.to_numpy(energy_0)

    N = 2**num_layers
    full_state = ttn.descend_full_state_pure(isos_012)
    norm = backend_obj.norm(backend_obj.reshape(full_state, (2**N, )))
    assert abs(backend_obj.to_numpy(norm) - 1) < 1e-12

    if backend != "jax":
        # wavefunctions assumes TensorFlow. This will interact with numpy OK, but
        # not JAX.
        h = ttn.ttn_1d_uniform._dense_ham_term(H)
        energy_1_full_state = sum(
            wf.expval(full_state, h, j, pbc=True) for j in range(N))
        assert abs(
            backend_obj.to_numpy(energy_1_full_state) -
            backend_obj.to_numpy(energy_1)) < 1e-12

    isos_012 = ttn.opt_tree_energy(isos_012,
                                   H,
                                   1,
                                   1,
                                   verbose=0,
                                   graphed=False,
                                   decomp_mode="eigh",
                                   ham_shift=0.2)

    for iso in isos_012:
        assert backend_obj.to_numpy(ttn.check_iso(iso)) < 1e-6

    isos_012 = ttn.opt_tree_energy(isos_012,
                                   H,
                                   1,
                                   1,
                                   verbose=0,
                                   graphed=False,
                                   decomp_mode="svd",
                                   ham_shift=0.2)

    for iso in isos_012:
        assert backend_obj.to_numpy(ttn.check_iso(iso)) < 1e-6

    isos_012 = ttn.opt_tree_energy(isos_012,
                                   H,
                                   1,
                                   1,
                                   verbose=0,
                                   graphed=False,
                                   decomp_mode="svd_full_iso_scipy",
                                   ham_shift=0.2)

    for iso in isos_012:
        assert backend_obj.to_numpy(ttn.check_iso(iso)) < 1e-12
def callback(psi, t, i):
    print(i,
          tf.norm(psi).numpy().real,
          wavefunctions.expval(psi, X, 0).numpy().real)
def callback(psi, t, i):
    print(
        i,
        tf.norm(psi).numpy().real,
        wavefunctions.expval(psi, X, 0).numpy().real
    )