예제 #1
0
def build_int8_trt_implicit_quant(rn18):
    rn18 = copy.deepcopy(rn18)
    data = torch.randn(1, 3, 224, 224)
    # Quantization
    qconfig = torch.ao.quantization.QConfig(
        activation=torch.ao.quantization.observer.HistogramObserver.with_args(
            qscheme=torch.per_tensor_symmetric, reduce_range=True),
        weight=torch.ao.quantization.default_per_channel_weight_observer)
    prepared = prepare_fx(rn18, {"": qconfig})
    for _ in range(10):
        prepared(data)
    quantized_rn18 = convert_fx(prepared)
    ref_res = quantized_rn18(data)

    # Build trt int8 model
    traced_rn18 = torch.fx.symbolic_trace(quantized_rn18)
    shape_prop.ShapeProp(traced_rn18).propagate(data)
    traced_rn18 = NormalizeArgs(traced_rn18).transform()
    interp = TRTInterpreter(traced_rn18,
                            InputTensorSpec.from_tensors([data]),
                            logger_level=trt.Logger.VERBOSE)
    engine, input_names, output_names = interp.run(
        fp16_mode=False, int8_mode=True, strict_type_constraints=True)
    trt_mod = TRTModule(engine, input_names, output_names)
    trt_res = trt_mod(data.cuda())
    print("implicit quant result diff max", torch.max(ref_res - trt_res.cpu()))
    return trt_mod
예제 #2
0
    def __init__(
        self,
        module: torch.fx.GraphModule,
        sample_input: Tensors,
        operator_support: OperatorSupport,
        settings: _SplitterSettingBase,
    ):
        """
        Preprocesses graph before splitting:
        - finds nodes supported by ACC,
        - finds fusion groups for ACC nodes having non-tensor IO,
        - builds a graph of direct dependencies,
        - builds a map of fused nodes to their fusions.
        As a result we get self.acc_nodes, self.deps and self.fusions.
        """
        self.module = module
        shape_prop.ShapeProp(self.module).propagate(*sample_input)

        self.settings = settings
        self.operator_support = operator_support
        self.sample_input = sample_input
        self.acc_nodes = FxNetAccNodesFinder(self.module,
                                             self.operator_support,
                                             self.settings.allow_non_tensor)()

        if self.settings.skip_fusion:
            self.fusions = {}
        else:
            self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)()

        # Modify deps to add more deps for fused nodes
        self.deps = self.find_deps()
        self.update_deps_for_fusions()
예제 #3
0
 def run_test_custom_compare_results(self,
                                     mod,
                                     inputs,
                                     expected_ops,
                                     comparators: List[Tuple[Callable,
                                                             List]],
                                     interpreter=None):
     # interpreter is ignored, we do not need this for Vanilla tests
     # Note this is different from internal version, we need to fix the test case
     # after we refactor the internal callsites to use this file
     mod = torch.fx.symbolic_trace(mod)
     shape_prop.ShapeProp(mod).propagate(*inputs)
     mod = NormalizeArgs(mod).transform()
     interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
     super().run_test_custom_compare_results(mod, inputs, expected_ops,
                                             comparators, interp)
예제 #4
0
    def test_acc_fusions_finder_2(self):
        """
        Let b and d be cpu nodes. After fusion all nodes should be cpu nodes
        because d is included in the fusion group which force all other nodes
        in the same fusion group to be on CPU too.
        """
        module_nn = self.TestModule()
        module_fx = torch.fx.symbolic_trace(module_nn)
        shape_prop.ShapeProp(module_fx).propagate(torch.randn(1, 1, 1))

        acc_node = {
            node for node in module_fx.graph.nodes if node.target == operator.add
        }
        fusions_finder = torch.fx.passes.splitter_base.FxNetAccFusionsFinder(
            module_fx,
            acc_node,
        )
        fusion_map = fusions_finder()
        self.assertEqual(len(fusion_map), 0)
예제 #5
0
    def test_acc_fusions_finder_1(self):
        """
        Assume every node is acc node. We should have one fusion group
        (a, a0, a1, a2, c, d, e).
        """
        module_nn = self.TestModule()
        module_fx = torch.fx.symbolic_trace(module_nn)
        shape_prop.ShapeProp(module_fx).propagate(torch.randn(1, 1, 1))

        acc_node = {
            node
            for node in module_fx.graph.nodes
            if node.op in torch.fx.passes.tools_common.CALLABLE_NODE_OPS
        }

        fusions_finder = torch.fx.passes.splitter_base.FxNetAccFusionsFinder(
            module_fx,
            acc_node,
        )
        fusion_map = fusions_finder()

        self.assertEqual(len(fusion_map), 7)
        for _, v in fusion_map.items():
            self.assertEqual(len(v), 7)
예제 #6
0
def trace(
    mod: nn.Module,
    sample_inputs: List[torch.Tensor],
    remove_assertions: bool = True,
    remove_exceptions: bool = True,
    use_acc_normalization: bool = True,
    ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None,
    leaf_module_list: Optional[Set[Type[nn.Module]]] = None,
) -> torch.fx.GraphModule:
    """
    Performs tracing and arg normalization specialized for accelerator lowering.

    It first rewrites the AST of the module's methods (and all attr methods
    recursively) to transform un-tracable parts of the module to make them
    traceable.

    It then traces to the functional level so that optimizations and backend
    accelerator importers have the ability to see and/or change inputs to each
    op.

    It then removes assertions and exception wrappers found during symbolic
    tracing if requested based on remove_assertions and remove_exceptions

    Dead code is then eliminated, which will e.g. remove any nodes that were
    only used by assertions or exceptions if they were removed.

    It then performs normalization on args/kwargs, aligning any arg that can be
    moved to kwarg to be so, and then making default values explicit.

    Args:

        mod (Module): The module to transform and trace.

        sample_inputs (Tuple[Union[torch.Tensor, List[torch.Tensor]]]):
                Sample inputs with which to run shape prop.

        remove_assertions (bool): Whether to remove assertion nodes from
                                    the graph after symbolic tracing.

        remove_exceptions (bool): Whether to remove exception wrapper nodes
                                    from the graph after symbolic tracing.

        use_acc_normalization (bool): Whether to use acc-specific
                                        normalization to all acc_ops.

        ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of
                                            modules that need AST rewriting.

        leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where
                                            modules will not be traced into.

    """
    if mod.training:
        warnings.warn(
            "acc_tracer does not support currently support models for training."
            " Calling eval on model before tracing.")
        mod.eval()

    # Rewrite the module to make it symbolic traceable, and then trace it.
    rewritten_graph, rewritten_mod = AccRewritingTracer().trace(
        mod,
        ast_rewriter_allow_list=ast_rewriter_allow_list,
        leaf_module_list=leaf_module_list,
    )

    assert isinstance(rewritten_mod, nn.Module)
    # Note: use the rewritten_mod here as the root. This is necessary because
    # RewrittenModule includes a new module for the ConditionalExceptionWrapper.
    traced = torch.fx.GraphModule(rewritten_mod, rewritten_graph)

    # Now remove all assertions and exceptions if requested.
    if remove_assertions:
        _remove_assertions(traced)
    if remove_exceptions:
        _remove_exceptions(traced)

    # Cleanup any dead code from the original module as well as resulting dead
    # nodes after removing assertions and exceptions.
    traced.graph.eliminate_dead_code()

    # Now normalize args/kwargs to make default values visible. Leave args/kwargs as
    # they were, since all-kwarg normalization is broken, and we don't need it anyway.
    shape_prop.ShapeProp(traced).propagate(*sample_inputs)
    traced = NormalizeArgs(traced,
                           normalize_to_only_use_kwargs=False).transform()

    # Normalize to acc-specialized wrappers for consistency across op naming and
    # ensuring all kwarg usage.
    if use_acc_normalization:
        acc_normalizer.normalize(traced)

    traced.recompile()

    return traced
예제 #7
0
 def run_test(self, mod, inputs, expected_ops, rtol=1e-05, atol=1e-06):
     mod = torch.fx.symbolic_trace(mod)
     shape_prop.ShapeProp(mod).propagate(*inputs)
     mod = NormalizeArgs(mod).transform()
     interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
     super().run_test(mod, inputs, expected_ops, interp, rtol, atol)
예제 #8
0
    def split_preview(self, dump_graph: bool = False):
        reports = ""
        subgraphs = self.put_nodes_into_subgraphs()
        acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
        cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
        reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
        reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"

        subgraphs = self.remove_small_acc_subgraphs(subgraphs)
        acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
        cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
        reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
        reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"

        for i, subgraph in enumerate(subgraphs):
            reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"_run_on_cpu_{i}: "
            reports += f"{len(subgraph.nodes)} node(s)\n"

        self.tag(subgraphs)
        split_mod = self.split(remove_tag=True)
        split_mod.eval()

        if dump_graph:
            drawer = FxGraphDrawer(split_mod, "preview", ignore_getattr=True)
            dot_graphs = drawer.get_all_dot_graphs()
            for name, dot_graph in dot_graphs.items():
                dot_graph.write_raw(f"{name}.dot")

        max_qps: float = self.PCIe_BW
        bottleneck_module = ""

        for node in split_mod.graph.nodes:
            if node.op == "call_module" and "acc" in node.target:
                reports += f"\nProcessing acc submodule {node.target}\n"

                submod = getattr(split_mod, node.target)

                def get_submod_inputs(main_mod, submod, example_inputs):
                    sub_inputs = None

                    def get_inputs(self, inputs):
                        nonlocal sub_inputs
                        sub_inputs = inputs

                    handle = submod.register_forward_pre_hook(get_inputs)
                    main_mod(*example_inputs)
                    handle.remove()
                    return sub_inputs

                submod_inputs = get_submod_inputs(split_mod, submod,
                                                  self.sample_input)
                shape_prop.ShapeProp(submod).propagate(*submod_inputs)

                total_input_bytes = 0
                total_output_bytes = 0

                reports += "Checking inputs...\n"
                for n in submod.graph.nodes:
                    if n.op == "placeholder":
                        if "tensor_meta" not in n.meta:
                            reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n"
                        else:
                            total_input_bytes += get_size_of_node(submod, n)[0]
                    if n.op == "output":
                        output_node = n

                reports += "Checking outputs...\n"

                def get_bytes(node: torch.fx.Node):
                    nonlocal total_output_bytes
                    nonlocal reports
                    if "tensor_meta" not in node.meta:
                        reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n"
                    else:
                        total_output_bytes += get_size_of_node(submod, node)[0]

                map_arg(output_node.args, get_bytes)
                qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
                reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
                reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"

                if qps < max_qps:
                    max_qps = qps
                    bottleneck_module = node.target

                try:
                    lowered_submod = self._lower_model_to_backend(
                        submod, submod_inputs)
                except RuntimeError:
                    reports += "Run into an error during lowering!\n"
                    reports += self._find_culprit(submod, submod_inputs)
                    continue

                try:
                    lowered_submod(*submod_inputs)
                except RuntimeError:
                    reports += "Run into an error during inference!\n"
                    reports += self._find_culprit(submod, submod_inputs)
                else:
                    reports += "Lowering and running succeed!\n"

        reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps},"
        reports += f" bottleneck is submodule {bottleneck_module}."
        print(reports)

        # return the reports for testing purposes
        return reports