def test_import_from_8bit_with_bias(self):
     # QuantizationMode dynamic
     exporter = QuantizedLinear(10, 5, mode='dynamic')
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     exporter.mode_8bit = False
     importer = QuantizedLinear(10, 5, mode='dynamic')
     self.assertTrue((exporter.weight != importer.weight).any())
     self.assertTrue((exporter.bias != importer.bias).any())
     importer.eval()
     importer.load_state_dict(state_dict, strict=False)
     x = torch.randn(3, 10)
     self.assertTrue((exporter(x) == importer(x)).all())
     # QuantizationMode ema
     exporter = QuantizedLinear(10, 5, requantize_output=False, mode='ema')
     x = torch.randn(3, 10)
     exporter(x)
     self.assertTrue(exporter.input_thresh != 0.)
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     exporter.mode_8bit = False
     importer = QuantizedLinear(10, 5, requantize_output=False, mode='ema')
     self.assertTrue((exporter.weight != importer.weight).any())
     self.assertTrue((exporter.bias != importer.bias).any())
     importer.eval()
     importer.load_state_dict(state_dict, strict=False)
     self.assertTrue((exporter(x) == importer(x)).all())
 def test_restrict_loading_to_train_model(self):
     exporter = QuantizedLinear(10, 5, mode='dynamic')
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     importer = QuantizedLinear(10, 5, mode='dynamic')
     with self.assertRaises(RuntimeError):
         importer.load_state_dict(state_dict, strict=False)
예제 #3
0
 def test_train_block_when_loading_quantized_model(self):
     exporter = QuantizedLinear(10, 5, mode="dynamic")
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     importer = QuantizedLinear(10, 5, mode="dynamic")
     importer.eval()
     importer.load_state_dict(state_dict, strict=False)
     with self.assertRaises(RuntimeError):
         importer.train()
 def test_import_from_8bit_without_bias(self):
     exporter = QuantizedLinear(10, 5, bias=False, mode='dynamic')
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     exporter.mode_8bit = False
     importer = QuantizedLinear(10, 5, bias=False, mode='dynamic')
     self.assertTrue((exporter.weight != importer.weight).any())
     importer.eval()
     importer.load_state_dict(state_dict, strict=False)
     x = torch.randn(3, 10)
     self.assertTrue((exporter(x) == importer(x)).all())