def test_standalone_module_float_interface(self): float_interface_config = { "input_quantized_idxs": [], # float input "output_quantized_idxs": [], # float output } interface_config = float_interface_config # input and output of first conv, observer for standalone module # will be inserted in the standalone module itself prepare_count_check = { ns.call_module(torch.ao.quantization.HistogramObserver): 2 } # for input and output of conv in the standalone module standalone_prepare_count_check = { ns.call_module(torch.ao.quantization.HistogramObserver): 2 } convert_count_check = { # input and output of reference conv ns.call_function(torch.quantize_per_tensor) : 2, ns.call_module(nnqr.Conv2d) : 1, ns.call_method("dequantize") : 2, } standalone_convert_count_check = { # standalone module will take float as input and output # so we'll see quantize and dequantize in the modoule ns.call_function(torch.quantize_per_tensor) : 2, ns.call_module(nnqr.Conv2d): 1, ns.call_method("dequantize") : 2, } self._test_standalone_module( interface_config, prepare_count_check, standalone_prepare_count_check, convert_count_check, standalone_convert_count_check)
def test_linear(self): class LinearModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(5, 10) def forward(self, x): return self.linear(x) linear_module_input = torch.rand(8, 5) m = LinearModule().eval() qconfig = torch.quantization.QConfig( activation=torch.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, dtype=torch.qint8), weight=torch.quantization.default_weight_observer) prepared = prepare_fx( m, {"": qconfig}, backend_config_dict=get_tensorrt_backend_config_dict()) # calibration prepared(linear_module_input) quantized = convert_fx(prepared, is_reference=True) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1 } self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence) # lower to trt trt_mod = lower_to_trt(quantized, linear_module_input, [((1, *linear_module_input.shape[1:]), (5, *linear_module_input.shape[1:]), (10, *linear_module_input.shape[1:]))]) # make sure it runs trt_mod(linear_module_input.cuda())
def test_unsupported_qconfig(self): """ Check that we won't quantize the model if the qconfig is not supported """ class LinearModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(5, 10) def forward(self, x): return self.linear(x) linear_module_input = torch.rand(8, 5) m = LinearModule().eval() trt_unsupported_qconfig = default_qconfig prepared = prepare_fx(m, {"": trt_unsupported_qconfig}, backend_config_dict=self.trt_backend_config_dict) # calibration prepared(linear_module_input) quantized = _convert_fx_do_not_use(prepared, is_reference=True) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 0, ns.call_method("dequantize"): 0, ns.call_module(torch.nn.Linear): 1, ns.call_module(torch.nn.quantized._reference.Linear): 0, } # check model is not quantized self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
def test_linear_relu_module(self): class LinearModule(torch.nn.Module): def __init__(self, has_relu=False, f_relu=False): super().__init__() self.linear = torch.nn.Linear(5, 10).float() if has_relu: if f_relu: self.relu = F.relu else: self.relu = torch.nn.ReLU() else: self.relu = torch.nn.Identity() def forward(self, x): return self.relu(self.linear(x)) linear_input = torch.rand(8, 5) shape_ranges = [((1, 5), (5, 5), (10, 5))] for has_relu, f_relu in itertools.product([True, False], [True, False]): # when has_relu=False, we have torch.nn.Identity, which would introduce # extra quant-dequat pair no_convert = { ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu), ns.call_method("dequantize"): 2 + int(not has_relu), } self._test_module(LinearModule(has_relu, f_relu), [linear_input], shape_ranges, no_convert=no_convert)
def test_conv(self): class Conv2d(torch.nn.Module): def __init__(self, *args): super().__init__() self.conv = torch.nn.Conv2d(*args) def forward(self, x): return self.conv(x) conv2d_input = torch.rand(1, 3, 224, 224) conv2d_module_args = (3, 3, 3) m = Conv2d(*conv2d_module_args).eval() qconfig = torch.quantization.QConfig( activation=torch.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 ), weight=torch.quantization.default_weight_observer ) prepared = prepare_fx(m, {"": qconfig}, backend_config_dict=get_tensorrt_backend_config_dict()) # calibration prepared(conv2d_input) quantized = convert_fx(prepared, is_reference=True) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1 } self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence) # lower to trt trt_mod = lower_to_trt(quantized, conv2d_input, [((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))]) # make sure it runs trt_mod(conv2d_input.cuda())
def test_addmm(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.weight = torch.randn(5, 5) self.bias = torch.randn(5) def forward(self, x): return torch.addmm(self.bias, x, self.weight) m = M().eval() prepared = prepare_fx( m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict) node_occurrence = { # weight ns.call_module(torch.ao.quantization.MinMaxObserver): 1, # activation ns.call_module(torch.ao.quantization.HistogramObserver): 2, } self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence) quantized = _convert_fx_do_not_use( prepared, is_reference=True, backend_config_dict=self.trt_backend_config_dict) node_occurrence = { # input activation, output activation and weight ns.call_function(torch.quantize_per_tensor): 3, ns.call_function(torch.addmm): 1, ns.call_method("dequantize"): 3, } self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
def test_clamp(self): class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.conv = torch.nn.Conv2d(2, 2, 2).float() self.relu6 = torch.nn.ReLU6() self.relu6_ = torch.nn.ReLU6(True) self.hardtanh = torch.nn.Hardtanh() self.hardtanh_ = torch.nn.Hardtanh(inplace=True) def forward(self, x): x = self.conv(x) x = self.relu6(x) self.relu6_(x) x = F.relu6(x) x = torch.clamp(x, -3, 3) x = x.clamp(-2.5, 2.5) # x = x.clamp_(-2, 2) # Enable when quantized `clamp_` is ready x = self.hardtanh(x) self.hardtanh_(x) x = F.hardtanh(x) F.hardtanh_(x) return x data = (torch.rand((1, 2, 5, 5), dtype=torch.float),) # list of node that should occur in order node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_function(F.hardtanh_), ns.call_method('dequantize') ] for quant_type in self.static_quant_types: m = self.checkGraphModeFxOp( M(), data, quant_type, expected_node_list=node_list)
def test_add_shadow_loggers_multiple_dtype_casts(self): """ Verifies that for nodes where the first input arg is a list, such as `cat`, we insert an individual dtype cast for each arg of the list. """ class M(nn.Module): def __init__(self): super().__init__() def forward(self, x): x = torch.cat([x, x, x], dim=0) return x m = M().eval() expected_occurrence = { # 3 dequantize function calls from the 3 dtype casts for [x, x, x] ns.call_function(torch.dequantize): 3, # 1 dequantize method call for module output ns.call_method("dequantize"): 1, } self._test_match_shadow_activations( m, (torch.randn(4, 4), ), prepared_expected_node_occurrence=expected_occurrence, results_len=1)
def test_qconfig_none(self): class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.conv1 = nn.Conv2d(1, 1, 1) self.conv2 = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x m = M().eval() m = symbolic_trace(m) qconfig_dict = {'': default_qconfig, 'conv2': None} m = prepare_static_fx(m, qconfig_dict) data = torch.randn(1, 1, 1, 1) m(data) m = convert_static_fx(m) m(data) # first conv is quantized, second conv is not quantized node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_method('dequantize'), ns.call_module(nn.Conv2d), ] self.checkGraphModuleNodes(m, expected_node_list=node_list)
def test_ops(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.linear = torch.nn.Linear(5, 5) self.relu = torch.nn.ReLU() def forward(self, x): x = self.conv(x) x = self.linear(x) x = x + 3 x = self.relu(x) x = x + 6 return x m = M().eval() m = prepare_fx(m, {"": default_qconfig}) m = _convert_fx_do_not_use(m, is_reference=True) expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 5, ns.call_method("dequantize"): 5, ns.call_module(torch.nn.quantized._reference.Linear): 1, ns.call_module(torch.nn.quantized._reference.Conv2d): 1, } self.checkGraphModuleNodes( m, expected_node_occurrence=expected_occurrence)
def test_fp32_input_fp32_output(self): prepare_custom_config_dict = {} prepare_count_check = { ns.call_module(torch.ao.quantization.MinMaxObserver): 3, } convert_count_check = { ns.call_function(torch.quantize_per_tensor): 3, ns.call_method('dequantize'): 3, } self._test_quantized_inputs_outputs( prepare_custom_config_dict, prepare_count_check, convert_count_check)
def test_conv_add(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) def forward(self, x, y): return self.conv(x) + y weighted_op_qint8_dtype_config = { # optional, input activation dtype "input_dtype": torch.qint8, # optional, weight dtype "weight_dtype": torch.qint8, # optional, bias dtype "bias_dtype": torch.float, # optional, output activation dtype "output_dtype": torch.qint8 } conv_add_config = { "pattern": (operator.add, torch.nn.Conv2d, MatchAllNode), "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, "dtype_configs": [ weighted_op_qint8_dtype_config, ], "root_module": torch.nn.Conv2d, "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, } m = M().eval() modified_backend_config_dict = copy.deepcopy( self.trt_backend_config_dict) modified_backend_config_dict["configs"].insert(0, conv_add_config) m = prepare_fx(m, {"": self.qconfig}, backend_config_dict=modified_backend_config_dict) node_occurrence = { ns.call_module(torch.ao.quantization.HistogramObserver): 3, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) m = _convert_fx_do_not_use( m, is_reference=True, backend_config_dict=modified_backend_config_dict) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 3, ns.call_method("dequantize"): 3, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
def test_quantized_input_fp32_output(self): prepare_custom_config_dict = { 'input_quantized_idxs': [0]} prepare_count_check = { ns.call_module(torch.ao.quantization.MinMaxObserver): 2, } convert_count_check = { # output of conv1, conv2 ns.call_function(torch.quantize_per_tensor): 2, # input of ref conv1, input of ref conv2, final output ns.call_method('dequantize'): 3, } self._test_quantized_inputs_outputs( prepare_custom_config_dict, prepare_count_check, convert_count_check)
def test_conv_relu_module(self): conv_module = { 1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d } conv1d_input = torch.rand(1, 3, 10) conv2d_input = torch.rand(1, 3, 10, 10) conv3d_input = torch.rand(1, 3, 10, 10, 10) conv_input = {1: conv1d_input, 2: conv2d_input, 3: conv3d_input} class ConvNdModule(torch.nn.Module): def __init__(self, dim, has_relu=False, f_relu=False): super().__init__() self.conv = conv_module[dim](3, 3, 3).float() if has_relu: if f_relu: self.relu = F.relu else: self.relu = torch.nn.ReLU() else: self.relu = torch.nn.Identity() def forward(self, x): return self.relu(self.conv(x)) # just testing conv2d since conv1d and conv3d are not supported in fx2trt for dim, has_relu, f_relu, is_qat in itertools.product([2], [True, False], [True, False], [True, False]): # when has_relu=False, we have torch.nn.Identity, which would introduce # extra quant-dequat pair no_convert = { ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu), ns.call_method("dequantize"): 2 + int(not has_relu), } self._test_module(ConvNdModule(dim, has_relu, f_relu), [conv_input[dim]], [((1, *conv_input[dim].shape[1:]), (5, *conv_input[dim].shape[1:]), (10, *conv_input[dim].shape[1:]))], no_convert=no_convert, is_qat=is_qat)
def test_conv(self): class Conv2dModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) def forward(self, x): return self.conv(x) conv2d_input = torch.rand(1, 3, 224, 224) no_convert = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_method("dequantize"): 2 } self._test_module(Conv2dModule(), [conv2d_input], [((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))], no_convert=no_convert)
def test_linear(self): class LinearModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(5, 10) def forward(self, x): return self.linear(x) linear_input = torch.rand(8, 5) shape_ranges = [((1, 5), (5, 5), (10, 5))] no_convert = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_method("dequantize"): 2, } self._test_module(LinearModule(), [linear_input], shape_ranges, no_convert=no_convert)
def test_cat(self): class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.cat([x, x], 1) m = M().eval() prepared = prepare_fx(m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict) self.assertTrue(len(dict(prepared.named_children())) == 1) quantized = _convert_fx_do_not_use(prepared, is_reference=True) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_function(torch.cat): 1, ns.call_method("dequantize"): 2, } self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
def test_general_value_ops(self): """ A test that checks correct patterns are produced for all supported general value ops like aten::avg_pool2d \ without actually checking for execution of these ops """ class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.avg_pool1d = torch.nn.AvgPool1d(3) self.avg_pool2d = torch.nn.AvgPool2d(3) self.avg_pool3d = torch.nn.AvgPool3d(3) self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1)) self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d( (1, 1, 1)) self.leaky_relu = torch.nn.LeakyReLU() self.hardsigmoid = torch.nn.Hardsigmoid() self.sigmoid = torch.nn.Sigmoid() self.tanh = torch.nn.Tanh() def forward(self, x): x = self.conv(x) x = self.avg_pool1d(x) x = self.avg_pool2d(x) x = self.avg_pool3d(x) x = self.adaptive_avg_pool1d(x) x = self.adaptive_avg_pool2d(x) x = self.adaptive_avg_pool3d(x) x = F.avg_pool1d(x, 3) x = F.avg_pool2d(x, 3) x = F.avg_pool3d(x, 3) x = F.adaptive_avg_pool1d(x, (1)) x = F.adaptive_avg_pool2d(x, (1, 1)) x = F.adaptive_avg_pool3d(x, (1, 1, 1)) x = torch.mean(x) x = torch.mean(x, [2, 3], False) x = x.mean() x = x.mean([2, 3], True) x = F.interpolate(x, 4, mode='nearest') x = F.interpolate(x, 4, mode='linear') x = self.leaky_relu(x) x = F.leaky_relu(x) x = F.leaky_relu(x, inplace=True) x = x.leaky_relu() x.leaky_relu_() x = self.hardsigmoid(x) x = F.hardsigmoid(x) x = F.hardsigmoid(x, inplace=True) x = x.hardsigmoid() x.hardsigmoid_() x = self.sigmoid(x) x = torch.sigmoid(x) # F.sigmoid is deprecated x = x.sigmoid() x.sigmoid_() x = self.tanh(x) # F.tanh is deprecated x = torch.tanh(x) x = x.tanh() x.tanh_() x = self.conv(x) return x # This model is not executable since we just put all ops # in the same forward m = M() original = symbolic_trace(m) # nothing to fuse so skipping the fuse step qconfig_dict = {'': default_qconfig} prepared = prepare_fx(original, qconfig_dict) # not runnable quantized = convert_fx(prepared) # This checks that the dequantize from the output of first conv # is being propagated to the end, so that we don't insert extra # observers # check exact counts of quantize and dequantize count_check = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method('dequantize'): 1 } order_check = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize'), ] self.checkGraphModuleNodes(quantized, expected_node_occurrence=count_check, expected_node_list=order_check)
def test_general_shape_ops(self): """ A test that checks dequantize will be swapped for all supported general shape ops like aten::flatten without actually checking for execution of these ops """ class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3) self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3) self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3) self.dropout = torch.nn.Dropout() self.conv1 = torch.nn.Conv2d(3, 3, 3) self.conv2 = torch.nn.Conv2d(3, 3, 3) self.relu = torch.nn.ReLU() def forward(self, x): x = self.conv1(x) # add_scalar x = x + 3 # mul_scalar x = x * 3 # add_scalar_out x += 3 # mul_scalar_out x *= 3 # add_scalar_relu x = x + 3 x = F.relu(x) # add_scalar_relu_out x += 3 x = F.relu(x) # mul_scalar_relu x = x * 3 x = F.relu(x) # mul_scalar_relu_out x *= 3 x = F.relu(x) x = self.maxpool1d(x) x = self.maxpool2d(x) x = self.maxpool3d(x) x = torch.flatten(x) x = torch.max(x) x = torch.min(x) x = x.reshape([-1]) x = x.resize_(1, 1, x.numel()) x = x.view(-1) # prim::ListConstruct xs = [x, x] # prim::ListUnpack x, y = xs # prim::TupleConstruct xs = (x, x) # prim::TupleUnpack x, y = xs x = x.transpose(1, 2) x = x.contiguous() x, y = torch.chunk(x, 2) x = F.dropout(x) x = self.dropout(x) x, _ = torch.sort(x) x = x.permute(0, 2, 3, 1) x = x.repeat_interleave(3, 1) x = torch.repeat_interleave(x, 3, 1) x = self.relu(x) x = F.relu(x) x = F.relu(x, inplace=True) x = x.relu() x.relu_() x = x.squeeze(0) x.squeeze_(0) x = torch.squeeze(x, 0) x = x.unsqueeze(0) x.unsqueeze_(0) x = torch.unsqueeze(x, 0) x = x.detach() x.detach_() x = x.repeat(4, 2) y = [] y.append(x) z = torch.stack(y, 0) z = [z, z] x, _ = z x = self.conv2(x) return x data = torch.rand(1, 3, 10, 10) # This model is not executable since we just put all ops # in the same forward m = M() original = symbolic_trace(m) # nothing to fuse so skipping the fuse step qconfig_dict = {'': default_qconfig} prepared = prepare_fx(original, qconfig_dict) # not runnable quantized = convert_fx(prepared) # This checks that the dequantize from the output of first conv # is being propagated to the end, so that we don't insert extra # observers and also successfully fused two quantized::conv2d # patterns # one quantize_per_tensor for input # check exact counts of quantize and dequantize count_check = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method('dequantize'): 1 } order_check = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize'), ] self.checkGraphModuleNodes(quantized, expected_node_occurrence=count_check, expected_node_list=order_check)
def test_input_weight_equalization_graphs(self): """ Tests that the modified model for equalization has the same graph structure as the model without equalization (before and after quantization). """ linear_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] linearAdd_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize'), ns.call_function(torch.add), ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] linear2_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] functionalLinear_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize') ] functionalLinearAdd_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize'), ns.call_function(torch.add), ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize') ] functionalLinear2_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize') ] linearRelu_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.LinearReLU), ns.call_method('dequantize') ] linearReluLinear_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.LinearReLU), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] functionalLinearRelu_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear_relu), ns.call_method('dequantize') ] functionalLinearReluLinear_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear_relu), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize') ] conv_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_method('dequantize') ] conv2_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize') ] functionalConv_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.conv2d), ns.call_method('dequantize') ] functionalConv2_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.conv2d), ns.call_function(torch.ops.quantized.conv2d), ns.call_method('dequantize') ] convRelu_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.ConvReLU2d), ns.call_method('dequantize') ] convReluConv_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.ConvReLU2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize') ] functionalConvRelu_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.conv2d_relu), ns.call_method('dequantize') ] functionalConvReluConv_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.conv2d_relu), ns.call_function(torch.ops.quantized.conv2d), ns.call_method('dequantize') ] tests = [(SingleLayerLinearModel, linear_node_list), (LinearAddModel, linearAdd_node_list), (TwoLayerLinearModel, linear2_node_list), (SingleLayerFunctionalLinearModel, functionalLinear_node_list), (FunctionalLinearAddModel, functionalLinearAdd_node_list), (TwoLayerFunctionalLinearModel, functionalLinear2_node_list), (LinearReluModel, linearRelu_node_list), (LinearReluLinearModel, linearReluLinear_node_list), (FunctionalLinearReluModel, functionalLinearRelu_node_list), (FunctionalLinearReluLinearModel, functionalLinearReluLinear_node_list), (ConvModel, conv_node_list), (TwoLayerConvModel, conv2_node_list), (SingleLayerFunctionalConvModel, functionalConv_node_list), (TwoLayerFunctionalConvModel, functionalConv2_node_list), (ConvReluModel, convRelu_node_list), (ConvReluConvModel, convReluConv_node_list), (FunctionalConvReluModel, functionalConvRelu_node_list), (FunctionalConvReluConvModel, functionalConvReluConv_node_list)] for (M, node_list) in tests: m = M().eval() prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) equalized_quantized_model = convert_fx(prepared) # Check the order of nodes in the graph self.checkGraphModuleNodes(equalized_quantized_model, expected_node_list=node_list)
def test_conv_add_standalone_module(self): class Standalone(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.relu = torch.nn.ReLU() def forward(self, x, y): return self.relu(self.conv(x) + y) class M(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.standalone = Standalone() def forward(self, x, y): y = self.conv(x) return self.standalone(x, y) from torch.ao.quantization.fx.backend_config_dict.observation_type import ObservationType weighted_op_quint8_dtype_config = { # optional, input activation dtype # TODO: change back to torch.qint8 after input_quantized_idxs and output_quantized_idxs # are more flexible "input_dtype": torch.quint8, # optional, weight dtype "weight_dtype": torch.qint8, # optional, bias dtype "bias_dtype": torch.float, # optional, output activation dtype "output_dtype": torch.quint8 } conv_add_config = { "pattern": (torch.nn.ReLU, (operator.add, torch.nn.Conv2d, MatchAllNode)), "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, "dtype_configs": [ weighted_op_quint8_dtype_config, ], "root_module": torch.nn.Conv2d, # "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, } conv_config = { "pattern": torch.nn.Conv2d, "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, "dtype_configs": [ weighted_op_quint8_dtype_config, ], "root_module": torch.nn.Conv2d, # "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, } m = M().eval() backend_config_dict = { "configs": [ conv_add_config, conv_config, ] } prepare_custom_config_dict = { "standalone_module_name": [("standalone", None, {"input_quantized_idxs": [0, 1]}, None)] } # TODO: use self.qconfig after input_quantized_idxs and output_quantized_idxs # are more flexible qconfig = torch.ao.quantization.QConfig( activation=torch.ao.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, dtype=torch.quint8 ), weight=torch.ao.quantization.default_weight_observer ) m = prepare_fx( m, {"": qconfig}, prepare_custom_config_dict=prepare_custom_config_dict, backend_config_dict=backend_config_dict) node_occurrence = { # for input and output of conv, where input is used twice, once in conv and # once in standalone module ns.call_module(torch.ao.quantization.HistogramObserver): 2, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) standalone_node_occurrence = { # output of the standalone module ns.call_module(torch.ao.quantization.HistogramObserver): 1, } self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence) m = _convert_fx_do_not_use(m, is_reference=True, backend_config_dict=backend_config_dict) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 3, ns.call_module(nn.Conv2d): 1, ns.call_method("dequantize"): 1, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) standalone_node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_module(nn.Conv2d): 1, ns.call_module(torch.nn.ReLU): 1, ns.call_method("dequantize"): 3, } self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
def test_selective_equalization(self): """ Tests that we are able to run numeric suite on the equalized model and construct a valid equalization_qconfig_dict equalizing only the top 4 layers with the highest quantization errors. """ torch.manual_seed(1) class M(nn.Module): def __init__(self): super().__init__() self.bot = torch.nn.Sequential(torch.nn.Linear(5, 5)) self.top = torch.nn.Sequential(torch.nn.Linear(5, 5)) def forward(self, x): x = self.bot(x) x = torch.add(x, 5) x = self.top(x) return x float_model = M().eval() # Hard coded so that the top layer has a higher quantization error x = torch.tensor([[0.0642, 0.7824, 0.4255, 0.7106, 0.5957], [0.8373, 0.8851, 0.8229, 0.0212, 0.8987], [0.9077, 0.7538, 0.4530, 0.5772, 0.1376], [0.0690, 0.9002, 0.7998, 0.2768, 0.8985], [0.0282, 0.5068, 0.6725, 0.1829, 0.5480]]) # Quantize the float model prepared_model = prepare_fx(copy.deepcopy(float_model), specific_qconfig_dict) prepared_model(x) quantized_model = convert_fx(copy.deepcopy(prepared_model)) # Get the SQNR between the float and quantized model layer_to_sqnr_dict = get_layer_sqnr_dict(copy.deepcopy(prepared_model), quantized_model, x) # Construct the equalization_qconfig_dict equalizing layers with the highest # quantization errors selective_equalization_qconfig_dict = get_equalization_qconfig_dict(layer_to_sqnr_dict, 1) # Create the selectively equalized model prepared_model = prepare_fx( copy.deepcopy(float_model), specific_qconfig_dict, equalization_qconfig_dict=selective_equalization_qconfig_dict, ) prepared_model(x) equalized_model = convert_fx(prepared_model) node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize'), ns.call_function(torch.add), ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] # Check the order of nodes in the graph self.checkGraphModuleNodes(equalized_model, expected_node_list=node_list)
def test_standalone_module_quantized_interface(self): quantized_interface_config = { "input_quantized_idxs": [0], # quantized input "output_quantized_idxs": [0], # quantized output } interface_config = quantized_interface_config # TODO: input_quantized_idxs only supports quint8, we can remove this # custom_backend_config_dict after # the `input_quantized_idxs` supports more complicated # configurations, as a first step we can change it to use a dictionary from # index to dtype qconfig = torch.ao.quantization.QConfig( activation=torch.ao.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, dtype=torch.quint8 ), weight=torch.ao.quantization.default_weight_observer ) weighted_op_quint8_dtype_config = { # optional, input activation dtype "input_dtype": torch.quint8, # optional, weight dtype "weight_dtype": torch.qint8, # optional, bias dtype "bias_dtype": torch.float, # optional, output activation dtype "output_dtype": torch.quint8 } conv_module_config = { "pattern": torch.nn.Conv2d, "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, "dtype_configs": [ weighted_op_quint8_dtype_config, ], "root_module": torch.nn.Conv2d, "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, } custom_backend_config_dict = { "configs": [conv_module_config] } # observer for input and output of first conv prepare_count_check = { ns.call_module(torch.ao.quantization.HistogramObserver): 2 } # for output of conv in the standalone module standalone_prepare_count_check = { ns.call_module(torch.ao.quantization.HistogramObserver): 1 } convert_count_check = { # quantizing input/output for reference conv ns.call_function(torch.quantize_per_tensor) : 2, ns.call_module(nnqr.Conv2d) : 1, # dequantize the input of reference conv and # dequantizing output of standalone module ns.call_method("dequantize") : 2, } standalone_convert_count_check = { # quantization of input happens in parent module # quantization of output happens in the standalone module ns.call_function(torch.quantize_per_tensor) : 1, ns.call_module(nnqr.Conv2d): 1, # dequantization of input happens in the standalone module # dequantization for output happens in parent module ns.call_method("dequantize") : 1, } self._test_standalone_module( interface_config, prepare_count_check, standalone_prepare_count_check, convert_count_check, standalone_convert_count_check, qconfig=qconfig, backend_config_dict=custom_backend_config_dict)