def _chebyshev_polynomials(self, terms): r"""Evaluates odd degree Chebyshev polynomials at x Chebyshev Polynomials of the first kind are defined as .. math:: P_0(x) = 1, P_1(x) = x, P_n(x) = 2 P_{n - 1}(x) - P_{n-2}(x) Args: self (MPCTensor): input at which polynomials are evaluated terms (int): highest degree of Chebyshev polynomials. Must be even and at least 6. Returns: MPCTensor of polynomials evaluated at self of shape `(terms, *self)` """ if terms % 2 != 0 or terms < 6: raise ValueError("Chebyshev terms must be even and >= 6") polynomials = [self.clone()] y = 4 * self.square() - 2 z = y - 1 polynomials.append(z.mul(self)) for k in range(2, terms // 2): next_polynomial = y * polynomials[k - 1] - polynomials[k - 2] polynomials.append(next_polynomial) return crypten.stack(polynomials)
def _argmax_helper_pairwise(enc_tensor, dim=None): """Returns 1 for all elements that have the highest value in the appropriate dimension of the tensor. Uses O(n^2) comparisons and a constant number of rounds of communication """ dim = -1 if dim is None else dim row_length = enc_tensor.size(dim) if enc_tensor.size(dim) > 1 else 2 # Copy each row (length - 1) times to compare to each other row a = enc_tensor.expand(row_length - 1, *enc_tensor.size()) # Generate cyclic permutations for each row b = crypten.stack( [enc_tensor.roll(i + 1, dims=dim) for i in range(row_length - 1)]) # Use either prod or sum & comparison depending on size if row_length - 1 < torch.iinfo(torch.long).bits * 2: pairwise_comparisons = a.ge(b, _scale=False) result = pairwise_comparisons.prod(0) result.share *= enc_tensor.encoder._scale result.encoder = enc_tensor.encoder else: # Sum of columns with all 1s will have value equal to (length - 1). # Using ge() since it is slightly faster than eq() pairwise_comparisons = a.ge(b) result = pairwise_comparisons.sum(0).ge(row_length - 1) return result, None
def forward(ctx, input): pred, target = input ctx.save_multiple_for_backward([pred, target]) ctx.mark_non_differentiable(target) log_pos, log_neg = crypten.stack([pred, 1.0 - pred]).log().unbind(dim=0) loss_values = target * log_pos + ((1.0 - target) * log_neg) return loss_values.sum().div(-target.nelement())
def polynomial(self, coeffs, func="mul"): """Computes a polynomial function on a tensor with given coefficients, `coeffs`, that can be a list of values or a 1-D tensor. Coefficients should be ordered from the order 1 (linear) term first, ending with the highest order term. (Constant is not included). """ # Coefficient input type-checking if isinstance(coeffs, list): coeffs = torch.tensor(coeffs, device=self.device) assert is_tensor(coeffs) or crypten.is_encrypted_tensor( coeffs), "Polynomial coefficients must be a list or tensor" assert coeffs.dim( ) == 1, "Polynomial coefficients must be a 1-D tensor" # Handle linear case if coeffs.size(0) == 1: return self.mul(coeffs) # Compute terms of polynomial using exponentially growing tree terms = crypten.stack([self, self.square()]) while terms.size(0) < coeffs.size(0): highest_term = terms.index_select( 0, torch.tensor(terms.size(0) - 1, device=self.device)) new_terms = getattr(terms, func)(highest_term) terms = crypten.cat([terms, new_terms]) # Resize the coefficients for broadcast terms = terms[:coeffs.size(0)] for _ in range(terms.dim() - 1): coeffs = coeffs.unsqueeze(1) # Multiply terms by coefficients and sum return terms.mul(coeffs).sum(0)
def _truncate_tanh(self): """Truncates `out` to +/- clip_value when self is outside [-clip_value, clip_value].""" clip_value = config.sigmoid_tanh_clip_value if clip_value is None: return self too_high, too_low = crypten.stack([self, -self]).gt(clip_value) in_range = 1 - too_high - too_low return (too_high - too_low) * clip_value + self.mul(in_range)
def _truncate_tanh(self, maxval, out): """Truncates `out` to +/-1 when self is outside [-maxval, maxval]. Args: maxval (int): interval width outside of which to truncate out (torch.tensor or MPCTensor): tensor to truncate """ too_high, too_low = crypten.stack([self, -self]).gt(maxval) in_range = -too_high - too_low + 1 out = too_high - too_low + out.mul(in_range) return out
def forward(ctx, input, skip_forward=False): pred, target = input assert pred.size() ctx.save_multiple_for_backward([pred, target]) ctx.mark_non_differentiable(target) if skip_forward: return pred # Compute full forward pass log_pos, log_neg = crypten.stack([pred, 1.0 - pred]).log().unbind(dim=0) loss_values = target * log_pos + ((1.0 - target) * log_neg) return -loss_values.mean()
def forward(ctx, input, skip_forward=False): logit, target = input sigmoid_out = logit.sigmoid() assert ( sigmoid_out.size() == target.size() ), "Incorrect input sizes for binary_cross_entropy_with_logits" ctx.mark_non_differentiable(target) ctx.save_multiple_for_backward([target, sigmoid_out]) if skip_forward: return sigmoid_out # Compute full forward pass log_pos, log_neg = ( crypten.stack([sigmoid_out, 1.0 - sigmoid_out]).log().unbind(dim=0) ) loss_values = target * log_pos + ((1.0 - target) * log_neg) return -loss_values.mean()
def hardtanh(self, min_value=-1, max_value=1): r"""Applies the HardTanh function element-wise HardTanh is defined as: .. math:: \text{HardTanh}(x) = \begin{cases} 1 & \text{ if } x > 1 \\ -1 & \text{ if } x < -1 \\ x & \text{ otherwise } \\ \end{cases} The range of the linear region :math:`[-1, 1]` can be adjusted using :attr:`min_val` and :attr:`max_val`. Args: min_val: minimum value of the linear region range. Default: -1 max_val: maximum value of the linear region range. Default: 1 """ intermediate = crypten.stack([self - min_value, self - max_value]).relu() intermediate = intermediate[0].sub(intermediate[1]) return intermediate.add_(min_value)
def crypten_collate(batch): elem = batch[0] elem_type = type(elem) if isinstance(elem, crypten.CrypTensor): return crypten.stack(list(batch), dim=0) elif isinstance(elem, typing.Sequence): size = len(elem) assert all( len(b) == size for b in batch), "each element in list of batch should be of equal size" transposed = zip(*batch) return [crypten_collate(samples) for samples in transposed] elif isinstance(elem, typing.Mapping): return {key: crypten_collate([b[key] for b in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple return elem_type(*(crypten_collate(samples) for samples in zip(*batch))) return "crypten_collate: batch must contain CrypTensor, dicts or lists; found {}".format( elem_type)
def forward(ctx, input, dim=0): ctx.save_for_backward(dim) return crypten.stack(input, dim=dim)
def backward(ctx, grad_output): pred, target = ctx.saved_tensors rec_pos, rec_neg = crypten.stack([pred, 1.0 - pred]).reciprocal().unbind(dim=0) grad = (rec_neg * (1.0 - target)) - rec_pos * target return grad.div_(target.nelement()).mul_(grad_output)