def test_quantized_embedding_inference_forward(self): embedding = QuantizedEmbedding(10, 3, mode="ema") with torch.no_grad(): scale = 127. / embedding.weight.abs().max() self.assertTrue((embedding.fake_quantized_weight == fake_quantize_np( embedding.weight.detach(), scale, 8)).all()) embedding.weight.data = torch.randn_like( embedding.weight).mul(127.).round().clamp(-127., 127.) indices = torch.tensor(np.arange(10)) embedding.eval() ground = F.embedding(indices, embedding.weight) quantized = embedding(indices) self.assertTrue((ground == quantized).all())
def test_quantization_turned_off(self): qembedding = QuantizedEmbedding(10, 3, mode="none") embedding = nn.Embedding(10, 3) embedding.weight.data = qembedding.weight indices = torch.tensor(np.arange(10)) self.assertTrue((embedding(indices) == qembedding(indices)).all()) self.assertTrue((embedding(indices) == qembedding(indices)).all())
def test_delay_quantization_start(self): qembedding = QuantizedEmbedding(10, 3, mode="ema", start_step=1) embedding = nn.Embedding(10, 3) embedding.weight.data = qembedding.weight indices = torch.tensor(np.arange(10)) self.assertTrue((embedding(indices) == qembedding(indices)).all()) self.assertTrue((embedding(indices) != qembedding(indices)).any())
def test_export_to_8bit(self): qembed = QuantizedEmbedding(10, 5, mode='EMA') qembed.eval() state_dict = qembed.state_dict() self.assertTrue('quantized_weight' not in state_dict) self.assertTrue('weight' in state_dict) qembed.mode_8bit = True state_dict = qembed.state_dict() self.assertTrue('quantized_weight' in state_dict) self.assertTrue(state_dict['quantized_weight'].dtype == torch.int8) self.assertTrue('weight' not in state_dict) qembed.mode_8bit = False state_dict = qembed.state_dict() self.assertTrue('quantized_weight' not in state_dict) self.assertTrue('weight' in state_dict)
def test_export_to_8bit(self): qembed = QuantizedEmbedding(10, 5, mode="EMA") qembed.eval() state_dict = qembed.state_dict() self.assertTrue("quantized_weight" not in state_dict) self.assertTrue("weight" in state_dict) qembed.mode_8bit = True state_dict = qembed.state_dict() self.assertTrue("quantized_weight" in state_dict) self.assertTrue(state_dict["quantized_weight"].dtype == torch.int8) self.assertTrue("weight" not in state_dict) qembed.mode_8bit = False state_dict = qembed.state_dict() self.assertTrue("quantized_weight" not in state_dict) self.assertTrue("weight" in state_dict)
def test_quantized_embedding_backward(self): embedding = QuantizedEmbedding(10, 3, mode="ema") linear = nn.Linear(3, 1) indices = torch.tensor([2]) h = embedding(indices) y = linear(h) y.backward() grad = torch.zeros_like(embedding.weight) grad[indices.item(), :] = linear.weight.t().squeeze() self.assertTrue((embedding.weight.grad == grad).all()) self.assertTrue((linear.weight.grad == h).all())
def quantized_embedding_setup(config, name, *args, **kwargs): """ Get QuantizedEmbedding layer according to config params """ try: quant_config = QuantizationConfig.from_dict(getattr(config, name)) embedding = QuantizedEmbedding.from_config(*args, **kwargs, config=quant_config) except AttributeError: embedding = nn.Embedding(*args, **kwargs) return embedding
def test_load_from_8bit(self): exporter = QuantizedEmbedding(10, 5, mode='EMA') exporter.eval() exporter.mode_8bit = True state_dict = exporter.state_dict() exporter.mode_8bit = False importer = QuantizedEmbedding(10, 5, mode='EMA') self.assertTrue((exporter.weight != importer.weight).any()) importer.eval() importer.load_state_dict(state_dict, strict=False) indices = torch.tensor(np.arange(10)) self.assertTrue((exporter(indices) == importer(indices)).all())