Exemplo n.º 1
0
def test_transpose_input_validation():
    a = pt.make_placeholder(name="a", shape=(10, 10), dtype=np.float64)
    pt.transpose(a)

    with pytest.raises(ValueError):
        pt.transpose(a, (2, 0, 1))

    with pytest.raises(ValueError):
        pt.transpose(a, (1, 1))

    with pytest.raises(ValueError):
        pt.transpose(a, (0,))
Exemplo n.º 2
0
def test_axis_permutation(ctx_factory, axes):
    cl_ctx = ctx_factory()
    queue = cl.CommandQueue(cl_ctx)

    ndim = len(axes)
    shape = (3, 4, 5)[:ndim]

    from numpy.random import default_rng
    rng = default_rng()

    x_in = rng.random(size=shape)

    namespace = pt.Namespace()
    x = pt.make_data_wrapper(namespace, x_in)
    assert_allclose_to_numpy(pt.transpose(x, axes), queue)
Exemplo n.º 3
0
def test_axis_permutation(ctx_factory, axes):
    cl_ctx = ctx_factory()
    queue = cl.CommandQueue(cl_ctx)

    ndim = len(axes)
    shape = (3, 4, 5)[:ndim]

    from numpy.random import default_rng
    rng = default_rng()

    x_in = rng.random(size=shape)

    namespace = pt.Namespace()
    x = pt.make_data_wrapper(namespace, x_in)
    prog = pt.generate_loopy(
            pt.transpose(x, axes),
            target=pt.PyOpenCLTarget(queue))

    _, (x_out,) = prog()
    assert (x_out == np.transpose(x_in, axes)).all()