Esempio n. 1
0
def verify_repeat(in_shape, repeats, axis):
    A = tvm.placeholder(shape=in_shape, name="A")
    B = topi.repeat(A, repeats, axis)
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_broadcast(B)
        foo = tvm.build(s, [A, B], device, name="repeat")
        data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
        out_npy = np.repeat(data_npy, repeats, axis)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
        foo(data_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device in get_all_backend():
        check_device(device)
Esempio n. 2
0
def verify_repeat(in_shape, repeats, axis):
    A = tvm.placeholder(shape=in_shape, name="A")
    B = topi.repeat(A, repeats, axis)
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_broadcast(B)
        foo = tvm.build(s, [A, B], device, name="repeat")
        data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
        out_npy = np.repeat(data_npy, repeats, axis)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
        foo(data_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device in get_all_backend():
        check_device(device)