def test_tensordot_int_vs_backend(backend, dtype): """ Tests that tensordot yields the same result as the backend equivalent. """ shape = (4, 4, 4) dtype = testing_utils.np_dtype_to_backend(backend, dtype) testing_utils.check_contraction_dtype(backend, dtype) tensor1 = tensornetwork.ones(shape, backend=backend, dtype=dtype) tensor2 = tensornetwork.ones(shape, backend=backend, dtype=dtype) tensors = [tensor1, tensor2] dim = 1 result = tensornetwork.tensordot(*tensors, dim) backend_obj = backends.backend_factory.get_backend(backend) arrays = [t.array for t in tensors] backend_result = backend_obj.tensordot(*arrays, axes=dim) np.testing.assert_allclose(backend_result, result.array)
def test_tensordot_invalid_backend_raises_value_error(backend, dtype): """ Tests that tensordot raises ValueError when fed Tensors with different backends. Other failure modes are tested at the backend level. """ backend_names = set(["jax", "numpy", "tensorflow", "pytorch"]) this_name = set([backend]) other_backend_names = list(backend_names - this_name) shape = (4, 4, 4) dtype1 = testing_utils.np_dtype_to_backend(backend, dtype) testing_utils.check_contraction_dtype(backend, dtype1) tensor1 = tensornetwork.ones(shape, backend=backend, dtype=dtype1) for other_backend in other_backend_names: dtype2 = testing_utils.np_dtype_to_backend(other_backend, dtype) testing_utils.check_contraction_dtype(other_backend, dtype2) tensor2 = tensornetwork.ones(shape, backend=other_backend, dtype=dtype2) with pytest.raises(ValueError): _ = tensornetwork.tensordot(tensor1, tensor2, [[2, 0, 1], [1, 2, 0]])