def test_ir_transform(): ib = tvm.ir_builder.create() n = tvm.var("n") with ib.for_range(0, n, name="i") as i: with ib.for_range(0, 10, name="j") as j: x = tvm.call_extern("int32", "TestA", i * 3 + j * 1) ib.emit(tvm.call_extern("int32", "TestB", x)) ib.emit(tvm.call_extern("int32", "TestC", x)) body = ib.get() def preorder(op): if op.name == "TestC": return tvm.const(0, "int32") return None def postorder(op): assert isinstance(op, tvm.tir.Call) if op.name == "TestA": return tvm.call_extern("int32", "TestB", op.args[0] + 1) return op body = tvm.ir_pass.IRTransform(body, preorder, postorder, ["Call"]) stmt_list = tvm.tir.stmt_list(body.body.body) assert stmt_list[0].value.args[0].name == "TestB" assert stmt_list[1].value.value == 0
def instr(index): """Generate matrix-matrix multiply VTA instruction""" irb = tvm.ir_builder.create() dev = env.dev irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) if index in (0, 2): irb.emit(tvm.call_extern( "int32", "VTAUopPush", 0, 0, dout.access_ptr("rw", "int32"), dinp.access_ptr("r", "int32"), dwgt.access_ptr("r", "int32"), 0, 0, 0)) else: irb.emit(tvm.call_extern( "int32", "VTAUopPush", 0, 1, dout.access_ptr("rw", "int32"), 0, 0, 0, 0, 0)) return irb.get()
def _fold_outermost_loop(body): stmt = body while not isinstance(stmt, tvm.stmt.For): if isinstance(stmt, (tvm.stmt.ProducerConsumer, )): stmt = stmt.body else: return None, body, None loop_var = stmt.loop_var gemm_offsets = [None, None, None] fail = [False] def _post_order(op): assert isinstance(op, tvm.expr.Call) base_args = 2 if op.name == "VTAUopPush": args = [] args += op.args[:base_args] for i in range(3): m = tvm.arith.DetectLinearEquation(op.args[i + base_args], [loop_var]) if not m: fail[0] = True return op if gemm_offsets[i] is not None: if not tvm.ir_pass.Equal(m[0], gemm_offsets[i]): fail[0] = True return op args.append(m[1]) else: gemm_offsets[i] = m[0] args.append(m[1]) args += op.args[base_args + 3:] return tvm.call_extern("int32", "VTAUopPush", *args) else: if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"): raise RuntimeError("unexpected op %s" % op) return op ret = tvm.ir_pass.IRTransform(stmt.body, None, _post_order, ["Call"]) if not fail[0] and all(x is not None for x in gemm_offsets): def _visit(op): if op.same_as(loop_var): fail[0] = True tvm.ir_pass.PostOrderVisit(ret, _visit) if not fail[0]: begin = tvm.call_extern("int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets) end = tvm.call_extern("int32", "VTAUopLoopEnd") return [begin, ret, end] raise ValueError("Failed to fold the GEMM instructions..")
def _fold_outermost_loop(body): stmt = body while not isinstance(stmt, tvm.stmt.For): if isinstance(stmt, (tvm.stmt.ProducerConsumer,)): stmt = stmt.body else: return None, body, None loop_var = stmt.loop_var gemm_offsets = [None, None, None] fail = [False] def _post_order(op): assert isinstance(op, tvm.expr.Call) base_args = 2 if op.name == "VTAUopPush": args = [] args += op.args[:base_args] for i in range(3): m = tvm.arith.DetectLinearEquation( op.args[i + base_args], [loop_var]) if not m: fail[0] = True return op if gemm_offsets[i] is not None: if not tvm.ir_pass.Equal(m[0], gemm_offsets[i]): fail[0] = True return op args.append(m[1]) else: gemm_offsets[i] = m[0] args.append(m[1]) args += op.args[base_args+3:] return tvm.call_extern("int32", "VTAUopPush", *args) if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"): raise RuntimeError("unexpected op %s" % op) return op ret = tvm.ir_pass.IRTransform( stmt.body, None, _post_order, ["Call"]) if not fail[0] and all(x is not None for x in gemm_offsets): def _visit(op): if op.same_as(loop_var): fail[0] = True tvm.ir_pass.PostOrderVisit(ret, _visit) if not fail[0]: begin = tvm.call_extern( "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets) end = tvm.call_extern("int32", "VTAUopLoopEnd") return [begin, ret, end] raise ValueError("Failed to fold the GEMM instructions..")
def intrin_func(ins, outs): # tvm call extern is used to interface to libxsmm batch reduce kernel gemm implementation # rco*r*s is the number of batches init_and_compute = tvm.call_extern ("int32","batch_reduce_kernel_init_update", ins[0].access_ptr("r"),ins[1].access_ptr("r"),outs[0].access_ptr("w"),\ rco*r*s,ofmblock,ifmblock,r,s,ifh_stride,ifw_stride, ofw, stride_width) reset = tvm.call_extern("int32", "batch_reduce_kernel_init", outs[0].access_ptr("w"), ofmblock, ofw) body = tvm.call_extern ("int32","batch_reduce_kernel_update", ins[0].access_ptr("r"),ins[1].access_ptr("r"),outs[0].access_ptr("w"), rco*r*s,ofmblock,\ ifmblock,ofw, stride_width,r,s, ifh_stride,ifw_stride) if math.ceil(in_channel / ifmblock) == rco: return init_and_compute, None, init_and_compute else: return init_and_compute, reset, body
def intrin_func(ins, outs): sp = [ INPUT_W, INPUT_H, PAD_LEFT, PAD_RIGHT, PAD_TOP, PAD_BOTTOM, # FMATRIX W_IDX_KERNEL, H_IDX_KERNEL, W_IDX, H_IDX, C1_IDX, # Xm STRIDE_W, STRIDE_H, KERNEL_W, KERNEL_H, DILATION_W, DILATION_H, JUMP_OFFSET, REPEAT_MODE, REPEAT_TIME, # Xt ] aa = ins[0] bb = outs[0] ib = tvm.ir_builder.create() fcol2img, Xm, Xt = pack_args(sp) ib.emit(tvm.call_extern(dtype, "set_fcol2img", fcol2img)) ib.emit( tvm.call_extern(dtype, "vector_dup", bb.access_ptr("w"), 0, (INPUT_H * INPUT_W * 16) // 64, 1, 1, 8, 8)) c = 0 for kh in range(KERNEL_H): for kw in range(KERNEL_W): sp[6] = kw sp[7] = kh fcol2img, Xm, Xt = pack_args(sp) ib.emit( tvm.call_extern( dtype, "col2img", bb.access_ptr("rw"), aa.access_ptr("r", offset=c * 16 * INPUT_C0 * REPEAT_TIME), Xm, Xt, )) c += 1 return ib.get()
def add_debug(stmt): debug = tvm.call_extern( "int32", "VTASetDebugMode", env.dev.command_handle, debug_flag) return tvm.make.stmt_seq(debug, stmt)
def _post_order(op): assert isinstance(op, tvm.expr.Call) base_args = 2 if op.name == "VTAUopPush": args = [] args += op.args[:base_args] for i in range(3): m = tvm.arith.DetectLinearEquation(op.args[i + base_args], [loop_var]) if not m: fail[0] = True return op if gemm_offsets[i] is not None: if not tvm.ir_pass.Equal(m[0], gemm_offsets[i]): fail[0] = True return op args.append(m[1]) else: gemm_offsets[i] = m[0] args.append(m[1]) args += op.args[base_args + 3:] return tvm.call_extern("int32", "VTAUopPush", *args) else: if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"): raise RuntimeError("unexpected op %s" % op) return op
def mma_sync(inputs, outputs): print(inputs) factor1_, factor2_, product_ = inputs schedule_ = outputs[0] #get address for matrix A A_ptr = factor1_.access_ptr("r") #get address for matrix B B_ptr = factor2_.access_ptr("r") #get address for matrix C C_ptr = product_.access_ptr("w") body = tvm.call_extern('float32', "wmma_call", A_ptr, B_ptr, C_ptr) #body = tvm.call_extern('float32',"__INIT_TILE_WARP__") init = tvm.call_extern('float32', "__INIT_TILE_WARP__") #product_.vstore((0,0,0,0),0.) return body, init, body
def intrin_func(ins, outs): ib = tvm.ir_builder.create() ib.emit( tvm.call_extern("float32", 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get()
def _post_order(op): if isinstance(op, tvm.stmt.Allocate): buffer_var = op.buffer_var if not buffer_var in rw_info: return None new_var = rw_info[buffer_var] let_stmt = tvm.make.LetStmt( new_var, tvm.call_extern("handle", "VTABufferCPUPtr", env.dev.command_handle, buffer_var), op.body) alloc = tvm.make.Allocate(buffer_var, op.dtype, op.extents, op.condition, let_stmt) del rw_info[buffer_var] return alloc elif isinstance(op, tvm.expr.Load): buffer_var = op.buffer_var if not buffer_var in rw_info: rw_info[buffer_var] = tvm.var(buffer_var.name + "_ptr", "handle") new_var = rw_info[buffer_var] return tvm.make.Load(op.dtype, new_var, op.index) elif isinstance(op, tvm.stmt.Store): buffer_var = op.buffer_var if not buffer_var in rw_info: rw_info[buffer_var] = tvm.var(buffer_var.name + "_ptr", "handle") new_var = rw_info[buffer_var] return tvm.make.Store(new_var, op.value, op.index) else: raise RuntimeError("not reached")
def meminfo_cache(): return tvm.make.node("MemoryInfo", unit_bits=8, max_simd_bits=32, max_num_bits=128, head_address=tvm.call_extern( "handle", "global_cache"))
def meminfo_cache(): return tvm.make.node( "MemoryInfo", unit_bits=8, max_simd_bits=32, max_num_bits=128, head_address=tvm.call_extern("handle", "global_cache"))
def cpu_access_rewrite(stmt_in): """Detect CPU access to VTA buffer and get address correctly. VTA's buffer is an opaque handle that do not correspond to address in CPU. This pass detect CPU access and rewrite to use pointer returned VTABufferCPUPtr for CPU access. Parameters ---------- stmt_in : Stmt Input statement Returns ------- stmt_out : Stmt Transformed statement """ env = get_env() rw_info = {} def _post_order(op): if isinstance(op, tvm.stmt.Allocate): buffer_var = op.buffer_var if not buffer_var in rw_info: return None new_var = rw_info[buffer_var] let_stmt = tvm.make.LetStmt( new_var, tvm.call_extern( "handle", "VTABufferCPUPtr", env.dev.command_handle, buffer_var), op.body) alloc = tvm.make.Allocate( buffer_var, op.dtype, op.extents, op.condition, let_stmt) del rw_info[buffer_var] return alloc if isinstance(op, tvm.expr.Load): buffer_var = op.buffer_var if not buffer_var in rw_info: rw_info[buffer_var] = tvm.var( buffer_var.name + "_ptr", "handle") new_var = rw_info[buffer_var] return tvm.make.Load(op.dtype, new_var, op.index) if isinstance(op, tvm.stmt.Store): buffer_var = op.buffer_var if not buffer_var in rw_info: rw_info[buffer_var] = tvm.var( buffer_var.name + "_ptr", "handle") new_var = rw_info[buffer_var] return tvm.make.Store(new_var, op.value, op.index) raise RuntimeError("not reached") stmt = tvm.ir_pass.IRTransform( stmt_in, None, _post_order, ["Allocate", "Load", "Store"]) for buffer_var, new_var in rw_info.items(): stmt = tvm.make.LetStmt( new_var, tvm.call_extern( "handle", "VTABufferCPUPtr", env.dev.command_handle, buffer_var), stmt) return stmt
def _body(): ib = tvm.ir_builder.create() ib.emit( tvm.call_extern("int32", "gemv_update", cc.access_ptr("w"), aa.access_ptr("r"), bb.access_ptr("r"), m, l, bb.strides[0])) return ib.get()
def _post_order(op): assert isinstance(op, tvm.expr.Call) base_args = 2 if op.name == "VTAUopPush": args = [] args += op.args[:base_args] for i in range(3): m = tvm.arith.DetectLinearEquation( op.args[i + base_args], [loop_var]) if not m: fail[0] = True return op if gemm_offsets[i] is not None: if not tvm.ir_pass.Equal(m[0], gemm_offsets[i]): fail[0] = True return op args.append(m[1]) else: gemm_offsets[i] = m[0] args.append(m[1]) args += op.args[base_args+3:] return tvm.call_extern("int32", "VTAUopPush", *args) if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"): raise RuntimeError("unexpected op %s" % op) return op
def _post_order(op): if isinstance(op, tvm.stmt.Allocate): buffer_var = op.buffer_var if not buffer_var in rw_info: return None new_var = rw_info[buffer_var] let_stmt = tvm.make.LetStmt( new_var, tvm.call_extern( "handle", "VTABufferCPUPtr", env.dev.command_handle, buffer_var), op.body) alloc = tvm.make.Allocate( buffer_var, op.dtype, op.extents, op.condition, let_stmt) del rw_info[buffer_var] return alloc if isinstance(op, tvm.expr.Load): buffer_var = op.buffer_var if not buffer_var in rw_info: rw_info[buffer_var] = tvm.var( buffer_var.name + "_ptr", "handle") new_var = rw_info[buffer_var] return tvm.make.Load(op.dtype, new_var, op.index) if isinstance(op, tvm.stmt.Store): buffer_var = op.buffer_var if not buffer_var in rw_info: rw_info[buffer_var] = tvm.var( buffer_var.name + "_ptr", "handle") new_var = rw_info[buffer_var] return tvm.make.Store(new_var, op.value, op.index) raise RuntimeError("not reached")
def cpu_access_rewrite(stmt_in): """Detect CPU access to VTA buffer and get address correctly. VTA's buffer is an opaque handle that do not correspond to address in CPU. This pass detect CPU access and rewrite to use pointer returned VTABufferCPUPtr for CPU access. Parameters ---------- stmt_in : Stmt Input statement Returns ------- stmt_out : Stmt Transformed statement """ env = get_env() rw_info = {} def _post_order(op): if isinstance(op, tvm.stmt.Allocate): buffer_var = op.buffer_var if not buffer_var in rw_info: return None new_var = rw_info[buffer_var] let_stmt = tvm.make.LetStmt( new_var, tvm.call_extern("handle", "VTABufferCPUPtr", env.dev.command_handle, buffer_var), op.body) alloc = tvm.make.Allocate(buffer_var, op.dtype, op.extents, op.condition, let_stmt) del rw_info[buffer_var] return alloc elif isinstance(op, tvm.expr.Load): buffer_var = op.buffer_var if not buffer_var in rw_info: rw_info[buffer_var] = tvm.var(buffer_var.name + "_ptr", "handle") new_var = rw_info[buffer_var] return tvm.make.Load(op.dtype, new_var, op.index) elif isinstance(op, tvm.stmt.Store): buffer_var = op.buffer_var if not buffer_var in rw_info: rw_info[buffer_var] = tvm.var(buffer_var.name + "_ptr", "handle") new_var = rw_info[buffer_var] return tvm.make.Store(new_var, op.value, op.index) else: raise RuntimeError("not reached") stmt = tvm.ir_pass.IRTransform(stmt_in, None, _post_order, ["Allocate", "Load", "Store"]) for buffer_var, new_var in rw_info.items(): stmt = tvm.make.LetStmt( new_var, tvm.call_extern("handle", "VTABufferCPUPtr", env.dev.command_handle, buffer_var), stmt) return stmt
def _body(): ib = tvm.ir_builder.create() ib.emit(tvm.call_extern("int32", "gemv_update", cc.access_ptr("w"), aa.access_ptr("r"), bb.access_ptr("r"), m, l, bb.strides[0])) return ib.get()
def _reset(): ib = tvm.ir_builder.create() # int32_t matmul_reset(elem_t *C, int32_t I, int32_t J, int32_t pad_I, # int32_t pad_J, int32_t C_row_len); ib.emit( tvm.call_extern("int32", "matmul_reset", cc.access_ptr("w"), II, JJ, pad_I, pad_J, strideC)) return ib.get()
def test_buffer_access_ptr_offset(): m = tvm.var('m') n = tvm.var('n') Ab = tvm.decl_buffer((m, n), tvm.float32) aptr = Ab.access_ptr("rw", offset=100) offset = tvm.ir_pass.Simplify(aptr.args[2]) assert tvm.ir_pass.Equal(offset, 100) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE v = tvm.var('int32') aptr = Ab.access_ptr("rw", offset=100 + 100 + v) offset = tvm.ir_pass.Simplify(aptr.args[2]) assert tvm.ir_pass.Equal(offset, 200 + v) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE aptr = Ab.access_ptr("rw", offset=tvm.call_extern('int32', "test_call", 100 + 100 + v)) offset = tvm.ir_pass.Simplify(aptr.args[2]) assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v)) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
def intrin_func(ins, outs): ib = tvm.ir_builder.create() aa, bb = ins cc = outs[0] ib.emit( tvm.call_extern("int32", "gemv_update", cc.access_ptr("w"), aa.access_ptr("r"), bb.access_ptr("r"), m, l, bb.strides[0])) return ib.get()
def reset(): irb = tvm.ir_builder.create() extern_call = tvm.call_extern( "int32", "sgemm_reset_{M}x{N}__{ARCH}".format(M=M, N=N, ARCH=ARCH), irb.buffer_ptr(cc), cc.elem_offset, cc.strides[0]) irb.emit(extern_call) return irb.get()
def test_cce_loop_2(): ib = tvm.ir_builder.create() len = 112 tile = 32 loop = (len + tile - 1) // tile with ib.for_range(0, loop, 'i') as i: head = i * tile with ib.if_scope(ib.likely(head + tile > len)): tail = len ib.emit(tvm.call_extern('float32', "cce_intrisic", head, tail)) with ib.else_scope(): tail = head + tile ib.emit(tvm.call_extern('float32', "cce_intrisic", head, tail)) stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt, True) stmt = tvm.ir_pass.Simplify(stmt) assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def _finalize(): ib = tvm.ir_builder.create() # Move out C from accumulator # int32_t matmul_finalize(elem_t *C, int32_t I, int32_t J, int32_t pad_I, # int32_t pad_J, int32_t C_row_len); ib.emit( tvm.call_extern("int32", "matmul_finalize", cc.access_ptr("rw"), II, JJ, pad_I, pad_J, strideC)) return ib.get()
def __init__(self, env): self.vta_axis = tvm.thread_axis("vta") self.vta_push_uop = tvm.make.StringImm("VTAPushGEMMOp") ctx = tvm.call_extern("handle", "VTATLSCommandHandle") self.command_handle = tvm.make.Call( "handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0) self.DEBUG_NO_SYNC = False env._dev_ctx = self self.gemm = intrin.gemm(env, env.mock_mode)
def intrin_func(ins, outs): ib = tvm.ir_builder.create() aa, bb = ins cc = outs[0] ib.emit(tvm.call_extern("int32", "gemv_update", cc.access_ptr("w"), aa.access_ptr("r"), bb.access_ptr("r"), m, l, bb.strides[0])) return ib.get()
def test_cce_loop_2(): ib = tvm.ir_builder.create() len = 112 tile = 32 loop = (len + tile - 1) // tile with ib.for_range(0, loop, 'i') as i: head = i * tile with ib.if_scope(ib.likely(head + tile > len)): tail = len ib.emit(tvm.call_extern('float32', "cce_intrisic", head, tail)) with ib.else_scope(): tail = head + tile ib.emit(tvm.call_extern('float32', "cce_intrisic", head, tail)) stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt, True) stmt = tvm.ir_pass.Simplify(stmt) assert (not any( collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def instr(index): """Generate matrix-matrix multiply VTA instruction""" irb = tvm.ir_builder.create() dev = env.dev irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) if index == 0 or index == 2: irb.emit( tvm.call_extern("int32", "VTAUopPush", 0, 0, dout.access_ptr("rw", "int32"), dinp.access_ptr("r", "int32"), dwgt.access_ptr("r", "int32"), 0, 0, 0)) else: irb.emit( tvm.call_extern("int32", "VTAUopPush", 0, 1, dout.access_ptr("rw", "int32"), 0, 0, 0, 0, 0)) return irb.get()
def __init__(self, env): self.vta_axis = tvm.thread_axis("vta") self.vta_push_uop = tvm.make.StringImm("VTAPushGEMMOp") ctx = tvm.call_extern("handle", "VTATLSCommandHandle") self.command_handle = tvm.make.Call("handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0) self.DEBUG_NO_SYNC = False env._dev_ctx = self self.gemm = intrin.gemm(env, env.mock_mode)
def body(): irb = tvm.ir_builder.create() extern_call = tvm.call_extern( "int32", "sgemm_compute_{M}x{N}__{ARCH}".format(M=M, N=N, ARCH=ARCH), K, irb.buffer_ptr(aa), aa.elem_offset, irb.buffer_ptr(bb), bb.elem_offset, irb.buffer_ptr(cc), cc.elem_offset, cc.strides[0]) irb.emit(extern_call) return irb.get()
def test_for(): dev_type = tvm.var("dev_type") def device_context(dev_id): ctx = tvm.call_extern("handle", "device_context", dev_type, dev_id) return tvm.make.Call("handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0) ib = tvm.ir_builder.create() n = tvm.var("n") A = ib.allocate("float32", n, name="A", scope="global") with ib.for_range(0, n, name="i") as i: ib.emit(tvm.call_extern("int32", "fadd", device_context(0), A)) with ib.for_range(0, 10, name="j") as j: ib.emit(tvm.call_extern("int32", "fadd", device_context(1), A)) ib.emit(tvm.call_extern("int32", "fadd", device_context(0), A)) body = ib.get() f = tvm.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True) f = tvm.ir_pass.CombineContextCall(f) assert f.body.value.dtype == "handle" assert f.body.body.value.dtype == "handle"
def Gemm_ir_wmma(A,B,C): ib = tvm.ir_builder.create() thread_x=tvm.thread_axis('threadIdx.x') ib.scope_attr(thread_x,'thread_extent',num_thread) #declare shared memory offsetb = 16*16 sync = tvm.call_extern("float32","__syncthreads") #define fragment def_matrix_frag = tvm.call_extern("float32","__FRAGMENT_F16__") ib.emit(def_matrix_frag) load_matrix_frag_a = tvm.call_extern("float32","__LOADFRAG_A__",A,0,16) ib.emit(load_matrix_frag_a) load_matrix_frag_b = tvm.call_extern("float32","__LOADFRAG_B__",B,0,16) ib.emit(load_matrix_frag_b) ib.emit(sync) wmma_compute = tvm.call_extern("float32","__WMMA_SYNC__",0,0) ib.emit(wmma_compute) ib.emit(sync) store_matrix_frag_c1 = tvm.call_extern("float32","__STOREFRAG_C_F16__",C,0,0,16) ib.emit(store_matrix_frag_c1) ib.emit(sync) body = ib.get() return(body)
def intrin_func(ins, outs): ib = tvm.ir_builder.create() inp, filt = ins outp = outs[0] ib.emit( tvm.call_extern( "int32", "inst_conv", outp.access_ptr("w"), inp.access_ptr("r"), filt.access_ptr("r"), )) return ib.get()
def test_for(): dev_type = tvm.var("dev_type") def device_context(dev_id): ctx = tvm.call_extern("handle", "device_context", dev_type, dev_id) return tvm.make.Call( "handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0) ib = tvm.ir_builder.create() n = tvm.var("n") A = ib.allocate("float32", n, name="A", scope="global") with ib.for_range(0, n, name="i") as i: ib.emit(tvm.call_extern ("int32", "fadd", device_context(0), A)) with ib.for_range(0, 10, name="j") as j: ib.emit(tvm.call_extern ("int32", "fadd", device_context(1), A)) ib.emit(tvm.call_extern ("int32", "fadd", device_context(0), A)) body = ib.get() f = tvm.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True) f = tvm.ir_pass.CombineContextCall(f) assert f.body.value.dtype == "handle" assert f.body.body.value.dtype == "handle"
def ir_warp(A, B): ib = tvm.ir_builder.create() A_ptr = ib.buffer_ptr(A) B_ptr = ib.buffer_ptr(B) tx = tvm.thread_axis('threadIdx.x') ib.scope_attr(tx, 'thread_extent', 10) i = tx with ib.if_scope(i % 2 == 0): B_ptr[i] = A_ptr[i] + 1 o1 = tvm.call_extern("float32", "__WMMA__") ib.emit(o1) body = ib.get() return (body)
def test_ir_transform(): ib = tvm.ir_builder.create() n = tvm.var("n") with ib.for_range(0, n, name="i") as i: with ib.for_range(0, 10, name="j") as j: x = tvm.call_extern("int32", "TestA", i * 3 + j * 1) ib.emit(tvm.call_extern("int32", "TestB", x)) ib.emit(tvm.call_extern("int32", "TestC", x)) body = ib.get() def preorder(op): if op.name == "TestC": return tvm.const(0, "int32") return None def postorder(op): assert isinstance(op, tvm.expr.Call) if op.name == "TestA": return tvm.call_extern("int32", "TestB", op.args[0] + 1) return op body = tvm.ir_pass.IRTransform(body, preorder, postorder, ["Call"]) stmt_list = tvm.make.stmt_list(body.body.body) assert stmt_list[0].value.args[0].name == "TestB" assert stmt_list[1].value.value == 0
def test_cce_loop_3(): ib = tvm.ir_builder.create() loop1 = 4 loop2 = 9998 tile = 39991 with ib.for_range(0,loop2,'i') as i: with ib.for_range(0,loop1,'j') as j: head1 = i head2 = j with ib.if_scope(ib.likely(head1*loop1 + head2 < tile)): ib.emit(tvm.call_extern('float16',"cce_intrisic",head1)) stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt,True) stmt = tvm.ir_pass.Simplify(stmt) assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def _body(): b_ = tvm.ir_builder.create() b_.emit( tvm.call_extern( "int32", opname, cc.access_ptr("w"), aa.access_ptr("r"), bb.access_ptr("r"), ci_, vh_, vw_, vc_, core_id, )) return b_.get()
def test_extern_call(): n = 10 A = tvm.placeholder((n,), name='A') B = tvm.compute((n,), lambda *i: tvm.call_extern("float32", "TVMTestAddOne", A(*i)), name='B') s = tvm.create_schedule(B.op) def check_llvm(): if not tvm.module.enabled("llvm"): return f = tvm.build(s, [A, B], "llvm") ctx = tvm.cpu(0) # launch the kernel. a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx) f(a, b) tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1) check_llvm()
def get_vthread(name): tx = tvm.thread_axis(name) ty = tvm.thread_axis(name) ib = tvm.ir_builder.create() A = ib.pointer("float32", name="A") C = ib.pointer("float32", name="C") with ib.for_range(0, n) as i: ib.scope_attr(tx, "virtual_thread", nthread) ib.scope_attr(ty, "virtual_thread", nthread) B = ib.allocate("float32", m, name="B", scope="shared") B[i] = A[i * nthread + tx] bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode()) ib.emit(tvm.call_extern("int32", "Run", bbuffer.access_ptr("r"), tvm.call_pure_intrin("int32", "tvm_context_id"))) C[i * nthread + tx] = B[i] + 1 return ib.get()
def test_cce_loop_3(): ib = tvm.ir_builder.create() loop1 = 4 loop2 = 9998 tile = 39991 with ib.for_range(0, loop2, 'i') as i: with ib.for_range(0, loop1, 'j') as j: head1 = i head2 = j with ib.if_scope(ib.likely(head1 * loop1 + head2 < tile)): ib.emit(tvm.call_extern('float16', "cce_intrisic", head1)) stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt, True) stmt = tvm.ir_pass.Simplify(stmt) assert (not any( collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def _body(): ib = tvm.ir_builder.create() # int32_t matmul_kernel(const elem_t *A, const elem_t *B, const acc_t *D, # elem_t *C, int32_t I, int32_t J, int32_t K, int32_t pad_I, # int32_t pad_J, int32_t pad_K, int32_t A_row_len, # int32_t B_row_len, int32_t D_row_len, int32_t C_row_len, # bool no_bias, bool repeating_bias); # D is set to a dummy address 1 to determine whether to overwrite # accumulator contents: on the first run, 1 will be retained and # overwrite the value in the accumulator; on subsequent runs D will be # replaced by NULL and C will accumulate on top of the accumulator's contents # This is controlled via bit 1 << (ADDR_LEN - 2) - see kernel source ib.emit( tvm.call_extern("int32", "matmul_kernel", aa.access_ptr("r"), bb.access_ptr("r"), 1, cc.access_ptr("rw"), II, JJ, KK, pad_I, pad_J, pad_K, strideA, strideB, 0, strideC, True, False)) return ib.get()
def get_vthread(name): tx = tvm.thread_axis(name) ty = tvm.thread_axis(name) ib = tvm.ir_builder.create() with ib.for_range(0, n) as i: ib.scope_attr(tx, "virtual_thread", nthread) ib.scope_attr(ty, "virtual_thread", nthread) A = ib.allocate("float32", m, name="A", scope="shared") B = ib.allocate("float32", m, name="B", scope="shared") C = ib.allocate("float32", m, name="C", scope="shared") cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asnode()) abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asnode()) bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode()) A[tx] = tx + 1.0 B[ty] = ty + 1.0 ib.emit(tvm.call_extern("int32", "Run", abuffer.access_ptr("r"), bbuffer.access_ptr("r"), cbuffer.access_ptr("rw"))) return ib.get()
def intrin_multivadd(n): n_a = tvm.var("n_a") Ab = tvm.decl_buffer((n, ), tvm.float32, strides=[n_a]) n_b = tvm.var("n_b") Bb = tvm.decl_buffer((n, ), tvm.float32, strides=[n_b]) n_c = tvm.var("n_c") Cb = tvm.decl_buffer((n, ), tvm.float32, strides=[n_c]) z = tvm.compute((n,), lambda i: tvm.call_extern("float32", 'vadd', Ab.access_ptr("w", offset=n_a*i), Bb.access_ptr("r", offset=n_b*i), Cb.access_ptr("r", offset=n_c*i))) # replace the pattern with the multivadd call. I need to figure out # how to pass it the right parameters. def intrin_func(ins, outs): return tvm.call_packed("multivadd") with tvm.build_config(): return tvm.decl_tensor_intrin(z.op, intrin_func, name="multivadd")
def coproc_sync(op): _ = op return tvm.call_extern( "int32", "VTASynchronize", get_env().dev.command_handle, 1<<31)
def _body(): ib = tvm.ir_builder.create() ib.emit(tvm.call_extern("int32", "test", cc.access_ptr("w"), aa.access_ptr("r"))) return ib.get()
def _reduce_reset(): ib = tvm.ir_builder.create() ib.emit(tvm.call_extern("int32", "gemv_reset", cc.access_ptr("w"), m)) return ib.get()
def coproc_dep_pop(op): return tvm.call_extern( "int32", "VTADepPop", get_env().dev.command_handle, op.args[0], op.args[1])
def _inject_copy(src, dst, pad_before, pad_after, pad_value): # FIXME: pad_value is ignored... _ = pad_value if dst.scope == "global": # Store if pad_before or pad_after: raise RuntimeError("Do not support copy into DRAM with pad") if src.scope == env.acc_scope: elem_width = env.OUT_WIDTH elem_bytes = env.OUT_ELEM_BYTES mem_type = env.dev.MEM_ID_OUT data_type = "int%d" % env.OUT_WIDTH task_qid = env.dev.QID_STORE_OUT else: raise RuntimeError("Do not support copy %s->dram" % (src.scope)) _check_compact(src) x_size, y_size, x_stride, offset = _get_2d_pattern( dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True) irb = tvm.ir_builder.create() irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid)) irb.emit(tvm.call_extern( "int32", "VTAStoreBuffer2D", env.dev.command_handle, src.access_ptr("r", "int32"), mem_type, dst.data, offset, x_size, y_size, x_stride)) return irb.get() elif src.scope == "global": if dst.scope == env.acc_scope: elem_width = env.ACC_WIDTH elem_bytes = env.ACC_ELEM_BYTES mem_type = env.dev.MEM_ID_ACC data_type = "int%d" % env.ACC_WIDTH task_qid = env.dev.QID_LOAD_OUT elif dst.scope == env.inp_scope: elem_width = env.INP_WIDTH elem_bytes = env.INP_ELEM_BYTES mem_type = env.dev.MEM_ID_INP data_type = "int%d" % env.INP_WIDTH task_qid = env.dev.QID_LOAD_INP elif dst.scope == env.wgt_scope: elem_width = env.WGT_WIDTH elem_bytes = env.WGT_ELEM_BYTES mem_type = env.dev.MEM_ID_WGT data_type = "int%d" % env.WGT_WIDTH task_qid = env.dev.QID_LOAD_WGT else: raise RuntimeError("Do not support copy dram->%s" % (dst.scope)) # collect pad statistics if pad_before: assert pad_after ndim = len(pad_before) if ndim <= 2 or ndim > 4: raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim) if ndim > 2: if not util.equal_const_int(pad_before[ndim - 1], 0): raise ValueError("Do not support pad on the innermost block") if not util.equal_const_int(pad_after[ndim - 1], 0): raise ValueError("Do not support pad on the innermost block") if ndim > 3: if not util.equal_const_int(pad_before[ndim - 2], 0): raise ValueError("Do not support pad on the innermost block") if not util.equal_const_int(pad_after[ndim - 2], 0): raise ValueError("Do not support pad on the innermost block") y_pad_before = pad_before[0] x_pad_before = pad_before[1] y_pad_after = pad_after[0] x_pad_after = pad_after[1] allow_fold = False else: x_pad_before = 0 y_pad_before = 0 x_pad_after = 0 y_pad_after = 0 allow_fold = True _check_compact(dst) x_size, y_size, x_stride, offset = _get_2d_pattern( src, elem_width, elem_bytes, data_type, dst.scope, allow_fold=allow_fold) irb = tvm.ir_builder.create() irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid)) irb.emit(tvm.call_extern( "int32", "VTALoadBuffer2D", env.dev.command_handle, src.data, offset, x_size, y_size, x_stride, x_pad_before, y_pad_before, x_pad_after, y_pad_after, dst.access_ptr("r", "int32"), mem_type)) return irb.get() else: raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))
def _do_fold(stmt): def _equal(x, y): return tvm.ir_pass.Equal(tvm.ir_pass.Simplify(x - y), 0) def _flatten_loop(src_coeff, dst_coeff, extents): src_coeff = list(src_coeff) dst_coeff = list(dst_coeff) extents = list(extents) rev_src_coeff = [src_coeff.pop()] rev_dst_coeff = [dst_coeff.pop()] rev_extents = [] assert src_coeff vsrc = src_coeff.pop() vdst = dst_coeff.pop() vext = extents.pop() while src_coeff: next_src = src_coeff.pop() next_dst = dst_coeff.pop() next_ext = extents.pop() if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext): vext = tvm.ir_pass.Simplify(vext * next_ext) else: rev_src_coeff.append(vsrc) rev_dst_coeff.append(vdst) rev_extents.append(vext) vsrc = next_src vdst = next_dst vext = next_ext rev_src_coeff.append(vsrc) rev_dst_coeff.append(vdst) rev_extents.append(vext) rev_src_coeff.reverse() rev_dst_coeff.reverse() rev_extents.reverse() return rev_src_coeff, rev_dst_coeff, rev_extents if _match_pragma(stmt, "alu"): # Get to the innermost loop body loop_body = stmt.body nest_size = 0 while isinstance(loop_body, tvm.stmt.For): loop_body = loop_body.body nest_size += 1 # Get the src/dst arguments dst_var = loop_body.buffer_var dst_idx = loop_body.index # Derive loop variables and extents tmp_body = stmt.body indices = [] extents = [] for _ in range(nest_size): indices.append(tmp_body.loop_var) extents.append(tmp_body.extent) tmp_body = tmp_body.body # Derive opcode if isinstance(loop_body.value, tvm.expr.Add): alu_opcode = env.dev.ALU_OPCODE_ADD lhs = loop_body.value.a rhs = loop_body.value.b elif isinstance(loop_body.value, tvm.expr.Sub): alu_opcode = env.dev.ALU_OPCODE_SUB lhs = loop_body.value.a rhs = loop_body.value.b elif isinstance(loop_body.value, tvm.expr.Mul): alu_opcode = env.dev.ALU_OPCODE_MUL lhs = loop_body.value.a rhs = loop_body.value.b elif isinstance(loop_body.value, tvm.expr.Min): alu_opcode = env.dev.ALU_OPCODE_MIN lhs = loop_body.value.a rhs = loop_body.value.b elif isinstance(loop_body.value, tvm.expr.Max): alu_opcode = env.dev.ALU_OPCODE_MAX lhs = loop_body.value.a rhs = loop_body.value.b elif isinstance(loop_body.value, tvm.expr.Call): if loop_body.value.name == 'shift_left': alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value.args[0] rhs = tvm.ir_pass.Simplify(-loop_body.value.args[1]) elif loop_body.value.name == 'shift_right': alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value.args[0] rhs = loop_body.value.args[1] else: raise RuntimeError( "Function call not recognized %s" % (loop_body.value.name)) elif isinstance(loop_body.value, tvm.expr.Load): alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value rhs = tvm.const(0, "int32") else: raise RuntimeError( "Expression not recognized %s, %s, %s" % ( type(loop_body.value), str(loop_body.value), str(stmt))) # Derive array index coefficients dst_coeff = tvm.arith.DetectLinearEquation(dst_idx, indices) # Check if lhs/rhs is immediate use_imm = False imm_val = None if isinstance(rhs, tvm.expr.IntImm): assert lhs.buffer_var.same_as(dst_var) src_coeff = tvm.arith.DetectLinearEquation(lhs.index, indices) use_imm = True imm_val = rhs if isinstance(lhs, tvm.expr.IntImm): assert rhs.buffer_var.same_as(dst_var) src_coeff = tvm.arith.DetectLinearEquation(rhs.index, indices) use_imm = True imm_val = lhs if imm_val is None: imm_val = 0 assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var) src_lhs_coeff = tvm.arith.DetectLinearEquation(lhs.index, indices) src_rhs_coeff = tvm.arith.DetectLinearEquation(rhs.index, indices) # Determine which side has the same coefficients lhs_equal = True rhs_equal = True for i, coef in enumerate(dst_coeff): if not tvm.ir_pass.Equal(coef, src_lhs_coeff[i]): lhs_equal = False if not tvm.ir_pass.Equal(coef, src_rhs_coeff[i]): rhs_equal = False # Make sure at least one of the source is identical to the # destination (in-place computation) assert lhs_equal or rhs_equal # Assign the source coefficients if lhs_equal: src_coeff = src_rhs_coeff else: src_coeff = src_lhs_coeff # Ensure that we have the proper tensor dimensions in the # innermost loop (pattern match) src_coeff = list(src_coeff) dst_coeff = list(dst_coeff) extents = list(extents) assert len(src_coeff) > 1 assert len(dst_coeff) > 1 assert len(extents) != 0 assert tvm.ir_pass.Equal( tvm.ir_pass.Simplify( src_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0) assert tvm.ir_pass.Equal( tvm.ir_pass.Simplify( dst_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0) assert tvm.ir_pass.Equal(src_coeff[-2], 1) assert tvm.ir_pass.Equal(dst_coeff[-2], 1) if env.BATCH > 1: assert len(src_coeff) > 2 assert len(dst_coeff) > 2 assert len(extents) > 1 assert tvm.ir_pass.Equal(src_coeff[-3], env.BLOCK_OUT) assert tvm.ir_pass.Equal(dst_coeff[-3], env.BLOCK_OUT) # Apply tensorization of the loop coefficients src_offset = src_coeff[-1] dst_offset = dst_coeff[-1] if env.BATCH == 1: src_coeff = src_coeff[:-2] dst_coeff = dst_coeff[:-2] extents = extents[:-1] else: src_coeff = src_coeff[:-3] dst_coeff = dst_coeff[:-3] extents = extents[:-2] src_coeff.append(src_offset) dst_coeff.append(dst_offset) src_coeff = [ tvm.ir_pass.Simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff] dst_coeff = [ tvm.ir_pass.Simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff] # Flatten the outer loops if extents: src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents) # Insert ALU micro-ops irb = tvm.ir_builder.create() for idx, extent in enumerate(extents): irb.emit(tvm.call_extern( "int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx], 0)) use_imm = int(use_imm) irb.emit(tvm.call_extern( "int32", "VTAUopPush", 1, 0, dst_coeff[len(dst_coeff)-1], src_coeff[len(src_coeff)-1], 0, alu_opcode, use_imm, imm_val)) for extent in extents: irb.emit(tvm.call_extern( "int32", "VTAUopLoopEnd")) return irb.get() return stmt
def device_context(dev_id): ctx = tvm.call_extern("handle", "device_context", dev_type, dev_id) return tvm.make.Call( "handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0)
def postorder(op): assert isinstance(op, tvm.expr.Call) if op.name == "TestA": return tvm.call_extern("int32", "TestB", op.args[0] + 1) return op
def intrin_func(ins, outs): ib = tvm.ir_builder.create() ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get()