Exemplo n.º 1
0
    def test_single_layer(self):
        r"""Dynamic Quantize SingleLayerLinearDynamicModel which has one Linear module,
        make sure it is swapped to nnqd.Linear which is the quantized version of
        the module
        """
        model = SingleLayerLinearDynamicModel().eval()
        qconfig_dict = {
            '': default_dynamic_qconfig
        }
        prepare_dynamic(model, qconfig_dict)
        convert_dynamic(model)

        def checkQuantized(model):
            self.checkDynamicQuantizedLinear(model.fc1)
            self.checkScriptable(model, self.calib_data, check_save_load=True)

        checkQuantized(model)

        # test one line API - out of place version
        base = SingleLayerLinearDynamicModel()
        keys_before = set(list(base.state_dict().keys()))
        model = quantize_dynamic(base, qconfig_dict)
        checkQuantized(model)
        keys_after = set(list(base.state_dict().keys()))
        self.assertEqual(keys_before, keys_after)  # simple check that nothing changed

        # in-place version
        model = SingleLayerLinearDynamicModel()
        quantize_dynamic(model, qconfig_dict, inplace=True)
        checkQuantized(model)

        # Test set qconfig
        model = SingleLayerLinearDynamicModel()
        quantize_dynamic(model, set([nn.Linear]), inplace=True)
        checkQuantized(model)
Exemplo n.º 2
0
    def test_single_layer(self):
        r"""Dynamic Quantize SingleLayerLinearDynamicModel which has one Linear module,
        make sure it is swapped to nnqd.Linear which is the quantized version of
        the module
        """
        model = SingleLayerLinearDynamicModel().eval()
        qconfig_dict = {'': default_dynamic_qconfig}
        model = prepare_dynamic(model, qconfig_dict)
        convert_dynamic(model)

        def checkQuantized(model):
            self.checkDynamicQuantizedLinear(model.fc1)

        checkQuantized(model)

        # test one line API
        model = quantize_dynamic(SingleLayerLinearDynamicModel().eval(),
                                 qconfig_dict)
        checkQuantized(model)