def test_normalize_args(self):
        m = resnet18()

        class FunctionalTracer(torch.fx.Tracer):
            def is_leaf_module(self, m: torch.nn.Module,
                               module_qualified_name: str) -> bool:
                # `leaves` contains the set of standard `nn.Modules` that are not
                # currently symbolically traceable. Ideally this set would be empty
                leaves = set([torch.nn.BatchNorm2d])
                return type(m) in leaves

        traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m))

        input = torch.randn(5, 3, 224, 224)
        ref_outs = traced(input)

        traced = NormalizeArgs(traced).transform()

        test_outs = traced(input)
        self.assertEqual(test_outs, ref_outs)

        modules = dict(traced.named_modules())
        for node in traced.graph.nodes:
            if node.op == 'call_function' and node.target.__module__ == 'torch.nn.functional':
                self.assertEqual(len(node.args), 0)
            if node.op == 'call_module':
                submod_class = modules[node.target].__class__
                nn_class = getattr(torch.nn, submod_class.__name__)
                if submod_class == nn_class:
                    self.assertEqual(len(node.args), 0)
Пример #2
0
def build_int8_trt_implicit_quant(rn18):
    rn18 = copy.deepcopy(rn18)
    data = torch.randn(1, 3, 224, 224)
    # Quantization
    qconfig = torch.ao.quantization.QConfig(
        activation=torch.ao.quantization.observer.HistogramObserver.with_args(
            qscheme=torch.per_tensor_symmetric, reduce_range=True),
        weight=torch.ao.quantization.default_per_channel_weight_observer)
    prepared = prepare_fx(rn18, {"": qconfig})
    for _ in range(10):
        prepared(data)
    quantized_rn18 = convert_fx(prepared)
    ref_res = quantized_rn18(data)

    # Build trt int8 model
    traced_rn18 = torch.fx.symbolic_trace(quantized_rn18)
    shape_prop.ShapeProp(traced_rn18).propagate(data)
    traced_rn18 = NormalizeArgs(traced_rn18).transform()
    interp = TRTInterpreter(traced_rn18,
                            InputTensorSpec.from_tensors([data]),
                            logger_level=trt.Logger.VERBOSE)
    engine, input_names, output_names = interp.run(
        fp16_mode=False, int8_mode=True, strict_type_constraints=True)
    trt_mod = TRTModule(engine, input_names, output_names)
    trt_res = trt_mod(data.cuda())
    print("implicit quant result diff max", torch.max(ref_res - trt_res.cpu()))
    return trt_mod
Пример #3
0
    def __init__(self,
                 module: torch.fx.GraphModule,
                 input_shapes: List[InputTensorSpec],
                 logger_level=trt.Logger.WARNING):
        # Preprocess the model
        module = copy.deepcopy(module)
        module = module.cpu().float()
        module = NormalizeArgs(module).transform()
        super().__init__(module)

        self.logger = trt.Logger(logger_level)
        self.builder = trt.Builder(self.logger)

        # TODO: explicit batching
        # EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
        # self.network = self.builder.create_network(EXPLICIT_BATCH)

        self.network = self.builder.create_network()

        self.input_shape_itr = iter(input_shapes)

        self._cur_node_name: Optional[str] = None

        self._input_names: List[str] = []
        self._output_names: List[str] = []
Пример #4
0
 def __init__(self, module : torch.nn.Module, input_specs : List[InputTensorSpec], logger_level=trt.Logger.WARNING):
     # Preprocess the model
     if not isinstance(module, torch.fx.GraphModule):
         module = torch.fx.symbolic_trace(module)
     else:
         module = copy.deepcopy(module)
     module = module.cpu().float()
     module = NormalizeArgs(module).transform()
     super().__init__(module, input_specs, logger_level=logger_level)
Пример #5
0
 def run_test_custom_compare_results(self,
                                     mod,
                                     inputs,
                                     expected_ops,
                                     comparators: List[Tuple[Callable,
                                                             List]],
                                     interpreter=None):
     # interpreter is ignored, we do not need this for Vanilla tests
     # Note this is different from internal version, we need to fix the test case
     # after we refactor the internal callsites to use this file
     mod = torch.fx.symbolic_trace(mod)
     shape_prop.ShapeProp(mod).propagate(*inputs)
     mod = NormalizeArgs(mod).transform()
     interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
     super().run_test_custom_compare_results(mod, inputs, expected_ops,
                                             comparators, interp)
Пример #6
0
    def __init__(self,
                 module: torch.fx.GraphModule,
                 input_shapes: List[InputTensorSpec],
                 logger_level=trt.Logger.WARNING):
        # Preprocess the model
        module = copy.copy(module)
        module = NormalizeArgs(module).transform()
        super().__init__(module)

        self.logger = trt.Logger(logger_level)
        self.builder = trt.Builder(self.logger)
        self.network = self.builder.create_network()

        self.input_shape_itr = iter(input_shapes)

        self._cur_node_name: Optional[str] = None

        self._input_names: List[str] = []
        self._output_names: List[str] = []
    def test_normalize_modules_exhaustive(self):
        """
        Exhaustively test `NormalizeArgs` on all standard
        torch.nn Module classes
        """
        for test_params in module_tests + new_module_tests:
            if 'constructor' not in test_params:
                constructor = getattr(torch.nn, test_params['module_name'])
            else:
                constructor = test_params['constructor']

            if 'constructor_args' not in test_params:
                args = ()
            else:
                args = test_params['constructor_args']

            mod = constructor(*args)
            # Skip modules that are not standard `torch.nn`
            # instances, including functionals. (functionals
            # are tested in test_normalize_args)
            if mod.__class__.__name__ not in dir(torch.nn):
                continue

            if 'input_fn' not in test_params:
                inputs = torch.randn(test_params['input_size'])
            else:
                inputs = test_params['input_fn']()

            if not isinstance(inputs, (tuple, list)):
                inputs = (inputs, )

            params = ', '.join(f'v{i}' for i in range(len(inputs)))

            # Generate a class to wrap this standard `nn.Module` instance
            test_classname = f'Test{mod.__class__.__name__}'
            test_mod_code = f"""
class {test_classname}(torch.nn.Module):
    def __init__(self, mod):
        super().__init__()
        self.mod = mod

    def forward(self, {params}):
        return self.mod({params})
            """

            gbls = {'torch': torch}
            exec(test_mod_code, gbls)

            test_instance = gbls[test_classname](mod)
            traced = symbolic_trace(test_instance)

            # Now actually test arg normalization!
            traced = NormalizeArgs(traced).transform()

            # These Modules have an RNG in their forward, so testing
            # correctness by comparing outputs is not correct. Skip that
            # check for these
            stochastic_modules = {
                'FractionalMaxPool2d', 'FractionalMaxPool3d', 'RReLU'
            }

            if mod.__class__.__name__ not in stochastic_modules:
                self.assertEqual(traced(*inputs), mod(*inputs))

            # Ensure all args/kwargs are normalized into kwargs
            modules = dict(traced.named_modules())
            for node in traced.graph.nodes:
                if node.op == 'call_module':
                    submod_class = modules[node.target].__class__
                    nn_class = getattr(torch.nn, submod_class.__name__)
                    if submod_class == nn_class:
                        self.assertEqual(len(node.args), 0)
Пример #8
0
def trace(
    mod: nn.Module,
    sample_inputs: List[torch.Tensor],
    remove_assertions: bool = True,
    remove_exceptions: bool = True,
    use_acc_normalization: bool = True,
    ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None,
    leaf_module_list: Optional[Set[Type[nn.Module]]] = None,
) -> torch.fx.GraphModule:
    """
    Performs tracing and arg normalization specialized for accelerator lowering.

    It first rewrites the AST of the module's methods (and all attr methods
    recursively) to transform un-tracable parts of the module to make them
    traceable.

    It then traces to the functional level so that optimizations and backend
    accelerator importers have the ability to see and/or change inputs to each
    op.

    It then removes assertions and exception wrappers found during symbolic
    tracing if requested based on remove_assertions and remove_exceptions

    Dead code is then eliminated, which will e.g. remove any nodes that were
    only used by assertions or exceptions if they were removed.

    It then performs normalization on args/kwargs, aligning any arg that can be
    moved to kwarg to be so, and then making default values explicit.

    Args:

        mod (Module): The module to transform and trace.

        sample_inputs (Tuple[Union[torch.Tensor, List[torch.Tensor]]]):
                Sample inputs with which to run shape prop.

        remove_assertions (bool): Whether to remove assertion nodes from
                                    the graph after symbolic tracing.

        remove_exceptions (bool): Whether to remove exception wrapper nodes
                                    from the graph after symbolic tracing.

        use_acc_normalization (bool): Whether to use acc-specific
                                        normalization to all acc_ops.

        ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of
                                            modules that need AST rewriting.

        leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where
                                            modules will not be traced into.

    """
    if mod.training:
        warnings.warn(
            "acc_tracer does not support currently support models for training."
            " Calling eval on model before tracing.")
        mod.eval()

    # Rewrite the module to make it symbolic traceable, and then trace it.
    rewritten_graph, rewritten_mod = AccRewritingTracer().trace(
        mod,
        ast_rewriter_allow_list=ast_rewriter_allow_list,
        leaf_module_list=leaf_module_list,
    )

    assert isinstance(rewritten_mod, nn.Module)
    # Note: use the rewritten_mod here as the root. This is necessary because
    # RewrittenModule includes a new module for the ConditionalExceptionWrapper.
    traced = torch.fx.GraphModule(rewritten_mod, rewritten_graph)

    # Now remove all assertions and exceptions if requested.
    if remove_assertions:
        _remove_assertions(traced)
    if remove_exceptions:
        _remove_exceptions(traced)

    # Cleanup any dead code from the original module as well as resulting dead
    # nodes after removing assertions and exceptions.
    traced.graph.eliminate_dead_code()

    # Now normalize args/kwargs to make default values visible. Leave args/kwargs as
    # they were, since all-kwarg normalization is broken, and we don't need it anyway.
    shape_prop.ShapeProp(traced).propagate(*sample_inputs)
    traced = NormalizeArgs(traced,
                           normalize_to_only_use_kwargs=False).transform()

    # Normalize to acc-specialized wrappers for consistency across op naming and
    # ensuring all kwarg usage.
    if use_acc_normalization:
        acc_normalizer.normalize(traced)

    traced.recompile()

    return traced
Пример #9
0
 def run_test(self, mod, inputs, expected_ops, rtol=1e-05, atol=1e-06):
     mod = torch.fx.symbolic_trace(mod)
     shape_prop.ShapeProp(mod).propagate(*inputs)
     mod = NormalizeArgs(mod).transform()
     interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
     super().run_test(mod, inputs, expected_ops, interp, rtol, atol)