Ejemplo n.º 1
0
def test_range_infer():
    x = tvm.Var('x')
    y = tvm.Var('y')
    t = tvm.Var('t')
    z = x + y + t
    zr = tvm.infer_range(z, {x: tvm.Range(10, 20), y: tvm.Range(10, 11)})
    assert str(zr) == "((t0 + 20), (t0 + 30))"
Ejemplo n.º 2
0
def test_simplify_mod():
    ib = tvm.ir_builder.create()
    n = tvm.var('n')
    A = ib.pointer("float32", name="A")
    with ib.for_range(0, 10, name="j") as j:
        with ib.for_range(0, 16, name="i") as i:
            A[i] = A[(j * 32 + i + 1) % 16]
    body = ib.get()
    stmt = tvm.ir_pass.CanonicalSimplify(body)
    diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index -
                                         (1 + i) % 16)
    assert diff.value == 0
    # if we can't prove that j is non-negative, we can't prove that (j+16) % 16 is j%16
    index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16)
    assert index != j
    index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16, {j: tvm.Range(0, 6)})
    assert index == j
    # if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16
    index = tvm.ir_pass.CanonicalSimplify((j + n * 32) % 16,
                                          {j: tvm.Range(0, 6)})
    assert index != j
    index = tvm.ir_pass.CanonicalSimplify((j + n * 32) % 16, {
        j: tvm.Range(0, 6),
        n: tvm.Range(0, 10)
    })
    assert index == j
Ejemplo n.º 3
0
def test_modular():
    rx = tvm.var("rx")
    ry = tvm.var("ry")
    y = tvm.var("y")
    x = tvm.var("x")
    vmap = {rx: tvm.Range(tvm.const(0), tvm.const(3)),
            ry: tvm.Range(tvm.const(0), tvm.const(3)),
            y: tvm.Range(tvm.const(0), tvm.const(2)),
            x: tvm.Range(tvm.const(0), tvm.const(14))}
    idx = ry * 16 + rx + y * 16 + x
    z1 = tvm.ir_pass.CanonicalSimplify(idx // 16, vmap)
    z2 = tvm.ir_pass.CanonicalSimplify(idx % 16, vmap)
    assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0
    assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0
Ejemplo n.º 4
0
def test_tensor_dom_infer():
    A = tvm.Tensor(2, name='A')
    B = tvm.Tensor(2, name='B')
    rd = tvm.RDom(tvm.Range(A.shape[1]))
    T = tvm.Tensor(2,
                   lambda i, j: tvm.reduce_sum(
                       A(i, rd.index[0]) * B(j, rd.index[0]), rdom=rd),
                   shape=(A.shape[0], B.shape[0]))
    C = tvm.Tensor(2, lambda i, j: T(i, j), shape=(A.shape[0], B.shape[0]))

    cdom = [tvm.Range(0, 10), tvm.Range(1, 11)]
    tdom = C.infer_input_domains(cdom, inputs=[T])[T]
    assert T.is_rtensor
    assert str(tdom[0]) == "(0, 10)"
Ejemplo n.º 5
0
def test_simplify_div():
    x = tvm.var('x')
    assert tvm.ir_pass.CanonicalSimplify((16+48*x)/16 - (1 + (x*3))).value == 0
    # (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
    # (17+48*x)/16 != 1+3*x
    r = tvm.ir_pass.CanonicalSimplify((17+48*x)/16)
    assert r.b.value == 16
    assert tvm.ir_pass.CanonicalSimplify(r.a - (17 + 48*x)).value == 0
    # However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified
    assert tvm.ir_pass.CanonicalSimplify((17+48*x)/16 - (1 + (x*3)), {x: tvm.Range(0,10)}).value == 0

    # Trying expressions that are not simplifiable for any values of the variables
    r = tvm.ir_pass.CanonicalSimplify((17+47*x)/16, {x: tvm.Range(0,10)})
    assert r.b.value == 16
    assert tvm.ir_pass.CanonicalSimplify(r.a - (17+47*x)).value == 0

    r = tvm.ir_pass.CanonicalSimplify((8*x - 17)/8, {x : tvm.Range(4,10)})
    assert tvm.ir_pass.CanonicalSimplify(r - (x-3)).value == 0
Ejemplo n.º 6
0
def test_tensor_reduce():
    A = tvm.Tensor(2, name='A')
    B = tvm.Tensor(2, name='B')
    T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
                   shape=(A.shape[0], B.shape[0], A.shape[1]))
    rd = tvm.RDom(tvm.Range(A.shape[1]))
    C = tvm.Tensor(2, lambda i, j: tvm.reduce_sum(T(i, j, rd.index[0]), rdom=rd),
                   shape=(A.shape[0], B.shape[0]))
    print(tvm.format_str(C.expr))
Ejemplo n.º 7
0
def test_split_dom_infer():
    A = tvm.Tensor(2, name='A')
    rd = tvm.RDom(tvm.Range(A.shape[1]))
    split1 = tvm.Split(0, 64)
    split2 = tvm.Split(1, 64)
    split3 = tvm.Split(0, 8)
    dom = [tvm.Range(A.shape[0]), tvm.Range(A.shape[1])]
    dom1 = split1.infer_inner_domain(dom)
    dom2 = split2.infer_inner_domain(dom1)
    dom3 = split3.infer_inner_domain(dom2)
    dom4 = split3.infer_inner_domain(rd)
    i1 = split1.loop_index.name
    i2 = split2.loop_index.name
    i3 = split3.loop_index.name
    assert str(dom1) == "[((%s * 64), ((%s * 64) + 64)), (0, A_shape_1_0)]" % (i1, i1)
    assert str(dom2) == "[((%s * 64), ((%s * 64) + 64)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i1, i2, i2)
    assert str(dom3) == "[(((%s * 64) + (%s * 8)), (((%s * 64) + (%s * 8)) + 8)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i3, i1, i3, i2, i2)
    assert str(dom4) == "[((%s * 8), ((%s * 8) + 8))]" % (i3, i3)
Ejemplo n.º 8
0
def test_simplify_mod():
    """Not yet working, mock design"""
    ib = tvm.ir_builder.create()
    n = tvm.var('n')
    j = tvm.var('j')
    A = ib.pointer("float32", name="A")
    with ib.for_range(0, 16, name="i") as i:
        A[i] = A[((n * 4 + j * 2) * 8 + i+1) % 16]
    body = ib.get()
    stmt = tvm.ir_pass.CanonicalSimplify(body)
    diff = tvm.ir_pass.CanonicalSimplify(stmt.body.value.index - (1 + i) % 16)
    assert diff.value == 0
    index = tvm.ir_pass.CanonicalSimplify(
        (j + n * 32) % 16, {j: tvm.Range(0, 6)})
    assert index == j
Ejemplo n.º 9
0
def test_bound():
    m = tvm.var('m')
    vrange = tvm.convert(
        {m: tvm.Range(tvm.const(0, "int32"), tvm.const(10, "int32"))})
    ret = tvm.ir_pass.Simplify(m % 10, vrange)
    assert ret == m
Ejemplo n.º 10
0
def test_simplify_combiner():
    dummy = tvm.var('dummy')

    prod = comm_reducer(lambda x, y: x * y, lambda t0: tvm.const(1, t0))

    sum_or_prod = comm_reducer(
        lambda x, y: tvm.expr.Select(dummy < 0, x + y, x * y),
        lambda t0: tvm.expr.Select(dummy < 0, tvm.const(0, t0), tvm.const(
            1, t0)))

    sum_and_prod = comm_reducer(
        lambda x, y: (x[0] + y[0], x[1] * y[1]), lambda t0, t1:
        (tvm.const(0, t0), tvm.const(5, t0) - tvm.const(4, t0)))

    sum_and_prod2 = comm_reducer(
        lambda x, y: (x[0] + y[0], x[1] * y[1] + 0 * x[0] + y[0] - y[0]),
        lambda t0, t1: (tvm.const(5, t0) - tvm.const(5, t0), tvm.const(1, t1)))

    some_reducer1 = comm_reducer(
        lambda x, y: (x[0] + y[0], x[0] + y[0] + x[1] + y[1], x[0] * y[2] + y[
            0] * x[2], x[1] + y[2], 4.0), lambda t0, t1, t2, t3, t4:
        (tvm.const(0, t0), tvm.const(1, t1), tvm.const(2, t2), tvm.const(
            3, t3), tvm.const(4, t4)))

    k = tvm.reduce_axis((0, 10), name="k")
    A = tvm.placeholder((10, ), name='A')

    # Test that SimplifyCombiner makes use of vranges
    vrange = {dummy: tvm.Range(-10, -5)}
    assert Equal(Simplify(sum_or_prod(A[k], k), vrange), tvm.sum(A[k], k))
    vrange = {dummy: tvm.Range(5, 10)}
    assert Equal(Simplify(sum_or_prod(A[k], k), vrange), prod(A[k], k))

    assert Equal(Simplify(sum_and_prod((A[k], A[10 - k]), k)[0]),
                 tvm.sum(A[k], k))
    assert Equal(Simplify(sum_and_prod((A[k], A[10 - k]), k)[1]),
                 prod(A[10 - k], k))

    assert Equal(Simplify(sum_and_prod2((A[k], A[10 - k]), k)[0]),
                 tvm.sum(A[k], k))
    assert Equal(Simplify(sum_and_prod2((A[k], A[10 - k]), k)[1]),
                 prod(A[10 - k], k))

    reference_simplified_sources = [[A[0]], [A[0], A[1]], [A[0], A[2]],
                                    [A[0], A[1], A[2], A[3]], [A[4]]]
    for j in range(5):
        # Here we use the j-th component of the result, so only it and the components it
        # depends on are left.
        simplified = Simplify(
            some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j])

        # Check that the remaining components are the expected ones.
        for lhs, rhs in zip(simplified.source,
                            reference_simplified_sources[j]):
            assert Equal(lhs, rhs)

    # Test that components with side effects are not removed
    side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call
                                            .Intrinsic, None, 0)
    assert Equal(Simplify(sum_and_prod((A[k], side_effect(A[10 - k])), k)[0]),
                 sum_and_prod((A[k], side_effect(A[10 - k])), k)[0])
    assert Equal(Simplify(sum_and_prod((side_effect(A[k]), A[10 - k]), k)[0]),
                 tvm.sum(side_effect(A[k]), k))