def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, quant_min=None, quant_max=None, output_obs=None, factory_kwargs=None) -> None: super(_InputEqualizationObserver, self).__init__() if qscheme not in { torch.per_tensor_affine, torch.per_tensor_symmetric }: raise TypeError("Input qscheme must be per-tensor") self.input_obs = PerChannelMinMaxObserver( ch_axis=1, dtype=dtype, qscheme=qscheme, quant_min=quant_min, quant_max=quant_max, factory_kwargs=factory_kwargs) if output_obs is None: self.output_obs = MinMaxObserver(dtype=dtype, qscheme=qscheme, quant_min=quant_min, quant_max=quant_max, factory_kwargs=factory_kwargs) else: self.output_obs = output_obs self.equalization_scale = torch.empty(0)
def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None, quant_max=None, factory_kwargs=None) -> None: super(_WeightEqualizationObserver, self).__init__() self.weight_col_obs = PerChannelMinMaxObserver( ch_axis=1, dtype=dtype, qscheme=qscheme, quant_min=quant_min, quant_max=quant_max, factory_kwargs=factory_kwargs) self.weight_row_obs = PerChannelMinMaxObserver( ch_axis=0, dtype=dtype, qscheme=qscheme, quant_min=quant_min, quant_max=quant_max, factory_kwargs=factory_kwargs) self.equalization_scale = torch.empty(0)
def test_input_weight_eq_observer(self, ndim, input_qdtype, input_qscheme, weight_qdtype, weight_qscheme): sizes = [] for _ in range((ndim - 1) * 2): sizes.append(np.random.randint(2, 10)) channel = np.random.randint(1, 10) if ndim == 2: x = np.random.random(size=(sizes[0], channel)) w = np.random.random(size=(sizes[1], channel)) elif ndim == 3: x = np.random.random(size=(sizes[0], channel, sizes[1])) w = np.random.random(size=(sizes[2], channel, sizes[3])) elif ndim == 4: x = np.random.random(size=(sizes[0], channel, sizes[1], sizes[2])) w = np.random.random(size=(sizes[3], channel, sizes[4], sizes[5])) elif ndim == 5: x = np.random.random(size=(sizes[0], channel, sizes[1], sizes[2], sizes[3])) w = np.random.random(size=(sizes[4], channel, sizes[5], sizes[6], sizes[7])) x = (x * 10).round(decimals=2).astype(np.float32) w = (w * 10).round(decimals=2).astype(np.float32) input_eq_obs = _InputEqualizationObserver(dtype=input_qdtype, qscheme=input_qscheme) weight_eq_obs = _WeightEqualizationObserver(dtype=weight_qdtype, qscheme=weight_qscheme) ret_x = input_eq_obs(torch.tensor(x)) ret_w = weight_eq_obs(torch.tensor(w)) self.assertEqual((ret_x, ret_w), (x, w)) # Check the min/max input columns are correct ref_min_inputs, ref_max_inputs = self.channel_minmax(x) min_inputs, max_inputs = input_eq_obs.get_input_minmax() self.assertEqual(min_inputs, torch.tensor(ref_min_inputs, dtype=torch.float32)) self.assertEqual(max_inputs, torch.tensor(ref_max_inputs, dtype=torch.float32)) # Check the min/max weight columns are correct ref_min_weights_col, ref_max_weights_col = self.channel_minmax(w) min_weights_col, max_weights_col = weight_eq_obs.get_weight_col_minmax() self.assertEqual(min_weights_col, torch.tensor(ref_min_weights_col, dtype=torch.float32)) self.assertEqual(max_weights_col, torch.tensor(ref_max_weights_col, dtype=torch.float32)) # Check the equalization scale is correct equalization_scale = calculate_equalization_scale(input_eq_obs, weight_eq_obs) ref_equalization_scale = np.sqrt((ref_max_weights_col - ref_min_weights_col) / (ref_max_inputs - ref_min_inputs)) self.assertEqual(equalization_scale, torch.tensor(ref_equalization_scale, dtype=torch.float32)) input_eq_obs.set_equalization_scale(equalization_scale) weight_eq_obs.set_equalization_scale(equalization_scale) # Check the input scale/zero-point values min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax() input_quant_obs = MinMaxObserver(dtype=input_qdtype, qscheme=input_qscheme) input_quant_obs.min_val = min_input_scaled input_quant_obs.max_val = max_input_scaled input_qparams = input_quant_obs.calculate_qparams() ref_min_input_scaled = np.min(ref_min_inputs * ref_equalization_scale) ref_min_input_scaled = min(0, ref_min_input_scaled) ref_max_input_scaled = np.max(ref_max_inputs * ref_equalization_scale) ref_max_input_scaled = max(0, ref_max_input_scaled) if input_qscheme == torch.per_tensor_symmetric: ref_scale = 2 * max(abs(ref_min_input_scaled), ref_max_input_scaled) / 255 ref_zero_point = 0 if input_qdtype is torch.qint8 else 128 else: ref_scale = (ref_max_input_scaled - ref_min_input_scaled) / 255 quant_min = -128 if input_qdtype is torch.qint8 else 0 quant_max = 127 if input_qdtype is torch.qint8 else 255 ref_zero_point = quant_min - np.round(ref_min_input_scaled / ref_scale) np.clip(ref_zero_point, quant_min, quant_max) self.assertEqual(input_qparams[0].item(), ref_scale, atol=1e-5, rtol=0) self.assertEqual(input_qparams[1].item(), ref_zero_point) # During input-weight equalization, we will scale the weights so that # the following weight quantized observer will have the correct scaled qparams # Check the weight scale/zero-point values of the quantized observer weight_quant_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=weight_qdtype, qscheme=weight_qscheme) # Scale the weights for input-weight equalization new_shape = [1] * w.ndim new_shape[1] = w.shape[1] ref_w_scaled = w * np.reciprocal(ref_equalization_scale.reshape(tuple(new_shape))) w = torch.tensor(w) new_shape[1] = w.size(1) w_scaled = torch.mul(w, torch.reciprocal(equalization_scale.view(new_shape))) self.assertEqual(w_scaled, ref_w_scaled) # Call forward on the weight quantization observer weight_quant_obs(w_scaled) # Check the min/max weight rows are correct ref_min_weights_scaled, ref_max_weights_scaled = self.channel_minmax(ref_w_scaled) self.assertEqual(weight_quant_obs.min_val, torch.tensor(ref_min_weights_scaled, dtype=torch.float32)) self.assertEqual(weight_quant_obs.max_val, torch.tensor(ref_max_weights_scaled, dtype=torch.float32)) weight_qparams = weight_quant_obs.calculate_qparams() if weight_qscheme == torch.per_channel_symmetric: ref_min_weights_scaled = np.minimum(np.zeros(ref_min_weights_scaled.shape), ref_min_weights_scaled) ref_max_weights_scaled = np.maximum(np.zeros(ref_max_weights_scaled.shape), ref_max_weights_scaled) ref_scales = 2 * np.maximum(np.abs(ref_min_weights_scaled), ref_max_weights_scaled) / 255 ref_zero_points = np.zeros_like( ref_scales) if weight_qdtype is torch.qint8 else np.ones_like(ref_scales) * 128 elif weight_qscheme == torch.per_channel_affine_float_qparams: ref_scales = (ref_max_weights_scaled - ref_min_weights_scaled) / 255 ref_scales = np.where(ref_scales > 1e-7, ref_scales, np.ones_like(ref_scales)) ref_zero_points = -1 * ref_min_weights_scaled / ref_scales else: ref_min_weights_scaled = np.minimum(np.zeros_like(ref_min_weights_scaled), ref_min_weights_scaled) ref_max_weights_scaled = np.maximum(np.zeros_like(ref_max_weights_scaled), ref_max_weights_scaled) ref_scales = (ref_max_weights_scaled - ref_min_weights_scaled) / 255 ref_zero_points = -128 if weight_qdtype is torch.qint8 else 0 ref_zero_points = ref_zero_points - np.round(ref_min_weights_scaled / ref_scales) self.assertTrue(torch.allclose(weight_qparams[0], torch.tensor( ref_scales, dtype=weight_qparams[0].dtype), atol=0.0001)) self.assertTrue(torch.allclose(weight_qparams[1], torch.tensor( ref_zero_points, dtype=weight_qparams[1].dtype), atol=1))
def test_input_weight_eq_observer(self, input_qdtype, input_qscheme, weight_qdtype, weight_qscheme): """ Tests that the Input- and Weight- EqualizationObservers perform as expected """ input_eq_obs = _InputEqualizationObserver(dtype=input_qdtype, qscheme=input_qscheme) weight_eq_obs = _WeightEqualizationObserver(dtype=weight_qdtype, qscheme=weight_qscheme) width = np.random.randint(1, 10) x_height = np.random.randint(2, 10) w_height = np.random.randint(2, 10) x = (np.random.random(size=(x_height, width)) * 10).round(decimals=2).astype(np.float32) w = (np.random.random(size=(w_height, width)) * 10).round(decimals=2).astype(np.float32) ret_x = input_eq_obs(torch.tensor(x)) ret_w = weight_eq_obs(torch.tensor(w)) self.assertEqual((ret_x, ret_w), (x, w)) # Check the min/max input columns are correct ref_min_inputs = x.min(axis=0) ref_max_inputs = x.max(axis=0) self.assertEqual(input_eq_obs.get_input_minmax(), (ref_min_inputs, ref_max_inputs)) # Check the min/max weight columns are correct ref_min_weights_col = w.min(axis=0) ref_max_weights_col = w.max(axis=0) self.assertEqual(weight_eq_obs.get_weight_col_minmax(), (ref_min_weights_col, ref_max_weights_col)) # Check the equalization scale is correct equalization_scale = calculate_equalization_scale( input_eq_obs, weight_eq_obs) ref_equalization_scale = np.sqrt( (ref_max_weights_col - ref_min_weights_col) / (ref_max_inputs - ref_min_inputs)) self.assertEqual(equalization_scale, ref_equalization_scale) input_eq_obs.set_equalization_scale(equalization_scale) weight_eq_obs.set_equalization_scale(equalization_scale) # Check the input scale/zero-point values min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax( ) input_quant_obs = MinMaxObserver(dtype=input_qdtype, qscheme=input_qscheme) input_quant_obs.min_val = min_input_scaled input_quant_obs.max_val = max_input_scaled input_qparams = input_quant_obs.calculate_qparams() ref_min_input_scaled = np.min(ref_min_inputs * ref_equalization_scale) ref_min_input_scaled = min(0, ref_min_input_scaled) ref_max_input_scaled = np.max(ref_max_inputs * ref_equalization_scale) ref_max_input_scaled = max(0, ref_max_input_scaled) if input_qscheme == torch.per_tensor_symmetric: ref_scale = 2 * max(abs(ref_min_input_scaled), ref_max_input_scaled) / 255 ref_zero_point = 0 if input_qdtype is torch.qint8 else 128 else: ref_scale = (ref_max_input_scaled - ref_min_input_scaled) / 255 ref_zero_point = -128 if input_qdtype is torch.qint8 else 0 self.assertEqual(input_qparams[0].item(), ref_scale, atol=1e-5, rtol=0) self.assertEqual(input_qparams[1].item(), ref_zero_point) # During input-weight equalization, we will scale the weights so that # the following weight quantized observer will have the correct scaled qparams # Check the weight scale/zero-point values of the quantized observer weight_quant_obs = PerChannelMinMaxObserver(dtype=weight_qdtype, qscheme=weight_qscheme) # Scale the weights for input-weight equalization ref_w_scaled = w * np.reciprocal(ref_equalization_scale) w_scaled = torch.mul(torch.tensor(w), torch.reciprocal(equalization_scale)) self.assertEqual(ref_w_scaled, w_scaled) # Call forward on the weight quantization observer weight_quant_obs(w_scaled) # Check the min/max weight rows are correct ref_min_weights_scaled = ref_w_scaled.min(axis=1) ref_max_weights_scaled = ref_w_scaled.max(axis=1) self.assertEqual(weight_quant_obs.min_vals, ref_min_weights_scaled) self.assertEqual(weight_quant_obs.max_vals, ref_max_weights_scaled) weight_qparams = weight_quant_obs.calculate_qparams() if weight_qscheme == torch.per_channel_symmetric: ref_min_weights_scaled = np.minimum( np.zeros(ref_min_weights_scaled.shape), ref_min_weights_scaled) ref_max_weights_scaled = np.maximum( np.zeros(ref_max_weights_scaled.shape), ref_max_weights_scaled) ref_scales = 2 * np.maximum(np.abs(ref_min_weights_scaled), ref_max_weights_scaled) / 255 ref_zero_points = np.zeros_like( ref_scales) if weight_qdtype is torch.qint8 else np.ones_like( ref_scales) * 128 elif weight_qscheme == torch.per_channel_affine_float_qparams: ref_scales = (ref_max_weights_scaled - ref_min_weights_scaled) / 255 ref_scales = np.where(ref_scales > 1e-7, ref_scales, np.ones_like(ref_scales)) ref_zero_points = -1 * ref_min_weights_scaled / ref_scales else: ref_min_weights_scaled = np.minimum( np.zeros_like(ref_min_weights_scaled), ref_min_weights_scaled) ref_max_weights_scaled = np.maximum( np.zeros_like(ref_max_weights_scaled), ref_max_weights_scaled) ref_scales = (ref_max_weights_scaled - ref_min_weights_scaled) / 255 ref_zero_points = -128 if weight_qdtype is torch.qint8 else 0 ref_zero_points = ref_zero_points - np.round( ref_min_weights_scaled / ref_scales) self.assertTrue( torch.allclose(weight_qparams[0], torch.tensor(ref_scales, dtype=weight_qparams[0].dtype), atol=0.0001)) self.assertTrue( torch.allclose(weight_qparams[1], torch.tensor(ref_zero_points, dtype=weight_qparams[1].dtype), atol=1))
class _InputEqualizationObserver(nn.Module): r"""Observer for tracking the running min/max values of input columns, and computing the quantization parameters for the overall min/max input values. Args: dtype: Quantized data type qscheme: Quantization scheme quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. output_obs: For the user to specify what kind of output observer they would like to use The running minimum/maximum :math:`x_\text{min/max}` are computed in the same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`, with the difference that the running min/max values are stored per column. The qparams are calculated by multiplying the min/max input column values with the equalization scale, reducing to find the global min/max input values, and then calculating in the same way as in :class:`~torch.quantization.observer.MinMaxObserver` .. note:: If the running minimum equals to the running maximum, the scales and zero_points are set to 1.0 and 0. """ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, quant_min=None, quant_max=None, output_obs=None, factory_kwargs=None) -> None: super(_InputEqualizationObserver, self).__init__() if qscheme not in { torch.per_tensor_affine, torch.per_tensor_symmetric }: raise TypeError("Input qscheme must be per-tensor") self.input_obs = PerChannelMinMaxObserver( ch_axis=1, dtype=dtype, qscheme=qscheme, quant_min=quant_min, quant_max=quant_max, factory_kwargs=factory_kwargs) if output_obs is None: self.output_obs = MinMaxObserver(dtype=dtype, qscheme=qscheme, quant_min=quant_min, quant_max=quant_max, factory_kwargs=factory_kwargs) else: self.output_obs = output_obs self.equalization_scale = torch.empty(0) def forward(self, x_orig): # TODO: Allow for convoluational layers if not (x_orig.ndim == 2): raise ValueError( "InputEqualizationObserver only supports Linear layers") return self.input_obs(x_orig) def get_input_minmax(self): return (self.input_obs.min_vals, self.input_obs.max_vals) def set_equalization_scale(self, equalization_scale): self.equalization_scale = equalization_scale def calculate_qparams(self): r""" Returns the scale/zero_point for the input and weight rows """ if self.equalization_scale.nelement() == 0: warnings.warn( "Must call calculate_scale before calling calculate_qparams.\ Returning default scale and zero point. ") return torch.tensor([1.0]), torch.tensor([0]), torch.tensor( [1.0]), torch.tensor([0]) # Calculate qparams for the scaled min/max inputs # Scale the input by the equalization scale located at the same column # index (min_inputs, max_inputs) = self.get_input_minmax() min_input_scaled = torch.min( torch.mul(min_inputs, self.equalization_scale)) max_input_scaled = torch.max( torch.mul(max_inputs, self.equalization_scale)) (scale_input, zero_point_input) = self.input_obs._calculate_qparams( min_input_scaled, max_input_scaled) return scale_input, zero_point_input
class _WeightEqualizationObserver(nn.Module): r"""Observer for tracking the running min/max values of weight columns and rows, and computing the quantization parameters for the weight rows. Args: dtype: Quantized data type qscheme: Quantization scheme quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. This observer is made up of 2 PerChannelMinMaxObservers - weight_col_obs: Used to record the running minimum and maximum of columns of incoming weight tensors - weight_row_obs: Used to record the running minimum and maximum of rows of incoming weight tensors The running minimum/maximum :math:`w_\text{min/max}` are computed in the same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`. The qparams are calculated by multiplying the min/max weight row values with the inverse of the equalization scale, and then calculating in the same way as in :class:`~torch.quantization.observer.PerChannelMinMaxObserver` .. note:: If the running minimum equals to the running maximum, the scales and zero_points are set to 1.0 and 0. """ def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None, quant_max=None, factory_kwargs=None) -> None: super(_WeightEqualizationObserver, self).__init__() self.weight_col_obs = PerChannelMinMaxObserver( ch_axis=1, dtype=dtype, qscheme=qscheme, quant_min=quant_min, quant_max=quant_max, factory_kwargs=factory_kwargs) self.weight_row_obs = PerChannelMinMaxObserver( ch_axis=0, dtype=dtype, qscheme=qscheme, quant_min=quant_min, quant_max=quant_max, factory_kwargs=factory_kwargs) self.equalization_scale = torch.empty(0) def forward(self, w_orig): # TODO: Allow for convoluational layers if not (w_orig.ndim == 2): raise ValueError( "WeightEqualizationObserver only supports Linear layers") return self._forward(w_orig) def _forward(self, w_orig): r""" Calculates the min/max values of each weight column and weight row. """ w_orig = self.weight_col_obs(w_orig) w_orig = self.weight_row_obs(w_orig) # Calculate the column indices of the min/max weight in each row num_row, _ = w_orig.shape min_weights_ind = [] max_weights_ind = [] for i in range(num_row): min_weights_ind.append( torch.nonzero( w_orig[i] == self.weight_row_obs.min_vals[i])[0][0]) max_weights_ind.append( torch.nonzero( w_orig[i] == self.weight_row_obs.max_vals[i])[0][0]) self.min_weights_ind = torch.tensor(min_weights_ind) self.max_weights_ind = torch.tensor(max_weights_ind) return w_orig def get_weight_col_minmax(self): return (self.weight_col_obs.min_vals, self.weight_col_obs.max_vals) def get_weight_row_minmax(self): return (self.weight_row_obs.min_vals, self.weight_row_obs.max_vals) def set_equalization_scale(self, equalization_scale): self.equalization_scale = equalization_scale def calculate_qparams(self): r""" Returns the scale/zero_point for the input and weight rows """ if self.equalization_scale.nelement() == 0: warnings.warn( "Must call calculate_scale before calling calculate_qparams.\ Returning default scale and zero point. ") return torch.tensor([1.0]), torch.tensor([0]), torch.tensor( [1.0]), torch.tensor([0]) if self.min_weights_ind is None or self.max_weights_ind is None: warnings.warn( "Must find the column indicies of the minimum of each row in the \ weights in order to calculate the qparams calculate the \ qparams. Returning default scale and zero point. ") return torch.tensor([1.0]), torch.tensor([0]), torch.tensor( [1.0]), torch.tensor([0]) # Calculate the qparams for weights by using the rows # Scale the weight rows by the reciprocal of the equalization scale # located at the same column index (min_weights, max_weights) = self.get_weight_row_minmax() min_weights_scaled = torch.mul( min_weights, torch.reciprocal(self.equalization_scale[self.min_weights_ind])) max_weights_scaled = torch.mul( max_weights, torch.reciprocal(self.equalization_scale[self.max_weights_ind])) (scale_weight, zero_point_weight) = self.weight_row_obs._calculate_qparams( min_weights_scaled, max_weights_scaled) return scale_weight, zero_point_weight