def verify_take(src_shape, indices_src, axis=None, mode="clip"):
    src_dtype = "float32"
    indices_dtype = "int32"
    indices_src = np.array(indices_src, dtype=indices_dtype)
    A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A")
    indices = tvm.placeholder(shape=indices_src.shape,
                              dtype=indices_dtype,
                              name="indices")
    if axis is None:
        out_tensor = topi.take(a=A, indices=indices, mode=mode)
    else:
        out_tensor = topi.take(a=A, indices=indices, axis=axis, mode=mode)

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(out_tensor)

        foo = tvm.build(s, [A] + [indices] + [out_tensor], device, name="take")
        shape_size = 1
        for i in range(len(src_shape)):
            shape_size = shape_size * src_shape[i]
        data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))

        if axis is None:
            np_mode = "raise" if mode == "fast" else mode
            out_npys = np.take(data_npy, indices_src, mode=np_mode)
        else:
            np_mode = "raise" if mode == "fast" else mode
            out_npys = np.take(data_npy, indices_src, axis=axis, mode=np_mode)
        data_nd = tvm.nd.array(data_npy, ctx)
        indices_nd = tvm.nd.array(indices_src, ctx)
        out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
        foo(data_nd, indices_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys)

    for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
        check_device(device)
Exemple #2
0
def verify_take(src_shape, indices_src, axis=None, mode="clip"):
    src_dtype = "float32"
    indices_dtype = "int32"
    indices_src = np.array(indices_src, dtype=indices_dtype)
    A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A")
    indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices")
    if axis is None:
        out_tensor = topi.take(a=A, indices=indices, mode=mode)
    else:
        out_tensor = topi.take(a=A, indices=indices, axis=axis, mode=mode)

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(out_tensor)

        foo = tvm.build(s, [A] + [indices] + [out_tensor] , device, name="take")
        shape_size = 1
        for i in range(len(src_shape)):
            shape_size = shape_size * src_shape[i]
        data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))

        if axis is None:
            out_npys = np.take(data_npy, indices_src, mode=mode)
        else:
            out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode)
        data_nd = tvm.nd.array(data_npy, ctx)
        indices_nd = tvm.nd.array(indices_src, ctx)
        out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
        foo(data_nd, indices_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys)

    for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
        check_device(device)