def test_fq_serializable(self): observer = default_observer quant_min = 0 quant_max = 255 fq_module = FakeQuantize(observer, quant_min, quant_max) X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32) y_ref = fq_module(X) state_dict = fq_module.state_dict() self.assertEqual(state_dict['scale'], 0.094488) self.assertEqual(state_dict['zero_point'], 53) b = io.BytesIO() torch.save(state_dict, b) b.seek(0) loaded_dict = torch.load(b) loaded_fq_module = FakeQuantize(observer, quant_min, quant_max) loaded_fq_module.load_state_dict(loaded_dict) for key in state_dict: self.assertEqual(state_dict[key], loaded_fq_module.state_dict()[key]) self.assertEqual(loaded_fq_module.calculate_qparams(), fq_module.calculate_qparams())
def test_fq_serializable(self): observer = default_per_channel_weight_observer quant_min = -128 quant_max = 127 fq_module = FakeQuantize(observer, quant_min, quant_max) X = torch.tensor([[-5, -3.5, -2, 0, 3, 5, 7], [1, 3, 2, 5, 6.5, 8, 10]], dtype=torch.float32) y_ref = fq_module(X) state_dict = fq_module.state_dict() self.assertEqual(state_dict['scale'], [0.054902, 0.078431]) self.assertEqual(state_dict['zero_point'], [0, 0]) b = io.BytesIO() torch.save(state_dict, b) b.seek(0) loaded_dict = torch.load(b) for key in state_dict: self.assertEqual(state_dict[key], loaded_dict[key])