예제 #1
0
    def test_fused_mod_per_channel(self):
        devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
        m = 5
        n = 10
        for device in devices:
            running_min_op = torch.empty(m, device=device).fill_(float("inf"))
            running_max_op = torch.empty(m, device=device).fill_(float("-inf"))
            avg_const = 0.001
            scale = torch.empty(m, device=device).fill_(0.1)
            zero_point = torch.empty(m, dtype=torch.int, device=device).fill_(0)
            obs = FusedMovingAvgObsFakeQuantize.with_args(
                averaging_constant=avg_const,
                observer=MovingAveragePerChannelMinMaxObserver,
            )
            mod = obs()
            mod = torch.jit.script(mod)
            mod.to(device)

            for i in range(10):
                x = torch.randn(m, n, device=device)
                if i > 2:
                    mod.observer_enabled[0] = 1
                if i > 4:
                    mod.fake_quant_enabled[0] = 1
                # Run the forward on the Module
                out = mod(x)

                # Run the operator directly
                pt_op = torch.fused_moving_avg_obs_fake_quant

                out_ref = pt_op(
                    x,
                    mod.observer_enabled,
                    mod.fake_quant_enabled,
                    running_min_op,
                    running_max_op,
                    scale,
                    zero_point,
                    avg_const,
                    0,
                    255,
                    0,
                    True,
                    False,
                )
                # Compare params with reference
                torch.testing.assert_allclose(out, out_ref)
                if mod.observer_enabled[0]:
                    torch.testing.assert_allclose(
                        running_min_op, mod.activation_post_process.min_val
                    )
                    torch.testing.assert_allclose(
                        running_max_op, mod.activation_post_process.max_val
                    )
                if mod.fake_quant_enabled:
                    torch.testing.assert_allclose(scale, mod.scale)
                    torch.testing.assert_allclose(zero_point, mod.zero_point)

            torch.testing.assert_allclose(mod.state_dict()['activation_post_process.min_val'], running_min_op)
            torch.testing.assert_allclose(mod.state_dict()['activation_post_process.max_val'], running_max_op)
예제 #2
0
    def test_fused_obs_fq_module(self, device):
        # Set up the parameters
        x = torch.randn(5, 5, device=device)
        running_min_op = torch.tensor(float("inf"), device=device)
        running_max_op = torch.tensor(float("-inf"), device=device)
        avg_const = 0.01
        scale = torch.tensor([1.0], device=device)
        zero_point = torch.tensor([0], dtype=torch.int, device=device)

        # Run the forward on the Module
        mod = FusedMovingAvgObsFakeQuantize()
        torch.quantization.enable_fake_quant(mod)
        torch.quantization.enable_observer(mod)
        mod.to(device)
        out = mod(x)

        # Run the operator directly
        pt_op = torch.fused_moving_avg_obs_fake_quant

        out_ref = pt_op(
            x,
            mod.observer_enabled,
            mod.fake_quant_enabled,
            running_min_op,
            running_max_op,
            scale,
            zero_point,
            avg_const,
            0,
            255,
            0,
            False,
        )

        # Compare params with reference
        torch.testing.assert_allclose(out, out_ref)
        torch.testing.assert_allclose(
            running_min_op, mod.activation_post_process.min_val
        )
        torch.testing.assert_allclose(
            running_max_op, mod.activation_post_process.max_val
        )
예제 #3
0
    def test_compare_fused_obs_fq_oss_module(self, device):
        mod = FusedMovingAvgObsFakeQuantize()
        torch.quantization.enable_fake_quant(mod)
        torch.quantization.enable_observer(mod)
        mod.to(device)

        mod_ref = FakeQuantize()
        torch.quantization.enable_fake_quant(mod_ref)
        torch.quantization.enable_observer(mod_ref)
        mod_ref.to(device)

        for i in range(10):
            x = torch.randn(5, 5, device=device)
            out = mod(x)
            out_ref = mod_ref(x)
            torch.testing.assert_allclose(out, out_ref)
            torch.testing.assert_allclose(
                mod_ref.activation_post_process.min_val,
                mod.activation_post_process.min_val,
            )
            torch.testing.assert_allclose(
                mod_ref.activation_post_process.max_val,
                mod.activation_post_process.max_val,
            )
예제 #4
0
    def test_fused_mod_reduce_range(self):
        obs = FusedMovingAvgObsFakeQuantize(quant_min=0, quant_max=255, dtype=torch.quint8, reduce_range=True)

        self.assertEqual(obs.quant_min, 0)
        self.assertEqual(obs.quant_max, 127)