def test_storage_share(): m = te.var("m") l = te.var("l") A = te.placeholder((m, l), name="A") num_stage = 5 B = A for t in range(num_stage): B = te.compute((m, l), lambda i, j: B[i, j] + (t + 1), name="A%d" % t) s = te.create_schedule(B.op) mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body # verify only have one allocations. # verify inplace folding works num_alloc = [0] def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 1
def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024): # Test Buffer register_mem(scope_tb, max_bits) m = 10 A = te.placeholder((m, ), name="A") C = te.placeholder((m, ), name="C") D = te.placeholder((m, ), name="D") A0 = te.compute((m, ), lambda i: A[i] + C[i], name="A0") A1 = te.compute((m, ), lambda i: D[i] * D[i], name="A1") A2 = te.compute((m, ), lambda i: A0[i] + A1[i], name="A2") B = te.compute((m, ), lambda i: A2[i], name="B") s = te.create_schedule(B.op) A0L = s.cache_read(A0, scope_tb, [A2]) A1L = s.cache_read(A1, scope_tb, [A2]) A2L = s.cache_read(A2, scope_tb, [B]) mod = schedule_to_module(s, [A, B, C, D]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body # verify only have one allocations. # verify inplace folding works num_alloc = [0] def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 2
def test_storage_combine_with_vectorization(): n = 1024 A = te.placeholder((n, ), name="A") B = te.placeholder((n, ), name="B") C = te.compute((n, ), lambda i: A[i] + B[i], name="C") s = te.create_schedule(C.op) AA = s.cache_read(A, "global:tag", readers=[C]) BB = s.cache_read(B, "global:tag", readers=[C]) CC = s.cache_write(C, "global:tag") s[CC].vectorize(s[CC].op.axis[0]) mod = schedule_to_module(s, [A, B, C]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.VectorizeLoop()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) mod = tvm.tir.transform.Simplify()(mod) stmt = mod["main"].body num_alloc = [0] def verify(v): # find add op if (isinstance(v, tvm.tir.Add) and isinstance(v.a, tvm.tir.Load) and isinstance(v.b, tvm.tir.Load)): lhs_ramp = v.a.index rhs_ramp = v.b.index # these two ramp load should not overlap assert lhs_ramp.lanes == n assert rhs_ramp.lanes == n assert lhs_ramp.base >= rhs_ramp.base + n or rhs_ramp.base >= lhs_ramp.base + n elif isinstance(v, tvm.tir.Allocate): num_alloc[0] += 1 tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 1
def test_copy_pad_split(): m = 4 * 3 A = te.placeholder((m, ), name="A") Apad = te.compute((m + 2, ), lambda i: tvm.tir.if_then_else( tvm.tir.all(i >= 1, i <= m), A[i - 1], 0.0), "Apad") B = te.compute((m, ), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2]) s = te.create_schedule(B.op) xo, xi = s[B].split(B.op.axis[0], factor=4) s[Apad].compute_at(s[B], xo) s[Apad].pragma(s[Apad].op.axis[0], "memcpy") mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) mod = tvm.tir.transform.Simplify()(mod._move()) def cb(src, dst, pad_before, pad_after, pad_value): assert dst.elem_offset.value == 0 tvm.testing.assert_prim_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1) rpad_before = tvm.te.max(1 - xo * 4, 0) rpad_after = tvm.te.max(xo * 4 - 7, 0) tvm.testing.assert_prim_expr_equal(pad_before[0], rpad_before) tvm.testing.assert_prim_expr_equal(pad_after[0], rpad_after) tvm.testing.assert_prim_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) return tvm.tir.Evaluate(0) stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
def test_inplace_rule(): m = 10 A = te.placeholder((m, ), name="A") A0 = te.compute((m, ), lambda i: A[i], name="A0") A1 = te.compute((m, ), lambda i: A[i] + 1, name="A1") AA = te.compute((m, ), lambda i: A0[i] + A1[i] + A1[0], name="AA") B = te.compute((m, ), lambda i: AA[i] + 1, name="B") s = te.create_schedule(B.op) mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body # verify only have one allocations. # verify inplace folding works num_alloc = [0] def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 2
def test_copy_pad(): m = te.var("m") l = te.var("l") A = te.placeholder((m, l), name="A") B = te.compute( (m + 2, l), lambda i, j: tvm.tir.if_then_else(tvm.tir.all(i >= 1, i < m + 1), A[ i - 1, j], 1.0), name="B", ) s = te.create_schedule(B.op) s[B].pragma(B.op.axis[0], "memcpy") mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) def cb(src, dst, pad_before, pad_after, pad_value): tvm.testing.assert_prim_expr_equal(src.elem_offset, 0) assert pad_before[0].value == 1 assert pad_before[1].value == 0 assert pad_after[0].value == 1 assert pad_after[1].value == 0 assert pad_value.value == 1.0 return tvm.tir.Evaluate(0) stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
def test_storage_share_gpu(): m = te.var("m") A = [te.placeholder((m), name="A")] num_stage = 5 for t in range(num_stage): A.append( te.compute((m, ), lambda i: A[-1][i] + (t + 1), name="A%d_s" % t)) A.append(te.compute((m, ), lambda i: A[-1][i], name="A%d" % t)) s = te.create_schedule(A[-1].op) for t in range(num_stage): x = A[2 * t + 2].op.axis[0] bx, tx = s[A[2 * t + 2]].split(x, factor=32) s[A[2 * t + 2]].bind(bx, te.thread_axis("blockIdx.x")) s[A[2 * t + 2]].bind(tx, te.thread_axis("threadIdx.x")) s[A[2 * t + 1]].compute_at(s[A[2 * t + 2]], tx) s[A[2 * t + 1]].set_scope("shared") mod = schedule_to_module(s, [A[0], A[-1]]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body alloc_stats = {"global": 0, "shared": 0} def verify(n): if isinstance(n, tvm.tir.Allocate): scope = n.buffer_var.type_annotation.storage_scope alloc_stats[scope] += 1 tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert alloc_stats["global"] == 2 assert alloc_stats["shared"] == num_stage
def test_storage_combine(): n = 8 A = te.placeholder((4, ), name="A") num_stage = 5 B = A stages = [] for t in range(num_stage): B = te.compute((n, ), lambda i: B[i] + B[0] + (t + 1), name="A%d" % t) stages.append(B) s = te.create_schedule(B.op) for S in stages[:-1]: s[S].set_scope("global:tag") mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body num_alloc = [0] def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 assert n.extents[0].value == 16 tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 1
def test_internal_reshape(self, target, dev, n_items, reordered_shape, dtype, fphysical_layout): # The reshaping of the buffer gets flattened away in # StorageFlatten. Therefore, testing the behavior by running only # ApplyLayoutTransforms. logical_shape = (n_items, ) A = te.placeholder(logical_shape, name="A", dtype=dtype) B = te.compute(shape=logical_shape, fcompute=lambda i: A[i], name="B") C = te.compute(shape=logical_shape, fcompute=lambda i: B[i], name="C") s = te.create_schedule(C.op) s[B].transform_layout(fphysical_layout) mod = schedule_to_module(s, [A, C]) body = mod["main"].body def walk_buffer_interactions(stmt, callback): buffer_classes = [ tvm.tir.BufferLoad, tvm.tir.BufferStore, tvm.tir.BufferRealize, ] def inner(node): if (type(node) in buffer_classes) and node.buffer.name == "B": callback(node) post_order_visit(stmt, inner) # All references to the buffer are the same object def check_references(): buffer_object = None def inner(node): nonlocal buffer_object if buffer_object is None: buffer_object = node.buffer else: assert node.buffer.same_as(buffer_object) return inner # The buffer has the expected shape. def check_shape(expected_shape): def inner(node): assert tuple(node.buffer.shape) == expected_shape return inner # Before the transform, the buffer should be in the logical shape. walk_buffer_interactions(body, check_references()) walk_buffer_interactions(body, check_shape(logical_shape)) mod = tvm.tir.transform.ApplyLayoutTransforms()(mod) body = mod["main"].body # After the transform, the buffer should be in the physical shape. walk_buffer_interactions(body, check_references()) walk_buffer_interactions(body, check_shape(reordered_shape))
def test_schedule0(): m = te.var("m") l = te.var("l") A = te.placeholder((m, l), name="A") A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1") s = te.create_schedule(A1.op) mod = schedule_to_module(s, [A, A1]) assert isinstance(mod["main"], tvm.tir.PrimFunc)
def run_passes(sch, args): mod = schedule_to_module(sch, args) return tvm.transform.Sequential([ tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify(), tvm.tir.transform.VectorizeLoop(), tvm.tir.transform.StorageRewrite(), tvm.tir.transform.MergeDynamicSharedMemoryAllocations(), ])(mod)
def run_passes(sch, args): mod = schedule_to_module(sch, args) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))( mod) return tvm.transform.Sequential([ tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify(), tvm.tir.transform.StorageRewrite(), tvm.tir.transform.LowerThreadAllreduce(), ])(mod)
def test_schedule1(): m = te.var("m") l = te.var("l") A = te.placeholder((m, l), name="A") A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1") s = te.create_schedule(A1.op) xo, xi = s[A1].split(A1.op.axis[0], 8) s[A1].pragma(xo, "auto_unroll_max_step", 10) mod = schedule_to_module(s, [A, A1]) assert isinstance(mod["main"], tvm.tir.PrimFunc)
def test_schedule2(): m = te.var("m") l = te.var("l") A = te.placeholder((m, l), name="A") A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1") A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2") s = te.create_schedule(A2.op) xo, xi = s[A2].split(A2.op.axis[0], 8) s[A1].compute_at(s[A2], xo) mod = schedule_to_module(s, [A, A2]) assert isinstance(mod["main"], tvm.tir.PrimFunc)
def ana_lower(sch, args, binds=None, simple_mode=True): """Do lower while keeping all axes in IR i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads """ sch = sch.normalize() # Phase 0 context = tvm.transform.PassContext( config={"tir.debug_keep_trivial_loop": True}) with context: mod = build_module.schedule_to_module(sch, args, binds=binds) mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) mod = tvm.tir.transform.Simplify()(mod._move()) assert simple_mode return mod["main"].body
def test_flatten2(): m = te.size_var("m") l = te.size_var("l") A = te.placeholder((m, l), name="A") A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1") A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2") s = te.create_schedule(A2.op) xo, xi = s[A2].split(A2.op.axis[0], 8) s[A1].compute_at(s[A2], xo) Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="A") A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name="A2") mod = schedule_to_module(s, [Ab, A2b], binds={A: Ab, A2: A2b}) mod = tvm.tir.transform.StorageFlatten(64)(mod)
def test_single_point_test(): A = te.placeholder((1, ), name="A") B = te.compute((1, ), lambda i: A[i], name="B") s = te.create_schedule(B.op) s[B].pragma(B.op.axis[0], "memcpy") mod = schedule_to_module(s, [A, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) def cb(src, dst, pad_before, pad_after, pad_value): tvm.testing.assert_prim_expr_equal(src.elem_offset, 0) tvm.testing.assert_prim_expr_equal(dst.elem_offset, 0) tvm.testing.assert_prim_expr_equal(src.strides[0], 1) tvm.testing.assert_prim_expr_equal(dst.strides[0], 1) return tvm.tir.Evaluate(0) stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
def lower_sch(sch, args, target_bits): binds = {} arg_list = [] for x in args: if isinstance(x, te.tensor.Tensor): buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name) assert x not in binds binds[x] = buf arg_list.append(buf) else: raise ValueError("args must be Tensor, Buffer or Var") sch = sch.normalize() mod = schedule_to_module(sch, args) mod = tvm.tir.transform.StorageFlatten(64)(mod) return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body
def test_flatten_storage_align(): m = 8 l = 16 A = te.placeholder((m, l), name="A") A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1") A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2") s = te.create_schedule(A2.op) s[A1].storage_align(A1.op.axis[0], 2, 1) mod = schedule_to_module(s, [A, A2]) mod = tvm.transform.Sequential( [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()])(mod) stmt = mod["main"].body assert stmt.extents[0].value == 17 * 8
def test_makeapi(): """Not yet working, mock design""" n = te.size_var("n") A = te.placeholder((n, ), name="A") B = te.placeholder((n, ), name="B") C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") s = te.create_schedule(C.op) mod = schedule_to_module(s, [n, A, B, C]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Apply( lambda f: f.with_attr({ "target": tvm.target.Target("llvm"), "global_symbol": "main", }))(mod) num_unpacked_args = 2 f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"] assert len(f.params) == 8
def lower_ethosu(sch, args, const_dict, name="main"): """Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target. The resulting TIR module will contain a single function that consists of a sequence of tir.call_extern to NPU operations. Parameters ---------- sch : tvm.te.Schedule The schedule to be lowered. args : Union[list of tvm.te.Tensor, TEGraph] The input/output tensors. const_dict : dict of int to numpy.ndarray The constant dictionary. name : str, optional The name of the lowered primitive function. Returns ------- mod : tvm.IRModule The lowered TIR module. const_dict : dict of int to numpy.ndarray The modified constant dictionary. """ if not isinstance(args, list): args = list(args.inputs) + list(args.outputs) # config setup curr_pass_ctx = tvm.ir.transform.PassContext.current() curr_cfg = dict() for key, value in curr_pass_ctx.config.items(): curr_cfg[key] = value tir_compiler_cfg = { "tir.LoopPartition": { "partition_const_loop": True, "no_unroll_loop_with_extent_one": True, }, "tir.UnrollLoop": { "auto_max_depth": -1 }, "tir.noalias": True, "tir.debug_keep_trivial_loop": True, } # Merge two configs curr_cfg = {**curr_cfg, **tir_compiler_cfg} sch = sch.normalize() with tvm.transform.PassContext(config=curr_cfg): mod = schedule_to_module(sch, args, name) mod = tvm.tir.transform.Simplify()(mod) mod = ethosu_passes.RemoveConcatenates()(mod) mod = tvm.tir.transform.InjectRollingBuffer()(mod) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.UnrollLoop()(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.LoopPartition()(mod) mod = ethosu_passes.RemoveZeroStores()(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) mod = ethosu_passes.ReplaceOperators()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod) mod = ethosu_passes.HoistAllocates()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) # MergeConstant pass currently does not support striped schedules. # It requires further investigation. if not util.is_striping_enabled(): mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod) mod = ethosu_passes.CopyComputeReordering()(mod) # When striping is enabled and if storage_rewrite is not run # the striping results in incorrect code generation. This needs # further investigation. Until such a time that is fixed, disable_storage_rewrite # user directive will be overridden if striping is enabled. disable_storage_rewrite = curr_cfg.get("tir.disable_storage_rewrite", False) if not disable_storage_rewrite or util.is_striping_enabled(): mod = tvm.tir.transform.StorageRewrite()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) mod = ethosu_passes.AnnotateAllocates()(mod) mod, const_dict = ethosu_passes.CreatePrimFuncWithoutConstants( const_dict)(mod) return mod, const_dict
def lower_ethosu(sch, args, const_dict, name="main"): """Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target. The resulting TIR module will contain a single function that consists of a sequence of tir.call_extern to NPU operations. Parameters ---------- sch : tvm.te.Schedule The schedule to be lowered. args : Union[list of tvm.te.Tensor, TEGraph] The input/output tensors. const_dict : dict of int to numpy.ndarray The constant dictionary. name : str, optional The name of the lowered primitive function. Returns ------- mod : tvm.IRModule The lowered TIR module. const_dict : dict of int to numpy.ndarray The modified constant dictionary. """ if not isinstance(args, list): args = list(args.inputs) + list(args.outputs) # config setup curr_pass_ctx = tvm.ir.transform.PassContext.current() curr_cfg = dict() for key, value in curr_pass_ctx.config.items(): curr_cfg[key] = value tir_compiler_cfg = { "tir.LoopPartition": { "partition_const_loop": True, "no_unroll_loop_with_extent_one": True, }, "tir.UnrollLoop": { "auto_max_depth": -1 }, "tir.noalias": True, "tir.debug_keep_trivial_loop": True, } # Merge two configs curr_cfg = {**curr_cfg, **tir_compiler_cfg} sch = sch.normalize() with tvm.transform.PassContext(config=curr_cfg): mod = schedule_to_module(sch, args, name) mod = tvm.tir.transform.Simplify()(mod) mod = ethosu_passes.RemoveConcatenates()(mod) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.UnrollLoop()(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.LoopPartition()(mod) mod = ethosu_passes.RemoveZeroStores()(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) mod = ethosu_passes.ReplaceOperators()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod) mod = tvm.tir.transform.StorageRewrite()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) mod = ethosu_passes.AnnotateAllocates()(mod) return mod, const_dict
def test_inplace_rule3(): # Test Buffer scope_tb = "local_TB3" max_bits = 1024 * 1024 * 1024 register_mem(scope_tb, max_bits) m = 10 B0 = te.placeholder((m, ), name="B0") B1 = te.placeholder((m, ), name="B1") B2 = te.placeholder((m, ), name="B2") B3 = te.placeholder((m, ), name="B3") B4 = te.placeholder((m, ), name="B4") B5 = te.placeholder((m, ), name="B5") B6 = te.compute((m, ), lambda i: B1[i] * B5[i], name="B6") B7 = te.compute((m, ), lambda i: B2[i] * B4[i], name="B7") B8 = te.compute((m, ), lambda i: B6[i] - B7[i], name="B8") B9 = te.compute((m, ), lambda i: B2[i] * B3[i], name="B9") B10 = te.compute((m, ), lambda i: B0[i] * B5[i], name="B10") B11 = te.compute((m, ), lambda i: B9[i] - B10[i], name="B11") B12 = te.compute((m, ), lambda i: B0[i] * B4[i], name="B12") B13 = te.compute((m, ), lambda i: B1[i] * B3[i], name="B13") B14 = te.compute((m, ), lambda i: B12[i] - B13[i], name="B14") B = te.compute((m, ), lambda i: B8[i] * B11[i] + B14[i], name="B") s = te.create_schedule(B.op) B1L = s.cache_read(B1, scope_tb, [B6, B13]) B5L = s.cache_read(B5, scope_tb, [B6, B10]) B2L = s.cache_read(B2, scope_tb, [B7, B9]) B4L = s.cache_read(B4, scope_tb, [B7, B12]) B3L = s.cache_read(B3, scope_tb, [B9, B13]) B0L = s.cache_read(B0, scope_tb, [B10, B12]) B8L = s.cache_write(B8, scope_tb) B11L = s.cache_write(B11, scope_tb) B14L = s.cache_write(B14, scope_tb) B6L = s.cache_write(B6, scope_tb) B7L = s.cache_write(B7, scope_tb) B9L = s.cache_write(B9, scope_tb) B10L = s.cache_write(B10, scope_tb) B12L = s.cache_write(B12, scope_tb) B13L = s.cache_write(B13, scope_tb) s[B12].compute_inline() s[B13].compute_inline() s[B8].compute_inline() s[B11].compute_inline() s[B14].compute_inline() s[B6].compute_inline() s[B7].compute_inline() s[B9].compute_inline() s[B10].compute_inline() s = s.normalize() mod = schedule_to_module(s, [B0, B1, B2, B3, B4, B5, B]) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body # verify only have one allocations. # verify inplace folding works def verify(n): if isinstance(n, tvm.tir.Allocate): assert n.extents[0].value == 70 tvm.tir.stmt_functor.post_order_visit(stmt, verify)