def test_quantized_embedding_inference_forward(self):
     embedding = QuantizedEmbedding(10, 3, mode="ema")
     with torch.no_grad():
         scale = 127. / embedding.weight.abs().max()
     self.assertTrue((embedding.fake_quantized_weight == fake_quantize_np(
         embedding.weight.detach(), scale, 8)).all())
     embedding.weight.data = torch.randn_like(
         embedding.weight).mul(127.).round().clamp(-127., 127.)
     indices = torch.tensor(np.arange(10))
     embedding.eval()
     ground = F.embedding(indices, embedding.weight)
     quantized = embedding(indices)
     self.assertTrue((ground == quantized).all())
 def test_quantization_turned_off(self):
     qembedding = QuantizedEmbedding(10, 3, mode="none")
     embedding = nn.Embedding(10, 3)
     embedding.weight.data = qembedding.weight
     indices = torch.tensor(np.arange(10))
     self.assertTrue((embedding(indices) == qembedding(indices)).all())
     self.assertTrue((embedding(indices) == qembedding(indices)).all())
 def test_delay_quantization_start(self):
     qembedding = QuantizedEmbedding(10, 3, mode="ema", start_step=1)
     embedding = nn.Embedding(10, 3)
     embedding.weight.data = qembedding.weight
     indices = torch.tensor(np.arange(10))
     self.assertTrue((embedding(indices) == qembedding(indices)).all())
     self.assertTrue((embedding(indices) != qembedding(indices)).any())
 def test_export_to_8bit(self):
     qembed = QuantizedEmbedding(10, 5, mode='EMA')
     qembed.eval()
     state_dict = qembed.state_dict()
     self.assertTrue('quantized_weight' not in state_dict)
     self.assertTrue('weight' in state_dict)
     qembed.mode_8bit = True
     state_dict = qembed.state_dict()
     self.assertTrue('quantized_weight' in state_dict)
     self.assertTrue(state_dict['quantized_weight'].dtype == torch.int8)
     self.assertTrue('weight' not in state_dict)
     qembed.mode_8bit = False
     state_dict = qembed.state_dict()
     self.assertTrue('quantized_weight' not in state_dict)
     self.assertTrue('weight' in state_dict)
Пример #5
0
 def test_export_to_8bit(self):
     qembed = QuantizedEmbedding(10, 5, mode="EMA")
     qembed.eval()
     state_dict = qembed.state_dict()
     self.assertTrue("quantized_weight" not in state_dict)
     self.assertTrue("weight" in state_dict)
     qembed.mode_8bit = True
     state_dict = qembed.state_dict()
     self.assertTrue("quantized_weight" in state_dict)
     self.assertTrue(state_dict["quantized_weight"].dtype == torch.int8)
     self.assertTrue("weight" not in state_dict)
     qembed.mode_8bit = False
     state_dict = qembed.state_dict()
     self.assertTrue("quantized_weight" not in state_dict)
     self.assertTrue("weight" in state_dict)
 def test_quantized_embedding_backward(self):
     embedding = QuantizedEmbedding(10, 3, mode="ema")
     linear = nn.Linear(3, 1)
     indices = torch.tensor([2])
     h = embedding(indices)
     y = linear(h)
     y.backward()
     grad = torch.zeros_like(embedding.weight)
     grad[indices.item(), :] = linear.weight.t().squeeze()
     self.assertTrue((embedding.weight.grad == grad).all())
     self.assertTrue((linear.weight.grad == h).all())
Пример #7
0
def quantized_embedding_setup(config, name, *args, **kwargs):
    """
    Get QuantizedEmbedding layer according to config params
    """
    try:
        quant_config = QuantizationConfig.from_dict(getattr(config, name))
        embedding = QuantizedEmbedding.from_config(*args,
                                                   **kwargs,
                                                   config=quant_config)
    except AttributeError:
        embedding = nn.Embedding(*args, **kwargs)
    return embedding
 def test_load_from_8bit(self):
     exporter = QuantizedEmbedding(10, 5, mode='EMA')
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     exporter.mode_8bit = False
     importer = QuantizedEmbedding(10, 5, mode='EMA')
     self.assertTrue((exporter.weight != importer.weight).any())
     importer.eval()
     importer.load_state_dict(state_dict, strict=False)
     indices = torch.tensor(np.arange(10))
     self.assertTrue((exporter(indices) == importer(indices)).all())