Exemplo n.º 1
0
def TruncRoundNearest(a, k, m, kappa):
    """
    Returns a / 2^m, rounded to the nearest integer.

    k: bit length of m
    m: compile-time integer
    """
    from types import sint, cint
    from library import reveal, load_int_to_secret
    if m == 1:
        lsb = sint()
        Mod2(lsb, a, k, kappa, False)
        return (a + lsb) / 2
    r_dprime = sint()
    r_prime = sint()
    r = [sint() for i in range(m)]
    u = sint()
    PRandM(r_dprime, r_prime, r, k, m, kappa)
    c = reveal((cint(1) << (k - 1)) + a + (cint(1) << m) * r_dprime + r_prime)
    c_prime = c % (cint(1) << (m - 1))
    if const_rounds:
        BitLTC1(u, c_prime, r[:-1], kappa)
    else:
        BitLTL(u, c_prime, r[:-1], kappa)
    bit = ((c - c_prime) / (cint(1) << (m - 1))) % 2
    xor = bit + u - 2 * bit * u
    prod = xor * r[-1]
    # u_prime = xor * u + (1 - xor) * r[-1]
    u_prime = bit * u + u - 2 * bit * u + r[-1] - prod
    a_prime = (c % (cint(1) << m)) - r_prime + (cint(1) << m) * u_prime
    d = (a - a_prime) / (cint(1) << m)
    rounding = xor + r[-1] - 2 * prod
    return d + rounding
Exemplo n.º 2
0
def TruncRoundNearest(a, k, m, kappa):
    """
    Returns a / 2^m, rounded to the nearest integer.

    k: bit length of m
    m: compile-time integer
    """
    from types import sint, cint
    from library import reveal, load_int_to_secret
    if m == 1:
        if program.options.ring:
            lsb = Mod2mRing(None, a, k, 1, False)
            return TruncRing(None, a + lsb, k + 1, 1, False)
        else:
            lsb = sint()
            Mod2(lsb, a, k, kappa, False)
            return (a + lsb) / 2
    r_dprime = sint()
    r_prime = sint()
    r = [sint() for i in range(m)]
    u = sint()
    PRandM(r_dprime, r_prime, r, k, m, kappa)
    c = reveal((cint(1) << (k - 1)) + a + (cint(1) << m) * r_dprime + r_prime)
    c_prime = c % (cint(1) << (m - 1))
    if const_rounds:
        BitLTC1(u, c_prime, r[:-1], kappa)
    else:
        BitLTL(u, c_prime, r[:-1], kappa)
    bit = ((c - c_prime) >> (m - 1)) % 2
    xor = bit + u - 2 * bit * u
    prod = xor * r[-1]
    # u_prime = xor * u + (1 - xor) * r[-1]
    u_prime = bit * u + u - 2 * bit * u + r[-1] - prod
    a_prime = (c % (cint(1) << m)) - r_prime + (cint(1) << m) * u_prime
    d = (a - a_prime) >> m
    rounding = xor + r[-1] - 2 * prod
    return d + rounding