Пример #1
0
    def test_nested3(self):
        r"""More complicated nested test case with child qconfig overrides
        parent qconfig
        """
        model = NestedModel().eval()
        custum_options = {
            'dtype': torch.quint8,
            'qscheme': torch.per_tensor_affine
        }
        custom_dynamic_qconfig = QConfigDynamic(weight=default_weight_observer)
        qconfig_dynamic_dict = {
            'fc3': default_dynamic_qconfig,
            'sub2': default_dynamic_qconfig,
            'sub2.fc1': custom_dynamic_qconfig
        }
        prepare_dynamic(model, qconfig_dynamic_dict)

        convert_dynamic(model)

        def checkQuantized(model):
            self.checkDynamicQuantizedLinear(model.sub2.fc1)
            self.checkDynamicQuantizedLinear(model.sub2.fc2)
            self.checkDynamicQuantizedLinear(model.fc3)
            self.checkScriptable(model, self.calib_data, check_save_load=True)

        checkQuantized(model)

        # test one line API
        model = quantize_dynamic(NestedModel().eval(), qconfig_dynamic_dict)
        checkQuantized(model)
Пример #2
0
    def test_type_match_rule(self):
        r"""Test quantization for nested model, top level 'fc3' and
        'fc1' of submodule 'sub2', All 'torch.nn.Linear' modules are quantized
        """
        model = NestedModel().eval()
        qconfig_dict = {
            'fc3': None,
            'sub2.fc1': None,
            torch.nn.Linear: default_dynamic_qconfig
        }

        prepare_dynamic(model, qconfig_dict)
        test_only_eval_fn(model, self.calib_data)
        convert_dynamic(model)

        def checkQuantized(model):
            self.checkDynamicQuantizedLinear(model.sub1.fc)
            self.checkLinear(model.fc3)
            self.checkLinear(model.sub2.fc1)
            self.checkDynamicQuantizedLinear(model.sub2.fc2)
            test_only_eval_fn(model, self.calib_data)
            self.checkScriptable(model, self.calib_data, check_save_load=True)

        checkQuantized(model)

        # test one line API
        model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
        checkQuantized(model)
Пример #3
0
    def test_nested1(self):
        r"""Test quantization for nested model, top level 'fc3' and
        'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
        """
        model = NestedModel().eval()
        qconfig_dict = {
            'fc3': default_dynamic_qconfig,
            'sub2.fc1': default_dynamic_qconfig
        }

        prepare_dynamic(model, qconfig_dict)
        convert_dynamic(model)

        def checkQuantized(model):
            self.checkLinear(model.sub1.fc)
            self.checkDynamicQuantizedLinear(model.fc3)
            self.checkDynamicQuantizedLinear(model.sub2.fc1)
            self.checkLinear(model.sub2.fc2)
            self.checkScriptable(model, self.calib_data, check_save_load=True)

        checkQuantized(model)

        # test one line API
        model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
        checkQuantized(model)
Пример #4
0
    def test_nested2(self):
        r"""Another test case for quantized, we will quantize all submodules
        of submodule sub2
        """
        model = NestedModel().eval()
        qconfig_dict = {
            'fc3': default_dynamic_qconfig,
            'sub2': default_dynamic_qconfig
        }
        prepare_dynamic(model, qconfig_dict)

        convert_dynamic(model)

        def checkQuantized(model):
            self.checkLinear(model.sub1.fc)
            self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
            self.checkDynamicQuantizedLinear(model.sub2.fc1)
            self.checkDynamicQuantizedLinear(model.sub2.fc2)
            self.checkDynamicQuantizedLinear(model.fc3)
            self.checkScriptable(model, self.calib_data, check_save_load=True)

        checkQuantized(model)

        # test one line API
        model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
        checkQuantized(model)
Пример #5
0
    def test_nested3(self):
        r"""More complicated nested test case with child qconfig overrides
        parent qconfig
        """
        model = NestedModel().eval()
        custum_options = {
            'dtype': torch.quint8,
            'qscheme': torch.per_tensor_affine
        }
        custom_qconfig = QConfig(weight=default_weight_observer(),
                                 activation=default_observer(**custum_options))
        qconfig_dict = {
            'fc3': default_qconfig,
            'sub2': default_qconfig,
            'sub2.fc1': custom_qconfig
        }
        model = prepare_dynamic(model, qconfig_dict)

        convert_dynamic(model)

        def checkQuantized(model):
            self.checkDynamicQuantizedLinear(model.sub2.fc1)
            self.checkDynamicQuantizedLinear(model.sub2.fc2)
            self.checkDynamicQuantizedLinear(model.fc3)

        checkQuantized(model)

        # test one line API
        model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
        checkQuantized(model)
Пример #6
0
    def test_nested3(self):
        r"""More complicated nested test case with child qconfig overrides
        parent qconfig
        """
        model = NestedModel().eval()
        custum_options = {
            'dtype': torch.quint8,
            'qscheme': torch.per_tensor_affine
        }
        custom_qconfig = QConfig(weight=default_weight_observer(),
                                 activation=default_observer(**custum_options))
        qconfig_dict = {
            'fc3': default_qconfig,
            'sub2': default_qconfig,
            'sub2.fc1': custom_qconfig
        }
        model = prepare(model, qconfig_dict)

        def checkPrepModules(model, before_calib=False):
            if before_calib:
                self.checkObservers(model)
            self.checkNoPrepModules(model)
            self.checkNoPrepModules(model.sub1)
            self.checkNoPrepModules(model.sub1.fc)
            self.checkNoPrepModules(model.sub1.relu)
            self.checkNoPrepModules(model.sub2)
            self.checkHasPrepModules(model.sub2.fc1)
            self.checkHasPrepModules(model.sub2.fc2)
            self.checkHasPrepModules(model.fc3)

        checkPrepModules(model, True)

        test_only_eval_fn(model, self.calib_data)
        convert(model)

        def checkQuantized(model):
            checkPrepModules(model)
            self.checkQuantizedLinear(model.sub2.fc1)
            self.checkQuantizedLinear(model.sub2.fc2)
            self.checkQuantizedLinear(model.fc3)
            test_only_eval_fn(model, self.calib_data)

        checkQuantized(model)

        # test one line API
        model = quantize(NestedModel().eval(), test_only_eval_fn,
                         self.calib_data, qconfig_dict)
        checkQuantized(model)
Пример #7
0
    def test_nested2(self):
        r"""Another test case for quantized, we will quantize all submodules
        of submodule sub2, this will include redundant quant/dequant, to
        remove them we need to manually call QuantWrapper or insert
        QuantStub/DeQuantStub, see `test_quant_dequant_wrapper` and
        `test_manual`
        """
        model = NestedModel().eval()
        qconfig_dict = {'fc3': default_qconfig, 'sub2': default_qconfig}
        model = prepare(model, qconfig_dict)

        def checkPrepModules(model, before_calib=False):
            if before_calib:
                self.checkObservers(model)
            self.checkNoPrepModules(model)
            self.checkNoPrepModules(model.sub1)
            self.checkNoPrepModules(model.sub1.fc)
            self.checkNoPrepModules(model.sub1.relu)
            self.checkNoPrepModules(model.sub2)
            self.checkHasPrepModules(model.sub2.fc1)
            self.checkHasPrepModules(model.sub2.fc2)
            self.checkHasPrepModules(model.fc3)

        checkPrepModules(model, True)

        test_only_eval_fn(model, self.calib_data)
        convert(model)

        def checkQuantized(model):
            checkPrepModules(model)
            self.checkLinear(model.sub1.fc)
            self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
            self.checkQuantizedLinear(model.sub2.fc1)
            self.checkQuantizedLinear(model.sub2.fc2)
            self.checkQuantizedLinear(model.fc3)
            test_only_eval_fn(model, self.calib_data)

        checkQuantized(model)

        # test one line API
        model = quantize(NestedModel().eval(), test_only_eval_fn,
                         self.calib_data, qconfig_dict)
        checkQuantized(model)
Пример #8
0
    def test_qconfig_dict(self):
        data = [(torch.randn(10, 5, dtype=torch.float) * 20, 1)]

        # Eager mode
        qconfig = QConfig(activation=Observer, weight=WeightObserver)
        eager_module = AnnotatedNestedModel()
        eager_module.fc3.qconfig = qconfig
        eager_module.sub2.fc1.qconfig = qconfig
        # Assign weights
        eager_module.sub1.fc.weight.data.fill_(1.0)
        eager_module.sub2.fc1.module.weight.data.fill_(1.0)
        eager_module.sub2.fc2.weight.data.fill_(1.0)
        eager_module.fc3.module.weight.data.fill_(1.0)

        script_module = torch.jit.script(NestedModel())
        # Copy weights for eager_module
        script_module.sub1.fc.weight = eager_module.sub1.fc.weight
        script_module.sub2.fc1.weight = eager_module.sub2.fc1.module.weight
        script_module.sub2.fc2.weight = eager_module.sub2.fc2.weight
        script_module.fc3.weight = eager_module.fc3.module.weight

        # Quantize eager module
        quantized_eager_module = quantize(eager_module, default_eval_fn, data)

        def get_forward(m):
            return m._c._get_method('forward')

        # Quantize script_module
        torch._C._jit_pass_constant_propagation(
            get_forward(script_module).graph)

        ScriptedObserver = torch.jit.script(Observer())
        ScriptedWeightObserver = torch.jit.script(WeightObserver())
        scripted_qconfig = QConfig(activation=ScriptedObserver._c,
                                   weight=ScriptedWeightObserver._c)
        qconfig_dict = {'sub2.fc1': scripted_qconfig, 'fc3': scripted_qconfig}
        torch._C._jit_pass_insert_observers(script_module._c, "forward",
                                            qconfig_dict)

        # Run script_module and Collect statistics
        get_forward(script_module)(data[0][0])

        # Insert quantize and dequantize calls
        script_module._c = torch._C._jit_pass_insert_quant_dequant(
            script_module._c, "forward")
        # Note that observer modules are not removed right now
        torch._C._jit_pass_quant_fusion(
            script_module._c._get_method('forward').graph)
        get_forward(script_module)(data[0][0])
        eager_result = quantized_eager_module(data[0][0])
        script_result = get_forward(script_module)(data[0][0])
        self.assertEqual(eager_result, script_result)
Пример #9
0
    def test_nested1(self):
        r"""Test quantization for nested model, top level 'fc3' and
        'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
        """
        model = NestedModel().eval()
        qconfig_dict = {'fc3': default_qconfig, 'sub2.fc1': default_qconfig}

        def checkPrepModules(model, before_calib=False):
            if before_calib:
                self.checkObservers(model)
            self.checkNoPrepModules(model)
            self.checkNoPrepModules(model.sub1)
            self.checkNoPrepModules(model.sub1.fc)
            self.checkNoPrepModules(model.sub1.relu)
            self.checkNoPrepModules(model.sub2)
            self.checkHasPrepModules(model.sub2.fc1)
            self.checkNoPrepModules(model.sub2.fc2)
            self.checkHasPrepModules(model.fc3)

        model = prepare(model, qconfig_dict)
        checkPrepModules(model, True)
        test_only_eval_fn(model, self.calib_data)
        convert(model)

        def checkQuantized(model):
            checkPrepModules(model)
            self.checkLinear(model.sub1.fc)
            self.checkQuantizedLinear(model.fc3)
            self.checkQuantizedLinear(model.sub2.fc1)
            self.checkLinear(model.sub2.fc2)
            test_only_eval_fn(model, self.calib_data)

        checkQuantized(model)

        # test one line API
        model = quantize(NestedModel().eval(), test_only_eval_fn,
                         self.calib_data, qconfig_dict)
        checkQuantized(model)