def verify_where(in_shape): Cond = tvm.placeholder(shape=in_shape, name="cond") dtype = Cond.dtype A = tvm.placeholder(shape=in_shape, name="A") B = tvm.placeholder(shape=in_shape, name="B") C = topi.where(Cond, A, B) def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) with tvm.target.create(device): s = topi.generic.schedule_broadcast(C) f = tvm.build(s, [Cond, A, B, C], device, name="where") cond_npy = np.random.uniform(low=-1, high=1, size=in_shape).astype(dtype) x_npy = np.random.uniform(size=in_shape).astype(dtype) y_npy = np.random.uniform(size=in_shape).astype(dtype) out_npy = np.where(cond_npy, x_npy, y_npy) cond_nd = tvm.nd.array(cond_npy, ctx) x_nd = tvm.nd.array(x_npy, ctx) y_nd = tvm.nd.array(y_npy, ctx) out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx) f(cond_nd, x_nd, y_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) for device in get_all_backend(): check_device(device)
def check_device(device): with tvm.target.create(device): ctx = tvm.context(device, 0) if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) data = tvm.placeholder((2, 1, 2, 4), 'int8', 'data') w = tvm.placeholder((3, 1, 2, 2), 'int8', 'w') conv1 = topi.nn.conv2d(data, w, 1, 0, 1, out_dtype='int32') zeros = topi.full((2, 3, 1, 3), 'int32', tvm.const(0, dtype='int32')) gt = topi.greater_equal(conv1, zeros) one = topi.full((2, 3, 1, 3), 'int32', tvm.const(1, dtype='int32')) two = topi.full((2, 3, 1, 3), 'int32', tvm.const(2, dtype='int32')) where = topi.where(gt, one, two) add = topi.add(conv1, where) outs = [add] s = topi.generic.schedule_conv2d_nchw(outs) tvm.build(s, [data, w, add], target=backend)