Esempio n. 1
0
    def collect(self, x):
        """Tracks the absolute max of all tensors

        Args:
            x: A tensor

        Raises:
            RuntimeError: If amax shape changes
        """
        if torch.min(x) < 0.:
            logging.log_first_n(logging.INFO, (
                "Calibrator encountered negative values. It shouldn't happen after ReLU. "
                "Make sure this is the right tensor to calibrate."), 1)
            x = x.abs()

        # Swap axis to reduce.
        axis = self._axis if isinstance(self._axis,
                                        (list, tuple)) else [self._axis]
        reduce_axis = []
        for i in range(x.dim()):
            if not i in axis:
                reduce_axis.append(i)
        local_amax = quant_utils.reduce_amax(x, axis=reduce_axis).detach()
        if self._calib_amax is None:
            self._calib_amax = local_amax
        else:
            if local_amax.shape != self._calib_amax.shape:
                raise RuntimeError("amax shape changed!")
            self._calib_amax.copy_(
                torch.max(self._calib_amax, local_amax).data)

        if self._track_amax:
            self._amaxs.append(local_amax.cpu().numpy())
Esempio n. 2
0
    def test_fake_quant_per_channel_bias(self):
        kernel_size = 3

        quant_conv_object = quant_conv.QuantConvTranspose3d(
            _NUM_IN_CHANNELS,
            _NUM_OUT_CHANNELS,
            kernel_size,
            bias=True,
            quant_desc_weight=tensor_quant.
            QUANT_DESC_8BIT_CONVTRANSPOSE3D_WEIGHT_PER_CHANNEL)
        test_input = torch.randn(2, _NUM_IN_CHANNELS, 2, 2, 2)

        quant_input = tensor_quant.fake_tensor_quant(
            test_input, torch.max(torch.abs(test_input)))

        weight_copy = quant_conv_object.weight.clone()
        amax = quant_utils.reduce_amax(weight_copy, axis=(0, 2, 3, 4))
        quant_weight = tensor_quant.fake_tensor_quant(weight_copy, amax)

        out1 = F.conv_transpose3d(quant_input,
                                  quant_weight,
                                  bias=quant_conv_object.bias)
        out2 = quant_conv_object(test_input)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(),
                                      out2.detach().cpu().numpy())
Esempio n. 3
0
    def test_reduce_amax(self):
        x_np = (np.random.rand(3, 7, 11, 13, 17) - 0.1).astype(np.float32)
        x_torch = torch.tensor(x_np)

        # Test reduce to one value
        amax_np = np.max(np.abs(x_np))
        amax_torch = quant_utils.reduce_amax(x_torch)
        np.testing.assert_array_equal(amax_np, amax_torch.cpu().numpy())

        # Test different axis
        axes = [(1, 2, 3), (0, 2, 3), (0, 3), (0, 1, 3, 4)]
        for axis in axes:
            keepdims = np.random.rand() > 0.5
            amax_np = np.max(np.abs(x_np), axis=axis, keepdims=keepdims)
            amax_torch = quant_utils.reduce_amax(x_torch, axis=axis, keepdims=keepdims)
            np.testing.assert_array_almost_equal(amax_np, amax_torch.cpu().numpy())

        with pytest.raises(ValueError) as excinfo:
            quant_utils.reduce_amax(x_torch, axis=(0, 1, 2, 3, 4, 5))
            assert "Cannot reduce more axes" in str(excinfo.value)
Esempio n. 4
0
    def test_max_calib(self):
        axis = 0
        reduce_axis = (1, 2, 3)
        quant_desc1 = tensor_quant.QuantDescriptor(axis=axis)
        quantizer1 = tensor_quantizer.TensorQuantizer(quant_desc1).cuda()
        quantizer1.enable_calib()

        quant_desc1 = tensor_quant.QuantDescriptor(axis=axis)
        quantizer1 = tensor_quantizer.TensorQuantizer(quant_desc1).cuda()
        quantizer1.enable_calib()

        with pytest.raises(RuntimeError, match="Calibrator returned None"):
            quantizer1.load_calib_amax()

        x_1 = torch.rand(127, 63, 7, 7).cuda()
        x_2 = torch.rand(127, 63, 7, 7).cuda()
        quantizer1(x_1)
        quantizer1(x_2)
        quantizer1.disable_calib()

        global_amax = torch.max(
            quant_utils.reduce_amax(x_1, axis=reduce_axis, keepdims=True),
            quant_utils.reduce_amax(x_2, axis=reduce_axis, keepdims=True))
        test_utils.compare(quantizer1._calibrator.compute_amax(), global_amax, atol=0, rtol=0, ctol=0)

        quantizer1.load_calib_amax()
        test_utils.compare(quantizer1.amax, global_amax, atol=0, rtol=0, ctol=0)

        quant_desc2 = tensor_quant.QuantDescriptor(learn_amax=True)
        quantizer2 = tensor_quantizer.TensorQuantizer(quant_desc2).cuda()
        quantizer2.enable_calib()
        quantizer2(x_1)
        quantizer2(x_2)

        quantizer2.load_calib_amax()
        quantizer2.init_learn_amax()
        test_utils.compare(quantizer2.clip.clip_value_min, -torch.max(global_amax), atol=0, rtol=0, ctol=0)
        test_utils.compare(quantizer2.clip.clip_value_max, torch.max(global_amax), atol=0, rtol=0, ctol=0)
Esempio n. 5
0
    def test_fine_grain(self):
        axis = 0
        reducs_axis = (1, 2, 3)
        max_calibrator = calib.MaxCalibrator(8, axis, False)

        x_1 = torch.rand(31, 63, 7, 7).cuda()
        x_2 = torch.rand(31, 63, 7, 7).cuda()
        max_calibrator.collect(x_1)
        max_calibrator.collect(x_2)

        assert max_calibrator.compute_amax().shape[0] == 31

        test_utils.compare(max_calibrator.compute_amax(),
                           quant_utils.reduce_amax(torch.max(x_1, x_2), axis=reducs_axis),
                           atol=0, rtol=0, ctol=0)

        max_calibrator.reset()
        assert max_calibrator.compute_amax() is None
Esempio n. 6
0
    def test_weight_fake_quant_per_channel(self):
        kernel_size = 3

        quant_conv_object = quant_conv.QuantConv1d(
            _NUM_IN_CHANNELS,
            _NUM_OUT_CHANNELS,
            kernel_size,
            bias=False,
            quant_desc_weight=QuantDescriptor(axis=(0)))
        quant_conv_object.input_quantizer.disable()
        test_input = torch.randn(16, _NUM_IN_CHANNELS, 256)

        weight_copy = quant_conv_object.weight.clone()
        amax = quant_utils.reduce_amax(weight_copy, axis=(1, 2))
        quant_weight = tensor_quant.fake_tensor_quant(weight_copy, amax)

        out1 = F.conv1d(test_input, quant_weight)
        out2 = quant_conv_object(test_input)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(), out2.detach().cpu().numpy())
Esempio n. 7
0
    def test_weight_fake_per_channel(self):
        size_in = 255
        size_out = 257
        quant_linear_object = quant_linear.QuantLinear(
            size_in,
            size_out,
            bias=False,
            quant_desc_weight=tensor_quant.
            QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW)
        quant_linear_object.input_quantizer.disable()
        test_input = torch.randn(32, size_in)

        weight_copy = quant_linear_object.weight.clone()
        amax = quant_utils.reduce_amax(weight_copy, axis=1, keepdims=True)
        quant_weight = tensor_quant.fake_tensor_quant(weight_copy, amax)

        out1 = F.linear(test_input, quant_weight)
        out2 = quant_linear_object(test_input)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(),
                                      out2.detach().cpu().numpy())
    def test_weight_fake_quant_per_channel(self):
        kernel_size = 3

        quant_conv_object = quant_conv.QuantConvTranspose2d(
            _NUM_IN_CHANNELS,
            _NUM_OUT_CHANNELS,
            kernel_size,
            bias=False,
            quant_desc_weight=tensor_quant.QUANT_DESC_8BIT_CONVTRANSPOSE2D_WEIGHT_PER_CHANNEL)
        quant_conv_object.input_quantizer.disable()
        test_input = torch.randn(16, _NUM_IN_CHANNELS, 256, 256)

        weight_copy = quant_conv_object.weight.clone()

        amax = quant_utils.reduce_amax(weight_copy, axis=(0, 2, 3))
        quant_weight = tensor_quant.fake_tensor_quant(weight_copy, amax)

        out1 = F.conv_transpose2d(test_input, quant_weight)
        out2 = quant_conv_object(test_input)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(), out2.detach().cpu().numpy())
Esempio n. 9
0
    def _get_amax(self, inputs):
        """get amax from buffer or compute it dynamically."""
        if hasattr(self, '_amax'):
            amax = self._amax
        else:
            if self._axis is None:
                reduce_axis = None
            else:
                reduce_axis = []
                # Swap axis to reduce
                axis = self._axis if isinstance(self._axis,
                                                (list,
                                                 tuple)) else [self._axis]
                for i in range(inputs.dim()):
                    if not i in axis:
                        reduce_axis.append(i)
            amax = quant_utils.reduce_amax(inputs,
                                           axis=reduce_axis,
                                           keepdims=True).detach()
        if self._scale_amax is not None:
            amax = amax.detach() * self._scale_amax

        return amax
Esempio n. 10
0
def calibrate_weights(model,
                      method="percentile",
                      perchannel=True,
                      percentile=99.99,
                      num_bins=2048):
    """Calibrate weights of all child quantized modules

    Ideally, we would split calibration functionality to histogram collector and calibrator which
    takes histogram and compute amax. But since we haven't decoupled collector and calibrator, it
    is easier to create a separate function to calibrate weight.

    .. note::
        This function uses `method` specified by the argument to decide which method to use, NOT the one
        specified by the calibrator embedded in weight_quantizer.
        We haven't moved calibration to GPU, so everything is transfered to CPU

    Args:
        model: A torch.nn.Module.
        method: A string of calibration method. Supports "mse" and "percentile". Default "percentile"
        perchannel: A bool. Set channel/neuron axis if True. Default True.
        percentile: A float. Default 99.99
        num_bins: A integer. Number of bins of histogram. Default 2048.

    """
    for name, module in model.named_modules():
        if hasattr(module, "weight") and hasattr(module, "weight_quantizer"):
            logging.info("Calibrate weight of %s", name)
            num_bits = module.weight_quantizer.num_bits
            unsigned = module.weight_quantizer.unsigned
            channel_second_modules = (quant_nn.QuantConvTranspose1d,
                                      quant_nn.QuantConvTranspose2d,
                                      quant_nn.QuantConvTranspose3d)
            if perchannel:
                axis = 1 if isinstance(module, channel_second_modules) else 0
            else:
                axis = None
            axis_size = module.weight.shape[axis] if axis is not None else 1

            # Histogram is always collected even if method is "max". Although "max" is supported here
            # but it is not the primary usage of this function
            if axis is None:
                calib_hist, calib_bin_edges = np.histogram(
                    module.weight.abs().cpu().detach().numpy(), bins=2048)
                calib_hist = [calib_hist]
                calib_bin_edges = [calib_bin_edges]
            else:
                calib_hist = []
                calib_bin_edges = []
                for i in range(axis_size):
                    hist, bin_edges = np.histogram(module.weight.index_select(
                        axis, torch.tensor(i, device=module.weight.device)
                    ).abs().cpu().detach().numpy(),
                                                   bins=num_bins)
                    calib_hist.append(hist)
                    calib_bin_edges.append(bin_edges)

            calib_amax = []
            if method == "max":
                reduce_axis = list(range(module.weight.dim()))
                reduce_axis.remove(axis)
                calib_amax.append(
                    quant_utils.reduce_amax(module.weight, axis=reduce_axis))
            elif method == 'mse':
                for i in range(axis_size):
                    calib_amax.append(
                        _compute_amax_mse(calib_hist[i], calib_bin_edges[i],
                                          num_bits, unsigned))
            elif method == 'percentile':
                for i in range(axis_size):
                    calib_amax.append(
                        _compute_amax_percentile(calib_hist[i],
                                                 calib_bin_edges[i],
                                                 percentile))
            else:
                raise TypeError(
                    "Unsupported calibration method {}".format(method))

            if axis is None:
                calib_amax = calib_amax[0]
            else:
                calib_amax_shape = [1] * module.weight.dim()
                calib_amax_shape[axis] = module.weight.shape[axis]
                calib_amax = torch.stack(calib_amax).reshape(calib_amax_shape)
            module.weight_quantizer.amax = calib_amax.detach().cpu().numpy()