コード例 #1
0
ファイル: _equalize.py プロジェクト: mwootton/pytorch
    def __init__(self,
                 dtype=torch.quint8,
                 qscheme=torch.per_tensor_affine,
                 quant_min=None,
                 quant_max=None,
                 output_obs=None,
                 factory_kwargs=None) -> None:
        super(_InputEqualizationObserver, self).__init__()

        if qscheme not in {
                torch.per_tensor_affine, torch.per_tensor_symmetric
        }:
            raise TypeError("Input qscheme must be per-tensor")

        self.input_obs = PerChannelMinMaxObserver(
            ch_axis=1,
            dtype=dtype,
            qscheme=qscheme,
            quant_min=quant_min,
            quant_max=quant_max,
            factory_kwargs=factory_kwargs)

        if output_obs is None:
            self.output_obs = MinMaxObserver(dtype=dtype,
                                             qscheme=qscheme,
                                             quant_min=quant_min,
                                             quant_max=quant_max,
                                             factory_kwargs=factory_kwargs)
        else:
            self.output_obs = output_obs

        self.equalization_scale = torch.empty(0)
コード例 #2
0
ファイル: _equalize.py プロジェクト: mwootton/pytorch
    def __init__(self,
                 dtype=torch.qint8,
                 qscheme=torch.per_tensor_affine,
                 quant_min=None,
                 quant_max=None,
                 factory_kwargs=None) -> None:
        super(_WeightEqualizationObserver, self).__init__()

        self.weight_col_obs = PerChannelMinMaxObserver(
            ch_axis=1,
            dtype=dtype,
            qscheme=qscheme,
            quant_min=quant_min,
            quant_max=quant_max,
            factory_kwargs=factory_kwargs)

        self.weight_row_obs = PerChannelMinMaxObserver(
            ch_axis=0,
            dtype=dtype,
            qscheme=qscheme,
            quant_min=quant_min,
            quant_max=quant_max,
            factory_kwargs=factory_kwargs)

        self.equalization_scale = torch.empty(0)
コード例 #3
0
ファイル: test_equalize_fx.py プロジェクト: vors/pytorch
    def test_input_weight_eq_observer(self, ndim, input_qdtype, input_qscheme, weight_qdtype, weight_qscheme):
        sizes = []
        for _ in range((ndim - 1) * 2):
            sizes.append(np.random.randint(2, 10))

        channel = np.random.randint(1, 10)
        if ndim == 2:
            x = np.random.random(size=(sizes[0], channel))
            w = np.random.random(size=(sizes[1], channel))
        elif ndim == 3:
            x = np.random.random(size=(sizes[0], channel, sizes[1]))
            w = np.random.random(size=(sizes[2], channel, sizes[3]))
        elif ndim == 4:
            x = np.random.random(size=(sizes[0], channel, sizes[1], sizes[2]))
            w = np.random.random(size=(sizes[3], channel, sizes[4], sizes[5]))
        elif ndim == 5:
            x = np.random.random(size=(sizes[0], channel, sizes[1], sizes[2], sizes[3]))
            w = np.random.random(size=(sizes[4], channel, sizes[5], sizes[6], sizes[7]))

        x = (x * 10).round(decimals=2).astype(np.float32)
        w = (w * 10).round(decimals=2).astype(np.float32)

        input_eq_obs = _InputEqualizationObserver(dtype=input_qdtype, qscheme=input_qscheme)
        weight_eq_obs = _WeightEqualizationObserver(dtype=weight_qdtype, qscheme=weight_qscheme)

        ret_x = input_eq_obs(torch.tensor(x))
        ret_w = weight_eq_obs(torch.tensor(w))
        self.assertEqual((ret_x, ret_w), (x, w))

        # Check the min/max input columns are correct
        ref_min_inputs, ref_max_inputs = self.channel_minmax(x)
        min_inputs, max_inputs = input_eq_obs.get_input_minmax()
        self.assertEqual(min_inputs, torch.tensor(ref_min_inputs, dtype=torch.float32))
        self.assertEqual(max_inputs, torch.tensor(ref_max_inputs, dtype=torch.float32))

        # Check the min/max weight columns are correct
        ref_min_weights_col, ref_max_weights_col = self.channel_minmax(w)
        min_weights_col, max_weights_col = weight_eq_obs.get_weight_col_minmax()
        self.assertEqual(min_weights_col, torch.tensor(ref_min_weights_col, dtype=torch.float32))
        self.assertEqual(max_weights_col, torch.tensor(ref_max_weights_col, dtype=torch.float32))

        # Check the equalization scale is correct
        equalization_scale = calculate_equalization_scale(input_eq_obs, weight_eq_obs)
        ref_equalization_scale = np.sqrt((ref_max_weights_col - ref_min_weights_col) /
                                         (ref_max_inputs - ref_min_inputs))
        self.assertEqual(equalization_scale, torch.tensor(ref_equalization_scale, dtype=torch.float32))

        input_eq_obs.set_equalization_scale(equalization_scale)
        weight_eq_obs.set_equalization_scale(equalization_scale)

        # Check the input scale/zero-point values
        min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax()
        input_quant_obs = MinMaxObserver(dtype=input_qdtype, qscheme=input_qscheme)
        input_quant_obs.min_val = min_input_scaled
        input_quant_obs.max_val = max_input_scaled
        input_qparams = input_quant_obs.calculate_qparams()

        ref_min_input_scaled = np.min(ref_min_inputs * ref_equalization_scale)
        ref_min_input_scaled = min(0, ref_min_input_scaled)
        ref_max_input_scaled = np.max(ref_max_inputs * ref_equalization_scale)
        ref_max_input_scaled = max(0, ref_max_input_scaled)

        if input_qscheme == torch.per_tensor_symmetric:
            ref_scale = 2 * max(abs(ref_min_input_scaled), ref_max_input_scaled) / 255
            ref_zero_point = 0 if input_qdtype is torch.qint8 else 128
        else:
            ref_scale = (ref_max_input_scaled - ref_min_input_scaled) / 255
            quant_min = -128 if input_qdtype is torch.qint8 else 0
            quant_max = 127 if input_qdtype is torch.qint8 else 255
            ref_zero_point = quant_min - np.round(ref_min_input_scaled / ref_scale)
            np.clip(ref_zero_point, quant_min, quant_max)

        self.assertEqual(input_qparams[0].item(), ref_scale, atol=1e-5, rtol=0)
        self.assertEqual(input_qparams[1].item(), ref_zero_point)

        # During input-weight equalization, we will scale the weights so that
        # the following weight quantized observer will have the correct scaled qparams
        # Check the weight scale/zero-point values of the quantized observer
        weight_quant_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=weight_qdtype, qscheme=weight_qscheme)

        # Scale the weights for input-weight equalization
        new_shape = [1] * w.ndim
        new_shape[1] = w.shape[1]
        ref_w_scaled = w * np.reciprocal(ref_equalization_scale.reshape(tuple(new_shape)))

        w = torch.tensor(w)
        new_shape[1] = w.size(1)
        w_scaled = torch.mul(w, torch.reciprocal(equalization_scale.view(new_shape)))

        self.assertEqual(w_scaled, ref_w_scaled)

        # Call forward on the weight quantization observer
        weight_quant_obs(w_scaled)

        # Check the min/max weight rows are correct
        ref_min_weights_scaled, ref_max_weights_scaled = self.channel_minmax(ref_w_scaled)
        self.assertEqual(weight_quant_obs.min_val, torch.tensor(ref_min_weights_scaled, dtype=torch.float32))
        self.assertEqual(weight_quant_obs.max_val, torch.tensor(ref_max_weights_scaled, dtype=torch.float32))

        weight_qparams = weight_quant_obs.calculate_qparams()

        if weight_qscheme == torch.per_channel_symmetric:
            ref_min_weights_scaled = np.minimum(np.zeros(ref_min_weights_scaled.shape), ref_min_weights_scaled)
            ref_max_weights_scaled = np.maximum(np.zeros(ref_max_weights_scaled.shape), ref_max_weights_scaled)

            ref_scales = 2 * np.maximum(np.abs(ref_min_weights_scaled), ref_max_weights_scaled) / 255
            ref_zero_points = np.zeros_like(
                ref_scales) if weight_qdtype is torch.qint8 else np.ones_like(ref_scales) * 128
        elif weight_qscheme == torch.per_channel_affine_float_qparams:
            ref_scales = (ref_max_weights_scaled - ref_min_weights_scaled) / 255
            ref_scales = np.where(ref_scales > 1e-7, ref_scales, np.ones_like(ref_scales))
            ref_zero_points = -1 * ref_min_weights_scaled / ref_scales
        else:
            ref_min_weights_scaled = np.minimum(np.zeros_like(ref_min_weights_scaled), ref_min_weights_scaled)
            ref_max_weights_scaled = np.maximum(np.zeros_like(ref_max_weights_scaled), ref_max_weights_scaled)

            ref_scales = (ref_max_weights_scaled - ref_min_weights_scaled) / 255
            ref_zero_points = -128 if weight_qdtype is torch.qint8 else 0
            ref_zero_points = ref_zero_points - np.round(ref_min_weights_scaled / ref_scales)

        self.assertTrue(torch.allclose(weight_qparams[0], torch.tensor(
            ref_scales, dtype=weight_qparams[0].dtype), atol=0.0001))
        self.assertTrue(torch.allclose(weight_qparams[1], torch.tensor(
            ref_zero_points, dtype=weight_qparams[1].dtype), atol=1))
コード例 #4
0
    def test_input_weight_eq_observer(self, input_qdtype, input_qscheme,
                                      weight_qdtype, weight_qscheme):
        """ Tests that the Input- and Weight- EqualizationObservers perform as expected
        """

        input_eq_obs = _InputEqualizationObserver(dtype=input_qdtype,
                                                  qscheme=input_qscheme)
        weight_eq_obs = _WeightEqualizationObserver(dtype=weight_qdtype,
                                                    qscheme=weight_qscheme)

        width = np.random.randint(1, 10)
        x_height = np.random.randint(2, 10)
        w_height = np.random.randint(2, 10)

        x = (np.random.random(size=(x_height, width)) *
             10).round(decimals=2).astype(np.float32)
        w = (np.random.random(size=(w_height, width)) *
             10).round(decimals=2).astype(np.float32)

        ret_x = input_eq_obs(torch.tensor(x))
        ret_w = weight_eq_obs(torch.tensor(w))
        self.assertEqual((ret_x, ret_w), (x, w))

        # Check the min/max input columns are correct
        ref_min_inputs = x.min(axis=0)
        ref_max_inputs = x.max(axis=0)
        self.assertEqual(input_eq_obs.get_input_minmax(),
                         (ref_min_inputs, ref_max_inputs))

        # Check the min/max weight columns are correct
        ref_min_weights_col = w.min(axis=0)
        ref_max_weights_col = w.max(axis=0)
        self.assertEqual(weight_eq_obs.get_weight_col_minmax(),
                         (ref_min_weights_col, ref_max_weights_col))

        # Check the equalization scale is correct
        equalization_scale = calculate_equalization_scale(
            input_eq_obs, weight_eq_obs)
        ref_equalization_scale = np.sqrt(
            (ref_max_weights_col - ref_min_weights_col) /
            (ref_max_inputs - ref_min_inputs))
        self.assertEqual(equalization_scale, ref_equalization_scale)

        input_eq_obs.set_equalization_scale(equalization_scale)
        weight_eq_obs.set_equalization_scale(equalization_scale)

        # Check the input scale/zero-point values
        min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax(
        )
        input_quant_obs = MinMaxObserver(dtype=input_qdtype,
                                         qscheme=input_qscheme)
        input_quant_obs.min_val = min_input_scaled
        input_quant_obs.max_val = max_input_scaled
        input_qparams = input_quant_obs.calculate_qparams()

        ref_min_input_scaled = np.min(ref_min_inputs * ref_equalization_scale)
        ref_min_input_scaled = min(0, ref_min_input_scaled)
        ref_max_input_scaled = np.max(ref_max_inputs * ref_equalization_scale)
        ref_max_input_scaled = max(0, ref_max_input_scaled)

        if input_qscheme == torch.per_tensor_symmetric:
            ref_scale = 2 * max(abs(ref_min_input_scaled),
                                ref_max_input_scaled) / 255
            ref_zero_point = 0 if input_qdtype is torch.qint8 else 128
        else:
            ref_scale = (ref_max_input_scaled - ref_min_input_scaled) / 255
            ref_zero_point = -128 if input_qdtype is torch.qint8 else 0

        self.assertEqual(input_qparams[0].item(), ref_scale, atol=1e-5, rtol=0)
        self.assertEqual(input_qparams[1].item(), ref_zero_point)

        # During input-weight equalization, we will scale the weights so that
        # the following weight quantized observer will have the correct scaled qparams
        # Check the weight scale/zero-point values of the quantized observer
        weight_quant_obs = PerChannelMinMaxObserver(dtype=weight_qdtype,
                                                    qscheme=weight_qscheme)

        # Scale the weights for input-weight equalization
        ref_w_scaled = w * np.reciprocal(ref_equalization_scale)
        w_scaled = torch.mul(torch.tensor(w),
                             torch.reciprocal(equalization_scale))
        self.assertEqual(ref_w_scaled, w_scaled)

        # Call forward on the weight quantization observer
        weight_quant_obs(w_scaled)

        # Check the min/max weight rows are correct
        ref_min_weights_scaled = ref_w_scaled.min(axis=1)
        ref_max_weights_scaled = ref_w_scaled.max(axis=1)
        self.assertEqual(weight_quant_obs.min_vals, ref_min_weights_scaled)
        self.assertEqual(weight_quant_obs.max_vals, ref_max_weights_scaled)

        weight_qparams = weight_quant_obs.calculate_qparams()

        if weight_qscheme == torch.per_channel_symmetric:
            ref_min_weights_scaled = np.minimum(
                np.zeros(ref_min_weights_scaled.shape), ref_min_weights_scaled)
            ref_max_weights_scaled = np.maximum(
                np.zeros(ref_max_weights_scaled.shape), ref_max_weights_scaled)

            ref_scales = 2 * np.maximum(np.abs(ref_min_weights_scaled),
                                        ref_max_weights_scaled) / 255
            ref_zero_points = np.zeros_like(
                ref_scales) if weight_qdtype is torch.qint8 else np.ones_like(
                    ref_scales) * 128
        elif weight_qscheme == torch.per_channel_affine_float_qparams:
            ref_scales = (ref_max_weights_scaled -
                          ref_min_weights_scaled) / 255
            ref_scales = np.where(ref_scales > 1e-7, ref_scales,
                                  np.ones_like(ref_scales))
            ref_zero_points = -1 * ref_min_weights_scaled / ref_scales
        else:
            ref_min_weights_scaled = np.minimum(
                np.zeros_like(ref_min_weights_scaled), ref_min_weights_scaled)
            ref_max_weights_scaled = np.maximum(
                np.zeros_like(ref_max_weights_scaled), ref_max_weights_scaled)

            ref_scales = (ref_max_weights_scaled -
                          ref_min_weights_scaled) / 255
            ref_zero_points = -128 if weight_qdtype is torch.qint8 else 0
            ref_zero_points = ref_zero_points - np.round(
                ref_min_weights_scaled / ref_scales)

        self.assertTrue(
            torch.allclose(weight_qparams[0],
                           torch.tensor(ref_scales,
                                        dtype=weight_qparams[0].dtype),
                           atol=0.0001))
        self.assertTrue(
            torch.allclose(weight_qparams[1],
                           torch.tensor(ref_zero_points,
                                        dtype=weight_qparams[1].dtype),
                           atol=1))
コード例 #5
0
ファイル: _equalize.py プロジェクト: mwootton/pytorch
class _InputEqualizationObserver(nn.Module):
    r"""Observer for tracking the running min/max values of input columns, and
    computing the quantization parameters for the overall min/max input values.

    Args:
        dtype: Quantized data type
        qscheme: Quantization scheme
        quant_min: Minimum quantization value. If unspecified, it will
            follow the 8-bit setup.
        quant_max: Maximum quantization value. If unspecified, it will
            follow the 8-bit setup.
        output_obs: For the user to specify what kind of output observer they
            would like to use

    The running minimum/maximum :math:`x_\text{min/max}` are computed in the
    same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`,
    with the difference that the running min/max values are stored per column.

    The qparams are calculated by multiplying the min/max input column values
    with the equalization scale, reducing to find the global min/max input
    values, and then calculating in the same way as in
    :class:`~torch.quantization.observer.MinMaxObserver`

    .. note:: If the running minimum equals to the running maximum, the scales
              and zero_points are set to 1.0 and 0.
    """
    def __init__(self,
                 dtype=torch.quint8,
                 qscheme=torch.per_tensor_affine,
                 quant_min=None,
                 quant_max=None,
                 output_obs=None,
                 factory_kwargs=None) -> None:
        super(_InputEqualizationObserver, self).__init__()

        if qscheme not in {
                torch.per_tensor_affine, torch.per_tensor_symmetric
        }:
            raise TypeError("Input qscheme must be per-tensor")

        self.input_obs = PerChannelMinMaxObserver(
            ch_axis=1,
            dtype=dtype,
            qscheme=qscheme,
            quant_min=quant_min,
            quant_max=quant_max,
            factory_kwargs=factory_kwargs)

        if output_obs is None:
            self.output_obs = MinMaxObserver(dtype=dtype,
                                             qscheme=qscheme,
                                             quant_min=quant_min,
                                             quant_max=quant_max,
                                             factory_kwargs=factory_kwargs)
        else:
            self.output_obs = output_obs

        self.equalization_scale = torch.empty(0)

    def forward(self, x_orig):
        # TODO: Allow for convoluational layers
        if not (x_orig.ndim == 2):
            raise ValueError(
                "InputEqualizationObserver only supports Linear layers")

        return self.input_obs(x_orig)

    def get_input_minmax(self):
        return (self.input_obs.min_vals, self.input_obs.max_vals)

    def set_equalization_scale(self, equalization_scale):
        self.equalization_scale = equalization_scale

    def calculate_qparams(self):
        r"""
        Returns the scale/zero_point for the input and weight rows
        """

        if self.equalization_scale.nelement() == 0:
            warnings.warn(
                "Must call calculate_scale before calling calculate_qparams.\
                Returning default scale and zero point. ")
            return torch.tensor([1.0]), torch.tensor([0]), torch.tensor(
                [1.0]), torch.tensor([0])

        # Calculate qparams for the scaled min/max inputs
        # Scale the input by the equalization scale located at the same column
        # index
        (min_inputs, max_inputs) = self.get_input_minmax()
        min_input_scaled = torch.min(
            torch.mul(min_inputs, self.equalization_scale))
        max_input_scaled = torch.max(
            torch.mul(max_inputs, self.equalization_scale))
        (scale_input, zero_point_input) = self.input_obs._calculate_qparams(
            min_input_scaled, max_input_scaled)

        return scale_input, zero_point_input
コード例 #6
0
ファイル: _equalize.py プロジェクト: mwootton/pytorch
class _WeightEqualizationObserver(nn.Module):
    r"""Observer for tracking the running min/max values of weight columns and
    rows, and computing the quantization parameters for the weight rows.

    Args:
        dtype: Quantized data type
        qscheme: Quantization scheme
        quant_min: Minimum quantization value. If unspecified, it will
            follow the 8-bit setup.
        quant_max: Maximum quantization value. If unspecified, it will
            follow the 8-bit setup.

    This observer is made up of 2 PerChannelMinMaxObservers
        - weight_col_obs: Used to record the running minimum and maximum of
        columns of incoming weight tensors
        - weight_row_obs: Used to record the running minimum and maximum of
        rows of incoming weight tensors

    The running minimum/maximum :math:`w_\text{min/max}` are computed in the
    same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`.

    The qparams are calculated by multiplying the min/max weight row values
    with the inverse of the equalization scale, and then calculating in the same
    way as in :class:`~torch.quantization.observer.PerChannelMinMaxObserver`

    .. note:: If the running minimum equals to the running maximum, the scales
              and zero_points are set to 1.0 and 0.
    """
    def __init__(self,
                 dtype=torch.qint8,
                 qscheme=torch.per_tensor_affine,
                 quant_min=None,
                 quant_max=None,
                 factory_kwargs=None) -> None:
        super(_WeightEqualizationObserver, self).__init__()

        self.weight_col_obs = PerChannelMinMaxObserver(
            ch_axis=1,
            dtype=dtype,
            qscheme=qscheme,
            quant_min=quant_min,
            quant_max=quant_max,
            factory_kwargs=factory_kwargs)

        self.weight_row_obs = PerChannelMinMaxObserver(
            ch_axis=0,
            dtype=dtype,
            qscheme=qscheme,
            quant_min=quant_min,
            quant_max=quant_max,
            factory_kwargs=factory_kwargs)

        self.equalization_scale = torch.empty(0)

    def forward(self, w_orig):
        # TODO: Allow for convoluational layers
        if not (w_orig.ndim == 2):
            raise ValueError(
                "WeightEqualizationObserver only supports Linear layers")

        return self._forward(w_orig)

    def _forward(self, w_orig):
        r"""
        Calculates the min/max values of each weight column and weight row.
        """

        w_orig = self.weight_col_obs(w_orig)
        w_orig = self.weight_row_obs(w_orig)

        # Calculate the column indices of the min/max weight in each row
        num_row, _ = w_orig.shape
        min_weights_ind = []
        max_weights_ind = []
        for i in range(num_row):
            min_weights_ind.append(
                torch.nonzero(
                    w_orig[i] == self.weight_row_obs.min_vals[i])[0][0])
            max_weights_ind.append(
                torch.nonzero(
                    w_orig[i] == self.weight_row_obs.max_vals[i])[0][0])
        self.min_weights_ind = torch.tensor(min_weights_ind)
        self.max_weights_ind = torch.tensor(max_weights_ind)

        return w_orig

    def get_weight_col_minmax(self):
        return (self.weight_col_obs.min_vals, self.weight_col_obs.max_vals)

    def get_weight_row_minmax(self):
        return (self.weight_row_obs.min_vals, self.weight_row_obs.max_vals)

    def set_equalization_scale(self, equalization_scale):
        self.equalization_scale = equalization_scale

    def calculate_qparams(self):
        r"""
        Returns the scale/zero_point for the input and weight rows
        """

        if self.equalization_scale.nelement() == 0:
            warnings.warn(
                "Must call calculate_scale before calling calculate_qparams.\
                Returning default scale and zero point. ")
            return torch.tensor([1.0]), torch.tensor([0]), torch.tensor(
                [1.0]), torch.tensor([0])

        if self.min_weights_ind is None or self.max_weights_ind is None:
            warnings.warn(
                "Must find the column indicies of the minimum of each row in the \
                weights in order to calculate the qparams calculate the \
                qparams. Returning default scale and zero point. ")
            return torch.tensor([1.0]), torch.tensor([0]), torch.tensor(
                [1.0]), torch.tensor([0])

        # Calculate the qparams for weights by using the rows
        # Scale the weight rows by the reciprocal of the equalization scale
        # located at the same column index
        (min_weights, max_weights) = self.get_weight_row_minmax()
        min_weights_scaled = torch.mul(
            min_weights,
            torch.reciprocal(self.equalization_scale[self.min_weights_ind]))
        max_weights_scaled = torch.mul(
            max_weights,
            torch.reciprocal(self.equalization_scale[self.max_weights_ind]))
        (scale_weight,
         zero_point_weight) = self.weight_row_obs._calculate_qparams(
             min_weights_scaled, max_weights_scaled)

        return scale_weight, zero_point_weight