def verify_where(in_shape): Cond = te.placeholder(shape=in_shape, name="cond") dtype = Cond.dtype A = te.placeholder(shape=in_shape, name="A") B = te.placeholder(shape=in_shape, name="B") C = topi.where(Cond, A, B) def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.Target(device): s = tvm.topi.testing.get_broadcast_schedule(device)(C) f = tvm.build(s, [Cond, A, B, C], device, name="where") cond_npy = np.random.uniform(low=-1, high=1, size=in_shape).astype(dtype) x_npy = np.random.uniform(size=in_shape).astype(dtype) y_npy = np.random.uniform(size=in_shape).astype(dtype) out_npy = np.where(cond_npy, x_npy, y_npy) cond_nd = tvm.nd.array(cond_npy, ctx) x_nd = tvm.nd.array(x_npy, ctx) y_nd = tvm.nd.array(y_npy, ctx) out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx) f(cond_nd, x_nd, y_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) for device, ctx in tvm.testing.enabled_targets(): check_device(device, ctx)
def check_device(device, ctx): with tvm.target.Target(device): print("Running on target: %s" % device) conv2d_compute, conv2d_schedule = tvm.topi.testing.get_conv2d_nchw_implement(device) data = te.placeholder((2, 1, 2, 4), "int8", "data") w = te.placeholder((3, 1, 2, 2), "int8", "w") conv1 = conv2d_compute(data, w, 1, 0, 1, "int32") zeros = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(0, dtype="int32")) gt = topi.greater_equal(conv1, zeros) one = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(1, dtype="int32")) two = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(2, dtype="int32")) where = topi.where(gt, one, two) add = topi.add(conv1, where) outs = [add] s = conv2d_schedule(outs) tvm.build(s, [data, w, add], target=backend)
def check_device(device): with tvm.target.create(device): ctx = tvm.context(device, 0) if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) conv2d_compute, conv2d_schedule = tvm.topi.testing.get_conv2d_nchw_implement(device) data = te.placeholder((2, 1, 2, 4), 'int8', 'data') w = te.placeholder((3, 1, 2, 2), 'int8', 'w') conv1 = conv2d_compute(data, w, 1, 0, 1, 'int32') zeros = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(0, dtype='int32')) gt = topi.greater_equal(conv1, zeros) one = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(1, dtype='int32')) two = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(2, dtype='int32')) where = topi.where(gt, one, two) add = topi.add(conv1, where) outs = [add] s = conv2d_schedule(outs) tvm.build(s, [data, w, add], target=backend)