def test_normalize_args(self): m = resnet18() class FunctionalTracer(torch.fx.Tracer): def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: # `leaves` contains the set of standard `nn.Modules` that are not # currently symbolically traceable. Ideally this set would be empty leaves = set([torch.nn.BatchNorm2d]) return type(m) in leaves traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m)) input = torch.randn(5, 3, 224, 224) ref_outs = traced(input) traced = NormalizeArgs(traced).transform() test_outs = traced(input) self.assertEqual(test_outs, ref_outs) modules = dict(traced.named_modules()) for node in traced.graph.nodes: if node.op == 'call_function' and node.target.__module__ == 'torch.nn.functional': self.assertEqual(len(node.args), 0) if node.op == 'call_module': submod_class = modules[node.target].__class__ nn_class = getattr(torch.nn, submod_class.__name__) if submod_class == nn_class: self.assertEqual(len(node.args), 0)
def build_int8_trt_implicit_quant(rn18): rn18 = copy.deepcopy(rn18) data = torch.randn(1, 3, 224, 224) # Quantization qconfig = torch.ao.quantization.QConfig( activation=torch.ao.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, reduce_range=True), weight=torch.ao.quantization.default_per_channel_weight_observer) prepared = prepare_fx(rn18, {"": qconfig}) for _ in range(10): prepared(data) quantized_rn18 = convert_fx(prepared) ref_res = quantized_rn18(data) # Build trt int8 model traced_rn18 = torch.fx.symbolic_trace(quantized_rn18) shape_prop.ShapeProp(traced_rn18).propagate(data) traced_rn18 = NormalizeArgs(traced_rn18).transform() interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data]), logger_level=trt.Logger.VERBOSE) engine, input_names, output_names = interp.run( fp16_mode=False, int8_mode=True, strict_type_constraints=True) trt_mod = TRTModule(engine, input_names, output_names) trt_res = trt_mod(data.cuda()) print("implicit quant result diff max", torch.max(ref_res - trt_res.cpu())) return trt_mod
def __init__(self, module: torch.fx.GraphModule, input_shapes: List[InputTensorSpec], logger_level=trt.Logger.WARNING): # Preprocess the model module = copy.deepcopy(module) module = module.cpu().float() module = NormalizeArgs(module).transform() super().__init__(module) self.logger = trt.Logger(logger_level) self.builder = trt.Builder(self.logger) # TODO: explicit batching # EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) # self.network = self.builder.create_network(EXPLICIT_BATCH) self.network = self.builder.create_network() self.input_shape_itr = iter(input_shapes) self._cur_node_name: Optional[str] = None self._input_names: List[str] = [] self._output_names: List[str] = []
def __init__(self, module : torch.nn.Module, input_specs : List[InputTensorSpec], logger_level=trt.Logger.WARNING): # Preprocess the model if not isinstance(module, torch.fx.GraphModule): module = torch.fx.symbolic_trace(module) else: module = copy.deepcopy(module) module = module.cpu().float() module = NormalizeArgs(module).transform() super().__init__(module, input_specs, logger_level=logger_level)
def run_test_custom_compare_results(self, mod, inputs, expected_ops, comparators: List[Tuple[Callable, List]], interpreter=None): # interpreter is ignored, we do not need this for Vanilla tests # Note this is different from internal version, we need to fix the test case # after we refactor the internal callsites to use this file mod = torch.fx.symbolic_trace(mod) shape_prop.ShapeProp(mod).propagate(*inputs) mod = NormalizeArgs(mod).transform() interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) super().run_test_custom_compare_results(mod, inputs, expected_ops, comparators, interp)
def __init__(self, module: torch.fx.GraphModule, input_shapes: List[InputTensorSpec], logger_level=trt.Logger.WARNING): # Preprocess the model module = copy.copy(module) module = NormalizeArgs(module).transform() super().__init__(module) self.logger = trt.Logger(logger_level) self.builder = trt.Builder(self.logger) self.network = self.builder.create_network() self.input_shape_itr = iter(input_shapes) self._cur_node_name: Optional[str] = None self._input_names: List[str] = [] self._output_names: List[str] = []
def test_normalize_modules_exhaustive(self): """ Exhaustively test `NormalizeArgs` on all standard torch.nn Module classes """ for test_params in module_tests + new_module_tests: if 'constructor' not in test_params: constructor = getattr(torch.nn, test_params['module_name']) else: constructor = test_params['constructor'] if 'constructor_args' not in test_params: args = () else: args = test_params['constructor_args'] mod = constructor(*args) # Skip modules that are not standard `torch.nn` # instances, including functionals. (functionals # are tested in test_normalize_args) if mod.__class__.__name__ not in dir(torch.nn): continue if 'input_fn' not in test_params: inputs = torch.randn(test_params['input_size']) else: inputs = test_params['input_fn']() if not isinstance(inputs, (tuple, list)): inputs = (inputs, ) params = ', '.join(f'v{i}' for i in range(len(inputs))) # Generate a class to wrap this standard `nn.Module` instance test_classname = f'Test{mod.__class__.__name__}' test_mod_code = f""" class {test_classname}(torch.nn.Module): def __init__(self, mod): super().__init__() self.mod = mod def forward(self, {params}): return self.mod({params}) """ gbls = {'torch': torch} exec(test_mod_code, gbls) test_instance = gbls[test_classname](mod) traced = symbolic_trace(test_instance) # Now actually test arg normalization! traced = NormalizeArgs(traced).transform() # These Modules have an RNG in their forward, so testing # correctness by comparing outputs is not correct. Skip that # check for these stochastic_modules = { 'FractionalMaxPool2d', 'FractionalMaxPool3d', 'RReLU' } if mod.__class__.__name__ not in stochastic_modules: self.assertEqual(traced(*inputs), mod(*inputs)) # Ensure all args/kwargs are normalized into kwargs modules = dict(traced.named_modules()) for node in traced.graph.nodes: if node.op == 'call_module': submod_class = modules[node.target].__class__ nn_class = getattr(torch.nn, submod_class.__name__) if submod_class == nn_class: self.assertEqual(len(node.args), 0)
def trace( mod: nn.Module, sample_inputs: List[torch.Tensor], remove_assertions: bool = True, remove_exceptions: bool = True, use_acc_normalization: bool = True, ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None, leaf_module_list: Optional[Set[Type[nn.Module]]] = None, ) -> torch.fx.GraphModule: """ Performs tracing and arg normalization specialized for accelerator lowering. It first rewrites the AST of the module's methods (and all attr methods recursively) to transform un-tracable parts of the module to make them traceable. It then traces to the functional level so that optimizations and backend accelerator importers have the ability to see and/or change inputs to each op. It then removes assertions and exception wrappers found during symbolic tracing if requested based on remove_assertions and remove_exceptions Dead code is then eliminated, which will e.g. remove any nodes that were only used by assertions or exceptions if they were removed. It then performs normalization on args/kwargs, aligning any arg that can be moved to kwarg to be so, and then making default values explicit. Args: mod (Module): The module to transform and trace. sample_inputs (Tuple[Union[torch.Tensor, List[torch.Tensor]]]): Sample inputs with which to run shape prop. remove_assertions (bool): Whether to remove assertion nodes from the graph after symbolic tracing. remove_exceptions (bool): Whether to remove exception wrapper nodes from the graph after symbolic tracing. use_acc_normalization (bool): Whether to use acc-specific normalization to all acc_ops. ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of modules that need AST rewriting. leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where modules will not be traced into. """ if mod.training: warnings.warn( "acc_tracer does not support currently support models for training." " Calling eval on model before tracing.") mod.eval() # Rewrite the module to make it symbolic traceable, and then trace it. rewritten_graph, rewritten_mod = AccRewritingTracer().trace( mod, ast_rewriter_allow_list=ast_rewriter_allow_list, leaf_module_list=leaf_module_list, ) assert isinstance(rewritten_mod, nn.Module) # Note: use the rewritten_mod here as the root. This is necessary because # RewrittenModule includes a new module for the ConditionalExceptionWrapper. traced = torch.fx.GraphModule(rewritten_mod, rewritten_graph) # Now remove all assertions and exceptions if requested. if remove_assertions: _remove_assertions(traced) if remove_exceptions: _remove_exceptions(traced) # Cleanup any dead code from the original module as well as resulting dead # nodes after removing assertions and exceptions. traced.graph.eliminate_dead_code() # Now normalize args/kwargs to make default values visible. Leave args/kwargs as # they were, since all-kwarg normalization is broken, and we don't need it anyway. shape_prop.ShapeProp(traced).propagate(*sample_inputs) traced = NormalizeArgs(traced, normalize_to_only_use_kwargs=False).transform() # Normalize to acc-specialized wrappers for consistency across op naming and # ensuring all kwarg usage. if use_acc_normalization: acc_normalizer.normalize(traced) traced.recompile() return traced
def run_test(self, mod, inputs, expected_ops, rtol=1e-05, atol=1e-06): mod = torch.fx.symbolic_trace(mod) shape_prop.ShapeProp(mod).propagate(*inputs) mod = NormalizeArgs(mod).transform() interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) super().run_test(mod, inputs, expected_ops, interp, rtol, atol)