def _fold_buffer_dim(buf, scope, elem_block): ndim = len(buf.shape) x_size = 1 base = 0 for i in range(1, ndim + 1): if not util.equal_const_int(buf.strides[ndim - i] - x_size, 0): raise RuntimeError("scope %s needs to have block=%d" % (scope, elem_block)) x_size = x_size * buf.shape[ndim - i] if util.equal_const_int(x_size - elem_block, 0): base = i + 1 break if base == 0: raise RuntimeError("scope %s need to have block=%d, shape=%s" % (scope, elem_block, buf.shape)) shape = [elem_block] strides = [1] if base < ndim + 1 and not util.equal_const_int( buf.strides[ndim - base], elem_block): shape.append(1) strides.append(elem_block) analyzer = tvm.arith.Analyzer() while base < ndim + 1: x_size = 1 x_stride = buf.strides[ndim - base] next_base = base if not util.equal_const_int(idxm(x_stride, elem_block), 0): raise RuntimeError( "scope %s need to have block=%d, shape=%s, strides=%s" % (scope, elem_block, buf.shape, buf.strides)) for i in range(base, ndim + 1): k = ndim - i if not util.equal_const_int(x_size * x_stride - buf.strides[k], 0): break x_size = x_size * buf.shape[k] next_base = i + 1 shape.append(analyzer.simplify(x_size)) strides.append(x_stride) assert next_base != base base = next_base strides = list(reversed(strides)) shape = list(reversed(shape)) return shape, strides
def _check_compact(buf): ndim = len(buf.shape) size = tvm.tir.const(1, buf.shape[0].dtype) for i in reversed(range(ndim)): if not util.equal_const_int(size - buf.strides[i], 0): raise RuntimeError( "Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides)) size = size * buf.shape[i]
def _inject_copy(src, dst, pad_before, pad_after, pad_value): # FIXME: pad_value is ignored... env = get_env() _ = pad_value if dst.scope == "global": # Store if pad_before or pad_after: raise RuntimeError("Do not support copy into DRAM with pad") if src.scope == env.acc_scope: elem_width = env.OUT_WIDTH elem_bytes = env.OUT_ELEM_BYTES mem_type = env.dev.MEM_ID_OUT data_type = "int%d" % env.OUT_WIDTH task_qid = env.dev.QID_STORE_OUT else: raise RuntimeError("Do not support copy %s->dram" % (src.scope)) _check_compact(src) x_size, y_size, x_stride, offset = _get_2d_pattern(dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True) irb = tvm.tir.ir_builder.create() irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid)) irb.emit( tvm.tir.call_extern( "int32", "VTAStoreBuffer2D", env.dev.command_handle, src.access_ptr("r", "int32"), mem_type, dst.data, offset, x_size, y_size, x_stride, )) return irb.get() elif src.scope == "global": if dst.scope == env.acc_scope: elem_width = env.ACC_WIDTH elem_bytes = env.ACC_ELEM_BYTES mem_type = env.dev.MEM_ID_ACC data_type = "int%d" % env.ACC_WIDTH task_qid = env.dev.QID_LOAD_OUT elif dst.scope == env.inp_scope: elem_width = env.INP_WIDTH elem_bytes = env.INP_ELEM_BYTES mem_type = env.dev.MEM_ID_INP data_type = "int%d" % env.INP_WIDTH task_qid = env.dev.QID_LOAD_INP elif dst.scope == env.wgt_scope: elem_width = env.WGT_WIDTH elem_bytes = env.WGT_ELEM_BYTES mem_type = env.dev.MEM_ID_WGT data_type = "int%d" % env.WGT_WIDTH task_qid = env.dev.QID_LOAD_WGT else: raise RuntimeError("Do not support copy dram->%s" % (dst.scope)) # collect pad statistics if pad_before: assert pad_after ndim = len(pad_before) if ndim <= 2 or ndim > 5: raise ValueError( "Limitation of 2D pad load forbid ndim=%d" % ndim) if ndim == 5: # This case occurs when batch size N > 1 y_pad_before = pad_before[1] x_pad_before = pad_before[2] y_pad_after = pad_after[1] x_pad_after = pad_after[2] for dim in range(3, ndim): if not util.equal_const_int(pad_before[dim], 0): raise ValueError( "Do not support pad on the innermost block") if not util.equal_const_int(pad_after[dim], 0): raise ValueError( "Do not support pad on the innermost block") else: y_pad_before = pad_before[0] x_pad_before = pad_before[1] y_pad_after = pad_after[0] x_pad_after = pad_after[1] for dim in range(2, ndim): if not util.equal_const_int(pad_before[dim], 0): raise ValueError( "Do not support pad on the innermost block") if not util.equal_const_int(pad_after[dim], 0): raise ValueError( "Do not support pad on the innermost block") allow_fold = False else: x_pad_before = 0 y_pad_before = 0 x_pad_after = 0 y_pad_after = 0 allow_fold = True _check_compact(dst) x_size, y_size, x_stride, offset = _get_2d_pattern( src, elem_width, elem_bytes, data_type, dst.scope, allow_fold=allow_fold) irb = tvm.tir.ir_builder.create() irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid)) irb.emit( tvm.tir.call_extern( "int32", "VTALoadBuffer2D", env.dev.command_handle, src.data, offset, x_size, y_size, x_stride, x_pad_before, y_pad_before, x_pad_after, y_pad_after, dst.access_ptr("r", "int32"), mem_type, )) return irb.get() else: raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))
def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold): elem_block = elem_bytes * 8 // elem_width if buf.dtype != dtype: raise RuntimeError("Expect buffer type to be %s instead of %s" % (dtype, buf.dtype)) shape, strides = buf.shape, buf.strides if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0): raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block)) if allow_fold: shape, strides = _fold_buffer_dim(buf, scope, elem_block) else: shape = list(x for x in shape) strides = list(x for x in strides) def raise_error(): """Internal function to raise error """ raise RuntimeError( ("Scope[%s]: cannot detect 2d pattern with elem_block=%d:" + " shape=%s, strides=%s") % (scope, elem_block, buf.shape, buf.strides)) ndim = len(shape) # Check if the inner-tensor is already flat flat = util.equal_const_int(shape[-1], elem_block) if flat: if not util.equal_const_int(strides[-1], 1): raise_error() if ndim == 1: x_size = 1 x_stride = 1 y_size = 1 return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) if not util.equal_const_int(strides[-2] - elem_block, 0): raise_error() if ndim == 2: x_size = shape[-2] x_stride = shape[-2] y_size = 1 return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) if not util.equal_const_int(idxm(strides[-3], elem_block), 0): raise_error() if ndim == 3: x_size = shape[-2] x_stride = idxd(strides[-3], elem_block) y_size = shape[-3] return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) else: if not util.equal_const_int(strides[-1], 1): raise_error() if not util.equal_const_int(strides[-2] - shape[-1], 0): raise_error() if not util.equal_const_int(shape[-1] * shape[-2], elem_block): raise_error() if ndim == 2: x_size = 1 x_stride = 1 y_size = 1 return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) if not util.equal_const_int(strides[-3], elem_block): raise_error() if ndim == 3: x_size = shape[-3] x_stride = shape[-3] y_size = 1 return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) if not util.equal_const_int(idxm(strides[-4], elem_block), 0): raise_error() if ndim == 4: x_size = shape[-3] x_stride = idxd(strides[-4], elem_block) y_size = shape[-4] return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) raise_error()