コード例 #1
0
    def fake_quant_scriptable(self):
        observer = default_observer
        quant_min = 0
        quant_max = 255
        fq_module = FakeQuantize(observer, quant_min, quant_max)
        scripted_module = torch.jit.script(fq_module)

        X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32)

        fq_module(X)
        scripted_module(X)
        self.assertEqual(fq_module.calculate_qparams(),
                         scripted_module.calculate_qparams())

        buf = io.BytesIO()
        torch.jit.save(scripted_module, buf)
        buf.seek(0)
        loaded_module = torch.jit.load(buf)
        self.assertEqual(fq_module.calculate_qparams(),
                         loaded_module.calculate_qparams())
コード例 #2
0
    def test_fq_serializable(self):
        observer = default_observer
        quant_min = 0
        quant_max = 255
        fq_module = FakeQuantize(observer, quant_min, quant_max)
        X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32)
        y_ref = fq_module(X)
        state_dict = fq_module.state_dict()
        self.assertEqual(state_dict['scale'], 0.094488)
        self.assertEqual(state_dict['zero_point'], 53)
        b = io.BytesIO()
        torch.save(state_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        loaded_fq_module = FakeQuantize(observer, quant_min, quant_max)
        loaded_fq_module.load_state_dict(loaded_dict)
        for key in state_dict:
            self.assertEqual(state_dict[key], loaded_fq_module.state_dict()[key])

        self.assertEqual(loaded_fq_module.calculate_qparams(), fq_module.calculate_qparams())