def test_load_from_8bit(self):
     exporter = QuantizedEmbedding(10, 5, mode='EMA')
     exporter.eval()
     exporter.mode_8bit = True
     state_dict = exporter.state_dict()
     exporter.mode_8bit = False
     importer = QuantizedEmbedding(10, 5, mode='EMA')
     self.assertTrue((exporter.weight != importer.weight).any())
     importer.eval()
     importer.load_state_dict(state_dict, strict=False)
     indices = torch.tensor(np.arange(10))
     self.assertTrue((exporter(indices) == importer(indices)).all())
 def test_export_to_8bit(self):
     qembed = QuantizedEmbedding(10, 5, mode='EMA')
     qembed.eval()
     state_dict = qembed.state_dict()
     self.assertTrue('quantized_weight' not in state_dict)
     self.assertTrue('weight' in state_dict)
     qembed.mode_8bit = True
     state_dict = qembed.state_dict()
     self.assertTrue('quantized_weight' in state_dict)
     self.assertTrue(state_dict['quantized_weight'].dtype == torch.int8)
     self.assertTrue('weight' not in state_dict)
     qembed.mode_8bit = False
     state_dict = qembed.state_dict()
     self.assertTrue('quantized_weight' not in state_dict)
     self.assertTrue('weight' in state_dict)
 def test_export_to_8bit(self):
     qembed = QuantizedEmbedding(10, 5, mode="EMA")
     qembed.eval()
     state_dict = qembed.state_dict()
     self.assertTrue("quantized_weight" not in state_dict)
     self.assertTrue("weight" in state_dict)
     qembed.mode_8bit = True
     state_dict = qembed.state_dict()
     self.assertTrue("quantized_weight" in state_dict)
     self.assertTrue(state_dict["quantized_weight"].dtype == torch.int8)
     self.assertTrue("weight" not in state_dict)
     qembed.mode_8bit = False
     state_dict = qembed.state_dict()
     self.assertTrue("quantized_weight" not in state_dict)
     self.assertTrue("weight" in state_dict)