def test_inverse_affine_iter_map():
    analyzer = tvm.arith.Analyzer()
    l0 = create_iter("l0", 64)
    l1 = create_iter("l1", 64)
    l2 = create_iter("l2", 64)

    # simple case
    l0_0, l0_1 = isplit(l0, 16)
    l1_0, l1_1 = isplit(l1, 4)
    l0_1_l1_1_fused = ifuse([l0_1, l1_1])

    iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1]))
    outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))]
    res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
    assert len(res) == 2
    l0_inverse = floormod(floordiv(outputs[0], 4), 16) + outputs[1] * 16
    l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4
    assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
    assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0

    # compound case
    l0_0, l0_1 = isplit(l0, 16)
    l1_0, l1_1 = isplit(l1, 4)
    l2_1, l2_2 = isplit(l2, 4)
    l2_0, l2_1 = isplit(l2_1, 4)

    l0_1_l2_1_l1_1_l2_0_fused = ifuse([l0_1, l2_1, l1_1, l2_0])

    iter_map = tvm.arith.detect_iter_map(
        [l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], var_dom([l0, l1, l2])
    )
    outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))]
    res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
    assert len(res) == 3
    l0_inverse = floormod(floordiv(outputs[0], 64), 16) + outputs[1] * 16
    l1_inverse = floormod(floordiv(outputs[0], 4), 4) + outputs[3] * 4
    l2_inverse = (
        floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 4 + outputs[2]
    )

    assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
    assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0
    assert analyzer.simplify(res[l2[0]] - l2_inverse) == 0

    # diamond-shape DAG
    l0_0, l0_1 = isplit(l0, 16)
    l1 = ifuse([l0_1, l0_0])
    l1_0, l1_1 = isplit(l1, 8)
    l2 = ifuse([l1_1, l1_0])

    iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0]))
    outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))]
    res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
    assert len(res) == 1
    l1_inverse = floormod(outputs[0], 8) * 8 + floormod(floordiv(outputs[0], 8), 8)
    l0_inverse = floormod(l1_inverse, 4) * 16 + floormod(floordiv(l1_inverse, 4), 16)

    assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
Exemple #2
0
def elementwise_fused(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128))
    for fused in tir.serial(0, 2097152):
        with tir.block([128, 128, 128], "B") as [vi, vj, vk]:
            tir.bind(vi, tir.floordiv(fused, 16384))
            tir.bind(vj, tir.floormod(tir.floordiv(fused, 128), 128))
            tir.bind(vk, tir.floormod(fused, 128))
            tir.reads([A[vi, vj, vk]])
            tir.writes([B[vi, vj, vk]])
            B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Exemple #3
0
def elementwise_symbolic_fused(a: ty.handle, b: ty.handle,
                               n: ty.int32) -> None:
    A = tir.match_buffer(a, (128, 128, n))
    B = tir.match_buffer(b, (128, 128, n))
    for i_j_k_fused in tir.serial(0, (n * 16384)):
        with tir.block([128, 128, n], "B") as [vi, vj, vk]:
            tir.bind(vi, tir.floordiv(i_j_k_fused, (n * 128)))
            tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, n), 128))
            tir.bind(vk, tir.floormod(i_j_k_fused, n))
            tir.reads([A[vi, vj, vk]])
            tir.writes([B[vi, vj, vk]])
            B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def test_div_by_zero():
	a = te.var(name='a',dtype='int32')
	b = te.var(name='b',dtype='int32')
	
	zero = tir.const(0)
	two = tir.const(2)
	fzero = tir.const(0.0)
	ftwo = tir.const(2.0)
	for int_bin_op in [lambda a,b: a % b, lambda a,b: tir.floordiv(a,b), lambda a,b: tir.truncdiv(a,b),
					lambda a,b: tir.floormod(a,b), lambda a,b: tir.truncmod(a,b) ]:
		try:
			int_bin_op(a,zero)
		except TVMError:
			pass
			

		try:	
			int_bin_op(two,zero)
		except TVMError:
			pass
	
	for float_bin_op in [lambda a,b: tir.div(a,b), lambda a,b: tir.truncmod(a,b)]:
		try:
			float_bin_op(a,fzero)
		except TVMError:
			pass


		try:
			float_bin_op(ftwo,fzero)
		except TVMError:
			pass
def test_suggest_index_map_bijective():
    i, j = _make_vars("i", "j")
    index_map = suggest_index_map(
        buffer=decl_buffer(shape=[8]),
        indices=[floormod(j, 4) * 2 + i],
        loops=_make_loops(
            loop_vars=[i, j],
            extents=[2, 32],
        ),
        predicate=True,
    )
    expected_index_map = IndexMap.from_func(
        lambda x: [
            floormod(x, 2),
            floordiv(x, 2),
        ], )
    assert index_map.is_equivalent_to(expected_index_map)
Exemple #6
0
def rowsum_transformed(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.match_buffer(b, (128, ))

    for io, ii_ko_fused, ki in tir.grid(32, 128, 4):
        with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
            tir.bind(vi, io * 4 + tir.floordiv(ii_ko_fused, 32))
            tir.bind(vk, tir.floormod(ii_ko_fused, 32) * 4 + ki)
            with tir.init():
                B[vi] = 0.0
            B[vi] = B[vi] + A[vi, vk]
Exemple #7
0
def opaque_access_fused(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 16])
    B = tir.match_buffer(b, [16, 16])
    for i_j_fused in tir.serial(0, 256):
        with tir.block([16, 16], "A") as [vi, vj]:
            tir.bind(vi, tir.floordiv(i_j_fused, 16))
            tir.bind(vj, tir.floormod(i_j_fused, 16))
            tir.reads([])
            tir.writes([A[0:16, 0:16]])
            tir.store(A.data, ((vi * 16) + vj), 1, 1)
    for i_j_fused in tir.serial(0, 256):
        with tir.block([16, 16], "B") as [vi, vj]:
            tir.bind(vi, tir.floordiv(i_j_fused, 16))
            tir.bind(vj, tir.floormod(i_j_fused, 16))
            tir.reads([])
            tir.writes([B[0:16, 0:16]])
            tir.evaluate(
                tir.tvm_fill_fragment(B.data,
                                      16,
                                      16,
                                      16,
                                      0, ((vi * 16) + vj),
                                      dtype="handle"))
def test_suggest_index_map_simple():
    i, j = _make_vars("i", "j")
    index_map = suggest_index_map(
        buffer=decl_buffer(shape=[8, 256]),
        indices=[
            floordiv(i, 16) * 4 + floordiv(j, 16),
            floormod(i, 16) * 16 + floormod(j, 16),
        ],
        loops=_make_loops(
            loop_vars=[i, j],
            extents=[32, 64],
        ),
        predicate=True,
    )
    expected_index_map = IndexMap.from_func(
        lambda x, y: [
            floordiv(x, 4),
            floordiv(y, 16),
            floormod(x, 4),
            floormod(y, 16),
        ],
    )
    assert index_map.is_equivalent_to(expected_index_map)
Exemple #9
0
    def gen_ir(
        data_ptr,
        n_fft,
        hop_length,
        win_length,
        window_ptr,
        normalized,
        onesided,
        output_ptr,
        loop_kind,
    ):
        ib = tir.ir_builder.create()
        data = ib.buffer_ptr(data_ptr)
        window = ib.buffer_ptr(window_ptr)
        output = ib.buffer_ptr(output_ptr)
        # https://librosa.org/doc/0.7.2/_modules/librosa/core/spectrum.html#stft
        with ib.for_range(0,
                          output_ptr.shape[0] * output_ptr.shape[1],
                          kind="parallel") as batch_row:
            with ib.for_range(0, output_ptr.shape[2], kind=loop_kind) as col:
                batch = ib.allocate("int32", (1), name="batch", scope="local")
                row = ib.allocate("int32", (1), name="row", scope="local")
                batch = tir.floordiv(batch_row, output_ptr.shape[1])
                row = tir.floormod(batch_row, output_ptr.shape[1])
                output[batch, row, col, 0] = tir.Cast(data_ptr.dtype, 0)
                output[batch, row, col, 1] = tir.Cast(data_ptr.dtype, 0)
                with ib.for_range(0, win_length) as wlen:
                    output[batch, row, col,
                           0] += (window[wlen] *
                                  data[batch, col * hop_length + wlen] *
                                  tir.cos(2 * pi * row * wlen / win_length))
                    output[batch, row, col,
                           1] -= (window[wlen] *
                                  data[batch, col * hop_length + wlen] *
                                  tir.sin(2 * pi * row * wlen / win_length))
                with ib.if_scope(normalized):
                    output[batch, row, col,
                           0] /= tir.sqrt(tir.const(n_fft, "float32"))
                    output[batch, row, col,
                           1] /= tir.sqrt(tir.const(n_fft, "float32"))

        return ib.get()
Exemple #10
0
def elementwise_fuse_with_opaque_block(a: ty.handle, b: ty.handle) -> None:
    B = tir.match_buffer(b, [128, 128, 128])
    A = tir.match_buffer(a, [128, 128, 128])
    for i_j_k_fused in tir.serial(0, 2097152):
        with tir.block([], "opaque"):
            tir.reads([
                A[tir.floormod(
                    tir.floordiv(tir.floordiv(i_j_k_fused, 128), 128), 128),
                  tir.floormod(tir.floordiv(i_j_k_fused, 128), 128),
                  tir.floormod(i_j_k_fused, 128), ]
            ])
            tir.writes([
                B[tir.floormod(
                    tir.floordiv(tir.floordiv(i_j_k_fused, 128), 128), 128),
                  tir.floormod(tir.floordiv(i_j_k_fused, 128), 128),
                  tir.floormod(i_j_k_fused, 128), ]
            ])
            with tir.block([128, 128, 128], "B") as [vi, vj, vk]:
                tir.bind(vi, tir.floordiv(i_j_k_fused, 16384))
                tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, 128), 128))
                tir.bind(vk, tir.floormod(i_j_k_fused, 128))
                tir.reads([A[vi, vj, vk]])
                tir.writes([B[vi, vj, vk]])
                B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Exemple #11
0
def transformed_square_sum_square_root(a: ty.handle, d: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 256, 256])
    D = tir.match_buffer(d, [16])
    C = tir.alloc_buffer([16])

    for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1):
        with tir.block(
            [16, tir.reduce_axis(0, 256),
             tir.reduce_axis(0, 256)], "C") as [b, i, j]:
            tir.bind(b, i0)
            tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256))
            tir.bind(j, tir.floormod(i1_i2_fused_outer, 256))
            tir.reads([C[b], A[b, i, j]])
            tir.writes([C[b]])
            with tir.init():
                C[b] = 0.0
            C[b] = C[b] + (A[b, i, j] * A[b, i, j])
    for i0_1 in tir.serial(0, 16):
        with tir.block([16], "D") as [b_1]:
            tir.bind(b_1, i0_1)
            tir.reads([C[b_1]])
            tir.writes([D[b_1]])
            D[b_1] = tir.sqrt(C[b_1], dtype="float32")
Exemple #12
0
def square_sum_square_root_rfactor(a: ty.handle, d: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 256, 256])
    D = tir.match_buffer(d, [16])
    C = tir.alloc_buffer([16])
    C_rf = tir.alloc_buffer([1, 16])

    for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1):
        with tir.block(
            [1, 16, tir.reduce_axis(0, 256),
             tir.reduce_axis(0, 256)], "C_rf") as [
                 vi1_i2_fused_inner,
                 b,
                 i,
                 j,
             ]:
            tir.bind(vi1_i2_fused_inner, i1_i2_fused_inner)
            tir.bind(b, i0)
            tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256))
            tir.bind(j, tir.floormod(i1_i2_fused_outer, 256))
            with tir.init():
                C_rf[vi1_i2_fused_inner, b] = 0.0
            C_rf[vi1_i2_fused_inner,
                 b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j])

    for i0_1, i1_i2_fused_inner_1 in tir.grid(16, 1):
        with tir.block([tir.reduce_axis(0, 1), 16],
                       "C") as [vi1_i2_fused_inner_1, b_1]:
            tir.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1)
            tir.bind(b_1, i0_1)
            with tir.init():
                C[b_1] = 0.0
            C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1]

    for i0_2 in tir.serial(0, 16):
        with tir.block([16], "D") as [b_2]:
            tir.bind(b_2, i0_2)
            D[b_2] = tir.sqrt(C[b_2], dtype="float32")
Exemple #13
0
substituting def of floorDiv gives

floorMod(a,b) = a - floor(a / b) * b
"""

def floormod(a,b):
	 return a - floor(a / b) * b

DIM = 1000
HDIM = 500
shape = (DIM,DIM)
c_tvm = tvm.nd.array(np.zeros(shape=shape,dtype='int32'))
c_np = np.zeros(shape)
	
c = te.compute(shape,lambda i,j: tir.floormod(HDIM - i,j + 1) )
d = te.compute(shape,lambda i,j: tir.floormod(HDIM - i,-(j + 1)))
s = te.create_schedule([c.op])
s2 = te.create_schedule([d.op])
f = tvm.build(s,[c])
f2 = tvm.build(s2,[d])
f(c_tvm)
out = c_tvm.asnumpy()
for i in range(DIM):
	for j in range(DIM):
		res = out[i][j]
		res2 = floormod(HDIM - i, j + 1)
		if res != res2:
			print(i,j,res,res2)
			assert False
def test_complex():
    n0 = create_iter("n0", 2)
    n1 = create_iter("n1", 4)

    m0 = ifuse([n0, n1], 6)
    m1 = create_iter("m1", 3)

    l0 = create_iter("l0", 4)
    l1 = create_iter("l1", 8)
    l2 = ifuse([m0, m1], 16)
    l3 = create_iter("l3", 32)

    k0, k4 = isplit(l0, 2)
    k1, k5 = isplit(l1, 2)
    k2, k6 = isplit(l2, 4)
    k3, k7 = isplit(l3, 4)

    j0 = ifuse([k0, k1], 7)
    j1 = ifuse([k2, k3])
    j2 = ifuse([k4, k5])
    j3 = ifuse([k6, k7], 15)

    i0 = ifuse([j0, j1], 200)
    i1 = ifuse([j2, j3], 50)

    res = tvm.arith.detect_iter_map(
        [i0[0], i1[0]],
        var_dom([l0, l1, n0, n1, m1, l3]),
        tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15),
    )
    assert len(res) == 2

    n0_mark = tvm.arith.IterMark(n0[0], n0[1])
    n1_mark = tvm.arith.IterMark(n1[0], n1[1])
    l0_mark = tvm.arith.IterMark(l0[0], l0[1])
    l1_mark = tvm.arith.IterMark(l1[0], l1[1])
    m1_mark = tvm.arith.IterMark(m1[0], m1[1])
    l3_mark = tvm.arith.IterMark(l3[0], l3[1])

    m0_expr = tvm.arith.IterSumExpr(
        [
            tvm.arith.IterSplitExpr(n0_mark, 1, n0[1], 4),
            tvm.arith.IterSplitExpr(n1_mark, 1, n1[1], 1),
        ],
        0,
    )
    m0_mark = tvm.arith.IterMark(m0_expr, 6)
    l2_expr = tvm.arith.IterSumExpr(
        [tvm.arith.IterSplitExpr(m0_mark, 1, 6, 3), tvm.arith.IterSplitExpr(m1_mark, 1, m1[1], 1)],
        0,
    )
    l2_mark = tvm.arith.IterMark(l2_expr, 16)
    k0_expr = tvm.arith.IterSplitExpr(l0_mark, 2, 2, 4)
    k1_expr = tvm.arith.IterSplitExpr(l1_mark, 2, 4, 1)
    k2_expr = tvm.arith.IterSplitExpr(l2_mark, 4, 4, 8)
    k3_expr = tvm.arith.IterSplitExpr(l3_mark, 4, 8, 1)
    k4_expr = tvm.arith.IterSplitExpr(l0_mark, 1, 2, 30)
    k5_expr = tvm.arith.IterSplitExpr(l1_mark, 1, 2, 15)
    k6_expr = tvm.arith.IterSplitExpr(l2_mark, 1, 4, 4)
    k7_expr = tvm.arith.IterSplitExpr(l3_mark, 1, 4, 1)

    j0_expr = tvm.arith.IterSumExpr([k0_expr, k1_expr], 0)
    j0_mark = tvm.arith.IterMark(j0_expr, 7)
    i0_expr = tvm.arith.IterSumExpr(
        [tvm.arith.IterSplitExpr(j0_mark, 1, 7, 32), k2_expr, k3_expr], 0
    )

    j3_expr = tvm.arith.IterSumExpr([k6_expr, k7_expr], 0)
    j3_mark = tvm.arith.IterMark(j3_expr, 15)
    i1_expr = tvm.arith.IterSumExpr(
        [k4_expr, k5_expr, tvm.arith.IterSplitExpr(j3_mark, 1, 15, 1)], 0
    )

    i0_mark = tvm.arith.IterMark(i0_expr, i0[1])
    i1_mark = tvm.arith.IterMark(i1_expr, i1[1])

    i0_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i0_mark, 1, i0[1], 1)], 0)
    i1_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i1_mark, 1, i1[1], 1)], 0)

    tvm.ir.assert_structural_equal(i0_final, res[0])
    tvm.ir.assert_structural_equal(i1_final, res[1])

    # wrong constraint
    res = tvm.arith.detect_iter_map(
        [i0[0], i1[0]],
        var_dom([l0, l1, n0, n1, m1, l3]),
        tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 9, l2[0] < 16, j0[0] < 7, j3[0] < 14),
    )
    assert len(res) == 0

    # subspace_division
    res = tvm.arith.subspace_divide(
        [i0[0], i1[0]],
        var_dom([l0, l1, n0, n1, m1, l3]),
        [n0[0], n1[0], m1[0], l3[0]],
        tvm.tir.all(m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15),
    )
    res = convert_division(res)
    assert len(res) == 3
    tvm.ir.assert_structural_equal(res[0][0], floordiv(l0[0], 2) * 4 + floordiv(l1[0], 2))
    tvm.ir.assert_structural_equal(
        res[0][1], (floordiv((n0[0] * 4 + n1[0]) * 3 + m1[0], 4) * 8) + floordiv(l3[0], 4)
    )
    tvm.ir.assert_structural_equal(res[1][0], ((floormod(l0[0], 2) * 2) + floormod(l1[0], 2)))
    tvm.ir.assert_structural_equal(
        res[1][1], ((floormod(((n0[0] * 4 + n1[0]) * 3 + m1[0]), 4) * 4) + floormod(l3[0], 4))
    )
    tvm.ir.assert_structural_equal(res[2][0], (floordiv(l0[0], 2) * 4) + floordiv(l1[0], 2) < 7)
    tvm.ir.assert_structural_equal(
        res[2][1],
        tvm.tir.all(
            n0[0] * 4 + n1[0] < 6,
            (n0[0] * 4 + n1[0]) * 3 + m1[0] < 16,
            floormod(((n0[0] * 4 + n1[0]) * 3 + m1[0]), 4) * 4 + floormod(l3[0], 4) < 15,
        ),
    )

    res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([n0, n1, m1, l3]), res[2][1])
    assert len(res1) == 2
    res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([l0, l1]))
    assert len(res2) == 2
def test_subspace_division():
    x = tvm.tir.Var("x", "int32")
    y = tvm.tir.Var("y", "int32")
    z = tvm.tir.Var("z", "int32")
    c = tvm.tir.SizeVar("c", "int32")

    # simple 1.1
    res = tvm.arith.subspace_divide(
        [z * 12 + y * 3 + x + c], var_dom([(x, 3), (y, 4), (z, 5)]), [x]
    )
    res = convert_division(res)
    assert len(res) == 2
    tvm.ir.assert_structural_equal(res[0][0], z * 4 + y)
    tvm.ir.assert_structural_equal(res[0][1], x + c)

    # simple 1.2
    res = tvm.arith.subspace_divide(
        [z * 12 + y * 3 + x + c], var_dom([(x, 3), (y, 4), (z, 5)]), [x], z * 4 + y < 18
    )
    res = convert_division(res)
    assert len(res) == 2
    tvm.ir.assert_structural_equal(res[0][0], z * 4 + y)
    tvm.ir.assert_structural_equal(res[0][1], x + c)
    tvm.ir.assert_structural_equal(res[1][0], z * 4 + y < 18)
    tvm.ir.assert_structural_equal(res[1][1], True)

    # compound 1
    i0 = create_iter("i0", 4)
    j0 = create_iter("j0", 8)
    i3 = create_iter("i3", 2)

    i1, i2 = isplit(j0, 4)
    k0 = ifuse([i0, i1])
    k1 = ifuse([i2, i3])

    # compound 1.1
    res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]])
    res = convert_division(res)
    assert len(res) == 3
    tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4))
    tvm.ir.assert_structural_equal(res[0][1], 0)
    tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4))
    tvm.ir.assert_structural_equal(res[1][1], i3[0])

    res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3]))
    assert len(res1) == 2
    res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0]))
    assert len(res2) == 2

    # compound 1.2
    res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [j0[0], i3[0]])
    res = convert_division(res)
    assert len(res) == 3
    tvm.ir.assert_structural_equal(res[0][0], i0[0])
    tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4))
    tvm.ir.assert_structural_equal(res[1][0], 0)
    tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0])

    res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3]))
    assert len(res1) == 2
    res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0]))
    assert len(res2) == 2

    # compound 1.3
    res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i0[0], i3[0]])
    res = convert_division(res)
    assert len(res) == 0

    # compound 1.4
    res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]], k0[0] < 7)
    res = convert_division(res)
    assert len(res) == 3
    tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4))
    tvm.ir.assert_structural_equal(res[0][1], 0)
    tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4))
    tvm.ir.assert_structural_equal(res[1][1], i3[0])
    tvm.ir.assert_structural_equal(res[2][0], (i0[0] * 2) + floordiv(j0[0], 4) < 7)
    tvm.ir.assert_structural_equal(res[2][1], True)

    res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3]))
    assert len(res1) == 2
    res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0]))
    assert len(res2) == 2

    # compound 1.5
    res = tvm.arith.subspace_divide(
        [k0[0], k1[0]], var_dom([i0, j0, i3]), [j0[0], i3[0]], k1[0] < 7
    )
    res = convert_division(res)
    assert len(res) == 3
    tvm.ir.assert_structural_equal(res[0][0], i0[0])
    tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4))
    tvm.ir.assert_structural_equal(res[1][0], 0)
    tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0])
    tvm.ir.assert_structural_equal(res[2][0], True)
    tvm.ir.assert_structural_equal(res[2][1], (floormod(j0[0], 4) * 2) + i3[0] < 7)

    res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3]))
    assert len(res1) == 2
    res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0]))
    assert len(res2) == 2

    # compound 1.6
    res = tvm.arith.subspace_divide(
        [k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]], tvm.tir.all(k0[0] < 7, k1[0] < 7)
    )
    res = convert_division(res)
    assert len(res) == 0

    # compound 2
    j0 = create_iter("j0", 4)
    l0 = create_iter("l0", 2)
    l1 = create_iter("l1", 6)
    j3 = create_iter("j3", 3)

    k0 = ifuse([l0, l1])
    i1, j2 = isplit(k0, 3)
    j1, i1 = isplit(i1, 2)
    i0 = ifuse([j0, j1])
    i2 = ifuse([j2, j3])

    # compound 2.1
    res = tvm.arith.subspace_divide(
        [i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l1[0], j3[0]]
    )
    res = convert_division(res)
    assert len(res) == 4
    tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0])
    tvm.ir.assert_structural_equal(res[0][1], 0)
    tvm.ir.assert_structural_equal(res[1][0], 0)
    tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3))
    tvm.ir.assert_structural_equal(res[2][0], 0)
    tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0])

    res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3]))
    assert len(res1) == 3
    res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0]))
    assert len(res2) == 3

    # compound 2.2
    res = tvm.arith.subspace_divide(
        [i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l0[0], l1[0], j3[0]]
    )
    res = convert_division(res)
    assert len(res) == 4
    tvm.ir.assert_structural_equal(res[0][0], j0[0])
    tvm.ir.assert_structural_equal(res[0][1], floordiv(l0[0] * 6 + l1[0], 6))
    tvm.ir.assert_structural_equal(res[1][0], 0)
    tvm.ir.assert_structural_equal(res[1][1], floormod(floordiv(l0[0] * 6 + l1[0], 3), 2))
    tvm.ir.assert_structural_equal(res[2][0], 0)
    tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0])

    res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l0, l1, j3]))
    assert len(res1) == 3
    res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0]))
    assert len(res2) == 3

    # compound 2.3
    res = tvm.arith.subspace_divide(
        [i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l0[0], j3[0]]
    )
    res = convert_division(res)
    assert len(res) == 0

    # compound 2.4
    res = tvm.arith.subspace_divide(
        [i0[0], i1[0], i2[0]],
        var_dom([j0, l0, l1, j3]),
        [l1[0], j3[0]],
        tvm.tir.all(i0[0] < 7, i2[0] < 8),
    )
    res = convert_division(res)
    assert len(res) == 4
    tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0])
    tvm.ir.assert_structural_equal(res[0][1], 0)
    tvm.ir.assert_structural_equal(res[1][0], 0)
    tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3))
    tvm.ir.assert_structural_equal(res[2][0], 0)
    tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0])
    tvm.ir.assert_structural_equal(res[3][0], (j0[0] * 2) + l0[0] < 7)
    tvm.ir.assert_structural_equal(res[3][1], (floormod(l1[0], 3) * 3) + j3[0] < 8)

    res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3]))
    assert len(res1) == 3
    res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0]))
    assert len(res2) == 3

    # compound 2.5
    res = tvm.arith.subspace_divide(
        [i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [j3[0]], i2[0] < 8
    )
    res = convert_division(res)
    assert len(res) == 0
Exemple #16
0
 def input_example(i0, i1, i2, i3):
     j0 = floordiv(i3, 32)
     j1 = floordiv(i2, 2)
     j2 = floormod(i2, 2)
     j3 = floormod(i3, 32)
     return j0, j1, j2, j3
Exemple #17
0
 def apply(lhs, rhs):
     a = _force_int(lhs)
     b = _suppress_zero(_force_int(rhs))
     return tir.floormod(a, b)