Exemplo n.º 1
0
def max(self, dim=None, keepdim=False, one_hot=True):
    """Returns the maximum value of all elements in the input tensor."""
    method = cfg.functions.max_method
    if dim is None:
        if method in ["log_reduction", "double_log_reduction"]:
            # max_result can be obtained directly
            max_result = _max_helper_all_tree_reductions(self, method=method)
        else:
            # max_result needs to be obtained through argmax
            with cfg.temp_override({"functions.max_method": method}):
                argmax_result = self.argmax(one_hot=True)
            max_result = self.mul(argmax_result).sum()
        return max_result
    else:
        argmax_result, max_result = _argmax_helper(self,
                                                   dim=dim,
                                                   one_hot=True,
                                                   method=method,
                                                   _return_max=True)
        if max_result is None:
            max_result = (self * argmax_result).sum(dim=dim, keepdim=keepdim)
        if keepdim:
            max_result = (max_result.unsqueeze(dim)
                          if max_result.dim() < self.dim() else max_result)
        if one_hot:
            return max_result, argmax_result
        else:
            return (
                max_result,
                _one_hot_to_index(argmax_result, dim, keepdim, self.device),
            )
Exemplo n.º 2
0
def log(self, input_in_01=False):
    r"""
    Approximates the natural logarithm using 8th order modified
    Householder iterations. This approximation is accurate within 2% relative
    error on [0.0001, 250].

    Iterations are computed by: :math:`h = 1 - x * exp(-y_n)`

    .. math::

        y_{n+1} = y_n - \sum_k^{order}\frac{h^k}{k}

    Args:
        input_in_01 (bool) : Allows a user to indicate that the input is in the domain [0, 1],
            causing the function optimize for this domain. This is useful for computing
            log-probabilities for entropy functions.

            We shift the domain of convergence by a constant :math:`a` using the following identity:

            .. math::

                \ln{u} = \ln {au} - \ln{a}

            Since the domain of convergence for CrypTen's log() function is approximately [1e-4, 1e2],
            we can set :math:`a=100`.

    Configuration parameters:
        iterations (int): number of Householder iterations for the approximation
        exp_iterations (int): number of iterations for limit approximation of exp
        order (int): number of polynomial terms used (order of Householder approx)
    """
    if input_in_01:
        return log(self.mul(100)) - 4.605170

    # Initialization to a decent estimate (found by qualitative inspection):
    #                ln(x) = x/120 - 20exp(-2x - 1.0) + 3.0
    iterations = cfg.functions.log_iterations
    exp_iterations = cfg.functions.log_exp_iterations
    order = cfg.functions.log_order

    term1 = self.div(120)
    term2 = exp(self.mul(2).add(1.0).neg()).mul(20)
    y = term1 - term2 + 3.0

    # 8th order Householder iterations
    with cfg.temp_override({"functions.exp_iterations": exp_iterations}):
        for _ in range(iterations):
            h = 1 - self * exp(-y)
            y -= h.polynomial([1 / (i + 1) for i in range(order)])
    return y
Exemplo n.º 3
0
def softmax(self, dim, **kwargs):
    r"""Compute the softmax of a tensor's elements along a given dimension"""
    # 0-d case
    if self.dim() == 0:
        assert dim == 0, "Improper dim argument"
        return self.new(torch.ones_like((self.data)))

    if self.size(dim) == 1:
        return self.new(torch.ones_like(self.data))

    maximum_value = self.max(dim, keepdim=True)[0]
    logits = self - maximum_value
    numerator = logits.exp()
    with cfg.temp_override({"functions.reciprocal_all_pos": True}):
        inv_denominator = numerator.sum(dim, keepdim=True).reciprocal()
    return numerator * inv_denominator
Exemplo n.º 4
0
def sigmoid(self):
    r"""Computes the sigmoid function using the following definition

    .. math::
        \sigma(x) = (1 + e^{-x})^{-1}

    If a valid method is given, this function will compute sigmoid
        using that method:

    "chebyshev" - computes tanh via Chebyshev approximation with
        truncation and uses the identity:

    .. math::
        \sigma(x) = \frac{1}{2}tanh(\frac{x}{2}) + \frac{1}{2}

    "reciprocal" - computes sigmoid using :math:`1 + e^{-x}` and computing
        the reciprocal

    """  # noqa: W605
    method = cfg.functions.sigmoid_tanh_method

    if method == "chebyshev":
        tanh_approx = tanh(self.div(2))
        return tanh_approx.div(2) + 0.5
    elif method == "reciprocal":
        ltz = self._ltz()
        sign = 1 - 2 * ltz

        pos_input = self.mul(sign)
        denominator = pos_input.neg().exp().add(1)

        # TODO: Set these with configurable parameters
        with cfg.temp_override({
                "functions.exp_iterations": 9,
                "functions.reciprocal_nr_iters": 3,
                "functions.reciprocal_all_pos": True,
                "functions.reciprocal_initial": 0.75,
        }):
            pos_output = denominator.reciprocal()

        result = pos_output.where(1 - ltz, 1 - pos_output)
        # TODO: Support addition with different encoder scales
        # result = pos_output + ltz - 2 * pos_output * ltz
        return result
    else:
        raise ValueError(f"Unrecognized method {method} for sigmoid")
Exemplo n.º 5
0
    def validation_function(*args, **kwargs):
        with cfg.temp_override({"debug.validation_mode": False}):
            # Compute crypten result
            result_enc = func(*args, **kwargs)
            result = (
                result_enc.get_plain_text()
                if crypten.is_encrypted_tensor(result_enc)
                else result_enc
            )

            args = list(args)

            # Compute torch result for corresponding function
            for i, arg in enumerate(args):
                if crypten.is_encrypted_tensor(arg):
                    args[i] = args[i].get_plain_text()

            kwargs.pop("input_in_01", None)
            for key, value in kwargs.items():
                if crypten.is_encrypted_tensor(value):
                    kwargs[key] = value.get_plain_text()
            reference = getattr(self.get_plain_text(), func_name)(*args, **kwargs)

            # TODO: Validate properties - Issue is tuples can contain encrypted tensors
            if not torch.is_tensor(reference):
                return result_enc

            # Check sizes match
            if result.size() != reference.size():
                crypten_log(
                    f"Size mismatch: Expected {reference.size()} but got {result.size()}"
                )
                raise ValueError(f"Function {func_name} returned incorrect size")

            # Check that results match
            diff = (result - reference).abs_()
            norm_diff = diff.div(result.abs() + reference.abs()).abs_()
            test_passed = norm_diff.le(tolerance) + diff.le(tolerance * 0.1)
            test_passed = test_passed.gt(0).all().item() == 1
            if not test_passed:
                crypten_log(f"Function {func_name} returned incorrect values")
                crypten_log("Result %s" % result)
                crypten_log("Result - Reference = %s" % (result - reference))
                raise ValueError(f"Function {func_name} returned incorrect values")

        return result_enc
Exemplo n.º 6
0
    def test_config(self):
        """Checks setting configuartion with config manager works"""
        # Set the config directly
        crypten.init()

        cfgs = [
            "functions.exp_iterations",
            "functions.max_method",
        ]

        for _cfg in cfgs:
            cfg[_cfg] = 10
            self.assertTrue(cfg[_cfg] == 10, "cfg.set failed")

            # Set with a context manager
            with cfg.temp_override({_cfg: 3}):
                self.assertTrue(cfg[_cfg] == 3,
                                "temp_override failed to set values")
            self.assertTrue(cfg[_cfg] == 10, "temp_override values persist")
Exemplo n.º 7
0
def _max_helper_accelerated_cascade(enc_tensor, dim=None):
    """Returns max along dimension `dim` using the accelerated cascading algorithm"""
    if enc_tensor.dim() == 0:
        return enc_tensor
    input, dim_used = enc_tensor, dim
    if dim is None:
        dim_used = 0
        input = enc_tensor.flatten()
    n = input.size(dim_used)  # number of items in the dimension
    if n < 3:
        with cfg.temp_override({"functions.max_method": "pairwise"}):
            enc_max, enc_argmax = enc_tensor.max(dim=dim_used)
            return enc_max
    steps = int(math.log(math.log(math.log(n)))) + 1
    enc_tensor_reduced = _compute_pairwise_comparisons_for_steps(
        enc_tensor, dim_used, steps)
    enc_max = _max_helper_double_log_reduction(enc_tensor_reduced,
                                               dim=dim_used)
    return enc_max
Exemplo n.º 8
0
def _max_helper_log_reduction(enc_tensor, dim=None):
    """Returns max along dim `dim` using the log_reduction algorithm"""
    if enc_tensor.dim() == 0:
        return enc_tensor
    input, dim_used = enc_tensor, dim
    if dim is None:
        dim_used = 0
        input = enc_tensor.flatten()
    n = input.size(dim_used)  # number of items in the dimension
    steps = int(math.log(n))
    enc_tensor_reduced = _compute_pairwise_comparisons_for_steps(
        input, dim_used, steps)

    # compute max over the resulting reduced tensor with n^2 algorithm
    # note that the resulting one-hot vector we get here finds maxes only
    # over the reduced vector in enc_tensor_reduced, so we won't use it
    with cfg.temp_override({"functions.max_method": "pairwise"}):
        enc_max_vec, enc_one_hot_reduced = enc_tensor_reduced.max(dim=dim_used)
    return enc_max_vec
Exemplo n.º 9
0
def _max_helper_double_log_recursive(enc_tensor, dim):
    """Recursive subroutine for computing max via double log reduction algorithm"""
    n = enc_tensor.size(dim)
    # compute integral sqrt(n) and the integer number of sqrt(n) size
    # vectors that can be extracted from n
    sqrt_n = int(math.sqrt(n))
    count_sqrt_n = n // sqrt_n
    # base case for recursion: no further splits along dimension dim
    if n == 1:
        return enc_tensor
    else:
        # split into tensors that can be broken into vectors of size sqrt(n)
        # and the remainder of the tensor
        size_arr = [sqrt_n * count_sqrt_n, n % sqrt_n]
        split_enc_tensor, remainder = enc_tensor.split(size_arr, dim=dim)

        # reshape such that dim holds sqrt_n and dim+1 holds count_sqrt_n
        updated_enc_tensor_size = [
            sqrt_n, enc_tensor.size(dim + 1) * count_sqrt_n
        ]
        size_arr = [enc_tensor.size(i) for i in range(enc_tensor.dim())]
        size_arr[dim], size_arr[dim + 1] = updated_enc_tensor_size
        split_enc_tensor = split_enc_tensor.reshape(size_arr)

        # recursive call on reshaped tensor
        split_enc_max = _max_helper_double_log_recursive(split_enc_tensor, dim)

        # reshape the result to have the (dim+1)th dimension as before
        # and concatenate the previously computed remainder
        size_arr[dim], size_arr[dim +
                                1] = [count_sqrt_n,
                                      enc_tensor.size(dim + 1)]
        enc_max_tensor = split_enc_max.reshape(size_arr)
        full_max_tensor = crypten.cat([enc_max_tensor, remainder], dim=dim)

        # call the max function on dimension dim
        with cfg.temp_override({"functions.max_method": "pairwise"}):
            enc_max, enc_arg_max = full_max_tensor.max(dim=dim, keepdim=True)
        # compute max over the resulting reduced tensor with n^2 algorithm
        # note that the resulting one-hot vector we get here finds maxes only
        # over the reduced vector in enc_tensor_reduced, so we won't use it
        return enc_max
Exemplo n.º 10
0
def reciprocal(self, input_in_01=False):
    r"""
    Args:
        input_in_01 (bool) : Allows a user to indicate that the input is in the range [0, 1],
                    causing the function optimize for this range. This is useful for improving
                    the accuracy of functions on probabilities (e.g. entropy functions).

    Methods:
        'NR' : `Newton-Raphson`_ method computes the reciprocal using iterations
                of :math:`x_{i+1} = (2x_i - self * x_i^2)` and uses
                :math:`3*exp(1 - 2x) + 0.003` as an initial guess by default

        'log' : Computes the reciprocal of the input from the observation that:
                :math:`x^{-1} = exp(-log(x))`

    Configuration params:
        reciprocal_method (str):  One of 'NR' or 'log'.
        reciprocal_nr_iters (int):  determines the number of Newton-Raphson iterations to run
                        for the `NR` method
        reciprocal_log_iters (int): determines the number of Householder
            iterations to run when computing logarithms for the `log` method
        reciprocal_all_pos (bool): determines whether all elements of the
            input are known to be positive, which optimizes the step of
            computing the sign of the input.
        reciprocal_initial (tensor): sets the initial value for the
            Newton-Raphson method. By default, this will be set to :math:
            `3*exp(-(x-.5)) + 0.003` as this allows the method to converge over
            a fairly large domain

    .. _Newton-Raphson:
        https://en.wikipedia.org/wiki/Newton%27s_method
    """
    pos_override = {"functions.reciprocal_all_pos": True}
    if input_in_01:
        with cfg.temp_override(pos_override):
            rec = reciprocal(self.mul(64)).mul(64)
        return rec

    # Get config options
    method = cfg.functions.reciprocal_method
    all_pos = cfg.functions.reciprocal_all_pos
    initial = cfg.functions.reciprocal_initial

    if not all_pos:
        sgn = self.sign()
        pos = sgn * self
        with cfg.temp_override(pos_override):
            return sgn * reciprocal(pos)

    if method == "NR":
        nr_iters = cfg.functions.reciprocal_nr_iters
        if initial is None:
            # Initialization to a decent estimate (found by qualitative inspection):
            #                1/x = 3exp(1 - 2x) + 0.003
            result = 3 * (1 - 2 * self).exp() + 0.003
        else:
            result = initial
        for _ in range(nr_iters):
            if hasattr(result, "square"):
                result += result - result.square().mul_(self)
            else:
                result = 2 * result - result * result * self
        return result
    elif method == "log":
        log_iters = cfg.functions.reciprocal_log_iters
        with cfg.temp_override({"functions.log_iters": log_iters}):
            return exp(-log(self))
    else:
        raise ValueError(
            f"Invalid method {method} given for reciprocal function")