Example #1
0
 def forward(self, x: QuantTensor):
     if self.is_quant_enabled:
         cleaned_up_value = round_ste(
             x.value / x.scale.detach()) * x.scale.detach()
         x = x.set(value=cleaned_up_value
                   )  # clean up accumulated floating point errors
         trunc_bit_width = self.lsb_trunc_bit_width_impl(x.bit_width)
         trunc_scale = 2.0**trunc_bit_width
         output_scale = trunc_scale * x.scale
         if self.training:
             x, output_scale, x_bit_width = self.tensor_quant(
                 x.value, output_scale, x.bit_width)
         else:  # avoid fp errors at inference time
             x_bit_width = x.bit_width
             x = round_ste(x.value / x.scale)
             x = x / trunc_scale
             x = self.tensor_quant.int_quant.float_to_int_impl(x)
             x = x * output_scale
         x = x / trunc_scale
         output_scale = output_scale / trunc_scale  # output_scale == input_scale
         output_bit_width = x_bit_width - trunc_bit_width
         return QuantTensor(x, output_scale, output_bit_width,
                            self.is_signed)
     else:
         return x
Example #2
0
 def int_weight(self, x: torch.Tensor):
     zero_hw_sentinel = getattr(self, ZERO_HW_SENTINEL_NAME)
     quant_weight, scale, _ = self.tensor_quant(x, zero_hw_sentinel)
     quant_weight = quant_weight / scale
     quant_weight = round_ste(quant_weight)
     quant_weight = quant_weight.int()
     return quant_weight
Example #3
0
 def int(self, float_datatype=False):
     if self.is_valid:
         int_value = round_ste(self._pre_round_int_value)
         if float_datatype:
             return int_value
         else:
             return int_value.int()
     else:
         raise RuntimeError(f"QuantTensor not valid.")
Example #4
0
 def int(self, float_datatype=False):
     if self.is_valid:
         int_value = self.value / self.scale
         int_value = round_ste(int_value)
         if float_datatype:
             return int_value
         else:
             return int_value.int()
     else:
         raise RuntimeError(f"QuantTensor not well formed, all fields must be set: {self}")
Example #5
0
 def forward(self, x, input_scale, input_bit_width):
     x = round_ste(x / input_scale) * input_scale  # clean up fp errors before floor
     trunc_bit_width = self.lsb_trunc_bit_width_impl(input_bit_width, self.zero_hw_sentinel)
     trunc_scale = 2.0 ** trunc_bit_width
     output_scale = trunc_scale * input_scale
     x, output_scale, input_bit_width = self.tensor_quant(x, output_scale, input_bit_width, self.zero_hw_sentinel)
     if self.explicit_rescaling:
         x = x / trunc_scale  # rescaling is explicit, so the truncation scale stays with x rather with output_scale
         output_scale = output_scale / trunc_scale
     output_bit_width = input_bit_width - trunc_bit_width
     return x, output_scale, output_bit_width
Example #6
0
 def forward(self, x, input_scale, input_bit_width):
     x = round_ste(
         x / input_scale) * input_scale  # clean up fp errors before floor
     trunc_bit_width = self.lsb_trunc_bit_width_impl(
         input_bit_width, self.zero_hw_sentinel)
     trunc_scale = 2.0**trunc_bit_width
     output_scale = trunc_scale * input_scale
     if self.training:
         x, output_scale, input_bit_width = self.tensor_quant(
             x, output_scale, input_bit_width, self.zero_hw_sentinel)
     else:  # avoid fp errors at inference time
         x = round_ste(x / input_scale)
         x = x / trunc_scale
         x = self.tensor_quant.int_quant.float_to_int_impl(x)
         x = x * output_scale
     if self.explicit_rescaling:
         x = x / trunc_scale  # rescaling is explicit, so the truncation scale stays with x rather with output_scale
         output_scale = output_scale / trunc_scale
     output_bit_width = input_bit_width - trunc_bit_width
     return x, output_scale, output_bit_width
Example #7
0
 def forward(self, x: Tensor, scale: Tensor,
             input_bit_width: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
     y = x / scale
     y = round_ste(y)  # clean up floating point error
     output_bit_width = self.msb_clamp_bit_width_impl()
     trunc_bit_width = input_bit_width - output_bit_width
     trunc_scale = 2.0**trunc_bit_width
     y = y / trunc_scale
     y = self.float_to_int_impl(y)
     y = y * scale
     y = self.delay_wrapper(x, y)
     return y, scale, output_bit_width
Example #8
0
 def forward(self, x: torch.Tensor):
     return round_ste(x)
Example #9
0
 def __int__(self):
     return round_ste(self.tensor / self.scale)