def verify_strided_set(in_shape, v_shape, begin, end, strides=None): A = te.placeholder(shape=in_shape, name="A") V = te.placeholder(shape=v_shape, name="V") b = te.placeholder(shape=(len(begin), ), name="b", dtype="int32") e = te.placeholder(shape=(len(end), ), name="e", dtype="int32") if strides is not None: st = te.placeholder(shape=(len(strides), ), name="st", dtype="int32") B = topi.strided_set(A, V, b, e, st) + 1 else: B = topi.strided_set(A, V, b, e) + 1 def check_device(device): ctx = tvm.context(device, 0) if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) with tvm.target.Target(device): s = tvm.topi.testing.get_injective_schedule(device)(B) if strides is not None: foo = tvm.build(s, [A, V, b, e, st, B], device, name="stride_set") s_np = np.asarray(strides).astype("int32") s_nd = tvm.nd.array(s_np, ctx) else: foo = tvm.build(s, [A, V, b, e, B], device, name="stride_set") x_np = np.random.uniform(size=in_shape).astype(A.dtype) v_np = np.random.uniform(size=v_shape).astype(V.dtype) b_np = np.asarray(begin).astype("int32") e_np = np.asarray(end).astype("int32") out_npy = tvm.topi.testing.strided_set_python(x_np, v_np, begin, end, strides) + 1 data_nd = tvm.nd.array(x_np, ctx) v_nd = tvm.nd.array(v_np, ctx) b_nd = tvm.nd.array(b_np, ctx) e_nd = tvm.nd.array(e_np, ctx) out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype) if strides is not None: foo(data_nd, v_nd, b_nd, e_nd, s_nd, out_nd) else: foo(data_nd, v_nd, b_nd, e_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]: check_device(device)
def compute_strided_set(attrs, inputs, output_type): """Compute definition of strided_set""" return [ topi.strided_set(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4]) ]