def generate_additive_triple(size0, size1, op, device=None, *args, **kwargs): """Generate multiplicative triples of given sizes""" generator = TTPClient.get().get_generator(device=device) a = generate_random_ring_element(size0, generator=generator, device=device) b = generate_random_ring_element(size1, generator=generator, device=device) if comm.get().get_rank() == 0: # Request c from TTP c = TTPClient.get().ttp_request("additive", device, size0, size1, op, *args, **kwargs) else: # TODO: Compute size without executing computation c_size = getattr(torch, op)(a, b, *args, **kwargs).size() c = generate_random_ring_element(c_size, generator=generator, device=device) a = ArithmeticSharedTensor.from_shares(a, precision=0) b = ArithmeticSharedTensor.from_shares(b, precision=0) c = ArithmeticSharedTensor.from_shares(c, precision=0) return a, b, c
def wrap_rng(size): """Generate random shared tensor of given size and sharing of its wraps""" generator = TTPClient.get().generator r = generate_random_ring_element(size, generator=generator) if comm.get().get_rank() == 0: # Request theta_r from TTP theta_r = TTPClient.get().ttp_request("wraps", size) else: theta_r = generate_random_ring_element(size, generator=generator) r = ArithmeticSharedTensor.from_shares(r, precision=0) theta_r = ArithmeticSharedTensor.from_shares(theta_r, precision=0) return r, theta_r
def square(size): """Generate square double of given size""" generator = TTPClient.get().generator r = generate_random_ring_element(size, generator=generator) if comm.get().get_rank() == 0: # Request r2 from TTP r2 = TTPClient.get().ttp_request("square", size) else: r2 = generate_random_ring_element(size, generator=generator) r = ArithmeticSharedTensor.from_shares(r, precision=0) r2 = ArithmeticSharedTensor.from_shares(r2, precision=0) return r, r2
def wrap_rng(size, num_parties): """Generate random shared tensor of given size and sharing of its wraps""" r = [generate_random_ring_element(size) for _ in range(num_parties)] theta_r = count_wraps(r) shares = comm.get().scatter(r, src=0) r = ArithmeticSharedTensor.from_shares(shares, precision=0) theta_r = ArithmeticSharedTensor(theta_r, precision=0, src=0) return r, theta_r
def randperm(tensor_size, encoder=None): """ Generate `tensor_size[:-1]` random ArithmeticSharedTensor permutations of the first `tensor_size[-1]` whole numbers """ generator = TTPClient.get().generator if comm.get().get_rank() == 0: # Request samples from TTP samples = TTPClient.get().ttp_request("randperm", tensor_size) else: samples = generate_random_ring_element(tensor_size, generator=generator) return ArithmeticSharedTensor.from_shares(samples)
def rand(*sizes, encoder=None): """Generate random ArithmeticSharedTensor uniform on [0, 1]""" generator = TTPClient.get().generator if isinstance(sizes, torch.Size): sizes = tuple(sizes) if isinstance(sizes[0], torch.Size): sizes = tuple(sizes[0]) if comm.get().get_rank() == 0: # Request samples from TTP samples = TTPClient.get().ttp_request("rand", *sizes, encoder=encoder) else: samples = generate_random_ring_element(sizes, generator=generator) return ArithmeticSharedTensor.from_shares(samples)
def B2A_rng(size): """Generate random bit tensor as arithmetic and binary shared tensors""" generator = TTPClient.get().generator # generate random bit rB = generate_kbit_random_tensor(size, bitlength=1, generator=generator) if comm.get().get_rank() == 0: # Request rA from TTP rA = TTPClient.get().ttp_request("B2A", size) else: rA = generate_random_ring_element(size, generator=generator) rA = ArithmeticSharedTensor.from_shares(rA, precision=0) rB = BinarySharedTensor.from_shares(rB) return rA, rB
def test_wraps(self): num_parties = int(self.world_size) size = (5, 5) # Generate random sharing with internal value get_random_test_tensor() zero_shares = generate_random_ring_element((num_parties, *size)) zero_shares = zero_shares - zero_shares.roll(1, dims=0) shares = list(zero_shares.unbind(0)) shares[0] += get_random_test_tensor(size=size, is_float=False) # Note: This test relies on count_wraps function being correct reference = count_wraps(shares) # Sync shares between parties share = comm.get().scatter(shares, 0) encrypted_tensor = ArithmeticSharedTensor.from_shares(share) encrypted_wraps = encrypted_tensor.wraps() test_passed = (encrypted_wraps.reveal() == reference ).sum() == reference.nelement() self.assertTrue(test_passed, "%d-party wraps failed" % num_parties)