Example #1
0
    def test_multi_linear_model_without_per_channel(self):
        torch.backends.quantized.engine = "qnnpack"

        q_config_mapping = QConfigMapping()
        q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))

        prepared_model = self._prepare_model_and_run_input(
            TwoLayerLinearModel(),
            q_config_mapping,
            TwoLayerLinearModel().get_example_inputs()[0],
        )

        # run the detector
        optims_str, per_channel_info = _detect_per_channel(prepared_model)

        # there should be optims possible
        self.assertNotEqual(
            optims_str,
            DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
        )
        self.assertEqual(per_channel_info["backend"], torch.backends.quantized.engine)
        self.assertEqual(len(per_channel_info["per_channel_status"]), 2)

        # for each linear layer, should be supported but not used
        for linear_key in per_channel_info["per_channel_status"].keys():
            module_entry = per_channel_info["per_channel_status"][linear_key]

            self.assertEqual(module_entry["per_channel_supported"], True)
            self.assertEqual(module_entry["per_channel_used"], False)
    def test_train_save_load_eval(self):
        r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict
        During eval, we first call prepare_qat and conver on the model and then load the state_dict
        and compare results against original model
        """
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = TwoLayerLinearModel()
                model = torch.ao.quantization.QuantWrapper(model)
                model.qconfig = torch.ao.quantization.get_default_qat_qconfig(
                    qengine)
                model = prepare_qat(model)

                fq_state_dict = model.state_dict()

                test_only_train_fn(model, self.train_data)
                model = convert(model)

                quant_state_dict = model.state_dict()

                x = torch.rand(2, 5, dtype=torch.float)
                ref = model(x)

                # Create model again for eval. Check result using quantized state_dict
                model = TwoLayerLinearModel()
                model = torch.ao.quantization.QuantWrapper(model)
                model.qconfig = torch.ao.quantization.get_default_qat_qconfig(
                    qengine)
                torch.ao.quantization.prepare_qat(model, inplace=True)
                new_state_dict = model.state_dict()

                # Check to make sure the model after prepare_qat has the same state_dict as original.
                self.assertEqual(set(fq_state_dict.keys()),
                                 set(new_state_dict.keys()))

                torch.ao.quantization.convert(model, inplace=True)
                model.eval()
                model.load_state_dict(quant_state_dict)
                out = model(x)
                self.assertEqual(ref, out)

                # Check model created using prepare has same state dict as quantized state_dict
                model = TwoLayerLinearModel()
                model.eval()
                model = torch.ao.quantization.QuantWrapper(model)
                model.qconfig = torch.ao.quantization.get_default_qconfig(
                    qengine)
                torch.ao.quantization.prepare(model, inplace=True)
                torch.ao.quantization.convert(model, inplace=True)
                self.assertEqual(set(model.state_dict().keys()),
                                 set(quant_state_dict.keys()))
                model.eval()
                model.load_state_dict(quant_state_dict)
                out = model(x)
                self.assertEqual(ref, out)