Esempio n. 1
0
    def test_qconfig_dict(self):
        data = [(torch.randn(10, 5, dtype=torch.float) * 20, 1)]

        # Eager mode
        qconfig = QConfig(activation=Observer, weight=WeightObserver)
        eager_module = AnnotatedNestedModel()
        eager_module.fc3.qconfig = qconfig
        eager_module.sub2.fc1.qconfig = qconfig
        # Assign weights
        eager_module.sub1.fc.weight.data.fill_(1.0)
        eager_module.sub2.fc1.module.weight.data.fill_(1.0)
        eager_module.sub2.fc2.weight.data.fill_(1.0)
        eager_module.fc3.module.weight.data.fill_(1.0)

        script_module = torch.jit.script(NestedModel())
        # Copy weights for eager_module
        script_module.sub1.fc.weight = eager_module.sub1.fc.weight
        script_module.sub2.fc1.weight = eager_module.sub2.fc1.module.weight
        script_module.sub2.fc2.weight = eager_module.sub2.fc2.weight
        script_module.fc3.weight = eager_module.fc3.module.weight

        # Quantize eager module
        quantized_eager_module = quantize(eager_module, default_eval_fn, data)

        def get_forward(m):
            return m._c._get_method('forward')

        # Quantize script_module
        torch._C._jit_pass_constant_propagation(
            get_forward(script_module).graph)

        ScriptedObserver = torch.jit.script(Observer())
        ScriptedWeightObserver = torch.jit.script(WeightObserver())
        scripted_qconfig = QConfig(activation=ScriptedObserver._c,
                                   weight=ScriptedWeightObserver._c)
        qconfig_dict = {'sub2.fc1': scripted_qconfig, 'fc3': scripted_qconfig}
        torch._C._jit_pass_insert_observers(script_module._c, "forward",
                                            qconfig_dict)

        # Run script_module and Collect statistics
        get_forward(script_module)(data[0][0])

        # Insert quantize and dequantize calls
        script_module._c = torch._C._jit_pass_insert_quant_dequant(
            script_module._c, "forward")
        # Note that observer modules are not removed right now
        torch._C._jit_pass_quant_fusion(
            script_module._c._get_method('forward').graph)
        get_forward(script_module)(data[0][0])
        eager_result = quantized_eager_module(data[0][0])
        script_result = get_forward(script_module)(data[0][0])
        self.assertEqual(eager_result, script_result)
    def test_nested1(self):
        r"""Test quantization for nested model, top level 'fc3' and
        'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
        """
        model = AnnotatedNestedModel()

        def checkPrepModules(model, before_calib=False):
            if before_calib:
                self.checkObservers(model)
            self.checkNoPrepModules(model)
            self.checkNoPrepModules(model.sub1)
            self.checkNoPrepModules(model.sub1.fc)
            self.checkNoPrepModules(model.sub1.relu)
            self.checkNoPrepModules(model.sub2)
            self.checkHasPrepModules(model.sub2.fc1)
            self.checkNoPrepModules(model.sub2.fc2)
            self.checkHasPrepModules(model.fc3)

        model = prepare(model)
        checkPrepModules(model, True)
        test_only_eval_fn(model, self.calib_data)
        model = convert(model)

        def checkQuantized(model):
            checkPrepModules(model)
            self.checkLinear(model.sub1.fc)
            self.checkWrappedQuantizedLinear(model.fc3)
            self.checkWrappedQuantizedLinear(model.sub2.fc1)
            self.checkLinear(model.sub2.fc2)
            test_only_eval_fn(model, self.calib_data)
            self.checkScriptable(model, self.calib_data)

        checkQuantized(model)

        # test one line API
        model = quantize(AnnotatedNestedModel(), test_only_eval_fn,
                         self.calib_data)
        checkQuantized(model)