def _compute_pairwise_comparisons_for_steps(input_tensor, dim, steps): """ Helper function that does pairwise comparisons by splitting input tensor for `steps` number of steps along dimension `dim`. """ enc_tensor_reduced = input_tensor.clone() for _ in range(steps): m = enc_tensor_reduced.size(dim) x, y, remainder = enc_tensor_reduced.split([m // 2, m // 2, m % 2], dim=dim) pairwise_max = crypten.where(x >= y, x, y) enc_tensor_reduced = crypten.cat([pairwise_max, remainder], dim=dim) return enc_tensor_reduced
def test_where(self): """Test that crypten.where properly conditions""" sizes = [(10,), (5, 10), (1, 5, 10)] y_types = [lambda x: x, crypten.cryptensor] for size, y_type in itertools.product(sizes, y_types): tensor1 = get_random_test_tensor(size=size, is_float=True) encrypted_tensor1 = crypten.cryptensor(tensor1) tensor2 = get_random_test_tensor(size=size, is_float=True) encrypted_tensor2 = y_type(tensor2) condition_tensor = ( get_random_test_tensor(max_value=1, size=size, is_float=False) + 1 ) condition_encrypted = crypten.cryptensor(condition_tensor) condition_bool = condition_tensor.bool() reference_out = torch.where(condition_bool, tensor1, tensor2) encrypted_out = crypten.where( condition_bool, encrypted_tensor1, encrypted_tensor2 ) y_is_private = crypten.is_encrypted_tensor(tensor2) self._check( encrypted_out, reference_out, f"{'private' if y_is_private else 'public'} y " "where failed with public condition", ) encrypted_out = encrypted_tensor1.where( condition_encrypted, encrypted_tensor2 ) self._check( encrypted_out, reference_out, f"{'private' if y_is_private else 'public'} y " "where failed with private condition", )