Esempio n. 1
0
def test_backend_dtype_exception():
  backend = jax_backend.JaxBackend(dtype=np.float32)
  tensor = np.random.rand(2, 2, 2)
  with pytest.raises(TypeError):
    _ = backend.convert_to_tensor(tensor)
def test_broadcast_left_multiplication(dtype):
    backend = jax_backend.JaxBackend()
    tensor1 = backend.randn((3, ), dtype=dtype, seed=10)
    tensor2 = backend.randn((3, 4, 2), dtype=dtype, seed=10)
    out = backend.broadcast_left_multiplication(tensor1, tensor2)
    np.testing.assert_allclose(out, np.reshape(tensor1, (3, 1, 1)) * tensor2)
def test_matrix_ops(dtype, method):
    backend = jax_backend.JaxBackend()
    matrix = backend.randn((4, 4), dtype=dtype, seed=10)
    matrix1 = getattr(backend, method)(matrix)
    matrix2 = getattr(sp.linalg, method)(matrix)
    np.testing.assert_almost_equal(matrix1, matrix2)
def test_random_uniform_seed(dtype):
    backend = jax_backend.JaxBackend()
    a = backend.random_uniform((4, 4), seed=10, dtype=dtype)
    b = backend.random_uniform((4, 4), seed=10, dtype=dtype)
    np.testing.assert_allclose(a, b)
def test_broadcast_right_multiplication(dtype):
    backend = jax_backend.JaxBackend()
    tensor1 = backend.randn((2, 3), dtype=dtype, seed=10)
    tensor2 = backend.randn((3, ), dtype=dtype, seed=10)
    out = backend.broadcast_right_multiplication(tensor1, tensor2)
    np.testing.assert_allclose(out, np.array(tensor1) * np.array(tensor2))
def test_zeros_dtype(dtype):
    backend = jax_backend.JaxBackend()
    a = backend.zeros((4, 4), dtype=dtype)
    assert a.dtype == dtype
def test_randn_dtype(dtype):
    backend = jax_backend.JaxBackend()
    a = backend.randn((4, 4), dtype=dtype)
    assert a.dtype == dtype
Esempio n. 8
0
def test_shape_tuple():
  backend = jax_backend.JaxBackend()
  a = backend.convert_to_tensor(np.ones([2, 3, 4]))
  actual = backend.shape_tuple(a)
  assert actual == (2, 3, 4)
Esempio n. 9
0
def test_shape_prod():
  backend = jax_backend.JaxBackend()
  a = backend.convert_to_tensor(2 * np.ones([1, 2, 3, 4]))
  actual = np.array(backend.shape_prod(a))
  assert actual == 2**24
Esempio n. 10
0
def test_eigsh_free_fermions(N, dtype, param_type):
    """
  Find the lowest eigenvalues and eigenvectors
  of a 1d free-fermion Hamiltonian on N sites.
  The dimension of the hermitian matrix is
  (2**N, 2**N).
  """
    backend = jax_backend.JaxBackend(precision=jax.lax.Precision.HIGHEST)
    np.random.seed(10)
    pot, hop = get_ham_params(dtype, N, param_type)
    P = jnp.diag(np.array([0, -1])).astype(dtype)
    c = jnp.array([[0, 1], [0, 0]], dtype)
    n = c.T @ c
    eye = jnp.eye(2, dtype=dtype)
    neye = jnp.kron(n, eye)
    eyen = jnp.kron(eye, n)
    ccT = jnp.kron(c @ P, c.T)
    cTc = jnp.kron(c.T, c)

    @jax.jit
    def matvec(vec):
        x = vec.reshape((4, 2**(N - 2)))
        out = jnp.zeros(x.shape, x.dtype)
        t1 = neye * pot[0] + eyen * pot[1] / 2
        t2 = cTc * hop[0] - ccT * jnp.conj(hop[0])
        out += jnp.einsum('ij,ki -> kj', x, t1 + t2)
        x = x.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(
            (4, 2**(N - 2)))
        out = out.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(
            (4, 2**(N - 2)))
        for site in range(1, N - 2):
            t1 = neye * pot[site] / 2 + eyen * pot[site + 1] / 2
            t2 = cTc * hop[site] - ccT * jnp.conj(hop[site])
            out += jnp.einsum('ij,ki -> kj', x, t1 + t2)
            x = x.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(
                (4, 2**(N - 2)))
            out = out.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(
                (4, 2**(N - 2)))
        t1 = neye * pot[N - 2] / 2 + eyen * pot[N - 1]
        t2 = cTc * hop[N - 2] - ccT * jnp.conj(hop[N - 2])
        out += jnp.einsum('ij,ki -> kj', x, t1 + t2)
        x = x.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(
            (4, 2**(N - 2)))
        out = out.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(
            (4, 2**(N - 2)))

        x = x.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(2**N)
        out = out.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(2**N)
        return out.ravel()

    H = np.diag(pot) + np.diag(hop.conj(), 1) + np.diag(hop, -1)
    single_particle_energies = np.linalg.eigh(H)[0]

    many_body_energies = []
    for n in range(2**N):
        many_body_energies.append(
            np.sum(single_particle_energies[np.nonzero(
                np.array(list(bin(n)[2:]), dtype=int)[::-1])[0]]))
    many_body_energies = np.sort(many_body_energies)

    init = jnp.array(np.random.randn(2**N)).astype(dtype)
    init /= jnp.linalg.norm(init)

    ncv = 20
    numeig = 4
    which = 'SA'
    tol = 1E-10
    maxiter = 30
    atol = 1E-8
    eta, _ = backend.eigsh(A=matvec,
                           args=[],
                           initial_state=init,
                           num_krylov_vecs=ncv,
                           numeig=numeig,
                           which=which,
                           tol=tol,
                           maxiter=maxiter)
    np.testing.assert_allclose(eta,
                               many_body_energies[:numeig],
                               atol=atol,
                               rtol=atol)
Esempio n. 11
0
def test_eigs_dtype_raises():
    solver = jax_backend.JaxBackend().eigs
    with pytest.raises(TypeError, match="dtype"):
        solver(lambda x: x, shape=(10, ), dtype=np.int32)
Esempio n. 12
0
#                   eigs and eigsh tests                     #
##############################################################
def generate_hermitian_matrix(be, dtype, D):
    H = be.randn((D, D), dtype=dtype, seed=10)
    H += H.T.conj()
    return H


def generate_matrix(be, dtype, D):
    return be.randn((D, D), dtype=dtype, seed=10)


@pytest.mark.parametrize("dtype", [np.float64, np.complex128])
@pytest.mark.parametrize(
    "solver, matrix_generator, exact_decomp, which",
    [(jax_backend.JaxBackend().eigs, generate_matrix, np.linalg.eig, "LM"),
     (jax_backend.JaxBackend().eigs, generate_matrix, np.linalg.eig, "LR"),
     (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix,
      np.linalg.eigh, "LA"),
     (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix,
      np.linalg.eigh, "SA"),
     (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix,
      np.linalg.eigh, "LM")])
def test_eigs_eigsh_all_eigvals_with_init(dtype, solver, matrix_generator,
                                          exact_decomp, which):
    backend = jax_backend.JaxBackend()
    D = 16
    np.random.seed(10)
    init = backend.randn((D, ), dtype=dtype, seed=10)
    H = matrix_generator(backend, dtype, D)
Esempio n. 13
0
def test_eps(dtype):
  backend = jax_backend.JaxBackend()
  assert backend.eps(dtype) == np.finfo(dtype).eps
Esempio n. 14
0
def test_item(dtype):
  backend = jax_backend.JaxBackend()
  tensor = backend.randn((1,), dtype=dtype, seed=10)
  assert backend.item(tensor) == tensor.item()
Esempio n. 15
0
def test_eye_dtype(dtype):
    backend = jax_backend.JaxBackend()
    a = backend.eye(N=4, M=4, dtype=dtype)
    assert a.dtype == dtype
Esempio n. 16
0
def test_sqrt():
  backend = jax_backend.JaxBackend()
  a = backend.convert_to_tensor(np.array([4., 9.]))
  actual = backend.sqrt(a)
  expected = np.array([2, 3])
  np.testing.assert_allclose(expected, actual)
Esempio n. 17
0
def test_ones_dtype(dtype):
    backend = jax_backend.JaxBackend()
    a = backend.ones((4, 4), dtype=dtype)
    assert a.dtype == dtype
Esempio n. 18
0
def test_trace():
    backend = jax_backend.JaxBackend()
    a = backend.convert_to_tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]))
    actual = backend.trace(a)
    np.testing.assert_allclose(actual, 6)
Esempio n. 19
0
def test_reshape():
    backend = jax_backend.JaxBackend()
    a = backend.convert_to_tensor(np.ones((2, 3, 4)))
    actual = backend.shape_tuple(backend.reshape(a, np.array((6, 4, 1))))
    assert actual == (6, 4, 1)
Esempio n. 20
0
def test_convert_bad_test():
    backend = jax_backend.JaxBackend()
    with pytest.raises(TypeError):
        backend.convert_to_tensor(tf.ones((2, 2)))
Esempio n. 21
0
def test_random_uniform_dtype(dtype):
    backend = jax_backend.JaxBackend()
    a = backend.random_uniform((4, 4), dtype=dtype)
    assert a.dtype == dtype
Esempio n. 22
0
def test_norm():
    backend = jax_backend.JaxBackend()
    a = backend.convert_to_tensor(np.ones((2, 2)))
    assert backend.norm(a) == 2
Esempio n. 23
0
def test_base_backend_eigs_not_implemented():
    backend = jax_backend.JaxBackend()
    tensor = backend.randn((4, 2, 3), dtype=np.float64)
    with pytest.raises(NotImplementedError):
        backend.eigs(tensor)
Esempio n. 24
0
def test_eye(dtype):
    backend = jax_backend.JaxBackend()
    a = backend.eye(N=4, M=5, dtype=dtype)
    np.testing.assert_allclose(np.eye(N=4, M=5, dtype=dtype), a)
Esempio n. 25
0
def test_broadcast_right_multiplication_raises():
    backend = jax_backend.JaxBackend()
    tensor1 = backend.randn((2, 3))
    tensor2 = backend.randn((3, 3))
    with pytest.raises(ValueError):
        backend.broadcast_right_multiplication(tensor1, tensor2)
Esempio n. 26
0
def test_zeros(dtype):
    backend = jax_backend.JaxBackend()
    a = backend.zeros((4, 4), dtype=dtype)
    np.testing.assert_allclose(np.zeros((4, 4), dtype=dtype), a)
Esempio n. 27
0
def test_sparse_shape():
    dtype = np.float64
    backend = jax_backend.JaxBackend()
    tensor = backend.randn((2, 3, 4), dtype=dtype, seed=10)
    np.testing.assert_allclose(backend.sparse_shape(tensor), tensor.shape)
Esempio n. 28
0
def test_random_uniform_non_zero_imag(dtype):
    backend = jax_backend.JaxBackend()
    a = backend.random_uniform((4, 4), dtype=dtype)
    assert np.linalg.norm(np.imag(a)) != 0.0
Esempio n. 29
0
def test_slice_raises_error():
    backend = jax_backend.JaxBackend()
    a = backend.convert_to_tensor(
        np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]))
    with pytest.raises(ValueError):
        backend.slice(a, (1, 1), (2, 2, 2))
Esempio n. 30
0
def test_randn_seed(dtype):
  backend = jax_backend.JaxBackend(dtype=dtype)
  a = backend.randn((4, 4), seed=10)
  b = backend.randn((4, 4), seed=10)
  np.testing.assert_allclose(a, b)