def build_int8_trt_implicit_quant(rn18): rn18 = copy.deepcopy(rn18) data = torch.randn(1, 3, 224, 224) # Quantization qconfig = torch.ao.quantization.QConfig( activation=torch.ao.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, reduce_range=True), weight=torch.ao.quantization.default_per_channel_weight_observer) prepared = prepare_fx(rn18, {"": qconfig}) for _ in range(10): prepared(data) quantized_rn18 = convert_fx(prepared) ref_res = quantized_rn18(data) # Build trt int8 model traced_rn18 = torch.fx.symbolic_trace(quantized_rn18) shape_prop.ShapeProp(traced_rn18).propagate(data) traced_rn18 = NormalizeArgs(traced_rn18).transform() interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data]), logger_level=trt.Logger.VERBOSE) engine, input_names, output_names = interp.run( fp16_mode=False, int8_mode=True, strict_type_constraints=True) trt_mod = TRTModule(engine, input_names, output_names) trt_res = trt_mod(data.cuda()) print("implicit quant result diff max", torch.max(ref_res - trt_res.cpu())) return trt_mod
def __init__( self, module: torch.fx.GraphModule, sample_input: Tensors, operator_support: OperatorSupport, settings: _SplitterSettingBase, ): """ Preprocesses graph before splitting: - finds nodes supported by ACC, - finds fusion groups for ACC nodes having non-tensor IO, - builds a graph of direct dependencies, - builds a map of fused nodes to their fusions. As a result we get self.acc_nodes, self.deps and self.fusions. """ self.module = module shape_prop.ShapeProp(self.module).propagate(*sample_input) self.settings = settings self.operator_support = operator_support self.sample_input = sample_input self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)() if self.settings.skip_fusion: self.fusions = {} else: self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)() # Modify deps to add more deps for fused nodes self.deps = self.find_deps() self.update_deps_for_fusions()
def run_test_custom_compare_results(self, mod, inputs, expected_ops, comparators: List[Tuple[Callable, List]], interpreter=None): # interpreter is ignored, we do not need this for Vanilla tests # Note this is different from internal version, we need to fix the test case # after we refactor the internal callsites to use this file mod = torch.fx.symbolic_trace(mod) shape_prop.ShapeProp(mod).propagate(*inputs) mod = NormalizeArgs(mod).transform() interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) super().run_test_custom_compare_results(mod, inputs, expected_ops, comparators, interp)
def test_acc_fusions_finder_2(self): """ Let b and d be cpu nodes. After fusion all nodes should be cpu nodes because d is included in the fusion group which force all other nodes in the same fusion group to be on CPU too. """ module_nn = self.TestModule() module_fx = torch.fx.symbolic_trace(module_nn) shape_prop.ShapeProp(module_fx).propagate(torch.randn(1, 1, 1)) acc_node = { node for node in module_fx.graph.nodes if node.target == operator.add } fusions_finder = torch.fx.passes.splitter_base.FxNetAccFusionsFinder( module_fx, acc_node, ) fusion_map = fusions_finder() self.assertEqual(len(fusion_map), 0)
def test_acc_fusions_finder_1(self): """ Assume every node is acc node. We should have one fusion group (a, a0, a1, a2, c, d, e). """ module_nn = self.TestModule() module_fx = torch.fx.symbolic_trace(module_nn) shape_prop.ShapeProp(module_fx).propagate(torch.randn(1, 1, 1)) acc_node = { node for node in module_fx.graph.nodes if node.op in torch.fx.passes.tools_common.CALLABLE_NODE_OPS } fusions_finder = torch.fx.passes.splitter_base.FxNetAccFusionsFinder( module_fx, acc_node, ) fusion_map = fusions_finder() self.assertEqual(len(fusion_map), 7) for _, v in fusion_map.items(): self.assertEqual(len(v), 7)
def trace( mod: nn.Module, sample_inputs: List[torch.Tensor], remove_assertions: bool = True, remove_exceptions: bool = True, use_acc_normalization: bool = True, ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None, leaf_module_list: Optional[Set[Type[nn.Module]]] = None, ) -> torch.fx.GraphModule: """ Performs tracing and arg normalization specialized for accelerator lowering. It first rewrites the AST of the module's methods (and all attr methods recursively) to transform un-tracable parts of the module to make them traceable. It then traces to the functional level so that optimizations and backend accelerator importers have the ability to see and/or change inputs to each op. It then removes assertions and exception wrappers found during symbolic tracing if requested based on remove_assertions and remove_exceptions Dead code is then eliminated, which will e.g. remove any nodes that were only used by assertions or exceptions if they were removed. It then performs normalization on args/kwargs, aligning any arg that can be moved to kwarg to be so, and then making default values explicit. Args: mod (Module): The module to transform and trace. sample_inputs (Tuple[Union[torch.Tensor, List[torch.Tensor]]]): Sample inputs with which to run shape prop. remove_assertions (bool): Whether to remove assertion nodes from the graph after symbolic tracing. remove_exceptions (bool): Whether to remove exception wrapper nodes from the graph after symbolic tracing. use_acc_normalization (bool): Whether to use acc-specific normalization to all acc_ops. ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of modules that need AST rewriting. leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where modules will not be traced into. """ if mod.training: warnings.warn( "acc_tracer does not support currently support models for training." " Calling eval on model before tracing.") mod.eval() # Rewrite the module to make it symbolic traceable, and then trace it. rewritten_graph, rewritten_mod = AccRewritingTracer().trace( mod, ast_rewriter_allow_list=ast_rewriter_allow_list, leaf_module_list=leaf_module_list, ) assert isinstance(rewritten_mod, nn.Module) # Note: use the rewritten_mod here as the root. This is necessary because # RewrittenModule includes a new module for the ConditionalExceptionWrapper. traced = torch.fx.GraphModule(rewritten_mod, rewritten_graph) # Now remove all assertions and exceptions if requested. if remove_assertions: _remove_assertions(traced) if remove_exceptions: _remove_exceptions(traced) # Cleanup any dead code from the original module as well as resulting dead # nodes after removing assertions and exceptions. traced.graph.eliminate_dead_code() # Now normalize args/kwargs to make default values visible. Leave args/kwargs as # they were, since all-kwarg normalization is broken, and we don't need it anyway. shape_prop.ShapeProp(traced).propagate(*sample_inputs) traced = NormalizeArgs(traced, normalize_to_only_use_kwargs=False).transform() # Normalize to acc-specialized wrappers for consistency across op naming and # ensuring all kwarg usage. if use_acc_normalization: acc_normalizer.normalize(traced) traced.recompile() return traced
def run_test(self, mod, inputs, expected_ops, rtol=1e-05, atol=1e-06): mod = torch.fx.symbolic_trace(mod) shape_prop.ShapeProp(mod).propagate(*inputs) mod = NormalizeArgs(mod).transform() interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) super().run_test(mod, inputs, expected_ops, interp, rtol, atol)
def split_preview(self, dump_graph: bool = False): reports = "" subgraphs = self.put_nodes_into_subgraphs() acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" subgraphs = self.remove_small_acc_subgraphs(subgraphs) acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" for i, subgraph in enumerate(subgraphs): reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"_run_on_cpu_{i}: " reports += f"{len(subgraph.nodes)} node(s)\n" self.tag(subgraphs) split_mod = self.split(remove_tag=True) split_mod.eval() if dump_graph: drawer = FxGraphDrawer(split_mod, "preview", ignore_getattr=True) dot_graphs = drawer.get_all_dot_graphs() for name, dot_graph in dot_graphs.items(): dot_graph.write_raw(f"{name}.dot") max_qps: float = self.PCIe_BW bottleneck_module = "" for node in split_mod.graph.nodes: if node.op == "call_module" and "acc" in node.target: reports += f"\nProcessing acc submodule {node.target}\n" submod = getattr(split_mod, node.target) def get_submod_inputs(main_mod, submod, example_inputs): sub_inputs = None def get_inputs(self, inputs): nonlocal sub_inputs sub_inputs = inputs handle = submod.register_forward_pre_hook(get_inputs) main_mod(*example_inputs) handle.remove() return sub_inputs submod_inputs = get_submod_inputs(split_mod, submod, self.sample_input) shape_prop.ShapeProp(submod).propagate(*submod_inputs) total_input_bytes = 0 total_output_bytes = 0 reports += "Checking inputs...\n" for n in submod.graph.nodes: if n.op == "placeholder": if "tensor_meta" not in n.meta: reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n" else: total_input_bytes += get_size_of_node(submod, n)[0] if n.op == "output": output_node = n reports += "Checking outputs...\n" def get_bytes(node: torch.fx.Node): nonlocal total_output_bytes nonlocal reports if "tensor_meta" not in node.meta: reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n" else: total_output_bytes += get_size_of_node(submod, node)[0] map_arg(output_node.args, get_bytes) qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes) reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes}," reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n" if qps < max_qps: max_qps = qps bottleneck_module = node.target try: lowered_submod = self._lower_model_to_backend( submod, submod_inputs) except RuntimeError: reports += "Run into an error during lowering!\n" reports += self._find_culprit(submod, submod_inputs) continue try: lowered_submod(*submod_inputs) except RuntimeError: reports += "Run into an error during inference!\n" reports += self._find_culprit(submod, submod_inputs) else: reports += "Lowering and running succeed!\n" reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps}," reports += f" bottleneck is submodule {bottleneck_module}." print(reports) # return the reports for testing purposes return reports