Example #1
0
def get_sign_bias_from_bn(prev: nn.Module, bn: nn.BatchNorm2d, scale=None):
    """get equivalent ``(s, b)`` for this bn such that
    ``bn(prev(x * scale)) >= 0`` iff ``s * (x + b) >= 0``, where
    ``s in [-1, 1]`` and ``b`` is an integer"""
    from .net_bin import PositiveInputCombination
    mean = bn.running_mean

    if isinstance(prev, PositiveInputCombination):
        mean = mean - prev.get_bias()

    k, b = bn._get_scale_bias(bn.running_var, mean)
    # cond: k * x + b > 0
    if scale is not None:
        k *= scale

    sign = k.sign()
    assert torch.all(sign.abs() > 1e-4)
    bias = b / k
    bias = torch.floor_(bias) + (sign < 0).to(bias.dtype)
    return sign, bias