def test_bound_nest_thread():
    m = tvm.var('m')
    A = tvm.placeholder((m), name='A')
    A1 = tvm.compute((m,), lambda i: A[i], name='A1')
    A2 = tvm.compute((m,), lambda i: A1[i] + 2, name='A2')
    A3 = tvm.compute((m,), lambda i: A2[i] + 3, name='A3')

    s = tvm.create_schedule(A3.op)

    block_x = tvm.thread_axis("blockIdx.x")
    thread_x = tvm.thread_axis("threadIdx.x")
    bx, tx = s[A3].split(A3.op.axis[0], factor=32)
    s[A3].bind(bx, block_x)
    s[A3].bind(tx, thread_x)
    s[A2].compute_at(s[A3], tx)
    _, xi = s[A2].split(A2.op.axis[0], nparts=1)
    s[A2].bind(xi, thread_x)
    s[A1].compute_at(s[A3], tx)
    s = s.normalize()
    bounds = tvm.schedule.InferBound(s)
    assert(bounds[A3.op.axis[0]].extent == m)
Beispiel #2
def _schedule_injective(op, sch):
    x = op.output(0)
    fused = sch[x].fuse(*sch[x].op.axis)
    num_thread =
    max_block = 256

        const_size = util.get_const_int(
        max_block = 256
        need_block_split = const_size > max_block * num_thread
    except ValueError:
        need_block_split = False

    if need_block_split:
        xo, xi = sch[x].split(fused, factor=num_thread * max_block)
        bx, tx = sch[x].split(xi, factor=num_thread)
        sch[x].reorder(bx, tx, xo)
        sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
        sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
        bx, tx = sch[x].split(fused, factor=num_thread)
        sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
        sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))

    return sch
def run_opencl():
    # NOTE: This is the setting for my rk3399 board. You need to modify
    # them according to your environment.
    target_host = "llvm -target=aarch64-linux-gnu"
    opencl_device_host = ''
    opencl_device_port = 9090

    # create scheule for the above "add one" compute decleration
    s = tvm.create_schedule(B.op)
    xo, xi = s[B].split(B.op.axis[0], factor=32)
    s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
    s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
    func =, [A, B], "opencl", target_host=target_host)

    remote = rpc.connect(opencl_device_host, opencl_device_port)

    # export and upload
    path = temp.relpath('lib_cl.tar')
    func = remote.load_module('lib_cl.tar')

    # run
    ctx =
    a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
    b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
    func(a, b)
    np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
    print("OpenCP test passed!")
Beispiel #4
    def extern(ins, outs):
        # pylint: disable=unused-argument
        """construct measurement function by building IR directly"""
        ib = tvm.ir_builder.create()

        bx = tvm.thread_axis("blockIdx.x")
        tx = tvm.thread_axis("threadIdx.x")

        ib.scope_attr(bx, "thread_extent", n // max_threads)
        ib.scope_attr(tx, "thread_extent", max_threads)

        idx = bx.var * max_threads + tx.var

        a = ib.allocate(dtype, (1), name='a', scope='local')
        b = ib.allocate(dtype, (1), name='b', scope='local')

        a[0] = outs[0].vload(idx, dtype)
        b[0] = outs[0].vload(idx, dtype)

        if base_type.find('float') != -1:
            mad_func = lambda x, y: (x * x + y)
            mad_func = lambda x, y: y * y + x

        for _ in range(item_per_thread // 4 // lanes):
            a[0] = mad_func(a[0], b[0])
            b[0] = mad_func(b[0], a[0])

        ib.emit(outs[0].vstore(idx, b[0]))
        return ib.get()
Beispiel #5
    def get_gemm_feature(target):
        k = tvm.reduce_axis((0, N), 'k')
        A = tvm.placeholder((N, N), name='A')
        B = tvm.placeholder((N, N), name='B')
        C = tvm.compute(A.shape, lambda y, x: tvm.sum(A[y, k] * B[k, x], axis=k),

        s = tvm.create_schedule(C.op)

        y, x = s[C].op.axis
        axes = list(s[C].tile(y, x, 8, 8)) + [k]
        perm = np.random.permutation(5)
        axes = [axes[x] for x in perm]

        if "gpu" in target.keys:
            pick = []
            # filter out reduction axis
            for i in range(len(perm)):
                if perm[i] != 4:
            s[C].bind(pick[0], tvm.thread_axis("blockIdx.x"))
            s[C].bind(pick[1], tvm.thread_axis("vthread"))
            s[C].bind(pick[2], tvm.thread_axis("threadIdx.y"))

        with target:
            feas = feature.get_itervar_feature(s, [A, B, C])
            feas = feature.flatten_itervar_feature(feas)
        return feas
def test_shared_memory():
    N = 1024
    M = 128

    A = tvm.placeholder((N,), name='A', dtype='float32')
    B = tvm.compute((N, ), lambda i: A[i], name='B')

    s = tvm.create_schedule([B.op])
    AA = s.cache_read(A, "shared", [B])
    o, i = s[B].split(s[B].op.axis[0], M)
    s[AA].compute_at(s[B], o)
    s[B].bind(o, tvm.thread_axis("blockIdx.x"))
    s[B].bind(i, tvm.thread_axis("threadIdx.x"))

    # shared memory usage: M * 4B
    # thread usage: M

    for target in ['opencl', 'cuda']:
        if not tvm.context(target).exist:
        valid = [None]
        with tvm.build_config(**{"add_lower_pass": [
            (2, get_verify_pass(valid,
                                max_shared_memory_per_block=4 * M - 1,
  , [A, B], target)
        assert not valid[0]

        with tvm.build_config(**{"add_lower_pass": [
            (2, get_verify_pass(valid,
                                max_shared_memory_per_block=4 * M,
  , [A, B], target)
        assert valid[0]
def test_storage_share_gpu():
    m = tvm.var('m')
    A = [tvm.placeholder((m), name='A')]
    num_stage = 5
    for t in range(num_stage):
        A.append(tvm.compute((m,), lambda i: A[-1][i] + (t+1), name='A%d_s' % t))
        A.append(tvm.compute((m,), lambda i: A[-1][i], name='A%d' % t))
    s = tvm.create_schedule(A[-1].op)
    for t in range(num_stage):
        x = A[2*t+2].op.axis[0]
        bx, tx = s[A[2*t+2]].split(x, factor=32)
        s[A[2*t+2]].bind(bx, tvm.thread_axis("blockIdx.x"))
        s[A[2*t+2]].bind(tx, tvm.thread_axis("threadIdx.x"))
        s[A[2*t+1]].compute_at(s[A[2*t+2]], tx)

    bounds = tvm.schedule.InferBound(s)
    assert isinstance(bounds, tvm.container.Map)
    stmt = tvm.schedule.ScheduleOps(s, bounds)
    Ab = tvm.decl_buffer(A[0].shape, A[0].dtype, name='A')
    Bb = tvm.decl_buffer(A[0].shape, A[0].dtype, name='B')
    stmt = tvm.ir_pass.StorageFlatten(stmt, {A[0]: Ab, A[-1]: Bb}, 64)
    stmt = tvm.ir_pass.CanonicalSimplify(stmt)
    stmt = tvm.ir_pass.Simplify(stmt)
    stmt = tvm.ir_pass.StorageRewrite(stmt)
    alloc_stats = {"global": 0, "shared": 0}

    def verify(n):
        if isinstance(n, tvm.stmt.AttrStmt):
            if n.attr_key == "storage_scope":
                alloc_stats[n.value.value] += 1
    tvm.ir_pass.PostOrderVisit(stmt, verify)
    assert alloc_stats["global"] == 2
    assert alloc_stats["shared"] == num_stage
Beispiel #8
 def check_cuda(dtype, n, lanes):
     if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
         print("skip because cuda is not enabled..")
     if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version):
         print("skip because gpu does not support int8")
     A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
     B = tvm.placeholder((n,), name='B', dtype="%sx%d" % (dtype, lanes))
     C = tvm.placeholder((n,), name='C', dtype="int32")
     D = tvm.compute((n,),
                     lambda i: tvm.call_pure_extern("int32", "__dp4a", A[i], B[i], C[i]), name='D')
     s = tvm.create_schedule(D.op)
     xo, xi = s[D].split(D.op.axis[0], factor=num_thread)
     s[D].bind(xo, tvm.thread_axis("blockIdx.x"))
     s[D].bind(xi, tvm.thread_axis("threadIdx.x"))
     fun =, [A, B, C, D], "cuda")
     np_a = np.random.randint(low=-128, high=127, size=(n,lanes))
     np_b = np.random.randint(low=-128, high=127, size=(n,lanes))
     np_c = np.random.randint(low=0, high=127, size=(n,))
     np_d = [sum(x * y) + z for x, y, z in zip(np_a, np_b, np_c)]
     ctx = tvm.gpu(0)
     a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(np_a)
     b = tvm.nd.empty((n,), B.dtype, ctx).copyfrom(np_b)
     c = tvm.nd.empty((n,), C.dtype, ctx).copyfrom(np_c)
     d = tvm.nd.empty((n,), D.dtype, ctx)
     fun(a, b, c, d)
     tvm.testing.assert_allclose(d.asnumpy(), np_d)
Beispiel #9
def test_exp():
    # graph
    n = tvm.convert(1024)
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute(A.shape, lambda *i: tvm.exp(A(*i)), name='B')
    s = tvm.create_schedule(B.op)
    # create iter var and assign them tags.
    num_thread = 8
    bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
    s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[B].bind(tx, tvm.thread_axis("threadIdx.x"))

    # one line to build the function.
    def check_device(device, host="stackvm"):
        if not tvm.module.enabled(host):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
        fexp =, [A, B],
                         device, host,
        ctx = tvm.context(device, 0)
        # launch the kernel.
        n = 1024
        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
        fexp(a, b)
            b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5)

    check_device("cuda", "llvm")
Beispiel #10
def test_multiple_kernels():
    N = 1024

    A = tvm.placeholder((N, N), name='A')
    B = tvm.compute((N, N), lambda i, j: A[i, j])
    C = tvm.compute((N, N), lambda i, j: B[i, j])

    s = tvm.create_schedule([C.op])

    s[C].bind(s[C].op.axis[1], tvm.thread_axis("threadIdx.x"))
    s[B].bind(s[B].op.axis[1], tvm.thread_axis("threadIdx.x"))

    # shared memory usage: 0
    # thread usage: N

    for target in ['opencl', 'cuda']:
        if not tvm.context(target).exist:

        valid = [None]
        with tvm.build_config(**{"add_lower_pass": [
            (2, get_verify_pass(valid,
                                max_threads_per_block=N - 1))]}):
  , [A, C], target)
        assert not valid[0]

        with tvm.build_config(**{"add_lower_pass": [
            (2, get_verify_pass(valid,
  , [A, C], target)
        assert valid[0]
Beispiel #11
    def traverse(op):
        """inline all one-to-one-mapping operators except the last stage (output)"""
        if "nms" in op.tag:
            sort = op.input_tensors[1]
            score = s[sort].op.input_tensors[0]
            fused = s[score].fuse(*s[score].op.axis)
            num_thread =
            bx, tx = s[score].split(fused, factor=num_thread)
            s[score].bind(bx, tvm.thread_axis("blockIdx.x"))
            s[score].bind(tx, tvm.thread_axis("threadIdx.x"))
        if tag.is_broadcast(op.tag):
            if op not in s.outputs:
                x = op.output(0)
                fused = s[x].fuse(*s[x].op.axis)
                num_thread =
                bx, tx = s[x].split(fused, factor=num_thread)
                s[x].bind(bx, tvm.thread_axis("blockIdx.x"))
                s[x].bind(tx, tvm.thread_axis("threadIdx.x"))
            for tensor in op.input_tensors:
                if tensor.op.input_tensors and tensor.op not in scheduled_ops:

Beispiel #12
def try_warp_memory():
    """skip this in default test because it require higher arch"""
    m = 128
    A = tvm.placeholder((m,), name='A')
    B = tvm.compute((m,), lambda i: A[i] + 3, name='B')
    warp_size = 32
    s = tvm.create_schedule(B.op)
    AA = s.cache_read(A, "warp", [B])
    xo, xi = s[B].split(B.op.axis[0], warp_size * 2)
    xi0, xi1 = s[B].split(xi, factor=warp_size)
    tx = tvm.thread_axis("threadIdx.x")
    s[B].bind(xi1, tx)
    s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
    s[AA].compute_at(s[B], xo)
    xo, xi = s[AA].split(s[AA].op.axis[0], warp_size)
    s[AA].bind(xi, tx)

    def tvm_callback_cuda_compile(code):
        ptx =  nvcc.compile_cuda(code, target="ptx")
        return ptx

    # one line to build the function.
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("skip because %s is not enabled.." % device)
        f =, [A, B], device)
        a = tvm.nd.array((np.random.uniform(size=m) * 256).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
        f(a, b)
            b.asnumpy(), a.asnumpy() + 3, rtol=1e-6)
Beispiel #13
def test_rfactor_argmax():
    def fcombine(x, y):
        lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
        rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
        return lhs, rhs

    def fidentity(t0, t1):
        return tvm.const(-1, t0), tvm.min_value(t1)

    argmax = tvm.comm_reducer(fcombine,

    nn = 1027
    mm = 10
    n = tvm.convert(nn)
    m = tvm.convert(mm)
    A0 = tvm.placeholder((m, n), name='A0', dtype='int32')
    A1 = tvm.placeholder((m, n), name='A1', dtype='float32')
    k = tvm.reduce_axis((0, n))
    B0, B1 = tvm.compute((m,), lambda i: argmax((A0[i, k], A1[i, k]), axis=k), name='B')

    # schedule
    s = tvm.create_schedule(B0.op)
    nthread = 16
    ko, kf = s[B0].split(k, factor=nthread)
    BF0, BF1 = s.rfactor(B0, kf)
    bx, ty = s[B0].split(s[B0].op.axis[0], factor=nthread)
    s[B0].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[B0].bind(ty, tvm.thread_axis("threadIdx.y"))
    tx = s[B0].op.reduce_axis[0]
    thread_x = tvm.thread_axis("threadIdx.x")
    s[B0].bind(tx, thread_x)
    s[BF0.op].compute_at(s[B0], tx)

    def check_target(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("skip because %s is not enabled.." % device)
        fapi = tvm.lower(s, args=[A0, A1, B0, B1])
        fargmax =,

        np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
        np_val = np.random.uniform(size=(mm, nn)).astype('float32')
        np_res = np.argmax(np_val, axis=1)

        nd_idx  = tvm.nd.array(np_idx, ctx)
        nd_val  = tvm.nd.array(np_val, ctx)
        nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
        nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
        fargmax(nd_idx, nd_val, nd_res0, nd_res1)
        tvm.testing.assert_allclose(np_res, nd_res0.asnumpy())

Beispiel #14
def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
    """Non-maximum supression.

    sorted_bbox_buf : tvm.schedule.Buffer
        3-D with shape [batch, num_bbox, 5]. The last dimension is in format of
        [w_start, h_start, w_end, h_end, score].

    out_buf : tvm.schedule.Buffer
        2-D with shape [batch, num_bbox]. Boolean mask of whether a bounding box should be removed.

    nms_threshold : float
        Non-maximum suppression threshold.

    stmt : Stmt
        The result IR statement.
    def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
        """Calculate overlap of two boxes.
        w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
                    - tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx]) + 1.0)
        h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
                    - tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]) + 1.0)
        i = w * h
        u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx] + 1.0) * \
            (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1] + 1.0) + \
            (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx] + 1.0) * \
            (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1] + 1.0) - i
        return i / u

    batch, num_bbox = get_const_tuple(out_buf.shape)
    max_threads = int(math.sqrt(
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("blockIdx.x")
    ib = tvm.ir_builder.create()
    p_data = ib.buffer_ptr(sorted_bbox_buf)
    p_out = ib.buffer_ptr(out_buf)
    nthread_tx = max_threads
    nthread_bx = num_bbox // max_threads + 1
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    i = bx * max_threads + tx
    with ib.for_range(0, batch, for_type="unroll", name="n") as b:
        base_idx = b * num_bbox
        with ib.if_scope(i < num_bbox):
            p_out[base_idx + i] = False
        with ib.for_range(0, num_bbox - 1) as l:
            with ib.if_scope(tvm.all(i < num_bbox, i > l, p_out[base_idx + l] == False)):
                iou = calculate_overlap(p_data, (base_idx + l) * 5, (base_idx + i) * 5)
                with ib.if_scope(iou > nms_threshold):
                    p_out[base_idx + i] = True
        ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
                              tvm.expr.Call.Intrinsic, None, 0))
    return ib.get()
Beispiel #15
def fuse_and_bind(s, tensor, axis=None, num_thread=None):
    """ fuse all the axis and bind to GPU threads """
    axis = axis or s[tensor].op.axis
    fused = s[tensor].fuse(*axis)
    bx, tx = s[tensor].split(fused, num_thread)
    s[tensor].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[tensor].bind(tx, tvm.thread_axis("threadIdx.x"))
    return bx, tx
Beispiel #16
def _schedule_output(op, sch):
    x = op.output(0)
    fused = sch[x].fuse(*sch[x].op.axis)
    num_thread =
    bx, tx = sch[x].split(fused, factor=num_thread)
    sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
    sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
    return sch
Beispiel #17
def test_device_module_dump():
    # graph
    n = tvm.convert(1024)
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
    s = tvm.create_schedule(B.op)
    # create iter var and assign them tags.
    num_thread = 8
    bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
    s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[B].bind(tx, tvm.thread_axis("threadIdx.x"))

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
        temp = util.tempdir()
        name = "myadd_%s" % device
        if sys.platform == "darwin" or sys.platform.startswith('linux'):
            f =, [A, B], device, "llvm -system-lib", name=name)
        elif sys.platform == "win32":
            f =, [A, B], device, "llvm", name=name)
            raise ValueError("Unsupported platform")

        path_dso = temp.relpath("")

        f1 = tvm.module.load(path_dso)
        a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
        f1(a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
        if sys.platform != "win32":
            f2 = tvm.module.system_lib()
            f2[name](a, b)
            np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)

    def check_stackvm(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
        temp = util.tempdir()
        name = "myadd_%s" % device
        f =, [A, B], device, "stackvm", name=name)
        path_dso = temp.relpath("dev_lib.stackvm")
        #f1 = tvm.module.load(path_dso)
        a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
        f(a, b)
        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)

    for device in ["cuda", "vulkan", "opencl", "metal"]:
Beispiel #18
def fuse_and_bind(s, tensor, axis=None, num_thread=None):
    """ fuse all the axis and bind to GPU threads """
    axis = axis or s[tensor].op.axis
    fused = s[tensor].fuse(*axis)
    max_threads =
    bx, tx = s[tensor].split(fused, num_thread or max_threads)
    s[tensor].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[tensor].bind(tx, tvm.thread_axis("threadIdx.x"))
    return bx, tx
Beispiel #19
Datei: Projekt: bddppq/tvm
def get_valid_counts_upsweep(data, idx_in, idx, partial):
    """Low level IR of first step of scan: unsweep.

    data: Buffer
        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.

    idx_in : Buffer
        2D Buffer of valid data indices with shape [batch_size, num_anchors].

    idx : Buffer
        2D Buffer of valid data indices with shape [batch_size, num_anchors].

    partial : Buffer
        2D Buffer of valid data indices with shape [batch_size, new_range].

    stmt : Stmt
        The result IR statement.
    batch_size = data.shape[0]
    num_anchors = data.shape[1]
    ib = tvm.ir_builder.create()
    data = ib.buffer_ptr(data)
    idx_in = ib.buffer_ptr(idx_in)
    idx = ib.buffer_ptr(idx)
    partial = ib.buffer_ptr(partial)
    max_threads = int(
    elem_per_thread = num_anchors // max_threads + 1
    nthread_tx = max_threads
    nthread_bx = batch_size
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("blockIdx.x")
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    new_range = num_anchors // elem_per_thread + 1
    # Scan: Upsweep:
    with ib.if_scope(tvm.all(bx < batch_size, tx < new_range)):
        with ib.for_range(0, elem_per_thread) as i:
            with ib.if_scope(bx * num_anchors + \
                             tx * elem_per_thread + i < batch_size * num_anchors):
                with ib.if_scope(i == 0):
                    partial[bx * new_range + tx] = idx_in[bx * num_anchors + tx * elem_per_thread]
                    idx[bx * num_anchors + tx * elem_per_thread] = \
                    idx_in[bx * num_anchors + tx * elem_per_thread]
                with ib.else_scope():
                    partial[bx * new_range + tx] += \
                    idx_in[bx * num_anchors + tx * elem_per_thread + i]
                    idx[bx * num_anchors + tx * elem_per_thread + i] = \
                    idx[bx * num_anchors + tx * elem_per_thread + i - 1] + \
                    idx_in[bx * num_anchors + tx * elem_per_thread + i]
            ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
                                  tvm.expr.Call.Intrinsic, None, 0))
    return ib.get()
Beispiel #20
def tile_and_bind(s, tensor, y, x, y_factor, x_factor=None):
    """ tile and bind to GPU threads """
    x_factor = x_factor or y_factor
    yo, xo, yi, xi = s[tensor].tile(y, x, y_factor, x_factor)
    s[tensor].bind(xo, tvm.thread_axis("blockIdx.x"))
    s[tensor].bind(xi, tvm.thread_axis("threadIdx.x"))
    s[tensor].bind(yo, tvm.thread_axis("blockIdx.y"))
    s[tensor].bind(yi, tvm.thread_axis("threadIdx.y"))
    return yo, xo, yi, xi
def test_gemm_bound():
    nn = 1024
    n = tvm.convert(nn)
    A = tvm.placeholder((n, n), name='A')
    B = tvm.placeholder((n, n), name='B')
    k = tvm.reduce_axis((0, n), name='k')
    C = tvm.compute(
        (n, n),
        lambda ii, jj: tvm.sum(A[ii, k] * B[jj, k], axis=k),
    # schedule
    s = tvm.create_schedule(C.op)
    xtile, ytile = 32, 32
    scale = 8
    num_thread = 8
    block_factor = scale * num_thread
    block_x = tvm.thread_axis("blockIdx.x")
    thread_x = tvm.thread_axis("threadIdx.x")
    block_y = tvm.thread_axis("blockIdx.y")
    thread_y = tvm.thread_axis("threadIdx.y")

    CC = s.cache_write(C, "local")
    AA = s.cache_read(A, "shared", [CC])
    BB = s.cache_read(B, "shared", [CC])
    by, yi = s[C].split(C.op.axis[0], factor=block_factor)
    bx, xi = s[C].split(C.op.axis[1], factor=block_factor)
    s[C].reorder(by, bx, yi, xi)
    s[C].bind(by, block_y)
    s[C].bind(bx, block_x)
    ty, yi = s[C].split(yi, nparts=num_thread)
    tx, xi = s[C].split(xi, nparts=num_thread)
    s[C].reorder(ty, tx, yi, xi)
    s[C].bind(ty, thread_y)
    s[C].bind(tx, thread_x)
    yo, xo = CC.op.axis
    s[CC].reorder(k, yo, xo)

    s[CC].compute_at(s[C], tx)
    s[AA].compute_at(s[CC], k)
    s[BB].compute_at(s[CC], k)

    ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
    tx, xi = s[AA].split(xi, nparts=num_thread)
    s[AA].bind(ty, thread_y)
    s[AA].bind(tx, thread_x)

    ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread)
    tx, xi = s[BB].split(xi, nparts=num_thread)
    s[BB].bind(ty, thread_y)
    s[BB].bind(tx, thread_x)
    s = s.normalize()
    bounds = tvm.schedule.InferBound(s)
    assert(bounds[CC.op.axis[0]].extent.value == 8)
    assert(bounds[CC.op.axis[1]].extent.value == 8)
Beispiel #22
    def _schedule(op):
        C = op.output(0)
        A, B = s[C].op.input_tensors
        _, M, N = get_const_tuple(C.shape)
        AA = s.cache_read(A, "shared", [C])
        AL = s.cache_read(AA, "local", [C])
        BB = s.cache_read(B, "shared", [C])
        BL = s.cache_read(BB, "local", [C])
        CC = s.cache_write(C, "local")

        b, y, x = s[C].op.axis
        y_bn = get_max_power2_factor(M, 64)
        x_bn = get_max_power2_factor(N, 64)
        by, y = s[C].split(y, y_bn)
        bx, x = s[C].split(x, x_bn)
        y_nthreads = min(y_bn, 8)
        x_nthreads = min(x_bn, 8)
        ty, yi = s[C].split(y, nparts=y_nthreads)
        tx, xi = s[C].split(x, nparts=x_nthreads)
        thread_x = tvm.thread_axis((0, x_nthreads), "threadIdx.x")
        thread_y = tvm.thread_axis((0, y_nthreads), "threadIdx.y")

        s[C].reorder(b, by, bx, ty, tx, yi, xi)
        s[C].bind(b, tvm.thread_axis("blockIdx.z"))
        s[C].bind(by, tvm.thread_axis("blockIdx.y"))
        s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
        s[C].bind(ty, thread_y)
        s[C].bind(tx, thread_x)
        s[C].pragma(yi, "auto_unroll_max_step", 16)

        s[CC].compute_at(s[C], tx)
        _, yi, xi = s[CC].op.axis
        k, = s[CC].op.reduce_axis
        ko, ki = s[CC].split(k, 8)
        s[CC].reorder(ko, ki, yi, xi)
        s[CC].pragma(ki, "auto_unroll_max_step", 16)

        s[AA].compute_at(s[CC], ko)
        s[AL].compute_at(s[CC], ki)
        s[BB].compute_at(s[CC], ko)
        s[BL].compute_at(s[CC], ki)
        _, y, k = s[AA].op.axis
        ty, yi = s[AA].split(y, nparts=y_nthreads)
        tx, ki = s[AA].split(k, nparts=x_nthreads)
        s[AA].reorder(ty, tx, yi, ki)
        s[AA].bind(ty, thread_y)
        s[AA].bind(tx, thread_x)
        s[AA].pragma(yi, "auto_unroll_max_step", 16)

        _, x, k = s[BB].op.axis
        ty, xi = s[BB].split(x, nparts=y_nthreads)
        tx, ki = s[BB].split(k, nparts=x_nthreads)
        s[BB].bind(ty, thread_y)
        s[BB].bind(tx, thread_x)
        s[BB].reorder(ty, tx, xi, ki)
        s[BB].pragma(xi, "auto_unroll_max_step", 16)
Beispiel #23
Datei: Projekt: bddppq/tvm
def get_valid_counts_scan(data, partial_in, partial):
    """Low level IR to do scan.

    data: Buffer
        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.

    idx_in : Buffer
        2D Buffer of valid data indices with shape [batch_size, num_anchors].

    idx : Buffer
        2D Buffer of valid data indices with shape [batch_size, num_anchors].

    partial : Buffer
        2D Buffer of valid data indices with shape [batch_size, new_range].

    stmt : Stmt
        The result IR statement.
    batch_size = data.shape[0]
    num_anchors = data.shape[1]
    ib = tvm.ir_builder.create()
    partial_in = ib.buffer_ptr(partial_in)
    partial = ib.buffer_ptr(partial)
    max_threads = int(
    elem_per_thread = num_anchors // max_threads + 1
    nthread_tx = max_threads
    nthread_bx = batch_size
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("blockIdx.x")
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    var = tvm.make.node("FloatImm", dtype="float32", value=2)
    new_range = num_anchors // elem_per_thread + 1
    iteration = log(cast(new_range, "float32")) // math.log(2)
    # Scan: Kogge-Stone adder
    with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))):
        with ib.for_range(0, iteration) as k:
            with ib.if_scope(k == 0):
                with ib.if_scope(tvm.all(tx > 0, tx < tvm.min(new_range, num_anchors))):
                    partial[bx * new_range + tx] = \
                    partial_in[bx * new_range + tx] + partial_in[bx * new_range + tx - 1]
                with ib.else_scope():
                    partial[bx * new_range] = partial_in[bx * new_range]
            with ib.else_scope():
                with ib.if_scope(tvm.all(tx >= cast(power(var, k), "int32"), \
                                         tx < tvm.min(new_range, num_anchors))):
                    partial[bx * new_range + tx] += \
                    partial[bx * new_range + tx - cast(power(var, k), "int32")]
            ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
                                  tvm.expr.Call.Intrinsic, None, 0))
    return ib.get()
Beispiel #24
def test_rpc_module():
    # graph
    n = tvm.convert(1024)
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
    temp = util.tempdir()
    s = tvm.create_schedule(B.op)
    xo, xi = s[B].split(B.op.axis[0], factor=64)
    s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
    s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
    # Build the dynamic lib.
    # If we don't want to do metal and only use cpu, just set target to be target
    f =, [A, B], "metal", target_host=target, name="myadd")
    path_dso1 = temp.relpath("dev_lib.dylib")
    f.export_library(path_dso1, xcode.create_dylib,
                     arch=arch, sdk=sdk)

    s = tvm.create_schedule(B.op)
    xo, xi = s[B].split(B.op.axis[0], factor=64)
    s[B].pragma(xo, "parallel_launch_point")
    s[B].pragma(xi, "parallel_barrier_when_finish")
    f =, [A, B], target, name="myadd_cpu")
    path_dso2 = temp.relpath("cpu_lib.dylib")
    f.export_library(path_dso2, xcode.create_dylib,
                     arch=arch, sdk=sdk)

    # Start RPC test server that contains the compiled library.
    server = xcode.popen_test_rpc(proxy_host, proxy_port, key,
                                  libs=[path_dso1, path_dso2])

    # connect to the proxy
    remote = rpc.connect(proxy_host, proxy_port, key=key)
    ctx = remote.metal(0)
    f1 = remote.load_module("dev_lib.dylib")
    a_np = np.random.uniform(size=1024).astype(A.dtype)
    a = tvm.nd.array(a_np, ctx)
    b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
    time_f = f1.time_evaluator(f1.entry_name, ctx, number=10)
    cost = time_f(a, b).mean
    print('%g secs/op' % cost)
    np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
    # CPU
    ctx = remote.cpu(0)
    f2 = remote.load_module("cpu_lib.dylib")
    a_np = np.random.uniform(size=1024).astype(A.dtype)
    a = tvm.nd.array(a_np, ctx)
    b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
    time_f = f2.time_evaluator(f1.entry_name, ctx, number=10)
    cost = time_f(a, b).mean
    print('%g secs/op' % cost)
    np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
Beispiel #25
Datei: Projekt: bddppq/tvm
def invalid_to_bottom_pre(data, flag, idx):
    """Low level IR to rearrange nms output to move all valid entries to top.

    data: Buffer
        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.

    flag : Buffer
        1D Buffer of flag indicating valid data with [num_anchors].

    idx : Buffer
        1D Buffer of valid data indices with [num_anchors].

    stmt : Stmt
        The result IR statement.
    batch_size = data.shape[0]
    num_anchors = data.shape[1]
    elem_length = data.shape[2]

    ib = tvm.ir_builder.create()

    data = ib.buffer_ptr(data)
    flag = ib.buffer_ptr(flag)
    idx = ib.buffer_ptr(idx)

    max_threads = int(math.sqrt(
    nthread_tx = max_threads
    nthread_bx = num_anchors // max_threads + 1
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("blockIdx.x")
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    j = bx * max_threads + tx

    with ib.for_range(0, batch_size, for_type="unroll") as i:
        base_idx = i * num_anchors * elem_length
        with ib.if_scope(j < num_anchors):
            with ib.if_scope(data[base_idx + j * elem_length] >= 0):
                flag[i * num_anchors + j] = 1
                idx[i * num_anchors + j] = 1
            with ib.else_scope():
                flag[i * num_anchors + j] = 0
                idx[i * num_anchors + j] = 0

    with ib.if_scope(j < batch_size):
        with ib.for_range(0, num_anchors) as k:
            with ib.if_scope(k > 0):
                idx[j * num_anchors + k] += idx[j * num_anchors + k - 1]
    return ib.get()
Beispiel #26
 def extern_generator_gpu(ins, outs):
     """Manually write the IR for the extern function, add pipeline"""
     ib = tvm.ir_builder.create()
     bx = tvm.thread_axis("blockIdx.x")
     tx = tvm.thread_axis("threadIdx.x")
     ib.scope_attr(bx, "thread_extent", (nn+max_threads-1) // max_threads)
     ib.scope_attr(tx, "thread_extent", max_threads)
     idx = bx.var * max_threads + tx.var
     with ib.if_scope(ib.likely(idx < n)):
         ib.emit(outs[0].vstore(idx*2, ins[0].vload(idx*2, "float32x2") + tvm.const(1, "float32x2")))
     return ib.get()
Beispiel #27
    def _schedule(Padded_out_grad, In_grad):

        block_x = tvm.thread_axis("blockIdx.x")
        thread_x = tvm.thread_axis("threadIdx.x")
        _, h, w, c = In_grad.op.axis

        fused_hwc = s[In_grad].fuse(h, w, c)
        xoc, xic = s[In_grad].split(fused_hwc, factor=128)

        s[In_grad].bind(xoc, block_x)
        s[In_grad].bind(xic, thread_x)
Beispiel #28
Datei: Projekt: bddppq/tvm
def get_valid_counts_pre(data, flag, idx, score_threshold):
    """Low level IR to Prepare get valid count of bounding boxes
    given a score threshold. Also moves valid boxes to the
    top of input data.

    data: Buffer
        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.

    flag : Buffer
        2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].

    idx : Buffer
        2D Buffer of valid data indices with shape [batch_size, num_anchors].

    score_threshold : float32
        Lower limit of score for valid bounding boxes.

    stmt : Stmt
        The result IR statement.
    batch_size = data.shape[0]
    num_anchors = data.shape[1]
    box_data_length = data.shape[2]

    ib = tvm.ir_builder.create()

    data = ib.buffer_ptr(data)
    flag = ib.buffer_ptr(flag)
    idx = ib.buffer_ptr(idx)
    score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold)

    max_threads = int(
    nthread_tx = max_threads
    nthread_bx = batch_size * num_anchors // max_threads + 1
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("blockIdx.x")
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    tid = bx * max_threads + tx

    with ib.if_scope(tid < batch_size * num_anchors):
        with ib.if_scope(data[tid * box_data_length + 1] > score_threshold):
            flag[tid] = 1
            idx[tid] = 1
        with ib.else_scope():
            flag[tid] = 0
            idx[tid] = 0

    return ib.get()
Beispiel #29
def argsort_ir(data_buf, out_index_buf):
    """Batched odd-even transposition sort.

    data_buf : tvm.schedule.Buffer
        2-D with shape [batch, num_bbox]

    out_index_buf : tvm.schedule.Buffer
        2-D with shape [batch, num_bbox]. Indices of data in sorted order.

    stmt : Stmt
        The result IR statement.
    batch, num_bbox = get_const_tuple(data_buf.shape)
    max_threads = int(
    ib = tvm.ir_builder.create()
    p_data = ib.buffer_ptr(data_buf)
    index_out = ib.buffer_ptr(out_index_buf)
    nthread_tx = max_threads
    nthread_bx = (num_bbox + 1) // 2 // max_threads + 1
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("vthread")
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "virtual_thread", nthread_bx)
    tid = bx * nthread_tx + tx
    temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
    temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")

    with ib.for_range(0, batch, for_type="unroll") as b:
        start = b * num_bbox
        for i in range(2):
            bbox_id = tid * 2 + i
            with ib.if_scope(bbox_id < num_bbox):
                index_out[start + bbox_id] = bbox_id
        with ib.for_range(0, num_bbox) as k:
            offset = start + 2 * tid + (k % 2)
            with ib.if_scope(
                tvm.all(offset + 1 < num_bbox, p_data[offset] < p_data[offset + 1])):
                temp_data[0] = p_data[offset]
                p_data[offset] = p_data[offset + 1]
                p_data[offset + 1] = temp_data[0]
                temp_index[0] = index_out[offset]
                index_out[offset] = index_out[offset + 1]
                index_out[offset + 1] = temp_index[0]
            ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
                                  tvm.expr.Call.Intrinsic, None, 0))
    return ib.get()
Beispiel #30
def test_num_thread():
    N = 1024
    M = 128

    A = tvm.placeholder((N,), name='A', dtype='float32')
    B = tvm.compute((N, ), lambda i: A[i], name='B')

    s = tvm.create_schedule([B.op])
    o, i = s[B].split(s[B].op.axis[0], M)

    s[B].bind(o, tvm.thread_axis('threadIdx.x'))
    s[B].bind(i, tvm.thread_axis("threadIdx.y"))

    # shared memory usage: 0
    # thread usage: N

    for target in ['opencl', 'cuda']:
        if not tvm.context(target).exist:

        valid = [None]
        with tvm.build_config(**{"add_lower_pass": [
            (2, get_verify_pass(valid,
                                max_threads_per_block=N - 1))]}):
  , [A, B], target)
        assert not valid[0]

        with tvm.build_config(**{"add_lower_pass": [
            (2, get_verify_pass(valid,
  , [A, B], target)
        assert valid[0]

        with tvm.build_config(**{"add_lower_pass": [
            (2, get_verify_pass(valid,
  , [A, B], target)
        assert not valid[0]

        with tvm.build_config(**{"add_lower_pass": [
            (2, get_verify_pass(valid,
  , [A, B], target)
        assert valid[0]
Beispiel #31
print(tvm.lower(s, [A, B], simple_mode=True))

# You can find that the IR code is quite like the C code.
# The reduction axis is similar to a normal axis, it can be splitted.
# In the following code we split both the row axis of B as well
# axis by different factors. The result is a nested reduction.
ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
xo, xi = s[B].split(B.op.axis[0], factor=32)
print(tvm.lower(s, [A, B], simple_mode=True))

# If we are building a GPU kernel, we can bind the rows of B to GPU threads.
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
print(tvm.lower(s, [A, B], simple_mode=True))

# Reduction Factoring and Parallelization
# ---------------------------------------
# One problem of building a reduction is that we cannot simply
# parallelize over the reduction axis. We need to divide the computation
# of the reduction, store the local reduction result in a temporal array
# before doing a reduction over the temp array.
# The rfactor primitive does such rewrite of the computation.
# In the following schedule, the result of B is written to a temporary
# result B.rf. The factored dimension becomes the first dimension of B.rf.
Beispiel #32
def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L):
    """Schedule conv2d for specific feature_in_out_filter pattern"""
    # scheduler params
    num_thread = 8
    vthread = 2
    opart2 = 4
    ofactor = 64
    wfactor = 56
    ifactor = 8
    if util.get_const_int(Filter.shape[0]) == 64:
        opart2 = 8
        ifactor = 16
    sfactor = max(1, ofactor//(opart2*2))
    spart = max(1, (wfactor + vthread-1) // vthread)

    block_x = tvm.thread_axis("blockIdx.x")
    block_y = tvm.thread_axis("blockIdx.y")
    block_z = tvm.thread_axis("blockIdx.z")
    thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
    thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
    thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
    thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")

    i, oc, h, w = s[Out].op.axis
    ooc, ioc = s[Out].split(oc, factor=ofactor)
    ow, iw = s[Out].split(w, factor=wfactor)
    ow = s[Out].fuse(ow, h)
    oioc, iioc = s[Out].split(ioc, nparts=vthread)
    oiw, iiw = s[Out].split(iw, nparts=vthread)
    oiioc, iiioc = s[Out].split(iioc, nparts=opart2)
    s[Out].reorder(i, ooc, ow, oioc, oiw, oiioc, iiw, iiioc)
    s[Out].bind(iiioc, thread_x)
    s[Out].bind(iiw, thread_y)
    s[Out].bind(oiioc, thread_xz)
    s[Out].bind(oiw, thread_yz)
    s[Out].bind(oioc, block_x)
    s[Out].bind(ow, block_y)
    s[Out].bind(ooc, block_z)

    s[Out_L].compute_at(s[Out], iiioc)

    # schedule Out_L local write
    i, oc, h, w = s[Out_L].op.axis
    ic, dh, dw = s[Out_L].op.reduce_axis
    oic, iic = s[Out_L].split(ic, factor=ifactor)
    s[Out_L].reorder(oic, dh, dw, iic, h, w)

    fuse_index = s[Out_L].fuse(dw, dh)
    fuse_index = s[Out_L].fuse(fuse_index, oic)
    dw = fuse_index
    s[temp_S].compute_at(s[Out_L], dw)
    s[Filter_S].compute_at(s[Out_L], dw)

    #schedule temp_S shared mem load
    i, ic, h, w = s[temp_S].op.axis
    _, iic = s[temp_S].split(ic, factor=sfactor)
    _, iw = s[temp_S].split(w, factor=spart)
    s[temp_S].bind(iic, thread_x)
    s[temp_S].bind(iw, thread_y)

    #schedule Filter_S shared mem load
    i, oc, h, w = s[Filter_S].op.axis
    _, ioc = s[Filter_S].split(oc, factor=sfactor)
    _, ii = s[Filter_S].split(i, factor=spart)
    s[Filter_S].bind(ioc, thread_x)
    s[Filter_S].bind(ii, thread_y)
Beispiel #33
def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
    """Schedule conv2d for specific feature_in_out_filter pattern"""
    if util.get_const_int(Filter_S.shape[0]) == util.get_const_int(Filter_S.shape[1]):
        mark = util.get_const_int(Out.shape[2]) * util.get_const_int(Out.shape[3])
        num_thread_x = 0
        if mark % 8 == 0 and mark % 7 == 0:
            num_thread_x = 8
            vthread_x = 7
            for i in range(5, mark):
                if mark % i == 0 and num_thread_x == 0:
                    vthread_x = i
                    mark = mark // i
                if mark % i == 0 and vthread_x > 0:
                    num_thread_x = i
        num_thread_y = 8
        vthread_y = 2
        ifactor = 8

        block_x = tvm.thread_axis("blockIdx.x")
        block_y = tvm.thread_axis("blockIdx.y")
        thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
        thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
        thread_xz = tvm.thread_axis((0, vthread_x), "vthread", name="vx")
        thread_yz = tvm.thread_axis((0, vthread_y), "vthread", name="vy")

        i, oc, h, w = s[Out].op.axis
        w = s[Out].fuse(h, w)
        ow, iw = s[Out].split(w, factor=num_thread_x*vthread_x)
        ooc, ioc = s[Out].split(oc, factor=num_thread_y*vthread_y)
        oiw, iiw = s[Out].split(iw, nparts=vthread_x)
        oioc, iioc = s[Out].split(ioc, nparts=vthread_y)
        s[Out].reorder(i, ooc, ow, oioc, oiw, iioc, iiw)
        s[Out].bind(iiw, thread_x)
        s[Out].bind(iioc, thread_y)
        s[Out].bind(oiw, thread_xz)
        s[Out].bind(oioc, thread_yz)
        s[Out].bind(ow, block_x)
        s[Out].bind(ooc, block_y)

        s[Out_L].compute_at(s[Out], iiw)

        # schedule Out_L local write
        i, oc, h, w = s[Out_L].op.axis
        ic, dh, dw = s[Out_L].op.reduce_axis
        oic, iic = s[Out_L].split(ic, factor=ifactor)
        s[Out_L].reorder(oic, dh, dw, iic, h, w)

        s[temp_S].compute_at(s[Out_L], oic)
        s[Filter_S].compute_at(s[Out_L], dw)

        num_thread =
        thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x")
        block_xx = tvm.thread_axis("blockIdx.x")

        i = s[temp].fuse(*s[temp].op.axis)
        bx, tx = s[temp].split(i, factor=num_thread)
        s[temp].bind(tx, thread_xx)
        s[temp].bind(bx, block_xx)

        i = s[temp_R].fuse(*s[temp_R].op.axis)
        bx, tx = s[temp_R].split(i, factor=num_thread)
        s[temp_R].bind(tx, thread_xx)
        s[temp_R].bind(bx, block_xx)

        #schedule temp_S shared mem load
        i, oic, h, w, iic = s[temp_S].op.axis
        oic = s[temp_S].fuse(oic, h, w)
        ooic, ioic = s[temp_S].split(oic, factor=num_thread_x)
        _, iooic = s[temp_S].split(ooic, factor=num_thread_y)
        s[temp_S].bind(ioic, thread_x)
        s[temp_S].bind(iooic, thread_y)

        i, oc, h, w = s[Filter_S].op.axis
        _, ioc = s[Filter_S].split(oc, factor=num_thread_y)
        _, ii = s[Filter_S].split(i, factor=num_thread_x)
        s[Filter_S].bind(ioc, thread_y)
        s[Filter_S].bind(ii, thread_x)
        # scheduler params
        vthread = 2
        opart2 = 4
        ofactor = 64
        wfactor = 28
        ifactor = 8
        if flag > 256:
            wfactor = 14
        num_thread_x = max(1, ofactor//(opart2*2))
        num_thread_y = max(1, (wfactor + vthread-1) // vthread)
        block_x = tvm.thread_axis("blockIdx.x")
        block_y = tvm.thread_axis("blockIdx.y")
        block_z = tvm.thread_axis("blockIdx.z")
        thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
        thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
        thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
        thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")

        i, oc, h, w = s[Out].op.axis
        ooc, ioc = s[Out].split(oc, factor=ofactor)
        ow, iw = s[Out].split(w, factor=wfactor)
        ow = s[Out].fuse(ow, h)
        oioc, iioc = s[Out].split(ioc, nparts=vthread)
        oiw, iiw = s[Out].split(iw, nparts=vthread)
        oiioc, iiioc = s[Out].split(iioc, nparts=opart2)
        s[Out].reorder(i, ooc, ow, oioc, oiw, oiioc, iiw, iiioc)
        s[Out].bind(iiioc, thread_x)
        s[Out].bind(iiw, thread_y)
        s[Out].bind(oiioc, thread_xz)
        s[Out].bind(oiw, thread_yz)
        s[Out].bind(oioc, block_x)
        s[Out].bind(ow, block_y)
        s[Out].bind(ooc, block_z)

        s[Out_L].compute_at(s[Out], iiioc)

        # schedule Out_L local write
        i, oc, h, w = s[Out_L].op.axis
        ic, dh, dw = s[Out_L].op.reduce_axis
        oic, iic = s[Out_L].split(ic, factor=ifactor)
        s[Out_L].reorder(oic, dh, dw, iic, h, w)
        max_num_thread =
        if util.get_const_int(Filter_S.shape[1]) == 128:
            oic = s[Out_L].fuse(dh, oic)
            s[temp_S].compute_at(s[Out_L], oic)
            s[Filter_S].compute_at(s[Out_L], oic)
            num_thread = max_num_thread
            s[temp_S].compute_at(s[Out_L], oic)
            s[Filter_S].compute_at(s[Out_L], dw)
            num_thread = 456
            if max_num_thread < num_thread:
                num_thread = max_num_thread

        thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x")
        block_xx = tvm.thread_axis("blockIdx.x")

        i = s[temp].fuse(*s[temp].op.axis)
        bx, tx = s[temp].split(i, factor=num_thread)
        s[temp].bind(tx, thread_xx)
        s[temp].bind(bx, block_xx)

        i = s[temp_R].fuse(*s[temp_R].op.axis)
        bx, tx = s[temp_R].split(i, factor=num_thread)
        s[temp_R].bind(tx, thread_xx)
        s[temp_R].bind(bx, block_xx)

        #schedule temp_S shared mem load
        i, oic, h, w, iic = s[temp_S].op.axis
        oic = s[temp_S].fuse(oic, h, w)
        ooic, ioic = s[temp_S].split(oic, factor=num_thread_x)
        _, iooic = s[temp_S].split(ooic, factor=num_thread_y)
        s[temp_S].bind(ioic, thread_x)
        s[temp_S].bind(iooic, thread_y)

        #schedule Filter_S shared mem load
        i, oc, h, w = s[Filter_S].op.axis
        _, ioc = s[Filter_S].split(oc, factor=num_thread_x)
        _, ii = s[Filter_S].split(i, factor=num_thread_y)
        s[Filter_S].bind(ioc, thread_x)
        s[Filter_S].bind(ii, thread_y)
Beispiel #34
def make_matrix_mul(shapeA,
                        'x_f': 8,
                        'y_f': 1,
                        'k_f': 8
    """TODO: Your code here"""
    """Hint: use tvm.reduce_axis, tvm.sum"""
    """Hint: treat 4 cases of transposeA, transposeB separately"""
    """Hint: for tvm schedule, use split, reorder, vectorize, parallel"""
    """Hint: debug tvm schedule using tvm.lower"""
    X = tvm.placeholder(shapeA, dtype=dtype, name='X')
    Y = tvm.placeholder(shapeB, dtype=dtype, name='Y')

    if not transposeA and not transposeB:
        k = tvm.reduce_axis((0, shapeA[1]), name='k')
        Z = tvm.compute((shapeA[0], shapeB[1]),
                        lambda i, j: tvm.sum(X[i, k] * Y[k, j], axis=k),
    elif not transposeA and transposeB:
        k = tvm.reduce_axis((0, shapeA[1]), name='k')
        Z = tvm.compute((shapeA[0], shapeB[0]),
                        lambda i, j: tvm.sum(X[i, k] * Y[j, k], axis=k),
    elif transposeA and not transposeB:
        k = tvm.reduce_axis((0, shapeA[0]), name='k')
        Z = tvm.compute((shapeA[1], shapeB[1]),
                        lambda i, j: tvm.sum(X[k, i] * Y[k, j], axis=k),
    elif transposeA and transposeB:
        k = tvm.reduce_axis((0, shapeA[0]), name='k')
        Z = tvm.compute((shapeA[1], shapeB[0]),
                        lambda i, j: tvm.sum(X[k, i] * Y[j, k], axis=k),

    x_f = args_opt['x_f']
    y_f = args_opt['y_f']
    k_f = args_opt['k_f']

    s = tvm.create_schedule(Z.op)
    xo, yo, xi, yi = s[Z].tile(Z.op.axis[0], Z.op.axis[1], x_f, y_f)

    k, = s[Z].op.reduce_axis
    ko, ki = s[Z].split(k, factor=k_f)
    s[Z].reorder(xo, yo, ko, xi, ki, yi)

    # zz = s[Z].fuse(ko, xi)

    tvm.lower(s, [X, Y, Z], simple_mode=True)
    s[Z].bind(xo, tvm.thread_axis("blockIdx.x"))
    s[Z].bind(yo, tvm.thread_axis("threadIdx.x"))

    # s[Z].bind(Z.op.axis[0], tvm.thread_axis("blockIdx.x"))
    # s[Z].bind(Z.op.axis[1], tvm.thread_axis("threadIdx.x"))

    f =, [X, Y, Z], tgt, target_host=tgt_host, name=func_name)
    return _export_module(f, func_name, remote)
Beispiel #35
rows = tvm.var("rows")
cols = tvm.var("cols")
max_chans = tvm.const(5)
chans = tvm.var("chans")

input_vec = tvm.placeholder((rows, cols, chans), dtype="float32")
kernel = tvm.compute((cols, chans),
                     lambda c, cc: 1.0 * c * cc,

result = tvm.compute((rows, cols, chans),
                     lambda y, x, c: input_vec[y, x, c] * kernel[
                         x, tvm.min(max_chans, tvm.max(0, c))],

sched = tvm.create_schedule(result.op)
result_stage = sched[result]
kernel_stage = sched[kernel]

arglist = [input_vec, result]

kernel_stage.compute_at(result_stage, result.op.axis[1])

print_schedule(sched, arglist)

result_stage.bind(result.op.axis[0], tvm.thread_axis("blockIdx.x"))
result_stage.bind(result.op.axis[1], tvm.thread_axis("threadIdx.x"))

fun =, arglist, "opencl", name="test_compute_at")
Beispiel #36
def _schedule_reduce(op, sch, is_idx_reduce=False):
    if is_idx_reduce:
        data_out = op.input_tensors[0]
        data_in = op.input_tensors[0]
        data_out = op.output(0)

    if not sch[data_out].op.reduce_axis:
        return schedule_injective_from_existing(sch, op.output(0))

    if len(sch[data_out].op.axis) > 0:
        all_reduce = False
        num_thread = 32
        target =
        if target and target.target_name == "opencl":
            # without it, CL_INVALID_WORK_GROUP_SIZE occurred when running
            # don't know why
            num_thread = 16
        block_x = tvm.thread_axis("blockIdx.x")
        thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
        thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
        all_reduce = True
        num_thread =
        thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")

    # Fuse and refactor the reduce axis
    fused_reduce = sch[data_out].fuse(*[
        for i in range(len(sch[data_out].op.reduce_axis))
    ko, ki = sch[data_out].split(fused_reduce, factor=num_thread)
    if is_idx_reduce:
        data_out_rf, _ = sch.rfactor(data_out, ki)
        data_out_rf = sch.rfactor(data_out, ki)
    tx = sch[data_out].op.reduce_axis[0]
    sch[data_out].bind(tx, thread_x)
    sch[data_out_rf].compute_at(sch[data_out], tx)
    if is_idx_reduce:
        real_output = op.output(0)
        temp_idx_input = data_out.op.output(0)
        temp_val_input = data_out.op.output(1)
        real_output = data_out
    if not all_reduce:
        # Fuse and split the axis
        fused_outer = sch[real_output].fuse(*[
            for i in range(len(sch[real_output].op.axis))
        bx, outer_in = sch[real_output].split(fused_outer, factor=num_thread)

        # Bind the axes to threads and blocks
        sch[real_output].bind(outer_in, thread_y)
        sch[real_output].bind(bx, block_x)
        if is_idx_reduce:
            sch[temp_idx_input].compute_at(sch[real_output], outer_in)
            sch[temp_val_input].compute_at(sch[real_output], outer_in)
        if is_idx_reduce:
            spatial_axis = sch[real_output].fuse(*(sch[real_output].op.axis))
            sch[real_output].bind(spatial_axis, tvm.thread_axis("blockIdx.x"))
            sch[temp_idx_input].compute_at(sch[real_output], spatial_axis)
            sch[temp_val_input].compute_at(sch[real_output], spatial_axis)
    return sch
Beispiel #37
    def _callback(op):
        if op.tag == 'depthwise_conv2d_nchw':
            pad_data = op.input_tensors[0]
            kernel = op.input_tensors[1]
            conv = op.output(0)

            ##### space definition begin #####
            n, f, y, x = s[conv].op.axis
            cfg.define_split("tile_f", f, num_outputs=4)
            cfg.define_split("tile_y", y, num_outputs=4)
            cfg.define_split("tile_x", x, num_outputs=4)
            cfg.define_knob("auto_unroll_max_step", [0, 256, 1500])

            target =
            if target.target_name in ['nvptx', 'rocm']:
                cfg.define_knob("unroll_explicit", [1])
                cfg.define_knob("unroll_explicit", [0, 1])

            # fallback support
            if cfg.is_fallback:
                ref_log = autotvm.tophub.load_reference_log(
                    target.target_name, target.model,
                # TODO(lmzheng): A bug here, set unroll_explicit to False as workaround
                cfg['unroll_explicit'].val = 0
            ##### space definition end #####

            if isinstance(kernel.op,
                          tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag:

            if conv.op in s.outputs:
                output = conv
                OL = s.cache_write(conv, 'local')
                output = s.outputs[0].output(0)
                OL = conv

            # create cache stage
            AA = s.cache_read(pad_data, 'shared', [OL])
            WW = s.cache_read(kernel, 'shared', [OL])
            AL = s.cache_read(AA, 'local', [OL])
            WL = s.cache_read(WW, 'local', [OL])

            # tile and bind spatial axes
            n, f, y, x = s[output].op.axis
            bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
            by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
            bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)

            kernel_scope, n = s[output].split(n, nparts=1)
            bf = s[output].fuse(n, bf)
            s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
            s[output].bind(by, tvm.thread_axis("blockIdx.y"))
            s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
            s[output].bind(vf, tvm.thread_axis("vthread"))
            s[output].bind(vy, tvm.thread_axis("vthread"))
            s[output].bind(vx, tvm.thread_axis("vthread"))
            s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
            s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
            s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
            s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
            s[OL].compute_at(s[output], tx)

            # cooperative fetching
            s[AA].compute_at(s[output], bx)
            s[WW].compute_at(s[output], bx)
            s[AL].compute_at(s[output], tx)
            s[WL].compute_at(s[output], tx)

            for load in [AA, WW]:
                fused = s[load].fuse(*list(s[load].op.axis))
                fused, tx = s[load].split(fused, cfg["tile_x"].size[2])
                fused, ty = s[load].split(fused, cfg["tile_y"].size[2])
                fused, tz = s[load].split(fused, cfg["tile_f"].size[2])
                s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
                s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
                s[load].bind(tx, tvm.thread_axis("threadIdx.x"))

            s[output].pragma(kernel_scope, 'auto_unroll_max_step',
            s[output].pragma(kernel_scope, 'unroll_explicit',
Beispiel #38
# In the implementation below, virtual threading distributes work across two
# threads split along the output channel axis.
# We show how work is split when computing the 2D convolution in the figure
# below.
# .. image::
#      :align: center
#      :width: 480px

# VTA only supports 2 virtual threads
v_threads = 2

# Perform virtual thread split along output channel outer axis
_, tx = s[res].split(oc_out, factor=v_threads)
s[res].reorder(tx, b_out)
s[res].bind(tx, tvm.thread_axis("cthread"))

# Let's look at the current TVM schedule after blocking and virtual threading
print(tvm.lower(s, [data, kernel, res], simple_mode=True))

# Lowering Copies to DMA Transfers
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Next we set the buffer scopes to the corresponding on-chip VTA SRAM buffers.
# We move the load loops into the 2D convolution computation loop to stage
# memory loads such that they fit in the on-chip SRAM buffers.
# Finally we annotate the load/store loop outer axes with the DMA copy pragma
# to perform bulk memory transfers on VTA.

# Set scope of SRAM buffers
Beispiel #39
def make_matrix_softmax_cross_entropy(shape,
    """TODO: Your code here"""
    """Hint: output shape should be (1,)"""
    X = tvm.placeholder(shape, dtype=dtype, name='X')
    a1 = tvm.reduce_axis((0, shape[1]), name='a1')
    MAX_X = tvm.compute((shape[0], ), lambda i: tvm.max(X[i, a1], axis=[a1]))

    E_X = tvm.compute(shape, lambda i, j: tvm.exp(X[i, j] - MAX_X(i)))

    a2 = tvm.reduce_axis((0, shape[1]), name='a2')
    E_X_SUM = tvm.compute((shape[0], ),
                          lambda i: tvm.sum(E_X[i, a2], axis=[a2]))

    SOFTMAX_X = tvm.compute(shape, lambda i, j: E_X[i, j] / E_X_SUM(i))

    LOG_SOFTMAX_X = tvm.compute(shape, lambda i, j: tvm.log(SOFTMAX_X[i, j]))

    X_P = tvm.placeholder(shape, dtype=dtype, name='X_P')

    MUL = tvm.compute(shape, lambda i, j: X_P[i, j] * LOG_SOFTMAX_X[i, j])

    a3 = tvm.reduce_axis((0, shape[1]), name='a3')
    SUM = tvm.compute((shape[0], ), lambda i: tvm.sum(-MUL[i, a3], axis=[a3]))

    a4 = tvm.reduce_axis((0, shape[0]), name='a4')
    MEAN = tvm.compute((1, ), lambda i: tvm.sum(SUM[a4] / shape[0], axis=[a4]))

    # s = tvm.create_schedule([MAX_X.op, E_X.op, E_X_SUM.op, SOFTMAX_X.op, LOG_SOFTMAX_X.op, MUL.op, SUM.op, MEAN.op])
    s = tvm.create_schedule(MEAN.op)

    # print(tvm.lower(s, [X, X_P, MEAN], simple_mode=True))

    # MAX_X
    s[MAX_X].bind(MAX_X.op.axis[0], tvm.thread_axis("blockIdx.x"))
    s[MAX_X].bind(a1, tvm.thread_axis("threadIdx.x"))

    # E_X
    s[E_X].bind(E_X.op.axis[0], tvm.thread_axis("blockIdx.x"))
    s[E_X].bind(E_X.op.axis[1], tvm.thread_axis("threadIdx.x"))

    # E_X_SUM
    s[E_X_SUM].bind(E_X_SUM.op.axis[0], tvm.thread_axis("blockIdx.x"))
    s[E_X_SUM].bind(a2, tvm.thread_axis("threadIdx.x"))

    s[SOFTMAX_X].bind(SOFTMAX_X.op.axis[0], tvm.thread_axis("blockIdx.x"))
    s[SOFTMAX_X].bind(SOFTMAX_X.op.axis[1], tvm.thread_axis("threadIdx.x"))


    # MUL
    s[MUL].bind(MUL.op.axis[0], tvm.thread_axis("blockIdx.x"))
    s[MUL].bind(MUL.op.axis[1], tvm.thread_axis("threadIdx.x"))

    # SUM
    s[SUM].bind(SUM.op.axis[0], tvm.thread_axis("blockIdx.x"))
    s[SUM].bind(a3, tvm.thread_axis("threadIdx.x"))

    # MEAN
    # s[MEAN].bind(a4, tvm.thread_axis("blockIdx.x"))
    s[MEAN].bind(a4, tvm.thread_axis("threadIdx.x"))

    # print(tvm.lower(s, [X, X_P, MEAN], simple_mode=True))

    # block_x = tvm.thread_axis("blockIdx.x")
    # thread_x = tvm.thread_axis("threadIdx.x")

    # zo, zi = s[SUM].split(SUM.op.axis[0], 3)
    # print(tvm.lower(s, [X, X_P, MEAN], simple_mode=True))
    # s[SUM].bind(zo, block_x)
    # s[SUM].bind(zi, thread_x)

    f =, [X, X_P, MEAN], tgt, target_host=tgt_host, name=func_name)
    return _export_module(f, func_name, remote)
Beispiel #40
def schedule_packed_conv2d(outs):
    """ Schedule the packed conv2d.
    assert len(outs) == 1
    output = outs[0]
    ewise_inputs = []
    ewise_ops = []
    conv2d_res = []
    assert output.dtype == "int8"
    assert output.op.input_tensors[0].dtype == "int32"

    def _traverse(op):
        if topi.tag.is_broadcast(op.tag):
            if not op.same_as(output.op):
            for tensor in op.input_tensors:
                if isinstance(tensor.op, tvm.tensor.PlaceholderOp):
                    ewise_inputs.append((op, tensor))
            assert op.tag == "packed_conv2d"

    assert len(conv2d_res) == 1
    conv2d_stage = conv2d_res[0].output(0)

    data, kernel = conv2d_stage.op.input_tensors
    if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
        temp = data.op.input_tensors[0]
        pad_data = data
        data = temp
        pad_data = None
    wrkld = _get_workload(data, pad_data, kernel, output)
    if wrkld in _WL2PLAN:
        plan = _WL2PLAN[wrkld]
        plan = find_schedules(wrkld, vt_only=True, best_only=True)[0]"Trying to find plan for %s", wrkld)
    env = get_env()

    load_inp = load_wgt = load_out = store_out = env.dma_copy
    alu = env.alu
    gemm = env.gemm

    # schedule1
    oshape = topi.util.get_const_tuple(output.shape)
    s = tvm.create_schedule(output.op)

    # setup pad
    if pad_data is not None:
        cdata = pad_data
        cdata = s.cache_read(data, env.inp_scope, [conv2d_stage])
    ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage])
    # cache read input
    cache_read_ewise = []

    for consumer, tensor in ewise_inputs:
        cache_read_ewise.append(s.cache_read(tensor, env.acc_scope,
    # set ewise scope
    for op in ewise_ops:
        s[op].pragma(s[op].op.axis[0], alu)

    # tile
    oc_factor = (plan.oc_factor if plan.oc_factor else plan.out_filter //
    h_factor = (plan.h_factor if plan.h_factor else oshape[2])
    w_factor = (plan.w_factor if plan.w_factor else oshape[3])

    x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis
    x_co0, x_co1 = s[output].split(x_co, factor=oc_factor)
    x_i0, x_i1 = s[output].split(x_i, factor=h_factor)
    x_j0, x_j1 = s[output].split(x_j, factor=w_factor)
    s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci)
    store_pt = x_j0

    # set all compute scopes
    s[conv2d_stage].compute_at(s[output], store_pt)
    for op in ewise_ops:
        s[op].compute_at(s[output], store_pt)

    for tensor in cache_read_ewise:
        s[tensor].compute_at(s[output], store_pt)
        s[tensor].pragma(s[tensor].op.axis[0], load_out)

    # virtual threading along output channel axes
    if plan.oc_nthread > 1:
        _, v_t = s[output].split(x_co0, factor=plan.oc_nthread)
        s[output].reorder(v_t, x_bo)
        s[output].bind(v_t, tvm.thread_axis("cthread"))

    # virtual threading along spatial rows
    if plan.h_nthread > 1:
        _, v_t = s[output].split(x_i0, factor=plan.h_nthread)
        s[output].reorder(v_t, x_bo)
        s[output].bind(v_t, tvm.thread_axis("cthread"))

    x_bo, x_co, x_i, x_j, x_bi, x_ci = s[conv2d_stage].op.axis
    k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis
    s[conv2d_stage].reorder(x_bo, k_o, x_j, d_j, d_i, x_co, x_i, x_bi, x_ci,

    if plan.ic_factor:
        k_o, _ = s[conv2d_stage].split(k_o, factor=plan.ic_factor)
        s[cdata].compute_at(s[conv2d_stage], k_o)
        s[ckernel].compute_at(s[conv2d_stage], k_o)

    # Use VTA instructions
    s[cdata].pragma(s[cdata].op.axis[0], load_inp)
    s[ckernel].pragma(s[ckernel].op.axis[0], load_wgt)
    s[conv2d_stage].tensorize(x_bi, gemm)
    s[output].pragma(x_co1, store_out)
    return s
Beispiel #41
def test_num_thread():
    N = 1024
    M = 128

    A = tvm.placeholder((N, ), name='A', dtype='float32')
    B = tvm.compute((N, ), lambda i: A[i], name='B')

    s = tvm.create_schedule([B.op])
    o, i = s[B].split(s[B].op.axis[0], M)

    s[B].bind(o, tvm.thread_axis('threadIdx.x'))
    s[B].bind(i, tvm.thread_axis("threadIdx.y"))

    # shared memory usage: 0
    # thread usage: N

    for target in ['opencl', 'cuda']:
        if not tvm.context(target).exist:

        valid = [None]
        with tvm.build_config(
                    "add_lower_pass": [(
                                        max_threads_per_block=N - 1))]
  , [A, B], target)
        assert not valid[0]

        with tvm.build_config(
                    "add_lower_pass": [(
  , [A, B], target)
        assert valid[0]

        with tvm.build_config(
                    "add_lower_pass": [(
                                        max_thread_y=M - 1))]
  , [A, B], target)
        assert not valid[0]

        with tvm.build_config(
                    "add_lower_pass": [(
  , [A, B], target)
        assert valid[0]
Beispiel #42
def get_valid_counts_ir(data, flag, idx, valid_count, out):
    """Low level IR to get valid count of bounding boxes
    given a score threshold. Also moves valid boxes to the
    top of input data.

    data : Buffer
        Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length].

    flag : Buffer
        2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].

    idx : Buffer
        2D Buffer of valid data indices with shape [batch_size, num_anchors].

    valid_count : Buffer
        1-D buffer for valid number of boxes.

    out : Buffer
        Rearranged data buffer.

    stmt : Stmt
        The result IR statement.
    batch_size = data.shape[0]
    num_anchors = data.shape[1]
    elem_length = data.shape[2]
    size = batch_size * num_anchors * elem_length

    ib = tvm.ir_builder.create()

    data = ib.buffer_ptr(data)
    flag = ib.buffer_ptr(flag)
    idx = ib.buffer_ptr(idx)
    valid_count = ib.buffer_ptr(valid_count)
    out = ib.buffer_ptr(out)

    max_threads = int(
    nthread_tx = max_threads
    nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("blockIdx.x")
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    tid = bx * max_threads + tx

    idxd = tvm.indexdiv
    idxm = tvm.indexmod

    with ib.if_scope(tid < batch_size * num_anchors):
        i = idxd(tid, num_anchors)
        j = idxm(tid, num_anchors)
        base_idx = i * num_anchors * elem_length
        with ib.if_scope(flag[tid] > 0):
            with ib.for_range(0, elem_length) as k:
                with ib.if_scope(base_idx +
                                 (idx[tid] - 1) * elem_length + k < size):
                    out[base_idx + (idx[tid] - 1) * elem_length + k] =\
                    data[base_idx + j * elem_length + k]
        with ib.if_scope(j == 0):
            valid_count[i] = idx[tid + num_anchors - 1]
        with ib.if_scope(j >= idx[i * num_anchors + num_anchors - 1]):
            with ib.for_range(0, elem_length) as l:
                with ib.if_scope(tid * elem_length + l < size):
                    out[tid * elem_length + l] = -1.0
    return ib.get()
def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
    assert N == 1, "Only consider batch_size = 1 in this template"

    data = tvm.placeholder((N, CI, H, W), name='data')
    kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
    conv = topi.nn.conv2d_nchw(data,
    s = tvm.create_schedule([conv.op])

    ##### space definition begin #####
    n, f, y, x = s[conv].op.axis
    rc, ry, rx = s[conv].op.reduce_axis

    cfg = autotvm.get_config()
    cfg.define_split("tile_f", f, num_outputs=4)
    cfg.define_split("tile_y", y, num_outputs=4)
    cfg.define_split("tile_x", x, num_outputs=4)
    cfg.define_split("tile_rc", rc, num_outputs=3)
    cfg.define_split("tile_ry", ry, num_outputs=3)
    cfg.define_split("tile_rx", rx, num_outputs=3)
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
    cfg.define_knob("unroll_explicit", [0, 1])
    ##### space definition end #####

    # inline padding
    pad_data = s[conv].op.input_tensors[0]
    data, raw_data = pad_data, data

    output = conv
    OL = s.cache_write(conv, 'local')

    # create cache stage
    AA = s.cache_read(data, 'shared', [OL])
    WW = s.cache_read(kernel, 'shared', [OL])
    AL = s.cache_read(AA, 'local', [OL])
    WL = s.cache_read(WW, 'local', [OL])

    # tile and bind spatial axes
    n, f, y, x = s[output].op.axis
    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
    kernel_scope = n  # this is the scope to attach global config inside this kernel

    s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
    s[output].bind(by, tvm.thread_axis("blockIdx.y"))
    s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[output].bind(vf, tvm.thread_axis("vthread"))
    s[output].bind(vy, tvm.thread_axis("vthread"))
    s[output].bind(vx, tvm.thread_axis("vthread"))
    s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
    s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
    s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
    s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
    s[OL].compute_at(s[output], tx)

    # tile reduction axes
    n, f, y, x = s[OL].op.axis
    rc, ry, rx = s[OL].op.reduce_axis
    rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
    ryo, rym, ryi = cfg['tile_rx'].apply(s, OL, ry)
    rxo, rxm, rxi = cfg['tile_ry'].apply(s, OL, rx)
    s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)

    s[AA].compute_at(s[OL], rxo)
    s[WW].compute_at(s[OL], rxo)
    s[AL].compute_at(s[OL], rxm)
    s[WL].compute_at(s[OL], rxm)

    # cooperative fetching
    for load in [AA, WW]:
        n, f, y, x = s[load].op.axis
        fused = s[load].fuse(n, f, y, x)
        tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
        ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
        tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
        s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
        s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
        s[load].bind(tx, tvm.thread_axis("threadIdx.x"))

    # tune unroll
    s[output].pragma(kernel_scope, 'auto_unroll_max_step',
    s[output].pragma(kernel_scope, 'unroll_explicit',

    return s, [raw_data, kernel, conv]
    def _callback(op):
        if op.tag == 'conv2d_transpose_nchw':
            pad_data = op.input_tensors[0]
            kernel = op.input_tensors[1]
            conv = op.output(0)

            ##### space definition begin #####
            n, f, y, x = s[conv].op.axis
            rc = s[conv].op.reduce_axis[0]
            cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
            cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
            cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
            cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
            cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])

            target =
            if target.target_name in ['nvptx', 'rocm']:
                cfg.define_knob("unroll_explicit", [1])
                cfg.define_knob("unroll_explicit", [0, 1])
            ##### space definition end #####

            if isinstance(kernel.op,
                          tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag:

            if conv.op in s.outputs:
                output = conv
                OL = s.cache_write(conv, 'local')
                output = s.outputs[0].output(0)
                OL = conv

            # create cache stage
            AA = pad_data
            WW = s.cache_read(kernel, 'shared', [OL])

            # tile and bind spatial axes
            n, f, y, x = s[output].op.axis
            kernel_scope, n = s[output].split(n, nparts=1)
            bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
            by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
            bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)

            bf = s[output].fuse(n, bf)
            s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
            s[output].bind(by, tvm.thread_axis("blockIdx.y"))
            s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
            s[output].bind(vf, tvm.thread_axis("vthread"))
            s[output].bind(vy, tvm.thread_axis("vthread"))
            s[output].bind(vx, tvm.thread_axis("vthread"))
            s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
            s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
            s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
            s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
            s[OL].compute_at(s[output], tx)

            # tile reduction axes
            n, f, y, x = s[OL].op.axis
            rc, ry, rx = s[OL].op.reduce_axis
            rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
            s[OL].reorder(rco, rcm, ry, rx, rci, n, f, y, x)

            s[AA].compute_at(s[OL], rcm)
            s[WW].compute_at(s[OL], rcm)

            # cooperative fetching
            for load in [AA, WW]:
                n, f, y, x = s[load].op.axis
                fused = s[load].fuse(n, f, y, x)
                tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
                ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
                tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
                s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
                s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
                s[load].bind(tx, tvm.thread_axis("threadIdx.x"))

            s[output].pragma(kernel_scope, 'auto_unroll_max_step',
            s[output].pragma(kernel_scope, 'unroll_explicit',
import tvm
import numpy as np
from tvm.contrib.nvcc import have_fp16, have_int8
from tvm.contrib import nvcc

tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")

def test_cuda_vectorize_add():
    num_thread = 8
    def check_cuda(dtype, n, lanes):
        if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
            print("skip because cuda is not enabled..")
        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
            print("skip because gpu does not support fp16")
        if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version):
            print("skip because gpu does not support int8")
def sort_ir_out(data, index, new_index, loc, output, axis_mul_before,
                axis_mul_after, axis):
    """Low level IR routing subfunction 4/4 for writing sorted indices to output format.

    data: Buffer
        Buffer of output boxes with class and score.

    index : Buffer
        Buffer of number of valid output boxes.

    new_index : Buffer
        Buffer of sorted indices in a flatten format.

    loc : Buffer
        Buffer of start locations of each sorting segment.

    output : Buffer
        Output buffer of output box indexes sorted by score.

    axis_mul_before : int
        The multiplication result of axis dimensions before axis.

    axis_mul_after : int
        The multiplication result of axis dimensions after axis.

    axis : int
        The axis used for sorting.

    is_descend : bool
        If the sorted data is in descending order.

    stmt : Stmt
        The result IR statement.
    max_threads = int(
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("blockIdx.x")
    ib = tvm.ir_builder.create()
    dshape = tvm.max(loc.shape[0], data.shape[axis])
    p_index = ib.buffer_ptr(index)
    index_new = ib.buffer_ptr(new_index)
    sizes = ib.buffer_ptr(loc)
    p_out = ib.buffer_ptr(output)
    nthread_tx = max_threads
    nthread_bx = dshape // max_threads + 1
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    tid = bx * max_threads + tx

    with ib.if_scope(axis_mul_before * axis_mul_after > 1):
        with ib.if_scope(tid < axis_mul_before * axis_mul_after):
            i = tid / axis_mul_after
            j = tid % axis_mul_after
            base_idx = i * data.shape[axis] * axis_mul_after + j
            with ib.for_range(0, data.shape[axis], name="k") as k:
                with ib.if_scope(tid == 0):
                    start = 0
                with ib.else_scope():
                    start = sizes[tid - 1]
                p_out[base_idx + k * axis_mul_after] =
                    k < p_index[tid], index_new[k + start], k)
    with ib.else_scope():
        with ib.if_scope(tid < data.shape[axis]):
            p_out[tid] = < p_index[0], index_new[tid], tid)

    body = ib.get()
    return body
def sort_ir(data, index, output, axis, is_descend):
    """Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.

    data: Buffer
        2D Buffer of input boxes' score with shape [batch_size, num_anchors].

    index : Buffer
        Buffer of number of valid number of boxes.

    output : Buffer
        Output buffer of indicies of sorted tensor.

    axis : int
        The axis used for sorting.

    is_descend : bool
        If the sorted data is in descending order.

    stmt : Stmt
        The result IR statement.

    max_threads = int(
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("blockIdx.x")
    ib = tvm.ir_builder.create()
    p_data = ib.buffer_ptr(data)
    p_index = ib.buffer_ptr(index)
    p_out = ib.buffer_ptr(output)
    ndim = len(data.shape)
    assert data.dtype == "float32", "Currently only supports input dtype to be float32"
    assert axis < ndim, "Axis out of boundary for input ndim %d" % ndim

    axis_mul_before = 1
    axis_mul_after = 1
    if axis < 0:
        axis = ndim + axis
    for i in range(0, ndim):
        if i < axis:
            axis_mul_before *= data.shape[i]
        elif i > axis:
            axis_mul_after *= data.shape[i]

    dshape = 0
    for i in range(0, len(index.shape)):
        dshape += index.shape[i]
    dshape = > axis_mul_before * axis_mul_after, dshape,
                        axis_mul_before * axis_mul_after)

    sizes_temp = ib.allocate("int32",
    sizes = ib.allocate("int32", dshape, name="sizes", scope="global")
    temp_index = ib.allocate("int32", dshape, name="temp_index", scope="local")
    temp_data = ib.allocate("float32", dshape, name="temp_data", scope="local")
    data_new = ib.allocate("float32", dshape, name="data_new", scope="global")
    index_new = ib.allocate("int32", dshape, name="index_new", scope="global")
    nthread_tx = max_threads
    nthread_bx = dshape // max_threads + 1
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    tid = bx * max_threads + tx

    with ib.if_scope(tid < axis_mul_before * axis_mul_after):
        sizes[tid] = p_index[tid]
        sizes_temp[tid] = p_index[tid]

    with ib.if_scope(tid < axis_mul_before * axis_mul_after):
        with ib.for_range(0, tvm.floor(tvm.sqrt((axis_mul_before * axis_mul_after) \
             .astype("float32"))) + 1, name="k") as k:
            with ib.if_scope(tid - (tvm.const(1, "int32") << k) >= 0):
                with ib.if_scope(k % 2 == 0):
                    sizes[tid] += sizes_temp[tid -
                                             (tvm.const(1, "int32") << k)]
                    sizes_temp[tid] = sizes[tid]
                with ib.else_scope():
                    sizes_temp[tid] += sizes[tid -
                                             (tvm.const(1, "int32") << k)]
                    sizes[tid] = sizes_temp[tid]

    with ib.if_scope(tid < axis_mul_before * axis_mul_after):
        i = tid / axis_mul_after
        j = tid % axis_mul_after
        current_sort_num = p_index[tid]
        base_idx = i * data.shape[axis] * axis_mul_after + j
        with ib.for_range(0, current_sort_num, name="k") as k:
            full_idx = base_idx + k * axis_mul_after
            with ib.if_scope(tid == 0):
                start = 0
            with ib.else_scope():
                start = sizes[tid - 1]
            index_new[start + k] = k
            data_new[start + k] = p_data[full_idx]

    with ib.if_scope(tid < axis_mul_before * axis_mul_after):
        with ib.if_scope(tid == 0):
            start = 0
        with ib.else_scope():
            start = sizes[tid - 1]
        # OddEvenTransposeSort
        with ib.for_range(0, p_index[tid], name="k") as k:
            with ib.for_range(0, p_index[tid] - 1, name="i") as i:
                with ib.if_scope(i % 2 == (k & 1)):
                    with ib.if_scope(
                        ((data_new[i + start] < data_new[i + start + 1])
                         ^ is_descend) == False):
                        temp_data[tid] = data_new[i + start]
                        data_new[i + start] = data_new[i + start + 1]
                        data_new[i + start + 1] = temp_data[tid]
                        temp_index[tid] = index_new[i + start]
                        index_new[i + start] = index_new[i + start + 1]
                        index_new[i + start + 1] = temp_index[tid]

    with ib.if_scope(tid < axis_mul_before * axis_mul_after):
        i = tid / axis_mul_after
        j = tid % axis_mul_after
        current_sort_num = p_index[tid]
        base_idx = i * data.shape[axis] * axis_mul_after + j
        with ib.for_range(0, data.shape[axis], name="k") as k:
            with ib.if_scope(tid == 0):
                start = 0
            with ib.else_scope():
                start = sizes[tid - 1]
            p_out[base_idx + k * axis_mul_after] =
                k < current_sort_num, index_new[k + start], k)
    body = ib.get()
    return body
 def _schedule(PaddedInput, Filter, DepthwiseConv2d):
     in_shape = get_const_tuple(PaddedInput.shape)
     out_shape = get_const_tuple(DepthwiseConv2d.shape)
     in_height = in_shape[2]
     in_width = in_shape[3]
     out_height = out_shape[2]
     out_width = out_shape[3]
     channel_multiplier = get_const_tuple(Filter.shape)[1]
     IS = s.cache_read(PaddedInput, "shared", [DepthwiseConv2d])
     FS = s.cache_read(Filter, "shared", [DepthwiseConv2d])
     IL = s.cache_read(IS, "local", [DepthwiseConv2d])
     FL = s.cache_read(FS, "local", [DepthwiseConv2d])
     if DepthwiseConv2d.op in s.outputs:
         Output = DepthwiseConv2d
         CL = s.cache_write(DepthwiseConv2d, "local")
         Output = outs[0].op.output(0)
     # schedule parameters
     num_thread_y = 8
     num_thread_x = 8
     num_vthread_y = 1
     num_vthread_x = 1
     blocking_h = out_height
     blocking_w = out_width
     if out_height % 32 == 0 or in_height >= 108:
         blocking_h = 32
     if out_width % 32 == 0:
         blocking_w = 32
         num_thread_x = 16
         num_vthread_x = 2
     elif in_width >= 108:
         blocking_w = 32
     block_y = tvm.thread_axis("blockIdx.y")
     block_x = tvm.thread_axis("blockIdx.x")
     thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
     thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
     thread_vy = tvm.thread_axis((0, num_vthread_y), "vthread", name="vy")
     thread_vx = tvm.thread_axis((0, num_vthread_x), "vthread", name="vx")
     # split and bind
     by, byi = s[Output].split(Output.op.axis[1], factor=channel_multiplier)
     s[Output].reorder(Output.op.axis[2], Output.op.axis[3], byi)
     by = s[Output].fuse(Output.op.axis[0], by)
     s[Output].bind(by, block_y)
     bx1, x1i = s[Output].split(Output.op.axis[2], factor=blocking_h)
     tvy, vyi = s[Output].split(x1i, nparts=num_vthread_y)
     ty, yi = s[Output].split(vyi, nparts=num_thread_y)
     bx2, x2i = s[Output].split(Output.op.axis[3], factor=blocking_w)
     tvx, vxi = s[Output].split(x2i, nparts=num_vthread_x)
     tx, xi = s[Output].split(vxi, nparts=num_thread_x)
     s[Output].reorder(bx1, bx2, tvy, tvx, ty, tx, yi, xi)
     bx = s[Output].fuse(bx1, bx2)
     s[Output].bind(bx, block_x)
     s[Output].bind(tvy, thread_vy)
     s[Output].bind(tvx, thread_vx)
     s[Output].bind(ty, thread_y)
     s[Output].bind(tx, thread_x)
     # local memory load
     s[IL].compute_at(s[Output], tx)
     s[FL].compute_at(s[Output], tx)
     if DepthwiseConv2d.op in s.outputs:
         s[CL].compute_at(s[Output], tx)
         s[DepthwiseConv2d].compute_at(s[Output], tx)
     # input's shared memory load
     s[IS].compute_at(s[Output], bx)
     ty, yi = s[IS].split(IS.op.axis[2], nparts=num_thread_y)
     tx, xi = s[IS].split(IS.op.axis[3], nparts=num_thread_x)
     s[IS].bind(ty, thread_y)
     s[IS].bind(tx, thread_x)
     # filter's shared memory load
     s[FS].compute_at(s[Output], bx)
     s[FS].reorder(FS.op.axis[2], FS.op.axis[3], FS.op.axis[1])
     ty, yi = s[FS].split(FS.op.axis[2], nparts=num_thread_y)
     tx, xi = s[FS].split(FS.op.axis[3], nparts=num_thread_x)
     s[FS].bind(ty, thread_y)
     s[FS].bind(tx, thread_x)
def sort_ir(data, index, output):
    """Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.

    data: Buffer
        2D Buffer of input boxes' score with shape [batch_size, num_anchors].

    index : Buffer
        1D Buffer of number of valid number of boxes.

    output : Buffer
        2D Output buffer of indicies of sorted tensor with shape [batch_size, num_anchors].

    stmt : Stmt
        The result IR statement.

    assert data.dtype == "float32", "Currently only supports input dtype to be float32"
    batch, num_anchors = get_const_tuple(data.shape)
    max_threads = int(
    ib = tvm.ir_builder.create()
    p_data = ib.buffer_ptr(data)
    p_index = ib.buffer_ptr(index)
    p_out = ib.buffer_ptr(output)
    nthread_tx = max_threads
    nthread_bx = (num_anchors + 1) // 2 // max_threads + 1
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("vthread")
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "virtual_thread", nthread_bx)
    tid = bx * nthread_tx + tx
    temp_data = ib.allocate("float32", (1, ), name="temp_data", scope="local")
    temp_index = ib.allocate("int32", (1, ), name="temp_index", scope="local")

    with ib.for_range(0, batch, for_type="unroll") as b:
        start = b * num_anchors
        for i in range(2):
            bbox_id = tid * 2 + i
            with ib.if_scope(bbox_id < num_anchors):
                p_out[start + bbox_id] = bbox_id
        # OddEvenTransposeSort
        with ib.for_range(0, p_index[b]) as k:
            with ib.if_scope(tid < (p_index[b] + 1) // 2):
                offset = start + 2 * tid + (k % 2)
                with ib.if_scope( \
                        tvm.all(offset + 1 < p_index[0], p_data[offset] < p_data[offset + 1])):
                    temp_data[0] = p_data[offset]
                    p_data[offset] = p_data[offset + 1]
                    p_data[offset + 1] = temp_data[0]
                    temp_index[0] = p_out[offset]
                    p_out[offset] = p_out[offset + 1]
                    p_out[offset + 1] = temp_index[0]
                tvm.make.Call(None, 'tvm_storage_sync',
                              tvm.convert(['shared']), tvm.expr.Call.Intrinsic,
                              None, 0))

    return ib.get()
def test_tensor_core_batch_conv():
    if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
        print("skip because cuda is not enabled..")
    if not nvcc.have_tensorcore(tvm.gpu(0).compute_version):
        print("skip because gpu does not support tensor core")

    # The sizes of inputs and filters
    batch_size = 32
    height = 14
    width = 14
    in_channels = 32
    out_channels = 64
    kernel_h = 3
    kernel_w = 3
    pad_h = 1
    pad_w = 1
    stride_h = 1
    stride_w = 1
    block_size = 16

    block_row_warps = 2
    block_col_warps = 4
    warp_row_tiles = 4
    warp_col_tiles = 2
    warp_size = 32
    chunk = 2

    # Input feature map: (N, H, W, IC, n, ic)
    data_shape = (batch_size // block_size, height, width,
                  in_channels // block_size, block_size, block_size)
    # Kernel: (H, W, IC, OC, ic, oc)
    kernel_shape = (kernel_h, kernel_w, in_channels // block_size,
                    out_channels // block_size, block_size, block_size)

    # Output feature map: (N, H, W, OC, n, oc)
    output_shape = (batch_size // block_size, height, width,
                    out_channels // block_size, block_size, block_size)

    assert (batch_size % block_size == 0)
    assert (in_channels % block_size == 0)
    assert (out_channels % block_size == 0)

    kh = tvm.reduce_axis((0, kernel_h), name='kh')
    kw = tvm.reduce_axis((0, kernel_w), name='kw')
    ic = tvm.reduce_axis((0, in_channels // block_size), name='ic')
    ii = tvm.reduce_axis((0, block_size), name='ii')

    # Algorithm
    A = tvm.placeholder(data_shape, name='A', dtype="float16")
    W = tvm.placeholder(kernel_shape, name='W', dtype="float16")
    Apad = tvm.compute(
        (batch_size // block_size, height + 2 * pad_h, width + 2 * pad_w,
         in_channels // block_size, block_size, block_size),
        lambda n, h, w, i, nn, ii: tvm.if_then_else(
            tvm.all(h >= pad_h, h - pad_h < height, w >= pad_w, w - pad_w <
                    width), A[n, h - pad_h, w - pad_w, i, nn, ii],
            tvm.const(0., "float16")),
    Conv = tvm.compute(
        lambda n, h, w, o, nn, oo: tvm.sum(Apad[
            n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype(
                "float32") * W[kh, kw, ic, o, ii, oo].astype("float32"),
                                           axis=[ic, kh, kw, ii]),

    s = tvm.create_schedule(Conv.op)

    AS = s.cache_read(Apad, 'shared', [Conv])
    WS = s.cache_read(W, 'shared', [Conv])
    AF = s.cache_read(AS, 'wmma.matrix_a', [Conv])
    WF = s.cache_read(WS, 'wmma.matrix_b', [Conv])
    ConvF = s.cache_write(Conv, 'wmma.accumulator')

    block_x = tvm.thread_axis('blockIdx.x')
    block_y = tvm.thread_axis('blockIdx.y')
    block_z = tvm.thread_axis('blockIdx.z')
    thread_x = tvm.thread_axis('threadIdx.x')
    thread_y = tvm.thread_axis('threadIdx.y')
    thread_z = tvm.thread_axis('threadIdx.z')

    nc, hc, wc, oc, nnc, ooc = Conv.op.axis
    block_k = s[Conv].fuse(hc, wc)
    s[Conv].bind(block_k, block_z)
    nc, nci = s[Conv].split(nc, factor=warp_row_tiles)
    block_i, nc = s[Conv].split(nc, factor=block_row_warps)
    oc, oci = s[Conv].split(oc, factor=warp_col_tiles)
    block_j, oc = s[Conv].split(oc, factor=block_col_warps)
    s[Conv].reorder(block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc)
    s[Conv].bind(block_i, block_x)
    s[Conv].bind(block_j, block_y)
    s[Conv].bind(nc, thread_y)
    s[Conv].bind(oc, thread_z)

    s[ConvF].compute_at(s[Conv], oc)
    n, h, w, o, nnf, oof = ConvF.op.axis
    ko, ki = s[ConvF].split(ic, factor=chunk)
    s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii)

    s[AF].compute_at(s[ConvF], kw)
    s[WF].compute_at(s[ConvF], kw)

    s[WS].compute_at(s[ConvF], kh)
    s[AS].compute_at(s[ConvF], kh)

    n, h, w, i, nn, ii = AS.op.axis
    tx, xo = s[AS].split(n, nparts=block_row_warps)
    ty, yo = s[AS].split(xo, nparts=block_col_warps)
    t = s[AS].fuse(nn, ii)
    to, ti = s[AS].split(t, factor=warp_size)
    s[AS].bind(tx, thread_y)
    s[AS].bind(ty, thread_z)
    s[AS].bind(ti, thread_x)

    kh, kw, ic, o, ii, oo = WS.op.axis
    tx, xo = s[WS].split(o, nparts=block_row_warps)
    ty, yo = s[WS].split(xo, nparts=block_col_warps)
    t = s[WS].fuse(ii, oo)
    to, ti = s[WS].split(t, nparts=warp_size)
    s[WS].bind(tx, thread_y)
    s[WS].bind(ty, thread_z)
    s[WS].bind(to, thread_x)

                    intrin_wmma_load_matrix((16, 16, 16), 'wmma.matrix_a'))
                    intrin_wmma_load_matrix((16, 16, 16), 'wmma.matrix_b'))
    s[Conv].tensorize(nnc, intrin_wmma_store_matrix((16, 16, 16)))
    s[ConvF].tensorize(nnf, intrin_wmma_gemm((16, 16, 16)))

    func =, [A, W, Conv], 'cuda')

    ctx = tvm.gpu(0)
    a_np = np.random.uniform(size=data_shape).astype(A.dtype)
    w_np = np.random.uniform(size=kernel_shape).astype(W.dtype)
    a = tvm.nd.array(a_np, ctx)
    w = tvm.nd.array(w_np, ctx)
    c = tvm.nd.array(np.zeros(output_shape, dtype=Conv.dtype), ctx)
    evaluator = func.time_evaluator(func.entry_name, ctx, number=3)
    print('conv2d with tensor core: %f ms' % (evaluator(a, w, c).mean * 1e3))

    if VERIFY:
        func(a, w, c)
        a_np = a_np.transpose(0, 4, 1, 2, 3,
                              5).reshape(batch_size, height, width,
        w_np = w_np.transpose(0, 1, 2, 4, 3,
                              5).reshape(kernel_h, kernel_w, in_channels,
        c_np = c.asnumpy().transpose(
            (0, 4, 1, 2, 3, 5)).reshape(batch_size, height, width,
        c_std = conv2d_nhwc_python(a_np.astype(Conv.dtype),
                                   (stride_h, stride_w),
                                   (pad_h, pad_w)).astype(Conv.dtype)
        np.testing.assert_allclose(c_np, c_std, rtol=1e-4, atol=1e-4)
def make_matrix_mul_2(shapeA,
    # assert shapeA[0] == shapeA[1]
    # assert shapeB[0] == shapeB[1]

    X = tvm.placeholder(shapeA, dtype=dtype, name='X')
    Y = tvm.placeholder(shapeB, dtype=dtype, name='Y')

    if not transposeA and not transposeB:
        k = tvm.reduce_axis((0, shapeA[1]), name='k')
        Z = tvm.compute((shapeA[0], shapeB[1]),
                        lambda i, j: tvm.sum(X[i, k] * Y[k, j], axis=k),
    elif not transposeA and transposeB:
        k = tvm.reduce_axis((0, shapeA[1]), name='k')
        Z = tvm.compute((shapeA[0], shapeB[0]),
                        lambda i, j: tvm.sum(X[i, k] * Y[j, k], axis=k),
    elif transposeA and not transposeB:
        k = tvm.reduce_axis((0, shapeA[0]), name='k')
        Z = tvm.compute((shapeA[1], shapeB[1]),
                        lambda i, j: tvm.sum(X[k, i] * Y[k, j], axis=k),
    elif transposeA and transposeB:
        k = tvm.reduce_axis((0, shapeA[0]), name='k')
        Z = tvm.compute((shapeA[1], shapeB[0]),
                        lambda i, j: tvm.sum(X[k, i] * Y[j, k], axis=k),

    s = tvm.create_schedule(Z.op)
    # X_shared = s.cache_read(X, 'shared', [Z])  # should store (step * block_factor) items
    # Y_shared = s.cache_read(Y, 'shared', [Z])
    # X_local = s.cache_read(X_shared, 'local', [Z])
    # Y_local = s.cache_read(Y_shared, 'local', [Z])
    # Z_local = s.cache_write(Z, 'local')

    # tile consts
    tile = 2
    num_thread = 8
    block_factor = tile * num_thread
    step = 8
    vthread = 2

    block_x = tvm.thread_axis("blockIdx.x", name='bx')
    block_y = tvm.thread_axis("blockIdx.y", name='by')
    thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x", name='tx')
    thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y", name='ty')
    vthread_x = tvm.thread_axis((0, vthread), "vthread", name="vtx")
    vthread_y = tvm.thread_axis((0, vthread), "vthread", name="vty")

    # Split the workloads into blocks of threads
    hi, wi = s[Z].op.axis
    bx, wi = s[Z].split(wi,
                        factor=block_factor)  # wi ranges up to block_factor
    by, hi = s[Z].split(hi,
                        factor=block_factor)  # hi ranges up to block_factor

    s[Z].bind(bx, block_x)  # bx number of blocks.
    s[Z].bind(by, block_y)

    # Split into virtual threads (vthread x vthread) grid
    vtx, wi = s[Z].split(
        wi, nparts=vthread
    )  # vtx ranges up to vthread. wi ranges up to (block_factor/vthread)
    vty, hi = s[Z].split(
        hi, nparts=vthread
    )  # vty ranges up to vthread. hi ranges up to (block_factor/vthread)

    # Split each vthread block into threads (num_thread x num_thread) grid
    tx, wi = s[Z].split(
        wi, nparts=num_thread
    )  # tx ranges up to vthread. wi ranges up to (block_factor/vthread/num_thread)
    ty, hi = s[Z].split(
        hi, nparts=num_thread
    )  # ty ranges up to vthread. hi ranges up to (block_factor/vthread/num_thread)

    # Reorder from block to vthread to thread. Decreasing order of size of submatrix to be controlled
    s[Z].reorder(by, bx, vty, vtx, ty, tx)

    s[Z].bind(vty, vthread_y)
    s[Z].bind(vtx, vthread_x)
    s[Z].bind(tx, thread_x)
    s[Z].bind(ty, thread_y)

    # # Schedule Z_local local write
    # s[Z_local].compute_at(s[Z], tx) # In the computation of Z, when looping over tx (inner most and smallest granule), compute Z_local, which is a write to local memory
    # hi, wi = s[Z_local].op.axis
    # k, = s[Z_local].op.reduce_axis
    # ko, ki = s[Z_local].split(k, factor=step)
    # s[Z_local].reorder(ko, ki, hi, wi) # May be unnecessary
    # # Attach computation to iteration variables
    # s[X_shared].compute_at(s[Z_local], wi)
    # # s[Y_shared].compute_at(s[Z_local], hi)
    # s[X_local].compute_at(s[Z_local], ki)
    # s[Y_local].compute_at(s[Z_local], ki)
    # # Schedule for X's shared memory load
    # hi, wi =  s[X_shared].op.axis
    # ty, hi = s[X_shared].split(hi, nparts=num_thread)
    # tx, wi = s[X_shared].split(wi, nparts=num_thread)
    # _, wi = s[X_shared].split(wi, factor=4) # Is this 4 because of vthread = 2, vthread*vthread?
    # s[X_shared].reorder(ty, tx, hi, wi)
    # # tvm.lower(s, [X, Y, Z], simple_mode=True)
    # s[X_shared].bind(tx, thread_x)
    # s[X_shared].bind(ty, thread_y)
    # s[X_shared].vectorize(wi)
    # # Schedule for Y's shared memory load
    # hi, wi = s[Y_shared].op.axis
    # ty, hi = s[Y_shared].split(hi, nparts=num_thread)
    # tx, wi = s[Y_shared].split(wi, nparts=num_thread)
    # _, wi = s[Y_shared].split(wi, factor=4)  # Is this 4 because of vthread = 2, vthread*vthread?
    # s[Y_shared].reorder(ty, tx, hi, wi)
    # s[Y_shared].bind(tx, thread_x)
    # s[Y_shared].bind(ty, thread_y)
    # s[Y_shared].vectorize(wi)

    # s[Z].vectorize(tx)

    f =, [X, Y, Z],
    return _export_module(f, func_name, remote)
def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
    """Schedule conv2d for specific feature_in_out_filter pattern"""
    if util.get_const_int(Filter.shape[0]) + util.get_const_int(Filter.shape[1]) <= 768:
        # scheduler params
        vthread_x = util.get_const_int(Out.shape[3])
        num_thread_x = 64
        ofactor = 8
        if util.get_const_int(Filter.shape[3]) == 1:
            ofactor = 64
        block_x = tvm.thread_axis("blockIdx.x")
        thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
        thread_xz = tvm.thread_axis((0, vthread_x), "vthread", name="vx")

        i, oc, h, w = s[Out].op.axis
        ooc, ioc = s[Out].split(oc, factor=num_thread_x)
        s[Out].reorder(i, ooc, h, w, ioc)
        ooc = s[Out].fuse(h, ooc)
        s[Out].bind(ioc, thread_x)
        s[Out].bind(w, thread_xz)
        s[Out].bind(ooc, block_x)

        s[Out_L].compute_at(s[Out], ioc)

        # schedule Out_L local write
        i, oc, h, w = s[Out_L].op.axis
        ic, dh, dw = s[Out_L].op.reduce_axis
        oic, iic = s[Out_L].split(ic, ofactor)
        s[Out_L].reorder(oic, dh, dw, iic, h, w)

        s[temp_S].compute_at(s[Out_L], oic)
        s[Filter_S].compute_at(s[Out_L], oic)

        #schedule temp_S shared mem load
        i, ic, h, w = s[temp_S].op.axis
        s[temp_S].reorder(i, ic, w, h)
        ic = s[temp_S].fuse(w, ic)
        _, iic = s[temp_S].split(ic, factor=num_thread_x)
        s[temp_S].bind(iic, thread_x)

        #schedule Filter_S shared mem load
        i, oc, h, w = s[Filter_S].op.axis
        _, ii = s[Filter_S].split(i, factor=num_thread_x)
        s[Filter_S].bind(ii, thread_x)
        s[Filter_S].storage_align(s[Filter_S].op.axis[0], 2, 1)

        # scheduler params
        vthread_x = util.get_const_int(Out.shape[2])
        num_thread_x = 16
        num_thread_y = util.get_const_int(Out.shape[3])
        ofactor = 8
        block_x = tvm.thread_axis("blockIdx.x")
        thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
        thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
        thread_xz = tvm.thread_axis((0, vthread_x), "vthread", name="vx")

        i, oc, h, w = s[Out].op.axis
        ooc, ioc = s[Out].split(oc, factor=num_thread_x)
        s[Out].reorder(i, ooc, h, w, ioc)
        s[Out].bind(ioc, thread_x)
        s[Out].bind(w, thread_y)
        s[Out].bind(h, thread_xz)
        s[Out].bind(ooc, block_x)

        s[Out_L].compute_at(s[Out], ioc)

        # schedule Out_L local write
        i, oc, h, w = s[Out_L].op.axis
        ic, dh, dw = s[Out_L].op.reduce_axis
        oic, iic = s[Out_L].split(ic, ofactor)
        s[Out_L].reorder(oic, dh, dw, iic, h, w)

        s[temp_S].compute_at(s[Out_L], oic)
        s[Filter_S].compute_at(s[Out_L], oic)

        num_thread =
        thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x")
        block_xx = tvm.thread_axis("blockIdx.x")

        i = s[temp].fuse(*s[temp].op.axis)
        bx, tx = s[temp].split(i, factor=num_thread)
        s[temp].bind(tx, thread_xx)
        s[temp].bind(bx, block_xx)

        i = s[temp_R].fuse(*s[temp_R].op.axis)
        bx, tx = s[temp_R].split(i, factor=num_thread)
        s[temp_R].bind(tx, thread_xx)
        s[temp_R].bind(bx, block_xx)

        #schedule temp_S shared mem load
        i, h, w, oc, ic = s[temp_S].op.axis
        icc = s[temp_S].fuse(oc, w, h)
        oic, iic = s[temp_S].split(icc, factor=num_thread_x)
        _, ioic = s[temp_S].split(oic, factor=num_thread_y)
        s[temp_S].bind(iic, thread_x)
        s[temp_S].bind(ioic, thread_y)

        #schedule Filter_S shared mem load
        i, oc, h, w = s[Filter_S].op.axis
        _, ii = s[Filter_S].split(i, factor=num_thread_x)
        h = s[Filter_S].fuse(h, w)
        _, ih = s[Filter_S].split(h, factor=num_thread_y)
        s[Filter_S].bind(ii, thread_x)
        s[Filter_S].bind(ih, thread_y)
        s[Filter_S].storage_align(s[Filter_S].op.axis[0], 2, 1)
def make_conv2d(shapeX, shapeF, tgt, tgt_host, func_name, dtype="float32"):
    in_size, in_size, in_channel, batch = shapeX
    kernel, kernel, in_channel, out_channel = shapeF

    print(tgt, tgt_host)
    """TODO: Your code here"""
    """Hint: use tvm.reduce_axis, tvm.sum"""
    """Hint: go by conv2d definition. Treat stride=1, padding=0 case only."""
    """For a challenge, treat the general case for stride and padding."""
    # X = tvm.placeholder(shapeX, dtype=dtype, name='X')
    # F = tvm.placeholder(shapeF, dtype=dtype, name='F')
    # kx = tvm.reduce_axis((0, R), name='kx')
    # ky = tvm.reduce_axis((0, S), name='ky')
    # kc = tvm.reduce_axis((0, C), name='kc')
    # Y = tvm.compute((N, M, H - R + 1, W - S + 1),
    #                 lambda n,m,h,w: tvm.sum(X[n, kc, h + kx, w + ky] * F[m, kc, kx, ky], axis=[kx,ky,kc]))
    # s = tvm.create_schedule(Y.op)
    # block_x = tvm.thread_axis("blockIdx.x")
    # thread_x = tvm.thread_axis("threadIdx.x")
    # s[Y].bind(kx, block_x)
    # s[Y].bind(ky, thread_x)
    # f =, [X, F, Y], tgt, target_host=tgt_host, name=func_name)

    pad = 1
    stride = 1

    # Algorithm
    A = tvm.placeholder((in_size, in_size, in_channel, batch), name='A')
    W = tvm.placeholder((kernel, kernel, in_channel, out_channel), name='W')
    out_size = (in_size - kernel + 2 * pad) // stride + 1
    # Pad input
    Apad = tvm.compute(
        (in_size + 2 * pad, in_size + 2 * pad, in_channel, batch),
        lambda yy, xx, cc, nn:
            tvm.all(yy >= pad, yy - pad < in_size, xx >= pad, xx - pad <
                    in_size), A[yy - pad, xx - pad, cc, nn], tvm.const(0.)),
    # Create reduction variables
    rc = tvm.reduce_axis((0, in_channel), name='rc')
    ry = tvm.reduce_axis((0, kernel), name='ry')
    rx = tvm.reduce_axis((0, kernel), name='rx')
    # Compute the convolution
    B = tvm.compute(
        (out_size, out_size, out_channel, batch),
        lambda yy, xx, ff, nn: tvm.sum(Apad[yy * stride + ry, xx * stride + rx,
                                            rc, nn] * W[ry, rx, rc, ff],
                                       axis=[ry, rx, rc]),

    # Designate the memory hierarchy
    s = tvm.create_schedule(B.op)
    s[Apad].compute_inline()  # compute Apad inline
    AA = s.cache_read(Apad, 'shared', [B])
    WW = s.cache_read(W, "shared", [B])
    AL = s.cache_read(AA, "local", [B])
    WL = s.cache_read(WW, "local", [B])
    BL = s.cache_write(B, "local")

    # tile consts
    tile = 2
    num_thread = 16
    block_factor = tile * num_thread
    step = 4
    vthread = 1

    # Get the GPU thread indices
    block_x = tvm.thread_axis("blockIdx.x")
    block_y = tvm.thread_axis("blockIdx.y")
    block_z = tvm.thread_axis("blockIdx.z")
    thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
    thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
    thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
    thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")

    # Split the workloads
    hi, wi, fi, ni = s[B].op.axis
    bz = s[B].fuse(hi, wi)
    by, fi = s[B].split(fi, factor=block_factor)
    bx, ni = s[B].split(ni, factor=block_factor)

    # Bind the iteration variables to GPU thread indices
    s[B].bind(bz, block_z)
    s[B].bind(by, block_y)
    s[B].bind(bx, block_x)

    tyz, fi = s[B].split(fi, nparts=vthread)  # virtual thread split
    txz, ni = s[B].split(ni, nparts=vthread)  # virtual thread split
    ty, fi = s[B].split(fi, nparts=num_thread)
    tx, ni = s[B].split(ni, nparts=num_thread)
    s[B].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni)

    s[B].bind(tyz, thread_yz)
    s[B].bind(txz, thread_xz)
    s[B].bind(ty, thread_y)
    s[B].bind(tx, thread_x)

    # Schedule BL local write
    s[BL].compute_at(s[B], tx)
    yi, xi, fi, ni = s[BL].op.axis
    ry, rx, rc = s[BL].op.reduce_axis
    rco, rci = s[BL].split(rc, factor=step)
    s[BL].reorder(rco, ry, rx, rci, fi, ni)

    # Attach computation to iteration variables
    s[AA].compute_at(s[BL], rx)
    s[WW].compute_at(s[BL], rx)
    s[AL].compute_at(s[BL], rci)
    s[WL].compute_at(s[BL], rci)

    # Schedule for A's shared memory load
    yi, xi, ci, ni = s[AA].op.axis
    ty, ci = s[AA].split(ci, nparts=num_thread)
    tx, ni = s[AA].split(ni, nparts=num_thread)
    _, ni = s[AA].split(ni, factor=4)
    s[AA].reorder(ty, tx, yi, xi, ci, ni)
    s[AA].bind(ty, thread_y)
    s[AA].bind(tx, thread_x)
    s[AA].vectorize(ni)  # vectorize memory load

    # Schedule for W's shared memory load
    yi, xi, ci, fi = s[WW].op.axis
    ty, ci = s[WW].split(ci, nparts=num_thread)
    tx, fi = s[WW].split(fi, nparts=num_thread)
    _, fi = s[WW].split(fi, factor=4)
    s[WW].reorder(ty, tx, yi, xi, ci, fi)
    s[WW].bind(ty, thread_y)
    s[WW].bind(tx, thread_x)
    s[WW].vectorize(fi)  # vectorize memory load

    f =, [A, W, B], tgt, target_host=tgt_host, name=func_name)

    return _export_module(f, func_name, remote)
def get_valid_counts_pre(data, flag, idx, score_threshold, id_index,
    """Low level IR to Prepare get valid count of bounding boxes
    given a score threshold. Also moves valid boxes to the
    top of input data.

    data: Buffer
        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.

    flag : Buffer
        2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].

    idx : Buffer
        2D Buffer of valid data indices with shape [batch_size, num_anchors].

    score_threshold : float32
        Lower limit of score for valid bounding boxes.

    id_index : optional, int
        index of the class categories, -1 to disable.

    score_index: optional, int
        Index of the scores/confidence of boxes.

    stmt : Stmt
        The result IR statement.
    batch_size = data.shape[0]
    num_anchors = data.shape[1]
    box_data_length = data.shape[2]

    ib = tvm.ir_builder.create()

    data = ib.buffer_ptr(data)
    flag = ib.buffer_ptr(flag)
    idx = ib.buffer_ptr(idx)
    score_threshold = tvm.make.node("FloatImm",
    id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
    score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)

    max_threads = int(
    nthread_tx = max_threads
    nthread_bx = batch_size * num_anchors // max_threads + 1
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("blockIdx.x")
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    tid = bx * max_threads + tx

    with ib.if_scope(tid < batch_size * num_anchors):
        with ib.if_scope(tvm.all(data[tid * box_data_length + score_index] > score_threshold, \
            tvm.any(id_index < 0, data[tid * box_data_length + id_index] >= 0))):
            flag[tid] = 1
            idx[tid] = 1
        with ib.else_scope():
            flag[tid] = 0
            idx[tid] = 0

    return ib.get()
def sort_oet_ir(data, index, new_data, new_index, loc, out_index, axis_mul_before, \
                axis_mul_after, axis, is_descend):
    """Low level IR routing subfunction 3/4 for Odd-Even-Transposition sorting.

    data: Buffer
        Buffer of output boxes with class and score.

    index : Buffer
        Buffer of number of valid output boxes.

    new_data : Buffer
        Buffer of flattened segmented data.

    new_index : Buffer
        Buffer of flattened segmented indices.

    loc : Buffer
        Buffer of start locations of each sorting segment.

    out_index : Buffer
        Output buffer of output box indexes sorted by score in a flattened segmented format.

    axis_mul_before : int
        The multiplication result of axis dimensions before axis.

    axis_mul_after : int
        The multiplication result of axis dimensions after axis.

    axis : int
        The axis used for sorting.

    is_descend : bool
        If the sorted data is in descending order.

    stmt : Stmt
        The result IR statement.
    max_threads = int(
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("blockIdx.x")
    ib = tvm.ir_builder.create()
    dshape = loc.shape
    fshape = data.shape[axis] * dshape[0]
    temp_data = ib.allocate("float32", dshape, name="temp_data", scope="local")
    p_data = ib.buffer_ptr(data)
    p_index = ib.buffer_ptr(index)
    data_new = ib.buffer_ptr(new_data)
    index_new = ib.buffer_ptr(new_index)
    index_out = ib.buffer_ptr(out_index)
    sizes = ib.buffer_ptr(loc)
    nthread_tx = max_threads
    nthread_bx = fshape // max_threads + 1
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    tid = bx * max_threads + tx

    with ib.if_scope(axis_mul_before * axis_mul_after > 1):
        with ib.if_scope(tid < axis_mul_before * axis_mul_after):
            with ib.if_scope(tid == 0):
                start = 0
            with ib.else_scope():
                start = sizes[tid - 1]
            # OddEvenTransposeSort
            with ib.for_range(0, p_index[tid], name="k") as k:
                with ib.for_range(0, p_index[tid] - 1, name="i") as i:
                    with ib.if_scope(i % 2 == k % 2):
                        with ib.if_scope(
                            ((data_new[i + start] <
                              data_new[i + start + 1]) == is_descend)):
                            temp_data[tid] = data_new[i + start]
                            data_new[i + start] = data_new[i + start + 1]
                            data_new[i + start + 1] = temp_data[tid]
                            index_out[tid] = index_new[i + start]
                            index_new[i + start] = index_new[i + start + 1]
                            index_new[i + start + 1] = index_out[tid]
        with ib.if_scope(tid < 1):
            with ib.for_range(0, sizes[dshape[0] - 1], name="i") as i:
                index_out[i] = index_new[i]
    with ib.else_scope():
        with ib.for_range(0, fshape, name="k", for_type="unroll") as k:
            with ib.if_scope(tvm.all(k % 2 == tid % 2, tid < fshape)):
                with ib.if_scope(k % 2 == 0):
                    with ib.if_scope(tvm.all(tid + 1 < fshape, (p_data[tid] < p_data[tid+1]) \
                                             == is_descend)):
                        data_new[tid] = p_data[tid + 1]
                        index_out[tid] = index_new[tid + 1]
                    with ib.else_scope():
                        data_new[tid] = p_data[tid]
                        index_out[tid] = index_new[tid]
                with ib.else_scope():
                    with ib.if_scope(tvm.all(tid + 1 < fshape, (data_new[tid] < data_new[tid+1]) \
                                             == is_descend)):
                        p_data[tid] = data_new[tid + 1]
                        index_new[tid] = index_out[tid + 1]
                    with ib.else_scope():
                        p_data[tid] = data_new[tid]
                        index_new[tid] = index_out[tid]
            with ib.if_scope(tvm.all(k % 2 != tid % 2, tid < fshape)):
                with ib.if_scope(k % 2 == 0):
                    with ib.if_scope(
                            tvm.all(tid > 0, (p_data[tid - 1] <
                                              p_data[tid]) == is_descend)):
                        data_new[tid] = p_data[tid - 1]
                        index_out[tid] = index_new[tid - 1]
                    with ib.else_scope():
                        data_new[tid] = p_data[tid]
                        index_out[tid] = index_new[tid]
                with ib.else_scope():
                    with ib.if_scope(tvm.all(tid > 0, (data_new[tid-1] < data_new[tid]) \
                                             == is_descend)):
                        p_data[tid] = data_new[tid - 1]
                        index_new[tid] = index_out[tid - 1]
                    with ib.else_scope():
                        p_data[tid] = data_new[tid]
                        index_new[tid] = index_out[tid]
        with ib.if_scope(fshape % 2 == 1):
            with ib.if_scope(tid < 1):
                with ib.for_range(0, fshape, name="k") as k:
                    index_out[tid] = index_new[tid]
    body = ib.get()
    return body
def _schedule_conv2d_NCHWc_int8(cfg, s, output):
    conv = output.op.input_tensors[0]
    packed_data, packed_kernel = conv.op.input_tensors

    if isinstance(packed_data.op,
                  tvm.tensor.ComputeOp) and "pad" in packed_data.op.tag:
        pad_data = packed_data
        packed_data = pad_data.op.input_tensors[0]
        pad_data = packed_data

    if autotvm.GLOBAL_SCOPE.in_tuning:
        # skip this part during tuning to make recrods accurate
        # this part will be pre-computed during NNVM's pre-compute optimization pass
        s[packed_data].pragma(s[packed_data].op.axis[0], "debug_skip_region")
        if isinstance(packed_kernel.op, tvm.tensor.ComputeOp) and\
              == 'packed_kernel':
            # data and kernel are not pre-computed, schedule layout transform here
            schedule_injective_from_existing(s, packed_data)
            schedule_injective_from_existing(s, packed_kernel)

    if pad_data != packed_data:

    # create cache stage
    AA = s.cache_read(pad_data, 'shared', [conv])
    WW = s.cache_read(packed_kernel, 'shared', [conv])


    # handle bias
    if output.op not in s.outputs:
        output = s.outputs[0].output(0)

    # tile and bind spatial axes
    n, f, y, x, c = s[output].op.axis
    cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
    cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
    cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
    cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)

    # this is the scope to attach global config inside this kernel
    kernel_scope, n = s[output].split(n, nparts=1)

    bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)

    s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi,
                      yi, xi)
    s[output].bind(bn, tvm.thread_axis("blockIdx.z"))
    s[output].bind(bf, tvm.thread_axis("blockIdx.y"))
    s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x"))
    s[output].bind(vn, tvm.thread_axis("vthread"))
    s[output].bind(vf, tvm.thread_axis("vthread"))
    s[output].bind(vy, tvm.thread_axis("vthread"))
    s[output].bind(vx, tvm.thread_axis("vthread"))

    cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
    if cfg["fuse_yx"].val:
        s[output].bind(tn, tvm.thread_axis("threadIdx.z"))
        s[output].bind(tf, tvm.thread_axis("threadIdx.y"))
        tyx = s[output].fuse(ty, tx)
        s[output].bind(tyx, tvm.thread_axis("threadIdx.x"))
        s[conv].compute_at(s[output], tyx)

        # number of threads
        n_tz = cfg["tile_n"].size[2]
        n_ty = cfg["tile_f"].size[2]
        n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
        s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z"))
        s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
        s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
        s[conv].compute_at(s[output], tx)

        # number of threads
        n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
        n_ty = cfg["tile_y"].size[2]
        n_tx = cfg["tile_x"].size[2]

    # tile and bind reduction axes
    n, f, y, x, c = s[conv].op.axis

    rc, ry, rx, rc_block = s[conv].op.reduce_axis
    cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=2)
    cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=2)
    cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=2)
    rco, rci = cfg['tile_rc'].apply(s, conv, rc)
    ryo, ryi = cfg['tile_ry'].apply(s, conv, ry)
    rxo, rxi = cfg['tile_rx'].apply(s, conv, rx)

    s[conv].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x, c, rc_block)

    cfg.define_reorder("reorder_inner", [rco, ryo, rxo], policy="all")
    cfg["reorder_inner"].apply(s, conv, [rco, ryo, rxo])
    cfg["reorder_inner"].apply(s, conv, [rci, ryi, rxi])

    _, rc_block = s[conv].split(rc_block, factor=4)
    s[conv].tensorize(rc_block, _dp4a)

    cache_loc = [rco, ryo, rxo][cfg["reorder_inner"].perm[-1]]
    s[AA].compute_at(s[conv], cache_loc)
    s[WW].compute_at(s[conv], cache_loc)

    # cooperative fetching
    for load in [AA, WW]:
        c = s[load].op.axis[-1]
        c_outer, c = s[load].split(c, factor=4)
        fused = s[load].op.axis[:-1] + [c_outer]
        fused = s[load].fuse(*fused)

        fused, tx = s[load].split(fused, factor=n_tx)
        fused, ty = s[load].split(fused, factor=n_ty)
        fused, tz = s[load].split(fused, factor=n_tz)
        s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
        s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
        s[load].bind(tx, tvm.thread_axis("threadIdx.x"))

    # double buffer
    cfg.define_knob('AA_double_buffer', [0, 1])
    cfg.define_knob('WW_double_buffer', [0, 1])
    if cfg['AA_double_buffer'].val:
    if cfg['WW_double_buffer'].val:

    # unroll
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
    s[output].pragma(kernel_scope, 'auto_unroll_max_step',
    s[output].pragma(kernel_scope, 'unroll_explicit', False)

    return s
def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress,
    """Low level IR routing for transform location in multibox_detection operator.

    data: Buffer
        Buffer of output boxes with class and score.

    sort_result : Buffer
        Buffer of output box indexes sorted by score.

    valid_count : Buffer
        Buffer of number of valid output boxes.

    out : Buffer
        Output buffer.

    nms_threshold : float
        Non-maximum suppression threshold.

    force_suppress : boolean
        Whether to suppress all detections regardless of class_id.

    nms_topk : int
        Keep maximum top k detections before nms, -1 for no limit.

    stmt : Stmt
        The result IR statement.
    def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
        """Calculate overlap of two boxes.
        w = tvm.make.Max(
            tvm.make.Min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
            - tvm.make.Max(out_tensor[box_a_idx], out_tensor[box_b_idx]))
        h = tvm.make.Max(
            tvm.make.Min(out_tensor[box_a_idx + 3],
                         out_tensor[box_b_idx + 3]) -
            tvm.make.Max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]))
        i = w * h
        u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \
            (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \
            (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \
            (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i
        return <= 0.0, 0.0, i / u)

    max_threads = int(
    tx = tvm.thread_axis("threadIdx.x")
    ty = tvm.thread_axis("threadIdx.y")
    bx = tvm.thread_axis("blockIdx.x")
    by = tvm.thread_axis("blockIdx.y")
    ib = tvm.ir_builder.create()
    p_data = ib.buffer_ptr(data)
    p_sort_result = ib.buffer_ptr(sort_result)
    p_valid_count = ib.buffer_ptr(valid_count)
    p_out = ib.buffer_ptr(out)
    batch_size = out.shape[0]
    num_anchors = out.shape[1]
    nthread_tx = max_threads
    nthread_bx = num_anchors // max_threads + 1
    nthread_ty = max_threads
    nthread_by = 6 // max_threads + 1
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(ty, "thread_extent", nthread_ty)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    ib.scope_attr(by, "thread_extent", nthread_by)
    i = bx * max_threads + tx
    j = by * max_threads + ty

    nms_threshold_node = tvm.make.node("FloatImm",
    nms_topk_node = tvm.make.node("IntImm", dtype="int32", value=nms_topk)
    force_suppress_node = tvm.make.node("IntImm",
                                        value=1 if force_suppress else 0)
    with ib.for_range(0, batch_size, for_type="unroll", name="n") as n:
        with ib.if_scope(
                tvm.all(nms_threshold_node > 0, nms_threshold_node < 1,
                        p_valid_count[0] > 0)):
            # Reorder output
            nkeep =
                tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]),
                nms_topk, p_valid_count[n])
            with ib.if_scope(i < nkeep):
                with ib.if_scope(j < 6):
                    p_out[(n * num_anchors * 6 + i * 6 + j)] = p_data[(
                        n * num_anchors * 6 +
                        p_sort_result[n * num_anchors + i] * 6 + j)]
            with ib.if_scope(
                    tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n])):
                with ib.if_scope(i < p_valid_count[n] - nkeep):
                    with ib.if_scope(j < 6):
                        p_out[(n * num_anchors * 6 + (i + nkeep) * 6 +
                               j)] = p_data[(n * num_anchors * 6 +
                                             (i + nkeep) * 6 + j)]
            # Apply nms
            with ib.if_scope(i < p_valid_count[n]):
                offset_i = i * 6
                with ib.if_scope(p_out[n * num_anchors * 6 + offset_i] >= 0):
                    with ib.if_scope(j < p_valid_count[n]):
                        offset_j = j * 6
                        with ib.if_scope(
                                    j > i,
                                    p_out[n * num_anchors * 6 + offset_j] >=
                            with ib.if_scope(
                                        force_suppress_node > 0,
                                        p_out[n * num_anchors * 6 + offset_i]
                                        == p_out[n * num_anchors * 6 +
                                # When force_suppress == True or class_id equals
                                iou = calculate_overlap(
                                    p_out, n * num_anchors * 6 + offset_i + 2,
                                    n * num_anchors * 6 + offset_j + 2)
                                with ib.if_scope(iou >= nms_threshold):
                                    p_out[n * num_anchors * 6 +
                                          offset_j] = -1.0
        with ib.else_scope():
            with ib.if_scope(i < p_valid_count[n]):
                with ib.if_scope(j < 6):
                    p_out[(n * num_anchors * 6 + i * 6 +
                           j)] = p_data[n * num_anchors * 6 + i * 6 + j]
        # Set invalid entry to be -1
        with ib.if_scope(i < num_anchors - p_valid_count[n]):
            with ib.if_scope(j < 6):
                p_out[n * num_anchors * 6 + (i + p_valid_count[n]) * 6 +
                      j] = -1.0
    body = ib.get()
    return body
def convolutionf16(D,F,LOAD_INDEX_D,LOAD_INDEX_F,O):
    ib = tvm.ir_builder.create()
    bidx = block_x
    tidx = thread_x
    #set shared memory buffer for loading data as column
    shmem_O = ib.allocate("float16", 16384, name="shmem_O",scope = "shared")

    shmem_D = ib.allocate("float16", 3072, name="shmem_D",scope = "shared")

    shmem_F = ib.allocate("float16", 3072, name="shmem_F",scope = "shared")

    #sync thread model
    sync = tvm.call_extern("float32","__syncthreads")

    #declare matrix fragement
    Define_matrix_fragment_a = tvm.call_extern("float32","SET_FRAGMENT_A",warp_col_tile)
    Define_matrix_fragment_b = tvm.call_extern("float32","SET_FRAGMENT_B",warp_row_tile)

    Define_matrix_fragment_c = tvm.call_extern("float32","SET_FRAGMENT_CF16",warp_col_tile,warp_row_tile)

    #set the loading index to current location
    #caculate the id of current warp
    warpid = tidx//32
    #caculate the id of current thread inside current warp
    lane = tidx%32
    #number of element in a row for shared memory
    o_row_num = warp_row_tile*block_row_warp*16
    # offset to point the pointer to the start of current warp in shared memory
    Dp = D.access_ptr("r")
    Fp = F.access_ptr("r")
    #offset_sh = bidx*thread_num*index_len+tidx*index_len
    #offset_sh = bidx+tidx+100000
    LDp = LOAD_INDEX_D.access_ptr("r")
    LFp = LOAD_INDEX_F.access_ptr("r")
    define_d_index = tvm.call_extern("float32","INDEXPOINTERD",LDp)
    define_f_index = tvm.call_extern("float32","INDEXPOINTERF",LFp)
    #loading parameter
    row_num = 16+shieft

    warp_offset_o = warpid%block_row_warp*16*warp_row_tile+warpid/block_row_warp*warp_col_tile*16*o_row_num

    offset_warp_row = warpid%block_row_warp*warp_row_tile*16*row_num

    offset_warp_col = warpid/block_row_warp*warp_col_tile*row_num*16
    offset_sh_load = warpid*16*row_num+(lane/2)*row_num+8*(lane%2)

    fragement_step = 16*row_num

    #main loop for computing the conv
    with ib.for_range(0,loop_len,name ="blk_id") as blk_id:
        with ib.if_scope(bidx+blk_id*block_num<rD*rF//block_len//block_len):
            #compute the location of current block   
            bx = (bidx+blk_id*block_num)//(rD//block_len)
            by = (bidx+blk_id*block_num)%(rD//block_len)
            #store the result from last computation
            with ib.for_range(0,warp_col_tile,name = "col_id") as col_id:
                with ib.for_range(0,warp_row_tile, name = "row_id") as row_id:
                    store_O_fragment = tvm.call_extern("float32","STOREFRAG_C_F16",shmem_O[warp_offset_o+col_id*16*o_row_num+row_id*16],\
                    #onebyone_store = tvm.call_extern("float32","ONEBYONE",shmem_O[warp_offset_o+col_id*16*o_row_num+row_id*16],col_id,row_id,o_row_num)
            Op = O.access_ptr("w")
            store_O_matrix = tvm.call_extern("float32","store_O_matrix",\
            #set pointers                   

            #col_index_d = by*16*warp_row_tile*block_row_warp+warpid*16+lane/2
            #row_index_d = 8*(lane%2)
            #col_index_f = bx*16*warp_col_tile*block_col_warp+warpid*16+lane/2
            #row_index_f = 8*(lane%2)

            #LDp = LOAD_INDEX_D.access_ptr("r",offset = col_index_d*cD+row_index_d)
            #LFp = LOAD_INDEX_F.access_ptr("r",offset = col_index_f*cF+row_index_f)
            #now load F, D
            offset_sh = offset_sh_load
            load_D_matrix = tvm.call_extern("float32","LOAD_MATRIX_D",Dp,shmem_D[offset_sh])
            offset_sh = offset_sh_load
            load_F_matrix = tvm.call_extern("float32","LOAD_MATRIX_F",Fp,shmem_F[offset_sh])           
            #load the fragement
            #load the out put matrix fragment
            with ib.for_range(0,warp_col_tile,name = "col_id") as col_id:
                with ib.for_range(0,warp_row_tile, name = "row_id") as row_id:              
                    fill_O_zero = tvm.call_extern("float","FILLZERO_CF16",col_id,row_id)

            with ib.for_range(0,cD//16,name = "reduce_crs") as reduce_crs:
                offset_sh = offset_warp_col
                with ib.for_range(0,warp_col_tile,name = "col") as col:
                    #offset_sh+= col*16*row_num
                    load_matrix_frag_F = tvm.call_extern("float32","LOADFRAG_A",shmem_F[offset_sh],col,row_num)
                offset_sh = offset_warp_row
                with ib.for_range(0,warp_row_tile,name = "row") as row:
                    #offset_sh+= row*16*row_num
                    load_matrix_frag_D = tvm.call_extern("float32","LOADFRAG_B",shmem_D[offset_sh],row,row_num)
                with ib.for_range(0,warp_col_tile,name = "col") as col:
                    with ib.for_range(0,warp_row_tile,name = "row") as row:
                        wmma_compute = tvm.call_extern("float32","WMMA_SYNC",col,row)
                #load data of the next iteration if it is not the last
                with ib.if_scope(reduce_crs<cD//16-1):
                    #reset pointer location
                    #row_index_d = 16*(reduce_crs+1)+8*(lane%2)
                    #row_index_f = 16*(reduce_crs+1)+8*(lane%2)

                    #LDp = LOAD_INDEX_D.access_ptr("r",offset = col_index_d*cD+row_index_d)
                    #LFp = LOAD_INDEX_F.access_ptr("r",offset = col_index_f*cF+row_index_f)
                    offset_sh = offset_sh_load
                    load_D_matrix = tvm.call_extern("float32","LOAD_MATRIX_D",Dp,shmem_D[offset_sh])
                    offset_sh = offset_sh_load
                    load_F_matrix = tvm.call_extern("float32","LOAD_MATRIX_F",Fp,shmem_F[offset_sh])           
                    #load the fragement
                    with ib.for_range(0,warp_col_tile,name = "col") as col:
                        offset_sh = col*16*row_num+warpid/block_row_warp*warp_col_tile*16*row_num
                        load_matrix_frag_F = tvm.call_extern("float32","LOADFRAG_A",shmem_F[offset_sh],col,row_num)

                    with ib.for_range(0,warp_row_tile,name = "row") as row:
                        offset_sh = row*16*row_num+warpid%block_row_warp*warp_row_tile*16*row_num
                        load_matrix_frag_D = tvm.call_extern("float32","LOADFRAG_B",shmem_D[offset_sh],row,row_num)
                with ib.if_scope(reduce_crs == cD//16-1):
                    with ib.for_range(0,warp_col_tile,name = "col_id") as col_id:
                        with ib.for_range(0,warp_row_tile, name = "row_id") as row_id:
                            store_O_fragment = tvm.call_extern("float32","STOREFRAG_C_F16",shmem_O[warp_offset_o+col_id*16*o_row_num+row_id*16],\
                    Op = O.access_ptr("w")
                    store_O_matrix = tvm.call_extern("float32","store_O_matrix",\
    body = ib.get()
def nms_ir(data, sorted_index, valid_count, out, box_indices, max_output_size,
           iou_threshold, force_suppress, top_k, coord_start, id_index,
    """Low level IR routing for transform location in multibox_detection operator.

    data : Buffer
        Buffer of output boxes with class and score.

    sort_index : Buffer
        Buffer of output box indexes sorted by score.

    valid_count : Buffer
        Buffer of number of valid output boxes.

    out : Buffer
        Output buffer.

    max_output_size : int
        Max number of output valid boxes for each instance.
        By default all valid boxes are returned.

    iou_threshold : float
        Overlapping(IoU) threshold to suppress object with smaller score.

    force_suppress : boolean
        Whether to suppress all detections regardless of class_id.

    top_k : int
        Keep maximum top k detections before nms, -1 for no limit.

    coord_start : int
        Start index of the consecutive 4 coordinates.

    id_index : int
        index of the class categories, -1 to disable.

    score_index : optional, int
        Index of the scores/confidence of boxes.

    stmt : Stmt
        The result IR statement.
    def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
        """Calculate overlap of two boxes.
        w = tvm.max(
            tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) -
            tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx]))
        h = tvm.max(
            tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) -
            tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]))
        i = w * h
        u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \
            (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \
            (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \
            (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i
        return tvm.expr.Select(u <= 0.0, 0.0, i / u)

    batch_size = data.shape[0]
    num_anchors = data.shape[1]
    box_data_length = data.shape[2]

    ib = tvm.ir_builder.create()

    data = ib.buffer_ptr(data)
    sorted_index = ib.buffer_ptr(sorted_index)
    valid_count = ib.buffer_ptr(valid_count)
    out = ib.buffer_ptr(out)
    box_indices = ib.buffer_ptr(box_indices)
    num_valid_boxes = ib.allocate("int32", (1, ),

    max_threads = int(
    nthread_tx = max_threads
    nthread_bx = num_anchors // max_threads + 1
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("blockIdx.x")
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    j = bx * max_threads + tx

    iou_threshold = tvm.make.node("FloatImm",
    top_k = tvm.make.node("IntImm", dtype="int32", value=top_k)
    coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start)
    id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
    score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)
    force_suppress = tvm.make.node("IntImm",
                                   value=1 if force_suppress else 0)

    with ib.for_range(0, batch_size, for_type="unroll") as i:
        base_idx = i * num_anchors * box_data_length
        with ib.if_scope(tvm.all(iou_threshold > 0, valid_count[i] > 0)):
            # Reorder output
            nkeep = if_then_else( \
                    tvm.all(top_k > 0, top_k < valid_count[i]),
                    top_k, valid_count[i])
            with ib.if_scope(j < nkeep):
                with ib.for_range(0, box_data_length) as k:
                    out[(base_idx + j * box_data_length + k)] = \
                    data[(base_idx + sorted_index[i * num_anchors + j] \
                    * box_data_length + k)]
                box_indices[i * num_anchors +
                            j] = sorted_index[i * num_anchors + j]
            with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])):
                with ib.if_scope(j < valid_count[i] - nkeep):
                    with ib.for_range(0, box_data_length) as k:
                        out[(base_idx + (j + nkeep) * box_data_length +
                             k)] = -1.0
                    box_indices[i * num_anchors + (j + nkeep)] = -1
            # Apply nms
            with ib.for_range(0, valid_count[i]) as k:
                offset_k = k * box_data_length
                with ib.if_scope(tvm.all(out[base_idx + offset_k + score_index] > 0, \
                    tvm.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0))):
                    with ib.if_scope(j < valid_count[i]):
                        offset_j = j * box_data_length
                        with ib.if_scope(tvm.all(j > k, \
                            out[base_idx + offset_j + score_index] > 0, \
                                                 tvm.any(id_index < 0, \
                                                    out[base_idx + offset_j + id_index] >= 0), \
       tvm.any(force_suppress > 0, id_index < 0, \
                                                         out[base_idx + offset_k + id_index] == \
                                                         out[base_idx + offset_j + id_index]))):
                            iou = calculate_overlap(
                                out, base_idx + offset_j + coord_start,
                                base_idx + offset_k + coord_start)
                            with ib.if_scope(iou >= iou_threshold):
                                out[base_idx + offset_j + score_index] = -1.0
                                with ib.if_scope(id_index >= 0):
                                    out[base_idx + offset_j + id_index] = -1.0
                                box_indices[i * num_anchors + j] = -1
        with ib.else_scope():
            with ib.if_scope(j < valid_count[i]):
                offset_j = j * box_data_length
                with ib.for_range(0, box_data_length) as k:
                    out[(base_idx + offset_j + k)] = data[base_idx + offset_j +
                box_indices[i * num_anchors + j] = j
        # Set invalid entry to be -1
        with ib.if_scope(j < num_anchors - valid_count[i]):
            with ib.for_range(0, box_data_length) as k:
                out[base_idx + (j + valid_count[i]) * box_data_length +
                    k] = -1.0
            box_indices[i * num_anchors + j + valid_count[i]] = -1
        # Only return max_output_size number of valid boxes
        num_valid_boxes[0] = 0
        with ib.if_scope(max_output_size > 0):
            with ib.if_scope(j < valid_count[i]):
                offset_j = j * box_data_length
                with ib.if_scope(out[base_idx + offset_j] >= 0):
                    with ib.if_scope(num_valid_boxes[0] == max_output_size):
                        with ib.for_range(0, box_data_length) as k:
                            out[base_idx + offset_j + k] = -1.0
                        box_indices[i * num_anchors + j] = -1
                    with ib.else_scope():
                        num_valid_boxes[0] += 1

    return ib.get()
# Direct Declare Extern Math Call
# -------------------------------
# The most straight-forward way to call target specific function is via
# extern function call construct in tvm.
# In the following example, we use :any:`tvm.call_pure_extern` to call
# :code:`__expf` function, which is only available under CUDA.
n = tvm.var("n")
A = tvm.placeholder((n, ), name='A')
B = tvm.compute(A.shape,
                lambda i: tvm.call_pure_extern("float32", "__expf", A[i]),
s = tvm.create_schedule(B.op)
num_thread = 64
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
f =, [A, B], "cuda", name="myexp")

# Unified Intrinsic Call
# ----------------------
# The above code verifies that direct external call can be used to
# call into device specific functions.
# However, the above way only works for CUDA target with float type.
# Ideally, we want to write same code for any device and any data type.
# TVM intrinsic provides the user a mechanism to achieve this, and this
# is the recommended way to solve the problem.
# The following code use tvm.exp instead, which create an intrinsic call