Ejemplo n.º 1
0
    def generate_additive_triple(size0,
                                 size1,
                                 op,
                                 device=None,
                                 *args,
                                 **kwargs):
        """Generate multiplicative triples of given sizes"""
        generator = TTPClient.get().get_generator(device=device)

        a = generate_random_ring_element(size0,
                                         generator=generator,
                                         device=device)
        b = generate_random_ring_element(size1,
                                         generator=generator,
                                         device=device)
        if comm.get().get_rank() == 0:
            # Request c from TTP
            c = TTPClient.get().ttp_request("additive", device, size0, size1,
                                            op, *args, **kwargs)
        else:
            # TODO: Compute size without executing computation
            c_size = getattr(torch, op)(a, b, *args, **kwargs).size()
            c = generate_random_ring_element(c_size,
                                             generator=generator,
                                             device=device)

        a = ArithmeticSharedTensor.from_shares(a, precision=0)
        b = ArithmeticSharedTensor.from_shares(b, precision=0)
        c = ArithmeticSharedTensor.from_shares(c, precision=0)

        return a, b, c
Ejemplo n.º 2
0
    def wrap_rng(size):
        """Generate random shared tensor of given size and sharing of its wraps"""
        generator = TTPClient.get().generator

        r = generate_random_ring_element(size, generator=generator)
        if comm.get().get_rank() == 0:
            # Request theta_r from TTP
            theta_r = TTPClient.get().ttp_request("wraps", size)
        else:
            theta_r = generate_random_ring_element(size, generator=generator)

        r = ArithmeticSharedTensor.from_shares(r, precision=0)
        theta_r = ArithmeticSharedTensor.from_shares(theta_r, precision=0)
        return r, theta_r
Ejemplo n.º 3
0
    def square(size):
        """Generate square double of given size"""
        generator = TTPClient.get().generator

        r = generate_random_ring_element(size, generator=generator)
        if comm.get().get_rank() == 0:
            # Request r2 from TTP
            r2 = TTPClient.get().ttp_request("square", size)
        else:
            r2 = generate_random_ring_element(size, generator=generator)

        r = ArithmeticSharedTensor.from_shares(r, precision=0)
        r2 = ArithmeticSharedTensor.from_shares(r2, precision=0)
        return r, r2
Ejemplo n.º 4
0
    def wrap_rng(size, num_parties):
        """Generate random shared tensor of given size and sharing of its wraps"""
        r = [generate_random_ring_element(size) for _ in range(num_parties)]
        theta_r = count_wraps(r)

        shares = comm.get().scatter(r, src=0)
        r = ArithmeticSharedTensor.from_shares(shares, precision=0)
        theta_r = ArithmeticSharedTensor(theta_r, precision=0, src=0)

        return r, theta_r
Ejemplo n.º 5
0
 def randperm(tensor_size, encoder=None):
     """
     Generate `tensor_size[:-1]` random ArithmeticSharedTensor permutations of
     the first `tensor_size[-1]` whole numbers
     """
     generator = TTPClient.get().generator
     if comm.get().get_rank() == 0:
         # Request samples from TTP
         samples = TTPClient.get().ttp_request("randperm", tensor_size)
     else:
         samples = generate_random_ring_element(tensor_size, generator=generator)
     return ArithmeticSharedTensor.from_shares(samples)
Ejemplo n.º 6
0
    def rand(*sizes, encoder=None):
        """Generate random ArithmeticSharedTensor uniform on [0, 1]"""
        generator = TTPClient.get().generator

        if isinstance(sizes, torch.Size):
            sizes = tuple(sizes)

        if isinstance(sizes[0], torch.Size):
            sizes = tuple(sizes[0])

        if comm.get().get_rank() == 0:
            # Request samples from TTP
            samples = TTPClient.get().ttp_request("rand", *sizes, encoder=encoder)
        else:
            samples = generate_random_ring_element(sizes, generator=generator)
        return ArithmeticSharedTensor.from_shares(samples)
Ejemplo n.º 7
0
    def B2A_rng(size):
        """Generate random bit tensor as arithmetic and binary shared tensors"""
        generator = TTPClient.get().generator

        # generate random bit
        rB = generate_kbit_random_tensor(size, bitlength=1, generator=generator)

        if comm.get().get_rank() == 0:
            # Request rA from TTP
            rA = TTPClient.get().ttp_request("B2A", size)
        else:
            rA = generate_random_ring_element(size, generator=generator)

        rA = ArithmeticSharedTensor.from_shares(rA, precision=0)
        rB = BinarySharedTensor.from_shares(rB)
        return rA, rB
Ejemplo n.º 8
0
    def test_wraps(self):
        num_parties = int(self.world_size)

        size = (5, 5)

        # Generate random sharing with internal value get_random_test_tensor()
        zero_shares = generate_random_ring_element((num_parties, *size))
        zero_shares = zero_shares - zero_shares.roll(1, dims=0)
        shares = list(zero_shares.unbind(0))
        shares[0] += get_random_test_tensor(size=size, is_float=False)

        # Note: This test relies on count_wraps function being correct
        reference = count_wraps(shares)

        # Sync shares between parties
        share = comm.get().scatter(shares, 0)

        encrypted_tensor = ArithmeticSharedTensor.from_shares(share)
        encrypted_wraps = encrypted_tensor.wraps()

        test_passed = (encrypted_wraps.reveal() == reference
                       ).sum() == reference.nelement()
        self.assertTrue(test_passed, "%d-party wraps failed" % num_parties)