def test_linear_api(self, batch_size, in_features, out_features, use_bias,
                        use_fused, per_channel, qengine):
        """test API functionality for nn.quantized.linear and nn.intrinsic.quantized.linear_relu"""
        if qengine not in torch.backends.quantized.supported_engines:
            return
        if qengine == 'qnnpack':
            if IS_PPC or TEST_WITH_UBSAN:
                return
            per_channel = False
        with override_quantized_engine(qengine):
            W = torch.rand(out_features, in_features).float()
            if per_channel:
                scale_tensor = torch.ones(out_features, dtype=torch.double)
                zero_point_tensor = torch.zeros(out_features, dtype=torch.long)
                for i in range(len(scale_tensor)):
                    scale_tensor[i] = (i + 1.0) / 255.0
                W_q = torch.quantize_per_channel(W,
                                                 scales=scale_tensor,
                                                 zero_points=zero_point_tensor,
                                                 axis=0,
                                                 dtype=torch.qint8)
            else:
                W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8)

            X = torch.rand(batch_size, in_features).float()
            X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8)
            B = torch.rand(out_features).float() if use_bias else None
            scale = 0.5
            zero_point = 3
            if use_fused:
                qlinear = nnq_fused.LinearReLU(in_features, out_features)
            else:
                qlinear = nnq.Linear(in_features, out_features)

            # Run module with default-initialized parameters.
            # This tests that the constructor is correct.
            qlinear(X_q)

            qlinear.set_weight_bias(W_q, B)
            # Simple round-trip test to ensure weight()/set_weight() API
            self.assertEqual(qlinear.weight(), W_q, atol=1e-5)
            W_pack = qlinear._packed_params._packed_params

            qlinear.scale = float(scale)
            qlinear.zero_point = int(zero_point)
            Z_q = qlinear(X_q)
            # Check if the module implementation matches calling the
            # ops directly
            if use_fused:
                Z_ref = torch.ops.quantized.linear_relu(
                    X_q, W_pack, scale, zero_point)

                self.assertTrue('QuantizedLinearReLU' in str(qlinear))
            else:
                Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale,
                                                   zero_point)

                self.assertTrue('QuantizedLinear' in str(qlinear))
            self.assertEqual(Z_ref, Z_q)

            # Test serialization of quantized Linear Module using state_dict
            model_dict = qlinear.state_dict()
            self.assertEqual(model_dict['_packed_params.weight'], W_q)
            if use_bias:
                self.assertEqual(model_dict['_packed_params.bias'], B)
            b = io.BytesIO()
            torch.save(model_dict, b)
            b.seek(0)
            loaded_dict = torch.load(b)
            for key in model_dict:
                self.assertEqual(model_dict[key], loaded_dict[key])
            if use_fused:
                loaded_qlinear = nnq_fused.LinearReLU(in_features,
                                                      out_features)
            else:
                loaded_qlinear = nnq.Linear(in_features, out_features)
            loaded_qlinear.load_state_dict(loaded_dict)

            linear_unpack = torch.ops.quantized.linear_unpack
            self.assertEqual(
                linear_unpack(qlinear._packed_params._packed_params),
                linear_unpack(loaded_qlinear._packed_params._packed_params))
            if use_bias:
                self.assertEqual(qlinear.bias(), loaded_qlinear.bias())
            self.assertEqual(qlinear.scale, loaded_qlinear.scale)
            self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
            self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
            self.assertTrue(hasattr(qlinear, '_packed_params'))
            self.assertTrue(hasattr(loaded_qlinear, '_packed_params'))
            self.assertTrue(hasattr(qlinear, '_weight_bias'))
            self.assertTrue(hasattr(loaded_qlinear, '_weight_bias'))
            self.assertEqual(qlinear._weight_bias(),
                             loaded_qlinear._weight_bias())
            self.assertEqual(
                qlinear._weight_bias(),
                torch.ops.quantized.linear_unpack(
                    qlinear._packed_params._packed_params))
            Z_q2 = loaded_qlinear(X_q)
            self.assertEqual(Z_q, Z_q2)

            # The below check is meant to ensure that `torch.save` and `torch.load`
            # serialization works, however it is currently broken by the following:
            # https://github.com/pytorch/pytorch/issues/24045
            #
            # Instead, we currently check that the proper exception is thrown on save.
            # <start code>
            # b = io.BytesIO()
            # torch.save(qlinear, b)
            # b.seek(0)
            # loaded = torch.load(b)
            # self.assertEqual(qlinear.weight(), loaded.weight())
            # self.assertEqual(qlinear.scale, loaded.scale)
            # self.assertEqual(qlinear.zero_point, loaded.zero_point)
            # <end code>
            with self.assertRaisesRegex(
                    RuntimeError,
                    r'torch.save\(\) is not currently supported'):
                b = io.BytesIO()
                torch.save(qlinear, b)

            # Test JIT
            self.checkScriptable(qlinear,
                                 list(zip([X_q], [Z_ref])),
                                 check_save_load=True)

            # Test from_float.
            float_linear = torch.nn.Linear(in_features, out_features).float()
            float_linear.qconfig = torch.quantization.default_qconfig
            torch.quantization.prepare(float_linear, inplace=True)
            float_linear(X.float())
            # Sequential allows swapping using "convert".
            quantized_float_linear = torch.nn.Sequential(float_linear)
            quantized_float_linear = torch.quantization.convert(
                quantized_float_linear, inplace=True)

            # Smoke test to make sure the module actually runs
            quantized_float_linear(X_q)

            # Smoke test extra_repr
            self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
 def test_linear_relu(self):
     module = nniq.LinearReLU(3, 1, bias=True, dtype=torch.qint8)
     self._test_op(module, input_size=[1, 3], generate=False)
    def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel):
        """test API functionality for nn.quantized.linear and nn.intrinsic.quantized.linear_relu"""
        if torch.backends.quantized.engine == 'qnnpack':
            per_channel = False
        W = torch.rand(out_features, in_features).float()
        if per_channel:
            scale_tensor = torch.ones(out_features, dtype=torch.double)
            zero_point_tensor = torch.zeros(out_features, dtype=torch.long)
            for i in range(len(scale_tensor)):
                scale_tensor[i] = (i + 1.0) / 255.0
            W_q = torch.quantize_per_channel(W, scales=scale_tensor,
                                             zero_points=zero_point_tensor,
                                             axis=0, dtype=torch.qint8)
        else:
            W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8)

        X = torch.rand(batch_size, in_features).float()
        X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8)
        B = torch.rand(out_features).float() if use_bias else None
        scale = 0.5
        zero_point = 3
        if use_fused:
            qlinear = nnq_fused.LinearReLU(in_features, out_features)
        else:
            qlinear = nnq.Linear(in_features, out_features)

        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        qlinear(X_q)

        qlinear.set_weight_bias(W_q, B)
        # Simple round-trip test to ensure weight()/set_weight() API
        self.assertEqual(qlinear.weight(), W_q, atol=1e-5, rtol=0)
        W_pack = qlinear._packed_params._packed_params

        qlinear.scale = float(scale)
        qlinear.zero_point = int(zero_point)
        Z_q = qlinear(X_q)
        # Check if the module implementation matches calling the
        # ops directly
        if use_fused:
            Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point)

            self.assertTrue('QuantizedLinearReLU' in str(qlinear))
        else:
            Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point)

            self.assertTrue('QuantizedLinear' in str(qlinear))
        self.assertEqual(Z_ref, Z_q)

        # Test serialization of quantized Linear Module using state_dict
        model_dict = qlinear.state_dict()
        b = io.BytesIO()
        torch.save(model_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        for key in model_dict:
            if isinstance(model_dict[key], torch._C.ScriptObject):
                assert isinstance(loaded_dict[key], torch._C.ScriptObject)
                w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key])
                w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key])
                self.assertEqual(w_model, w_loaded)
                self.assertEqual(b_model, b_loaded)
            else:
                self.assertEqual(model_dict[key], loaded_dict[key])
        if use_fused:
            loaded_qlinear = nnq_fused.LinearReLU(in_features, out_features)
        else:
            loaded_qlinear = nnq.Linear(in_features, out_features)
        loaded_qlinear.load_state_dict(loaded_dict)

        linear_unpack = torch.ops.quantized.linear_unpack
        self.assertEqual(linear_unpack(qlinear._packed_params._packed_params),
                         linear_unpack(loaded_qlinear._packed_params._packed_params))
        if use_bias:
            self.assertEqual(qlinear.bias(), loaded_qlinear.bias())
        self.assertEqual(qlinear.scale, loaded_qlinear.scale)
        self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
        self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
        self.assertTrue(hasattr(qlinear, '_packed_params'))
        self.assertTrue(hasattr(loaded_qlinear, '_packed_params'))
        self.assertTrue(hasattr(qlinear, '_weight_bias'))
        self.assertTrue(hasattr(loaded_qlinear, '_weight_bias'))
        self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias())
        self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params))
        Z_q2 = loaded_qlinear(X_q)
        self.assertEqual(Z_q, Z_q2)

        b = io.BytesIO()
        torch.save(qlinear, b)
        b.seek(0)
        loaded = torch.load(b)
        self.assertEqual(qlinear.weight(), loaded.weight())
        self.assertEqual(qlinear.scale, loaded.scale)
        self.assertEqual(qlinear.zero_point, loaded.zero_point)

        # Test JIT
        self.checkScriptable(qlinear, [[X_q]], check_save_load=True)

        # Test from_float.
        float_linear = torch.nn.Linear(in_features, out_features).float()
        float_linear.qconfig = torch.quantization.default_qconfig
        torch.quantization.prepare(float_linear, inplace=True)
        float_linear(X.float())
        # Sequential allows swapping using "convert".
        quantized_float_linear = torch.nn.Sequential(float_linear)
        quantized_float_linear = torch.quantization.convert(quantized_float_linear, inplace=True)

        # Smoke test to make sure the module actually runs
        quantized_float_linear(X_q)

        # Smoke test extra_repr
        self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
 def test_linear_relu(self):
     for i, qengine in enumerate(supported_qengines):
         with override_quantized_engine(qengine):
             module = nniq.LinearReLU(3, 1, bias=True, dtype=torch.qint8)
             self._test_op(module, input_size=[1, 3], generate=False, iter=i)