def test_scriptability_serialization(self): # test serialization of quantized functional modules with tempfile.TemporaryFile() as f: torch.save(self.qmodel_under_test, f) f.seek(0) loaded = torch.load(f) self.assertEqual(self.qmodel_under_test.myadd.zero_point, loaded.myadd.zero_point) state_dict = self.qmodel_under_test.state_dict() self.assertTrue('myadd.zero_point' in state_dict.keys(), 'zero point not in state dict for functional modules') x = torch.rand(10, 1, dtype=torch.float) xq = torch.quantize_linear(x, 1.0, 0, torch.qint8) self.checkScriptable(self.qmodel_under_test, [(xq, xq)], check_save_load=True) self.checkScriptable(self.model_under_test, [(xq.dequantize(), xq.dequantize())], check_save_load=True)
def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_fused): """test API functionality for nn.quantized.linear and nn._intrinsic.quantized.linear_relu""" W = torch.rand(out_features, in_features).float() W_q = torch.quantize_linear(W, 0.1, 4, torch.qint8) X = torch.rand(batch_size, in_features).float() X_q = torch.quantize_linear(X, 0.2, 10, torch.quint8) B = torch.rand(out_features).float() if use_bias else None B_q = torch.quantize_linear(B, W_q.q_scale() * X_q.q_scale(), 0, torch.qint32) if use_bias else None scale = 0.5 zero_point = 3 if use_fused: qlinear = nnq_fused.LinearReLU(in_features, out_features) else: qlinear = nnq.Linear(in_features, out_features) # Run module with default-initialized parameters. # This tests that the constructor is correct. qlinear(X_q) qlinear.set_weight(W_q) # Simple round-trip test to ensure weight()/set_weight() API self.assertEqual(qlinear.weight(), W_q) W_pack = qlinear._packed_weight qlinear.bias = B_q if use_bias else None qlinear.scale = float(scale) qlinear.zero_point = int(zero_point) Z_q = qlinear(X_q) # Check if the module implementation matches calling the # ops directly if use_fused: Z_ref = torch.ops.quantized.fbgemm_linear_relu( X_q, W_pack, B_q, scale, zero_point) else: Z_ref = torch.ops.quantized.fbgemm_linear(X_q, W_pack, B_q, scale, zero_point) self.assertEqual(Z_ref, Z_q) # Test serialization of quantized Linear Module using state_dict model_dict = qlinear.state_dict() self.assertEqual(model_dict['weight'], W_q) if use_bias: self.assertEqual(model_dict['bias'], B_q) with tempfile.TemporaryFile() as f: torch.save(model_dict, f) f.seek(0) loaded_dict = torch.load(f) for key in model_dict: self.assertEqual(model_dict[key], loaded_dict[key]) if use_fused: loaded_qlinear = nnq_fused.LinearReLU(in_features, out_features) else: loaded_qlinear = nnq.Linear(in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) linear_unpack = torch.ops.quantized.fbgemm_linear_unpack self.assertEqual(linear_unpack(qlinear._packed_weight), linear_unpack(loaded_qlinear._packed_weight)) if use_bias: self.assertEqual(qlinear.bias, loaded_qlinear.bias) self.assertEqual(qlinear.scale, loaded_qlinear.scale) self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) self.assertTrue(hasattr(qlinear, '_packed_weight')) self.assertTrue(hasattr(loaded_qlinear, '_packed_weight')) self.assertTrue(hasattr(qlinear, 'weight')) self.assertTrue(hasattr(loaded_qlinear, 'weight')) self.assertEqual(qlinear.weight(), loaded_qlinear.weight()) self.assertEqual( qlinear.weight(), torch.ops.quantized.fbgemm_linear_unpack(qlinear._packed_weight)) Z_q2 = loaded_qlinear(X_q) self.assertEqual(Z_q, Z_q2) # test serialization of module directly with tempfile.TemporaryFile() as f: torch.save(qlinear, f) f.seek(0) loaded = torch.load(f) # This check is disabled pending an issue in PyTorch serialization: # https://github.com/pytorch/pytorch/issues/24045 # self.assertEqual(qlinear.weight(), loaded.weight()) self.assertEqual(qlinear.bias, loaded.bias) self.assertEqual(qlinear.scale, loaded.scale) self.assertEqual(qlinear.zero_point, loaded.zero_point) # Test JIT self.checkScriptable(qlinear, list(zip([X_q], [Z_ref])), check_save_load=True) # Test from_float. float_linear = torch.nn.Linear(in_features, out_features).float() float_linear.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(float_linear) float_linear(X.float()) # Sequential allows swapping using "convert". quantized_float_linear = torch.nn.Sequential(float_linear) torch.quantization.convert(quantized_float_linear) # Smoke test to make sure the module actually runs quantized_float_linear(X_q) # Smoke test extra_repr str(quantized_float_linear)
def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_default_observer): """test API functionality for nn.quantized.dynamic.Linear""" W = torch.rand(out_features, in_features).float() W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8) W_q = torch.quantize_linear(W, W_scale, W_zp, torch.qint8) X = torch.rand(batch_size, in_features).float() B = torch.rand(out_features).float() if use_bias else None qlinear = nnqd.Linear(in_features, out_features) # Run module with default-initialized parameters. # This tests that the constructor is correct. qlinear(X) qlinear.set_weight(W_q) # Simple round-trip test to ensure weight()/set_weight() API self.assertEqual(qlinear.weight(), W_q) W_pack = qlinear._packed_weight qlinear.bias = B if use_bias else None Z_dq = qlinear(X) # Check if the module implementation matches calling the # ops directly Z_ref = torch.ops.quantized.fbgemm_linear_dynamic(X, W_pack, B) self.assertEqual(Z_ref, Z_dq) # Test serialization of dynamic quantized Linear Module using state_dict model_dict = qlinear.state_dict() self.assertEqual(model_dict['weight'], W_q) if use_bias: self.assertEqual(model_dict['bias'], B) with tempfile.TemporaryFile() as f: torch.save(model_dict, f) f.seek(0) loaded_dict = torch.load(f) for key in model_dict: self.assertEqual(model_dict[key], loaded_dict[key]) loaded_qlinear = nnqd.Linear(in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) linear_unpack = torch.ops.quantized.fbgemm_linear_unpack self.assertEqual(linear_unpack(qlinear._packed_weight), linear_unpack(loaded_qlinear._packed_weight)) if use_bias: self.assertEqual(qlinear.bias, loaded_qlinear.bias) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) self.assertTrue(hasattr(qlinear, '_packed_weight')) self.assertTrue(hasattr(loaded_qlinear, '_packed_weight')) self.assertTrue(hasattr(qlinear, 'weight')) self.assertTrue(hasattr(loaded_qlinear, 'weight')) self.assertEqual(qlinear.weight(), loaded_qlinear.weight()) self.assertEqual( qlinear.weight(), torch.ops.quantized.fbgemm_linear_unpack(qlinear._packed_weight)) Z_dq2 = qlinear(X) self.assertEqual(Z_dq, Z_dq2) # test serialization of module directly with tempfile.TemporaryFile() as f: torch.save(qlinear, f) f.seek(0) loaded = torch.load(f) # This check is disabled pending an issue in PyTorch serialization: # https://github.com/pytorch/pytorch/issues/24045 # self.assertEqual(qlinear.weight(), loaded.weight()) self.assertEqual(qlinear.zero_point, loaded.zero_point) # Test JIT self.checkScriptable(qlinear, list(zip([X], [Z_ref])), check_save_load=True) # Test from_float float_linear = torch.nn.Linear(in_features, out_features).float() if use_default_observer: float_linear.qconfig = torch.quantization.default_dynamic_qconfig prepare_dynamic(float_linear) float_linear(X.float()) quantized_float_linear = nnqd.Linear.from_float(float_linear) # Smoke test to make sure the module actually runs quantized_float_linear(X) # Smoke test extra_repr str(quantized_float_linear)