Exemplo n.º 1
0
def test_einsum_array_manipulation(ctx_factory, spec):
    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    n = 4
    a = np.random.rand(n, n)
    arg_names = ("a", )

    knl = lp.make_einsum(spec, arg_names)
    evt, (out, ) = knl(queue, a=a)
    ans = np.einsum(spec, a)

    assert np.linalg.norm(out - ans) <= 1e-15
Exemplo n.º 2
0
def test_einsum_array_ops_same_dims(ctx_factory, spec):
    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    n = 4
    a = np.random.rand(n, n)
    b = np.random.rand(n, n)
    arg_names = ("a", "b")

    knl = lp.make_einsum(spec, arg_names)
    evt, (out, ) = knl(queue, a=a, b=b)
    ans = np.einsum(spec, a, b)

    assert np.linalg.norm(out - ans) <= 1e-15
Exemplo n.º 3
0
def test_einsum_array_ops_triple_prod(ctx_factory, spec):
    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    n = 3
    a = np.random.rand(n, n)
    b = np.random.rand(n, n)
    c = np.random.rand(n, n)
    arg_names = ("a", "b", "c")

    knl = lp.make_einsum(spec, arg_names)
    evt, (out, ) = knl(queue, a=a, b=b, c=c)
    ans = np.einsum(spec, a, b, c)

    assert np.linalg.norm(out - ans) <= 1e-15
Exemplo n.º 4
0
def test_make_einsum_error_handling():
    with pytest.raises(ValueError):
        lp.make_einsum("ij,j->j", ("a", ))

    with pytest.raises(ValueError):
        lp.make_einsum("ij,j->jj", ("a", "b"))