def generate_binary_triple(size0, size1, device=None): """Generate binary triples of given size""" generator = TTPClient.get().get_generator(device=device) a = generate_kbit_random_tensor(size0, generator=generator, device=device) b = generate_kbit_random_tensor(size1, generator=generator, device=device) if comm.get().get_rank() == 0: # Request c from TTP c = TTPClient.get().ttp_request("binary", device, size0, size1) else: size2 = torch.broadcast_tensors(a, b)[0].size() c = generate_kbit_random_tensor(size2, generator=generator, device=device) # Stack to vectorize scatter function a = BinarySharedTensor.from_shares(a) b = BinarySharedTensor.from_shares(b) c = BinarySharedTensor.from_shares(c) return a, b, c
def generate_binary_triple(size): """Generate binary triples of given size""" generator = TTPClient.get().generator a = generate_kbit_random_tensor(size, generator=generator) b = generate_kbit_random_tensor(size, generator=generator) if comm.get().get_rank() == 0: # Request c from TTP c = TTPClient.get().ttp_request("binary", size) else: c = generate_kbit_random_tensor(size, generator=generator) # Stack to vectorize scatter function a = BinarySharedTensor.from_shares(a) b = BinarySharedTensor.from_shares(b) c = BinarySharedTensor.from_shares(c) return a, b, c
def B2A_rng(size): """Generate random bit tensor as arithmetic and binary shared tensors""" generator = TTPClient.get().generator # generate random bit rB = generate_kbit_random_tensor(size, bitlength=1, generator=generator) if comm.get().get_rank() == 0: # Request rA from TTP rA = TTPClient.get().ttp_request("B2A", size) else: rA = generate_random_ring_element(size, generator=generator) rA = ArithmeticSharedTensor.from_shares(rA, precision=0) rB = BinarySharedTensor.from_shares(rB) return rA, rB