def get_sign_bias_from_bn(prev: nn.Module, bn: nn.BatchNorm2d, scale=None): """get equivalent ``(s, b)`` for this bn such that ``bn(prev(x * scale)) >= 0`` iff ``s * (x + b) >= 0``, where ``s in [-1, 1]`` and ``b`` is an integer""" from .net_bin import PositiveInputCombination mean = bn.running_mean if isinstance(prev, PositiveInputCombination): mean = mean - prev.get_bias() k, b = bn._get_scale_bias(bn.running_var, mean) # cond: k * x + b > 0 if scale is not None: k *= scale sign = k.sign() assert torch.all(sign.abs() > 1e-4) bias = b / k bias = torch.floor_(bias) + (sign < 0).to(bias.dtype) return sign, bias