Exemple #1
0
def delayed_binary_quant(scaling_impl_all, quant_delay_steps):
    """
    Delayed BinaryQuant with all variants of scaling
    """
    return BinaryQuant(scaling_impl=scaling_impl_all, quant_delay_steps=quant_delay_steps)
def _weight_quant_init_impl(bit_width: Optional[int],
                            quant_type: QuantType,
                            narrow_range: bool,
                            scaling_override: Optional[nn.Module],
                            restrict_scaling_type: RestrictValueType,
                            scaling_const: float,
                            scaling_stats_op: StatsOp,
                            scaling_impl_type: ScalingImplType,
                            scaling_stats_reduce_dim: Optional[int],
                            scaling_shape: Tuple[int, ...],
                            scaling_min_val: Optional[float],
                            bit_width_impl_type: Optional[BitWidthImplType],
                            restrict_bit_width_type: Optional[RestrictValueType],
                            min_overall_bit_width: Optional[int],
                            max_overall_bit_width: Optional[int],
                            bit_width_impl_override: Optional[Union[BitWidthConst, BitWidthParameter]],
                            scaling_stats_input_view_shape_impl: StatsInputViewShapeImpl,
                            scaling_stats_input_concat_dim: int,
                            ternary_threshold: Optional[float],
                            scaling_stats_sigma: Optional[float],
                            tracked_parameter_list: List[torch.nn.Parameter],
                            zero_hw_sentinel: torch.Tensor,
                            override_pretrained_bit_width: bool):

    if quant_type == QuantType.FP:
        tensor_quant = IdentityQuant()
    else:
        if scaling_impl_type != ScalingImplType.OVERRIDE and scaling_override is not None:
            raise Exception("Overriding scaling requires to set ScalingImplType to OVERRIDE explicitly.")
        if scaling_impl_type == ScalingImplType.OVERRIDE and scaling_override is None:
            raise Exception("Overriding scaling requires to pass a scaling impl module.")

        if scaling_impl_type == ScalingImplType.OVERRIDE and scaling_override is not None:
            scaling_impl = scaling_override

        elif scaling_impl_type == ScalingImplType.STATS \
                or scaling_impl_type == ScalingImplType.AFFINE_STATS \
                or scaling_impl_type == ScalingImplType.PARAMETER_FROM_STATS:
            stats_scaling = ParameterStatsScaling(stats_op=scaling_stats_op,
                                                  restrict_scaling_type=restrict_scaling_type,
                                                  tracked_parameter_list=tracked_parameter_list,
                                                  stats_input_view_shape_impl=scaling_stats_input_view_shape_impl,
                                                  stats_input_concat_dim=scaling_stats_input_concat_dim,
                                                  sigma=scaling_stats_sigma,
                                                  scaling_min_val=scaling_min_val,
                                                  stats_reduce_dim=scaling_stats_reduce_dim,
                                                  stats_output_shape=scaling_shape,
                                                  affine=scaling_impl_type == ScalingImplType.AFFINE_STATS)
            if scaling_impl_type == ScalingImplType.PARAMETER_FROM_STATS:
                if quant_type == QuantType.BINARY or quant_type == QuantType.TERNARY:
                    raise Exception("Parameter from stats scaling is currently not supported for binary/ternary")
                scaling_init = stats_scaling(zero_hw_sentinel).detach()
                scaling_impl = StandaloneScaling(scaling_init=scaling_init,
                                                 parameter_shape=scaling_shape,
                                                 restrict_scaling_type=restrict_scaling_type,
                                                 is_parameter=True,
                                                 scaling_min_val=scaling_min_val)
            else:
                scaling_impl = stats_scaling

        elif scaling_impl_type == ScalingImplType.CONST or scaling_impl_type == ScalingImplType.HE:
            if scaling_impl_type == ScalingImplType.HE:
                scaling_const = 0.0
                for param in tracked_parameter_list:  # takes average of He scaling over parameter list
                    two_dim_param = param.view(param.shape[0], -1)
                    scaling_const += math.sqrt(2.0 / two_dim_param.shape[1])
                scaling_const /= len(tracked_parameter_list)
            scaling_init = torch.tensor(scaling_const)
            scaling_impl = StandaloneScaling(scaling_init=scaling_init,
                                             parameter_shape=SCALING_SCALAR_SHAPE,
                                             restrict_scaling_type=restrict_scaling_type,
                                             is_parameter=False,
                                             scaling_min_val=None)
        else:
            raise Exception("Scaling type {} not supported for weight quantization"
                            .format(str(scaling_impl_type)))

        if bit_width == 1 and quant_type == QuantType.BINARY:
            tensor_quant = BinaryQuant(scaling_impl=scaling_impl)

        elif bit_width == 2 and quant_type == QuantType.TERNARY:
            tensor_quant = TernaryQuant(scaling_impl=scaling_impl, threshold=ternary_threshold)

        elif bit_width >= 2 and quant_type == QuantType.INT:
            if bit_width_impl_override is None:
                if (bit_width_impl_type is None
                        or bit_width is None
                        or restrict_bit_width_type is None):
                    raise Exception("Bit width is not defined properly")

                if bit_width_impl_type == BitWidthImplType.CONST:
                    tensor_clamp_impl = TensorClampSte()
                    bit_width_impl = BitWidthConst(bit_width, restrict_bit_width_type)
                elif bit_width_impl_type == BitWidthImplType.PARAMETER:
                    tensor_clamp_impl = TensorClamp()
                    bit_width_impl = BitWidthParameter(bit_width_init=bit_width,
                                                       restrict_bit_width_type=restrict_bit_width_type,
                                                       min_overall_bit_width=min_overall_bit_width,
                                                       max_overall_bit_width=max_overall_bit_width,
                                                       override_pretrained=override_pretrained_bit_width)
                else:
                    raise Exception("Bit width type {} not supported for weight quantization."
                                    .format(str(bit_width_impl_type)))
            else:
                tensor_clamp_impl = TensorClamp()
                bit_width_impl = bit_width_impl_override

            float_to_int_impl = RestrictValue(restrict_value_type=RestrictValueType.INT,
                                              float_to_int_impl_type=FloatToIntImplType.ROUND,
                                              min_val=None)
            int_scaling_impl = IntScaling(narrow_range,
                                          signed=True,
                                          restrict_scaling_type=restrict_scaling_type)
            tensor_quant = RescalingIntQuant(narrow_range=narrow_range,
                                             signed=True,
                                             scaling_impl=scaling_impl,
                                             int_scaling_impl=int_scaling_impl,
                                             tensor_clamp_impl=tensor_clamp_impl,
                                             msb_clamp_bit_width_impl=bit_width_impl,
                                             float_to_int_impl=float_to_int_impl,
                                             runtime=False)
        else:
            raise Exception('Unsupported weight quantization: {} bit width, {} quantization.'
                            .format(bit_width, str(quant_type)))
    return tensor_quant
Exemple #3
0
def binary_quant(scaling_impl_all):
    """
    Binary quant with all variants of scaling
    """
    return BinaryQuant(scaling_impl=scaling_impl_all)