def __init__( self, relay_mod: tvm.IRModule, relay_param: Dict[str, tvm.runtime.NDArray] = None, plotter: Plotter = None, parser: VizParser = None, ): self._plotter = plotter if plotter is not None else TermPlotter() self._relay_param = relay_param if relay_param is not None else {} self._parser = parser if parser is not None else TermVizParser() global_vars = relay_mod.get_global_vars() graph_names = [] # If we have main function, put it to the first. # Then main function can be shown on the top. for gv_node in global_vars: if gv_node.name_hint == "main": graph_names.insert(0, gv_node.name_hint) else: graph_names.append(gv_node.name_hint) node_to_id = {} # callback to generate an unique string-ID for nodes. def traverse_expr(node): if node in node_to_id: return node_to_id[node] = str(len(node_to_id)) for name in graph_names: node_to_id.clear() relay.analysis.post_order_visit(relay_mod[name], traverse_expr) graph = self._plotter.create_graph(name) self._add_nodes(graph, node_to_id)
def prune_tensorrt_subgraphs(mod: tvm.IRModule) -> tvm.IRModule: """ Un-partition those partitions which: - have no multiply-accumulates (if remove_no_mac_subgraphs is True) - can't actually be supported by TensorRT now that we see the whole partition.""" global_vars_to_inline = [ gv for gv in mod.get_global_vars() if mod[gv].attrs and mod[gv].attrs["Compiler"] == "tensorrt" and not is_valid_subgraph(mod[gv].params, mod[gv].body) ] return relay.transform.InlineCompilerFunctionsBoundTo( global_vars_to_inline)(mod)
def prune_tensorrt_subgraphs(mod: tvm.IRModule) -> tvm.IRModule: """ Removes invalid subgraphs and those with no multiply-accumulates (if remove_no_max_subgraphs is set). """ class SubgraphRemover(ExprMutator): """ Reverts subgraphs in subgraphs_to_remove back to TVM instead of using an external codegen. """ def __init__( self, subgraphs_to_remove: List[str], mod: tvm.IRModule, new_mod: tvm.IRModule ) -> None: ExprMutator.__init__(self) self.subgraphs_to_remove = subgraphs_to_remove self.mod = mod self.new_mod = new_mod def visit_call(self, call: relay.expr.Call) -> relay.expr.Expr: if isinstance(call.op, GlobalVar): name = call.op.name_hint if name in self.subgraphs_to_remove: # "Inline" the subgraph back into new main function. func = self.mod[name] var_map = {} for arg, param in zip(call.args, func.params): var_map[param] = super().visit(arg) new_body = relay.bind(func.body, var_map) return new_body if name != "main": args = [] for arg in call.args: args.append(super().visit(arg)) return call.op(*args) return super().visit_call(call) subgraphs_to_remove: List[str] = [] # Remove invalid subgraphs for subgraph in mod.get_global_vars(): name = subgraph.name_hint if not mod[name].attrs or mod[name].attrs["Compiler"] != "tensorrt": continue if not is_valid_subgraph(mod[name].params, mod[name].body): subgraphs_to_remove.append(name) # Create new pruned module new_mod = tvm.IRModule(mod.functions, mod.type_definitions) new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"]) new_mod = transform.RemoveUnusedFunctions()(new_mod) return new_mod