def test_oneD_pool(): m = tvm.var('m') ib = tvm.ir_builder.create() #data = tvm.placeholder((16,), name = 'data') data = ib.pointer("float32", name="A") out = ib.pointer("float32", name="A") with ib.for_range(0, 16, 'ow') as ow: with ib.for_range(0, 3, 'kw') as kw: with ib.if_scope(ib.likely(ow > 0)): with ib.if_scope(ib.likely(ow < 15)): out[ow] = tvm.max(out[ow], data[ow + kw - 1]) with ib.for_range(0, 16, 'ow') as ow: with ib.for_range(0, 3, 'kw') as kw: with ib.if_scope(ib.likely(ow < 1)): with ib.if_scope(ib.likely(kw > 0)): out[ow] = tvm.max(out[ow], data[ow + kw - 1]) with ib.for_range(0, 16, 'ow') as ow: with ib.for_range(0, 3, 'kw') as kw: with ib.if_scope(ib.likely(ow > 14)): with ib.if_scope(ib.likely(kw < 2)): out[ow] = tvm.max(out[ow], data[ow + kw - 1]) stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt, True) stmt = tvm.ir_pass.Simplify(stmt) assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def _sample(i, c, ph, pw): roi = rois[i] batch_index = roi[0].astype('int32') roi_start_w, roi_start_h, roi_end_w, roi_end_h = roi[1], roi[2], roi[3], roi[4] roi_start_h *= spatial_scale roi_end_h *= spatial_scale roi_start_w *= spatial_scale roi_end_w *= spatial_scale # force malformed ROIs to be 1x1 roi_h = tvm.max(roi_end_h - roi_start_h, tvm.const(1.0, dtype)) roi_w = tvm.max(roi_end_w - roi_start_w, tvm.const(1.0, dtype)) bin_h = roi_h / pooled_size_h bin_w = roi_w / pooled_size_w if sample_ratio > 0: roi_bin_grid_h = roi_bin_grid_w = tvm.const(sample_ratio, 'int32') else: roi_bin_grid_h = tvm.ceil(roi_h / pooled_size_h).astype('int32') roi_bin_grid_w = tvm.ceil(roi_w / pooled_size_w).astype('int32') count = roi_bin_grid_h * roi_bin_grid_w rh = tvm.reduce_axis((0, roi_bin_grid_h)) rw = tvm.reduce_axis((0, roi_bin_grid_w)) roi_start_h += ph * bin_h roi_start_w += pw * bin_w return tvm.sum(_bilinear(batch_index, c, roi_start_h + (rh + 0.5) * bin_h / roi_bin_grid_h, roi_start_w + (rw + 0.5) * bin_w / roi_bin_grid_w) / count, axis=[rh, rw])
def test_min_max_bound(): analyzer = tvm.arith.Analyzer() x, y = tvm.var("x"), tvm.var("y") analyzer.update(x, tvm.arith.ConstIntBound(-9, 11)) analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) bd = analyzer.const_int_bound(tvm.min(x, y)) assert bd.min_value == -9 assert bd.max_value == 10 analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) bd = analyzer.const_int_bound(tvm.min(x, y)) assert bd.min_value == bd.NEG_INF assert bd.max_value == 10 bd = analyzer.const_int_bound(tvm.max(x, y)) assert bd.min_value == 4 assert bd.max_value == bd.POS_INF analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True) analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) bd = analyzer.const_int_bound(tvm.max(x, y)) assert bd.min_value == 4 assert bd.max_value == bd.POS_INF
def test_mul_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") ck.verify((x + 2) * 3, x * 3 + 6) ck.verify((x * 2) * 3, x * 6) ck.verify(tvm.min(x, y) * tvm.max(x, y), x * y) ck.verify(tvm.max(x, y) * tvm.min(x, y), x * y) ck.verify((x - y) * (-2), (y - x) * 2)
def test_deduce(): a = tvm.var('a') b = tvm.var('b') c = tvm.var('c') d = tvm.var('d') b_s = tvm.arith.intset_interval(2, 3) c_s = tvm.arith.intset_interval(10, 15) d_s = tvm.arith.intset_interval(-3, -1) zero = tvm.const(0, "int32") e0 = (-b)*a+c-d res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) ans0 = ((d - c) /(b*-1)) assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) # expression containing variable a is on rhs res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) e0 = d*a+c-d res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) ans0 = ((0-c)/d + 1) assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) # expression containing variable a is on rhs res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) e1 = (a*4+b < c) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) ans1 = (((c - b) + -1)/4) assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) # expression containing variable a is on rhs e1 = (c > a*4+b) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) e2 = (tvm.max(5, a * 4) < 0) res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) assert str(res2.max()) == "neg_inf" assert str(res2.min()) == "pos_inf" # expression containing variable a is on rhs e2 = (zero < tvm.max(5, a * 4)) res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) assert str(res2.max()) == "neg_inf" assert str(res2.min()) == "pos_inf" e3 = (-b)+a*c-d res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) ans3 = 2/c+1 assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3) res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)
def cb(src, dst, pad_before, pad_after, pad_value): assert(dst.elem_offset.value == 0) assert_expr_equal(src.elem_offset, tvm.max(xo * 4, 1) - 1) rpad_before = tvm.max(1 - xo * 4, 0) rpad_after = tvm.max(xo * 4 - 7, 0) assert_expr_equal(pad_before[0], rpad_before) assert_expr_equal(pad_after[0], rpad_after) assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) return tvm.make.Evaluate(0)
def test_simplify_minmax(): x = tvm.var('x') e1 = tvm.max(x, 1) - tvm.max(x, 1) e1s = tvm.ir_pass.CanonicalSimplify(e1) assert e1s.value == 0 e2 = tvm.min(x, 1) - tvm.min(x, 1) e2s = tvm.ir_pass.CanonicalSimplify(e2) assert e2s.value == 0
def calculate_overlap(out_tensor, box_a_idx, box_b_idx): """Calculate overlap of two boxes. """ w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) - tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx])) h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) - tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1])) i = w * h u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \ (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \ (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \ (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i return tvm.expr.Select(u <= 0.0, 0.0, i / u)
def calculate_overlap(out_tensor, box_a_idx, box_b_idx): """Calculate overlap of two boxes. """ w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) - tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx]) + 1.0) h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) - tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]) + 1.0) i = w * h u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx] + 1.0) * \ (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1] + 1.0) + \ (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx] + 1.0) * \ (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1] + 1.0) - i return i / u
def global_pool(data, pool_type): """Perform global pooling on the data Parameters ---------- data : tvm.Tensor 4-D with shape [batch, channel, in_height, in_width] pool_type : str Pool type, 'max' or 'avg' Returns ------- output : tvm.Tensor 4-D with shape [batch, channel, 1, 1] """ assert len(data.shape) == 4, "only support 4-dim pooling" batch, channel, height, width = data.shape dheight = tvm.reduce_axis((0, height)) dwidth = tvm.reduce_axis((0, width)) if pool_type == 'max': return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \ tvm.max(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \ tag="global_pool_max") elif pool_type == 'avg': tsum = tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \ tvm.sum(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \ tag="global_pool_sum") return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \ tsum[n, c, h, w] / (height*width).astype(tsum.dtype), \ tag=tag.ELEMWISE) else: raise ValueError("Pool type should be 'avg' or 'max'.")
def my_clip(x, a_min, a_max): """Unlike topi's current clip, put min and max into two stages.""" const_min = tvm.const(a_min, x.dtype) const_max = tvm.const(a_max, x.dtype) x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") return x
def test_deduce(): a = tvm.var('a') b = tvm.var('b') c = tvm.var('c') d = tvm.var('d') b_s = tvm.arith.intset_interval(2, 3) c_s = tvm.arith.intset_interval(10, 15) d_s = tvm.arith.intset_interval(-3, -1) e0 = (-b)*a+c-d res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) ans0 = ((d - c) /(b*-1)) assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) e0 = d*a+c-d res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) ans0 = ((0-c)/d + 1) assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) e1 = (a*4+b < c) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) ans1 = (((c - b) + -1)/4) assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) e2 = (tvm.max(5, a * 4) < 0) res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) assert str(res2.max()) == "neg_inf" assert str(res2.min()) == "pos_inf" e3 = (-b)+a*c-d res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) ans3 = 2/c+1 assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)
def test_select_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") # Add rules ck.verify(tvm.expr.Select(x < 0, y, 0) + tvm.expr.Select(x < 0, 1, z), tvm.expr.Select(x < 0, y + 1, z)) ck.verify(tvm.expr.Select(x < 0, y, 1) - tvm.expr.Select(x < 0, 1, z), tvm.expr.Select(x < 0, y + (-1), 1 - z)) ck.verify(tvm.expr.Select(x < 0, y, z) - y, tvm.expr.Select(x < 0, 0, z - y)) ck.verify(tvm.expr.Select(x < 0, y, z) - z, tvm.expr.Select(x < 0, y - z, 0)) ck.verify(tvm.min(tvm.expr.Select(x < 0, y, 0), tvm.expr.Select(x < 0, 1, z)), tvm.expr.Select(x < 0, tvm.min(y, 1), tvm.min(0, z))) ck.verify(tvm.max(tvm.expr.Select(x < 0, y, 0), tvm.expr.Select(x < 0, 1, z)), tvm.expr.Select(x < 0, tvm.max(y, 1), tvm.max(0, z))) ck.verify(tvm.expr.Select(x * 3 + 1 != 0, y, z), y) ck.verify(tvm.expr.Select(x * 3 + 1 == 0, y, z), z) ck.verify(tvm.expr.Select(x > 0, y + 1, y + 1), y + 1)
def test_div_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(z, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(x / 2 / 3, x / 6) ck.verify((x / 2 + 1) / 3, (x + 2) / 6) ck.verify(x * 2 / 4, x / 2) ck.verify(x * 4 / 2, x * 2) ck.verify((x * 4 + y) / 2, x * 2 + y / 2) ck.verify(tvm.min(x * 6, y) / 2, tvm.min(x * 3, y / 2)) ck.verify(tvm.max(x * 6, y) / 2, tvm.max(x * 3, y / 2)) ck.verify((y + x * 4) / 2, y / 2 + x * 2) ck.verify(tvm.min(y, x * 6) / 2, tvm.min(y / 2, x * 3)) ck.verify(tvm.max(y, x * 6) / 2, tvm.max(y / 2, x * 3)) # 3-operands ck.verify((x * 6 + y + z) / 2, x * 3 + (y + z) / 2) ck.verify((x * 6 - y + (y + 3)) / 2, x * 3 + 1) ck.verify((x * 6 + (y + 3) - y) / 2, x * 3 + 1) ck.verify((y + x * 6 + z) / 2, x * 3 + (y + z) / 2) ck.verify((x + 4) / 2, x / 2 + 2) ck.verify((x + y) / x, y / x + 1) ck.verify((y + x) / x, y / x + 1) ck.verify(((x + y) + z) / x, (y + z) / x + 1) ck.verify(((y + x) + z) / x, (y + z) / x + 1) ck.verify((y + (x + z)) / x, (y + z) / x + 1) ck.verify((y + (z + x)) / x, (y + z) / x + 1) ck.verify((x * y) / y, x) ck.verify((y * x) / y, x) ck.verify((x * z + y) / z, x + y / z) ck.verify((z * x + y) / z, x + y / z) ck.verify((y + x * z) / z, y / z + x) ck.verify((y + z * x) / z, y / z + x)
def relu(x): """Take relu of input x. Parameters ---------- x : tvm.Tensor Input argument. Returns ------- y : tvm.Tensor The result. """ return tvm.compute(x.shape, lambda *i: tvm.max(x(*i), tvm.const(0, x.dtype)))
def test_min_max_select(): analyzer = tvm.arith.Analyzer() x, y = tvm.var("x"), tvm.var("y") m = analyzer.modular_set(tvm.min(x * 3, y * 9)) assert m.coeff == 3 assert m.base == 0 m = analyzer.modular_set(tvm.max(x * 3 + 1, y * 9 + 4)) assert m.coeff == 3 assert m.base == 1 m = analyzer.modular_set(tvm.expr.Select(x > 0, x * 3 + 1, y * 9 + 2)) assert m.coeff == 1 assert m.base == 0
def compute_clip(attrs, inputs, _): """ Clip operator. """ x = inputs[0] a_min = attrs.get_float("a_min") a_max = attrs.get_float("a_max") const_min = tvm.const(a_min, x.dtype) const_max = tvm.const(a_max, x.dtype) with tvm.tag_scope(topi.tag.ELEMWISE): x = tvm.compute( x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") x = tvm.compute( x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") return x
def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, vh): """Transform prior anchor box to output box through location predictions. """ al = anchor[anchor_base_idx] at = anchor[anchor_base_idx + 1] ar = anchor[anchor_base_idx + 2] ab = anchor[anchor_base_idx + 3] aw = ar - al ah = ab - at ax = (al + ar) / 2.0 ay = (at + ab) / 2.0 px = loc[loc_base_idx] py = loc[loc_base_idx + 1] pw = loc[loc_base_idx + 2] ph = loc[loc_base_idx + 3] ox = px * vx * aw + ax oy = py * vy * ah + ay ow = exp(pw * vw) * aw / 2.0 oh = exp(ph * vh) * ah / 2.0 return tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, ox - ow)), ox - ow), \ tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, oy - oh)), oy - oh), \ tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, ox + ow)), ox + ow), \ tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, oy + oh)), oy + oh)
def check_select(ctx, n, dtype): A = tvm.placeholder((n,), name='A', dtype=dtype) true_value = tvm.const(1, dtype=dtype) false_value = tvm.const(3, dtype=dtype) max_lhs = tvm.const(2, dtype=dtype) max_rhs = tvm.expr.Select(A[0] > 0, true_value, false_value) C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C') s = tvm.create_schedule(C.op) s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x")) fun = tvm.build(s, [A, C], target) a = tvm.nd.empty((n,), A.dtype, ctx) c = tvm.nd.empty((n,), A.dtype, ctx) # Only need to test compiling here fun(a, c)
def test_add_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") ck.verify(x + (y - x), y) ck.verify(x - (y + 1) + (y + 1), x) ck.verify((x - 10) + (10 - z), x - z) ck.verify((x - y) + (z - x), z - y) ck.verify(tvm.min(x, y - z) + z, tvm.min(x + z, y)) ck.verify(tvm.min(x - z, y) + z, tvm.min(x, y + z)) ck.verify(tvm.max(x, y - 10) + 10, tvm.max(x + 10, y)) ck.verify(tvm.max(x - 11, y) + 11, tvm.max(x, y + 11)) ck.verify(tvm.max(x, y * 2) + tvm.min(x, y * 2), x + y * 2); ck.verify(tvm.min(x, y * 2) + tvm.max(x, y * 2), x + y * 2); ck.verify(tvm.max(x, y + 2) + (-2), tvm.max(x + (-2), y)); ck.verify(tvm.min(x, y + 2) + (-2), tvm.min(x + (-2), y)); ck.verify(tvm.min(x + 2, y + 3) + (-2), tvm.min(x, y + 1)); ck.verify(x * y + x * 10, x * (y + 10)) ck.verify(y * x + x * 10, x * (y + 10)) ck.verify(y * x + 10 * x, x * (y + 10)) ck.verify(x * y + 10 * x, x * (y + 10)) ck.verify(y * (x % 8) + 10 * (x % 8), (x % 8) * (y + 10)) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify((x / 8) * 8 + x % 8, x) # canonicalization ck.verify(x + 2 + 3 + 4 + x, x * 2 + 9); ck.verify(x + 2 + 3 + 4 + x * 3, x * 4 + 9); # conservative bound try: ck.analyzer.update(x, tvm.arith.ConstIntBound(-1, 1000), override=True) ck.verify((x / 8) * 8 + x % 8, x) raise RuntimeError("bad") except AssertionError: pass
def test_max_pool(): for i in range(5): N, H, W, CO, CI, KH, KW = [np.random.randint(10, 32) for _ in range(7)] (input_dtype, _) = random_dtypes() D = tvm.placeholder((N, CI, H, W), dtype=input_dtype) KH = min(H, KH) KW = min(W, KW) kh = tvm.reduce_axis((0, KH)) kw = tvm.reduce_axis((0, KW)) OH = (H - KH) + 1 OW = (W - KW) + 1 C = tvm.compute( (N, CO, OH, OW), lambda n, co, h, w: tvm.max(D[n][co][h + kh][w + kw], axis=[kh, kw])) s = tvm.create_schedule([C.op]) assert compute_flop(s) == N * CO * OH * OW * KH * KW
def log_softmax(x): """Perform log softmax activation on the data Parameters ---------- data : tvm.Tensor 2-D input data Returns ------- output : tvm.Tensor 2-D output with same shape """ assert len(x.shape) == 2, "only support 2-dim log softmax" m, n = x.shape k = tvm.reduce_axis((0, n), name='k') max_elem = tvm.compute((m, ), lambda i: tvm.max(x[i, k], axis=k)) k = tvm.reduce_axis((0, n), name='k') expsum = tvm.compute( (m, ), lambda i: tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k)) return tvm.compute( x.shape, lambda i, j: x[i, j] - max_elem[i] - tvm.log(expsum[i]))
def test_max_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") flm = tvm.floormod fld = tvm.floordiv # const int bound ck.verify(tvm.max(x % 2, y % 2 + 10), y % 2 + 10) ck.verify(tvm.max(flm(x, 2), flm(y, 2) + 10), flm(y, 2) + 10) ck.verify(tvm.max(x + 1, x + 10), x + 10) ck.verify(tvm.max(x + 111, x + 10), x + 111) ck.verify(tvm.max(x + 1, x), x + 1) ck.verify(tvm.max(x, x + 2), x + 2) ck.verify(tvm.max(1 - x, 2 - x), 2 - x) ck.verify(tvm.max(3 - x, 2 - x), 3 - x) ck.verify(tvm.max(tvm.min(x, y), tvm.max(x, y)), tvm.max(x, y)) ck.verify(tvm.max(tvm.min(x, y), tvm.max(y, x)), tvm.max(x, y)) ck.verify(tvm.max(tvm.min(x, y), x), x) ck.verify(tvm.max(tvm.min(y, x), x), x) ck.verify(tvm.max(tvm.max(x, y), x), tvm.max(x, y)) ck.verify(tvm.max(tvm.max(x, y), y), tvm.max(x, y)) ck.verify(tvm.max(x, tvm.min(x, y)), x) ck.verify(tvm.max(x, tvm.min(y, x)), x) ck.verify(tvm.max(x, tvm.max(x, y)), tvm.max(x, y)) ck.verify(tvm.max(y, tvm.max(x, y)), tvm.max(x, y)) ck.verify(tvm.max(tvm.max(tvm.max(x, y), z), y), tvm.max(tvm.max(x, y), z)) ck.verify(tvm.max(tvm.max(tvm.max(tvm.max(x, y), z), x * 2), y), tvm.max(tvm.max(tvm.max(x, y), z), x * 2)) ck.verify( tvm.max(tvm.max(tvm.max(tvm.max(tvm.max(x, y), z), x * 2), z * 2), y), tvm.max(tvm.max(tvm.max(tvm.max(x, y), z), x * 2), z * 2)) ck.verify(tvm.max(tvm.min(x, y), tvm.min(x, z)), tvm.min(tvm.max(y, z), x)) ck.verify(tvm.max(tvm.min(x, y), tvm.min(z, x)), tvm.min(tvm.max(y, z), x)) ck.verify(tvm.max(tvm.min(y, x), tvm.min(x, z)), tvm.min(tvm.max(y, z), x)) ck.verify(tvm.max(tvm.min(y, x), tvm.min(z, x)), tvm.min(tvm.max(y, z), x)) ck.verify(tvm.max(y + x, z + x), tvm.max(y, z) + x) ck.verify(tvm.max(y + x, x + z), tvm.max(y, z) + x) ck.verify(tvm.max(x + y, z + x), tvm.max(y, z) + x) ck.verify(tvm.max(x + y, x + z), tvm.max(y, z) + x) ck.verify(tvm.max(x - y, x - z), x - tvm.min(y, z)) ck.verify(tvm.max(y - x, z - x), tvm.max(y, z) - x) ck.verify(tvm.max(tvm.max(x, 1), 10), tvm.max(x, 10)) ck.verify(tvm.max(tvm.max(x, 11), 10), tvm.max(x, 11)) ck.verify(tvm.max(x * 3, 9), tvm.max(x, 3) * 3) ck.verify(tvm.max(3 - x, 1), 3 - tvm.min(x, 2)) # DivMod rules # truc div ck.verify(tvm.max(x / 10, y / 10), tvm.max(x, y) / 10) ck.verify(tvm.max(x / (-10), y / (-10)), tvm.min(x, y) / (-10)) ck.verify(tvm.max((x + 3) / 4 * 4, x), (x + 3) / 4 * 4) # floordiv ck.verify(tvm.max(fld(x, 10), fld(y, 10)), fld(tvm.max(x, y), 10)) ck.verify(tvm.max(fld(x, (-10)), fld(y, (-10))), fld(tvm.min(x, y), (-10))) ck.verify(tvm.max(fld(x + 3, 4) * 4, x), fld(x + 3, 4) * 4) ck.verify(tvm.max(fld(x, 4) * 4, x), x) ck.verify(tvm.max(x, fld(x, 4) * 4), x)
output_shape, lambda bo, co, i, j, bi, ci: tvm.sum(data_buf[ bo, ic, i * stride_h + dy, j * stride_w + dx, bi, ic_tns].astype( env.acc_dtype) * kernel_buf[co, ic, dy, dx, ci, ic_tns].astype( env.acc_dtype), axis=[ic, dy, dx, ic_tns]), name="res_conv") # Add shift stage for fix-point normalization res_shr = tvm.compute(output_shape, lambda *i: res_conv(*i) >> 8, name="res_shr") # Apply clipping between (0, input max value) inp_max = (1 << (env.INP_WIDTH - 1)) - 1 res_max = tvm.compute(output_shape, lambda *i: tvm.max(res_shr(*i), 0), "res_max") res_min = tvm.compute(output_shape, lambda *i: tvm.min(res_max(*i), inp_max), "res_min") # Result Tensor res = tvm.compute(output_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name="res") ###################################################################### # Scheduling the Computation # -------------------------- # We'll look at a set of schedule transformations necessary to map the # 2D convolution onto VTA in an efficient fashion. # Those include:
def test_make(): x = tvm.const(1) y = tvm.var("x") z = x + y assert isinstance(tvm.max(x, y), tvm.expr.Max) assert isinstance(tvm.min(x, y), tvm.expr.Min)
def _compute(*indices): value = x(*indices) const_min = tvm.const(a_min, value.dtype) const_max = tvm.const(a_max, value.dtype) return tvm.max(tvm.min(value, const_max), const_min)
def test_cmp_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") flm = tvm.floormod fld = tvm.floordiv # const int bound ck.verify((x % 2 + 10).equal(0), tvm.const(0, "bool")) ck.verify(tvm.expr.NE(x % 2 + 10, 0), tvm.const(1, "bool")) ck.verify(x % 2 + 10 > 1, tvm.const(1, "bool")) ck.verify(x % 2 + 10 <= 1, tvm.const(0, "bool")) ck.verify(flm(x, 2) + 2 > 1, tvm.const(1, "bool")) ck.verify(flm(x, 2) + 10 <= 1, tvm.const(0, "bool")) ck.verify(x * 3 + 10 == 0, tvm.const(0, "bool")) ck.verify(x * 3 + 10 != 0, tvm.const(1, "bool")) # canonicalization ck.verify((x - 10).equal(0), x.equal(10)) ck.verify((10 - x).equal(0), x.equal(10)) ck.verify((x * y).equal(0), tvm.expr.Or(x.equal(0), y.equal(0))) # cmp bound ck.verify(x + y < x + z, y < z) ck.verify(x + y < z + x, y < z) ck.verify(y + x < x + z, y < z) ck.verify(y + x < z + x, y < z) ck.verify(y - x < z - x, y < z) ck.verify(x - y < x - z, z < y) ck.verify(x < z + x, tvm.expr.LT(0, z)) ck.verify(x < x + z, tvm.expr.LT(0, z)) ck.verify(100 < x + 1, tvm.expr.LT(99, x)) ck.verify(1 < 100 - x, tvm.expr.LT(x, 99)) ck.verify(x * 3 < y * 3, x < y) ck.verify(x * (-3) < y * (-3), y < x) ck.verify(x * 3 >= y * 3, y <= x) ck.verify(x * 4 >= 2, tvm.expr.LE(1, x)) ck.verify(x * 2 >= 50, tvm.expr.LE(25, x)) ck.verify(x * 4 <= 2, x <= 0) ck.verify((0 - x * 3) <= 0, tvm.expr.LE(0, x)) ck.verify((0 - x * 3) >= 0, tvm.expr.LE(x, 0)) ck.verify(2 * x <= 0, x <= 0) ck.verify(x * 2 >= 3, tvm.expr.LE(2, x)) ck.verify(x * 2 >= 2, tvm.expr.LE(1, x)) ck.verify(x * 2 >= 1, tvm.expr.LE(1, x)) ck.verify(x * 2 >= 0, tvm.expr.LE(0, x)) ck.verify(x * 2 >= -1, tvm.expr.LE(0, x)) ck.verify(x * 2 >= -2, tvm.expr.LE(-1, x)) ck.verify(x * 2 >= -3, tvm.expr.LE(-1, x)) ck.verify(x * 2 <= 3, tvm.expr.LE(x, 1)) ck.verify(x * 2 <= 2, tvm.expr.LE(x, 1)) ck.verify(x * 2 <= 1, tvm.expr.LE(x, 0)) ck.verify(x * 2 <= 0, tvm.expr.LE(x, 0)) ck.verify(x * 2 <= -1, tvm.expr.LE(x, -1)) ck.verify(x * 2 <= -2, tvm.expr.LE(x, -1)) ck.verify(x * 2 <= -3, tvm.expr.LE(x, -2)) ck.verify(x * (-2) >= 3, tvm.expr.LE(x, -2)) ck.verify(x * (-2) >= 2, tvm.expr.LE(x, -1)) ck.verify(x * (-2) >= 1, tvm.expr.LE(x, -1)) ck.verify(x * (-2) >= 0, tvm.expr.LE(x, 0)) ck.verify(x * (-2) >= -1, tvm.expr.LE(x, 0)) ck.verify(x * (-2) >= -2, tvm.expr.LE(x, 1)) ck.verify(x * (-2) >= -3, tvm.expr.LE(x, 1)) ck.verify(x * (-2) <= 3, tvm.expr.LE(-1, x)) ck.verify(x * (-2) <= 2, tvm.expr.LE(-1, x)) ck.verify(x * (-2) <= 1, tvm.expr.LE(0, x)) ck.verify(x * (-2) <= 0, tvm.expr.LE(0, x)) ck.verify(x * (-2) <= -1, tvm.expr.LE(1, x)) ck.verify(x * (-2) <= -2, tvm.expr.LE(1, x)) ck.verify(x * (-2) <= -3, tvm.expr.LE(2, x)) # DivMod rules # truc div ck.verify(x / 2 < 3, x < 6) ck.verify(3 < x / 2, tvm.expr.LT(7, x)) ck.verify(x / 3 >= 0, tvm.expr.LE(-2, x)) ck.verify(x / 2 >= 1, tvm.expr.LE(2, x)) ck.verify(x / 2 >= 0, tvm.expr.LE(-1, x)) ck.verify(x / 2 >= -1, tvm.expr.LE(-3, x)) ck.verify(x / 2 <= 1, tvm.expr.LE(x, 3)) ck.verify(x / 2 <= 0, tvm.expr.LE(x, 1)) ck.verify(x / 2 <= -1, tvm.expr.LE(x, -2)) ck.verify(x / 4 * 4 < x, tvm.expr.LT(0, x % 4)) ck.verify(x / 4 * 4 >= x, tvm.expr.LE(x % 4, 0)) ck.verify(x / 4 * 4 < x + y, tvm.expr.LT(0, x % 4 + y)) ck.verify(x / 4 * 4 < x - y, tvm.expr.LT(y, x % 4)) ck.verify((x + 2) / 4 * 4 >= x, tvm.expr.LE((x + 2) % 4, 2)) ck.verify((x + 2) / 4 * 4 >= x + y, tvm.expr.LE((x + 2) % 4 + y, 2)) ck.verify((x + 2) / 4 * 4 >= x - y, tvm.expr.LE((x + 2) % 4 + (-2), y)) # floor div ck.verify(fld(x, 2) < 3, x < 6) ck.verify(3 < fld(x, 2), tvm.expr.LT(7, x)) ck.verify(-3 < fld(x, 2), tvm.expr.LT(-5, x)) ck.verify(fld(x, 3) >= 0, tvm.expr.LE(0, x)) ck.verify(fld(x, 2) >= 1, tvm.expr.LE(2, x)) ck.verify(fld(x, 2) >= 0, tvm.expr.LE(0, x)) ck.verify(fld(x, 2) >= -1, tvm.expr.LE(-2, x)) ck.verify(fld(x, 2) <= 1, tvm.expr.LE(x, 3)) ck.verify(fld(x, 2) <= 0, tvm.expr.LE(x, 1)) ck.verify(fld(x, 2) <= -1, tvm.expr.LE(x, -1)) ck.verify(fld(x, 4) * 4 < x, tvm.expr.LT(0, flm(x, 4))) ck.verify(fld(x, 4) * 4 >= x, tvm.expr.LE(flm(x, 4), 0)) ck.verify(fld(x, 4) * 4 < x + y, tvm.expr.LT(0, flm(x, 4) + y)) ck.verify(fld(x, 4) * 4 < x - y, tvm.expr.LT(y, flm(x, 4))) ck.verify(fld(x + 2, 4) * 4 >= x, tvm.expr.LE(flm(x + 2, 4), 2)) ck.verify(fld(x + 2, 4) * 4 >= x + y, tvm.expr.LE(flm(x + 2, 4) + y, 2)) ck.verify(fld(x + 2, 4) * 4 >= x - y, tvm.expr.LE(flm(x + 2, 4) + (-2), y)) # End DivMod Rules ck.verify(tvm.min(x, 11) < 10, x < 10) ck.verify(tvm.min(x, 8) < 10, tvm.const(1, "bool")) ck.verify(tvm.max(8, x) > 10, tvm.expr.LT(10, x)) ck.verify(x + 1 < tvm.max(8, x), x < 7) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(-10, 0), override=True) ck.analyzer.update(z, tvm.arith.ConstIntBound(-5, 5), override=True) ck.verify(x < 11, tvm.const(1, "bool")) ck.verify(x <= 10, tvm.const(1, "bool")) ck.verify(z <= 5, tvm.const(1, "bool")) ck.verify(x + y <= 10, tvm.const(1, "bool")) ck.verify(x + y >= -10, tvm.const(1, "bool")) ck.verify(z - 5 <= y + 10, tvm.const(1, "bool")) ck.verify(tvm.all(x > -1, z <= x + 5), tvm.const(1, "bool")) ck.verify(x * y <= 0, tvm.const(1, "bool")) ck.verify((x + 1) * (y - 1) < 0, tvm.const(1, "bool")) ck.verify(y * y >= 0, tvm.const(1, "bool")) ck.verify(x * 6 <= -3, tvm.const(0, "bool")) ck.verify((y - 1) % 3 == 0, (y + (-1)) % 3 == 0)
def test_vector_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") # Add rules ck.verify(tvm.expr.Ramp(x, 1, 4) + tvm.expr.Ramp(y, 2, 4), tvm.expr.Ramp(x + y, 3, 4)) ck.verify(tvm.expr.Ramp(x, 1, 2) + y, tvm.expr.Ramp(x + y, 1, 2)) ck.verify(y + tvm.expr.Ramp(x, 1, 2) , tvm.expr.Ramp(y + x, 1, 2)) ck.verify(y.astype("int32x2") + x.astype("int32x2"), (y + x).astype("int32x2")) # Sub rules ck.verify(tvm.expr.Ramp(x, 4, 4) - tvm.expr.Ramp(y, 2, 4), tvm.expr.Ramp(x - y, 2, 4)) ck.verify(tvm.expr.Ramp(x, 1, 2) - y, tvm.expr.Ramp(x - y, 1, 2)) ck.verify(y - tvm.expr.Ramp(x, 1, 2) , tvm.expr.Ramp(y - x, -1, 2)) ck.verify(y.astype("int32x2") - x.astype("int32x2"), (y - x).astype("int32x2")) # Mul rules ck.verify(y.astype("int32x2") * x.astype("int32x2"), (y * x).astype("int32x2")) ck.verify(tvm.expr.Ramp(x, 4, 4) * 2, tvm.expr.Ramp(x * 2, 8, 4)) ck.verify(2 * tvm.expr.Ramp(x, 4, 4), tvm.expr.Ramp(x * 2, 8, 4)) ## Div rules ck.verify(y.astype("int32x2") / x.astype("int32x2"), (y / x).astype("int32x2")) ck.verify(tvm.expr.Ramp(x, 4, 4) / 2, tvm.expr.Ramp(x/ 2, 2, 4)) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) / 8, (x).astype("int32x4")) ck.verify(tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8, tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8) ## Mod rules ck.verify(y.astype("int32x2") % x.astype("int32x2"), (y % x).astype("int32x2")) ck.verify(tvm.expr.Ramp(x, 4, 4) % 2, tvm.expr.Broadcast(x % 2, 4)) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) % 8, tvm.expr.Ramp(1, 1, 4)) ck.verify(tvm.expr.Ramp(x * 8 + 1, 15, 4) % 8, tvm.expr.Ramp(1, 15, 4) % 8) # Min/Max rules vx = tvm.var("vx", dtype="int32x2") vc = tvm.var("vc", dtype="uint1") ck.verify(tvm.min(y.astype("int32x2"), x.astype("int32x2")), tvm.min(y, x).astype("int32x2")) ck.verify(tvm.min(tvm.min(vx, y.astype("int32x2")), x.astype("int32x2")), tvm.min(vx, tvm.min(y, x).astype("int32x2"))) ck.verify(tvm.max(y.astype("int32x2"), x.astype("int32x2")), tvm.max(y, x).astype("int32x2")) ck.verify(tvm.max(tvm.max(vx, y.astype("int32x2")), x.astype("int32x2")), tvm.max(vx, tvm.max(y, x).astype("int32x2"))) ## Logical rules ck.verify(y.astype("int32x2").equal(x.astype("int32x2")), (y.equal(x)).astype("uint1x2")) ck.verify(tvm.expr.NE(y.astype("int32x2"), (x.astype("int32x2"))), (tvm.expr.NE(y, x)).astype("uint1x2")) ck.verify(y.astype("int32x2") > x.astype("int32x2"), (x < y).astype("uint1x2")) ck.verify(y.astype("int32x2") >= x.astype("int32x2"), (x <= y).astype("uint1x2")) ck.verify(y.astype("int32x2") < x.astype("int32x2"), (y < x).astype("uint1x2")) ck.verify(y.astype("int32x2") <= x.astype("int32x2"), (y <= x).astype("uint1x2")) ck.verify(tvm.expr.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")), (tvm.expr.And(y <= x, vc)).astype("uint1x2")) ck.verify(tvm.expr.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")), (tvm.expr.Or(y <= x, vc)).astype("uint1x2"))
def test_vector_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") # Add rules ck.verify( tvm.expr.Ramp(x, 1, 4) + tvm.expr.Ramp(y, 2, 4), tvm.expr.Ramp(x + y, 3, 4)) ck.verify(tvm.expr.Ramp(x, 1, 2) + y, tvm.expr.Ramp(x + y, 1, 2)) ck.verify(y + tvm.expr.Ramp(x, 1, 2), tvm.expr.Ramp(y + x, 1, 2)) ck.verify( y.astype("int32x2") + x.astype("int32x2"), (y + x).astype("int32x2")) # Sub rules ck.verify( tvm.expr.Ramp(x, 4, 4) - tvm.expr.Ramp(y, 2, 4), tvm.expr.Ramp(x - y, 2, 4)) ck.verify(tvm.expr.Ramp(x, 1, 2) - y, tvm.expr.Ramp(x - y, 1, 2)) ck.verify(y - tvm.expr.Ramp(x, 1, 2), tvm.expr.Ramp(y - x, -1, 2)) ck.verify( y.astype("int32x2") - x.astype("int32x2"), (y - x).astype("int32x2")) # Mul rules ck.verify( y.astype("int32x2") * x.astype("int32x2"), (y * x).astype("int32x2")) ck.verify(tvm.expr.Ramp(x, 4, 4) * 2, tvm.expr.Ramp(x * 2, 8, 4)) ck.verify(2 * tvm.expr.Ramp(x, 4, 4), tvm.expr.Ramp(x * 2, 8, 4)) ## DivMod rules # truc div ck.verify( y.astype("int32x2") / x.astype("int32x2"), (y / x).astype("int32x2")) ck.verify(tvm.expr.Ramp(x, 4, 4) / 2, tvm.expr.Ramp(x / 2, 2, 4)) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) / 8, (x).astype("int32x4")) ck.verify( tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8, tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8) ck.verify( y.astype("int32x2") % x.astype("int32x2"), (y % x).astype("int32x2")) ck.verify(tvm.expr.Ramp(x, 4, 4) % 2, tvm.expr.Broadcast(x % 2, 4)) ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) % 8, tvm.expr.Ramp(1, 1, 4)) ck.verify(tvm.expr.Ramp(x * 8 + 1, 15, 4) % 8, tvm.expr.Ramp(1, 15, 4) % 8) # floor div fld = tvm.floordiv flm = tvm.floormod ck.analyzer.update(x, tvm.arith.ConstIntBound(-10, 1000), override=True) ck.verify(fld(y.astype("int32x2"), x.astype("int32x2")), fld(y, x).astype("int32x2")) ck.verify(fld(tvm.expr.Ramp(x, 4, 4), 2), tvm.expr.Ramp(fld(x, 2), 2, 4)) ck.verify(fld(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4")) ck.verify(fld(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8), fld(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8)) ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")), flm(y, x).astype("int32x2")) ck.verify(flm(tvm.expr.Ramp(x, 4, 4), 2), tvm.expr.Broadcast(flm(x, 2), 4)) ck.verify(flm(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8), tvm.expr.Ramp(1, 1, 4)) ck.verify(flm(tvm.expr.Ramp(x * 8 + 1, 15, 4), 8), flm(tvm.expr.Ramp(1, 15, 4), 8)) # Min/Max rules vx = tvm.var("vx", dtype="int32x2") vc = tvm.var("vc", dtype="uint1") ck.verify(tvm.min(y.astype("int32x2"), x.astype("int32x2")), tvm.min(y, x).astype("int32x2")) ck.verify(tvm.min(tvm.min(vx, y.astype("int32x2")), x.astype("int32x2")), tvm.min(vx, tvm.min(y, x).astype("int32x2"))) ck.verify(tvm.max(y.astype("int32x2"), x.astype("int32x2")), tvm.max(y, x).astype("int32x2")) ck.verify(tvm.max(tvm.max(vx, y.astype("int32x2")), x.astype("int32x2")), tvm.max(vx, tvm.max(y, x).astype("int32x2"))) ## Logical rules ck.verify( y.astype("int32x2").equal(x.astype("int32x2")), (y.equal(x)).astype("uint1x2")) ck.verify(tvm.expr.NE(y.astype("int32x2"), (x.astype("int32x2"))), (tvm.expr.NE(y, x)).astype("uint1x2")) ck.verify( y.astype("int32x2") > x.astype("int32x2"), (x < y).astype("uint1x2")) ck.verify( y.astype("int32x2") >= x.astype("int32x2"), (x <= y).astype("uint1x2")) ck.verify( y.astype("int32x2") < x.astype("int32x2"), (y < x).astype("uint1x2")) ck.verify( y.astype("int32x2") <= x.astype("int32x2"), (y <= x).astype("uint1x2")) ck.verify( tvm.expr.And( y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")), (tvm.expr.And(y <= x, vc)).astype("uint1x2")) ck.verify( tvm.expr.Or( y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")), (tvm.expr.Or(y <= x, vc)).astype("uint1x2"))
def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, ratios, feature_stride, rpn_min_size, iou_loss): """Predict bounding boxes based on anchors, scores and deltas. Parameters ---------- cls_prob_buf : tvm.schedule.Buffer 4-D with shape [batch, 2 * num_anchors, height, width] bbox_pred_buf : tvm.schedule.Buffer 4-D with shape [batch, 4 * num_anchors, height, width] im_info_buf : tvm.schedule.Buffer 2-D with shape [batch, 3] out_buf : tvm.schedule.Buffer 3-D with shape [batch, num_bbox, 5] The last dimension is in format of [w_start, h_start, w_end, h_end, score] scales : list/tuple of float Scales of anchor windoes. ratios : list/tuple of float Ratios of anchor windoes. feature_stride : int The size of the receptive field each unit in the convolution layer of the rpn, for example the product of all stride's prior to this layer. rpn_min_size : int Minimum height or width in proposal. iou_loss : bool Usage of IoU loss. Returns ------- stmt : Stmt The result IR statement. """ batch, num_anchors, height, width = get_const_tuple(cls_prob_buf.shape) num_anchors //= 2 max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = (batch * height * width) // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") tid = bx * max_threads + tx ib = tvm.ir_builder.create() ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) p_score = ib.buffer_ptr(cls_prob_buf) p_delta = ib.buffer_ptr(bbox_pred_buf) p_im_info = ib.buffer_ptr(im_info_buf) p_out = ib.buffer_ptr(out_buf) with ib.if_scope(tid < batch * height * width): w = tid % width h = (tid // width) % height b = tid // width // height for k in range(num_anchors): out_index = tid * num_anchors + k ratio = ratios[k // len(scales)] scale = scales[k % len(scales)] anchor = generate_anchor(ratio, scale, feature_stride) im_height = p_im_info[b * 3] im_width = p_im_info[b * 3 + 1] x1 = anchor[0] + w * feature_stride y1 = anchor[1] + h * feature_stride x2 = anchor[2] + w * feature_stride y2 = anchor[3] + h * feature_stride delta = [p_delta[((((b * num_anchors + k) * 4 + i) * height + h) * width + w)] for i in range(4)] regression_func = reg_iou if iou_loss else reg_bbox pred_x1, pred_y1, pred_x2, pred_y2 = regression_func(x1, y1, x2, y2, *delta) pred_x1 = tvm.max(tvm.min(pred_x1, im_width - 1.0), 0.0) pred_y1 = tvm.max(tvm.min(pred_y1, im_height - 1.0), 0.0) pred_x2 = tvm.max(tvm.min(pred_x2, im_width - 1.0), 0.0) pred_y2 = tvm.max(tvm.min(pred_y2, im_height - 1.0), 0.0) real_height = (im_height / feature_stride).astype('int32') real_width = (im_width / feature_stride).astype('int32') bbox_w = pred_x2 - pred_x1 + 1.0 bbox_h = pred_y2 - pred_y1 + 1.0 min_size = p_im_info[b * 3 + 2] * rpn_min_size pred_score = p_score[((b * num_anchors * 2 + num_anchors + k) * height + h) * width + w] pred_score = tvm.expr.Select(tvm.any(h >= real_height, w >= real_width), -1.0, pred_score) p_out[out_index * 5 + 0] = pred_x1 p_out[out_index * 5 + 1] = pred_y1 p_out[out_index * 5 + 2] = pred_x2 p_out[out_index * 5 + 3] = pred_y2 p_out[out_index * 5 + 4] = pred_score with ib.if_scope(tvm.any(bbox_w < min_size, bbox_h < min_size)): p_out[out_index * 5 + 0] -= min_size / 2.0 p_out[out_index * 5 + 1] -= min_size / 2.0 p_out[out_index * 5 + 2] += min_size / 2.0 p_out[out_index * 5 + 3] += min_size / 2.0 p_out[out_index * 5 + 4] = -1.0 return ib.get()
def pool_nhwc(data, kernel, stride, padding, pool_type, ceil_mode=False): """Perform pooling on the data in NHWC layout Parameters ---------- data : tvm.Tensor 4-D with shape [batch, in_height, in_width, channel] kernel : list/tuple of two ints Kernel size, [kernel_height, kernel_width] stride : list/tuple of two ints Stride size, [stride_height, stride_width] paddding : list/tuple of two ints Pad size, [pad_height, pad_width] pool_type : str Pool type, 'max' or 'avg' ceil_mode : bool Whether to use ceil when caculate output size. Returns ------- output : tvm.Tensor 4-D with shape [batch, channel, out_height, out_width] """ assert len(data.shape) == 4, "only support 4-dim pooling" assert len(stride) == 2, "only support 2-dim stride" kernel_height, kernel_width = kernel stride_height, stride_width = stride batch, height, width, channel = data.shape pad_top, pad_left, pad_down, pad_right = get_pad_tuple( padding, (kernel_height, kernel_width)) if ceil_mode: # Additional padding to ensure we do ceil instead of floor when divide stride. pad_down += stride_height - 1 pad_right += stride_width - 1 pad_before = [0, pad_top, pad_left, 0] pad_after = [0, pad_down, pad_right, 0] out_height = util.simplify((height - kernel_height + pad_top + pad_down) // stride_height + 1) out_width = util.simplify((width - kernel_width + pad_left + pad_right) // stride_width + 1) dheight = tvm.reduce_axis((0, kernel_height)) dwidth = tvm.reduce_axis((0, kernel_width)) if pool_type == 'max': temp = pad(data, pad_before, pad_after, name="pad_temp", \ pad_value=tvm.min_value(data.dtype)) return tvm.compute((batch, out_height, out_width, channel), \ lambda n, h, w, c: \ tvm.max(temp[n, h*stride_height+dheight, w*stride_width+dwidth, c], \ axis=[dheight, dwidth]), \ tag="pool_max") elif pool_type == 'avg': temp = pad(data, pad_before, pad_after, name="pad_temp", \ pad_value=tvm.const(0.).astype(data.dtype)) tsum = tvm.compute((batch, out_height, out_width, channel, ), \ lambda n, h, w, c: \ tvm.sum(temp[n, h*stride_height+dheight, w*stride_width+dwidth, c], \ axis=[dheight, dwidth]), \ tag="pool_avg") return tvm.compute((batch, out_height, out_width, channel), \ lambda n, h, w, c: \ tsum[n, h, w, c] / (kernel_height*kernel_width), \ tag=tag.ELEMWISE) else: raise ValueError("Pool type should be 'avg' or 'max'.")
def test_add_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") ck.verify(x + (y - x), y) ck.verify(x - (y + 1) + (y + 1), x) ck.verify((x - 10) + (10 - z), x - z) ck.verify((x - y) + (z - x), z - y) ck.verify(tvm.min(x, y - z) + z, tvm.min(x + z, y)) ck.verify(tvm.min(x - z, y) + z, tvm.min(x, y + z)) ck.verify(tvm.max(x, y - 10) + 10, tvm.max(x + 10, y)) ck.verify(tvm.max(x - 11, y) + 11, tvm.max(x, y + 11)) ck.verify(tvm.max(x, y * 2) + tvm.min(x, y * 2), x + y * 2) ck.verify(tvm.min(x, y * 2) + tvm.max(x, y * 2), x + y * 2) ck.verify(tvm.max(x, y + 2) + (-2), tvm.max(x + (-2), y)) ck.verify(tvm.min(x, y + 2) + (-2), tvm.min(x + (-2), y)) ck.verify(tvm.min(x + 2, y + 3) + (-2), tvm.min(x, y + 1)) ck.verify(tvm.max(0, 1 - x * 4) + x * 4, tvm.max(x * 4, 1)) ck.verify(tvm.max(2 - x * 4, 0) + x * 4, tvm.max(x * 4, 2)) ck.verify(tvm.min(0, 1 - x * 4) + x * 4, tvm.min(x * 4, 1)) ck.verify(tvm.min(2 - x * 4, 0) + x * 4, tvm.min(x * 4, 2)) ck.verify(x * y + x * 10, x * (y + 10)) ck.verify(y * x + x * 10, x * (y + 10)) ck.verify(y * x + 10 * x, x * (y + 10)) ck.verify(x * y + 10 * x, x * (y + 10)) # canonicalization ck.verify(x + 2 + 3 + 4 + x, x * 2 + 9) ck.verify(x + 2 + 3 + 4 + x * 3, x * 4 + 9) # DivMod rules # truc div ck.verify(y * (x % 8) + 10 * (x % 8), (x % 8) * (y + 10)) ck.analyzer.update(x, tvm.arith.ConstIntBound(-1, 1000), override=True) ck.verify((x / 8) * 8 + x % 8, x) # floor div fld = tvm.floordiv flm = tvm.floormod ck.verify(y * flm(x, 8) + 10 * flm(x, 8), flm(x, 8) * (y + 10)) ck.verify(fld(x, 8) * 8 + flm(x, 8), x)
def test_max_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") # const int bound ck.verify(tvm.max(x % 2, y % 2 + 10), y % 2 + 10) ck.verify(tvm.max(x + 1, x + 10), x + 10) ck.verify(tvm.max(x + 111, x + 10), x + 111) ck.verify(tvm.max(x + 1, x), x + 1) ck.verify(tvm.max(x, x + 2), x + 2) ck.verify(tvm.max(1 - x, 2 - x), 2 - x) ck.verify(tvm.max(3 - x, 2 - x), 3 - x) ck.verify(tvm.max((x + 3) / 4 * 4, x), (x + 3) / 4 * 4) ck.verify(tvm.max(tvm.min(x, y), tvm.max(x, y)), tvm.max(x, y)) ck.verify(tvm.max(tvm.min(x, y), tvm.max(y, x)), tvm.max(x, y)) ck.verify(tvm.max(tvm.min(x, y), x), x) ck.verify(tvm.max(tvm.min(y, x), x), x) ck.verify(tvm.max(tvm.max(x, y), x), tvm.max(x, y)) ck.verify(tvm.max(tvm.max(x, y), y), tvm.max(x, y)) ck.verify(tvm.max(x, tvm.min(x, y)), x) ck.verify(tvm.max(x, tvm.min(y, x)), x) ck.verify(tvm.max(x, tvm.max(x, y)), tvm.max(x, y)) ck.verify(tvm.max(y, tvm.max(x, y)), tvm.max(x, y)) ck.verify(tvm.max(tvm.max(tvm.max(x, y), z), y), tvm.max(tvm.max(x, y), z)) ck.verify(tvm.max(tvm.max(tvm.max(tvm.max(x, y), z), x * 2), y), tvm.max(tvm.max(tvm.max(x, y), z), x * 2)) ck.verify(tvm.max(tvm.max(tvm.max(tvm.max(tvm.max(x, y), z), x * 2), z * 2), y), tvm.max(tvm.max(tvm.max(tvm.max(x, y), z), x * 2), z * 2)) ck.verify(tvm.max(tvm.min(x, y), tvm.min(x, z)), tvm.min(tvm.max(y, z), x)) ck.verify(tvm.max(tvm.min(x, y), tvm.min(z, x)), tvm.min(tvm.max(y, z), x)) ck.verify(tvm.max(tvm.min(y, x), tvm.min(x, z)), tvm.min(tvm.max(y, z), x)) ck.verify(tvm.max(tvm.min(y, x), tvm.min(z, x)), tvm.min(tvm.max(y, z), x)) ck.verify(tvm.max(y + x, z + x), tvm.max(y, z) + x) ck.verify(tvm.max(y + x, x + z), tvm.max(y, z) + x) ck.verify(tvm.max(x + y, z + x), tvm.max(y, z) + x) ck.verify(tvm.max(x + y, x + z), tvm.max(y, z) + x) ck.verify(tvm.max(x - y, x - z), x - tvm.min(y, z)) ck.verify(tvm.max(y - x, z - x), tvm.max(y, z) - x) ck.verify(tvm.max(tvm.max(x, 1), 10), tvm.max(x, 10)) ck.verify(tvm.max(tvm.max(x, 11), 10), tvm.max(x, 11)) ck.verify(tvm.max(x / 10, y / 10), tvm.max(x, y) / 10) ck.verify(tvm.max(x / (-10), y / (-10)), tvm.min(x, y) / (-10)) ck.verify(tvm.max(x * 3, 9), tvm.max(x, 3) * 3)
def test_sub_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") ck.verify(x + y - y, x) ck.verify(x + y - x, y) ck.verify(x - (y + x), 0 - y) ck.verify(x - (x + y), 0 - y) ck.verify(tvm.min(x, y) - x, tvm.min(0, y - x)) ck.verify(tvm.min(x, y) - y, tvm.min(x - y, 0)) ck.verify(tvm.max(x, y) - x, tvm.max(0, y - x)) ck.verify(tvm.max(x, y) - y, tvm.max(x - y, 0)) ck.verify(x - tvm.min(x, y), tvm.max(0, x - y)) ck.verify(y - tvm.min(x, y), tvm.max(y - x, 0)) ck.verify(x - tvm.max(x, y), tvm.min(0, x - y)) ck.verify(y - tvm.max(x, y), tvm.min(y - x, 0)) # mul co-efficient foldng ck.verify(x - x, 0) ck.verify(x * y - x, x * (y + (-1))) ck.verify(x * y - 10 * x, x * (y + (-10))) ck.verify(y * x - x * z, x * (y - z)) ck.verify(y * x - z * x, x * (y - z)) ck.verify(x + 10 - 20, x + (-10)) # 4-operands pattern ck.verify((x + y) - (x + z), y - z) ck.verify((y + x) - (x + z), y - z) ck.verify((x + y) - (z + x), y - z) ck.verify((y + x) - (z + x), y - z) ck.verify(tvm.min(x + y, z) - x, tvm.min(y, z - x)) ck.verify(tvm.min(y + x, z) - x, tvm.min(y, z - x)) ck.verify(tvm.min(z, x + y) - x, tvm.min(z - x, y)) ck.verify(tvm.min(z, y + x) - x, tvm.min(z - x, y)) ck.verify(x - tvm.min(x + y, z), tvm.max(0 - y, x - z)) ck.verify(x - tvm.min(y + x, z), tvm.max(0 - y, x - z)) ck.verify(x - tvm.min(z, x + y), tvm.max(x - z, 0 - y)) ck.verify(x - tvm.min(z, y + x), tvm.max(x - z, 0 - y)) ck.verify(tvm.min(x, y) - tvm.min(y, x), 0) ck.verify(tvm.max(x, y) - tvm.max(y, x), 0) ck.verify(tvm.min(x, y) - tvm.min(x + 10, y + 10), -10) ck.verify(tvm.min(x + 10, y + 1) - tvm.min(x, y - 9), 10) # div pattern ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(x - (x / 3) * 3, x % 3) ck.verify((x + 5) / 3 - x / 3, (((x + 2) % 3) + 5)/ 3)
def test_sub_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") ck.verify(x + y - y, x) ck.verify(x + y - x, y) ck.verify(x - (y + x), 0 - y) ck.verify(x - (x + y), 0 - y) ck.verify(tvm.min(x, y) - x, tvm.min(0, y - x)) ck.verify(tvm.min(x, y) - y, tvm.min(x - y, 0)) ck.verify(tvm.max(x, y) - x, tvm.max(0, y - x)) ck.verify(tvm.max(x, y) - y, tvm.max(x - y, 0)) ck.verify(x - tvm.min(x, y), tvm.max(0, x - y)) ck.verify(y - tvm.min(x, y), tvm.max(y - x, 0)) ck.verify(x - tvm.max(x, y), tvm.min(0, x - y)) ck.verify(y - tvm.max(x, y), tvm.min(y - x, 0)) # mul co-efficient foldng ck.verify(x - x, 0) ck.verify(x * y - x, x * (y + (-1))) ck.verify(x * y - 10 * x, x * (y + (-10))) ck.verify(y * x - x * z, x * (y - z)) ck.verify(y * x - z * x, x * (y - z)) ck.verify(x + 10 - 20, x + (-10)) # 4-operands pattern ck.verify((x + y) - (x + z), y - z) ck.verify((y + x) - (x + z), y - z) ck.verify((x + y) - (z + x), y - z) ck.verify((y + x) - (z + x), y - z) ck.verify(tvm.min(x + y, z) - x, tvm.min(y, z - x)) ck.verify(tvm.min(y + x, z) - x, tvm.min(y, z - x)) ck.verify(tvm.min(z, x + y) - x, tvm.min(z - x, y)) ck.verify(tvm.min(z, y + x) - x, tvm.min(z - x, y)) ck.verify(tvm.max(x + y, z) - x, tvm.max(y, z - x)) ck.verify(tvm.max(y + x, z) - x, tvm.max(y, z - x)) ck.verify(tvm.max(z, x + y) - x, tvm.max(z - x, y)) ck.verify(tvm.max(z, y + x) - x, tvm.max(z - x, y)) ck.verify(x - tvm.min(x + y, z), tvm.max(0 - y, x - z)) ck.verify(x - tvm.min(y + x, z), tvm.max(0 - y, x - z)) ck.verify(x - tvm.min(z, x + y), tvm.max(x - z, 0 - y)) ck.verify(x - tvm.min(z, y + x), tvm.max(x - z, 0 - y)) ck.verify(tvm.min(x, y) - tvm.min(y, x), 0) ck.verify(tvm.max(x, y) - tvm.max(y, x), 0) ck.verify(tvm.min(x, y) - tvm.min(x + 10, y + 10), -10) ck.verify(tvm.min(x + 10, y + 1) - tvm.min(x, y - 9), 10) # DivMod patterns # truc div ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(x - (x / 3) * 3, x % 3) ck.verify((x + 5) / 3 - x / 3, ((x % 3) + 5) / 3) ck.verify((x + 5) / 3 - (x + 1) / 3, (((x + 1) % 3) + 4) / 3) ck.verify(y - (y / (-5)) * (-5), y % 5) ck.verify((y / 3) * 3 - y, 0 - y % 3) ck.verify(y - ((y - 6) / 5) * 5, (y + (-6)) % 5 + 6) ck.verify(((y - 6) / 5) * 5 - y, (-6) - (y + (-6)) % 5) ck.verify(y - ((y + z) / 5) * 5, (y + z) % 5 - z) ck.verify(((y + z) / 5) * 5 - y, z - (y + z) % 5) ck.verify(y - ((y - z) / 5) * 5, (y - z) % 5 + z) ck.verify(((y - z) / 5) * 5 - y, 0 - (y - z) % 5 - z) ck.verify(y * 3 - (y / 2) * 6, (y % 2) * 3) ck.verify((y / 3) * 6 - y * 2, (y % 3) * (-2)) ck.verify(y * 5 - ((y + z) / 2) * 10, ((y + z) % 2 - z) * 5) ck.verify(y * 5 - ((y - z) / 2) * 10, ((y - z) % 2 + z) * 5) ck.verify(((y + z) / 3) * 6 - y * 2, (z - (y + z) % 3) * 2) ck.verify(((y - z) / 3) * 6 - y * 2, (0 - (y - z) % 3 - z) * 2) ck.verify(5 * y - ((y + z) / 2) * 10, ((y + z) % 2 - z) * 5) ck.verify(5 * y - 10 * ((y - z) / 2), ((y - z) % 2 + z) * 5) ck.verify(6 * ((y + z) / 3) - y * 2, (z - (y + z) % 3) * 2) ck.verify(((y - z) / 3) * 6 - 2 * y, (0 - (y - z) % 3 - z) * 2) # floor div fld = tvm.floordiv flm = tvm.floormod ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(-1000, 1000), override=True) ck.verify(x - fld(x, 3) * 3, flm(x, 3)) ck.verify(fld(x + 5, 3) - fld(x, 3), fld(flm(x, 3) + 5, 3)) ck.verify(fld(x + 5, 3) - fld(x + 2, 3), fld(flm(x + 2, 3), 3) + 1) ck.verify(fld(y, 3) * 3 - y, 0 - flm(y, 3)) ck.verify(y - fld(y - 6, 5) * 5, flm(y + (-6), 5) + 6) ck.verify(fld(y - 6, 5) * 5 - y, (-6) - flm(y + (-6), 5)) ck.verify(y - fld(y + z, 5) * 5, flm(y + z, 5) - z) ck.verify(fld(y + z, 5) * 5 - y, z - flm(y + z, 5)) ck.verify(y - fld(y - z, 5) * 5, flm(y - z, 5) + z) ck.verify(fld(y - z, 5) * 5 - y, 0 - flm(y - z, 5) - z) ck.verify(y * 3 - fld(y, 2) * 6, flm(y, 2) * 3) ck.verify(fld(y, 3) * 6 - y * 2, flm(y, 3) * (-2)) ck.verify(y * 5 - fld(y + z, 2) * 10, (flm(y + z, 2) - z) * 5) ck.verify(y * 5 - fld(y - z, 2) * 10, (flm(y - z, 2) + z) * 5) ck.verify(fld(y + z, 3) * 6 - y * 2, (z - flm(y + z, 3)) * 2) ck.verify(fld(y - z, 3) * 6 - y * 2, (0 - flm(y - z, 3) - z) * 2) ck.verify(5 * y - fld(y + z, 2) * 10, (flm(y + z, 2) - z) * 5) ck.verify(5 * y - 10 * fld(y - z, 2), (flm(y - z, 2) + z) * 5) ck.verify(6 * fld(y + z, 3) - y * 2, (z - flm(y + z, 3)) * 2) ck.verify(fld(y - z, 3) * 6 - 2 * y, (0 - flm(y - z, 3) - z) * 2)
def test_sub_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") ck.verify(x + y - y, x) ck.verify(x + y - x, y) ck.verify(x - (y + x), 0 - y) ck.verify(x - (x + y), 0 - y) ck.verify(tvm.min(x, y) - x, tvm.min(0, y - x)) ck.verify(tvm.min(x, y) - y, tvm.min(x - y, 0)) ck.verify(tvm.max(x, y) - x, tvm.max(0, y - x)) ck.verify(tvm.max(x, y) - y, tvm.max(x - y, 0)) ck.verify(x - tvm.min(x, y), tvm.max(0, x - y)) ck.verify(y - tvm.min(x, y), tvm.max(y - x, 0)) ck.verify(x - tvm.max(x, y), tvm.min(0, x - y)) ck.verify(y - tvm.max(x, y), tvm.min(y - x, 0)) # mul co-efficient foldng ck.verify(x - x, 0) ck.verify(x * y - x, x * (y + (-1))) ck.verify(x * y - 10 * x, x * (y + (-10))) ck.verify(y * x - x * z, x * (y - z)) ck.verify(y * x - z * x, x * (y - z)) ck.verify(x + 10 - 20, x + (-10)) # 4-operands pattern ck.verify((x + y) - (x + z), y - z) ck.verify((y + x) - (x + z), y - z) ck.verify((x + y) - (z + x), y - z) ck.verify((y + x) - (z + x), y - z) ck.verify(tvm.min(x + y, z) - x, tvm.min(y, z - x)) ck.verify(tvm.min(y + x, z) - x, tvm.min(y, z - x)) ck.verify(tvm.min(z, x + y) - x, tvm.min(z - x, y)) ck.verify(tvm.min(z, y + x) - x, tvm.min(z - x, y)) ck.verify(x - tvm.min(x + y, z), tvm.max(0 - y, x - z)) ck.verify(x - tvm.min(y + x, z), tvm.max(0 - y, x - z)) ck.verify(x - tvm.min(z, x + y), tvm.max(x - z, 0 - y)) ck.verify(x - tvm.min(z, y + x), tvm.max(x - z, 0 - y)) ck.verify(tvm.min(x, y) - tvm.min(y, x), 0) ck.verify(tvm.max(x, y) - tvm.max(y, x), 0) ck.verify(tvm.min(x, y) - tvm.min(x + 10, y + 10), -10) ck.verify(tvm.min(x + 10, y + 1) - tvm.min(x, y - 9), 10) # div pattern ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(x - (x / 3) * 3, x % 3) ck.verify((x + 5) / 3 - x / 3, (((x + 2) % 3) + 5) / 3)
def test_min_index_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") # const int bound ck.verify(tvm.min(x % 2, y % 2 + 10), x % 2) ck.verify(tvm.min(x + 1, x + 10), x + 1) ck.verify(tvm.min(x + 111, x + 10), x + 10) ck.verify(tvm.min(x + 1, x), x) ck.verify(tvm.min(x, x + 2), x) ck.verify(tvm.min(1 - x, 2 - x), 1 - x) ck.verify(tvm.min(3 - x, 2 - x), 2 - x) ck.verify(tvm.min((x + 3) / 4 * 4, x), x) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000)) ck.verify(tvm.min((x + 3) / 4 * 4, tvm.max(x, 4)), tvm.max(x, 4)) ck.verify(tvm.min(x, (x + 3) / 4 * 4), x) ck.verify(tvm.min(tvm.max(x, 4), (x + 3) / 4 * 4), tvm.max(x, 4)) ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True) ck.verify(tvm.min(tvm.max(x, y), tvm.min(x, y)), tvm.min(x, y)) ck.verify(tvm.min(tvm.max(x, y), tvm.min(y, x)), tvm.min(x, y)) ck.verify(tvm.min(tvm.max(x, y), x), x) ck.verify(tvm.min(tvm.max(y, x), x), x) ck.verify(tvm.min(tvm.min(x, y), x), tvm.min(x, y)) ck.verify(tvm.min(tvm.min(x, y), y), tvm.min(x, y)) ck.verify(tvm.min(x, tvm.max(x, y)), x) ck.verify(tvm.min(x, tvm.max(y, x)), x) ck.verify(tvm.min(x, tvm.min(x, y)), tvm.min(x, y)) ck.verify(tvm.min(y, tvm.min(x, y)), tvm.min(x, y)) ck.verify(tvm.min(tvm.min(tvm.min(x, y), z), y), tvm.min(tvm.min(x, y), z)) ck.verify(tvm.min(tvm.min(tvm.min(tvm.min(x, y), z), x * 2), y), tvm.min(tvm.min(tvm.min(x, y), z), x * 2)) ck.verify(tvm.min(tvm.min(tvm.min(tvm.min(tvm.min(x, y), z), x * 2), z * 2), y), tvm.min(tvm.min(tvm.min(tvm.min(x, y), z), x * 2), z * 2)) ck.verify(tvm.min(tvm.max(x, y), tvm.max(x, z)), tvm.max(tvm.min(y, z), x)) ck.verify(tvm.min(tvm.max(x, y), tvm.max(z, x)), tvm.max(tvm.min(y, z), x)) ck.verify(tvm.min(tvm.max(y, x), tvm.max(x, z)), tvm.max(tvm.min(y, z), x)) ck.verify(tvm.min(tvm.max(y, x), tvm.max(z, x)), tvm.max(tvm.min(y, z), x)) ck.verify(tvm.min(y + x, z + x), tvm.min(y, z) + x) ck.verify(tvm.min(y + x, x + z), tvm.min(y, z) + x) ck.verify(tvm.min(x + y, z + x), tvm.min(y, z) + x) ck.verify(tvm.min(x + y, x + z), tvm.min(y, z) + x) ck.verify(tvm.min(x - y, x - z), x - tvm.max(y, z)) ck.verify(tvm.min(y - x, z - x), tvm.min(y, z) - x) ck.verify(tvm.min(tvm.min(x, 1), 10), tvm.min(x, 1)) ck.verify(tvm.min(tvm.min(x, 11), 10), tvm.min(x, 10)) ck.verify(tvm.min(x / 10, y / 10), tvm.min(x, y) / 10) ck.verify(tvm.min(x / (-10), y / (-10)), tvm.max(x, y) / (-10)) ck.verify(tvm.min(x * 3, 9), tvm.min(x, 3) * 3)
def test_cmp_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") # const int bound ck.verify((x % 2 + 10).equal(0), tvm.const(0, "bool")) ck.verify(tvm.expr.NE(x % 2 + 10, 0), tvm.const(1, "bool")) ck.verify(x % 2 + 10 > 1, tvm.const(1, "bool")) ck.verify(x % 2 + 10 <= 1, tvm.const(0, "bool")) ck.verify(x * 3 + 10 == 0, tvm.const(0, "bool")) ck.verify(x * 3 + 10 != 0, tvm.const(1, "bool")) # canonicalization ck.verify((x - 10).equal(0), x.equal(10)) ck.verify((10 - x).equal(0), x.equal(10)) ck.verify((x * y).equal(0), tvm.expr.Or(x.equal(0), y.equal(0))) # cmp bound ck.verify(x + y < x + z, y < z) ck.verify(x + y < z + x, y < z) ck.verify(y + x < x + z, y < z) ck.verify(y + x < z + x, y < z) ck.verify(y - x < z - x, y < z) ck.verify(x - y < x - z, z < y) ck.verify(x < z + x, tvm.expr.LT(0, z)) ck.verify(x < x + z, tvm.expr.LT(0, z)) ck.verify(100 < x + 1, tvm.expr.LT(99, x)) ck.verify(1 < 100 - x, tvm.expr.LT(x, 99)) ck.verify(x * 3 < y * 3, x < y) ck.verify(x * (-3) < y * (-3), y < x) ck.verify(x * 3 >= y * 3, y <= x) ck.verify(x * 4 >= 2, tvm.expr.LE(1, x)) ck.verify(x * 2 >= 50, tvm.expr.LE(25, x)) ck.verify(x / 2 < 3, x < 6) ck.verify(x * 4 <= 2, x <= 0) ck.verify(3 < x / 2, tvm.expr.LT(7, x)) ck.verify(x / 4 * 4 < x, tvm.expr.LT(0, x % 4)) ck.verify(x / 4 * 4 >= x, tvm.expr.LE(x % 4, 0)) ck.verify(x / 4 * 4 < x + y, tvm.expr.LT(0, x % 4 + y)) ck.verify(x / 4 * 4 < x - y, tvm.expr.LT(y, x % 4)) ck.verify((x + 2) / 4 * 4 >= x, tvm.expr.LE((x + 2) % 4, 2)) ck.verify((x + 2) / 4 * 4 >= x + y, tvm.expr.LE((x + 2) % 4 + y, 2)) ck.verify((x + 2) / 4 * 4 >= x - y, tvm.expr.LE((x + 2) % 4 + (-2), y)) ck.verify(tvm.min(x, 11) < 10, x < 10) ck.verify(tvm.min(x, 8) < 10, tvm.const(1, "bool")) ck.verify(tvm.max(8, x) > 10, tvm.expr.LT(10, x)) ck.verify(x + 1 < tvm.max(8, x), x < 7) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(-10, 0), override=True) ck.analyzer.update(z, tvm.arith.ConstIntBound(-5, 5), override=True) ck.verify(x < 11, tvm.const(1, "bool")) ck.verify(x <= 10, tvm.const(1, "bool")) ck.verify(z <= 5, tvm.const(1, "bool")) ck.verify(x + y <= 10, tvm.const(1, "bool")) ck.verify(x + y >= -10, tvm.const(1, "bool")) ck.verify(z - 5 <= y + 10, tvm.const(1, "bool")) ck.verify(tvm.all(x > -1, z <= x + 5), tvm.const(1, "bool")) ck.verify(x*y <= 0, tvm.const(1, "bool")) ck.verify((x + 1)*(y - 1) < 0, tvm.const(1, "bool")) ck.verify(y*y >= 0, tvm.const(1, "bool"))
def f(n): rv = tvm.reduce_axis((0, n)) init = lambda dtype: tvm.select(n > 1, tvm.const(0, dtype), n.astype(dtype)) sum = tvm.comm_reducer(lambda x, y: tvm.max(x + y, n.astype('float32')), init, name='sum') return sum(X[rv], axis=rv)