def max(self, dim=None, keepdim=False, one_hot=True): """Returns the maximum value of all elements in the input tensor.""" method = cfg.functions.max_method if dim is None: if method in ["log_reduction", "double_log_reduction"]: # max_result can be obtained directly max_result = _max_helper_all_tree_reductions(self, method=method) else: # max_result needs to be obtained through argmax with cfg.temp_override({"functions.max_method": method}): argmax_result = self.argmax(one_hot=True) max_result = self.mul(argmax_result).sum() return max_result else: argmax_result, max_result = _argmax_helper(self, dim=dim, one_hot=True, method=method, _return_max=True) if max_result is None: max_result = (self * argmax_result).sum(dim=dim, keepdim=keepdim) if keepdim: max_result = (max_result.unsqueeze(dim) if max_result.dim() < self.dim() else max_result) if one_hot: return max_result, argmax_result else: return ( max_result, _one_hot_to_index(argmax_result, dim, keepdim, self.device), )
def log(self, input_in_01=False): r""" Approximates the natural logarithm using 8th order modified Householder iterations. This approximation is accurate within 2% relative error on [0.0001, 250]. Iterations are computed by: :math:`h = 1 - x * exp(-y_n)` .. math:: y_{n+1} = y_n - \sum_k^{order}\frac{h^k}{k} Args: input_in_01 (bool) : Allows a user to indicate that the input is in the domain [0, 1], causing the function optimize for this domain. This is useful for computing log-probabilities for entropy functions. We shift the domain of convergence by a constant :math:`a` using the following identity: .. math:: \ln{u} = \ln {au} - \ln{a} Since the domain of convergence for CrypTen's log() function is approximately [1e-4, 1e2], we can set :math:`a=100`. Configuration parameters: iterations (int): number of Householder iterations for the approximation exp_iterations (int): number of iterations for limit approximation of exp order (int): number of polynomial terms used (order of Householder approx) """ if input_in_01: return log(self.mul(100)) - 4.605170 # Initialization to a decent estimate (found by qualitative inspection): # ln(x) = x/120 - 20exp(-2x - 1.0) + 3.0 iterations = cfg.functions.log_iterations exp_iterations = cfg.functions.log_exp_iterations order = cfg.functions.log_order term1 = self.div(120) term2 = exp(self.mul(2).add(1.0).neg()).mul(20) y = term1 - term2 + 3.0 # 8th order Householder iterations with cfg.temp_override({"functions.exp_iterations": exp_iterations}): for _ in range(iterations): h = 1 - self * exp(-y) y -= h.polynomial([1 / (i + 1) for i in range(order)]) return y
def softmax(self, dim, **kwargs): r"""Compute the softmax of a tensor's elements along a given dimension""" # 0-d case if self.dim() == 0: assert dim == 0, "Improper dim argument" return self.new(torch.ones_like((self.data))) if self.size(dim) == 1: return self.new(torch.ones_like(self.data)) maximum_value = self.max(dim, keepdim=True)[0] logits = self - maximum_value numerator = logits.exp() with cfg.temp_override({"functions.reciprocal_all_pos": True}): inv_denominator = numerator.sum(dim, keepdim=True).reciprocal() return numerator * inv_denominator
def sigmoid(self): r"""Computes the sigmoid function using the following definition .. math:: \sigma(x) = (1 + e^{-x})^{-1} If a valid method is given, this function will compute sigmoid using that method: "chebyshev" - computes tanh via Chebyshev approximation with truncation and uses the identity: .. math:: \sigma(x) = \frac{1}{2}tanh(\frac{x}{2}) + \frac{1}{2} "reciprocal" - computes sigmoid using :math:`1 + e^{-x}` and computing the reciprocal """ # noqa: W605 method = cfg.functions.sigmoid_tanh_method if method == "chebyshev": tanh_approx = tanh(self.div(2)) return tanh_approx.div(2) + 0.5 elif method == "reciprocal": ltz = self._ltz() sign = 1 - 2 * ltz pos_input = self.mul(sign) denominator = pos_input.neg().exp().add(1) # TODO: Set these with configurable parameters with cfg.temp_override({ "functions.exp_iterations": 9, "functions.reciprocal_nr_iters": 3, "functions.reciprocal_all_pos": True, "functions.reciprocal_initial": 0.75, }): pos_output = denominator.reciprocal() result = pos_output.where(1 - ltz, 1 - pos_output) # TODO: Support addition with different encoder scales # result = pos_output + ltz - 2 * pos_output * ltz return result else: raise ValueError(f"Unrecognized method {method} for sigmoid")
def validation_function(*args, **kwargs): with cfg.temp_override({"debug.validation_mode": False}): # Compute crypten result result_enc = func(*args, **kwargs) result = ( result_enc.get_plain_text() if crypten.is_encrypted_tensor(result_enc) else result_enc ) args = list(args) # Compute torch result for corresponding function for i, arg in enumerate(args): if crypten.is_encrypted_tensor(arg): args[i] = args[i].get_plain_text() kwargs.pop("input_in_01", None) for key, value in kwargs.items(): if crypten.is_encrypted_tensor(value): kwargs[key] = value.get_plain_text() reference = getattr(self.get_plain_text(), func_name)(*args, **kwargs) # TODO: Validate properties - Issue is tuples can contain encrypted tensors if not torch.is_tensor(reference): return result_enc # Check sizes match if result.size() != reference.size(): crypten_log( f"Size mismatch: Expected {reference.size()} but got {result.size()}" ) raise ValueError(f"Function {func_name} returned incorrect size") # Check that results match diff = (result - reference).abs_() norm_diff = diff.div(result.abs() + reference.abs()).abs_() test_passed = norm_diff.le(tolerance) + diff.le(tolerance * 0.1) test_passed = test_passed.gt(0).all().item() == 1 if not test_passed: crypten_log(f"Function {func_name} returned incorrect values") crypten_log("Result %s" % result) crypten_log("Result - Reference = %s" % (result - reference)) raise ValueError(f"Function {func_name} returned incorrect values") return result_enc
def test_config(self): """Checks setting configuartion with config manager works""" # Set the config directly crypten.init() cfgs = [ "functions.exp_iterations", "functions.max_method", ] for _cfg in cfgs: cfg[_cfg] = 10 self.assertTrue(cfg[_cfg] == 10, "cfg.set failed") # Set with a context manager with cfg.temp_override({_cfg: 3}): self.assertTrue(cfg[_cfg] == 3, "temp_override failed to set values") self.assertTrue(cfg[_cfg] == 10, "temp_override values persist")
def _max_helper_accelerated_cascade(enc_tensor, dim=None): """Returns max along dimension `dim` using the accelerated cascading algorithm""" if enc_tensor.dim() == 0: return enc_tensor input, dim_used = enc_tensor, dim if dim is None: dim_used = 0 input = enc_tensor.flatten() n = input.size(dim_used) # number of items in the dimension if n < 3: with cfg.temp_override({"functions.max_method": "pairwise"}): enc_max, enc_argmax = enc_tensor.max(dim=dim_used) return enc_max steps = int(math.log(math.log(math.log(n)))) + 1 enc_tensor_reduced = _compute_pairwise_comparisons_for_steps( enc_tensor, dim_used, steps) enc_max = _max_helper_double_log_reduction(enc_tensor_reduced, dim=dim_used) return enc_max
def _max_helper_log_reduction(enc_tensor, dim=None): """Returns max along dim `dim` using the log_reduction algorithm""" if enc_tensor.dim() == 0: return enc_tensor input, dim_used = enc_tensor, dim if dim is None: dim_used = 0 input = enc_tensor.flatten() n = input.size(dim_used) # number of items in the dimension steps = int(math.log(n)) enc_tensor_reduced = _compute_pairwise_comparisons_for_steps( input, dim_used, steps) # compute max over the resulting reduced tensor with n^2 algorithm # note that the resulting one-hot vector we get here finds maxes only # over the reduced vector in enc_tensor_reduced, so we won't use it with cfg.temp_override({"functions.max_method": "pairwise"}): enc_max_vec, enc_one_hot_reduced = enc_tensor_reduced.max(dim=dim_used) return enc_max_vec
def _max_helper_double_log_recursive(enc_tensor, dim): """Recursive subroutine for computing max via double log reduction algorithm""" n = enc_tensor.size(dim) # compute integral sqrt(n) and the integer number of sqrt(n) size # vectors that can be extracted from n sqrt_n = int(math.sqrt(n)) count_sqrt_n = n // sqrt_n # base case for recursion: no further splits along dimension dim if n == 1: return enc_tensor else: # split into tensors that can be broken into vectors of size sqrt(n) # and the remainder of the tensor size_arr = [sqrt_n * count_sqrt_n, n % sqrt_n] split_enc_tensor, remainder = enc_tensor.split(size_arr, dim=dim) # reshape such that dim holds sqrt_n and dim+1 holds count_sqrt_n updated_enc_tensor_size = [ sqrt_n, enc_tensor.size(dim + 1) * count_sqrt_n ] size_arr = [enc_tensor.size(i) for i in range(enc_tensor.dim())] size_arr[dim], size_arr[dim + 1] = updated_enc_tensor_size split_enc_tensor = split_enc_tensor.reshape(size_arr) # recursive call on reshaped tensor split_enc_max = _max_helper_double_log_recursive(split_enc_tensor, dim) # reshape the result to have the (dim+1)th dimension as before # and concatenate the previously computed remainder size_arr[dim], size_arr[dim + 1] = [count_sqrt_n, enc_tensor.size(dim + 1)] enc_max_tensor = split_enc_max.reshape(size_arr) full_max_tensor = crypten.cat([enc_max_tensor, remainder], dim=dim) # call the max function on dimension dim with cfg.temp_override({"functions.max_method": "pairwise"}): enc_max, enc_arg_max = full_max_tensor.max(dim=dim, keepdim=True) # compute max over the resulting reduced tensor with n^2 algorithm # note that the resulting one-hot vector we get here finds maxes only # over the reduced vector in enc_tensor_reduced, so we won't use it return enc_max
def reciprocal(self, input_in_01=False): r""" Args: input_in_01 (bool) : Allows a user to indicate that the input is in the range [0, 1], causing the function optimize for this range. This is useful for improving the accuracy of functions on probabilities (e.g. entropy functions). Methods: 'NR' : `Newton-Raphson`_ method computes the reciprocal using iterations of :math:`x_{i+1} = (2x_i - self * x_i^2)` and uses :math:`3*exp(1 - 2x) + 0.003` as an initial guess by default 'log' : Computes the reciprocal of the input from the observation that: :math:`x^{-1} = exp(-log(x))` Configuration params: reciprocal_method (str): One of 'NR' or 'log'. reciprocal_nr_iters (int): determines the number of Newton-Raphson iterations to run for the `NR` method reciprocal_log_iters (int): determines the number of Householder iterations to run when computing logarithms for the `log` method reciprocal_all_pos (bool): determines whether all elements of the input are known to be positive, which optimizes the step of computing the sign of the input. reciprocal_initial (tensor): sets the initial value for the Newton-Raphson method. By default, this will be set to :math: `3*exp(-(x-.5)) + 0.003` as this allows the method to converge over a fairly large domain .. _Newton-Raphson: https://en.wikipedia.org/wiki/Newton%27s_method """ pos_override = {"functions.reciprocal_all_pos": True} if input_in_01: with cfg.temp_override(pos_override): rec = reciprocal(self.mul(64)).mul(64) return rec # Get config options method = cfg.functions.reciprocal_method all_pos = cfg.functions.reciprocal_all_pos initial = cfg.functions.reciprocal_initial if not all_pos: sgn = self.sign() pos = sgn * self with cfg.temp_override(pos_override): return sgn * reciprocal(pos) if method == "NR": nr_iters = cfg.functions.reciprocal_nr_iters if initial is None: # Initialization to a decent estimate (found by qualitative inspection): # 1/x = 3exp(1 - 2x) + 0.003 result = 3 * (1 - 2 * self).exp() + 0.003 else: result = initial for _ in range(nr_iters): if hasattr(result, "square"): result += result - result.square().mul_(self) else: result = 2 * result - result * result * self return result elif method == "log": log_iters = cfg.functions.reciprocal_log_iters with cfg.temp_override({"functions.log_iters": log_iters}): return exp(-log(self)) else: raise ValueError( f"Invalid method {method} given for reciprocal function")