def verify_strided_set(in_shape, v_shape, begin, end, strides=None): A = tvm.placeholder(shape=in_shape, name="A") V = tvm.placeholder(shape=v_shape, name="V") b = tvm.placeholder(shape=(len(begin),), name="b", dtype='int32') e = tvm.placeholder(shape=(len(end),), name="e", dtype='int32') if strides is not None: st = tvm.placeholder(shape=(len(strides),), name="st", dtype='int32') B = topi.strided_set(A, V, b, e, st) + 1 else: B = topi.strided_set(A, V, b, e) + 1 def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) with tvm.target.create(device): s = topi.generic.schedule_injective(B) if strides is not None: foo = tvm.build(s, [A, V, b, e, st, B], device, name="stride_set") s_np = np.asarray(strides).astype('int32') s_nd = tvm.nd.array(s_np, ctx) else: foo = tvm.build(s, [A, V, b, e, B], device, name="stride_set") x_np = np.random.uniform(size=in_shape).astype(A.dtype) v_np = np.random.uniform(size=v_shape).astype(V.dtype) b_np = np.asarray(begin).astype('int32') e_np = np.asarray(end).astype('int32') out_npy = topi.testing.strided_set_python( x_np, v_np, begin, end, strides) + 1 data_nd = tvm.nd.array(x_np, ctx) v_nd = tvm.nd.array(v_np, ctx) b_nd = tvm.nd.array(b_np, ctx) e_nd = tvm.nd.array(e_np, ctx) out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype) if strides is not None: foo(data_nd, v_nd, b_nd, e_nd, s_nd, out_nd) else: foo(data_nd, v_nd, b_nd, e_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]: check_device(device)
def compute_strided_set(attrs, inputs, output_type): """Compute definition of strided_set""" return [ topi.strided_set(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4]) ]