def build_int8_trt_implicit_quant(rn18):
    rn18 = copy.deepcopy(rn18)
    data = torch.randn(1, 3, 224, 224)
    # Quantization
    qconfig = torch.ao.quantization.QConfig(
        activation=torch.ao.quantization.observer.HistogramObserver.with_args(
            qscheme=torch.per_tensor_symmetric, reduce_range=True
        ),
        weight=torch.ao.quantization.default_per_channel_weight_observer
    )
    prepared = prepare_fx(rn18, {"": qconfig})
    for _ in range(10):
        prepared(data)
    quantized_rn18 = convert_fx(prepared)
    ref_res = quantized_rn18(data)

    # Build trt int8 model
    traced_rn18 = torch.fx.symbolic_trace(quantized_rn18)
    shape_prop.ShapeProp(traced_rn18).propagate(data)
    traced_rn18 = NormalizeArgs(traced_rn18).transform()
    interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data]), logger_level=trt.Logger.VERBOSE)
    interpreter_result = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True)
    trt_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
    trt_res = trt_mod(data.cuda())
    print("implicit quant result diff max", torch.max(ref_res - trt_res.cpu()))
    return trt_mod
예제 #2
0
    def __call__(
        self,
        module: nn.Module,
        input: Input,
        cuda_graph_batch_size: int = -1,
    ) -> nn.Module:
        """See `LowerFunc` protocol"""
        if self.fp16:
            module.eval().half()
            input = tuple(x.half() if x.dtype == torch.float32 else x for x in input)
        const_split_mod = split_const_subgraphs(module)
        const_split_mod.run_folding()
        module = self.acc_trace(const_split_mod, input)  # type: ignore[misc]
        print(f"acc traced: {module.graph}")
        split_module, splits = self.split(module, input)  # type: ignore[arg-type]
        split_module.eval()  # type: ignore[attr-defined]
        for _split in splits:  # type: ignore[attr-defined]
            if _split.device == "acc":
                # Ensure parent module is updated with the traced sub-net before running
                # remove_duplicate_output_args.
                self.remove_duplicate_output_args(_split.module, [_split.name])  # type: ignore[misc, operator]

                interp_res = self.trt_interpreter(
                    _split.module, _split.input, _split.name
                )

                trt_module = TRTModule(
                    engine=interp_res.engine,
                    input_names=interp_res.input_names,
                    output_names=interp_res.output_names,
                    cuda_graph_batch_size=cuda_graph_batch_size,
                )
                setattr(split_module, _split.name, trt_module)

        return split_module  # type: ignore[return-value]
예제 #3
0
    def run_test(self, mod, inputs, expected_ops, unexpected_ops, interpreter, rtol, atol):
        with torch.no_grad():
            cuda_inputs = []
            for i in inputs:
                cuda_inputs.append(i.cuda())

            mod.eval()
            if len(expected_ops):
                self.assert_has_op(mod, expected_ops)
            if unexpected_ops:
                self.assert_unexpected_op(mod, unexpected_ops)

            interpreter_result = interpreter.run(fp16_mode=False)
            trt_mod = TRTModule(
                interpreter_result.engine,
                interpreter_result.input_names,
                interpreter_result.output_names,
            )

            ref_outputs = mod(*inputs)
            outputs = trt_mod(*cuda_inputs)

            if isinstance(outputs, torch.Tensor):
                ref_outputs = [ref_outputs]
                outputs = [outputs]

            for out, ref in zip(outputs, ref_outputs):
                torch.testing.assert_allclose(out.cpu(), ref, rtol=rtol, atol=atol)
def build_int8_trt(rn18):
    rn18 = copy.deepcopy(rn18)
    data = torch.randn(1, 3, 224, 224)
    # data = torch.randn(1, 32)
    # data = torch.randn(1, 64, 10, 10)
    # TensorRT only supports symmetric quantization
    qconfig = torch.ao.quantization.QConfig(
        activation=torch.ao.quantization.observer.HistogramObserver.with_args(
            qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
        ),
        # weight=torch.ao.quantization.default_weight_observer
        # uncomment to check per channel quant works
        weight=torch.quantization.default_per_channel_weight_observer
    )
    prepared = prepare_fx(rn18, {"": qconfig})
    for _ in range(10):
        prepared(data)
    quantized_rn18 = convert_fx(prepared, is_reference=True)
    ref_res = quantized_rn18(data)
    print("quantized model:", quantized_rn18)

    quantized_rn18 = acc_tracer.trace(quantized_rn18, [data])  # type: ignore[assignment]
    interp = TRTInterpreter(
        quantized_rn18,
        [InputTensorSpec(torch.Size([-1, *data.shape[1:]]), torch.float,
                         shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))], has_batch_dim=True)],
        explicit_batch_dimension=True, explicit_precision=True, logger_level=trt.Logger.VERBOSE)
    interpreter_result = interp.run(fp16_mode=False, int8_mode=True)
    trt_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
    trt_res = trt_mod(data.cuda())
    print("explicit quant result diff max", torch.max(ref_res - trt_res.cpu()))
    return trt_mod
def build_fp16_trt(rn18):
    rn18 = copy.deepcopy(rn18)
    rn18 = acc_tracer.trace(rn18, [torch.randn(1, 3, 224, 224)])
    interp = TRTInterpreter(
        rn18, [InputTensorSpec(torch.Size([3, 224, 224]), torch.float, has_batch_dim=False)])
    interpreter_result = interp.run(fp16_mode=True)
    return TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
예제 #6
0
def lower_mod_default(mod: torch.fx.GraphModule,
                      inputs: Tensors,
                      batch_size: Any = 2048) -> TRTModule:
    interp = TRTInterpreter(mod,
                            InputTensorSpec.from_tensors(inputs),
                            explicit_batch_dimension=True)
    interpreter_result = interp.run(max_batch_size=batch_size)
    res_mod = TRTModule(interpreter_result.engine,
                        interpreter_result.input_names,
                        interpreter_result.output_names)
    return res_mod
예제 #7
0
 def _lower_model_to_backend(self, mod: torch.fx.GraphModule,
                             inputs: Iterable[torch.Tensor]):
     """
     Lower a GraphModule `mod` to TensorRT with `inputs`.
     """
     # Current code for lowering is place-holder, subject to future change
     # based on feeds model's actual status
     interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
     interpreter_result = interp.run(*inputs)
     return TRTModule(interpreter_result.engine,
                      interpreter_result.input_names,
                      interpreter_result.output_names)
예제 #8
0
    def run_test_custom_compare_results(
        self,
        mod,
        inputs,
        expected_ops,
        interpreter,
        comparators: List[Tuple[Callable, List]],
        fp16_mode=False,
    ):
        """
        Runs the test and compares the result using the provided comparators.
        The size of comparators must be equal to the number of outputs from 'mod'.

        mod          - a model to run.
        inputs       - a list of the model inputs.
        expected ops - a list of ops that should be verified.
        interpreter  - used for converting the model to TRT.
        comparators  - a list of (func, args) pairs corresponding to each of
                       the module outputs. usage: func(x, y, *args)

        """
        with torch.no_grad():
            cuda_inputs = []
            for i in inputs:
                cuda_inputs.append(i.cuda())

            mod.eval()
            if len(expected_ops):
                self.assert_has_op(mod, expected_ops)

            interpreter_result = interpreter.run(fp16_mode=fp16_mode)
            trt_mod = TRTModule(
                interpreter_result.engine,
                interpreter_result.input_names,
                interpreter_result.output_names,
            )
            res_trt = trt_mod(*cuda_inputs).cpu()
            res_cpu = mod(*inputs)
            assert len(res_trt) == len(res_cpu)
            assert len(res_cpu) == len(comparators)
            for output_trt, output_cpu, comparator in zip(
                res_trt, res_cpu, comparators
            ):
                comp_func = comparator[0]
                args = comparator[1]
                self.assertTrue(comp_func(output_trt, output_cpu, *args))
예제 #9
0
def lower_to_trt(model, inputs, shape_ranges):
    """ Lower a quantized model to TensorRT
    """
    assert len(inputs) == 1, "lower_to_trt only works for one input currently"
    model = acc_tracer.trace(model, inputs)  # type: ignore[attr-defined]
    # TODO: test multiple inputs setting and enable multiple inputs
    input_specs = [
        InputTensorSpec(
            torch.Size([-1, *inputs[0].shape[1:]]), torch.float,
            shape_ranges=shape_ranges, has_batch_dim=True)
    ]

    interp = TRTInterpreter(
        model,
        input_specs,
        explicit_batch_dimension=True, explicit_precision=True)
    result = interp.run(fp16_mode=False, int8_mode=True)
    trt_mod = TRTModule(result.engine, result.input_names, result.output_names)
    return trt_mod
예제 #10
0
    def test_save_and_load_trt_module(self):
        class TestModule(torch.nn.Module):
            def forward(self, x):
                return x + x

        inputs = [torch.randn(1, 1)]
        mod = TestModule().eval()
        ref_output = mod(*inputs)

        mod = acc_tracer.trace(mod, inputs)
        interp = TRTInterpreter(
            mod, input_specs=InputTensorSpec.from_tensors(inputs))
        trt_mod = TRTModule(*interp.run(fp16_mode=False))
        torch.save(trt_mod, "trt.pt")
        reload_trt_mod = torch.load("trt.pt")

        torch.testing.assert_allclose(reload_trt_mod(inputs[0].cuda()).cpu(),
                                      ref_output,
                                      rtol=1e-04,
                                      atol=1e-04)
예제 #11
0
    %linear_weight : [#users=1] = get_attr[target=linear.weight]
    %linear_bias : [#users=1] = get_attr[target=linear.bias]
    %linear_1 : [#users=1] = call_function[target=torch.fx.experimental.fx_acc.acc_ops.linear](args = (), ...
    %relu_1 : [#users=1] = call_function[target=torch.fx.experimental.fx_acc.acc_ops.relu](args = (), ...
    return relu_1
graph():
    %relu_1 : [#users=1] = placeholder[target=relu_1]
    %linalg_norm_1 : [#users=1] = call_function[target=torch.fx.experimental.fx_acc.acc_ops.linalg_norm](args = (), ...
    return linalg_norm_1
"""

# Now let's lower split_mod._run_on_acc_0. If we know the model can be fully lowered,
# we can skip the splitter part.
interp = TRTInterpreter(split_mod._run_on_acc_0, InputTensorSpec.from_tensors(inputs))
engine, input_names, output_names = interp.run()
trt_mod = TRTModule(engine, input_names, output_names)
split_mod._run_on_acc_0 = trt_mod

cuda_inputs = [input.cuda() for input in inputs]
split_mod.cuda()
lowered_model_output = split_mod(*cuda_inputs)

# Make sure the results match
model.cuda()
regular_model_output = model(*cuda_inputs)
torch.testing.assert_close(lowered_model_output, regular_model_output.to(torch.float16), atol=3e-3, rtol=1e-2)

# We can utilize the trt profiler to print out the time spend on each layer.
trt_mod.enable_profiling()
trt_mod(*cuda_inputs)
'''