Beispiel #1
0
def test_normal(ctx, shape, axis):

    rd = Reduce(ctx)

    a = get_test_array(shape, numpy.int64)
    a_dev = ctx.to_device(a)
    b_ref = a.sum(axis)
    if len(b_ref.shape) == 0:
        b_ref = numpy.array([b_ref], numpy.int64)
    b_dev = ctx.allocate(b_ref.shape, numpy.int64)

    rd.prepare_for(b_dev, a_dev, axis=axis,
        code=dict(kernel="return input1 + input2;"))
    rd(b_dev, a_dev)
    assert diff_is_negligible(b_dev.get(), b_ref)
Beispiel #2
0
def test_normal(ctx, shape, axis):

    rd = Reduce(ctx)

    a = get_test_array(shape, numpy.int64)
    a_dev = ctx.to_device(a)
    b_ref = a.sum(axis)
    if len(b_ref.shape) == 0:
        b_ref = numpy.array([b_ref], numpy.int64)
    b_dev = ctx.allocate(b_ref.shape, numpy.int64)

    rd.prepare_for(b_dev,
                   a_dev,
                   axis=axis,
                   code=dict(kernel="return input1 + input2;"))
    rd(b_dev, a_dev)
    assert diff_is_negligible(b_dev.get(), b_ref)
Beispiel #3
0
def test_nontrivial_function(ctx):
    rd = Reduce(ctx)
    shape = (100, 100)
    a = get_test_array(shape, numpy.int64)
    a_dev = ctx.to_device(a)
    b_ref = a.sum(0)
    b_dev = ctx.allocate((100,), numpy.int64)

    rd.prepare_for(b_dev, a_dev, axis=0,
        code=dict(
            kernel="return test(input1, input2);",
            functions="""
            WITHIN_KERNEL ${output.ctype} test(${input.ctype} val1, ${input.ctype} val2)
            {
                return val1 + val2;
            }
            """))

    rd(b_dev, a_dev)
    assert diff_is_negligible(b_dev.get(), b_ref)
Beispiel #4
0
def test_nontrivial_function(ctx):
    rd = Reduce(ctx)
    shape = (100, 100)
    a = get_test_array(shape, numpy.int64)
    a_dev = ctx.to_device(a)
    b_ref = a.sum(0)
    b_dev = ctx.allocate((100, ), numpy.int64)

    rd.prepare_for(b_dev,
                   a_dev,
                   axis=0,
                   code=dict(kernel="return test(input1, input2);",
                             functions="""
            WITHIN_KERNEL ${output.ctype} test(${input.ctype} val1, ${input.ctype} val2)
            {
                return val1 + val2;
            }
            """))

    rd(b_dev, a_dev)
    assert diff_is_negligible(b_dev.get(), b_ref)