def TruncRoundNearest(a, k, m, kappa): """ Returns a / 2^m, rounded to the nearest integer. k: bit length of m m: compile-time integer """ from types import sint, cint from library import reveal, load_int_to_secret if m == 1: lsb = sint() Mod2(lsb, a, k, kappa, False) return (a + lsb) / 2 r_dprime = sint() r_prime = sint() r = [sint() for i in range(m)] u = sint() PRandM(r_dprime, r_prime, r, k, m, kappa) c = reveal((cint(1) << (k - 1)) + a + (cint(1) << m) * r_dprime + r_prime) c_prime = c % (cint(1) << (m - 1)) if const_rounds: BitLTC1(u, c_prime, r[:-1], kappa) else: BitLTL(u, c_prime, r[:-1], kappa) bit = ((c - c_prime) / (cint(1) << (m - 1))) % 2 xor = bit + u - 2 * bit * u prod = xor * r[-1] # u_prime = xor * u + (1 - xor) * r[-1] u_prime = bit * u + u - 2 * bit * u + r[-1] - prod a_prime = (c % (cint(1) << m)) - r_prime + (cint(1) << m) * u_prime d = (a - a_prime) / (cint(1) << m) rounding = xor + r[-1] - 2 * prod return d + rounding
def TruncRoundNearest(a, k, m, kappa): """ Returns a / 2^m, rounded to the nearest integer. k: bit length of m m: compile-time integer """ from types import sint, cint from library import reveal, load_int_to_secret if m == 1: if program.options.ring: lsb = Mod2mRing(None, a, k, 1, False) return TruncRing(None, a + lsb, k + 1, 1, False) else: lsb = sint() Mod2(lsb, a, k, kappa, False) return (a + lsb) / 2 r_dprime = sint() r_prime = sint() r = [sint() for i in range(m)] u = sint() PRandM(r_dprime, r_prime, r, k, m, kappa) c = reveal((cint(1) << (k - 1)) + a + (cint(1) << m) * r_dprime + r_prime) c_prime = c % (cint(1) << (m - 1)) if const_rounds: BitLTC1(u, c_prime, r[:-1], kappa) else: BitLTL(u, c_prime, r[:-1], kappa) bit = ((c - c_prime) >> (m - 1)) % 2 xor = bit + u - 2 * bit * u prod = xor * r[-1] # u_prime = xor * u + (1 - xor) * r[-1] u_prime = bit * u + u - 2 * bit * u + r[-1] - prod a_prime = (c % (cint(1) << m)) - r_prime + (cint(1) << m) * u_prime d = (a - a_prime) >> m rounding = xor + r[-1] - 2 * prod return d + rounding