コード例 #1
0
    def test_fq_serializable(self):
        observer = default_observer
        quant_min = 0
        quant_max = 255
        fq_module = FakeQuantize(observer, quant_min, quant_max)
        X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32)
        y_ref = fq_module(X)
        state_dict = fq_module.state_dict()
        self.assertEqual(state_dict['scale'], 0.094488)
        self.assertEqual(state_dict['zero_point'], 53)
        b = io.BytesIO()
        torch.save(state_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        loaded_fq_module = FakeQuantize(observer, quant_min, quant_max)
        loaded_fq_module.load_state_dict(loaded_dict)
        for key in state_dict:
            self.assertEqual(state_dict[key], loaded_fq_module.state_dict()[key])

        self.assertEqual(loaded_fq_module.calculate_qparams(), fq_module.calculate_qparams())
コード例 #2
0
 def test_fq_serializable(self):
     observer = default_per_channel_weight_observer
     quant_min = -128
     quant_max = 127
     fq_module = FakeQuantize(observer, quant_min, quant_max)
     X = torch.tensor([[-5, -3.5, -2, 0, 3, 5, 7], [1, 3, 2, 5, 6.5, 8, 10]], dtype=torch.float32)
     y_ref = fq_module(X)
     state_dict = fq_module.state_dict()
     self.assertEqual(state_dict['scale'], [0.054902, 0.078431])
     self.assertEqual(state_dict['zero_point'], [0, 0])
     b = io.BytesIO()
     torch.save(state_dict, b)
     b.seek(0)
     loaded_dict = torch.load(b)
     for key in state_dict:
         self.assertEqual(state_dict[key], loaded_dict[key])