def __add__(self, other): if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none: self.check_scaling_factors_same(other) self.check_zero_points_same(other) output_value = self.value + other.value output_scale = (self.scale + other.scale) / 2 output_zero_point = (self.zero_point + other.zero_point) / 2 max_val = max_int(signed=self.signed, narrow_range=False, bit_width=self.bit_width) max_val += max_int(signed=other.signed, narrow_range=False, bit_width=other.bit_width) min_val = min_int(signed=self.signed, narrow_range=False, bit_width=self.bit_width) min_val += min_int(signed=other.signed, narrow_range=False, bit_width=other.bit_width) output_bit_width = ceil_ste(torch.log2(max_val - min_val)) output_signed = self.signed or other.signed output_training = self.training or other.training output = QuantTensor(value=output_value, scale=output_scale, zero_point=output_zero_point, bit_width=output_bit_width, signed=output_signed, training=output_training) elif isinstance(other, QuantTensor): output = QuantTensor(self.value + other.value) else: output = QuantTensor(self.value + other) return output
def test_IntQuant(x, narrow_range, signed, bit_width, scale, int_scale, float_to_int_impl, scale_multiplier): float_to_int_impl_mock = Mock() tensor_clamp_impl = TensorClamp() value = torch.tensor(x) bit_width = torch.tensor(bit_width, dtype=torch.float) scale = torch.tensor(scale) int_scale = torch.tensor(int_scale) tol = scale * scale_multiplier float_to_int_impl_mock.side_effect = float_to_int_impl() obj = IntQuant(narrow_range=narrow_range, signed=signed, float_to_int_impl=float_to_int_impl_mock, tensor_clamp_impl=tensor_clamp_impl) output = obj(scale, int_scale, bit_width, value) min_value = int(min_int(signed, narrow_range, bit_width)) max_value = int(max_int(signed, bit_width)) admissible_values = [x for x in range(min_value, max_value + 1)] value = (value / scale) * int_scale expected_output = tensor_clamp(value, min_val=min_int(signed, narrow_range, bit_width), max_val=max_int(signed, bit_width)) expected_output = (expected_output / int_scale) * scale int_output = obj.to_int(scale, int_scale, bit_width, value) # The assert is performed internally check_admissible_values check_admissible_values(int_output, admissible_values) assert torch.allclose(expected_output, output, RTOL, tol)
def max_output_bit_width(self, input_bit_width): max_input_val = max_int(bit_width=input_bit_width, narrow_range=False, signed=False) max_output_val = max_input_val * self.in_channels output_bit_width = ceil_ste(torch.log2(max_output_val)) return output_bit_width
def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups max_uint_output = max_uint_input * max_kernel_val * self.kernel_size[0] * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width
def max_acc_bit_width(self, input_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_uint_output = max_uint_input * self._avg_scaling max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width
def max_acc_bit_width(self, input_bit_width, reduce_size): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_uint_output = max_uint_input * reduce_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width
def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_input_val = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_weight_val = self.weight_quant.max_uint_value(weight_bit_width) max_output_val = max_input_val * max_weight_val output_bit_width = ceil_ste(torch.log2(max_output_val)) return output_bit_width
def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) group_size = self.out_channels // self.groups overlapping_sums = max(round(self.kernel_size[0] / self.stride[0]), 1) overlapping_sums *= max(round(self.kernel_size[1] / self.stride[1]), 1) max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width
def forward(self, x: Tensor): absmax = self.absmax_impl(x) bit_width = self.bit_width_impl() num_quantized_bins = max_int(self.signed, False, bit_width).int() thresholds = torch.zeros(self.num_bins // 2 + 1 - num_quantized_bins // 2, device=x.device) divergence = torch.zeros_like(thresholds) quantized_bins = torch.zeros(num_quantized_bins, device=x.device) hist = torch.histc(x, bins=self.num_bins, min=-absmax, max=absmax).int() hist_edges = torch.linspace(-absmax, absmax, self.num_bins + 1) for i in range(num_quantized_bins // 2, self.num_bins // 2 + 1): p_bin_idx_start = self.num_bins // 2 - i p_bin_idx_stop = self.num_bins // 2 + i + 1 thresholds[i - num_quantized_bins // 2] = hist_edges[p_bin_idx_stop] sliced_nd_hist = hist[p_bin_idx_start:p_bin_idx_stop] p = sliced_nd_hist.clone() left_outlier_count = torch.sum(hist[0:p_bin_idx_start]) p[0] += left_outlier_count right_outlier_count = torch.sum(hist[p_bin_idx_stop:]) p[-1] += right_outlier_count is_nonzeros = (sliced_nd_hist != 0).float() num_merged_bins = torch.numel(p) // num_quantized_bins for j in range(num_quantized_bins): start = j * num_merged_bins stop = start + num_merged_bins quantized_bins[j] = sliced_nd_hist[start:stop].sum() quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum() q = torch.zeros_like(p, dtype=torch.float32, device=x.device) for j in range(num_quantized_bins): start = j * num_merged_bins if j == num_quantized_bins - 1: stop = -1 else: stop = start + num_merged_bins norm = is_nonzeros[start:stop].sum() if norm != 0: q[start:stop] = quantized_bins[j] / norm q[sliced_nd_hist == 0] = 0. p = self.smooth_normalize_distribution(p, self.smoothing_eps) q = self.smooth_normalize_distribution(q, self.smoothing_eps) if q is None: divergence[i - num_quantized_bins // 2] = float('inf') else: divergence[i - num_quantized_bins // 2] = torch.distributions.kl.kl_divergence(p, q) min_divergence_idx = torch.argmin(divergence) opt_threshold = thresholds[min_divergence_idx] return opt_threshold
def __add__(self, other): QuantTensor.check_input_type(other) if self.is_valid and other.is_valid: self.check_scaling_factors_same(other) self.check_zero_points_same(other) output_value = self.value + other.value output_scale = (self.scale + other.scale) / 2 output_zero_point = (self.zero_point + other.zero_point) / 2 max_uint_val = max_int(signed=False, narrow_range=False, bit_width=self.bit_width) max_uint_val += max_int(signed=False, narrow_range=False, bit_width=other.bit_width) output_bit_width = ceil_ste(torch.log2(max_uint_val)) output_signed = self.signed or other.signed output_training = self.training or other.training output = QuantTensor( value=output_value, scale=output_scale, zero_point=output_zero_point, bit_width=output_bit_width, signed=output_signed, training=output_training) else: output_value = self.value + other.value output = QuantTensor(output_value) return output
def max_int(self, bit_width): return max_int(self.signed, bit_width)
def max_int(self, bit_width): return max_int(self.signed, self.narrow_range, bit_width)
def forward(self, bit_width): return max_int(self.signed, bit_width) + 1
def forward(self, bit_width): if self.signed: return - min_int(self.signed, self.narrow_range, bit_width) else: return max_int(self.signed, bit_width)
def forward(self, bit_width: Tensor) -> Tensor: return max_int(self.signed, False, bit_width) + 1