def test_matmul(self): """Test matrix multiplication.""" for tensor_type in [lambda x: x, ArithmeticSharedTensor]: tensor = get_random_test_tensor(max_value=7, is_float=True) for width in range(2, tensor.nelement()): matrix_size = (tensor.nelement(), width) matrix = get_random_test_tensor( max_value=7, size=matrix_size, is_float=True ) reference = tensor.matmul(matrix) encrypted_tensor = ArithmeticSharedTensor(tensor) matrix = tensor_type(matrix) encrypted_tensor = encrypted_tensor.matmul(matrix) private_type = tensor_type == ArithmeticSharedTensor self._check( encrypted_tensor, reference, "Private-%s matrix multiplication failed" % ("private" if private_type else "public"), )
def test_broadcast(self): """Test broadcast functionality.""" arithmetic_functions = ["add", "sub", "mul", "div"] arithmetic_sizes = [ (), (1,), (2,), (1, 1), (1, 2), (2, 1), (2, 2), (1, 1, 1), (1, 1, 2), (1, 2, 1), (2, 1, 1), (2, 2, 2), (1, 1, 1, 1), (1, 1, 1, 2), (1, 1, 2, 1), (1, 2, 1, 1), (2, 1, 1, 1), (2, 2, 2, 2), ] matmul_sizes = [(1, 1), (1, 5), (5, 1), (5, 5)] batch_dims = [(), (1,), (5,), (1, 1), (1, 5), (5, 5)] for tensor_type in [lambda x: x, ArithmeticSharedTensor]: for func in arithmetic_functions: for size1, size2 in itertools.combinations(arithmetic_sizes, 2): tensor1 = get_random_test_tensor(size=size1, is_float=True) tensor2 = get_random_test_tensor(size=size2, is_float=True) # ArithmeticSharedTensors can't divide by negative # private values - MPCTensor overrides this to allow negatives # multiply denom by 10 to avoid division by small num if func == "div" and tensor_type == ArithmeticSharedTensor: continue encrypted1 = ArithmeticSharedTensor(tensor1) encrypted2 = tensor_type(tensor2) reference = getattr(tensor1, func)(tensor2) encrypted_out = getattr(encrypted1, func)(encrypted2) private = isinstance(encrypted2, ArithmeticSharedTensor) self._check( encrypted_out, reference, "%s %s broadcast failed" % ("private" if private else "public", func), ) for size in matmul_sizes: for batch1, batch2 in itertools.combinations(batch_dims, 2): size1 = (*batch1, *size) size2 = (*batch2, *size) tensor1 = get_random_test_tensor(size=size1, is_float=True) tensor2 = get_random_test_tensor(size=size2, is_float=True) tensor2 = tensor1.transpose(-2, -1) encrypted1 = ArithmeticSharedTensor(tensor1) encrypted2 = tensor_type(tensor2) reference = tensor1.matmul(tensor2) encrypted_out = encrypted1.matmul(encrypted2) private = isinstance(encrypted2, ArithmeticSharedTensor) self._check( encrypted_out, reference, "%s matmul broadcast failed" % ("private" if private else "public"), )