def test_conv_functional_qat(self):

        class M(torch.nn.Module):
            def __init__(self, weight2d, bias2d):
                super().__init__()
                self.weight2d = torch.nn.Parameter(weight2d)
                self.bias2d = torch.nn.Parameter(bias2d)
                self.stride2d = (1, 1)
                self.padding2d = (0, 0)
                self.dilation2d = (1, 1)
                self.groups = 1

            def forward(self, x):
                x = F.conv2d(
                    x, self.weight2d, self.bias2d, self.stride2d, self.padding2d,
                    self.dilation2d, self.groups)
                return x

        m = M(torch.randn(1, 1, 1, 1), torch.randn(1)).eval()
        qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        self._test_auto_tracing(m, qconfig, (torch.randn(1, 1, 2, 2),))

        # test backprop does not crash
        inputs = torch.randn(1, 1, 1, 1)
        inputs.requires_grad = True
        mp = _quantize_dbr.prepare(m, {'': qconfig}, (inputs,))
        output = mp(inputs)
        labels = torch.randn(1, 1, 1, 1)
        loss = (output - labels).sum()
        loss.backward()
        optim = torch.optim.SGD(mp.parameters(), lr=0.01)
        optim.step()
 def test_qconfig_dict_module_name(self):
     """
     Verifies that the 'module_name' option of qconfig_dict works
     on module types.
     """
     m = nn.Sequential(
         nn.Sequential(
             nn.Conv2d(1, 1, 1),
         ),
         nn.Conv2d(1, 1, 1),
         nn.Sequential(
             nn.Conv2d(1, 1, 1),
             nn.Conv2d(1, 1, 1),
         ),
     )
     qconfig_dict = {
         '': torch.quantization.default_qconfig,
         'module_name': [
             ('0', torch.quantization.default_qconfig),
             ('1', None),
             ('2.0', None),
         ],
     }
     example_args = (torch.randn(1, 1, 1, 1),)
     mp = _quantize_dbr.prepare(m, qconfig_dict, example_args)
     mp(*example_args)
     mq = _quantize_dbr.convert(mp)
     mq(*example_args)
     self.assertTrue(isinstance(mq[0][0], nnq.Conv2d))
     self.assertTrue(isinstance(mq[1], nn.Conv2d))
     self.assertTrue(isinstance(mq[2][0], nn.Conv2d))
     self.assertTrue(isinstance(mq[2][1], nnq.Conv2d))
    def test_conv_mod_qat(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x1 = self.conv(x)
                return x1

        m = M().eval()
        qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        self._test_auto_tracing(
            copy.deepcopy(m), qconfig, (torch.randn(1, 1, 2, 2),))

        # test backprop does not crash
        inputs = torch.randn(1, 1, 1, 1)
        inputs.requires_grad = True
        mp = _quantize_dbr.prepare(m, {'': qconfig}, (inputs,))
        output = mp(inputs)
        labels = torch.randn(1, 1, 1, 1)
        loss = (output - labels).sum()
        loss.backward()
        optim = torch.optim.SGD(mp.parameters(), lr=0.01)
        optim.step()
    def test_qconfig_dict_object_type_function_global_none(self):
        """
        Verifies that the 'object_type' option of qconfig_dict works
        on function types when global qconfig is None.
        """
        class M(nn.Module):
            def forward(self, x):
                x = x + x
                return x

        m = M()
        qconfig_dict = {
            '': None,
            'object_type': [
                (torch.add, torch.quantization.default_qconfig),
            ],
        }
        example_args = (torch.randn(1, 1, 1, 1),)
        mp = _quantize_dbr.prepare(m, qconfig_dict, example_args)
        mp(*example_args)
        mq = _quantize_dbr.convert(mp)
        mq(*example_args)
        rewritten = mq.rewrite_for_scripting()
        expected_occurrence = {
            NodeSpec.call_function(torch.add): 0,
            NodeSpec.call_function(toq.add): 1,
        }
        self.checkGraphModuleNodes(
            rewritten, expected_node_occurrence=expected_occurrence)
示例#5
0
    def test_prepare_custom_config_dict_non_traceable_module_class_mid_leaf(
            self):
        # if M2 is set as leaf, only M1 should have auto_quant_state
        qconfig_dict = {'': torch.quantization.default_qconfig}
        m, M1, M2, M3 = self._get_non_traceable_module_class_test_model()
        prepare_custom_config_dict = {
            'non_traceable_module_class': [M2],
        }
        mp = _quantize_dbr.prepare(
            m,
            qconfig_dict, (torch.randn(1, 1, 1, 1), ),
            prepare_custom_config_dict=prepare_custom_config_dict)
        self.assertTrue(not hasattr(mp.m2.m1, '_auto_quant_state'))
        self.assertTrue(not hasattr(mp.m2, '_auto_quant_state'))
        self.assertTrue(hasattr(mp, '_auto_quant_state'))
        mq = _quantize_dbr.convert(mp)
        self.assertTrue(isinstance(mq.m2.m1.conv, nn.Conv2d))
        self.assertTrue(isinstance(mq.m2.conv, nn.Conv2d))
        mqt = torch.jit.trace(mq, (torch.randn(1, 1, 1, 1), ))

        # mqt.m2 and all children should not have quantized ops
        FileCheck().check_count("aten::add", 1,
                                exactly=True).run(mqt.m2.m1.graph)
        FileCheck().check_count("quantized::add", 0,
                                exactly=True).run(mqt.m2.m1.graph)
        FileCheck().check_count("aten::add", 1, exactly=True).run(mqt.m2.graph)
        FileCheck().check_count("quantized::add", 0,
                                exactly=True).run(mqt.m2.graph)
    def test_qconfig_dict_global(self):
        """
        Verifies that the '' option of qconfig_dict works
        """

        # regular case
        m = nn.Sequential(nn.Conv2d(1, 1, 1))
        qconfig_dict = {'': torch.quantization.default_qconfig}
        example_args = (torch.randn(1, 1, 1, 1),)
        mp = _quantize_dbr.prepare(m, qconfig_dict, example_args)
        mp(*example_args)
        mq = _quantize_dbr.convert(mp)
        mq(*example_args)
        self.assertTrue(isinstance(mq[0], nnq.Conv2d))

        # quantization turned off
        m = nn.Sequential(nn.Conv2d(1, 1, 1))
        qconfig_dict = {'': None}
        example_args = (torch.randn(1, 1, 1, 1),)
        mp = _quantize_dbr.prepare(m, qconfig_dict, example_args)
        mp(*example_args)
        mq = _quantize_dbr.convert(mp)
        mq(*example_args)
        self.assertTrue(isinstance(mq[0], nn.Conv2d))
 def test_observers_not_touched_by_tracing(self):
     """
     Verifies that running dynamic tracing does not change any data
     stored in observers and fake quants.
     """
     m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
     qconfig = torch.quantization.default_qconfig
     mp = _quantize_dbr.prepare(m, {'': qconfig}, (torch.randn(1, 1, 1, 1),))
     for _, mod in mp.named_modules():
         if isinstance(mod, (ObserverBase, FakeQuantizeBase)):
             scale, zp = mod.calculate_qparams()
             # Assume that if scale is 1.0 and zp is 0, no calibration
             # has happened.
             self.assertTrue(torch.allclose(scale, torch.ones(1)))
             self.assertTrue(torch.equal(zp, torch.zeros(1, dtype=torch.long)))
    def test_numeric_suite(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = nn.Conv2d(1, 1, 1)
                self.conv2 = nn.Sequential(nn.Conv2d(1, 1, 1))

            def forward(self, x):
                x = self.conv(x)
                x = self.conv2(x)
                x = x + x
                return x

        m = M().eval()
        qconfig = torch.quantization.default_qconfig
        example_args = (torch.randn(1, 1, 2, 2),)
        mp = _quantize_dbr.prepare(m, {'': qconfig}, example_args)
        out_p = mp(*example_args)
        mq = _quantize_dbr.convert(copy.deepcopy(mp))
        out_q = mq(*example_args)

        mp, mq = ns.add_loggers('mp', mp, 'mq', mq)

        mp(*example_args)
        mq(*example_args)

        act_comparison = ns.extract_logger_info(mp, mq, 'mq')
        ns_fx.extend_logger_results_with_comparison(
            act_comparison, 'mp', 'mq', torch.ao.ns.fx.utils.compute_sqnr,
            'sqnr')

        # TODO(future PR): enforce validity of the result above, using
        # NS for FX utils. Will need some refactoring.

        # TODO(future PR): consider adding a util for below
        to_print = []
        for idx, (layer_name, v) in enumerate(act_comparison.items()):
            to_print.append([
                layer_name,
                v['node_output']['mq'][0]['fqn'],
                v['node_output']['mq'][0]['ref_node_target_type'],
                v['node_output']['mq'][0]['sqnr']])
    def test_fusion(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)
                self.relu = torch.nn.ReLU()
                self.child = nn.Sequential(
                    nn.Conv2d(1, 1, 1),
                    nn.ReLU(),
                )

            def forward(self, x):
                x = self.conv(x)
                x = self.relu(x)
                x = self.child(x)
                return x

        m = M().eval()
        qconfig = torch.quantization.default_qconfig
        mp = _quantize_dbr.prepare(m, {'': qconfig}, (torch.randn(1, 1, 1, 1),))
        self.assertTrue(isinstance(mp.conv, nni.ConvReLU2d))
        self.assertTrue(isinstance(mp.child[0], nni.ConvReLU2d))
示例#10
0
    def test_unsupported_ops_recorded(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv2d = nn.Conv2d(1, 1, 1)
                self.softshrink = nn.Softshrink()

            def forward(self, x):
                # supported
                x = self.conv2d(x)
                x = x + x
                # not supported
                x = self.softshrink(x)
                x = F.tanhshrink(x)
                return x

        m = M().eval()
        qconfig_dict = {'': torch.quantization.default_qconfig}
        mp = _quantize_dbr.prepare(m, qconfig_dict,
                                   (torch.randn(1, 1, 1, 1), ))
        expected = set([nn.Softshrink, F.tanhshrink])
        self.assertTrue(
            mp._auto_quant_state.seen_op_types_without_op_hooks == expected)
    def _test_auto_tracing(
        self,
        m,
        qconfig,
        example_args,
        fuse_modules=True,
        do_fx_comparison=True,
        do_torchscript_checks=True,
    ):
        m_copy = copy.deepcopy(m)

        qconfig_dict = {'': qconfig}

        mp = _quantize_dbr.prepare(
            m, qconfig_dict, example_args, fuse_modules=fuse_modules)
        out_p = mp(*example_args)
        # print(mp)
        mq = _quantize_dbr.convert(mp)
        # print(mq)
        # verify it runs
        out_q = mq(*example_args)
        # print(out_q)

        # compare it against FX
        if do_fx_comparison:
            m_copy_p = prepare_fx(m_copy, {'': qconfig})
            out_m_copy_p = m_copy_p(*example_args)
            # print(m_copy_p)
            m_copy_q = convert_fx(m_copy_p)
            # print(m_copy_q)
            # print(m_copy_q.graph)
            out_q_fx = m_copy_q(*example_args)
            # print(out_q)
            # print(out_q_fx)
            self.assertTrue(_allclose(out_p, out_m_copy_p))
            # print(out_q)
            # print(out_q_fx)
            self.assertTrue(_allclose(out_q, out_q_fx))

        if do_torchscript_checks:
            # verify torch.jit.trace works
            mq_jit_traced = torch.jit.trace(
                mq, example_args, check_trace=False)
            # print(mq_jit_traced.graph)
            traced_out = mq_jit_traced(*example_args)
            self.assertTrue(_allclose(traced_out, out_q))

            # verify torch.jit.script works
            rewritten = mq.rewrite_for_scripting()
            rewritten_out = rewritten(*example_args)
            # print(rewritten)
            self.assertTrue(_allclose(rewritten_out, out_q))

            scripted_rewritten = torch.jit.script(rewritten)
            # print(scripted_rewritten.graph)
            scripted_rewritten_out = scripted_rewritten(*example_args)
            # print('scripted_rewritten_out', scripted_rewritten_out)
            self.assertTrue(_allclose(scripted_rewritten_out, out_q))

            traced_rewritten = torch.jit.trace(
                rewritten, example_args, check_trace=False)
            traced_rewritten_out = traced_rewritten(*example_args)
            self.assertTrue(_allclose(traced_rewritten_out, out_q))