def test_roll(ctx_factory, shift, axis): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) namespace = pt.Namespace() pt.make_size_param(namespace, "n") x = pt.make_placeholder(namespace, name="x", shape=("n", "n"), dtype=np.float) x_in = np.arange(1., 10.).reshape(3, 3) assert_allclose_to_numpy(pt.roll(x, shift=shift, axis=axis), queue, {x: x_in})
def test_roll_input_validation(): a = pt.make_placeholder(name="a", shape=(10, 10), dtype=np.float64) pt.roll(a, 1, axis=0) with pytest.raises(ValueError): pt.roll(a, 1, axis=2) with pytest.raises(ValueError): pt.roll(a, 1, axis=-1)
def test_roll(ctx_factory, shift, axis): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) namespace = pt.Namespace() pt.make_size_param(namespace, "n") x = pt.make_placeholder(namespace, name="x", shape=("n", "n"), dtype=np.float) prog = pt.generate_loopy( pt.roll(x, shift=shift, axis=axis), target=pt.PyOpenCLTarget(queue)) x_in = np.arange(1., 10.).reshape(3, 3) _, (x_out,) = prog(x=x_in) assert (x_out == np.roll(x_in, shift=shift, axis=axis)).all()
def face_swap(self, vec): return pt.stack((pt.roll(vec[:, 1], +1), pt.roll(vec[:, 0], -1)), axis=1)
def test_array_dot_repr(): x = pt.make_placeholder("x", (10, 4), np.int64) y = pt.make_placeholder("y", (10, 4), np.int64) def _assert_stripped_repr(ary: pt.Array, expected_repr: str): expected_str = "".join([c for c in repr(ary) if c not in [" ", "\n"]]) result_str = "".join([c for c in expected_repr if c not in [" ", "\n"]]) assert expected_str == result_str _assert_stripped_repr( 3*x + 4*y, """ IndexLambda( expr=Sum((Subscript(Variable('_in0'), (Variable('_0'), Variable('_1'))), Subscript(Variable('_in1'), (Variable('_0'), Variable('_1'))))), shape=(10, 4), dtype='int64', bindings={'_in0': IndexLambda(expr=Product((3, Subscript(Variable('_in1'), (Variable('_0'), Variable('_1'))))), shape=(10, 4), dtype='int64', bindings={'_in1': Placeholder(shape=(10, 4), dtype='int64', name='x')}), '_in1': IndexLambda(expr=Product((4, Subscript(Variable('_in1'), (Variable('_0'), Variable('_1'))))), shape=(10, 4), dtype='int64', bindings={'_in1': Placeholder(shape=(10, 4), dtype='int64', name='y')})})""") _assert_stripped_repr( pt.roll(x.reshape(2, 20).reshape(-1), 3), """ Roll( array=Reshape(array=Reshape(array=Placeholder(shape=(10, 4), dtype='int64', name='x'), newshape=(2, 20), order='C'), newshape=(40), order='C'), shift=3, axis=0)""") _assert_stripped_repr(y * pt.not_equal(x, 3), """ IndexLambda( expr=Product((Subscript(Variable('_in0'), (Variable('_0'), Variable('_1'))), Subscript(Variable('_in1'), (Variable('_0'), Variable('_1'))))), shape=(10, 4), dtype='int64', bindings={'_in0': Placeholder(shape=(10, 4), dtype='int64', name='y'), '_in1': IndexLambda( expr=Comparison(Subscript(Variable('_in0'), (Variable('_0'), Variable('_1'))), '!=', 3), shape=(10, 4), dtype=<class 'numpy.bool_'>, bindings={'_in0': Placeholder(shape=(10, 4), dtype='int64', name='x')})})""") _assert_stripped_repr( x[y[:, 2:3], x[2, :]], """ AdvancedIndexInContiguousAxes( array=Placeholder(shape=(10, 4), dtype='int64', name='x'), indices=(BasicIndex(array=Placeholder(shape=(10, 4), dtype='int64', name='y'), indices=(NormalizedSlice(start=0, stop=10, step=1), NormalizedSlice(start=2, stop=3, step=1))), BasicIndex(array=Placeholder(shape=(10, 4), dtype='int64', name='x'), indices=(2, NormalizedSlice(start=0, stop=4, step=1)))))""") _assert_stripped_repr( pt.stack([x[y[:, 2:3], x[2, :]].T, y[x[:, 2:3], y[2, :]].T]), """ Stack( arrays=( AxisPermutation( array=AdvancedIndexInContiguousAxes( array=Placeholder(shape=(10, 4), dtype='int64', name='x'), indices=(BasicIndex(array=(...), indices=(NormalizedSlice(start=0, stop=10, step=1), NormalizedSlice(start=2, stop=3, step=1))), BasicIndex(array=(...), indices=(2, NormalizedSlice(start=0, stop=4, step=1))))), axis_permutation=(1, 0)), AxisPermutation(array=AdvancedIndexInContiguousAxes( array=Placeholder(shape=(10, 4), dtype='int64', name='y'), indices=(BasicIndex(array=(...), indices=(NormalizedSlice(start=0, stop=10, step=1), NormalizedSlice(start=2, stop=3, step=1))), BasicIndex(array=(...), indices=(2, NormalizedSlice(start=0, stop=4, step=1))))), axis_permutation=(1, 0))), axis=0) """)