Ejemplo n.º 1
0
        def _instr(index):
            ib = tvm.ir_builder.create()
            if index == 1:
                ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16')))
                return ib.get()

            a_int8 = ins[0].vload([0], "uint8x4")
            re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8)
            vec_ai32 = re_int32.astype('int32x16')
            vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai32)
            vec_b = ins[1].vload([0, 0], "int8x64")
            vec_one = tvm.const(1, "int16x32")
            pair_reduction = tvm.call_llvm_intrin('int16x32',
                                                  'llvm.x86.avx512.pmaddubs.w.512',
                                                  tvm.const(0, 'uint32'),
                                                  vec_a, vec_b)
            quad_reduction = tvm.call_llvm_intrin('int32x16',
                                                  'llvm.x86.avx512.pmaddw.d.512',
                                                  tvm.const(0, 'uint32'),
                                                  pair_reduction, vec_one)
            if index == 0:
                ib.emit(outs[0].vstore(0, quad_reduction))
            else:
                ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], 'int32x16')))
            return ib.get()
Ejemplo n.º 2
0
        def _instr(index):
            ib = tvm.ir_builder.create()
            if index == 1:
                ib.emit(outs[0].vstore(
                    0, tvm.const(0, '%s32x%d' % (dtype, int32_lanes))))
                return ib.get()

            dtype_a = '%s8x%d' % (dtype, num_int8_elements)
            dtype_b = '%s8x%d' % (dtype, int32_lanes * num_int8_elements)
            dtype_c = '%s32x%d' % (dtype, int32_lanes)

            a_int8 = ins[0].vload([0], dtype_a)
            re_int32 = tvm.call_pure_intrin('%s32' % dtype, 'reinterpret',
                                            a_int8)
            # broadcast a
            vec_ai32 = re_int32.astype(dtype_c)

            vec_a = tvm.call_pure_intrin(dtype_b, 'reinterpret', vec_ai32)
            vec_b = ins[1].vload([0, 0], dtype_b)
            vec_c = outs[0].vload([0], dtype_c)

            inst = 'udot' if dtype == 'uint' else 'sdot'
            inst = 'llvm.aarch64.neon.%s.v%di32.v%di8' % (
                inst, int32_lanes, int32_lanes * num_int8_elements)
            vdot = tvm.call_llvm_intrin(dtype_c, inst, tvm.const(2, 'uint32'),
                                        vec_c, vec_a, vec_b)
            ib.emit(outs[0].vstore(0, vdot))
            return ib.get()
Ejemplo n.º 3
0
def test_llvm_lookup_intrin():
    ib = tvm.ir_builder.create()
    m = tvm.size_var("m")
    A = ib.pointer("uint8x8", name="A")
    x = tvm.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.const(1, 'uint32'), A)
    ib.emit(x)
    body = ib.get()
    func = tvm.ir_pass.MakeAPI(body, "ctpop", [A], 1, True)
    fcode = tvm.build(func, None, "llvm")
Ejemplo n.º 4
0
def test_llvm_lookup_intrin():
    ib = tvm.ir_builder.create()
    m = tvm.var("m")
    A = ib.pointer("uint8x8", name="A")
    x = tvm.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.const(1, 'uint32'), A)
    ib.emit(x)
    body = ib.get()
    func = tvm.ir_pass.MakeAPI(body, "ctpop", [A], 1, True)
    fcode = tvm.build(func, None, "llvm")
Ejemplo n.º 5
0
        def _instr(index):
            ib = tvm.ir_builder.create()
            if index == 1:
                ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16')))
                return ib.get()

            a_int8 = ins[0].vload([0], "uint8x4")
            re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8)
            vec_ai32 = re_int32.astype('int32x16')
            vec_b = ins[1].vload([0, 0], "int8x64")

            vnni_inst_name = 'llvm.x86.avx512.vpdpbusd.512'
            llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(
                vnni_inst_name)

            if llvm_id != 0:  # VNNI is available for current LLVM version
                vec_bi32 = tvm.call_pure_intrin('int32x16', 'reinterpret',
                                                vec_b)
                vec_zero = tvm.const(0, "int32x16")
                quad_reduction = tvm.call_llvm_intrin(
                    'int32x16', 'llvm.x86.avx512.vpdpbusd.512',
                    tvm.const(0, 'uint32'), vec_zero, vec_ai32, vec_bi32)
            else:  # Fall back to the normal AVX512
                vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret',
                                             vec_ai32)
                vec_one = tvm.const(1, "int16x32")
                pair_reduction = tvm.call_llvm_intrin(
                    'int16x32', 'llvm.x86.avx512.pmaddubs.w.512',
                    tvm.const(0, 'uint32'), vec_a, vec_b)
                quad_reduction = tvm.call_llvm_intrin(
                    'int32x16', 'llvm.x86.avx512.pmaddw.d.512',
                    tvm.const(0, 'uint32'), pair_reduction, vec_one)

            if index == 0:
                ib.emit(outs[0].vstore(0, quad_reduction))
            else:
                ib.emit(outs[0].vstore(
                    0, quad_reduction + outs[0].vload([0], 'int32x16')))
            return ib.get()
Ejemplo n.º 6
0
        def _instr(index):
            irb = tvm.ir_builder.create()
            if index == 1:
                irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8')))
                return irb.get()

            cnts8 = [None] * 8
            cnts4 = [None] * 4
            cnts2 = [None] * 2
            for bw in range(w_b):
                for bx in range(x_b):
                    if k_i == 16:
                        for i in range(m):
                            ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload(
                                [bx, 0], 'uint8x16')
                            cnts = tvm.popcount(ands)
                            upper_half = tvm.call_pure_intrin(
                                'uint8x8', 'vectorhigh', cnts)
                            lower_half = tvm.call_pure_intrin(
                                'uint8x8', 'vectorlow', cnts)
                            cnts8[i] = upper_half + lower_half
                        for i in range(m // 2):
                            cnts4[i] = tvm.call_llvm_intrin(
                                'uint8x8', vpadd, args_1, cnts8[i * 2],
                                cnts8[i * 2 + 1])
                        for i in range(m // 4):
                            cnts2[i] = tvm.call_llvm_intrin(
                                'uint8x8', vpadd, args_1, cnts4[i * 2],
                                cnts4[i * 2 + 1])
                        cnts = tvm.call_pure_intrin('uint8x16',
                                                    'vectorcombine', cnts2[0],
                                                    cnts2[1])
                        shifted_cnts = cnts << tvm.const(bw + bx, dtype)
                        out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2,
                                                   zz.vload(0, 'uint16x8'),
                                                   shifted_cnts)
                    else:  # ki == 8
                        for i in range(m):
                            ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload(
                                [bx, 0], 'uint8x8')
                            cnts8[i] = tvm.popcount(ands)
                        for i in range(m // 2):
                            cnts4[i] = tvm.call_llvm_intrin(
                                'uint8x8', vpadd, args_1, cnts8[i * 2],
                                cnts8[i * 2 + 1])
                        for i in range(m // 4):
                            cnts2[i] = tvm.call_llvm_intrin(
                                'uint8x8', vpadd, args_1, cnts4[i * 2],
                                cnts4[i * 2 + 1])
                        cnts = tvm.call_pure_intrin('uint8x16',
                                                    'vectorcombine', cnts2[0],
                                                    cnts2[1])
                        shifted_cnts = cnts << tvm.const(bw + bx, dtype)
                        out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2,
                                                   zz.vload(0, 'uint16x8'),
                                                   shifted_cnts)
                    irb.emit(zz.vstore(0, out))
            return irb.get()
Ejemplo n.º 7
0
        def _instr(index):
            ib = tvm.ir_builder.create()
            if index == 1:
                ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16')))
                return ib.get()

            a_int8 = ins[0].vload([0], "uint8x4")
            re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8)
            vec_ai32 = re_int32.astype('int32x16')
            vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai32)
            vec_b = ins[1].vload([0, 0], "int8x64")
            vec_one = tvm.const(1, "int16x32")
            pair_reduction = tvm.call_llvm_intrin(
                'int16x32', 'llvm.x86.avx512.pmaddubs.w.512',
                tvm.const(0, 'uint32'), vec_a, vec_b)
            quad_reduction = tvm.call_llvm_intrin(
                'int32x16', 'llvm.x86.avx512.pmaddw.d.512',
                tvm.const(0, 'uint32'), pair_reduction, vec_one)
            if index == 0:
                ib.emit(outs[0].vstore(0, quad_reduction))
            else:
                ib.emit(outs[0].vstore(
                    0, quad_reduction + outs[0].vload([0], 'int32x16')))
            return ib.get()
Ejemplo n.º 8
0
        def _instr(index):
            ib = tvm.ir_builder.create()
            if index == 1:
                ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16')))
                return ib.get()

            a_int8 = ins[0].vload([0], "uint8x4")
            re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8)
            vec_ai32 = re_int32.astype('int32x8')
            vec_a = tvm.call_pure_intrin('int8x32', 'reinterpret', vec_ai32)
            vec_b_0 = ins[1].vload([0, 0], "int8x32")
            vec_b_1 = ins[1].vload([8, 0], "int8x32")
            vec_one = tvm.const(1, "int16x16")
            pair_reduction_0 = tvm.call_llvm_intrin(
                'int16x16', 'llvm.x86.avx2.pmadd.ub.sw',
                tvm.const(0, 'uint32'), vec_a, vec_b_0)
            quad_reduction_0 = tvm.call_llvm_intrin('int32x8',
                                                    'llvm.x86.avx2.pmadd.wd',
                                                    tvm.const(0, 'uint32'),
                                                    pair_reduction_0, vec_one)
            pair_reduction_1 = tvm.call_llvm_intrin(
                'int16x16', 'llvm.x86.avx2.pmadd.ub.sw',
                tvm.const(0, 'uint32'), vec_a, vec_b_1)
            quad_reduction_1 = tvm.call_llvm_intrin('int32x8',
                                                    'llvm.x86.avx2.pmadd.wd',
                                                    tvm.const(0, 'uint32'),
                                                    pair_reduction_1, vec_one)
            if index == 0:
                ib.emit(outs[0].vstore([0], quad_reduction_0))
                ib.emit(outs[0].vstore([8], quad_reduction_1))
            else:
                ib.emit(outs[0].vstore([0], quad_reduction_0 + \
                        outs[0].vload([0], 'int32x8')))
                ib.emit(outs[0].vstore([8], quad_reduction_1 + \
                        outs[0].vload([8], 'int32x8')))
            return ib.get()
Ejemplo n.º 9
0
 def _instr(index):
     irb = tvm.ir_builder.create()
     if index == 1: # reduce reset
         irb.emit(zz.vstore(0, tvm.const(0, return_dtype)))
         return irb.get()
     # body and reduce update
     cnts8 = [None] * 8
     cnts4 = [None] * 4
     cnts2 = [None] * 2
     for bw in range(w_b):
         for bx in range(x_b):
             if k_i == 16:
                 for i in range(m):
                     w_ = ww.vload([bw, i, 0], 'uint8x16').astype(full_dtype)
                     x_ = xx.vload([bx, 0], 'uint8x16').astype(full_dtype)
                     if unipolar:
                         cnts = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_)
                     else:
                         cnts = tvm.popcount(w_ & x_)
                     upper_half = tvm.call_pure_intrin(half_dtype, 'vectorhigh', cnts)
                     lower_half = tvm.call_pure_intrin(half_dtype, 'vectorlow', cnts)
                     cnts8[i] = upper_half + lower_half
                 for i in range(m//2):
                     cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
                                                     args_1, cnts8[i*2], cnts8[i*2+1])
                 for i in range(m//4):
                     cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
                                                     args_1, cnts4[i*2], cnts4[i*2+1])
                 cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
                 shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype)
                 out = tvm.call_llvm_intrin(return_dtype, vpadalu,
                                            args_2, zz.vload(0, return_dtype), shifted_cnts)
             else: # ki == 8
                 for i in range(m):
                     w_ = ww.vload([bw, i, 0], 'uint8x8').astype(half_dtype)
                     x_ = xx.vload([bx, 0], 'uint8x8').astype(half_dtype)
                     if unipolar:
                         cnts8[i] = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_)
                     else:
                         cnts8[i] = tvm.popcount(w_ & x_)
                 for i in range(m//2):
                     cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
                                                     args_1, cnts8[i*2], cnts8[i*2+1])
                 for i in range(m//4):
                     cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
                                                     args_1, cnts4[i*2], cnts4[i*2+1])
                 cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
                 shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype)
                 out = tvm.call_llvm_intrin(return_dtype, vpadalu,
                                            args_2, zz.vload(0, return_dtype), shifted_cnts)
             irb.emit(zz.vstore(0, out))
     return irb.get()
Ejemplo n.º 10
0
        def _instr(index):
            irb = tvm.ir_builder.create()
            if index == 1:
                irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8')))
                return irb.get()

            cnts8 = [None] * 8
            cnts4 = [None] * 4
            cnts2 = [None] * 2
            for bw in range(w_b):
                for bx in range(x_b):
                    if k_i == 16:
                        for i in range(m):
                            ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload([bx, 0], 'uint8x16')
                            cnts = tvm.popcount(ands)
                            upper_half = tvm.call_pure_intrin('uint8x8', 'vectorhigh', cnts)
                            lower_half = tvm.call_pure_intrin('uint8x8', 'vectorlow', cnts)
                            cnts8[i] = upper_half + lower_half
                        for i in range(m//2):
                            cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
                                                            args_1, cnts8[i*2], cnts8[i*2+1])
                        for i in range(m//4):
                            cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
                                                            args_1, cnts4[i*2], cnts4[i*2+1])
                        cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1])
                        shifted_cnts = cnts << tvm.const(bw+bx, dtype)
                        out = tvm.call_llvm_intrin('uint16x8', vpadalu,
                                                   args_2, zz.vload(0, 'uint16x8'), shifted_cnts)
                    else: # ki == 8
                        for i in range(m):
                            ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload([bx, 0], 'uint8x8')
                            cnts8[i] = tvm.popcount(ands)
                        for i in range(m//2):
                            cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
                                                            args_1, cnts8[i*2], cnts8[i*2+1])
                        for i in range(m//4):
                            cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
                                                            args_1, cnts4[i*2], cnts4[i*2+1])
                        cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1])
                        shifted_cnts = cnts << tvm.const(bw+bx, dtype)
                        out = tvm.call_llvm_intrin('uint16x8', vpadalu,
                                                   args_2, zz.vload(0, 'uint16x8'), shifted_cnts)
                    irb.emit(zz.vstore(0, out))
            return irb.get()
Ejemplo n.º 11
0
        def _instr(index):
            ib = tvm.ir_builder.create()
            if index == 1:
                for i in range(4):
                    ib.emit(outs[0].vstore([i * 32], tvm.const(0, 'int16x32')))
                return ib.get()

            a_int8 = ins[0].vload([0], "uint8x2")
            re_int16 = tvm.call_pure_intrin('int16', 'reinterpret', a_int8)
            vec_ai16 = re_int16.astype('int16x32')
            vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai16)

            for i in range(4):
                vec_b = ins[1].vload([i * 32, 0], "int8x64")
                pair_reduction = tvm.call_llvm_intrin(
                    'int16x32', 'llvm.x86.avx512.pmaddubs.w.512',
                    tvm.const(0, 'uint32'), vec_a, vec_b)
                if index == 0:
                    ib.emit(outs[0].vstore([i * 32], pair_reduction))
                else:
                    ib.emit(outs[0].vstore(
                        [i * 32],
                        pair_reduction + outs[0].vload([i * 32], 'int16x32')))
            return ib.get()
Ejemplo n.º 12
0
    def vectorize(op):
        if isinstance(op, tvm.stmt.For):
            outer_loops.pop()
            if to_vectorize:
                if str(op.loop_var) == f'{to_vectorize[-1]}.init':
                    return tvm.tir.For(op.loop_var, op.min, op.extent,
                                       tvm.stmt.For.Vectorized, op.device_api,
                                       op.body)
                elif str(op.loop_var) == str(to_vectorize[-1]):
                    loops = []
                    loads = []
                    store = [None]
                    guard = [None]

                    def get_loops(op):
                        if isinstance(op, tvm.stmt.For):
                            loops.append(op)
                        elif isinstance(op, tvm.expr.Load):
                            loads.append(op)
                        elif isinstance(op, tvm.stmt.Store):
                            assert store[0] is None
                            store[0] = op
                        elif isinstance(op, tvm.stmt.IfThenElse):
                            guard[0] = op

                    tvm.ir_pass.PostOrderVisit(op, get_loops)
                    inner, outer = loops
                    loops = loops[::-1]

                    inner_ext = as_const_int(inner.extent)
                    outer_ext = as_const_int(outer.extent)
                    assert inner_ext is not None and outer_ext is not None
                    assert outer_ext == 16 and inner_ext == 4

                    empty = {
                        outer.loop_var: tvm.const(0, 'int32'),
                        inner.loop_var: tvm.const(0, 'int32')
                    }

                    operands = []
                    indeces = []
                    for elem in loads:
                        iters = [i.loop_var for i in outer_loops + loops]
                        coef = tvm.arith.DetectLinearEquation(
                            elem.index, iters)
                        base_index = sum(
                            i * j for i, j in zip(iters[:-2], coef)) + coef[-1]
                        inner_stride = as_const_int(coef[-2])
                        outer_stride = as_const_int(coef[-3])
                        assert inner_stride is not None and outer_stride is not None

                        if tvm.ir_pass.Equal(elem.buffer_var,
                                             store[0].buffer_var):
                            index = tvm.tir.Ramp(base_index,
                                                 tvm.const(1, 'int32'), 16)
                            continue

                        indeces = []
                        for i in range(outer_ext):
                            for j in range(inner_ext):
                                indeces.append(i * outer_stride +
                                               j * inner_stride)
                        bound = max(indeces) + 1
                        to_load = tvm.tir.Ramp(base_index,
                                               tvm.const(1, 'int32'), bound)
                        value = tvm.tir.Load(elem.dtype + 'x%d' % bound,
                                             elem.buffer_var, to_load,
                                             tvm.const(1, 'int32x%d' % bound))
                        assert 64 % bound == 0

                        operands.append(
                            tvm.tir.Shuffle(
                                [value] * (64 // bound),
                                [tvm.const(i, 'int32') for i in indeces]))

                    buffer_var = store[0].buffer_var

                    operands = [
                        tvm.tir.Load('int32x16', buffer_var, index,
                                     tvm.const(1, 'int32x16'))
                    ] + operands

                    operands = [
                        tvm.call_pure_intrin('int32x16', 'reinterpret', i)
                        for i in operands
                    ]

                    res = tvm.call_llvm_intrin('int32x16',
                                               'llvm.x86.avx512.vpdpbusd.512',
                                               tvm.const(0,
                                                         'uint32'), *operands)

                    res = tvm.tir.Store(buffer_var, res, index,
                                        tvm.const(1, 'int32x16'))
                    if guard[0] is not None:
                        res = tvm.tir.IfThenElse(guard[0].condition, res, None)
                    return res
        elif isinstance(op, tvm.stmt.AttrStmt):
            if not to_vectorize:
                return None
            if tvm.ir_pass.Equal(op.node.var, to_vectorize[-1]):
                to_vectorize.pop()
                return op.body
        return None