def test_ops(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.linear = torch.nn.Linear(5, 5) self.relu = torch.nn.ReLU() def forward(self, x): x = self.conv(x) x = self.linear(x) x = x + 3 x = self.relu(x) x = x + 6 return x m = M().eval() m = prepare_fx(m, {"": default_qconfig}) m = _convert_fx_do_not_use(m, is_reference=True) expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 5, ns.call_method("dequantize"): 5, ns.call_module(torch.nn.quantized._reference.Linear): 1, ns.call_module(torch.nn.quantized._reference.Conv2d): 1, } self.checkGraphModuleNodes( m, expected_node_occurrence=expected_occurrence)
def _test_module(self, m, inputs, shape_ranges, no_prepare=None, no_convert=None): """ Args: m: the float module we want to test inputs: list of inputs for the module shape_ranges: a list of shape_range, where every shape_range is a tuple of three tuples ((min_input_shape), (optimized_input_shape), (max_input_shape)). Each shape_range is used to populate a TensorRT optimization profile. e.g. If the input shape varies from (1, 224) to (100, 224) and we want to optimize for (25, 224) because it's the most common input shape, then we set shape_ranges to ((1, 224), (25, 225), (100, 224)) no_prepare: node occurrence after prepare no_convert: node occurrence after convert """ m = m.eval() prepared = prepare_fx(m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict) self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_prepare) # calibration prepared(*inputs) quantized = _convert_fx_do_not_use(prepared, is_reference=True) self.checkGraphModuleNodes(quantized, expected_node_occurrence=no_convert) # lower to trt trt_mod = lower_to_trt(quantized, inputs, shape_ranges) inputs_cuda = [i.cuda() for i in inputs] # make sure it runs trt_mod(*inputs_cuda)
def test_unsupported_qconfig(self): """ Check that we won't quantize the model if the qconfig is not supported """ class LinearModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(5, 10) def forward(self, x): return self.linear(x) linear_module_input = torch.rand(8, 5) m = LinearModule().eval() trt_unsupported_qconfig = default_qconfig prepared = prepare_fx(m, {"": trt_unsupported_qconfig}, backend_config_dict=self.trt_backend_config_dict) # calibration prepared(linear_module_input) quantized = _convert_fx_do_not_use(prepared, is_reference=True) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 0, ns.call_method("dequantize"): 0, ns.call_module(torch.nn.Linear): 1, ns.call_module(torch.nn.quantized._reference.Linear): 0, } # check model is not quantized self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
def _test_quantized_inputs_outputs( self, prepare_custom_config_dict, prepare_count_check, convert_count_check): """ Test the option to have inputs and outputs of the graph quantized """ class M(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1) self.conv2 = torch.nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x # quantized input, quantized output m = M() qconfig_dict = {'': torch.ao.quantization.default_qconfig} m.eval() mp = torch.ao.quantization.quantize_fx.prepare_fx( m, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict) self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check) mp(torch.randn(1, 1, 4, 4)) mq = _convert_fx_do_not_use( mp, is_reference=True, backend_config_dict=self.trt_backend_config_dict) self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check)
def test_addmm(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.weight = torch.randn(5, 5) self.bias = torch.randn(5) def forward(self, x): return torch.addmm(self.bias, x, self.weight) m = M().eval() prepared = prepare_fx( m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict) node_occurrence = { # weight ns.call_module(torch.ao.quantization.MinMaxObserver): 1, # activation ns.call_module(torch.ao.quantization.HistogramObserver): 2, } self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence) quantized = _convert_fx_do_not_use( prepared, is_reference=True, backend_config_dict=self.trt_backend_config_dict) node_occurrence = { # input activation, output activation and weight ns.call_function(torch.quantize_per_tensor): 3, ns.call_function(torch.addmm): 1, ns.call_method("dequantize"): 3, } self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
def test_conv_add(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) def forward(self, x, y): return self.conv(x) + y weighted_op_qint8_dtype_config = { # optional, input activation dtype "input_dtype": torch.qint8, # optional, weight dtype "weight_dtype": torch.qint8, # optional, bias dtype "bias_dtype": torch.float, # optional, output activation dtype "output_dtype": torch.qint8 } conv_add_config = { "pattern": (operator.add, torch.nn.Conv2d, MatchAllNode), "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, "dtype_configs": [ weighted_op_qint8_dtype_config, ], "root_module": torch.nn.Conv2d, "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, } m = M().eval() modified_backend_config_dict = copy.deepcopy( self.trt_backend_config_dict) modified_backend_config_dict["configs"].insert(0, conv_add_config) m = prepare_fx(m, {"": self.qconfig}, backend_config_dict=modified_backend_config_dict) node_occurrence = { ns.call_module(torch.ao.quantization.HistogramObserver): 3, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) m = _convert_fx_do_not_use( m, is_reference=True, backend_config_dict=modified_backend_config_dict) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 3, ns.call_method("dequantize"): 3, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
def test_cat(self): class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.cat([x, x], 1) m = M().eval() prepared = prepare_fx(m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict) self.assertTrue(len(dict(prepared.named_children())) == 1) quantized = _convert_fx_do_not_use(prepared, is_reference=True) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_function(torch.cat): 1, ns.call_method("dequantize"): 2, } self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
def test_conv_add_standalone_module(self): class Standalone(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.relu = torch.nn.ReLU() def forward(self, x, y): return self.relu(self.conv(x) + y) class M(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.standalone = Standalone() def forward(self, x, y): y = self.conv(x) return self.standalone(x, y) from torch.ao.quantization.fx.backend_config_dict.observation_type import ObservationType weighted_op_quint8_dtype_config = { # optional, input activation dtype # TODO: change back to torch.qint8 after input_quantized_idxs and output_quantized_idxs # are more flexible "input_dtype": torch.quint8, # optional, weight dtype "weight_dtype": torch.qint8, # optional, bias dtype "bias_dtype": torch.float, # optional, output activation dtype "output_dtype": torch.quint8 } conv_add_config = { "pattern": (torch.nn.ReLU, (operator.add, torch.nn.Conv2d, MatchAllNode)), "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, "dtype_configs": [ weighted_op_quint8_dtype_config, ], "root_module": torch.nn.Conv2d, # "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, } conv_config = { "pattern": torch.nn.Conv2d, "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, "dtype_configs": [ weighted_op_quint8_dtype_config, ], "root_module": torch.nn.Conv2d, # "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, } m = M().eval() backend_config_dict = { "configs": [ conv_add_config, conv_config, ] } prepare_custom_config_dict = { "standalone_module_name": [("standalone", None, {"input_quantized_idxs": [0, 1]}, None)] } # TODO: use self.qconfig after input_quantized_idxs and output_quantized_idxs # are more flexible qconfig = torch.ao.quantization.QConfig( activation=torch.ao.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, dtype=torch.quint8 ), weight=torch.ao.quantization.default_weight_observer ) m = prepare_fx( m, {"": qconfig}, prepare_custom_config_dict=prepare_custom_config_dict, backend_config_dict=backend_config_dict) node_occurrence = { # for input and output of conv, where input is used twice, once in conv and # once in standalone module ns.call_module(torch.ao.quantization.HistogramObserver): 2, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) standalone_node_occurrence = { # output of the standalone module ns.call_module(torch.ao.quantization.HistogramObserver): 1, } self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence) m = _convert_fx_do_not_use(m, is_reference=True, backend_config_dict=backend_config_dict) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 3, ns.call_module(nn.Conv2d): 1, ns.call_method("dequantize"): 1, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) standalone_node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_module(nn.Conv2d): 1, ns.call_module(torch.nn.ReLU): 1, ns.call_method("dequantize"): 3, } self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
def _test_standalone_module( self, interface_config, prepare_count_check, standalone_prepare_count_check, convert_count_check, standalone_convert_count_check, qconfig=None, backend_config_dict=None): """ Test standalone module with different quantized input/quantized output configurations """ class StandaloneModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(1, 1, 1) def forward(self, x): return self.conv(x) class M(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(1, 1, 1) self.standalone = StandaloneModule() def forward(self, x): x = self.conv(x) x = self.standalone(x) return x class RefM(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1) self.conv2 = torch.nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x if backend_config_dict is None: backend_config_dict = self.trt_backend_config_dict if qconfig is None: qconfig = self.trt_qconfig data = torch.randn(1, 1, 1, 1) # instantiate M and RefM and align the parameters original_m = M().eval() original_ref_m = RefM().eval() original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach()) original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach()) original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach()) prepare_config = { "standalone_module_name": [("standalone", None, interface_config, backend_config_dict)] } original_m_copy = copy.deepcopy(original_m) original_ref_m_copy = copy.deepcopy(original_ref_m) qconfig_dict = {"": qconfig} # check prepared model m = prepare_fx( original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config, backend_config_dict=backend_config_dict) # calibration m(data) self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check) self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check) # check converted/quantized model m = _convert_fx_do_not_use(m, is_reference=True, backend_config_dict=backend_config_dict) self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check) self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check) res = m(data) # quantize the reference model ref_m = prepare_fx(original_ref_m_copy, qconfig_dict, backend_config_dict=backend_config_dict) ref_m(data) ref_m = _convert_fx_do_not_use(ref_m, is_reference=True, backend_config_dict=backend_config_dict) ref_res = ref_m(data) self.assertEqual(res, ref_res)