def get_strides(index, stride_vars): """Get the striding of given vars in an indexing expression. Parameters ---------- index : tvm.tir.PrimExpr The index expression where the stride vars are present. stride_vars : list of tvm.tir.Var The vars to determine the striding of. Returns ------- strides : list of int The striding of each stride var in the index expression in the same order as the stride vars were given. """ strides = [1] * len(stride_vars) dmap = {} def _visit(stmt): if isinstance(stmt, tvm.tir.Var): dmap[stmt] = arith.IntervalSet(0, 0) tvm.tir.stmt_functor.post_order_visit(index, _visit) min_value = int(arith.Analyzer().int_set(index, dmap).min_value) for var in dmap: if var in stride_vars: # NOTE: Doing this using a [0, 1] interval doesn't work reliably # Seems to be a bug dmap[var] = arith.IntervalSet(1, 1) max_value = int(arith.Analyzer().int_set(index, dmap).max_value) stride = int(max_value - min_value) i = stride_vars.index(var) strides[i] = stride dmap[var] = arith.IntervalSet(0, 0) return strides
def get_base_address(index): """Determine the first (base) address accessed by an index expression. Parameters ---------- index : tvm.tir.PrimExpr The index expression to determine the base address of. Returns ------- base_address: The first address accessed by the index expression. """ dmap = {} def _visit(stmt): if isinstance(stmt, tvm.tir.Var): dmap[stmt] = arith.IntervalSet(0, 0) tvm.tir.stmt_functor.post_order_visit(index, _visit) base_address = int(arith.Analyzer().int_set(index, dmap).min_value) return base_address