def __init__(self, num_features, eps=1e-5, momentum=0.1, restrict_value_type: RestrictValueType = RestrictValueType.FP, impl_type: ScalingImplType = ScalingImplType.STATS, bias_quant_type: QuantType = QuantType.FP, bias_narrow_range: bool = False, bias_bit_width: int = None): QuantLayer.__init__(self, compute_output_scale=False, compute_output_bit_width=False, return_quant_tensor=False) nn.Module.__init__(self) if bias_quant_type != QuantType.FP and not (self.compute_output_scale and self.compute_output_bit_width): raise Exception("Quantizing bias requires to compute output scale and output bit width") self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) if impl_type == ScalingImplType.PARAMETER_FROM_STATS: self.running_mean = None self.running_var = None elif impl_type == ScalingImplType.STATS: self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) else: raise Exception("Scaling mode not supported") self.eps = eps self.momentum = momentum self.impl_type = impl_type self.num_features = num_features self.restrict_value_type = restrict_value_type self.bias_quant = BiasQuantProxy(quant_type=bias_quant_type, narrow_range=bias_narrow_range, bit_width=bias_bit_width) self.restrict_weight = RestrictValue(restrict_value_type=restrict_value_type, float_to_int_impl_type=FloatToIntImplType.ROUND, min_val=None) self.restrict_scaling_preprocess = RestrictValue.restrict_value_op(restrict_value_type, restrict_value_op_impl_type= RestrictValueOpImplType.TORCH_MODULE)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): weight_key = prefix + 'weight' bias_key = prefix + 'bias' running_mean_key = prefix + 'running_mean' running_var_key = prefix + 'running_var' num_batches_tracked_key = prefix + 'num_batches_tracked' # If it's converting a FP BN into weight/bias impl if self.impl_type == ScalingImplType.PARAMETER_FROM_STATS \ and running_mean_key in state_dict and running_var_key in state_dict: weight_init, bias_init = mul_add_from_bn(bn_bias=state_dict[bias_key], bn_weight=state_dict[weight_key], bn_mean=state_dict[running_mean_key], bn_var=state_dict[running_var_key], bn_eps=self.eps, affine_only=False) restrict_op = RestrictValue.restrict_value_op(restrict_value_type=self.restrict_value_type, restrict_value_op_impl_type=RestrictValueOpImplType.TORCH_FN) self.weight_sign = torch.sign(weight_init.data) weight_init = weight_init.detach().clone().abs().data self.weight.data = restrict_op(weight_init) self.bias.data = bias_init.detach().clone().data del state_dict[bias_key] del state_dict[weight_key] del state_dict[running_mean_key] del state_dict[running_var_key] del state_dict[num_batches_tracked_key] super(QuantBatchNorm2d, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) if config.IGNORE_MISSING_KEYS and bias_key in missing_keys: missing_keys.remove(bias_key) if config.IGNORE_MISSING_KEYS and weight_key in missing_keys: missing_keys.remove(weight_key) if num_batches_tracked_key in unexpected_keys: unexpected_keys.remove(num_batches_tracked_key)