def cint_cint_division(a, b, k, f): """ Goldschmidt method implemented with SE aproximation: http://stackoverflow.com/questions/2661541/picking-good-first-estimates-for-goldschmidt-division """ # theta can be replaced with something smaller # for safety we assume that is the same theta from previous GS method theta = int(ceil(log(k/3.5) / log(2))) two = cint(2) * two_power(f) sign_b = cint(1) - 2 * cint(b < 0) sign_a = cint(1) - 2 * cint(a < 0) absolute_b = b * sign_b absolute_a = a * sign_a w0 = approximate_reciprocal(absolute_b, k, f, theta) A = Array(theta, cint) B = Array(theta, cint) W = Array(theta, cint) A[0] = absolute_a B[0] = absolute_b W[0] = w0 for i in range(1, theta): A[i] = shift_two(A[i - 1] * W[i - 1], f) B[i] = shift_two(B[i - 1] * W[i - 1], f) W[i] = two - B[i] return (sign_a * sign_b) * A[theta - 1]
def sint_cint_division(a, b, k, f, kappa): """ type(a) = sint, type(b) = cint """ theta = int(ceil(log(k/3.5) / log(2))) two = cint(2) * two_power(f) sign_b = cint(1) - 2 * cint(b < 0) sign_a = sint(1) - 2 * sint(a < 0) absolute_b = b * sign_b absolute_a = a * sign_a w0 = approximate_reciprocal(absolute_b, k, f, theta) A = Array(theta, sint) B = Array(theta, cint) W = Array(theta, cint) A[0] = absolute_a B[0] = absolute_b W[0] = w0 @for_range(1, theta) def block(i): A[i] = TruncPr(A[i - 1] * W[i - 1], 2*k, f, kappa) temp = (B[i - 1] * W[i - 1]) >> f # no reading and writing to the same variable in a for loop. W[i] = two - temp B[i] = temp return (sign_a * sign_b) * A[theta - 1]
def Norm(b, k, f, kappa, simplex_flag=False): """ Computes secret integer values [c] and [v_prime] st. 2^{k-1} <= c < 2^k and c = b*v_prime """ # For simplex, we can get rid of computing abs(b) temp = None if simplex_flag == False: temp = b.less_than(0, 2 * k) elif simplex_flag == True: temp = cint(0) sign = 1 - 2 * temp # 1 - 2 * [b < 0] absolute_val = sign * b #next 2 lines actually compute the SufOR for little indian encoding bits = absolute_val.bit_decompose(k, kappa)[::-1] suffixes = PreOR(bits)[::-1] z = [0] * k for i in range(k - 1): z[i] = suffixes[i] - suffixes[i + 1] z[k - 1] = suffixes[k - 1] #doing complicated stuff to compute v = 2^{k-m} acc = cint(0) for i in range(k): acc += two_power(k - i - 1) * z[i] part_reciprocal = absolute_val * acc signed_acc = sign * acc return part_reciprocal, signed_acc
def cint_cint_division(a, b, k, f): """ Goldschmidt method implemented with SE aproximation: http://stackoverflow.com/questions/2661541/picking-good-first-estimates-for-goldschmidt-division """ # theta can be replaced with something smaller # for safety we assume that is the same theta from previous GS method theta = int(ceil(log(k/3.5) / log(2))) two = cint(2) * two_power(f) sign_b = cint(1) - 2 * cint(b < 0) sign_a = cint(1) - 2 * cint(a < 0) absolute_b = b * sign_b absolute_a = a * sign_a w0 = approximate_reciprocal(absolute_b, k, f, theta) A = Array(theta, cint) B = Array(theta, cint) W = Array(theta, cint) A[0] = absolute_a B[0] = absolute_b W[0] = w0 @for_range(1, theta) def block(i): A[i] = (A[i - 1] * W[i - 1]) >> f B[i] = (B[i - 1] * W[i - 1]) >> f W[i] = two - B[i] return (sign_a * sign_b) * A[theta - 1]
def Norm(b, k, f, kappa, simplex_flag=False): """ Computes secret integer values [c] and [v_prime] st. 2^{k-1} <= c < 2^k and c = b*v_prime """ # For simplex, we can get rid of computing abs(b) temp = None if simplex_flag == False: temp = b.less_than(0, 2 * k) elif simplex_flag == True: temp = cint(0) sign = 1 - 2 * temp # 1 - 2 * [b < 0] absolute_val = sign * b #next 2 lines actually compute the SufOR for little indian encoding bits = absolute_val.bit_decompose(k, kappa)[::-1] suffixes = PreOR(bits)[::-1] z = [0] * k for i in range(k - 1): z[i] = suffixes[i] - suffixes[i+1] z[k - 1] = suffixes[k-1] #doing complicated stuff to compute v = 2^{k-m} acc = cint(0) for i in range(k): acc += two_power(k-i-1) * z[i] part_reciprocal = absolute_val * acc signed_acc = sign * acc return part_reciprocal, signed_acc
def sint_cint_division(a, b, k, f, kappa): """ type(a) = sint, type(b) = cint """ theta = int(ceil(log(k/3.5) / log(2))) two = cint(2) * two_power(f) sign_b = cint(1) - 2 * cint(b < 0) sign_a = sint(1) - 2 * sint(a < 0) absolute_b = b * sign_b absolute_a = a * sign_a w0 = approximate_reciprocal(absolute_b, k, f, theta) A = Array(theta, sint) B = Array(theta, cint) W = Array(theta, cint) A[0] = absolute_a B[0] = absolute_b W[0] = w0 @for_range(1, theta) def block(i): A[i] = TruncPr(A[i - 1] * W[i - 1], 2*k, f, kappa) temp = shift_two(B[i - 1] * W[i - 1], f) # no reading and writing to the same variable in a for loop. W[i] = two - temp B[i] = temp return (sign_a * sign_b) * A[theta - 1]
def f(i): time() types.cint(i).print_reg('mpre') @for_range(self.M) def f(j): self.m_prefs[i][j] = (i + j) % self.N
def f(i): time() types.cint(i).print_reg('fran') @for_range(self.M) def f(j): self.f_ranks[i][(j - i + self.N) % self.N] = j
def f(i): time() types.cint(i).print_reg('fran') @for_range(inner_loops) def f(j): if self.reverse: self.f_ranks[i] = tuple((-i - j + 2 * self.N - 2) % self.N \ for j in range(self.M)) else: self.f_ranks[i][(-i - j + 2 * self.N - 2) % self.N] = j
def f(i): time() types.cint(i).print_reg('mpre') @for_range(inner_loops) def f(j): self.m_prefs[i][j] = (-i + j + self.N - 1) % (self.N - 1) if self.M < self.N: self.m_prefs[i][self.M-1] = (2 * self.N - 2 - i) % self.N else: self.m_prefs[i][self.N-1] = self.N - 1
def reveal(secret): try: return secret.reveal() except AttributeError: res = cint() instructions.asm_open(res, secret) return res
def int2FL_plain(a, gamma, l, kappa): lam = gamma - 1 a_abs = 0 v = cint(0) p = cint(0) s = cint(0) z = cint(0) # extracts the sign and calculates the abs s = cint(a < 0) a_abs = a * (1 - 2 * s) # isolates most significative bit a_bits = a_abs.bit_decompose() b = 0 b_c = 1 blen = 0 for a_i in range(len(a_bits) - 1, -1, -1): # enumerate(a_bits): b = (a_bits[a_i]) * (b == 0) * ((b_c) / 2) + b blen = (a_bits[a_i]) * (blen == 0) * ((a_i + 1)) + blen b_c = b_c * 2 # obtains p # blen= len(a_bits) - blen v = a_abs * b # (2 ** (b))#scale a p = - (lam - blen) # (len(a_bits)-blen)) # reduces v v_l = MemValue(v) z_l = MemValue(z) if_then(a_abs > 0) if (lam > l): v_l.write(v_l.read() / (2 ** (gamma - l - 1))) else: v_l.write(v_l.read() * (2 ** l - lam)) else_then() z_l.write(cint(1)) end_if() # corrects output # s is coming from the abs extraction v = cint(v_l.read()) z = cint(z_l.read()) p = cint((p + lam - l) * (1 - z)) return v, p, z, s
def load_int(self, value): if self.n <= 64: tmp = regint(value) elif value == self.long_one(): tmp = cint(1, size=self.n) else: raise CompilerError('loading long integers to cbits not supported') self.load_other(tmp)
def reveal(secret): if isinstance(secret, _number): return secret.reveal() if secret.is_gf2n: res = cgf2n() else: res = cint() instructions.asm_open(res, secret) return res
def reveal(secret): try: return secret.reveal() except AttributeError: if secret.is_gf2n: res = cgf2n() else: res = cint() instructions.asm_open(res, secret) return res
def bits(a,m): """ Get the bits of an int """ if isinstance(a, int): res = [None]*m for i in range(m): res[i] = a & 1 a >>= 1 else: res = [] from Compiler.types import regint, cint while m > 0: aa = regint() convmodp(aa, a, bitlength=0) res += [cint(x) for x in aa.bit_decompose(min(64, m))] m -= 64 if m > 0: aa = cint() shrci(aa, a, 64) a = aa return res
def AppRcr(b, k, f, kappa, simplex_flag=False): """ Approximate reciprocal of [b]: Given [b], compute [1/b] """ alpha = cint(int(2.9142 * 2**k)) c, v = Norm(b, k, f, kappa, simplex_flag) #v should be 2**{k - m} where m is the length of the bitwise repr of [b] d = alpha - 2 * c w = d * v w = TruncPr(w, 2 * k, 2 * (k - f)) # now w * 2 ^ {-f} should be an initial approximation of 1/b return w
def print_str(s, *args): """ Print a string, with optional args for adding variables/registers with %s """ def print_plain_str(ss): """ Print a plain string (no custom formatting options) """ i = 1 while 4*i < len(ss): print_char4(ss[4*(i-1):4*i]) i += 1 i = 4*(i-1) while i < len(ss): print_char(ss[i]) i += 1 if len(args) != s.count('%s'): raise CompilerError('Incorrect number of arguments for string format:', s) substrings = s.split('%s') for i,ss in enumerate(substrings): print_plain_str(ss) if i < len(args): if isinstance(args[i], MemValue): val = args[i].register else: val = args[i] if isinstance(val, program.Tape.Register): if val.reg_type == 'ci': cint(val).print_reg_plain() elif val.is_clear: val.print_reg_plain() else: raise CompilerError('Cannot print secret value:', args[i]) elif isinstance(val, list): print_str('[' + ', '.join('%s' for i in range(len(val))) + ']', *val) else: try: val.output() except AttributeError: print_plain_str(str(val))
def SDiv(a, b, l, kappa): theta = int(ceil(log(l / 3.5) / log(2))) alpha = AdvInteger.two_power(2 * l) beta = 1 / types.cint(AdvInteger.two_power(l)) w = types.cint(int(2.9142 * AdvInteger.two_power(l))) - 2 * b x = alpha - b * w y = a * w y = AdvInteger.TruncPr(y, 2 * l, l, kappa) x2 = types.sint() AdvInteger.Mod2m(x2, x, 2 * l + 1, l, kappa, False) x1 = (x - x2) * beta for i in range(theta - 1): y = y * (x1 + two_power(l)) + AdvInteger.TruncPr( y * x2, 2 * l, l, kappa) y = AdvInteger.TruncPr(y, 2 * l + 1, l + 1, kappa) x = x1 * x2 + AdvInteger.TruncPr(x2**2, 2 * l + 1, l + 1, kappa) x = x1 * x1 + AdvInteger.TruncPr(x, 2 * l + 1, l - 1, kappa) x2 = types.sint() AdvInteger.Mod2m(x2, x, 2 * l, l, kappa, False) x1 = (x - x2) * beta y = y * (x1 + two_power(l)) + AdvInteger.TruncPr(y * x2, 2 * l, l, kappa) y = AdvInteger.TruncPr(y, 2 * l + 1, l - 1, kappa) return y
def twos_complement(x): bits = x.bit_decompose(k)[::-1] bit_array = Array(k, cint) bit_array.assign(bits) twos_result = MemValue(cint(0)) @for_range(k) def block(i): val = twos_result.read() val <<= 1 val += 1 - bit_array[i] twos_result.write(val) return twos_result.read() + 1
def SDiv_mono(a, b, l, kappa): theta = int(ceil(log(l / 3.5) / log(2))) alpha = two_power(2 * l) w = types.cint(int(2.9142 * two_power(l))) - 2 * b x = alpha - b * w y = a * w y = AdvInteger.TruncPr(y, 2 * l + 1, l + 1, kappa) for i in range(theta - 1): y = y * (alpha + x) # keep y with l bits y = AdvInteger.TruncPr(y, 3 * l, 2 * l, kappa) x = x**2 # keep x with 2l bits x = AdvInteger.TruncPr(x, 4 * l, 2 * l, kappa) y = y * (alpha + x) y = AdvInteger.TruncPr(y, 3 * l, 2 * l, kappa) return y
def FPDiv(a, b, k, f, kappa, simplex_flag=False): """ Goldschmidt method as presented in Catrina10, """ theta = int(ceil(log(k/3.5))) alpha = cint(1) << (2 * f) w = AppRcr(b, k, f, kappa, simplex_flag) x = alpha - b * w y = a * w y = TruncPr(y, 2*k, f, kappa) for i in range(theta+1): y = y * (alpha + x) x = x * x y = TruncPr(y, 2*k, 2*f, kappa) x = TruncPr(x, 2*k, 2*f, kappa) y = y * (alpha + x) y = TruncPr(y, 2*k, 2*f, kappa) return y
def f(i): types.cint(i).print_reg('wife') self.husbands[i].reveal().print_reg('husb')
def f(i): time() types.cint(i).print_reg('fran') @for_range(self.M) def f(j): self.f_ranks[i][(j-i+self.N)%self.N] = j
def BitDecFull(a, n_bits=None, maybe_mixed=False): from .library import get_program, do_while, if_, break_point from .types import sint, regint, longint, cint p = get_program().prime assert p bit_length = p.bit_length() n_bits = n_bits or bit_length assert n_bits <= bit_length logp = int(round(math.log(p, 2))) if abs(p - 2**logp) / p < 2**-get_program().security: # inspired by Rabbit (https://eprint.iacr.org/2021/119) # no need for exact randomness generation # if modulo a power of two is close enough if get_program().use_edabit(): b, bbits = sint.get_edabit(logp, True, size=a.size) if logp != bit_length: from .GC.types import sbits bbits += [0] else: bbits = [sint.get_random_bit(size=a.size) for i in range(logp)] b = sint.bit_compose(bbits) if logp != bit_length: bbits += [sint(0, size=a.size)] else: bbits = [sint(size=a.size) for i in range(bit_length)] tbits = [[sint(size=1) for i in range(bit_length)] for j in range(a.size)] pbits = util.bit_decompose(p) # Loop until we get some random integers less than p done = [regint(0) for i in range(a.size)] @do_while def get_bits_loop(): for j in range(a.size): @if_(done[j] == 0) def _(): for i in range(bit_length): tbits[j][i].link(sint.get_random_bit()) c = regint(BITLT(tbits[j], pbits, bit_length).reveal()) done[j].link(c) return (sum(done) != a.size) for j in range(a.size): for i in range(bit_length): movs(bbits[i][j], tbits[j][i]) b = sint.bit_compose(bbits) c = (a - b).reveal() cmodp = c t = bbits[0].bit_decompose_clear(p - c, bit_length) c = longint(c, bit_length) czero = (c == 0) q = bbits[0].long_one() - comparison.BitLTL_raw(bbits, t) fbar = [ bbits[0].clear_type.conv(cint(x)) for x in ((1 << bit_length) + c - p).bit_decompose(n_bits) ] fbard = bbits[0].bit_decompose_clear(cmodp, n_bits) g = [q.if_else(fbar[i], fbard[i]) for i in range(n_bits)] h = bbits[0].bit_adder(bbits, g) abits = [ bbits[0].clear_type(cint(czero)).if_else(bbits[i], h[i]) for i in range(n_bits) ] if maybe_mixed: return abits else: return [sint.conv(bit) for bit in abits]
def output_if(self, cond): cint(self).output_if(cond)
def print_str(s, *args): """ Print a string, with optional args for adding variables/registers with %s """ def print_plain_str(ss): """ Print a plain string (no custom formatting options) """ i = 1 while 4 * i < len(ss): print_char4(ss[4 * (i - 1):4 * i]) i += 1 i = 4 * (i - 1) while i < len(ss): print_char(ss[i]) i += 1 if len(args) != s.count('%s'): raise CompilerError('Incorrect number of arguments for string format:', s) substrings = s.split('%s') for i, ss in enumerate(substrings): print_plain_str(ss) if i < len(args): if isinstance(args[i], MemValue): val = args[i].read() else: val = args[i] if isinstance(val, program.Tape.Register): if val.is_clear: val.print_reg_plain() else: raise CompilerError('Cannot print secret value:', args[i]) elif isinstance(val, cfix): # print decimal representation of a clear fixed point number # number is encoded as [left].[right] left = val.v sign = -1 * (val.v < 0) + 1 * (val.v >= 0) positive_left = cint(sign) * left right = positive_left % 2**val.f @if_(sign == -1) def block(): print_str('-') cint((positive_left - right + 1) >> val.f).print_reg_plain() x = 0 max_dec_base = 8 # max 32-bit precision last_nonzero = 0 for i, b in enumerate(reversed(right.bit_decompose(val.f))): x += b * int(10**max_dec_base / 2**(i + 1)) v = x for i in range(max_dec_base): t = v % 10 b = (t > 0) last_nonzero = (1 - b) * last_nonzero + b * i v = (v - t) / 10 print_plain_str('.') @for_range(max_dec_base - 1 - last_nonzero) def f(i): print_str('0') x.print_reg_plain() elif isinstance(val, sfix) or isinstance(val, sfloat): raise CompilerError('Cannot print secret value:', args[i]) elif isinstance(val, cfloat): val.print_float_plain() elif isinstance(val, list): print_str('[' + ', '.join('%s' for i in range(len(val))) + ']', *val) else: try: val.output() except AttributeError: print_plain_str(str(val))
def output_if(self, cond): if Program.prog.options.binary: raise CompilerError('conditional output not supported') cint(self).output_if(cond)
def conv_regint_by_bit(cls, n, res, other): assert n == res.n assert n == other.size cls.conv_cint_vec(cint(other, size=other.size), res)