Exemple #1
0
    def __init__(
        self,
        relay_mod: tvm.IRModule,
        relay_param: Dict[str, tvm.runtime.NDArray] = None,
        plotter: Plotter = None,
        parser: VizParser = None,
    ):
        self._plotter = plotter if plotter is not None else TermPlotter()
        self._relay_param = relay_param if relay_param is not None else {}
        self._parser = parser if parser is not None else TermVizParser()

        global_vars = relay_mod.get_global_vars()
        graph_names = []
        # If we have main function, put it to the first.
        # Then main function can be shown on the top.
        for gv_node in global_vars:
            if gv_node.name_hint == "main":
                graph_names.insert(0, gv_node.name_hint)
            else:
                graph_names.append(gv_node.name_hint)

        node_to_id = {}

        # callback to generate an unique string-ID for nodes.
        def traverse_expr(node):
            if node in node_to_id:
                return
            node_to_id[node] = str(len(node_to_id))

        for name in graph_names:
            node_to_id.clear()
            relay.analysis.post_order_visit(relay_mod[name], traverse_expr)
            graph = self._plotter.create_graph(name)
            self._add_nodes(graph, node_to_id)
Exemple #2
0
def prune_tensorrt_subgraphs(mod: tvm.IRModule) -> tvm.IRModule:
    """
    Un-partition those partitions which:
     - have no multiply-accumulates (if remove_no_mac_subgraphs is True)
     - can't actually be supported by TensorRT now that we see the whole partition."""
    global_vars_to_inline = [
        gv for gv in mod.get_global_vars()
        if mod[gv].attrs and mod[gv].attrs["Compiler"] == "tensorrt"
        and not is_valid_subgraph(mod[gv].params, mod[gv].body)
    ]
    return relay.transform.InlineCompilerFunctionsBoundTo(
        global_vars_to_inline)(mod)
Exemple #3
0
def prune_tensorrt_subgraphs(mod: tvm.IRModule) -> tvm.IRModule:
    """
    Removes invalid subgraphs and those with no multiply-accumulates (if remove_no_max_subgraphs
    is set).
    """

    class SubgraphRemover(ExprMutator):
        """
        Reverts subgraphs in subgraphs_to_remove back to TVM instead of using an external codegen.
        """

        def __init__(
            self, subgraphs_to_remove: List[str], mod: tvm.IRModule, new_mod: tvm.IRModule
        ) -> None:
            ExprMutator.__init__(self)
            self.subgraphs_to_remove = subgraphs_to_remove
            self.mod = mod
            self.new_mod = new_mod

        def visit_call(self, call: relay.expr.Call) -> relay.expr.Expr:
            if isinstance(call.op, GlobalVar):
                name = call.op.name_hint
                if name in self.subgraphs_to_remove:
                    # "Inline" the subgraph back into new main function.
                    func = self.mod[name]
                    var_map = {}
                    for arg, param in zip(call.args, func.params):
                        var_map[param] = super().visit(arg)
                    new_body = relay.bind(func.body, var_map)
                    return new_body
                if name != "main":
                    args = []
                    for arg in call.args:
                        args.append(super().visit(arg))
                    return call.op(*args)
            return super().visit_call(call)

    subgraphs_to_remove: List[str] = []
    # Remove invalid subgraphs
    for subgraph in mod.get_global_vars():
        name = subgraph.name_hint
        if not mod[name].attrs or mod[name].attrs["Compiler"] != "tensorrt":
            continue
        if not is_valid_subgraph(mod[name].params, mod[name].body):
            subgraphs_to_remove.append(name)
    # Create new pruned module
    new_mod = tvm.IRModule(mod.functions, mod.type_definitions)
    new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"])
    new_mod = transform.RemoveUnusedFunctions()(new_mod)
    return new_mod