def test_import_from_8bit_with_bias(self):
     # QuantizationMode dynamic
     exporter = QuantizedLinear(10, 5, mode='dynamic')
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     exporter.mode_8bit = False
     importer = QuantizedLinear(10, 5, mode='dynamic')
     self.assertTrue((exporter.weight != importer.weight).any())
     self.assertTrue((exporter.bias != importer.bias).any())
     importer.eval()
     importer.load_state_dict(state_dict, strict=False)
     x = torch.randn(3, 10)
     self.assertTrue((exporter(x) == importer(x)).all())
     # QuantizationMode ema
     exporter = QuantizedLinear(10, 5, requantize_output=False, mode='ema')
     x = torch.randn(3, 10)
     exporter(x)
     self.assertTrue(exporter.input_thresh != 0.)
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     exporter.mode_8bit = False
     importer = QuantizedLinear(10, 5, requantize_output=False, mode='ema')
     self.assertTrue((exporter.weight != importer.weight).any())
     self.assertTrue((exporter.bias != importer.bias).any())
     importer.eval()
     importer.load_state_dict(state_dict, strict=False)
     self.assertTrue((exporter(x) == importer(x)).all())
 def test_export_to_8bit_with_bias(self):
     qlinear = QuantizedLinear(10, 5, mode='EMA')
     qlinear.eval()
     state_dict = qlinear.state_dict()
     self.assertTrue('weight' in state_dict)
     self.assertTrue('bias' in state_dict)
     self.assertTrue('quantized_weight' not in state_dict)
     self.assertTrue('_quantized_bias' not in state_dict)
     self.assertTrue('bias_scale' not in state_dict)
     qlinear.mode_8bit = True
     state_dict = qlinear.state_dict()
     self.assertTrue('weight' not in state_dict)
     self.assertTrue('bias' not in state_dict)
     self.assertTrue('quantized_weight' in state_dict)
     self.assertTrue(state_dict['quantized_weight'].dtype == torch.int8)
     self.assertTrue('_quantized_bias' in state_dict)
     self.assertTrue(state_dict['_quantized_bias'].dtype == torch.int32)
     self.assertTrue('bias_scale' in state_dict)
     qlinear.mode_8bit = False
     state_dict = qlinear.state_dict()
     self.assertTrue('weight' in state_dict)
     self.assertTrue('bias' in state_dict)
     self.assertTrue('quantized_weight' not in state_dict)
     self.assertTrue('_quantized_bias' not in state_dict)
     self.assertTrue('bias_scale' not in state_dict)
예제 #3
0
 def test_export_to_8bit_with_bias(self):
     qlinear = QuantizedLinear(10, 5, mode="EMA")
     qlinear.eval()
     state_dict = qlinear.state_dict()
     self.assertTrue("weight" in state_dict)
     self.assertTrue("bias" in state_dict)
     self.assertTrue("quantized_weight" not in state_dict)
     self.assertTrue("_quantized_bias" not in state_dict)
     self.assertTrue("bias_scale" not in state_dict)
     qlinear.mode_8bit = True
     state_dict = qlinear.state_dict()
     self.assertTrue("weight" not in state_dict)
     self.assertTrue("bias" not in state_dict)
     self.assertTrue("quantized_weight" in state_dict)
     self.assertTrue(state_dict["quantized_weight"].dtype == torch.int8)
     self.assertTrue("_quantized_bias" in state_dict)
     self.assertTrue(state_dict["_quantized_bias"].dtype == torch.int32)
     self.assertTrue("bias_scale" in state_dict)
     qlinear.mode_8bit = False
     state_dict = qlinear.state_dict()
     self.assertTrue("weight" in state_dict)
     self.assertTrue("bias" in state_dict)
     self.assertTrue("quantized_weight" not in state_dict)
     self.assertTrue("_quantized_bias" not in state_dict)
     self.assertTrue("bias_scale" not in state_dict)
 def test_import_from_8bit_without_bias(self):
     exporter = QuantizedLinear(10, 5, bias=False, mode='dynamic')
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     exporter.mode_8bit = False
     importer = QuantizedLinear(10, 5, bias=False, mode='dynamic')
     self.assertTrue((exporter.weight != importer.weight).any())
     importer.eval()
     importer.load_state_dict(state_dict, strict=False)
     x = torch.randn(3, 10)
     self.assertTrue((exporter(x) == importer(x)).all())
 def test_restrict_loading_to_train_model(self):
     exporter = QuantizedLinear(10, 5, mode='dynamic')
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     importer = QuantizedLinear(10, 5, mode='dynamic')
     with self.assertRaises(RuntimeError):
         importer.load_state_dict(state_dict, strict=False)
예제 #6
0
 def test_train_block_when_loading_quantized_model(self):
     exporter = QuantizedLinear(10, 5, mode="dynamic")
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     importer = QuantizedLinear(10, 5, mode="dynamic")
     importer.eval()
     importer.load_state_dict(state_dict, strict=False)
     with self.assertRaises(RuntimeError):
         importer.train()