示例#1
0
    def backward(ctx: Dict[str, Any], grad: MPCTensor) -> Tuple[MPCTensor]:
        """Perform the backward pass for the matrix multiplication operation.

        Args:
            ctx (Dict[str, Any]): Context used to retrieve the information for the backward pass
            grad (MPCTensor): The gradient that came from the child nodes

        Returns:
            (x_grad, y_grad) (Tuple[MPCTensor]): The gradients passed to the X and Y nodes.

        Raises:
            ValueError: if gradient shape does not match X and Y shape
        """
        x, y = ctx["x"], ctx["y"]

        x_grad = grad.clone()
        y_grad = grad.clone()

        if len(x.shape) < 2:
            x = x.unsqueeze(0)
            x_grad = x_grad.unsqueeze(0)

        if len(y.shape) < 2:
            y = y.unsqueeze(1)
            y_grad = y_grad.unsqueeze(1)

        x_grad = x_grad @ y.t()
        y_grad = x.t() @ y_grad

        if x.shape != x_grad.shape or y.shape != y_grad.shape:
            raise ValueError(
                "The gradient shape and the shape of X and Y should be the same!"
            )

        return x_grad, y_grad
示例#2
0
def _chebyshev_polynomials(tensor: MPCTensor, terms: int) -> MPCTensor:
    r"""Evaluates odd degree Chebyshev polynomials at x.

    Chebyshev Polynomials of the first kind are defined as
        P_0(x) = 1, P_1(x) = x, P_n(x) = 2 P_{n - 1}(x) - P_{n-2}(x)

    Args:
        tensor (MPCTensor): input at which polynomials are evaluated
        terms (int): highest degree of Chebyshev polynomials. Must be even and at least 6.

    Returns:
        MPCTensor: polynomials evaluated at self of shape `(terms, *self)`

    Raises:
        ValueError: if terms < 6 or is not divisible by 2
    """
    if terms % 2 != 0 or terms < 6:
        raise ValueError("Chebyshev terms must be even and >= 6")

    polynomials = [tensor.clone()]
    y = 4 * tensor * tensor - 2
    z = y - 1
    polynomials.append(z * tensor)

    for k in range(2, terms // 2):
        next_polynomial = y * polynomials[k - 1] - polynomials[k - 2]
        polynomials.append(next_polynomial)

    return stack(polynomials)