def verify_take(src_shape, indices_src, axis=None, mode="clip"): src_dtype = "float32" indices_dtype = "int32" indices_src = np.array(indices_src, dtype=indices_dtype) A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A") indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices") if axis is None: out_tensor = topi.take(a=A, indices=indices, mode=mode) else: out_tensor = topi.take(a=A, indices=indices, axis=axis, mode=mode) def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) with tvm.target.create(device): s = topi.generic.schedule_injective(out_tensor) foo = tvm.build(s, [A] + [indices] + [out_tensor], device, name="take") shape_size = 1 for i in range(len(src_shape)): shape_size = shape_size * src_shape[i] data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape)) if axis is None: np_mode = "raise" if mode == "fast" else mode out_npys = np.take(data_npy, indices_src, mode=np_mode) else: np_mode = "raise" if mode == "fast" else mode out_npys = np.take(data_npy, indices_src, axis=axis, mode=np_mode) data_nd = tvm.nd.array(data_npy, ctx) indices_nd = tvm.nd.array(indices_src, ctx) out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype) foo(data_nd, indices_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys) for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]: check_device(device)
def verify_take(src_shape, indices_src, axis=None, mode="clip"): src_dtype = "float32" indices_dtype = "int32" indices_src = np.array(indices_src, dtype=indices_dtype) A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A") indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices") if axis is None: out_tensor = topi.take(a=A, indices=indices, mode=mode) else: out_tensor = topi.take(a=A, indices=indices, axis=axis, mode=mode) def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) with tvm.target.create(device): s = topi.generic.schedule_injective(out_tensor) foo = tvm.build(s, [A] + [indices] + [out_tensor] , device, name="take") shape_size = 1 for i in range(len(src_shape)): shape_size = shape_size * src_shape[i] data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape)) if axis is None: out_npys = np.take(data_npy, indices_src, mode=mode) else: out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode) data_nd = tvm.nd.array(data_npy, ctx) indices_nd = tvm.nd.array(indices_src, ctx) out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype) foo(data_nd, indices_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys) for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]: check_device(device)