def test_import_from_8bit_with_bias(self): # QuantizationMode dynamic exporter = QuantizedLinear(10, 5, mode='dynamic') exporter.eval() exporter.mode_8bit = True state_dict = exporter.state_dict() exporter.mode_8bit = False importer = QuantizedLinear(10, 5, mode='dynamic') self.assertTrue((exporter.weight != importer.weight).any()) self.assertTrue((exporter.bias != importer.bias).any()) importer.eval() importer.load_state_dict(state_dict, strict=False) x = torch.randn(3, 10) self.assertTrue((exporter(x) == importer(x)).all()) # QuantizationMode ema exporter = QuantizedLinear(10, 5, requantize_output=False, mode='ema') x = torch.randn(3, 10) exporter(x) self.assertTrue(exporter.input_thresh != 0.) exporter.eval() exporter.mode_8bit = True state_dict = exporter.state_dict() exporter.mode_8bit = False importer = QuantizedLinear(10, 5, requantize_output=False, mode='ema') self.assertTrue((exporter.weight != importer.weight).any()) self.assertTrue((exporter.bias != importer.bias).any()) importer.eval() importer.load_state_dict(state_dict, strict=False) self.assertTrue((exporter(x) == importer(x)).all())
def test_train_block_when_loading_quantized_model(self): exporter = QuantizedLinear(10, 5, mode='dynamic') exporter.eval() exporter.mode_8bit = True state_dict = exporter.state_dict() importer = QuantizedLinear(10, 5, mode='dynamic') importer.eval() importer.load_state_dict(state_dict, strict=False) with self.assertRaises(RuntimeError): importer.train()
def test_import_from_8bit_without_bias(self): exporter = QuantizedLinear(10, 5, bias=False, mode='dynamic') exporter.eval() exporter.mode_8bit = True state_dict = exporter.state_dict() exporter.mode_8bit = False importer = QuantizedLinear(10, 5, bias=False, mode='dynamic') self.assertTrue((exporter.weight != importer.weight).any()) importer.eval() importer.load_state_dict(state_dict, strict=False) x = torch.randn(3, 10) self.assertTrue((exporter(x) == importer(x)).all())