def __init__(self, quant_type: QuantType, bit_width: Optional[int], narrow_range: bool) -> None: super(BiasQuantProxy, self).__init__() if quant_type == QuantType.FP: self.tensor_quant = None elif quant_type == QuantType.INT: tensor_clamp_impl = TensorClamp() float_to_int_impl = RestrictValue(restrict_value_type=RestrictValueType.INT, float_to_int_impl_type=FloatToIntImplType.ROUND, min_val=None) if bit_width is not None: bit_width_impl = BitWidthConst(bit_width, restrict_bit_width_type=RestrictValueType.INT) self.tensor_quant = PrescaledRestrictIntQuant(narrow_range=narrow_range, signed=True, tensor_clamp_impl=tensor_clamp_impl, msb_clamp_bit_width_impl=bit_width_impl, float_to_int_impl=float_to_int_impl) self.requires_input_bit_width = False else: self.tensor_quant = PrescaledIntQuant(narrow_range=narrow_range, signed=True, tensor_clamp_impl=tensor_clamp_impl, float_to_int_impl=float_to_int_impl) self.requires_input_bit_width = True else: raise Exception('Quantization type {} not supported for bias quant.' .format(str(quant_type)))
def test_RescalingIntQuant(x, narrow_range, signed, scale, int_scale, bit_width): value = torch.tensor(x) scale = torch.tensor(scale) bit_width = torch.tensor(bit_width, dtype=torch.float) int_scale = torch.tensor(int_scale, dtype=torch.float) tensor_clamp_impl = TensorClamp() msb_clamp_bitwidth_mock = Mock() msb_clamp_bitwidth_mock.return_value = bit_width float_to_int_impl_mock = Mock() float_to_int_impl_mock.side_effect = (lambda y: y) int_scaling_impl_mock = Mock() int_scaling_impl_mock.return_value = int_scale scaling_impl = Mock() scaling_impl.return_value = scale int_quant = IntQuant(signed=signed, narrow_range=narrow_range, tensor_clamp_impl=tensor_clamp_impl, float_to_int_impl=float_to_int_impl_mock) obj = RescalingIntQuant(int_scaling_impl=int_scaling_impl_mock, bit_width_impl=msb_clamp_bitwidth_mock, scaling_impl=scaling_impl, int_quant=int_quant) output, scale_out, bit_width = obj(value) expected_output = int_quant(scale, int_scale, bit_width, value) expected_scale = scale / int_scale assert torch.allclose(expected_output, output, RTOL, ATOL) assert torch.allclose(expected_scale, scale_out, RTOL, ATOL)
def test_IntQuant(x, narrow_range, signed, bit_width, scale, int_scale, float_to_int_impl, scale_multiplier): float_to_int_impl_mock = Mock() tensor_clamp_impl = TensorClamp() value = torch.tensor(x) bit_width = torch.tensor(bit_width, dtype=torch.float) scale = torch.tensor(scale) int_scale = torch.tensor(int_scale) tol = scale * scale_multiplier float_to_int_impl_mock.side_effect = float_to_int_impl() obj = IntQuant(narrow_range=narrow_range, signed=signed, float_to_int_impl=float_to_int_impl_mock, tensor_clamp_impl=tensor_clamp_impl) output = obj(scale, int_scale, bit_width, value) min_value = int(min_int(signed, narrow_range, bit_width)) max_value = int(max_int(signed, bit_width)) admissible_values = [x for x in range(min_value, max_value + 1)] value = (value / scale) * int_scale expected_output = tensor_clamp(value, min_val=min_int(signed, narrow_range, bit_width), max_val=max_int(signed, bit_width)) expected_output = (expected_output / int_scale) * scale int_output = obj.to_int(scale, int_scale, bit_width, value) # The assert is performed internally check_admissible_values check_admissible_values(int_output, admissible_values) assert torch.allclose(expected_output, output, RTOL, tol)
def __init__(self, signed: bool, narrow_range: bool, quant_type: QuantType, ms_bit_width_to_clamp: int, clamp_at_least_init_val: bool, min_overall_bit_width: Optional[int], max_overall_bit_width: Optional[int], msb_clamp_bit_width_impl_type: BitWidthImplType, override_pretrained_bit_width: bool): super(ClampQuantProxy, self).__init__() if quant_type == QuantType.FP: self.tensor_quant = IdentityPrescaledIntQuant() elif quant_type == QuantType.INT: msb_clamp_bit_width_impl = MsbClampParameterBitWidth(ms_bit_width_to_clamp=ms_bit_width_to_clamp, clamp_at_least_init_val=clamp_at_least_init_val, min_overall_bit_width=min_overall_bit_width, max_overall_bit_width=max_overall_bit_width, bit_width_impl_type=msb_clamp_bit_width_impl_type, override_pretrained=override_pretrained_bit_width) tensor_clamp_impl = TensorClamp() float_to_int_impl = RestrictValue(restrict_value_type=RestrictValueType.INT, float_to_int_impl_type=FloatToIntImplType.ROUND, min_val=None) tensor_quant_impl = PrescaledRestrictIntQuantWithInputBitWidth self.tensor_quant = tensor_quant_impl(signed=signed, narrow_range=narrow_range, tensor_clamp_impl=tensor_clamp_impl, float_to_int_impl=float_to_int_impl, msb_clamp_bit_width_impl=msb_clamp_bit_width_impl) else: raise Exception("Quantization type {} not supported for accumulators.".format(quant_type))
def __init__(self, signed: bool, quant_type: QuantType, ls_bit_width_to_trunc: int, trunc_at_least_init_val: bool, min_overall_bit_width: Optional[int], max_overall_bit_width: Optional[int], lsb_trunc_bit_width_impl_type: BitWidthImplType, explicit_rescaling: bool, override_pretrained_bit_width: bool): super(TruncQuantProxy, self).__init__() self.register_buffer(ZERO_HW_SENTINEL_NAME, torch.tensor(ZERO_HW_SENTINEL_VALUE)) self.explicit_rescaling = explicit_rescaling if quant_type == QuantType.FP: self.lsb_trunc_bit_width_impl = ZeroLsbTruncBitWidth() self.tensor_quant = IdentityPrescaledIntQuant() elif quant_type == QuantType.INT: self.lsb_trunc_bit_width_impl = LsbTruncParameterBitWidth(ls_bit_width_to_trunc=ls_bit_width_to_trunc, trunc_at_least_init_val=trunc_at_least_init_val, min_overall_bit_width=min_overall_bit_width, max_overall_bit_width=max_overall_bit_width, bit_width_impl_type=lsb_trunc_bit_width_impl_type, override_pretrained=override_pretrained_bit_width) tensor_clamp_impl = TensorClamp() float_to_int_impl = RestrictValue(restrict_value_type=RestrictValueType.INT, float_to_int_impl_type=FloatToIntImplType.FLOOR, min_val=None) self.tensor_quant = PrescaledIntQuant(signed=signed, narrow_range=False, tensor_clamp_impl=tensor_clamp_impl, float_to_int_impl=float_to_int_impl) else: raise Exception("Quantization type {} not supported for accumulators.".format(quant_type))
def __init__( self, scaling_impl: Module, tensor_clamp_impl: Module = TensorClamp(), quant_delay_steps: int = 0): super(ClampedBinaryQuant, self).__init__() self.scaling_impl = scaling_impl self.bit_width = BitWidthConst(1) self.zero_point = StatelessBuffer(torch.tensor(0.0)) self.delay_wrapper = DelayWrapper(quant_delay_steps) self.tensor_clamp_impl = tensor_clamp_impl
def __init__(self, narrow_range: bool, signed: bool, float_to_int_impl: Module = RoundSte(), tensor_clamp_impl: Module = TensorClamp(), quant_delay_steps: int = 0): super(IntQuant, self).__init__() self.float_to_int_impl = float_to_int_impl self.tensor_clamp_impl = tensor_clamp_impl self.signed = signed self.narrow_range = narrow_range self.delay_wrapper = DelayWrapper(quant_delay_steps)
def test_PrescaledRestrictIntQuanth(x, narrow_range, signed, scale, bit_width): value = torch.tensor(x) scale = torch.tensor(scale) bit_width = torch.tensor(bit_width, dtype=torch.float) tensor_clamp_impl = TensorClamp() msb_clamp_bitwidth_mock = Mock() msb_clamp_bitwidth_mock.return_value = bit_width float_to_int_impl_mock = Mock() float_to_int_impl_mock.side_effect = (lambda y: y) int_quant = IntQuant(signed=signed, narrow_range=narrow_range, tensor_clamp_impl=tensor_clamp_impl, float_to_int_impl=float_to_int_impl_mock) obj = PrescaledRestrictIntQuant(bit_width_impl=msb_clamp_bitwidth_mock, int_quant=int_quant) output, scale, bit_width = obj(value, scale) expected_output = int_quant(scale, torch.tensor(1.0), bit_width, value) assert torch.allclose(expected_output, output, RTOL, ATOL)
def test_PrescaledRestrictIntQuantWithInputBitWidth(x, narrow_range, signed, scale, bit_width): value = torch.tensor(x) scale = torch.tensor(scale) tensor_clamp_impl = TensorClamp() msb_clamp_bitwidth_mock = Mock() msb_clamp_bitwidth_mock.return_value = torch.tensor(bit_width, dtype=torch.float) float_to_int_impl_mock = Mock() float_to_int_impl_mock.side_effect = (lambda y: y) obj = PrescaledRestrictIntQuantWithInputBitWidth(narrow_range=narrow_range, signed=signed, tensor_clamp_impl=tensor_clamp_impl, msb_clamp_bit_width_impl=msb_clamp_bitwidth_mock, float_to_int_impl=float_to_int_impl_mock) output, scale, bit_width = obj(value, scale, bit_width, torch.tensor(ZERO_HW_SENTINEL_VALUE)) expected_IntQuant = IntQuant(signed=signed, narrow_range=narrow_range, tensor_clamp_impl=tensor_clamp_impl, float_to_int_impl=float_to_int_impl_mock) expected_output = expected_IntQuant(scale, torch.tensor(ZERO_HW_SENTINEL_VALUE) + 1, bit_width, value) assert torch.allclose(expected_output, output, RTOL, ATOL)
def _weight_quant_init_impl(bit_width: Optional[int], quant_type: QuantType, narrow_range: bool, scaling_override: Optional[nn.Module], restrict_scaling_type: RestrictValueType, scaling_const: float, scaling_stats_op: StatsOp, scaling_impl_type: ScalingImplType, scaling_stats_reduce_dim: Optional[int], scaling_shape: Tuple[int, ...], scaling_min_val: Optional[float], bit_width_impl_type: Optional[BitWidthImplType], restrict_bit_width_type: Optional[RestrictValueType], min_overall_bit_width: Optional[int], max_overall_bit_width: Optional[int], bit_width_impl_override: Optional[Union[BitWidthConst, BitWidthParameter]], scaling_stats_input_view_shape_impl: StatsInputViewShapeImpl, scaling_stats_input_concat_dim: int, ternary_threshold: Optional[float], scaling_stats_sigma: Optional[float], tracked_parameter_list: List[torch.nn.Parameter], zero_hw_sentinel: torch.Tensor, override_pretrained_bit_width: bool): if quant_type == QuantType.FP: tensor_quant = IdentityQuant() else: if scaling_impl_type != ScalingImplType.OVERRIDE and scaling_override is not None: raise Exception("Overriding scaling requires to set ScalingImplType to OVERRIDE explicitly.") if scaling_impl_type == ScalingImplType.OVERRIDE and scaling_override is None: raise Exception("Overriding scaling requires to pass a scaling impl module.") if scaling_impl_type == ScalingImplType.OVERRIDE and scaling_override is not None: scaling_impl = scaling_override elif scaling_impl_type == ScalingImplType.STATS \ or scaling_impl_type == ScalingImplType.AFFINE_STATS \ or scaling_impl_type == ScalingImplType.PARAMETER_FROM_STATS: stats_scaling = ParameterStatsScaling(stats_op=scaling_stats_op, restrict_scaling_type=restrict_scaling_type, tracked_parameter_list=tracked_parameter_list, stats_input_view_shape_impl=scaling_stats_input_view_shape_impl, stats_input_concat_dim=scaling_stats_input_concat_dim, sigma=scaling_stats_sigma, scaling_min_val=scaling_min_val, stats_reduce_dim=scaling_stats_reduce_dim, stats_output_shape=scaling_shape, affine=scaling_impl_type == ScalingImplType.AFFINE_STATS) if scaling_impl_type == ScalingImplType.PARAMETER_FROM_STATS: if quant_type == QuantType.BINARY or quant_type == QuantType.TERNARY: raise Exception("Parameter from stats scaling is currently not supported for binary/ternary") scaling_init = stats_scaling(zero_hw_sentinel).detach() scaling_impl = StandaloneScaling(scaling_init=scaling_init, parameter_shape=scaling_shape, restrict_scaling_type=restrict_scaling_type, is_parameter=True, scaling_min_val=scaling_min_val) else: scaling_impl = stats_scaling elif scaling_impl_type == ScalingImplType.CONST or scaling_impl_type == ScalingImplType.HE: if scaling_impl_type == ScalingImplType.HE: scaling_const = 0.0 for param in tracked_parameter_list: # takes average of He scaling over parameter list two_dim_param = param.view(param.shape[0], -1) scaling_const += math.sqrt(2.0 / two_dim_param.shape[1]) scaling_const /= len(tracked_parameter_list) scaling_init = torch.tensor(scaling_const) scaling_impl = StandaloneScaling(scaling_init=scaling_init, parameter_shape=SCALING_SCALAR_SHAPE, restrict_scaling_type=restrict_scaling_type, is_parameter=False, scaling_min_val=None) else: raise Exception("Scaling type {} not supported for weight quantization" .format(str(scaling_impl_type))) if bit_width == 1 and quant_type == QuantType.BINARY: tensor_quant = BinaryQuant(scaling_impl=scaling_impl) elif bit_width == 2 and quant_type == QuantType.TERNARY: tensor_quant = TernaryQuant(scaling_impl=scaling_impl, threshold=ternary_threshold) elif bit_width >= 2 and quant_type == QuantType.INT: if bit_width_impl_override is None: if (bit_width_impl_type is None or bit_width is None or restrict_bit_width_type is None): raise Exception("Bit width is not defined properly") if bit_width_impl_type == BitWidthImplType.CONST: tensor_clamp_impl = TensorClampSte() bit_width_impl = BitWidthConst(bit_width, restrict_bit_width_type) elif bit_width_impl_type == BitWidthImplType.PARAMETER: tensor_clamp_impl = TensorClamp() bit_width_impl = BitWidthParameter(bit_width_init=bit_width, restrict_bit_width_type=restrict_bit_width_type, min_overall_bit_width=min_overall_bit_width, max_overall_bit_width=max_overall_bit_width, override_pretrained=override_pretrained_bit_width) else: raise Exception("Bit width type {} not supported for weight quantization." .format(str(bit_width_impl_type))) else: tensor_clamp_impl = TensorClamp() bit_width_impl = bit_width_impl_override float_to_int_impl = RestrictValue(restrict_value_type=RestrictValueType.INT, float_to_int_impl_type=FloatToIntImplType.ROUND, min_val=None) int_scaling_impl = IntScaling(narrow_range, signed=True, restrict_scaling_type=restrict_scaling_type) tensor_quant = RescalingIntQuant(narrow_range=narrow_range, signed=True, scaling_impl=scaling_impl, int_scaling_impl=int_scaling_impl, tensor_clamp_impl=tensor_clamp_impl, msb_clamp_bit_width_impl=bit_width_impl, float_to_int_impl=float_to_int_impl, runtime=False) else: raise Exception('Unsupported weight quantization: {} bit width, {} quantization.' .format(bit_width, str(quant_type))) return tensor_quant
def __init__(self, activation_impl: Module, bit_width: int, signed: bool, narrow_range: bool, min_val: float, max_val: float, quant_type: QuantType, float_to_int_impl_type: FloatToIntImplType, scaling_override: Optional[Module], scaling_impl_type: ScalingImplType, scaling_per_channel: bool, scaling_min_val: Optional[float], scaling_stats_sigma: Optional[float], scaling_stats_op: Optional[StatsOp], scaling_stats_buffer_momentum: Optional[float], scaling_stats_permute_dims: Optional[Tuple], per_channel_broadcastable_shape: Optional[Tuple[int, ...]], min_overall_bit_width: Optional[int], max_overall_bit_width: Optional[int], bit_width_impl_override: Module, bit_width_impl_type: BitWidthImplType, restrict_bit_width_type: RestrictValueType, restrict_scaling_type: RestrictValueType, override_pretrained_bit_width: bool): super(ActivationQuantProxy, self).__init__() if not signed and min_val != 0.0: raise Exception("Min val has to be 0.0 when quantization is unsigned.") if scaling_per_channel and per_channel_broadcastable_shape is None: raise Exception("Per channel scaling requires to specify number of channels.") if quant_type == QuantType.FP: tensor_quant = IdentityQuant() else: if scaling_impl_type != ScalingImplType.OVERRIDE and scaling_override is not None: raise Exception("Overriding scaling requires to set ScalingImplType to OVERRIDE explicitly.") if scaling_impl_type == ScalingImplType.OVERRIDE and scaling_override is None: raise Exception("Overriding scaling requires to pass a scaling impl module.") if scaling_per_channel: scaling_shape = per_channel_broadcastable_shape else: scaling_shape = SCALING_SCALAR_SHAPE if scaling_impl_type == ScalingImplType.OVERRIDE and scaling_override is not None: scaling_impl = scaling_override runtime = False elif scaling_impl_type == ScalingImplType.CONST or scaling_impl_type == ScalingImplType.PARAMETER: scaling_init = RescalingIntQuant.scaling_init_from_min_max(min_val, max_val) scaling_impl = StandaloneScaling(is_parameter=scaling_impl_type == ScalingImplType.PARAMETER, parameter_shape=scaling_shape, restrict_scaling_type=restrict_scaling_type, scaling_init=scaling_init, scaling_min_val=scaling_min_val) runtime = False elif scaling_impl_type == ScalingImplType.STATS or scaling_impl_type == ScalingImplType.AFFINE_STATS: if scaling_per_channel and not scaling_stats_op == StatsOp.MAX_AVE: scaling_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS scaling_stats_reduce_dim = 1 elif scaling_per_channel and scaling_stats_op == StatsOp.MAX_AVE: raise Exception("Can't do per channel scaling with MAX AVE statistics.") elif not scaling_per_channel and scaling_stats_op == StatsOp.MAX_AVE: scaling_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS scaling_stats_reduce_dim = 1 else: # not scaling_per_channel scaling_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_TENSOR scaling_stats_reduce_dim = None scaling_stats_permute_dims = None stats_buffer_init = RescalingIntQuant.scaling_init_from_min_max(min_val, max_val).item() scaling_impl = RuntimeStatsScaling(stats_op=scaling_stats_op, restrict_scaling_type=restrict_scaling_type, stats_input_view_shape_impl=scaling_stats_input_view_shape_impl, stats_output_shape=scaling_shape, sigma=scaling_stats_sigma, scaling_min_val=scaling_min_val, stats_reduce_dim=scaling_stats_reduce_dim, stats_buffer_momentum=scaling_stats_buffer_momentum, stats_buffer_init=stats_buffer_init, stats_permute_dims=scaling_stats_permute_dims, affine=scaling_impl_type == ScalingImplType.AFFINE_STATS) runtime = True else: raise Exception("Scaling type {} not supported for int runtime quantization" .format(str(scaling_impl_type))) if quant_type == QuantType.BINARY: if not signed: raise Exception("Binary activation supports only signed activations") tensor_quant = ClampedBinaryQuant(scaling_impl=scaling_impl) elif quant_type == QuantType.INT: if bit_width_impl_override is None: if bit_width_impl_type is None or bit_width is None or restrict_bit_width_type is None: raise Exception("Bit width is not defined properly") if bit_width_impl_type == BitWidthImplType.CONST: tensor_clamp_impl = TensorClamp() # If it's const, don't pass gradients to clipped values msb_clamp_bit_width_impl = BitWidthConst(bit_width, restrict_bit_width_type) elif bit_width_impl_type == BitWidthImplType.PARAMETER: tensor_clamp_impl = TensorClamp() # if it's learned, I pass gradients to the bit width msb_clamp_bit_width_impl = BitWidthParameter(bit_width, min_overall_bit_width, max_overall_bit_width, restrict_bit_width_type, override_pretrained_bit_width) else: raise Exception("Bit width type {} not supported for weight quantization" .format(str(bit_width_impl_type))) else: msb_clamp_bit_width_impl = bit_width_impl_override tensor_clamp_impl = TensorClamp() # if there is an override, it's learned float_to_int_impl = RestrictValue(restrict_value_type=RestrictValueType.INT, float_to_int_impl_type=float_to_int_impl_type, min_val=None) int_scaling_impl = IntScaling(narrow_range, signed=signed, restrict_scaling_type=restrict_scaling_type) tensor_quant = RescalingIntQuant(signed=signed, narrow_range=narrow_range, scaling_impl=scaling_impl, int_scaling_impl=int_scaling_impl, tensor_clamp_impl=tensor_clamp_impl, msb_clamp_bit_width_impl=msb_clamp_bit_width_impl, float_to_int_impl=float_to_int_impl, runtime=runtime) else: raise Exception("Quantization type {} not supported for activations.".format(quant_type)) self.fused_activation_quant_proxy = FusedActivationQuantProxy(activation_impl, tensor_quant) self.scaling_impl_type = scaling_impl_type # needed to switch between different scaling modes