Ejemplo n.º 1
0
def delayed_clamped_binary_quant(scaling_impl_all, quant_delay_steps):
    """
    ClampedBinaryQuant with all variants of scaling
    """
    return ClampedBinaryQuant(scaling_impl=scaling_impl_all, quant_delay_steps=quant_delay_steps)
Ejemplo n.º 2
0
    def __init__(self,
                 activation_impl: Module,
                 bit_width: int,
                 signed: bool,
                 narrow_range: bool,
                 min_val: float,
                 max_val: float,
                 quant_type: QuantType,
                 float_to_int_impl_type: FloatToIntImplType,
                 scaling_override: Optional[Module],
                 scaling_impl_type: ScalingImplType,
                 scaling_per_channel: bool,
                 scaling_min_val: Optional[float],
                 scaling_stats_sigma: Optional[float],
                 scaling_stats_op: Optional[StatsOp],
                 scaling_stats_buffer_momentum: Optional[float],
                 scaling_stats_permute_dims: Optional[Tuple],
                 per_channel_broadcastable_shape: Optional[Tuple[int, ...]],
                 min_overall_bit_width: Optional[int],
                 max_overall_bit_width: Optional[int],
                 bit_width_impl_override: Module,
                 bit_width_impl_type: BitWidthImplType,
                 restrict_bit_width_type: RestrictValueType,
                 restrict_scaling_type: RestrictValueType,
                 override_pretrained_bit_width: bool):
        super(ActivationQuantProxy, self).__init__()

        if not signed and min_val != 0.0:
            raise Exception("Min val has to be 0.0 when quantization is unsigned.")
        if scaling_per_channel and per_channel_broadcastable_shape is None:
            raise Exception("Per channel scaling requires to specify number of channels.")

        if quant_type == QuantType.FP:
            tensor_quant = IdentityQuant()
        else:
            if scaling_impl_type != ScalingImplType.OVERRIDE and scaling_override is not None:
                raise Exception("Overriding scaling requires to set ScalingImplType to OVERRIDE explicitly.")
            if scaling_impl_type == ScalingImplType.OVERRIDE and scaling_override is None:
                raise Exception("Overriding scaling requires to pass a scaling impl module.")

            if scaling_per_channel:
                scaling_shape = per_channel_broadcastable_shape
            else:
                scaling_shape = SCALING_SCALAR_SHAPE

            if scaling_impl_type == ScalingImplType.OVERRIDE and scaling_override is not None:
                scaling_impl = scaling_override
                runtime = False

            elif scaling_impl_type == ScalingImplType.CONST or scaling_impl_type == ScalingImplType.PARAMETER:
                scaling_init = RescalingIntQuant.scaling_init_from_min_max(min_val, max_val)
                scaling_impl = StandaloneScaling(is_parameter=scaling_impl_type == ScalingImplType.PARAMETER,
                                                 parameter_shape=scaling_shape,
                                                 restrict_scaling_type=restrict_scaling_type,
                                                 scaling_init=scaling_init,
                                                 scaling_min_val=scaling_min_val)
                runtime = False
            elif scaling_impl_type == ScalingImplType.STATS or scaling_impl_type == ScalingImplType.AFFINE_STATS:

                if scaling_per_channel and not scaling_stats_op == StatsOp.MAX_AVE:
                    scaling_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                    scaling_stats_reduce_dim = 1
                elif scaling_per_channel and scaling_stats_op == StatsOp.MAX_AVE:
                    raise Exception("Can't do per channel scaling with MAX AVE statistics.")
                elif not scaling_per_channel and scaling_stats_op == StatsOp.MAX_AVE:
                    scaling_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                    scaling_stats_reduce_dim = 1
                else:  # not scaling_per_channel
                    scaling_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_TENSOR
                    scaling_stats_reduce_dim = None
                    scaling_stats_permute_dims = None

                stats_buffer_init = RescalingIntQuant.scaling_init_from_min_max(min_val, max_val).item()
                scaling_impl = RuntimeStatsScaling(stats_op=scaling_stats_op,
                                                   restrict_scaling_type=restrict_scaling_type,
                                                   stats_input_view_shape_impl=scaling_stats_input_view_shape_impl,
                                                   stats_output_shape=scaling_shape,
                                                   sigma=scaling_stats_sigma,
                                                   scaling_min_val=scaling_min_val,
                                                   stats_reduce_dim=scaling_stats_reduce_dim,
                                                   stats_buffer_momentum=scaling_stats_buffer_momentum,
                                                   stats_buffer_init=stats_buffer_init,
                                                   stats_permute_dims=scaling_stats_permute_dims,
                                                   affine=scaling_impl_type == ScalingImplType.AFFINE_STATS)
                runtime = True
            else:
                raise Exception("Scaling type {} not supported for int runtime quantization"
                                .format(str(scaling_impl_type)))

            if quant_type == QuantType.BINARY:
                if not signed:
                    raise Exception("Binary activation supports only signed activations")
                tensor_quant = ClampedBinaryQuant(scaling_impl=scaling_impl)

            elif quant_type == QuantType.INT:

                if bit_width_impl_override is None:
                    if bit_width_impl_type is None or bit_width is None or restrict_bit_width_type is None:
                        raise Exception("Bit width is not defined properly")

                    if bit_width_impl_type == BitWidthImplType.CONST:
                        tensor_clamp_impl = TensorClamp()  # If it's const, don't pass gradients to clipped values
                        msb_clamp_bit_width_impl = BitWidthConst(bit_width, restrict_bit_width_type)
                    elif bit_width_impl_type == BitWidthImplType.PARAMETER:
                        tensor_clamp_impl = TensorClamp()  # if it's learned, I pass gradients to the bit width
                        msb_clamp_bit_width_impl = BitWidthParameter(bit_width,
                                                                     min_overall_bit_width,
                                                                     max_overall_bit_width,
                                                                     restrict_bit_width_type,
                                                                     override_pretrained_bit_width)
                    else:
                        raise Exception("Bit width type {} not supported for weight quantization"
                                        .format(str(bit_width_impl_type)))
                else:
                    msb_clamp_bit_width_impl = bit_width_impl_override
                    tensor_clamp_impl = TensorClamp()  # if there is an override, it's learned

                float_to_int_impl = RestrictValue(restrict_value_type=RestrictValueType.INT,
                                                  float_to_int_impl_type=float_to_int_impl_type,
                                                  min_val=None)
                int_scaling_impl = IntScaling(narrow_range,
                                              signed=signed,
                                              restrict_scaling_type=restrict_scaling_type)
                tensor_quant = RescalingIntQuant(signed=signed,
                                                 narrow_range=narrow_range,
                                                 scaling_impl=scaling_impl,
                                                 int_scaling_impl=int_scaling_impl,
                                                 tensor_clamp_impl=tensor_clamp_impl,
                                                 msb_clamp_bit_width_impl=msb_clamp_bit_width_impl,
                                                 float_to_int_impl=float_to_int_impl,
                                                 runtime=runtime)
            else:
                raise Exception("Quantization type {} not supported for activations.".format(quant_type))

        self.fused_activation_quant_proxy = FusedActivationQuantProxy(activation_impl, tensor_quant)
        self.scaling_impl_type = scaling_impl_type  # needed to switch between different scaling modes
Ejemplo n.º 3
0
def clamped_binary_quant(scaling_impl_all):
    """
    ClampedBinaryQuant with all variants of scaling
    """
    return ClampedBinaryQuant(scaling_impl=scaling_impl_all)