def extract_call_extern_list(mod): """This function will obtain all extern calls from a TIR module Parameters ---------- mod : tvm.IRModule The TIR Module for NPU Returns ------- list of tvm.tir.Call objects that are tir extern calls """ # There should only be a single function assert len(mod.functions.items()) == 1 primfunc = mod.functions.items()[0][1] call_extern_list = list() def populate_call_extern_list(stmt): if isinstance(stmt, tvm.tir.Call) and stmt.op.name == "tir.call_extern": call_extern_list.append(stmt) stmt_functor.post_order_visit(primfunc.body, populate_call_extern_list) return call_extern_list
def extract_ethosu_conv2d_extern_calls(mod): """This function will obtain all ethosu_conv2d calls from a NPU TIR module Parameters ---------- mod : tvm.IRModule This is a NPU TIR Module Returns ------- list List of tvm.tir.Call objects that are tir extern calls for ethosu_conv2d """ # There should only be a single function assert len(mod.functions.items()) == 1 primfunc = mod.functions.items()[0][1] ethosu_conv2d_calls = list() def populate_ethosu_conv2d_calls(stmt): if ( isinstance(stmt, tvm.tir.Call) and stmt.op.name == "T.call_extern" and stmt.args[0] == "ethosu_conv2d" ): ethosu_conv2d_calls.append(stmt) stmt_functor.post_order_visit(primfunc.body, populate_ethosu_conv2d_calls) return ethosu_conv2d_calls
def collect_loops(node): output = [] def callback(node): if isinstance(node, tvm.tir.For): output.append(node) post_order_visit(node, callback) return output[::-1]
def extract_buffers(stmt): buffers = [] def visitor(node): if isinstance(node, (tvm.tir.BufferLoad, tvm.tir.BufferStore, tvm.tir.BufferRealize)): buffers.append(node.buffer) post_order_visit(stmt, visitor) return buffers
def extract_loop_vars(stmt): output = [] def callback(node): if isinstance(node, tvm.tir.For): output.append(node.loop_var) post_order_visit(stmt, callback) return output[::-1]
def _get_allocates(primfunc): """helper to extract all allocate nodes by name""" allocates = dict() def get_allocate(stmt): if isinstance(stmt, tvm.tir.Allocate): allocates[str(stmt.buffer_var.name)] = stmt stmt_functor.post_order_visit(primfunc.body, get_allocate) return allocates
def test_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(): primfunc = fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2 mod = tvm.IRModule.from_expr(primfunc) mod = tvm.tir.transform.ConvertForLoopsToSerial()(mod) def verify_serial_loops(stmt): if isinstance(stmt, tvm.tir.For): assert stmt.kind == tvm.tir.ForKind.SERIAL for _, primfunc in mod.functions.items(): stmt_functor.post_order_visit(primfunc.body, verify_serial_loops)
def _get_block(s: tir.ScheduleState, name_hint: str) -> tir.StmtSRef: result = None def f_visit(node): nonlocal result if isinstance(node, tvm.tir.Block) and node.name_hint == name_hint: result = node func = s.mod["main"] post_order_visit(func.body, f_visit) assert result is not None and isinstance(result, tvm.tir.Block) return s.get_sref(result)
def extract_logical_indices(stmt): output = {} # Since the for loops can be reordered by the layout # transformation, identify the loop corresponding to each # pre-transformation axis based on the iteration extent. def callback(node): if isinstance(node, tvm.tir.For): output[node.loop_var] = node.extent.value post_order_visit(stmt, callback) return sorted(output, key=output.get)
def walk_buffer_interactions(stmt, callback): buffer_classes = [ tvm.tir.BufferLoad, tvm.tir.BufferStore, tvm.tir.BufferRealize, ] def inner(node): if (type(node) in buffer_classes) and node.buffer.name == "B": callback(node) post_order_visit(stmt, inner)
def extract_ethosu_depthwise_conv2d_extern_call(mod): # There should only be a single function assert len(mod.functions.items()) == 1 primfunc = mod.functions.items()[0][1] ethosu_depthwise_conv2d_calls = list() def populate_ethosu_depthwise_conv2d_calls(stmt): if ( isinstance(stmt, tvm.tir.Call) and stmt.op.name == "tir.call_extern" and stmt.args[0] == "ethosu_depthwise_conv2d" ): ethosu_depthwise_conv2d_calls.append(stmt) stmt_functor.post_order_visit(primfunc.body, populate_ethosu_depthwise_conv2d_calls) return ethosu_depthwise_conv2d_calls[0]
def test_2d_physical(self, dtype, transform_A, transform_B): logical_shape = (2, 3, 4) A = te.placeholder(shape=logical_shape, dtype=dtype, name="A") B = te.compute(shape=A.shape, fcompute=lambda i, j, k: A[i, j, k], name="B") s = te.create_schedule(B.op) func = self.get_transform(transform_A) if func: s[A].transform_layout(func) func = self.get_transform(transform_B) if func: s[B].transform_layout(func) # If the two buffers are accessed with the same indices, CSE # will replace them with a Let binding. Since this makes it # harder to test what the transformed indices are, disabling # the CSE pass for this test. with tvm.transform.PassContext( disabled_pass=["tir.CommonSubexprElimTIR"]): mod = tvm.lower(s, [A, B]) logical_index_vars = self.extract_logical_indices(mod["main"].body) expected_indices_A = self.transform_indices(transform_A, logical_shape, logical_index_vars) expected_indices_B = self.transform_indices(transform_B, logical_shape, logical_index_vars) def callback(node): if type(node) in [tvm.tir.BufferLoad, tvm.tir.BufferStore]: name = node.buffer.name if name == "A": expected_indices = expected_indices_A elif name == "B": expected_indices = expected_indices_B else: raise RuntimeError(f"Unexpected buffer: {name}") tvm.ir.assert_structural_equal(expected_indices, node.indices) post_order_visit(mod["main"].body, callback)