def test_inverse_affine_iter_map(): analyzer = tvm.arith.Analyzer() l0 = create_iter("l0", 64) l1 = create_iter("l1", 64) l2 = create_iter("l2", 64) # simple case l0_0, l0_1 = isplit(l0, 16) l1_0, l1_1 = isplit(l1, 4) l0_1_l1_1_fused = ifuse([l0_1, l1_1]) iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1])) outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 2 l0_inverse = floormod(floordiv(outputs[0], 4), 16) + outputs[1] * 16 l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4 assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0 # compound case l0_0, l0_1 = isplit(l0, 16) l1_0, l1_1 = isplit(l1, 4) l2_1, l2_2 = isplit(l2, 4) l2_0, l2_1 = isplit(l2_1, 4) l0_1_l2_1_l1_1_l2_0_fused = ifuse([l0_1, l2_1, l1_1, l2_0]) iter_map = tvm.arith.detect_iter_map( [l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], var_dom([l0, l1, l2]) ) outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 3 l0_inverse = floormod(floordiv(outputs[0], 64), 16) + outputs[1] * 16 l1_inverse = floormod(floordiv(outputs[0], 4), 4) + outputs[3] * 4 l2_inverse = ( floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 4 + outputs[2] ) assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0 assert analyzer.simplify(res[l2[0]] - l2_inverse) == 0 # diamond-shape DAG l0_0, l0_1 = isplit(l0, 16) l1 = ifuse([l0_1, l0_0]) l1_0, l1_1 = isplit(l1, 8) l2 = ifuse([l1_1, l1_0]) iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0])) outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 1 l1_inverse = floormod(outputs[0], 8) * 8 + floormod(floordiv(outputs[0], 8), 8) l0_inverse = floormod(l1_inverse, 4) * 16 + floormod(floordiv(l1_inverse, 4), 16) assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
def elementwise_fused(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for fused in tir.serial(0, 2097152): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, tir.floordiv(fused, 16384)) tir.bind(vj, tir.floormod(tir.floordiv(fused, 128), 128)) tir.bind(vk, tir.floormod(fused, 128)) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def elementwise_symbolic_fused(a: ty.handle, b: ty.handle, n: ty.int32) -> None: A = tir.match_buffer(a, (128, 128, n)) B = tir.match_buffer(b, (128, 128, n)) for i_j_k_fused in tir.serial(0, (n * 16384)): with tir.block([128, 128, n], "B") as [vi, vj, vk]: tir.bind(vi, tir.floordiv(i_j_k_fused, (n * 128))) tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, n), 128)) tir.bind(vk, tir.floormod(i_j_k_fused, n)) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def test_div_by_zero(): a = te.var(name='a',dtype='int32') b = te.var(name='b',dtype='int32') zero = tir.const(0) two = tir.const(2) fzero = tir.const(0.0) ftwo = tir.const(2.0) for int_bin_op in [lambda a,b: a % b, lambda a,b: tir.floordiv(a,b), lambda a,b: tir.truncdiv(a,b), lambda a,b: tir.floormod(a,b), lambda a,b: tir.truncmod(a,b) ]: try: int_bin_op(a,zero) except TVMError: pass try: int_bin_op(two,zero) except TVMError: pass for float_bin_op in [lambda a,b: tir.div(a,b), lambda a,b: tir.truncmod(a,b)]: try: float_bin_op(a,fzero) except TVMError: pass try: float_bin_op(ftwo,fzero) except TVMError: pass
def test_suggest_index_map_bijective(): i, j = _make_vars("i", "j") index_map = suggest_index_map( buffer=decl_buffer(shape=[8]), indices=[floormod(j, 4) * 2 + i], loops=_make_loops( loop_vars=[i, j], extents=[2, 32], ), predicate=True, ) expected_index_map = IndexMap.from_func( lambda x: [ floormod(x, 2), floordiv(x, 2), ], ) assert index_map.is_equivalent_to(expected_index_map)
def rowsum_transformed(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.match_buffer(b, (128, )) for io, ii_ko_fused, ki in tir.grid(32, 128, 4): with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: tir.bind(vi, io * 4 + tir.floordiv(ii_ko_fused, 32)) tir.bind(vk, tir.floormod(ii_ko_fused, 32) * 4 + ki) with tir.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def opaque_access_fused(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, [16, 16]) B = tir.match_buffer(b, [16, 16]) for i_j_fused in tir.serial(0, 256): with tir.block([16, 16], "A") as [vi, vj]: tir.bind(vi, tir.floordiv(i_j_fused, 16)) tir.bind(vj, tir.floormod(i_j_fused, 16)) tir.reads([]) tir.writes([A[0:16, 0:16]]) tir.store(A.data, ((vi * 16) + vj), 1, 1) for i_j_fused in tir.serial(0, 256): with tir.block([16, 16], "B") as [vi, vj]: tir.bind(vi, tir.floordiv(i_j_fused, 16)) tir.bind(vj, tir.floormod(i_j_fused, 16)) tir.reads([]) tir.writes([B[0:16, 0:16]]) tir.evaluate( tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle"))
def test_suggest_index_map_simple(): i, j = _make_vars("i", "j") index_map = suggest_index_map( buffer=decl_buffer(shape=[8, 256]), indices=[ floordiv(i, 16) * 4 + floordiv(j, 16), floormod(i, 16) * 16 + floormod(j, 16), ], loops=_make_loops( loop_vars=[i, j], extents=[32, 64], ), predicate=True, ) expected_index_map = IndexMap.from_func( lambda x, y: [ floordiv(x, 4), floordiv(y, 16), floormod(x, 4), floormod(y, 16), ], ) assert index_map.is_equivalent_to(expected_index_map)
def gen_ir( data_ptr, n_fft, hop_length, win_length, window_ptr, normalized, onesided, output_ptr, loop_kind, ): ib = tir.ir_builder.create() data = ib.buffer_ptr(data_ptr) window = ib.buffer_ptr(window_ptr) output = ib.buffer_ptr(output_ptr) # https://librosa.org/doc/0.7.2/_modules/librosa/core/spectrum.html#stft with ib.for_range(0, output_ptr.shape[0] * output_ptr.shape[1], kind="parallel") as batch_row: with ib.for_range(0, output_ptr.shape[2], kind=loop_kind) as col: batch = ib.allocate("int32", (1), name="batch", scope="local") row = ib.allocate("int32", (1), name="row", scope="local") batch = tir.floordiv(batch_row, output_ptr.shape[1]) row = tir.floormod(batch_row, output_ptr.shape[1]) output[batch, row, col, 0] = tir.Cast(data_ptr.dtype, 0) output[batch, row, col, 1] = tir.Cast(data_ptr.dtype, 0) with ib.for_range(0, win_length) as wlen: output[batch, row, col, 0] += (window[wlen] * data[batch, col * hop_length + wlen] * tir.cos(2 * pi * row * wlen / win_length)) output[batch, row, col, 1] -= (window[wlen] * data[batch, col * hop_length + wlen] * tir.sin(2 * pi * row * wlen / win_length)) with ib.if_scope(normalized): output[batch, row, col, 0] /= tir.sqrt(tir.const(n_fft, "float32")) output[batch, row, col, 1] /= tir.sqrt(tir.const(n_fft, "float32")) return ib.get()
def elementwise_fuse_with_opaque_block(a: ty.handle, b: ty.handle) -> None: B = tir.match_buffer(b, [128, 128, 128]) A = tir.match_buffer(a, [128, 128, 128]) for i_j_k_fused in tir.serial(0, 2097152): with tir.block([], "opaque"): tir.reads([ A[tir.floormod( tir.floordiv(tir.floordiv(i_j_k_fused, 128), 128), 128), tir.floormod(tir.floordiv(i_j_k_fused, 128), 128), tir.floormod(i_j_k_fused, 128), ] ]) tir.writes([ B[tir.floormod( tir.floordiv(tir.floordiv(i_j_k_fused, 128), 128), 128), tir.floormod(tir.floordiv(i_j_k_fused, 128), 128), tir.floormod(i_j_k_fused, 128), ] ]) with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, tir.floordiv(i_j_k_fused, 16384)) tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, 128), 128)) tir.bind(vk, tir.floormod(i_j_k_fused, 128)) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def transformed_square_sum_square_root(a: ty.handle, d: ty.handle) -> None: A = tir.match_buffer(a, [16, 256, 256]) D = tir.match_buffer(d, [16]) C = tir.alloc_buffer([16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1): with tir.block( [16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C") as [b, i, j]: tir.bind(b, i0) tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256)) tir.bind(j, tir.floormod(i1_i2_fused_outer, 256)) tir.reads([C[b], A[b, i, j]]) tir.writes([C[b]]) with tir.init(): C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in tir.serial(0, 16): with tir.block([16], "D") as [b_1]: tir.bind(b_1, i0_1) tir.reads([C[b_1]]) tir.writes([D[b_1]]) D[b_1] = tir.sqrt(C[b_1], dtype="float32")
def square_sum_square_root_rfactor(a: ty.handle, d: ty.handle) -> None: A = tir.match_buffer(a, [16, 256, 256]) D = tir.match_buffer(d, [16]) C = tir.alloc_buffer([16]) C_rf = tir.alloc_buffer([1, 16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1): with tir.block( [1, 16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C_rf") as [ vi1_i2_fused_inner, b, i, j, ]: tir.bind(vi1_i2_fused_inner, i1_i2_fused_inner) tir.bind(b, i0) tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256)) tir.bind(j, tir.floormod(i1_i2_fused_outer, 256)) with tir.init(): C_rf[vi1_i2_fused_inner, b] = 0.0 C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) for i0_1, i1_i2_fused_inner_1 in tir.grid(16, 1): with tir.block([tir.reduce_axis(0, 1), 16], "C") as [vi1_i2_fused_inner_1, b_1]: tir.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1) tir.bind(b_1, i0_1) with tir.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] for i0_2 in tir.serial(0, 16): with tir.block([16], "D") as [b_2]: tir.bind(b_2, i0_2) D[b_2] = tir.sqrt(C[b_2], dtype="float32")
substituting def of floorDiv gives floorMod(a,b) = a - floor(a / b) * b """ def floormod(a,b): return a - floor(a / b) * b DIM = 1000 HDIM = 500 shape = (DIM,DIM) c_tvm = tvm.nd.array(np.zeros(shape=shape,dtype='int32')) c_np = np.zeros(shape) c = te.compute(shape,lambda i,j: tir.floormod(HDIM - i,j + 1) ) d = te.compute(shape,lambda i,j: tir.floormod(HDIM - i,-(j + 1))) s = te.create_schedule([c.op]) s2 = te.create_schedule([d.op]) f = tvm.build(s,[c]) f2 = tvm.build(s2,[d]) f(c_tvm) out = c_tvm.asnumpy() for i in range(DIM): for j in range(DIM): res = out[i][j] res2 = floormod(HDIM - i, j + 1) if res != res2: print(i,j,res,res2) assert False
def test_complex(): n0 = create_iter("n0", 2) n1 = create_iter("n1", 4) m0 = ifuse([n0, n1], 6) m1 = create_iter("m1", 3) l0 = create_iter("l0", 4) l1 = create_iter("l1", 8) l2 = ifuse([m0, m1], 16) l3 = create_iter("l3", 32) k0, k4 = isplit(l0, 2) k1, k5 = isplit(l1, 2) k2, k6 = isplit(l2, 4) k3, k7 = isplit(l3, 4) j0 = ifuse([k0, k1], 7) j1 = ifuse([k2, k3]) j2 = ifuse([k4, k5]) j3 = ifuse([k6, k7], 15) i0 = ifuse([j0, j1], 200) i1 = ifuse([j2, j3], 50) res = tvm.arith.detect_iter_map( [i0[0], i1[0]], var_dom([l0, l1, n0, n1, m1, l3]), tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15), ) assert len(res) == 2 n0_mark = tvm.arith.IterMark(n0[0], n0[1]) n1_mark = tvm.arith.IterMark(n1[0], n1[1]) l0_mark = tvm.arith.IterMark(l0[0], l0[1]) l1_mark = tvm.arith.IterMark(l1[0], l1[1]) m1_mark = tvm.arith.IterMark(m1[0], m1[1]) l3_mark = tvm.arith.IterMark(l3[0], l3[1]) m0_expr = tvm.arith.IterSumExpr( [ tvm.arith.IterSplitExpr(n0_mark, 1, n0[1], 4), tvm.arith.IterSplitExpr(n1_mark, 1, n1[1], 1), ], 0, ) m0_mark = tvm.arith.IterMark(m0_expr, 6) l2_expr = tvm.arith.IterSumExpr( [tvm.arith.IterSplitExpr(m0_mark, 1, 6, 3), tvm.arith.IterSplitExpr(m1_mark, 1, m1[1], 1)], 0, ) l2_mark = tvm.arith.IterMark(l2_expr, 16) k0_expr = tvm.arith.IterSplitExpr(l0_mark, 2, 2, 4) k1_expr = tvm.arith.IterSplitExpr(l1_mark, 2, 4, 1) k2_expr = tvm.arith.IterSplitExpr(l2_mark, 4, 4, 8) k3_expr = tvm.arith.IterSplitExpr(l3_mark, 4, 8, 1) k4_expr = tvm.arith.IterSplitExpr(l0_mark, 1, 2, 30) k5_expr = tvm.arith.IterSplitExpr(l1_mark, 1, 2, 15) k6_expr = tvm.arith.IterSplitExpr(l2_mark, 1, 4, 4) k7_expr = tvm.arith.IterSplitExpr(l3_mark, 1, 4, 1) j0_expr = tvm.arith.IterSumExpr([k0_expr, k1_expr], 0) j0_mark = tvm.arith.IterMark(j0_expr, 7) i0_expr = tvm.arith.IterSumExpr( [tvm.arith.IterSplitExpr(j0_mark, 1, 7, 32), k2_expr, k3_expr], 0 ) j3_expr = tvm.arith.IterSumExpr([k6_expr, k7_expr], 0) j3_mark = tvm.arith.IterMark(j3_expr, 15) i1_expr = tvm.arith.IterSumExpr( [k4_expr, k5_expr, tvm.arith.IterSplitExpr(j3_mark, 1, 15, 1)], 0 ) i0_mark = tvm.arith.IterMark(i0_expr, i0[1]) i1_mark = tvm.arith.IterMark(i1_expr, i1[1]) i0_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i0_mark, 1, i0[1], 1)], 0) i1_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i1_mark, 1, i1[1], 1)], 0) tvm.ir.assert_structural_equal(i0_final, res[0]) tvm.ir.assert_structural_equal(i1_final, res[1]) # wrong constraint res = tvm.arith.detect_iter_map( [i0[0], i1[0]], var_dom([l0, l1, n0, n1, m1, l3]), tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 9, l2[0] < 16, j0[0] < 7, j3[0] < 14), ) assert len(res) == 0 # subspace_division res = tvm.arith.subspace_divide( [i0[0], i1[0]], var_dom([l0, l1, n0, n1, m1, l3]), [n0[0], n1[0], m1[0], l3[0]], tvm.tir.all(m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15), ) res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], floordiv(l0[0], 2) * 4 + floordiv(l1[0], 2)) tvm.ir.assert_structural_equal( res[0][1], (floordiv((n0[0] * 4 + n1[0]) * 3 + m1[0], 4) * 8) + floordiv(l3[0], 4) ) tvm.ir.assert_structural_equal(res[1][0], ((floormod(l0[0], 2) * 2) + floormod(l1[0], 2))) tvm.ir.assert_structural_equal( res[1][1], ((floormod(((n0[0] * 4 + n1[0]) * 3 + m1[0]), 4) * 4) + floormod(l3[0], 4)) ) tvm.ir.assert_structural_equal(res[2][0], (floordiv(l0[0], 2) * 4) + floordiv(l1[0], 2) < 7) tvm.ir.assert_structural_equal( res[2][1], tvm.tir.all( n0[0] * 4 + n1[0] < 6, (n0[0] * 4 + n1[0]) * 3 + m1[0] < 16, floormod(((n0[0] * 4 + n1[0]) * 3 + m1[0]), 4) * 4 + floormod(l3[0], 4) < 15, ), ) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([n0, n1, m1, l3]), res[2][1]) assert len(res1) == 2 res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([l0, l1])) assert len(res2) == 2
def test_subspace_division(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") z = tvm.tir.Var("z", "int32") c = tvm.tir.SizeVar("c", "int32") # simple 1.1 res = tvm.arith.subspace_divide( [z * 12 + y * 3 + x + c], var_dom([(x, 3), (y, 4), (z, 5)]), [x] ) res = convert_division(res) assert len(res) == 2 tvm.ir.assert_structural_equal(res[0][0], z * 4 + y) tvm.ir.assert_structural_equal(res[0][1], x + c) # simple 1.2 res = tvm.arith.subspace_divide( [z * 12 + y * 3 + x + c], var_dom([(x, 3), (y, 4), (z, 5)]), [x], z * 4 + y < 18 ) res = convert_division(res) assert len(res) == 2 tvm.ir.assert_structural_equal(res[0][0], z * 4 + y) tvm.ir.assert_structural_equal(res[0][1], x + c) tvm.ir.assert_structural_equal(res[1][0], z * 4 + y < 18) tvm.ir.assert_structural_equal(res[1][1], True) # compound 1 i0 = create_iter("i0", 4) j0 = create_iter("j0", 8) i3 = create_iter("i3", 2) i1, i2 = isplit(j0, 4) k0 = ifuse([i0, i1]) k1 = ifuse([i2, i3]) # compound 1.1 res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]]) res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4)) tvm.ir.assert_structural_equal(res[0][1], 0) tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][1], i3[0]) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])) assert len(res1) == 2 res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])) assert len(res2) == 2 # compound 1.2 res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [j0[0], i3[0]]) res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], i0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][0], 0) tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])) assert len(res1) == 2 res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])) assert len(res2) == 2 # compound 1.3 res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i0[0], i3[0]]) res = convert_division(res) assert len(res) == 0 # compound 1.4 res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]], k0[0] < 7) res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4)) tvm.ir.assert_structural_equal(res[0][1], 0) tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][1], i3[0]) tvm.ir.assert_structural_equal(res[2][0], (i0[0] * 2) + floordiv(j0[0], 4) < 7) tvm.ir.assert_structural_equal(res[2][1], True) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])) assert len(res1) == 2 res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])) assert len(res2) == 2 # compound 1.5 res = tvm.arith.subspace_divide( [k0[0], k1[0]], var_dom([i0, j0, i3]), [j0[0], i3[0]], k1[0] < 7 ) res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], i0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][0], 0) tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) tvm.ir.assert_structural_equal(res[2][0], True) tvm.ir.assert_structural_equal(res[2][1], (floormod(j0[0], 4) * 2) + i3[0] < 7) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])) assert len(res1) == 2 res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])) assert len(res2) == 2 # compound 1.6 res = tvm.arith.subspace_divide( [k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]], tvm.tir.all(k0[0] < 7, k1[0] < 7) ) res = convert_division(res) assert len(res) == 0 # compound 2 j0 = create_iter("j0", 4) l0 = create_iter("l0", 2) l1 = create_iter("l1", 6) j3 = create_iter("j3", 3) k0 = ifuse([l0, l1]) i1, j2 = isplit(k0, 3) j1, i1 = isplit(i1, 2) i0 = ifuse([j0, j1]) i2 = ifuse([j2, j3]) # compound 2.1 res = tvm.arith.subspace_divide( [i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l1[0], j3[0]] ) res = convert_division(res) assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0]) tvm.ir.assert_structural_equal(res[0][1], 0) tvm.ir.assert_structural_equal(res[1][0], 0) tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3)) tvm.ir.assert_structural_equal(res[2][0], 0) tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])) assert len(res1) == 3 res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])) assert len(res2) == 3 # compound 2.2 res = tvm.arith.subspace_divide( [i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l0[0], l1[0], j3[0]] ) res = convert_division(res) assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], j0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(l0[0] * 6 + l1[0], 6)) tvm.ir.assert_structural_equal(res[1][0], 0) tvm.ir.assert_structural_equal(res[1][1], floormod(floordiv(l0[0] * 6 + l1[0], 3), 2)) tvm.ir.assert_structural_equal(res[2][0], 0) tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0]) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l0, l1, j3])) assert len(res1) == 3 res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0])) assert len(res2) == 3 # compound 2.3 res = tvm.arith.subspace_divide( [i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l0[0], j3[0]] ) res = convert_division(res) assert len(res) == 0 # compound 2.4 res = tvm.arith.subspace_divide( [i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l1[0], j3[0]], tvm.tir.all(i0[0] < 7, i2[0] < 8), ) res = convert_division(res) assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0]) tvm.ir.assert_structural_equal(res[0][1], 0) tvm.ir.assert_structural_equal(res[1][0], 0) tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3)) tvm.ir.assert_structural_equal(res[2][0], 0) tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) tvm.ir.assert_structural_equal(res[3][0], (j0[0] * 2) + l0[0] < 7) tvm.ir.assert_structural_equal(res[3][1], (floormod(l1[0], 3) * 3) + j3[0] < 8) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])) assert len(res1) == 3 res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])) assert len(res2) == 3 # compound 2.5 res = tvm.arith.subspace_divide( [i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [j3[0]], i2[0] < 8 ) res = convert_division(res) assert len(res) == 0
def input_example(i0, i1, i2, i3): j0 = floordiv(i3, 32) j1 = floordiv(i2, 2) j2 = floormod(i2, 2) j3 = floormod(i3, 32) return j0, j1, j2, j3
def apply(lhs, rhs): a = _force_int(lhs) b = _suppress_zero(_force_int(rhs)) return tir.floormod(a, b)