def test_qbatch_norm_relu(self): bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d} class BNRelu(torch.nn.Module): def __init__(self, dim, inplace): super(BNRelu, self).__init__() self.bn = bn_module[dim](3).to(torch.float) self.relu = torch.nn.ReLU(inplace=inplace) def forward(self, x): return self.relu(self.bn(x)) class BNFuncRelu(torch.nn.Module): def __init__(self, dim): super(BNFuncRelu, self).__init__() self.bn = bn_module[dim](3).to(torch.float) def forward(self, x): return F.relu(self.bn(x), False) class BNFuncInplaceRelu(torch.nn.Module): def __init__(self, dim): super(BNFuncInplaceRelu, self).__init__() self.bn = bn_module[dim](3).to(torch.float) def forward(self, x): return F.relu(self.bn(x), True) options = itertools.product(self.static_quant_types, [2, 3]) quantized_nodes = { 2: ns.call_module(nniq.BNReLU2d), 3: ns.call_module(nniq.BNReLU3d), } for quant_type, dim in options: for instance in [BNRelu(dim, True), BNRelu(dim, False), BNFuncRelu(dim), BNFuncInplaceRelu(dim)]: self.checkGraphModeFxOp( instance, self.img_data_dict[dim], quant_type, quantized_nodes[dim])
def _test_norm_impl(self, float_module, float_op, op_args, data, quantized_module, quantized_op, skip_op_arg_for_functional=False): ''' Test for normalization op, float_op can be torch op or functional op, op_args is a list of positional argument for the module/op ''' class M(torch.nn.Module): def __init__(self, is_module): super(M, self).__init__() self.is_module = is_module if self.is_module: self.op = float_module(*op_args) else: self.op = float_op def forward(self, input): if self.is_module: return self.op(input) else: args = [input] if not skip_op_arg_for_functional: args += op_args return self.op(*args) options = itertools.product([True, False], self.static_quant_types) quantized_nodes = { # is_module True: ns.call_module(quantized_module), False: ns.call_function(quantized_op), } for is_module, quant_type in options: self.checkGraphModeFxOp(M(is_module), data, quant_type, quantized_nodes[is_module])
def test_match_activations_mod(self): m = nn.Sequential( torch.quantization.QuantStub(), nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1), ).eval() expected_occurrence = { ns.call_module(OutputLogger): 2, } self._test_match_activations( m, (torch.randn(2, 1, 2, 2),), prepared_expected_node_occurrence=expected_occurrence, results_len=2)
def test_quantized_conv(self): conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d} class Conv(torch.nn.Module): def __init__(self, dim): super(Conv, self).__init__() self.conv = conv_module[dim](3, 3, 3).float() def forward(self, x): return self.conv(x) options = itertools.product([1, 2, 3], self.static_quant_types) quantized_nodes = { # dim 1: ns.call_module(nnq.Conv1d), 2: ns.call_module(nnq.Conv2d), 3: ns.call_module(nnq.Conv3d), } for dim, quant_type in options: model = self.checkGraphModeFxOp( Conv(dim), self.img_data_dict[dim], quant_type, quantized_nodes[dim])
def test_match_activations_fun(self): class M(nn.Module): def __init__(self): super().__init__() self.w1 = nn.Parameter(torch.Tensor(4, 4)) self.b1 = nn.Parameter(torch.zeros(4)) self.w2 = nn.Parameter(torch.Tensor(4, 4)) self.b2 = nn.Parameter(torch.zeros(4)) torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5)) def forward(self, x): x = F.linear(x, self.w1, self.b1) x = F.linear(x, self.w2, self.b2) x = F.relu(x) return x m = M().eval() mp = prepare_fx(m, {'': torch.quantization.default_qconfig}) mp(torch.randn(4, 4)) # TODO(future PR): prevent the need for copying here, we can copy the # modules but should reuse the underlying tensors mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) mp_ns, mq_ns = prepare_model_outputs('fp32_prepared', mp, 'int8', mq, OutputLogger) expected_occurrence = { ns.call_module(OutputLogger): 2, } self.checkGraphModuleNodes( mp_ns, expected_node_occurrence=expected_occurrence) self.checkGraphModuleNodes( mq_ns, expected_node_occurrence=expected_occurrence) # TODO(before land): test both scripted and non-scripted mp_ns = torch.jit.script(mp_ns) mq_ns = torch.jit.script(mq_ns) # calibrate input_fp32 = torch.randn(4, 4) mp_ns(input_fp32) mq_ns(input_fp32) # check activation result correctness act_compare_dict = get_matching_activations('fp32_prepared', mp_ns, 'int8', mq_ns, OutputLogger) self.assertTrue(len(act_compare_dict) == 2) self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)
def test_addmm(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.weight = torch.randn(5, 5) self.bias = torch.randn(5) def forward(self, x): return torch.addmm(self.bias, x, self.weight) m = M().eval() prepared = prepare_fx(m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict) node_occurrence = { # weight ns.call_module(torch.ao.quantization.MinMaxObserver): 1, # activation ns.call_module(torch.ao.quantization.HistogramObserver): 2, } self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence) quantized = _convert_fx_do_not_use( prepared, is_reference=True, backend_config_dict=self.trt_backend_config_dict) node_occurrence = { # input activation, output activation and weight ns.call_function(torch.quantize_per_tensor): 3, ns.call_function(torch.addmm): 1, ns.call_method("dequantize"): 3, } self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
def test_conv_relu_module(self): conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d} conv1d_input = torch.rand(1, 3, 10) conv2d_input = torch.rand(1, 3, 10, 10) conv3d_input = torch.rand(1, 3, 10, 10, 10) conv_input = {1: conv1d_input, 2: conv2d_input, 3: conv3d_input} class ConvNdModule(torch.nn.Module): def __init__(self, dim, has_relu=False, f_relu=False): super().__init__() self.conv = conv_module[dim](3, 3, 3).float() if has_relu: if f_relu: self.relu = F.relu else: self.relu = torch.nn.ReLU() else: self.relu = torch.nn.Identity() def forward(self, x): return self.relu(self.conv(x)) # just testing conv2d since conv1d and conv3d are not supported in fx2trt for dim, has_relu, f_relu in itertools.product([2], [True, False], [True, False]): # when has_relu=False, we have torch.nn.Identity, which would introduce # extra quant-dequat pair no_convert = { ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu), ns.call_method("dequantize"): 2 + int(not has_relu), } self._test_module( ConvNdModule(dim, has_relu, f_relu), [conv_input[dim]], [((1, *conv_input[dim].shape[1:]), (5, *conv_input[dim].shape[1:]), (10, *conv_input[dim].shape[1:]))], no_convert=no_convert)
def test_linear_fp16_shadow_activations(self): for should_log_inputs in (True, False): qconfig_dict = {'': torch.quantization.float16_static_qconfig} m = LinearReluFunctional().eval() num_loggers = 4 if should_log_inputs else 2 expected_occurrence = { ns.call_module(OutputLogger): num_loggers, } res2 = self._test_match_shadow_activations( m, (torch.randn(4, 4), ), prepared_expected_node_occurrence=expected_occurrence, results_len=1, qconfig_dict=qconfig_dict, should_log_inputs=should_log_inputs)
def test_cat(self): class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.cat([x, x], 1) m = M().eval() prepared = prepare_fx(m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict) self.assertTrue(len(dict(prepared.named_children())) == 1) quantized = _convert_fx_do_not_use( prepared, is_reference=True, backend_config_dict=self.trt_backend_config_dict) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 2, ns.call_function(torch.cat): 1, ns.call_method("dequantize"): 2, } self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
def test_clamp(self): class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.conv = torch.nn.Conv2d(2, 2, 2).float() self.relu6 = torch.nn.ReLU6() self.relu6_ = torch.nn.ReLU6(True) self.hardtanh = torch.nn.Hardtanh() self.hardtanh_ = torch.nn.Hardtanh(inplace=True) def forward(self, x): x = self.conv(x) x = self.relu6(x) self.relu6_(x) x = F.relu6(x) x = torch.clamp(x, -3, 3) x = x.clamp(-2.5, 2.5) # x = x.clamp_(-2, 2) # Enable when quantized `clamp_` is ready x = self.hardtanh(x) self.hardtanh_(x) x = F.hardtanh(x) F.hardtanh_(x) return x data = (torch.rand((1, 2, 5, 5), dtype=torch.float), ) # list of node that should occur in order node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_function(F.hardtanh_), ns.call_method('dequantize') ] for quant_type in self.static_quant_types: m = self.checkGraphModeFxOp(M(), data, quant_type, expected_node_list=node_list)
def test_embedding(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) def forward(self, indices): return self.emb(indices) model = M().eval() indices = torch.randint(low=0, high=10, size=(20,)) quantized_node = ns.call_module(nnq.Embedding) configs = [ (float_qparams_weight_only_qconfig, ns.call_module(nnq.Embedding)), (None, ns.call_module(nn.Embedding)), (default_qconfig, ns.call_module(nn.Embedding)), ] for qconfig, node in configs: qconfig_dict = {"": qconfig} m = prepare_fx(model, qconfig_dict) m = convert_fx(m) self._compare_script_and_mobile(m, input=indices)
def test_linear_relu_module(self): class LinearModule(torch.nn.Module): def __init__(self, has_relu=False, f_relu=False): super().__init__() self.linear = torch.nn.Linear(5, 10).float() if has_relu: if f_relu: self.relu = F.relu else: self.relu = torch.nn.ReLU() else: self.relu = torch.nn.Identity() def forward(self, x): return self.relu(self.linear(x)) linear_input = torch.rand(8, 5) shape_ranges = [ ((1, 5), (5, 5), (10, 5)) ] for has_relu, f_relu, is_qat in itertools.product([True, False], [True, False], [True, False]): # when has_relu=False, we have torch.nn.Identity, which would introduce # extra quant-dequat pair no_convert = { ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu), ns.call_method("dequantize"): 2 + int(not has_relu), } self._test_module( LinearModule(has_relu, f_relu), [linear_input], shape_ranges, no_convert=no_convert, is_qat=is_qat)
def test_linear(self): class LinearModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(5, 10) def forward(self, x): return self.linear(x) linear_module_input = torch.rand(8, 5) m = LinearModule().eval() qconfig = torch.quantization.QConfig( activation=torch.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 ), weight=torch.quantization.default_weight_observer ) prepared = prepare_fx(m, {"": qconfig}, backend_config_dict=get_tensorrt_backend_config_dict()) # calibration prepared(linear_module_input) quantized = convert_fx(prepared, is_reference=True) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1 } self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence) # lower to trt trt_mod = lower_to_trt( quantized, linear_module_input, [((1, *linear_module_input.shape[1:]), (5, *linear_module_input.shape[1:]), (10, *linear_module_input.shape[1:]))]) # make sure it runs trt_mod(linear_module_input.cuda())
def _test_match_activations_fun_impl(self, prepare_fn=prepare_fx): m = LinearReluLinearFunctional().eval() qconfig_dict = None if prepare_fn == prepare_qat_fx: qconfig_dict = { '': torch.quantization.get_default_qat_qconfig('fbgemm') } expected_occurrence = { ns.call_module(OutputLogger): 2, } self._test_match_activations( m, (torch.randn(4, 4), ), prepared_expected_node_occurrence=expected_occurrence, results_len=2, prepare_fn=prepare_fn, qconfig_dict=qconfig_dict)
def _test_match_activations_mod_impl(self, prepare_fn=prepare_fx): m = nn.Sequential( torch.quantization.QuantStub(), nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1), ).eval() qconfig_dict = None if prepare_fn == prepare_qat_fx: qconfig_dict = { '': torch.quantization.get_default_qat_qconfig('fbgemm') } expected_occurrence = { ns.call_module(OutputLogger): 2, } self._test_match_activations( m, (torch.randn(2, 1, 2, 2), ), prepared_expected_node_occurrence=expected_occurrence, results_len=2, qconfig_dict=qconfig_dict, prepare_fn=prepare_fn)
def test_standalone_module_float_interface(self): float_interface_config = { "input_quantized_idxs": [], # float input "output_quantized_idxs": [], # float output } interface_config = float_interface_config # input and output of first conv, observer for standalone module # will be inserted in the standalone module itself prepare_count_check = { ns.call_module(torch.ao.quantization.HistogramObserver): 2 } # for input and output of conv in the standalone module standalone_prepare_count_check = { ns.call_module(torch.ao.quantization.HistogramObserver): 2 } convert_count_check = { # input and output of reference conv ns.call_function(torch.quantize_per_tensor): 2, ns.call_module(nnqr.Conv2d): 1, ns.call_method("dequantize"): 2, } standalone_convert_count_check = { # standalone module will take float as input and output # so we'll see quantize and dequantize in the modoule ns.call_function(torch.quantize_per_tensor): 2, ns.call_module(nnqr.Conv2d): 1, ns.call_method("dequantize"): 2, } self._test_standalone_module(interface_config, prepare_count_check, standalone_prepare_count_check, convert_count_check, standalone_convert_count_check)
def test_match_activations_mod(self): m = nn.Sequential( torch.quantization.QuantStub(), nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1), ).eval() mp = prepare_fx(m, {'': torch.quantization.default_qconfig}) mp(torch.randn(2, 1, 2, 2)) # TODO(future PR): prevent the need for copying here, we can copy the # modules but should reuse the underlying tensors mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) mp_ns, mq_ns = prepare_model_outputs('fp32_prepared', mp, 'int8', mq, OutputLogger) expected_occurrence = { ns.call_module(OutputLogger): 2, } self.checkGraphModuleNodes( mp_ns, expected_node_occurrence=expected_occurrence) self.checkGraphModuleNodes( mq_ns, expected_node_occurrence=expected_occurrence) # TODO(before land): test both scripted and non-scripted mp_ns = torch.jit.script(mp_ns) mq_ns = torch.jit.script(mq_ns) # calibrate input_fp32 = torch.randn(2, 1, 2, 2) mp_ns(input_fp32) mq_ns(input_fp32) # check activation result correctness act_compare_dict = get_matching_activations('fp32_prepared', mp_ns, 'int8', mq_ns, OutputLogger) self.assertTrue(len(act_compare_dict) == 2) self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)
def _test_quantized_binary_op_impl(self, binary_op, ibinary_op, quantized_op): class Op(torch.nn.Module): def __init__(self, is_inplace, is_scalar): super(Op, self).__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1).float() self.conv2 = torch.nn.Conv2d(1, 1, 1).float() self.is_scalar = is_scalar self.op = ibinary_op if is_inplace else binary_op def forward(self, x, y): x = self.conv1(x) y = 3 if self.is_scalar else self.conv2(y) x = self.op(x, y) return x # TODO: decide whether we want to quantize or not # in this case # class NonQuantizedOp(torch.nn.Module): # def __init__(self, is_inplace, is_scalar): # super(NonQuantizedOp, self).__init__() # self.is_scalar = is_scalar # self.op = ibinary_op if is_inplace else binary_op # def forward(self, x, y): # y = 3 if self.is_scalar else y # x = self.op(x, y) # return x data = (torch.randn(1, 1, 1, 1, dtype=torch.float), torch.randn(1, 1, 1, 1, dtype=torch.float)) quantized_node = ns.call_function(quantized_op) options = itertools.product([True, False], [True, False]) quant_type = QuantType.STATIC for is_inplace, is_scalar in options: self.checkGraphModeFxOp(Op(is_inplace, is_scalar), data, quant_type, quantized_node)
def _test_quantized_binary_op_relu_impl(self, binary_op, ibinary_op, quantized_op): class OpRelu(torch.nn.Module): def __init__(self, is_inplace, is_functional_relu, is_inplace_relu, is_scalar): super(OpRelu, self).__init__() self.conv1 = torch.nn.Conv2d(2, 2, 2).float() self.conv2 = torch.nn.Conv2d(2, 2, 2).float() self.op = ibinary_op if is_inplace else binary_op self.is_functional_relu = is_functional_relu self.is_inplace_relu = is_inplace_relu self.is_scalar = is_scalar if self.is_functional_relu: self.relu = F.relu else: self.relu = torch.nn.ReLU(self.is_inplace_relu) def forward(self, x, y): x = self.conv1(x) y = 3 if self.is_scalar else self.conv2(y) x = self.op(x, y) x = self.relu(x, self.is_inplace_relu) if \ self.is_functional_relu else self.relu(x) return x data = (torch.rand((1, 2, 5, 5), dtype=torch.float), torch.rand((1, 2, 5, 5), dtype=torch.float)) quantized_node = ns.call_function(quantized_op) options = itertools.product([True, False], [True, False], [True, False], [True, False], self.static_quant_types) for is_inplace_op, is_functional_relu, is_inplace_relu, is_scalar, quant_type in options: self.checkGraphModeFxOp( OpRelu(is_inplace_op, is_functional_relu, is_inplace_relu, is_scalar), data, quant_type, quantized_node)
def test_selective_equalization(self): """ Tests that we are able to run numeric suite on the equalized model and construct a valid equalization_qconfig_dict equalizing only the top 4 layers with the highest quantization errors. """ torch.manual_seed(1) class M(nn.Module): def __init__(self): super().__init__() self.bot = torch.nn.Sequential(torch.nn.Linear(5, 5)) self.top = torch.nn.Sequential(torch.nn.Linear(5, 5)) def forward(self, x): x = self.bot(x) x = torch.add(x, 5) x = self.top(x) return x float_model = M().eval() # Hard coded so that the top layer has a higher quantization error x = torch.tensor([[0.0642, 0.7824, 0.4255, 0.7106, 0.5957], [0.8373, 0.8851, 0.8229, 0.0212, 0.8987], [0.9077, 0.7538, 0.4530, 0.5772, 0.1376], [0.0690, 0.9002, 0.7998, 0.2768, 0.8985], [0.0282, 0.5068, 0.6725, 0.1829, 0.5480]]) # Quantize the float model prepared_model = prepare_fx(copy.deepcopy(float_model), specific_qconfig_dict) prepared_model(x) quantized_model = convert_fx(copy.deepcopy(prepared_model)) # Get the SQNR between the float and quantized model layer_to_sqnr_dict = get_layer_sqnr_dict(copy.deepcopy(prepared_model), quantized_model, x) # Construct the equalization_qconfig_dict equalizing layers with the highest # quantization errors selective_equalization_qconfig_dict = get_equalization_qconfig_dict(layer_to_sqnr_dict, 1) # Create the selectively equalized model prepared_model = prepare_fx( copy.deepcopy(float_model), specific_qconfig_dict, equalization_qconfig_dict=selective_equalization_qconfig_dict, ) prepared_model(x) equalized_model = convert_fx(prepared_model) node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize'), ns.call_function(torch.add), ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] # Check the order of nodes in the graph self.checkGraphModuleNodes(equalized_model, expected_node_list=node_list)
def test_input_weight_equalization_graphs(self): """ Tests that the modified model for equalization has the same graph structure as the model without equalization (before and after quantization). """ linear_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] linearAdd_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize'), ns.call_function(torch.add), ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] linear2_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Linear), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] functionalLinear_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize') ] functionalLinearAdd_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize'), ns.call_function(torch.add), ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize') ] functionalLinear2_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize') ] linearRelu_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.LinearReLU), ns.call_method('dequantize') ] linearReluLinear_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.LinearReLU), ns.call_module(nnq.Linear), ns.call_method('dequantize') ] functionalLinearRelu_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear_relu), ns.call_method('dequantize') ] functionalLinearReluLinear_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.linear_relu), ns.call_function(torch.ops.quantized.linear), ns.call_method('dequantize') ] conv_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_method('dequantize') ] conv2_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize') ] functionalConv_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.conv2d), ns.call_method('dequantize') ] functionalConv2_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.conv2d), ns.call_function(torch.ops.quantized.conv2d), ns.call_method('dequantize') ] convRelu_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.ConvReLU2d), ns.call_method('dequantize') ] convReluConv_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.ConvReLU2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize') ] functionalConvRelu_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.conv2d_relu), ns.call_method('dequantize') ] functionalConvReluConv_node_list = [ ns.call_function(torch.mul), ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.conv2d_relu), ns.call_function(torch.ops.quantized.conv2d), ns.call_method('dequantize') ] tests = [(SingleLayerLinearModel, linear_node_list), (LinearAddModel, linearAdd_node_list), (TwoLayerLinearModel, linear2_node_list), (SingleLayerFunctionalLinearModel, functionalLinear_node_list), (FunctionalLinearAddModel, functionalLinearAdd_node_list), (TwoLayerFunctionalLinearModel, functionalLinear2_node_list), (LinearReluModel, linearRelu_node_list), (LinearReluLinearModel, linearReluLinear_node_list), (FunctionalLinearReluModel, functionalLinearRelu_node_list), (FunctionalLinearReluLinearModel, functionalLinearReluLinear_node_list), (ConvModel, conv_node_list), (TwoLayerConvModel, conv2_node_list), (SingleLayerFunctionalConvModel, functionalConv_node_list), (TwoLayerFunctionalConvModel, functionalConv2_node_list), (ConvReluModel, convRelu_node_list), (ConvReluConvModel, convReluConv_node_list), (FunctionalConvReluModel, functionalConvRelu_node_list), (FunctionalConvReluConvModel, functionalConvReluConv_node_list)] for (M, node_list) in tests: m = M().eval() prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) equalized_quantized_model = convert_fx(prepared) # Check the order of nodes in the graph self.checkGraphModuleNodes(equalized_quantized_model, expected_node_list=node_list)
def test_input_weight_equalization_prepare(self): """ Tests that graphs created after prepare_fx is as expected """ single_nn_layer_node_occurrence = { ns.call_module(_InputEqualizationObserver): 1, ns.call_module(MinMaxObserver): 2, } two_nn_layer_node_occurrence = { ns.call_module(_InputEqualizationObserver): 2, ns.call_module(MinMaxObserver): 3, } single_F_layer_node_occurrence = { ns.call_module(_InputEqualizationObserver): 1, ns.call_module(_WeightEqualizationObserver): 1, ns.call_module(MinMaxObserver): 3, } two_F_layer_node_occurrence = { ns.call_module(_InputEqualizationObserver): 2, ns.call_module(_WeightEqualizationObserver): 2, ns.call_module(MinMaxObserver): 5, } fp_F_layer_node_occurrence = { ns.call_module(_InputEqualizationObserver): 2, ns.call_module(_WeightEqualizationObserver): 2, ns.call_module(MinMaxObserver): 6, } tests = [(SingleLayerLinearModel, single_nn_layer_node_occurrence), (TwoLayerLinearModel, two_nn_layer_node_occurrence), (TwoLayerFunctionalLinearModel, two_F_layer_node_occurrence), (FunctionalLinearAddModel, fp_F_layer_node_occurrence), (LinearReluModel, single_nn_layer_node_occurrence), (LinearReluLinearModel, two_nn_layer_node_occurrence), (FunctionalLinearReluModel, single_F_layer_node_occurrence), (FunctionalLinearReluLinearModel, two_F_layer_node_occurrence), (ConvModel, single_nn_layer_node_occurrence), (TwoLayerConvModel, two_nn_layer_node_occurrence), (TwoLayerFunctionalConvModel, two_F_layer_node_occurrence), (ConvReluModel, single_nn_layer_node_occurrence), (ConvReluConvModel, two_nn_layer_node_occurrence), (FunctionalConvReluModel, single_F_layer_node_occurrence), (FunctionalConvReluConvModel, two_F_layer_node_occurrence)] for (M, node_occurrence) in tests: m = M().eval() prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
def test_general_value_ops(self): """ A test that checks correct patterns are produced for all supported general value ops like aten::avg_pool2d \ without actually checking for execution of these ops """ class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.avg_pool1d = torch.nn.AvgPool1d(3) self.avg_pool2d = torch.nn.AvgPool2d(3) self.avg_pool3d = torch.nn.AvgPool3d(3) self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1)) self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d( (1, 1, 1)) self.leaky_relu = torch.nn.LeakyReLU() self.hardsigmoid = torch.nn.Hardsigmoid() self.sigmoid = torch.nn.Sigmoid() self.tanh = torch.nn.Tanh() def forward(self, x): x = self.conv(x) x = self.avg_pool1d(x) x = self.avg_pool2d(x) x = self.avg_pool3d(x) x = self.adaptive_avg_pool1d(x) x = self.adaptive_avg_pool2d(x) x = self.adaptive_avg_pool3d(x) x = F.avg_pool1d(x, 3) x = F.avg_pool2d(x, 3) x = F.avg_pool3d(x, 3) x = F.adaptive_avg_pool1d(x, (1)) x = F.adaptive_avg_pool2d(x, (1, 1)) x = F.adaptive_avg_pool3d(x, (1, 1, 1)) x = torch.mean(x) x = torch.mean(x, [2, 3], False) x = x.mean() x = x.mean([2, 3], True) x = F.interpolate(x, 4, mode='nearest') x = F.interpolate(x, 4, mode='linear') x = self.leaky_relu(x) x = F.leaky_relu(x) x = F.leaky_relu(x, inplace=True) x = x.leaky_relu() x.leaky_relu_() x = self.hardsigmoid(x) x = F.hardsigmoid(x) x = F.hardsigmoid(x, inplace=True) x = x.hardsigmoid() x.hardsigmoid_() x = self.sigmoid(x) x = torch.sigmoid(x) # F.sigmoid is deprecated x = x.sigmoid() x.sigmoid_() x = self.tanh(x) # F.tanh is deprecated x = torch.tanh(x) x = x.tanh() x.tanh_() x = self.conv(x) return x # This model is not executable since we just put all ops # in the same forward m = M() original = symbolic_trace(m) # nothing to fuse so skipping the fuse step qconfig_dict = {'': default_qconfig} prepared = prepare_fx(original, qconfig_dict) # not runnable quantized = convert_fx(prepared) # This checks that the dequantize from the output of first conv # is being propagated to the end, so that we don't insert extra # observers # check exact counts of quantize and dequantize count_check = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method('dequantize'): 1 } order_check = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize'), ] self.checkGraphModuleNodes(quantized, expected_node_occurrence=count_check, expected_node_list=order_check)
def test_general_shape_ops(self): """ A test that checks dequantize will be swapped for all supported general shape ops like aten::flatten without actually checking for execution of these ops """ class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3) self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3) self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3) self.dropout = torch.nn.Dropout() self.conv1 = torch.nn.Conv2d(3, 3, 3) self.conv2 = torch.nn.Conv2d(3, 3, 3) self.relu = torch.nn.ReLU() def forward(self, x): x = self.conv1(x) # add_scalar x = x + 3 # mul_scalar x = x * 3 # add_scalar_out x += 3 # mul_scalar_out x *= 3 # add_scalar_relu x = x + 3 x = F.relu(x) # add_scalar_relu_out x += 3 x = F.relu(x) # mul_scalar_relu x = x * 3 x = F.relu(x) # mul_scalar_relu_out x *= 3 x = F.relu(x) x = self.maxpool1d(x) x = self.maxpool2d(x) x = self.maxpool3d(x) x = torch.flatten(x) x = torch.max(x) x = torch.min(x) x = x.reshape([-1]) x = x.resize_(1, 1, x.numel()) x = x.view(-1) # prim::ListConstruct xs = [x, x] # prim::ListUnpack x, y = xs # prim::TupleConstruct xs = (x, x) # prim::TupleUnpack x, y = xs x = x.transpose(1, 2) x = x.contiguous() x, y = torch.chunk(x, 2) x = F.dropout(x) x = self.dropout(x) x, _ = torch.sort(x) x = x.permute(0, 2, 3, 1) x = x.repeat_interleave(3, 1) x = torch.repeat_interleave(x, 3, 1) x = self.relu(x) x = F.relu(x) x = F.relu(x, inplace=True) x = x.relu() x.relu_() x = x.squeeze(0) x.squeeze_(0) x = torch.squeeze(x, 0) x = x.unsqueeze(0) x.unsqueeze_(0) x = torch.unsqueeze(x, 0) x = x.detach() x.detach_() x = x.repeat(4, 2) y = [] y.append(x) z = torch.stack(y, 0) z = [z, z] x, _ = z x = self.conv2(x) return x data = torch.rand(1, 3, 10, 10) # This model is not executable since we just put all ops # in the same forward m = M() original = symbolic_trace(m) # nothing to fuse so skipping the fuse step qconfig_dict = {'': default_qconfig} prepared = prepare_fx(original, qconfig_dict) # not runnable quantized = convert_fx(prepared) # This checks that the dequantize from the output of first conv # is being propagated to the end, so that we don't insert extra # observers and also successfully fused two quantized::conv2d # patterns # one quantize_per_tensor for input # check exact counts of quantize and dequantize count_check = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method('dequantize'): 1 } order_check = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize'), ] self.checkGraphModuleNodes(quantized, expected_node_occurrence=count_check, expected_node_list=order_check)
def test_linear(self): class ModuleLinear(torch.nn.Module): def __init__(self, has_relu=False, f_relu=False): super(ModuleLinear, self).__init__() self.linear = torch.nn.Linear(30, 4).float() if has_relu: if f_relu: self.relu = F.relu else: self.relu = torch.nn.ReLU() else: self.relu = torch.nn.Identity() def forward(self, x): return self.relu(self.linear(x)) class FuncLinear(torch.nn.Module): def __init__(self, has_relu=False, f_relu=False): super(FuncLinear, self).__init__() self.w = torch.randn(4, 30) self.b = torch.randn(4) if has_relu: if f_relu: self.relu = F.relu else: self.relu = torch.nn.ReLU() else: self.relu = torch.nn.Identity() def forward(self, x): return self.relu(F.linear(x, self.w, self.b)) data = (torch.rand((1, 30), dtype=torch.float), ) options = itertools.product( [(ModuleLinear(has_relu=False), True)], # TODO: enable after raw `tensor` is supported in fx # (FuncLinear(has_relu=False), False)], self.all_quant_types) quantized_nodes = { # is_module True: { # quant_type: QuantType.DYNAMIC: ns.call_module(nnqd.Linear), QuantType.STATIC: ns.call_module(nnq.Linear), # note that we are checking the final result QuantType.QAT: ns.call_module(nnq.Linear), }, False: { # quant_type: QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic), QuantType.STATIC: ns.call_function(torch.ops.quantized.linear), QuantType.QAT: ns.call_function(torch.ops.quantized.linear), } } for (model, is_module), quant_type in options: self.checkGraphModeFxOp(model, data, quant_type, quantized_nodes[is_module][quant_type]) for f_relu, quant_type in itertools.product( [True, False], [QuantType.STATIC, QuantType.QAT]): for model, quantized_node in [(ModuleLinear(has_relu=True, f_relu=f_relu), ns.call_module(nniq.LinearReLU))]: # TODO: support functional linear + relu fusion # (FuncLinear(has_relu=True, f_relu=f_relu), ns.call_function(torch.ops.quantized.linear_relu))]: self.checkGraphModeFxOp(model, data, quant_type, quantized_node)
def test_standalone_module_quantized_interface(self): quantized_interface_config = { "input_quantized_idxs": [0], # quantized input "output_quantized_idxs": [0], # quantized output } interface_config = quantized_interface_config # TODO: input_quantized_idxs only supports quint8, we can remove this # custom_backend_config_dict after # the `input_quantized_idxs` supports more complicated # configurations, as a first step we can change it to use a dictionary from # index to dtype qconfig = torch.ao.quantization.QConfig( activation=torch.ao.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, dtype=torch.quint8 ), weight=torch.ao.quantization.default_weight_observer ) weighted_op_quint8_dtype_config = { # optional, input activation dtype "input_dtype": torch.quint8, # optional, weight dtype "weight_dtype": torch.qint8, # optional, bias dtype "bias_dtype": torch.float, # optional, output activation dtype "output_dtype": torch.quint8 } conv_module_config = { "pattern": torch.nn.Conv2d, "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, "dtype_configs": [ weighted_op_quint8_dtype_config, ], "root_module": torch.nn.Conv2d, "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, } custom_backend_config_dict = { "configs": [conv_module_config] } # observer for input and output of first conv prepare_count_check = { ns.call_module(torch.ao.quantization.HistogramObserver): 2 } # for output of conv in the standalone module standalone_prepare_count_check = { ns.call_module(torch.ao.quantization.HistogramObserver): 1 } convert_count_check = { # quantizing input/output for reference conv ns.call_function(torch.quantize_per_tensor) : 2, ns.call_module(nnqr.Conv2d) : 1, # dequantize the input of reference conv and # dequantizing output of standalone module ns.call_method("dequantize") : 2, } standalone_convert_count_check = { # quantization of input happens in parent module # quantization of output happens in the standalone module ns.call_function(torch.quantize_per_tensor) : 1, ns.call_module(nnqr.Conv2d): 1, # dequantization of input happens in the standalone module # dequantization for output happens in parent module ns.call_method("dequantize") : 1, } self._test_standalone_module( interface_config, prepare_count_check, standalone_prepare_count_check, convert_count_check, standalone_convert_count_check, qconfig=qconfig, backend_config_dict=custom_backend_config_dict)
def test_conv_add_standalone_module(self): class Standalone(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.relu = torch.nn.ReLU() def forward(self, x, y): return self.relu(self.conv(x) + y) class M(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.standalone = Standalone() def forward(self, x, y): y = self.conv(x) return self.standalone(x, y) from torch.ao.quantization.fx.backend_config_dict.observation_type import ObservationType weighted_op_quint8_dtype_config = { # optional, input activation dtype # TODO: change back to torch.qint8 after input_quantized_idxs and output_quantized_idxs # are more flexible "input_dtype": torch.quint8, # optional, weight dtype "weight_dtype": torch.qint8, # optional, bias dtype "bias_dtype": torch.float, # optional, output activation dtype "output_dtype": torch.quint8 } conv_add_config = { "pattern": (torch.nn.ReLU, (operator.add, torch.nn.Conv2d, MatchAllNode)), "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, "dtype_configs": [ weighted_op_quint8_dtype_config, ], "root_module": torch.nn.Conv2d, # "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, } conv_config = { "pattern": torch.nn.Conv2d, "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, "dtype_configs": [ weighted_op_quint8_dtype_config, ], "root_module": torch.nn.Conv2d, # "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d, } m = M().eval() backend_config_dict = { "configs": [ conv_add_config, conv_config, ] } prepare_custom_config_dict = { "standalone_module_name": [("standalone", None, {"input_quantized_idxs": [0, 1]}, None)] } # TODO: use self.qconfig after input_quantized_idxs and output_quantized_idxs # are more flexible qconfig = torch.ao.quantization.QConfig( activation=torch.ao.quantization.observer.HistogramObserver.with_args( qscheme=torch.per_tensor_symmetric, dtype=torch.quint8 ), weight=torch.ao.quantization.default_weight_observer ) m = prepare_fx( m, {"": qconfig}, prepare_custom_config_dict=prepare_custom_config_dict, backend_config_dict=backend_config_dict) node_occurrence = { # for input and output of conv, where input is used twice, once in conv and # once in standalone module ns.call_module(torch.ao.quantization.HistogramObserver): 2, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) standalone_node_occurrence = { # output of the standalone module ns.call_module(torch.ao.quantization.HistogramObserver): 1, } self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence) m = _convert_fx_do_not_use(m, is_reference=True, backend_config_dict=backend_config_dict) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 3, ns.call_module(nn.Conv2d): 1, ns.call_method("dequantize"): 1, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) standalone_node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_module(nn.Conv2d): 1, ns.call_module(torch.nn.ReLU): 1, ns.call_method("dequantize"): 3, } self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
def test_input_weight_equalization_prepare(self): """ Tests that graphs created after prepare_fx is as expected """ qconfig_dict = { "": None, "object_type": [(nn.Linear, default_qconfig), (nn.functional.linear, default_qconfig)] } default_equalization_qconfig_dict = { "": default_qconfig, "object_type": [(nn.Linear, default_equalization_qconfig), (nn.functional.linear, default_equalization_qconfig)] } # Basic test with one linear layer class LinearModule(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(1, 1) def forward(self, x): return self.linear(x) linear_node_occurrence = { ns.call_module(_InputEqualizationObserver): 1, ns.call_module(MinMaxObserver): 2, } # Test with two linear layers class Linear2Module(nn.Module): def __init__(self): super().__init__() self.linear1 = nn.Linear(1, 1) self.linear2 = nn.Linear(1, 1) def forward(self, x): x = self.linear1(x) x = self.linear2(x) return x linear2_node_occurrence = { ns.call_module(_InputEqualizationObserver): 2, ns.call_module(MinMaxObserver): 3, } class Linear(torch.nn.Module): def __init__(self): super().__init__() self.w = torch.ones(5, 5) self.b = torch.zeros(5) def forward(self, x): return nn.functional.linear(x, self.w, self.b) # Test where we have two functional linear layers class FunctionalLinearModule(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = Linear() self.linear2 = Linear() def forward(self, x): x = self.linear1(x) x = self.linear2(x) return x functionalLinear_node_occurrence = { ns.call_module(_InputEqualizationObserver): 2, ns.call_module(_WeightEqualizationObserver): 2, ns.call_module(MinMaxObserver): 5, } # Test where we have a Linear layer followed by a fp32 operation # (conv layer) without a qconfig class FunctionalLinearFP32Module(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = Linear() self.conv = nn.Conv2d(3, 3, 1, 1) self.linear2 = Linear() def forward(self, x): x = self.linear1(x) x = torch.add(x, 5) x = self.linear2(x) return x fp32_equalization_qconfig_dict = { "": None, "object_type": [(nn.Linear, default_equalization_qconfig), (nn.functional.linear, default_equalization_qconfig)] } functionalLinearFP32_node_occurrence = { ns.call_module(_InputEqualizationObserver): 2, ns.call_module(_WeightEqualizationObserver): 2, ns.call_module(MinMaxObserver): 6, } tests = [(LinearModule, default_equalization_qconfig_dict, linear_node_occurrence), (Linear2Module, fp32_equalization_qconfig_dict, linear2_node_occurrence), (FunctionalLinearModule, default_equalization_qconfig_dict, functionalLinear_node_occurrence), (FunctionalLinearFP32Module, fp32_equalization_qconfig_dict, functionalLinearFP32_node_occurrence)] for (M, equalization_qconfig_dict, node_occurrence) in tests: m = M().eval() prepared = prepare_fx( m, qconfig_dict, equalization_qconfig_dict=equalization_qconfig_dict) self.checkGraphModuleNodes( prepared, expected_node_occurrence=node_occurrence)
def test_input_weight_equalization_branching(self): """ Tests that graphs containing branches are prepared correctly. Specifically, equalization observers should not be inserted in front of branches in which both initial layers in the branches plan to be quantized. """ # Tests that we do not add an equalization observer due to both initial # nodes in the branch containing layers that need to be equalized. # Note that this should print out 2 warning messages for not being able # to equalize layers linear1 and linear1 because it is part of a branch class TestBranchingWithoutEqualizationModel(nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = nn.Linear(5, 5) self.linear2 = nn.Linear(5, 5) def forward(self, x): y = self.linear1(x) z = self.linear2(x) return torch.add(y, z) no_eq_branching_node_occurrence = { ns.call_module(_InputEqualizationObserver): 0, ns.call_module(MinMaxObserver): 3, } m = TestBranchingWithoutEqualizationModel().eval() example_inputs = (torch.rand(1, 5), ) prepared = prepare_fx( m, specific_qconfig_dict, example_inputs=example_inputs, _equalization_config=default_equalization_qconfig_dict) self.checkGraphModuleNodes( prepared, expected_node_occurrence=no_eq_branching_node_occurrence) # Tests that we will add an equalization observer because there is only # one initial node in the branch that needs to be equalized class TestBranchingWithEqualizationModel(nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = nn.Linear(5, 5) def forward(self, x): y = self.linear1(x) z = torch.add(x, 5) return torch.add(y, z) eq_branching_node_occurrence = { ns.call_module(_InputEqualizationObserver): 1, ns.call_module(MinMaxObserver): 2, } m = TestBranchingWithEqualizationModel().eval() example_inputs = (torch.randn(1, 5), ) prepared = prepare_fx( m, specific_qconfig_dict, example_inputs=example_inputs, _equalization_config=default_equalization_qconfig_dict) self.checkGraphModuleNodes( prepared, expected_node_occurrence=eq_branching_node_occurrence)
def test_user_module(self): """ For user defined modules, 1. weight extraction should not crash 2. unshadowed activations should have loggers, loggers will only log if the output dtype is in the allowlist 3. shadowed activations should not have loggers (since I/O dtype is unknown) """ class UserModule(nn.Module): def forward(self, x): return x class M(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(1, 1) self.user_module = UserModule() def forward(self, x): x = self.linear(x) x = self.user_module(x) return x m = M().eval() # quantize without tracing through UserModule qconfig_dict = {'': torch.quantization.default_qconfig} prepare_custom_config_dict = { 'non_traceable_module_name': ['user_module'] } mp = prepare_fx(m, qconfig_dict, prepare_custom_config_dict) mp(torch.randn(1, 1, 1)) mq = convert_fx(copy.deepcopy(mp)) # weight extraction should not crash weights = _extract_weights_impl('fp32_prepared', mp, 'int8', mq) # unshadowed activations should have loggers # add loggers, without retracing # note: converting again because we cannot copy a quantized linear mp_ns, mq_ns = _add_loggers_impl('fp32_prepared', copy.deepcopy(mp), 'int8', convert_fx(copy.deepcopy(mp)), OutputLogger, should_log_inputs=True) # both fp32 and int8 models should have 4 loggers each, 2 for I/O # of linear, and 2 for I/O of user_module unshadowed_expected_occurrence = { ns.call_module(OutputLogger): 4, } self.checkGraphModuleNodes( mp_ns, expected_node_occurrence=unshadowed_expected_occurrence) self.checkGraphModuleNodes( mq_ns, expected_node_occurrence=unshadowed_expected_occurrence) # shadowed activations should only have loggers for nodes where # the types are known and we can do a dtype cast # add shadow loggers, without retracing mp_shadows_mq_ns = _add_shadow_loggers_impl('fp32_prepared', mp, 'int8', mq, OutputLogger, should_log_inputs=True) # 2 loggers for I/O of linear, 0 loggers for I/O of user_module shadowed_expected_occurrence = { ns.call_module(OutputLogger): 2, } self.checkGraphModuleNodes( mp_shadows_mq_ns, expected_node_occurrence=unshadowed_expected_occurrence)