예제 #1
0
    def test_dequantize_dim(self):
        # make observer
        observer = APoTObserver(4, 2)

        # generate random size of tensor2quantize between 1 -> 20
        size1 = random.randint(1, 20)
        size2 = random.randint(1, 20)
        size3 = random.randint(1, 20)

        # make tensor2quantize: random fp values between 0 -> 1000
        tensor2quantize = 1000 * torch.rand(
            size1, size2, size3, dtype=torch.float)

        observer.forward(tensor2quantize)

        alpha, gamma, quantization_levels, level_indices = observer.calculate_qparams(
            signed=False)

        # make mock apot_tensor
        original_apot = quantize_APoT(tensor2quantize=tensor2quantize,
                                      alpha=alpha,
                                      gamma=gamma,
                                      quantization_levels=quantization_levels,
                                      level_indices=level_indices)

        # dequantize apot_tensor
        dequantize_result = dequantize_APoT(apot_tensor=original_apot)

        self.assertEqual(original_apot.data.size(), dequantize_result.size())
예제 #2
0
    def forward(self, X: torch.Tensor, signed: bool):  # type: ignore[override]
        if self.observer_enabled[0] == 1:
            self.activation_post_process.forward(X)
            self.alpha, self.gamma, self.quantization_levels, self.level_indices = \
                self.activation_post_process.calculate_qparams(signed)
        if self.fake_quant_enabled[0] == 1:
            assert (self.alpha is not None
                    and self.gamma is not None
                    and self.quantization_levels is not None
                    and self.level_indices is not None), "Must set qparams for fake quant"

            X = quantize_APoT(X, self.alpha, self.gamma, self.quantization_levels, self.level_indices)
            X = dequantize_APoT(X)

        return X
예제 #3
0
    def forward(ctx,  # type: ignore[override]
                x: Tensor,
                alpha: Tensor,
                gamma: Tensor,
                quantization_levels: Tensor,
                level_indices: Tensor) -> Tensor:
        quantized_result = quantize_APoT(x, alpha, gamma, quantization_levels, level_indices)

        # calculate mask tensor
        mask = x.detach().apply_(lambda x: (x <= alpha and x >= -alpha))

        result = dequantize_APoT(quantized_result)

        ctx.save_for_backward(mask)

        return result
예제 #4
0
    def test_forward(self):
        # generate a tensor of size 20 with random values
        # between 0 -> 1000 to quantize -> dequantize
        X = 1000 * torch.rand(20)

        observer = APoTObserver(b=4, k=2)
        observer.forward(X)
        alpha, gamma, quantization_levels, level_indices = observer.calculate_qparams(signed=False)

        apot_fake = APoTFakeQuantize(b=4, k=2)
        apot_fake.enable_observer()
        apot_fake.enable_fake_quant()

        X_reduced_precision_fp = apot_fake.forward(torch.clone(X), False)

        # get X_expected by converting fp -> apot -> fp to simulate quantize -> dequantize
        X_to_apot = quantize_APoT(X, alpha, gamma, quantization_levels, level_indices)
        X_expected = dequantize_APoT(X_to_apot)

        self.assertTrue(torch.equal(X_reduced_precision_fp, X_expected))
예제 #5
0
    def test_dequantize_quantize_rand_b6(self):
        # make observer
        observer = APoTObserver(12, 4)

        # generate random size of tensor2quantize between 1 -> 20
        size = random.randint(1, 20)

        # make tensor2quantize: random fp values between 0 -> 1000
        tensor2quantize = 1000 * torch.rand(size, dtype=torch.float)

        observer.forward(tensor2quantize)

        alpha, gamma, quantization_levels, level_indices = observer.calculate_qparams(
            signed=False)

        # make mock apot_tensor
        original_apot = quantize_APoT(tensor2quantize=tensor2quantize,
                                      alpha=alpha,
                                      gamma=gamma,
                                      quantization_levels=quantization_levels,
                                      level_indices=level_indices)

        original_input = torch.clone(original_apot.data).int()

        # dequantize apot_tensor
        dequantize_result = dequantize_APoT(apot_tensor=original_apot)

        # quantize apot_tensor
        final_apot = quantize_APoT(tensor2quantize=dequantize_result,
                                   alpha=alpha,
                                   gamma=gamma,
                                   quantization_levels=quantization_levels,
                                   level_indices=level_indices)

        result = final_apot.data.int()

        self.assertTrue(torch.equal(original_input, result))