def test_static_quantized_inference(self):
     qlinear = QuantizedLinear(10, 5, mode="EMA")
     weight = qlinear.weight.data.detach()
     weight_scale = get_dynamic_scale(weight, 8)
     weight_int = quantize_np(weight, weight_scale, 8)
     self.assertTrue((weight_int == torch.round(weight_int)).all())
     self.assertTrue(weight_int.abs().max() <= 127)
     x = torch.randn(3, 10) * 2**0.5 - 0.36
     x_thresh = 3.
     output_thresh = 2.3
     output_scale = 127. / output_thresh
     x_scale = 127. / x_thresh
     qlinear.input_thresh = torch.tensor(x_thresh)
     qlinear.output_thresh = torch.tensor(output_thresh)
     x_int = quantize_np(x, x_scale, 8)
     self.assertTrue((x_int == torch.round(x_int)).all())
     self.assertTrue(x_int.abs().max() <= 127)
     bias = qlinear.bias.data
     bias_scale = x_scale * weight_scale
     bias_int = quantize_np(bias, bias_scale, 32)
     self.assertTrue((bias_int == torch.round(bias_int)).all())
     self.assertTrue(bias_int.abs().max() <= 2**(32 - 1) - 1)
     output_int = x_int @ weight_int.t() + bias_int
     output_int = torch.clamp(output_int, -(2**(32 - 1) - 1),
                              2**(32 - 1) - 1)
     output = torch.round(output_int / bias_scale * output_scale).clamp(
         -127, 127) / output_scale
     qlinear.eval()
     qlinear_output = qlinear(x)
     self.assertTrue((qlinear_output - output).norm() < 10**-6)
 def test_export_to_8bit_with_bias(self):
     qlinear = QuantizedLinear(10, 5, mode='EMA')
     qlinear.eval()
     state_dict = qlinear.state_dict()
     self.assertTrue('weight' in state_dict)
     self.assertTrue('bias' in state_dict)
     self.assertTrue('quantized_weight' not in state_dict)
     self.assertTrue('_quantized_bias' not in state_dict)
     self.assertTrue('bias_scale' not in state_dict)
     qlinear.mode_8bit = True
     state_dict = qlinear.state_dict()
     self.assertTrue('weight' not in state_dict)
     self.assertTrue('bias' not in state_dict)
     self.assertTrue('quantized_weight' in state_dict)
     self.assertTrue(state_dict['quantized_weight'].dtype == torch.int8)
     self.assertTrue('_quantized_bias' in state_dict)
     self.assertTrue(state_dict['_quantized_bias'].dtype == torch.int32)
     self.assertTrue('bias_scale' in state_dict)
     qlinear.mode_8bit = False
     state_dict = qlinear.state_dict()
     self.assertTrue('weight' in state_dict)
     self.assertTrue('bias' in state_dict)
     self.assertTrue('quantized_weight' not in state_dict)
     self.assertTrue('_quantized_bias' not in state_dict)
     self.assertTrue('bias_scale' not in state_dict)
 def test_training_and_inference_differences_dynamic(self):
     qlinear = QuantizedLinear(10, 5, bias=False)
     x = torch.randn(3, 10) * 2 + 0.1
     y = qlinear(x)
     qlinear.eval()
     y_hat = qlinear(x)
     self.assertTrue((y - y_hat).norm() < 1e-6)
 def test_export_to_8bit_with_bias(self):
     qlinear = QuantizedLinear(10, 5, mode="EMA")
     qlinear.eval()
     state_dict = qlinear.state_dict()
     self.assertTrue("weight" in state_dict)
     self.assertTrue("bias" in state_dict)
     self.assertTrue("quantized_weight" not in state_dict)
     self.assertTrue("_quantized_bias" not in state_dict)
     self.assertTrue("bias_scale" not in state_dict)
     qlinear.mode_8bit = True
     state_dict = qlinear.state_dict()
     self.assertTrue("weight" not in state_dict)
     self.assertTrue("bias" not in state_dict)
     self.assertTrue("quantized_weight" in state_dict)
     self.assertTrue(state_dict["quantized_weight"].dtype == torch.int8)
     self.assertTrue("_quantized_bias" in state_dict)
     self.assertTrue(state_dict["_quantized_bias"].dtype == torch.int32)
     self.assertTrue("bias_scale" in state_dict)
     qlinear.mode_8bit = False
     state_dict = qlinear.state_dict()
     self.assertTrue("weight" in state_dict)
     self.assertTrue("bias" in state_dict)
     self.assertTrue("quantized_weight" not in state_dict)
     self.assertTrue("_quantized_bias" not in state_dict)
     self.assertTrue("bias_scale" not in state_dict)
 def test_restrict_loading_to_train_model(self):
     exporter = QuantizedLinear(10, 5, mode='dynamic')
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     importer = QuantizedLinear(10, 5, mode='dynamic')
     with self.assertRaises(RuntimeError):
         importer.load_state_dict(state_dict, strict=False)
 def test_train_block_when_loading_quantized_model(self):
     exporter = QuantizedLinear(10, 5, mode="dynamic")
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     importer = QuantizedLinear(10, 5, mode="dynamic")
     importer.eval()
     importer.load_state_dict(state_dict, strict=False)
     with self.assertRaises(RuntimeError):
         importer.train()
 def test_import_from_8bit_without_bias(self):
     exporter = QuantizedLinear(10, 5, bias=False, mode='dynamic')
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     exporter.mode_8bit = False
     importer = QuantizedLinear(10, 5, bias=False, mode='dynamic')
     self.assertTrue((exporter.weight != importer.weight).any())
     importer.eval()
     importer.load_state_dict(state_dict, strict=False)
     x = torch.randn(3, 10)
     self.assertTrue((exporter(x) == importer(x)).all())
 def test_import_from_8bit_with_bias(self):
     # QuantizationMode dynamic
     exporter = QuantizedLinear(10, 5, mode='dynamic')
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     exporter.mode_8bit = False
     importer = QuantizedLinear(10, 5, mode='dynamic')
     self.assertTrue((exporter.weight != importer.weight).any())
     self.assertTrue((exporter.bias != importer.bias).any())
     importer.eval()
     importer.load_state_dict(state_dict, strict=False)
     x = torch.randn(3, 10)
     self.assertTrue((exporter(x) == importer(x)).all())
     # QuantizationMode ema
     exporter = QuantizedLinear(10, 5, requantize_output=False, mode='ema')
     x = torch.randn(3, 10)
     exporter(x)
     self.assertTrue(exporter.input_thresh != 0.)
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     exporter.mode_8bit = False
     importer = QuantizedLinear(10, 5, requantize_output=False, mode='ema')
     self.assertTrue((exporter.weight != importer.weight).any())
     self.assertTrue((exporter.bias != importer.bias).any())
     importer.eval()
     importer.load_state_dict(state_dict, strict=False)
     self.assertTrue((exporter(x) == importer(x)).all())