예제 #1
0
    def __init__(self,
                 num_features,
                 eps=1e-5,
                 momentum=0.1,
                 restrict_value_type: RestrictValueType = RestrictValueType.FP,
                 impl_type: ScalingImplType = ScalingImplType.STATS,
                 bias_quant_type: QuantType = QuantType.FP,
                 bias_narrow_range: bool = False,
                 bias_bit_width: int = None):
        QuantLayer.__init__(self,
                            compute_output_scale=False,
                            compute_output_bit_width=False,
                            return_quant_tensor=False)
        nn.Module.__init__(self)

        if bias_quant_type != QuantType.FP and not (self.compute_output_scale and self.compute_output_bit_width):
            raise Exception("Quantizing bias requires to compute output scale and output bit width")

        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))

        if impl_type == ScalingImplType.PARAMETER_FROM_STATS:
            self.running_mean = None
            self.running_var = None
        elif impl_type == ScalingImplType.STATS:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
        else:
            raise Exception("Scaling mode not supported")

        self.eps = eps
        self.momentum = momentum
        self.impl_type = impl_type
        self.num_features = num_features
        self.restrict_value_type = restrict_value_type
        self.bias_quant = BiasQuantProxy(quant_type=bias_quant_type,
                                         narrow_range=bias_narrow_range,
                                         bit_width=bias_bit_width)
        self.restrict_weight = RestrictValue(restrict_value_type=restrict_value_type,
                                             float_to_int_impl_type=FloatToIntImplType.ROUND,
                                             min_val=None)
        self.restrict_scaling_preprocess = RestrictValue.restrict_value_op(restrict_value_type,
                                                                           restrict_value_op_impl_type=
                                                                           RestrictValueOpImplType.TORCH_MODULE)
예제 #2
0
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        weight_key = prefix + 'weight'
        bias_key = prefix + 'bias'
        running_mean_key = prefix + 'running_mean'
        running_var_key = prefix + 'running_var'
        num_batches_tracked_key = prefix + 'num_batches_tracked'

        # If it's converting a FP BN into weight/bias impl
        if self.impl_type == ScalingImplType.PARAMETER_FROM_STATS \
                and running_mean_key in state_dict and running_var_key in state_dict:
            weight_init, bias_init = mul_add_from_bn(bn_bias=state_dict[bias_key],
                                                     bn_weight=state_dict[weight_key],
                                                     bn_mean=state_dict[running_mean_key],
                                                     bn_var=state_dict[running_var_key],
                                                     bn_eps=self.eps,
                                                     affine_only=False)
            restrict_op = RestrictValue.restrict_value_op(restrict_value_type=self.restrict_value_type,
                                                          restrict_value_op_impl_type=RestrictValueOpImplType.TORCH_FN)
            self.weight_sign = torch.sign(weight_init.data)
            weight_init = weight_init.detach().clone().abs().data
            self.weight.data = restrict_op(weight_init)
            self.bias.data = bias_init.detach().clone().data
            del state_dict[bias_key]
            del state_dict[weight_key]
            del state_dict[running_mean_key]
            del state_dict[running_var_key]
            del state_dict[num_batches_tracked_key]
        super(QuantBatchNorm2d, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)
        if config.IGNORE_MISSING_KEYS and bias_key in missing_keys:
            missing_keys.remove(bias_key)
        if config.IGNORE_MISSING_KEYS and weight_key in missing_keys:
            missing_keys.remove(weight_key)
        if num_batches_tracked_key in unexpected_keys:
            unexpected_keys.remove(num_batches_tracked_key)