def test_min_max_bound(): analyzer = tvm.arith.Analyzer() x, y = tvm.var("x"), tvm.var("y") analyzer.update(x, tvm.arith.ConstIntBound(-9, 11)) analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) bd = analyzer.const_int_bound(tvm.min(x, y)) assert bd.min_value == -9 assert bd.max_value == 10 analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) bd = analyzer.const_int_bound(tvm.min(x, y)) assert bd.min_value == bd.NEG_INF assert bd.max_value == 10 bd = analyzer.const_int_bound(tvm.max(x, y)) assert bd.min_value == 4 assert bd.max_value == bd.POS_INF analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True) analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) bd = analyzer.const_int_bound(tvm.max(x, y)) assert bd.min_value == 4 assert bd.max_value == bd.POS_INF
def _im2col_compute(i, j, k, data): j_h = (((j*block_size) // wo)*stride_h)-pad_t j_w = (((j*block_size) % wo)*stride_w)-pad_l # num rows in l1 for fmatrix is discounted by the amount of bottom padding h_3d = kernel_h - tvm.max(((j_h+kernel_h) - h), 0) pad_t_3d = tvm.max(-j_h, 0) pad_b_3d = tvm.max(((j_h+kernel_h) - h), 0) w_idx_kernel = (k % kernel_w) h_idx_kernel = ((k // kernel_w) % kernel_h) w_idx = j_w # when this is < 0, the slice will start from row 0 so there is no redundancy between base address and this param h_idx = tvm.min(j_h, 0) c1_idx = (k // kernel_w) // kernel_h load3d_input = data[i, c1_idx, # assume padding < kernel size tvm.max(0, j_h):tvm.min(h, j_h+kernel_h), 0:w, 0:c0] return load3d(load3d_input, w, h_3d, pad_l, pad_r, pad_t_3d, pad_b_3d, w_idx_kernel, h_idx_kernel, w_idx, h_idx, 0, stride_w, stride_h, kernel_w, kernel_h, dilation_w, dilation_h, jump_offset, repeat_mode, repeat_time, csize)
def test_bound_simplification_failure(): # Check that the bounds are not expanded A = tvm.compute((2, ), lambda j: j, "A") def _check(B, A=A): s = tvm.create_schedule(B.op) s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.lower(s, [B, A], simple_mode=True) if not bounds[A.op.axis[0]].extent.value <= 2: print(stmt) assert bounds[A.op.axis[0]].extent.value <= 2 # These are hard to simplify, moreover we don't simplify them _check( tvm.compute( (10, ), lambda i: A[tvm.min(3 * i, 4 * i) + tvm.min(-3 * i, -2 * i)])) _check( tvm.compute( (10, ), lambda i: A[tvm.min(3 * i, 4 * i) + tvm.max(-3 * i, -4 * i)])) _check(tvm.compute((10, ), lambda i: A[-2 * (i / 2) - tvm.min(i, 0 - i)])) _check(tvm.compute((10, ), lambda i: A[i + (0 - i)])) # This would cause out of bounds, but we nevertheless include it _check(tvm.compute((10, ), lambda i: A[i]))
def get_valid_counts_scan(data, partial_in, partial): """Low level IR to do scan. Parameters ---------- data: Buffer 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. idx_in : Buffer 2D Buffer of valid data indices with shape [batch_size, num_anchors]. idx : Buffer 2D Buffer of valid data indices with shape [batch_size, num_anchors]. partial : Buffer 2D Buffer of valid data indices with shape [batch_size, new_range]. Returns ------- stmt : Stmt The result IR statement. """ batch_size = data.shape[0] num_anchors = data.shape[1] ib = tvm.ir_builder.create() partial_in = ib.buffer_ptr(partial_in) partial = ib.buffer_ptr(partial) max_threads = int( tvm.target.current_target(allow_none=False).max_num_threads) elem_per_thread = num_anchors // max_threads + 1 nthread_tx = max_threads nthread_bx = batch_size tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) var = tvm.make.node("FloatImm", dtype="float32", value=2) new_range = num_anchors // elem_per_thread + 1 iteration = log(cast(new_range, "float32")) // math.log(2) # Scan: Kogge-Stone adder with ib.if_scope( tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))): with ib.for_range(0, iteration) as k: with ib.if_scope(k == 0): with ib.if_scope( tvm.all(tx > 0, tx < tvm.min(new_range, num_anchors))): partial[bx * new_range + tx] = \ partial_in[bx * new_range + tx] + partial_in[bx * new_range + tx - 1] with ib.else_scope(): partial[bx * new_range] = partial_in[bx * new_range] with ib.else_scope(): with ib.if_scope(tvm.all(tx >= cast(power(var, k), "int32"), \ tx < tvm.min(new_range, num_anchors))): partial[bx * new_range + tx] += \ partial[bx * new_range + tx - cast(power(var, k), "int32")] ib.emit( tvm.make.Call(None, 'tvm_storage_sync', tvm.convert(['shared']), tvm.expr.Call.Intrinsic, None, 0)) return ib.get()
def test_max_min(): ck = IntSetChecker() x, y = tvm.var("x"), tvm.var("y") ck.verify(tvm.max(x, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (1, 11)) ck.verify(tvm.min(x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 9)) ck.verify(tvm.min(x, y), {}, (tvm.min(x, y), tvm.min(x, y))) ck.verify(tvm.max(x, y), {}, (tvm.max(x, y), tvm.max(x, y)))
def test_mul_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") ck.verify((x + 2) * 3, x * 3 + 6) ck.verify((x * 2) * 3, x * 6) ck.verify(tvm.min(x, y) * tvm.max(x, y), x * y) ck.verify(tvm.max(x, y) * tvm.min(x, y), x * y) ck.verify((x - y) * (-2), (y - x) * 2)
def test_simplify_minmax(): x = tvm.var('x') e1 = tvm.max(x, 1) - tvm.max(x, 1) e1s = tvm.ir_pass.CanonicalSimplify(e1) assert e1s.value == 0 e2 = tvm.min(x, 1) - tvm.min(x, 1) e2s = tvm.ir_pass.CanonicalSimplify(e2) assert e2s.value == 0
def _get_pixel(n, c, y, x, cc): y = tvm.max(tvm.min(y, in_h - 1), 0) x = tvm.max(tvm.min(x, in_w - 1), 0) if layout == 'NHWC': return data(n, y, x, c).astype('float') if layout == 'NCHW': return data(n, c, y, x).astype('float') # else must be NCHWxc return data(n, c, y, x, cc).astype('float')
def get_valid_counts_scan(data, partial_in, partial): """Low level IR to do scan. Parameters ---------- data: Buffer 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. idx_in : Buffer 2D Buffer of valid data indices with shape [batch_size, num_anchors]. idx : Buffer 2D Buffer of valid data indices with shape [batch_size, num_anchors]. partial : Buffer 2D Buffer of valid data indices with shape [batch_size, new_range]. Returns ------- stmt : Stmt The result IR statement. """ batch_size = data.shape[0] num_anchors = data.shape[1] ib = tvm.ir_builder.create() partial_in = ib.buffer_ptr(partial_in) partial = ib.buffer_ptr(partial) max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) elem_per_thread = num_anchors // max_threads + 1 nthread_tx = max_threads nthread_bx = batch_size tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) var = tvm.make.node("FloatImm", dtype="float32", value=2) new_range = num_anchors // elem_per_thread + 1 iteration = log(cast(new_range, "float32")) // math.log(2) # Scan: Kogge-Stone adder with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))): with ib.for_range(0, iteration) as k: with ib.if_scope(k == 0): with ib.if_scope(tvm.all(tx > 0, tx < tvm.min(new_range, num_anchors))): partial[bx * new_range + tx] = \ partial_in[bx * new_range + tx] + partial_in[bx * new_range + tx - 1] with ib.else_scope(): partial[bx * new_range] = partial_in[bx * new_range] with ib.else_scope(): with ib.if_scope(tvm.all(tx >= cast(power(var, k), "int32"), \ tx < tvm.min(new_range, num_anchors))): partial[bx * new_range + tx] += \ partial[bx * new_range + tx - cast(power(var, k), "int32")] ib.emit(tvm.make.Call(None, 'tvm_storage_sync', tvm.convert(['shared']), tvm.expr.Call.Intrinsic, None, 0)) return ib.get()
def _get_pixel(data, layout, n, c, y, x, cc): if boxes is None: y = tvm.max(tvm.min(y, image_height - 1), 0) x = tvm.max(tvm.min(x, image_width - 1), 0) if layout == 'NHWC': return data(n, y, x, c).astype('float') if layout == 'NCHW': return data(n, c, y, x).astype('float') # else must be NCHWxc return data(n, c, y, x, cc).astype('float')
def test_canonical_mixed(): ck = CanonicalChecker() x = tvm.var("x") z = tvm.const(3, "int32") ck.verify(x / (z * z) - x / (z * z), 0) ck.verify(x / (z + z) - x / (z + z), 0) ck.verify(x - 2 < 3, x < 5) ck.verify(tvm.max(x, 1) - tvm.max(x, 1), 0) ck.verify(tvm.min(x, 1) - tvm.min(x, 1), 0) ck.verify(x * x - x * x, 0)
def test_cmp_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") # const int bound ck.verify((x % 2 + 10).equal(0), tvm.const(0, "bool")) ck.verify(tvm.expr.NE(x % 2 + 10, 0), tvm.const(1, "bool")) ck.verify(x % 2 + 10 > 1, tvm.const(1, "bool")) ck.verify(x % 2 + 10 <= 1, tvm.const(0, "bool")) ck.verify(x * 3 + 10 == 0, tvm.const(0, "bool")) ck.verify(x * 3 + 10 != 0, tvm.const(1, "bool")) # canonicalization ck.verify((x - 10).equal(0), x.equal(10)) ck.verify((10 - x).equal(0), x.equal(10)) ck.verify((x * y).equal(0), tvm.expr.Or(x.equal(0), y.equal(0))) # cmp bound ck.verify(x + y < x + z, y < z) ck.verify(x + y < z + x, y < z) ck.verify(y + x < x + z, y < z) ck.verify(y + x < z + x, y < z) ck.verify(y - x < z - x, y < z) ck.verify(x - y < x - z, z < y) ck.verify(x < z + x, tvm.expr.LT(0, z)) ck.verify(x < x + z, tvm.expr.LT(0, z)) ck.verify(100 < x + 1, tvm.expr.LT(99, x)) ck.verify(1 < 100 - x, tvm.expr.LT(x, 99)) ck.verify(x * 3 < y * 3, x < y) ck.verify(x * (-3) < y * (-3), y < x) ck.verify(x * 3 >= y * 3, y <= x) ck.verify(x * 4 >= 2, tvm.expr.LE(1, x)) ck.verify(x * 2 >= 50, tvm.expr.LE(25, x)) ck.verify(x / 2 < 3, x < 6) ck.verify(x * 4 <= 2, x <= 0) ck.verify(3 < x / 2, tvm.expr.LT(7, x)) ck.verify(x / 4 * 4 < x, tvm.expr.LT(0, x % 4)) ck.verify(x / 4 * 4 >= x, tvm.expr.LE(x % 4, 0)) ck.verify(x / 4 * 4 < x + y, tvm.expr.LT(0, x % 4 + y)) ck.verify(x / 4 * 4 < x - y, tvm.expr.LT(y, x % 4)) ck.verify((x + 2) / 4 * 4 >= x, tvm.expr.LE((x + 2) % 4, 2)) ck.verify((x + 2) / 4 * 4 >= x + y, tvm.expr.LE((x + 2) % 4 + y, 2)) ck.verify((x + 2) / 4 * 4 >= x - y, tvm.expr.LE((x + 2) % 4 + (-2), y)) ck.verify(tvm.min(x, 11) < 10, x < 10) ck.verify(tvm.min(x, 8) < 10, tvm.const(1, "bool")) ck.verify(tvm.max(8, x) > 10, tvm.expr.LT(10, x)) ck.verify(x + 1 < tvm.max(8, x), x < 7)
def _bilinear(i, c, y, x): y_low = y.astype('int32') x_low = x.astype('int32') y_high = tvm.min(tvm.ceil(y).astype('int32'), height - 1) x_high = tvm.min(tvm.ceil(x).astype('int32'), width - 1) y_lerp = y - y_low x_lerp = x - x_low bottom = x_lerp * data[i, c, y_high, x_high] + \ (1-x_lerp) * data[i, c, y_high, x_low] top = x_lerp * data[i, c, y_low, x_high] + \ (1-x_lerp) * data[i, c, y_low, x_low] return y_lerp * bottom + (1 - y_lerp) * top
def calculate_overlap(out_tensor, box_a_idx, box_b_idx): """Calculate overlap of two boxes. """ w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) - tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx]) + 1.0) h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) - tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]) + 1.0) i = w * h u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx] + 1.0) * \ (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1] + 1.0) + \ (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx] + 1.0) * \ (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1] + 1.0) - i return i / u
def calculate_overlap(out_tensor, box_a_idx, box_b_idx): """Calculate overlap of two boxes. """ w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) - tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx])) h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) - tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1])) i = w * h u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \ (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \ (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \ (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i return tvm.expr.Select(u <= 0.0, 0.0, i / u)
def test_basic(): a = tvm.var() b = tvm.var() m = tvm.arith.EvalModular(a * 4 + b * 6 + 7) assert m.coeff == 2 assert m.base == 1 m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 + 3)) assert m.coeff == 4 assert m.base == 3 m = tvm.arith.EvalModular((a * 4 + 1) / (b * 8 + 3)) assert m.coeff == 1 assert m.base == 0 m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 / 4)) assert m.coeff == 2 assert m.base == 0 m = tvm.arith.EvalModular((a * 12 + 1) - (b * 3 * 7 + 2)) assert m.coeff == 3 assert m.base == 2 m = tvm.arith.EvalModular(a * 12 + tvm.min(b * 3 * 7, 2)) assert m.coeff == 1 assert m.base == 0
def my_clip(x, a_min, a_max): """Unlike topi's current clip, put min and max into two stages.""" const_min = tvm.const(a_min, x.dtype) const_max = tvm.const(a_max, x.dtype) x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") return x
def test_canonical_mixed(): ck = CanonicalChecker() x = tvm.var("x") z = tvm.const(3, "int32") tdiv = tvm.truncdiv tmod = tvm.truncmod ck.verify(tdiv(x, (z * z)) - tdiv(x, (z * z)), 0) ck.verify(tdiv(x, (z + z)) - tdiv(x, (z + z)), 0) ck.verify(x - 2 < 3, x < 5) ck.verify(tvm.max(x, 1) - tvm.max(x, 1), 0) ck.verify(tvm.min(x, 1) - tvm.min(x, 1), 0) ck.verify(x * x - x * x, 0) fld = tvm.floordiv ck.verify(fld(x, (z * z)) - fld(x, (z * z)), 0) ck.verify(fld(x, (z + z)) - fld(x, (z + z)), 0)
def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, ib, ic): """ Get 2d pixel """ if boxes is None: y = tvm.max(tvm.min(y, image_height - 1), 0) x = tvm.max(tvm.min(x, image_width - 1), 0) if layout == 'NHWC': return data(n, y, x, c).astype('float') if layout == 'NCHW': return data(n, c, y, x).astype('float') if nchw_pack_layout(layout): return data(n, c, y, x, ib, ic).astype('float') # else must be NCHWxc assert nchw_xc_layout(layout) return data(n, c, y, x, cc).astype('float')
def test_mix_index(): a = tvm.var("a") b = tvm.var("b") analyzer = tvm.arith.Analyzer() m = analyzer.modular_set(a * 4 + b * 6 + 7) assert m.coeff == 2 assert m.base == 1 m = analyzer.modular_set((a * 4 + 1) * (b * 8 + 3)) assert m.coeff == 4 assert m.base == 3 m = analyzer.modular_set((a * 4 + 1) / (b * 8 + 3)) assert m.coeff == 1 assert m.base == 0 m = analyzer.modular_set((a * 4 + 1) * (b * 8 / 4)) assert m.coeff == 2 assert m.base == 0 m = analyzer.modular_set((a * 12 + 1) - (b * 3 * 7 + 2)) assert m.coeff == 3 assert m.base == 2 m = analyzer.modular_set(a * 12 + tvm.min(b * 3 * 7, 2)) assert m.coeff == 1 assert m.base == 0
def _run(env, remote): m = 8 n = 10 # compute a = tvm.placeholder((m, n, env.BATCH, env.BLOCK_OUT), name="a", dtype=env.acc_dtype) a_buf = tvm.compute((m, n, env.BATCH, env.BLOCK_OUT), lambda *i: a(*i), "a_buf") # DRAM->SRAM max_buf = tvm.compute((m, n, env.BATCH, env.BLOCK_OUT), lambda *i: tvm.max(a_buf(*i), 0), "res_buf") # relu min_buf = tvm.compute((m, n, env.BATCH, env.BLOCK_OUT), lambda *i: tvm.min(max_buf(*i), (1 << (env.INP_WIDTH - 1)) - 1), "max_buf") # relu res = tvm.compute((m, n, env.BATCH, env.BLOCK_OUT), lambda *i: min_buf(*i).astype(env.inp_dtype), "min_buf") # SRAM->DRAM # schedule s = tvm.create_schedule(res.op) s[a_buf].set_scope(env.acc_scope) # SRAM s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM s[max_buf].set_scope(env.acc_scope) # SRAM s[min_buf].set_scope(env.acc_scope) # SRAM s[max_buf].pragma(max_buf.op.axis[0], env.alu) # compute s[min_buf].pragma(min_buf.op.axis[0], env.alu) # compute s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM # build with vta.build_config(): mod = vta.build(s, [a, res], "ext_dev", env.target_host) if not remote: return temp = util.tempdir() mod.save(temp.relpath("load_act.o")) remote.upload(temp.relpath("load_act.o")) f = remote.load_module("load_act.o") # verify ctx = remote.ext_dev(0) a_np = np.random.randint(-256, 256, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype) res_np = np.clip(a_np, 0, (1 << (env.INP_WIDTH - 1)) - 1).astype(res.dtype) a_nd = tvm.nd.array(a_np, ctx) res_nd = tvm.nd.array( np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) if env.TARGET == "tsim": simulator.tsim_init("libvta_hw") f(a_nd, res_nd) np.testing.assert_equal(res_np, res_nd.asnumpy()) if env.TARGET == "tsim": print("Relu test took {} clock cycles".format( simulator.tsim_cycles()))
def test_div_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") tdiv = tvm.truncdiv tmod = tvm.truncmod ck.verify(tdiv(x, x), 1) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(z, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(tdiv(tdiv(x, 2), 3), tdiv(x, 6)) ck.verify(tdiv(tdiv(x, 2) + 1, 3), tdiv(x + 2, 6)) ck.verify(tdiv(x * 2, 4), tdiv(x, 2)) ck.verify(tdiv(x * 4, 2), x * 2) ck.verify(tdiv(x * 4 + y, 2), x * 2 + tdiv(y, 2)) ck.verify(tdiv(tvm.min(x * 6, y), 2), tvm.min(x * 3, tdiv(y, 2))) ck.verify(tdiv(tvm.max(x * 6, y), 2), tvm.max(x * 3, tdiv(y, 2))) ck.verify(tdiv(y + x * 4, 2), tdiv(y, 2) + x * 2) ck.verify(tdiv(tvm.min(y, x * 6), 2), tvm.min(tdiv(y, 2), x * 3)) ck.verify(tdiv(tvm.max(y, x * 6), 2), tvm.max(tdiv(y, 2), x * 3)) # 3-operands ck.verify(tdiv(x * 6 + y + z, 2), x * 3 + tdiv(y + z, 2)) ck.verify(tdiv(x * 6 - y + (y + 3), 2), x * 3 + 1) ck.verify(tdiv(x * 6 + (y + 3) - y, 2), x * 3 + 1) ck.verify(tdiv(y + x * 6 + z, 2), x * 3 + tdiv(y + z, 2)) ck.verify(tdiv(x + 4, 2), tdiv(x, 2) + 2) ck.verify(tdiv(x + y, x), tdiv(y, x) + 1) ck.verify(tdiv(y + x, x), tdiv(y, x) + 1) ck.verify(tdiv((x + y) + z, x), tdiv(y + z, x) + 1) ck.verify(tdiv((y + x) + z, x), tdiv(y + z, x) + 1) ck.verify(tdiv(y + (x + z), x), tdiv(y + z, x) + 1) ck.verify(tdiv(y + (z + x), x), tdiv(y + z, x) + 1) ck.verify(tdiv(x * y, y), x) ck.verify(tdiv(y * x, y), x) ck.verify(tdiv(x * z + y, z), x + tdiv(y, z)) ck.verify(tdiv(z * x + y, z), x + tdiv(y, z)) ck.verify(tdiv(y + x * z, z), tdiv(y, z) + x) ck.verify(tdiv(y + z * x, z), tdiv(y, z) + x)
def test(): env = nnpu.get_env() a = tvm.placeholder((4, 16), env.cfg['dtype_w'], 'a') sph = ScheduleProcHelper() a_buf, a_dram = nnpu.utils.CopyHtoBuf(a, 'a', sph) k = tvm.reduce_axis((0, 16), 'k') c_buf = tvm.compute((4, 1), lambda i, j: tvm.sum(a_buf[i,k], axis=k), 'c_buf') sph.MarkScope(c_buf) c_host, c_dram = nnpu.utils.CopyBufToH(c_buf, 'c', sph) k1 = tvm.reduce_axis((0, 16), 'k1') max_buf = tvm.compute((4, 1), lambda i, j: tvm.max(a_buf[i,k1], axis=k1), 'max_buf') sph.MarkScope(max_buf) max_host, max_dram = nnpu.utils.CopyBufToH(max_buf, 'max', sph) k2 = tvm.reduce_axis((0, 16), 'k2') min_buf = tvm.compute((4, 1), lambda i, j: tvm.min(a_buf[i,k2], axis=k2), 'min_buf') sph.MarkScope(min_buf) min_host, min_dram = nnpu.utils.CopyBufToH(min_buf, 'min', sph) # create schedule and tensorize s = tvm.create_schedule([c_host.op, max_host.op, min_host.op]) sph.Transform(s) s[c_buf].tensorize(s[c_buf].op.axis[1], env.intrins.get('VReduceSum', mode='w')) s[max_buf].tensorize(s[max_buf].op.axis[1], env.intrins.get('VReduceMax', mode='w')) s[min_buf].tensorize(s[min_buf].op.axis[1], env.intrins.get('VReduceMin', mode='w')) # build print(nnpu.lower(s, [a, c_host, max_host, min_host], simple_mode=True)) func = nnpu.build(s, [a, c_host, max_host, min_host], 'nnpu', 'llvm', name='nnpu_func') # create data and run ctx = tvm.nd.TVMContext(13, 0) a_np = np.random.randint(size=(4, 16), dtype=a.dtype, low = 0, high = 64) #a_np = np.random.random(size=shape).astype(a_host.dtype) a_nd = tvm.nd.array(a_np, ctx) c_nd = tvm.nd.array(np.zeros((4, 1)).astype(c_host.dtype), ctx) max_nd = tvm.nd.array(np.zeros((4, 1)).astype(c_host.dtype), ctx) min_nd = tvm.nd.array(np.zeros((4, 1)).astype(c_host.dtype), ctx) func(a_nd, c_nd, max_nd, min_nd) # check results gt = np.sum(a_np, axis=(1,), keepdims=True) np.testing.assert_allclose(c_nd.asnumpy(), gt) np.testing.assert_allclose(max_nd.asnumpy(), np.max(a_np, axis=(1,), keepdims=True)) np.testing.assert_allclose(min_nd.asnumpy(), np.min(a_np, axis=(1,), keepdims=True)) print('test passed')
def test_floordiv_index_simplify(): # short name for floordiv fld = tvm.floordiv ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") ck.verify(fld(fld(x, 2), 3), fld(x, 6)) ck.verify(fld(fld(x, 2) + 1, 3), fld(x + 2, 6)) ck.verify(fld(x * 2, 4), fld(x, 2)) ck.verify(fld(x * 4, 2), x * 2) ck.verify(fld(x * 4 + y, 2), x * 2 + fld(y, 2)) ck.verify(fld(tvm.min(x * 6, y), 2), tvm.min(x * 3, fld(y, 2))) ck.verify(fld(tvm.max(x * 6, y), 2), tvm.max(x * 3, fld(y, 2))) ck.verify(fld(y + x * 4, 2), fld(y, 2) + x * 2) ck.verify(fld(tvm.min(y, x * 6), 2), tvm.min(fld(y, 2), x * 3)) ck.verify(fld(tvm.max(y, x * 6), 2), tvm.max(fld(y, 2), x * 3)) # 3-operands ck.verify(fld(x * 6 + y + z, 2), x * 3 + fld(y + z, 2)) ck.verify(fld(x * 6 - y + (y + 3), 2), x * 3 + 1) ck.verify(fld(x * 6 + (y + 3) - y, 2), x * 3 + 1) ck.verify(fld(y + x * 6 + z, 2), x * 3 + fld(y + z, 2)) ck.verify(fld(x + 4, 2), fld(x, 2) + 2) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(fld(x + y, x), fld(y, x) + 1) ck.verify(fld(y + x, x), fld(y, x) + 1) ck.verify(fld((x + y) + z, x), fld(y + z, x) + 1) ck.verify(fld((y + x) + z, x), fld(y + z, x) + 1) ck.verify(fld(y + (x + z), x), fld(y + z, x) + 1) ck.verify(fld(y + (z + x), x), fld(y + z, x) + 1) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(z, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(fld(x * y, y), x) ck.verify(fld(y * x, y), x) ck.verify(fld(x * z + y, z), x + fld(y, z)) ck.verify(fld(z * x + y, z), x + fld(y, z)) ck.verify(fld(y + x * z, z), fld(y, z) + x) ck.verify(fld(y + z * x, z), fld(y, z) + x)
def test_div_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") ck.verify(x / x, 1) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(z, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(x / 2 / 3, x / 6) ck.verify((x / 2 + 1) / 3, (x + 2) / 6) ck.verify(x * 2 / 4, x / 2) ck.verify(x * 4 / 2, x * 2) ck.verify((x * 4 + y) / 2, x * 2 + y / 2) ck.verify(tvm.min(x * 6, y) / 2, tvm.min(x * 3, y / 2)) ck.verify(tvm.max(x * 6, y) / 2, tvm.max(x * 3, y / 2)) ck.verify((y + x * 4) / 2, y / 2 + x * 2) ck.verify(tvm.min(y, x * 6) / 2, tvm.min(y / 2, x * 3)) ck.verify(tvm.max(y, x * 6) / 2, tvm.max(y / 2, x * 3)) # 3-operands ck.verify((x * 6 + y + z) / 2, x * 3 + (y + z) / 2) ck.verify((x * 6 - y + (y + 3)) / 2, x * 3 + 1) ck.verify((x * 6 + (y + 3) - y) / 2, x * 3 + 1) ck.verify((y + x * 6 + z) / 2, x * 3 + (y + z) / 2) ck.verify((x + 4) / 2, x / 2 + 2) ck.verify((x + y) / x, y / x + 1) ck.verify((y + x) / x, y / x + 1) ck.verify(((x + y) + z) / x, (y + z) / x + 1) ck.verify(((y + x) + z) / x, (y + z) / x + 1) ck.verify((y + (x + z)) / x, (y + z) / x + 1) ck.verify((y + (z + x)) / x, (y + z) / x + 1) ck.verify((x * y) / y, x) ck.verify((y * x) / y, x) ck.verify((x * z + y) / z, x + y / z) ck.verify((z * x + y) / z, x + y / z) ck.verify((y + x * z) / z, y / z + x) ck.verify((y + z * x) / z, y / z + x)
def _sample(i, c, ph, pw): roi = rois[i] batch_index = roi[0].astype('int32') roi_start_w = roi[1] * spatial_scale roi_start_h = roi[2] * spatial_scale roi_end_w = roi[3] * spatial_scale roi_end_h = roi[4] * spatial_scale roi_h = roi_end_h - roi_start_h roi_w = roi_end_w - roi_start_w roi_h = roi_h roi_w = roi_w bin_h = roi_h / pooled_size_h bin_w = roi_w / pooled_size_w hstart = ph * bin_h wstart = pw * bin_w hend = (ph + 1) * bin_h wend = (pw + 1) * bin_w hstart = tvm.min(tvm.max(hstart + roi_start_h, 0), height - 1) wstart = tvm.min(tvm.max(wstart + roi_start_w, 0), width - 1) hend = tvm.min(tvm.max(hend + roi_start_h, 0), height - 1) wend = tvm.min(tvm.max(wend + roi_start_w, 0), width - 1) non_empty = tvm.all(hstart < hend, wstart < wend) def min_value(dtype): return tvm.expr.Select(non_empty, tvm.min_value(dtype), tvm.const(0.0, dtype)) stride_h = (hend - hstart) / 3.0 stride_w = (wend - wstart) / 3.0 hstart += stride_h wstart += stride_w stride_h = tvm.max(0.01, stride_h) stride_w = tvm.max(0.01, stride_w) _max = tvm.comm_reducer(lambda x, y: tvm.make._OpMax(x, y), min_value, name='max') rh = tvm.reduce_axis((0, tvm.expr.Select(non_empty, 2, 0)), 'rh') rw = tvm.reduce_axis((0, tvm.expr.Select(non_empty, 2, 0)), 'rw') return _max(_bilinear(batch_index, c, hstart + rh * stride_h, wstart + rw * stride_w), axis=[rh, rw])
def test_select_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") # Add rules ck.verify(tvm.expr.Select(x < 0, y, 0) + tvm.expr.Select(x < 0, 1, z), tvm.expr.Select(x < 0, y + 1, z)) ck.verify(tvm.expr.Select(x < 0, y, 1) - tvm.expr.Select(x < 0, 1, z), tvm.expr.Select(x < 0, y + (-1), 1 - z)) ck.verify(tvm.expr.Select(x < 0, y, z) - y, tvm.expr.Select(x < 0, 0, z - y)) ck.verify(tvm.expr.Select(x < 0, y, z) - z, tvm.expr.Select(x < 0, y - z, 0)) ck.verify(tvm.min(tvm.expr.Select(x < 0, y, 0), tvm.expr.Select(x < 0, 1, z)), tvm.expr.Select(x < 0, tvm.min(y, 1), tvm.min(0, z))) ck.verify(tvm.max(tvm.expr.Select(x < 0, y, 0), tvm.expr.Select(x < 0, 1, z)), tvm.expr.Select(x < 0, tvm.max(y, 1), tvm.max(0, z))) ck.verify(tvm.expr.Select(x * 3 + 1 != 0, y, z), y) ck.verify(tvm.expr.Select(x * 3 + 1 == 0, y, z), z) ck.verify(tvm.expr.Select(x > 0, y + 1, y + 1), y + 1)
def test_add_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") ck.verify(x + (y - x), y) ck.verify(x - (y + 1) + (y + 1), x) ck.verify((x - 10) + (10 - z), x - z) ck.verify((x - y) + (z - x), z - y) ck.verify(tvm.min(x, y - z) + z, tvm.min(x + z, y)) ck.verify(tvm.min(x - z, y) + z, tvm.min(x, y + z)) ck.verify(tvm.max(x, y - 10) + 10, tvm.max(x + 10, y)) ck.verify(tvm.max(x - 11, y) + 11, tvm.max(x, y + 11)) ck.verify(tvm.max(x, y * 2) + tvm.min(x, y * 2), x + y * 2) ck.verify(tvm.min(x, y * 2) + tvm.max(x, y * 2), x + y * 2) ck.verify(tvm.max(x, y + 2) + (-2), tvm.max(x + (-2), y)) ck.verify(tvm.min(x, y + 2) + (-2), tvm.min(x + (-2), y)) ck.verify(tvm.min(x + 2, y + 3) + (-2), tvm.min(x, y + 1)) ck.verify(x * y + x * 10, x * (y + 10)) ck.verify(y * x + x * 10, x * (y + 10)) ck.verify(y * x + 10 * x, x * (y + 10)) ck.verify(x * y + 10 * x, x * (y + 10)) ck.verify(y * (x % 8) + 10 * (x % 8), (x % 8) * (y + 10)) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify((x / 8) * 8 + x % 8, x) # canonicalization ck.verify(x + 2 + 3 + 4 + x, x * 2 + 9) ck.verify(x + 2 + 3 + 4 + x * 3, x * 4 + 9) # conservative bound try: ck.analyzer.update(x, tvm.arith.ConstIntBound(-1, 1000), override=True) ck.verify((x / 8) * 8 + x % 8, x) raise RuntimeError("bad") except AssertionError: pass
def test_add_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") ck.verify(x + (y - x), y) ck.verify(x - (y + 1) + (y + 1), x) ck.verify((x - 10) + (10 - z), x - z) ck.verify((x - y) + (z - x), z - y) ck.verify(tvm.min(x, y - z) + z, tvm.min(x + z, y)) ck.verify(tvm.min(x - z, y) + z, tvm.min(x, y + z)) ck.verify(tvm.max(x, y - 10) + 10, tvm.max(x + 10, y)) ck.verify(tvm.max(x - 11, y) + 11, tvm.max(x, y + 11)) ck.verify(tvm.max(x, y * 2) + tvm.min(x, y * 2), x + y * 2); ck.verify(tvm.min(x, y * 2) + tvm.max(x, y * 2), x + y * 2); ck.verify(tvm.max(x, y + 2) + (-2), tvm.max(x + (-2), y)); ck.verify(tvm.min(x, y + 2) + (-2), tvm.min(x + (-2), y)); ck.verify(tvm.min(x + 2, y + 3) + (-2), tvm.min(x, y + 1)); ck.verify(x * y + x * 10, x * (y + 10)) ck.verify(y * x + x * 10, x * (y + 10)) ck.verify(y * x + 10 * x, x * (y + 10)) ck.verify(x * y + 10 * x, x * (y + 10)) ck.verify(y * (x % 8) + 10 * (x % 8), (x % 8) * (y + 10)) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify((x / 8) * 8 + x % 8, x) # canonicalization ck.verify(x + 2 + 3 + 4 + x, x * 2 + 9); ck.verify(x + 2 + 3 + 4 + x * 3, x * 4 + 9); # conservative bound try: ck.analyzer.update(x, tvm.arith.ConstIntBound(-1, 1000), override=True) ck.verify((x / 8) * 8 + x % 8, x) raise RuntimeError("bad") except AssertionError: pass
def test_div_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(z, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(x / 2 / 3, x / 6) ck.verify((x / 2 + 1) / 3, (x + 2) / 6) ck.verify(x * 2 / 4, x / 2) ck.verify(x * 4 / 2, x * 2) ck.verify((x * 4 + y) / 2, x * 2 + y / 2) ck.verify(tvm.min(x * 6, y) / 2, tvm.min(x * 3, y / 2)) ck.verify(tvm.max(x * 6, y) / 2, tvm.max(x * 3, y / 2)) ck.verify((y + x * 4) / 2, y / 2 + x * 2) ck.verify(tvm.min(y, x * 6) / 2, tvm.min(y / 2, x * 3)) ck.verify(tvm.max(y, x * 6) / 2, tvm.max(y / 2, x * 3)) # 3-operands ck.verify((x * 6 + y + z) / 2, x * 3 + (y + z) / 2) ck.verify((x * 6 - y + (y + 3)) / 2, x * 3 + 1) ck.verify((x * 6 + (y + 3) - y) / 2, x * 3 + 1) ck.verify((y + x * 6 + z) / 2, x * 3 + (y + z) / 2) ck.verify((x + 4) / 2, x / 2 + 2) ck.verify((x + y) / x, y / x + 1) ck.verify((y + x) / x, y / x + 1) ck.verify(((x + y) + z) / x, (y + z) / x + 1) ck.verify(((y + x) + z) / x, (y + z) / x + 1) ck.verify((y + (x + z)) / x, (y + z) / x + 1) ck.verify((y + (z + x)) / x, (y + z) / x + 1) ck.verify((x * y) / y, x) ck.verify((y * x) / y, x) ck.verify((x * z + y) / z, x + y / z) ck.verify((z * x + y) / z, x + y / z) ck.verify((y + x * z) / z, y / z + x) ck.verify((y + z * x) / z, y / z + x)
def _pool(i, c, ph, pw): roi = rois[i] batch_index = roi[0].astype('int32') roi_start_w, roi_start_h, roi_end_w, roi_end_h = roi[1], roi[2], roi[ 3], roi[4] roi_start_h = tvm.round(roi_start_h * spatial_scale).astype('int32') roi_start_w = tvm.round(roi_start_w * spatial_scale).astype('int32') roi_end_h = tvm.round(roi_end_h * spatial_scale).astype('int32') roi_end_w = tvm.round(roi_end_w * spatial_scale).astype('int32') # force malformed ROIs to be 1x1 roi_h = tvm.max(roi_end_h - roi_start_h + 1, tvm.const(1, 'int32')) roi_w = tvm.max(roi_end_w - roi_start_w + 1, tvm.const(1, 'int32')) bin_h = roi_h.astype(dtype) / pooled_size_h bin_w = roi_w.astype(dtype) / pooled_size_w # use epsilon to prevent floating point precision loss in floor/ceil epsilon = tvm.const(0.00001, dtype) hstart = tvm.floor(ph * bin_h + epsilon).astype('int32') wstart = tvm.floor(pw * bin_w + epsilon).astype('int32') hend = tvm.ceil((ph + 1) * bin_h - epsilon).astype('int32') wend = tvm.ceil((pw + 1) * bin_w - epsilon).astype('int32') hstart = tvm.min(tvm.max(hstart + roi_start_h, 0), height) wstart = tvm.min(tvm.max(wstart + roi_start_w, 0), width) hend = tvm.min(tvm.max(hend + roi_start_h, 0), height) wend = tvm.min(tvm.max(wend + roi_start_w, 0), width) non_empty = tvm.all(hstart < hend, wstart < wend) min_value = lambda dtype: tvm.if_then_else( non_empty, tvm.min_value(dtype), tvm.const(0.0, dtype)) # pylint: disable=unnecessary-lambda _max = tvm.comm_reducer(lambda x, y: tvm.make._OpMax(x, y), min_value, name='max') rh = tvm.reduce_axis((0, hend - hstart), 'rh') rw = tvm.reduce_axis((0, wend - wstart), 'rw') return _max(data[batch_index, c, hstart + rh, wstart + rw], axis=[rh, rw])
def compute_clip(attrs, inputs, _): """ Clip operator. """ x = inputs[0] a_min = attrs.get_float("a_min") a_max = attrs.get_float("a_max") const_min = tvm.const(a_min, x.dtype) const_max = tvm.const(a_max, x.dtype) with tvm.tag_scope(topi.tag.ELEMWISE): x = tvm.compute( x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") x = tvm.compute( x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") return x
def compute_clip(attrs, inputs, output_type, target): """ Clip operator. """ x = inputs[0] a_min = attrs.a_min a_max = attrs.a_max const_min = tvm.const(a_min, x.dtype) const_max = tvm.const(a_max, x.dtype) with tvm.tag_scope(topi.tag.ELEMWISE): x = tvm.compute( x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") x = tvm.compute( x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") return [x]
def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, vh): """Transform prior anchor box to output box through location predictions. """ al = anchor[anchor_base_idx] at = anchor[anchor_base_idx + 1] ar = anchor[anchor_base_idx + 2] ab = anchor[anchor_base_idx + 3] aw = ar - al ah = ab - at ax = (al + ar) / 2.0 ay = (at + ab) / 2.0 px = loc[loc_base_idx] py = loc[loc_base_idx + 1] pw = loc[loc_base_idx + 2] ph = loc[loc_base_idx + 3] ox = px * vx * aw + ax oy = py * vy * ah + ay ow = exp(pw * vw) * aw / 2.0 oh = exp(ph * vh) * ah / 2.0 return tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, ox - ow)), ox - ow), \ tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, oy - oh)), oy - oh), \ tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, ox + ow)), ox + ow), \ tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, oy + oh)), oy + oh)
def test_min_max_select(): analyzer = tvm.arith.Analyzer() x, y = tvm.var("x"), tvm.var("y") m = analyzer.modular_set(tvm.min(x * 3, y * 9)) assert m.coeff == 3 assert m.base == 0 m = analyzer.modular_set(tvm.max(x * 3 + 1, y * 9 + 4)) assert m.coeff == 3 assert m.base == 1 m = analyzer.modular_set(tvm.expr.Select(x > 0, x * 3 + 1, y * 9 + 2)) assert m.coeff == 1 assert m.base == 0
def test_cmp_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") # const int bound ck.verify((x % 2 + 10).equal(0), tvm.const(0, "bool")) ck.verify(tvm.expr.NE(x % 2 + 10, 0), tvm.const(1, "bool")) ck.verify(x % 2 + 10 > 1, tvm.const(1, "bool")) ck.verify(x % 2 + 10 <= 1, tvm.const(0, "bool")) ck.verify(x * 3 + 10 == 0, tvm.const(0, "bool")) ck.verify(x * 3 + 10 != 0, tvm.const(1, "bool")) # canonicalization ck.verify((x - 10).equal(0), x.equal(10)) ck.verify((10 - x).equal(0), x.equal(10)) ck.verify((x * y).equal(0), tvm.expr.Or(x.equal(0), y.equal(0))) # cmp bound ck.verify(x + y < x + z, y < z) ck.verify(x + y < z + x, y < z) ck.verify(y + x < x + z, y < z) ck.verify(y + x < z + x, y < z) ck.verify(y - x < z - x, y < z) ck.verify(x - y < x - z, z < y) ck.verify(x < z + x, tvm.expr.LT(0, z)) ck.verify(x < x + z, tvm.expr.LT(0, z)) ck.verify(100 < x + 1, tvm.expr.LT(99, x)) ck.verify(1 < 100 - x, tvm.expr.LT(x, 99)) ck.verify(x * 3 < y * 3, x < y) ck.verify(x * (-3) < y * (-3), y < x) ck.verify(x * 3 >= y * 3, y <= x) ck.verify(x * 4 >= 2, tvm.expr.LE(1, x)) ck.verify(x * 2 >= 50, tvm.expr.LE(25, x)) ck.verify(x / 2 < 3, x < 6) ck.verify(x * 4 <= 2, x <= 0) ck.verify(3 < x / 2, tvm.expr.LT(7, x)) ck.verify(x / 4 * 4 < x, tvm.expr.LT(0, x % 4)) ck.verify(x / 4 * 4 >= x, tvm.expr.LE(x % 4, 0)) ck.verify(x / 4 * 4 < x + y, tvm.expr.LT(0, x % 4 + y)) ck.verify(x / 4 * 4 < x - y, tvm.expr.LT(y, x % 4)) ck.verify((x + 2) / 4 * 4 >= x, tvm.expr.LE((x + 2) % 4, 2)) ck.verify((x + 2) / 4 * 4 >= x + y, tvm.expr.LE((x + 2) % 4 + y, 2)) ck.verify((x + 2) / 4 * 4 >= x - y, tvm.expr.LE((x + 2) % 4 + (-2), y)) ck.verify(tvm.min(x, 11) < 10, x < 10) ck.verify(tvm.min(x, 8) < 10, tvm.const(1, "bool")) ck.verify(tvm.max(8, x) > 10, tvm.expr.LT(10, x)) ck.verify(x + 1 < tvm.max(8, x), x < 7) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(-10, 0), override=True) ck.analyzer.update(z, tvm.arith.ConstIntBound(-5, 5), override=True) ck.verify(x < 11, tvm.const(1, "bool")) ck.verify(x <= 10, tvm.const(1, "bool")) ck.verify(z <= 5, tvm.const(1, "bool")) ck.verify(x + y <= 10, tvm.const(1, "bool")) ck.verify(x + y >= -10, tvm.const(1, "bool")) ck.verify(z - 5 <= y + 10, tvm.const(1, "bool")) ck.verify(tvm.all(x > -1, z <= x + 5), tvm.const(1, "bool")) ck.verify(x*y <= 0, tvm.const(1, "bool")) ck.verify((x + 1)*(y - 1) < 0, tvm.const(1, "bool")) ck.verify(y*y >= 0, tvm.const(1, "bool"))
def test_max_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") # const int bound ck.verify(tvm.max(x % 2, y % 2 + 10), y % 2 + 10) ck.verify(tvm.max(x + 1, x + 10), x + 10) ck.verify(tvm.max(x + 111, x + 10), x + 111) ck.verify(tvm.max(x + 1, x), x + 1) ck.verify(tvm.max(x, x + 2), x + 2) ck.verify(tvm.max(1 - x, 2 - x), 2 - x) ck.verify(tvm.max(3 - x, 2 - x), 3 - x) ck.verify(tvm.max((x + 3) / 4 * 4, x), (x + 3) / 4 * 4) ck.verify(tvm.max(tvm.min(x, y), tvm.max(x, y)), tvm.max(x, y)) ck.verify(tvm.max(tvm.min(x, y), tvm.max(y, x)), tvm.max(x, y)) ck.verify(tvm.max(tvm.min(x, y), x), x) ck.verify(tvm.max(tvm.min(y, x), x), x) ck.verify(tvm.max(tvm.max(x, y), x), tvm.max(x, y)) ck.verify(tvm.max(tvm.max(x, y), y), tvm.max(x, y)) ck.verify(tvm.max(x, tvm.min(x, y)), x) ck.verify(tvm.max(x, tvm.min(y, x)), x) ck.verify(tvm.max(x, tvm.max(x, y)), tvm.max(x, y)) ck.verify(tvm.max(y, tvm.max(x, y)), tvm.max(x, y)) ck.verify(tvm.max(tvm.max(tvm.max(x, y), z), y), tvm.max(tvm.max(x, y), z)) ck.verify(tvm.max(tvm.max(tvm.max(tvm.max(x, y), z), x * 2), y), tvm.max(tvm.max(tvm.max(x, y), z), x * 2)) ck.verify(tvm.max(tvm.max(tvm.max(tvm.max(tvm.max(x, y), z), x * 2), z * 2), y), tvm.max(tvm.max(tvm.max(tvm.max(x, y), z), x * 2), z * 2)) ck.verify(tvm.max(tvm.min(x, y), tvm.min(x, z)), tvm.min(tvm.max(y, z), x)) ck.verify(tvm.max(tvm.min(x, y), tvm.min(z, x)), tvm.min(tvm.max(y, z), x)) ck.verify(tvm.max(tvm.min(y, x), tvm.min(x, z)), tvm.min(tvm.max(y, z), x)) ck.verify(tvm.max(tvm.min(y, x), tvm.min(z, x)), tvm.min(tvm.max(y, z), x)) ck.verify(tvm.max(y + x, z + x), tvm.max(y, z) + x) ck.verify(tvm.max(y + x, x + z), tvm.max(y, z) + x) ck.verify(tvm.max(x + y, z + x), tvm.max(y, z) + x) ck.verify(tvm.max(x + y, x + z), tvm.max(y, z) + x) ck.verify(tvm.max(x - y, x - z), x - tvm.min(y, z)) ck.verify(tvm.max(y - x, z - x), tvm.max(y, z) - x) ck.verify(tvm.max(tvm.max(x, 1), 10), tvm.max(x, 10)) ck.verify(tvm.max(tvm.max(x, 11), 10), tvm.max(x, 11)) ck.verify(tvm.max(x / 10, y / 10), tvm.max(x, y) / 10) ck.verify(tvm.max(x / (-10), y / (-10)), tvm.min(x, y) / (-10)) ck.verify(tvm.max(x * 3, 9), tvm.max(x, 3) * 3)
def test_min_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") # const int bound ck.verify(tvm.min(x % 2, y % 2 + 10), x % 2) ck.verify(tvm.min(x + 1, x + 10), x + 1) ck.verify(tvm.min(x + 111, x + 10), x + 10) ck.verify(tvm.min(x + 1, x), x) ck.verify(tvm.min(x, x + 2), x) ck.verify(tvm.min(1 - x, 2 - x), 1 - x) ck.verify(tvm.min(3 - x, 2 - x), 2 - x) ck.verify(tvm.min((x + 3) / 4 * 4, x), x) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000)) ck.verify(tvm.min((x + 3) / 4 * 4, tvm.max(x, 4)), tvm.max(x, 4)) ck.verify(tvm.min(x, (x + 3) / 4 * 4), x) ck.verify(tvm.min(tvm.max(x, 4), (x + 3) / 4 * 4), tvm.max(x, 4)) ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True) ck.verify(tvm.min(tvm.max(x, y), tvm.min(x, y)), tvm.min(x, y)) ck.verify(tvm.min(tvm.max(x, y), tvm.min(y, x)), tvm.min(x, y)) ck.verify(tvm.min(tvm.max(x, y), x), x) ck.verify(tvm.min(tvm.max(y, x), x), x) ck.verify(tvm.min(tvm.min(x, y), x), tvm.min(x, y)) ck.verify(tvm.min(tvm.min(x, y), y), tvm.min(x, y)) ck.verify(tvm.min(x, tvm.max(x, y)), x) ck.verify(tvm.min(x, tvm.max(y, x)), x) ck.verify(tvm.min(x, tvm.min(x, y)), tvm.min(x, y)) ck.verify(tvm.min(y, tvm.min(x, y)), tvm.min(x, y)) ck.verify(tvm.min(tvm.min(tvm.min(x, y), z), y), tvm.min(tvm.min(x, y), z)) ck.verify(tvm.min(tvm.min(tvm.min(tvm.min(x, y), z), x * 2), y), tvm.min(tvm.min(tvm.min(x, y), z), x * 2)) ck.verify(tvm.min(tvm.min(tvm.min(tvm.min(tvm.min(x, y), z), x * 2), z * 2), y), tvm.min(tvm.min(tvm.min(tvm.min(x, y), z), x * 2), z * 2)) ck.verify(tvm.min(tvm.max(x, y), tvm.max(x, z)), tvm.max(tvm.min(y, z), x)) ck.verify(tvm.min(tvm.max(x, y), tvm.max(z, x)), tvm.max(tvm.min(y, z), x)) ck.verify(tvm.min(tvm.max(y, x), tvm.max(x, z)), tvm.max(tvm.min(y, z), x)) ck.verify(tvm.min(tvm.max(y, x), tvm.max(z, x)), tvm.max(tvm.min(y, z), x)) ck.verify(tvm.min(y + x, z + x), tvm.min(y, z) + x) ck.verify(tvm.min(y + x, x + z), tvm.min(y, z) + x) ck.verify(tvm.min(x + y, z + x), tvm.min(y, z) + x) ck.verify(tvm.min(x + y, x + z), tvm.min(y, z) + x) ck.verify(tvm.min(x - y, x - z), x - tvm.max(y, z)) ck.verify(tvm.min(y - x, z - x), tvm.min(y, z) - x) ck.verify(tvm.min(tvm.min(x, 1), 10), tvm.min(x, 1)) ck.verify(tvm.min(tvm.min(x, 11), 10), tvm.min(x, 10)) ck.verify(tvm.min(x / 10, y / 10), tvm.min(x, y) / 10) ck.verify(tvm.min(x / (-10), y / (-10)), tvm.max(x, y) / (-10)) ck.verify(tvm.min(x * 3, 9), tvm.min(x, 3) * 3)
def test_vector_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") # Add rules ck.verify(tvm.expr.Ramp(x, 1, 4) + tvm.expr.Ramp(y, 2, 4), tvm.expr.Ramp(x + y, 3, 4)) ck.verify(tvm.expr.Ramp(x, 1, 2) + y, tvm.expr.Ramp(x + y, 1, 2)) ck.verify(y + tvm.expr.Ramp(x, 1, 2) , tvm.expr.Ramp(y + x, 1, 2)) ck.verify(y.astype("int32x2") + x.astype("int32x2"), (y + x).astype("int32x2")) # Sub rules ck.verify(tvm.expr.Ramp(x, 4, 4) - tvm.expr.Ramp(y, 2, 4), tvm.expr.Ramp(x - y, 2, 4)) ck.verify(tvm.expr.Ramp(x, 1, 2) - y, tvm.expr.Ramp(x - y, 1, 2)) ck.verify(y - tvm.expr.Ramp(x, 1, 2) , tvm.expr.Ramp(y - x, -1, 2)) ck.verify(y.astype("int32x2") - x.astype("int32x2"), (y - x).astype("int32x2")) # Mul rules ck.verify(y.astype("int32x2") * x.astype("int32x2"), (y * x).astype("int32x2")) ck.verify(tvm.expr.Ramp(x, 4, 4) * 2, tvm.expr.Ramp(x * 2, 8, 4)) ck.verify(2 * tvm.expr.Ramp(x, 4, 4), tvm.expr.Ramp(x * 2, 8, 4)) ## Div rules ck.verify(y.astype("int32x2") / x.astype("int32x2"), (y / x).astype("int32x2")) ck.verify(tvm.expr.Ramp(x, 4, 4) / 2, tvm.expr.Ramp(x/ 2, 2, 4)) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) / 8, (x).astype("int32x4")) ck.verify(tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8, tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8) ## Mod rules ck.verify(y.astype("int32x2") % x.astype("int32x2"), (y % x).astype("int32x2")) ck.verify(tvm.expr.Ramp(x, 4, 4) % 2, tvm.expr.Broadcast(x % 2, 4)) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) % 8, tvm.expr.Ramp(1, 1, 4)) ck.verify(tvm.expr.Ramp(x * 8 + 1, 15, 4) % 8, tvm.expr.Ramp(1, 15, 4) % 8) # Min/Max rules vx = tvm.var("vx", dtype="int32x2") vc = tvm.var("vc", dtype="uint1") ck.verify(tvm.min(y.astype("int32x2"), x.astype("int32x2")), tvm.min(y, x).astype("int32x2")) ck.verify(tvm.min(tvm.min(vx, y.astype("int32x2")), x.astype("int32x2")), tvm.min(vx, tvm.min(y, x).astype("int32x2"))) ck.verify(tvm.max(y.astype("int32x2"), x.astype("int32x2")), tvm.max(y, x).astype("int32x2")) ck.verify(tvm.max(tvm.max(vx, y.astype("int32x2")), x.astype("int32x2")), tvm.max(vx, tvm.max(y, x).astype("int32x2"))) ## Logical rules ck.verify(y.astype("int32x2").equal(x.astype("int32x2")), (y.equal(x)).astype("uint1x2")) ck.verify(tvm.expr.NE(y.astype("int32x2"), (x.astype("int32x2"))), (tvm.expr.NE(y, x)).astype("uint1x2")) ck.verify(y.astype("int32x2") > x.astype("int32x2"), (x < y).astype("uint1x2")) ck.verify(y.astype("int32x2") >= x.astype("int32x2"), (x <= y).astype("uint1x2")) ck.verify(y.astype("int32x2") < x.astype("int32x2"), (y < x).astype("uint1x2")) ck.verify(y.astype("int32x2") <= x.astype("int32x2"), (y <= x).astype("uint1x2")) ck.verify(tvm.expr.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")), (tvm.expr.And(y <= x, vc)).astype("uint1x2")) ck.verify(tvm.expr.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")), (tvm.expr.Or(y <= x, vc)).astype("uint1x2"))
def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, ratios, feature_stride, rpn_min_size, iou_loss): """Predict bounding boxes based on anchors, scores and deltas. Parameters ---------- cls_prob_buf : tvm.schedule.Buffer 4-D with shape [batch, 2 * num_anchors, height, width] bbox_pred_buf : tvm.schedule.Buffer 4-D with shape [batch, 4 * num_anchors, height, width] im_info_buf : tvm.schedule.Buffer 2-D with shape [batch, 3] out_buf : tvm.schedule.Buffer 3-D with shape [batch, num_bbox, 5] The last dimension is in format of [w_start, h_start, w_end, h_end, score] scales : list/tuple of float Scales of anchor windoes. ratios : list/tuple of float Ratios of anchor windoes. feature_stride : int The size of the receptive field each unit in the convolution layer of the rpn, for example the product of all stride's prior to this layer. rpn_min_size : int Minimum height or width in proposal. iou_loss : bool Usage of IoU loss. Returns ------- stmt : Stmt The result IR statement. """ batch, num_anchors, height, width = get_const_tuple(cls_prob_buf.shape) num_anchors //= 2 max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = (batch * height * width) // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") tid = bx * max_threads + tx ib = tvm.ir_builder.create() ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) p_score = ib.buffer_ptr(cls_prob_buf) p_delta = ib.buffer_ptr(bbox_pred_buf) p_im_info = ib.buffer_ptr(im_info_buf) p_out = ib.buffer_ptr(out_buf) with ib.if_scope(tid < batch * height * width): w = tid % width h = (tid // width) % height b = tid // width // height for k in range(num_anchors): out_index = tid * num_anchors + k ratio = ratios[k // len(scales)] scale = scales[k % len(scales)] anchor = generate_anchor(ratio, scale, feature_stride) im_height = p_im_info[b * 3] im_width = p_im_info[b * 3 + 1] x1 = anchor[0] + w * feature_stride y1 = anchor[1] + h * feature_stride x2 = anchor[2] + w * feature_stride y2 = anchor[3] + h * feature_stride delta = [p_delta[((((b * num_anchors + k) * 4 + i) * height + h) * width + w)] for i in range(4)] regression_func = reg_iou if iou_loss else reg_bbox pred_x1, pred_y1, pred_x2, pred_y2 = regression_func(x1, y1, x2, y2, *delta) pred_x1 = tvm.max(tvm.min(pred_x1, im_width - 1.0), 0.0) pred_y1 = tvm.max(tvm.min(pred_y1, im_height - 1.0), 0.0) pred_x2 = tvm.max(tvm.min(pred_x2, im_width - 1.0), 0.0) pred_y2 = tvm.max(tvm.min(pred_y2, im_height - 1.0), 0.0) real_height = (im_height / feature_stride).astype('int32') real_width = (im_width / feature_stride).astype('int32') bbox_w = pred_x2 - pred_x1 + 1.0 bbox_h = pred_y2 - pred_y1 + 1.0 min_size = p_im_info[b * 3 + 2] * rpn_min_size pred_score = p_score[((b * num_anchors * 2 + num_anchors + k) * height + h) * width + w] pred_score = tvm.expr.Select(tvm.any(h >= real_height, w >= real_width), -1.0, pred_score) p_out[out_index * 5 + 0] = pred_x1 p_out[out_index * 5 + 1] = pred_y1 p_out[out_index * 5 + 2] = pred_x2 p_out[out_index * 5 + 3] = pred_y2 p_out[out_index * 5 + 4] = pred_score with ib.if_scope(tvm.any(bbox_w < min_size, bbox_h < min_size)): p_out[out_index * 5 + 0] -= min_size / 2.0 p_out[out_index * 5 + 1] -= min_size / 2.0 p_out[out_index * 5 + 2] += min_size / 2.0 p_out[out_index * 5 + 3] += min_size / 2.0 p_out[out_index * 5 + 4] = -1.0 return ib.get()
def run_gemm_packed(env, remote, batch_size, channel, block): data_shape = (batch_size // env.BATCH, channel // env.BLOCK_IN, env.BATCH, env.BLOCK_IN) weight_shape = (channel // env.BLOCK_OUT, channel // env.BLOCK_IN, env.BLOCK_OUT, env.BLOCK_IN) res_shape = (batch_size // env.BATCH, channel // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT) # To compute number of ops, use a x2 factor for FMA num_ops = 2 * channel * channel * batch_size ko = tvm.reduce_axis((0, channel // env.BLOCK_IN), name='ko') ki = tvm.reduce_axis((0, env.BLOCK_IN), name='ki') data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) weight = tvm.placeholder(weight_shape, name="weight", dtype=env.wgt_dtype) data_buf = tvm.compute(data_shape, lambda *i: data(*i), "data_buf") weight_buf = tvm.compute(weight_shape, lambda *i: weight(*i), "weight_buf") res_gem = tvm.compute(res_shape, lambda bo, co, bi, ci: tvm.sum( data_buf[bo, ko, bi, ki].astype(env.acc_dtype) * weight_buf[co, ko, ci, ki].astype(env.acc_dtype), axis=[ko, ki]), name="res_gem") res_shf = tvm.compute(res_shape, lambda *i: res_gem(*i)>>8, name="res_shf") res_max = tvm.compute(res_shape, lambda *i: tvm.max(res_shf(*i), 0), "res_max") #relu res_min = tvm.compute(res_shape, lambda *i: tvm.min(res_max(*i), (1<<(env.INP_WIDTH-1))-1), "res_min") #relu res = tvm.compute(res_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name="res") def verify(s, check_correctness=True): mod = vta.build(s, [data, weight, res], "ext_dev", env.target_host, name="gemm") temp = util.tempdir() mod.save(temp.relpath("gemm.o")) remote.upload(temp.relpath("gemm.o")) f = remote.load_module("gemm.o") # verify ctx = remote.ext_dev(0) # Data in original format data_orig = np.random.randint( -128, 128, size=(batch_size, channel)).astype(data.dtype) weight_orig = np.random.randint( -128, 128, size=(channel, channel)).astype(weight.dtype) data_packed = data_orig.reshape( batch_size // env.BATCH, env.BATCH, channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3)) weight_packed = weight_orig.reshape( channel // env.BLOCK_OUT, env.BLOCK_OUT, channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3)) res_np = np.zeros(res_shape).astype(res.dtype) data_arr = tvm.nd.array(data_packed, ctx) weight_arr = tvm.nd.array(weight_packed, ctx) res_arr = tvm.nd.array(res_np, ctx) res_ref = np.zeros(res_shape).astype(env.acc_dtype) for b in range(batch_size // env.BATCH): for i in range(channel // env.BLOCK_OUT): for j in range(channel // env.BLOCK_IN): res_ref[b,i,:] += np.dot(data_packed[b,j,:].astype(env.acc_dtype), weight_packed[i,j].T.astype(env.acc_dtype)) res_ref = np.right_shift(res_ref, 8) res_ref = np.clip(res_ref, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype) time_f = f.time_evaluator("gemm", ctx, number=20) cost = time_f(data_arr, weight_arr, res_arr) res_unpack = res_arr.asnumpy().reshape(batch_size // env.BATCH, channel // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT) if check_correctness: tvm.testing.assert_allclose(res_unpack, res_ref) return cost def run_schedule(load_inp, load_wgt, gemm, alu, store_out, print_ir, check_correctness): s = tvm.create_schedule(res.op) s[data_buf].set_scope(env.inp_scope) s[weight_buf].set_scope(env.wgt_scope) s[res_gem].set_scope(env.acc_scope) s[res_shf].set_scope(env.acc_scope) s[res_min].set_scope(env.acc_scope) s[res_max].set_scope(env.acc_scope) if block: bblock = block // env.BATCH iblock = block // env.BLOCK_IN oblock = block // env.BLOCK_OUT xbo, xco, xbi, xci = s[res].op.axis xb1, xco1, xb2, xco2 = s[res].tile(xbo, xco, bblock, oblock) store_pt = xb2 s[res_gem].compute_at(s[res], xco1) s[res_shf].compute_at(s[res], xco1) s[res_min].compute_at(s[res], xco1) s[res_max].compute_at(s[res], xco1) xbo, xco, xbi, xci = s[res_gem].op.axis # Compute one line at a time ko1, ko2 = s[res_gem].split(ko, iblock) s[res_gem].reorder(ko1, ko2, xbo, xco, xbi, xci, ki) s[data_buf].compute_at(s[res_gem], ko1) s[weight_buf].compute_at(s[res_gem], ko1) # Use VTA instructions s[data_buf].pragma(s[data_buf].op.axis[0], load_inp) s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt) s[res_gem].tensorize(xbi, gemm) s[res_shf].pragma(s[res_shf].op.axis[0], alu) s[res_min].pragma(s[res_min].op.axis[0], alu) s[res_max].pragma(s[res_max].op.axis[0], alu) s[res].pragma(store_pt, store_out) else: xbo, xco, xbi, xci = s[res_gem].op.axis s[res_gem].reorder(ko, xbo, xco, xbi, xci, ki) # Use VTA instructions s[data_buf].pragma(s[data_buf].op.axis[0], load_inp) s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt) s[res_gem].tensorize(xbi, gemm) s[res_shf].pragma(s[res_shf].op.axis[0], alu) s[res_min].pragma(s[res_min].op.axis[0], alu) s[res_max].pragma(s[res_max].op.axis[0], alu) s[res].pragma(s[res].op.axis[0], store_out) if print_ir: print(tvm.lower(s, [data, weight, res], simple_mode=True)) return verify(s, check_correctness) def gemm_normal(print_ir): mock = env.mock print("----- GEMM GOPS End-to-End Test-------") def run_test(header, print_ir, check_correctness): cost = run_schedule( env.dma_copy, env.dma_copy, env.gemm, env.alu, env.dma_copy, print_ir, check_correctness) gops = (num_ops / cost.mean) / float(10 ** 9) print(header) print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) with vta.build_config(): run_test("NORMAL", print_ir, True) def gemm_unittest(print_ir): mock = env.mock print("----- GEMM Unit Test-------") def run_test(header, print_ir): cost = run_schedule( mock.dma_copy, mock.dma_copy, env.gemm, mock.alu, mock.dma_copy, print_ir, False) gops = (num_ops / cost.mean) / float(10 ** 9) print(header) print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) with vta.build_config(): run_test("NORMAL", print_ir) def alu_unittest(print_ir): mock = env.mock print("----- ALU Unit Test-------") def run_test(header, print_ir): cost = run_schedule( mock.dma_copy, mock.dma_copy, mock.gemm, env.alu, mock.dma_copy, print_ir, False) gops = (num_ops / cost.mean) / float(10 ** 9) print(header) print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) with vta.build_config(): run_test("NORMAL", print_ir) print("") def load_inp_unittest(print_ir): mock = env.mock print("----- LoadInp Unit Test-------") def run_test(header, print_ir): cost = run_schedule( env.dma_copy, mock.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False) gops = (num_ops / cost.mean) / float(10 ** 9) bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10 ** 9) print(header) print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % ( cost.mean, gops, bandwith)) with vta.build_config(): run_test("NORMAL", print_ir) print("") def load_wgt_unittest(print_ir): mock = env.mock print("----- LoadWgt Unit Test-------") def run_test(header, print_ir): cost = run_schedule( mock.dma_copy, env.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False) gops = (num_ops / cost.mean) / float(10 ** 9) bandwith = (channel * channel * env.WGT_WIDTH / cost.mean) / float(10 ** 9) print(header) print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % ( cost.mean, gops, bandwith)) with vta.build_config(): run_test("NORMAL", print_ir) print("") def store_out_unittest(print_ir): mock = env.mock print("----- StoreOut Unit Test-------") def run_test(header, print_ir): cost = run_schedule( mock.dma_copy, mock.dma_copy, mock.gemm, mock.alu, env.dma_copy, print_ir, False) gops = (num_ops / cost.mean) / float(10 ** 9) bandwith = (batch_size * channel * env.OUT_WIDTH / cost.mean) / float(10 ** 9) print(header) print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % ( cost.mean, gops, bandwith)) with vta.build_config(): run_test("NORMAL", print_ir) print("") gemm_normal(False) gemm_unittest(False) alu_unittest(False)
def test_sub_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") ck.verify(x + y - y, x) ck.verify(x + y - x, y) ck.verify(x - (y + x), 0 - y) ck.verify(x - (x + y), 0 - y) ck.verify(tvm.min(x, y) - x, tvm.min(0, y - x)) ck.verify(tvm.min(x, y) - y, tvm.min(x - y, 0)) ck.verify(tvm.max(x, y) - x, tvm.max(0, y - x)) ck.verify(tvm.max(x, y) - y, tvm.max(x - y, 0)) ck.verify(x - tvm.min(x, y), tvm.max(0, x - y)) ck.verify(y - tvm.min(x, y), tvm.max(y - x, 0)) ck.verify(x - tvm.max(x, y), tvm.min(0, x - y)) ck.verify(y - tvm.max(x, y), tvm.min(y - x, 0)) # mul co-efficient foldng ck.verify(x - x, 0) ck.verify(x * y - x, x * (y + (-1))) ck.verify(x * y - 10 * x, x * (y + (-10))) ck.verify(y * x - x * z, x * (y - z)) ck.verify(y * x - z * x, x * (y - z)) ck.verify(x + 10 - 20, x + (-10)) # 4-operands pattern ck.verify((x + y) - (x + z), y - z) ck.verify((y + x) - (x + z), y - z) ck.verify((x + y) - (z + x), y - z) ck.verify((y + x) - (z + x), y - z) ck.verify(tvm.min(x + y, z) - x, tvm.min(y, z - x)) ck.verify(tvm.min(y + x, z) - x, tvm.min(y, z - x)) ck.verify(tvm.min(z, x + y) - x, tvm.min(z - x, y)) ck.verify(tvm.min(z, y + x) - x, tvm.min(z - x, y)) ck.verify(x - tvm.min(x + y, z), tvm.max(0 - y, x - z)) ck.verify(x - tvm.min(y + x, z), tvm.max(0 - y, x - z)) ck.verify(x - tvm.min(z, x + y), tvm.max(x - z, 0 - y)) ck.verify(x - tvm.min(z, y + x), tvm.max(x - z, 0 - y)) ck.verify(tvm.min(x, y) - tvm.min(y, x), 0) ck.verify(tvm.max(x, y) - tvm.max(y, x), 0) ck.verify(tvm.min(x, y) - tvm.min(x + 10, y + 10), -10) ck.verify(tvm.min(x + 10, y + 1) - tvm.min(x, y - 9), 10) # div pattern ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(x - (x / 3) * 3, x % 3) ck.verify((x + 5) / 3 - x / 3, (((x + 2) % 3) + 5)/ 3)
def _compute(*indices): value = x(*indices) const_min = tvm.const(a_min, value.dtype) const_max = tvm.const(a_max, value.dtype) return tvm.max(tvm.min(value, const_max), const_min)
def test_make(): x = tvm.const(1) y = tvm.var("x") z = x + y assert isinstance(tvm.max(x, y), tvm.expr.Max) assert isinstance(tvm.min(x, y), tvm.expr.Min)
kernel_buf[co, ic, dy, dx, ci, ic_tns].astype(env.acc_dtype), axis=[ic, dy, dx, ic_tns]), name="res_conv") # Add shift stage for fix-point normalization res_shr = tvm.compute(output_shape, lambda *i: res_conv(*i) >> 8, name="res_shr") # Apply clipping between (0, input max value) inp_max = (1 << (env.INP_WIDTH - 1)) - 1 res_max = tvm.compute(output_shape, lambda *i: tvm.max(res_shr(*i), 0), "res_max") res_min = tvm.compute(output_shape, lambda *i: tvm.min(res_max(*i), inp_max), "res_min") # Result Tensor res = tvm.compute(output_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name="res") ###################################################################### # Scheduling the Computation # -------------------------- # We'll look at a set of schedule transformations necessary to map the # 2D convolution onto VTA in an efficient fashion. # Those include: #