def test_static_quantized_inference(self): qlinear = QuantizedLinear(10, 5, mode="EMA") weight = qlinear.weight.data.detach() weight_scale = get_dynamic_scale(weight, 8) weight_int = quantize_np(weight, weight_scale, 8) self.assertTrue((weight_int == torch.round(weight_int)).all()) self.assertTrue(weight_int.abs().max() <= 127) x = torch.randn(3, 10) * 2**0.5 - 0.36 x_thresh = 3. output_thresh = 2.3 output_scale = 127. / output_thresh x_scale = 127. / x_thresh qlinear.input_thresh = torch.tensor(x_thresh) qlinear.output_thresh = torch.tensor(output_thresh) x_int = quantize_np(x, x_scale, 8) self.assertTrue((x_int == torch.round(x_int)).all()) self.assertTrue(x_int.abs().max() <= 127) bias = qlinear.bias.data bias_scale = x_scale * weight_scale bias_int = quantize_np(bias, bias_scale, 32) self.assertTrue((bias_int == torch.round(bias_int)).all()) self.assertTrue(bias_int.abs().max() <= 2**(32 - 1) - 1) output_int = x_int @ weight_int.t() + bias_int output_int = torch.clamp(output_int, -(2**(32 - 1) - 1), 2**(32 - 1) - 1) output = torch.round(output_int / bias_scale * output_scale).clamp( -127, 127) / output_scale qlinear.eval() qlinear_output = qlinear(x) self.assertTrue((qlinear_output - output).norm() < 10**-6)
def test_export_to_8bit_with_bias(self): qlinear = QuantizedLinear(10, 5, mode='EMA') qlinear.eval() state_dict = qlinear.state_dict() self.assertTrue('weight' in state_dict) self.assertTrue('bias' in state_dict) self.assertTrue('quantized_weight' not in state_dict) self.assertTrue('_quantized_bias' not in state_dict) self.assertTrue('bias_scale' not in state_dict) qlinear.mode_8bit = True state_dict = qlinear.state_dict() self.assertTrue('weight' not in state_dict) self.assertTrue('bias' not in state_dict) self.assertTrue('quantized_weight' in state_dict) self.assertTrue(state_dict['quantized_weight'].dtype == torch.int8) self.assertTrue('_quantized_bias' in state_dict) self.assertTrue(state_dict['_quantized_bias'].dtype == torch.int32) self.assertTrue('bias_scale' in state_dict) qlinear.mode_8bit = False state_dict = qlinear.state_dict() self.assertTrue('weight' in state_dict) self.assertTrue('bias' in state_dict) self.assertTrue('quantized_weight' not in state_dict) self.assertTrue('_quantized_bias' not in state_dict) self.assertTrue('bias_scale' not in state_dict)
def test_training_and_inference_differences_dynamic(self): qlinear = QuantizedLinear(10, 5, bias=False) x = torch.randn(3, 10) * 2 + 0.1 y = qlinear(x) qlinear.eval() y_hat = qlinear(x) self.assertTrue((y - y_hat).norm() < 1e-6)
def test_export_to_8bit_with_bias(self): qlinear = QuantizedLinear(10, 5, mode="EMA") qlinear.eval() state_dict = qlinear.state_dict() self.assertTrue("weight" in state_dict) self.assertTrue("bias" in state_dict) self.assertTrue("quantized_weight" not in state_dict) self.assertTrue("_quantized_bias" not in state_dict) self.assertTrue("bias_scale" not in state_dict) qlinear.mode_8bit = True state_dict = qlinear.state_dict() self.assertTrue("weight" not in state_dict) self.assertTrue("bias" not in state_dict) self.assertTrue("quantized_weight" in state_dict) self.assertTrue(state_dict["quantized_weight"].dtype == torch.int8) self.assertTrue("_quantized_bias" in state_dict) self.assertTrue(state_dict["_quantized_bias"].dtype == torch.int32) self.assertTrue("bias_scale" in state_dict) qlinear.mode_8bit = False state_dict = qlinear.state_dict() self.assertTrue("weight" in state_dict) self.assertTrue("bias" in state_dict) self.assertTrue("quantized_weight" not in state_dict) self.assertTrue("_quantized_bias" not in state_dict) self.assertTrue("bias_scale" not in state_dict)
def test_restrict_loading_to_train_model(self): exporter = QuantizedLinear(10, 5, mode='dynamic') exporter.eval() exporter.mode_8bit = True state_dict = exporter.state_dict() importer = QuantizedLinear(10, 5, mode='dynamic') with self.assertRaises(RuntimeError): importer.load_state_dict(state_dict, strict=False)
def test_train_block_when_loading_quantized_model(self): exporter = QuantizedLinear(10, 5, mode="dynamic") exporter.eval() exporter.mode_8bit = True state_dict = exporter.state_dict() importer = QuantizedLinear(10, 5, mode="dynamic") importer.eval() importer.load_state_dict(state_dict, strict=False) with self.assertRaises(RuntimeError): importer.train()
def test_import_from_8bit_without_bias(self): exporter = QuantizedLinear(10, 5, bias=False, mode='dynamic') exporter.eval() exporter.mode_8bit = True state_dict = exporter.state_dict() exporter.mode_8bit = False importer = QuantizedLinear(10, 5, bias=False, mode='dynamic') self.assertTrue((exporter.weight != importer.weight).any()) importer.eval() importer.load_state_dict(state_dict, strict=False) x = torch.randn(3, 10) self.assertTrue((exporter(x) == importer(x)).all())
def test_import_from_8bit_with_bias(self): # QuantizationMode dynamic exporter = QuantizedLinear(10, 5, mode='dynamic') exporter.eval() exporter.mode_8bit = True state_dict = exporter.state_dict() exporter.mode_8bit = False importer = QuantizedLinear(10, 5, mode='dynamic') self.assertTrue((exporter.weight != importer.weight).any()) self.assertTrue((exporter.bias != importer.bias).any()) importer.eval() importer.load_state_dict(state_dict, strict=False) x = torch.randn(3, 10) self.assertTrue((exporter(x) == importer(x)).all()) # QuantizationMode ema exporter = QuantizedLinear(10, 5, requantize_output=False, mode='ema') x = torch.randn(3, 10) exporter(x) self.assertTrue(exporter.input_thresh != 0.) exporter.eval() exporter.mode_8bit = True state_dict = exporter.state_dict() exporter.mode_8bit = False importer = QuantizedLinear(10, 5, requantize_output=False, mode='ema') self.assertTrue((exporter.weight != importer.weight).any()) self.assertTrue((exporter.bias != importer.bias).any()) importer.eval() importer.load_state_dict(state_dict, strict=False) self.assertTrue((exporter(x) == importer(x)).all())