Exemple #1
0
def _compute_pairwise_comparisons_for_steps(input_tensor, dim, steps):
    """
    Helper function that does pairwise comparisons by splitting input
    tensor for `steps` number of steps along dimension `dim`.
    """
    enc_tensor_reduced = input_tensor.clone()
    for _ in range(steps):
        m = enc_tensor_reduced.size(dim)
        x, y, remainder = enc_tensor_reduced.split([m // 2, m // 2, m % 2],
                                                   dim=dim)
        pairwise_max = crypten.where(x >= y, x, y)
        enc_tensor_reduced = crypten.cat([pairwise_max, remainder], dim=dim)
    return enc_tensor_reduced
    def test_where(self):
        """Test that crypten.where properly conditions"""
        sizes = [(10,), (5, 10), (1, 5, 10)]
        y_types = [lambda x: x, crypten.cryptensor]

        for size, y_type in itertools.product(sizes, y_types):
            tensor1 = get_random_test_tensor(size=size, is_float=True)
            encrypted_tensor1 = crypten.cryptensor(tensor1)
            tensor2 = get_random_test_tensor(size=size, is_float=True)
            encrypted_tensor2 = y_type(tensor2)

            condition_tensor = (
                get_random_test_tensor(max_value=1, size=size, is_float=False) + 1
            )
            condition_encrypted = crypten.cryptensor(condition_tensor)
            condition_bool = condition_tensor.bool()

            reference_out = torch.where(condition_bool, tensor1, tensor2)

            encrypted_out = crypten.where(
                condition_bool, encrypted_tensor1, encrypted_tensor2
            )

            y_is_private = crypten.is_encrypted_tensor(tensor2)
            self._check(
                encrypted_out,
                reference_out,
                f"{'private' if y_is_private else 'public'} y "
                "where failed with public condition",
            )

            encrypted_out = encrypted_tensor1.where(
                condition_encrypted, encrypted_tensor2
            )
            self._check(
                encrypted_out,
                reference_out,
                f"{'private' if y_is_private else 'public'} y "
                "where failed with private condition",
            )