def extract_call_extern_list(mod):
    """This function will obtain all extern
    calls from a TIR module
    Parameters
    ----------
    mod : tvm.IRModule
        The TIR Module for NPU

    Returns
    -------
    list
        of tvm.tir.Call objects
        that are tir extern calls
    """
    # There should only be a single function
    assert len(mod.functions.items()) == 1
    primfunc = mod.functions.items()[0][1]

    call_extern_list = list()

    def populate_call_extern_list(stmt):
        if isinstance(stmt,
                      tvm.tir.Call) and stmt.op.name == "tir.call_extern":
            call_extern_list.append(stmt)

    stmt_functor.post_order_visit(primfunc.body, populate_call_extern_list)
    return call_extern_list
Example #2
0
def extract_ethosu_conv2d_extern_calls(mod):
    """This function will obtain all ethosu_conv2d
    calls from a NPU TIR module

    Parameters
    ----------
    mod : tvm.IRModule
        This is a NPU TIR Module

    Returns
    -------
    list
        List of tvm.tir.Call objects
        that are tir extern calls
        for ethosu_conv2d
    """
    # There should only be a single function
    assert len(mod.functions.items()) == 1
    primfunc = mod.functions.items()[0][1]

    ethosu_conv2d_calls = list()

    def populate_ethosu_conv2d_calls(stmt):
        if (
            isinstance(stmt, tvm.tir.Call)
            and stmt.op.name == "T.call_extern"
            and stmt.args[0] == "ethosu_conv2d"
        ):
            ethosu_conv2d_calls.append(stmt)

    stmt_functor.post_order_visit(primfunc.body, populate_ethosu_conv2d_calls)
    return ethosu_conv2d_calls
Example #3
0
        def collect_loops(node):
            output = []

            def callback(node):
                if isinstance(node, tvm.tir.For):
                    output.append(node)

            post_order_visit(node, callback)
            return output[::-1]
Example #4
0
def extract_buffers(stmt):
    buffers = []

    def visitor(node):
        if isinstance(node, (tvm.tir.BufferLoad, tvm.tir.BufferStore, tvm.tir.BufferRealize)):
            buffers.append(node.buffer)

    post_order_visit(stmt, visitor)
    return buffers
Example #5
0
    def extract_loop_vars(stmt):
        output = []

        def callback(node):
            if isinstance(node, tvm.tir.For):
                output.append(node.loop_var)

        post_order_visit(stmt, callback)
        return output[::-1]
Example #6
0
def _get_allocates(primfunc):
    """helper to extract all allocate nodes by name"""
    allocates = dict()

    def get_allocate(stmt):
        if isinstance(stmt, tvm.tir.Allocate):
            allocates[str(stmt.buffer_var.name)] = stmt

    stmt_functor.post_order_visit(primfunc.body, get_allocate)
    return allocates
def test_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2():
    primfunc = fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2
    mod = tvm.IRModule.from_expr(primfunc)
    mod = tvm.tir.transform.ConvertForLoopsToSerial()(mod)

    def verify_serial_loops(stmt):
        if isinstance(stmt, tvm.tir.For):
            assert stmt.kind == tvm.tir.ForKind.SERIAL

    for _, primfunc in mod.functions.items():
        stmt_functor.post_order_visit(primfunc.body, verify_serial_loops)
def _get_block(s: tir.ScheduleState, name_hint: str) -> tir.StmtSRef:
    result = None

    def f_visit(node):
        nonlocal result
        if isinstance(node, tvm.tir.Block) and node.name_hint == name_hint:
            result = node

    func = s.mod["main"]
    post_order_visit(func.body, f_visit)
    assert result is not None and isinstance(result, tvm.tir.Block)
    return s.get_sref(result)
Example #9
0
    def extract_logical_indices(stmt):
        output = {}

        # Since the for loops can be reordered by the layout
        # transformation, identify the loop corresponding to each
        # pre-transformation axis based on the iteration extent.
        def callback(node):
            if isinstance(node, tvm.tir.For):
                output[node.loop_var] = node.extent.value

        post_order_visit(stmt, callback)
        return sorted(output, key=output.get)
Example #10
0
        def walk_buffer_interactions(stmt, callback):
            buffer_classes = [
                tvm.tir.BufferLoad,
                tvm.tir.BufferStore,
                tvm.tir.BufferRealize,
            ]

            def inner(node):
                if (type(node) in buffer_classes) and node.buffer.name == "B":
                    callback(node)

            post_order_visit(stmt, inner)
    def extract_ethosu_depthwise_conv2d_extern_call(mod):
        # There should only be a single function
        assert len(mod.functions.items()) == 1
        primfunc = mod.functions.items()[0][1]

        ethosu_depthwise_conv2d_calls = list()

        def populate_ethosu_depthwise_conv2d_calls(stmt):
            if (
                isinstance(stmt, tvm.tir.Call)
                and stmt.op.name == "tir.call_extern"
                and stmt.args[0] == "ethosu_depthwise_conv2d"
            ):
                ethosu_depthwise_conv2d_calls.append(stmt)

        stmt_functor.post_order_visit(primfunc.body, populate_ethosu_depthwise_conv2d_calls)
        return ethosu_depthwise_conv2d_calls[0]
Example #12
0
    def test_2d_physical(self, dtype, transform_A, transform_B):
        logical_shape = (2, 3, 4)
        A = te.placeholder(shape=logical_shape, dtype=dtype, name="A")
        B = te.compute(shape=A.shape,
                       fcompute=lambda i, j, k: A[i, j, k],
                       name="B")

        s = te.create_schedule(B.op)

        func = self.get_transform(transform_A)
        if func:
            s[A].transform_layout(func)

        func = self.get_transform(transform_B)
        if func:
            s[B].transform_layout(func)

        # If the two buffers are accessed with the same indices, CSE
        # will replace them with a Let binding.  Since this makes it
        # harder to test what the transformed indices are, disabling
        # the CSE pass for this test.
        with tvm.transform.PassContext(
                disabled_pass=["tir.CommonSubexprElimTIR"]):
            mod = tvm.lower(s, [A, B])

        logical_index_vars = self.extract_logical_indices(mod["main"].body)
        expected_indices_A = self.transform_indices(transform_A, logical_shape,
                                                    logical_index_vars)
        expected_indices_B = self.transform_indices(transform_B, logical_shape,
                                                    logical_index_vars)

        def callback(node):
            if type(node) in [tvm.tir.BufferLoad, tvm.tir.BufferStore]:
                name = node.buffer.name
                if name == "A":
                    expected_indices = expected_indices_A
                elif name == "B":
                    expected_indices = expected_indices_B
                else:
                    raise RuntimeError(f"Unexpected buffer: {name}")

                tvm.ir.assert_structural_equal(expected_indices, node.indices)

        post_order_visit(mod["main"].body, callback)