示例#1
0
 def setUp(self):
     super().setUp()
     self.qconfig = torch.ao.quantization.QConfig(
         activation=torch.ao.quantization.observer.HistogramObserver.
         with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
         weight=torch.ao.quantization.default_weight_observer)
     self.trt_backend_config_dict = get_tensorrt_backend_config_dict()
示例#2
0
    def test_conv(self):
        class Conv2d(torch.nn.Module):
            def __init__(self, *args):
                super().__init__()
                self.conv = torch.nn.Conv2d(*args)

            def forward(self, x):
                return self.conv(x)

        conv2d_input = torch.rand(1, 3, 224, 224)
        conv2d_module_args = (3, 3, 3)

        m = Conv2d(*conv2d_module_args).eval()
        qconfig = torch.quantization.QConfig(
            activation=torch.quantization.observer.HistogramObserver.with_args(
                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
            weight=torch.quantization.default_weight_observer)
        prepared = prepare_fx(
            m, {"": qconfig},
            backend_config_dict=get_tensorrt_backend_config_dict())
        # calibration
        prepared(conv2d_input)
        quantized = convert_fx(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method("dequantize"): 1
        }
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
        # lower to trt
        trt_mod = lower_to_trt(quantized, conv2d_input,
                               [((1, 3, 224, 224), (5, 3, 224, 224),
                                 (10, 3, 224, 224))])
        # make sure it runs
        trt_mod(conv2d_input.cuda())
示例#3
0
    def test_linear(self):
        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        m = LinearModule().eval()
        qconfig = torch.quantization.QConfig(
            activation=torch.quantization.observer.HistogramObserver.with_args(
                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
            weight=torch.quantization.default_weight_observer)
        prepared = prepare_fx(
            m, {"": qconfig},
            backend_config_dict=get_tensorrt_backend_config_dict())
        # calibration
        prepared(linear_module_input)
        quantized = convert_fx(prepared, is_reference=True)
        node_occurrence = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method("dequantize"): 1
        }
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=node_occurrence)
        # lower to trt
        trt_mod = lower_to_trt(quantized, linear_module_input,
                               [((1, *linear_module_input.shape[1:]),
                                 (5, *linear_module_input.shape[1:]),
                                 (10, *linear_module_input.shape[1:]))])
        # make sure it runs
        trt_mod(linear_module_input.cuda())
示例#4
0
 def setUp(self):
     super().setUp()
     self.trt_backend_config_dict = get_tensorrt_backend_config_dict()