示例#1
0
 def __add__(self, other):
     QuantTensor.check_input_type(other)
     self.check_scaling_factors_same(other)
     output_tensor = self.tensor + other.tensor
     output_scale = (self.scale + other.scale) / 2
     max_uint_val = max_uint(narrow_range=False, bit_width=self.bit_width)
     max_uint_val += max_uint(narrow_range=False, bit_width=other.bit_width)
     output_bit_width = ceil_ste(torch.log2(max_uint_val))
     output = pack_quant_tensor(output_tensor, output_scale,
                                output_bit_width)
     return output
示例#2
0
 def __add__(self, other):
     QuantTensor.check_input_type(other)
     if self.is_valid and other.is_valid:
         self.check_scaling_factors_same(other)
         output_value = self.value + other.value
         output_scale = (self.scale + other.scale) / 2
         max_uint_val = max_uint(narrow_range=False, bit_width=self.bit_width)
         max_uint_val += max_uint(narrow_range=False, bit_width=other.bit_width)
         output_bit_width = ceil_ste(torch.log2(max_uint_val))
         output_signed = self.signed or other.signed
         output = QuantTensor(output_value, output_scale, output_bit_width, output_signed)
     else:
         output_value = self.value + other.value
         output = QuantTensor(output_value)
     return output
示例#3
0
 def max_output_bit_width(self, input_bit_width, weight_bit_width):
     max_input_val = max_uint(bit_width=input_bit_width, narrow_range=False)
     max_fc_val = self.weight_quant.tensor_quant.int_quant.max_uint(
         weight_bit_width)
     max_output_val = max_input_val * max_fc_val * self.in_features
     output_bit_width = ceil_ste(torch.log2(max_output_val))
     return output_bit_width
示例#4
0
 def max_acc_bit_width(self, input_bit_width, weight_bit_width):
     max_uint_input = max_uint(bit_width=input_bit_width, narrow_range=False)
     max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
     group_size = self.out_channels // self.groups
     max_uint_output = max_uint_input * max_kernel_val * self.kernel_size[0] * group_size
     max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
     return max_output_bit_width
示例#5
0
 def max_acc_bit_width(self, input_bit_width, weight_bit_width):
     max_uint_input = max_uint(bit_width=input_bit_width, narrow_range=False)
     max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
     group_size = self.out_channels // self.groups
     overlapping_sums = max(round(self.kernel_size[0] / self.stride[0]), 1)
     max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size
     max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
     return max_output_bit_width
示例#6
0
 def max_uint(self, bit_width):
     return max_uint(self.narrow_range, bit_width)
示例#7
0
 def max_output_bit_width(self, input_bit_width):
     max_input_val = max_uint(bit_width=input_bit_width, narrow_range=False)
     max_output_val = max_input_val * self.in_channels
     output_bit_width = ceil_ste(torch.log2(max_output_val))
     return output_bit_width
示例#8
0
 def max_output_bit_width(self, input_bit_width):
     max_uint_input = max_uint(bit_width=input_bit_width,
                               narrow_range=False)
     max_uint_output = max_uint_input * self.kernel_size * self.kernel_size
     max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
     return max_output_bit_width
示例#9
0
 def max_acc_bit_width(self, input_bit_width, reduce_size):
     max_uint_input = max_uint(bit_width=input_bit_width,
                               narrow_range=False)
     max_uint_output = max_uint_input * reduce_size
     max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
     return max_output_bit_width
示例#10
0
 def max_acc_bit_width(self, input_bit_width):
     max_uint_input = max_uint(bit_width=input_bit_width,
                               narrow_range=False)
     max_uint_output = max_uint_input * self._avg_scaling
     max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
     return max_output_bit_width
示例#11
0
 def max_acc_bit_width(self, input_bit_width, weight_bit_width):
     max_input_val = max_uint(bit_width=input_bit_width, narrow_range=False)
     max_weight_val = self.weight_quant.max_uint_value(weight_bit_width)
     max_output_val = max_input_val * max_weight_val
     output_bit_width = ceil_ste(torch.log2(max_output_val))
     return output_bit_width