def test_per_tensor_dynamic_quant_observers(self, X, reduce_range): X, (scale, zero_point, torch_type) = X x = torch.from_numpy(X) obs = MinMaxDynamicQuantObserver(dtype=torch.quint8, reduce_range=reduce_range) result = obs(x) qparams = obs.calculate_qparams() ref = torch._choose_qparams_per_tensor(x, reduce_range) self.assertEqual(ref[0], qparams[0]) self.assertEqual(ref[1], qparams[1])
def test_observer_scriptable(self): obs_list = [ MinMaxObserver(), MovingAverageMinMaxObserver(), MinMaxDynamicQuantObserver() ] for obs in obs_list: scripted = torch.jit.script(obs) x = torch.rand(3, 4) obs(x) scripted(x) self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams()) buf = io.BytesIO() torch.jit.save(scripted, buf) buf.seek(0) loaded = torch.jit.load(buf) self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams()) # Check TensorListObserver from torch.quantization.observer import _MinMaxTensorListObserver obs = _MinMaxTensorListObserver() scripted = torch.jit.script(obs) x = [torch.rand(3, 4), torch.rand(4, 5)] obs(x) scripted(x) self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams())
def test_observer_scriptable(self): obs_list = [MinMaxObserver(), MovingAverageMinMaxObserver(), MinMaxDynamicQuantObserver()] for obs in obs_list: scripted = torch.jit.script(obs) x = torch.rand(3, 4) obs(x) scripted(x) self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams()) buf = io.BytesIO() torch.jit.save(scripted, buf) buf.seek(0) loaded = torch.jit.load(buf) self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams())