Esempio n. 1
0
 def __add__(self, other):
     if isinstance(other,
                   QuantTensor) and self.is_not_none and other.is_not_none:
         self.check_scaling_factors_same(other)
         self.check_zero_points_same(other)
         output_value = self.value + other.value
         output_scale = (self.scale + other.scale) / 2
         output_zero_point = (self.zero_point + other.zero_point) / 2
         max_val = max_int(signed=self.signed,
                           narrow_range=False,
                           bit_width=self.bit_width)
         max_val += max_int(signed=other.signed,
                            narrow_range=False,
                            bit_width=other.bit_width)
         min_val = min_int(signed=self.signed,
                           narrow_range=False,
                           bit_width=self.bit_width)
         min_val += min_int(signed=other.signed,
                            narrow_range=False,
                            bit_width=other.bit_width)
         output_bit_width = ceil_ste(torch.log2(max_val - min_val))
         output_signed = self.signed or other.signed
         output_training = self.training or other.training
         output = QuantTensor(value=output_value,
                              scale=output_scale,
                              zero_point=output_zero_point,
                              bit_width=output_bit_width,
                              signed=output_signed,
                              training=output_training)
     elif isinstance(other, QuantTensor):
         output = QuantTensor(self.value + other.value)
     else:
         output = QuantTensor(self.value + other)
     return output
Esempio n. 2
0
def test_IntQuant(x, narrow_range, signed, bit_width, scale, int_scale,
                  float_to_int_impl, scale_multiplier):
    float_to_int_impl_mock = Mock()
    tensor_clamp_impl = TensorClamp()

    value = torch.tensor(x)
    bit_width = torch.tensor(bit_width, dtype=torch.float)
    scale = torch.tensor(scale)
    int_scale = torch.tensor(int_scale)

    tol = scale * scale_multiplier
    float_to_int_impl_mock.side_effect = float_to_int_impl()

    obj = IntQuant(narrow_range=narrow_range,
                   signed=signed,
                   float_to_int_impl=float_to_int_impl_mock,
                   tensor_clamp_impl=tensor_clamp_impl)
    output = obj(scale, int_scale, bit_width, value)

    min_value = int(min_int(signed, narrow_range, bit_width))
    max_value = int(max_int(signed, bit_width))
    admissible_values = [x for x in range(min_value, max_value + 1)]

    value = (value / scale) * int_scale
    expected_output = tensor_clamp(value,
                                   min_val=min_int(signed, narrow_range,
                                                   bit_width),
                                   max_val=max_int(signed, bit_width))
    expected_output = (expected_output / int_scale) * scale

    int_output = obj.to_int(scale, int_scale, bit_width, value)

    # The assert is performed internally check_admissible_values
    check_admissible_values(int_output, admissible_values)
    assert torch.allclose(expected_output, output, RTOL, tol)
Esempio n. 3
0
 def max_output_bit_width(self, input_bit_width):
     max_input_val = max_int(bit_width=input_bit_width,
                             narrow_range=False,
                             signed=False)
     max_output_val = max_input_val * self.in_channels
     output_bit_width = ceil_ste(torch.log2(max_output_val))
     return output_bit_width
Esempio n. 4
0
 def max_acc_bit_width(self, input_bit_width, weight_bit_width):
     max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
     max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
     group_size = self.out_channels // self.groups
     max_uint_output = max_uint_input * max_kernel_val * self.kernel_size[0] * group_size
     max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
     return max_output_bit_width
Esempio n. 5
0
 def max_acc_bit_width(self, input_bit_width):
     max_uint_input = max_int(bit_width=input_bit_width,
                              signed=False,
                              narrow_range=False)
     max_uint_output = max_uint_input * self._avg_scaling
     max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
     return max_output_bit_width
Esempio n. 6
0
 def max_acc_bit_width(self, input_bit_width, reduce_size):
     max_uint_input = max_int(bit_width=input_bit_width,
                              signed=False,
                              narrow_range=False)
     max_uint_output = max_uint_input * reduce_size
     max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
     return max_output_bit_width
Esempio n. 7
0
 def max_acc_bit_width(self, input_bit_width, weight_bit_width):
     max_input_val = max_int(bit_width=input_bit_width,
                             signed=False,
                             narrow_range=False)
     max_weight_val = self.weight_quant.max_uint_value(weight_bit_width)
     max_output_val = max_input_val * max_weight_val
     output_bit_width = ceil_ste(torch.log2(max_output_val))
     return output_bit_width
Esempio n. 8
0
 def max_acc_bit_width(self, input_bit_width, weight_bit_width):
     max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
     max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
     group_size = self.out_channels // self.groups
     overlapping_sums = max(round(self.kernel_size[0] / self.stride[0]), 1)
     overlapping_sums *= max(round(self.kernel_size[1] / self.stride[1]), 1)
     max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size
     max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
     return max_output_bit_width
Esempio n. 9
0
 def forward(self, x: Tensor):
     absmax = self.absmax_impl(x)
     bit_width = self.bit_width_impl()
     num_quantized_bins = max_int(self.signed, False, bit_width).int()
     thresholds = torch.zeros(self.num_bins // 2 + 1 -
                              num_quantized_bins // 2,
                              device=x.device)
     divergence = torch.zeros_like(thresholds)
     quantized_bins = torch.zeros(num_quantized_bins, device=x.device)
     hist = torch.histc(x, bins=self.num_bins, min=-absmax,
                        max=absmax).int()
     hist_edges = torch.linspace(-absmax, absmax, self.num_bins + 1)
     for i in range(num_quantized_bins // 2, self.num_bins // 2 + 1):
         p_bin_idx_start = self.num_bins // 2 - i
         p_bin_idx_stop = self.num_bins // 2 + i + 1
         thresholds[i -
                    num_quantized_bins // 2] = hist_edges[p_bin_idx_stop]
         sliced_nd_hist = hist[p_bin_idx_start:p_bin_idx_stop]
         p = sliced_nd_hist.clone()
         left_outlier_count = torch.sum(hist[0:p_bin_idx_start])
         p[0] += left_outlier_count
         right_outlier_count = torch.sum(hist[p_bin_idx_stop:])
         p[-1] += right_outlier_count
         is_nonzeros = (sliced_nd_hist != 0).float()
         num_merged_bins = torch.numel(p) // num_quantized_bins
         for j in range(num_quantized_bins):
             start = j * num_merged_bins
             stop = start + num_merged_bins
             quantized_bins[j] = sliced_nd_hist[start:stop].sum()
         quantized_bins[-1] += sliced_nd_hist[num_quantized_bins *
                                              num_merged_bins:].sum()
         q = torch.zeros_like(p, dtype=torch.float32, device=x.device)
         for j in range(num_quantized_bins):
             start = j * num_merged_bins
             if j == num_quantized_bins - 1:
                 stop = -1
             else:
                 stop = start + num_merged_bins
             norm = is_nonzeros[start:stop].sum()
             if norm != 0:
                 q[start:stop] = quantized_bins[j] / norm
         q[sliced_nd_hist == 0] = 0.
         p = self.smooth_normalize_distribution(p, self.smoothing_eps)
         q = self.smooth_normalize_distribution(q, self.smoothing_eps)
         if q is None:
             divergence[i - num_quantized_bins // 2] = float('inf')
         else:
             divergence[i - num_quantized_bins //
                        2] = torch.distributions.kl.kl_divergence(p, q)
     min_divergence_idx = torch.argmin(divergence)
     opt_threshold = thresholds[min_divergence_idx]
     return opt_threshold
Esempio n. 10
0
 def __add__(self, other):
     QuantTensor.check_input_type(other)
     if self.is_valid and other.is_valid:
         self.check_scaling_factors_same(other)
         self.check_zero_points_same(other)
         output_value = self.value + other.value
         output_scale = (self.scale + other.scale) / 2
         output_zero_point = (self.zero_point + other.zero_point) / 2
         max_uint_val = max_int(signed=False, narrow_range=False, bit_width=self.bit_width)
         max_uint_val += max_int(signed=False, narrow_range=False, bit_width=other.bit_width)
         output_bit_width = ceil_ste(torch.log2(max_uint_val))
         output_signed = self.signed or other.signed
         output_training = self.training or other.training
         output = QuantTensor(
             value=output_value,
             scale=output_scale,
             zero_point=output_zero_point,
             bit_width=output_bit_width,
             signed=output_signed,
             training=output_training)
     else:
         output_value = self.value + other.value
         output = QuantTensor(output_value)
     return output
Esempio n. 11
0
 def max_int(self, bit_width):
     return max_int(self.signed, bit_width)
Esempio n. 12
0
 def max_int(self, bit_width):
     return max_int(self.signed, self.narrow_range, bit_width)
Esempio n. 13
0
 def forward(self, bit_width):
     return max_int(self.signed, bit_width) + 1
Esempio n. 14
0
 def forward(self, bit_width):
     if self.signed:
         return - min_int(self.signed, self.narrow_range, bit_width)
     else:
         return max_int(self.signed, bit_width)
Esempio n. 15
0
 def forward(self, bit_width: Tensor) -> Tensor:
     return max_int(self.signed, False, bit_width) + 1