예제 #1
0
    def PRZS(*size, device=None):
        """
        Generate a Pseudo-random Sharing of Zero (using arithmetic shares)

        This function does so by generating `n` numbers across `n` parties with
        each number being held by exactly 2 parties. One of these parties adds
        this number while the other subtracts this number.
        """
        from crypten import generators

        tensor = ArithmeticSharedTensor(src=SENTINEL)
        if device is None:
            device = torch.device("cpu")
        elif isinstance(device, str):
            device = torch.device(device)
        g0 = generators["prev"][device]
        g1 = generators["next"][device]
        current_share = generate_random_ring_element(*size,
                                                     generator=g0,
                                                     device=device)
        next_share = generate_random_ring_element(*size,
                                                  generator=g1,
                                                  device=device)
        tensor.share = current_share - next_share
        return tensor
예제 #2
0
    def generate_additive_triple(size0,
                                 size1,
                                 op,
                                 device=None,
                                 *args,
                                 **kwargs):
        """Generate multiplicative triples of given sizes"""
        generator = TTPClient.get().get_generator(device=device)

        a = generate_random_ring_element(size0,
                                         generator=generator,
                                         device=device)
        b = generate_random_ring_element(size1,
                                         generator=generator,
                                         device=device)
        if comm.get().get_rank() == 0:
            # Request c from TTP
            c = TTPClient.get().ttp_request("additive", device, size0, size1,
                                            op, *args, **kwargs)
        else:
            # TODO: Compute size without executing computation
            c_size = getattr(torch, op)(a, b, *args, **kwargs).size()
            c = generate_random_ring_element(c_size,
                                             generator=generator,
                                             device=device)

        a = ArithmeticSharedTensor.from_shares(a, precision=0)
        b = ArithmeticSharedTensor.from_shares(b, precision=0)
        c = ArithmeticSharedTensor.from_shares(c, precision=0)

        return a, b, c
예제 #3
0
    def generate_additive_triple(size0, size1, op, *args, **kwargs):
        """Generate multiplicative triples of given sizes"""
        a = generate_random_ring_element(size0)
        b = generate_random_ring_element(size1)
        c = getattr(torch, op)(a, b, *args, **kwargs)

        a = ArithmeticSharedTensor(a, precision=0, src=0)
        b = ArithmeticSharedTensor(b, precision=0, src=0)
        c = ArithmeticSharedTensor(c, precision=0, src=0)

        return a, b, c
예제 #4
0
    def PRZS(*size):
        """
        Generate a Pseudo-random Sharing of Zero (using arithmetic shares)

        This function does so by generating `n` numbers across `n` parties with
        each number being held by exactly 2 parties. One of these parties adds
        this number while the other subtracts this number.
        """
        tensor = ArithmeticSharedTensor(src=SENTINEL)
        current_share = generate_random_ring_element(*size, generator=comm.get().g0)
        next_share = generate_random_ring_element(*size, generator=comm.get().g1)
        tensor.share = current_share - next_share
        return tensor
예제 #5
0
    def wrap_rng(size):
        """Generate random shared tensor of given size and sharing of its wraps"""
        generator = TTPClient.get().generator

        r = generate_random_ring_element(size, generator=generator)
        if comm.get().get_rank() == 0:
            # Request theta_r from TTP
            theta_r = TTPClient.get().ttp_request("wraps", size)
        else:
            theta_r = generate_random_ring_element(size, generator=generator)

        r = ArithmeticSharedTensor.from_shares(r, precision=0)
        theta_r = ArithmeticSharedTensor.from_shares(theta_r, precision=0)
        return r, theta_r
예제 #6
0
    def square(size):
        """Generate square double of given size"""
        generator = TTPClient.get().generator

        r = generate_random_ring_element(size, generator=generator)
        if comm.get().get_rank() == 0:
            # Request r2 from TTP
            r2 = TTPClient.get().ttp_request("square", size)
        else:
            r2 = generate_random_ring_element(size, generator=generator)

        r = ArithmeticSharedTensor.from_shares(r, precision=0)
        r2 = ArithmeticSharedTensor.from_shares(r2, precision=0)
        return r, r2
예제 #7
0
 def PRSS(*size, device=None):
     """
     Generates a Pseudo-random Secret Share from a set of random arithmetic shares
     """
     share = generate_random_ring_element(*size, device=device)
     tensor = ArithmeticSharedTensor.from_shares(share=share)
     return tensor
예제 #8
0
    def wraps(self, size):
        r = [
            generate_random_ring_element(size, generator=g)
            for g in self.generators
        ]
        theta_r = count_wraps(r)

        return theta_r - self._get_additive_PRSS(size, remove_rank=True)
예제 #9
0
파일: test_nn.py 프로젝트: nthparty/CrypTen
 def _generate_parameters(size):
     num_parties = int(self.world_size)
     reference = get_random_test_tensor(size=size, is_float=False)
     zero_shares = generate_random_ring_element((num_parties, *size))
     zero_shares = zero_shares - zero_shares.roll(1, dims=0)
     shares = list(zero_shares.unbind(0))
     shares[0] += reference
     return shares, reference
예제 #10
0
 def _get_additive_PRSS(self, size, remove_rank=False):
     """
     Generates a plaintext value from a set of random additive secret shares
     generated by each party
     """
     gens = self.generators[1:] if remove_rank else self.generators
     result = torch.stack(
         [generate_random_ring_element(size, generator=g) for g in gens])
     return result.sum(0)
예제 #11
0
    def square(size, device=None):
        """Generate square double of given size"""
        r = generate_random_ring_element(size, device=device)
        r2 = r.mul(r)

        # Stack to vectorize scatter function
        stacked = torch_stack([r, r2])
        stacked = ArithmeticSharedTensor(stacked, precision=0, src=0)
        return stacked[0], stacked[1]
예제 #12
0
    def wrap_rng(size, num_parties):
        """Generate random shared tensor of given size and sharing of its wraps"""
        r = [generate_random_ring_element(size) for _ in range(num_parties)]
        theta_r = count_wraps(r)

        shares = comm.get().scatter(r, src=0)
        r = ArithmeticSharedTensor.from_shares(shares, precision=0)
        theta_r = ArithmeticSharedTensor(theta_r, precision=0, src=0)

        return r, theta_r
예제 #13
0
 def randperm(tensor_size, encoder=None):
     """
     Generate `tensor_size[:-1]` random ArithmeticSharedTensor permutations of
     the first `tensor_size[-1]` whole numbers
     """
     generator = TTPClient.get().generator
     if comm.get().get_rank() == 0:
         # Request samples from TTP
         samples = TTPClient.get().ttp_request("randperm", tensor_size)
     else:
         samples = generate_random_ring_element(tensor_size, generator=generator)
     return ArithmeticSharedTensor.from_shares(samples)
예제 #14
0
 def _get_additive_PRSS(self, size, remove_rank=False):
     """
     Generates a plaintext value from a set of random additive secret shares
     generated by each party
     """
     gens = self._get_generators(device=self.device)
     if remove_rank:
         gens = gens[1:]
     result = None
     for idx, g in enumerate(gens):
         elem = generate_random_ring_element(size, generator=g, device=g.device)
         result = elem if idx == 0 else result + elem
     return result
예제 #15
0
 def _get_additive_PRSS(self, size, remove_rank=False):
     """
     Generates a plaintext value from a set of random additive secret shares
     generated by each party
     """
     gens = self._get_generators(device=self.device)
     if remove_rank:
         gens = gens[1:]
     result = torch_stack([
         generate_random_ring_element(size, generator=g, device=g.device)
         for g in gens
     ])
     return result.sum(0)
예제 #16
0
    def rand(*sizes, encoder=None):
        """Generate random ArithmeticSharedTensor uniform on [0, 1]"""
        generator = TTPClient.get().generator

        if isinstance(sizes, torch.Size):
            sizes = tuple(sizes)

        if isinstance(sizes[0], torch.Size):
            sizes = tuple(sizes[0])

        if comm.get().get_rank() == 0:
            # Request samples from TTP
            samples = TTPClient.get().ttp_request("rand", *sizes, encoder=encoder)
        else:
            samples = generate_random_ring_element(sizes, generator=generator)
        return ArithmeticSharedTensor.from_shares(samples)
예제 #17
0
    def B2A_rng(size):
        """Generate random bit tensor as arithmetic and binary shared tensors"""
        generator = TTPClient.get().generator

        # generate random bit
        rB = generate_kbit_random_tensor(size, bitlength=1, generator=generator)

        if comm.get().get_rank() == 0:
            # Request rA from TTP
            rA = TTPClient.get().ttp_request("B2A", size)
        else:
            rA = generate_random_ring_element(size, generator=generator)

        rA = ArithmeticSharedTensor.from_shares(rA, precision=0)
        rB = BinarySharedTensor.from_shares(rB)
        return rA, rB
예제 #18
0
    def test_wraps(self):
        num_parties = int(self.world_size)

        size = (5, 5)

        # Generate random sharing with internal value get_random_test_tensor()
        zero_shares = generate_random_ring_element((num_parties, *size))
        zero_shares = zero_shares - zero_shares.roll(1, dims=0)
        shares = list(zero_shares.unbind(0))
        shares[0] += get_random_test_tensor(size=size, is_float=False)

        # Note: This test relies on count_wraps function being correct
        reference = count_wraps(shares)

        # Sync shares between parties
        share = comm.get().scatter(shares, 0)

        encrypted_tensor = ArithmeticSharedTensor.from_shares(share)
        encrypted_wraps = encrypted_tensor.wraps()

        test_passed = (encrypted_wraps.reveal() == reference
                       ).sum() == reference.nelement()
        self.assertTrue(test_passed, "%d-party wraps failed" % num_parties)