def run_test(
        self,
        mod,
        inputs,
        expected_ops,
        unexpected_ops=None,
        apply_passes=None,
        test_explicit_batch_dim=True,
        test_implicit_batch_dim=True,
        rtol=1e-03,
        atol=1e-03,
    ):
        mod.eval()
        mod = acc_tracer.trace(mod, inputs)

        if apply_passes is not None:
            for p in apply_passes:
                mod = p(mod)

        if test_implicit_batch_dim:
            interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
            super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol)

        if test_explicit_batch_dim:
            interp = TRTInterpreter(
                mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
            )
            super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol)
def build_int8_trt_implicit_quant(rn18):
    rn18 = copy.deepcopy(rn18)
    data = torch.randn(1, 3, 224, 224)
    # Quantization
    qconfig = torch.ao.quantization.QConfig(
        activation=torch.ao.quantization.observer.HistogramObserver.with_args(
            qscheme=torch.per_tensor_symmetric, reduce_range=True
        ),
        weight=torch.ao.quantization.default_per_channel_weight_observer
    )
    prepared = prepare_fx(rn18, {"": qconfig})
    for _ in range(10):
        prepared(data)
    quantized_rn18 = convert_fx(prepared)
    ref_res = quantized_rn18(data)

    # Build trt int8 model
    traced_rn18 = torch.fx.symbolic_trace(quantized_rn18)
    shape_prop.ShapeProp(traced_rn18).propagate(data)
    traced_rn18 = NormalizeArgs(traced_rn18).transform()
    interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data]), logger_level=trt.Logger.VERBOSE)
    interpreter_result = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True)
    trt_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
    trt_res = trt_mod(data.cuda())
    print("implicit quant result diff max", torch.max(ref_res - trt_res.cpu()))
    return trt_mod
    def run_test_with_assert_error(
        self,
        mod,
        inputs,
        expect_error,
        test_explicit_batch_dim=True,
        test_implicit_batch_dim=True,
    ):
        mod.eval()
        mod = acc_tracer.trace(mod, inputs)

        if test_implicit_batch_dim:
            interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
            super().run_test_with_error(mod, inputs, interp, expect_error)

        if test_explicit_batch_dim:
            interp = TRTInterpreter(
                mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
            )
            super().run_test_with_error(mod, inputs, interp, expect_error)
Exemple #4
0
def lower_mod_default(mod: torch.fx.GraphModule,
                      inputs: Tensors,
                      batch_size: Any = 2048) -> TRTModule:
    interp = TRTInterpreter(mod,
                            InputTensorSpec.from_tensors(inputs),
                            explicit_batch_dimension=True)
    interpreter_result = interp.run(max_batch_size=batch_size)
    res_mod = TRTModule(interpreter_result.engine,
                        interpreter_result.input_names,
                        interpreter_result.output_names)
    return res_mod
Exemple #5
0
 def _lower_model_to_backend(self, mod: torch.fx.GraphModule,
                             inputs: Iterable[torch.Tensor]):
     """
     Lower a GraphModule `mod` to TensorRT with `inputs`.
     """
     # Current code for lowering is place-holder, subject to future change
     # based on feeds model's actual status
     interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
     interpreter_result = interp.run(*inputs)
     return TRTModule(interpreter_result.engine,
                      interpreter_result.input_names,
                      interpreter_result.output_names)
Exemple #6
0
    def __call__(self, mod, input, split_name):
        input_specs_val = (self.lower_setting.input_specs
                           if self.lower_setting.input_specs else
                           InputTensorSpec.from_tensors(input))
        if self.lower_setting.enable_fuse:
            mod = fuse_permute_matmul(mod)
            mod = fuse_permute_linear(mod)
            mod = fuse_unsqueeze_cat_sum(mod)

        # Prepare algorithm selector and timing_cache for TRTInterpreter
        algo_selector = None
        if self.lower_setting.algo_selector:
            algo_selector = self.lower_setting.algo_selector(
                f"{split_name}.json")
        cache_data = None
        if self.timing_cache_manager:
            try:
                cache_data = self.timing_cache_manager.get_timing_cache_trt(
                    split_name)
            except Exception as e:
                logger.warning(
                    f"Cannot load timing cache for {split_name}: {str(e)}")
                cache_data = None

        import tensorrt as trt  # @manual=//caffe2:torch_fx2trt # noqa: F401

        interpreter = TRTInterpreter(
            mod,
            input_specs=input_specs_val,
            explicit_batch_dimension=self.lower_setting.
            explicit_batch_dimension,
            explicit_precision=self.lower_setting.explicit_precision,
            logger_level=trt.Logger.VERBOSE
            if self.lower_setting.verbose_log else trt.Logger.WARNING,
        )

        interp_result = interpreter.run(
            max_batch_size=self.lower_setting.max_batch_size,
            max_workspace_size=self.lower_setting.max_workspace_size,
            fp16_mode=self.lower_setting.fp16_mode,
            int8_mode=self.lower_setting.int8_mode,
            strict_type_constraints=self.lower_setting.strict_type_constraints,
            algorithm_selector=algo_selector,
            timing_cache=cache_data,
        )

        # Update timing cache file if needed
        timing_cache = interp_result.serialized_cache
        if timing_cache and self.timing_cache_manager:
            self.timing_cache_manager.update_timing_cache(
                split_name, timing_cache)

        return interp_result
 def run_test_custom_compare_results(
     self,
     mod,
     inputs,
     expected_ops,
     interpreter,
     comparators: List[Tuple[Callable, List]],
     fp16_mode=False,
 ):
     # interpreter is ignored, we do not need this for Vanilla tests
     # Note this is different from internal version, we need to fix the test case
     # after we refactor the internal callsites to use this file
     mod = torch.fx.symbolic_trace(mod)
     shape_prop.ShapeProp(mod).propagate(*inputs)
     mod = NormalizeArgs(mod).transform()
     interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
     super().run_test_custom_compare_results(
         mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode
     )
Exemple #8
0
    def test_save_and_load_trt_module(self):
        class TestModule(torch.nn.Module):
            def forward(self, x):
                return x + x

        inputs = [torch.randn(1, 1)]
        mod = TestModule().eval()
        ref_output = mod(*inputs)

        mod = acc_tracer.trace(mod, inputs)
        interp = TRTInterpreter(
            mod, input_specs=InputTensorSpec.from_tensors(inputs))
        trt_mod = TRTModule(*interp.run(fp16_mode=False))
        torch.save(trt_mod, "trt.pt")
        reload_trt_mod = torch.load("trt.pt")

        torch.testing.assert_allclose(reload_trt_mod(inputs[0].cuda()).cpu(),
                                      ref_output,
                                      rtol=1e-04,
                                      atol=1e-04)
 def test_from_tensors(self):
     tensors = [torch.randn(1, 2, 3), torch.randn(2, 4)]
     specs = InputTensorSpec.from_tensors(tensors)
     for spec, tensor in zip(specs, tensors):
         self._validate_spec(spec, tensor)
 def run_test(self, mod, inputs, expected_ops, rtol=1e-05, atol=1e-06):
     mod = torch.fx.symbolic_trace(mod)
     shape_prop.ShapeProp(mod).propagate(*inputs)
     mod = NormalizeArgs(mod).transform()
     interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
     super().run_test(mod, inputs, expected_ops, None, interp, rtol, atol)
graph():
    %x : [#users=1] = placeholder[target=x]
    %linear_weight : [#users=1] = get_attr[target=linear.weight]
    %linear_bias : [#users=1] = get_attr[target=linear.bias]
    %linear_1 : [#users=1] = call_function[target=torch.fx.experimental.fx_acc.acc_ops.linear](args = (), ...
    %relu_1 : [#users=1] = call_function[target=torch.fx.experimental.fx_acc.acc_ops.relu](args = (), ...
    return relu_1
graph():
    %relu_1 : [#users=1] = placeholder[target=relu_1]
    %linalg_norm_1 : [#users=1] = call_function[target=torch.fx.experimental.fx_acc.acc_ops.linalg_norm](args = (), ...
    return linalg_norm_1
"""

# Now let's lower split_mod._run_on_acc_0. If we know the model can be fully lowered,
# we can skip the splitter part.
interp = TRTInterpreter(split_mod._run_on_acc_0, InputTensorSpec.from_tensors(inputs))
engine, input_names, output_names = interp.run()
trt_mod = TRTModule(engine, input_names, output_names)
split_mod._run_on_acc_0 = trt_mod

cuda_inputs = [input.cuda() for input in inputs]
split_mod.cuda()
lowered_model_output = split_mod(*cuda_inputs)

# Make sure the results match
model.cuda()
regular_model_output = model(*cuda_inputs)
torch.testing.assert_close(lowered_model_output, regular_model_output.to(torch.float16), atol=3e-3, rtol=1e-2)

# We can utilize the trt profiler to print out the time spend on each layer.
trt_mod.enable_profiling()