def test_einsum_array_manipulation(ctx_factory, spec): ctx = ctx_factory() queue = cl.CommandQueue(ctx) n = 4 a = np.random.rand(n, n) arg_names = ("a", ) knl = lp.make_einsum(spec, arg_names) evt, (out, ) = knl(queue, a=a) ans = np.einsum(spec, a) assert np.linalg.norm(out - ans) <= 1e-15
def test_einsum_array_ops_same_dims(ctx_factory, spec): ctx = ctx_factory() queue = cl.CommandQueue(ctx) n = 4 a = np.random.rand(n, n) b = np.random.rand(n, n) arg_names = ("a", "b") knl = lp.make_einsum(spec, arg_names) evt, (out, ) = knl(queue, a=a, b=b) ans = np.einsum(spec, a, b) assert np.linalg.norm(out - ans) <= 1e-15
def test_einsum_array_ops_triple_prod(ctx_factory, spec): ctx = ctx_factory() queue = cl.CommandQueue(ctx) n = 3 a = np.random.rand(n, n) b = np.random.rand(n, n) c = np.random.rand(n, n) arg_names = ("a", "b", "c") knl = lp.make_einsum(spec, arg_names) evt, (out, ) = knl(queue, a=a, b=b, c=c) ans = np.einsum(spec, a, b, c) assert np.linalg.norm(out - ans) <= 1e-15
def test_make_einsum_error_handling(): with pytest.raises(ValueError): lp.make_einsum("ij,j->j", ("a", )) with pytest.raises(ValueError): lp.make_einsum("ij,j->jj", ("a", "b"))