def test_scriptability_serialization(self):
        # test serialization of quantized functional modules
        with tempfile.TemporaryFile() as f:
            torch.save(self.qmodel_under_test, f)
            f.seek(0)
            loaded = torch.load(f)
        self.assertEqual(self.qmodel_under_test.myadd.zero_point, loaded.myadd.zero_point)
        state_dict = self.qmodel_under_test.state_dict()
        self.assertTrue('myadd.zero_point' in state_dict.keys(),
                        'zero point not in state dict for functional modules')

        x = torch.rand(10, 1, dtype=torch.float)
        xq = torch.quantize_linear(x, 1.0, 0, torch.qint8)
        self.checkScriptable(self.qmodel_under_test, [(xq, xq)], check_save_load=True)
        self.checkScriptable(self.model_under_test, [(xq.dequantize(), xq.dequantize())], check_save_load=True)
예제 #2
0
    def test_linear_api(self, batch_size, in_features, out_features, use_bias,
                        use_fused):
        """test API functionality for nn.quantized.linear and nn._intrinsic.quantized.linear_relu"""
        W = torch.rand(out_features, in_features).float()
        W_q = torch.quantize_linear(W, 0.1, 4, torch.qint8)
        X = torch.rand(batch_size, in_features).float()
        X_q = torch.quantize_linear(X, 0.2, 10, torch.quint8)
        B = torch.rand(out_features).float() if use_bias else None
        B_q = torch.quantize_linear(B,
                                    W_q.q_scale() * X_q.q_scale(), 0,
                                    torch.qint32) if use_bias else None
        scale = 0.5
        zero_point = 3
        if use_fused:
            qlinear = nnq_fused.LinearReLU(in_features, out_features)
        else:
            qlinear = nnq.Linear(in_features, out_features)

        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        qlinear(X_q)

        qlinear.set_weight(W_q)
        # Simple round-trip test to ensure weight()/set_weight() API
        self.assertEqual(qlinear.weight(), W_q)
        W_pack = qlinear._packed_weight
        qlinear.bias = B_q if use_bias else None

        qlinear.scale = float(scale)
        qlinear.zero_point = int(zero_point)
        Z_q = qlinear(X_q)
        # Check if the module implementation matches calling the
        # ops directly
        if use_fused:
            Z_ref = torch.ops.quantized.fbgemm_linear_relu(
                X_q, W_pack, B_q, scale, zero_point)
        else:
            Z_ref = torch.ops.quantized.fbgemm_linear(X_q, W_pack, B_q, scale,
                                                      zero_point)
        self.assertEqual(Z_ref, Z_q)

        # Test serialization of quantized Linear Module using state_dict

        model_dict = qlinear.state_dict()
        self.assertEqual(model_dict['weight'], W_q)
        if use_bias:
            self.assertEqual(model_dict['bias'], B_q)
        with tempfile.TemporaryFile() as f:
            torch.save(model_dict, f)
            f.seek(0)
            loaded_dict = torch.load(f)
        for key in model_dict:
            self.assertEqual(model_dict[key], loaded_dict[key])
        if use_fused:
            loaded_qlinear = nnq_fused.LinearReLU(in_features, out_features)
        else:
            loaded_qlinear = nnq.Linear(in_features, out_features)
        loaded_qlinear.load_state_dict(loaded_dict)

        linear_unpack = torch.ops.quantized.fbgemm_linear_unpack
        self.assertEqual(linear_unpack(qlinear._packed_weight),
                         linear_unpack(loaded_qlinear._packed_weight))
        if use_bias:
            self.assertEqual(qlinear.bias, loaded_qlinear.bias)
        self.assertEqual(qlinear.scale, loaded_qlinear.scale)
        self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
        self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
        self.assertTrue(hasattr(qlinear, '_packed_weight'))
        self.assertTrue(hasattr(loaded_qlinear, '_packed_weight'))
        self.assertTrue(hasattr(qlinear, 'weight'))
        self.assertTrue(hasattr(loaded_qlinear, 'weight'))
        self.assertEqual(qlinear.weight(), loaded_qlinear.weight())
        self.assertEqual(
            qlinear.weight(),
            torch.ops.quantized.fbgemm_linear_unpack(qlinear._packed_weight))
        Z_q2 = loaded_qlinear(X_q)
        self.assertEqual(Z_q, Z_q2)

        # test serialization of module directly
        with tempfile.TemporaryFile() as f:
            torch.save(qlinear, f)
            f.seek(0)
            loaded = torch.load(f)
        # This check is disabled pending an issue in PyTorch serialization:
        # https://github.com/pytorch/pytorch/issues/24045
        # self.assertEqual(qlinear.weight(), loaded.weight())
        self.assertEqual(qlinear.bias, loaded.bias)
        self.assertEqual(qlinear.scale, loaded.scale)
        self.assertEqual(qlinear.zero_point, loaded.zero_point)

        # Test JIT
        self.checkScriptable(qlinear,
                             list(zip([X_q], [Z_ref])),
                             check_save_load=True)

        # Test from_float.
        float_linear = torch.nn.Linear(in_features, out_features).float()
        float_linear.qconfig = torch.quantization.default_qconfig
        torch.quantization.prepare(float_linear)
        float_linear(X.float())
        # Sequential allows swapping using "convert".
        quantized_float_linear = torch.nn.Sequential(float_linear)
        torch.quantization.convert(quantized_float_linear)

        # Smoke test to make sure the module actually runs
        quantized_float_linear(X_q)

        # Smoke test extra_repr
        str(quantized_float_linear)
예제 #3
0
    def test_linear_api(self, batch_size, in_features, out_features, use_bias,
                        use_default_observer):
        """test API functionality for nn.quantized.dynamic.Linear"""
        W = torch.rand(out_features, in_features).float()
        W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8)
        W_q = torch.quantize_linear(W, W_scale, W_zp, torch.qint8)
        X = torch.rand(batch_size, in_features).float()
        B = torch.rand(out_features).float() if use_bias else None
        qlinear = nnqd.Linear(in_features, out_features)
        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        qlinear(X)
        qlinear.set_weight(W_q)

        # Simple round-trip test to ensure weight()/set_weight() API
        self.assertEqual(qlinear.weight(), W_q)
        W_pack = qlinear._packed_weight
        qlinear.bias = B if use_bias else None
        Z_dq = qlinear(X)

        # Check if the module implementation matches calling the
        # ops directly
        Z_ref = torch.ops.quantized.fbgemm_linear_dynamic(X, W_pack, B)
        self.assertEqual(Z_ref, Z_dq)

        # Test serialization of dynamic quantized Linear Module using state_dict
        model_dict = qlinear.state_dict()
        self.assertEqual(model_dict['weight'], W_q)
        if use_bias:
            self.assertEqual(model_dict['bias'], B)
        with tempfile.TemporaryFile() as f:
            torch.save(model_dict, f)
            f.seek(0)
            loaded_dict = torch.load(f)
        for key in model_dict:
            self.assertEqual(model_dict[key], loaded_dict[key])
        loaded_qlinear = nnqd.Linear(in_features, out_features)
        loaded_qlinear.load_state_dict(loaded_dict)

        linear_unpack = torch.ops.quantized.fbgemm_linear_unpack
        self.assertEqual(linear_unpack(qlinear._packed_weight),
                         linear_unpack(loaded_qlinear._packed_weight))
        if use_bias:
            self.assertEqual(qlinear.bias, loaded_qlinear.bias)
        self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
        self.assertTrue(hasattr(qlinear, '_packed_weight'))
        self.assertTrue(hasattr(loaded_qlinear, '_packed_weight'))
        self.assertTrue(hasattr(qlinear, 'weight'))
        self.assertTrue(hasattr(loaded_qlinear, 'weight'))

        self.assertEqual(qlinear.weight(), loaded_qlinear.weight())
        self.assertEqual(
            qlinear.weight(),
            torch.ops.quantized.fbgemm_linear_unpack(qlinear._packed_weight))
        Z_dq2 = qlinear(X)
        self.assertEqual(Z_dq, Z_dq2)

        # test serialization of module directly
        with tempfile.TemporaryFile() as f:
            torch.save(qlinear, f)
            f.seek(0)
            loaded = torch.load(f)
        # This check is disabled pending an issue in PyTorch serialization:
        # https://github.com/pytorch/pytorch/issues/24045
        # self.assertEqual(qlinear.weight(), loaded.weight())
        self.assertEqual(qlinear.zero_point, loaded.zero_point)

        # Test JIT
        self.checkScriptable(qlinear,
                             list(zip([X], [Z_ref])),
                             check_save_load=True)

        # Test from_float
        float_linear = torch.nn.Linear(in_features, out_features).float()
        if use_default_observer:
            float_linear.qconfig = torch.quantization.default_dynamic_qconfig
        prepare_dynamic(float_linear)
        float_linear(X.float())
        quantized_float_linear = nnqd.Linear.from_float(float_linear)

        # Smoke test to make sure the module actually runs
        quantized_float_linear(X)

        # Smoke test extra_repr
        str(quantized_float_linear)