def test_index_add(self): """Test index_add function of encrypted tensor""" index_add_functions = ["index_add", "index_add_"] tensor_size1 = [5, 5, 5, 5] index = torch.tensor([1, 2, 3, 4, 4, 2, 1, 3], dtype=torch.long) for dimension in range(0, 4): tensor_size2 = [5, 5, 5, 5] tensor_size2[dimension] = index.size(0) for func in index_add_functions: for tensor_type in [lambda x: x, ArithmeticSharedTensor]: tensor1 = get_random_test_tensor(size=tensor_size1, is_float=True) tensor2 = get_random_test_tensor(size=tensor_size2, is_float=True) encrypted = ArithmeticSharedTensor(tensor1) encrypted2 = tensor_type(tensor2) reference = getattr(tensor1, func)(dimension, index, tensor2) encrypted_out = getattr(encrypted, func)( dimension, index, encrypted2 ) private = tensor_type == ArithmeticSharedTensor self._check( encrypted_out, reference, "%s %s failed" % ("private" if private else "public", func), ) if func.endswith("_"): # Check in-place index_add worked self._check( encrypted, reference, "%s %s failed" % ("private" if private else "public", func), ) else: # Check original is not modified self._check( encrypted, tensor1, "%s %s failed" % ("private" if private else "public", func), )
def test_prod(self): """Tests prod reduction on encrypted tensor.""" tensor = get_random_test_tensor(size=(3, 3), max_value=3, is_float=False) encrypted = ArithmeticSharedTensor(tensor) self._check(encrypted.prod(), tensor.prod().float(), "prod failed") # test with dim argument tensor = get_random_test_tensor(size=(5, 5, 5), max_value=3, is_float=False) encrypted = ArithmeticSharedTensor(tensor) for dim in [0, 1, 2]: reference = tensor.prod(dim).float() encrypted_out = encrypted.prod(dim) self._check(encrypted_out, reference, "prod failed")
def test_squeeze(self): tensor = get_random_test_tensor(is_float=True) for dim in [0, 1, 2]: # Test unsqueeze reference = tensor.unsqueeze(dim) encrypted = ArithmeticSharedTensor(tensor) encrypted_out = encrypted.unsqueeze(dim) self._check(encrypted_out, reference, "unsqueeze failed") # Test squeeze encrypted = ArithmeticSharedTensor(tensor.unsqueeze(0)) encrypted_out = encrypted.squeeze() self._check(encrypted_out, reference.squeeze(), "squeeze failed") # Check that the encrypted_out and encrypted point to the same # thing. encrypted_out[0:2] = torch.tensor([0.0, 1.0], dtype=torch.float) ref = encrypted.squeeze().get_plain_text() self._check(encrypted_out, ref, "squeeze failed")
def test_scatter(self): """Test scatter/scatter_add function of encrypted tensor""" funcs = ["scatter", "scatter_", "scatter_add", "scatter_add_"] sizes = [(5, 5), (5, 5, 5), (5, 5, 5, 5)] for func in funcs: for size in sizes: for tensor_type in [lambda x: x, ArithmeticSharedTensor]: for dim in range(len(size)): tensor1 = get_random_test_tensor(size=size, is_float=True) tensor2 = get_random_test_tensor(size=size, is_float=True) index = get_random_test_tensor(size=size, is_float=False) index = index.abs().clamp(0, 4) encrypted = ArithmeticSharedTensor(tensor1) encrypted2 = tensor_type(tensor2) reference = getattr(tensor1, func)(dim, index, tensor2) encrypted_out = getattr(encrypted, func)(dim, index, encrypted2) private = tensor_type == ArithmeticSharedTensor self._check( encrypted_out, reference, "%s %s failed" % ("private" if private else "public", func), ) if func.endswith("_"): # Check in-place scatter/scatter-add modified input self._check( encrypted, reference, "%s %s failed to modify input" % ("private" if private else "public", func), ) else: # Check original is not modified self._check( encrypted, tensor1, "%s %s unintendedly modified input" % ("private" if private else "public", func), )
def _conv1d(self, signal_size, in_channels): """Test convolution of encrypted tensor with public/private tensors.""" nbatches = [1, 3] kernel_sizes = [1, 2, 3] ochannels = [1, 3, 6] paddings = [0, 1] strides = [1, 2] for func_name in ["conv1d", "conv_transpose1d"]: for kernel_type in [lambda x: x, ArithmeticSharedTensor]: for ( batches, kernel_size, out_channels, padding, stride, ) in itertools.product( nbatches, kernel_sizes, ochannels, paddings, strides ): input_size = (batches, in_channels, signal_size) signal = get_random_test_tensor(size=input_size, is_float=True) if func_name == "conv1d": k_size = (out_channels, in_channels, kernel_size) else: k_size = (in_channels, out_channels, kernel_size) kernel = get_random_test_tensor(size=k_size, is_float=True) reference = getattr(F, func_name)( signal, kernel, padding=padding, stride=stride ) encrypted_signal = ArithmeticSharedTensor(signal) encrypted_kernel = kernel_type(kernel) encrypted_conv = getattr(encrypted_signal, func_name)( encrypted_kernel, padding=padding, stride=stride ) self._check(encrypted_conv, reference, f"{func_name} failed")
def test_wraps(self): num_parties = int(self.world_size) size = (5, 5) # Generate random sharing with internal value get_random_test_tensor() zero_shares = generate_random_ring_element((num_parties, *size)) zero_shares = zero_shares - zero_shares.roll(1, dims=0) shares = list(zero_shares.unbind(0)) shares[0] += get_random_test_tensor(size=size, is_float=False) # Note: This test relies on count_wraps function being correct reference = count_wraps(shares) # Sync shares between parties share = comm.get().scatter(shares, 0) encrypted_tensor = ArithmeticSharedTensor.from_shares(share) encrypted_wraps = encrypted_tensor.wraps() test_passed = (encrypted_wraps.reveal() == reference ).sum() == reference.nelement() self.assertTrue(test_passed, "%d-party wraps failed" % num_parties)
def rand(*sizes, device=None): """Generate random ArithmeticSharedTensor uniform on [0, 1]""" samples = torch.rand(*sizes, device=device) return ArithmeticSharedTensor(samples, src=0)
def test_src_failure(self): """Tests that out-of-bounds src fails as expected""" tensor = get_random_test_tensor(is_float=True) for src in [None, "abc", -2, self.world_size]: with self.assertRaises(AssertionError): ArithmeticSharedTensor(tensor, src=src)
def test_broadcast(self): """Test broadcast functionality.""" arithmetic_functions = ["add", "sub", "mul", "div"] arithmetic_sizes = [ (), (1,), (2,), (1, 1), (1, 2), (2, 1), (2, 2), (1, 1, 1), (1, 1, 2), (1, 2, 1), (2, 1, 1), (2, 2, 2), (1, 1, 1, 1), (1, 1, 1, 2), (1, 1, 2, 1), (1, 2, 1, 1), (2, 1, 1, 1), (2, 2, 2, 2), ] matmul_sizes = [(1, 1), (1, 5), (5, 1), (5, 5)] batch_dims = [(), (1,), (5,), (1, 1), (1, 5), (5, 5)] for tensor_type in [lambda x: x, ArithmeticSharedTensor]: for func in arithmetic_functions: for size1, size2 in itertools.combinations(arithmetic_sizes, 2): tensor1 = get_random_test_tensor(size=size1, is_float=True) tensor2 = get_random_test_tensor(size=size2, is_float=True) # ArithmeticSharedTensors can't divide by negative # private values - MPCTensor overrides this to allow negatives # multiply denom by 10 to avoid division by small num if func == "div" and tensor_type == ArithmeticSharedTensor: continue encrypted1 = ArithmeticSharedTensor(tensor1) encrypted2 = tensor_type(tensor2) reference = getattr(tensor1, func)(tensor2) encrypted_out = getattr(encrypted1, func)(encrypted2) private = isinstance(encrypted2, ArithmeticSharedTensor) self._check( encrypted_out, reference, "%s %s broadcast failed" % ("private" if private else "public", func), ) for size in matmul_sizes: for batch1, batch2 in itertools.combinations(batch_dims, 2): size1 = (*batch1, *size) size2 = (*batch2, *size) tensor1 = get_random_test_tensor(size=size1, is_float=True) tensor2 = get_random_test_tensor(size=size2, is_float=True) tensor2 = tensor1.transpose(-2, -1) encrypted1 = ArithmeticSharedTensor(tensor1) encrypted2 = tensor_type(tensor2) reference = tensor1.matmul(tensor2) encrypted_out = encrypted1.matmul(encrypted2) private = isinstance(encrypted2, ArithmeticSharedTensor) self._check( encrypted_out, reference, "%s matmul broadcast failed" % ("private" if private else "public"), )
def _conv2d(self, image_size, in_channels): """Test convolution of encrypted tensor with public/private tensors.""" nbatches = [1, 3] kernel_sizes = [(1, 1), (2, 2), (2, 3)] ochannels = [1, 3, 6] paddings = [0, 1, (0, 1)] strides = [1, 2, (1, 2)] dilations = [1, 2] groupings = [1, 2] for func_name in ["conv2d", "conv_transpose2d"]: for kernel_type in [lambda x: x, ArithmeticSharedTensor]: for ( batches, kernel_size, out_channels, padding, stride, dilation, groups, ) in itertools.product( nbatches, kernel_sizes, ochannels, paddings, strides, dilations, groupings, ): # sample input: input_size = (batches, in_channels * groups, *image_size) input = get_random_test_tensor(size=input_size, is_float=True) # sample filtering kernel: if func_name == "conv2d": k_size = (out_channels * groups, in_channels, *kernel_size) else: k_size = (in_channels * groups, out_channels, *kernel_size) kernel = get_random_test_tensor(size=k_size, is_float=True) # perform filtering: encr_matrix = ArithmeticSharedTensor(input) encr_kernel = kernel_type(kernel) encr_conv = getattr(encr_matrix, func_name)( encr_kernel, padding=padding, stride=stride, dilation=dilation, groups=groups, ) # check that result is correct: reference = getattr(F, func_name)( input, kernel, padding=padding, stride=stride, dilation=dilation, groups=groups, ) self._check(encr_conv, reference, "%s failed" % func_name)
def test_arithmetic(self): """Tests arithmetic functions on encrypted tensor.""" arithmetic_functions = ["add", "add_", "sub", "sub_", "mul", "mul_"] for func in arithmetic_functions: for tensor_type in [lambda x: x, ArithmeticSharedTensor]: tensor1 = get_random_test_tensor(is_float=True) tensor2 = get_random_test_tensor(is_float=True) encrypted = ArithmeticSharedTensor(tensor1) encrypted2 = tensor_type(tensor2) reference = getattr(tensor1, func)(tensor2) encrypted_out = getattr(encrypted, func)(encrypted2) private_type = tensor_type == ArithmeticSharedTensor self._check( encrypted_out, reference, "%s %s failed" % ("private" if private_type else "public", func), ) if "_" in func: # Check in-place op worked self._check( encrypted, reference, "%s %s failed" % ("private" if private_type else "public", func), ) else: # Check original is not modified self._check( encrypted, tensor1, "%s %s failed" % ( "private" if tensor_type == ArithmeticSharedTensor else "public", func, ), ) # Check encrypted vector with encrypted scalar works. tensor1 = get_random_test_tensor(is_float=True) tensor2 = get_random_test_tensor(is_float=True, size=(1,)) encrypted1 = ArithmeticSharedTensor(tensor1) encrypted2 = ArithmeticSharedTensor(tensor2) reference = getattr(tensor1, func)(tensor2) encrypted_out = getattr(encrypted1, func)(encrypted2) self._check(encrypted_out, reference, "private %s failed" % func) tensor = get_random_test_tensor(is_float=True) reference = tensor * tensor encrypted = ArithmeticSharedTensor(tensor) encrypted_out = encrypted.square() self._check(encrypted_out, reference, "square failed") # Test radd, rsub, and rmul reference = 2 + tensor1 encrypted = ArithmeticSharedTensor(tensor1) encrypted_out = 2 + encrypted self._check(encrypted_out, reference, "right add failed") reference = 2 - tensor1 encrypted_out = 2 - encrypted self._check(encrypted_out, reference, "right sub failed") reference = 2 * tensor1 encrypted_out = 2 * encrypted self._check(encrypted_out, reference, "right mul failed")
def bernoulli(tensor): """Generate random ArithmeticSharedTensor bernoulli on {0, 1}""" samples = torch.bernoulli(tensor) return ArithmeticSharedTensor(samples, src=0)