def Inv(a): """ Invert a non-zero value """ t = [types.sint() for i in range(3)] c = [types.cint() for i in range(2)] one = types.cint() ldi(one, 1) inverse(t[0], t[1]) s = t[0]*a asm_open(c[0], s) # avoid division by zero for benchmarking divc(c[1], one, c[0]) #divc(c[1], c[0], one) return c[1]*t[0]
def Trunc(a, l, m, kappa, compute_modulo=False): """ Oblivious truncation by secret m """ if l == 1: if compute_modulo: return a * m, 1 + m else: return a * (1 - m) r = [types.sint() for i in range(l)] r_dprime = types.sint(0) r_prime = types.sint(0) rk = types.sint() c = types.cint() ci = [types.cint() for i in range(l)] d = types.sint() x, pow2m = B2U(m, l, kappa) #assert(pow2m.value == 2**m.value) #assert(sum(b.value for b in x) == m.value) for i in range(l): bit(r[i]) t1 = two_power(i) * r[i] t2 = t1*x[i] r_prime += t2 r_dprime += t1 - t2 #assert(r_prime.value == (sum(2**i*x[i].value*r[i].value for i in range(l)) % comparison.program.P)) comparison.PRandInt(rk, kappa) r_dprime += two_power(l) * rk #assert(r_dprime.value == (2**l * rk.value + sum(2**i*(1 - x[i].value)*r[i].value for i in range(l)) % comparison.program.P)) asm_open(c, a + r_dprime + r_prime) for i in range(1,l): ci[i] = c % two_power(i) #assert(ci[i].value == c.value % 2**i) c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l)) #assert(c_dprime.value == (sum(ci[i].value*(x[i-1].value - x[i].value) for i in range(1,l)) % comparison.program.P)) lts(d, c_dprime, r_prime, l, kappa) if compute_modulo: b = c_dprime - r_prime + pow2m * d return b, pow2m else: to_shift = a - c_dprime + r_prime if program.Program.prog.options.ring: shifted = TruncInRing(to_shift, l, pow2m) else: pow2inv = Inv(pow2m) #assert(pow2inv.value * pow2m.value % comparison.program.P == 1) shifted = to_shift * pow2inv b = shifted - d return b
def maskField(a, k, kappa): r_dprime = types.sint() r_prime = types.sint() c = types.cint() r = [types.sint() for i in range(k)] comparison.PRandM(r_dprime, r_prime, r, k, k, kappa) asm_open(c, a + two_power(k) * r_dprime + r_prime)# + 2**(k-1)) return c, r
def Pow2(a, l, kappa): m = int(ceil(log(l, 2))) t = BitDec(a, m, m, kappa) x = [types.sint() for i in range(m)] pow2k = [types.cint() for i in range(m)] for i in range(m): pow2k[i] = two_power(2**i) t[i] = t[i]*pow2k[i] + 1 - t[i] return KMul(t)
def two_power(n): if isinstance(n, int) and n < 31: return 2**n else: max = types.cint(1) << 31 res = 2**(n%31) for i in range(n / 31): res *= max return res
def AppRcr(b, k, f, kappa): """ Approximate reciprocal of [b]: Given [b], compute [1/b] """ alpha = types.cint(int(2.9142 * (2**k))) c, v = Norm(b, k, f, kappa) d = alpha - 2 * c w = d * v w = TruncPr(w, 2 * k, 2 * (k - f)) return w
def SDiv(a, b, l, kappa): theta = int(ceil(log(l / 3.5) / log(2))) alpha = two_power(2*l) beta = 1 / types.cint(two_power(l)) w = types.cint(int(2.9142 * two_power(l))) - 2 * b x = alpha - b * w y = a * w y = TruncPr(y, 2 * l, l, kappa) x2 = types.sint() comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False) x1 = (x - x2) * beta for i in range(theta-1): y = y * (x1 + two_power(l)) + TruncPr(y * x2, 2 * l, l, kappa) y = TruncPr(y, 2 * l + 1, l + 1, kappa) x = x1 * x2 + TruncPr(x2**2, 2 * l + 1, l + 1, kappa) x = x1 * x1 + TruncPr(x, 2 * l + 1, l - 1, kappa) x2 = types.sint() comparison.Mod2m(x2, x, 2 * l, l, kappa, False) x1 = (x - x2) * beta y = y * (x1 + two_power(l)) + TruncPr(y * x2, 2 * l, l, kappa) y = TruncPr(y, 2 * l + 1, l - 1, kappa) return y
def TruncRoundNearest(a, k, m, kappa): """ Returns a / 2^m, rounded to the nearest integer. k: bit length of m m: compile-time integer """ from types import sint, cint from library import reveal, load_int_to_secret if m == 1: if program.options.ring: lsb = Mod2mRing(None, a, k, 1, False) return TruncRing(None, a + lsb, k + 1, 1, False) else: lsb = sint() Mod2(lsb, a, k, kappa, False) return (a + lsb) / 2 r_dprime = sint() r_prime = sint() r = [sint() for i in range(m)] u = sint() PRandM(r_dprime, r_prime, r, k, m, kappa) c = reveal((cint(1) << (k - 1)) + a + (cint(1) << m) * r_dprime + r_prime) c_prime = c % (cint(1) << (m - 1)) if const_rounds: BitLTC1(u, c_prime, r[:-1], kappa) else: BitLTL(u, c_prime, r[:-1], kappa) bit = ((c - c_prime) >> (m - 1)) % 2 xor = bit + u - 2 * bit * u prod = xor * r[-1] # u_prime = xor * u + (1 - xor) * r[-1] u_prime = bit * u + u - 2 * bit * u + r[-1] - prod a_prime = (c % (cint(1) << m)) - r_prime + (cint(1) << m) * u_prime d = (a - a_prime) >> m rounding = xor + r[-1] - 2 * prod return d + rounding
def B2U_from_Pow2(pow2a, l, kappa): #assert(pow2a.value == 2**a.value) r = [types.sint() for i in range(l)] t = types.sint() c = types.cint() for i in range(l): bit(r[i]) comparison.PRandInt(t, kappa) asm_open(c, pow2a + two_power(l) * t + sum(two_power(i)*r[i] for i in range(l))) comparison.program.curr_tape.require_bit_length(l + kappa) c = list(bits(c, l)) x = [c[i] + r[i] - 2*c[i]*r[i] for i in range(l)] #print ' '.join(str(b.value) for b in x) y = PreOR(x, kappa) #print ' '.join(str(b.value) for b in y) return [1 - y[i] for i in range(l)]
def SDiv_mono(a, b, l, kappa): theta = int(ceil(log(l / 3.5) / log(2))) alpha = two_power(2*l) w = types.cint(int(2.9142 * two_power(l))) - 2 * b x = alpha - b * w y = a * w y = TruncPr(y, 2 * l + 1, l + 1, kappa) for i in range(theta-1): y = y * (alpha + x) # keep y with l bits y = TruncPr(y, 3 * l, 2 * l, kappa) x = x**2 # keep x with 2l bits x = TruncPr(x, 4 * l, 2 * l, kappa) y = y * (alpha + x) y = TruncPr(y, 3 * l, 2 * l, kappa) return y
def B2U(a, l, kappa): pow2a = Pow2(a, l, kappa) #assert(pow2a.value == 2**a.value) r = [types.sint() for i in range(l)] t = types.sint() c = types.cint() for i in range(l): bit(r[i]) comparison.PRandInt(t, kappa) asm_open(c, pow2a + two_power(l) * t + sum(two_power(i)*r[i] for i in range(l))) comparison.program.curr_tape.require_bit_length(l + kappa) c = list(bits(c, l)) x = [c[i] + r[i] - 2*c[i]*r[i] for i in range(l)] #print ' '.join(str(b.value) for b in x) y = PreOR(x, kappa) #print ' '.join(str(b.value) for b in y) return [1 - y[i] for i in range(l)], pow2a
def FPDiv(a, b, k, f, kappa): theta = int(ceil(log(k / 3.5))) alpha = types.cint(1 * two_power(2 * f)) w = AppRcr(b, k, f, kappa) x = alpha - b * w y = a * w y = TruncPr(y, 2 * k, f, kappa) for i in range(theta): y = y * (alpha + x) x = x * x y = TruncPr(y, 2 * k, 2 * f, kappa) x = TruncPr(x, 2 * k, 2 * f, kappa) y = y * (alpha + x) y = TruncPr(y, 2 * k, 2 * f, kappa) return y
def FPDiv(a, b, k, f, kappa): theta = int(ceil(log(k/3.5))) alpha = types.cint(1 * two_power(2*f)) w = AppRcr(b, k, f, kappa) x = alpha - b * w y = a * w y = TruncPr(y, 2*k, f, kappa) for i in range(theta): y = y * (alpha + x) x = x * x y = TruncPr(y, 2*k, 2*f, kappa) x = TruncPr(x, 2*k, 2*f, kappa) y = y * (alpha + x) y = TruncPr(y, 2*k, 2*f, kappa) return y
def BitDec(a, k, m, kappa, bits_to_compute=None): r_dprime = types.sint() r_prime = types.sint() c = types.cint() r = [types.sint() for i in range(m)] comparison.PRandM(r_dprime, r_prime, r, k, m, kappa) #assert(r_prime.value == sum(r[i].value*2**i for i in range(m)) % comparison.program.P) pow2 = two_power(k + kappa) asm_open(c, pow2 + two_power(k) + a - two_power(m) * r_dprime - r_prime) #rval = 2**m*r_dprime.value + r_prime.value #assert(rval % 2**m == r_prime.value) #assert(rval == (2**m*r_dprime.value + sum(r[i].value*2**i for i in range(m)) % comparison.program.P )) try: pass #assert(c.value == (2**(k + kappa) + 2**k + (a.value%2**k) - rval) % comparison.program.P) except AssertionError: print 'BitDec assertion failed' print 'a =', a.value print 'a mod 2^%d =' % k, (a.value % 2**k) return BitAdd(list(bits(c, m)), r, bits_to_compute)[:-1]
def BitDecField(a, k, m, kappa, bits_to_compute=None): r_dprime = types.sint() r_prime = types.sint() c = types.cint() r = [types.sint() for i in range(m)] comparison.PRandM(r_dprime, r_prime, r, k, m, kappa) #assert(r_prime.value == sum(r[i].value*2**i for i in range(m)) % comparison.program.P) pow2 = two_power(k + kappa) asm_open(c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime) #rval = 2**m*r_dprime.value + r_prime.value #assert(rval % 2**m == r_prime.value) #assert(rval == (2**m*r_dprime.value + sum(r[i].value*2**i for i in range(m)) % comparison.program.P )) try: pass#assert(c.value == (2**(k + kappa) + 2**k + (a.value%2**k) - rval) % comparison.program.P) except AssertionError: print 'BitDec assertion failed' print 'a =', a.value print 'a mod 2^%d =' % k, (a.value % 2**k) return types.intbitint.bit_adder(list(bits(c,m)), r)
def SDiv(a, b, l, kappa, round_nearest=False): theta = int(ceil(log(l / 3.5) / log(2))) alpha = two_power(2 * l) w = types.cint(int(2.9142 * two_power(l))) - 2 * b x = alpha - b * w y = a * w y = y.round(2 * l + 1, l, kappa, round_nearest) x2 = types.sint() comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False) x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True) for i in range(theta - 1): y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, round_nearest) y = y.round(2 * l + 1, l + 1, kappa, round_nearest) x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest) x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest) x2 = types.sint() comparison.Mod2m(x2, x, 2 * l, l, kappa, False) x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True) y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, round_nearest) y = y.round(2 * l + 1, l - 1, kappa, round_nearest) return y