Exemplo n.º 1
0
def test_gemm():
    # graph
    nn = 2048
    n = te.var("n")
    n = tvm.runtime.convert(nn)
    m, l = n, n
    A = te.placeholder((l, n), name="A")
    B = te.placeholder((l, m), name="B")
    k = te.reduce_axis((0, l), name="k")
    C = te.compute((m, n), lambda ii, jj: te.sum(A[k, jj] * B[k, ii], axis=k), name="C")

    # schedule
    s = te.create_schedule(C.op)
    AA = s.cache_read(A, "shared", [C])
    BB = s.cache_read(B, "shared", [C])
    AL = s.cache_read(AA, "local", [C])
    BL = s.cache_read(BB, "local", [C])
    CC = s.cache_write(C, "local")

    scale = 8
    num_thread = 8
    block_factor = scale * num_thread
    block_x = te.thread_axis("blockIdx.x")
    thread_x = te.thread_axis((0, num_thread), "threadIdx.x")
    block_y = te.thread_axis("blockIdx.y")
    thread_y = te.thread_axis((0, num_thread), "threadIdx.y")
    thread_xz = te.thread_axis((0, 2), "vthread", name="vx")
    thread_yz = te.thread_axis((0, 2), "vthread", name="vy")

    by, yi = s[C].split(C.op.axis[0], factor=block_factor)
    bx, xi = s[C].split(C.op.axis[1], factor=block_factor)
    s[C].bind(by, block_y)
    s[C].bind(bx, block_x)
    s[C].reorder(by, bx, yi, xi)

    tyz, yi = s[C].split(yi, nparts=2)
    ty, yi = s[C].split(yi, nparts=num_thread)
    txz, xi = s[C].split(xi, nparts=2)
    tx, xi = s[C].split(xi, nparts=num_thread)
    s[C].bind(tyz, thread_yz)
    s[C].bind(txz, thread_xz)
    s[C].bind(ty, thread_y)
    s[C].bind(tx, thread_x)
    s[C].reorder(tyz, txz, ty, tx, yi, xi)
    s[CC].compute_at(s[C], tx)

    yo, xo = CC.op.axis
    ko, ki = s[CC].split(k, factor=8)
    kt, ki = s[CC].split(ki, factor=1)
    s[CC].reorder(ko, kt, ki, yo, xo)
    s[AA].compute_at(s[CC], ko)
    s[BB].compute_at(s[CC], ko)
    s[CC].unroll(kt)
    s[AL].compute_at(s[CC], kt)
    s[BL].compute_at(s[CC], kt)
    # Schedule for A's shared memory load
    ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
    _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread * 4)
    tx, xi = s[AA].split(xi, nparts=num_thread)
    s[AA].bind(ty, thread_y)
    s[AA].bind(tx, thread_x)
    s[AA].vectorize(xi)
    # Schedule for B' shared memory load
    ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread)
    _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread * 4)
    tx, xi = s[BB].split(xi, nparts=num_thread)
    s[BB].bind(ty, thread_y)
    s[BB].bind(tx, thread_x)
    s[BB].vectorize(xi)
    s[AA].double_buffer()
    s[BB].double_buffer()
    # correctness
    def check_device(device):
        dev = tvm.device(device, 0)
        if not dev.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Device %s" % device)
        f = tvm.build(s, [A, B, C], device)
        # launch the kernel.
        n, m, l = nn, nn, nn
        a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
        b_np = np.random.uniform(size=(m, l)).astype(B.dtype)
        a = tvm.nd.array(a_np, dev)
        b = tvm.nd.array(b_np, dev)
        c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), dev)
        for i in range(2):
            f(a, b, c)
        tvm.testing.assert_allclose(c.asnumpy(), np.dot(b_np.T, a_np), rtol=1e-5)

        num_flops = 2 * nn * nn * nn
        num_runs = 10
        timer_f = f.time_evaluator(f.entry_name, dev, number=num_runs)
        t = timer_f(a, b, c).mean
        GFLOPS = num_flops / (t * 1e3) / 1e6
        print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS))

    for device in ["cuda", "opencl", "rocm", "nvptx", "vulkan"]:
        with tvm.transform.PassContext(
            config={"tir.UnrollLoop": {"auto_max_step": 128, "explicit_unroll": device != "cuda"}}
        ):
            check_device(device)
Exemplo n.º 2
0
def test_expr_constructor():
    x = tvm.tir.Var("xx", "float32")
    assert isinstance(x, tvm.tir.Var)
    assert x.name == "xx"

    x = tvm.tir.Reduce(None, [1], [tvm.tir.IterVar((0, 1), "x", 2)], None, 0)
    assert isinstance(x, tvm.tir.Reduce)
    assert x.combiner == None
    assert x.value_index == 0

    x = tvm.tir.FloatImm("float32", 1.0)
    assert isinstance(x, tvm.tir.FloatImm)
    assert x.value == 1.0
    assert x.dtype == "float32"

    x = tvm.tir.IntImm("int64", 2)
    assert isinstance(x, tvm.tir.IntImm)
    assert x.value == 2
    assert x.dtype == "int64"

    x = tvm.tir.StringImm("xyza")
    assert isinstance(x, tvm.tir.StringImm)
    assert x.value == "xyza"

    x = tvm.tir.Cast("float32", tvm.tir.IntImm("uint32", 1))
    assert isinstance(x, tvm.tir.Cast)
    assert x.dtype == "float32"
    assert x.value.value == 1

    a = tvm.tir.const(1.0, dtype="float32")
    b = te.var("x", dtype="float32")

    for cls in [
            tvm.tir.Add, tvm.tir.Sub, tvm.tir.Mul, tvm.tir.Div, tvm.tir.Mod,
            tvm.tir.Min, tvm.tir.Max, tvm.tir.LT, tvm.tir.LE, tvm.tir.GT,
            tvm.tir.GE
    ]:
        x = cls(a, b)
        assert isinstance(x, cls)
        assert x.a == a
        assert x.b.same_as(b)

    a = tvm.runtime.convert(te.var("x") > 1)
    b = tvm.runtime.convert(te.var("x") == 1)

    for cls in [tvm.tir.And, tvm.tir.Or]:
        x = cls(a, b)
        assert isinstance(x, cls)
        assert x.a == a
        assert x.b.same_as(b)

    x = tvm.tir.Not(a)
    assert isinstance(x, tvm.tir.Not)
    assert x.a == a

    x = tvm.tir.Select(a, a, b)
    assert isinstance(x, tvm.tir.Select)
    assert x.true_value == a
    assert x.false_value == b
    assert x.condition == a

    buffer_var = te.var("x", dtype="handle")
    x = tvm.tir.Load("float32", buffer_var, 1, a)
    assert isinstance(x, tvm.tir.Load)
    assert x.dtype == "float32"
    assert x.buffer_var == buffer_var
    assert x.index.value == 1
    assert x.predicate == a

    x = tvm.tir.Ramp(1, 2, 10)
    assert isinstance(x, tvm.tir.Ramp)
    assert x.base.value == 1
    assert x.stride.value == 2
    assert x.lanes == 10

    x = tvm.tir.Broadcast(a, 10)
    assert isinstance(x, tvm.tir.Broadcast)
    assert x.value == a
    assert x.lanes == 10

    x = tvm.tir.Shuffle([a], [0])
    assert isinstance(x, tvm.tir.Shuffle)
    assert x.vectors[0] == a
    assert x.indices[0].value == 0

    x = tvm.tir.Call("float32", "tir.call_extern",
                     [tvm.tir.StringImm("xyz"), a], tvm.tir.Call.Extern)
    assert isinstance(x, tvm.tir.Call)
    assert x.dtype == "float32"
    assert x.op.name == "tir.call_extern"
    assert x.args[1] == a
    assert x.call_type == tvm.tir.Call.Extern

    v = te.var("aa")
    x = tvm.tir.Let(v, 1, v)
    assert x.var == v
    assert x.value.value == 1
    assert x.body == v
def test_select():
    ck = IntSetChecker()
    x, y = te.var("x"), te.var("y")
    ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1),
              {x: tvm.arith.IntervalSet(0, 10)}, (-1, 11))
Exemplo n.º 4
0
# TVM adopts tensor semantics, with each intermediate result
# represented as a multi-dimensional array. The user needs to describe
# the computation rule that generates the tensors.
#
# We first define a symbolic variable n to represent the shape.
# We then define two placeholder Tensors, A and B, with given shape (n,)
#
# We then describe the result tensor C, with a compute operation.  The
# compute function takes the shape of the tensor, as well as a lambda
# function that describes the computation rule for each position of
# the tensor.
#
# No computation happens during this phase, as we are only declaring how
# the computation should be done.
#
n = te.var("n")
A = te.placeholder((n, ), name="A")
B = te.placeholder((n, ), name="B")
C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
print(type(C))

######################################################################
# Schedule the Computation
# ------------------------
# While the above lines describe the computation rule, we can compute
# C in many ways since the axis of C can be computed in a data
# parallel manner.  TVM asks the user to provide a description of the
# computation called a schedule.
#
# A schedule is a set of transformation of computation that transforms
# the loop of computations in the program.
def test_basic_operation():
    np.random.seed(0)
    shape = (10, 10)
    x = te.var("x", dtype='float32')
    k = te.reduce_axis((0, 10), name="k")
    l = te.reduce_axis((0, 10), name="l")
    A0 = te.placeholder(shape, name='A0')
    A1 = te.placeholder(shape, name='A1')
    zeros = np.zeros(shape)

    B = te.compute(shape, lambda i, j: A0[i, j], name='B')
    check_grad(B, [A0])

    B = te.compute(shape, lambda i, j: A0[i, j] + A1[i, j], name='B')
    check_grad(B, [A0, A1])

    B = te.compute(shape, lambda i, j: A0[i, j] + A0[j, i], name='B')
    check_grad(B, A0)

    B = te.compute(shape, lambda i, j: te.floor(A0[i, j]), name='B')
    check_grad(B, A0, desired_grads=[zeros])

    B = te.compute(shape, lambda i, j: te.ceil(A0[i, j]), name='B')
    check_grad(B, A0, desired_grads=[zeros])

    B = te.compute(shape, lambda i, j: te.trunc(A0[i, j]), name='B')
    check_grad(B, A0, desired_grads=[zeros])

    B = te.compute(shape, lambda i, j: te.round(A0[i, j]), name='B')
    check_grad(B, A0, desired_grads=[zeros])

    B = te.compute(shape, lambda i, j: A0[i, j] + te.exp(A0[j, i]), name='B')
    check_grad(B, A0)

    B = te.compute(
        shape,
        lambda i, j: te.log(0.1 + te.abs(A0[i, j] + te.exp(A0[j, i]))),
        name='B')
    check_grad(B, A0)

    B = te.compute(shape,
                   lambda i, j: te.sigmoid(A0[i, j] * A0[i, j] * A0[j, i]),
                   name='B')
    check_grad(B, A0)

    B = te.compute(shape,
                   lambda i, j: te.tanh(A0[i, j] * A0[i, j] * A0[j, i]),
                   name='B')
    check_grad(B, A0)

    B = te.compute(shape,
                   lambda i, j: te.sqrt(A0[i, j] * A0[i, j] * A0[j, i]),
                   name='B')
    check_grad(B, A0, data_range=(0.1, 10))

    B = te.compute(shape,
                   lambda i, j: te.power(te.abs(A0[i, j]), A0[j, i]),
                   name='B')
    check_grad(B, A0, data_range=(-4, 4))

    B = te.compute(shape, lambda i, j: A0[i, j] * A0[j, i], name='B')
    check_grad(B, A0)

    B = te.compute((10, ),
                   lambda i: te.sum(A0[i, k] * A0[k, i], axis=k),
                   name='B')
    check_grad(B, A0)

    B = te.compute(shape,
                   lambda i, j: te.sum(A0[i, k] * A0[k, i] + 5, axis=k),
                   name='B')
    check_grad(B, A0)

    B = te.compute(shape,
                   lambda i, j: te.max(A0[i, k] * A0[k, j] + 5, axis=k),
                   name='B')
    check_grad(B, A0)

    B = te.compute(shape,
                   lambda i, j: A0[i, j] * (A1[j, i] + A0[j, i]),
                   name='B')
    check_grad(B, [A0, A1])

    B = te.compute(shape,
                   lambda i, j: te.sum(
                       A0[k, k] - A0[te.min(j + k, 9), j] * A0[i, k], axis=k),
                   name='B')
    check_grad(B, A0)

    def fcombine(x, y):
        return x * y

    def fidentity(t0):
        return tvm.tir.const(1, t0)

    prod = te.comm_reducer(fcombine, fidentity, name='prod')
    B = te.compute((10, 10),
                   lambda i, j: prod(A0[i, k] + A0[k, i], axis=k),
                   name='B')
    check_grad(B, A0)

    X = te.placeholder((10, ), name='X')
    A = te.compute((10, ), lambda i: X[i] + X[9 - i])
    B = te.compute((10, ), lambda i: X[i] * X[9 - i])
    Y = topi.tensordot(A, B, 1)
    check_grad(Y, X)
Exemplo n.º 6
0
def gemm_4x4_int8_int8_int32(M, N, K, unroll, in_type):
    """
    Int8 4x4 matrix multiplication and accumulation using a sequence of
    umull -> uadalp -> umull2 -> uadalp instructions. This function
    takes two arrays of int8 data type  A[4][K] and B[4][K], and produces
    a 4x4 matrix which is equal to A*B'.

    The pseudo code is as follows.

    .. code-block:: c

        void gemm_4x4_int8_int8_int32(int8 A[4][K], int8 B[4][K], int32 C[4][4]){
            for (int i = 0; i < 4; i++){
                for (int j = 0; j < 4; j++){
                    for (int k = 0; k < K; k++){
                        C[i][j] += A[i][k] * B[j][k]
                    }
            }
        }

    Notes:
        * The tiling strategy is picked to maximize register usage.

    Parameters
    ----------
    M : int
        rows of the matrix A
    N : int
        columns of the matrix B
    K : int
        columns of matrix A
    unroll : bool
        Unroll the loop accumulation if True
    in_type : str, {'uint8', 'int8'}

    Returns
    -------
    intrin : TensorIntrin
        The ARM uint8/int8 TensorIntrin that can be used in tensorizing schedule
    """
    assert in_type in ["uint8", "int8"]
    A = te.placeholder((K // 16, te.var("m"), 16), dtype=in_type, name="A")
    B = te.placeholder((K // 16, te.var("n"), 16), dtype=in_type, name="B")
    dtype_vec = in_type + "x16"
    idxm = tvm.tir.indexmod

    k = te.reduce_axis((0, K), "k")
    C = te.compute(
        (te.var("m"), te.var("n")),
        lambda x, y: te.sum(
            A[k // 16, x, idxm(k, 16)].astype("int32") * B[
                k // 16, y, idxm(k, 16)].astype("int32"),
            axis=k,
        ),
        name="C",
    )

    a_buffer = tvm.tir.decl_buffer(
        A.shape,
        dtype=in_type,
        name="a_buffer",
        offset_factor=1,
        strides=[te.var("sa_1"), te.var("sa_2"), 1],
    )

    b_buffer = tvm.tir.decl_buffer(
        B.shape,
        dtype=in_type,
        name="b_buffer",
        offset_factor=1,
        strides=[te.var("sb_1"), te.var("sb_2"), 1],
    )

    c_buffer = tvm.tir.decl_buffer(C.shape,
                                   dtype="int32",
                                   name="c_buffer",
                                   offset_factor=1,
                                   strides=[te.var("sc"), 1])

    # Intrinsics used in the following algorithm
    umull_intrin = "llvm.aarch64.neon.umull" if in_type == "uint8" else "llvm.aarch64.neon.smull"
    uaddlp_intrin = "llvm.aarch64.neon.uaddlp" if in_type == "uint8" else "llvm.aarch64.neon.saddlp"
    addp_intrin = "llvm.aarch64.neon.addp"

    def uadalp(a, b):
        """Add pair and accumulate

        Parameters:
        ----------
        a: int16x8 vector
        b: int16x8 vector

        Returns:
        --------
            return a int32x4 vector

        Pseudocode:
        ----------
            a += (b0+b1, b2+b3, b4+b5, b6+b7)
        """

        return a + tvm.tir.call_llvm_pure_intrin("int32x4", uaddlp_intrin,
                                                 tvm.tir.const(1, "uint32"), b)

    def umull(a, b):
        """Multiply long (higher part)

        Parameters:
        ----------
        a: int8x16 vector
        b: int8x16 vector

        Returns:
        --------
            return a int16x8 vector

        Pseudocode:
        ----------
            c = (a0*b0, a1*b1, a2*b2, a3*b3, a4*b4, a5*b5, a6*b6, a7*b7)
        """
        a_high = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", a)
        b_high = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", b)
        c = tvm.tir.call_llvm_pure_intrin("int16x8", umull_intrin,
                                          tvm.tir.const(2, "uint32"), a_high,
                                          b_high)
        return c

    def umull2(a, b):
        """Multiply long (lower part)

        Parameters:
        ----------
        a: int8x16 vector
        b: int8x16 vector

        Returns:
        --------
            return a int16x8 vector

        Pseudocode:
        ----------
            c = (a8*b8, a9*b9, a10*b10, a11*b11, a12*b12, a13*b13, a14*b14, a15*b15)
        """
        a_low = tvm.tir.call_intrin("int8x8", "tir.vectorlow", a)
        b_low = tvm.tir.call_intrin("int8x8", "tir.vectorlow", b)
        c = tvm.tir.call_llvm_pure_intrin("int16x8", umull_intrin,
                                          tvm.tir.const(2, "uint32"), a_low,
                                          b_low)
        return c

    def addp(a, b):
        """Add two vectors in pairs

        Parameters:
        ----------
        a: int32x4 vector
        b: int32x4 vector

        Returns:
        --------
            return a int32x4 vector

        Pseudocode:
        ----------
            c = (a0+a1, a2+a3, b0+b1, b0+b3)
        """
        return tvm.tir.call_llvm_pure_intrin("int32x4", addp_intrin,
                                             tvm.tir.const(2, "uint32"), a, b)

    def accumulation_loop(M, N, ins, acc, tile_idx):
        """Internal tile accumulation. This function
        takes two arrays of int8 data type  A[tile_idx][4][16] and B[tile_idx][4][16], produces
        a 4x4 matrix which is equal to A*B' and accumulates into C[4][4]

        The pseudo code is as follows.

        .. code-block:: c

            void gemm_4x4_int8_int8_int32(int8 A[tile_idx][4][K],
                                          int8 B[tile_idx][4][K],
                                          int32 C[4][4]){
                for (int i = 0; i < 4; i++){
                    for (int j = 0; j < 4; j++){
                        for (int k = 0; k < 16; k++){
                            C[i][j] += A[tile_idx][i][k] * B[tile_idx][j][k]
                        }
                }
            }

        Notes:
            * The tiling strategy is picked to maximize register usage.

        Parameters:
        ----------
        M : int
            Number of total rows of the output matrix
        N : int
            Number of total columns of the output matrix
        ins : list of tvm.tir.buffer
            Input buffers
        acc : tvm.tir.ir_builder.BufferVar
            Bank of register accumulators
        tiled_idx : int
            Index of a sub-tile of A and B in A[tile_idx][:][:] and B[tile_idx][:][:].
            Please note that  0 <= tile_idx <= K//16

        """
        a0 = ins[0].vload([tile_idx, 0, 0], dtype_vec)
        a1 = tvm.tir.const(0, "int8x16")
        if M > 1:
            a1 = ins[0].vload([tile_idx, 1, 0], dtype_vec)
        a2 = tvm.tir.const(0, "int8x16")
        if M > 2:
            a2 = ins[0].vload([tile_idx, 2, 0], dtype_vec)
        a3 = tvm.tir.const(0, "int8x16")
        if M > 3:
            a3 = ins[0].vload([tile_idx, 3, 0], dtype_vec)

        b0 = ins[1].vload([tile_idx, 0, 0], dtype_vec)
        b1 = tvm.tir.const(0, "int8x16")
        if N > 1:
            b1 = ins[1].vload([tile_idx, 1, 0], dtype_vec)
        b2 = tvm.tir.const(0, "int8x16")
        if N > 2:
            b2 = ins[1].vload([tile_idx, 2, 0], dtype_vec)
        b3 = tvm.tir.const(0, "int8x16")
        if N > 3:
            b3 = ins[1].vload([tile_idx, 3, 0], dtype_vec)

        # First half
        # Lower part of a0 * {b0,b1,b2,b3}
        d00 = umull(a0, b0)
        d01 = umull(a0, b1)
        d02 = umull(a0, b2)
        d03 = umull(a0, b3)

        # Lower part of a1 * {b0,b1,b2,b3}
        d10 = umull(a1, b0)
        d11 = umull(a1, b1)
        d12 = umull(a1, b2)
        d13 = umull(a1, b3)

        # Accumulate
        acc[0] = uadalp(acc[0], d00)
        acc[1] = uadalp(acc[1], d01)
        acc[2] = uadalp(acc[2], d02)
        acc[3] = uadalp(acc[3], d03)
        acc[4] = uadalp(acc[4], d10)
        acc[5] = uadalp(acc[5], d11)
        acc[6] = uadalp(acc[6], d12)
        acc[7] = uadalp(acc[7], d13)

        # Higher part of a0 * {b0,b1,b2,b3}
        d00 = umull2(a0, b0)
        d01 = umull2(a0, b1)
        d02 = umull2(a0, b2)
        d03 = umull2(a0, b3)

        # Higher part of a1 * {b0,b1,b2,b3}
        d10 = umull2(a1, b0)
        d11 = umull2(a1, b1)
        d12 = umull2(a1, b2)
        d13 = umull2(a1, b3)

        # Accumulate again
        acc[0] = uadalp(acc[0], d00)
        acc[1] = uadalp(acc[1], d01)
        acc[2] = uadalp(acc[2], d02)
        acc[3] = uadalp(acc[3], d03)
        acc[4] = uadalp(acc[4], d10)
        acc[5] = uadalp(acc[5], d11)
        acc[6] = uadalp(acc[6], d12)
        acc[7] = uadalp(acc[7], d13)

        # Second half
        # Lower part of a2 * {b0,b1,b2,b3}
        d00 = umull(a2, b0)
        d01 = umull(a2, b1)
        d02 = umull(a2, b2)
        d03 = umull(a2, b3)

        # Lower part of a3 * {b0,b1,b2,b3}
        d10 = umull(a3, b0)
        d11 = umull(a3, b1)
        d12 = umull(a3, b2)
        d13 = umull(a3, b3)

        # Accumulate
        acc[8] = uadalp(acc[8], d00)
        acc[9] = uadalp(acc[9], d01)
        acc[10] = uadalp(acc[10], d02)
        acc[11] = uadalp(acc[11], d03)
        acc[12] = uadalp(acc[12], d10)
        acc[13] = uadalp(acc[13], d11)
        acc[14] = uadalp(acc[14], d12)
        acc[15] = uadalp(acc[15], d13)

        # Higher part of a2 * {b0,b1,b2,b3}
        d00 = umull2(a2, b0)
        d01 = umull2(a2, b1)
        d02 = umull2(a2, b2)
        d03 = umull2(a2, b3)

        # Lower part of a3 * {b0,b1,b2,b3}
        d10 = umull2(a3, b0)
        d11 = umull2(a3, b1)
        d12 = umull2(a3, b2)
        d13 = umull2(a3, b3)

        # Accumulate
        acc[8] = uadalp(acc[8], d00)
        acc[9] = uadalp(acc[9], d01)
        acc[10] = uadalp(acc[10], d02)
        acc[11] = uadalp(acc[11], d03)
        acc[12] = uadalp(acc[12], d10)
        acc[13] = uadalp(acc[13], d11)
        acc[14] = uadalp(acc[14], d12)
        acc[15] = uadalp(acc[15], d13)

    def _intrin_func(ins, outs):
        def _instr():
            ib = tvm.tir.ir_builder.create()
            # Allocate a local buffer (possibly translates to registers)
            acc = ib.allocate("int32x4", 16, name="accs", scope="local")
            m = outs[0].shape[0]
            n = outs[0].shape[1]
            # Initialization
            for i in range(0, 16):
                acc[i] = tvm.tir.const(0, "int32x4")

            if unroll:
                for i in range(0, int(K // 16)):
                    accumulation_loop(M, N, ins, acc, i)
            else:
                with ib.for_range(0, K // 16, name="i") as i:
                    accumulation_loop(M, N, ins, acc, i)

            # Final accumulations
            # acc[4*r + c] contains the partial accumulations of element C[r][c]
            #
            # In particular:
            # acc[4*r] contains the partial sums of a[r,0:K].*b[0,0:K] -> (a,b,c,d)
            # acc[4*r+1] contains the partial sums of a[r, 0:K].*b[1,0:K] -> (e,f,g,h)
            # acc[4*r+2] contains the partial sums of a[r, 0:K].*b[2,0:K] -> (i,j,k,l)
            # acc[4*r+3] contains the partial sums of a[r, 0:K].*b[3,0:K] -> (m,n,o,p)
            #
            # Please note that 0<= r, c < 4

            acc[0] = addp(acc[0], acc[1])  # (a+b, c+d, e+f, g+h)
            acc[1] = addp(acc[2], acc[3])  # (i+j, k+l, m+n, o+p)
            acc[0] = addp(acc[0],
                          acc[1])  # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)

            acc[4] = addp(acc[4], acc[5])  # (a+b, c+d, e+f, g+h)
            acc[5] = addp(acc[6], acc[7])  # (i+j, k+l, m+n, o+p)
            acc[4] = addp(acc[4],
                          acc[5])  # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)

            acc[8] = addp(acc[8], acc[9])  # (a+b, c+d, e+f, g+h)
            acc[9] = addp(acc[10], acc[11])  # (i+j, k+l, m+n, o+p)
            acc[8] = addp(acc[8],
                          acc[9])  # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)

            acc[12] = addp(acc[12], acc[13])  # (a+b, c+d, e+f, g+h)
            acc[13] = addp(acc[14], acc[15])  # (i+j, k+l, m+n, o+p)
            acc[12] = addp(acc[12],
                           acc[13])  # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)

            # Store the result
            if N > 3:
                out_0 = acc[0]
                out_1 = acc[4]
                out_2 = acc[8]
                out_3 = acc[12]
            elif N > 2:
                out_0 = tvm.tir.call_intrin("int32x3", "tir.reinterpret",
                                            acc[0])
                out_1 = tvm.tir.call_intrin("int32x3", "tir.reinterpret",
                                            acc[4])
                out_2 = tvm.tir.call_intrin("int32x3", "tir.reinterpret",
                                            acc[8])
                out_3 = tvm.tir.call_intrin("int32x3", "tir.reinterpret",
                                            acc[12])
            elif N > 1:
                out_0 = tvm.tir.call_intrin("int32x2", "tir.reinterpret",
                                            acc[0])
                out_1 = tvm.tir.call_intrin("int32x2", "tir.reinterpret",
                                            acc[4])
                out_2 = tvm.tir.call_intrin("int32x2", "tir.reinterpret",
                                            acc[8])
                out_3 = tvm.tir.call_intrin("int32x2", "tir.reinterpret",
                                            acc[12])
            else:
                out_0 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[0])
                out_1 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[4])
                out_2 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[8])
                out_3 = tvm.tir.call_intrin("int32", "tir.reinterpret",
                                            acc[12])

            ib.emit(outs[0].vstore([0, 0], out_0))
            if M > 1:
                ib.emit(outs[0].vstore([1, 0], out_1))
            if M > 2:
                ib.emit(outs[0].vstore([2, 0], out_2))
            if M > 3:
                ib.emit(outs[0].vstore([3, 0], out_3))
            return ib.get()

        # body, reset, update
        return _instr()

    buffer_params = {"offset_factor": 1}
    return te.decl_tensor_intrin(
        C.op,
        _intrin_func,
        binds={
            A: a_buffer,
            B: b_buffer,
            C: c_buffer
        },
        default_buffer_params=buffer_params,
    )
Exemplo n.º 7
0
def gemm_acc_4x4_int8_int8_int32(dtype):
    """
    Int8 4x4 matrix multiplication and accumulation using sdot/udot
    instructions. This function takes two arrays of int8 datatype
    -- A[4][4] and B[4][4] and produces a 4x4 matrix
    which is equal to A*B'.

    The pseudo code is as follows.

    .. code-block:: c

        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
            for (int i = 0; i < 4; i++){
                for (int j = 0; j < 4; j++){
                    for (int k = 0; k < 4; k++){
                        C[i][j] += A[i][k] * B[j][k]
                    }
            }
        }

    Notes:
        * The tiling strategy is picked to maximize register usage.

    Parameters
    ----------
    dtype : str, {"uint8", "int8"}
        Whether it works on unsigned int or signed int

    Returns
    -------
    intrin : TensorIntrin
        The Arm TensorIntrin that can be used in tensorizing schedule
    """
    assert dtype in ["uint8", "int8"]
    # This needs to be a variable number of "rows" since TVM
    # "thinks" I only need to compute one row because of
    # padding
    A = te.placeholder((te.var("rows"), 4), dtype, name="A")
    B = te.placeholder((4, 4), dtype, name="B")
    dtype_vec = dtype + "x16"

    k = te.reduce_axis((0, 4), name="k")
    C = te.compute(
        (te.var("rows"), 4),
        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"),
                            axis=k),
        name="C",
    )

    aa_buffer = tvm.tir.decl_buffer(A.shape,
                                    dtype,
                                    name="aa_buffer",
                                    offset_factor=1,
                                    strides=[te.var("sa"), 1])
    bb_buffer = tvm.tir.decl_buffer(B.shape,
                                    dtype,
                                    name="bb_buffer",
                                    offset_factor=1,
                                    strides=[te.var("sb"), 1])
    cc_buffer = tvm.tir.decl_buffer(C.shape,
                                    dtype="int32",
                                    name="cc_buffer",
                                    offset_factor=1,
                                    strides=[te.var("sc"), 1])

    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"

    def _intrin_func(ins, outs):
        def _instr(index):
            ib = tvm.tir.ir_builder.create()
            if index == 1:
                for i in range(0, 4):
                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0,
                                                                 "int32x4")))
                return ib.get()
            # Load all the elements of tile A.
            # vec_a = [a, b, c, d,
            #          e, f, g, h,
            #          l, m, n, o,
            #          p, q, r, s];
            vec_a = ins[0].vload([0, 0], dtype_vec)

            # Replicate 4 times the i-th row of A. For instance,
            # vec_a[0] = [a, b, c, d,
            #             a, b, c, d,
            #             a, b, c, d,
            #             a, b, c, d,];
            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]

            # Load all the elements of B. Remember that B
            # is transposed:
            # vec_b = [0, 4, 8, 12,
            #          1, 5, 9, 13,
            #          2, 6, 10, 14,
            #          3, 7, 11, 15,];
            vec_b = ins[1].vload([0, 0], dtype_vec)

            # Execute the dot product
            for i in range(0, 4):
                vec_c = outs[0].vload([i, 0], "int32x4")
                # Compute the product between the i-th row of A
                # and all the rows of B. Remember that sdot/udot
                # subdive the input vectors in 16 elements
                # and then take the dot product among each group.
                # The result is stored in a int32x4 register
                #
                # For instance, for i=0, we have:
                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
                #                           a*1+b*5+c*9+d*13,
                #                           a*2+b*6+c*10+d*14,
                #                           a*3+b*7+c*11+d*15]
                vdot = tvm.tir.call_llvm_intrin(
                    "int32x4",
                    llvm_intrin,
                    tvm.tir.const(3, "uint32"),
                    vec_c,
                    vec_b,
                    vec_aa[i],
                )

                # Store the result
                ib.emit(outs[0].vstore([i, 0], vdot))

            return ib.get()

        # body, reset, update
        return _instr(0), _instr(1), _instr(2)

    buffer_params = {"offset_factor": 1}
    return te.decl_tensor_intrin(
        C.op,
        _intrin_func,
        binds={
            A: aa_buffer,
            B: bb_buffer,
            C: cc_buffer
        },
        default_buffer_params=buffer_params,
    )
def test_max_index_simplify():
    ck = RewriteChecker()
    x, y, z = te.var("x"), te.var("y"), te.var("z")
    flm = tvm.te.floormod
    fld = tvm.te.floordiv
    tdiv = tvm.tir.truncdiv
    tmod = tvm.tir.truncmod
    # const int bound
    ck.verify(tvm.te.max(tmod(x, 2), tmod(y, 2) + 10), tmod(y, 2) + 10)
    ck.verify(tvm.te.max(flm(x, 2), flm(y, 2) + 10), flm(y, 2) + 10)

    ck.verify(tvm.te.max(x + 1, x + 10), x + 10)
    ck.verify(tvm.te.max(x + 111, x + 10), x + 111)
    ck.verify(tvm.te.max(x + 1, x), x + 1)
    ck.verify(tvm.te.max(x, x + 2), x + 2)
    ck.verify(tvm.te.max(1 - x, 2 - x), 2 - x)
    ck.verify(tvm.te.max(3 - x, 2 - x), 3 - x)

    ck.verify(tvm.te.max(tvm.te.min(x, y), tvm.te.max(x, y)), tvm.te.max(x, y))
    ck.verify(tvm.te.max(tvm.te.min(x, y), tvm.te.max(y, x)), tvm.te.max(x, y))

    ck.verify(tvm.te.max(tvm.te.min(x, y), x), x)
    ck.verify(tvm.te.max(tvm.te.min(y, x), x), x)
    ck.verify(tvm.te.max(tvm.te.max(x, y), x), tvm.te.max(x, y))
    ck.verify(tvm.te.max(tvm.te.max(x, y), y), tvm.te.max(x, y))

    ck.verify(tvm.te.max(x, tvm.te.min(x, y)), x)
    ck.verify(tvm.te.max(x, tvm.te.min(y, x)), x)
    ck.verify(tvm.te.max(x, tvm.te.max(x, y)), tvm.te.max(x, y))
    ck.verify(tvm.te.max(y, tvm.te.max(x, y)), tvm.te.max(x, y))

    ck.verify(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), y), tvm.te.max(tvm.te.max(x, y), z))
    ck.verify(
        tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), y),
        tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2),
    )
    ck.verify(
        tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), z * 2), y),
        tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), z * 2),
    )

    ck.verify(tvm.te.max(tvm.te.min(x, y), tvm.te.min(x, z)), tvm.te.min(tvm.te.max(y, z), x))
    ck.verify(tvm.te.max(tvm.te.min(x, y), tvm.te.min(z, x)), tvm.te.min(tvm.te.max(y, z), x))
    ck.verify(tvm.te.max(tvm.te.min(y, x), tvm.te.min(x, z)), tvm.te.min(tvm.te.max(y, z), x))
    ck.verify(tvm.te.max(tvm.te.min(y, x), tvm.te.min(z, x)), tvm.te.min(tvm.te.max(y, z), x))

    ck.verify(tvm.te.max(y + x, z + x), tvm.te.max(y, z) + x)
    ck.verify(tvm.te.max(y + x, x + z), tvm.te.max(y, z) + x)
    ck.verify(tvm.te.max(x + y, z + x), tvm.te.max(y, z) + x)
    ck.verify(tvm.te.max(x + y, x + z), tvm.te.max(y, z) + x)

    ck.verify(tvm.te.max(x - y, x - z), x - tvm.te.min(y, z))
    ck.verify(tvm.te.max(y - x, z - x), tvm.te.max(y, z) - x)

    ck.verify(tvm.te.max(tvm.te.max(x, 1), 10), tvm.te.max(x, 10))
    ck.verify(tvm.te.max(tvm.te.max(x, 11), 10), tvm.te.max(x, 11))

    ck.verify(tvm.te.max(x * 3, 9), tvm.te.max(x, 3) * 3)
    ck.verify(tvm.te.max(3 - x, 1), 3 - tvm.te.min(x, 2))
    ck.verify(tvm.te.max(x * 2, 0), tvm.te.max(x, 0) * 2)
    ck.verify(tvm.te.max(0 - x * 2, 0), tvm.te.min(x, 0) * -2)
    ck.verify(tvm.te.max(x * (-2), -4), tvm.te.min(x, 2) * -2)
    ck.verify(tvm.te.max(x * (-2), 4), tvm.te.min(x, -2) * -2)
    ck.verify(tvm.te.max(x * (0), 4), 4)
    ck.verify(tvm.te.max(x * (0), -4), 0)

    # DivMod rules
    # truc div
    ck.verify(tvm.te.max(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.te.max(x, y), 10))
    ck.verify(tvm.te.max(tdiv(x, (-10)), tdiv(y, (-10))), tdiv(tvm.te.min(x, y), (-10)))
    ck.verify(tvm.te.max(tdiv(x + 3, 4) * 4, x), tdiv(x + 3, 4) * 4)

    # floordiv
    ck.verify(tvm.te.max(fld(x, 10), fld(y, 10)), fld(tvm.te.max(x, y), 10))
    ck.verify(tvm.te.max(fld(x, (-10)), fld(y, (-10))), fld(tvm.te.min(x, y), (-10)))
    ck.verify(tvm.te.max(fld(x + 3, 4) * 4, x), fld(x + 3, 4) * 4)
    ck.verify(tvm.te.max(fld(x, 4) * 4, x), x)
    ck.verify(tvm.te.max(x, fld(x, 4) * 4), x)
def test_cmp_simplify():
    ck = RewriteChecker()
    x, y, z = te.var("x"), te.var("y"), te.var("z")
    flm = tvm.te.floormod
    fld = tvm.te.floordiv
    tdiv = tvm.tir.truncdiv
    tmod = tvm.tir.truncmod
    # const int bound
    ck.verify((tmod(x, 2) + 10).equal(0), tvm.tir.const(0, "bool"))
    ck.verify(tvm.tir.NE(tmod(x, 2) + 10, 0), tvm.tir.const(1, "bool"))
    ck.verify(tmod(x, 2) + 10 > 1, tvm.tir.const(1, "bool"))
    ck.verify(tmod(x, 2) + 10 <= 1, tvm.tir.const(0, "bool"))
    ck.verify(flm(x, 2) + 2 > 1, tvm.tir.const(1, "bool"))
    ck.verify(flm(x, 2) + 10 <= 1, tvm.tir.const(0, "bool"))

    ck.verify(x * 3 + 10 == 0, tvm.tir.const(0, "bool"))
    ck.verify(x * 3 + 10 != 0, tvm.tir.const(1, "bool"))

    # canonicalization
    ck.verify((x - 10).equal(0), x.equal(10))
    ck.verify((10 - x).equal(0), x.equal(10))
    ck.verify((x * y).equal(0), tvm.tir.Or(x.equal(0), y.equal(0)))

    # cmp bound
    ck.verify(x + y < x + z, y < z)
    ck.verify(x + y < z + x, y < z)
    ck.verify(y + x < x + z, y < z)
    ck.verify(y + x < z + x, y < z)
    ck.verify(y - x < z - x, y < z)
    ck.verify(x - y < x - z, z < y)

    ck.verify(x < z + x, tvm.tir.LT(0, z))
    ck.verify(x < x + z, tvm.tir.LT(0, z))

    ck.verify(100 < x + 1, tvm.tir.LT(99, x))
    ck.verify(1 < 100 - x, tvm.tir.LT(x, 99))
    ck.verify(x * 3 < y * 3, x < y)
    ck.verify(x * (-3) < y * (-3), y < x)
    ck.verify(x * 3 >= y * 3, y <= x)

    ck.verify(x * 4 >= 2, tvm.tir.LE(1, x))
    ck.verify(x * 2 >= 50, tvm.tir.LE(25, x))
    ck.verify(x * 4 <= 2, x <= 0)
    ck.verify((0 - x * 3) <= 0, tvm.tir.LE(0, x))
    ck.verify((0 - x * 3) >= 0, tvm.tir.LE(x, 0))
    ck.verify(2 * x <= 0, x <= 0)

    ck.verify(x * 2 >= 3, tvm.tir.LE(2, x))
    ck.verify(x * 2 >= 2, tvm.tir.LE(1, x))
    ck.verify(x * 2 >= 1, tvm.tir.LE(1, x))
    ck.verify(x * 2 >= 0, tvm.tir.LE(0, x))
    ck.verify(x * 2 >= -1, tvm.tir.LE(0, x))
    ck.verify(x * 2 >= -2, tvm.tir.LE(-1, x))
    ck.verify(x * 2 >= -3, tvm.tir.LE(-1, x))

    ck.verify(x * 2 <= 3, tvm.tir.LE(x, 1))
    ck.verify(x * 2 <= 2, tvm.tir.LE(x, 1))
    ck.verify(x * 2 <= 1, tvm.tir.LE(x, 0))
    ck.verify(x * 2 <= 0, tvm.tir.LE(x, 0))
    ck.verify(x * 2 <= -1, tvm.tir.LE(x, -1))
    ck.verify(x * 2 <= -2, tvm.tir.LE(x, -1))
    ck.verify(x * 2 <= -3, tvm.tir.LE(x, -2))

    ck.verify(x * (-2) >= 3, tvm.tir.LE(x, -2))
    ck.verify(x * (-2) >= 2, tvm.tir.LE(x, -1))
    ck.verify(x * (-2) >= 1, tvm.tir.LE(x, -1))
    ck.verify(x * (-2) >= 0, tvm.tir.LE(x, 0))
    ck.verify(x * (-2) >= -1, tvm.tir.LE(x, 0))
    ck.verify(x * (-2) >= -2, tvm.tir.LE(x, 1))
    ck.verify(x * (-2) >= -3, tvm.tir.LE(x, 1))

    ck.verify(x * (-2) <= 3, tvm.tir.LE(-1, x))
    ck.verify(x * (-2) <= 2, tvm.tir.LE(-1, x))
    ck.verify(x * (-2) <= 1, tvm.tir.LE(0, x))
    ck.verify(x * (-2) <= 0, tvm.tir.LE(0, x))
    ck.verify(x * (-2) <= -1, tvm.tir.LE(1, x))
    ck.verify(x * (-2) <= -2, tvm.tir.LE(1, x))
    ck.verify(x * (-2) <= -3, tvm.tir.LE(2, x))

    # DivMod rules
    # truc div
    ck.verify(tdiv(x, 2) < 3, x < 6)
    ck.verify(3 < tdiv(x, 2), tvm.tir.LT(7, x))
    ck.verify(tdiv(x, 3) >= 0, tvm.tir.LE(-2, x))
    ck.verify(tdiv(x, 2) >= 1, tvm.tir.LE(2, x))
    ck.verify(tdiv(x, 2) >= 0, tvm.tir.LE(-1, x))
    ck.verify(tdiv(x, 2) >= -1, tvm.tir.LE(-3, x))

    ck.verify(tdiv(x, 2) <= 1, tvm.tir.LE(x, 3))
    ck.verify(tdiv(x, 2) <= 0, tvm.tir.LE(x, 1))
    ck.verify(tdiv(x, 2) <= -1, tvm.tir.LE(x, -2))

    ck.verify(tdiv(x, 4) * 4 < x, tvm.tir.LT(0, tmod(x, 4)))
    ck.verify(tdiv(x, 4) * 4 >= x, tvm.tir.LE(tmod(x, 4), 0))

    ck.verify(tdiv(x, 4) * 4 < x + y, tvm.tir.LT(0, tmod(x, 4) + y))
    ck.verify(tdiv(x, 4) * 4 < x - y, tvm.tir.LT(y, tmod(x, 4)))

    ck.verify(tdiv(x + 2, 4) * 4 >= x, tvm.tir.LE(tmod(x + 2, 4), 2))
    ck.verify(tdiv(x + 2, 4) * 4 >= x + y, tvm.tir.LE(tmod(x + 2, 4) + y, 2))
    ck.verify(tdiv(x + 2, 4) * 4 >= x - y, tvm.tir.LE(tmod(x + 2, 4) + (-2), y))

    # floor div
    ck.verify(fld(x, 2) < 3, x < 6)
    ck.verify(3 < fld(x, 2), tvm.tir.LT(7, x))
    ck.verify(-3 < fld(x, 2), tvm.tir.LT(-5, x))
    ck.verify(fld(x, 3) >= 0, tvm.tir.LE(0, x))
    ck.verify(fld(x, 2) >= 1, tvm.tir.LE(2, x))
    ck.verify(fld(x, 2) >= 0, tvm.tir.LE(0, x))
    ck.verify(fld(x, 2) >= -1, tvm.tir.LE(-2, x))

    ck.verify(fld(x, 2) <= 1, tvm.tir.LE(x, 3))
    ck.verify(fld(x, 2) <= 0, tvm.tir.LE(x, 1))
    ck.verify(fld(x, 2) <= -1, tvm.tir.LE(x, -1))

    ck.verify(fld(x, 4) * 4 < x, tvm.tir.LT(0, flm(x, 4)))
    ck.verify(fld(x, 4) * 4 >= x, tvm.tir.LE(flm(x, 4), 0))

    ck.verify(fld(x, 4) * 4 < x + y, tvm.tir.LT(0, flm(x, 4) + y))
    ck.verify(fld(x, 4) * 4 < x - y, tvm.tir.LT(y, flm(x, 4)))

    ck.verify(fld(x + 2, 4) * 4 >= x, tvm.tir.LE(flm(x + 2, 4), 2))
    ck.verify(fld(x + 2, 4) * 4 >= x + y, tvm.tir.LE(flm(x + 2, 4) + y, 2))
    ck.verify(fld(x + 2, 4) * 4 >= x - y, tvm.tir.LE(flm(x + 2, 4) + (-2), y))
    # End DivMod Rules

    ck.verify(tvm.te.min(x, 11) < 10, x < 10)
    ck.verify(tvm.te.min(x, 8) < 10, tvm.tir.const(1, "bool"))
    ck.verify(tvm.te.max(8, x) > 10, tvm.tir.LT(10, x))
    ck.verify(x + 1 < tvm.te.max(8, x), x < 7)

    ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10), override=True)
    ck.analyzer.update(y, tvm.arith.ConstIntBound(-10, 0), override=True)
    ck.analyzer.update(z, tvm.arith.ConstIntBound(-5, 5), override=True)

    ck.verify(x < 11, tvm.tir.const(1, "bool"))
    ck.verify(x <= 10, tvm.tir.const(1, "bool"))
    ck.verify(z <= 5, tvm.tir.const(1, "bool"))
    ck.verify(x + y <= 10, tvm.tir.const(1, "bool"))
    ck.verify(x + y >= -10, tvm.tir.const(1, "bool"))
    ck.verify(z - 5 <= y + 10, tvm.tir.const(1, "bool"))
    ck.verify(tvm.tir.all(x > -1, z <= x + 5), tvm.tir.const(1, "bool"))
    ck.verify(x * y <= 0, tvm.tir.const(1, "bool"))
    ck.verify((x + 1) * (y - 1) < 0, tvm.tir.const(1, "bool"))
    ck.verify(y * y >= 0, tvm.tir.const(1, "bool"))
    ck.verify(x * 6 <= -3, tvm.tir.const(0, "bool"))
    ck.verify(tmod(y - 1, 3) == 0, tmod(y + (-1), 3) == 0)
def test_sub_index_simplify():
    ck = RewriteChecker()
    x, y, z = te.var("x"), te.var("y"), te.var("z")
    a, b = tvm.tir.Any(), tvm.tir.Any()

    ck.verify(x + y - y, x)
    ck.verify(x + y - x, y)
    ck.verify(x - (y + x), 0 - y)
    ck.verify(x - (x + y), 0 - y)

    ck.verify(tvm.te.min(x, y) - x, tvm.te.min(0, y - x))
    ck.verify(tvm.te.min(x, y) - y, tvm.te.min(x - y, 0))
    ck.verify(tvm.te.max(x, y) - x, tvm.te.max(0, y - x))
    ck.verify(tvm.te.max(x, y) - y, tvm.te.max(x - y, 0))

    ck.verify(x - tvm.te.min(x, y), tvm.te.max(0, x - y))
    ck.verify(y - tvm.te.min(x, y), tvm.te.max(y - x, 0))
    ck.verify(x - tvm.te.max(x, y), tvm.te.min(0, x - y))
    ck.verify(y - tvm.te.max(x, y), tvm.te.min(y - x, 0))

    # mul co-efficient foldng
    ck.verify(x - x, 0)
    ck.verify(a - a, 0)
    ck.verify(a - b, a - b)
    ck.verify(x * y - x, x * (y + (-1)))
    ck.verify(x * y - 10 * x, x * (y + (-10)))
    ck.verify(y * x - x * z, x * (y - z))
    ck.verify(y * x - z * x, x * (y - z))

    ck.verify(x + 10 - 20, x + (-10))

    # 4-operands pattern
    ck.verify((x + y) - (x + z), y - z)
    ck.verify((y + x) - (x + z), y - z)
    ck.verify((x + y) - (z + x), y - z)
    ck.verify((y + x) - (z + x), y - z)

    ck.verify(tvm.te.min(x + y, z) - x, tvm.te.min(y, z - x))
    ck.verify(tvm.te.min(y + x, z) - x, tvm.te.min(y, z - x))
    ck.verify(tvm.te.min(z, x + y) - x, tvm.te.min(z - x, y))
    ck.verify(tvm.te.min(z, y + x) - x, tvm.te.min(z - x, y))

    ck.verify(tvm.te.max(x + y, z) - x, tvm.te.max(y, z - x))
    ck.verify(tvm.te.max(y + x, z) - x, tvm.te.max(y, z - x))
    ck.verify(tvm.te.max(z, x + y) - x, tvm.te.max(z - x, y))
    ck.verify(tvm.te.max(z, y + x) - x, tvm.te.max(z - x, y))

    ck.verify(x - tvm.te.min(x + y, z), tvm.te.max(0 - y, x - z))
    ck.verify(x - tvm.te.min(y + x, z), tvm.te.max(0 - y, x - z))
    ck.verify(x - tvm.te.min(z, x + y), tvm.te.max(x - z, 0 - y))
    ck.verify(x - tvm.te.min(z, y + x), tvm.te.max(x - z, 0 - y))

    ck.verify(tvm.te.min(x, y) - tvm.te.min(y, x), 0)
    ck.verify(tvm.te.max(x, y) - tvm.te.max(y, x), 0)
    ck.verify(tvm.te.min(x, y) - tvm.te.min(x + 10, y + 10), -10)
    ck.verify(tvm.te.min(x + 10, y + 1) - tvm.te.min(x, y - 9), 10)

    # DivMod patterns
    # truc div
    tdiv = tvm.tir.truncdiv
    tmod = tvm.tir.truncmod
    ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
    ck.verify(x - tdiv(x, 3) * 3, tmod(x, 3))

    ck.verify(tdiv(x + 5, 3) - tdiv(x, 3), tdiv(tmod(x, 3) + 5, 3))
    ck.verify(tdiv(x + 5, 3) - tdiv(x + 1, 3), tdiv(tmod(x + 1, 3) + 4, 3))

    ck.verify(y - tdiv(y, (-5)) * (-5), tmod(y, 5))
    ck.verify(tdiv(y, 3) * 3 - y, 0 - tmod(y, 3))
    ck.verify(y - tdiv(y - 6, 5) * 5, tmod(y + (-6), 5) + 6)
    ck.verify(tdiv(y - 6, 5) * 5 - y, (-6) - tmod(y + (-6), 5))
    ck.verify(y - tdiv(y + z, 5) * 5, tmod(y + z, 5) - z)
    ck.verify(tdiv(y + z, 5) * 5 - y, z - tmod(y + z, 5))
    ck.verify(y - tdiv(y - z, 5) * 5, tmod(y - z, 5) + z)
    ck.verify(tdiv(y - z, 5) * 5 - y, 0 - tmod(y - z, 5) - z)

    ck.verify(y * 3 - tdiv(y, 2) * 6, tmod(y, 2) * 3)
    ck.verify(tdiv(y, 3) * 6 - y * 2, tmod(y, 3) * (-2))
    ck.verify(y * 5 - tdiv(y + z, 2) * 10, (tmod(y + z, 2) - z) * 5)
    ck.verify(y * 5 - tdiv(y - z, 2) * 10, (tmod(y - z, 2) + z) * 5)
    ck.verify(tdiv(y + z, 3) * 6 - y * 2, (z - tmod(y + z, 3)) * 2)
    ck.verify(tdiv(y - z, 3) * 6 - y * 2, (0 - tmod(y - z, 3) - z) * 2)
    ck.verify(5 * y - tdiv(y + z, 2) * 10, (tmod(y + z, 2) - z) * 5)
    ck.verify(5 * y - 10 * tdiv(y - z, 2), (tmod(y - z, 2) + z) * 5)
    ck.verify(6 * tdiv(y + z, 3) - y * 2, (z - tmod(y + z, 3)) * 2)
    ck.verify(tdiv(y - z, 3) * 6 - 2 * y, (0 - tmod(y - z, 3) - z) * 2)

    # floor div
    fld = tvm.te.floordiv
    flm = tvm.te.floormod
    ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), override=True)
    ck.analyzer.update(y, tvm.arith.ConstIntBound(-1000, 1000), override=True)
    ck.verify(x - fld(x, 3) * 3, flm(x, 3))
    ck.verify(fld(x + 5, 3) - fld(x, 3), fld(flm(x, 3) + 5, 3))
    ck.verify(fld(x + 5, 3) - fld(x + 2, 3), fld(flm(x + 2, 3), 3) + 1)

    ck.verify(fld(y, 3) * 3 - y, 0 - flm(y, 3))
    ck.verify(y - fld(y - 6, 5) * 5, flm(y + (-6), 5) + 6)
    ck.verify(fld(y - 6, 5) * 5 - y, (-6) - flm(y + (-6), 5))
    ck.verify(y - fld(y + z, 5) * 5, flm(y + z, 5) - z)
    ck.verify(fld(y + z, 5) * 5 - y, z - flm(y + z, 5))
    ck.verify(y - fld(y - z, 5) * 5, flm(y - z, 5) + z)
    ck.verify(fld(y - z, 5) * 5 - y, 0 - flm(y - z, 5) - z)
    ck.verify(y * 3 - fld(y, 2) * 6, flm(y, 2) * 3)
    ck.verify(fld(y, 3) * 6 - y * 2, flm(y, 3) * (-2))
    ck.verify(y * 5 - fld(y + z, 2) * 10, (flm(y + z, 2) - z) * 5)
    ck.verify(y * 5 - fld(y - z, 2) * 10, (flm(y - z, 2) + z) * 5)
    ck.verify(fld(y + z, 3) * 6 - y * 2, (z - flm(y + z, 3)) * 2)
    ck.verify(fld(y - z, 3) * 6 - y * 2, (0 - flm(y - z, 3) - z) * 2)
    ck.verify(5 * y - fld(y + z, 2) * 10, (flm(y + z, 2) - z) * 5)
    ck.verify(5 * y - 10 * fld(y - z, 2), (flm(y - z, 2) + z) * 5)
    ck.verify(6 * fld(y + z, 3) - y * 2, (z - flm(y + z, 3)) * 2)
    ck.verify(fld(y - z, 3) * 6 - 2 * y, (0 - flm(y - z, 3) - z) * 2)
def test_vector_simplify():
    ck = RewriteChecker()
    x, y, z = te.var("x"), te.var("y"), te.var("z")
    # Add rules
    ck.verify(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x + y, 3, 4))
    ck.verify(tvm.tir.Ramp(x, 1, 2) + y, tvm.tir.Ramp(x + y, 1, 2))
    ck.verify(y + tvm.tir.Ramp(x, 1, 2), tvm.tir.Ramp(y + x, 1, 2))
    ck.verify(y.astype("int32x2") + x.astype("int32x2"), (y + x).astype("int32x2"))
    ck.verify(tvm.tir.Broadcast(0, 4) + y, tvm.tir.Broadcast(y, 4))
    ck.verify(
        tvm.tir.Ramp(x, 1, 4).astype("float32x4") + tvm.tir.Broadcast(0.0, 4),
        tvm.tir.Ramp(x, 1, 4).astype("float32x4"),
    )
    # Sub rules
    ck.verify(tvm.tir.Ramp(x, 4, 4) - tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x - y, 2, 4))
    ck.verify(tvm.tir.Ramp(x, 1, 2) - y, tvm.tir.Ramp(x - y, 1, 2))
    ck.verify(y - tvm.tir.Ramp(x, 1, 2), tvm.tir.Ramp(y - x, -1, 2))
    ck.verify(y.astype("int32x2") - x.astype("int32x2"), (y - x).astype("int32x2"))

    # Mul rules
    ck.verify(y.astype("int32x2") * x.astype("int32x2"), (y * x).astype("int32x2"))
    ck.verify(tvm.tir.Ramp(x, 4, 4) * 2, tvm.tir.Ramp(x * 2, 8, 4))
    ck.verify(2 * tvm.tir.Ramp(x, 4, 4), tvm.tir.Ramp(x * 2, 8, 4))
    ck.verify(tvm.tir.Broadcast(0, 4) * x, tvm.tir.Broadcast(0, 4))
    ck.verify(tvm.tir.Broadcast(0.0, 4) * x, tvm.tir.Broadcast(0.0, 4))

    ## DivMod rules
    tdiv = tvm.tir.truncdiv
    tmod = tvm.tir.truncmod
    # truc div
    ck.verify(tdiv(y.astype("int32x2"), x.astype("int32x2")), tdiv(y, x).astype("int32x2"))
    ck.verify(tdiv(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Ramp(tdiv(x, 2), 2, 4))
    ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
    ck.verify(tdiv(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4"))
    ck.verify(tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8))
    # truc mod
    ck.verify(tmod(y.astype("int32x2"), x.astype("int32x2")), tmod(y, x).astype("int32x2"))
    ck.verify(tmod(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(tmod(x, 2), 4))
    ck.verify(tmod(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), tvm.tir.Ramp(1, 1, 4))
    ck.verify(tmod(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), tmod(tvm.tir.Ramp(1, 15, 4), 8))

    # floor div
    fld = tvm.te.floordiv
    flm = tvm.te.floormod
    ck.analyzer.update(x, tvm.arith.ConstIntBound(-10, 1000), override=True)
    ck.verify(fld(y.astype("int32x2"), x.astype("int32x2")), fld(y, x).astype("int32x2"))
    ck.verify(fld(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Ramp(fld(x, 2), 2, 4))
    ck.verify(fld(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4"))
    ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8))
    ck.verify(fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)), tvm.tir.Ramp(fld(x, 4), 2, 5))
    ck.verify(
        fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
        fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
    )
    ck.verify(fld(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)), tvm.tir.Broadcast(x * 2, 4))
    ck.verify(
        fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)),
        fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)),
    )
    ck.verify(
        fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), tvm.tir.Broadcast(4, 4)),
        fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), tvm.tir.Broadcast(4, 4)),
    )
    ck.verify(
        fld(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Broadcast(fld(x, 16), 4)
    )
    ck.verify(
        fld(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Broadcast(fld(x, 8), 4)
    )
    ck.verify(
        fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
        fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
    )  # Example negative case: x = 15; [60, 61, 62, 63, 64] / 64 = [0, 0, 0, 0, 1]
    ck.verify(
        fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
        fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
    )  # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [0, 1, 1, 1]
    ck.verify(
        fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
        fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
    )  # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [0, 1, 1, 1]

    # floor mod
    ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")), flm(y, x).astype("int32x2"))
    ck.verify(flm(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(flm(x, 2), 4))
    ck.verify(flm(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), tvm.tir.Ramp(1, 1, 4))
    ck.verify(flm(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), flm(tvm.tir.Ramp(1, 15, 4), 8))
    ck.verify(flm(tvm.tir.Ramp(x, 8, 4), tvm.tir.Broadcast(4, 4)), tvm.tir.Broadcast(flm(x, 4), 4))
    ck.verify(
        flm(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
        flm(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
    )
    ck.verify(flm(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)), tvm.tir.Ramp(0, 1, 4))
    ck.verify(
        flm(tvm.tir.Ramp(x * 8, 1, 5), tvm.tir.Broadcast(4, 5)),
        flm(tvm.tir.Ramp(0, 1, 5), tvm.tir.Broadcast(4, 5)),
    )
    ck.verify(
        flm(tvm.tir.Ramp(x * 8 + 7, 1, 4), tvm.tir.Broadcast(4, 4)),
        flm(tvm.tir.Ramp(3, 1, 4), tvm.tir.Broadcast(4, 4)),
    )
    ck.verify(
        flm(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 4, 64), 1, 4)
    )
    ck.verify(
        flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 8, 64), 2, 4)
    )
    ck.verify(
        flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
        flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
    )  # Example negative case: x = 15; [60, 61, 62, 63, 64] % 64 = [60, 61, 62, 63, 0]
    ck.verify(
        flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
        flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
    )  # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [63, 0, 1, 2]
    ck.verify(
        flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)),
        flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)),
    )  # Example negative case: x = 9; [18, 19, 20, ..., 25] % 20 = [18, 19, 0, 1, ..., 5]
    ck.verify(
        flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
        flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
    )  # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [63, 6, 13, 20]

    # Min/Max rules
    vx = te.var("vx", dtype="int32x2")
    vc = te.var("vc", dtype="uint1")
    ck.verify(
        tvm.te.min(y.astype("int32x2"), x.astype("int32x2")), tvm.te.min(y, x).astype("int32x2")
    )
    ck.verify(
        tvm.te.min(tvm.te.min(vx, y.astype("int32x2")), x.astype("int32x2")),
        tvm.te.min(vx, tvm.te.min(y, x).astype("int32x2")),
    )
    ck.verify(
        tvm.te.max(y.astype("int32x2"), x.astype("int32x2")), tvm.te.max(y, x).astype("int32x2")
    )
    ck.verify(
        tvm.te.max(tvm.te.max(vx, y.astype("int32x2")), x.astype("int32x2")),
        tvm.te.max(vx, tvm.te.max(y, x).astype("int32x2")),
    )

    ## Logical rules
    ck.verify(y.astype("int32x2").equal(x.astype("int32x2")), (y.equal(x)).astype("uint1x2"))
    ck.verify(
        tvm.tir.NE(y.astype("int32x2"), (x.astype("int32x2"))), (tvm.tir.NE(y, x)).astype("uint1x2")
    )
    ck.verify(y.astype("int32x2") > x.astype("int32x2"), (x < y).astype("uint1x2"))
    ck.verify(y.astype("int32x2") >= x.astype("int32x2"), (x <= y).astype("uint1x2"))
    ck.verify(y.astype("int32x2") < x.astype("int32x2"), (y < x).astype("uint1x2"))
    ck.verify(y.astype("int32x2") <= x.astype("int32x2"), (y <= x).astype("uint1x2"))
    ck.verify(
        tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
        (tvm.tir.And(y <= x, vc)).astype("uint1x2"),
    )
    ck.verify(
        tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
        (tvm.tir.Or(y <= x, vc)).astype("uint1x2"),
    )
Exemplo n.º 12
0
def verify_batch_matmul(x_batch, y_batch, M, N, K, dynamic=False, debug=False):

    if not dynamic:
        x = te.placeholder((x_batch, M, K), name="x")
        y = te.placeholder((y_batch, N, K), name="y")
        dtype = x.dtype
    else:
        assert x_batch == y_batch or x_batch == 1 or y_batch == 1
        batch_size = max(x_batch, y_batch)
        dynamic_batch_size = te.var("dynamic_batch_size")
        dynamic_M = te.var("dynamic_M")
        dynamic_N = te.var("dynamic_N")
        dynamic_K = te.var("dynamic_K")

        x = te.placeholder((dynamic_batch_size, dynamic_M, dynamic_K),
                           name="x")
        y = te.placeholder((dynamic_batch_size, dynamic_N, dynamic_K),
                           name="y")
        dtype = x.dtype

    # use memoize to pickle the test data for next time use
    @memoize("topi.tests.test_topi_batch_matmul")
    def get_ref_data():
        a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype)
        b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype)
        c_np = tvm.topi.testing.batch_matmul(a_np, b_np)
        return (a_np, b_np, c_np)

    # get the test data
    a_np, b_np, c_np = get_ref_data()

    def check_device(target, dev):
        print("Running on target: %s" % target)
        with tvm.target.Target(target):
            fcompute, fschedule = tvm.topi.testing.dispatch(
                target, _batch_matmul_implement)
            out = fcompute(x, y)
            if not dynamic:
                s = fschedule([out])
                out_shape = out.shape
            else:
                s = te.create_schedule(out.op)
                out_shape = (batch_size, M, N)

            if debug:
                print(tvm.lower(s, [x, y, out], simple_mode=True))

        a = tvm.nd.array(a_np, dev)
        b = tvm.nd.array(b_np, dev)
        c = tvm.nd.array(np.zeros(get_const_tuple(out_shape), dtype=dtype),
                         dev)
        f = tvm.build(s, [x, y, out], target, name="dense")
        f(a, b, c)
        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)

    for target, dev in tvm.testing.enabled_targets():
        if dynamic and (target == "cuda" or target == "nvptx"):
            print("Dynamic batch matmul test is skippped on %s" % target)
            continue

        check_device(target, dev)
Exemplo n.º 13
0
def test_stmt_constructor():
    v = te.var("aa")
    buffer_var = te.var("buf", dtype="handle")
    nop = tvm.tir.Evaluate(1)
    x = tvm.tir.LetStmt(v, 1, tvm.tir.Evaluate(1))
    assert isinstance(x, tvm.tir.LetStmt)
    assert x.var == v
    assert x.value.value == 1
    assert isinstance(x.body, tvm.tir.Evaluate)

    x = tvm.tir.AttrStmt(v == 1, "xx", 1, tvm.tir.Evaluate(1))
    assert isinstance(x, tvm.tir.AttrStmt)
    assert x.value.value == 1

    x = tvm.tir.AssertStmt(tvm.tir.const(1, "uint1"),
                           tvm.runtime.convert("hellow"), nop)
    assert isinstance(x, tvm.tir.AssertStmt)
    assert x.body == nop

    x = tvm.tir.For(te.var("x"), 0, 10, 0, 0, nop)
    assert isinstance(x, tvm.tir.For)
    assert x.min.value == 0
    assert x.extent.value == 10
    assert x.body == nop

    x = tvm.tir.Store(buffer_var, 1, 10, tvm.tir.const(1, "uint1"))
    assert isinstance(x, tvm.tir.Store)
    assert x.buffer_var == buffer_var
    assert x.index.value == 10
    assert x.value.value == 1

    tensor = te.placeholder((), dtype="float32")
    x = tvm.tir.Provide(tensor.op, 0, 10, [])
    assert isinstance(x, tvm.tir.Provide)
    assert x.value_index == 0
    assert x.value.value == 10

    x = tvm.tir.Allocate(buffer_var, "float32", [10],
                         tvm.tir.const(1, "uint1"), nop)
    assert isinstance(x, tvm.tir.Allocate)
    assert x.dtype == "float32"
    assert x.buffer_var == buffer_var
    assert x.body == nop

    x = tvm.tir.AttrStmt(buffer_var, "xyz", 1, nop)
    assert isinstance(x, tvm.tir.AttrStmt)
    assert x.node == buffer_var
    assert x.attr_key == "xyz"
    assert x.body == nop

    x = tvm.tir.Free(buffer_var)
    assert isinstance(x, tvm.tir.Free)
    assert x.buffer_var == buffer_var

    x = tvm.tir.Realize(None, 0, "float", [], tvm.tir.const(1, "uint1"), nop)
    assert isinstance(x, tvm.tir.Realize)
    assert x.body == nop

    x = tvm.tir.IfThenElse(tvm.tir.const(1, "uint1"), tvm.tir.Evaluate(11),
                           nop)
    assert isinstance(x, tvm.tir.IfThenElse)
    assert x.then_case.value.value == 11
    assert x.else_case == nop

    x = tvm.tir.Prefetch(None, 1, "float32", [])
    assert isinstance(x, tvm.tir.Prefetch)
    assert x.value_index == 1
def check_packed_func(target="llvm"):
    ib = tvm.tir.ir_builder.create()

    m = n = k = 16

    #
    # Prepare buffer for a, b and c:
    #
    a = te.placeholder((m, k), name="a", dtype="float64")
    b = te.placeholder((k, n), name="b", dtype="float64")
    k = te.reduce_axis((0, k), name="k")
    c = te.compute((m, n),
                   lambda i, j: te.sum(a[i, k] * b[k, j], axis=k),
                   name="c")

    a_buffer = tvm.tir.decl_buffer(a.shape,
                                   a.dtype,
                                   name="a_buffer",
                                   offset_factor=1,
                                   strides=[te.var("s1"), 1])
    b_buffer = tvm.tir.decl_buffer(b.shape,
                                   b.dtype,
                                   name="b_buffer",
                                   offset_factor=1,
                                   strides=[te.var("s2"), 1])
    c_buffer = tvm.tir.decl_buffer(c.shape,
                                   c.dtype,
                                   name="c_buffer",
                                   offset_factor=1,
                                   strides=[te.var("s3"), 1])

    with ib.for_range(0, 10, "i", kind="parallel"):
        ib.emit(
            tvm.tir.call_packed("tvm.test_matmul", a_buffer, b_buffer,
                                c_buffer))

    stmt = ib.get()

    # Construct a valid IRModule to be lowered:
    mod = tvm.IRModule.from_expr(
        tvm.tir.PrimFunc([a_buffer, b_buffer, c_buffer], stmt))

    target = tvm.target.Target(target)
    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
    mod = tvm.tir.transform.Apply(
        lambda f: f.with_attr("global_symbol", "main"))(mod)
    mod = tvm.tir.transform.MakePackedAPI()(mod)

    # Do the lowering:
    mod = tvm.tir.transform.LowerTVMBuiltin()(mod)

    # Get the PrimFunc from module:
    prim_func = mod.functions.items()[0][1]

    node = prim_func.body

    # Recursively visit PrimFunc until we meet the for-loop:
    while isinstance(node,
                     (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt)):
        node = node.body

    # For-loop:
    assert isinstance(node, tvm.tir.stmt.For)

    #
    # let stack_tcode = tir.tvm_stack_alloca("arg_tcode", 4)
    #
    alloca_tcode = node.body
    assert isinstance(alloca_tcode, tvm.tir.LetStmt)

    expected_value = tvm.tir.call_intrin("handle",
                                         tvm.ir.Op.get("tir.tvm_stack_alloca"),
                                         "arg_tcode", 4)
    expected_var = alloca_tcode.var
    expected_stmt = tvm.tir.LetStmt(expected_var, expected_value,
                                    alloca_tcode.body)

    tvm.ir.assert_structural_equal(alloca_tcode,
                                   expected_stmt,
                                   map_free_vars=True)

    #
    # let stack_value = tir.tvm_stack_alloca("arg_value", 4)
    #
    alloca_value = alloca_tcode.body
    assert isinstance(alloca_value, tvm.tir.LetStmt)

    expected_value = tvm.tir.call_intrin("handle",
                                         tvm.ir.Op.get("tir.tvm_stack_alloca"),
                                         "arg_value", 4)
    expected_var = alloca_value.var
    expected_stmt = tvm.tir.LetStmt(expected_var, expected_value,
                                    alloca_value.body)

    tvm.ir.assert_structural_equal(alloca_value,
                                   expected_stmt,
                                   map_free_vars=True)

    #
    # let stack_array = tir.tvm_stack_alloca("array", 3)
    #
    alloca_array = alloca_value.body
    assert isinstance(alloca_array, tvm.tir.LetStmt)

    expected_value = tvm.tir.call_intrin("handle",
                                         tvm.ir.Op.get("tir.tvm_stack_alloca"),
                                         "array", 3)
    expected_var = alloca_array.var
    expected_stmt = tvm.tir.LetStmt(expected_var, expected_value,
                                    alloca_array.body)

    tvm.ir.assert_structural_equal(alloca_array,
                                   expected_stmt,
                                   map_free_vars=True)

    #
    # let stack_shape = tir.tvm_stack_alloca("shape", 12)
    #
    alloca_shape = alloca_array.body
    assert isinstance(alloca_shape, tvm.tir.LetStmt)

    expected_value = tvm.tir.call_intrin("handle",
                                         tvm.ir.Op.get("tir.tvm_stack_alloca"),
                                         "shape", 12)
    expected_var = alloca_shape.var
    expected_stmt = tvm.tir.LetStmt(expected_var, expected_value,
                                    alloca_shape.body)

    tvm.ir.assert_structural_equal(alloca_shape,
                                   expected_stmt,
                                   map_free_vars=True)
Exemplo n.º 15
0
def smlal_int16_int32():
    """
    Intrinsic to be used in order to load two int16x8 vectors and multiply
    them together through a pair of smlal/smlal2 instructions. The pseudo-code
    for the algorithm is as follows:

        vec_a = vload(A, "int16x8")
        vec_b = vload(B, "int16x8")

        vec_c[0:4] += vec_a[0:4]*vec_b[0:4] //  -> smlal instruction
        vec_c[4:8] += vec_a[4:8]*vec_b[4:8] // -> smlal2 instruction

    So we load a single int16x8 vector and we accumulate its lower (0:4) and
    higher part separately.
    """
    int16_lanes = 8
    A = te.placeholder((int16_lanes, ), dtype="int16", name="A")
    B = te.placeholder((int16_lanes, 1), dtype="int16", name="B")
    C = te.compute(
        (int16_lanes, ),
        lambda i: A[i].astype("int32") * B[i, 0].astype("int32"),
        name="C",
    )

    a_buffer = tvm.tir.decl_buffer(A.shape,
                                   dtype="int16",
                                   name="a_buffer",
                                   offset_factor=1,
                                   strides=[1])
    b_buffer = tvm.tir.decl_buffer(
        B.shape,
        dtype="int16",
        name="b_buffer",
        offset_factor=1,
        strides=[te.var("sb"), 1],
    )
    c_buffer = tvm.tir.decl_buffer(
        C.shape,
        dtype="int32",
        name="c_buffer",
        offset_factor=1,
        strides=[1],
    )

    def _intrin_func(ins, outs):
        def _instr(index):
            ib = tvm.tir.ir_builder.create()
            if index == 1:
                ib.emit(outs[0].vstore(0, tvm.tir.const(0, "int32x8")))
                return ib.get()

            vec_a = ins[0].vload([0], "int16x8")
            vec_b = ins[1].vload([0, 0], "int16x8")
            inst = "llvm.aarch64.neon.smull"

            # Higher part of the vector
            vec_c_h = outs[0].vload([4], "int32x4")
            vec_a_h = tvm.tir.call_intrin("int16x4", "tir.vectorhigh", vec_a)
            vec_b_h = tvm.tir.call_intrin("int16x4", "tir.vectorhigh", vec_b)
            vmull_h = tvm.tir.call_llvm_pure_intrin("int32x4", inst,
                                                    tvm.tir.const(2, "uint32"),
                                                    vec_a_h, vec_b_h)
            vec_out_h = vec_c_h + vmull_h

            # Lower part of the vector
            vec_c_l = outs[0].vload([0], "int32x4")
            vec_a_l = tvm.tir.call_intrin("int16x4", "tir.vectorlow", vec_a)
            vec_b_l = tvm.tir.call_intrin("int16x4", "tir.vectorlow", vec_b)
            vmull_l = tvm.tir.call_llvm_pure_intrin("int32x4", inst,
                                                    tvm.tir.const(2, "uint32"),
                                                    vec_a_l, vec_b_l)
            vec_out_l = vec_c_l + vmull_l

            # Combine higher and lower part in a single int32x8 vector to store
            # (this will require two different store instructions, since the
            # length of a NEON vector is fixed at 128
            vec_out = tvm.tir.call_intrin("int32x8", "tir.vectorcombine",
                                          vec_out_l, vec_out_h)
            ib.emit(outs[0].vstore(0, vec_out))
            return ib.get()

        # body, reset, update
        return _instr(0), _instr(1), _instr(2)

    buffer_params = {"offset_factor": 1}
    return te.decl_tensor_intrin(
        C.op,
        _intrin_func,
        binds={
            A: a_buffer,
            B: b_buffer,
            C: c_buffer
        },
        default_buffer_params=buffer_params,
    )
def test_let_simplify():
    ck = RewriteChecker()
    x, y = te.var("x"), te.var("y")
    z = tvm.tir.Let(x, 1, x + 1)
    ck.verify(z + z, 4)
Exemplo n.º 17
0
def gemm_acc_2x2_int8_int8_int32(dtype):
    """
    Int8 2x2 matrix multiplication using smmla/ummla instructions
    This function takes two arrays of int8 datatype -- A[2][8] and
    B[2][8] and produces a 2x2 matrix which is equal to A*B'
    The pseudo code is as follows.

    .. code-block:: c

        void mmla_2x2_int8_int8_int32(int8 A[2][8], int8 B[2][8], int32 C[2][2]){
            for (int i = 0; i < 2; i++){
                for (int j = 0; j < 2; j++){
                    for (int k = 0; k < 8; k++){
                        C[i][j] += A[i][k] * B[j][k]
                    }
            }
        }

    Parameters
    ----------
    dtype : str, {"uint8", "int8"}
        Whether it works on unsigned int or signed int

    Returns
    -------
    intrin : TensorIntrin
        The Arm TensorIntrin that can be used in tensorizing schedule
    """
    assert dtype in ["uint8", "int8"]
    A = te.placeholder((2, 8), dtype, name="A")
    B = te.placeholder((2, 8), dtype, name="B")
    dtype_vec = dtype + "x16"

    k = te.reduce_axis((0, 8), name="k")
    C = te.compute(
        (2, 2),
        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"),
                            axis=k),
        name="C",
    )

    aa_buffer = tvm.tir.decl_buffer(A.shape,
                                    dtype,
                                    name="aa_buffer",
                                    offset_factor=1,
                                    strides=[te.var("sa"), 1])
    bb_buffer = tvm.tir.decl_buffer(B.shape,
                                    dtype,
                                    name="bb_buffer",
                                    offset_factor=1,
                                    strides=[te.var("sb"), 1])
    cc_buffer = tvm.tir.decl_buffer(C.shape,
                                    dtype="int32",
                                    name="cc_buffer",
                                    offset_factor=1,
                                    strides=[te.var("sc"), 1])

    llvm_intrin = "llvm.aarch64.neon.smmla" if dtype == "int8" else "llvm.aarch64.neon.ummla"

    def _intrin_func(ins, outs):
        def _instr(index):
            ib = tvm.tir.ir_builder.create()
            if index == 1:
                ib.emit(outs[0].vstore([0, 0], tvm.tir.const(0, "int32x4")))
                return ib.get()
            # Load in vec_a the two rows of A
            # vec_a = [a, b, c, d, e, f, g, h;
            #          i, j, k, l, m, n, o, p,]
            vec_a = ins[0].vload([0, 0], dtype_vec)
            # Load in vec_b the two rows of B
            # vec_b = [0, 2, 4, 6, 8, 10, 12, 14;
            #          1, 3, 5, 7, 9, 11, 13, 14,]
            vec_b = ins[1].vload([0, 0], dtype_vec)

            # Execute the matrix multiplication via (s/u)mmla:
            # vec_c = [a*0 + b*2 + c*4 + d*6 +e*8 + f*10 + g*12 + h*14;
            #          a*1 + b*3 + c*5 + d*7 +e*9 + f*11 + g*13 + h*15;
            #          i*0 + j*2 + k*4 + l*6 +m*8 + n*10 + o*12 + p*14;
            #          i*1 + j*3 + k*5 + l*7 +m*9 + n*11 + o*13 + p*15]
            vec_c = outs[0].vload([0, 0], "int32x4")
            vmmla = tvm.tir.call_llvm_intrin(
                "int32x4",
                llvm_intrin,
                tvm.tir.const(3, "uint32"),
                vec_c,
                vec_a,
                vec_b,
            )
            # Store the result
            ib.emit(outs[0].vstore([0, 0], vmmla))
            return ib.get()

        # body, reset, update
        return _instr(0), _instr(1), _instr(2)

    buffer_params = {"offset_factor": 1}
    return te.decl_tensor_intrin(
        C.op,
        _intrin_func,
        binds={
            A: aa_buffer,
            B: bb_buffer,
            C: cc_buffer
        },
        default_buffer_params=buffer_params,
    )
Exemplo n.º 18
0
def rnn_matexp():
    n_num_step = 128
    n_num_hidden = 1152
    n_batch_size = 4
    detect_global_barrier = DETECT_GLOBAL_BARRIER

    num_step = te.var("num_step")
    num_hidden = tvm.runtime.convert(n_num_hidden)
    batch_size = tvm.runtime.convert(n_batch_size)
    num_thread_y = 8
    num_thread_x = 16 * 3
    num_sm = 24

    Whh = te.placeholder((num_hidden, num_hidden), name="Whh")
    s_init = te.compute((1, batch_size, num_hidden),
                        lambda _, i, j: 1.0,
                        name="init")
    s_state = te.placeholder((num_step, batch_size, num_hidden))
    kh = te.reduce_axis((0, num_hidden), name="kh")
    s_update = te.compute(
        (num_step, batch_size, num_hidden),
        lambda t, i, j: te.sum(s_state[t - 1, i, kh] * Whh[kh, j], axis=kh),
        name="update",
    )
    s_scan = tvm.te.scan(s_init, s_update, s_state)
    # schedule
    s = te.create_schedule(s_scan.op)
    CL = s_update
    SS = s.cache_read(s_state, "shared", [CL])
    SL = s.cache_read(SS, "local", [CL])
    WhhL = s.cache_read(Whh, "local", [CL])
    ko, ki = s[CL].split(s[CL].op.reduce_axis[0], nparts=num_thread_y)
    CLF = s.rfactor(CL, ko)

    block_x = te.thread_axis((0, num_sm), "blockIdx.x")
    thread_x = te.thread_axis((0, num_thread_x), "threadIdx.x")
    thread_y = te.thread_axis((0, num_thread_y), "threadIdx.y")
    if PERSIST_KERNEL:
        s[s_scan.op].env_threads([block_x, thread_y, thread_x])

    bx, xi = s[s_init].split(s_init.op.axis[2], nparts=num_sm)
    tx, xi = s[s_init].split(xi, nparts=num_thread_x)
    s[s_init].bind(bx, block_x)
    s[s_init].bind(tx, thread_x)

    bx, xi = s[s_update].split(s[CL].op.axis[2], nparts=num_sm)
    tx, xi = s[s_update].split(xi, nparts=num_thread_x)
    s[s_update].bind(bx, block_x)
    s[s_update].bind(tx, thread_x)
    s[CL].bind(s[CL].op.reduce_axis[0], thread_y)
    s[CLF].compute_at(s[CL], s[CL].op.reduce_axis[0])
    # Duplicate store predicate.
    s[CL].set_store_predicate(thread_y.equal(0))

    if PERSIST_KERNEL:
        s[WhhL].compute_at(s[s_scan], thread_x)
        s[WhhL].unroll(WhhL.op.axis[0])
    else:
        s[WhhL].compute_at(s[CLF], CLF.op.axis[3])

    kr, ki = s[CLF].split(CLF.op.reduce_axis[0], nparts=1)
    ko, ki = s[CLF].split(ki, factor=4)
    s[SS].compute_at(s[CLF], kr)
    s[SL].compute_at(s[CLF], ko)

    xo, xi = s[SS].split(SS.op.axis[2], factor=num_thread_x * num_thread_y * 3)
    ty, xi = s[SS].split(xi, nparts=num_thread_y)
    tx, xi = s[SS].split(xi, nparts=num_thread_x)
    s[SS].bind(ty, thread_y)
    s[SS].bind(tx, thread_x)

    def check_device(target):
        with tvm.transform.PassContext(
                config={
                    "tir.UnrollLoop": {
                        "auto_max_step": 128,
                    },
                    "tir.detect_global_barrier": detect_global_barrier,
                }):
            f = tvm.build(s, [s_scan, Whh], target)
        dev = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
        # launch the kernel.
        res_np = np.zeros(
            (n_num_step, n_batch_size, n_num_hidden)).astype("float32")
        Whh_np = np.zeros((n_num_hidden, n_num_hidden)).astype("float32")
        Whh_np[:] = 2.0 / n_num_hidden
        Whh_np[:, n_num_hidden // 2:] = 0

        res_a = tvm.nd.array(res_np, dev)
        Whh_a = tvm.nd.array(Whh_np, dev)
        # Skip first pass as it is compilation
        f(res_a, Whh_a)
        dev.sync()
        # measure time cost of second step.
        tstart = time.time()
        f(res_a, Whh_a)
        dev.sync()
        tgap = time.time() - tstart
        print("Time cost=%g" % tgap)
        # correctness
        if not SKIP_CHECK:
            res_gpu = res_a.asnumpy()
            res_cmp = np.ones_like(res_np).astype("float64")
            Whh_np = Whh_np.astype("float64")
            for t in range(1, n_num_step):
                res_cmp[t][:] = np.dot(res_cmp[t - 1], Whh_np)
            for i in range(n_num_step):
                for j in range(n_num_hidden):
                    if abs(res_cmp[i, 0, j] - res_gpu[i, 0, j]) > 1e-5:
                        print("%d, %d: %g vs %g" %
                              (i, j, res_cmp[i, 0, j], res_gpu[i, 0, j]))
            tvm.testing.assert_allclose(res_gpu, res_cmp, rtol=1e-3)

    check_device("cuda")
Exemplo n.º 19
0
def dot_int8_int8_int32_neon():
    """
    Int8 dot product using vmlal instructions

    .. code-block:: c

        void dot_int8_int8_int32(int8 data[4], int8 kernel[4][4], int32 output[4]){
            for (int i = 0; i < 4; i++){
                out[i] = 0;
                for (int k = 0; k < 4; k++){
                    out[i] += data[k] * kernel[i][k]
                }
            }
        }

    We use the smull and saddlp instructions to compute the dot product.
    smull : int8x16 -> int8x16 -> int16x8 elementwise multiplication
    saddlp: int16x8 -> int32x4 pairwise addition of elements

    Data is broadcast across the register
    int8 elements
    |         data      |         data      |
    |  0 |  1 |  2 |  3 |  4 |  5 |  6 |  7 |

                      smull

    int8 elements
    |     kernel[i]     |     kernel[i+1]   |
    |  0 |  1 |  2 |  3 |  4 |  5 |  6 |  7 |

                        =

    int16 elements
    |               data * kernel[i]        |         data * kernel[i+1]            |
    |    0    |    1    |    2    |    3    |    4    |    5    |    6    |    7    |

                                          saddlp =

    int32 elements
    |    partial sum(data * kernel[i])      |  partial sum(data * kernel[i+1])      |
    |         0         |         1         |         2         |         3         |


    We apply the above kernel twice and use addp to compute the second set of pairwise additions

    int32 elements (narrowed for so they fit on a line)
    |    psum d*k[i]    |   psum d*k[i+1]   |           |   psum d*k[i+2]   |   psum d*k[i+3]   |
    |    0    |    1    |    2    |    3    |   addp    |    4    |    5    |    6    |    7    |
                                                 =
    |sum d*ki |sum d*ki1|sum d*ki2|sum d*ki3|
    |    0    |    1    |    2    |    3    |


    """
    int32_lanes = 4  # 4 int32 lanes = 128
    num_int8_elements = 4  # 4 int8 elements in int32
    data = te.placeholder((num_int8_elements, ), dtype="int8", name="data")
    kernel = te.placeholder((int32_lanes, num_int8_elements),
                            dtype="int8",
                            name="kernel")
    k = te.reduce_axis((0, num_int8_elements), name="k")
    C = te.compute(
        (int32_lanes, ),
        lambda i: te.sum(
            data[k].astype("int32") * kernel[i, k].astype("int32"), axis=k),
        name="C",
    )

    a_buffer = tvm.tir.decl_buffer(data.shape,
                                   dtype="int8",
                                   name="a_buffer",
                                   offset_factor=1,
                                   strides=[1])
    b_buffer = tvm.tir.decl_buffer(kernel.shape,
                                   dtype="int8",
                                   name="b_buffer",
                                   offset_factor=1,
                                   strides=[te.var("ldw"), 1])

    def _intrin_func(ins, outs):
        def _instr(index):
            int_8xl = "int8x8"
            int_32xl = "int32x4"
            ib = tvm.tir.ir_builder.create()
            if index == 1:
                ib.emit(outs[0].vstore(0, tvm.tir.const(0, int_32xl)))
                return ib.get()

            # this broadcasts data to the vector size
            a_int8 = ins[0].vload([0], "int8x4")
            re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8)
            vec_ai32 = re_int32.astype("int32x2")
            vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32)

            vec_b = ins[1].vload([0, 0], "int8x16")

            def pairwise_add_mul(extract_half):
                vec_b_half = tvm.tir.call_intrin("int8x8", extract_half, vec_b)
                multiply = tvm.tir.call_llvm_pure_intrin(
                    "int16x8",
                    "llvm.aarch64.neon.smull.v8i16",  # saturating pairwise multiplication
                    tvm.tir.const(2, "uint32"),
                    vec_a,
                    vec_b_half,
                )
                pairwise_reduction = tvm.tir.call_llvm_pure_intrin(
                    "int32x4",
                    "llvm.aarch64.neon.saddlp.v4i32.v8i16",
                    tvm.tir.const(1, "uint32"),
                    multiply,
                )
                return pairwise_reduction

            pair_1 = pairwise_add_mul("tir.vectorlow")
            pair_2 = pairwise_add_mul("tir.vectorhigh")
            quad_reduction = tvm.tir.call_llvm_pure_intrin(
                "int32x4",
                "llvm.aarch64.neon.addp.v4i32",
                tvm.tir.const(2, "uint32"),
                pair_1,
                pair_2,
            )
            if index == 0:
                ib.emit(outs[0].vstore(0, quad_reduction))
            else:
                ib.emit(outs[0].vstore(
                    0, quad_reduction + outs[0].vload([0], int_32xl)))
            return ib.get()

        # body, reset, update
        return _instr(0), _instr(1), _instr(2)

    buffer_params = {"offset_factor": 1}
    return te.decl_tensor_intrin(
        C.op,
        _intrin_func,
        binds={
            data: a_buffer,
            kernel: b_buffer
        },
        default_buffer_params=buffer_params,
    )
Exemplo n.º 20
0
def test_domain_touched():
    i = te.var("i")
    j = te.var("j")
    n = tvm.runtime.convert(100)
    m = te.var("m")

    a = tvm.tir.decl_buffer((n, m), name="a")
    b = tvm.tir.decl_buffer((n, m), name="b")

    ir = tvm.tir.For(
        i,
        0,
        n,
        0,
        0,
        tvm.tir.For(
            j,
            0,
            m,
            0,
            0,
            tvm.tir.BufferStore(
                a,
                tvm.tir.BufferLoad(b, [i - 1, j + 1]) +
                tvm.tir.BufferLoad(a, [i - 1, j - 1]),
                [i, j],
            ),
        ),
    )

    a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False)

    assert a_domain_r[0].min.value == -1
    assert a_domain_r[0].extent.value == 100
    assert a_domain_r[1].min.value == -1
    assert a_domain_r[1].extent.name == "m"

    a_domain_w = tvm.arith._ffi_api.DomainTouched(ir, a, False, True)
    assert a_domain_w[0].min.value == 0
    assert a_domain_w[0].extent.value == 100
    assert a_domain_w[1].min.value == 0
    assert a_domain_w[1].extent.name == "m"

    a_domain_rw = tvm.arith._ffi_api.DomainTouched(ir, a, True, True)
    assert a_domain_rw[0].min.value == -1
    assert a_domain_rw[0].extent.value == 101
    assert a_domain_rw[1].min.value == -1
    assert isinstance(a_domain_rw[1].extent, tvm.tir.Add)
    assert a_domain_rw[1].extent.a.name == "m"
    assert a_domain_rw[1].extent.b.value == 1

    b_domain_r = tvm.arith._ffi_api.DomainTouched(ir, b, True, False)
    assert b_domain_r
    assert b_domain_r[0].min.value == -1
    assert b_domain_r[0].extent.value == 100
    assert b_domain_r[1].min.value == 1
    assert b_domain_r[1].extent.name == "m"

    b_domain_w = tvm.arith._ffi_api.DomainTouched(ir, b, False, True)
    assert isinstance(b_domain_w, tvm.container.Array)
    assert len(b_domain_w) == 0
Exemplo n.º 21
0
def gemm_acc_nx16_int8_int8_int32(dtype, rows):
    """
    Int8 nx16 matrix multiplication and accumulation using sdot/udot instructions
    This function takes two arrays of int8 datatype -- A[n][4] and
    B[4][16] and produces a rowsx16 matrix which is equal to A*B'
    The pseudo code is as follows.

    .. code-block:: c

        void mmla_nx16_int8_int8_int32(int8 A[n][16], int8 B[4][16][4], int32 output[n][16]){
            for (int i = 0; i < n; i++){
                for (int j = 0; j < 16; j++){
                    for (int k = 0; k < 16; k++){
                        out[i][j] += A[i][k] * B[k//4][j][k%4]
                    }
                }
            }
        }

    Notes:
        * The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16
          we need 4 tiles of B to compute a single row of the output. The first 4 values of
          k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on
        * The tiling strategy is picked to maximize register usage.

    Parameters
    ----------
    dtype : str, {"uint8", "int8"}
        Whether it works on unsigned int or signed int
    rows : int
        Number of of the output rows "n"

    Returns
    -------
    intrin : TensorIntrin
        The Arm TensorIntrin that can be used in tensorizing schedule
    """
    assert dtype in ["uint8", "int8"]
    A = te.placeholder((rows, 16), dtype, name="A")
    B = te.placeholder((4, 16, 4), dtype, name="B")
    dtype_vec = dtype + "x16"
    idxm = tvm.tir.indexmod
    k = te.reduce_axis((0, 16), name="k")
    C = te.compute(
        (rows, 16),
        lambda i, j: te.sum(A[i, k].astype("int32") * B[
            k // 4, j, idxm(k, 4)].astype("int32"),
                            axis=k),
        name="C",
    )

    aa_buffer = tvm.tir.decl_buffer(A.shape,
                                    dtype,
                                    name="aa_buffer",
                                    offset_factor=1,
                                    strides=[te.var("sa"), 1])
    bb_buffer = tvm.tir.decl_buffer(
        B.shape,
        dtype,
        name="bb_buffer",
        offset_factor=1,
        strides=[te.var("sb0"), te.var("sb1"), 1],
    )
    cc_buffer = tvm.tir.decl_buffer(C.shape,
                                    dtype="int32",
                                    name="cc_buffer",
                                    offset_factor=1,
                                    strides=[te.var("sc"), 1])

    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"

    def _intrin_func(ins, outs):
        def _instr(index):
            ib = tvm.tir.ir_builder.create()
            if index == 1:
                for i in range(0, rows):
                    ib.emit(outs[0].vstore([i, 0],
                                           tvm.tir.const(0, "int32x16")))
                return ib.get()
            # Iterate on the number of rows of the output
            for k in range(0, rows):
                # Load 16 elements of A
                # vec_a = [a, b, c, d, e, f, g, h, l, m, n, o, p, q, r, s];
                vec_a = ins[0].vload([k, 0], dtype_vec)

                # Iterate over each of the 4 rowsx4 tiles of the output
                for j in range(0, 4):
                    # Accumulate over each of the 4 (16x4) tiles contained in B
                    for i in range(0, 4):
                        # Replicate a single 4-element group of A (A[k, i:i+4])
                        vec_aa = select_word(vec_a, i, dtype_vec)

                        # Load 4 rows (each rows with 4 elements) from B (B[i:i+4, j:j+4])
                        # vec_b = [0, 16, 32, 48,
                        #          1, 17, 33, 49,
                        #          2, 18, 34, 50,
                        #          3, 19, 35, 51,];
                        vec_b = ins[1].vload([i, 4 * j, 0], dtype_vec)

                        # Accumulate in the correct part of the output
                        vec_c = outs[0].vload([k, 4 * j], "int32x4")

                        # Compute the dot product between the rowsx4 tile
                        # from A and the 4x4 tile from B
                        #
                        # For instance, for i=0, we have:
                        # sdot(vec_aa[0], vec_b) = [a*0+b*16+c*32+d*48,
                        #                           a*1+b*17+c*33+d*49,
                        #                           a*2+b*18+c*34+d*50,
                        #                           a*3+b*19+c*35+d*51]
                        vdot = tvm.tir.call_llvm_intrin(
                            "int32x4",
                            llvm_intrin,
                            tvm.tir.const(3, "uint32"),
                            vec_c,
                            vec_b,
                            vec_aa,
                        )
                        ib.emit(outs[0].vstore([k, 4 * j], vdot))
            return ib.get()

        # body, reset, update
        return _instr(0), _instr(1), _instr(2)

    buffer_params = {"offset_factor": 1}
    return te.decl_tensor_intrin(
        C.op,
        _intrin_func,
        binds={
            A: aa_buffer,
            B: bb_buffer,
            C: cc_buffer
        },
        default_buffer_params=buffer_params,
    )
Exemplo n.º 22
0
def intrin_gemm_MxKxN(M, K, N, in_dtype, out_dtype):
    """Defines a SIMD-accelerated transposed matmul."""
    # we generate a unique ID for every intrinsic definition, to prevent name
    # collisions in the generated source (e.g., if there are multiple operators
    # in the same module that use the same intrinsic)
    #
    # TODO(weberlo, areusch): to cut down on memory usage, we should cache each intrinsic
    # instantiation and include it only once, eliminating the need for unique
    # IDs
    UNIQ_ID_LEN = 8
    uniq_id = "".join(random.choices(string.ascii_uppercase, k=UNIQ_ID_LEN))

    if isinstance(M, tvm.tir.IntImm):
        M = M.value
    if isinstance(K, tvm.tir.IntImm):
        K = K.value
    if isinstance(N, tvm.tir.IntImm):
        N = N.value
    assert K % 4 == 0
    # TODO(weberlo, areusch): support more dtypes?
    assert in_dtype == "int8"
    assert out_dtype == "int32"
    A = te.placeholder((M, K), name="a", dtype=in_dtype)
    B = te.placeholder((N, K), name="b", dtype=in_dtype)
    k = te.reduce_axis((0, K), name="k")
    C = te.compute(
        (M, N),
        lambda i, j: te.sum(A[i, k].astype(out_dtype) * B[j, k].astype(out_dtype), axis=k),
        name="c",
    )
    A_buf = tvm.tir.decl_buffer(
        A.shape, A.dtype, name="A", offset_factor=1, strides=[te.var("A_s"), 1]
    )
    B_buf = tvm.tir.decl_buffer(
        B.shape, B.dtype, name="B", offset_factor=1, strides=[te.var("B_s"), 1]
    )
    C_buf = tvm.tir.decl_buffer(
        C.shape, C.dtype, name="C", offset_factor=1, strides=[te.var("C_s"), 1]
    )

    def intrin_func(ins, outs):
        aa, bb = ins
        cc = outs[0]

        def _reduce_update():
            ib = tvm.tir.ir_builder.create()
            ib.emit(
                tvm.tir.call_extern(
                    "int32",
                    f"gemm_{M}x{K}x{N}_update_{uniq_id}",
                    aa.access_ptr("r"),
                    bb.access_ptr("r"),
                    cc.access_ptr("w"),
                    aa.strides[0],
                    bb.strides[0],
                    cc.strides[0],
                )
            )
            return ib.get()

        def _reduce_reset():
            ib = tvm.tir.ir_builder.create()
            ib.emit(
                tvm.tir.call_extern(
                    "int32", f"gemm_{M}x{K}x{N}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
                )
            )
            return ib.get()

        def _body():
            ib = tvm.tir.ir_builder.create()
            ib.emit(
                tvm.tir.call_extern(
                    "int32",
                    f"gemm_{M}x{K}x{N}_body_{uniq_id}",
                    aa.access_ptr("r"),
                    bb.access_ptr("r"),
                    cc.access_ptr("w"),
                    aa.strides[0],
                    bb.strides[0],
                    cc.strides[0],
                )
            )
            return ib.get()

        return _body(), _reduce_reset(), _reduce_update()

    intrin_decl = te.decl_tensor_intrin(C.op, intrin_func, binds={A: A_buf, B: B_buf, C: C_buf})
    return intrin_decl, uniq_id
Exemplo n.º 23
0
def dot_int8_int8_int32(int32_lanes, dtype='uint'):
    """
    Int8 dot product by every 4 elements using ARM v8.2 udot.
    This function takes two arrays of int8 datatype -- data[4] and
    kernel[int32_lanes][4] -- and computes a dot product of data[4] with every
    4 elements of kernels, resulting in output[int32_lanes] of uint32 datatype.
    The pseudo code is as follows.

    .. code-block:: c

        void dot_int8_int8_int32(int8 data[4], int8 kernel[16][4], int32 output[16]){
            for (int i = 0; i < int32_lanes; i++){
                out[i] = 0;
                for (int k = 0; k < 4; k++){
                    out[i] += data[k] * kernel[i][k]
                }
            }
        }

    Physically, the kernel array sits in a vector register and
    the data[4] is broadcasted to another vector register. This
    function returns a TensorIntrin that can be used to tensorize
    a schedule.

    Parameters
    ----------
    int32_lanes: int
        How many int32/uint32 to produce
    dtype: str, optional, {"uint", "int"}
        Whether it works on unsigned int or signed int

    Returns
    -------
    intrin : TensorIntrin
        The ARM uint8 TensorIntrin that can be used in tensorizing schedule
    """
    num_int8_elements = 4  # 4 int8 elements in int32

    data = te.placeholder((num_int8_elements, ),
                          dtype='%s8' % dtype,
                          name='data')
    kernel = te.placeholder((int32_lanes, num_int8_elements),
                            dtype='%s8' % dtype,
                            name='kernel')

    k = te.reduce_axis((0, num_int8_elements), name='k')
    C = te.compute((int32_lanes, ),
                   lambda i: te.sum(data[k].astype('%s32' % dtype) * kernel[
                       i, k].astype('%s32' % dtype),
                                    axis=k),
                   name="C")

    a_buffer = tvm.tir.decl_buffer(data.shape,
                                   dtype='%s8' % dtype,
                                   name="a_buffer",
                                   offset_factor=1,
                                   strides=[1])
    b_buffer = tvm.tir.decl_buffer(kernel.shape,
                                   dtype='%s8' % dtype,
                                   name="b_buffer",
                                   offset_factor=1,
                                   strides=[te.var('s'), 1])

    def _intrin_func(ins, outs):
        def _instr(index):
            ib = tvm.tir.ir_builder.create()
            if index == 1:
                ib.emit(outs[0].vstore(
                    0, tvm.tir.const(0, '%s32x%d' % (dtype, int32_lanes))))
                return ib.get()

            dtype_a = '%s8x%d' % (dtype, num_int8_elements)
            dtype_b = '%s8x%d' % (dtype, int32_lanes * num_int8_elements)
            dtype_c = '%s32x%d' % (dtype, int32_lanes)

            a_int8 = ins[0].vload([0], dtype_a)
            re_int32 = tvm.tir.call_pure_intrin('%s32' % dtype, 'reinterpret',
                                                a_int8)
            # broadcast a
            vec_ai32 = re_int32.astype(dtype_c)

            vec_a = tvm.tir.call_pure_intrin(dtype_b, 'reinterpret', vec_ai32)
            vec_b = ins[1].vload([0, 0], dtype_b)
            vec_c = outs[0].vload([0], dtype_c)

            inst = 'udot' if dtype == 'uint' else 'sdot'
            inst = 'llvm.aarch64.neon.%s.v%di32.v%di8' % (
                inst, int32_lanes, int32_lanes * num_int8_elements)
            vdot = tvm.tir.call_llvm_intrin(dtype_c, inst,
                                            tvm.tir.const(2, 'uint32'), vec_c,
                                            vec_a, vec_b)
            ib.emit(outs[0].vstore(0, vdot))
            return ib.get()

        # body, reset, update
        return _instr(0), _instr(1), _instr(2)

    buffer_params = {"offset_factor": 1}
    return te.decl_tensor_intrin(C.op,
                                 _intrin_func,
                                 binds={
                                     data: a_buffer,
                                     kernel: b_buffer
                                 },
                                 default_buffer_params=buffer_params)
Exemplo n.º 24
0
from __future__ import absolute_import, print_function

import tvm
import tvm.testing
from tvm import te
from tvm import topi
import numpy as np

######################################################################
# Basic example
# -------------
# Let's revisit the sum of rows operation (equivalent to :code:`B = numpy.sum(A, axis=1)`') \
# To compute the sum of rows of a two dimensional TVM tensor A, we should
# specify the symbolic operation as well as schedule as follows
#
n = te.var("n")
m = te.var("m")
A = te.placeholder((n, m), name="A")
k = te.reduce_axis((0, m), "k")
B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
s = te.create_schedule(B.op)

######################################################################
# and to examine the IR code in human readable format, we can do
#
print(tvm.lower(s, [A], simple_mode=True))

######################################################################
# However, for such a common operation we had to define the reduce axis ourselves as well as explicit computation with
# :code:`te.compute`. Imagine for more complicated operations how much details we need to provide.
# Fortunately, we can replace those two lines with simple :code:`topi.sum` much like :code:`numpy.sum`
Exemplo n.º 25
0
from __future__ import absolute_import, print_function

import tvm
import tvm.testing
from tvm import te
import numpy as np

m = te.var('m')
n = te.var('n')
X = te.placeholder((m, n), name='x')
s_state = te.placeholder((m, n))
s_init = te.compute((1, n), lambda _, i: X[0, i])
s_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i])
s_scan = te.scan(s_init, s_update, s_state, inputs=[X])

s = te.create_schedule(s_scan.op)
num_thread = 256
block_X = te.thread_axis('blockIdx.x')
thread_X = te.thread_axis('threadIdx.x')
xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread)
s[s_init].bind(xo, block_X)
s[s_init].bind(xi, thread_X)
xo, xi = s[s_update].split(s.update.op.axis[1], factor=num_thread)
s[s_update].bind(xo, block_X)
s[s_update].bind(xi, thread_X)

print(tvm.lower(s, [X, s_scan], simple_mode=True))

# multi-stage scan cell
m = te.var('m')
n = te.var('n')
def test_cse():
    z1 = te.var("z1")
    z2 = te.var("z2")
    z3 = te.var("z3")
    i1 = te.var("i1")
    i2 = te.var("i2")
    x = te.var("x")
    y = te.var("y")
    a = te.var("a")
    b = te.var("b")
    dtype = "int32"
    buffer = tvm.tir.decl_buffer((50,), dtype)
    # Test prog :
    # let z1=1 in let z2=2 in
    #   Mem[i1] = z1+z2;
    #   let x = 1 in let y = 1 in
    #     let a = (x+y) + (z1+z2) in
    #       let b = (x+y) + z3 in
    #         Mem[i2] = a+b;
    body = tvm.tir.LetStmt(
        z1,
        1,
        tvm.tir.LetStmt(
            z2,
            2,
            tvm.tir.SeqStmt(
                [
                    tvm.tir.BufferStore(buffer, z1 + z2, [i1]),
                    tvm.tir.LetStmt(
                        x,
                        1,
                        tvm.tir.LetStmt(
                            y,
                            1,
                            tvm.tir.LetStmt(
                                a,
                                (x + y) + (z1 + z2),
                                tvm.tir.LetStmt(
                                    b, (x + y) + z3, tvm.tir.BufferStore(buffer, a + b, [i2])
                                ),
                            ),
                        ),
                    ),
                ]
            ),
        ),
    )
    # This test program gives the opportunity to introduce two new variables, at two different levels
    # and to perform replacements in the value of "a" and "b", using these new variables
    # We will check all of that underneath and more, making also sure that nothing else has been changed

    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, z3], body))
    body = tvm.tir.transform.CommonSubexprElimTIR()(mod)

    tvm.transform.PrintIR()(body)

    body = body["main"].body  # Gets the body of the main, i.e. the full statement

    assert body.var.name == "z1"
    assert body.value == 1

    body = body.body

    assert body.var.name == "z2"
    assert body.value == 2
    # This is the let-in for the first variable generated cse_var_1
    assert isinstance(body.body, tvm.tir.LetStmt)

    body = body.body

    # And this is the name and value of this variable
    cse_var_1 = body.var  # Keep the variable accessible for later checking the replacements
    assert body.var.name == "cse_var_1"
    assert tvm.ir.structural_equal(body.value, z1 + z2)
    assert isinstance(body.body, tvm.tir.SeqStmt)

    body = body.body

    assert isinstance(body[0], tvm.tir.BufferStore)
    assert isinstance(body[1], tvm.tir.LetStmt)

    body = body[1]

    assert body.var.name == "x"
    assert body.value == 1

    body = body.body

    assert body.var.name == "y"
    assert body.value == 1
    # This is the let-in for the second variable generated cse_var_2
    assert isinstance(body.body, tvm.tir.LetStmt)

    body = body.body

    # And this is the name and value of this variable
    cse_var_2 = body.var  # Keep the variable accessible for later checking the replacements
    assert body.var.name == "cse_var_2"
    assert tvm.ir.structural_equal(body.value, x + y)

    body = body.body

    body.var.name == "a"
    # Check that the replacement has been done correctly!
    assert tvm.ir.structural_equal(body.value, cse_var_2 + cse_var_1)

    body = body.body

    body.var.name == "b"
    # Check that the replacement has been done correctly!
    assert tvm.ir.structural_equal(body.value, cse_var_2 + z3)

    assert isinstance(body.body, tvm.tir.BufferStore)
Exemplo n.º 27
0
def test_split_infer_type():
    def verify_split(dshape, indices_or_sections, ret_type, axis=None):
        x = relay.var("x", relay.ty.TensorType(dshape, "float32"))
        y = relay.split(x, indices_or_sections, axis=axis)
        yy = run_infer_type(y.astuple())
        assert yy.checked_type == ret_type

    idxd = tvm.tir.indexdiv

    d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4")
    axis = te.var("axis")
    verify_split((5, 5, 2, 2),
                 5,
                 relay.ty.TupleType(
                     tvm.runtime.convert([
                         relay.ty.TensorType((5, 1, 2, 2), "float32"),
                         relay.ty.TensorType((5, 1, 2, 2), "float32"),
                         relay.ty.TensorType((5, 1, 2, 2), "float32"),
                         relay.ty.TensorType((5, 1, 2, 2), "float32"),
                         relay.ty.TensorType((5, 1, 2, 2), "float32")
                     ])),
                 axis=1)
    verify_split((5, 5, 2, 2),
                 5,
                 relay.ty.TupleType(
                     tvm.runtime.convert([
                         relay.ty.TensorType((1, 5, 2, 2), "float32"),
                         relay.ty.TensorType((1, 5, 2, 2), "float32"),
                         relay.ty.TensorType((1, 5, 2, 2), "float32"),
                         relay.ty.TensorType((1, 5, 2, 2), "float32"),
                         relay.ty.TensorType((1, 5, 2, 2), "float32")
                     ])),
                 axis=0)
    verify_split(
        (d1, d2, d3, d4),
        4,
        relay.ty.TupleType(
            tvm.runtime.convert([
                relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
                relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
                relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
                relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32")
            ])),
        axis=2)
    verify_split(
        (d1, d2, d3, d4),
        2,
        relay.ty.TupleType(
            tvm.runtime.convert([
                relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32"),
                relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32")
            ])),
        axis=0)
    verify_split((d1, d2, d3, d4), (2, 4, 7),
                 relay.ty.TupleType(
                     tvm.runtime.convert([
                         relay.ty.TensorType((d1, 2, d3, d4), "float32"),
                         relay.ty.TensorType((d1, 2, d3, d4), "float32"),
                         relay.ty.TensorType((d1, 3, d3, d4), "float32"),
                         relay.ty.TensorType((d1, (d2 - 7), d3, d4), "float32")
                     ])),
                 axis=1)
Exemplo n.º 28
0
def test_reduce_functions():
    def _with_keepdims(func):
        def _wrapper(data, axis=None, keepdims=False):
            if not keepdims:
                return func(data, axis=axis)
            else:
                if axis is not None:
                    axis = axis if isinstance(axis, int) else axis[0]
                    out_shape = list(data.shape)
                    out_shape[axis] = 1
                else:
                    out_shape = [1 for _ in range(len(data.shape))]
                return func(data, axis=axis).reshape(out_shape)

        return _wrapper

    def _np_log_sum_exp(x, axis, keepdims=False):
        max_x = np.max(x, axis=axis, keepdims=True)
        x = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True))
        x = x + max_x
        if not keepdims:
            x = np.squeeze(x, axis=axis)
        return x

    def _unbiased_relay_wrapper(f):
        def _unbiased_func(x, axis=None, keepdims=False, exclude=False):
            return f(x, axis=axis, keepdims=keepdims, exclude=exclude, unbiased=True)

        return _unbiased_func

    def _unbiased_np_wrapper(f):
        def _unbiased_func(a, axis=None, dtype=None, keepdims=None):
            return f(a, axis=axis, dtype=dtype, ddof=1, keepdims=keepdims)

        return _unbiased_func

    d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4")
    for func in [
        [relay.sum, np.sum],
        [relay.max, np.max],
        [relay.min, np.min],
        [relay.mean, np.mean],
        [relay.variance, np.var],
        [_unbiased_relay_wrapper(relay.variance), _unbiased_np_wrapper(np.var)],
        [relay.std, np.std],
        [_unbiased_relay_wrapper(relay.std), _unbiased_np_wrapper(np.std)],
        [relay.prod, np.prod],
        [relay.all, np.all],
        [relay.any, np.any],
        [relay.logsumexp, _np_log_sum_exp],
        [relay.argmin, _with_keepdims(np.argmin)],
        [relay.argmax, _with_keepdims(np.argmax)],
    ]:
        verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
        verify_reduce(func, (d1, d2, d3, d4), 2, True, False, (d1, d2, 1, d4))
        verify_reduce(func, (d1, d2, d3, d4), 0, True, False, (1, d2, d3, d4))
        verify_reduce(func, (d1, d2, d3), 1, True, False, (d1, 1, d3))
        verify_reduce(func, (d1, d2, d3), 0, True, False, (1, d2, d3))
        verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))
        verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
        verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4))
        verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
        verify_reduce(func, (2, 3, 4), -1, True, False, (2, 3, 1))
        verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
        verify_reduce(func, (4, 4, 3), None, False, False, ())
        verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,))
        verify_reduce(func, (128, 24, 128), (0, 1), False, False, (128,))
        verify_reduce(func, (128, 24, 128), (0, 2), False, False, (24,))
        verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128))
        verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1))
Exemplo n.º 29
0
"""

Split

split是fuse的反操作,把iter以factor为间隔分离成outer与inner两层迭代,增加循环层数,用于将循环操作分割为更小的子任务。
事实上,以CUDA为例,gridDim和blockDim都可以最多是三维,所以通过split可以产生新的维度用于绑定到grid和block上

"""

import tvm
from tvm import te
import numpy as np

# declare some variables for use later
n = te.var('n')
m = te.var('m')

A = te.placeholder((m, ), name='A')
B = te.compute((m, ), lambda i: A[i] * 2, name='B')

s = te.create_schedule(B.op)

xo, xi = s[B].split(
    B.op.axis[0],
    factor=32)  # split can split a specified axis into two axises by factor.
print(tvm.lower(s, [A, B], simple_mode=True))

A = te.placeholder((m, ), name='A')
B = te.compute((m, ), lambda i: A[i], name='B')

s = te.create_schedule(B.op)
Exemplo n.º 30
0
Arquivo: scan.py Projeto: jackwish/tvm
# The scan is carried over the highest dimension of the tensor.
# :code:`s_state` is a placeholder that describes the transition state of the scan.
# :code:`s_init` describes how we can initialize the first k timesteps.
# Here since s_init's first dimension is 1, it describes how we initialize
# The state at first timestep.
#
# :code:`s_update` describes how to update the value at timestep t. The update
# value can refer back to the values of previous timestep via state placeholder.
# Note that while it is invalid to refer to :code:`s_state` at current or later timestep.
#
# The scan takes in state placeholder, initial value and update description.
# It is also recommended(although not necessary) to list the inputs to the scan cell.
# The result of the scan is a tensor, giving the result of :code:`s_state` after the
# update over the time domain.
#
m = te.var("m")
n = te.var("n")
X = te.placeholder((m, n), name="X")
s_state = te.placeholder((m, n))
s_init = te.compute((1, n), lambda _, i: X[0, i])
s_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i])
s_scan = tvm.te.scan(s_init, s_update, s_state, inputs=[X])

######################################################################
# Schedule the Scan Cell
# ----------------------
# We can schedule the body of the scan by scheduling the update and
# init part seperately. Note that it is invalid to schedule the
# first iteration dimension of the update part.
# To split on the time iteration, user can schedule on scan_op.scan_axis instead.
#