コード例 #1
0
def test_ir_transform():
    ib = tvm.ir_builder.create()
    n = tvm.var("n")
    with ib.for_range(0, n, name="i") as i:
        with ib.for_range(0, 10, name="j") as j:
            x = tvm.call_extern("int32", "TestA", i * 3 + j * 1)
            ib.emit(tvm.call_extern("int32", "TestB", x))
            ib.emit(tvm.call_extern("int32", "TestC", x))
    body = ib.get()

    def preorder(op):
        if op.name == "TestC":
            return tvm.const(0, "int32")
        return None

    def postorder(op):
        assert isinstance(op, tvm.tir.Call)
        if op.name == "TestA":
            return tvm.call_extern("int32", "TestB", op.args[0] + 1)
        return op

    body = tvm.ir_pass.IRTransform(body, preorder, postorder, ["Call"])
    stmt_list = tvm.tir.stmt_list(body.body.body)
    assert stmt_list[0].value.args[0].name == "TestB"
    assert stmt_list[1].value.value == 0
コード例 #2
0
ファイル: intrin.py プロジェクト: bddppq/tvm
 def instr(index):
     """Generate matrix-matrix multiply VTA instruction"""
     irb = tvm.ir_builder.create()
     dev = env.dev
     irb.scope_attr(dev.vta_axis, "coproc_scope",
                    dev.get_task_qid(dev.QID_COMPUTE))
     irb.scope_attr(dev.vta_axis, "coproc_uop_scope",
                    dev.vta_push_uop)
     if index in (0, 2):
         irb.emit(tvm.call_extern(
             "int32", "VTAUopPush",
             0, 0,
             dout.access_ptr("rw", "int32"),
             dinp.access_ptr("r", "int32"),
             dwgt.access_ptr("r", "int32"),
             0, 0, 0))
     else:
         irb.emit(tvm.call_extern(
             "int32", "VTAUopPush",
             0, 1,
             dout.access_ptr("rw", "int32"),
             0,
             0,
             0, 0, 0))
     return irb.get()
コード例 #3
0
    def _fold_outermost_loop(body):
        stmt = body
        while not isinstance(stmt, tvm.stmt.For):
            if isinstance(stmt, (tvm.stmt.ProducerConsumer, )):
                stmt = stmt.body
            else:
                return None, body, None

        loop_var = stmt.loop_var
        gemm_offsets = [None, None, None]
        fail = [False]

        def _post_order(op):
            assert isinstance(op, tvm.expr.Call)
            base_args = 2
            if op.name == "VTAUopPush":
                args = []
                args += op.args[:base_args]
                for i in range(3):
                    m = tvm.arith.DetectLinearEquation(op.args[i + base_args],
                                                       [loop_var])
                    if not m:
                        fail[0] = True
                        return op
                    if gemm_offsets[i] is not None:
                        if not tvm.ir_pass.Equal(m[0], gemm_offsets[i]):
                            fail[0] = True
                            return op
                        args.append(m[1])
                    else:
                        gemm_offsets[i] = m[0]
                        args.append(m[1])
                args += op.args[base_args + 3:]
                return tvm.call_extern("int32", "VTAUopPush", *args)
            else:
                if op.name not in ("VTATLSCommandHandle",
                                   "tvm_thread_context"):
                    raise RuntimeError("unexpected op %s" % op)
                return op

        ret = tvm.ir_pass.IRTransform(stmt.body, None, _post_order, ["Call"])

        if not fail[0] and all(x is not None for x in gemm_offsets):

            def _visit(op):
                if op.same_as(loop_var):
                    fail[0] = True

            tvm.ir_pass.PostOrderVisit(ret, _visit)
            if not fail[0]:
                begin = tvm.call_extern("int32", "VTAUopLoopBegin",
                                        stmt.extent, *gemm_offsets)
                end = tvm.call_extern("int32", "VTAUopLoopEnd")
                return [begin, ret, end]
        raise ValueError("Failed to fold the GEMM instructions..")
コード例 #4
0
ファイル: ir_pass.py プロジェクト: bddppq/tvm
    def _fold_outermost_loop(body):
        stmt = body
        while not isinstance(stmt, tvm.stmt.For):
            if isinstance(stmt, (tvm.stmt.ProducerConsumer,)):
                stmt = stmt.body
            else:
                return None, body, None

        loop_var = stmt.loop_var
        gemm_offsets = [None, None, None]
        fail = [False]

        def _post_order(op):
            assert isinstance(op, tvm.expr.Call)
            base_args = 2
            if op.name == "VTAUopPush":
                args = []
                args += op.args[:base_args]
                for i in range(3):
                    m = tvm.arith.DetectLinearEquation(
                        op.args[i + base_args], [loop_var])
                    if not m:
                        fail[0] = True
                        return op
                    if gemm_offsets[i] is not None:
                        if not tvm.ir_pass.Equal(m[0], gemm_offsets[i]):
                            fail[0] = True
                            return op
                        args.append(m[1])
                    else:
                        gemm_offsets[i] = m[0]
                        args.append(m[1])
                args += op.args[base_args+3:]
                return tvm.call_extern("int32", "VTAUopPush", *args)
            if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"):
                raise RuntimeError("unexpected op %s" % op)
            return op

        ret = tvm.ir_pass.IRTransform(
            stmt.body, None, _post_order, ["Call"])

        if not fail[0] and all(x is not None for x in gemm_offsets):
            def _visit(op):
                if op.same_as(loop_var):
                    fail[0] = True
            tvm.ir_pass.PostOrderVisit(ret, _visit)
            if not fail[0]:
                begin = tvm.call_extern(
                    "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets)
                end = tvm.call_extern("int32", "VTAUopLoopEnd")
                return [begin, ret, end]
        raise ValueError("Failed to fold the GEMM instructions..")
コード例 #5
0
 def intrin_func(ins, outs):
     # tvm call extern is used to interface to libxsmm batch reduce kernel gemm implementation
     # rco*r*s is the number of batches
     init_and_compute = tvm.call_extern ("int32","batch_reduce_kernel_init_update", ins[0].access_ptr("r"),ins[1].access_ptr("r"),outs[0].access_ptr("w"),\
                                            rco*r*s,ofmblock,ifmblock,r,s,ifh_stride,ifw_stride, ofw, stride_width)
     reset = tvm.call_extern("int32", "batch_reduce_kernel_init",
                             outs[0].access_ptr("w"), ofmblock, ofw)
     body = tvm.call_extern ("int32","batch_reduce_kernel_update", ins[0].access_ptr("r"),ins[1].access_ptr("r"),outs[0].access_ptr("w"), rco*r*s,ofmblock,\
                                    ifmblock,ofw, stride_width,r,s, ifh_stride,ifw_stride)
     if math.ceil(in_channel / ifmblock) == rco:
         return init_and_compute, None, init_and_compute
     else:
         return init_and_compute, reset, body
コード例 #6
0
ファイル: col2im.py プロジェクト: zhuyawen/akg
 def intrin_func(ins, outs):
     sp = [
         INPUT_W,
         INPUT_H,
         PAD_LEFT,
         PAD_RIGHT,
         PAD_TOP,
         PAD_BOTTOM,  # FMATRIX
         W_IDX_KERNEL,
         H_IDX_KERNEL,
         W_IDX,
         H_IDX,
         C1_IDX,  # Xm
         STRIDE_W,
         STRIDE_H,
         KERNEL_W,
         KERNEL_H,
         DILATION_W,
         DILATION_H,
         JUMP_OFFSET,
         REPEAT_MODE,
         REPEAT_TIME,  # Xt
     ]
     aa = ins[0]
     bb = outs[0]
     ib = tvm.ir_builder.create()
     fcol2img, Xm, Xt = pack_args(sp)
     ib.emit(tvm.call_extern(dtype, "set_fcol2img", fcol2img))
     ib.emit(
         tvm.call_extern(dtype, "vector_dup", bb.access_ptr("w"), 0,
                         (INPUT_H * INPUT_W * 16) // 64, 1, 1, 8, 8))
     c = 0
     for kh in range(KERNEL_H):
         for kw in range(KERNEL_W):
             sp[6] = kw
             sp[7] = kh
             fcol2img, Xm, Xt = pack_args(sp)
             ib.emit(
                 tvm.call_extern(
                     dtype,
                     "col2img",
                     bb.access_ptr("rw"),
                     aa.access_ptr("r",
                                   offset=c * 16 * INPUT_C0 * REPEAT_TIME),
                     Xm,
                     Xt,
                 ))
             c += 1
     return ib.get()
コード例 #7
0
    def add_debug(stmt):
        debug = tvm.call_extern(
            "int32", "VTASetDebugMode",
            env.dev.command_handle,
            debug_flag)

        return tvm.make.stmt_seq(debug, stmt)
コード例 #8
0
 def _post_order(op):
     assert isinstance(op, tvm.expr.Call)
     base_args = 2
     if op.name == "VTAUopPush":
         args = []
         args += op.args[:base_args]
         for i in range(3):
             m = tvm.arith.DetectLinearEquation(op.args[i + base_args],
                                                [loop_var])
             if not m:
                 fail[0] = True
                 return op
             if gemm_offsets[i] is not None:
                 if not tvm.ir_pass.Equal(m[0], gemm_offsets[i]):
                     fail[0] = True
                     return op
                 args.append(m[1])
             else:
                 gemm_offsets[i] = m[0]
                 args.append(m[1])
         args += op.args[base_args + 3:]
         return tvm.call_extern("int32", "VTAUopPush", *args)
     else:
         if op.name not in ("VTATLSCommandHandle",
                            "tvm_thread_context"):
             raise RuntimeError("unexpected op %s" % op)
         return op
コード例 #9
0
    def mma_sync(inputs, outputs):
        print(inputs)
        factor1_, factor2_, product_ = inputs
        schedule_ = outputs[0]
        #get address for matrix A
        A_ptr = factor1_.access_ptr("r")
        #get address for matrix B
        B_ptr = factor2_.access_ptr("r")
        #get address for matrix C
        C_ptr = product_.access_ptr("w")

        body = tvm.call_extern('float32', "wmma_call", A_ptr, B_ptr, C_ptr)
        #body = tvm.call_extern('float32',"__INIT_TILE_WARP__")
        init = tvm.call_extern('float32', "__INIT_TILE_WARP__")
        #product_.vstore((0,0,0,0),0.)
        return body, init, body
コード例 #10
0
 def intrin_func(ins, outs):
     ib = tvm.ir_builder.create()
     ib.emit(
         tvm.call_extern("float32", 'vadd', ins[0].access_ptr("r"),
                         ins[1].access_ptr('r'),
                         outs[0].access_ptr('wr')))
     return ib.get()
コード例 #11
0
 def _post_order(op):
     if isinstance(op, tvm.stmt.Allocate):
         buffer_var = op.buffer_var
         if not buffer_var in rw_info:
             return None
         new_var = rw_info[buffer_var]
         let_stmt = tvm.make.LetStmt(
             new_var,
             tvm.call_extern("handle", "VTABufferCPUPtr",
                             env.dev.command_handle, buffer_var), op.body)
         alloc = tvm.make.Allocate(buffer_var, op.dtype, op.extents,
                                   op.condition, let_stmt)
         del rw_info[buffer_var]
         return alloc
     elif isinstance(op, tvm.expr.Load):
         buffer_var = op.buffer_var
         if not buffer_var in rw_info:
             rw_info[buffer_var] = tvm.var(buffer_var.name + "_ptr",
                                           "handle")
         new_var = rw_info[buffer_var]
         return tvm.make.Load(op.dtype, new_var, op.index)
     elif isinstance(op, tvm.stmt.Store):
         buffer_var = op.buffer_var
         if not buffer_var in rw_info:
             rw_info[buffer_var] = tvm.var(buffer_var.name + "_ptr",
                                           "handle")
         new_var = rw_info[buffer_var]
         return tvm.make.Store(new_var, op.value, op.index)
     else:
         raise RuntimeError("not reached")
コード例 #12
0
 def meminfo_cache():
     return tvm.make.node("MemoryInfo",
                          unit_bits=8,
                          max_simd_bits=32,
                          max_num_bits=128,
                          head_address=tvm.call_extern(
                              "handle", "global_cache"))
コード例 #13
0
ファイル: test_pass_storage_sync.py プロジェクト: gwli/tvm
 def meminfo_cache():
     return tvm.make.node(
         "MemoryInfo",
         unit_bits=8,
         max_simd_bits=32,
         max_num_bits=128,
         head_address=tvm.call_extern("handle", "global_cache"))
コード例 #14
0
ファイル: ir_pass.py プロジェクト: bddppq/tvm
def cpu_access_rewrite(stmt_in):
    """Detect CPU access to VTA buffer and get address correctly.

    VTA's buffer is an opaque handle that do not
    correspond to address in CPU.
    This pass detect CPU access and rewrite to use pointer
    returned VTABufferCPUPtr for CPU access.

    Parameters
    ----------
    stmt_in : Stmt
        Input statement

    Returns
    -------
    stmt_out : Stmt
        Transformed statement
    """
    env = get_env()
    rw_info = {}
    def _post_order(op):
        if isinstance(op, tvm.stmt.Allocate):
            buffer_var = op.buffer_var
            if not buffer_var in rw_info:
                return None
            new_var = rw_info[buffer_var]
            let_stmt = tvm.make.LetStmt(
                new_var, tvm.call_extern(
                    "handle", "VTABufferCPUPtr",
                    env.dev.command_handle,
                    buffer_var), op.body)
            alloc = tvm.make.Allocate(
                buffer_var, op.dtype, op.extents,
                op.condition, let_stmt)
            del rw_info[buffer_var]
            return alloc
        if isinstance(op, tvm.expr.Load):
            buffer_var = op.buffer_var
            if not buffer_var in rw_info:
                rw_info[buffer_var] = tvm.var(
                    buffer_var.name + "_ptr", "handle")
            new_var = rw_info[buffer_var]
            return tvm.make.Load(op.dtype, new_var, op.index)
        if isinstance(op, tvm.stmt.Store):
            buffer_var = op.buffer_var
            if not buffer_var in rw_info:
                rw_info[buffer_var] = tvm.var(
                    buffer_var.name + "_ptr", "handle")
            new_var = rw_info[buffer_var]
            return tvm.make.Store(new_var, op.value, op.index)
        raise RuntimeError("not reached")
    stmt = tvm.ir_pass.IRTransform(
        stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
    for buffer_var, new_var in rw_info.items():
        stmt = tvm.make.LetStmt(
            new_var, tvm.call_extern(
                "handle", "VTABufferCPUPtr",
                env.dev.command_handle,
                buffer_var), stmt)
    return stmt
コード例 #15
0
 def _body():
     ib = tvm.ir_builder.create()
     ib.emit(
         tvm.call_extern("int32", "gemv_update", cc.access_ptr("w"),
                         aa.access_ptr("r"), bb.access_ptr("r"), m, l,
                         bb.strides[0]))
     return ib.get()
コード例 #16
0
ファイル: ir_pass.py プロジェクト: bddppq/tvm
 def _post_order(op):
     assert isinstance(op, tvm.expr.Call)
     base_args = 2
     if op.name == "VTAUopPush":
         args = []
         args += op.args[:base_args]
         for i in range(3):
             m = tvm.arith.DetectLinearEquation(
                 op.args[i + base_args], [loop_var])
             if not m:
                 fail[0] = True
                 return op
             if gemm_offsets[i] is not None:
                 if not tvm.ir_pass.Equal(m[0], gemm_offsets[i]):
                     fail[0] = True
                     return op
                 args.append(m[1])
             else:
                 gemm_offsets[i] = m[0]
                 args.append(m[1])
         args += op.args[base_args+3:]
         return tvm.call_extern("int32", "VTAUopPush", *args)
     if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"):
         raise RuntimeError("unexpected op %s" % op)
     return op
コード例 #17
0
ファイル: ir_pass.py プロジェクト: bddppq/tvm
 def _post_order(op):
     if isinstance(op, tvm.stmt.Allocate):
         buffer_var = op.buffer_var
         if not buffer_var in rw_info:
             return None
         new_var = rw_info[buffer_var]
         let_stmt = tvm.make.LetStmt(
             new_var, tvm.call_extern(
                 "handle", "VTABufferCPUPtr",
                 env.dev.command_handle,
                 buffer_var), op.body)
         alloc = tvm.make.Allocate(
             buffer_var, op.dtype, op.extents,
             op.condition, let_stmt)
         del rw_info[buffer_var]
         return alloc
     if isinstance(op, tvm.expr.Load):
         buffer_var = op.buffer_var
         if not buffer_var in rw_info:
             rw_info[buffer_var] = tvm.var(
                 buffer_var.name + "_ptr", "handle")
         new_var = rw_info[buffer_var]
         return tvm.make.Load(op.dtype, new_var, op.index)
     if isinstance(op, tvm.stmt.Store):
         buffer_var = op.buffer_var
         if not buffer_var in rw_info:
             rw_info[buffer_var] = tvm.var(
                 buffer_var.name + "_ptr", "handle")
         new_var = rw_info[buffer_var]
         return tvm.make.Store(new_var, op.value, op.index)
     raise RuntimeError("not reached")
コード例 #18
0
def cpu_access_rewrite(stmt_in):
    """Detect CPU access to VTA buffer and get address correctly.

    VTA's buffer is an opaque handle that do not
    correspond to address in CPU.
    This pass detect CPU access and rewrite to use pointer
    returned VTABufferCPUPtr for CPU access.

    Parameters
    ----------
    stmt_in : Stmt
        Input statement

    Returns
    -------
    stmt_out : Stmt
        Transformed statement
    """
    env = get_env()
    rw_info = {}

    def _post_order(op):
        if isinstance(op, tvm.stmt.Allocate):
            buffer_var = op.buffer_var
            if not buffer_var in rw_info:
                return None
            new_var = rw_info[buffer_var]
            let_stmt = tvm.make.LetStmt(
                new_var,
                tvm.call_extern("handle", "VTABufferCPUPtr",
                                env.dev.command_handle, buffer_var), op.body)
            alloc = tvm.make.Allocate(buffer_var, op.dtype, op.extents,
                                      op.condition, let_stmt)
            del rw_info[buffer_var]
            return alloc
        elif isinstance(op, tvm.expr.Load):
            buffer_var = op.buffer_var
            if not buffer_var in rw_info:
                rw_info[buffer_var] = tvm.var(buffer_var.name + "_ptr",
                                              "handle")
            new_var = rw_info[buffer_var]
            return tvm.make.Load(op.dtype, new_var, op.index)
        elif isinstance(op, tvm.stmt.Store):
            buffer_var = op.buffer_var
            if not buffer_var in rw_info:
                rw_info[buffer_var] = tvm.var(buffer_var.name + "_ptr",
                                              "handle")
            new_var = rw_info[buffer_var]
            return tvm.make.Store(new_var, op.value, op.index)
        else:
            raise RuntimeError("not reached")

    stmt = tvm.ir_pass.IRTransform(stmt_in, None, _post_order,
                                   ["Allocate", "Load", "Store"])
    for buffer_var, new_var in rw_info.items():
        stmt = tvm.make.LetStmt(
            new_var,
            tvm.call_extern("handle", "VTABufferCPUPtr",
                            env.dev.command_handle, buffer_var), stmt)
    return stmt
コード例 #19
0
ファイル: build_module.py プロジェクト: bddppq/tvm
    def add_debug(stmt):
        debug = tvm.call_extern(
            "int32", "VTASetDebugMode",
            env.dev.command_handle,
            debug_flag)

        return tvm.make.stmt_seq(debug, stmt)
コード例 #20
0
ファイル: tensorize.py プロジェクト: LANHUIYING/tvm
 def _body():
     ib = tvm.ir_builder.create()
     ib.emit(tvm.call_extern("int32", "gemv_update",
                             cc.access_ptr("w"),
                             aa.access_ptr("r"),
                             bb.access_ptr("r"),
                             m, l, bb.strides[0]))
     return ib.get()
コード例 #21
0
 def _reset():
     ib = tvm.ir_builder.create()
     # int32_t matmul_reset(elem_t *C, int32_t I, int32_t J, int32_t pad_I,
     #         int32_t pad_J, int32_t C_row_len);
     ib.emit(
         tvm.call_extern("int32", "matmul_reset", cc.access_ptr("w"),
                         II, JJ, pad_I, pad_J, strideC))
     return ib.get()
コード例 #22
0
def test_buffer_access_ptr_offset():
    m = tvm.var('m')
    n = tvm.var('n')
    Ab = tvm.decl_buffer((m, n), tvm.float32)
    aptr = Ab.access_ptr("rw", offset=100)
    offset = tvm.ir_pass.Simplify(aptr.args[2])
    assert tvm.ir_pass.Equal(offset, 100)
    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
    v = tvm.var('int32')
    aptr = Ab.access_ptr("rw", offset=100 + 100 + v)
    offset = tvm.ir_pass.Simplify(aptr.args[2])
    assert tvm.ir_pass.Equal(offset, 200 + v)
    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
    aptr = Ab.access_ptr("rw", offset=tvm.call_extern('int32', "test_call", 100 + 100 + v))
    offset = tvm.ir_pass.Simplify(aptr.args[2])
    assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v))
    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
コード例 #23
0
 def intrin_func(ins, outs):
     ib = tvm.ir_builder.create()
     aa, bb = ins
     cc = outs[0]
     ib.emit(
         tvm.call_extern("int32", "gemv_update", cc.access_ptr("w"),
                         aa.access_ptr("r"), bb.access_ptr("r"), m, l,
                         bb.strides[0]))
     return ib.get()
コード例 #24
0
 def reset():
     irb = tvm.ir_builder.create()
     extern_call = tvm.call_extern(
         "int32", "sgemm_reset_{M}x{N}__{ARCH}".format(M=M,
                                                       N=N,
                                                       ARCH=ARCH),
         irb.buffer_ptr(cc), cc.elem_offset, cc.strides[0])
     irb.emit(extern_call)
     return irb.get()
コード例 #25
0
def test_cce_loop_2():
  ib = tvm.ir_builder.create()
  len = 112
  tile = 32
  loop = (len + tile - 1) // tile
  with ib.for_range(0, loop, 'i') as i:
    head = i * tile
    with ib.if_scope(ib.likely(head + tile > len)):
      tail = len
      ib.emit(tvm.call_extern('float32', "cce_intrisic", head, tail))
    with ib.else_scope():
      tail = head + tile
      ib.emit(tvm.call_extern('float32', "cce_intrisic", head, tail))

  stmt = ib.get()
  stmt = tvm.ir_pass.LoopPartition(stmt, True)
  stmt = tvm.ir_pass.Simplify(stmt)
  assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
コード例 #26
0
 def _finalize():
     ib = tvm.ir_builder.create()
     # Move out C from accumulator
     # int32_t matmul_finalize(elem_t *C, int32_t I, int32_t J, int32_t pad_I,
     #         int32_t pad_J, int32_t C_row_len);
     ib.emit(
         tvm.call_extern("int32", "matmul_finalize",
                         cc.access_ptr("rw"), II, JJ, pad_I, pad_J,
                         strideC))
     return ib.get()
コード例 #27
0
ファイル: environment.py プロジェクト: LANHUIYING/tvm
 def __init__(self, env):
     self.vta_axis = tvm.thread_axis("vta")
     self.vta_push_uop = tvm.make.StringImm("VTAPushGEMMOp")
     ctx = tvm.call_extern("handle", "VTATLSCommandHandle")
     self.command_handle = tvm.make.Call(
         "handle", "tvm_thread_context", [ctx],
         tvm.expr.Call.Intrinsic, None, 0)
     self.DEBUG_NO_SYNC = False
     env._dev_ctx = self
     self.gemm = intrin.gemm(env, env.mock_mode)
コード例 #28
0
ファイル: tensorize.py プロジェクト: LANHUIYING/tvm
 def intrin_func(ins, outs):
     ib = tvm.ir_builder.create()
     aa, bb = ins
     cc = outs[0]
     ib.emit(tvm.call_extern("int32", "gemv_update",
                             cc.access_ptr("w"),
                             aa.access_ptr("r"),
                             bb.access_ptr("r"),
                             m, l, bb.strides[0]))
     return ib.get()
コード例 #29
0
def test_cce_loop_2():
    ib = tvm.ir_builder.create()
    len = 112
    tile = 32
    loop = (len + tile - 1) // tile
    with ib.for_range(0, loop, 'i') as i:
        head = i * tile
        with ib.if_scope(ib.likely(head + tile > len)):
            tail = len
            ib.emit(tvm.call_extern('float32', "cce_intrisic", head, tail))
        with ib.else_scope():
            tail = head + tile
            ib.emit(tvm.call_extern('float32', "cce_intrisic", head, tail))

    stmt = ib.get()
    stmt = tvm.ir_pass.LoopPartition(stmt, True)
    stmt = tvm.ir_pass.Simplify(stmt)
    assert (not any(
        collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
コード例 #30
0
ファイル: intrin.py プロジェクト: zhangquan920/tasn
 def instr(index):
     """Generate matrix-matrix multiply VTA instruction"""
     irb = tvm.ir_builder.create()
     dev = env.dev
     irb.scope_attr(dev.vta_axis, "coproc_scope",
                    dev.get_task_qid(dev.QID_COMPUTE))
     irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
     if index == 0 or index == 2:
         irb.emit(
             tvm.call_extern("int32", "VTAUopPush", 0, 0,
                             dout.access_ptr("rw", "int32"),
                             dinp.access_ptr("r", "int32"),
                             dwgt.access_ptr("r", "int32"), 0, 0, 0))
     else:
         irb.emit(
             tvm.call_extern("int32", "VTAUopPush", 0, 1,
                             dout.access_ptr("rw", "int32"), 0, 0, 0, 0,
                             0))
     return irb.get()
コード例 #31
0
 def __init__(self, env):
     self.vta_axis = tvm.thread_axis("vta")
     self.vta_push_uop = tvm.make.StringImm("VTAPushGEMMOp")
     ctx = tvm.call_extern("handle", "VTATLSCommandHandle")
     self.command_handle = tvm.make.Call("handle", "tvm_thread_context",
                                         [ctx], tvm.expr.Call.Intrinsic,
                                         None, 0)
     self.DEBUG_NO_SYNC = False
     env._dev_ctx = self
     self.gemm = intrin.gemm(env, env.mock_mode)
コード例 #32
0
 def body():
     irb = tvm.ir_builder.create()
     extern_call = tvm.call_extern(
         "int32", "sgemm_compute_{M}x{N}__{ARCH}".format(M=M,
                                                         N=N,
                                                         ARCH=ARCH), K,
         irb.buffer_ptr(aa),
         aa.elem_offset, irb.buffer_ptr(bb), bb.elem_offset,
         irb.buffer_ptr(cc), cc.elem_offset, cc.strides[0])
     irb.emit(extern_call)
     return irb.get()
コード例 #33
0
def test_for():
    dev_type = tvm.var("dev_type")

    def device_context(dev_id):
        ctx = tvm.call_extern("handle", "device_context", dev_type, dev_id)
        return tvm.make.Call("handle", "tvm_thread_context", [ctx],
                             tvm.expr.Call.Intrinsic, None, 0)

    ib = tvm.ir_builder.create()
    n = tvm.var("n")
    A = ib.allocate("float32", n, name="A", scope="global")
    with ib.for_range(0, n, name="i") as i:
        ib.emit(tvm.call_extern("int32", "fadd", device_context(0), A))
        with ib.for_range(0, 10, name="j") as j:
            ib.emit(tvm.call_extern("int32", "fadd", device_context(1), A))
            ib.emit(tvm.call_extern("int32", "fadd", device_context(0), A))
    body = ib.get()
    f = tvm.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True)
    f = tvm.ir_pass.CombineContextCall(f)
    assert f.body.value.dtype == "handle"
    assert f.body.body.value.dtype == "handle"
コード例 #34
0
def Gemm_ir_wmma(A,B,C):
    ib = tvm.ir_builder.create()
    thread_x=tvm.thread_axis('threadIdx.x')   
    ib.scope_attr(thread_x,'thread_extent',num_thread)
    #declare shared memory
    offsetb = 16*16
    
    sync = tvm.call_extern("float32","__syncthreads")
    #define fragment 
    def_matrix_frag = tvm.call_extern("float32","__FRAGMENT_F16__")
    ib.emit(def_matrix_frag)
  

    
    load_matrix_frag_a = tvm.call_extern("float32","__LOADFRAG_A__",A,0,16)
    ib.emit(load_matrix_frag_a)
    load_matrix_frag_b = tvm.call_extern("float32","__LOADFRAG_B__",B,0,16)
    ib.emit(load_matrix_frag_b)
    ib.emit(sync)
    wmma_compute = tvm.call_extern("float32","__WMMA_SYNC__",0,0)
    ib.emit(wmma_compute)
    ib.emit(sync)
    store_matrix_frag_c1 = tvm.call_extern("float32","__STOREFRAG_C_F16__",C,0,0,16)
    ib.emit(store_matrix_frag_c1)
    ib.emit(sync)
           
    body = ib.get()

    return(body)
コード例 #35
0
 def intrin_func(ins, outs):
     ib = tvm.ir_builder.create()
     inp, filt = ins
     outp = outs[0]
     ib.emit(
         tvm.call_extern(
             "int32",
             "inst_conv",
             outp.access_ptr("w"),
             inp.access_ptr("r"),
             filt.access_ptr("r"),
         ))
     return ib.get()
コード例 #36
0
def test_for():
    dev_type = tvm.var("dev_type")
    def device_context(dev_id):
        ctx = tvm.call_extern("handle", "device_context", dev_type, dev_id)
        return tvm.make.Call(
            "handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0)

    ib = tvm.ir_builder.create()
    n = tvm.var("n")
    A = ib.allocate("float32", n, name="A", scope="global")
    with ib.for_range(0, n, name="i") as i:
        ib.emit(tvm.call_extern
                ("int32", "fadd", device_context(0), A))
        with ib.for_range(0, 10, name="j") as j:
            ib.emit(tvm.call_extern
                    ("int32", "fadd", device_context(1), A))
            ib.emit(tvm.call_extern
                    ("int32", "fadd", device_context(0), A))
    body = ib.get()
    f = tvm.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True)
    f = tvm.ir_pass.CombineContextCall(f)
    assert f.body.value.dtype == "handle"
    assert f.body.body.value.dtype == "handle"
コード例 #37
0
def ir_warp(A, B):
    ib = tvm.ir_builder.create()
    A_ptr = ib.buffer_ptr(A)
    B_ptr = ib.buffer_ptr(B)

    tx = tvm.thread_axis('threadIdx.x')
    ib.scope_attr(tx, 'thread_extent', 10)
    i = tx
    with ib.if_scope(i % 2 == 0):
        B_ptr[i] = A_ptr[i] + 1

    o1 = tvm.call_extern("float32", "__WMMA__")
    ib.emit(o1)
    body = ib.get()
    return (body)
コード例 #38
0
ファイル: test_pass_ir_transform.py プロジェクト: bddppq/tvm
def test_ir_transform():
    ib = tvm.ir_builder.create()
    n = tvm.var("n")
    with ib.for_range(0, n, name="i") as i:
        with ib.for_range(0, 10, name="j") as j:
            x = tvm.call_extern("int32", "TestA", i * 3 + j * 1)
            ib.emit(tvm.call_extern("int32", "TestB", x))
            ib.emit(tvm.call_extern("int32", "TestC", x))
    body = ib.get()

    def preorder(op):
        if op.name == "TestC":
            return tvm.const(0, "int32")
        return None

    def postorder(op):
        assert isinstance(op, tvm.expr.Call)
        if op.name == "TestA":
            return tvm.call_extern("int32", "TestB", op.args[0] + 1)
        return op
    body = tvm.ir_pass.IRTransform(body, preorder, postorder, ["Call"])
    stmt_list = tvm.make.stmt_list(body.body.body)
    assert stmt_list[0].value.args[0].name == "TestB"
    assert stmt_list[1].value.value == 0
コード例 #39
0
def test_cce_loop_3():
    ib = tvm.ir_builder.create()
    loop1 = 4
    loop2 = 9998
    tile = 39991
    with ib.for_range(0,loop2,'i') as i:
        with ib.for_range(0,loop1,'j') as j:
            head1 = i
            head2 = j
            with ib.if_scope(ib.likely(head1*loop1 + head2 < tile)):
                ib.emit(tvm.call_extern('float16',"cce_intrisic",head1))

    stmt = ib.get()
    stmt = tvm.ir_pass.LoopPartition(stmt,True)
    stmt = tvm.ir_pass.Simplify(stmt)
    assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
コード例 #40
0
ファイル: matmul.py プロジェクト: zsangel378/mindspore
 def _body():
     b_ = tvm.ir_builder.create()
     b_.emit(
         tvm.call_extern(
             "int32",
             opname,
             cc.access_ptr("w"),
             aa.access_ptr("r"),
             bb.access_ptr("r"),
             ci_,
             vh_,
             vw_,
             vc_,
             core_id,
         ))
     return b_.get()
コード例 #41
0
def test_extern_call():
    n = 10
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute((n,), lambda *i: tvm.call_extern("float32", "TVMTestAddOne", A(*i)), name='B')
    s = tvm.create_schedule(B.op)

    def check_llvm():
        if not tvm.module.enabled("llvm"):
            return
        f = tvm.build(s, [A, B], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
        f(a, b)
        tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1)
    check_llvm()
コード例 #42
0
ファイル: test_ext.py プロジェクト: LANHUIYING/tvm
def test_extern_call():
    n = 10
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute((n,), lambda *i: tvm.call_extern("float32", "TVMTestAddOne", A(*i)), name='B')
    s = tvm.create_schedule(B.op)

    def check_llvm():
        if not tvm.module.enabled("llvm"):
            return
        f = tvm.build(s, [A, B], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
        f(a, b)
        tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1)
    check_llvm()
コード例 #43
0
 def get_vthread(name):
     tx = tvm.thread_axis(name)
     ty = tvm.thread_axis(name)
     ib = tvm.ir_builder.create()
     A = ib.pointer("float32", name="A")
     C = ib.pointer("float32", name="C")
     with ib.for_range(0, n) as i:
         ib.scope_attr(tx, "virtual_thread", nthread)
         ib.scope_attr(ty, "virtual_thread", nthread)
         B = ib.allocate("float32", m, name="B", scope="shared")
         B[i] = A[i * nthread + tx]
         bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
         ib.emit(tvm.call_extern("int32", "Run",
                                 bbuffer.access_ptr("r"),
                                 tvm.call_pure_intrin("int32", "tvm_context_id")))
         C[i * nthread + tx] = B[i] + 1
     return ib.get()
コード例 #44
0
def test_cce_loop_3():
    ib = tvm.ir_builder.create()
    loop1 = 4
    loop2 = 9998
    tile = 39991
    with ib.for_range(0, loop2, 'i') as i:
        with ib.for_range(0, loop1, 'j') as j:
            head1 = i
            head2 = j
            with ib.if_scope(ib.likely(head1 * loop1 + head2 < tile)):
                ib.emit(tvm.call_extern('float16', "cce_intrisic", head1))

    stmt = ib.get()
    stmt = tvm.ir_pass.LoopPartition(stmt, True)
    stmt = tvm.ir_pass.Simplify(stmt)
    assert (not any(
        collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
コード例 #45
0
 def get_vthread(name):
     tx = tvm.thread_axis(name)
     ty = tvm.thread_axis(name)
     ib = tvm.ir_builder.create()
     A = ib.pointer("float32", name="A")
     C = ib.pointer("float32", name="C")
     with ib.for_range(0, n) as i:
         ib.scope_attr(tx, "virtual_thread", nthread)
         ib.scope_attr(ty, "virtual_thread", nthread)
         B = ib.allocate("float32", m, name="B", scope="shared")
         B[i] = A[i * nthread + tx]
         bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
         ib.emit(tvm.call_extern("int32", "Run",
                                 bbuffer.access_ptr("r"),
                                 tvm.call_pure_intrin("int32", "tvm_context_id")))
         C[i * nthread + tx] = B[i] + 1
     return ib.get()
コード例 #46
0
 def _body():
     ib = tvm.ir_builder.create()
     # int32_t matmul_kernel(const elem_t *A, const elem_t *B, const acc_t *D,
     #          elem_t *C, int32_t I, int32_t J, int32_t K, int32_t pad_I,
     #          int32_t pad_J, int32_t pad_K, int32_t A_row_len,
     #          int32_t B_row_len, int32_t D_row_len, int32_t C_row_len,
     #          bool no_bias, bool repeating_bias);
     # D is set to a dummy address 1 to determine whether to overwrite
     # accumulator contents: on the first run, 1 will be retained and
     # overwrite the value in the accumulator; on subsequent runs D will be
     # replaced by NULL and C will accumulate on top of the accumulator's contents
     # This is controlled via bit 1 << (ADDR_LEN - 2) - see kernel source
     ib.emit(
         tvm.call_extern("int32", "matmul_kernel", aa.access_ptr("r"),
                         bb.access_ptr("r"), 1, cc.access_ptr("rw"), II,
                         JJ, KK, pad_I, pad_J, pad_K, strideA, strideB,
                         0, strideC, True, False))
     return ib.get()
コード例 #47
0
 def get_vthread(name):
     tx = tvm.thread_axis(name)
     ty = tvm.thread_axis(name)
     ib = tvm.ir_builder.create()
     with ib.for_range(0, n) as i:
         ib.scope_attr(tx, "virtual_thread", nthread)
         ib.scope_attr(ty, "virtual_thread", nthread)
         A = ib.allocate("float32", m, name="A", scope="shared")
         B = ib.allocate("float32", m, name="B", scope="shared")
         C = ib.allocate("float32", m, name="C", scope="shared")
         cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asnode())
         abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asnode())
         bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
         A[tx] = tx + 1.0
         B[ty] = ty + 1.0
         ib.emit(tvm.call_extern("int32", "Run",
                                 abuffer.access_ptr("r"),
                                 bbuffer.access_ptr("r"),
                                 cbuffer.access_ptr("rw")))
     return ib.get()
コード例 #48
0
ファイル: test_schedule_tensorize.py プロジェクト: bddppq/tvm
    def intrin_multivadd(n):
        n_a = tvm.var("n_a")
        Ab = tvm.decl_buffer((n, ), tvm.float32, strides=[n_a])

        n_b = tvm.var("n_b")
        Bb = tvm.decl_buffer((n, ), tvm.float32, strides=[n_b])

        n_c = tvm.var("n_c")
        Cb = tvm.decl_buffer((n, ), tvm.float32, strides=[n_c])

        z = tvm.compute((n,), lambda i: tvm.call_extern("float32", 'vadd',
                                                        Ab.access_ptr("w", offset=n_a*i),
                                                        Bb.access_ptr("r", offset=n_b*i),
                                                        Cb.access_ptr("r", offset=n_c*i)))

        # replace the pattern with the multivadd call. I need to figure out
        # how to pass it the right parameters.
        def intrin_func(ins, outs):
            return tvm.call_packed("multivadd")

        with tvm.build_config():
            return tvm.decl_tensor_intrin(z.op, intrin_func, name="multivadd")
コード例 #49
0
ファイル: environment.py プロジェクト: LANHUIYING/tvm
def coproc_sync(op):
    _ = op
    return tvm.call_extern(
        "int32", "VTASynchronize",
        get_env().dev.command_handle, 1<<31)
コード例 #50
0
 def _body():
   ib = tvm.ir_builder.create()
   ib.emit(tvm.call_extern("int32", "test", cc.access_ptr("w"), aa.access_ptr("r")))
   return ib.get()
コード例 #51
0
ファイル: tensorize.py プロジェクト: LANHUIYING/tvm
 def _reduce_reset():
     ib = tvm.ir_builder.create()
     ib.emit(tvm.call_extern("int32", "gemv_reset", cc.access_ptr("w"), m))
     return ib.get()
コード例 #52
0
ファイル: environment.py プロジェクト: LANHUIYING/tvm
def coproc_dep_pop(op):
    return tvm.call_extern(
        "int32", "VTADepPop",
        get_env().dev.command_handle,
        op.args[0], op.args[1])
コード例 #53
0
ファイル: ir_pass.py プロジェクト: bddppq/tvm
    def _inject_copy(src, dst, pad_before, pad_after, pad_value):
        # FIXME: pad_value is ignored...
        _ = pad_value
        if dst.scope == "global":
            # Store
            if pad_before or pad_after:
                raise RuntimeError("Do not support copy into DRAM with pad")
            if src.scope == env.acc_scope:
                elem_width = env.OUT_WIDTH
                elem_bytes = env.OUT_ELEM_BYTES
                mem_type = env.dev.MEM_ID_OUT
                data_type = "int%d" % env.OUT_WIDTH
                task_qid = env.dev.QID_STORE_OUT
            else:
                raise RuntimeError("Do not support copy %s->dram" % (src.scope))
            _check_compact(src)
            x_size, y_size, x_stride, offset = _get_2d_pattern(
                dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True)
            irb = tvm.ir_builder.create()
            irb.scope_attr(env.dev.vta_axis, "coproc_scope",
                           env.dev.get_task_qid(task_qid))
            irb.emit(tvm.call_extern(
                "int32", "VTAStoreBuffer2D",
                env.dev.command_handle,
                src.access_ptr("r", "int32"),
                mem_type, dst.data, offset, x_size, y_size, x_stride))
            return irb.get()
        elif src.scope == "global":
            if dst.scope == env.acc_scope:
                elem_width = env.ACC_WIDTH
                elem_bytes = env.ACC_ELEM_BYTES
                mem_type = env.dev.MEM_ID_ACC
                data_type = "int%d" % env.ACC_WIDTH
                task_qid = env.dev.QID_LOAD_OUT
            elif dst.scope == env.inp_scope:
                elem_width = env.INP_WIDTH
                elem_bytes = env.INP_ELEM_BYTES
                mem_type = env.dev.MEM_ID_INP
                data_type = "int%d" % env.INP_WIDTH
                task_qid = env.dev.QID_LOAD_INP
            elif dst.scope == env.wgt_scope:
                elem_width = env.WGT_WIDTH
                elem_bytes = env.WGT_ELEM_BYTES
                mem_type = env.dev.MEM_ID_WGT
                data_type = "int%d" % env.WGT_WIDTH
                task_qid = env.dev.QID_LOAD_WGT
            else:
                raise RuntimeError("Do not support copy dram->%s" % (dst.scope))
            # collect pad statistics
            if pad_before:
                assert pad_after
                ndim = len(pad_before)
                if ndim <= 2 or ndim > 4:
                    raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim)
                if ndim > 2:
                    if not util.equal_const_int(pad_before[ndim - 1], 0):
                        raise ValueError("Do not support pad on the innermost block")
                    if not util.equal_const_int(pad_after[ndim - 1], 0):
                        raise ValueError("Do not support pad on the innermost block")
                if ndim > 3:
                    if not util.equal_const_int(pad_before[ndim - 2], 0):
                        raise ValueError("Do not support pad on the innermost block")
                    if not util.equal_const_int(pad_after[ndim - 2], 0):
                        raise ValueError("Do not support pad on the innermost block")
                y_pad_before = pad_before[0]
                x_pad_before = pad_before[1]
                y_pad_after = pad_after[0]
                x_pad_after = pad_after[1]
                allow_fold = False
            else:
                x_pad_before = 0
                y_pad_before = 0
                x_pad_after = 0
                y_pad_after = 0
                allow_fold = True

            _check_compact(dst)
            x_size, y_size, x_stride, offset = _get_2d_pattern(
                src, elem_width, elem_bytes, data_type,
                dst.scope, allow_fold=allow_fold)

            irb = tvm.ir_builder.create()
            irb.scope_attr(env.dev.vta_axis, "coproc_scope",
                           env.dev.get_task_qid(task_qid))

            irb.emit(tvm.call_extern(
                "int32", "VTALoadBuffer2D",
                env.dev.command_handle,
                src.data, offset, x_size, y_size, x_stride,
                x_pad_before, y_pad_before,
                x_pad_after, y_pad_after,
                dst.access_ptr("r", "int32"), mem_type))
            return irb.get()

        else:
            raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))
コード例 #54
0
ファイル: ir_pass.py プロジェクト: bddppq/tvm
    def _do_fold(stmt):
        def _equal(x, y):
            return tvm.ir_pass.Equal(tvm.ir_pass.Simplify(x - y), 0)

        def _flatten_loop(src_coeff, dst_coeff, extents):
            src_coeff = list(src_coeff)
            dst_coeff = list(dst_coeff)
            extents = list(extents)
            rev_src_coeff = [src_coeff.pop()]
            rev_dst_coeff = [dst_coeff.pop()]
            rev_extents = []
            assert src_coeff
            vsrc = src_coeff.pop()
            vdst = dst_coeff.pop()
            vext = extents.pop()
            while src_coeff:
                next_src = src_coeff.pop()
                next_dst = dst_coeff.pop()
                next_ext = extents.pop()

                if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext):
                    vext = tvm.ir_pass.Simplify(vext * next_ext)
                else:
                    rev_src_coeff.append(vsrc)
                    rev_dst_coeff.append(vdst)
                    rev_extents.append(vext)
                    vsrc = next_src
                    vdst = next_dst
                    vext = next_ext
            rev_src_coeff.append(vsrc)
            rev_dst_coeff.append(vdst)
            rev_extents.append(vext)
            rev_src_coeff.reverse()
            rev_dst_coeff.reverse()
            rev_extents.reverse()

            return rev_src_coeff, rev_dst_coeff, rev_extents

        if _match_pragma(stmt, "alu"):
            # Get to the innermost loop body
            loop_body = stmt.body
            nest_size = 0
            while isinstance(loop_body, tvm.stmt.For):
                loop_body = loop_body.body
                nest_size += 1
            # Get the src/dst arguments
            dst_var = loop_body.buffer_var
            dst_idx = loop_body.index
            # Derive loop variables and extents
            tmp_body = stmt.body
            indices = []
            extents = []
            for _ in range(nest_size):
                indices.append(tmp_body.loop_var)
                extents.append(tmp_body.extent)
                tmp_body = tmp_body.body
            # Derive opcode
            if isinstance(loop_body.value, tvm.expr.Add):
                alu_opcode = env.dev.ALU_OPCODE_ADD
                lhs = loop_body.value.a
                rhs = loop_body.value.b
            elif isinstance(loop_body.value, tvm.expr.Sub):
                alu_opcode = env.dev.ALU_OPCODE_SUB
                lhs = loop_body.value.a
                rhs = loop_body.value.b
            elif isinstance(loop_body.value, tvm.expr.Mul):
                alu_opcode = env.dev.ALU_OPCODE_MUL
                lhs = loop_body.value.a
                rhs = loop_body.value.b
            elif isinstance(loop_body.value, tvm.expr.Min):
                alu_opcode = env.dev.ALU_OPCODE_MIN
                lhs = loop_body.value.a
                rhs = loop_body.value.b
            elif isinstance(loop_body.value, tvm.expr.Max):
                alu_opcode = env.dev.ALU_OPCODE_MAX
                lhs = loop_body.value.a
                rhs = loop_body.value.b
            elif isinstance(loop_body.value, tvm.expr.Call):
                if loop_body.value.name == 'shift_left':
                    alu_opcode = env.dev.ALU_OPCODE_SHR
                    lhs = loop_body.value.args[0]
                    rhs = tvm.ir_pass.Simplify(-loop_body.value.args[1])
                elif loop_body.value.name == 'shift_right':
                    alu_opcode = env.dev.ALU_OPCODE_SHR
                    lhs = loop_body.value.args[0]
                    rhs = loop_body.value.args[1]
                else:
                    raise RuntimeError(
                        "Function call not recognized %s" % (loop_body.value.name))
            elif isinstance(loop_body.value, tvm.expr.Load):
                alu_opcode = env.dev.ALU_OPCODE_SHR
                lhs = loop_body.value
                rhs = tvm.const(0, "int32")
            else:
                raise RuntimeError(
                    "Expression not recognized %s, %s, %s" % (
                        type(loop_body.value), str(loop_body.value), str(stmt)))

            # Derive array index coefficients
            dst_coeff = tvm.arith.DetectLinearEquation(dst_idx, indices)
            # Check if lhs/rhs is immediate
            use_imm = False
            imm_val = None
            if isinstance(rhs, tvm.expr.IntImm):
                assert lhs.buffer_var.same_as(dst_var)
                src_coeff = tvm.arith.DetectLinearEquation(lhs.index, indices)
                use_imm = True
                imm_val = rhs
            if isinstance(lhs, tvm.expr.IntImm):
                assert rhs.buffer_var.same_as(dst_var)
                src_coeff = tvm.arith.DetectLinearEquation(rhs.index, indices)
                use_imm = True
                imm_val = lhs
            if imm_val is None:
                imm_val = 0
                assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var)
                src_lhs_coeff = tvm.arith.DetectLinearEquation(lhs.index, indices)
                src_rhs_coeff = tvm.arith.DetectLinearEquation(rhs.index, indices)
                # Determine which side has the same coefficients
                lhs_equal = True
                rhs_equal = True
                for i, coef in enumerate(dst_coeff):
                    if not tvm.ir_pass.Equal(coef, src_lhs_coeff[i]):
                        lhs_equal = False
                    if not tvm.ir_pass.Equal(coef, src_rhs_coeff[i]):
                        rhs_equal = False
                # Make sure at least one of the source is identical to the
                # destination (in-place computation)
                assert lhs_equal or rhs_equal
                # Assign the source coefficients
                if lhs_equal:
                    src_coeff = src_rhs_coeff
                else:
                    src_coeff = src_lhs_coeff

            # Ensure that we have the proper tensor dimensions in the
            # innermost loop (pattern match)
            src_coeff = list(src_coeff)
            dst_coeff = list(dst_coeff)
            extents = list(extents)
            assert len(src_coeff) > 1
            assert len(dst_coeff) > 1
            assert len(extents) != 0
            assert tvm.ir_pass.Equal(
                tvm.ir_pass.Simplify(
                    src_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0)
            assert tvm.ir_pass.Equal(
                tvm.ir_pass.Simplify(
                    dst_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0)
            assert tvm.ir_pass.Equal(src_coeff[-2], 1)
            assert tvm.ir_pass.Equal(dst_coeff[-2], 1)
            if env.BATCH > 1:
                assert len(src_coeff) > 2
                assert len(dst_coeff) > 2
                assert len(extents) > 1
                assert tvm.ir_pass.Equal(src_coeff[-3], env.BLOCK_OUT)
                assert tvm.ir_pass.Equal(dst_coeff[-3], env.BLOCK_OUT)

            # Apply tensorization of the loop coefficients
            src_offset = src_coeff[-1]
            dst_offset = dst_coeff[-1]
            if env.BATCH == 1:
                src_coeff = src_coeff[:-2]
                dst_coeff = dst_coeff[:-2]
                extents = extents[:-1]
            else:
                src_coeff = src_coeff[:-3]
                dst_coeff = dst_coeff[:-3]
                extents = extents[:-2]
            src_coeff.append(src_offset)
            dst_coeff.append(dst_offset)
            src_coeff = [
                tvm.ir_pass.Simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
            dst_coeff = [
                tvm.ir_pass.Simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]

            # Flatten the outer loops
            if extents:
                src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents)

            # Insert ALU micro-ops
            irb = tvm.ir_builder.create()
            for idx, extent in enumerate(extents):
                irb.emit(tvm.call_extern(
                    "int32", "VTAUopLoopBegin",
                    extent, dst_coeff[idx], src_coeff[idx], 0))
            use_imm = int(use_imm)
            irb.emit(tvm.call_extern(
                "int32", "VTAUopPush",
                1, 0,
                dst_coeff[len(dst_coeff)-1],
                src_coeff[len(src_coeff)-1],
                0,
                alu_opcode, use_imm, imm_val))
            for extent in extents:
                irb.emit(tvm.call_extern(
                    "int32", "VTAUopLoopEnd"))
            return irb.get()
        return stmt
コード例 #55
0
 def device_context(dev_id):
     ctx = tvm.call_extern("handle", "device_context", dev_type, dev_id)
     return tvm.make.Call(
         "handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0)
コード例 #56
0
ファイル: test_pass_ir_transform.py プロジェクト: bddppq/tvm
 def postorder(op):
     assert isinstance(op, tvm.expr.Call)
     if op.name == "TestA":
         return tvm.call_extern("int32", "TestB", op.args[0] + 1)
     return op
コード例 #57
0
 def intrin_func(ins, outs):
     ib = tvm.ir_builder.create()
     ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
     return ib.get()