def collect(self, x): """Tracks the absolute max of all tensors Args: x: A tensor Raises: RuntimeError: If amax shape changes """ if torch.min(x) < 0.: logging.log_first_n(logging.INFO, ( "Calibrator encountered negative values. It shouldn't happen after ReLU. " "Make sure this is the right tensor to calibrate."), 1) x = x.abs() # Swap axis to reduce. axis = self._axis if isinstance(self._axis, (list, tuple)) else [self._axis] reduce_axis = [] for i in range(x.dim()): if not i in axis: reduce_axis.append(i) local_amax = quant_utils.reduce_amax(x, axis=reduce_axis).detach() if self._calib_amax is None: self._calib_amax = local_amax else: if local_amax.shape != self._calib_amax.shape: raise RuntimeError("amax shape changed!") self._calib_amax.copy_( torch.max(self._calib_amax, local_amax).data) if self._track_amax: self._amaxs.append(local_amax.cpu().numpy())
def test_fake_quant_per_channel_bias(self): kernel_size = 3 quant_conv_object = quant_conv.QuantConvTranspose3d( _NUM_IN_CHANNELS, _NUM_OUT_CHANNELS, kernel_size, bias=True, quant_desc_weight=tensor_quant. QUANT_DESC_8BIT_CONVTRANSPOSE3D_WEIGHT_PER_CHANNEL) test_input = torch.randn(2, _NUM_IN_CHANNELS, 2, 2, 2) quant_input = tensor_quant.fake_tensor_quant( test_input, torch.max(torch.abs(test_input))) weight_copy = quant_conv_object.weight.clone() amax = quant_utils.reduce_amax(weight_copy, axis=(0, 2, 3, 4)) quant_weight = tensor_quant.fake_tensor_quant(weight_copy, amax) out1 = F.conv_transpose3d(quant_input, quant_weight, bias=quant_conv_object.bias) out2 = quant_conv_object(test_input) np.testing.assert_array_equal(out1.detach().cpu().numpy(), out2.detach().cpu().numpy())
def test_reduce_amax(self): x_np = (np.random.rand(3, 7, 11, 13, 17) - 0.1).astype(np.float32) x_torch = torch.tensor(x_np) # Test reduce to one value amax_np = np.max(np.abs(x_np)) amax_torch = quant_utils.reduce_amax(x_torch) np.testing.assert_array_equal(amax_np, amax_torch.cpu().numpy()) # Test different axis axes = [(1, 2, 3), (0, 2, 3), (0, 3), (0, 1, 3, 4)] for axis in axes: keepdims = np.random.rand() > 0.5 amax_np = np.max(np.abs(x_np), axis=axis, keepdims=keepdims) amax_torch = quant_utils.reduce_amax(x_torch, axis=axis, keepdims=keepdims) np.testing.assert_array_almost_equal(amax_np, amax_torch.cpu().numpy()) with pytest.raises(ValueError) as excinfo: quant_utils.reduce_amax(x_torch, axis=(0, 1, 2, 3, 4, 5)) assert "Cannot reduce more axes" in str(excinfo.value)
def test_max_calib(self): axis = 0 reduce_axis = (1, 2, 3) quant_desc1 = tensor_quant.QuantDescriptor(axis=axis) quantizer1 = tensor_quantizer.TensorQuantizer(quant_desc1).cuda() quantizer1.enable_calib() quant_desc1 = tensor_quant.QuantDescriptor(axis=axis) quantizer1 = tensor_quantizer.TensorQuantizer(quant_desc1).cuda() quantizer1.enable_calib() with pytest.raises(RuntimeError, match="Calibrator returned None"): quantizer1.load_calib_amax() x_1 = torch.rand(127, 63, 7, 7).cuda() x_2 = torch.rand(127, 63, 7, 7).cuda() quantizer1(x_1) quantizer1(x_2) quantizer1.disable_calib() global_amax = torch.max( quant_utils.reduce_amax(x_1, axis=reduce_axis, keepdims=True), quant_utils.reduce_amax(x_2, axis=reduce_axis, keepdims=True)) test_utils.compare(quantizer1._calibrator.compute_amax(), global_amax, atol=0, rtol=0, ctol=0) quantizer1.load_calib_amax() test_utils.compare(quantizer1.amax, global_amax, atol=0, rtol=0, ctol=0) quant_desc2 = tensor_quant.QuantDescriptor(learn_amax=True) quantizer2 = tensor_quantizer.TensorQuantizer(quant_desc2).cuda() quantizer2.enable_calib() quantizer2(x_1) quantizer2(x_2) quantizer2.load_calib_amax() quantizer2.init_learn_amax() test_utils.compare(quantizer2.clip.clip_value_min, -torch.max(global_amax), atol=0, rtol=0, ctol=0) test_utils.compare(quantizer2.clip.clip_value_max, torch.max(global_amax), atol=0, rtol=0, ctol=0)
def test_fine_grain(self): axis = 0 reducs_axis = (1, 2, 3) max_calibrator = calib.MaxCalibrator(8, axis, False) x_1 = torch.rand(31, 63, 7, 7).cuda() x_2 = torch.rand(31, 63, 7, 7).cuda() max_calibrator.collect(x_1) max_calibrator.collect(x_2) assert max_calibrator.compute_amax().shape[0] == 31 test_utils.compare(max_calibrator.compute_amax(), quant_utils.reduce_amax(torch.max(x_1, x_2), axis=reducs_axis), atol=0, rtol=0, ctol=0) max_calibrator.reset() assert max_calibrator.compute_amax() is None
def test_weight_fake_quant_per_channel(self): kernel_size = 3 quant_conv_object = quant_conv.QuantConv1d( _NUM_IN_CHANNELS, _NUM_OUT_CHANNELS, kernel_size, bias=False, quant_desc_weight=QuantDescriptor(axis=(0))) quant_conv_object.input_quantizer.disable() test_input = torch.randn(16, _NUM_IN_CHANNELS, 256) weight_copy = quant_conv_object.weight.clone() amax = quant_utils.reduce_amax(weight_copy, axis=(1, 2)) quant_weight = tensor_quant.fake_tensor_quant(weight_copy, amax) out1 = F.conv1d(test_input, quant_weight) out2 = quant_conv_object(test_input) np.testing.assert_array_equal(out1.detach().cpu().numpy(), out2.detach().cpu().numpy())
def test_weight_fake_per_channel(self): size_in = 255 size_out = 257 quant_linear_object = quant_linear.QuantLinear( size_in, size_out, bias=False, quant_desc_weight=tensor_quant. QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW) quant_linear_object.input_quantizer.disable() test_input = torch.randn(32, size_in) weight_copy = quant_linear_object.weight.clone() amax = quant_utils.reduce_amax(weight_copy, axis=1, keepdims=True) quant_weight = tensor_quant.fake_tensor_quant(weight_copy, amax) out1 = F.linear(test_input, quant_weight) out2 = quant_linear_object(test_input) np.testing.assert_array_equal(out1.detach().cpu().numpy(), out2.detach().cpu().numpy())
def test_weight_fake_quant_per_channel(self): kernel_size = 3 quant_conv_object = quant_conv.QuantConvTranspose2d( _NUM_IN_CHANNELS, _NUM_OUT_CHANNELS, kernel_size, bias=False, quant_desc_weight=tensor_quant.QUANT_DESC_8BIT_CONVTRANSPOSE2D_WEIGHT_PER_CHANNEL) quant_conv_object.input_quantizer.disable() test_input = torch.randn(16, _NUM_IN_CHANNELS, 256, 256) weight_copy = quant_conv_object.weight.clone() amax = quant_utils.reduce_amax(weight_copy, axis=(0, 2, 3)) quant_weight = tensor_quant.fake_tensor_quant(weight_copy, amax) out1 = F.conv_transpose2d(test_input, quant_weight) out2 = quant_conv_object(test_input) np.testing.assert_array_equal(out1.detach().cpu().numpy(), out2.detach().cpu().numpy())
def _get_amax(self, inputs): """get amax from buffer or compute it dynamically.""" if hasattr(self, '_amax'): amax = self._amax else: if self._axis is None: reduce_axis = None else: reduce_axis = [] # Swap axis to reduce axis = self._axis if isinstance(self._axis, (list, tuple)) else [self._axis] for i in range(inputs.dim()): if not i in axis: reduce_axis.append(i) amax = quant_utils.reduce_amax(inputs, axis=reduce_axis, keepdims=True).detach() if self._scale_amax is not None: amax = amax.detach() * self._scale_amax return amax
def calibrate_weights(model, method="percentile", perchannel=True, percentile=99.99, num_bins=2048): """Calibrate weights of all child quantized modules Ideally, we would split calibration functionality to histogram collector and calibrator which takes histogram and compute amax. But since we haven't decoupled collector and calibrator, it is easier to create a separate function to calibrate weight. .. note:: This function uses `method` specified by the argument to decide which method to use, NOT the one specified by the calibrator embedded in weight_quantizer. We haven't moved calibration to GPU, so everything is transfered to CPU Args: model: A torch.nn.Module. method: A string of calibration method. Supports "mse" and "percentile". Default "percentile" perchannel: A bool. Set channel/neuron axis if True. Default True. percentile: A float. Default 99.99 num_bins: A integer. Number of bins of histogram. Default 2048. """ for name, module in model.named_modules(): if hasattr(module, "weight") and hasattr(module, "weight_quantizer"): logging.info("Calibrate weight of %s", name) num_bits = module.weight_quantizer.num_bits unsigned = module.weight_quantizer.unsigned channel_second_modules = (quant_nn.QuantConvTranspose1d, quant_nn.QuantConvTranspose2d, quant_nn.QuantConvTranspose3d) if perchannel: axis = 1 if isinstance(module, channel_second_modules) else 0 else: axis = None axis_size = module.weight.shape[axis] if axis is not None else 1 # Histogram is always collected even if method is "max". Although "max" is supported here # but it is not the primary usage of this function if axis is None: calib_hist, calib_bin_edges = np.histogram( module.weight.abs().cpu().detach().numpy(), bins=2048) calib_hist = [calib_hist] calib_bin_edges = [calib_bin_edges] else: calib_hist = [] calib_bin_edges = [] for i in range(axis_size): hist, bin_edges = np.histogram(module.weight.index_select( axis, torch.tensor(i, device=module.weight.device) ).abs().cpu().detach().numpy(), bins=num_bins) calib_hist.append(hist) calib_bin_edges.append(bin_edges) calib_amax = [] if method == "max": reduce_axis = list(range(module.weight.dim())) reduce_axis.remove(axis) calib_amax.append( quant_utils.reduce_amax(module.weight, axis=reduce_axis)) elif method == 'mse': for i in range(axis_size): calib_amax.append( _compute_amax_mse(calib_hist[i], calib_bin_edges[i], num_bits, unsigned)) elif method == 'percentile': for i in range(axis_size): calib_amax.append( _compute_amax_percentile(calib_hist[i], calib_bin_edges[i], percentile)) else: raise TypeError( "Unsupported calibration method {}".format(method)) if axis is None: calib_amax = calib_amax[0] else: calib_amax_shape = [1] * module.weight.dim() calib_amax_shape[axis] = module.weight.shape[axis] calib_amax = torch.stack(calib_amax).reshape(calib_amax_shape) module.weight_quantizer.amax = calib_amax.detach().cpu().numpy()