def wraps(x): """Privately computes the number of wraparounds for a set a shares To do so, we note that: [theta_x] = theta_z + [beta_xr] - [theta_r] - [eta_xr] Where [theta_i] is the wraps for a variable i [beta_ij] is the differential wraps for variables i and j [eta_ij] is the plaintext wraps for variables i and j Note: Since [eta_xr] = 0 with probability 1 - |x| / Q for modulus Q, we can make the assumption that [eta_xr] = 0 with high probability. """ provider = crypten.mpc.get_default_provider() r, theta_r = provider.wrap_rng(x.size()) beta_xr = theta_r.clone() beta_xr._tensor = count_wraps([x._tensor, r._tensor]) z = x + r theta_z = comm.get().gather(z._tensor, 0) theta_x = beta_xr - theta_r # TODO: Incorporate eta_xr if x.rank == 0: theta_z = count_wraps(theta_z) theta_x._tensor += theta_z return theta_x
def wraps(self, size): r = [ generate_random_ring_element(size, generator=g) for g in self.generators ] theta_r = count_wraps(r) return theta_r - self._get_additive_PRSS(size, remove_rank=True)
def wrap_rng(size, num_parties): """Generate random shared tensor of given size and sharing of its wraps""" r = [generate_random_ring_element(size) for _ in range(num_parties)] theta_r = count_wraps(r) shares = comm.get().scatter(r, src=0) r = ArithmeticSharedTensor.from_shares(shares, precision=0) theta_r = ArithmeticSharedTensor(theta_r, precision=0, src=0) return r, theta_r
def test_wraps(self): num_parties = int(self.world_size) size = (5, 5) # Generate random sharing with internal value get_random_test_tensor() zero_shares = generate_random_ring_element((num_parties, *size)) zero_shares = zero_shares - zero_shares.roll(1, dims=0) shares = list(zero_shares.unbind(0)) shares[0] += get_random_test_tensor(size=size, is_float=False) # Note: This test relies on count_wraps function being correct reference = count_wraps(shares) # Sync shares between parties share = comm.get().scatter(shares, 0) encrypted_tensor = ArithmeticSharedTensor.from_shares(share) encrypted_wraps = encrypted_tensor.wraps() test_passed = (encrypted_wraps.reveal() == reference ).sum() == reference.nelement() self.assertTrue(test_passed, "%d-party wraps failed" % num_parties)