Ejemplo n.º 1
0
 def _unipolar_conv(n, h, w, co, vh, vw, vc):
     return tvm.sum(
         ((tvm.popcount(kernel_vec[co, dh, dw, kb, vc, ci].astype('int16') &
                        data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci].astype('int16')) -
           tvm.popcount(~kernel_vec[co, dh, dw, kb, vc, ci].astype('int16') &
                        data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci]).astype('int16'))
          << (kb + ib).astype('int16')), axis=[dh, dw, kb, ib, ci])
Ejemplo n.º 2
0
        def _instr(index):
            irb = tvm.ir_builder.create()
            if index == 1:
                irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8')))
                return irb.get()

            cnts8 = [None] * 8
            cnts4 = [None] * 4
            cnts2 = [None] * 2
            for bw in range(w_b):
                for bx in range(x_b):
                    if k_i == 16:
                        for i in range(m):
                            ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload(
                                [bx, 0], 'uint8x16')
                            cnts = tvm.popcount(ands)
                            upper_half = tvm.call_pure_intrin(
                                'uint8x8', 'vectorhigh', cnts)
                            lower_half = tvm.call_pure_intrin(
                                'uint8x8', 'vectorlow', cnts)
                            cnts8[i] = upper_half + lower_half
                        for i in range(m // 2):
                            cnts4[i] = tvm.call_llvm_intrin(
                                'uint8x8', vpadd, args_1, cnts8[i * 2],
                                cnts8[i * 2 + 1])
                        for i in range(m // 4):
                            cnts2[i] = tvm.call_llvm_intrin(
                                'uint8x8', vpadd, args_1, cnts4[i * 2],
                                cnts4[i * 2 + 1])
                        cnts = tvm.call_pure_intrin('uint8x16',
                                                    'vectorcombine', cnts2[0],
                                                    cnts2[1])
                        shifted_cnts = cnts << tvm.const(bw + bx, dtype)
                        out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2,
                                                   zz.vload(0, 'uint16x8'),
                                                   shifted_cnts)
                    else:  # ki == 8
                        for i in range(m):
                            ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload(
                                [bx, 0], 'uint8x8')
                            cnts8[i] = tvm.popcount(ands)
                        for i in range(m // 2):
                            cnts4[i] = tvm.call_llvm_intrin(
                                'uint8x8', vpadd, args_1, cnts8[i * 2],
                                cnts8[i * 2 + 1])
                        for i in range(m // 4):
                            cnts2[i] = tvm.call_llvm_intrin(
                                'uint8x8', vpadd, args_1, cnts4[i * 2],
                                cnts4[i * 2 + 1])
                        cnts = tvm.call_pure_intrin('uint8x16',
                                                    'vectorcombine', cnts2[0],
                                                    cnts2[1])
                        shifted_cnts = cnts << tvm.const(bw + bx, dtype)
                        out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2,
                                                   zz.vload(0, 'uint16x8'),
                                                   shifted_cnts)
                    irb.emit(zz.vstore(0, out))
            return irb.get()
Ejemplo n.º 3
0
 def _conv(nn, ff, yy, xx):
     b1b2 = (b1+b2).astype(out_dtype)
     return tvm.sum(
         ((tvm.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] &
                        Filter_q[ff, rc, ry, rx, b2]) -
           tvm.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] &
                        ~Filter_q[ff, rc, ry, rx, b2]))
          << (b1b2)).astype(out_dtype),
         axis=[rc, ry, rx, b2, b1]).astype(out_dtype)
Ejemplo n.º 4
0
 def _conv(nn, ff, yy, xx):
     b1b2 = (b1+b2).astype(out_dtype)
     return tvm.sum(
         ((tvm.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] &
                        Filter_q[ff, rc, ry, rx, b2]) -
           tvm.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] &
                        ~Filter_q[ff, rc, ry, rx, b2]))
          << (b1b2)).astype(out_dtype),
         axis=[rc, ry, rx, b2, b1]).astype(out_dtype)
Ejemplo n.º 5
0
    def _conv(n, h, w, co, vh, vw, vc):
        b1b2 = (b1+b2).astype(out_dtype)
        if dorefa:
            return tvm.sum(
                (tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1].astype(out_dtype) &
                              kernel_vec[co, dh, dw, ci, vc, b2].astype(out_dtype)) -
                 tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1].astype(out_dtype) &
                              ~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2,
                axis=[dh, dw, ci, b1, b2])

        return tvm.sum(tvm.popcount(
            data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] &
            kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) << b1b2,
                       axis=[dh, dw, ci, b1, b2])
Ejemplo n.º 6
0
    def _conv(n, h, w, co, vh, vw, vc):
        b1b2 = (b1+b2).astype(out_dtype)
        if unipolar:
            return tvm.sum(
                ((tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] &
                               kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) -
                  tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1]&
                               ~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2),
                axis=[dh, dw, ci, b1, b2])

        return tvm.sum(tvm.popcount(
            data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] &
            kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) << b1b2,
                       axis=[dh, dw, ci, b1, b2])
Ejemplo n.º 7
0
    def _conv(n, co, h, w, vh, vw, vc):
        b1b2 = (b1+b2).astype(out_dtype)
        if unipolar:
            return tvm.sum((tvm.popcount(
                data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype) &
                kernel_vec[co, ci, dh, dw, b2, vc].astype(out_dtype))  -
                            tvm.popcount(
                                data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype)
                                & ~kernel_vec[co, ci, dh, dw, b2, vc]).astype(out_dtype)) << b1b2,
                           axis=[ci, dh, dw, b1, b2])

        return tvm.sum((tvm.popcount(
            data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1] &
            kernel_vec[co, ci, dh, dw, b2, vc])).astype(out_dtype) << b1b2,
                       axis=[ci, dh, dw, b1, b2])
Ejemplo n.º 8
0
def binary_dense(data, weight):
    """Binary matrix multiplication using xor and bit-count.

    Parameters
    ----------
    data : tvm.Tensor
        2-D with shape [batch, in_dim], dtype is uint32.

    weight : tvm.Tensor
        2-D with shape [out_dim, in_dim], dtype is uint32.

    Returns
    -------
    output : tvm.Tensor
        2-D with shape [batch, out_dim], dtype is float32.
    """
    assert data.dtype == 'uint32' and weight.dtype == 'uint32', \
        "dtype of data and weight should be uint32"
    assert len(data.shape) == 2 and len(weight.shape) == 2, \
        "only support 2-dim binary dense"
    batch, in_dim = data.shape
    out_dim, _ = weight.shape
    k = tvm.reduce_axis((0, in_dim), name='k')
    matmul = tvm.compute((batch, out_dim), lambda i, j: \
                          tvm.sum(tvm.popcount(data[i, k] ^ weight[j, k]), axis=k), \
                          tag='binary_dense')

    return tvm.compute((batch, out_dim), lambda i, j: \
                        32 * in_dim - 2. * matmul(i, j), \
                        tag=tag.ELEMWISE)
Ejemplo n.º 9
0
Archivo: bnn.py Proyecto: bddppq/tvm
def binary_dense(data, weight):
    """Binary matrix multiplication using xor and bit-count.

    Parameters
    ----------
    data : tvm.Tensor
        2-D with shape [batch, in_dim], dtype is uint32.

    weight : tvm.Tensor
        2-D with shape [out_dim, in_dim], dtype is uint32.

    Returns
    -------
    output : tvm.Tensor
        2-D with shape [batch, out_dim], dtype is float32.
    """
    assert data.dtype == 'uint32' and weight.dtype == 'uint32', \
        "dtype of data and weight should be uint32"
    assert len(data.shape) == 2 and len(weight.shape) == 2, \
        "only support 2-dim binary dense"
    batch, in_dim = data.shape
    out_dim, _ = weight.shape
    k = tvm.reduce_axis((0, in_dim), name='k')
    matmul = tvm.compute((batch, out_dim), lambda i, j: \
                          tvm.sum(tvm.popcount(data[i, k] ^ weight[j, k]), axis=k), \
                          tag='binary_dense')

    return tvm.compute((batch, out_dim), lambda i, j: \
                        32 * in_dim - 2. * matmul(i, j), \
                        tag=tag.ELEMWISE)
Ejemplo n.º 10
0
    def run(dtype):
        # graph
        n = tvm.convert(1024)
        A = tvm.placeholder((n,), name='A', dtype=dtype)
        B = tvm.compute(A.shape, lambda *i: tvm.popcount(A(*i)), name='B')
        s = tvm.create_schedule(B.op)
        # simple schedule
        num_thread = 8
        bx, tx = s[B].split(B.op.axis[0], factor=num_thread)

        def check_device(device):
            ctx = tvm.context(device, 0)
            if not ctx.exist:
                print("skip because %s is not enabled.." % device)
                return
            target = tvm.target.create(device)
            if "cpu" not in target.keys:
                s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
                s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
            func = tvm.build(s, [A, B], device)
            # launch the kernel.
            n = 1024
            a = tvm.nd.array(np.random.randint(low=0, high=1000, size=n, dtype=A.dtype), ctx)
            b = tvm.nd.array(np.zeros(shape=n, dtype=B.dtype), ctx)
            func(a, b)
            np.testing.assert_allclose(
                b.asnumpy(), list(map(lambda x: bin(x).count('1'), a.asnumpy())), rtol=1e-5)

        check_device("llvm")
        check_device("cuda")
        check_device("opencl")
        if dtype == "uint32":
            check_device("metal")
            check_device("vulkan")
Ejemplo n.º 11
0
    def run(dtype):
        # graph
        n = tvm.convert(1024)
        A = tvm.placeholder((n,), name='A', dtype=dtype)
        B = tvm.compute(A.shape, lambda *i: tvm.popcount(A(*i)), name='B')
        s = tvm.create_schedule(B.op)
        # simple schedule
        num_thread = 8
        bx, tx = s[B].split(B.op.axis[0], factor=num_thread)

        def check_device(device):
            ctx = tvm.context(device, 0)
            if not ctx.exist:
                print("skip because %s is not enabled.." % device)
                return
            target = tvm.target.create(device)
            if "cpu" not in target.keys:
                s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
                s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
            func = tvm.build(s, [A, B], device)
            # launch the kernel.
            n = 1024
            a = tvm.nd.array(np.random.randint(low=0, high=1000, size=n, dtype=A.dtype), ctx)
            b = tvm.nd.array(np.zeros(shape=n, dtype=B.dtype), ctx)
            func(a, b)
            np.testing.assert_allclose(
                b.asnumpy(), list(map(lambda x: bin(x).count('1'), a.asnumpy())), rtol=1e-5)

        check_device("llvm")
        check_device("cuda")
        check_device("opencl")
        check_device("metal")
        if dtype == "uint32":
            check_device("vulkan")
Ejemplo n.º 12
0
    def _conv(n, co, h, w, vh, vw, vc):
        b1b2 = (b1 + b2).astype(out_dtype)
        if dorefa:
            return tvm.sum(
                (tvm.popcount(data_vec[n, h, w, ci, vh * HSTR + dh,
                                       vw * WSTR + dw, b1].astype(out_dtype)
                              & kernel_vec[co, ci, dh, dw, b2,
                                           vc].astype(out_dtype)) -
                 tvm.popcount(data_vec[n, h, w, ci, vh * HSTR + dh,
                                       vw * WSTR + dw, b1].astype(out_dtype)
                              & ~kernel_vec[co, ci, dh, dw, b2, vc]).astype(
                                  out_dtype)) << b1b2,
                axis=[ci, dh, dw, b1, b2])

        return tvm.sum((tvm.popcount(
            data_vec[n, h, w, ci, vh * HSTR + dh, vw * WSTR + dw, b1]
            & kernel_vec[co, ci, dh, dw, b2, vc])).astype(out_dtype) << b1b2,
                       axis=[ci, dh, dw, b1, b2])
Ejemplo n.º 13
0
def bgemm_topi(Y, X, K, dtype="uint64"):
    DB = 1
    WB = 1
    out_dtype = dtype
    data_packed = tvm.placeholder((Y, DB, K), dtype=dtype, name="A")
    weight_packed = tvm.placeholder((X, WB, K), dtype=dtype, name="B")

    oshape = (Y, X)
    k = tvm.reduce_axis((0, K), name='k')
    db = tvm.reduce_axis((0, DB), name='db')
    wb = tvm.reduce_axis((0, WB), name='wb')

    matmul = tvm.compute(
        oshape,
        lambda i, j: tvm.sum(tvm.popcount(weight_packed[
            j, wb, k] & data_packed[i, db, k]).astype(out_dtype) <<
                             (db + wb).astype(out_dtype),
                             axis=[wb, db, k]),
        tag='bitserial_dense')

    s = tvm.create_schedule(matmul.op)
    cfg = autotvm.get_config()

    CC = s.cache_write(matmul, "global")

    y, x = s[matmul].op.axis

    yo, yi = cfg.define_split("tile_y",
                              y,
                              num_outputs=2,
                              filter=lambda x: x.size[-1] <= 8)
    xo, xi = cfg.define_split("tile_x",
                              x,
                              num_outputs=2,
                              filter=lambda x: x.size[-1] <= 8)

    yo, yi = cfg["tile_y"].apply(s, matmul, y)
    xo, xi = cfg["tile_x"].apply(s, matmul, x)

    s[matmul].reorder(yo, xo, yi, xi)
    cfg.define_knob("compute_at_axis", [0, 1, 2])
    if cfg["compute_at_axis"].val == 0:
        s[CC].compute_at(s[matmul], xo)
    elif cfg["compute_at_axis"].val == 1:
        s[CC].compute_at(s[matmul], yi)
    elif cfg["compute_at_axis"].val == 2:
        s[CC].compute_at(s[matmul], xi)

    yc, xc = s[CC].op.axis
    wb, db, k = s[CC].op.reduce_axis

    cfg.define_reorder("reorder_0", [k, yc, xc], policy="all")
    cfg["reorder_0"].apply(s, CC, [k, yc, xc])

    cfg.add_flop(2 * Y * X * K * int(dtype[4:]))

    return s, [data_packed, weight_packed, matmul]
Ejemplo n.º 14
0
        def _instr(index):
            irb = tvm.ir_builder.create()
            if index == 1:
                irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8')))
                return irb.get()

            cnts8 = [None] * 8
            cnts4 = [None] * 4
            cnts2 = [None] * 2
            for bw in range(w_b):
                for bx in range(x_b):
                    if k_i == 16:
                        for i in range(m):
                            ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload([bx, 0], 'uint8x16')
                            cnts = tvm.popcount(ands)
                            upper_half = tvm.call_pure_intrin('uint8x8', 'vectorhigh', cnts)
                            lower_half = tvm.call_pure_intrin('uint8x8', 'vectorlow', cnts)
                            cnts8[i] = upper_half + lower_half
                        for i in range(m//2):
                            cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
                                                            args_1, cnts8[i*2], cnts8[i*2+1])
                        for i in range(m//4):
                            cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
                                                            args_1, cnts4[i*2], cnts4[i*2+1])
                        cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1])
                        shifted_cnts = cnts << tvm.const(bw+bx, dtype)
                        out = tvm.call_llvm_intrin('uint16x8', vpadalu,
                                                   args_2, zz.vload(0, 'uint16x8'), shifted_cnts)
                    else: # ki == 8
                        for i in range(m):
                            ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload([bx, 0], 'uint8x8')
                            cnts8[i] = tvm.popcount(ands)
                        for i in range(m//2):
                            cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
                                                            args_1, cnts8[i*2], cnts8[i*2+1])
                        for i in range(m//4):
                            cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
                                                            args_1, cnts4[i*2], cnts4[i*2+1])
                        cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1])
                        shifted_cnts = cnts << tvm.const(bw+bx, dtype)
                        out = tvm.call_llvm_intrin('uint16x8', vpadalu,
                                                   args_2, zz.vload(0, 'uint16x8'), shifted_cnts)
                    irb.emit(zz.vstore(0, out))
            return irb.get()
Ejemplo n.º 15
0
def bitserial_dense(data, weight, data_bits, weight_bits, pack_dtype='uint32',
                    out_dtype='int16', unipolar=True):
    """The default implementation of bitserial dense in topi.

    Parameters
    ----------
    data : tvm.Tensor
        2-D with shape [batch, in_dim]
    weight : tvm.Tensor
        2-D with shape [out_dim, in_dim] or
        3-D with shape [out_dim, weight_bits, in_dim]
    Returns
    -------
    output : tvm.Tensor
        2-D with shape [batch, out_dim]
    """
    data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
    if len(weight.shape) == 2:
        weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
    else:
        weight_packed = weight
    Y, DB, K = get_const_tuple(data_packed.shape)
    X, WB, _ = get_const_tuple(weight_packed.shape)

    oshape = (Y, X)
    k = tvm.reduce_axis((0, K), name='k')
    db = tvm.reduce_axis((0, DB), name='db')
    wb = tvm.reduce_axis((0, WB), name='wb')

    matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum(
        (tvm.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]) -
         tvm.popcount(~weight_packed[j, wb, k] & data_packed[i, db, k])).astype(out_dtype)
        << (db+wb).astype(out_dtype), axis=[wb, db, k]),
                                  tag='bitserial_dense_unipolar')

    matmul = tvm.compute(oshape, lambda i, j: tvm.sum(
        tvm.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]).astype(out_dtype)
        << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')


    if unipolar:
        return matmul_unipolar
    return matmul
Ejemplo n.º 16
0
def bitserial_dense(data, weight, data_bits, weight_bits, pack_dtype='uint32',
                    out_dtype='int16', unipolar=True):
    """The default implementation of bitserial dense in topi.

    Parameters
    ----------
    data : tvm.Tensor
        2-D with shape [batch, in_dim]
    weight : tvm.Tensor
        2-D with shape [out_dim, in_dim] or
        3-D with shape [out_dim, weight_bits, in_dim]
    Returns
    -------
    output : tvm.Tensor
        2-D with shape [batch, out_dim]
    """
    data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
    if len(weight.shape) == 2:
        weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
    else:
        weight_packed = weight
    Y, DB, K = get_const_tuple(data_packed.shape)
    X, WB, _ = get_const_tuple(weight_packed.shape)

    oshape = (Y, X)
    k = tvm.reduce_axis((0, K), name='k')
    db = tvm.reduce_axis((0, DB), name='db')
    wb = tvm.reduce_axis((0, WB), name='wb')

    matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum(
        (tvm.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]) -
         tvm.popcount(~weight_packed[j, wb, k] & data_packed[i, db, k])).astype(out_dtype)
        << (db+wb).astype(out_dtype), axis=[wb, db, k]),
                                  tag='bitserial_dense_unipolar')

    matmul = tvm.compute(oshape, lambda i, j: tvm.sum(
        tvm.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]).astype(out_dtype)
        << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')


    if unipolar:
        return matmul_unipolar
    return matmul
Ejemplo n.º 17
0
    def check_correct_assembly(type, elements, counts):
        n = tvm.convert(elements)
        A = tvm.placeholder(n, dtype=type, name='A')
        B = tvm.compute(A.shape, lambda i: tvm.popcount(A[i]), name='B')
        s = tvm.create_schedule(B.op)
        s[B].vectorize(s[B].op.axis[0])
        f = tvm.build(s, [A, B], target)

        # Verify we see the correct number of vpaddl and vcnt instructions in the assembly
        assembly = f.get_source('asm')
        matches = re.findall("vpaddl", assembly)
        assert (len(matches) == counts)
        matches = re.findall("vcnt", assembly)
        assert (len(matches) == 1)
Ejemplo n.º 18
0
    def check_correct_assembly(type, elements, counts):
        n = tvm.convert(elements)
        A = tvm.placeholder(n, dtype=type, name='A')
        B = tvm.compute(A.shape, lambda i: tvm.popcount(A[i]), name='B')
        s = tvm.create_schedule(B.op)
        s[B].vectorize(s[B].op.axis[0])
        f = tvm.build(s, [A, B], target)

        # Verify we see the correct number of vpaddl and vcnt instructions in the assembly
        assembly = f.get_source('asm')
        matches = re.findall("vpaddl", assembly)
        assert (len(matches) == counts)
        matches = re.findall("vcnt", assembly)
        assert (len(matches) == 1)
Ejemplo n.º 19
0
def test_popcount_llvm():
    # graph
    n = tvm.var('n')
    A = tvm.placeholder((n,), name='A', dtype="uint32")
    B = tvm.compute(A.shape, lambda *i: tvm.popcount(A(*i)), name='B')
    s = tvm.create_schedule(B.op)

    if not tvm.module.enabled("llvm"):
        return
    f = tvm.build(s, [A, B], "llvm")
    ctx = tvm.cpu(0)
    # launch the kernel.
    n = 1024
    a = tvm.nd.array(np.random.randint(low=0, high=1000, size=n, dtype=A.dtype), ctx)
    b = tvm.nd.array(np.zeros(shape=n, dtype=B.dtype), ctx)
    f(a, b)
    np.testing.assert_allclose(
        b.asnumpy(), list(map(lambda x: bin(x).count('1'), a.asnumpy())), rtol=1e-5)
Ejemplo n.º 20
0
 def _instr(index):
     irb = tvm.ir_builder.create()
     if index == 1: # reduce reset
         irb.emit(zz.vstore(0, tvm.const(0, return_dtype)))
         return irb.get()
     # body and reduce update
     cnts8 = [None] * 8
     cnts4 = [None] * 4
     cnts2 = [None] * 2
     for bw in range(w_b):
         for bx in range(x_b):
             if k_i == 16:
                 for i in range(m):
                     w_ = ww.vload([bw, i, 0], 'uint8x16').astype(full_dtype)
                     x_ = xx.vload([bx, 0], 'uint8x16').astype(full_dtype)
                     if unipolar:
                         cnts = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_)
                     else:
                         cnts = tvm.popcount(w_ & x_)
                     upper_half = tvm.call_pure_intrin(half_dtype, 'vectorhigh', cnts)
                     lower_half = tvm.call_pure_intrin(half_dtype, 'vectorlow', cnts)
                     cnts8[i] = upper_half + lower_half
                 for i in range(m//2):
                     cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
                                                     args_1, cnts8[i*2], cnts8[i*2+1])
                 for i in range(m//4):
                     cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
                                                     args_1, cnts4[i*2], cnts4[i*2+1])
                 cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
                 shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype)
                 out = tvm.call_llvm_intrin(return_dtype, vpadalu,
                                            args_2, zz.vload(0, return_dtype), shifted_cnts)
             else: # ki == 8
                 for i in range(m):
                     w_ = ww.vload([bw, i, 0], 'uint8x8').astype(half_dtype)
                     x_ = xx.vload([bx, 0], 'uint8x8').astype(half_dtype)
                     if unipolar:
                         cnts8[i] = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_)
                     else:
                         cnts8[i] = tvm.popcount(w_ & x_)
                 for i in range(m//2):
                     cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
                                                     args_1, cnts8[i*2], cnts8[i*2+1])
                 for i in range(m//4):
                     cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
                                                     args_1, cnts4[i*2], cnts4[i*2+1])
                 cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
                 shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype)
                 out = tvm.call_llvm_intrin(return_dtype, vpadalu,
                                            args_2, zz.vload(0, return_dtype), shifted_cnts)
             irb.emit(zz.vstore(0, out))
     return irb.get()
Ejemplo n.º 21
0
def _intrin_popcount(m, k_i, w_b, x_b):
    dtype = 'uint8'
    w = tvm.placeholder((w_b, m, k_i), dtype=dtype, name='w')
    x = tvm.placeholder((x_b, k_i,), dtype=dtype, name='x')
    k = tvm.reduce_axis((0, k_i), name='k')
    bw = tvm.reduce_axis((0, w_b), name='bw')
    bx = tvm.reduce_axis((0, x_b), name='bx')
    z = tvm.compute((m,), lambda i:
                    tvm.sum(tvm.popcount(w[bw, i, k].astype('uint16') &
                                         x[bx, k].astype('uint16'))
                            << (bw+bx).astype('uint16'), axis=[bw, bx, k]), name='z')

    Wb = tvm.decl_buffer(w.shape, w.dtype,
                         name="W",
                         offset_factor=k_i,
                         strides=[tvm.var('ldw'), tvm.var('ldw'), 1])
    Xb = tvm.decl_buffer(x.shape, x.dtype,
                         name="X",
                         offset_factor=k_i,
                         strides=[tvm.var('ldw'), 1])

    def _intrin_func(ins, outs):
        ww, xx = ins
        zz = outs[0]
        vpadd = "llvm.arm.neon.vpadd.v8u8"
        vpadalu = "llvm.arm.neon.vpadalu.v16u8.v8u16"
        args_1 = tvm.const(1, 'uint32')
        args_2 = tvm.const(2, 'uint32')

        def _instr(index):
            irb = tvm.ir_builder.create()
            if index == 1:
                irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8')))
                return irb.get()

            cnts8 = [None] * 8
            cnts4 = [None] * 4
            cnts2 = [None] * 2
            for bw in range(w_b):
                for bx in range(x_b):
                    if k_i == 16:
                        for i in range(m):
                            ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload([bx, 0], 'uint8x16')
                            cnts = tvm.popcount(ands)
                            upper_half = tvm.call_pure_intrin('uint8x8', 'vectorhigh', cnts)
                            lower_half = tvm.call_pure_intrin('uint8x8', 'vectorlow', cnts)
                            cnts8[i] = upper_half + lower_half
                        for i in range(m//2):
                            cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
                                                            args_1, cnts8[i*2], cnts8[i*2+1])
                        for i in range(m//4):
                            cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
                                                            args_1, cnts4[i*2], cnts4[i*2+1])
                        cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1])
                        shifted_cnts = cnts << tvm.const(bw+bx, dtype)
                        out = tvm.call_llvm_intrin('uint16x8', vpadalu,
                                                   args_2, zz.vload(0, 'uint16x8'), shifted_cnts)
                    else: # ki == 8
                        for i in range(m):
                            ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload([bx, 0], 'uint8x8')
                            cnts8[i] = tvm.popcount(ands)
                        for i in range(m//2):
                            cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
                                                            args_1, cnts8[i*2], cnts8[i*2+1])
                        for i in range(m//4):
                            cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
                                                            args_1, cnts4[i*2], cnts4[i*2+1])
                        cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1])
                        shifted_cnts = cnts << tvm.const(bw+bx, dtype)
                        out = tvm.call_llvm_intrin('uint16x8', vpadalu,
                                                   args_2, zz.vload(0, 'uint16x8'), shifted_cnts)
                    irb.emit(zz.vstore(0, out))
            return irb.get()
        # body, reset, update
        return _instr(0), _instr(1), _instr(2)
    with tvm.build_config(offset_factor=1, partition_const_loop=True):
        return tvm.decl_tensor_intrin(z.op, _intrin_func, binds={w: Wb, x:Xb})
Ejemplo n.º 22
0
def _intrin_popcount(m, k_i, w_b, x_b, unipolar):
    pack_dtype = 'uint8'
    w = tvm.placeholder((w_b, m, k_i), dtype=pack_dtype, name='w')
    x = tvm.placeholder((
        x_b,
        k_i,
    ), dtype=pack_dtype, name='x')
    k = tvm.reduce_axis((0, k_i), name='k')
    bw = tvm.reduce_axis((0, w_b), name='bw')
    bx = tvm.reduce_axis((0, x_b), name='bx')
    if unipolar:
        dtype = 'int16'
        z = tvm.compute(
            (m, ),
            lambda i: tvm.sum((tvm.popcount(w[bw, i, k].astype(dtype) & x[
                bx, k].astype(dtype)) - tvm.popcount(~w[bw, i, k].astype(
                    dtype) & x[bx, k].astype(dtype))) <<
                              (bw + bx).astype(dtype),
                              axis=[bw, bx, k]),
            name='z')
    else:
        dtype = 'uint16'
        z = tvm.compute((m, ),
                        lambda i: tvm.sum(tvm.popcount(w[bw, i, k].astype(
                            dtype) & x[bx, k].astype(dtype)) <<
                                          (bw + bx).astype(dtype),
                                          axis=[bw, bx, k]),
                        name='z')
    Wb = tvm.decl_buffer(w.shape,
                         w.dtype,
                         name="W",
                         offset_factor=k_i,
                         strides=[tvm.var('ldw'),
                                  tvm.var('ldw'), 1])  # stride can be inferred
    Xb = tvm.decl_buffer(x.shape,
                         x.dtype,
                         name="X",
                         offset_factor=k_i,
                         strides=[tvm.var('ldw'), 1])
    Zb = tvm.decl_buffer(z.shape,
                         z.dtype,
                         name="Z",
                         offset_factor=1,
                         strides=[1])

    def _intrin_func(ins, outs):
        ww, xx = ins
        zz = outs[0]

        args_1 = tvm.const(1, 'uint32')
        args_2 = tvm.const(2, 'uint32')

        if unipolar:
            vpadd = "llvm.arm.neon.vpadd.v8i8"
            vpadalu = "llvm.arm.neon.vpadals.v16i8.v8i16"
            full_dtype = 'int8x16'
            half_dtype = 'int8x8'
            return_dtype = 'int16x8'
        else:
            vpadd = "llvm.arm.neon.vpadd.v8u8"
            vpadalu = "llvm.arm.neon.vpadalu.v16u8.v8u16"
            full_dtype = 'uint8x16'
            half_dtype = 'uint8x8'
            return_dtype = 'uint16x8'

        def _instr(index):
            irb = tvm.ir_builder.create()
            if index == 1:  # reduce reset
                irb.emit(zz.vstore(0, tvm.const(0, return_dtype)))
                return irb.get()
            # body and reduce update
            cnts8 = [None] * 8
            cnts4 = [None] * 4
            cnts2 = [None] * 2
            for bw in range(w_b):
                for bx in range(x_b):
                    if k_i == 16:
                        for i in range(m):
                            w_ = ww.vload([bw, i, 0],
                                          'uint8x16').astype(full_dtype)
                            x_ = xx.vload([bx, 0],
                                          'uint8x16').astype(full_dtype)
                            if unipolar:
                                cnts = tvm.popcount(w_
                                                    & x_) - tvm.popcount(~w_
                                                                         & x_)
                            else:
                                cnts = tvm.popcount(w_ & x_)
                            upper_half = tvm.call_pure_intrin(
                                half_dtype, 'vectorhigh', cnts)
                            lower_half = tvm.call_pure_intrin(
                                half_dtype, 'vectorlow', cnts)
                            cnts8[i] = upper_half + lower_half
                        for i in range(m // 2):
                            cnts4[i] = tvm.call_llvm_intrin(
                                half_dtype, vpadd, args_1, cnts8[i * 2],
                                cnts8[i * 2 + 1])
                        for i in range(m // 4):
                            cnts2[i] = tvm.call_llvm_intrin(
                                half_dtype, vpadd, args_1, cnts4[i * 2],
                                cnts4[i * 2 + 1])
                        cnts = tvm.call_pure_intrin(full_dtype,
                                                    'vectorcombine', cnts2[0],
                                                    cnts2[1])
                        shifted_cnts = cnts << tvm.const(bw + bx, pack_dtype)
                        out = tvm.call_llvm_intrin(return_dtype, vpadalu,
                                                   args_2,
                                                   zz.vload(0, return_dtype),
                                                   shifted_cnts)
                    else:  # ki == 8
                        for i in range(m):
                            w_ = ww.vload([bw, i, 0],
                                          'uint8x8').astype(half_dtype)
                            x_ = xx.vload([bx, 0],
                                          'uint8x8').astype(half_dtype)
                            if unipolar:
                                cnts8[i] = tvm.popcount(
                                    w_ & x_) - tvm.popcount(~w_ & x_)
                            else:
                                cnts8[i] = tvm.popcount(w_ & x_)
                        for i in range(m // 2):
                            cnts4[i] = tvm.call_llvm_intrin(
                                half_dtype, vpadd, args_1, cnts8[i * 2],
                                cnts8[i * 2 + 1])
                        for i in range(m // 4):
                            cnts2[i] = tvm.call_llvm_intrin(
                                half_dtype, vpadd, args_1, cnts4[i * 2],
                                cnts4[i * 2 + 1])
                        cnts = tvm.call_pure_intrin(full_dtype,
                                                    'vectorcombine', cnts2[0],
                                                    cnts2[1])
                        shifted_cnts = cnts << tvm.const(bw + bx, pack_dtype)
                        out = tvm.call_llvm_intrin(return_dtype, vpadalu,
                                                   args_2,
                                                   zz.vload(0, return_dtype),
                                                   shifted_cnts)
                    irb.emit(zz.vstore(0, out))
            return irb.get()

        # body, reset, update
        return _instr(0), _instr(1), _instr(2)

    with tvm.build_config(offset_factor=1, partition_const_loop=True):
        return tvm.decl_tensor_intrin(z.op,
                                      _intrin_func,
                                      binds={
                                          w: Wb,
                                          x: Xb,
                                          z: Zb
                                      })
Ejemplo n.º 23
0
def bitserial_dense_generic(cfg, data, weight, data_bits, weight_bits,
                            pack_dtype, out_dtype, unipolar):
    """The default implementation of bitserial dense in topi.

    Parameters
    ----------
    data : tvm.Tensor
        2-D with shape [batch, in_dim]

    weight : tvm.Tensor
        2-D with shape [out_dim, in_dim]

    Returns
    -------
    output : tvm.Tensor
        2-D with shape [batch, out_dim]
    """
    data_packed = bitpack(data,
                          data_bits,
                          pack_axis=1,
                          bit_axis=1,
                          pack_type=pack_dtype)
    if len(weight.shape) == 2:
        weight_packed = bitpack(weight,
                                weight_bits,
                                pack_axis=1,
                                bit_axis=1,
                                pack_type=pack_dtype)
    else:
        weight_packed = weight

    batch, DB, in_dim = get_const_tuple(data_packed.shape)
    out_dim, WB, in_dim = get_const_tuple(weight_packed.shape)

    # Pad Inputs so that microkernel can be used
    # out_dim and in_dim need to be multiples of 8
    if out_dim % 8 != 0:
        out_dim_pad = out_dim % 8
        data_packed = pad(data_packed, [0, 0, 0], [out_dim_pad, 0, 0],
                          name='PaddedInput')
        out_dim += out_dim_pad

    ######## Search space

    x, y = cfg.axis(batch), cfg.axis(out_dim)
    db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(
        in_dim)

    ko, ki = cfg.define_split(
        'tile_k',
        k,
        num_outputs=2,
        filter=lambda xx: xx.size[-1] == 8 or xx.size[-1] == 16)
    xo, xi = cfg.define_split('tile_x', x, num_outputs=2)
    yo, yi = cfg.define_split('tile_y',
                              y,
                              num_outputs=2,
                              filter=lambda xx: xx.size[-1] == 8)

    cfg.define_reorder('reorder_0', [yo, xo, ko, xi, wb, db, yi, ki],
                       policy='candidate',
                       candidate=[[yo, xo, ko, xi, wb, db, yi, ki],
                                  [yo, xo, xi, ko, wb, db, yi, ki],
                                  [yo, xo, ko, xi, wb, db, yi, ki]])

    ###### Compute rule
    VY = cfg['tile_y'].size[-1]
    VK = cfg['tile_k'].size[-1]

    wvshape = (out_dim // VY, in_dim // VK, WB, VY, VK)
    oshape = (batch, out_dim)

    k = tvm.reduce_axis((0, in_dim), name='k')
    db = tvm.reduce_axis((0, DB), name='db')
    wb = tvm.reduce_axis((0, WB), name='wb')

    # Tile data and weights
    weight_vec = tvm.compute(wvshape,
                             lambda yo, ko, wb, vy, vk: weight_packed[
                                 yo * VY + vy][wb][ko * VK + vk],
                             name='weight_vec')
    matmul_unipolar = tvm.compute(
        oshape,
        lambda x, y: tvm.sum((tvm.popcount(
            weight_vec[y // VY, k // VK, wb, y % VY, k % VK].astype(out_dtype)
            & data_packed[x, db, k].astype(out_dtype)) - tvm.popcount(
                ~weight_vec[y // VY, k // VK, wb, y % VY, k % VK].astype(
                    out_dtype) & data_packed[x, db, k].astype(out_dtype))) <<
                             (wb + db).astype(out_dtype),
                             axis=[wb, db, k]),
        tag='bitserial_dense_unipolar')

    matmul = tvm.compute(
        oshape,
        lambda x, y: tvm.sum(tvm.popcount(weight_vec[
            y // VY, k // VK, wb, y % VY, k % VK].astype(
                out_dtype) & data_packed[x, db, k].astype(out_dtype)) <<
                             (wb + db).astype(out_dtype),
                             axis=[wb, db, k]),
        tag='bitserial_dense')

    cfg.add_flop(batch * out_dim * in_dim * binary_op_multiplier(pack_dtype))

    if unipolar:
        return matmul_unipolar
    return matmul
Ejemplo n.º 24
0
 def _conv(nn, yy, xx, ff):
     b1b2 = (b1+b2).astype(out_dtype)
     return tvm.sum((tvm.popcount(
         PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] &
         Filter_q[ry, rx, rc, ff, b2]) << b1b2).astype(out_dtype),
                    axis=[rc, ry, rx, b2, b1])
Ejemplo n.º 25
0
 def _conv(nn, yy, xx, ff):
     b1b2 = (b1 + b2).astype(out_dtype)
     return tvm.sum((tvm.popcount(
         PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1]
         & Filter_q[ry, rx, rc, ff, b2]) << b1b2).astype(out_dtype),
                    axis=[rc, ry, rx, b2, b1])
Ejemplo n.º 26
0
 def _conv(n, h, w, co, vh, vw, vc):
     return tvm.sum((tvm.popcount(
         kernel_vec[co, dh, dw, kb, vc, ci].astype('uint16') &
         data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci].astype('uint16'))
                     << (kb + ib).astype('uint16')), axis=[dh, dw, kb, ib, ci])
Ejemplo n.º 27
0
def bitserial_dense(cfg,
                    data,
                    weight,
                    data_bits,
                    weight_bits,
                    pack_dtype='uint32',
                    out_dtype='int16',
                    unipolar=True):
    """Bitserial dense implementation. TODO: Why are these separate

    Parameters
    ----------
    data : tvm.Tensor
        2-D with shape [batch, in_dim]
    weight : tvm.Tensor
        2-D with shape [out_dim, in_dim] or
        3-D with shape [out_dim, weight_bits, in_dim]
    Returns
    -------
    output : tvm.Tensor
        2-D with shape [batch, out_dim]
    """
    data_packed = bitpack(data,
                          data_bits,
                          pack_axis=1,
                          bit_axis=1,
                          pack_type=pack_dtype)
    if len(weight.shape) == 2:
        weight_packed = bitpack(weight,
                                weight_bits,
                                pack_axis=1,
                                bit_axis=1,
                                pack_type=pack_dtype)
    else:
        weight_packed = weight
    Y, DB, K = get_const_tuple(data_packed.shape)
    X, WB, _ = get_const_tuple(weight_packed.shape)
    ######## Search space
    x, y = cfg.axis(X), cfg.axis(Y)
    db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K)
    ko, ki = cfg.define_split('tile_k', k, num_outputs=2)
    yo, yi = cfg.define_split('tile_y', y, num_outputs=2)
    xo, xi = cfg.define_split('tile_x', x, num_outputs=2)

    cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi],
                       policy='candidate',
                       candidate=[[yo, xo, ko, yi, wb, db, ki, xi],
                                  [yo, xo, yi, ko, wb, db, ki, xi]])

    cfg.define_annotate('ann_reduce', [db, wb], policy='try_unroll')
    cfg.define_annotate('ann_spatial', [yi, xi], policy='try_unroll_vec')

    ###### Compute rule
    VX = cfg['tile_x'].size[-1]

    wvshape = (X // VX, WB, VX, K)
    oshape = (Y, X)

    k = tvm.reduce_axis((0, K), name='k')
    db = tvm.reduce_axis((0, DB), name='db')
    wb = tvm.reduce_axis((0, WB), name='wb')

    # Tile data and weights
    weight_vec = tvm.compute(
        wvshape,
        lambda xo, wb, vx, k: weight_packed[xo * VX + vx][wb][k],
        name='weight_vec')

    idxdiv = tvm.indexdiv
    idxmod = tvm.indexmod

    matmul_unipolar = tvm.compute(
        oshape,
        lambda i, j: tvm.sum((tvm.popcount(weight_vec[
            idxdiv(j, VX), wb, idxmod(j, VX),
            k] & data_packed[i, db, k]) - tvm.popcount(~weight_vec[idxdiv(
                j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k])).astype(
                    out_dtype) << (db + wb).astype(out_dtype),
                             axis=[wb, db, k]),
        tag='bitserial_dense_unipolar')

    matmul = tvm.compute(oshape,
                         lambda i, j: tvm.sum(tvm.popcount(weight_vec[idxdiv(
                             j, VX), wb, idxmod(j, VX), k] & data_packed[
                                 i, db, k]).astype(out_dtype) <<
                                              (db + wb).astype(out_dtype),
                                              axis=[wb, db, k]),
                         tag='bitserial_dense')

    # binary ops
    cfg.add_flop(2 * Y * X * K * binary_op_multiplier(pack_dtype))

    if unipolar:
        return matmul_unipolar
    return matmul
Ejemplo n.º 28
0
def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtype='uint32',
                            out_dtype='int16', unipolar=True):
    """Bitserial dense implementation. TODO: Why are these separate

    Parameters
    ----------
    data : tvm.Tensor
        2-D with shape [batch, in_dim]
    weight : tvm.Tensor
        2-D with shape [out_dim, in_dim] or
        3-D with shape [out_dim, weight_bits, in_dim]
    Returns
    -------
    output : tvm.Tensor
        2-D with shape [batch, out_dim]
    """
    data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
    if len(weight.shape) == 2:
        weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
    else:
        weight_packed = weight
    Y, DB, K = get_const_tuple(data_packed.shape)
    X, WB, _ = get_const_tuple(weight_packed.shape)
    ######## Search space
    x, y = cfg.axis(X), cfg.axis(Y)
    db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K)
    ko, ki = cfg.define_split('tile_k', k, policy='all', num_outputs=2)
    yo, yi = cfg.define_split('tile_y', y, policy='all', num_outputs=2)
    xo, xi = cfg.define_split('tile_x', x, policy='all', num_outputs=2)

    cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi],
                       policy='candidate', candidate=[
                           [yo, xo, ko, yi, wb, db, ki, xi],
                           [yo, xo, yi, ko, wb, db, ki, xi]])

    cfg.define_annotate('ann_reduce', [db, wb], policy='try_unroll')
    cfg.define_annotate('ann_spatial', [yi, xi], policy='try_unroll_vec')

    ###### Compute rule
    VX = cfg['tile_x'].size[-1]

    wvshape = (X//VX, WB, VX, K)
    oshape = (Y, X)

    k = tvm.reduce_axis((0, K), name='k')
    db = tvm.reduce_axis((0, DB), name='db')
    wb = tvm.reduce_axis((0, WB), name='wb')

    # Tile data and weights
    weight_vec = tvm.compute(wvshape, lambda xo, wb, vx, k:
                             weight_packed[xo*VX+vx][wb][k], name='weight_vec')

    matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum(
        (tvm.popcount(weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k]) -
         tvm.popcount(~weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k])).astype(out_dtype)
        << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar')

    matmul = tvm.compute(oshape, lambda i, j: tvm.sum(
        tvm.popcount(weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k]).astype(out_dtype)
        << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')

    # binary ops
    cfg.add_flop(2 * Y * X * K * binary_op_multiplier(pack_dtype))

    if unipolar:
        return matmul_unipolar
    return matmul