def setUp(self): super().setUp() self.qconfig = torch.ao.quantization.QConfig( activation=torch.ao.quantization.observer.HistogramObserver. with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8), weight=torch.ao.quantization.default_weight_observer) self.trt_backend_config_dict = get_tensorrt_backend_config_dict()
def test_conv(self): class Conv2d(torch.nn.Module): def __init__(self, *args): super().__init__() self.conv = torch.nn.Conv2d(*args) def forward(self, x): return self.conv(x) conv2d_input = torch.rand(1, 3, 224, 224) conv2d_module_args = (3, 3, 3) m = Conv2d(*conv2d_module_args).eval() qconfig = torch.quantization.QConfig( activation=torch.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, dtype=torch.qint8), weight=torch.quantization.default_weight_observer) prepared = prepare_fx( m, {"": qconfig}, backend_config_dict=get_tensorrt_backend_config_dict()) # calibration prepared(conv2d_input) quantized = convert_fx(prepared, is_reference=True) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1 } self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence) # lower to trt trt_mod = lower_to_trt(quantized, conv2d_input, [((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))]) # make sure it runs trt_mod(conv2d_input.cuda())
def test_linear(self): class LinearModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(5, 10) def forward(self, x): return self.linear(x) linear_module_input = torch.rand(8, 5) m = LinearModule().eval() qconfig = torch.quantization.QConfig( activation=torch.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, dtype=torch.qint8), weight=torch.quantization.default_weight_observer) prepared = prepare_fx( m, {"": qconfig}, backend_config_dict=get_tensorrt_backend_config_dict()) # calibration prepared(linear_module_input) quantized = convert_fx(prepared, is_reference=True) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1 } self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence) # lower to trt trt_mod = lower_to_trt(quantized, linear_module_input, [((1, *linear_module_input.shape[1:]), (5, *linear_module_input.shape[1:]), (10, *linear_module_input.shape[1:]))]) # make sure it runs trt_mod(linear_module_input.cuda())
def setUp(self): super().setUp() self.trt_backend_config_dict = get_tensorrt_backend_config_dict()