コード例 #1
0
 def test_matmul(self):
     """Test matrix multiplication."""
     for tensor_type in [lambda x: x, ArithmeticSharedTensor]:
         tensor = get_random_test_tensor(max_value=7, is_float=True)
         for width in range(2, tensor.nelement()):
             matrix_size = (tensor.nelement(), width)
             matrix = get_random_test_tensor(
                 max_value=7, size=matrix_size, is_float=True
             )
             reference = tensor.matmul(matrix)
             encrypted_tensor = ArithmeticSharedTensor(tensor)
             matrix = tensor_type(matrix)
             encrypted_tensor = encrypted_tensor.matmul(matrix)
             private_type = tensor_type == ArithmeticSharedTensor
             self._check(
                 encrypted_tensor,
                 reference,
                 "Private-%s matrix multiplication failed"
                 % ("private" if private_type else "public"),
             )
コード例 #2
0
    def test_broadcast(self):
        """Test broadcast functionality."""
        arithmetic_functions = ["add", "sub", "mul", "div"]
        arithmetic_sizes = [
            (),
            (1,),
            (2,),
            (1, 1),
            (1, 2),
            (2, 1),
            (2, 2),
            (1, 1, 1),
            (1, 1, 2),
            (1, 2, 1),
            (2, 1, 1),
            (2, 2, 2),
            (1, 1, 1, 1),
            (1, 1, 1, 2),
            (1, 1, 2, 1),
            (1, 2, 1, 1),
            (2, 1, 1, 1),
            (2, 2, 2, 2),
        ]
        matmul_sizes = [(1, 1), (1, 5), (5, 1), (5, 5)]
        batch_dims = [(), (1,), (5,), (1, 1), (1, 5), (5, 5)]

        for tensor_type in [lambda x: x, ArithmeticSharedTensor]:
            for func in arithmetic_functions:
                for size1, size2 in itertools.combinations(arithmetic_sizes, 2):
                    tensor1 = get_random_test_tensor(size=size1, is_float=True)
                    tensor2 = get_random_test_tensor(size=size2, is_float=True)

                    # ArithmeticSharedTensors can't divide by negative
                    # private values - MPCTensor overrides this to allow negatives
                    # multiply denom by 10 to avoid division by small num
                    if func == "div" and tensor_type == ArithmeticSharedTensor:
                        continue

                    encrypted1 = ArithmeticSharedTensor(tensor1)
                    encrypted2 = tensor_type(tensor2)
                    reference = getattr(tensor1, func)(tensor2)
                    encrypted_out = getattr(encrypted1, func)(encrypted2)

                    private = isinstance(encrypted2, ArithmeticSharedTensor)
                    self._check(
                        encrypted_out,
                        reference,
                        "%s %s broadcast failed"
                        % ("private" if private else "public", func),
                    )

            for size in matmul_sizes:
                for batch1, batch2 in itertools.combinations(batch_dims, 2):
                    size1 = (*batch1, *size)
                    size2 = (*batch2, *size)

                    tensor1 = get_random_test_tensor(size=size1, is_float=True)
                    tensor2 = get_random_test_tensor(size=size2, is_float=True)
                    tensor2 = tensor1.transpose(-2, -1)

                    encrypted1 = ArithmeticSharedTensor(tensor1)
                    encrypted2 = tensor_type(tensor2)

                    reference = tensor1.matmul(tensor2)
                    encrypted_out = encrypted1.matmul(encrypted2)
                    private = isinstance(encrypted2, ArithmeticSharedTensor)
                    self._check(
                        encrypted_out,
                        reference,
                        "%s matmul broadcast failed"
                        % ("private" if private else "public"),
                    )