def backward(ctx: Dict[str, Any], grad: MPCTensor) -> Tuple[MPCTensor]: """Perform the backward pass for the matrix multiplication operation. Args: ctx (Dict[str, Any]): Context used to retrieve the information for the backward pass grad (MPCTensor): The gradient that came from the child nodes Returns: (x_grad, y_grad) (Tuple[MPCTensor]): The gradients passed to the X and Y nodes. Raises: ValueError: if gradient shape does not match X and Y shape """ x, y = ctx["x"], ctx["y"] x_grad = grad.clone() y_grad = grad.clone() if len(x.shape) < 2: x = x.unsqueeze(0) x_grad = x_grad.unsqueeze(0) if len(y.shape) < 2: y = y.unsqueeze(1) y_grad = y_grad.unsqueeze(1) x_grad = x_grad @ y.t() y_grad = x.t() @ y_grad if x.shape != x_grad.shape or y.shape != y_grad.shape: raise ValueError( "The gradient shape and the shape of X and Y should be the same!" ) return x_grad, y_grad
def _chebyshev_polynomials(tensor: MPCTensor, terms: int) -> MPCTensor: r"""Evaluates odd degree Chebyshev polynomials at x. Chebyshev Polynomials of the first kind are defined as P_0(x) = 1, P_1(x) = x, P_n(x) = 2 P_{n - 1}(x) - P_{n-2}(x) Args: tensor (MPCTensor): input at which polynomials are evaluated terms (int): highest degree of Chebyshev polynomials. Must be even and at least 6. Returns: MPCTensor: polynomials evaluated at self of shape `(terms, *self)` Raises: ValueError: if terms < 6 or is not divisible by 2 """ if terms % 2 != 0 or terms < 6: raise ValueError("Chebyshev terms must be even and >= 6") polynomials = [tensor.clone()] y = 4 * tensor * tensor - 2 z = y - 1 polynomials.append(z * tensor) for k in range(2, terms // 2): next_polynomial = y * polynomials[k - 1] - polynomials[k - 2] polynomials.append(next_polynomial) return stack(polynomials)