def __init__(self, signed: bool, quant_type: QuantType, ls_bit_width_to_trunc: int, trunc_at_least_init_val: bool, min_overall_bit_width: Optional[int], max_overall_bit_width: Optional[int], lsb_trunc_bit_width_impl_type: BitWidthImplType, explicit_rescaling: bool, override_pretrained_bit_width: bool): super(TruncQuantProxy, self).__init__() self.explicit_rescaling = explicit_rescaling if quant_type == QuantType.FP: self.lsb_trunc_bit_width_impl = ZeroLsbTruncBitWidth() self.tensor_quant = IdentityPrescaledIntQuant() elif quant_type == QuantType.INT: self.lsb_trunc_bit_width_impl = LsbTruncParameterBitWidth( ls_bit_width_to_trunc=ls_bit_width_to_trunc, trunc_at_least_init_val=trunc_at_least_init_val, min_overall_bit_width=min_overall_bit_width, max_overall_bit_width=max_overall_bit_width, bit_width_impl_type=lsb_trunc_bit_width_impl_type, override_pretrained=override_pretrained_bit_width) tensor_clamp_impl = TensorClamp() float_to_int_impl = RestrictValue( restrict_value_type=RestrictValueType.INT, float_to_int_impl_type=FloatToIntImplType.FLOOR, min_val=None) msb_clamp_bit_width_impl = IdentityBitWidth() self.tensor_quant = PrescaledRestrictIntQuantWithInputBitWidth( narrow_range=False, signed=signed, tensor_clamp_impl=tensor_clamp_impl, msb_clamp_bit_width_impl=msb_clamp_bit_width_impl, float_to_int_impl=float_to_int_impl) else: raise Exception( "Quantization type {} not supported for accumulators.".format( quant_type))
def test_PrescaledRestrictIntQuantWithInputBitWidth(x, narrow_range, signed, scale, bit_width): value = torch.tensor(x) scale = torch.tensor(scale) tensor_clamp_impl = TensorClamp() msb_clamp_bitwidth_mock = Mock() msb_clamp_bitwidth_mock.return_value = torch.tensor(bit_width, dtype=torch.float) float_to_int_impl_mock = Mock() float_to_int_impl_mock.side_effect = (lambda y: y) obj = PrescaledRestrictIntQuantWithInputBitWidth(narrow_range=narrow_range, signed=signed, tensor_clamp_impl=tensor_clamp_impl, msb_clamp_bit_width_impl=msb_clamp_bitwidth_mock, float_to_int_impl=float_to_int_impl_mock) output, scale, bit_width = obj(value, scale, bit_width, torch.tensor(ZERO_HW_SENTINEL_VALUE)) expected_IntQuant = IntQuant(signed=signed, narrow_range=narrow_range, tensor_clamp_impl=tensor_clamp_impl, float_to_int_impl=float_to_int_impl_mock) expected_output = expected_IntQuant(scale, torch.tensor(ZERO_HW_SENTINEL_VALUE) + 1, bit_width, value) assert torch.allclose(expected_output, output, RTOL, ATOL)
def __init__(self, quant_type: QuantType, bit_width: Optional[int], narrow_range: bool) -> None: super(BiasQuantProxy, self).__init__() self.scale_output_shape = OVER_BATCH_OVER_CHANNELS_SHAPE if quant_type == QuantType.FP: self.tensor_quant = None elif quant_type == QuantType.INT: tensor_clamp_impl = TensorClamp() float_to_int_impl = RestrictValue( restrict_value_type=RestrictValueType.INT, float_to_int_impl_type=FloatToIntImplType.ROUND, min_val=None) if bit_width is not None: bit_width_impl = BitWidthConst( bit_width, restrict_bit_width_type=RestrictValueType.INT) self.tensor_quant = PrescaledRestrictIntQuant( narrow_range=narrow_range, signed=True, tensor_clamp_impl=tensor_clamp_impl, msb_clamp_bit_width_impl=bit_width_impl, float_to_int_impl=float_to_int_impl) self.requires_input_bit_width = False else: msb_clamp_bit_width_impl = IdentityBitWidth() self.tensor_quant = PrescaledRestrictIntQuantWithInputBitWidth( narrow_range=narrow_range, signed=True, tensor_clamp_impl=tensor_clamp_impl, msb_clamp_bit_width_impl=msb_clamp_bit_width_impl, float_to_int_impl=float_to_int_impl) self.requires_input_bit_width = True else: raise Exception( 'Quantization type {} not supported for bias quant.'.format( str(quant_type)))