def build_int8_trt_implicit_quant(rn18): rn18 = copy.deepcopy(rn18) data = torch.randn(1, 3, 224, 224) # Quantization qconfig = torch.ao.quantization.QConfig( activation=torch.ao.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, reduce_range=True ), weight=torch.ao.quantization.default_per_channel_weight_observer ) prepared = prepare_fx(rn18, {"": qconfig}) for _ in range(10): prepared(data) quantized_rn18 = convert_fx(prepared) ref_res = quantized_rn18(data) # Build trt int8 model traced_rn18 = torch.fx.symbolic_trace(quantized_rn18) shape_prop.ShapeProp(traced_rn18).propagate(data) traced_rn18 = NormalizeArgs(traced_rn18).transform() interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data]), logger_level=trt.Logger.VERBOSE) interpreter_result = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True) trt_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names) trt_res = trt_mod(data.cuda()) print("implicit quant result diff max", torch.max(ref_res - trt_res.cpu())) return trt_mod
def __call__( self, module: nn.Module, input: Input, cuda_graph_batch_size: int = -1, ) -> nn.Module: """See `LowerFunc` protocol""" if self.fp16: module.eval().half() input = tuple(x.half() if x.dtype == torch.float32 else x for x in input) const_split_mod = split_const_subgraphs(module) const_split_mod.run_folding() module = self.acc_trace(const_split_mod, input) # type: ignore[misc] print(f"acc traced: {module.graph}") split_module, splits = self.split(module, input) # type: ignore[arg-type] split_module.eval() # type: ignore[attr-defined] for _split in splits: # type: ignore[attr-defined] if _split.device == "acc": # Ensure parent module is updated with the traced sub-net before running # remove_duplicate_output_args. self.remove_duplicate_output_args(_split.module, [_split.name]) # type: ignore[misc, operator] interp_res = self.trt_interpreter( _split.module, _split.input, _split.name ) trt_module = TRTModule( engine=interp_res.engine, input_names=interp_res.input_names, output_names=interp_res.output_names, cuda_graph_batch_size=cuda_graph_batch_size, ) setattr(split_module, _split.name, trt_module) return split_module # type: ignore[return-value]
def run_test(self, mod, inputs, expected_ops, unexpected_ops, interpreter, rtol, atol): with torch.no_grad(): cuda_inputs = [] for i in inputs: cuda_inputs.append(i.cuda()) mod.eval() if len(expected_ops): self.assert_has_op(mod, expected_ops) if unexpected_ops: self.assert_unexpected_op(mod, unexpected_ops) interpreter_result = interpreter.run(fp16_mode=False) trt_mod = TRTModule( interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names, ) ref_outputs = mod(*inputs) outputs = trt_mod(*cuda_inputs) if isinstance(outputs, torch.Tensor): ref_outputs = [ref_outputs] outputs = [outputs] for out, ref in zip(outputs, ref_outputs): torch.testing.assert_allclose(out.cpu(), ref, rtol=rtol, atol=atol)
def build_int8_trt(rn18): rn18 = copy.deepcopy(rn18) data = torch.randn(1, 3, 224, 224) # data = torch.randn(1, 32) # data = torch.randn(1, 64, 10, 10) # TensorRT only supports symmetric quantization qconfig = torch.ao.quantization.QConfig( activation=torch.ao.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 ), # weight=torch.ao.quantization.default_weight_observer # uncomment to check per channel quant works weight=torch.quantization.default_per_channel_weight_observer ) prepared = prepare_fx(rn18, {"": qconfig}) for _ in range(10): prepared(data) quantized_rn18 = convert_fx(prepared, is_reference=True) ref_res = quantized_rn18(data) print("quantized model:", quantized_rn18) quantized_rn18 = acc_tracer.trace(quantized_rn18, [data]) # type: ignore[assignment] interp = TRTInterpreter( quantized_rn18, [InputTensorSpec(torch.Size([-1, *data.shape[1:]]), torch.float, shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))], has_batch_dim=True)], explicit_batch_dimension=True, explicit_precision=True, logger_level=trt.Logger.VERBOSE) interpreter_result = interp.run(fp16_mode=False, int8_mode=True) trt_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names) trt_res = trt_mod(data.cuda()) print("explicit quant result diff max", torch.max(ref_res - trt_res.cpu())) return trt_mod
def build_fp16_trt(rn18): rn18 = copy.deepcopy(rn18) rn18 = acc_tracer.trace(rn18, [torch.randn(1, 3, 224, 224)]) interp = TRTInterpreter( rn18, [InputTensorSpec(torch.Size([3, 224, 224]), torch.float, has_batch_dim=False)]) interpreter_result = interp.run(fp16_mode=True) return TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
def lower_mod_default(mod: torch.fx.GraphModule, inputs: Tensors, batch_size: Any = 2048) -> TRTModule: interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True) interpreter_result = interp.run(max_batch_size=batch_size) res_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names) return res_mod
def _lower_model_to_backend(self, mod: torch.fx.GraphModule, inputs: Iterable[torch.Tensor]): """ Lower a GraphModule `mod` to TensorRT with `inputs`. """ # Current code for lowering is place-holder, subject to future change # based on feeds model's actual status interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) interpreter_result = interp.run(*inputs) return TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
def run_test_custom_compare_results( self, mod, inputs, expected_ops, interpreter, comparators: List[Tuple[Callable, List]], fp16_mode=False, ): """ Runs the test and compares the result using the provided comparators. The size of comparators must be equal to the number of outputs from 'mod'. mod - a model to run. inputs - a list of the model inputs. expected ops - a list of ops that should be verified. interpreter - used for converting the model to TRT. comparators - a list of (func, args) pairs corresponding to each of the module outputs. usage: func(x, y, *args) """ with torch.no_grad(): cuda_inputs = [] for i in inputs: cuda_inputs.append(i.cuda()) mod.eval() if len(expected_ops): self.assert_has_op(mod, expected_ops) interpreter_result = interpreter.run(fp16_mode=fp16_mode) trt_mod = TRTModule( interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names, ) res_trt = trt_mod(*cuda_inputs).cpu() res_cpu = mod(*inputs) assert len(res_trt) == len(res_cpu) assert len(res_cpu) == len(comparators) for output_trt, output_cpu, comparator in zip( res_trt, res_cpu, comparators ): comp_func = comparator[0] args = comparator[1] self.assertTrue(comp_func(output_trt, output_cpu, *args))
def lower_to_trt(model, inputs, shape_ranges): """ Lower a quantized model to TensorRT """ assert len(inputs) == 1, "lower_to_trt only works for one input currently" model = acc_tracer.trace(model, inputs) # type: ignore[attr-defined] # TODO: test multiple inputs setting and enable multiple inputs input_specs = [ InputTensorSpec( torch.Size([-1, *inputs[0].shape[1:]]), torch.float, shape_ranges=shape_ranges, has_batch_dim=True) ] interp = TRTInterpreter( model, input_specs, explicit_batch_dimension=True, explicit_precision=True) result = interp.run(fp16_mode=False, int8_mode=True) trt_mod = TRTModule(result.engine, result.input_names, result.output_names) return trt_mod
def test_save_and_load_trt_module(self): class TestModule(torch.nn.Module): def forward(self, x): return x + x inputs = [torch.randn(1, 1)] mod = TestModule().eval() ref_output = mod(*inputs) mod = acc_tracer.trace(mod, inputs) interp = TRTInterpreter( mod, input_specs=InputTensorSpec.from_tensors(inputs)) trt_mod = TRTModule(*interp.run(fp16_mode=False)) torch.save(trt_mod, "trt.pt") reload_trt_mod = torch.load("trt.pt") torch.testing.assert_allclose(reload_trt_mod(inputs[0].cuda()).cpu(), ref_output, rtol=1e-04, atol=1e-04)
%linear_weight : [#users=1] = get_attr[target=linear.weight] %linear_bias : [#users=1] = get_attr[target=linear.bias] %linear_1 : [#users=1] = call_function[target=torch.fx.experimental.fx_acc.acc_ops.linear](args = (), ... %relu_1 : [#users=1] = call_function[target=torch.fx.experimental.fx_acc.acc_ops.relu](args = (), ... return relu_1 graph(): %relu_1 : [#users=1] = placeholder[target=relu_1] %linalg_norm_1 : [#users=1] = call_function[target=torch.fx.experimental.fx_acc.acc_ops.linalg_norm](args = (), ... return linalg_norm_1 """ # Now let's lower split_mod._run_on_acc_0. If we know the model can be fully lowered, # we can skip the splitter part. interp = TRTInterpreter(split_mod._run_on_acc_0, InputTensorSpec.from_tensors(inputs)) engine, input_names, output_names = interp.run() trt_mod = TRTModule(engine, input_names, output_names) split_mod._run_on_acc_0 = trt_mod cuda_inputs = [input.cuda() for input in inputs] split_mod.cuda() lowered_model_output = split_mod(*cuda_inputs) # Make sure the results match model.cuda() regular_model_output = model(*cuda_inputs) torch.testing.assert_close(lowered_model_output, regular_model_output.to(torch.float16), atol=3e-3, rtol=1e-2) # We can utilize the trt profiler to print out the time spend on each layer. trt_mod.enable_profiling() trt_mod(*cuda_inputs) '''