def test_backend_dtype_exception(): backend = jax_backend.JaxBackend(dtype=np.float32) tensor = np.random.rand(2, 2, 2) with pytest.raises(TypeError): _ = backend.convert_to_tensor(tensor)
def test_broadcast_left_multiplication(dtype): backend = jax_backend.JaxBackend() tensor1 = backend.randn((3, ), dtype=dtype, seed=10) tensor2 = backend.randn((3, 4, 2), dtype=dtype, seed=10) out = backend.broadcast_left_multiplication(tensor1, tensor2) np.testing.assert_allclose(out, np.reshape(tensor1, (3, 1, 1)) * tensor2)
def test_matrix_ops(dtype, method): backend = jax_backend.JaxBackend() matrix = backend.randn((4, 4), dtype=dtype, seed=10) matrix1 = getattr(backend, method)(matrix) matrix2 = getattr(sp.linalg, method)(matrix) np.testing.assert_almost_equal(matrix1, matrix2)
def test_random_uniform_seed(dtype): backend = jax_backend.JaxBackend() a = backend.random_uniform((4, 4), seed=10, dtype=dtype) b = backend.random_uniform((4, 4), seed=10, dtype=dtype) np.testing.assert_allclose(a, b)
def test_broadcast_right_multiplication(dtype): backend = jax_backend.JaxBackend() tensor1 = backend.randn((2, 3), dtype=dtype, seed=10) tensor2 = backend.randn((3, ), dtype=dtype, seed=10) out = backend.broadcast_right_multiplication(tensor1, tensor2) np.testing.assert_allclose(out, np.array(tensor1) * np.array(tensor2))
def test_zeros_dtype(dtype): backend = jax_backend.JaxBackend() a = backend.zeros((4, 4), dtype=dtype) assert a.dtype == dtype
def test_randn_dtype(dtype): backend = jax_backend.JaxBackend() a = backend.randn((4, 4), dtype=dtype) assert a.dtype == dtype
def test_shape_tuple(): backend = jax_backend.JaxBackend() a = backend.convert_to_tensor(np.ones([2, 3, 4])) actual = backend.shape_tuple(a) assert actual == (2, 3, 4)
def test_shape_prod(): backend = jax_backend.JaxBackend() a = backend.convert_to_tensor(2 * np.ones([1, 2, 3, 4])) actual = np.array(backend.shape_prod(a)) assert actual == 2**24
def test_eigsh_free_fermions(N, dtype, param_type): """ Find the lowest eigenvalues and eigenvectors of a 1d free-fermion Hamiltonian on N sites. The dimension of the hermitian matrix is (2**N, 2**N). """ backend = jax_backend.JaxBackend(precision=jax.lax.Precision.HIGHEST) np.random.seed(10) pot, hop = get_ham_params(dtype, N, param_type) P = jnp.diag(np.array([0, -1])).astype(dtype) c = jnp.array([[0, 1], [0, 0]], dtype) n = c.T @ c eye = jnp.eye(2, dtype=dtype) neye = jnp.kron(n, eye) eyen = jnp.kron(eye, n) ccT = jnp.kron(c @ P, c.T) cTc = jnp.kron(c.T, c) @jax.jit def matvec(vec): x = vec.reshape((4, 2**(N - 2))) out = jnp.zeros(x.shape, x.dtype) t1 = neye * pot[0] + eyen * pot[1] / 2 t2 = cTc * hop[0] - ccT * jnp.conj(hop[0]) out += jnp.einsum('ij,ki -> kj', x, t1 + t2) x = x.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape( (4, 2**(N - 2))) out = out.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape( (4, 2**(N - 2))) for site in range(1, N - 2): t1 = neye * pot[site] / 2 + eyen * pot[site + 1] / 2 t2 = cTc * hop[site] - ccT * jnp.conj(hop[site]) out += jnp.einsum('ij,ki -> kj', x, t1 + t2) x = x.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape( (4, 2**(N - 2))) out = out.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape( (4, 2**(N - 2))) t1 = neye * pot[N - 2] / 2 + eyen * pot[N - 1] t2 = cTc * hop[N - 2] - ccT * jnp.conj(hop[N - 2]) out += jnp.einsum('ij,ki -> kj', x, t1 + t2) x = x.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape( (4, 2**(N - 2))) out = out.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape( (4, 2**(N - 2))) x = x.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(2**N) out = out.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(2**N) return out.ravel() H = np.diag(pot) + np.diag(hop.conj(), 1) + np.diag(hop, -1) single_particle_energies = np.linalg.eigh(H)[0] many_body_energies = [] for n in range(2**N): many_body_energies.append( np.sum(single_particle_energies[np.nonzero( np.array(list(bin(n)[2:]), dtype=int)[::-1])[0]])) many_body_energies = np.sort(many_body_energies) init = jnp.array(np.random.randn(2**N)).astype(dtype) init /= jnp.linalg.norm(init) ncv = 20 numeig = 4 which = 'SA' tol = 1E-10 maxiter = 30 atol = 1E-8 eta, _ = backend.eigsh(A=matvec, args=[], initial_state=init, num_krylov_vecs=ncv, numeig=numeig, which=which, tol=tol, maxiter=maxiter) np.testing.assert_allclose(eta, many_body_energies[:numeig], atol=atol, rtol=atol)
def test_eigs_dtype_raises(): solver = jax_backend.JaxBackend().eigs with pytest.raises(TypeError, match="dtype"): solver(lambda x: x, shape=(10, ), dtype=np.int32)
# eigs and eigsh tests # ############################################################## def generate_hermitian_matrix(be, dtype, D): H = be.randn((D, D), dtype=dtype, seed=10) H += H.T.conj() return H def generate_matrix(be, dtype, D): return be.randn((D, D), dtype=dtype, seed=10) @pytest.mark.parametrize("dtype", [np.float64, np.complex128]) @pytest.mark.parametrize( "solver, matrix_generator, exact_decomp, which", [(jax_backend.JaxBackend().eigs, generate_matrix, np.linalg.eig, "LM"), (jax_backend.JaxBackend().eigs, generate_matrix, np.linalg.eig, "LR"), (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, "LA"), (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, "SA"), (jax_backend.JaxBackend().eigsh, generate_hermitian_matrix, np.linalg.eigh, "LM")]) def test_eigs_eigsh_all_eigvals_with_init(dtype, solver, matrix_generator, exact_decomp, which): backend = jax_backend.JaxBackend() D = 16 np.random.seed(10) init = backend.randn((D, ), dtype=dtype, seed=10) H = matrix_generator(backend, dtype, D)
def test_eps(dtype): backend = jax_backend.JaxBackend() assert backend.eps(dtype) == np.finfo(dtype).eps
def test_item(dtype): backend = jax_backend.JaxBackend() tensor = backend.randn((1,), dtype=dtype, seed=10) assert backend.item(tensor) == tensor.item()
def test_eye_dtype(dtype): backend = jax_backend.JaxBackend() a = backend.eye(N=4, M=4, dtype=dtype) assert a.dtype == dtype
def test_sqrt(): backend = jax_backend.JaxBackend() a = backend.convert_to_tensor(np.array([4., 9.])) actual = backend.sqrt(a) expected = np.array([2, 3]) np.testing.assert_allclose(expected, actual)
def test_ones_dtype(dtype): backend = jax_backend.JaxBackend() a = backend.ones((4, 4), dtype=dtype) assert a.dtype == dtype
def test_trace(): backend = jax_backend.JaxBackend() a = backend.convert_to_tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) actual = backend.trace(a) np.testing.assert_allclose(actual, 6)
def test_reshape(): backend = jax_backend.JaxBackend() a = backend.convert_to_tensor(np.ones((2, 3, 4))) actual = backend.shape_tuple(backend.reshape(a, np.array((6, 4, 1)))) assert actual == (6, 4, 1)
def test_convert_bad_test(): backend = jax_backend.JaxBackend() with pytest.raises(TypeError): backend.convert_to_tensor(tf.ones((2, 2)))
def test_random_uniform_dtype(dtype): backend = jax_backend.JaxBackend() a = backend.random_uniform((4, 4), dtype=dtype) assert a.dtype == dtype
def test_norm(): backend = jax_backend.JaxBackend() a = backend.convert_to_tensor(np.ones((2, 2))) assert backend.norm(a) == 2
def test_base_backend_eigs_not_implemented(): backend = jax_backend.JaxBackend() tensor = backend.randn((4, 2, 3), dtype=np.float64) with pytest.raises(NotImplementedError): backend.eigs(tensor)
def test_eye(dtype): backend = jax_backend.JaxBackend() a = backend.eye(N=4, M=5, dtype=dtype) np.testing.assert_allclose(np.eye(N=4, M=5, dtype=dtype), a)
def test_broadcast_right_multiplication_raises(): backend = jax_backend.JaxBackend() tensor1 = backend.randn((2, 3)) tensor2 = backend.randn((3, 3)) with pytest.raises(ValueError): backend.broadcast_right_multiplication(tensor1, tensor2)
def test_zeros(dtype): backend = jax_backend.JaxBackend() a = backend.zeros((4, 4), dtype=dtype) np.testing.assert_allclose(np.zeros((4, 4), dtype=dtype), a)
def test_sparse_shape(): dtype = np.float64 backend = jax_backend.JaxBackend() tensor = backend.randn((2, 3, 4), dtype=dtype, seed=10) np.testing.assert_allclose(backend.sparse_shape(tensor), tensor.shape)
def test_random_uniform_non_zero_imag(dtype): backend = jax_backend.JaxBackend() a = backend.random_uniform((4, 4), dtype=dtype) assert np.linalg.norm(np.imag(a)) != 0.0
def test_slice_raises_error(): backend = jax_backend.JaxBackend() a = backend.convert_to_tensor( np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])) with pytest.raises(ValueError): backend.slice(a, (1, 1), (2, 2, 2))
def test_randn_seed(dtype): backend = jax_backend.JaxBackend(dtype=dtype) a = backend.randn((4, 4), seed=10) b = backend.randn((4, 4), seed=10) np.testing.assert_allclose(a, b)