Esempio n. 1
0
def test_tensordot_int_vs_backend(backend, dtype):
    """
  Tests that tensordot yields the same result as the backend equivalent.
  """
    shape = (4, 4, 4)
    dtype = testing_utils.np_dtype_to_backend(backend, dtype)
    testing_utils.check_contraction_dtype(backend, dtype)
    tensor1 = tensornetwork.ones(shape, backend=backend, dtype=dtype)
    tensor2 = tensornetwork.ones(shape, backend=backend, dtype=dtype)
    tensors = [tensor1, tensor2]
    dim = 1
    result = tensornetwork.tensordot(*tensors, dim)
    backend_obj = backends.backend_factory.get_backend(backend)
    arrays = [t.array for t in tensors]
    backend_result = backend_obj.tensordot(*arrays, axes=dim)
    np.testing.assert_allclose(backend_result, result.array)
Esempio n. 2
0
def test_tensordot_invalid_backend_raises_value_error(backend, dtype):
    """
  Tests that tensordot raises ValueError when fed Tensors with different
  backends. Other failure modes are tested at the backend level.
  """
    backend_names = set(["jax", "numpy", "tensorflow", "pytorch"])
    this_name = set([backend])
    other_backend_names = list(backend_names - this_name)
    shape = (4, 4, 4)
    dtype1 = testing_utils.np_dtype_to_backend(backend, dtype)
    testing_utils.check_contraction_dtype(backend, dtype1)
    tensor1 = tensornetwork.ones(shape, backend=backend, dtype=dtype1)
    for other_backend in other_backend_names:
        dtype2 = testing_utils.np_dtype_to_backend(other_backend, dtype)
        testing_utils.check_contraction_dtype(other_backend, dtype2)
        tensor2 = tensornetwork.ones(shape,
                                     backend=other_backend,
                                     dtype=dtype2)
        with pytest.raises(ValueError):
            _ = tensornetwork.tensordot(tensor1, tensor2,
                                        [[2, 0, 1], [1, 2, 0]])