def run_test( self, mod, inputs, expected_ops, apply_passes=None, test_explicit_batch_dim=True, test_implicit_batch_dim=True, rtol=1e-03, atol=1e-03, ): mod.eval() mod = acc_tracer.trace(mod, inputs) if apply_passes is not None: for p in apply_passes: mod = p(mod) if test_implicit_batch_dim: interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) super().run_test(mod, inputs, expected_ops, interp, rtol, atol) if test_explicit_batch_dim: interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True) super().run_test(mod, inputs, expected_ops, interp, rtol, atol)
def build_int8_trt_implicit_quant(rn18): rn18 = copy.deepcopy(rn18) data = torch.randn(1, 3, 224, 224) # Quantization qconfig = torch.ao.quantization.QConfig( activation=torch.ao.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, reduce_range=True), weight=torch.ao.quantization.default_per_channel_weight_observer) prepared = prepare_fx(rn18, {"": qconfig}) for _ in range(10): prepared(data) quantized_rn18 = convert_fx(prepared) ref_res = quantized_rn18(data) # Build trt int8 model traced_rn18 = torch.fx.symbolic_trace(quantized_rn18) shape_prop.ShapeProp(traced_rn18).propagate(data) traced_rn18 = NormalizeArgs(traced_rn18).transform() interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data]), logger_level=trt.Logger.VERBOSE) engine, input_names, output_names = interp.run( fp16_mode=False, int8_mode=True, strict_type_constraints=True) trt_mod = TRTModule(engine, input_names, output_names) trt_res = trt_mod(data.cuda()) print("implicit quant result diff max", torch.max(ref_res - trt_res.cpu())) return trt_mod
def lower_mod_to_trt(mod: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]): """ Helper function that given a GraphModule `mod` and its `inputs`, build a TRTModule that runs the original `mod` on TensorRT. """ interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) engine, input_names, output_names = interp.run(*inputs) return TRTModule(engine, input_names, output_names)
def lower_mod_default(mod: torch.fx.GraphModule, inputs: Tensors, batch_size: Any = 2048) -> TRTModule: interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True) res_mod = TRTModule(*interp.run(max_batch_size=batch_size)) return res_mod
def _lower_model_to_backend(self, mod: torch.fx.GraphModule, inputs: Iterable[torch.Tensor]): """ Lower a GraphModule `mod` to TensorRT with `inputs`. """ # Current code for lowering is place-holder, subject to future change # based on feeds model's actual status interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) engine, input_names, output_names = interp.run(*inputs) return TRTModule(engine, input_names, output_names)
def run_test_with_assert_error( self, mod, inputs, expect_error, test_explicit_batch_dim=True, test_implicit_batch_dim=True, ): mod.eval() mod = acc_tracer.trace(mod, inputs) if test_implicit_batch_dim: interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) super().run_test_with_error(mod, inputs, interp, expect_error) if test_explicit_batch_dim: interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True) super().run_test_with_error(mod, inputs, interp, expect_error)
def run_test_custom_compare_results(self, mod, inputs, expected_ops, comparators: List[Tuple[Callable, List]], interpreter=None): # interpreter is ignored, we do not need this for Vanilla tests # Note this is different from internal version, we need to fix the test case # after we refactor the internal callsites to use this file mod = torch.fx.symbolic_trace(mod) shape_prop.ShapeProp(mod).propagate(*inputs) mod = NormalizeArgs(mod).transform() interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) super().run_test_custom_compare_results(mod, inputs, expected_ops, comparators, interp)
def test_save_and_load_trt_module(self): class TestModule(torch.nn.Module): def forward(self, x): return x + x inputs = [torch.randn(1, 1)] mod = TestModule().eval() ref_output = mod(*inputs) mod = acc_tracer.trace(mod, inputs) interp = TRTInterpreter( mod, input_specs=InputTensorSpec.from_tensors(inputs)) trt_mod = TRTModule(*interp.run(fp16_mode=False)) torch.save(trt_mod, "trt.pt") reload_trt_mod = torch.load("trt.pt") torch.testing.assert_allclose(reload_trt_mod(inputs[0].cuda()).cpu(), ref_output, rtol=1e-04, atol=1e-04)
def run_test(self, mod, inputs, expected_ops, rtol=1e-05, atol=1e-06): mod = torch.fx.symbolic_trace(mod) shape_prop.ShapeProp(mod).propagate(*inputs) mod = NormalizeArgs(mod).transform() interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) super().run_test(mod, inputs, expected_ops, interp, rtol, atol)
%x : [#users=1] = placeholder[target=x] %linear_weight : [#users=1] = get_attr[target=linear.weight] %linear_bias : [#users=1] = get_attr[target=linear.bias] %linear_1 : [#users=1] = call_function[target=torch.fx.experimental.fx_acc.acc_ops.linear](args = (), ... %relu_1 : [#users=1] = call_function[target=torch.fx.experimental.fx_acc.acc_ops.relu](args = (), ... return relu_1 graph(): %relu_1 : [#users=1] = placeholder[target=relu_1] %linalg_norm_1 : [#users=1] = call_function[target=torch.fx.experimental.fx_acc.acc_ops.linalg_norm](args = (), ... return linalg_norm_1 """ # Now let's lower split_mod._run_on_acc_0. If we know the model can be fully lowered, # we can skip the splitter part. interp = TRTInterpreter(split_mod._run_on_acc_0, InputTensorSpec.from_tensors(inputs)) engine, input_names, output_names = interp.run() trt_mod = TRTModule(engine, input_names, output_names) split_mod._run_on_acc_0 = trt_mod cuda_inputs = [input.cuda() for input in inputs] split_mod.cuda() lowered_model_output = split_mod(*cuda_inputs) # Make sure the results match model.cuda() regular_model_output = model(*cuda_inputs) torch.testing.assert_close(lowered_model_output, regular_model_output.to(torch.float16), atol=3e-3, rtol=1e-2)