def test_transpose_input_validation(): a = pt.make_placeholder(name="a", shape=(10, 10), dtype=np.float64) pt.transpose(a) with pytest.raises(ValueError): pt.transpose(a, (2, 0, 1)) with pytest.raises(ValueError): pt.transpose(a, (1, 1)) with pytest.raises(ValueError): pt.transpose(a, (0,))
def test_axis_permutation(ctx_factory, axes): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) ndim = len(axes) shape = (3, 4, 5)[:ndim] from numpy.random import default_rng rng = default_rng() x_in = rng.random(size=shape) namespace = pt.Namespace() x = pt.make_data_wrapper(namespace, x_in) assert_allclose_to_numpy(pt.transpose(x, axes), queue)
def test_axis_permutation(ctx_factory, axes): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) ndim = len(axes) shape = (3, 4, 5)[:ndim] from numpy.random import default_rng rng = default_rng() x_in = rng.random(size=shape) namespace = pt.Namespace() x = pt.make_data_wrapper(namespace, x_in) prog = pt.generate_loopy( pt.transpose(x, axes), target=pt.PyOpenCLTarget(queue)) _, (x_out,) = prog() assert (x_out == np.transpose(x_in, axes)).all()