Esempio n. 1
0
def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, out):
    """Generate scatter ir for 1d inputs, using a sorting based approach.
    By sorting indices and comparing neighboring two indices, we can tell which
    of elements in the indices tensor can scatter its update value into the output.
    Sorting of indices, and sorting of updates with respect to indices, can be done
    at the same time by thrust's sort_by_key function. It is important that sorting
    be done in a "stable" way via stable_sort, to guarantee deterministic output.
    Negative indices are assumed to have been converted to corresponding positive
    indices.

    Parameters
    ----------
    data : tir.Tensor
        The input data to the operator.

    indices_sorted : tir.Tensor
        The sorted index locations to update.

    updates : tir.Tensor
        The values to update, sorted by indices.

    out : tir.Tensor
        The output tensor.

    Returns
    -------
    ret : tir
        The computational ir.
    """
    n = data.shape[0]

    ib = tvm.tir.ir_builder.create()

    out_ptr = ib.buffer_ptr(out)
    data_ptr = ib.buffer_ptr(data)

    max_threads = int(
        tvm.target.Target.current(allow_none=False).max_num_threads)
    nthread_tx = max_threads

    with ib.new_scope():
        nthread_bx = ceil_div(n, nthread_tx)
        tx = te.thread_axis("threadIdx.x")
        bx = te.thread_axis("blockIdx.x")
        ib.scope_attr(tx, "thread_extent", nthread_tx)
        ib.scope_attr(bx, "thread_extent", nthread_bx)
        tid = bx * nthread_tx + tx
        with ib.if_scope(tid < n):
            out_ptr[tid] = data_ptr[tid]

    indices_ptr = ib.buffer_ptr(indices_sorted)
    updates_ptr = ib.buffer_ptr(updates_sorted)

    ni = indices_sorted.shape[0]

    with ib.new_scope():
        nthread_bx = ceil_div(ni, nthread_tx)
        tx = te.thread_axis("threadIdx.x")
        bx = te.thread_axis("blockIdx.x")
        ib.scope_attr(tx, "thread_extent", nthread_tx)
        ib.scope_attr(bx, "thread_extent", nthread_bx)
        tid = bx * nthread_tx + tx

        with ib.if_scope(tid == ni - 1):
            # The last element can always update.
            index = indices_ptr[tid]
            update = updates_ptr[tid]
            out_ptr[index] = update

        with ib.else_scope():
            with ib.if_scope(tid < ni - 1):
                index = indices_ptr[tid]
                index_next = indices_ptr[tid + 1]

                # If the next neighbor in the sorted list of indices has a different index,
                # that means thread tid is the last one to have this index.
                # This thread can update the output.
                with ib.if_scope(index != index_next):
                    update = updates_ptr[tid]
                    out_ptr[index] = update

    return ib.get()
Esempio n. 2
0
def schedule_conv2d_winograd(cfg, s, output, pre_computed):
    """Schedule winograd template"""
    inverse = s[output].op.input_tensors[0]
    bgemm, A = s[inverse].op.input_tensors
    kernel_pack, data_pack_trans = s[bgemm].op.input_tensors
    data_pack = s[data_pack_trans].op.input_tensors[0]
    input_tile, B = s[data_pack].op.input_tensors
    pad_data = s[input_tile].op.input_tensors[0]

    # data transform
    s[B].compute_inline()
    s[A].compute_inline()

    # probably will improve real topology execution
    if autotvm.GLOBAL_SCOPE.in_tuning:
        # Padding to texture
        AA = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [input_tile])
        bind_data_copy(s[AA])

    s[input_tile].compute_inline()

    OL = s.cache_write(data_pack, "local")
    c, p, eps, nu, cb = s[data_pack].op.axis
    fused = s[data_pack].fuse(c, p, eps, nu)
    bx, tx = s[data_pack].split(fused, 128)
    s[data_pack].vectorize(cb)
    s[data_pack].bind(bx, te.thread_axis("blockIdx.x"))
    s[data_pack].bind(tx, te.thread_axis("threadIdx.x"))

    _, _, eps, nu, cb = s[OL].op.axis
    r_a, r_b = s[OL].op.reduce_axis
    s[OL].unroll(eps)
    s[OL].unroll(nu)
    s[OL].unroll(r_a)
    s[OL].unroll(r_b)
    s[OL].vectorize(cb)
    s[OL].compute_at(s[data_pack], tx)
    s[data_pack].set_scope(get_texture_storage(data_pack.shape))

    s[data_pack_trans].compute_inline()

    # transform kernel
    if not pre_computed:
        kernel, G = s[kernel_pack].op.input_tensors
        eps, nu, ci, co, cob = s[kernel_pack].op.axis
        if autotvm.GLOBAL_SCOPE.in_tuning:
            # skip this part during tuning to make recrods accurate
            # this part will be pre-computed during pre-compute optimization pass
            s[G].pragma(s[G].op.axis[0], "debug_skip_region")
            s[kernel_pack].pragma(eps, "debug_skip_region")
        else:
            s[G].compute_inline()
            r_a, r_b = s[kernel_pack].op.reduce_axis
            for axis in [eps, nu, r_a, r_b]:
                s[kernel_pack].unroll(axis)

            fused = s[kernel_pack].fuse(ci, co)
            bb, tt = s[kernel_pack].split(fused, 128)
            s[kernel_pack].reorder(bb, tt, eps, nu, r_a, r_b, cob)
            s[kernel_pack].vectorize(cob)
            s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x"))
            s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x"))
    else:
        kernel = kernel_pack

    if isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in kernel.op.tag:
        # manage scheduling of datacopy
        pack_data = pad_data.op.input_tensors[0]
        bind_data_copy(s[pack_data])
        bind_data_copy(s[kernel])
    elif isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
        s[kernel].compute_inline()
    s[pad_data].compute_inline()

    ##### space definition begin #####
    cfg.define_knob("auto_unroll_max_step", [0, 4, 16])
    b1, b2, y, x, cb = s[bgemm].op.axis
    rcc = s[bgemm].op.reduce_axis[0]
    alpha = get_const_int(b1.dom.extent)

    cfg.define_split(
        "tile_y", y, num_outputs=3, filter=lambda entry: entry.size[2] <= 64 and entry.size[1] <= 16
    )

    min_x_div = 1
    for bn in range(4, 0, -1):
        if bgemm.shape[3] % bn == 0:
            min_x_div = bn
            break

    cfg.define_split(
        "tile_x",
        x,
        num_outputs=3,
        filter=lambda entry: entry.size[2] <= 64
        and entry.size[1] >= min_x_div
        and entry.size[1] <= 16,
    )
    cfg.define_split("tile_rc", rcc, num_outputs=2)
    # TODO: Uncomment the following lines when multi_filter will be introduced
    # cfg.multi_filter(
    # filter=lambda entity: entity["tile_y"].size[2] * entity["tile_x"].size[2] in range(32,1024)
    # )
    ##### space definition end #####

    # batch gemm
    OL = s.cache_write(bgemm, "local")
    if (
        autotvm.GLOBAL_SCOPE.in_tuning
        or isinstance(kernel.op, tvm.te.ComputeOp)
        and "filter_pack" in kernel.op.tag
    ):
        BB = s.cache_read(kernel_pack, get_texture_storage(kernel_pack.shape), [OL])
        bind_data_copy(s[BB])

    by = s[bgemm].fuse(b1, b2, y)

    # tile and bind spatial axes
    bgemm_scope, by = s[bgemm].split(by, nparts=1)
    by, vy, ty = cfg["tile_y"].apply(s, bgemm, by)
    bx, vx, tx = cfg["tile_x"].apply(s, bgemm, x)
    s[bgemm].bind(by, te.thread_axis("blockIdx.y"))
    s[bgemm].bind(bx, te.thread_axis("blockIdx.x"))
    s[bgemm].bind(vy, te.thread_axis("vthread"))
    s[bgemm].bind(vx, te.thread_axis("vthread"))
    s[bgemm].bind(ty, te.thread_axis("threadIdx.y"))
    s[bgemm].bind(tx, te.thread_axis("threadIdx.x"))
    s[bgemm].reorder(bgemm_scope, by, bx, vy, vx, ty, tx, cb)
    s[bgemm].vectorize(cb)
    s[bgemm].set_scope(get_texture_storage(bgemm.shape))

    # tile reduction axes
    s[OL].compute_at(s[bgemm], tx)
    b1, b2, y, x, cb = s[OL].op.axis
    (rcc, rcb) = s[OL].op.reduce_axis
    b = s[OL].fuse(b1, b2)
    s[OL].reorder(b, y, x, rcc, rcb, cb)
    # s[OL].unroll(rcb)
    s[OL].pragma(rcb, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
    s[OL].pragma(rcb, "unroll_explicit", True)
    s[OL].vectorize(cb)

    # schedule inverse, output and fusion
    if output.op in s.outputs:
        OL = None
    else:
        OL = output
        s[OL].set_scope("local")
        output = s.outputs[0]

    if len(s[output].op.axis) == 4:
        n, co, h, w = s[output].op.axis
        cb = None
    else:
        n, co, h, w, cb = s[output].op.axis
    inverse_scope, n = s[output].split(n, nparts=1)

    fused = s[output].fuse(n, co, h, w)
    bb, tt = s[output].split(fused, 128)
    if cb is not None:
        s[output].reorder(bb, tt, cb)
        s[output].vectorize(cb)

    s[output].bind(bb, te.thread_axis("blockIdx.x"))
    s[output].bind(tt, te.thread_axis("threadIdx.x"))

    if OL is not None:
        s[OL].compute_at(s[output], tt)

    co, p, vh, vw, cb = s[inverse].op.axis
    r_a, r_b = s[inverse].op.reduce_axis
    for axis in [vh, vw, r_a, r_b]:
        s[inverse].unroll(axis)
    s[inverse].vectorize(cb)
    s[inverse].compute_at(s[output], tt)

    return s
Esempio n. 3
0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import numpy as np
import topi
import unittest
from tvm.contrib.nvcc import have_fp16, have_int8
from tvm.contrib import nvcc

tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")


def test_cuda_vectorize_add():
    num_thread = 8

    def check_cuda(dtype, n, lanes):
        if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
            print("skip because cuda is not enabled..")
            return
        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
            print("Skip because gpu does not have fp16 support")
            return
        if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version):
            print("skip because gpu does not support int8")
Esempio n. 4
0
def get_valid_counts_ir(data, valid_indices, valid_boxes, out, out_indices):
    """Low level IR to get valid count of bounding boxes
    given a score threshold. Also prepares to move valid boxes to the
    top of input data.

    Parameters
    ----------
    data : Buffer
        Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length].

    valid_indices: Buffer
        2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].

    Returns
    -------
    out : Buffer
        Sorted valid boxes

    out_indices : Buffer
        Incidices of valid boxes in original data
    """
    batch_size = data.shape[0]
    num_anchors = data.shape[1]
    elem_length = data.shape[2]

    ib = tvm.tir.ir_builder.create()

    data = ib.buffer_ptr(data)
    valid_indices = ib.buffer_ptr(valid_indices)
    valid_boxes = ib.buffer_ptr(valid_boxes)

    out = ib.buffer_ptr(out)
    out_indices = ib.buffer_ptr(out_indices)
    one = tvm.tir.const(1, dtype=out.dtype)

    max_threads = int(
        tvm.target.Target.current(allow_none=False).max_num_threads)
    nthread_tx = max_threads
    nthread_bx = num_anchors // max_threads + 1
    nthread_by = batch_size
    with ib.new_scope():
        tx = te.thread_axis("threadIdx.x")
        bx = te.thread_axis("blockIdx.x")
        by = te.thread_axis("blockIdx.y")
        ib.scope_attr(tx, "thread_extent", nthread_tx)
        ib.scope_attr(bx, "thread_extent", nthread_bx)
        ib.scope_attr(by, "thread_extent", nthread_by)
        tid = bx * max_threads + tx
        with ib.if_scope(tid < num_anchors):
            i = by
            j = tid
            with ib.for_range(0, elem_length) as k:
                out[(i * num_anchors + j) * elem_length + k] = -one
            out_indices[i * num_anchors + j] = -1
    with ib.new_scope():
        tx = te.thread_axis("threadIdx.x")
        bx = te.thread_axis("blockIdx.x")
        by = te.thread_axis("blockIdx.y")
        ib.scope_attr(tx, "thread_extent", nthread_tx)
        ib.scope_attr(bx, "thread_extent", nthread_bx)
        ib.scope_attr(by, "thread_extent", nthread_by)
        tid = bx * max_threads + tx
        with ib.if_scope(tid < num_anchors):
            i = by
            j = tid
            with ib.if_scope(valid_boxes[i, tid] > 0):
                with ib.for_range(0, elem_length) as k:
                    out[(i * num_anchors + valid_indices[i, tid]) * elem_length
                        + k] = data[(i * num_anchors + j) * elem_length + k]
                out_indices[i * num_anchors + valid_indices[i, tid]] = j
    return ib.get()
Esempio n. 5
0
def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index,
                       score_index):
    """Low level IR to identify bounding boxes given a score threshold.

    Parameters
    ----------
    data : Buffer
        Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length].

    score_threshold : Buffer or float32
        Lower limit of score for valid bounding boxes.

    id_index : optional, int
        index of the class categories, -1 to disable.

    score_index: optional, int
        Index of the scores/confidence of boxes.

    Returns
    -------
    valid_boxes: Buffer
        2D Buffer  indicating valid boxes with shape [batch_size, num_anchors].

    """
    batch_size = data.shape[0]
    num_anchors = data.shape[1]
    elem_length = data.shape[2]

    ib = tvm.tir.ir_builder.create()

    data = ib.buffer_ptr(data)

    valid_boxes = ib.buffer_ptr(valid_boxes)
    if isinstance(score_threshold, float):
        score_threshold = tvm.tir.FloatImm("float32", score_threshold)
    id_index = tvm.tir.IntImm("int32", id_index)
    score_index = tvm.tir.IntImm("int32", score_index)

    max_threads = int(
        tvm.target.Target.current(allow_none=False).max_num_threads)
    with ib.new_scope():
        nthread_tx = max_threads
        nthread_bx = ceil_div(num_anchors, max_threads)
        nthread_by = batch_size
        tx = te.thread_axis("threadIdx.x")
        bx = te.thread_axis("blockIdx.x")
        by = te.thread_axis("blockIdx.y")
        ib.scope_attr(tx, "thread_extent", nthread_tx)
        ib.scope_attr(bx, "thread_extent", nthread_bx)
        ib.scope_attr(by, "thread_extent", nthread_by)
        tid = bx * max_threads + tx

        with ib.if_scope(tid < num_anchors):
            i = by
            j = tid
            score = data[(i * num_anchors + j) * elem_length + score_index]
            with ib.if_scope(
                    tvm.tir.all(
                        score > score_threshold,
                        tvm.tir.any(
                            id_index < 0,
                            data[(i * num_anchors + j) * elem_length +
                                 id_index] >= 0),
                    )):
                valid_boxes[i * num_anchors + j] = 1
            with ib.else_scope():
                valid_boxes[i * num_anchors + j] = 0
    return ib.get()
Esempio n. 6
0
def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
    assert N == 1, "Only consider batch_size = 1 in this template"

    data = te.placeholder((N, CI, H, W), name="data")
    kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
    conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype="float32")
    s = te.create_schedule([conv.op])

    ##### space definition begin #####
    n, f, y, x = s[conv].op.axis
    rc, ry, rx = s[conv].op.reduce_axis

    cfg = autotvm.get_config()
    cfg.define_split("tile_f", f, num_outputs=4)
    cfg.define_split("tile_y", y, num_outputs=4)
    cfg.define_split("tile_x", x, num_outputs=4)
    cfg.define_split("tile_rc", rc, num_outputs=3)
    cfg.define_split("tile_ry", ry, num_outputs=3)
    cfg.define_split("tile_rx", rx, num_outputs=3)
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
    cfg.define_knob("unroll_explicit", [0, 1])
    ##### space definition end #####

    # inline padding
    pad_data = s[conv].op.input_tensors[0]
    s[pad_data].compute_inline()
    data, raw_data = pad_data, data

    output = conv
    OL = s.cache_write(conv, "local")

    # create cache stage
    AA = s.cache_read(data, "shared", [OL])
    WW = s.cache_read(kernel, "shared", [OL])
    AL = s.cache_read(AA, "local", [OL])
    WL = s.cache_read(WW, "local", [OL])

    # tile and bind spatial axes
    n, f, y, x = s[output].op.axis
    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
    kernel_scope = n  # this is the scope to attach global config inside this kernel

    s[output].bind(bf, te.thread_axis("blockIdx.z"))
    s[output].bind(by, te.thread_axis("blockIdx.y"))
    s[output].bind(bx, te.thread_axis("blockIdx.x"))
    s[output].bind(vf, te.thread_axis("vthread"))
    s[output].bind(vy, te.thread_axis("vthread"))
    s[output].bind(vx, te.thread_axis("vthread"))
    s[output].bind(tf, te.thread_axis("threadIdx.z"))
    s[output].bind(ty, te.thread_axis("threadIdx.y"))
    s[output].bind(tx, te.thread_axis("threadIdx.x"))
    s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
    s[OL].compute_at(s[output], tx)

    # tile reduction axes
    n, f, y, x = s[OL].op.axis
    rc, ry, rx = s[OL].op.reduce_axis
    rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
    ryo, rym, ryi = cfg["tile_rx"].apply(s, OL, ry)
    rxo, rxm, rxi = cfg["tile_ry"].apply(s, OL, rx)
    s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)

    s[AA].compute_at(s[OL], rxo)
    s[WW].compute_at(s[OL], rxo)
    s[AL].compute_at(s[OL], rxm)
    s[WL].compute_at(s[OL], rxm)

    # cooperative fetching
    for load in [AA, WW]:
        n, f, y, x = s[load].op.axis
        fused = s[load].fuse(n, f, y, x)
        tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
        ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
        tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
        s[load].bind(tz, te.thread_axis("threadIdx.z"))
        s[load].bind(ty, te.thread_axis("threadIdx.y"))
        s[load].bind(tx, te.thread_axis("threadIdx.x"))

    # tune unroll
    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
    s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)

    return s, [raw_data, kernel, conv]
Esempio n. 7
0
    def apply(
        self, sch, op, axes, axis_lens=None, max_unroll=None, vec_size=None, cfg=None, source=None
    ):
        """Apply annotation to an array of axes

        Parameters
        ----------
        sch: tvm.te.schedule.Schedule
            The tvm schedule
        op: tvm.te.Operation
            The stage to be applied
        axes: Array of tvm.te.schedule.IterVar
            axis to split
        axis_lens: Array of int, optional
            the length of axes
        max_unroll: int, optional
            maximum unroll step
        vec_size: Array of int, optional
            valid vector lanes for vectorization
        cfg: ConfigEntity, optional
            cfg for recording error
        source: Array of Array tensor, optional
            source tensor for attaching cache

        Returns
        -------
        axes : list of tvm.te.schedule.IterVar
            The transformed axes
        """
        if source is not None:  # special case : attach cache_read/cache_write
            for src, to in zip(source, self.anns):
                for t in src:
                    sch[t].compute_at(sch[op], axes[to])
        else:  # other cases
            for i, ann in enumerate(self.anns):
                if ann == "none":
                    pass
                elif ann == "unroll":
                    if max_unroll and axis_lens[i] > max_unroll:
                        cfg.raise_error("Too large factor for unrolling")
                    sch[op].unroll(axes[i])
                elif ann == "vec":
                    if vec_size and axis_lens[i] not in vec_size:
                        cfg.raise_error("Wrong size of lanes in vectorization")
                    sch[op].vectorize(axes[i])
                elif ann == "blockIdx.x":
                    sch[op].bind(axes[i], thread_axis("blockIdx.x"))
                elif ann == "blockIdx.y":
                    sch[op].bind(axes[i], thread_axis("blockIdx.y"))
                elif ann == "blockIdx.z":
                    sch[op].bind(axes[i], thread_axis("blockIdx.z"))
                elif ann == "threadIdx.x":
                    sch[op].bind(axes[i], thread_axis("threadIdx.x"))
                elif ann == "threadIdx.y":
                    sch[op].bind(axes[i], thread_axis("threadIdx.y"))
                elif ann == "threadIdx.z":
                    sch[op].bind(axes[i], thread_axis("threadIdx.z"))
                elif ann == "vthread":
                    sch[op].bind(axes[i], thread_axis("vthread"))
                elif ann == "fuse":
                    assert i < len(axes) - 1
                    axes[i + 1] = sch[op].fuse(axes[i], axes[i + 1])
                else:
                    raise RuntimeError("Invalid annotation " + ann)
        return axes
Esempio n. 8
0
def _schedule_conv2d_NCHWc_int8(cfg, s, output):
    conv = output.op.input_tensors[0]
    packed_data, packed_kernel = conv.op.input_tensors

    if isinstance(packed_data.op, tvm.te.ComputeOp) and "pad" in packed_data.op.tag:
        pad_data = packed_data
        packed_data = pad_data.op.input_tensors[0]
    else:
        pad_data = packed_data

    if autotvm.GLOBAL_SCOPE.in_tuning:
        # skip this part during tuning to make recrods accurate
        # this part will be pre-computed during NNVM's pre-compute optimization pass
        s[packed_data].pragma(s[packed_data].op.axis[0], "debug_skip_region")
        s[packed_kernel].pragma(s[packed_kernel].op.axis[0], "debug_skip_region")
    else:
        if isinstance(packed_kernel.op, tvm.te.ComputeOp) and packed_kernel.name == "packed_kernel":
            # data and kernel are not pre-computed, schedule layout transform here
            schedule_injective_from_existing(s, packed_data)
            schedule_injective_from_existing(s, packed_kernel)

    if pad_data != packed_data:
        s[pad_data].compute_inline()

    # create cache stage
    AA = s.cache_read(pad_data, "shared", [conv])
    WW = s.cache_read(packed_kernel, "shared", [conv])

    s[conv].set_scope("local")

    # handle bias
    if output.op not in s.outputs:
        s[output].compute_inline()
        output = s.outputs[0].output(0)

    # tile and bind spatial axes
    if len(s[output].op.axis) == 5:
        n, f, y, x, c = s[output].op.axis
    else:
        # For task extraction of auto-tuning, the expected output is 4D.  Since auto-tuning tasks
        # are created from scratch, therefore the real auto-tuning will still happen on 5D output.
        n, f, y, x = s[output].op.axis

    cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
    cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
    cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
    cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)

    # this is the scope to attach global config inside this kernel
    kernel_scope, n = s[output].split(n, nparts=1)

    bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)

    s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi)
    s[output].bind(bn, te.thread_axis("blockIdx.z"))
    s[output].bind(bf, te.thread_axis("blockIdx.y"))
    s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x"))
    s[output].bind(vn, te.thread_axis("vthread"))
    s[output].bind(vf, te.thread_axis("vthread"))
    s[output].bind(vy, te.thread_axis("vthread"))
    s[output].bind(vx, te.thread_axis("vthread"))

    cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
    if cfg["fuse_yx"].val:
        s[output].bind(tn, te.thread_axis("threadIdx.z"))
        s[output].bind(tf, te.thread_axis("threadIdx.y"))
        tyx = s[output].fuse(ty, tx)
        s[output].bind(tyx, te.thread_axis("threadIdx.x"))
        s[conv].compute_at(s[output], tyx)

        # number of threads
        n_tz = cfg["tile_n"].size[2]
        n_ty = cfg["tile_f"].size[2]
        n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
    else:
        s[output].bind(s[output].fuse(tn, tf), te.thread_axis("threadIdx.z"))
        s[output].bind(ty, te.thread_axis("threadIdx.y"))
        s[output].bind(tx, te.thread_axis("threadIdx.x"))
        s[conv].compute_at(s[output], tx)

        # number of threads
        n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
        n_ty = cfg["tile_y"].size[2]
        n_tx = cfg["tile_x"].size[2]

    # tile and bind reduction axes
    n, f, y, x, c = s[conv].op.axis

    rc, ry, rx, rc_block = s[conv].op.reduce_axis
    cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=2)
    cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=2)
    cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=2)
    rco, rci = cfg["tile_rc"].apply(s, conv, rc)
    ryo, ryi = cfg["tile_ry"].apply(s, conv, ry)
    rxo, rxi = cfg["tile_rx"].apply(s, conv, rx)

    s[conv].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x, c, rc_block)

    cfg.define_reorder("reorder_inner", [rco, ryo, rxo], policy="all")
    cfg["reorder_inner"].apply(s, conv, [rco, ryo, rxo])
    cfg["reorder_inner"].apply(s, conv, [rci, ryi, rxi])

    _, rc_block = s[conv].split(rc_block, factor=4)
    s[conv].tensorize(rc_block, _dp4a)

    cache_loc = [rco, ryo, rxo][cfg["reorder_inner"].perm[-1]]
    s[AA].compute_at(s[conv], cache_loc)
    s[WW].compute_at(s[conv], cache_loc)

    # cooperative fetching
    for load in [AA, WW]:
        c = s[load].op.axis[-1]
        c_outer, c = s[load].split(c, factor=4)
        s[load].vectorize(c)
        fused = s[load].op.axis[:-1] + [c_outer]
        fused = s[load].fuse(*fused)

        fused, tx = s[load].split(fused, factor=n_tx)
        fused, ty = s[load].split(fused, factor=n_ty)
        fused, tz = s[load].split(fused, factor=n_tz)
        s[load].bind(tz, te.thread_axis("threadIdx.z"))
        s[load].bind(ty, te.thread_axis("threadIdx.y"))
        s[load].bind(tx, te.thread_axis("threadIdx.x"))

    # double buffer
    cfg.define_knob("AA_double_buffer", [0, 1])
    cfg.define_knob("WW_double_buffer", [0, 1])
    if cfg["AA_double_buffer"].val:
        s[AA].double_buffer()
    if cfg["WW_double_buffer"].val:
        s[WW].double_buffer()

    # unroll
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
    s[output].pragma(kernel_scope, "unroll_explicit", False)

    return s
Esempio n. 9
0
def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
    """Schedule tensorcore template"""
    packed_data, packed_kernel = s[Conv].op.input_tensors
    ic, kh, kw, ii = s[Conv].op.reduce_axis
    pad_data = s[packed_data].op.input_tensors[0]

    block_x = te.thread_axis('blockIdx.x')
    block_y = te.thread_axis('blockIdx.y')
    block_z = te.thread_axis('blockIdx.z')
    thread_x = te.thread_axis('threadIdx.x')
    thread_y = te.thread_axis('threadIdx.y')
    thread_z = te.thread_axis('threadIdx.z')

    # Designate the memory hierarchy
    AS = s.cache_read(packed_data, 'shared', [Conv])
    WS = s.cache_read(packed_kernel, 'shared', [Conv])
    AF = s.cache_read(AS, 'wmma.matrix_a', [Conv])
    WF = s.cache_read(WS, 'wmma.matrix_b', [Conv])
    ConvF = s.cache_write(Conv, 'wmma.accumulator')

    if Conv.op in s.outputs:
        output = Conv
        ConvS = s.cache_read(ConvF, 'shared', [Conv])
        OL = ConvS
    else:
        output = s.outputs[0].output(0)
        s[Conv].set_scope('shared')
        OL = Conv

    out_dtype = Conv.dtype

    if isinstance(packed_kernel.op, te.tensor.ComputeOp) and packed_kernel.name == "packed_kernel":
        if autotvm.GLOBAL_SCOPE.in_tuning:
            s[packed_kernel].pragma(
                s[packed_kernel].op.axis[0], "debug_skip_region")
        else:
            with Target('cuda'):
                schedule_injective_from_existing(s, packed_kernel)

    if isinstance(pad_data.op, te.tensor.ComputeOp) and "pad" in pad_data.op.tag:
        s[pad_data].compute_inline()
        data = pad_data.op.input_tensors[0]

        if autotvm.GLOBAL_SCOPE.in_tuning:
            # skip this part during tuning to make recrods accurate
            # this part will be pre-computed during NNVM's pre-compute optimization pass
            s[pad_data].pragma(s[pad_data].op.axis[0], "debug_skip_region")
    else:
        data = pad_data
        s[data].compute_inline()

    data_dtype = data.dtype
    kernel_dtype = packed_kernel.dtype

    # Schedule for autotvm
    cfg.define_knob("block_row_warps", [1, 2, 4])
    cfg.define_knob("block_col_warps", [1, 2, 4])
    cfg.define_knob("warp_row_tiles", [1, 2, 4, 8, 16])
    cfg.define_knob("warp_col_tiles", [1, 2, 4, 8, 16])
    cfg.define_knob("chunk", [1, 2, 4, 8])
    cfg.define_knob("fuse_pack", [0, 1])
    cfg.define_knob("split_block_k_nums", [1, 2, 4, 8, 16, 32])
    cfg.define_knob("vector_ws", [1, 8])
    cfg.define_knob("vector_as", [1, 8, 16])

    block_row_warps = cfg["block_row_warps"].val
    block_col_warps = cfg["block_col_warps"].val
    warp_row_tiles = cfg["warp_row_tiles"].val
    warp_col_tiles = cfg["warp_col_tiles"].val
    chunk = cfg["chunk"].val
    vector_as = cfg["vector_as"].val
    vector_ws = cfg["vector_ws"].val
    split_block_k_nums = cfg["split_block_k_nums"].val
    fuse_pack = cfg["fuse_pack"].val

    if not fuse_pack:
        s[packed_data].compute_inline()
    else:
        with Target('cuda'):
            schedule_injective_from_existing(s, packed_data)

    if data_dtype in ['int4', 'uint4']:
        wmma_m = wmma_n = 8
        wmma_k = 32
    else:
        wmma_m = 8
        wmma_n = 32
        wmma_k = 16

    warp_size = 32

    # Schedule for output
    if len(s[output].op.axis) == 4:
        hc, wc, nc, oc, = output.op.axis
        nc, nnc = s[output].split(nc, factor=wmma_m)
        oc, ooc = s[output].split(oc, factor=wmma_n)
    else:
        hc, wc, nc, oc, nnc, ooc = output.op.axis

    kernel_scope, hc = s[output].split(hc, nparts=1)

    block_k = s[output].fuse(hc, wc)
    block_k, split_block_k = s[output].split(
        block_k, factor=split_block_k_nums)
    nc, nci = s[output].split(nc, factor=warp_row_tiles)
    block_i, nc = s[output].split(nc, factor=block_row_warps)
    oc, oci = s[output].split(oc, factor=warp_col_tiles)
    block_j, oc = s[output].split(oc, factor=block_col_warps)
    s[output].reorder(block_k, split_block_k, block_i,
                      block_j, nc, oc, nci, oci, nnc, ooc)
    t = s[output].fuse(nnc, ooc)
    _, tx = s[output].split(t, factor=warp_size)
    s[output].bind(block_k, block_z)
    s[output].bind(block_i, block_x)
    s[output].bind(block_j, block_y)
    s[output].bind(tx, thread_x)
    s[output].bind(nc, thread_y)
    s[output].bind(oc, thread_z)

    # Schedule wmma store
    s[OL].compute_at(s[output], block_j)
    hc, wc, nc, oc, nnc, ooc = OL.op.axis
    oc, oci = s[OL].split(oc, factor=warp_col_tiles)
    _, oc = s[OL].split(oc, factor=block_col_warps)
    nc, nci = s[OL].split(nc, factor=warp_row_tiles)
    _, nc = s[OL].split(nc, factor=block_row_warps)
    s[OL].reorder(nc, oc, nci, oci, nnc, ooc)
    s[OL].bind(nc, thread_y)
    s[OL].bind(oc, thread_z)

    # Schedule local computation
    s[ConvF].compute_at(s[OL], oc)
    _, _, n, o, nnf, oof = ConvF.op.axis
    ko, ki = s[ConvF].split(ic, factor=chunk)
    s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii)

    cfg.define_reorder("reorder_inner", [ko, kh], policy="all")
    cfg["reorder_inner"].apply(s, ConvF, [ko, kh])
    cfg["reorder_inner"].apply(s, ConvF, [ki, kw])

    cfg.define_knob("compute_at_AS", [0, 1, 2, 3])
    cfg.define_knob("compute_at_WS", [0, 1, 2, 3])
    compute_at_AS = cfg["compute_at_AS"].val
    compute_at_WS = cfg["compute_at_WS"].val

    # Move intermediate computation into each output compute tile
    s[AF].compute_at(s[ConvF], kw)
    s[WF].compute_at(s[ConvF], kw)

    # Schedule for A's share memory
    if compute_at_AS == 0:
        s[AS].compute_at(s[ConvF], ki)
    elif compute_at_AS == 1:
        s[AS].compute_at(s[ConvF], kw)
    elif compute_at_AS == 2:
        s[AS].compute_at(s[ConvF], ko)
    else:
        s[AS].compute_at(s[ConvF], kh)
    _, _, n, _, nn, ii = AS.op.axis
    tx, xo = s[AS].split(n, nparts=block_row_warps)
    ty, _ = s[AS].split(xo, nparts=block_col_warps)
    t = s[AS].fuse(nn, ii)
    to, ti = s[AS].split(t, nparts=warp_size)
    ti, _t = s[AS].split(ti, factor=vector_as)
    s[AS].bind(tx, thread_y)
    s[AS].bind(ty, thread_z)
    s[AS].bind(to, thread_x)
    s[AS].vectorize(_t)

    # Schedule for W's share memory
    if compute_at_WS == 0:
        s[WS].compute_at(s[ConvF], ki)
    elif compute_at_WS == 1:
        s[WS].compute_at(s[ConvF], kw)
    elif compute_at_WS == 2:
        s[WS].compute_at(s[ConvF], ko)
    else:
        s[WS].compute_at(s[ConvF], kh)
    s[WS].compute_at(s[ConvF], kw)
    kh, kw, ic, o, ii, oo = WS.op.axis
    tx, xo = s[WS].split(o, nparts=block_row_warps)
    ty, _ = s[WS].split(xo, nparts=block_col_warps)
    t = s[WS].fuse(ii, oo)
    to, ti = s[WS].split(t, nparts=warp_size)
    ti, _t = s[WS].split(ti, factor=vector_ws)
    s[WS].bind(tx, thread_y)
    s[WS].bind(ty, thread_z)
    s[WS].bind(to, thread_x)
    s[WS].vectorize(ti)

    # double buffer
    cfg.define_knob('AS_double_buffer', [0, 1])
    cfg.define_knob('WS_double_buffer', [0, 1])
    if cfg['AS_double_buffer'].val:
        s[AS].double_buffer()
    if cfg['WS_double_buffer'].val:
        s[WS].double_buffer()

    # unroll
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
    s[output].pragma(kernel_scope, 'auto_unroll_max_step',
                     cfg['auto_unroll_max_step'].val)
    s[output].pragma(kernel_scope, 'unroll_explicit', False)

    shape = (wmma_m, wmma_n, wmma_k)

    AS_shape = (wmma_m, wmma_k)
    AL_shape = (wmma_m, wmma_k)
    WS_shape = (wmma_n, wmma_k)
    WL_shape = (wmma_n, wmma_k)
    CL_shape = (wmma_m, wmma_n)
    CS_shape = (wmma_m, wmma_n)

    AL_gemm = te.placeholder(AL_shape, name='A', dtype=data_dtype)
    WL_gemm = te.placeholder(WL_shape, name='B', dtype=kernel_dtype)
    k_gemm = te.reduce_axis((0, wmma_k), name="k")
    CL_compute = te.compute(CL_shape, lambda ii, jj:
                            te.sum((AL_gemm[ii, k_gemm].astype(
                                'int32') * WL_gemm[jj, k_gemm].astype('int32')), axis=k_gemm),
                            name='C')

    AL_strides = [wmma_k, 1]
    AS_strides = [wmma_k, 1]
    WL_strides = [wmma_k, 1]
    WS_strides = [wmma_k, 1]
    CL_strides = [wmma_n, 1]
    CS_strides = [wmma_n, 1]

    s[AF].tensorize(AF.op.axis[-2],
                    intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape,
                                              "row_major", AS_shape, AL_shape, data_dtype))

    s[WF].tensorize(WF.op.axis[-2],
                    intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape,
                                              "col_major", WS_shape, WL_shape, kernel_dtype))

    s[OL].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides,
                                                  shape, out_dtype, CL_shape, CS_shape))

    s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides,
                                             WL_strides, CL_strides, shape))

    return s
Esempio n. 10
0
def _schedule_group_conv2d_nchw_direct(cfg, s, conv):
    """Schedule group conv2d NCHW direct template"""
    workload = conv.op.attrs["workload"]
    groups = get_const_int(workload[6])
    num_filters = get_const_int(conv.shape[1])

    ##### space definition begin #####
    n, f, y, x = s[conv].op.axis
    rc, ry, rx = s[conv].op.reduce_axis
    cfg.define_split("tile_n", n, num_outputs=4)
    cfg.define_split("tile_g", cfg.axis(groups), num_outputs=2)
    cfg.define_split("tile_f", cfg.axis(num_filters // groups), num_outputs=4)
    cfg.define_split("tile_y", y, num_outputs=4)
    cfg.define_split("tile_x", x, num_outputs=4)
    cfg.define_split("tile_rc", rc, num_outputs=2)
    cfg.define_split("tile_ry", ry, num_outputs=2)
    cfg.define_split("tile_rx", rx, num_outputs=2)
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])

    target = tvm.target.Target.current()
    if target.kind.name in ["nvptx", "rocm"]:
        cfg.define_knob("unroll_explicit", [1])
    else:
        cfg.define_knob("unroll_explicit", [0, 1])

    pad_data, kernel = s[conv].op.input_tensors

    s[pad_data].compute_inline()

    if conv.op in s.outputs:
        output = conv
        OL = s.cache_write(conv, "local")
    else:
        output = s.outputs[0].output(0)
        s[conv].set_scope("local")
        OL = conv

    # create cache stage
    AA = s.cache_read(pad_data, "shared", [OL])
    WW = s.cache_read(kernel, "shared", [OL])

    # tile and bind spatial axes
    n, f, y, x = s[output].op.axis
    kernel_scope, n = s[output].split(n, nparts=1)

    g, f = s[output].split(f, nparts=groups)
    bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
    bg, vg = cfg["tile_g"].apply(s, output, g)
    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)

    s[output].reorder(bn, bg, bf, by, bx, vn, vg, vf, vy, vx, tn, tf, ty, tx,
                      ni, fi, yi, xi)
    s[output].bind(bn, te.thread_axis("blockIdx.z"))
    s[output].bind(s[output].fuse(bg, bf), te.thread_axis("blockIdx.y"))
    s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x"))
    s[output].bind(vn, te.thread_axis("vthread"))
    s[output].bind(vg, te.thread_axis("vthread"))
    s[output].bind(vf, te.thread_axis("vthread"))
    s[output].bind(vy, te.thread_axis("vthread"))
    s[output].bind(vx, te.thread_axis("vthread"))

    cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
    if cfg["fuse_yx"].val:
        s[output].bind(tn, te.thread_axis("threadIdx.z"))
        s[output].bind(tf, te.thread_axis("threadIdx.y"))
        tyx = s[output].fuse(ty, tx)
        s[output].bind(tyx, te.thread_axis("threadIdx.x"))
        s[OL].compute_at(s[output], tyx)

        # number of threads
        n_tz = cfg["tile_n"].size[2]
        n_ty = cfg["tile_f"].size[2]
        n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
    else:
        s[output].bind(s[output].fuse(tn, tf), te.thread_axis("threadIdx.z"))
        s[output].bind(ty, te.thread_axis("threadIdx.y"))
        s[output].bind(tx, te.thread_axis("threadIdx.x"))
        s[OL].compute_at(s[output], tx)

        # number of threads
        n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
        n_ty = cfg["tile_y"].size[2]
        n_tx = cfg["tile_x"].size[2]

    # tile reduction axes
    n, f, y, x = s[OL].op.axis
    rc, ry, rx = s[OL].op.reduce_axis
    rco, rci = cfg["tile_rc"].apply(s, OL, rc)
    ryo, ryi = cfg["tile_rx"].apply(s, OL, ry)
    rxo, rxi = cfg["tile_ry"].apply(s, OL, rx)
    s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)

    s[AA].compute_at(s[OL], rxo)
    s[WW].compute_at(s[OL], rxo)

    # cooperative fetching
    for load in [AA, WW]:
        n, f, y, x = s[load].op.axis
        fused = s[load].fuse(n, f, y, x)
        fused, tx = s[load].split(fused, factor=n_tx)
        fused, ty = s[load].split(fused, factor=n_ty)
        fused, tz = s[load].split(fused, factor=n_tz)
        s[load].bind(tz, te.thread_axis("threadIdx.z"))
        s[load].bind(ty, te.thread_axis("threadIdx.y"))
        s[load].bind(tx, te.thread_axis("threadIdx.x"))

    # unroll
    s[output].pragma(kernel_scope, "auto_unroll_max_step",
                     cfg["auto_unroll_max_step"].val)
    s[output].pragma(kernel_scope, "unroll_explicit",
                     cfg["unroll_explicit"].val)

    N, CO, OH, OW = get_const_tuple(output.shape)
    _, CI_div_groups, KH, KW = get_const_tuple(kernel.shape)
    cfg.add_flop(2 * N * OH * OW * CO * CI_div_groups * KH * KW)
Esempio n. 11
0
def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
    """Schedule tensorcore template"""
    kh, kw, ic = s[Conv].op.reduce_axis
    out_dtype = Conv.dtype
    trans_paddata, kernel = s[Conv].op.input_tensors
    in_dtype = trans_paddata.dtype
    batch, _, _, _ = get_const_tuple(Conv.shape)
    _, _, _, out_channels = get_const_tuple(kernel.shape)
    paddata = s[trans_paddata].op.input_tensors

    # inline the pad and dtype transform
    s[trans_paddata].compute_inline()
    s[kernel].compute_inline()
    s[paddata[0]].compute_inline()

    # Designate the memory hierarchy
    AS = s.cache_read(trans_paddata, 'shared', [Conv])
    WS = s.cache_read(kernel, 'shared', [Conv])
    AF = s.cache_read(AS, 'wmma.matrix_a', [Conv])
    WF = s.cache_read(WS, 'wmma.matrix_b', [Conv])
    ConvF = s.cache_write(Conv, 'wmma.accumulator')

    if Conv.op in s.outputs:
        output = Conv
        ConvS = s.cache_read(ConvF, 'shared', [Conv])
        OL = ConvS
    else:
        output = s.outputs[0].output(0)
        s[Conv].set_scope('shared')
        OL = Conv

    # Schedule for autotvm
    cfg.define_knob("block_row_warps", [1, 2, 4])
    cfg.define_knob("block_col_warps", [1, 2, 4])
    cfg.define_knob("warp_row_tiles", [1, 2, 4])
    cfg.define_knob("warp_col_tiles", [1, 2, 4])
    cfg.define_knob("chunk", [1, 2, 4, 8])
    cfg.define_knob("offset", [0, 8])
    cfg.define_knob("vector_width", [1, 2, 4, 8])

    if (batch % 16 == 0 and out_channels % 16 == 0):
        cfg.define_knob("wmma_m", [16, 8, 32])
    elif (batch % 8 == 0 and out_channels % 32 == 0):
        cfg.define_knob("wmma_m", [8, 16, 32])
    elif (batch % 32 == 0 and out_channels % 8 == 0):
        cfg.define_knob("wmma_m", [32, 16, 8])

    # fallback support
    target = tvm.target.Target.current()
    if cfg.is_fallback:
        ref_log = autotvm.tophub.load_reference_log(
            target.id.name, target.model, 'conv2d_nhwc_tensorcore.cuda')
        cfg.fallback_with_reference_log(ref_log)

    block_row_warps = cfg["block_row_warps"].val
    block_col_warps = cfg["block_col_warps"].val
    warp_row_tiles = cfg["warp_row_tiles"].val
    warp_col_tiles = cfg["warp_col_tiles"].val
    chunk = cfg["chunk"].val
    offset = cfg["offset"].val
    wmma_m = cfg["wmma_m"].val
    vector_width = cfg["vector_width"].val

    wmma_k = 16
    if wmma_m == 16:
        wmma_n = 16
    elif wmma_m == 8:
        wmma_n = 32
    elif wmma_m == 32:
        wmma_n = 8

    warp_size = 32

    block_x = te.thread_axis('blockIdx.x')
    block_y = te.thread_axis('blockIdx.y')
    block_z = te.thread_axis('blockIdx.z')
    thread_x = te.thread_axis('threadIdx.x')
    thread_y = te.thread_axis('threadIdx.y')
    thread_z = te.thread_axis('threadIdx.z')

    # Define the intrin strides
    def get_strides(extents):
        return [np.prod(extents[i:]).tolist() for i in range(len(extents))]

    AS_align = chunk * wmma_k + offset
    WS_align = warp_col_tiles * block_col_warps * wmma_n + offset
    block_factor_n = wmma_m * warp_row_tiles * block_row_warps
    block_factor_o = wmma_n * warp_col_tiles * block_col_warps
    CS_align = block_factor_o + offset
    AS_strides = get_strides([1, 1, AS_align, 1])
    AL_strides = get_strides([1, 1, wmma_k, 1])
    WS_strides = get_strides([WS_align, 1])
    WL_strides = get_strides([wmma_n * warp_col_tiles, 1])
    CL_strides = get_strides([1, 1, wmma_n * warp_col_tiles, 1])
    CS_strides = get_strides([1, 1, CS_align, 1])

    # Schedule for output
    nc, hc, wc, oc = output.op.axis
    block_k = s[output].fuse(hc, wc)
    s[output].bind(block_k, block_z)
    block_i, nc = s[output].split(nc, factor=block_factor_n)
    block_j, oc = s[output].split(oc, factor=block_factor_o)
    s[output].reorder(block_k, block_i, block_j, nc, oc)
    t = s[output].fuse(nc, oc)
    t, ti = s[output].split(t, factor=vector_width)
    t, tx = s[output].split(t, factor=warp_size)
    t, ty = s[output].split(t, factor=block_row_warps)
    t, tz = s[output].split(t, factor=block_col_warps)
    s[output].bind(block_i, block_x)
    s[output].bind(block_j, block_y)
    s[output].bind(tz, thread_z)
    s[output].bind(ty, thread_y)
    s[output].bind(tx, thread_x)
    s[output].vectorize(ti)

    # Schedule wmma store
    s[OL].compute_at(s[output], block_j)
    nc, hc, wc, oc = OL.op.axis
    s[OL].reorder(hc, wc, nc, oc)
    s[OL].storage_align(wc, CS_align - 1, CS_align)
    oc, ooc = s[OL].split(oc, factor=wmma_n)
    oc, oci = s[OL].split(oc, factor=warp_col_tiles)
    _, oc = s[OL].split(oc, factor=block_col_warps)
    nc, nnc = s[OL].split(nc, factor=wmma_m)
    nc, nci = s[OL].split(nc, factor=warp_row_tiles)
    _, nc = s[OL].split(nc, factor=block_row_warps)
    s[OL].reorder(nc, oc, nci, oci, nnc, ooc)
    s[OL].bind(nc, thread_y)
    s[OL].bind(oc, thread_z)

    # Schedule wmma computation
    s[ConvF].compute_at(s[OL], oc)
    n, h, w, o = ConvF.op.axis
    n, nnf = s[ConvF].split(n, factor=wmma_m)
    o, oof = s[ConvF].split(o, factor=wmma_n)
    ic, ii = s[ConvF].split(ic, factor=wmma_k)
    ko, ki = s[ConvF].split(ic, factor=chunk)
    s[ConvF].reorder(kh, kw, ko, ki, n, o, nnf, oof, ii)

    s[AF].compute_at(s[ConvF], ki)
    s[WF].compute_at(s[ConvF], ki)

    # Schedule wmma load
    n, h, w, i = AF.op.axis
    n, nn = s[AF].split(n, factor=wmma_m)
    i, ii = s[AF].split(i, factor=wmma_k)
    s[AF].reorder(n, i, nn, ii)

    kh, kw, i, o = WF.op.axis
    i, ii = s[WF].split(i, factor=wmma_k)
    o, oo = s[WF].split(o, factor=wmma_n)
    s[WF].reorder(o, i, oo)
    s[WF].reorder(i, o, ii, oo)

    s[WS].compute_at(s[ConvF], ko)
    s[AS].compute_at(s[ConvF], ko)

    # Schedule for data's share memory
    n, h, w, i = AS.op.axis
    s[AS].reorder(h, w, n, i)
    s[AS].storage_align(w, AS_align - 1, AS_align)
    t = s[AS].fuse(n, i)
    t, ti = s[AS].split(t, factor=vector_width)
    t, tx = s[AS].split(t, factor=warp_size)
    t, ty = s[AS].split(t, factor=block_row_warps)
    _, tz = s[AS].split(t, factor=block_col_warps)
    s[AS].bind(ty, thread_y)
    s[AS].bind(tz, thread_z)
    s[AS].bind(tx, thread_x)
    s[AS].vectorize(ti)

    # Schedule for kernel's share memory
    kh, kw, ic, o = WS.op.axis
    t = s[WS].fuse(ic, o)
    s[WS].storage_align(ic, WS_align - 1, WS_align)
    t, ti = s[WS].split(t, factor=vector_width)
    t, tx = s[WS].split(t, factor=warp_size)
    t, ty = s[WS].split(t, factor=block_row_warps)
    _, tz = s[WS].split(t, factor=block_col_warps)
    s[WS].bind(ty, thread_y)
    s[WS].bind(tz, thread_z)
    s[WS].bind(tx, thread_x)
    s[WS].vectorize(ti)

    shape = (wmma_m, wmma_n, wmma_k)

    # tensorize the wmma process
    AS_shape = (wmma_m, 1, 1, wmma_k)
    AL_shape = (wmma_m, 1, 1, wmma_k)
    WS_shape = (wmma_k, wmma_n)
    WL_shape = (wmma_k, wmma_n)
    CL_shape = (wmma_m, 1, 1, wmma_n)
    CS_shape = (wmma_m, 1, 1, wmma_n)

    AL_gemm = te.placeholder(AL_shape, name='A', dtype=in_dtype)
    WL_gemm = te.placeholder(WL_shape, name='B', dtype=in_dtype)
    k_gemm = te.reduce_axis((0, wmma_k), name="k")
    CL_compute = te.compute(CL_shape, lambda ii, t0, t1, jj:
                            te.sum(AL_gemm[ii, t0, t1, k_gemm].astype(out_dtype) * \
                                   WL_gemm[k_gemm, jj].astype(out_dtype), axis=k_gemm),
                            name='C')

    s[AF].tensorize(nn, intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape,
                                                  "row_major", AS_shape, AL_shape, in_dtype))
    s[WF].tensorize(ii, intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape,
                                                  "row_major", WS_shape, WL_shape, in_dtype))
    s[OL].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides,
                                                  shape, out_dtype, CL_shape, CS_shape))
    s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides,
                                             WL_strides, CL_strides, shape))

    N, OH, OW, CO = get_const_tuple(output.shape)
    KH, KW, CI, _ = get_const_tuple(kernel.shape)
    cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW)
Esempio n. 12
0
def gen_ir_2d(data, indices, updates, axis, out, update_func):
    """Generate scatter ir for 2d inputs

    Parameters
    ----------
    data : tir.Tensor
        The input data to the operator.

    indices : tir.Tensor
        The index locations to update.

    updates : tir.Tensor
        The values to update.

    axis : int
        The axis to scatter on

    out : tir.Tensor
        The output tensor.

    update_func: function
        The function to be applied to a destination and the corresponding update

    Returns
    -------
    ret : tir
        The computational ir.
    """
    n = data.shape[0]
    c = data.shape[1]

    ib = tvm.tir.ir_builder.create()

    out_ptr = ib.buffer_ptr(out)
    data_ptr = ib.buffer_ptr(data)

    _memcpy_ir(ib, out_ptr, data_ptr, data.shape)

    indices_ptr = ib.buffer_ptr(indices)
    updates_ptr = ib.buffer_ptr(updates)

    ni = indices.shape[0]
    ci = indices.shape[1]

    if axis == 0:
        with ib.new_scope():
            j = te.thread_axis("blockIdx.x")
            ib.scope_attr(j, "thread_extent", ci)
            with ib.for_range(0, ni, name="i") as i:
                idx = i * ci + j
                index = indices_ptr[idx]
                with ib.if_scope(index < 0):
                    update_func(out_ptr, (index + n) * c + j, updates_ptr[idx])
                with ib.else_scope():
                    update_func(out_ptr, index * c + j, updates_ptr[idx])
    else:
        with ib.new_scope():
            i = te.thread_axis("blockIdx.x")
            ib.scope_attr(i, "thread_extent", ni)
            with ib.for_range(0, ci, name="j") as j:
                idx = i * ci + j
                index = indices_ptr[idx]
                with ib.if_scope(index < 0):
                    update_func(out_ptr, i * c + (index + c), updates_ptr[idx])
                with ib.else_scope():
                    update_func(out_ptr, i * c + index, updates_ptr[idx])
    return ib.get()
Esempio n. 13
0
    def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
        ib = tvm.tir.ir_builder.create()

        data = ib.buffer_ptr(data_ptr)
        indices = ib.buffer_ptr(indices_ptr)
        updates = ib.buffer_ptr(updates_ptr)
        out = ib.buffer_ptr(out_ptr)

        # We combine all the indices dimensions but the first one into a single
        # dimension so we can iterate it in single loop instead of an arbitrary
        # number of loops. We do the same thing for all the update dimensions.
        fused_indices_dimension = 1
        for i in indices_ptr.shape[1:]:
            fused_indices_dimension *= i

        fused_updates_dimension = 1
        for i in updates_ptr.shape[len(indices_ptr.shape) - 1:]:
            fused_updates_dimension *= i

        fused_shape = 1
        for i in data_ptr.shape:
            fused_shape *= i

        # For now we avoid parallizing over dimensions indexed by `indices` as
        # there may be repeated indices and hadling parallel accumulation can
        # be hard. So we parallelize over X_M .. X_{N-1} instead. This will
        # work well when these dimensions are large enough to saturate memory
        # bandwidth, but performance will be bad when these dimensions are
        # small.
        bx = te.thread_axis("blockIdx.x")
        tx = te.thread_axis("threadIdx.x")
        max_threads = int(
            tvm.target.Target.current(allow_none=False).max_num_threads)
        tdim = min(max_threads, fused_updates_dimension)
        ib.scope_attr(tx, "thread_extent", tdim)
        bdim = ceil_div(fused_updates_dimension, tdim)
        ib.scope_attr(bx, "thread_extent", bdim)

        # Copy data into the output. This loop writes to the same portions of
        # memory as the following loop, so we do not need a memory sync.
        with ib.for_range(0,
                          ceil_div(fused_shape, fused_updates_dimension),
                          name="i") as i:
            index = i * fused_updates_dimension + bx * tdim + tx
            with ib.if_scope(bx * tdim + tx < fused_updates_dimension):
                out[index] = data[index]

        with ib.for_range(0, fused_indices_dimension) as i:
            j = bx * tdim + tx
            with ib.if_scope(j < fused_updates_dimension):
                offset = fused_updates_dimension
                index = j  # This is x_M, .. x_{N-1} part of the index into out.
                # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
                # of the index into out.
                for l in reversed(range(indices_ptr.shape[0].value)):
                    # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
                    index += offset * indices[i + l * fused_indices_dimension]
                    offset *= data_ptr.shape[l]
                if mode == "update":
                    out[index] = updates[i * fused_updates_dimension + j]
                elif mode == "add":
                    out[index] += updates[i * fused_updates_dimension + j]
                else:
                    raise NotImplementedError(
                        "scatter_nd mode not in [update, add]:", mode)

        return ib.get()
Esempio n. 14
0
def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _):
    """Generate scatter add ir for 1d inputs, using atomic_add instruction

    Parameters
    ----------
    data : tir.Tensor
        The input data to the operator.

    indices : tir.Tensor
        The index locations to update.

    updates : tir.Tensor
        The values to update.

    axis : int
        The axis to scatter on

    out : tir.Tensor
        The output tensor.

    Returns
    -------
    ret : tir
        The computational ir.
    """
    assert axis == 0
    n = data.shape[0]

    ib = tvm.tir.ir_builder.create()

    out_ptr = ib.buffer_ptr(out)
    data_ptr = ib.buffer_ptr(data)

    max_threads = int(
        tvm.target.Target.current(allow_none=False).max_num_threads)
    nthread_tx = max_threads

    with ib.new_scope():
        nthread_bx = ceil_div(n, nthread_tx)
        tx = te.thread_axis("threadIdx.x")
        bx = te.thread_axis("blockIdx.x")
        ib.scope_attr(tx, "thread_extent", nthread_tx)
        ib.scope_attr(bx, "thread_extent", nthread_bx)
        tid = bx * nthread_tx + tx
        with ib.if_scope(tid < n):
            out_ptr[tid] = data_ptr[tid]

    indices_ptr = ib.buffer_ptr(indices)
    updates_ptr = ib.buffer_ptr(updates)

    ni = indices.shape[0]

    atomic_add_return = ib.allocate(updates.dtype, (1, ),
                                    name="atomic_add_return",
                                    scope="local")

    with ib.new_scope():
        nthread_bx = ceil_div(ni, nthread_tx)
        tx = te.thread_axis("threadIdx.x")
        bx = te.thread_axis("blockIdx.x")
        ib.scope_attr(tx, "thread_extent", nthread_tx)
        ib.scope_attr(bx, "thread_extent", nthread_bx)
        tid = bx * nthread_tx + tx

        with ib.if_scope(tid < ni):
            index = indices_ptr[tid]
            with ib.if_scope(index < 0):
                atomic_add_return[0] = atomic_add(
                    tvm.tir.call_intrin("handle", "tir.address_of",
                                        out_ptr[index + n]),
                    updates_ptr[tid],
                )
            with ib.else_scope():
                atomic_add_return[0] = atomic_add(
                    tvm.tir.call_intrin("handle", "tir.address_of",
                                        out_ptr[index]),
                    updates_ptr[tid],
                )

    return ib.get()
Esempio n. 15
0
def _schedule_dense_int8(cfg, s, output):
    data, weight = s[output].op.input_tensors

    batch, in_dim = get_const_tuple(data.shape)
    out_dim, _ = get_const_tuple(weight.shape)

    in_dim_factor = 4
    assert in_dim % in_dim_factor == 0, "Input dimension must divide {}".format(in_dim_factor)
    if in_dim % 16 == 0:
        in_dim_factor = 16

    # create tuning space
    cfg.define_split("tile_y", batch, num_outputs=4)
    cfg.define_split("tile_x", out_dim, num_outputs=4)
    cfg.define_split("tile_k", in_dim // in_dim_factor, num_outputs=2)
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])

    # create cache stage
    AA = s.cache_read(data, "shared", [output])
    WW = s.cache_read(weight, "shared", [output])
    CC = s.cache_write(output, "local")

    # handle bias
    if output.op not in s.outputs:
        s[output].compute_inline()
        output = s.outputs[0].output(0)

    n, x = s[output].op.axis

    # this is the scope to attach global config inside this kernel
    kernel_scope, n = s[output].split(n, nparts=1)

    ko = CC.op.reduce_axis[0]
    ko, ki = s[CC].split(ko, factor=4)
    ko, kt = cfg["tile_k"].apply(s, CC, ko)
    s[CC].tensorize(ki, _dp4a)
    by, vy, ty, yi = cfg["tile_y"].apply(s, output, n)
    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)

    s[output].reorder(by, bx, vy, vx, ty, tx, yi, xi)
    s[output].bind(by, te.thread_axis("blockIdx.y"))
    s[output].bind(bx, te.thread_axis("blockIdx.x"))
    s[output].bind(vy, te.thread_axis("vthread"))
    s[output].bind(vx, te.thread_axis("vthread"))
    s[output].bind(ty, te.thread_axis("threadIdx.y"))
    s[output].bind(tx, te.thread_axis("threadIdx.x"))
    n_ty = cfg["tile_y"].size[2]
    n_tx = cfg["tile_x"].size[2]

    s[CC].compute_at(s[output], tx)
    yo, xo = CC.op.axis[:2]
    s[CC].reorder(ko, kt, yo, xo, ki)

    for load in [AA, WW]:
        s[load].compute_at(s[CC], ko)

        outer, inner = s[load].split(s[load].op.axis[-1], factor=in_dim_factor)
        s[load].vectorize(inner)
        fused = s[load].op.axis[:-1] + [outer]
        fused = s[load].fuse(*fused)

        fused, tx = s[load].split(fused, factor=n_tx)
        fused, ty = s[load].split(fused, factor=n_ty)
        s[load].bind(tx, te.thread_axis("threadIdx.x"))
        s[load].bind(ty, te.thread_axis("threadIdx.y"))

    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
    s[output].pragma(kernel_scope, "unroll_explicit", False)
    return s
Esempio n. 16
0
def schedule(attrs):
    cfg, s, output = attrs.auto_config, attrs.scheduler, attrs.outputs[0]
    th_vals, rd_vals = [attrs.get_extent(x) for x in output.op.axis], [
        attrs.get_extent(x) for x in output.op.reduce_axis
    ]

    C = output
    A, B = C.op.input_tensors

    y, x = s[C].op.axis
    k = s[C].op.reduce_axis[0]

    # storage_align params
    factor = 16
    offset = 8

    layout = 'NN'
    for opt in attrs.options:
        if opt.startswith('layout/'):
            layout = opt[len('layout/'):]
            break
    '''
    if dtype == 'int8':
      factor = 32
      offset = 16
    '''
    # create cache stages
    AA = s.cache_read(A, "shared", [C])
    if (layout == "NN" or layout == "TN"):
        s[AA].storage_align(AA.op.axis[0], factor, offset)
    AL = s.cache_read(AA, "local", [C])
    BB = s.cache_read(B, "shared", [C])
    if (layout == "TT" or layout == "NT"):
        s[BB].storage_align(BB.op.axis[0], factor, offset)
    BL = s.cache_read(BB, "local", [C])
    CL = s.cache_write(C, "local")

    # autotvm search space definition
    cfg.define_knob("bx", [2, 4, 8])
    cfg.define_knob("by", [16, 32, 64])
    cfg.define_knob("step_k", [8, 16, 32])
    cfg.define_knob("v", [4, 8])
    by = cfg['by'].val
    bx = cfg['bx'].val
    step_k = cfg['step_k'].val
    v = cfg['v'].val

    # thread tile
    TX, TY = 8, 1

    # warp tile
    cfg.define_knob("warp_m", [16, 8, 32])
    warp_tile_m = cfg[
        'warp_m'].val  # it could be 8, 16, 32 on CUDA version >= 10.0
    warp_tile_k = 16  # it must be 16
    # block tile
    tile_x = bx * TX
    tile_y = by * TY

    yo, ty = s[C].split(y, tile_y)
    ty, yi = s[C].split(ty, TY)

    # schedule for C stage
    xo, xi = s[C].split(x, tile_x)
    WX = min(warp_tile_m, tile_x)
    tz, xi = s[C].split(xi, WX)
    tx, xi = s[C].split(xi, TX)
    s[C].reorder(yo, xo, tz, ty, tx, yi, xi)
    s[C].bind(yo, te.thread_axis("blockIdx.y"))
    s[C].bind(xo, te.thread_axis("blockIdx.x"))
    s[C].bind(ty, te.thread_axis("threadIdx.y"))
    s[C].bind(tz, te.thread_axis("threadIdx.z"))
    s[C].bind(tx, te.thread_axis("threadIdx.x"))

    # schedule for CL stage
    ko, ki = s[CL].split(k, step_k * warp_tile_k)
    kl, ki = s[CL].split(ki, warp_tile_k)
    s[CL].compute_at(s[C], tx)
    yo, xo = CL.op.axis
    s[CL].reorder(ko, kl, ki, yo, xo)

    # schedule for AA stage
    s[AA].compute_at(s[CL], ko)
    xo, xi = s[AA].split(s[AA].op.axis[1], factor=bx * v)
    tz, tx = s[AA].split(xi, factor=(WX // TX) * v)
    tx, vec = s[AA].split(tx, factor=v)
    fused = s[AA].fuse(s[AA].op.axis[0], xo)
    _, ty = s[AA].split(fused, factor=by)
    s[AA].bind(ty, te.thread_axis("threadIdx.y"))
    s[AA].bind(tz, te.thread_axis("threadIdx.z"))
    s[AA].bind(tx, te.thread_axis("threadIdx.x"))
    # vectorization is very important for float16/int8 inputs
    s[AA].vectorize(vec)

    # schedule for BB stage
    s[BB].compute_at(s[CL], ko)
    xo, xi = s[BB].split(s[BB].op.axis[1], factor=bx * v)
    tz, tx = s[BB].split(xi, factor=(WX // TX) * v)
    tx, vec = s[BB].split(tx, factor=v)
    fused = s[BB].fuse(s[BB].op.axis[0], xo)
    _, ty = s[BB].split(fused, factor=by)
    s[BB].bind(ty, te.thread_axis("threadIdx.y"))
    s[BB].bind(tz, te.thread_axis("threadIdx.z"))
    s[BB].bind(tx, te.thread_axis("threadIdx.x"))
    s[BB].vectorize(vec)

    s[AL].compute_at(s[CL], kl)
    s[BL].compute_at(s[CL], kl)

    s[CL].pragma(ko, 'tensor_core')
Esempio n. 17
0
def rnn_matexp():
    n_num_step = 128
    n_num_hidden = 1152
    n_batch_size = 4
    detect_global_barrier = DETECT_GLOBAL_BARRIER

    num_step = te.var("num_step")
    num_hidden = tvm.runtime.convert(n_num_hidden)
    batch_size = tvm.runtime.convert(n_batch_size)
    num_thread_y = 8
    num_thread_x = 16 * 3
    num_sm = 24

    Whh = te.placeholder((num_hidden, num_hidden), name="Whh")
    s_init = te.compute((1, batch_size, num_hidden),
                        lambda _, i, j: 1.0,
                        name="init")
    s_state = te.placeholder((num_step, batch_size, num_hidden))
    kh = te.reduce_axis((0, num_hidden), name="kh")
    s_update = te.compute(
        (num_step, batch_size, num_hidden),
        lambda t, i, j: te.sum(s_state[t - 1, i, kh] * Whh[kh, j], axis=kh),
        name="update",
    )
    s_scan = tvm.te.scan(s_init, s_update, s_state)
    # schedule
    s = te.create_schedule(s_scan.op)
    CL = s_update
    SS = s.cache_read(s_state, "shared", [CL])
    SL = s.cache_read(SS, "local", [CL])
    WhhL = s.cache_read(Whh, "local", [CL])
    ko, ki = s[CL].split(s[CL].op.reduce_axis[0], nparts=num_thread_y)
    CLF = s.rfactor(CL, ko)

    block_x = te.thread_axis((0, num_sm), "blockIdx.x")
    thread_x = te.thread_axis((0, num_thread_x), "threadIdx.x")
    thread_y = te.thread_axis((0, num_thread_y), "threadIdx.y")
    if PERSIST_KERNEL:
        s[s_scan.op].env_threads([block_x, thread_y, thread_x])

    bx, xi = s[s_init].split(s_init.op.axis[2], nparts=num_sm)
    tx, xi = s[s_init].split(xi, nparts=num_thread_x)
    s[s_init].bind(bx, block_x)
    s[s_init].bind(tx, thread_x)

    bx, xi = s[s_update].split(s[CL].op.axis[2], nparts=num_sm)
    tx, xi = s[s_update].split(xi, nparts=num_thread_x)
    s[s_update].bind(bx, block_x)
    s[s_update].bind(tx, thread_x)
    s[CL].bind(s[CL].op.reduce_axis[0], thread_y)
    s[CLF].compute_at(s[CL], s[CL].op.reduce_axis[0])
    # Duplicate store predicate.
    s[CL].set_store_predicate(thread_y.equal(0))

    if PERSIST_KERNEL:
        s[WhhL].compute_at(s[s_scan], thread_x)
        s[WhhL].unroll(WhhL.op.axis[0])
    else:
        s[WhhL].compute_at(s[CLF], CLF.op.axis[3])

    kr, ki = s[CLF].split(CLF.op.reduce_axis[0], nparts=1)
    ko, ki = s[CLF].split(ki, factor=4)
    s[SS].compute_at(s[CLF], kr)
    s[SL].compute_at(s[CLF], ko)

    xo, xi = s[SS].split(SS.op.axis[2], factor=num_thread_x * num_thread_y * 3)
    ty, xi = s[SS].split(xi, nparts=num_thread_y)
    tx, xi = s[SS].split(xi, nparts=num_thread_x)
    s[SS].bind(ty, thread_y)
    s[SS].bind(tx, thread_x)

    def check_device(target):
        with tvm.transform.PassContext(
                config={
                    "tir.UnrollLoop": {
                        "auto_max_step": 128,
                    },
                    "tir.detect_global_barrier": detect_global_barrier,
                }):
            f = tvm.build(s, [s_scan, Whh], target)
        dev = tvm.cuda(0) if target == "cuda" else tvm.cl(0)
        # launch the kernel.
        res_np = np.zeros(
            (n_num_step, n_batch_size, n_num_hidden)).astype("float32")
        Whh_np = np.zeros((n_num_hidden, n_num_hidden)).astype("float32")
        Whh_np[:] = 2.0 / n_num_hidden
        Whh_np[:, n_num_hidden // 2:] = 0

        res_a = tvm.nd.array(res_np, dev)
        Whh_a = tvm.nd.array(Whh_np, dev)
        # Skip first pass as it is compilation
        f(res_a, Whh_a)
        dev.sync()
        # measure time cost of second step.
        tstart = time.time()
        f(res_a, Whh_a)
        dev.sync()
        tgap = time.time() - tstart
        print("Time cost=%g" % tgap)
        # correctness
        if not SKIP_CHECK:
            res_cuda = res_a.asnumpy()
            res_cmp = np.ones_like(res_np).astype("float64")
            Whh_np = Whh_np.astype("float64")
            for t in range(1, n_num_step):
                res_cmp[t][:] = np.dot(res_cmp[t - 1], Whh_np)
            for i in range(n_num_step):
                for j in range(n_num_hidden):
                    if abs(res_cmp[i, 0, j] - res_cuda[i, 0, j]) > 1e-5:
                        print("%d, %d: %g vs %g" %
                              (i, j, res_cmp[i, 0, j], res_cuda[i, 0, j]))
            tvm.testing.assert_allclose(res_cuda, res_cmp, rtol=1e-3)

    check_device("cuda")
    def _callback(op):
        if op.tag == "conv2d_transpose_nchw":
            pad_data = op.input_tensors[0]
            kernel = op.input_tensors[1]
            conv = op.output(0)

            ##### space definition begin #####
            n, f, y, x = s[conv].op.axis
            rc = s[conv].op.reduce_axis[0]
            cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
            cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
            cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
            cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
            cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
            cfg.define_knob("auto_unroll_max_step", [64, 512, 1500])

            target = tvm.target.Target.current()
            if target.kind.name in ["nvptx", "rocm"]:
                cfg.define_knob("unroll_explicit", [1])
            else:
                cfg.define_knob("unroll_explicit", [0, 1])

            if cfg.is_fallback:
                N, F, Y, X = get_const_tuple(conv.shape)
                _fallback_schedule(N, F, Y, X)

            ##### space definition end #####

            if isinstance(kernel.op,
                          tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
                s[kernel].compute_inline()

            if conv.op in s.outputs:
                output = conv
                OL = s.cache_write(conv, "local")
            else:
                output = s.outputs[0].output(0)
                s[conv].set_scope("local")
                OL = conv

            # create cache stage
            s[pad_data].set_scope("shared")
            AA = pad_data
            WW = s.cache_read(kernel, "shared", [OL])

            # tile and bind spatial axes
            n, f, y, x = s[output].op.axis
            kernel_scope, n = s[output].split(n, nparts=1)
            bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
            bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
            by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
            bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)

            s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx,
                              ni, fi, yi, xi)
            s[output].bind(bn, te.thread_axis("blockIdx.z"))
            s[output].bind(bf, te.thread_axis("blockIdx.y"))
            s[output].bind(s[output].fuse(by, bx),
                           te.thread_axis("blockIdx.x"))
            s[output].bind(vn, te.thread_axis("vthread"))
            s[output].bind(vf, te.thread_axis("vthread"))
            s[output].bind(vy, te.thread_axis("vthread"))
            s[output].bind(vx, te.thread_axis("vthread"))

            cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf

            if cfg["fuse_yx"].val:
                s[output].bind(tn, te.thread_axis("threadIdx.z"))
                s[output].bind(tf, te.thread_axis("threadIdx.y"))
                tyx = s[output].fuse(ty, tx)
                s[output].bind(s[output].fuse(ty, tx),
                               te.thread_axis("threadIdx.x"))
                s[OL].compute_at(s[output], tyx)

                # number of threads
                n_tz = cfg["tile_n"].size[2]
                n_ty = cfg["tile_f"].size[2]
                n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
            else:
                s[output].bind(s[output].fuse(tn, tf),
                               te.thread_axis("threadIdx.z"))
                s[output].bind(ty, te.thread_axis("threadIdx.y"))
                s[output].bind(tx, te.thread_axis("threadIdx.x"))
                s[OL].compute_at(s[output], tx)

                # number of threads
                n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
                n_ty = cfg["tile_y"].size[2]
                n_tx = cfg["tile_x"].size[2]

            # tile reduction axes
            n, f, y, x = s[OL].op.axis
            rc, ry, rx = s[OL].op.reduce_axis
            rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
            s[OL].reorder(rco, rcm, ry, rx, rci, n, f, y, x)

            s[AA].compute_at(s[OL], rx)
            s[WW].compute_at(s[OL], rx)

            # cooperative fetching
            for load in [AA, WW]:
                n, f, y, x = s[load].op.axis
                fused = s[load].fuse(f, y, x)
                tz, fused = s[load].split(fused, nparts=n_tz)
                ty, fused = s[load].split(fused, nparts=n_ty)
                tx, fused = s[load].split(fused, nparts=n_tx)
                s[load].bind(tz, te.thread_axis("threadIdx.z"))
                s[load].bind(ty, te.thread_axis("threadIdx.y"))
                s[load].bind(tx, te.thread_axis("threadIdx.x"))

            s[output].pragma(kernel_scope, "auto_unroll_max_step",
                             cfg["auto_unroll_max_step"].val)
            s[output].pragma(kernel_scope, "unroll_explicit",
                             cfg["unroll_explicit"].val)
Esempio n. 19
0
def test_rpc_module():
    # graph
    n = tvm.runtime.convert(1024)
    A = te.placeholder((n, ), name="A")
    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
    a_np = np.random.uniform(size=1024).astype(A.dtype)
    temp = utils.tempdir()

    # Establish remote connection with target hardware
    tracker = rpc.connect_tracker(tracker_host, tracker_port)
    remote = tracker.request(key, priority=0, session_timeout=60)

    # Compile the Graph for CPU target
    s = te.create_schedule(B.op)
    xo, xi = s[B].split(B.op.axis[0], factor=64)
    s[B].parallel(xi)
    s[B].pragma(xo, "parallel_launch_point")
    s[B].pragma(xi, "parallel_barrier_when_finish")
    f = tvm.build(s, [A, B], target, name="myadd_cpu")
    path_dso_cpu = temp.relpath("cpu_lib.so")
    f.export_library(path_dso_cpu, ndk.create_shared)

    # Execute the portable graph on cpu target
    print("Run CPU test ...")
    dev = remote.cpu(0)
    remote.upload(path_dso_cpu)
    f2 = remote.load_module("cpu_lib.so")
    a = tvm.nd.array(a_np, dev)
    b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev)
    time_f = f2.time_evaluator(f2.entry_name, dev, number=10)
    cost = time_f(a, b).mean
    print("%g secs/op\n" % cost)
    np.testing.assert_equal(b.numpy(), a.numpy() + 1)

    # Compile the Graph for OpenCL target
    if test_opencl:
        s = te.create_schedule(B.op)
        xo, xi = s[B].split(B.op.axis[0], factor=64)
        s[B].bind(xi, te.thread_axis("threadIdx.x"))
        s[B].bind(xo, te.thread_axis("blockIdx.x"))
        # Build the dynamic lib.
        # If we don't want to do metal and only use cpu, just set target to be target
        f = tvm.build(s, [A, B],
                      tvm.target.Target("opencl", host=target),
                      name="myadd")
        path_dso_cl = temp.relpath("dev_lib_cl.so")
        f.export_library(path_dso_cl, ndk.create_shared)

        print("Run GPU(OpenCL Flavor) test ...")
        dev = remote.cl(0)
        remote.upload(path_dso_cl)
        f1 = remote.load_module("dev_lib_cl.so")
        a = tvm.nd.array(a_np, dev)
        b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev)
        time_f = f1.time_evaluator(f1.entry_name, dev, number=10)
        cost = time_f(a, b).mean
        print("%g secs/op\n" % cost)
        np.testing.assert_equal(b.numpy(), a.numpy() + 1)

    # Compile the Graph for Vulkan target
    if test_vulkan:
        s = te.create_schedule(B.op)
        xo, xi = s[B].split(B.op.axis[0], factor=64)
        s[B].bind(xi, te.thread_axis("threadIdx.x"))
        s[B].bind(xo, te.thread_axis("blockIdx.x"))
        # Build the dynamic lib.
        # If we don't want to do metal and only use cpu, just set target to be target
        f = tvm.build(s, [A, B],
                      tvm.target.Target("vulkan", host=target),
                      name="myadd")
        path_dso_vulkan = temp.relpath("dev_lib_vulkan.so")
        f.export_library(path_dso_vulkan, ndk.create_shared)

        print("Run GPU(Vulkan Flavor) test ...")
        dev = remote.vulkan(0)
        remote.upload(path_dso_vulkan)
        f1 = remote.load_module("dev_lib_vulkan.so")
        a = tvm.nd.array(a_np, dev)
        b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev)
        time_f = f1.time_evaluator(f1.entry_name, dev, number=10)
        cost = time_f(a, b).mean
        print("%g secs/op\n" % cost)
        np.testing.assert_equal(b.numpy(), a.numpy() + 1)
 def copy_to_texture(stage):
     _y, _x, _v = s[stage].op.axis
     # TODO(csullivan): removing this vectorize results in numerical errors, autovectorize
     s[stage].vectorize(_v)
     s[stage].bind(_y, te.thread_axis("blockIdx.x"))
     s[stage].bind(_x, te.thread_axis("threadIdx.x"))
Esempio n. 21
0
def get_valid_indices_ir(valid_boxes, valid_count, valid_indices):
    """Low level IR to get the ouput indices of valid boxes
    and the count of valid boxes

    Parameters
    ----------
    valid_boxes: Buffer
        2D Buffer  indicating valid boxes with shape [batch_size, num_anchors].

    Returns
    -------
    valid_count: Buffer
        1D Buffer of number of valid boxes per batch [batch_size].

    valid_indices: Buffer
        2D Buffer indicating output sorted indcies of valid boxes [batch_size, num_anchors].
    """
    batch_size = valid_boxes.shape[0]
    num_anchors = valid_boxes.shape[1]

    ib = tvm.tir.ir_builder.create()

    valid_boxes = ib.buffer_ptr(valid_boxes)

    valid_count = ib.buffer_ptr(valid_count)
    valid_indices = ib.buffer_ptr(valid_indices)

    max_threads = int(
        tvm.target.Target.current(allow_none=False).max_num_threads)
    with ib.if_scope(num_anchors > 0):
        # Copy boxes to valid_indices
        with ib.new_scope():
            nthread_tx = max_threads
            nthread_bx = ceil_div(num_anchors, max_threads)
            nthread_by = batch_size
            tx = te.thread_axis("threadIdx.x")
            bx = te.thread_axis("blockIdx.x")
            by = te.thread_axis("blockIdx.y")
            ib.scope_attr(tx, "thread_extent", nthread_tx)
            ib.scope_attr(bx, "thread_extent", nthread_bx)
            ib.scope_attr(by, "thread_extent", nthread_by)
            tid = bx * nthread_tx + tx
            with ib.if_scope(tid < num_anchors):
                valid_indices[by, tid] = valid_boxes[by, tid]

        nthread_tx = max_threads
        nthread_bx = ceil_div(num_anchors, max_threads)
        nthread_by = batch_size

        ## The following algorithm performs parallel exclusive scan to get
        ## a tensor that can later be used to select valid indices
        # Up Sweep of exclusive scan
        lim = tvm.tir.generic.cast(
            tvm.tir.ceil(
                tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))),
            "int64")
        with ib.for_range(0, lim, dtype="int64") as l2_width:
            width = 2 << l2_width

            with ib.new_scope():
                tx = te.thread_axis("threadIdx.x")
                bx = te.thread_axis("blockIdx.x")
                ib.scope_attr(tx, "thread_extent", nthread_tx)
                ib.scope_attr(
                    bx,
                    "thread_extent",
                    tvm.tir.generic.cast(
                        ceil_div(num_anchors, max_threads * width), "int32"),
                )
                tid = bx * nthread_tx + tx

                by = te.thread_axis("blockIdx.y")
                ib.scope_attr(by, "thread_extent", nthread_by)
                start = ib.allocate("int64", (1, ),
                                    name="start",
                                    scope="local")
                middle = ib.allocate("int64", (1, ),
                                     name="middle",
                                     scope="local")
                end = ib.allocate("int64", (1, ), name="end", scope="local")
                start[0] = width * tid
                with ib.if_scope(start[0] < num_anchors):
                    middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
                    end[0] = tvm.te.min(start[0] + width, num_anchors)
                    with ib.if_scope(middle[0] < num_anchors):
                        valid_indices[by * num_anchors + end[0] -
                                      1] += valid_indices[by * num_anchors +
                                                          middle[0] - 1]

        # Down Sweep of exclusive scan
        with ib.new_scope():
            bx = te.thread_axis("blockIdx.x")
            ib.scope_attr(bx, "thread_extent", batch_size)
            with ib.if_scope(bx < batch_size):
                valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1]
                valid_indices[(bx + 1) * num_anchors - 1] = 0

        with ib.for_range(0, lim, dtype="int64") as l2_width:
            width = 2 << (lim - l2_width - 1)

            with ib.new_scope():
                tx = te.thread_axis("threadIdx.x")
                bx = te.thread_axis("blockIdx.x")
                ib.scope_attr(tx, "thread_extent", nthread_tx)
                ib.scope_attr(
                    bx,
                    "thread_extent",
                    tvm.tir.generic.cast(
                        ceil_div(num_anchors, max_threads * width), "int32"),
                )
                tid = bx * nthread_tx + tx

                by = te.thread_axis("blockIdx.y")
                ib.scope_attr(by, "thread_extent", nthread_by)
                start = ib.allocate("int64", (1, ),
                                    name="start",
                                    scope="local")
                middle = ib.allocate("int64", (1, ),
                                     name="middle",
                                     scope="local")
                end = ib.allocate("int64", (1, ), name="end", scope="local")
                tmp = ib.allocate("int32", (1, ), name="end", scope="local")
                start[0] = width * tid
                with ib.if_scope(tvm.tir.all(start[0] < num_anchors)):
                    middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
                    end[0] = tvm.tir.min(start[0] + width, num_anchors)
                    with ib.if_scope(middle[0] < num_anchors):
                        tmp[0] = valid_indices[by * num_anchors + middle[0] -
                                               1]
                        valid_indices[by * num_anchors + middle[0] -
                                      1] = valid_indices[by * num_anchors +
                                                         end[0] - 1]
                        valid_indices[by * num_anchors + end[0] - 1] += tmp[0]
    with ib.else_scope():
        with ib.new_scope():
            bx = te.thread_axis("blockIdx.x")
            ib.scope_attr(bx, "thread_extent", batch_size)
            with ib.if_scope(bx < batch_size):
                valid_count[bx] = 0

    return ib.get()
def schedule_conv2d_1x1_WCHNc_CRSKk(data, filt, packed_data, packed_filter,
                                    conv):
    # data: [W, C, H*N, c]
    # filter: [C, R*S*K, k]
    # output: [W, K, H, N, k]

    # conv2d( [N, C, H, W, c] , [1, 1, C, K, k]
    # inputs: (1, 128//4, 56, 56, 4), (1, 1, 128, 128//4, 4)

    # data: (56, 128//4, 56*1, 4) = (56, 32, 56, 4)
    # filt: (128, 1*1*128//4, 4) = (128, 32, 4)
    # conv: (56, 32, 56, 1, 4)

    s = te.create_schedule(conv.op)
    cfg = autotvm.get_config()

    s[packed_data].compute_inline()
    s[packed_filter].compute_inline()
    A, B, C = packed_data, packed_filter, conv
    At = s.cache_read(A, "global.texture", [C])
    Bt = s.cache_read(B, "global.texture", [C])
    Al = s.cache_read(At, "local", [C])
    Bl = s.cache_read(Bt, "local", [C])
    Cl = s.cache_write(C, "local")

    def copy_to_texture(stage):
        axes = s[stage].op.axis
        fused = s[stage].fuse(*axes[:-1])
        block, thread = s[stage].split(fused, factor=32)
        s[stage].vectorize(axes[-1])
        s[stage].bind(block, te.thread_axis("blockIdx.x"))
        s[stage].bind(thread, te.thread_axis("threadIdx.x"))

    copy_to_texture(At)
    copy_to_texture(Bt)

    _w, _ko, _h, _n, _ki = s[C].op.axis
    kernel_scope, _n = s[C].split(_n, nparts=1)

    cfg.define_split("tile_f", _ko, num_outputs=4)
    cfg.define_split("tile_w", _w, num_outputs=4)
    cfg.define_split("tile_h", _h, num_outputs=4)
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])

    bk, vk, tk, ki = cfg["tile_f"].apply(s, C, _ko)
    bw, vw, tw, wi = cfg["tile_w"].apply(s, C, _w)
    bh, vh, th, hi = cfg["tile_h"].apply(s, C, _h)
    s[C].reorder(bh, _n, vh, th, hi)
    bhn = s[C].fuse(bh, _n)

    s[C].bind(bk, te.thread_axis("blockIdx.z"))
    s[C].bind(bhn, te.thread_axis("blockIdx.y"))
    s[C].bind(bw, te.thread_axis("blockIdx.x"))
    s[C].bind(vk, te.thread_axis("vthread"))
    s[C].bind(vh, te.thread_axis("vthread"))
    s[C].bind(vw, te.thread_axis("vthread"))
    s[C].bind(tk, te.thread_axis("threadIdx.z"))
    s[C].bind(th, te.thread_axis("threadIdx.y"))
    s[C].bind(tw, te.thread_axis("threadIdx.x"))
    s[C].reorder(bw, bk, bhn, vw, vk, vh, tw, tk, th, ki, hi, wi, _ki)
    s[C].vectorize(_ki)

    # TODO(csullivan): Try uneven workgroup split
    # _wo, _wi = s[C].split(_w, factor=4)
    # #_hno, _hni = s[C].split(_hn, factor=8)
    # #s[C].reorder(_wo, _wi, _ko, _hno, _hni, _ki)
    # s[C].reorder(_wo, _ko, _hn, _ki, _wi)
    # s[C].unroll(_wi)

    # # mace:
    # # const int out_ch_blk = get_global_id(0);
    # # const int out_w_blk = get_global_id(1);
    # # const int out_hb = get_global_id(2);

    # bx = te.thread_axis("blockIdx.x")
    # by = te.thread_axis("blockIdx.y")
    # bz = te.thread_axis("blockIdx.z")
    # s[C].bind(_ko, bx)
    # s[C].bind(_wo, by)
    # s[C].bind(_hn, bz)

    # s[Cl].compute_at(s[C], _hn)
    s[Cl].compute_at(s[C], th)

    _wl, _kol, _hl, _nl, _kil = s[Cl].op.axis
    _khl, _kwl, _cl, _cl4 = s[Cl].op.reduce_axis

    cfg.define_split("tile_c", _cl, num_outputs=2)
    cfg.define_split("tile_kh", _khl, num_outputs=2)
    cfg.define_split("tile_kw", _kwl, num_outputs=2)

    _clo, _cli = cfg["tile_c"].apply(s, Cl, _cl)
    _khlo, _khli = cfg["tile_kh"].apply(s, Cl, _khl)
    _kwlo, _kwli = cfg["tile_kw"].apply(s, Cl, _kwl)
    # s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)
    s[Cl].reorder(_clo, _khlo, _kwlo, _cli, _cl4, _khli, _kwli, _kol, _hl, _nl,
                  _kil, _wl)
    # s[Cl].reorder(_clo, _khlo, _kwlo, _cli, _cl4, _khli, _kwli)
    # s[Cl].reorder(_cl, _cl4, _kil, _wl)
    s[Cl].unroll(_cl4)
    s[Cl].unroll(_wl)
    s[Cl].vectorize(_kil)

    _wla, _cla, _hnla, _cl4a = s[Al].op.axis
    s[Al].compute_at(s[Cl], _cli)
    s[Al].vectorize(_cl4a)
    s[Al].unroll(_wla)

    _clb, _rskolb, _kilb = s[Bl].op.axis
    s[Bl].compute_at(s[Cl], _cli)
    s[Bl].vectorize(_kilb)
    s[Bl].unroll(_clb)

    s[C].pragma(kernel_scope, "auto_unroll_max_step",
                cfg["auto_unroll_max_step"].val)

    WO, K, HO, N, K4 = get_const_tuple(C.shape)
    RSC, _, _ = get_const_tuple(B.shape)
    cfg.add_flop(2 * N * K * K4 * HO * WO * RSC)

    return s
Esempio n. 23
0
def nms_ir(
    data,
    sorted_index,
    valid_count,
    indices,
    out,
    box_indices,
    num_valid_boxes,
    max_output_size,
    iou_threshold,
    force_suppress,
    top_k,
    coord_start,
    id_index,
    score_index,
    return_indices,
):
    """Low level IR routing for transform location in multibox_detection operator.

    Parameters
    ----------
    data : Buffer
        Buffer of output boxes with class and score.

    sorted_index : Buffer
        Buffer of output box indexes sorted by score.

    valid_count : Buffer
        Buffer of number of valid output boxes.

    indices : Buffer
        indices in original tensor, with shape [batch_size, num_anchors],
        represents the index of box in original data. It could be the third
        output out_indices of get_valid_counts. The values in the second
        dimension are like the output of arange(num_anchors) if get_valid_counts
        is not used before non_max_suppression.

    out : Buffer
        Output buffer, to be filled with sorted boxes.

    box_indices : Buffer
        A indices tensor mapping sorted indices to original indices
        This is the first output of NMS when return_indices=True.

    num_valid_boxes : Buffer
        Record the number of boxes that have survived IOU tests.
        This is the second output of NMS when return_indices=True.

    max_output_size : int
        Max number of output valid boxes for each instance.
        By default all valid boxes are returned.

    iou_threshold : float
        Overlapping(IoU) threshold to suppress object with smaller score.

    force_suppress : boolean
        Whether to suppress all detections regardless of class_id.

    top_k : int
        Keep maximum top k detections before nms, -1 for no limit.

    coord_start : int
        Start index of the consecutive 4 coordinates.

    id_index : int
        index of the class categories, -1 to disable.

    score_index : optional, int
        Index of the scores/confidence of boxes.

    return_indices : boolean
        Whether to return box indices in input data.

    Returns
    -------
    stmt : Stmt
        The result IR statement.
    """
    def get_boundaries(output, box_idx):
        l = tvm.te.min(
            output[box_idx],
            output[box_idx + 2],
        )
        t = tvm.te.min(
            output[box_idx + 1],
            output[box_idx + 3],
        )
        r = tvm.te.max(
            output[box_idx],
            output[box_idx + 2],
        )
        b = tvm.te.max(
            output[box_idx + 1],
            output[box_idx + 3],
        )
        return l, t, r, b

    def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
        """Calculate overlap of two boxes."""
        a_l, a_t, a_r, a_b = get_boundaries(out_tensor, box_a_idx)
        b_l, b_t, b_r, b_b = get_boundaries(out_tensor, box_b_idx)

        # Overlapping width and height
        w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l))
        h = tvm.te.max(0.0, tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t))

        # Overlapping area
        area = h * w

        # total area of the figure formed by box a and box b
        # except for overlapping area
        u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area
        return tvm.tir.Select(u <= 0.0, 0.0, area / u)

    batch_size = data.shape[0]
    num_anchors = data.shape[1]
    box_data_length = data.shape[2]

    ib = tvm.tir.ir_builder.create()

    data = ib.buffer_ptr(data)
    sorted_index = ib.buffer_ptr(sorted_index)
    valid_count = ib.buffer_ptr(valid_count)
    indices = ib.buffer_ptr(indices)
    num_valid_boxes = ib.buffer_ptr(num_valid_boxes)
    out = ib.buffer_ptr(out)
    box_indices = ib.buffer_ptr(box_indices)

    if isinstance(iou_threshold, float):
        iou_threshold = tvm.tir.FloatImm("float32", iou_threshold)
    top_k = tvm.tir.IntImm("int32", top_k)
    coord_start = tvm.tir.IntImm("int32", coord_start)
    id_index = tvm.tir.IntImm("int32", id_index)
    score_index = tvm.tir.IntImm("int32", score_index)
    force_suppress = tvm.tir.IntImm("int32", 1 if force_suppress else 0)

    max_threads = int(
        tvm.target.Target.current(allow_none=False).max_num_threads)

    with ib.new_scope():
        nthread_tx = max_threads
        nthread_bx = ceil_div(num_anchors, max_threads)
        nthread_by = batch_size
        tx = te.thread_axis("threadIdx.x")
        bx = te.thread_axis("blockIdx.x")
        by = te.thread_axis("blockIdx.y")
        ib.scope_attr(by, "thread_extent", nthread_by)
        ib.scope_attr(tx, "thread_extent", nthread_tx)
        ib.scope_attr(bx, "thread_extent", nthread_bx)
        i = by
        base_idx = i * num_anchors * box_data_length
        with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
            # Reorder output
            nkeep = if_then_else(
                tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k,
                valid_count[i])
            j = bx * max_threads + tx
            with ib.if_scope(j < num_anchors):
                box_indices[i * num_anchors + j] = -1
            with ib.if_scope(j < nkeep):
                # Fill in out with sorted boxes
                with ib.for_range(0, box_data_length) as k:
                    out[(base_idx + j * box_data_length + k)] = data[(
                        base_idx +
                        sorted_index[i * num_anchors + j] * box_data_length +
                        k)]
            with ib.else_scope():
                # Indices > nkeep are discarded
                with ib.if_scope(j < num_anchors):
                    with ib.for_range(0, box_data_length) as k:
                        out[(base_idx + j * box_data_length + k)] = -1.0
        with ib.else_scope():
            with ib.if_scope(j < valid_count[i]):
                with ib.for_range(0, box_data_length) as k:
                    offset = base_idx + j * box_data_length + k
                    out[offset] = data[offset]
                box_indices[i * num_anchors + j] = j

    with ib.new_scope():
        nthread_by = batch_size
        nthread_tx = max_threads

        by = te.thread_axis("blockIdx.y")
        tx = te.thread_axis("threadIdx.x")
        ib.scope_attr(by, "thread_extent", nthread_by)
        ib.scope_attr(tx, "thread_extent", nthread_tx)

        i = by

        base_idx = i * num_anchors * box_data_length
        num_valid_boxes_local = ib.allocate("int32", (1, ),
                                            name="num_valid_boxes_local",
                                            scope="local")
        num_valid_boxes_local[0] = 0
        nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]),
                             top_k, valid_count[i])

        def nms_inner_loop(ib, j):
            # The box j is valid, invalidate other boxes that overlap with j above iou_threshold

            # When return_indices is False, no need to populate box_indices
            if return_indices:
                with ib.if_scope(tx + 0 == 0):
                    orig_idx = sorted_index[i * num_anchors + j]
                    box_indices[i,
                                num_valid_boxes_local[0]] = indices[i,
                                                                    orig_idx]

            num_valid_boxes_local[0] += 1

            offset_j = j * box_data_length
            num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx)

            with ib.for_range(0, num_iter_per_thread) as _k:
                k = j + 1 + _k * nthread_tx + tx
                offset_k = k * box_data_length

                with ib.if_scope(
                        tvm.tir.all(
                            k < nkeep,
                            out[base_idx + offset_k + score_index] >
                            0,  # is the box k still valid?
                            tvm.tir.any(
                                force_suppress > 0,
                                id_index < 0,
                                out[base_idx + offset_k +
                                    id_index] == out[base_idx + offset_j +
                                                     id_index],
                            ),
                        )):
                    iou = calculate_overlap(
                        out,
                        base_idx + offset_j + coord_start,
                        base_idx + offset_k + coord_start,
                    )
                    with ib.if_scope(iou >= iou_threshold):
                        # invalidate the box k
                        out[base_idx + offset_k + score_index] = -1.0
                        with ib.if_scope(id_index >= 0):
                            out[base_idx + offset_k + id_index] = -1.0

                # Make sure to do the next loop in a lock step
                ib.emit(
                    tvm.tir.Call(None, "tir.tvm_storage_sync",
                                 tvm.runtime.convert(["shared"])))

        if isinstance(max_output_size, int):
            max_output_size = tvm.tir.const(max_output_size)

        with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
            # Apply nms
            with ib.for_range(0, nkeep) as j:
                # Proceed to the inner loop if the box j is still valid
                with ib.if_scope(
                        out[base_idx +
                            (j * box_data_length) + score_index] > -1.0):
                    with ib.if_scope(max_output_size > 0):
                        # No need to do more iteration if we already reach max_output_size boxes
                        with ib.if_scope(
                                num_valid_boxes_local[0] < max_output_size):
                            nms_inner_loop(ib, j)
                    with ib.else_scope():
                        nms_inner_loop(ib, j)

            with ib.if_scope(tx + 0 == 0):
                num_valid_boxes[i] = num_valid_boxes_local[0]

        with ib.else_scope():
            num_valid_boxes[i] = 0

    return ib.get()
def schedule_depthwise_conv2d_NCHWc_KCRSk_acc32(cfg, s, output):
    """schedule optimized for batch size = 1"""

    conv = output.op.input_tensors[0]

    ##### space definition begin #####
    n, fc, y, x, fb = s[conv].op.axis
    ry, rx = s[conv].op.reduce_axis
    cfg.define_split("tile_fc", fc, num_outputs=4)
    cfg.define_split("tile_y", y, num_outputs=4)
    cfg.define_split("tile_x", x, num_outputs=4)
    cfg.define_split("tile_ry", ry, num_outputs=2)
    cfg.define_split("tile_rx", rx, num_outputs=2)
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])

    pad_data, flattened_kernel = s[conv].op.input_tensors
    kernel = s[flattened_kernel].op.input_tensors[0]
    s[flattened_kernel].compute_inline()

    s[pad_data].compute_inline()
    if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
        s[kernel].compute_inline()
    kernel = flattened_kernel

    if conv.op in s.outputs:
        output = conv
        OL = s.cache_write(conv, "local")
    else:
        output = s.outputs[0].output(0)
        s[conv].set_scope("local")
        OL = conv

    # create cache stage
    AT = s.cache_read(pad_data, "global.texture", [OL])
    WT = s.cache_read(kernel, "global.texture", [OL])

    def copy_to_texture(stage):
        axes = s[stage].op.axis
        fused = s[stage].fuse(*axes[:-1])
        block, thread = s[stage].split(fused, factor=32)
        s[stage].vectorize(axes[-1])
        s[stage].bind(block, te.thread_axis("blockIdx.x"))
        s[stage].bind(thread, te.thread_axis("threadIdx.x"))

    copy_to_texture(AT)
    copy_to_texture(WT)

    AA = s.cache_read(AT, "shared", [OL])
    WW = s.cache_read(WT, "shared", [OL])

    # tile and bind spatial axes
    n, fc, y, x, fb = s[output].op.axis

    kernel_scope, n = s[output].split(n, nparts=1)

    bf, vf, tf, fi = cfg["tile_fc"].apply(s, output, fc)
    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)

    bf = s[output].fuse(n, bf)
    s[output].bind(bf, te.thread_axis("blockIdx.z"))
    s[output].bind(by, te.thread_axis("blockIdx.y"))
    s[output].bind(bx, te.thread_axis("blockIdx.x"))
    s[output].bind(vf, te.thread_axis("vthread"))
    s[output].bind(vy, te.thread_axis("vthread"))
    s[output].bind(vx, te.thread_axis("vthread"))
    s[output].bind(tf, te.thread_axis("threadIdx.z"))
    s[output].bind(ty, te.thread_axis("threadIdx.y"))
    s[output].bind(tx, te.thread_axis("threadIdx.x"))
    s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi, fb)
    s[output].vectorize(fb)

    s[OL].compute_at(s[output], tx)

    # tile reduction axes
    n, fc, y, x, fb = s[OL].op.axis

    ry, rx = s[OL].op.reduce_axis
    ryo, ryi = cfg["tile_ry"].apply(s, OL, ry)
    rxo, rxi = cfg["tile_rx"].apply(s, OL, rx)

    s[OL].reorder(ryo, rxo, ryi, rxi, n, fc, y, x, fb)
    s[OL].vectorize(fb)
    # s[OL].unroll()

    s[AA].compute_at(s[OL], rxo)
    s[WW].compute_at(s[OL], rxo)
    # cooperative fetching
    for load in [AA, WW]:
        if load == WW:
            n, fyx, v = s[load].op.axis
            fused = s[load].fuse(n, fyx)
        else:
            n, f, y, x, v = s[load].op.axis
            fused = s[load].fuse(n, f, y, x)
        tz, fused = s[load].split(fused, nparts=cfg["tile_fc"].size[2])
        ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
        tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
        s[load].bind(tz, te.thread_axis("threadIdx.z"))
        s[load].bind(ty, te.thread_axis("threadIdx.y"))
        s[load].bind(tx, te.thread_axis("threadIdx.x"))
        s[load].vectorize(v)

    # unroll
    s[output].pragma(kernel_scope, "auto_unroll_max_step",
                     cfg["auto_unroll_max_step"].val)

    N, OCC, OH, OW, OCB = get_const_tuple(output.shape)
    ICC, MKHKW, ICB = get_const_tuple(kernel.shape)
    M = (OCC * OCB) // (ICC * ICB)
    KHKW = MKHKW // M

    if isinstance(N, int):
        cfg.add_flop(2 * N * OH * OW * OCC * OCB * KHKW)
Esempio n. 25
0
#
# .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/conv_gpu_blocking.png
#      :align: center
#      :height: 308px
#      :width: 317px
#

# tile consts
tile = 8
num_thread = 8
block_factor = tile * num_thread
step = 8
vthread = 2

# Get the GPU thread indices
block_x = te.thread_axis("blockIdx.x")
block_y = te.thread_axis("blockIdx.y")
block_z = te.thread_axis("blockIdx.z")
thread_x = te.thread_axis((0, num_thread), "threadIdx.x")
thread_y = te.thread_axis((0, num_thread), "threadIdx.y")
thread_xz = te.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = te.thread_axis((0, vthread), "vthread", name="vy")

# Split the workloads
hi, wi, fi, ni = s[B].op.axis
bz = s[B].fuse(hi, wi)
by, fi = s[B].split(fi, factor=block_factor)
bx, ni = s[B].split(ni, factor=block_factor)

# Bind the iteration variables to GPU thread indices
s[B].bind(bz, block_z)
Esempio n. 26
0
    def mcpu_auto_schedule(s, output, prefix):
        hyper_params = [[-1, 2, 8, 4], [-1, 1, 512, 1]]
        slice_data, slice_reduce = [], []
        for i in range(len(output.op.axis)):
            slice_data.append(
                cfg.define_split(f"{prefix}:D{i}",
                                 attrs.get_extent(output.op.axis[i]),
                                 num_outputs=4,
                                 init_vals=[
                                     hyper_params[i % len(hyper_params)],
                                 ]))
        for i in range(len(output.op.reduce_axis)):
            slice_reduce.append(
                cfg.define_split(f"{prefix}:R{i}",
                                 attrs.get_extent(output.op.reduce_axis[i]),
                                 num_outputs=2,
                                 init_vals=[
                                     [-1, 4],
                                 ]))

        unroll = cfg.define_knob(f"{prefix}:UN", [1, 4, 8, 16, 32, 64],
                                 init_vals=[
                                     1,
                                 ] if attrs.backend == 'c-mcpu_avx512' else [
                                     0,
                                 ])

        output_local, = s.cache_write([output], "local")

        slice_axes = []
        for i in range(len(output.op.axis)):
            slice_axes.append(
                cfg.apply_split(s, output_local, output_local.op.axis[i],
                                slice_data[i]))

        if output.op.reduce_axis:
            reduce_at = cfg.define_knob(
                f"{prefix}:RA", [x for x in range(len(output.op.reduce_axis))],
                init_vals=[
                    0,
                ])
            output_local_K_o, output_local_K_i = cfg.apply_split(
                s, output_local, output_local.op.reduce_axis[reduce_at],
                slice_reduce[reduce_at])
            output_local_K_o, output_local_K_i = [output_local_K_o
                                                  ], [output_local_K_i]
        else:
            output_local_K_o, output_local_K_i = [], []

        first, second, third, fourth = [x[0] for x in slice_axes], [
            x[1] for x in slice_axes
        ], [x[2] for x in slice_axes], [x[3] for x in slice_axes]
        s[output_local].reorder(*(first + second + output_local_K_o + third +
                                  output_local_K_i + fourth))

        slice_global_axes = []
        for i in range(len(output.op.axis)):
            if cfg.define_knob(f"{prefix}:_{i}", [False, True],
                               init_vals=[
                                   0,
                               ]):
                slice_global_axes.append(
                    cfg.apply_split(s, output, output.op.axis[i], [
                        -1, slice_data[i][1],
                        int(np.product(slice_data[i][2:]))
                    ]))
            else:
                slice_global_axes.append(
                    cfg.apply_split(
                        s, output, output.op.axis[i],
                        [-1, 1, int(np.product(slice_data[i][1:]))]))

        s[output].reorder(*([x[0] for x in slice_global_axes] +
                            [x[1] for x in slice_global_axes] +
                            [x[2] for x in slice_global_axes]))

        s[output_local].compute_at(s[output], slice_global_axes[-1][1])
        s[output].bind(s[output].fuse(*[x[0] for x in slice_global_axes]),
                       te.thread_axis('threadIdx.x'))

        s[output_local].pragma(first[0], "auto_unroll_max_step", unroll)
        s[output_local].pragma(first[0], "unroll_explicit", True)
        # s[output_local].vectorize(fourth[-1])
        s[output_local].unroll(fourth[-1])
Esempio n. 27
0
from tvm import te

batch = 128
in_channel = 64
in_size = 62

# cnhw -> nchw
A = te.placeholder((in_channel, batch, in_size, in_size), name='A')

A_ch = te.compute((batch, in_channel, in_size, in_size),
                  lambda n, c, h, w: A[c, n, h, w],
                  name='A_change')

s = te.create_schedule(A_ch.op)

block_x = te.thread_axis("blockIdx.x")
block_y = te.thread_axis("blockIdx.y")
block_z = te.thread_axis("blockIdx.z")
thread_x = te.thread_axis("threadIdx.x")
thread_y = te.thread_axis("threadIdx.y")

blockdim, threaddim = 32, 31
n, c, h, w = s[A_ch].op.axis
hw = s[A_ch].fuse(h, w)
no, ni = s[A_ch].split(n, nparts=blockdim)
co, ci = s[A_ch].split(c, nparts=blockdim)
hwo, hwi = s[A_ch].split(hw, nparts=31 * 31)
s[A_ch].reorder(no, co, hwo, ni, ci, hwi)
s[A_ch].bind(no, block_y)
s[A_ch].bind(co, block_x)
s[A_ch].bind(hwo, thread_x)
Esempio n. 28
0
def _schedule_dense_large_batch(cfg, s, C):
    """Schedule float32/64 dense with large batch size"""
    A, B = C.op.input_tensors
    batch, in_dim = get_const_tuple(A.shape)
    out_dim, _ = get_const_tuple(B.shape)
    k = C.op.reduce_axis[0]

    # create tuning space
    try:
        block_cand = [64, 128]
        vthread_cand = [2 ** x for x in range(1, 7)]
        n_thread_cand = [2 ** x for x in range(3, 7)]
        cfg.define_split(
            "tile_x",
            batch,
            num_outputs=4,
            filter=lambda x: (
                x.size[1] in vthread_cand
                and x.size[2] in n_thread_cand
                and (x.size[1] * x.size[2] * x.size[3]) in block_cand
            ),
        )
        cfg.define_split(
            "tile_y",
            out_dim,
            num_outputs=4,
            filter=lambda x: (
                x.size[1] in vthread_cand
                and x.size[2] in n_thread_cand
                and (x.size[1] * x.size[2] * x.size[3]) in block_cand
            ),
        )
        cfg.define_split("tile_k", in_dim, num_outputs=3, filter=lambda x: x.size[0] > 2)
    except IndexError:
        # Index error happens when no entities left after filtering, which was designed
        # to prune tuning space for better search efficiency.
        logger.debug("Tuning space was created without pruning due to unfit shapes")
        cfg.define_split("tile_x", batch, num_outputs=4)
        cfg.define_split("tile_y", out_dim, num_outputs=4)
        cfg.define_split("tile_k", in_dim, num_outputs=3)

    if cfg.is_fallback:
        if batch > 1:
            cfg["tile_x"] = SplitEntity([-1, 2, 16, 2])
        else:
            cfg["tile_x"] = SplitEntity([1, 1, 1, 1])
        if out_dim > 1:
            cfg["tile_y"] = SplitEntity([-1, 2, 16, 2])
        else:
            cfg["tile_y"] = SplitEntity([1, 1, 1, 1])
        if in_dim > 8:
            cfg["tile_k"] = SplitEntity([-1, 8, 1])
        else:
            cfg["tile_k"] = SplitEntity([-1, 1, 1])

    # Explicit memory access
    AA = s.cache_read(A, "shared", [C])
    BB = s.cache_read(B, "shared", [C])
    AL = s.cache_read(AA, "local", [C])
    BL = s.cache_read(BB, "local", [C])
    CC = s.cache_write(C, "local")

    # Deal with op fusion
    if C.op not in s.outputs:
        s[C].compute_inline()
        C = s.outputs[0].output(0)

    # Split and reorder computation
    bx, txz, tx, xi = cfg["tile_x"].apply(s, C, C.op.axis[0])
    by, tyz, ty, yi = cfg["tile_y"].apply(s, C, C.op.axis[1])
    s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
    s[CC].compute_at(s[C], tx)

    # Binding
    s[C].bind(by, te.thread_axis("blockIdx.y"))
    s[C].bind(bx, te.thread_axis("blockIdx.x"))
    s[C].bind(tyz, te.thread_axis("vthread"))
    s[C].bind(txz, te.thread_axis("vthread"))
    s[C].bind(ty, te.thread_axis("threadIdx.y"))
    s[C].bind(tx, te.thread_axis("threadIdx.x"))

    # Split reduction
    yo, xo = CC.op.axis
    ko, kt, ki = cfg["tile_k"].apply(s, CC, k)
    s[CC].reorder(ko, kt, ki, yo, xo)
    s[AA].compute_at(s[CC], ko)
    s[BB].compute_at(s[CC], ko)
    s[CC].unroll(kt)
    s[AL].compute_at(s[CC], kt)
    s[BL].compute_at(s[CC], kt)

    # Schedule for A's shared memory load
    num_thread_x = cfg["tile_x"].size[2]
    ty, _ = s[AA].split(s[AA].op.axis[0], nparts=num_thread_x)
    _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread_x * 4)
    tx, xi = s[AA].split(xi, nparts=num_thread_x)
    s[AA].bind(ty, te.thread_axis("threadIdx.y"))
    s[AA].bind(tx, te.thread_axis("threadIdx.x"))
    s[AA].double_buffer()

    # Schedule for B' shared memory load
    num_thread_y = cfg["tile_y"].size[2]
    ty, _ = s[BB].split(s[BB].op.axis[0], nparts=num_thread_y)
    _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread_y * 4)
    tx, xi = s[BB].split(xi, nparts=num_thread_y)
    s[BB].bind(ty, te.thread_axis("threadIdx.y"))
    s[BB].bind(tx, te.thread_axis("threadIdx.x"))
    s[BB].double_buffer()
Esempio n. 29
0
"""

import tvm
from tvm import te
import numpy as np

n = te.var('n')
m = te.var('m')

A = te.placeholder((n,), name='A')
B = te.compute(A.shape, lambda i: A[i] * 2, name='B')

s = te.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], factor=64)
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))
print(tvm.lower(s, [A, B], simple_mode=True))


"""

Results:

primfn(A_1: handle, B_1: handle) -> ()
  attr = {"tir.noalias": True, "global_symbol": "main"}
  buffers = {A: Buffer(A_2: handle, float32, [n: int32], [stride: int32], type="auto"),
             B: Buffer(B_2: handle, float32, [n], [stride_1: int32], type="auto")}
  buffer_map = {B_1: B, A_1: A} {
  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((n + 63), 64);
  attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64;
Esempio n. 30
0
def gen_ir_4d(data, indices, updates, axis, out, update_func):
    """Generate scatter ir for 4d inputs

    Parameters
    ----------
    data : tir.Tensor
        The input data to the operator.

    indices : tir.Tensor
        The index locations to update.

    updates : tir.Tensor
        The values to update.

    axis : int
        The axis to scatter on

    out : tir.Tensor
        The output tensor.

    update_func: function
        The function to be applied to a destination and the corresponding update

    Returns
    -------
    ret : tir
        The computational ir.
    """
    warp_size = tvm.target.Target.current(False).thread_warp_size

    n = data.shape[0]
    c = data.shape[1]
    h = data.shape[2]
    w = data.shape[3]

    ib = tvm.tir.ir_builder.create()

    out_ptr = ib.buffer_ptr(out)
    data_ptr = ib.buffer_ptr(data)
    _memcpy_ir(ib, out_ptr, data_ptr, data.shape)

    indices_ptr = ib.buffer_ptr(indices)
    updates_ptr = ib.buffer_ptr(updates)
    ni = indices.shape[0]
    ci = indices.shape[1]
    hi = indices.shape[2]
    wi = indices.shape[3]

    if axis == 0:
        with ib.new_scope():
            j = te.thread_axis("blockIdx.y")
            ib.scope_attr(j, "thread_extent", ci)
            k = te.thread_axis("blockIdx.z")
            ib.scope_attr(k, "thread_extent", hi)
            tx = te.thread_axis("threadIdx.x")
            ib.scope_attr(tx, "thread_extent", warp_size)
            with ib.for_range(0, ni, name="i") as i:
                with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_:
                    l = l_ * warp_size + tx
                    with ib.if_scope(l < wi):
                        idx = ((i * ci + j) * hi + k) * wi + l
                        index = indices_ptr[idx]
                        with ib.if_scope(index < 0):
                            update_func(out_ptr,
                                        (((index + n) * c + j) * h + k) * w +
                                        l, updates_ptr[idx])
                        with ib.else_scope():
                            update_func(out_ptr,
                                        ((index * c + j) * h + k) * w + l,
                                        updates_ptr[idx])
    elif axis == 1:
        with ib.new_scope():
            i = te.thread_axis("blockIdx.x")
            ib.scope_attr(i, "thread_extent", ni)
            k = te.thread_axis("blockIdx.z")
            ib.scope_attr(k, "thread_extent", hi)
            tx = te.thread_axis("threadIdx.x")
            ib.scope_attr(tx, "thread_extent", warp_size)
            with ib.for_range(0, ci, name="j") as j:
                with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_:
                    l = l_ * warp_size + tx
                    with ib.if_scope(l < wi):
                        idx = ((i * ci + j) * hi + k) * wi + l
                        index = indices_ptr[idx]
                        with ib.if_scope(index < 0):
                            update_func(out_ptr,
                                        ((i * c + (index + c)) * h + k) * w +
                                        l, updates_ptr[idx])
                        with ib.else_scope():
                            update_func(out_ptr,
                                        ((i * c + index) * h + k) * w + l,
                                        updates_ptr[idx])
    elif axis == 2:
        with ib.new_scope():
            i = te.thread_axis("blockIdx.x")
            ib.scope_attr(i, "thread_extent", ni)
            j = te.thread_axis("blockIdx.y")
            ib.scope_attr(j, "thread_extent", ci)
            tx = te.thread_axis("threadIdx.x")
            ib.scope_attr(tx, "thread_extent", warp_size)
            with ib.for_range(0, hi, name="k") as k:
                with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_:
                    l = l_ * warp_size + tx
                    with ib.if_scope(l < wi):
                        idx = ((i * ci + j) * hi + k) * wi + l
                        index = indices_ptr[idx]
                        with ib.if_scope(index < 0):
                            update_func(out_ptr, ((i * c + j) * h +
                                                  (index + h)) * w + l,
                                        updates_ptr[idx])
                        with ib.else_scope():
                            update_func(out_ptr,
                                        ((i * c + j) * h + index) * w + l,
                                        updates_ptr[idx])
    else:
        with ib.new_scope():
            i = te.thread_axis("blockIdx.x")
            ib.scope_attr(i, "thread_extent", ni)
            j = te.thread_axis("blockIdx.y")
            ib.scope_attr(j, "thread_extent", ci)
            k = te.thread_axis("blockIdx.z")
            ib.scope_attr(k, "thread_extent", hi)
            with ib.for_range(0, wi, name="l") as l:
                idx = ((i * ci + j) * hi + k) * wi + l
                index = indices_ptr[idx]
                with ib.if_scope(index < 0):
                    update_func(out_ptr,
                                ((i * c + j) * h + k) * w + (index + w),
                                updates_ptr[idx])
                with ib.else_scope():
                    update_func(out_ptr, ((i * c + j) * h + k) * w + index,
                                updates_ptr[idx])
    return ib.get()