def test_qconfig_dict_module_name(self): """ Verifies that the 'module_name' option of qconfig_dict works on module types. """ m = nn.Sequential( nn.Sequential( nn.Conv2d(1, 1, 1), ), nn.Conv2d(1, 1, 1), nn.Sequential( nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1), ), ) qconfig_dict = { '': torch.quantization.default_qconfig, 'module_name': [ ('0', torch.quantization.default_qconfig), ('1', None), ('2.0', None), ], } example_args = (torch.randn(1, 1, 1, 1),) mp = _quantize_dbr.prepare(m, qconfig_dict, example_args) mp(*example_args) mq = _quantize_dbr.convert(mp) mq(*example_args) self.assertTrue(isinstance(mq[0][0], nnq.Conv2d)) self.assertTrue(isinstance(mq[1], nn.Conv2d)) self.assertTrue(isinstance(mq[2][0], nn.Conv2d)) self.assertTrue(isinstance(mq[2][1], nnq.Conv2d))
def test_qconfig_dict_object_type_function_global_none(self): """ Verifies that the 'object_type' option of qconfig_dict works on function types when global qconfig is None. """ class M(nn.Module): def forward(self, x): x = x + x return x m = M() qconfig_dict = { '': None, 'object_type': [ (torch.add, torch.quantization.default_qconfig), ], } example_args = (torch.randn(1, 1, 1, 1),) mp = _quantize_dbr.prepare(m, qconfig_dict, example_args) mp(*example_args) mq = _quantize_dbr.convert(mp) mq(*example_args) rewritten = mq.rewrite_for_scripting() expected_occurrence = { NodeSpec.call_function(torch.add): 0, NodeSpec.call_function(toq.add): 1, } self.checkGraphModuleNodes( rewritten, expected_node_occurrence=expected_occurrence)
def test_prepare_custom_config_dict_non_traceable_module_class_mid_leaf( self): # if M2 is set as leaf, only M1 should have auto_quant_state qconfig_dict = {'': torch.quantization.default_qconfig} m, M1, M2, M3 = self._get_non_traceable_module_class_test_model() prepare_custom_config_dict = { 'non_traceable_module_class': [M2], } mp = _quantize_dbr.prepare( m, qconfig_dict, (torch.randn(1, 1, 1, 1), ), prepare_custom_config_dict=prepare_custom_config_dict) self.assertTrue(not hasattr(mp.m2.m1, '_auto_quant_state')) self.assertTrue(not hasattr(mp.m2, '_auto_quant_state')) self.assertTrue(hasattr(mp, '_auto_quant_state')) mq = _quantize_dbr.convert(mp) self.assertTrue(isinstance(mq.m2.m1.conv, nn.Conv2d)) self.assertTrue(isinstance(mq.m2.conv, nn.Conv2d)) mqt = torch.jit.trace(mq, (torch.randn(1, 1, 1, 1), )) # mqt.m2 and all children should not have quantized ops FileCheck().check_count("aten::add", 1, exactly=True).run(mqt.m2.m1.graph) FileCheck().check_count("quantized::add", 0, exactly=True).run(mqt.m2.m1.graph) FileCheck().check_count("aten::add", 1, exactly=True).run(mqt.m2.graph) FileCheck().check_count("quantized::add", 0, exactly=True).run(mqt.m2.graph)
def test_qconfig_dict_global(self): """ Verifies that the '' option of qconfig_dict works """ # regular case m = nn.Sequential(nn.Conv2d(1, 1, 1)) qconfig_dict = {'': torch.quantization.default_qconfig} example_args = (torch.randn(1, 1, 1, 1),) mp = _quantize_dbr.prepare(m, qconfig_dict, example_args) mp(*example_args) mq = _quantize_dbr.convert(mp) mq(*example_args) self.assertTrue(isinstance(mq[0], nnq.Conv2d)) # quantization turned off m = nn.Sequential(nn.Conv2d(1, 1, 1)) qconfig_dict = {'': None} example_args = (torch.randn(1, 1, 1, 1),) mp = _quantize_dbr.prepare(m, qconfig_dict, example_args) mp(*example_args) mq = _quantize_dbr.convert(mp) mq(*example_args) self.assertTrue(isinstance(mq[0], nn.Conv2d))
def test_numeric_suite(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(1, 1, 1) self.conv2 = nn.Sequential(nn.Conv2d(1, 1, 1)) def forward(self, x): x = self.conv(x) x = self.conv2(x) x = x + x return x m = M().eval() qconfig = torch.quantization.default_qconfig example_args = (torch.randn(1, 1, 2, 2),) mp = _quantize_dbr.prepare(m, {'': qconfig}, example_args) out_p = mp(*example_args) mq = _quantize_dbr.convert(copy.deepcopy(mp)) out_q = mq(*example_args) mp, mq = ns.add_loggers('mp', mp, 'mq', mq) mp(*example_args) mq(*example_args) act_comparison = ns.extract_logger_info(mp, mq, 'mq') ns_fx.extend_logger_results_with_comparison( act_comparison, 'mp', 'mq', torch.ao.ns.fx.utils.compute_sqnr, 'sqnr') # TODO(future PR): enforce validity of the result above, using # NS for FX utils. Will need some refactoring. # TODO(future PR): consider adding a util for below to_print = [] for idx, (layer_name, v) in enumerate(act_comparison.items()): to_print.append([ layer_name, v['node_output']['mq'][0]['fqn'], v['node_output']['mq'][0]['ref_node_target_type'], v['node_output']['mq'][0]['sqnr']])
def _test_auto_tracing( self, m, qconfig, example_args, fuse_modules=True, do_fx_comparison=True, do_torchscript_checks=True, ): m_copy = copy.deepcopy(m) qconfig_dict = {'': qconfig} mp = _quantize_dbr.prepare( m, qconfig_dict, example_args, fuse_modules=fuse_modules) out_p = mp(*example_args) # print(mp) mq = _quantize_dbr.convert(mp) # print(mq) # verify it runs out_q = mq(*example_args) # print(out_q) # compare it against FX if do_fx_comparison: m_copy_p = prepare_fx(m_copy, {'': qconfig}) out_m_copy_p = m_copy_p(*example_args) # print(m_copy_p) m_copy_q = convert_fx(m_copy_p) # print(m_copy_q) # print(m_copy_q.graph) out_q_fx = m_copy_q(*example_args) # print(out_q) # print(out_q_fx) self.assertTrue(_allclose(out_p, out_m_copy_p)) # print(out_q) # print(out_q_fx) self.assertTrue(_allclose(out_q, out_q_fx)) if do_torchscript_checks: # verify torch.jit.trace works mq_jit_traced = torch.jit.trace( mq, example_args, check_trace=False) # print(mq_jit_traced.graph) traced_out = mq_jit_traced(*example_args) self.assertTrue(_allclose(traced_out, out_q)) # verify torch.jit.script works rewritten = mq.rewrite_for_scripting() rewritten_out = rewritten(*example_args) # print(rewritten) self.assertTrue(_allclose(rewritten_out, out_q)) scripted_rewritten = torch.jit.script(rewritten) # print(scripted_rewritten.graph) scripted_rewritten_out = scripted_rewritten(*example_args) # print('scripted_rewritten_out', scripted_rewritten_out) self.assertTrue(_allclose(scripted_rewritten_out, out_q)) traced_rewritten = torch.jit.trace( rewritten, example_args, check_trace=False) traced_rewritten_out = traced_rewritten(*example_args) self.assertTrue(_allclose(traced_rewritten_out, out_q))