Exemple #1
0
        def _instr(index):
            ib = tvm.ir_builder.create()
            if index == 1:
                ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16')))
                return ib.get()

            a_int8 = ins[0].vload([0], "uint8x4")
            re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8)
            vec_ai32 = re_int32.astype('int32x16')
            vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai32)
            vec_b = ins[1].vload([0, 0], "int8x64")
            vec_one = tvm.const(1, "int16x32")
            pair_reduction = tvm.call_llvm_intrin('int16x32',
                                                  'llvm.x86.avx512.pmaddubs.w.512',
                                                  tvm.const(0, 'uint32'),
                                                  vec_a, vec_b)
            quad_reduction = tvm.call_llvm_intrin('int32x16',
                                                  'llvm.x86.avx512.pmaddw.d.512',
                                                  tvm.const(0, 'uint32'),
                                                  pair_reduction, vec_one)
            if index == 0:
                ib.emit(outs[0].vstore(0, quad_reduction))
            else:
                ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], 'int32x16')))
            return ib.get()
        def _instr(index):
            ib = tvm.ir_builder.create()
            if index == 1:
                ib.emit(outs[0].vstore(
                    0, tvm.const(0, '%s32x%d' % (dtype, int32_lanes))))
                return ib.get()

            dtype_a = '%s8x%d' % (dtype, num_int8_elements)
            dtype_b = '%s8x%d' % (dtype, int32_lanes * num_int8_elements)
            dtype_c = '%s32x%d' % (dtype, int32_lanes)

            a_int8 = ins[0].vload([0], dtype_a)
            re_int32 = tvm.call_pure_intrin('%s32' % dtype, 'reinterpret',
                                            a_int8)
            # broadcast a
            vec_ai32 = re_int32.astype(dtype_c)

            vec_a = tvm.call_pure_intrin(dtype_b, 'reinterpret', vec_ai32)
            vec_b = ins[1].vload([0, 0], dtype_b)
            vec_c = outs[0].vload([0], dtype_c)

            inst = 'udot' if dtype == 'uint' else 'sdot'
            inst = 'llvm.aarch64.neon.%s.v%di32.v%di8' % (
                inst, int32_lanes, int32_lanes * num_int8_elements)
            vdot = tvm.call_llvm_intrin(dtype_c, inst, tvm.const(2, 'uint32'),
                                        vec_c, vec_a, vec_b)
            ib.emit(outs[0].vstore(0, vdot))
            return ib.get()
        def _instr(index):
            irb = tvm.ir_builder.create()
            if index == 1:
                irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8')))
                return irb.get()

            cnts8 = [None] * 8
            cnts4 = [None] * 4
            cnts2 = [None] * 2
            for bw in range(w_b):
                for bx in range(x_b):
                    if k_i == 16:
                        for i in range(m):
                            ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload(
                                [bx, 0], 'uint8x16')
                            cnts = tvm.popcount(ands)
                            upper_half = tvm.call_pure_intrin(
                                'uint8x8', 'vectorhigh', cnts)
                            lower_half = tvm.call_pure_intrin(
                                'uint8x8', 'vectorlow', cnts)
                            cnts8[i] = upper_half + lower_half
                        for i in range(m // 2):
                            cnts4[i] = tvm.call_llvm_intrin(
                                'uint8x8', vpadd, args_1, cnts8[i * 2],
                                cnts8[i * 2 + 1])
                        for i in range(m // 4):
                            cnts2[i] = tvm.call_llvm_intrin(
                                'uint8x8', vpadd, args_1, cnts4[i * 2],
                                cnts4[i * 2 + 1])
                        cnts = tvm.call_pure_intrin('uint8x16',
                                                    'vectorcombine', cnts2[0],
                                                    cnts2[1])
                        shifted_cnts = cnts << tvm.const(bw + bx, dtype)
                        out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2,
                                                   zz.vload(0, 'uint16x8'),
                                                   shifted_cnts)
                    else:  # ki == 8
                        for i in range(m):
                            ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload(
                                [bx, 0], 'uint8x8')
                            cnts8[i] = tvm.popcount(ands)
                        for i in range(m // 2):
                            cnts4[i] = tvm.call_llvm_intrin(
                                'uint8x8', vpadd, args_1, cnts8[i * 2],
                                cnts8[i * 2 + 1])
                        for i in range(m // 4):
                            cnts2[i] = tvm.call_llvm_intrin(
                                'uint8x8', vpadd, args_1, cnts4[i * 2],
                                cnts4[i * 2 + 1])
                        cnts = tvm.call_pure_intrin('uint8x16',
                                                    'vectorcombine', cnts2[0],
                                                    cnts2[1])
                        shifted_cnts = cnts << tvm.const(bw + bx, dtype)
                        out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2,
                                                   zz.vload(0, 'uint16x8'),
                                                   shifted_cnts)
                    irb.emit(zz.vstore(0, out))
            return irb.get()
Exemple #4
0
 def _instr(index):
     irb = tvm.ir_builder.create()
     if index == 1: # reduce reset
         irb.emit(zz.vstore(0, tvm.const(0, return_dtype)))
         return irb.get()
     # body and reduce update
     cnts8 = [None] * 8
     cnts4 = [None] * 4
     cnts2 = [None] * 2
     for bw in range(w_b):
         for bx in range(x_b):
             if k_i == 16:
                 for i in range(m):
                     w_ = ww.vload([bw, i, 0], 'uint8x16').astype(full_dtype)
                     x_ = xx.vload([bx, 0], 'uint8x16').astype(full_dtype)
                     if unipolar:
                         cnts = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_)
                     else:
                         cnts = tvm.popcount(w_ & x_)
                     upper_half = tvm.call_pure_intrin(half_dtype, 'vectorhigh', cnts)
                     lower_half = tvm.call_pure_intrin(half_dtype, 'vectorlow', cnts)
                     cnts8[i] = upper_half + lower_half
                 for i in range(m//2):
                     cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
                                                     args_1, cnts8[i*2], cnts8[i*2+1])
                 for i in range(m//4):
                     cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
                                                     args_1, cnts4[i*2], cnts4[i*2+1])
                 cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
                 shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype)
                 out = tvm.call_llvm_intrin(return_dtype, vpadalu,
                                            args_2, zz.vload(0, return_dtype), shifted_cnts)
             else: # ki == 8
                 for i in range(m):
                     w_ = ww.vload([bw, i, 0], 'uint8x8').astype(half_dtype)
                     x_ = xx.vload([bx, 0], 'uint8x8').astype(half_dtype)
                     if unipolar:
                         cnts8[i] = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_)
                     else:
                         cnts8[i] = tvm.popcount(w_ & x_)
                 for i in range(m//2):
                     cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
                                                     args_1, cnts8[i*2], cnts8[i*2+1])
                 for i in range(m//4):
                     cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
                                                     args_1, cnts4[i*2], cnts4[i*2+1])
                 cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
                 shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype)
                 out = tvm.call_llvm_intrin(return_dtype, vpadalu,
                                            args_2, zz.vload(0, return_dtype), shifted_cnts)
             irb.emit(zz.vstore(0, out))
     return irb.get()
Exemple #5
0
def test_reinterpret():
    nn = 1024
    n = tvm.convert(nn)
    A = tvm.placeholder((n, ), name='A', dtype="int32")
    B = tvm.compute(
        A.shape,
        lambda *i: tvm.call_pure_intrin("float32", "reinterpret", A(*i)),
        name='B')
    s = tvm.create_schedule(B.op)

    def check_c():
        mhost = tvm.build(s, [A, B], "c", name="reinterpret")
        temp = util.tempdir()
        path_dso = temp.relpath("temp.so")
        mhost.export_library(path_dso)
        m = tvm.module.load(path_dso)
        fadd = m['reinterpret']
        ctx = tvm.cpu(0)
        n = nn
        a = tvm.nd.array(
            np.random.randint(-2**30, 2**30, size=n).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
        fadd(a, b)
        tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy().view('float32'))

    check_c()
Exemple #6
0
        def _instr(index):
            irb = tvm.ir_builder.create()
            if index == 1:
                irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8')))
                return irb.get()

            cnts8 = [None] * 8
            cnts4 = [None] * 4
            cnts2 = [None] * 2
            for bw in range(w_b):
                for bx in range(x_b):
                    if k_i == 16:
                        for i in range(m):
                            ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload([bx, 0], 'uint8x16')
                            cnts = tvm.popcount(ands)
                            upper_half = tvm.call_pure_intrin('uint8x8', 'vectorhigh', cnts)
                            lower_half = tvm.call_pure_intrin('uint8x8', 'vectorlow', cnts)
                            cnts8[i] = upper_half + lower_half
                        for i in range(m//2):
                            cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
                                                            args_1, cnts8[i*2], cnts8[i*2+1])
                        for i in range(m//4):
                            cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
                                                            args_1, cnts4[i*2], cnts4[i*2+1])
                        cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1])
                        shifted_cnts = cnts << tvm.const(bw+bx, dtype)
                        out = tvm.call_llvm_intrin('uint16x8', vpadalu,
                                                   args_2, zz.vload(0, 'uint16x8'), shifted_cnts)
                    else: # ki == 8
                        for i in range(m):
                            ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload([bx, 0], 'uint8x8')
                            cnts8[i] = tvm.popcount(ands)
                        for i in range(m//2):
                            cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
                                                            args_1, cnts8[i*2], cnts8[i*2+1])
                        for i in range(m//4):
                            cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
                                                            args_1, cnts4[i*2], cnts4[i*2+1])
                        cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1])
                        shifted_cnts = cnts << tvm.const(bw+bx, dtype)
                        out = tvm.call_llvm_intrin('uint16x8', vpadalu,
                                                   args_2, zz.vload(0, 'uint16x8'), shifted_cnts)
                    irb.emit(zz.vstore(0, out))
            return irb.get()
Exemple #7
0
def test_llvm_intrin():
    ib = tvm.ir_builder.create()
    n = tvm.convert(4)
    A = ib.pointer("float32", name="A")
    args = [tvm.call_pure_intrin("handle", "tvm_address_of", A[0]), 0, 3, 1]
    ib.emit(
        tvm.make.Evaluate(
            tvm.make.Call("int32", "prefetch", args, tvm.expr.Call.Intrinsic,
                          None, 0)))
    body = ib.get()
    func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
    fcode = tvm.build(func, None, "llvm")
Exemple #8
0
        def _instr(index):
            ib = tvm.ir_builder.create()
            if index == 1:
                ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16')))
                return ib.get()

            a_int8 = ins[0].vload([0], "uint8x4")
            re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8)
            vec_ai32 = re_int32.astype('int32x16')
            vec_b = ins[1].vload([0, 0], "int8x64")

            vnni_inst_name = 'llvm.x86.avx512.vpdpbusd.512'
            llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(
                vnni_inst_name)

            if llvm_id != 0:  # VNNI is available for current LLVM version
                vec_bi32 = tvm.call_pure_intrin('int32x16', 'reinterpret',
                                                vec_b)
                vec_zero = tvm.const(0, "int32x16")
                quad_reduction = tvm.call_llvm_intrin(
                    'int32x16', 'llvm.x86.avx512.vpdpbusd.512',
                    tvm.const(0, 'uint32'), vec_zero, vec_ai32, vec_bi32)
            else:  # Fall back to the normal AVX512
                vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret',
                                             vec_ai32)
                vec_one = tvm.const(1, "int16x32")
                pair_reduction = tvm.call_llvm_intrin(
                    'int16x32', 'llvm.x86.avx512.pmaddubs.w.512',
                    tvm.const(0, 'uint32'), vec_a, vec_b)
                quad_reduction = tvm.call_llvm_intrin(
                    'int32x16', 'llvm.x86.avx512.pmaddw.d.512',
                    tvm.const(0, 'uint32'), pair_reduction, vec_one)

            if index == 0:
                ib.emit(outs[0].vstore(0, quad_reduction))
            else:
                ib.emit(outs[0].vstore(
                    0, quad_reduction + outs[0].vload([0], 'int32x16')))
            return ib.get()
Exemple #9
0
def test_llvm_intrin():
    ib = tvm.ir_builder.create()
    n = tvm.convert(4)
    A = ib.pointer("float32", name="A")
    args = [
        tvm.call_pure_intrin("handle", "tvm_address_of", A[0]),
        0, 3, 1
    ]
    ib.emit(tvm.make.Evaluate(
        tvm.make.Call(
            "int32", "prefetch", args, tvm.expr.Call.Intrinsic, None, 0)))
    body = ib.get()
    func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
    fcode = tvm.build(func, None, "llvm")
Exemple #10
0
        def _instr(index):
            ib = tvm.ir_builder.create()
            if index == 1:
                for i in range(4):
                    ib.emit(outs[0].vstore([i * 32], tvm.const(0, 'int16x32')))
                return ib.get()

            a_int8 = ins[0].vload([0], "uint8x2")
            re_int16 = tvm.call_pure_intrin('int16', 'reinterpret', a_int8)
            vec_ai16 = re_int16.astype('int16x32')
            vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai16)

            for i in range(4):
                vec_b = ins[1].vload([i * 32, 0], "int8x64")
                pair_reduction = tvm.call_llvm_intrin(
                    'int16x32', 'llvm.x86.avx512.pmaddubs.w.512',
                    tvm.const(0, 'uint32'), vec_a, vec_b)
                if index == 0:
                    ib.emit(outs[0].vstore([i * 32], pair_reduction))
                else:
                    ib.emit(outs[0].vstore(
                        [i * 32],
                        pair_reduction + outs[0].vload([i * 32], 'int16x32')))
            return ib.get()
Exemple #11
0
        def _instr(index):
            ib = tvm.ir_builder.create()
            if index == 1:
                ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16')))
                return ib.get()

            a_int8 = ins[0].vload([0], "uint8x4")
            re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8)
            vec_ai32 = re_int32.astype('int32x8')
            vec_a = tvm.call_pure_intrin('int8x32', 'reinterpret', vec_ai32)
            vec_b_0 = ins[1].vload([0, 0], "int8x32")
            vec_b_1 = ins[1].vload([8, 0], "int8x32")
            vec_one = tvm.const(1, "int16x16")
            pair_reduction_0 = tvm.call_llvm_intrin(
                'int16x16', 'llvm.x86.avx2.pmadd.ub.sw',
                tvm.const(0, 'uint32'), vec_a, vec_b_0)
            quad_reduction_0 = tvm.call_llvm_intrin('int32x8',
                                                    'llvm.x86.avx2.pmadd.wd',
                                                    tvm.const(0, 'uint32'),
                                                    pair_reduction_0, vec_one)
            pair_reduction_1 = tvm.call_llvm_intrin(
                'int16x16', 'llvm.x86.avx2.pmadd.ub.sw',
                tvm.const(0, 'uint32'), vec_a, vec_b_1)
            quad_reduction_1 = tvm.call_llvm_intrin('int32x8',
                                                    'llvm.x86.avx2.pmadd.wd',
                                                    tvm.const(0, 'uint32'),
                                                    pair_reduction_1, vec_one)
            if index == 0:
                ib.emit(outs[0].vstore([0], quad_reduction_0))
                ib.emit(outs[0].vstore([8], quad_reduction_1))
            else:
                ib.emit(outs[0].vstore([0], quad_reduction_0 + \
                        outs[0].vload([0], 'int32x8')))
                ib.emit(outs[0].vstore([8], quad_reduction_1 + \
                        outs[0].vload([8], 'int32x8')))
            return ib.get()
Exemple #12
0
        def _instr(index):
            ib = tvm.ir_builder.create()
            if index == 1:
                ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16')))
                return ib.get()

            a_int8 = ins[0].vload([0], "uint8x4")
            re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8)
            vec_ai32 = re_int32.astype('int32x16')
            vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai32)
            vec_b = ins[1].vload([0, 0], "int8x64")
            vec_one = tvm.const(1, "int16x32")
            pair_reduction = tvm.call_llvm_intrin(
                'int16x32', 'llvm.x86.avx512.pmaddubs.w.512',
                tvm.const(0, 'uint32'), vec_a, vec_b)
            quad_reduction = tvm.call_llvm_intrin(
                'int32x16', 'llvm.x86.avx512.pmaddw.d.512',
                tvm.const(0, 'uint32'), pair_reduction, vec_one)
            if index == 0:
                ib.emit(outs[0].vstore(0, quad_reduction))
            else:
                ib.emit(outs[0].vstore(
                    0, quad_reduction + outs[0].vload([0], 'int32x16')))
            return ib.get()
 def get_vthread(name):
     tx = tvm.thread_axis(name)
     ty = tvm.thread_axis(name)
     ib = tvm.ir_builder.create()
     A = ib.pointer("float32", name="A")
     C = ib.pointer("float32", name="C")
     with ib.for_range(0, n) as i:
         ib.scope_attr(tx, "virtual_thread", nthread)
         ib.scope_attr(ty, "virtual_thread", nthread)
         B = ib.allocate("float32", m, name="B", scope="shared")
         B[i] = A[i * nthread + tx]
         bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
         ib.emit(tvm.call_extern("int32", "Run",
                                 bbuffer.access_ptr("r"),
                                 tvm.call_pure_intrin("int32", "tvm_context_id")))
         C[i * nthread + tx] = B[i] + 1
     return ib.get()
Exemple #14
0
 def get_vthread(name):
     tx = tvm.thread_axis(name)
     ty = tvm.thread_axis(name)
     ib = tvm.ir_builder.create()
     A = ib.pointer("float32", name="A")
     C = ib.pointer("float32", name="C")
     with ib.for_range(0, n) as i:
         ib.scope_attr(tx, "virtual_thread", nthread)
         ib.scope_attr(ty, "virtual_thread", nthread)
         B = ib.allocate("float32", m, name="B", scope="shared")
         B[i] = A[i * nthread + tx]
         bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
         ib.emit(tvm.call_extern("int32", "Run",
                                 bbuffer.access_ptr("r"),
                                 tvm.call_pure_intrin("int32", "tvm_context_id")))
         C[i * nthread + tx] = B[i] + 1
     return ib.get()
Exemple #15
0
def mylog(x):
    """customized log intrinsic function"""
    return tvm.call_pure_intrin(x.dtype, "mylog", x)
Exemple #16
0
    def vectorize(op):
        if isinstance(op, tvm.stmt.For):
            outer_loops.pop()
            if to_vectorize:
                if str(op.loop_var) == f'{to_vectorize[-1]}.init':
                    return tvm.tir.For(op.loop_var, op.min, op.extent,
                                       tvm.stmt.For.Vectorized, op.device_api,
                                       op.body)
                elif str(op.loop_var) == str(to_vectorize[-1]):
                    loops = []
                    loads = []
                    store = [None]
                    guard = [None]

                    def get_loops(op):
                        if isinstance(op, tvm.stmt.For):
                            loops.append(op)
                        elif isinstance(op, tvm.expr.Load):
                            loads.append(op)
                        elif isinstance(op, tvm.stmt.Store):
                            assert store[0] is None
                            store[0] = op
                        elif isinstance(op, tvm.stmt.IfThenElse):
                            guard[0] = op

                    tvm.ir_pass.PostOrderVisit(op, get_loops)
                    inner, outer = loops
                    loops = loops[::-1]

                    inner_ext = as_const_int(inner.extent)
                    outer_ext = as_const_int(outer.extent)
                    assert inner_ext is not None and outer_ext is not None
                    assert outer_ext == 16 and inner_ext == 4

                    empty = {
                        outer.loop_var: tvm.const(0, 'int32'),
                        inner.loop_var: tvm.const(0, 'int32')
                    }

                    operands = []
                    indeces = []
                    for elem in loads:
                        iters = [i.loop_var for i in outer_loops + loops]
                        coef = tvm.arith.DetectLinearEquation(
                            elem.index, iters)
                        base_index = sum(
                            i * j for i, j in zip(iters[:-2], coef)) + coef[-1]
                        inner_stride = as_const_int(coef[-2])
                        outer_stride = as_const_int(coef[-3])
                        assert inner_stride is not None and outer_stride is not None

                        if tvm.ir_pass.Equal(elem.buffer_var,
                                             store[0].buffer_var):
                            index = tvm.tir.Ramp(base_index,
                                                 tvm.const(1, 'int32'), 16)
                            continue

                        indeces = []
                        for i in range(outer_ext):
                            for j in range(inner_ext):
                                indeces.append(i * outer_stride +
                                               j * inner_stride)
                        bound = max(indeces) + 1
                        to_load = tvm.tir.Ramp(base_index,
                                               tvm.const(1, 'int32'), bound)
                        value = tvm.tir.Load(elem.dtype + 'x%d' % bound,
                                             elem.buffer_var, to_load,
                                             tvm.const(1, 'int32x%d' % bound))
                        assert 64 % bound == 0

                        operands.append(
                            tvm.tir.Shuffle(
                                [value] * (64 // bound),
                                [tvm.const(i, 'int32') for i in indeces]))

                    buffer_var = store[0].buffer_var

                    operands = [
                        tvm.tir.Load('int32x16', buffer_var, index,
                                     tvm.const(1, 'int32x16'))
                    ] + operands

                    operands = [
                        tvm.call_pure_intrin('int32x16', 'reinterpret', i)
                        for i in operands
                    ]

                    res = tvm.call_llvm_intrin('int32x16',
                                               'llvm.x86.avx512.vpdpbusd.512',
                                               tvm.const(0,
                                                         'uint32'), *operands)

                    res = tvm.tir.Store(buffer_var, res, index,
                                        tvm.const(1, 'int32x16'))
                    if guard[0] is not None:
                        res = tvm.tir.IfThenElse(guard[0].condition, res, None)
                    return res
        elif isinstance(op, tvm.stmt.AttrStmt):
            if not to_vectorize:
                return None
            if tvm.ir_pass.Equal(op.node.var, to_vectorize[-1]):
                to_vectorize.pop()
                return op.body
        return None
Exemple #17
0
def mylog(x):
    """customized log intrinsic function"""
    return tvm.call_pure_intrin(x.dtype, "mylog", x)
Exemple #18
0
def atomic_add(x, y):
    return tvm.call_pure_intrin(y.dtype, "atomic_add", x, y)
Exemple #19
0
def get_valid_counts_ir(data, valid_count, flag, score_threshold, id_index,
                        score_index):
    """Low level IR to get valid count of bounding boxes
    given a score threshold. Also prepares to move valid boxes to the
    top of input data.

    Parameters
    ----------
    data : Buffer
        Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length].

    valid_count : Buffer
        1D buffer for valid number of boxes with shape [batch_size, ].

    flag : Buffer
        2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].

    score_threshold : float32
        Lower limit of score for valid bounding boxes.

    id_index : optional, int
        index of the class categories, -1 to disable.

    score_index: optional, int
        Index of the scores/confidence of boxes.

    Returns
    -------
    stmt : Stmt
        The result IR statement.
    """
    batch_size = data.shape[0]
    num_anchors = data.shape[1]
    elem_length = data.shape[2]

    ib = tvm.ir_builder.create()

    data = ib.buffer_ptr(data)

    valid_count = ib.buffer_ptr(valid_count)
    flag = ib.buffer_ptr(flag)
    atomic_add_return = ib.allocate(valid_count.dtype, (1, ),
                                    name='atomic_add_return',
                                    scope='local')
    one_count = tvm.const(1, dtype=valid_count.dtype)
    score_threshold = tvm.make.node("FloatImm",
                                    dtype="float32",
                                    value=score_threshold)
    id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
    score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)

    max_threads = int(
        tvm.target.Target.current(allow_none=False).max_num_threads)
    nthread_tx = max_threads
    nthread_bx = batch_size * num_anchors // max_threads + 1
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("blockIdx.x")
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    tid = bx * max_threads + tx
    idxd = tvm.indexdiv

    # initialize valid_count
    with ib.if_scope(tid < batch_size):
        valid_count[tid] = 0
    # initialize flag
    with ib.if_scope(tid < batch_size * num_anchors):
        flag[tid] = 0
    with ib.if_scope(tid < batch_size * num_anchors):
        i = idxd(tid, num_anchors)
        with ib.if_scope(
                tvm.all(
                    data[tid * elem_length + score_index] > score_threshold,
                    tvm.any(id_index < 0,
                            data[tid * elem_length + id_index] >= 0))):
            flag[tid] = 1
            atomic_add_return[0] = atomic_add(
                tvm.call_pure_intrin("handle", "tvm_address_of",
                                     valid_count[i]), one_count)

    return ib.get()