def test_observer_scriptable(self): obs_list = [ MinMaxObserver(), MovingAverageMinMaxObserver(), MinMaxDynamicQuantObserver() ] for obs in obs_list: scripted = torch.jit.script(obs) x = torch.rand(3, 4) obs(x) scripted(x) self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams()) buf = io.BytesIO() torch.jit.save(scripted, buf) buf.seek(0) loaded = torch.jit.load(buf) self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams()) # Check TensorListObserver from torch.quantization.observer import _MinMaxTensorListObserver obs = _MinMaxTensorListObserver() scripted = torch.jit.script(obs) x = [torch.rand(3, 4), torch.rand(4, 5)] obs(x) scripted(x) self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams())
def test_tensor_list_observer(self): from torch.quantization.observer import _MinMaxTensorListObserver x = [torch.tensor([1.0, 2.5, 3.5]), torch.tensor([2.0, 4.5, 3.5]), torch.tensor([4.0, 2.5, 3.5]), ] obs = _MinMaxTensorListObserver() obs(x) qparams = obs.calculate_qparams() ref_min_val = [] ref_max_val = [] ref_qparams = [] for i in x: obs_ref = MinMaxObserver() obs_ref(i) ref_min_val.append(obs_ref.min_val) ref_max_val.append(obs_ref.max_val) ref_qparams.append(obs_ref.calculate_qparams()) for i in range(len(x)): self.assertEqual(obs.min_val[i], ref_min_val[i]) self.assertEqual(obs.max_val[i], ref_max_val[i]) self.assertEqual(qparams[0][i], ref_qparams[i][0]) self.assertEqual(qparams[1][i], ref_qparams[i][1])