def test_fused_mod_per_channel(self): devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] m = 5 n = 10 for device in devices: running_min_op = torch.empty(m, device=device).fill_(float("inf")) running_max_op = torch.empty(m, device=device).fill_(float("-inf")) avg_const = 0.001 scale = torch.empty(m, device=device).fill_(0.1) zero_point = torch.empty(m, dtype=torch.int, device=device).fill_(0) obs = FusedMovingAvgObsFakeQuantize.with_args( averaging_constant=avg_const, observer=MovingAveragePerChannelMinMaxObserver, ) mod = obs() mod = torch.jit.script(mod) mod.to(device) for i in range(10): x = torch.randn(m, n, device=device) if i > 2: mod.observer_enabled[0] = 1 if i > 4: mod.fake_quant_enabled[0] = 1 # Run the forward on the Module out = mod(x) # Run the operator directly pt_op = torch.fused_moving_avg_obs_fake_quant out_ref = pt_op( x, mod.observer_enabled, mod.fake_quant_enabled, running_min_op, running_max_op, scale, zero_point, avg_const, 0, 255, 0, True, False, ) # Compare params with reference torch.testing.assert_allclose(out, out_ref) if mod.observer_enabled[0]: torch.testing.assert_allclose( running_min_op, mod.activation_post_process.min_val ) torch.testing.assert_allclose( running_max_op, mod.activation_post_process.max_val ) if mod.fake_quant_enabled: torch.testing.assert_allclose(scale, mod.scale) torch.testing.assert_allclose(zero_point, mod.zero_point) torch.testing.assert_allclose(mod.state_dict()['activation_post_process.min_val'], running_min_op) torch.testing.assert_allclose(mod.state_dict()['activation_post_process.max_val'], running_max_op)
def test_fused_obs_fq_module(self, device): # Set up the parameters x = torch.randn(5, 5, device=device) running_min_op = torch.tensor(float("inf"), device=device) running_max_op = torch.tensor(float("-inf"), device=device) avg_const = 0.01 scale = torch.tensor([1.0], device=device) zero_point = torch.tensor([0], dtype=torch.int, device=device) # Run the forward on the Module mod = FusedMovingAvgObsFakeQuantize() torch.quantization.enable_fake_quant(mod) torch.quantization.enable_observer(mod) mod.to(device) out = mod(x) # Run the operator directly pt_op = torch.fused_moving_avg_obs_fake_quant out_ref = pt_op( x, mod.observer_enabled, mod.fake_quant_enabled, running_min_op, running_max_op, scale, zero_point, avg_const, 0, 255, 0, False, ) # Compare params with reference torch.testing.assert_allclose(out, out_ref) torch.testing.assert_allclose( running_min_op, mod.activation_post_process.min_val ) torch.testing.assert_allclose( running_max_op, mod.activation_post_process.max_val )
def test_compare_fused_obs_fq_oss_module(self, device): mod = FusedMovingAvgObsFakeQuantize() torch.quantization.enable_fake_quant(mod) torch.quantization.enable_observer(mod) mod.to(device) mod_ref = FakeQuantize() torch.quantization.enable_fake_quant(mod_ref) torch.quantization.enable_observer(mod_ref) mod_ref.to(device) for i in range(10): x = torch.randn(5, 5, device=device) out = mod(x) out_ref = mod_ref(x) torch.testing.assert_allclose(out, out_ref) torch.testing.assert_allclose( mod_ref.activation_post_process.min_val, mod.activation_post_process.min_val, ) torch.testing.assert_allclose( mod_ref.activation_post_process.max_val, mod.activation_post_process.max_val, )
def test_fused_mod_reduce_range(self): obs = FusedMovingAvgObsFakeQuantize(quant_min=0, quant_max=255, dtype=torch.quint8, reduce_range=True) self.assertEqual(obs.quant_min, 0) self.assertEqual(obs.quant_max, 127)