def forward(self, x: QuantTensor): if self.is_quant_enabled: cleaned_up_value = round_ste( x.value / x.scale.detach()) * x.scale.detach() x = x.set(value=cleaned_up_value ) # clean up accumulated floating point errors trunc_bit_width = self.lsb_trunc_bit_width_impl(x.bit_width) trunc_scale = 2.0**trunc_bit_width output_scale = trunc_scale * x.scale if self.training: x, output_scale, x_bit_width = self.tensor_quant( x.value, output_scale, x.bit_width) else: # avoid fp errors at inference time x_bit_width = x.bit_width x = round_ste(x.value / x.scale) x = x / trunc_scale x = self.tensor_quant.int_quant.float_to_int_impl(x) x = x * output_scale x = x / trunc_scale output_scale = output_scale / trunc_scale # output_scale == input_scale output_bit_width = x_bit_width - trunc_bit_width return QuantTensor(x, output_scale, output_bit_width, self.is_signed) else: return x
def int_weight(self, x: torch.Tensor): zero_hw_sentinel = getattr(self, ZERO_HW_SENTINEL_NAME) quant_weight, scale, _ = self.tensor_quant(x, zero_hw_sentinel) quant_weight = quant_weight / scale quant_weight = round_ste(quant_weight) quant_weight = quant_weight.int() return quant_weight
def int(self, float_datatype=False): if self.is_valid: int_value = round_ste(self._pre_round_int_value) if float_datatype: return int_value else: return int_value.int() else: raise RuntimeError(f"QuantTensor not valid.")
def int(self, float_datatype=False): if self.is_valid: int_value = self.value / self.scale int_value = round_ste(int_value) if float_datatype: return int_value else: return int_value.int() else: raise RuntimeError(f"QuantTensor not well formed, all fields must be set: {self}")
def forward(self, x, input_scale, input_bit_width): x = round_ste(x / input_scale) * input_scale # clean up fp errors before floor trunc_bit_width = self.lsb_trunc_bit_width_impl(input_bit_width, self.zero_hw_sentinel) trunc_scale = 2.0 ** trunc_bit_width output_scale = trunc_scale * input_scale x, output_scale, input_bit_width = self.tensor_quant(x, output_scale, input_bit_width, self.zero_hw_sentinel) if self.explicit_rescaling: x = x / trunc_scale # rescaling is explicit, so the truncation scale stays with x rather with output_scale output_scale = output_scale / trunc_scale output_bit_width = input_bit_width - trunc_bit_width return x, output_scale, output_bit_width
def forward(self, x, input_scale, input_bit_width): x = round_ste( x / input_scale) * input_scale # clean up fp errors before floor trunc_bit_width = self.lsb_trunc_bit_width_impl( input_bit_width, self.zero_hw_sentinel) trunc_scale = 2.0**trunc_bit_width output_scale = trunc_scale * input_scale if self.training: x, output_scale, input_bit_width = self.tensor_quant( x, output_scale, input_bit_width, self.zero_hw_sentinel) else: # avoid fp errors at inference time x = round_ste(x / input_scale) x = x / trunc_scale x = self.tensor_quant.int_quant.float_to_int_impl(x) x = x * output_scale if self.explicit_rescaling: x = x / trunc_scale # rescaling is explicit, so the truncation scale stays with x rather with output_scale output_scale = output_scale / trunc_scale output_bit_width = input_bit_width - trunc_bit_width return x, output_scale, output_bit_width
def forward(self, x: Tensor, scale: Tensor, input_bit_width: Tensor) -> Tuple[Tensor, Tensor, Tensor]: y = x / scale y = round_ste(y) # clean up floating point error output_bit_width = self.msb_clamp_bit_width_impl() trunc_bit_width = input_bit_width - output_bit_width trunc_scale = 2.0**trunc_bit_width y = y / trunc_scale y = self.float_to_int_impl(y) y = y * scale y = self.delay_wrapper(x, y) return y, scale, output_bit_width
def forward(self, x: torch.Tensor): return round_ste(x)
def __int__(self): return round_ste(self.tensor / self.scale)