Esempio n. 1
0
    def test_fused_obs_fq_moving_avg_module(self, device):
        # Set up the parameters
        running_min_op = torch.tensor(float("inf"), device=device)
        running_max_op = torch.tensor(float("-inf"), device=device)
        avg_const = 0.001
        scale = torch.tensor([1.0], device=device)
        zero_point = torch.tensor([0], dtype=torch.int, device=device)

        mod = FusedMovingAvgObsFakeQuantize(averaging_constant=0.001)
        mod.to(device)
        mod.observer_enabled[0] = 0
        mod.fake_quant_enabled[0] = 0

        for i in range(10):
            x = torch.randn(5, 5, device=device)
            if i > 2:
                mod.observer_enabled[0] = 1
            if i > 4:
                mod.fake_quant_enabled[0] = 1
            # Run the forward on the Module
            out = mod(x)

            # Run the operator directly
            pt_op = torch.fused_moving_avg_obs_fake_quant

            out_ref = pt_op(
                x,
                mod.observer_enabled,
                mod.fake_quant_enabled,
                running_min_op,
                running_max_op,
                scale,
                zero_point,
                avg_const,
                0,
                255,
                0,
                False,
            )

            # Compare params with reference
            torch.testing.assert_allclose(out, out_ref)
            torch.testing.assert_allclose(
                running_min_op, mod.activation_post_process.min_val
            )
            torch.testing.assert_allclose(
                running_max_op, mod.activation_post_process.max_val
            )
    def test_compare_fused_obs_fq_oss_module(self, device):
        mod = FusedMovingAvgObsFakeQuantize()
        torch.quantization.enable_fake_quant(mod)
        torch.quantization.enable_observer(mod)
        mod.to(device)

        mod_ref = FakeQuantize()
        torch.quantization.enable_fake_quant(mod_ref)
        torch.quantization.enable_observer(mod_ref)
        mod_ref.to(device)

        for i in range(10):
            x = torch.randn(5, 5, device=device)
            out = mod(x)
            out_ref = mod_ref(x)
            torch.testing.assert_allclose(out, out_ref)
            torch.testing.assert_allclose(
                mod_ref.activation_post_process.min_val,
                mod.activation_post_process.min_val,
            )
            torch.testing.assert_allclose(
                mod_ref.activation_post_process.max_val,
                mod.activation_post_process.max_val,
            )