def test_range_infer(): x = tvm.Var('x') y = tvm.Var('y') t = tvm.Var('t') z = x + y + t zr = tvm.infer_range(z, {x: tvm.Range(10, 20), y: tvm.Range(10, 11)}) assert str(zr) == "((t0 + 20), (t0 + 30))"
def test_simplify_mod(): ib = tvm.ir_builder.create() n = tvm.var('n') A = ib.pointer("float32", name="A") with ib.for_range(0, 10, name="j") as j: with ib.for_range(0, 16, name="i") as i: A[i] = A[(j * 32 + i + 1) % 16] body = ib.get() stmt = tvm.ir_pass.CanonicalSimplify(body) diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16) assert diff.value == 0 # if we can't prove that j is non-negative, we can't prove that (j+16) % 16 is j%16 index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16) assert index != j index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16, {j: tvm.Range(0, 6)}) assert index == j # if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16 index = tvm.ir_pass.CanonicalSimplify((j + n * 32) % 16, {j: tvm.Range(0, 6)}) assert index != j index = tvm.ir_pass.CanonicalSimplify((j + n * 32) % 16, { j: tvm.Range(0, 6), n: tvm.Range(0, 10) }) assert index == j
def test_modular(): rx = tvm.var("rx") ry = tvm.var("ry") y = tvm.var("y") x = tvm.var("x") vmap = {rx: tvm.Range(tvm.const(0), tvm.const(3)), ry: tvm.Range(tvm.const(0), tvm.const(3)), y: tvm.Range(tvm.const(0), tvm.const(2)), x: tvm.Range(tvm.const(0), tvm.const(14))} idx = ry * 16 + rx + y * 16 + x z1 = tvm.ir_pass.CanonicalSimplify(idx // 16, vmap) z2 = tvm.ir_pass.CanonicalSimplify(idx % 16, vmap) assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0 assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0
def test_tensor_dom_infer(): A = tvm.Tensor(2, name='A') B = tvm.Tensor(2, name='B') rd = tvm.RDom(tvm.Range(A.shape[1])) T = tvm.Tensor(2, lambda i, j: tvm.reduce_sum( A(i, rd.index[0]) * B(j, rd.index[0]), rdom=rd), shape=(A.shape[0], B.shape[0])) C = tvm.Tensor(2, lambda i, j: T(i, j), shape=(A.shape[0], B.shape[0])) cdom = [tvm.Range(0, 10), tvm.Range(1, 11)] tdom = C.infer_input_domains(cdom, inputs=[T])[T] assert T.is_rtensor assert str(tdom[0]) == "(0, 10)"
def test_simplify_div(): x = tvm.var('x') assert tvm.ir_pass.CanonicalSimplify((16+48*x)/16 - (1 + (x*3))).value == 0 # (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0 # (17+48*x)/16 != 1+3*x r = tvm.ir_pass.CanonicalSimplify((17+48*x)/16) assert r.b.value == 16 assert tvm.ir_pass.CanonicalSimplify(r.a - (17 + 48*x)).value == 0 # However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified assert tvm.ir_pass.CanonicalSimplify((17+48*x)/16 - (1 + (x*3)), {x: tvm.Range(0,10)}).value == 0 # Trying expressions that are not simplifiable for any values of the variables r = tvm.ir_pass.CanonicalSimplify((17+47*x)/16, {x: tvm.Range(0,10)}) assert r.b.value == 16 assert tvm.ir_pass.CanonicalSimplify(r.a - (17+47*x)).value == 0 r = tvm.ir_pass.CanonicalSimplify((8*x - 17)/8, {x : tvm.Range(4,10)}) assert tvm.ir_pass.CanonicalSimplify(r - (x-3)).value == 0
def test_tensor_reduce(): A = tvm.Tensor(2, name='A') B = tvm.Tensor(2, name='B') T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k), shape=(A.shape[0], B.shape[0], A.shape[1])) rd = tvm.RDom(tvm.Range(A.shape[1])) C = tvm.Tensor(2, lambda i, j: tvm.reduce_sum(T(i, j, rd.index[0]), rdom=rd), shape=(A.shape[0], B.shape[0])) print(tvm.format_str(C.expr))
def test_split_dom_infer(): A = tvm.Tensor(2, name='A') rd = tvm.RDom(tvm.Range(A.shape[1])) split1 = tvm.Split(0, 64) split2 = tvm.Split(1, 64) split3 = tvm.Split(0, 8) dom = [tvm.Range(A.shape[0]), tvm.Range(A.shape[1])] dom1 = split1.infer_inner_domain(dom) dom2 = split2.infer_inner_domain(dom1) dom3 = split3.infer_inner_domain(dom2) dom4 = split3.infer_inner_domain(rd) i1 = split1.loop_index.name i2 = split2.loop_index.name i3 = split3.loop_index.name assert str(dom1) == "[((%s * 64), ((%s * 64) + 64)), (0, A_shape_1_0)]" % (i1, i1) assert str(dom2) == "[((%s * 64), ((%s * 64) + 64)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i1, i2, i2) assert str(dom3) == "[(((%s * 64) + (%s * 8)), (((%s * 64) + (%s * 8)) + 8)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i3, i1, i3, i2, i2) assert str(dom4) == "[((%s * 8), ((%s * 8) + 8))]" % (i3, i3)
def test_simplify_mod(): """Not yet working, mock design""" ib = tvm.ir_builder.create() n = tvm.var('n') j = tvm.var('j') A = ib.pointer("float32", name="A") with ib.for_range(0, 16, name="i") as i: A[i] = A[((n * 4 + j * 2) * 8 + i+1) % 16] body = ib.get() stmt = tvm.ir_pass.CanonicalSimplify(body) diff = tvm.ir_pass.CanonicalSimplify(stmt.body.value.index - (1 + i) % 16) assert diff.value == 0 index = tvm.ir_pass.CanonicalSimplify( (j + n * 32) % 16, {j: tvm.Range(0, 6)}) assert index == j
def test_bound(): m = tvm.var('m') vrange = tvm.convert( {m: tvm.Range(tvm.const(0, "int32"), tvm.const(10, "int32"))}) ret = tvm.ir_pass.Simplify(m % 10, vrange) assert ret == m
def test_simplify_combiner(): dummy = tvm.var('dummy') prod = comm_reducer(lambda x, y: x * y, lambda t0: tvm.const(1, t0)) sum_or_prod = comm_reducer( lambda x, y: tvm.expr.Select(dummy < 0, x + y, x * y), lambda t0: tvm.expr.Select(dummy < 0, tvm.const(0, t0), tvm.const( 1, t0))) sum_and_prod = comm_reducer( lambda x, y: (x[0] + y[0], x[1] * y[1]), lambda t0, t1: (tvm.const(0, t0), tvm.const(5, t0) - tvm.const(4, t0))) sum_and_prod2 = comm_reducer( lambda x, y: (x[0] + y[0], x[1] * y[1] + 0 * x[0] + y[0] - y[0]), lambda t0, t1: (tvm.const(5, t0) - tvm.const(5, t0), tvm.const(1, t1))) some_reducer1 = comm_reducer( lambda x, y: (x[0] + y[0], x[0] + y[0] + x[1] + y[1], x[0] * y[2] + y[ 0] * x[2], x[1] + y[2], 4.0), lambda t0, t1, t2, t3, t4: (tvm.const(0, t0), tvm.const(1, t1), tvm.const(2, t2), tvm.const( 3, t3), tvm.const(4, t4))) k = tvm.reduce_axis((0, 10), name="k") A = tvm.placeholder((10, ), name='A') # Test that SimplifyCombiner makes use of vranges vrange = {dummy: tvm.Range(-10, -5)} assert Equal(Simplify(sum_or_prod(A[k], k), vrange), tvm.sum(A[k], k)) vrange = {dummy: tvm.Range(5, 10)} assert Equal(Simplify(sum_or_prod(A[k], k), vrange), prod(A[k], k)) assert Equal(Simplify(sum_and_prod((A[k], A[10 - k]), k)[0]), tvm.sum(A[k], k)) assert Equal(Simplify(sum_and_prod((A[k], A[10 - k]), k)[1]), prod(A[10 - k], k)) assert Equal(Simplify(sum_and_prod2((A[k], A[10 - k]), k)[0]), tvm.sum(A[k], k)) assert Equal(Simplify(sum_and_prod2((A[k], A[10 - k]), k)[1]), prod(A[10 - k], k)) reference_simplified_sources = [[A[0]], [A[0], A[1]], [A[0], A[2]], [A[0], A[1], A[2], A[3]], [A[4]]] for j in range(5): # Here we use the j-th component of the result, so only it and the components it # depends on are left. simplified = Simplify( some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j]) # Check that the remaining components are the expected ones. for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]): assert Equal(lhs, rhs) # Test that components with side effects are not removed side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call .Intrinsic, None, 0) assert Equal(Simplify(sum_and_prod((A[k], side_effect(A[10 - k])), k)[0]), sum_and_prod((A[k], side_effect(A[10 - k])), k)[0]) assert Equal(Simplify(sum_and_prod((side_effect(A[k]), A[10 - k]), k)[0]), tvm.sum(side_effect(A[k]), k))