Пример #1
0
def _chebyshev_polynomials(self, terms):
    r"""Evaluates odd degree Chebyshev polynomials at x

    Chebyshev Polynomials of the first kind are defined as

    .. math::
        P_0(x) = 1, P_1(x) = x, P_n(x) = 2 P_{n - 1}(x) - P_{n-2}(x)

    Args:
        self (MPCTensor): input at which polynomials are evaluated
        terms (int): highest degree of Chebyshev polynomials.
                        Must be even and at least 6.
    Returns:
        MPCTensor of polynomials evaluated at self of shape `(terms, *self)`
    """
    if terms % 2 != 0 or terms < 6:
        raise ValueError("Chebyshev terms must be even and >= 6")

    polynomials = [self.clone()]
    y = 4 * self.square() - 2
    z = y - 1
    polynomials.append(z.mul(self))

    for k in range(2, terms // 2):
        next_polynomial = y * polynomials[k - 1] - polynomials[k - 2]
        polynomials.append(next_polynomial)

    return crypten.stack(polynomials)
Пример #2
0
def _argmax_helper_pairwise(enc_tensor, dim=None):
    """Returns 1 for all elements that have the highest value in the appropriate
    dimension of the tensor. Uses O(n^2) comparisons and a constant number of
    rounds of communication
    """
    dim = -1 if dim is None else dim
    row_length = enc_tensor.size(dim) if enc_tensor.size(dim) > 1 else 2

    # Copy each row (length - 1) times to compare to each other row
    a = enc_tensor.expand(row_length - 1, *enc_tensor.size())

    # Generate cyclic permutations for each row
    b = crypten.stack(
        [enc_tensor.roll(i + 1, dims=dim) for i in range(row_length - 1)])

    # Use either prod or sum & comparison depending on size
    if row_length - 1 < torch.iinfo(torch.long).bits * 2:
        pairwise_comparisons = a.ge(b, _scale=False)
        result = pairwise_comparisons.prod(0)
        result.share *= enc_tensor.encoder._scale
        result.encoder = enc_tensor.encoder
    else:
        # Sum of columns with all 1s will have value equal to (length - 1).
        # Using ge() since it is slightly faster than eq()
        pairwise_comparisons = a.ge(b)
        result = pairwise_comparisons.sum(0).ge(row_length - 1)
    return result, None
Пример #3
0
 def forward(ctx, input):
     pred, target = input
     ctx.save_multiple_for_backward([pred, target])
     ctx.mark_non_differentiable(target)
     log_pos, log_neg = crypten.stack([pred, 1.0 - pred]).log().unbind(dim=0)
     loss_values = target * log_pos + ((1.0 - target) * log_neg)
     return loss_values.sum().div(-target.nelement())
Пример #4
0
    def polynomial(self, coeffs, func="mul"):
        """Computes a polynomial function on a tensor with given coefficients,
        `coeffs`, that can be a list of values or a 1-D tensor.

        Coefficients should be ordered from the order 1 (linear) term first,
        ending with the highest order term. (Constant is not included).
        """
        # Coefficient input type-checking
        if isinstance(coeffs, list):
            coeffs = torch.tensor(coeffs, device=self.device)
        assert is_tensor(coeffs) or crypten.is_encrypted_tensor(
            coeffs), "Polynomial coefficients must be a list or tensor"
        assert coeffs.dim(
        ) == 1, "Polynomial coefficients must be a 1-D tensor"

        # Handle linear case
        if coeffs.size(0) == 1:
            return self.mul(coeffs)

        # Compute terms of polynomial using exponentially growing tree
        terms = crypten.stack([self, self.square()])
        while terms.size(0) < coeffs.size(0):
            highest_term = terms.index_select(
                0, torch.tensor(terms.size(0) - 1, device=self.device))
            new_terms = getattr(terms, func)(highest_term)
            terms = crypten.cat([terms, new_terms])

        # Resize the coefficients for broadcast
        terms = terms[:coeffs.size(0)]
        for _ in range(terms.dim() - 1):
            coeffs = coeffs.unsqueeze(1)

        # Multiply terms by coefficients and sum
        return terms.mul(coeffs).sum(0)
Пример #5
0
def _truncate_tanh(self):
    """Truncates `out` to +/- clip_value when self is outside [-clip_value, clip_value]."""
    clip_value = config.sigmoid_tanh_clip_value
    if clip_value is None:
        return self
    too_high, too_low = crypten.stack([self, -self]).gt(clip_value)
    in_range = 1 - too_high - too_low
    return (too_high - too_low) * clip_value + self.mul(in_range)
Пример #6
0
    def _truncate_tanh(self, maxval, out):
        """Truncates `out` to +/-1 when self is outside [-maxval, maxval].

        Args:
            maxval (int): interval width outside of which to truncate
            out (torch.tensor or MPCTensor): tensor to truncate
        """
        too_high, too_low = crypten.stack([self, -self]).gt(maxval)
        in_range = -too_high - too_low + 1
        out = too_high - too_low + out.mul(in_range)
        return out
Пример #7
0
    def forward(ctx, input, skip_forward=False):
        pred, target = input
        assert pred.size()
        ctx.save_multiple_for_backward([pred, target])
        ctx.mark_non_differentiable(target)
        if skip_forward:
            return pred

        # Compute full forward pass
        log_pos, log_neg = crypten.stack([pred, 1.0 - pred]).log().unbind(dim=0)
        loss_values = target * log_pos + ((1.0 - target) * log_neg)
        return -loss_values.mean()
Пример #8
0
    def forward(ctx, input, skip_forward=False):
        logit, target = input
        sigmoid_out = logit.sigmoid()
        assert (
            sigmoid_out.size() == target.size()
        ), "Incorrect input sizes for binary_cross_entropy_with_logits"
        ctx.mark_non_differentiable(target)
        ctx.save_multiple_for_backward([target, sigmoid_out])
        if skip_forward:
            return sigmoid_out

        # Compute full forward pass
        log_pos, log_neg = (
            crypten.stack([sigmoid_out, 1.0 - sigmoid_out]).log().unbind(dim=0)
        )
        loss_values = target * log_pos + ((1.0 - target) * log_neg)
        return -loss_values.mean()
Пример #9
0
def hardtanh(self, min_value=-1, max_value=1):
    r"""Applies the HardTanh function element-wise

    HardTanh is defined as:

    .. math::
        \text{HardTanh}(x) = \begin{cases}
            1 & \text{ if } x > 1 \\
            -1 & \text{ if } x < -1 \\
            x & \text{ otherwise } \\
        \end{cases}

    The range of the linear region :math:`[-1, 1]` can be adjusted using
    :attr:`min_val` and :attr:`max_val`.

    Args:
        min_val: minimum value of the linear region range. Default: -1
        max_val: maximum value of the linear region range. Default: 1
    """
    intermediate = crypten.stack([self - min_value, self - max_value]).relu()
    intermediate = intermediate[0].sub(intermediate[1])
    return intermediate.add_(min_value)
Пример #10
0
def crypten_collate(batch):
    elem = batch[0]
    elem_type = type(elem)

    if isinstance(elem, crypten.CrypTensor):
        return crypten.stack(list(batch), dim=0)

    elif isinstance(elem, typing.Sequence):
        size = len(elem)
        assert all(
            len(b) == size for b in
            batch), "each element in list of batch should be of equal size"
        transposed = zip(*batch)
        return [crypten_collate(samples) for samples in transposed]

    elif isinstance(elem, typing.Mapping):
        return {key: crypten_collate([b[key] for b in batch]) for key in elem}

    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(crypten_collate(samples)
                           for samples in zip(*batch)))

    return "crypten_collate: batch must contain CrypTensor, dicts or lists; found {}".format(
        elem_type)
Пример #11
0
 def forward(ctx, input, dim=0):
     ctx.save_for_backward(dim)
     return crypten.stack(input, dim=dim)
Пример #12
0
 def backward(ctx, grad_output):
     pred, target = ctx.saved_tensors
     rec_pos, rec_neg = crypten.stack([pred, 1.0 - pred]).reciprocal().unbind(dim=0)
     grad = (rec_neg * (1.0 - target)) - rec_pos * target
     return grad.div_(target.nelement()).mul_(grad_output)