コード例 #1
0
ファイル: _equalize.py プロジェクト: mwootton/pytorch
class _InputEqualizationObserver(nn.Module):
    r"""Observer for tracking the running min/max values of input columns, and
    computing the quantization parameters for the overall min/max input values.

    Args:
        dtype: Quantized data type
        qscheme: Quantization scheme
        quant_min: Minimum quantization value. If unspecified, it will
            follow the 8-bit setup.
        quant_max: Maximum quantization value. If unspecified, it will
            follow the 8-bit setup.
        output_obs: For the user to specify what kind of output observer they
            would like to use

    The running minimum/maximum :math:`x_\text{min/max}` are computed in the
    same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`,
    with the difference that the running min/max values are stored per column.

    The qparams are calculated by multiplying the min/max input column values
    with the equalization scale, reducing to find the global min/max input
    values, and then calculating in the same way as in
    :class:`~torch.quantization.observer.MinMaxObserver`

    .. note:: If the running minimum equals to the running maximum, the scales
              and zero_points are set to 1.0 and 0.
    """
    def __init__(self,
                 dtype=torch.quint8,
                 qscheme=torch.per_tensor_affine,
                 quant_min=None,
                 quant_max=None,
                 output_obs=None,
                 factory_kwargs=None) -> None:
        super(_InputEqualizationObserver, self).__init__()

        if qscheme not in {
                torch.per_tensor_affine, torch.per_tensor_symmetric
        }:
            raise TypeError("Input qscheme must be per-tensor")

        self.input_obs = PerChannelMinMaxObserver(
            ch_axis=1,
            dtype=dtype,
            qscheme=qscheme,
            quant_min=quant_min,
            quant_max=quant_max,
            factory_kwargs=factory_kwargs)

        if output_obs is None:
            self.output_obs = MinMaxObserver(dtype=dtype,
                                             qscheme=qscheme,
                                             quant_min=quant_min,
                                             quant_max=quant_max,
                                             factory_kwargs=factory_kwargs)
        else:
            self.output_obs = output_obs

        self.equalization_scale = torch.empty(0)

    def forward(self, x_orig):
        # TODO: Allow for convoluational layers
        if not (x_orig.ndim == 2):
            raise ValueError(
                "InputEqualizationObserver only supports Linear layers")

        return self.input_obs(x_orig)

    def get_input_minmax(self):
        return (self.input_obs.min_vals, self.input_obs.max_vals)

    def set_equalization_scale(self, equalization_scale):
        self.equalization_scale = equalization_scale

    def calculate_qparams(self):
        r"""
        Returns the scale/zero_point for the input and weight rows
        """

        if self.equalization_scale.nelement() == 0:
            warnings.warn(
                "Must call calculate_scale before calling calculate_qparams.\
                Returning default scale and zero point. ")
            return torch.tensor([1.0]), torch.tensor([0]), torch.tensor(
                [1.0]), torch.tensor([0])

        # Calculate qparams for the scaled min/max inputs
        # Scale the input by the equalization scale located at the same column
        # index
        (min_inputs, max_inputs) = self.get_input_minmax()
        min_input_scaled = torch.min(
            torch.mul(min_inputs, self.equalization_scale))
        max_input_scaled = torch.max(
            torch.mul(max_inputs, self.equalization_scale))
        (scale_input, zero_point_input) = self.input_obs._calculate_qparams(
            min_input_scaled, max_input_scaled)

        return scale_input, zero_point_input
コード例 #2
0
ファイル: _equalize.py プロジェクト: mwootton/pytorch
class _WeightEqualizationObserver(nn.Module):
    r"""Observer for tracking the running min/max values of weight columns and
    rows, and computing the quantization parameters for the weight rows.

    Args:
        dtype: Quantized data type
        qscheme: Quantization scheme
        quant_min: Minimum quantization value. If unspecified, it will
            follow the 8-bit setup.
        quant_max: Maximum quantization value. If unspecified, it will
            follow the 8-bit setup.

    This observer is made up of 2 PerChannelMinMaxObservers
        - weight_col_obs: Used to record the running minimum and maximum of
        columns of incoming weight tensors
        - weight_row_obs: Used to record the running minimum and maximum of
        rows of incoming weight tensors

    The running minimum/maximum :math:`w_\text{min/max}` are computed in the
    same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`.

    The qparams are calculated by multiplying the min/max weight row values
    with the inverse of the equalization scale, and then calculating in the same
    way as in :class:`~torch.quantization.observer.PerChannelMinMaxObserver`

    .. note:: If the running minimum equals to the running maximum, the scales
              and zero_points are set to 1.0 and 0.
    """
    def __init__(self,
                 dtype=torch.qint8,
                 qscheme=torch.per_tensor_affine,
                 quant_min=None,
                 quant_max=None,
                 factory_kwargs=None) -> None:
        super(_WeightEqualizationObserver, self).__init__()

        self.weight_col_obs = PerChannelMinMaxObserver(
            ch_axis=1,
            dtype=dtype,
            qscheme=qscheme,
            quant_min=quant_min,
            quant_max=quant_max,
            factory_kwargs=factory_kwargs)

        self.weight_row_obs = PerChannelMinMaxObserver(
            ch_axis=0,
            dtype=dtype,
            qscheme=qscheme,
            quant_min=quant_min,
            quant_max=quant_max,
            factory_kwargs=factory_kwargs)

        self.equalization_scale = torch.empty(0)

    def forward(self, w_orig):
        # TODO: Allow for convoluational layers
        if not (w_orig.ndim == 2):
            raise ValueError(
                "WeightEqualizationObserver only supports Linear layers")

        return self._forward(w_orig)

    def _forward(self, w_orig):
        r"""
        Calculates the min/max values of each weight column and weight row.
        """

        w_orig = self.weight_col_obs(w_orig)
        w_orig = self.weight_row_obs(w_orig)

        # Calculate the column indices of the min/max weight in each row
        num_row, _ = w_orig.shape
        min_weights_ind = []
        max_weights_ind = []
        for i in range(num_row):
            min_weights_ind.append(
                torch.nonzero(
                    w_orig[i] == self.weight_row_obs.min_vals[i])[0][0])
            max_weights_ind.append(
                torch.nonzero(
                    w_orig[i] == self.weight_row_obs.max_vals[i])[0][0])
        self.min_weights_ind = torch.tensor(min_weights_ind)
        self.max_weights_ind = torch.tensor(max_weights_ind)

        return w_orig

    def get_weight_col_minmax(self):
        return (self.weight_col_obs.min_vals, self.weight_col_obs.max_vals)

    def get_weight_row_minmax(self):
        return (self.weight_row_obs.min_vals, self.weight_row_obs.max_vals)

    def set_equalization_scale(self, equalization_scale):
        self.equalization_scale = equalization_scale

    def calculate_qparams(self):
        r"""
        Returns the scale/zero_point for the input and weight rows
        """

        if self.equalization_scale.nelement() == 0:
            warnings.warn(
                "Must call calculate_scale before calling calculate_qparams.\
                Returning default scale and zero point. ")
            return torch.tensor([1.0]), torch.tensor([0]), torch.tensor(
                [1.0]), torch.tensor([0])

        if self.min_weights_ind is None or self.max_weights_ind is None:
            warnings.warn(
                "Must find the column indicies of the minimum of each row in the \
                weights in order to calculate the qparams calculate the \
                qparams. Returning default scale and zero point. ")
            return torch.tensor([1.0]), torch.tensor([0]), torch.tensor(
                [1.0]), torch.tensor([0])

        # Calculate the qparams for weights by using the rows
        # Scale the weight rows by the reciprocal of the equalization scale
        # located at the same column index
        (min_weights, max_weights) = self.get_weight_row_minmax()
        min_weights_scaled = torch.mul(
            min_weights,
            torch.reciprocal(self.equalization_scale[self.min_weights_ind]))
        max_weights_scaled = torch.mul(
            max_weights,
            torch.reciprocal(self.equalization_scale[self.max_weights_ind]))
        (scale_weight,
         zero_point_weight) = self.weight_row_obs._calculate_qparams(
             min_weights_scaled, max_weights_scaled)

        return scale_weight, zero_point_weight