def check_device(device): with tvm.target.create(device): ctx = tvm.context(device, 0) if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) data = tvm.placeholder((2, 1, 2, 4), 'int8', 'data') w = tvm.placeholder((3, 1, 2, 2), 'int8', 'w') conv1 = topi.nn.conv2d(data, w, 1, 0, 1, out_dtype='int32') zeros = topi.full((2, 3, 1, 3), 'int32', tvm.const(0, dtype='int32')) gt = topi.greater_equal(conv1, zeros) one = topi.full((2, 3, 1, 3), 'int32', tvm.const(1, dtype='int32')) two = topi.full((2, 3, 1, 3), 'int32', tvm.const(2, dtype='int32')) where = topi.where(gt, one, two) add = topi.add(conv1, where) outs = [add] s = topi.generic.schedule_conv2d_nchw(outs) tvm.build(s, [data, w, add], target=backend)
def greater_equal_compute(attrs, inputs, output_type, target): assert len(inputs) == 2 return [topi.greater_equal(inputs[0], inputs[1])]
def greater_equal(x, y): return topi.greater_equal(x, y).astype("int8")