def delayed_binary_quant(scaling_impl_all, quant_delay_steps): """ Delayed BinaryQuant with all variants of scaling """ return BinaryQuant(scaling_impl=scaling_impl_all, quant_delay_steps=quant_delay_steps)
def _weight_quant_init_impl(bit_width: Optional[int], quant_type: QuantType, narrow_range: bool, scaling_override: Optional[nn.Module], restrict_scaling_type: RestrictValueType, scaling_const: float, scaling_stats_op: StatsOp, scaling_impl_type: ScalingImplType, scaling_stats_reduce_dim: Optional[int], scaling_shape: Tuple[int, ...], scaling_min_val: Optional[float], bit_width_impl_type: Optional[BitWidthImplType], restrict_bit_width_type: Optional[RestrictValueType], min_overall_bit_width: Optional[int], max_overall_bit_width: Optional[int], bit_width_impl_override: Optional[Union[BitWidthConst, BitWidthParameter]], scaling_stats_input_view_shape_impl: StatsInputViewShapeImpl, scaling_stats_input_concat_dim: int, ternary_threshold: Optional[float], scaling_stats_sigma: Optional[float], tracked_parameter_list: List[torch.nn.Parameter], zero_hw_sentinel: torch.Tensor, override_pretrained_bit_width: bool): if quant_type == QuantType.FP: tensor_quant = IdentityQuant() else: if scaling_impl_type != ScalingImplType.OVERRIDE and scaling_override is not None: raise Exception("Overriding scaling requires to set ScalingImplType to OVERRIDE explicitly.") if scaling_impl_type == ScalingImplType.OVERRIDE and scaling_override is None: raise Exception("Overriding scaling requires to pass a scaling impl module.") if scaling_impl_type == ScalingImplType.OVERRIDE and scaling_override is not None: scaling_impl = scaling_override elif scaling_impl_type == ScalingImplType.STATS \ or scaling_impl_type == ScalingImplType.AFFINE_STATS \ or scaling_impl_type == ScalingImplType.PARAMETER_FROM_STATS: stats_scaling = ParameterStatsScaling(stats_op=scaling_stats_op, restrict_scaling_type=restrict_scaling_type, tracked_parameter_list=tracked_parameter_list, stats_input_view_shape_impl=scaling_stats_input_view_shape_impl, stats_input_concat_dim=scaling_stats_input_concat_dim, sigma=scaling_stats_sigma, scaling_min_val=scaling_min_val, stats_reduce_dim=scaling_stats_reduce_dim, stats_output_shape=scaling_shape, affine=scaling_impl_type == ScalingImplType.AFFINE_STATS) if scaling_impl_type == ScalingImplType.PARAMETER_FROM_STATS: if quant_type == QuantType.BINARY or quant_type == QuantType.TERNARY: raise Exception("Parameter from stats scaling is currently not supported for binary/ternary") scaling_init = stats_scaling(zero_hw_sentinel).detach() scaling_impl = StandaloneScaling(scaling_init=scaling_init, parameter_shape=scaling_shape, restrict_scaling_type=restrict_scaling_type, is_parameter=True, scaling_min_val=scaling_min_val) else: scaling_impl = stats_scaling elif scaling_impl_type == ScalingImplType.CONST or scaling_impl_type == ScalingImplType.HE: if scaling_impl_type == ScalingImplType.HE: scaling_const = 0.0 for param in tracked_parameter_list: # takes average of He scaling over parameter list two_dim_param = param.view(param.shape[0], -1) scaling_const += math.sqrt(2.0 / two_dim_param.shape[1]) scaling_const /= len(tracked_parameter_list) scaling_init = torch.tensor(scaling_const) scaling_impl = StandaloneScaling(scaling_init=scaling_init, parameter_shape=SCALING_SCALAR_SHAPE, restrict_scaling_type=restrict_scaling_type, is_parameter=False, scaling_min_val=None) else: raise Exception("Scaling type {} not supported for weight quantization" .format(str(scaling_impl_type))) if bit_width == 1 and quant_type == QuantType.BINARY: tensor_quant = BinaryQuant(scaling_impl=scaling_impl) elif bit_width == 2 and quant_type == QuantType.TERNARY: tensor_quant = TernaryQuant(scaling_impl=scaling_impl, threshold=ternary_threshold) elif bit_width >= 2 and quant_type == QuantType.INT: if bit_width_impl_override is None: if (bit_width_impl_type is None or bit_width is None or restrict_bit_width_type is None): raise Exception("Bit width is not defined properly") if bit_width_impl_type == BitWidthImplType.CONST: tensor_clamp_impl = TensorClampSte() bit_width_impl = BitWidthConst(bit_width, restrict_bit_width_type) elif bit_width_impl_type == BitWidthImplType.PARAMETER: tensor_clamp_impl = TensorClamp() bit_width_impl = BitWidthParameter(bit_width_init=bit_width, restrict_bit_width_type=restrict_bit_width_type, min_overall_bit_width=min_overall_bit_width, max_overall_bit_width=max_overall_bit_width, override_pretrained=override_pretrained_bit_width) else: raise Exception("Bit width type {} not supported for weight quantization." .format(str(bit_width_impl_type))) else: tensor_clamp_impl = TensorClamp() bit_width_impl = bit_width_impl_override float_to_int_impl = RestrictValue(restrict_value_type=RestrictValueType.INT, float_to_int_impl_type=FloatToIntImplType.ROUND, min_val=None) int_scaling_impl = IntScaling(narrow_range, signed=True, restrict_scaling_type=restrict_scaling_type) tensor_quant = RescalingIntQuant(narrow_range=narrow_range, signed=True, scaling_impl=scaling_impl, int_scaling_impl=int_scaling_impl, tensor_clamp_impl=tensor_clamp_impl, msb_clamp_bit_width_impl=bit_width_impl, float_to_int_impl=float_to_int_impl, runtime=False) else: raise Exception('Unsupported weight quantization: {} bit width, {} quantization.' .format(bit_width, str(quant_type))) return tensor_quant
def binary_quant(scaling_impl_all): """ Binary quant with all variants of scaling """ return BinaryQuant(scaling_impl=scaling_impl_all)