def test_fused_obs_fq_moving_avg_module(self, device): # Set up the parameters running_min_op = torch.tensor(float("inf"), device=device) running_max_op = torch.tensor(float("-inf"), device=device) avg_const = 0.001 scale = torch.tensor([1.0], device=device) zero_point = torch.tensor([0], dtype=torch.int, device=device) mod = FusedMovingAvgObsFakeQuantize(averaging_constant=0.001) mod.to(device) mod.observer_enabled[0] = 0 mod.fake_quant_enabled[0] = 0 for i in range(10): x = torch.randn(5, 5, device=device) if i > 2: mod.observer_enabled[0] = 1 if i > 4: mod.fake_quant_enabled[0] = 1 # Run the forward on the Module out = mod(x) # Run the operator directly pt_op = torch.fused_moving_avg_obs_fake_quant out_ref = pt_op( x, mod.observer_enabled, mod.fake_quant_enabled, running_min_op, running_max_op, scale, zero_point, avg_const, 0, 255, 0, False, ) # Compare params with reference torch.testing.assert_allclose(out, out_ref) torch.testing.assert_allclose( running_min_op, mod.activation_post_process.min_val ) torch.testing.assert_allclose( running_max_op, mod.activation_post_process.max_val )
def test_compare_fused_obs_fq_oss_module(self, device): mod = FusedMovingAvgObsFakeQuantize() torch.quantization.enable_fake_quant(mod) torch.quantization.enable_observer(mod) mod.to(device) mod_ref = FakeQuantize() torch.quantization.enable_fake_quant(mod_ref) torch.quantization.enable_observer(mod_ref) mod_ref.to(device) for i in range(10): x = torch.randn(5, 5, device=device) out = mod(x) out_ref = mod_ref(x) torch.testing.assert_allclose(out, out_ref) torch.testing.assert_allclose( mod_ref.activation_post_process.min_val, mod.activation_post_process.min_val, ) torch.testing.assert_allclose( mod_ref.activation_post_process.max_val, mod.activation_post_process.max_val, )