async def _extract_lsb(context: Mpc, x: (Share, ShareFuture)): """ Section 5.3 Extracting the Least Significant Bit Returns a future to [x_0], which represents [r]_B > c """ bit_length = context.field.modulus.bit_length() s_b, s_bits = context.preproc.get_share_bits(context) d_ = s_b + x d = await d_.open() # msb s_0 = s_bits[0] # lsb s_1 = s_bits[bit_length - 1] # [s_{bit_length-1}] s_2 = s_bits[bit_length - 2] # [s_{bit_length-2}] s_prod = s_1 * s_2 # lsb d0 = d.value & 1 d_xor_1 = context.field(d0 ^ (d.value < (1 << (bit_length - 1)))) d_xor_2 = context.field(d0 ^ (d.value < (1 << (bit_length - 2)))) d_xor_12 = context.field(d0 ^ (d.value < ((1 << (bit_length - 1)) + (1 << (bit_length - 2))))) d_0 = ((context.field(1) - s_1 - s_2 + s_prod) * d0 + ((s_2 - s_prod) * d_xor_2) + ((s_1 - s_prod) * d_xor_1) + (s_prod * d_xor_12)) # [x0] = [s0] ^ [d0], equal to [r]_B > c return LessThan._xor_bits(s_0, d_0)
def _compute_x(context: Mpc, r_bits: list, c_bits: list): """ Section 5.2 Computing X Computes [x] from equation 7 The least significant bit of [x], written [x_0] is equal to the value [r_i], where i is the most significant bit where [r_i] != c_i [x_0] == ([r]_B > c) TODO: precompute PRODUCT(1 + [r_j]) Compute PRODUCT(1 + c_j) without MPC See final further work points in paper section 6 """ power_bits = [ context.field(1) + LessThan._xor_bits(r, c) for r, c in zip(r_bits[1:], c_bits[1:]) ] powers = [context.Share(1)] for b in reversed(power_bits): powers.insert(0, b * powers[0]) # TODO: make this log(n) x = context.field(0) for (r_i, c_i, p) in zip(r_bits, c_bits, powers): x += r_i * (context.field(1) - c_i) * p return x
async def _prog(context: Mpc, xs: ShareArray): rs = context.ShareArray( [context.preproc.get_rand(context) for _ in range(len(xs))]) sigs = await (await (xs * rs)).open() sig_invs = context.ShareArray([1 / sig for sig in sigs]) return await (rs * sig_invs)
def from_point(context: Mpc, p: Point) -> SharedPoint: """ Given a local point and a context, created a shared point """ if not isinstance(p, Point): raise Exception(f"Could not create shared point-- p ({p}) is not a Point!") return SharedPoint( context, context.Share(p.x), context.Share(p.y), curve=p.curve )
async def gen_test_bit(context: Mpc, diff: Share): cj, bj = await Equality._gen_test_bit(context, diff) while cj == 0: cj, bj = await Equality._gen_test_bit(context, diff) legendre = Equality.legendre_mod_p(cj) if legendre == 0: return Equality.gen_test_bit(context, diff) return (legendre / context.field(2)) * (bj + context.Share(legendre))
async def reduce_degree_share_array(context: Mpc, x_2t: ShareArray): assert x_2t.t == context.t * 2 r_t, r_2t = [], [] for _ in range(len(x_2t)): r_t_, r_2t_ = context.preproc.get_double_shares(context) r_t.append(r_t_) r_2t.append(r_2t_) q_t = context.ShareArray(r_t) q_2t = context.ShareArray(r_2t, 2 * context.t) diff = await (x_2t - q_2t).open() return q_t + diff
async def _prog(context: Mpc, x: ShareArray, y: ShareArray): assert len(x) == len(y) xy_2t = context.ShareArray( [j.v * k.v for j, k in zip(x._shares, y._shares)], context.t * 2) xy_t = await DoubleSharingMultiplyArrays.reduce_degree_share_array( context, xy_2t) return xy_t
async def _prog(context: Mpc, j: ShareArray, k: ShareArray): assert len(j) == len(k) a, b, ab = [], [], [] for _ in range(len(j)): p, q, pq = context.preproc.get_triples(context) a.append(p) b.append(q) ab.append(pq) u, v = context.ShareArray(a), context.ShareArray(b) f, g = await gather(*[(j - u).open(), (k - v).open()]) xy = [ d * e + d * q + e * p + pq for (p, q, pq, d, e) in zip(a, b, ab, f, g) ] return context.ShareArray(xy)
async def _prog(context: Mpc, p_share: Share, q_share: Share, security_parameter: int = 32): diff = p_share - q_share x = context.ShareArray(await gather(*[ Equality.gen_test_bit(context, diff) for _ in range(security_parameter) ])) return await x.multiplicative_product()
def execute(self, sid, program, **kwargs): send, recv = self.get_send_recv(sid) context = Mpc( sid, self.n, self.t, self.my_id, send, recv, program, self.mpc_config, **kwargs, ) program_result = asyncio.Future() def callback(future): program_result.set_result(future.result()) task = asyncio.create_task(context._run()) task.add_done_callback(callback) task.add_done_callback(print_exception_callback) self.progs.append(task) return program_result
async def _gen_test_bit(context: Mpc, diff: Share): # # b \in {0, 1} b = context.preproc.get_bit(context) # # _b \in {5, 1}, for p = 1 mod 8, s.t. (5/p) = -1 # # so _b = -4 * b + 5 _b = (-4 * b) + context.Share(5) _r = context.preproc.get_rand(context) _rp = context.preproc.get_rand(context) # c = a * r + b * rp * rp # If b_i == 1, c_i is guaranteed to be a square modulo p if a is zero # and with probability 1/2 otherwise (except if rp == 0). # If b_i == -1 it will be non-square. c = await ((diff * _r) + (_b * _rp * _rp)).open() return c, _b
async def _transform_comparison(context: Mpc, a_share: Share, b_share: Share): """ Section 5.1 First Transformation Compute [r]_B and [c]_B, which are bitwise sharings of a random share [r] and [c] = 2([a] - [b]) + [r] """ z = a_share - b_share r_b, r_bits = context.preproc.get_share_bits(context) # [c] = 2[z] + [r]_B = 2([a]-[b]) + [r]_B c = await (2 * z + r_b).open() c_bits = [ context.field(x) for x in map(int, "{0:0255b}".format(c.value)) ] # LSB first c_bits.reverse() return r_bits, c_bits
async def generate_bits(n, t, k, my_id, _send, _recv, field): subscribe_recv_task, subscribe = subscribe_recv(_recv) def _get_send_recv(tag): return wrap_send(tag, _send), subscribe(tag) # Start listening for my share of t and 2t shares from all parties. send, recv = _get_send_recv("randousha") rs_t2t = await randousha(n, t, 2 * k, my_id, send, recv, field) # To generate bits, we generate a batch of `t,2t` sharings of # [u]_t, [u]_2t, [r]_t, [r]_2t. The goal is to recontruct `u^2` # so we can return `[u]/sqrt(u^2)`. The [r] sharings are used # for publicly reconstructing: # u^2 = open([u]_t * [u]_t + [r]_2t) - [r]_t us_t2t = rs_t2t[0:k] rs_t2t = rs_t2t[k:2 * k] us_t, _ = zip(*us_t2t) us_t = list(map(field, us_t)) rs_t, rs_2t = zip(*rs_t2t) # Compute degree reduction to get the bit async def prog(ctx): u2rs_2t = [u * u + r for u, r in zip(us_t, rs_2t)] assert len(u2rs_2t) == len(rs_t) u2rs = await ctx.ShareArray(u2rs_2t, 2 * t).open() u2s_t = [u2r - r for u2r, r in zip(u2rs, rs_t)] u2s = await ctx.ShareArray(u2s_t).open() bits = [u / u2.sqrt() for u, u2 in zip(us_t, u2s)] return bits # TODO: compute triples through degree reduction send, recv = _get_send_recv("opening") ctx = Mpc(f"mpc:opening", n, t, my_id, send, recv, prog, {}) result = await ctx._run() # print(f'[{my_id}] Generate triples complete') subscribe_recv_task.cancel() return result
async def generate_triples(n, t, k, my_id, _send, _recv, field): subscribe_recv_task, subscribe = subscribe_recv(_recv) def _get_send_recv(tag): return wrap_send(tag, _send), subscribe(tag) # Start listening for my share of t and 2t shares from all parties. send, recv = _get_send_recv("randousha") rs_t2t = await randousha(n, t, 3 * k, my_id, send, recv, field) as_t2t = rs_t2t[0 * k:1 * k] bs_t2t = rs_t2t[1 * k:2 * k] rs_t2t = rs_t2t[2 * k:3 * k] as_t, _ = zip(*as_t2t) bs_t, _ = zip(*bs_t2t) as_t = list(map(field, as_t)) bs_t = list(map(field, bs_t)) rs_t, rs_2t = zip(*rs_t2t) # Compute degree reduction to get triples # TODO: Use the mixins and preprocessing system async def prog(ctx): assert len(rs_2t) == len(rs_t) == len(as_t) == len(bs_t) abrs_2t = [a * b + r for a, b, r in zip(as_t, bs_t, rs_2t)] abrs = await ctx.ShareArray(abrs_2t, 2 * t).open() abs_t = [abr - r for abr, r in zip(abrs, rs_t)] return list(zip(as_t, bs_t, abs_t)) # TODO: compute triples through degree reduction send, recv = _get_send_recv("opening") ctx = Mpc(f"mpc:opening", n, t, my_id, send, recv, prog, {}) result = await ctx._run() subscribe_recv_task.cancel() return result
async def _mixing_loop(self): # Task 3. Participating in mixing epochs contract_concise = ConciseContract(self.contract) pp_elements = PreProcessedElements() n = contract_concise.n() t = contract_concise.t() K = contract_concise.K() # noqa: N806 PER_MIX_TRIPLES = contract_concise.PER_MIX_TRIPLES() # noqa: N806 PER_MIX_BITS = contract_concise.PER_MIX_BITS() # noqa: N806 epoch = 0 while True: # 3.a. Wait for the next mix to be initiated while True: epochs_initiated = contract_concise.epochs_initiated() if epochs_initiated > epoch: break await asyncio.sleep(5) # 3.b. Collect the inputs inputs = [] for idx in range(epoch * K, (epoch + 1) * K): # Get the public input masked_input, inputmask_idx = contract_concise.input_queue(idx) masked_input = field(int.from_bytes(masked_input, "big")) # Get the input masks inputmask = self._inputmasks[inputmask_idx] m_share = masked_input - inputmask inputs.append(m_share) # 3.c. Collect the preprocessing triples = self._triples[ (epoch + 0) * PER_MIX_TRIPLES : (epoch + 1) * PER_MIX_TRIPLES ] bits = self._bits[(epoch + 0) * PER_MIX_BITS : (epoch + 1) * PER_MIX_BITS] # Hack explanation... the relevant mixins are in triples key = (self.myid, n, t) for mixin in (pp_elements._triples, pp_elements._one_minus_ones): if key in mixin.cache: del mixin.cache[key] del mixin.count[key] # 3.d. Call the MPC program async def prog(ctx): pp_elements._init_data_dir() # Overwrite triples and one_minus_ones for kind, elems in zip(("triples", "one_minus_one"), (triples, bits)): if kind == "triples": elems = flatten_lists(elems) elems = [e.value for e in elems] mixin = pp_elements.mixins[kind] mixin_filename = mixin.build_filename(ctx.N, ctx.t, ctx.myid) mixin._write_preprocessing_file( mixin_filename, ctx.t, ctx.myid, elems, append=False ) pp_elements._init_mixins() logging.info(f"[{ctx.myid}] Running permutation network") inps = list(map(ctx.Share, inputs)) assert len(inps) == K shuffled = await iterated_butterfly_network(ctx, inps, K) shuffled_shares = ctx.ShareArray(list(map(ctx.Share, shuffled))) opened_values = await shuffled_shares.open() msgs = [ m.value.to_bytes(32, "big").decode().strip("\x00") for m in opened_values ] return msgs send, recv = self.get_send_recv(f"mpc:{epoch}") logging.info(f"[{self.myid}] MPC initiated:{epoch}") # Config just has to specify mixins used by switching_network config = {MixinConstants.MultiplyShareArray: BeaverMultiplyArrays()} ctx = Mpc(f"mpc:{epoch}", n, t, self.myid, send, recv, prog, config) result = await ctx._run() logging.info(f"[{self.myid}] MPC complete {result}") # 3.e. Output the published messages to contract result = ",".join(result) tx_hash = self.contract.functions.propose_output(epoch, result).transact( {"from": self.w3.eth.accounts[self.myid]} ) tx_receipt = await wait_for_receipt(self.w3, tx_hash) rich_logs = self.contract.events.MixOutput().processReceipt(tx_receipt) if rich_logs: epoch = rich_logs[0]["args"]["epoch"] output = rich_logs[0]["args"]["output"] logging.info(f"[{self.myid}] MIX OUTPUT[{epoch}] {output}") else: pass epoch += 1 pass
async def _prog(context: Mpc, x: Share, y: Share): xy_2t = context.Share(x.v * y.v, context.t * 2) xy_t = await DoubleSharingMultiply.reduce_degree_share(context, xy_2t) return xy_t