def __getitem__(self, indices): from ..arith import Analyzer # pylint: disable=import-outside-toplevel from .expr import BufferLoad, Ramp # pylint: disable=import-outside-toplevel from .stmt import BufferRegion # pylint: disable=import-outside-toplevel if not isinstance(indices, (tuple, list)): indices = [indices] if any(isinstance(index, slice) and index.step is None for index in indices): region = [] analyzer = Analyzer() for index in indices: if isinstance(index, slice): region.append( Range.from_min_extent( index.start, analyzer.simplify(index.stop - index.start) ) ) else: region.append(Range.from_min_extent(index, 1)) return BufferRegion(self, region) else: analyzer = Analyzer() expr_indices = [] for index in indices: if isinstance(index, slice): lanes = analyzer.simplify( (index.stop - index.start + index.step - 1) // index.step ) if lanes == 1: expr_indices.append(index.start) else: expr_indices.append(Ramp(index.start, index.step, int(lanes))) else: expr_indices.append(index) return BufferLoad(self, expr_indices)
def test_block_access_region_detector(): block = func.body.block.body.block alloc_buffers = func.body.block.alloc_buffers buffer_var_map = {buf.data: buf for buf in alloc_buffers} ret = tir.analysis.get_block_access_region(block, buffer_var_map) tvm.ir.assert_structural_equal(block.reads, ret[0]) tvm.ir.assert_structural_equal(block.writes, ret[1]) D = alloc_buffers[-1] tvm.ir.assert_structural_equal( [tvm.tir.BufferRegion(D, [Range(0, 128), Range(0, 128)])], ret[2])
def _get_region(tslice): region = [] for idx in tslice.indices: if isinstance(idx, slice): assert idx.step is None region.append(Range(idx.start, idx.stop)) else: if isinstance(idx, tvm.tir.IterVar): begin = idx.var else: begin = idx region.append(Range.make_by_min_extent(begin, 1)) return region
def test_complete_matmul_original(): func = matmul_original A, B, C = [func.buffer_map[x] for x in func.params] block1 = func.body.block.body.body.body[0].block assert isinstance(block1, tvm.tir.Block) vi, vj = [x.var for x in block1.iter_vars] access_C = tvm.tir.BufferRegion( C, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vj * 4, 4)]) tvm.ir.assert_structural_equal(block1.reads, []) tvm.ir.assert_structural_equal(block1.writes, [access_C]) block2 = func.body.block.body.body.body[1].body.block assert isinstance(block2, tvm.tir.Block) vi, vj, vk = [x.var for x in block2.iter_vars] access_A = tvm.tir.BufferRegion( A, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vk * 4, 4)]) access_B = tvm.tir.BufferRegion( B, [Range.from_min_extent(vj * 4, 4), Range.from_min_extent(vk * 4, 4)]) access_C = tvm.tir.BufferRegion( C, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vj * 4, 4)]) tvm.ir.assert_structural_equal(block2.reads, [access_C, access_A, access_B]) tvm.ir.assert_structural_equal(block2.writes, [access_C])
def test_complete_with_root(): func = elementwise_with_root A, B, C = [func.buffer_map[x] for x in func.params] block1 = func.body.block.body[0].body.body.block assert isinstance(block1, tvm.tir.Block) vi, vj = [x.var for x in block1.iter_vars] tvm.ir.assert_structural_equal( block1.reads, [tir.BufferRegion(A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], ) tvm.ir.assert_structural_equal( block1.writes, [tir.BufferRegion(B, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], ) block2 = func.body.block.body[1].body.body.block assert isinstance(block2, tvm.tir.Block) vi, vj = [x.var for x in block2.iter_vars] tvm.ir.assert_structural_equal( block2.reads, [tir.BufferRegion(B, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], ) tvm.ir.assert_structural_equal( block2.writes, [tir.BufferRegion(C, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], )
def wrap_up_realize(self, node, body): """Wrap up all the variables which will no longer be used""" to_pop = [] for key, val in self.usage.items(): _, level, _ = val if key not in self.symbols: # don't realize the symbols that are never visited continue if level != node: continue _internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key) ty, entry = self.symbols[key] # pylint: disable=invalid-name if ty in [Symbol.Input, Symbol.OutputBuffer]: continue if "Buffer" in ty.name: _buf = entry _scope = "global" if ty is Symbol.BufferVar else ty.name[:-6].lower() to_pop.append(key) else: continue if _scope == "global": body = self.wrap_up_binds(body) _domain = [Range.from_min_extent(0, i) for i in _buf.shape] _dtype = _buf.dtype _true = tvm.runtime.convert(True) body = tvm.tir.ProducerRealize(_buf, _domain, _true, body, tvm.runtime.convert(_scope)) for elem in to_pop: self.symbols.pop(elem) return body
def realize(buffer_slice: BufferSlice, scope: str, condition: bool = True, span: bool = None): assert self.context, "call 'exit_scope' before 'enter_scope'" buffer: Buffer = buffer_slice.buffer bounds: List[Range] = [] for s in buffer_slice.slices: min: Union[PrimExpr, int] = s.start extent: Union[PrimExpr, int] = 1 if s.stop is None else s.stop - s.start if isinstance(extent, PrimExpr): extent = self.context.analyzer.simplify(extent) bounds.append(Range.from_min_extent(min, extent, span=s.span)) scope = tvm.runtime.convert(scope, span=span) return tvm.tir.AttrStmt( buffer, "realize_scope", scope, tvm.tir.BufferRealize(buffer, bounds, condition, self.body, span=span), span=span, )
def buffer_slice_to_region( buffer_slice: BufferSlice, analyzer: Optional[Analyzer] = None) -> BufferRegion: """Construct BufferRegion from BufferSlice Parameters ---------- buffer_slice : BufferSlice The input BufferSlice analyzer : Optional[tvm.arith.Analyzer] The analyzer for simplifying. If not provided, the method will construct a new one Returns ------- buffer_region : BufferRegion The constructed BufferRegion. """ region: List[Range] = [] for s in buffer_slice.slices: start = s.start if isinstance(s.start, PrimExpr) else IntImm( "int32", s.start) extent = IntImm(start.dtype, 1) if s.stop is None else s.stop - s.start if not analyzer: analyzer = Analyzer() if isinstance(extent, PrimExpr): extent = analyzer.simplify(extent) region.append(Range.from_min_extent(start, extent, span=s.span)) return BufferRegion(buffer_slice.buffer, region)
def as_buffer_region(self, analyzer: Optional[Analyzer] = None) -> BufferRegion: """Construct BufferRegion from BufferSlice Parameters ---------- analyzer : Optional[tvm.arith.Analyzer] The analyzer for simplifying. If not provided, the method will construct a new one Returns ------- buffer_region : BufferRegion The constructed BufferRegion. """ region: List[Range] = [] for s in self.slices: start = s.start if isinstance(s.start, PrimExpr) else IntImm( "int32", s.start) extent = IntImm(start.dtype, 1) if s.stop is None else s.stop - s.start if not analyzer: analyzer = Analyzer() if isinstance(extent, PrimExpr): extent = analyzer.simplify(extent) if s.step != 1: self.report_error( "BufferRegion do not support non-trivial stride", s.span) region.append(Range.from_min_extent(start, extent, span=s.span)) return BufferRegion(self.buffer, region)
def enter_scope( self, node: synr.ast.Node, context: ContextMaintainer, arg_list: List[Any], span: synr.ast.Span, ): assert isinstance( node, synr.ast.For ), f"ForScopeHandler expected synr.ast.For but got {type(node)}" loop_var_names = list() spans = list() if isinstance(node.lhs, synr.ast.Var): loop_var_names.append(node.lhs.id.name) spans.append(tvm_span_from_synr(node.lhs.id.span)) elif isinstance(node.lhs, list): for elt in node.lhs: if not isinstance(elt, synr.ast.Var): context.report_error( f"Invalid loop var. Expected a var, but got {type(elt)}", elt.span) loop_var_names.append(elt.id.name) spans.append(tvm_span_from_synr(elt.id.span)) else: context.report_error( f"Invalid loop var. Expected var or list of vars as lhs, but got {type(node.lhs)}", span, ) self.node = node self.context = context # collect loop infos by calling self.func call_with_error_reporting(context.report_error, span, self.func, *arg_list) if len(loop_var_names) != len(self.loop_info): self.context.report_error( f"Inconsistent number of vars and loops, got {len(loop_var_names)} " + f"vs {len(self.loop_info)}", self.node.span, ) # generate loop vars self.loop_vars = [] for name, lv_span, li in zip(loop_var_names, spans, self.loop_info): if not li.begin.dtype.startswith("int"): raise NotImplementedError( f"Unsupported dtype in loop begin: {li.begin.dtype}") if not li.extent.dtype.startswith("int"): raise NotImplementedError( f"Unsupported dtype in loop extent: {li.extent.dtype}") dtype = "int64" if "int64" in [li.begin.dtype, li.extent.dtype ] else "int32" self.loop_vars.append(tvm.te.var(name, dtype=dtype, span=lv_span)) for loop_var, loop_info in zip(self.loop_vars, self.loop_info): context.update_symbol(loop_var.name, loop_var, node) context.loop_stack[loop_var] = Range.from_min_extent( loop_info.begin, loop_info.extent)
def test_complete_matmul(): func = matmul A, B, C = [func.buffer_map[x] for x in func.params] block = func.body.block.body.body.body.body.block assert isinstance(block, tvm.tir.Block) vi, vj, vk = [x.var for x in block.iter_vars] access_A = tir.BufferRegion(A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vk, 1)]) access_B = tir.BufferRegion(B, [Range.from_min_extent(vj, 1), Range.from_min_extent(vk, 1)]) access_C = tir.BufferRegion(C, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)]) tvm.ir.assert_structural_equal(block.reads, [access_C, access_A, access_B]) tvm.ir.assert_structural_equal(block.writes, [access_C])
def non_surjective_inverse( self, shape: List[Union[Range, PrimExpr]]) -> Tuple["IndexMap", PrimExpr]: """Return the inverse of the map Can be applied to transformations that introduce padding. Parameters ---------- shape: List[Union[Range,PrimExpr]] The region over which the inverse should be determined. Used for determining the predicate. Returns ------- result : Tuple[IndexMap, PrimExpr] The inverse, and a predicate for which the inverse maps to a valid index in the input range. Examples -------- .. code-block:: python index_map = IndexMap.from_func(lambda i: [i//4, i%4]) inverse_map, predicate = index_map.non_surjective_inverse([14]) assert inverse_map.is_equivalent_to(IndexMap.from_func(lambda j,k: [4*j + k]) print(predicate) # Prints "(axis0==3) && (axis2 >= 2)" """ shape = [ dim if isinstance(dim, Range) else Range(0, dim) for dim in shape ] return _ffi_api.IndexMapNonSurjectiveInverse(self, shape)
def inverse(self, shape: List[Union[Range, PrimExpr]]) -> "IndexMap": """Return the inverse of the map Throws an error if the function is not bijective. Parameters ---------- shape: List[Union[Range,PrimExpr]] The region over which the inverse should be determined. Used for validating that the mapping is bijective over this range. Returns ------- inverse : IndexMap The inverse """ shape = [ dim if isinstance(dim, Range) else Range(0, dim) for dim in shape ] return _ffi_api.IndexMapInverse(self, shape)
def block(axes=None, name_hint: str = "", span: Optional[Span] = None): assert ( self.node and self.context and self.body ), "call 'exit_scope' before 'enter_scope'" block_info = self.context.block_info_stack[-1] if axes is None: axes = [] if len(axes) != len(self.block_vars): self.context.report_error( "Inconsistent number of block vars, " + f"there are {len(axes)} axes but {len(self.block_vars)} block vars. " + "The number of block vars should match the number of axes.", self.node.span, ) block_iters: List[IterVar] = [] for i, axis in enumerate(axes): axis = tvm.runtime.convert(axis) if isinstance(axis, tvm.tir.PrimExpr): block_var_dom = Range.from_min_extent(0, axis) block_iters.append(IterVar(block_var_dom, self.block_vars[i], 0)) elif isinstance(axis, Range): block_iters.append(IterVar(axis, self.block_vars[i], 0)) elif isinstance(axis, IterVar): block_iters.append(IterVar(axis.dom, self.block_vars[i], axis.iter_type)) else: self.context.report_error( "Invalid argument of tir.block(), " + f"expected PrimExpr, Range or IterVar, but got {type(axis)}", self.node.span, ) # create block read/write regions reads: List[BufferRegion] = ( [buffer_slice_to_region(read) for read in block_info.reads] if block_info.reads else [] ) writes: List[BufferRegion] = ( [buffer_slice_to_region(write) for write in block_info.writes] if block_info.writes else [] ) region_detect_mask: int = (block_info.reads is None) | ( (block_info.writes is None) << 1 ) annotations = {} if block_info.annotations is None else block_info.annotations if region_detect_mask != 0: annotations["tir.script_parsing_detect_access"] = region_detect_mask inner = tvm.tir.Block( block_iters, reads, writes, name_hint, self.body, block_info.init, block_info.alloc_buffers, block_info.match_buffers, annotations, span, ) # create block var iter binding values: List[PrimExpr] if not block_info.iter_bindings: values = self.context.loop_stack[-2].copy() if len(block_iters) == 0: # It is an opaque block without any bindings values = [] elif len(values) == 0: values = [tvm.tir.const(float("nan"), dtype="float32")] * len(block_iters) elif len(values) != len(block_iters): self.context.report_error( "Number of block iter var and outer loop nesting mismatch, " + f"{len(block_iters)} block iter vars but {len(values)} loops", self.node.span, ) else: for block_var in self.block_vars: if block_var not in block_info.iter_bindings: self.context.report_error( "Missing block iter var binding for " + block_var.name, self.node.span, ) values = [block_info.iter_bindings[block_var] for block_var in self.block_vars] predicate = ( tvm.tir.const(True, "bool") if block_info.predicate is None else block_info.predicate ) body = tvm.tir.BlockRealize(values, predicate, inner, span) return body