def test_import_from_8bit_with_bias(self): # QuantizationMode dynamic exporter = QuantizedLinear(10, 5, mode='dynamic') exporter.eval() exporter.mode_8bit = True state_dict = exporter.state_dict() exporter.mode_8bit = False importer = QuantizedLinear(10, 5, mode='dynamic') self.assertTrue((exporter.weight != importer.weight).any()) self.assertTrue((exporter.bias != importer.bias).any()) importer.eval() importer.load_state_dict(state_dict, strict=False) x = torch.randn(3, 10) self.assertTrue((exporter(x) == importer(x)).all()) # QuantizationMode ema exporter = QuantizedLinear(10, 5, requantize_output=False, mode='ema') x = torch.randn(3, 10) exporter(x) self.assertTrue(exporter.input_thresh != 0.) exporter.eval() exporter.mode_8bit = True state_dict = exporter.state_dict() exporter.mode_8bit = False importer = QuantizedLinear(10, 5, requantize_output=False, mode='ema') self.assertTrue((exporter.weight != importer.weight).any()) self.assertTrue((exporter.bias != importer.bias).any()) importer.eval() importer.load_state_dict(state_dict, strict=False) self.assertTrue((exporter(x) == importer(x)).all())
def test_restrict_loading_to_train_model(self): exporter = QuantizedLinear(10, 5, mode='dynamic') exporter.eval() exporter.mode_8bit = True state_dict = exporter.state_dict() importer = QuantizedLinear(10, 5, mode='dynamic') with self.assertRaises(RuntimeError): importer.load_state_dict(state_dict, strict=False)
def test_train_block_when_loading_quantized_model(self): exporter = QuantizedLinear(10, 5, mode="dynamic") exporter.eval() exporter.mode_8bit = True state_dict = exporter.state_dict() importer = QuantizedLinear(10, 5, mode="dynamic") importer.eval() importer.load_state_dict(state_dict, strict=False) with self.assertRaises(RuntimeError): importer.train()
def test_import_from_8bit_without_bias(self): exporter = QuantizedLinear(10, 5, bias=False, mode='dynamic') exporter.eval() exporter.mode_8bit = True state_dict = exporter.state_dict() exporter.mode_8bit = False importer = QuantizedLinear(10, 5, bias=False, mode='dynamic') self.assertTrue((exporter.weight != importer.weight).any()) importer.eval() importer.load_state_dict(state_dict, strict=False) x = torch.randn(3, 10) self.assertTrue((exporter(x) == importer(x)).all())
def test_static_quantized_inference(self): qlinear = QuantizedLinear(10, 5, mode="EMA") weight = qlinear.weight.data.detach() weight_scale = get_dynamic_scale(weight, 8) weight_int = quantize_np(weight, weight_scale, 8) self.assertTrue((weight_int == torch.round(weight_int)).all()) self.assertTrue(weight_int.abs().max() <= 127) x = torch.randn(3, 10) * 2**0.5 - 0.36 x_thresh = 3. output_thresh = 2.3 output_scale = 127. / output_thresh x_scale = 127. / x_thresh qlinear.input_thresh = torch.tensor(x_thresh) qlinear.output_thresh = torch.tensor(output_thresh) x_int = quantize_np(x, x_scale, 8) self.assertTrue((x_int == torch.round(x_int)).all()) self.assertTrue(x_int.abs().max() <= 127) bias = qlinear.bias.data bias_scale = x_scale * weight_scale bias_int = quantize_np(bias, bias_scale, 32) self.assertTrue((bias_int == torch.round(bias_int)).all()) self.assertTrue(bias_int.abs().max() <= 2**(32 - 1) - 1) output_int = x_int @ weight_int.t() + bias_int output_int = torch.clamp(output_int, -(2**(32 - 1) - 1), 2**(32 - 1) - 1) output = torch.round(output_int / bias_scale * output_scale).clamp( -127, 127) / output_scale qlinear.eval() qlinear_output = qlinear(x) self.assertTrue((qlinear_output - output).norm() < 10**-6)
def test_export_to_8bit_with_bias(self): qlinear = QuantizedLinear(10, 5, mode='EMA') qlinear.eval() state_dict = qlinear.state_dict() self.assertTrue('weight' in state_dict) self.assertTrue('bias' in state_dict) self.assertTrue('quantized_weight' not in state_dict) self.assertTrue('_quantized_bias' not in state_dict) self.assertTrue('bias_scale' not in state_dict) qlinear.mode_8bit = True state_dict = qlinear.state_dict() self.assertTrue('weight' not in state_dict) self.assertTrue('bias' not in state_dict) self.assertTrue('quantized_weight' in state_dict) self.assertTrue(state_dict['quantized_weight'].dtype == torch.int8) self.assertTrue('_quantized_bias' in state_dict) self.assertTrue(state_dict['_quantized_bias'].dtype == torch.int32) self.assertTrue('bias_scale' in state_dict) qlinear.mode_8bit = False state_dict = qlinear.state_dict() self.assertTrue('weight' in state_dict) self.assertTrue('bias' in state_dict) self.assertTrue('quantized_weight' not in state_dict) self.assertTrue('_quantized_bias' not in state_dict) self.assertTrue('bias_scale' not in state_dict)
def test_training_and_inference_differences_dynamic(self): qlinear = QuantizedLinear(10, 5, bias=False) x = torch.randn(3, 10) * 2 + 0.1 y = qlinear(x) qlinear.eval() y_hat = qlinear(x) self.assertTrue((y - y_hat).norm() < 1e-6)
def test_ema_quantization_data_parallel(self): if not torch.cuda.is_available() or torch.cuda.device_count() <= 1: return ema_decay = 0.9 fake_quantize = FakeLinearQuantizationWithSTE().apply qlinear = nn.DataParallel( QuantizedLinear(10, 5, bias=False, ema_decay=ema_decay, mode="EMA")).cuda() for i in range(5): x = torch.randn(2, 10).cuda() tmp_input_thresh = x[0].abs().max() if i == 0: input_ema = tmp_input_thresh else: input_ema -= (1 - ema_decay) * (input_ema - tmp_input_thresh) y = (fake_quantize(x, get_scale(8, input_ema), 8) @ qlinear.module.fake_quantized_weight.t()).detach() tmp_output_thresh = y[0].abs().max() if i == 0: output_ema = tmp_output_thresh else: output_ema -= (1 - ema_decay) * \ (output_ema - tmp_output_thresh) qlinear(x) self.assertEqual(qlinear.module.input_thresh, input_ema) self.assertEqual(qlinear.module.output_thresh, output_ema)
def test_ema_quantization(self): ema_decay = 0.9 qlinear = QuantizedLinear(10, 5, bias=False, ema_decay=ema_decay, mode="EMA") for i in range(5): x = torch.randn(3, 10) tmp_input_thresh = x.abs().max() if i == 0: input_ema = tmp_input_thresh else: input_ema -= (1 - ema_decay) * (input_ema - tmp_input_thresh) y = (fake_quantize_np(x, get_scale(8, input_ema), 8) @ qlinear.fake_quantized_weight.t()).detach() tmp_output_thresh = y.abs().max() if i == 0: output_ema = tmp_output_thresh else: output_ema -= (1 - ema_decay) * (output_ema - tmp_output_thresh) y = fake_quantize_np(y, get_scale(8, output_ema), 8) y_hat = qlinear(x) self.assertTrue((y == y_hat).all()) self.assertEqual(qlinear.input_thresh, input_ema) self.assertEqual(qlinear.output_thresh, output_ema)
def test_export_to_8bit_with_bias(self): qlinear = QuantizedLinear(10, 5, mode="EMA") qlinear.eval() state_dict = qlinear.state_dict() self.assertTrue("weight" in state_dict) self.assertTrue("bias" in state_dict) self.assertTrue("quantized_weight" not in state_dict) self.assertTrue("_quantized_bias" not in state_dict) self.assertTrue("bias_scale" not in state_dict) qlinear.mode_8bit = True state_dict = qlinear.state_dict() self.assertTrue("weight" not in state_dict) self.assertTrue("bias" not in state_dict) self.assertTrue("quantized_weight" in state_dict) self.assertTrue(state_dict["quantized_weight"].dtype == torch.int8) self.assertTrue("_quantized_bias" in state_dict) self.assertTrue(state_dict["_quantized_bias"].dtype == torch.int32) self.assertTrue("bias_scale" in state_dict) qlinear.mode_8bit = False state_dict = qlinear.state_dict() self.assertTrue("weight" in state_dict) self.assertTrue("bias" in state_dict) self.assertTrue("quantized_weight" not in state_dict) self.assertTrue("_quantized_bias" not in state_dict) self.assertTrue("bias_scale" not in state_dict)
def test_none_quantized_linear(self): qlinear = QuantizedLinear(10, 5, mode="NONE") linear = nn.Linear(10, 5) linear.weight.data = qlinear.weight linear.bias.data = qlinear.bias x = torch.randn(3, 10) y = linear(x) y_hat = qlinear(x) self.assertTrue((y - y_hat).norm() < 1e-6)
def test_dynamic_quantized_linear_backward(self): x = torch.randn(1, 100, requires_grad=True) linear = QuantizedLinear(100, 1, bias=False, mode="DYNAMIC") y = linear(x) y.backward() self.assertTrue((x.grad == linear.fake_quantized_weight).all()) with torch.no_grad(): scale = (2**(8 - 1) - 1) / x.abs().max() self.assertTrue((fake_quantize_np(x.detach(), scale, 8) == linear.weight.grad).all())
def test_start_quantization_delay(self): quantization_delay = 2 qlinear = QuantizedLinear(10, 5, start_step=quantization_delay, mode="DYNAMIC") linear = nn.Linear(10, 5) linear.weight.data = qlinear.weight linear.bias.data = qlinear.bias for _ in range(quantization_delay): x = torch.randn(3, 10) qy = qlinear(x) y = linear(x) self.assertTrue((y == qy).all()) qy = qlinear(x) self.assertFalse((y == qy).all())
def test_dynamic_quantized_linear_forward(self): """Test QuantizedLinear forward method by giving in the input and weight values that are already quantized, therefore the quantization step should have no effect on the values and we know what values are expected""" x = torch.randn(1, 100).mul(127.).round().clamp(-127., 127.) qlinear = QuantizedLinear(100, 1, bias=False, requantize_output=False, mode="dynamic") with torch.no_grad(): scale = 127. / qlinear.weight.abs().max() self.assertTrue((qlinear.fake_quantized_weight == fake_quantize_np( qlinear.weight.detach(), scale, 8)).all()) qlinear.weight.data = torch.randn_like( qlinear.weight).mul(127.).round().clamp(-127., 127.) y = qlinear(x) self.assertEqual(y.shape, (1, 1)) self.assertTrue((y == (x @ qlinear.weight.t())).all())
def test_start_quantization_delay_data_parallel(self): if not torch.cuda.is_available(): return quantization_delay = 2 qlinear = QuantizedLinear(10, 5, start_step=quantization_delay, mode="DYNAMIC") linear = nn.Linear(10, 5) linear.weight.data = qlinear.weight linear.bias.data = qlinear.bias qlinear = nn.DataParallel(qlinear).cuda() linear = nn.DataParallel(linear).cuda() for _ in range(quantization_delay): x = torch.randn(3, 10).cuda() qy = qlinear(x) y = linear(x) self.assertTrue((y == qy).all()) qy = qlinear(x) self.assertFalse((y == qy).all())