예제 #1
0
파일: linop.py 프로젝트: AdityaJ7/xitorch
    def check(self, warn: Union[bool, None] = None) -> None:
        """
        Perform checks to make sure the linear operator behaves as a proper
        linear operator.

        Arguments
        ---------
        * warn: bool or None
            If True, then raises a warning to the user that the check might slow
            down the program. This is to remind the user to turn off the check
            when not in a debugging mode.
            If None, it will raise a warning if it runs not in a debug mode, but
            will be silent if it runs in a debug mode.

        Exceptions
        ----------
        * RuntimeError
            Raised if an error is raised when performing linear operations of the
            object (e.g. calling .mv(), .mm(), etc)
        * AssertionError
            Raised if the linear operations do not behave as proper linear operations.
            (e.g. not scaling linearly)
        """
        if warn is None:
            warn = not is_debug_enabled()
        if warn:
            msg = "The linear operator check is performed. This might slow down your program."
            warnings.warn(msg, stacklevel=2)
        checklinop(self)
예제 #2
0
    def check(self, warn: Optional[bool] = None) -> None:
        """
        Perform checks to make sure the ``LinearOperator`` behaves as a proper
        linear operator.

        Arguments
        ---------
        warn: bool or None
            If ``True``, then raises a warning to the user that the check might slow
            down the program. This is to remind the user to turn off the check
            when not in a debugging mode.
            If ``None``, it will raise a warning if it runs not in a debug mode, but
            will be silent if it runs in a debug mode.

        Raises
        ------
        RuntimeError
            Raised if an error is raised when performing linear operations of the
            object (e.g. calling ``.mv()``, ``.mm()``, etc)
        AssertionError
            Raised if the linear operations do not behave as proper linear operations.
            (e.g. not scaling linearly)
        """
        if warn is None:
            warn = not is_debug_enabled()
        if warn:
            msg = "The linear operator check is performed. This might slow down your program."
            warnings.warn(msg, stacklevel=2)
        checklinop(self)
        print("Check linear operator done")
예제 #3
0
def minimize(fcn: Callable[..., torch.Tensor],
             y0: torch.Tensor,
             params: Sequence[Any] = [],
             bck_options: Mapping[str, Any] = {},
             method: Union[str, None] = None,
             **fwd_options) -> torch.Tensor:
    """
    Solve the unbounded minimization problem:

    .. math::

        \mathbf{y^*} = \\arg\min_\mathbf{y} f(\mathbf{y}, \\theta)

    to find the best :math:`\mathbf{y}` that minimizes the output of the
    function :math:`f`.

    Arguments
    ---------
    fcn: callable
        The function to be optimized with output tensor with 1 element.
    y0: torch.tensor
        Initial guess of the solution with shape ``(*ny)``
    params: list
        List of any other parameters to be put in ``fcn``
    bck_options: dict
        Method-specific options for the backward solve.
    method: str or None
        Minimization method.
    **fwd_options
        Method-specific options (see method section)

    Returns
    -------
    torch.tensor
        The solution of the minimization with shape ``(*ny)``
    """
    # perform implementation check if debug mode is enabled
    if is_debug_enabled():
        assert_fcn_params(fcn, (y0, *params))

    pfunc = get_pure_function(fcn)

    fwd_options["method"] = _get_minimizer_default_method(method)
    # the rootfinder algorithms are designed to move to the opposite direction
    # of the output of the function, so the output of this function is just
    # the grad of z w.r.t. y
    @make_sibling(pfunc)
    def new_fcn(y, *params):
        with torch.enable_grad():
            y1 = y.clone().requires_grad_()
            z = pfunc(y1, *params)
        grady, = torch.autograd.grad(z, (y1, ),
                                     retain_graph=True,
                                     create_graph=torch.is_grad_enabled())
        return grady

    return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options,
                             len(params), *params, *pfunc.objparams())
예제 #4
0
def equilibrium(fcn: Callable[..., torch.Tensor],
                y0: torch.Tensor,
                params: Sequence[Any] = [],
                bck_options: Mapping[str, Any] = {},
                method: Union[str, None] = None,
                **fwd_options):
    """
    Solving the equilibrium equation of a given function,

    .. math::

        \mathbf{y} = \mathbf{f}(\mathbf{y}, \\theta)

    where :math:`\mathbf{f}` is a function that can be non-linear and
    produce output of the same shape of :math:`\mathbf{y}`, and :math:`\\theta`
    is other parameters required in the function.
    The output of this block is :math:`\mathbf{y}`
    that produces the same :math:`\mathbf{y}` as the output.

    Arguments
    ---------
    fcn : callable
        The function :math:`\mathbf{f}` with output tensor ``(*ny)``
    y0 : torch.tensor
        Initial guess of the solution with shape ``(*ny)``
    params : list
        List of any other parameters to be put in ``fcn``
    bck_options : dict
        Method-specific options for the backward solve
    method : str or None
        Rootfinder method.
    **fwd_options
        Method-specific options (see method section)

    Returns
    -------
    torch.tensor
        The solution which satisfies
        :math:`\mathbf{y} = \mathbf{f}(\mathbf{y},\\theta)`
        with shape ``(*ny)``
    """
    # perform implementation check if debug mode is enabled
    if is_debug_enabled():
        assert_fcn_params(fcn, (y0, *params))

    pfunc = get_pure_function(fcn)

    @make_sibling(pfunc)
    def new_fcn(y, *params):
        return y - pfunc(y, *params)

    fwd_options["method"] = _get_rootfinder_default_method(method)
    return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options,
                             len(params), *params, *pfunc.getobjparams())
예제 #5
0
def minimize(fcn: Callable[..., torch.Tensor],
             y0: torch.Tensor,
             params: Sequence[Any] = [],
             fwd_options: Mapping[str, Any] = {},
             bck_options: Mapping[str, Any] = {}) -> torch.Tensor:
    """
    Solve the minimization problem:

        z = (argmin_y) fcn(y, *params)

    to find the best `y` that minimizes the output of the function `fcn`.
    The output of `fcn` must be a single element tensor.

    Arguments
    ---------
    * fcn: callable with output tensor (numel=1)
        The function
    * y0: torch.tensor with shape (*ny)
        Initial guess of the solution
    * params: list
        List of any other parameters to be put in fcn
    * fwd_options: dict
        Options for the minimizer method
    * bck_options: dict
        Options for the backward solve method

    Returns
    -------
    * y: torch.tensor with shape (*ny)
        The solution of the minimization
    """
    # perform implementation check if debug mode is enabled
    if is_debug_enabled():
        assert_fcn_params(fcn, (y0, *params))

    pfunc = get_pure_function(fcn)

    # the rootfinder algorithms are designed to move to the opposite direction
    # of the output of the function, so the output of this function is just
    # the grad of z w.r.t. y
    @make_sibling(pfunc)
    def new_fcn(y, *params):
        with torch.enable_grad():
            y1 = y.clone().requires_grad_()
            z = pfunc(y1, *params)
        grady, = torch.autograd.grad(z, (y1, ),
                                     retain_graph=True,
                                     create_graph=torch.is_grad_enabled())
        return grady

    return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options,
                             len(params), *params, *pfunc.objparams())
예제 #6
0
def equilibrium(fcn: Callable[..., torch.Tensor],
                y0: torch.Tensor,
                params: Sequence[Any] = [],
                fwd_options: Mapping[str, Any] = {},
                bck_options: Mapping[str, Any] = {}):
    """
    Solving the equilibrium equation of a given function,

        y = fcn(y, *params)

    where `fcn` is a function that can be non-linear and produce output of shape
    `y`. The output of this block is `y` that produces the 0 as the output.

    Arguments
    ---------
    * fcn: callable with output tensor (*ny)
        The function
    * y0: torch.tensor with shape (*ny)
        Initial guess of the solution
    * params: list
        List of any other parameters to be put in fcn
    * fwd_options: dict
        Options for the rootfinder method
    * bck_options: dict
        Options for the backward solve method

    Returns
    -------
    * yout: torch.tensor with shape (*ny)
        The solution which satisfies yout = fcn(yout)

    Note
    ----
    * To obtain the correct gradient and higher order gradients, the fcn must be:
        - a torch.nn.Module with fcn.parameters() list the tensors that determine
            the output of the fcn.
        - a method in xt.EditableModule object with no out-of-scope parameters.
        - a function with no out-of-scope parameters.
    """
    # perform implementation check if debug mode is enabled
    if is_debug_enabled():
        assert_fcn_params(fcn, (y0, *params))

    pfunc = get_pure_function(fcn)

    @make_sibling(pfunc)
    def new_fcn(y, *params):
        return y - pfunc(y, *params)

    return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options,
                             len(params), *params, *pfunc.getobjparams())
예제 #7
0
    def backward(ctx, grad_eival, grad_eivec):
        in_debug_mode = is_debug_enabled()

        eival, eivec = ctx.saved_tensors
        min_threshold = torch.finfo(eival.dtype).eps ** 0.6
        eivect = eivec.transpose(-2, -1)

        # remove the degenerate part
        # see https://arxiv.org/pdf/2011.04366.pdf
        if grad_eivec is not None:
            # take the contribution from the eivec
            F = eival.unsqueeze(-2) - eival.unsqueeze(-1)
            idx = torch.abs(F) < min_threshold
            F[idx] = float("inf")

            # if in debug mode, check the degeneracy requirements
            if in_debug_mode:
                degenerate = torch.any(idx)
                xtg = eivect @ grad_eivec
                diff_xtg = (xtg - xtg.transpose(-2, -1))[idx]
                reqsat = torch.allclose(diff_xtg, torch.zeros_like(diff_xtg))
                # if the requirement is not satisfied, mathematically the derivative
                # should be `nan`, but here we just raise a warning
                if not reqsat:
                    msg = ("Degeneracy appears but the loss function seem to depend "
                           "strongly on the eigenvector. The gradient might be incorrect.\n")
                    msg += "Eigenvalues:\n%s\n" % str(eival)
                    msg += "Degenerate map:\n%s\n" % str(idx)
                    msg += "Requirements (should be all 0s):\n%s" % str(diff_xtg)
                    warnings.warn(MathWarning(msg))

            F = F.pow(-1)
            F = F * torch.matmul(eivect, grad_eivec)
            result = torch.matmul(eivec, torch.matmul(F, eivect))
        else:
            result = torch.zeros_like(eivec)

        # calculate the contribution from the eival
        if grad_eival is not None:
            result += torch.matmul(eivec, grad_eival.unsqueeze(-1) * eivect)

        # symmetrize to reduce numerical instability
        result = (result + result.transpose(-2, -1)) * 0.5
        return result
예제 #8
0
def _mcquad(ffcn, log_pfcn, x0, xsamples, wsamples, fparams, pparams,
            fwd_options, bck_options):
    # this is mcquad with an additional xsamples argument, to prevent xsamples being set by users

    if is_debug_enabled():
        assert_fcn_params(ffcn, (x0, *fparams))
        assert_fcn_params(log_pfcn, (x0, *pparams))

    # check if ffcn produces a list / tuple
    out = ffcn(x0, *fparams)
    is_tuple_out = isinstance(out, list) or isinstance(out, tuple)

    # get the pure functions
    pure_ffcn = get_pure_function(ffcn)
    pure_logpfcn = get_pure_function(log_pfcn)
    nfparams = len(fparams)
    npparams = len(pparams)
    fobjparams = pure_ffcn.objparams()
    pobjparams = pure_logpfcn.objparams()
    nf_objparams = len(fobjparams)

    if is_tuple_out:
        packer = TensorPacker(out)

        @make_sibling(pure_ffcn)
        def pure_ffcn2(x, *fparams):
            y = pure_ffcn(x, *fparams)
            return packer.flatten(y)

        res = _MCQuad.apply(pure_ffcn2, pure_logpfcn, x0, None, None,
                            fwd_options, bck_options, nfparams, nf_objparams,
                            npparams, *fparams, *fobjparams, *pparams,
                            *pobjparams)
        return packer.pack(res)
    else:
        return _MCQuad.apply(pure_ffcn, pure_logpfcn, x0, None, None,
                             fwd_options, bck_options, nfparams, nf_objparams,
                             npparams, *fparams, *fobjparams, *pparams,
                             *pobjparams)
예제 #9
0
파일: symeig.py 프로젝트: mfkasim1/xitorch
def svd(A: LinearOperator,
        k: Optional[int] = None,
        mode: str = "uppest",
        bck_options: Mapping[str, Any] = {},
        method: Union[str, Callable, None] = None,
        **fwd_options) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    r"""
    Perform the singular value decomposition (SVD):

    .. math::

        \mathbf{A} = \mathbf{U\Sigma V}^H

    where :math:`\mathbf{U}` and :math:`\mathbf{V}` are semi-unitary matrix and
    :math:`\mathbf{\Sigma}` is a diagonal matrix containing real non-negative
    numbers.

    Arguments
    ---------
    A: xitorch.LinearOperator
        The linear operator to be decomposed. It has a shape of ``(*BA, m, n)``
        where ``(*BA)`` is the batched dimension of ``A``.
    k: int or None
        The number of decomposition obtained. If ``None``, it will be
        ``min(*A.shape[-2:])``
    mode: str
        ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``,
        it will take the lowest ``k`` decomposition.
        If ``"uppest"``, it will take the uppermost ``k``.
    bck_options: dict
        Method-specific options for :func:`solve` which used in backpropagation
        calculation.
    method: str or callable or None
        Method for the svd (same options for :func:`symeig`). If ``None``,
        it will choose ``"exacteig"``.
    **fwd_options
        Method-specific options (see method section below).

    Returns
    -------
    tuple of tensors (u, s, vh)
        It will return ``u, s, vh`` with shapes respectively
        ``(*BA, m, k)``, ``(*BA, k)``, and ``(*BA, k, n)``.

    Note
    ----
    It is a naive implementation of symmetric eigendecomposition of ``A.H @ A``
    or ``A @ A.H`` (depending which one is cheaper)

    Warnings
    --------
    * If ``s`` contains very small numbers or degenerate values, the
      calculation and its gradient might be inaccurate.
    * The second derivative through U or V might be unstable.
      Extra care must be taken.
    """
    # A: (*BA, m, n)
    # adapted from scipy.sparse.linalg.svds

    if is_debug_enabled():
        A.check()
    BA = A.shape[:-2]

    m = A.shape[-2]
    n = A.shape[-1]
    if m < n:
        AAsym = A.matmul(A.H, is_hermitian=True)
        min_nm = m
    else:
        AAsym = A.H.matmul(A, is_hermitian=True)
        min_nm = n

    eivals, eivecs = symeig(AAsym,
                            k,
                            mode,
                            bck_options=bck_options,
                            method=method,
                            **fwd_options)  # (*BA, k) and (*BA, min(mn), k)

    # clamp the eigenvalues to a small positive values to avoid numerical
    # instability
    eivals = torch.clamp(eivals, min=0.0)
    s = torch.sqrt(eivals)  # (*BA, k)
    sdiv = torch.clamp(s, min=1e-12).unsqueeze(-2)  # (*BA, 1, k)
    if m < n:
        u = eivecs  # (*BA, m, k)
        v = A.rmm(u) / sdiv  # (*BA, n, k)
    else:
        v = eivecs  # (*BA, n, k)
        u = A.mm(v) / sdiv  # (*BA, m, k)
    vh = v.transpose(-2, -1)
    return u, s, vh
예제 #10
0
파일: quad.py 프로젝트: xitorch/xitorch
def quad(fcn: Union[Callable[..., torch.Tensor],
                    Callable[..., Sequence[torch.Tensor]]],
         xl: Union[float, int, torch.Tensor],
         xu: Union[float, int, torch.Tensor],
         params: Sequence[Any] = [],
         bck_options: Mapping[str, Any] = {},
         method: Union[str, Callable, None] = None,
         **fwd_options) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
    r"""
    Calculate the quadrature:

    .. math::

        y = \int_{x_l}^{x_u} f(x, \theta)\ \mathrm{d}x

    Arguments
    ---------
    fcn: callable
        The function to be integrated. Its output must be a tensor with
        shape ``(*nout)`` or list of tensors.
    xl: float, int or 1-element torch.Tensor
        The lower bound of the integration.
    xu: float, int or 1-element torch.Tensor
        The upper bound of the integration.
    params: list
        Sequence of any other parameters for the function ``fcn``.
    bck_options: dict
        Options for the backward quadrature method.
    method: str or callable or None
        Quadrature method. If None, it will choose ``"leggauss"``.
    **fwd_options
        Method-specific options (see method section).

    Returns
    -------
    torch.tensor or a list of tensors
        The quadrature results with shape ``(*nout)`` or list of tensors.
    """
    # perform implementation check if debug mode is enabled
    if is_debug_enabled():
        assert_fcn_params(fcn, (xl, *params))
    if isinstance(xl, torch.Tensor):
        assert_runtime(torch.numel(xl) == 1, "xl must be a 1-element tensors")
    if isinstance(xu, torch.Tensor):
        assert_runtime(torch.numel(xu) == 1, "xu must be a 1-element tensors")
    if method is None:
        method = "leggauss"
    fwd_options["method"] = method

    out = fcn(xl, *params)
    if isinstance(out, torch.Tensor):
        dtype = out.dtype
        device = out.device
        is_tuple_out = False
    elif len(out) > 0:
        dtype = out[0].dtype
        device = out[0].device
        is_tuple_out = True
    else:
        raise RuntimeError("The output of the fcn must be non-empty")

    pfunc = get_pure_function(fcn)
    nparams = len(params)
    if is_tuple_out:
        packer = TensorPacker(out)

        @make_sibling(pfunc)
        def pfunc2(x, *params):
            y = fcn(x, *params)
            return packer.flatten(y)

        res = _Quadrature.apply(pfunc2, xl, xu, fwd_options, bck_options,
                                nparams, dtype, device, *params,
                                *pfunc.objparams())
        return packer.pack(res)
    else:
        return _Quadrature.apply(pfunc, xl, xu, fwd_options, bck_options,
                                 nparams, dtype, device, *params,
                                 *pfunc.objparams())
예제 #11
0
def solve(A: LinearOperator,
          B: torch.Tensor,
          E: Union[torch.Tensor, None] = None,
          M: Optional[LinearOperator] = None,
          bck_options: Mapping[str, Any] = {},
          method: Union[str, Callable, None] = None,
          **fwd_options) -> torch.Tensor:
    r"""
    Performing iterative method to solve the equation

    .. math::

        \mathbf{AX=B}

    or

    .. math::

        \mathbf{AX-MXE=B}

    where :math:`\mathbf{E}` is a diagonal matrix.
    This function can also solve batched multiple inverse equation at the
    same time by applying :math:`\mathbf{A}` to a tensor :math:`\mathbf{X}`
    with shape ``(...,na,ncols)``.
    The applied :math:`\mathbf{E}` are not necessarily identical for each column.

    Arguments
    ---------
    A: xitorch.LinearOperator
        A linear operator that takes an input ``X`` and produce the vectors in the same
        space as ``B``.
        It should have the shape of ``(*BA, na, na)``
    B: torch.Tensor
        The tensor on the right hand side with shape ``(*BB, na, ncols)``
    E: torch.Tensor or None
        If a tensor, it will solve :math:`\mathbf{AX-MXE = B}`.
        It will be regarded as the diagonal of the matrix.
        Otherwise, it just solves :math:`\mathbf{AX = B}` and ``M`` is ignored.
        If it is a tensor, it should have shape of ``(*BE, ncols)``.
    M: xitorch.LinearOperator or None
        The transformation on the ``E`` side. If ``E`` is ``None``,
        then this argument is ignored.
        If E is not ``None`` and ``M`` is ``None``, then ``M=I``.
        If LinearOperator, it must be Hermitian with shape ``(*BM, na, na)``.
    bck_options: dict
        Options of the iterative solver in the backward calculation.
    method: str or callable or None
        The method of linear equation solver. If ``None``, it will choose
        ``"cg"`` or ``"bicgstab"`` based on the matrices symmetry.
        `Note`: default method will be changed quite frequently, so if you want
        future compatibility, please specify a method.
    **fwd_options
        Method-specific options (see method below)

    Returns
    -------
    torch.Tensor
        The tensor :math:`\mathbf{X}` that satisfies :math:`\mathbf{AX-MXE=B}`.
    """
    assert_runtime(A.shape[-1] == A.shape[-2],
                   "The linear operator A must have a square shape")
    assert_runtime(
        A.shape[-1] == B.shape[-2],
        "Mismatch shape of A & B (A: %s, B: %s)" % (A.shape, B.shape))
    assert_runtime(
        not torch.is_grad_enabled() or A.is_getparamnames_implemented,
        "The _getparamnames(self, prefix) of linear operator A must be "
        "implemented if using solve with grad enabled")
    if M is not None:
        assert_runtime(M.shape[-1] == M.shape[-2],
                       "The linear operator M must have a square shape")
        assert_runtime(
            M.shape[-1] == A.shape[-1],
            "The shape of A & M must match (A: %s, M: %s)" %
            (A.shape, M.shape))
        assert_runtime(M.is_hermitian,
                       "The linear operator M must be a Hermitian matrix")
        assert_runtime(
            not torch.is_grad_enabled() or M.is_getparamnames_implemented,
            "The _getparamnames(self, prefix) of linear operator M must be "
            "implemented if using solve with grad enabled")
    if E is not None:
        assert_runtime(
            E.shape[-1] == B.shape[-1],
            "The last dimension of E & B must match (E: %s, B: %s)" %
            (E.shape, B.shape))
    if E is None and M is not None:
        warnings.warn(
            "M is supplied but will be ignored because E is not supplied")

    # perform expensive check if debug mode is enabled
    if is_debug_enabled():
        A.check()
        if M is not None:
            M.check()

    if method is None:
        if isinstance(A, MatrixLinearOperator) and \
           (M is None or isinstance(M, MatrixLinearOperator)):
            method = "exactsolve"
        else:
            is_hermit = A.is_hermitian and (M is None or M.is_hermitian)
            method = "cg" if is_hermit else "bicgstab"

    if method == "exactsolve":
        return exactsolve(A, B, E, M)
    else:
        # get the unique parameters of A
        params = A.getlinopparams()
        mparams = M.getlinopparams() if M is not None else []
        na = len(params)
        return solve_torchfcn.apply(A, B, E, M, method, fwd_options,
                                    bck_options, na, *params, *mparams)
예제 #12
0
def solve_ivp(fcn: Callable[..., torch.Tensor],
              ts: torch.Tensor,
              y0: torch.Tensor,
              params: Sequence[Any] = [],
              fwd_options: Mapping[str, Any] = {},
              bck_options: Mapping[str, Any] = {}) -> torch.Tensor:
    """
    Solve the initial value problem (IVP) which given the initial value `y0`,
    the function is then solve

        y(t) = y0 + int_t0^t f(t', y, *params) dt'

    Arguments
    ---------
    * fcn: callable with output a tensor with shape (*ny) or a list of tensors
        The function that represents dy/dt. The function takes an input of a
        single time `t` and `y` with shape (*ny) and produce dydt with shape (*ny).
    * ts: torch.tensor with shape (nt,)
        The time points where the value of `y` is returned.
        It must be monotonically increasing or decreasing.
    * y0: torch.tensor with shape (*ny) or a list of tensors
        The initial value of y, i.e. y(t[0]) == y0
    * params: list
        List of other parameters required in the function.
    * fwd_options: dict
        Options for the forward solve_ivp method.
    * bck_options: dict
        Options for the backward solve_ivp method.

    Returns
    -------
    * yt: torch.tensor with shape (nt,*ny) or a list of tensors
        The values of `y` for each time step in `ts`.
    """
    if is_debug_enabled():
        assert_fcn_params(fcn, (ts[0], y0, *params))
    assert_runtime(len(ts.shape) == 1, "Argument ts must be a 1D tensor")

    # run once to see if the outputs is a tuple or a single tensor
    is_y0_list = isinstance(y0, list) or isinstance(y0, tuple)
    dydt = fcn(ts[0], y0, *params)
    is_dydt_list = isinstance(dydt, list) or isinstance(dydt, tuple)
    if is_y0_list != is_dydt_list:
        raise RuntimeError(
            "The y0 and output of fcn must both be tuple or a tensor")

    pfcn = get_pure_function(fcn)
    if is_y0_list:
        nt = len(ts)
        roller = TensorPacker(y0)

        @make_sibling(pfcn)
        def pfcn2(t, ytensor, *params):
            ylist = roller.pack(ytensor)
            res_list = pfcn(t, ylist, *params)
            res = roller.flatten(res_list)
            return res

        y0 = roller.flatten(y0)
        res = _SolveIVP.apply(pfcn2, ts, fwd_options, bck_options, len(params),
                              y0, *params, *pfcn.objparams())
        return roller.pack(res)
    else:
        return _SolveIVP.apply(pfcn, ts, fwd_options, bck_options, len(params),
                               y0, *params, *pfcn.objparams())
예제 #13
0
def svd(A: LinearOperator,
        k: Optional[int] = None,
        mode: str = "uppest",
        bck_options: Mapping[str, Any] = {},
        method: Union[str, Callable, None] = None,
        **fwd_options) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    r"""
    Perform the singular value decomposition (SVD):

    .. math::

        \mathbf{A} = \mathbf{U\Sigma V}^H

    where :math:`\mathbf{U}` and :math:`\mathbf{V}` are semi-unitary matrix and
    :math:`\mathbf{\Sigma}` is a diagonal matrix containing real non-negative
    numbers.
    This function can handle derivatives for degenerate singular values by setting non-zero
    ``degen_atol`` and ``degen_rtol`` in the backward option using the expressions
    in [1]_.

    Arguments
    ---------
    A: xitorch.LinearOperator
        The linear operator to be decomposed. It has a shape of ``(*BA, m, n)``
        where ``(*BA)`` is the batched dimension of ``A``.
    k: int or None
        The number of decomposition obtained. If ``None``, it will be
        ``min(*A.shape[-2:])``
    mode: str
        ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``,
        it will take the lowest ``k`` decomposition.
        If ``"uppest"``, it will take the uppermost ``k``.
    bck_options: dict
        Method-specific options for :func:`solve` which used in backpropagation
        calculation with some additional arguments for computing the backward
        derivatives:

        * ``degen_atol`` (``float`` or None): Minimum absolute difference between
          two singular values to be treated as degenerate. If None, it is
          ``torch.finfo(dtype).eps**0.6``. If 0.0, no special treatment on
          degeneracy is applied. (default: None)
        * ``degen_rtol`` (``float`` or None): Minimum relative difference between
          two singular values to be treated as degenerate. If None, it is
          ``torch.finfo(dtype).eps**0.4``. If 0.0, no special treatment on
          degeneracy is applied. (default: None)

        Note: the default values of ``degen_atol`` and ``degen_rtol`` are going
        to change in the future. So, for future compatibility, please specify
        the specific values.

    method: str or callable or None
        Method for the svd (same options for :func:`symeig`). If ``None``,
        it will choose ``"exacteig"``.
    **fwd_options
        Method-specific options (see method section below).

    Returns
    -------
    tuple of tensors (u, s, vh)
        It will return ``u, s, vh`` with shapes respectively
        ``(*BA, m, k)``, ``(*BA, k)``, and ``(*BA, k, n)``.

    Note
    ----
    It is a naive implementation of symmetric eigendecomposition of ``A.H @ A``
    or ``A @ A.H`` (depending which one is cheaper)

    References
    ----------
    .. [1] Muhammad F. Kasim,
           "Derivatives of partial eigendecomposition of a real symmetric matrix for degenerate cases".
           arXiv:2011.04366 (2020)
           `https://arxiv.org/abs/2011.04366 <https://arxiv.org/abs/2011.04366>`_
    """
    # A: (*BA, m, n)
    # adapted from scipy.sparse.linalg.svds

    if is_debug_enabled():
        A.check()
    BA = A.shape[:-2]

    m = A.shape[-2]
    n = A.shape[-1]
    if m < n:
        AAsym = A.matmul(A.H, is_hermitian=True)
        min_nm = m
    else:
        AAsym = A.H.matmul(A, is_hermitian=True)
        min_nm = n

    eivals, eivecs = symeig(AAsym,
                            k,
                            mode,
                            bck_options=bck_options,
                            method=method,
                            **fwd_options)  # (*BA, k) and (*BA, min(mn), k)

    # clamp the eigenvalues to a small positive values to avoid numerical
    # instability
    eivals = torch.clamp(eivals, min=0.0)
    s = torch.sqrt(eivals)  # (*BA, k)
    sdiv = torch.clamp(s, min=1e-12).unsqueeze(-2)  # (*BA, 1, k)
    if m < n:
        u = eivecs  # (*BA, m, k)
        v = A.rmm(u) / sdiv  # (*BA, n, k)
    else:
        v = eivecs  # (*BA, n, k)
        u = A.mm(v) / sdiv  # (*BA, m, k)
    vh = v.transpose(-2, -1)
    return u, s, vh
예제 #14
0
def solve(A: LinearOperator,
          B: torch.Tensor,
          E: Union[torch.Tensor, None] = None,
          M: Union[LinearOperator, None] = None,
          posdef=False,
          fwd_options: Mapping[str, Any] = {},
          bck_options: Mapping[str, Any] = {}):
    """
    Performing iterative method to solve the equation AX=B or
    AX-MXE=B, where E is a diagonal matrix.
    This function can also solve batched multiple inverse equation at the
    same time by applying A to a tensor X with shape (...,na,ncols).
    The applied E are not necessarily identical for each column.

    Arguments
    ---------
    * A: xitorch.LinearOperator instance with shape (*BA, na, na)
        A function that takes an input X and produce the vectors in the same
        space as B.
    * B: torch.tensor (*BB, na, ncols)
        The tensor on the right hand side.
    * E: torch.tensor (*BE, ncols) or None
        If not None, it will solve AX-MXE = B. Otherwise, it just solves
        AX = B and M is ignored. E would be applied to every column.
    * M: xitorch.LinearOperator instance (*BM, na, na) or None
        The transformation on the E side. If E is None,
        then this argument is ignored. I E is not None and M is None, then M=I.
        This LinearOperator must be Hermitian.
    * fwd_options: dict
        Options of the iterative solver in the forward calculation
    * bck_options: dict
        Options of the iterative solver in the backward calculation
    """
    assert_runtime(A.shape[-1] == A.shape[-2],
                   "The linear operator A must have a square shape")
    assert_runtime(
        A.shape[-1] == B.shape[-2],
        "Mismatch shape of A & B (A: %s, B: %s)" % (A.shape, B.shape))
    if M is not None:
        assert_runtime(M.shape[-1] == M.shape[-2],
                       "The linear operator M must have a square shape")
        assert_runtime(
            M.shape[-1] == A.shape[-1],
            "The shape of A & M must match (A: %s, M: %s)" %
            (A.shape, M.shape))
        assert_runtime(M.is_hermitian,
                       "The linear operator M must be a Hermitian matrix")
    if E is not None:
        assert_runtime(
            E.shape[-1] == B.shape[-1],
            "The last dimension of E & B must match (E: %s, B: %s)" %
            (E.shape, B.shape))
    if E is None and M is not None:
        warnings.warn(
            "M is supplied but will be ignored because E is not supplied")

    # perform expensive check if debug mode is enabled
    if is_debug_enabled():
        A.check()
        if M is not None:
            M.check()

    if "method" not in fwd_options or fwd_options["method"].lower(
    ) == "exactsolve":
        return exactsolve(A, B, E, M)
    else:
        # get the unique parameters of A
        params = A.getlinopparams()
        mparams = M.getlinopparams() if M is not None else []
        na = len(params)
        return solve_torchfcn.apply(A, B, E, M, posdef, fwd_options,
                                    bck_options, na, *params, *mparams)
예제 #15
0
파일: symeig.py 프로젝트: AdityaJ7/xitorch
def symeig(A: LinearOperator,
           neig: Union[int, None] = None,
           mode: str = "lowest",
           M: Union[LinearOperator, None] = None,
           fwd_options: Mapping[str, Any] = {},
           bck_options: Mapping[str, Any] = {}):
    """
    Obtain `neig` lowest eigenvalues and eigenvectors of a linear operator.
    If M is specified, it solve the eigendecomposition Ax = eMx.

    Arguments
    ---------
    * A: xitorch.LinearOperator hermitian instance with shape (*BA, q, q)
        The linear module object on which the eigenpairs are constructed.
    * neig: int or None
        The number of eigenpairs to be retrieved. If None, all eigenpairs are
        retrieved
    * mode: str
        "lowest" or "uppermost"/"uppest". If "lowest", it will take the lowest
        `neig` eigenpairs. If "uppest", it will take the uppermost `neig`.
    * M: xitorch.LinearOperator hermitian instance with shape (*BM, q, q) or None
        The transformation on the right hand side. If None, then M=I.
    * fwd_options: dict with str as key
        Eigendecomposition iterative algorithm options.
    * bck_options: dict with str as key
        Conjugate gradient options to calculate the gradient in
        backpropagation calculation.

    Returns
    -------
    * eigvals: (*BAM, neig)
    * eigvecs: (*BAM, na, neig)
        The lowest eigenvalues and eigenvectors, where *BAM are the broadcasted
        shape of *BA and *BM.
    """
    assert_runtime(A.is_hermitian, "The linear operator A must be Hermitian")
    if M is not None:
        assert_runtime(M.is_hermitian,
                       "The linear operator M must be Hermitian")
        assert_runtime(
            M.shape[-1] == A.shape[-1],
            "The shape of A & M must match (A: %s, M: %s)" %
            (A.shape, M.shape))
    mode = mode.lower()
    if mode == "uppermost":
        mode = "uppest"

    # perform expensive check if debug mode is enabled
    if is_debug_enabled():
        A.check()
        if M is not None:
            M.check()

    if "method" not in fwd_options or fwd_options["method"].lower(
    ) == "exacteig":
        return exacteig(A, neig, mode, M)
    else:
        # get the unique parameters of A & M
        params = A.getlinopparams()
        mparams = M.getlinopparams() if M is not None else []
        na = len(params)
        return symeig_torchfcn.apply(A, neig, mode, M, fwd_options,
                                     bck_options, na, *params, *mparams)
예제 #16
0
def minimize(fcn: Callable[..., torch.Tensor],
             y0: torch.Tensor,
             params: Sequence[Any] = [],
             bck_options: Mapping[str, Any] = {},
             method: Union[str, Callable] = None,
             **fwd_options) -> torch.Tensor:
    r"""
    Solve the unbounded minimization problem:

    .. math::

        \mathbf{y^*} = \arg\min_\mathbf{y} f(\mathbf{y}, \theta)

    to find the best :math:`\mathbf{y}` that minimizes the output of the
    function :math:`f`.

    Arguments
    ---------
    fcn: callable
        The function to be optimized with output tensor with 1 element.
    y0: torch.tensor
        Initial guess of the solution with shape ``(*ny)``
    params: list
        Sequence of any other parameters to be put in ``fcn``
    bck_options: dict
        Method-specific options for the backward solve (see :func:`xitorch.linalg.solve`)
    method: str or callable or None
        Minimization method. If None, it will choose ``"broyden1"``.
    **fwd_options
        Method-specific options (see method section)

    Returns
    -------
    torch.tensor
        The solution of the minimization with shape ``(*ny)``

    Example
    -------
    .. testsetup:: root1

        import torch
        from xitorch.optimize import minimize

    .. doctest:: root1

        >>> def func1(y, A):  # example function
        ...     return torch.sum((A @ y)**2 + y / 2.0)
        >>> A = torch.tensor([[1.1, 0.4], [0.3, 0.8]]).requires_grad_()
        >>> y0 = torch.zeros((2,1))  # zeros as the initial guess
        >>> ymin = minimize(func1, y0, params=(A,))
        >>> print(ymin)
        tensor([[-0.0519],
                [-0.2684]], grad_fn=<_RootFinderBackward>)
    """
    # perform implementation check if debug mode is enabled
    if is_debug_enabled():
        assert_fcn_params(fcn, (y0, *params))

    pfunc = get_pure_function(fcn)

    fwd_options["method"] = _get_minimizer_default_method(method)
    # the rootfinder algorithms are designed to move to the opposite direction
    # of the output of the function, so the output of this function is just
    # the grad of z w.r.t. y

    @make_sibling(pfunc)
    def new_fcn(y, *params):
        with torch.enable_grad():
            y1 = y.clone().requires_grad_()
            z = pfunc(y1, *params)
        grady, = torch.autograd.grad(z, (y1, ),
                                     retain_graph=True,
                                     create_graph=torch.is_grad_enabled())
        return grady

    return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options,
                             len(params), *params, *pfunc.objparams())
예제 #17
0
def quad(fcn: Union[Callable[..., torch.Tensor], Callable[...,
                                                          List[torch.Tensor]]],
         xl: Union[float, int, torch.Tensor],
         xu: Union[float, int, torch.Tensor],
         params: Sequence[Any] = [],
         fwd_options: Mapping[str, Any] = {},
         bck_options: Mapping[str, Any] = {}):
    """
    Calculate the quadrature of the function `fcn` from `x0` to `xf`:

        y = int_xl^xu fcn(x, *params)

    Arguments
    ---------
    * fcn: callable with output tensor with shape (*nout) or list of tensors
        The function to be integrated.
    * xl, xu: float, int, or 1-element torch.Tensor
        The lower and upper bound of the integration.
    * params: list
        List of any other parameters for the function `fcn`.
    * fwd_options: dict
        Options for the forward quadrature method.
    * bck_options: dict
        Options for the backward quadrature method.

    Returns
    -------
    * y: torch.tensor with shape (*nout) or list of tensors
        The quadrature results.
    """
    # perform implementation check if debug mode is enabled
    if is_debug_enabled():
        assert_fcn_params(fcn, (xl, *params))
    if isinstance(xl, torch.Tensor):
        assert_runtime(torch.numel(xl) == 1, "xl must be a 1-element tensors")
    if isinstance(xu, torch.Tensor):
        assert_runtime(torch.numel(xu) == 1, "xu must be a 1-element tensors")

    out = fcn(xl, *params)
    is_tuple_out = not isinstance(out, torch.Tensor)
    if not is_tuple_out:
        dtype = out.dtype
        device = out.device
    elif len(out) > 0:
        dtype = out[0].dtype
        device = out[0].device
    else:
        raise RuntimeError("The output of the fcn must be non-empty")

    pfunc = get_pure_function(fcn)
    nparams = len(params)
    if is_tuple_out:
        packer = TensorPacker(out)

        @make_sibling(pfunc)
        def pfunc2(x, *params):
            y = fcn(x, *params)
            return packer.flatten(y)

        res = _Quadrature.apply(pfunc2, xl, xu, fwd_options, bck_options,
                                nparams, dtype, device, *params,
                                *pfunc.objparams())
        return packer.pack(res)
    else:
        return _Quadrature.apply(pfunc, xl, xu, fwd_options, bck_options,
                                 nparams, dtype, device, *params,
                                 *pfunc.objparams())
예제 #18
0
    def backward(ctx, grad_evals, grad_evecs):
        # grad_evals: (*BAM, neig)
        # grad_evecs: (*BAM, na, neig)

        # get the variables from ctx
        evals, evecs = ctx.saved_tensors[:2]
        na = ctx.na
        amparams = ctx.saved_tensors[2:]
        params = amparams[:na]
        mparams = amparams[na:]

        M = ctx.M
        A = ctx.A
        degen_atol: Optional[float] = ctx.bck_alg_config["degen_atol"]
        degen_rtol: Optional[float] = ctx.bck_alg_config["degen_rtol"]

        # set the default values of degen_*tol
        dtype = evals.dtype
        if degen_atol is None:
            degen_atol = torch.finfo(dtype).eps**0.6
        if degen_rtol is None:
            degen_rtol = torch.finfo(dtype).eps**0.4

        # check the degeneracy
        if degen_atol > 0 or degen_rtol > 0:
            # idx_degen: (*BAM, neig, neig)
            idx_degen, isdegenerate = _check_degen(evals, degen_atol,
                                                   degen_rtol)
        else:
            isdegenerate = False
        if not isdegenerate:
            idx_degen = None

        # the loss function where the gradient will be retrieved
        # warnings: if not all params have the connection to the output of A,
        # it could cause an infinite loop because pytorch will keep looking
        # for the *params node and propagate further backward via the `evecs`
        # path. So make sure all the *params are all connected in the graph.
        with torch.enable_grad():
            params = [p.clone().requires_grad_() for p in params]
            with A.uselinopparams(*params):
                loss = A.mm(evecs)  # (*BAM, na, neig)

        # if degenerate, check the conditions for finite derivative
        if is_debug_enabled() and isdegenerate:
            xtg = torch.matmul(evecs.transpose(-2, -1), grad_evecs)
            req1 = idx_degen * (xtg - xtg.transpose(-2, -1))
            reqtol = xtg.abs().max() * grad_evecs.shape[-2] * torch.finfo(
                grad_evecs.dtype).eps

            if not torch.all(torch.abs(req1) <= reqtol):
                # if the requirements are not satisfied, raises a warning
                msg = (
                    "Degeneracy appears but the loss function seem to depend "
                    "strongly on the eigenvector. The gradient might be incorrect.\n"
                )
                msg += "Eigenvalues:\n%s\n" % str(evals)
                msg += "Degenerate map:\n%s\n" % str(idx_degen)
                msg += "Requirements (should be all 0s):\n%s" % str(req1)
                warnings.warn(MathWarning(msg))

        # calculate the contributions from the eigenvalues
        gevalsA = grad_evals.unsqueeze(-2) * evecs  # (*BAM, na, neig)

        # calculate the contributions from the eigenvectors
        with M.uselinopparams(
                *mparams) if M is not None else dummy_context_manager():
            # orthogonalize the grad_evecs with evecs
            B = _ortho(grad_evecs, evecs, D=idx_degen, M=M, mright=False)

            with A.uselinopparams(*params):
                gevecs = solve(A,
                               -B,
                               evals,
                               M,
                               bck_options=ctx.bck_config,
                               **ctx.bck_config)  # (*BAM, na, neig)

            # orthogonalize gevecs w.r.t. evecs
            gevecsA = _ortho(gevecs, evecs, D=None, M=M, mright=True)

        # accummulate the gradient contributions
        gaccumA = gevalsA + gevecsA
        grad_params = torch.autograd.grad(
            outputs=(loss, ),
            inputs=params,
            grad_outputs=(gaccumA, ),
            create_graph=torch.is_grad_enabled(),
        )

        grad_mparams = []
        if ctx.M is not None:
            with torch.enable_grad():
                mparams = [p.clone().requires_grad_() for p in mparams]
                with M.uselinopparams(*mparams):
                    mloss = M.mm(evecs)  # (*BAM, na, neig)
            gevalsM = -gevalsA * evals.unsqueeze(-2)
            gevecsM = -gevecsA * evals.unsqueeze(-2)

            # the contribution from the parallel elements
            gevecsM_par = (-0.5 *
                           torch.einsum("...ae,...ae->...e", grad_evecs, evecs)
                           ).unsqueeze(-2) * evecs  # (*BAM, na, neig)

            gaccumM = gevalsM + gevecsM + gevecsM_par
            grad_mparams = torch.autograd.grad(
                outputs=(mloss, ),
                inputs=mparams,
                grad_outputs=(gaccumM, ),
                create_graph=torch.is_grad_enabled(),
            )

        return (None, None, None, None, None, None, None, *grad_params,
                *grad_mparams)
예제 #19
0
def symeig(A: LinearOperator,
           neig: Optional[int] = None,
           mode: str = "lowest",
           M: Optional[LinearOperator] = None,
           bck_options: Mapping[str, Any] = {},
           method: Union[str, Callable, None] = None,
           **fwd_options) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""
    Obtain ``neig`` lowest eigenvalues and eigenvectors of a linear operator,

    .. math::

        \mathbf{AX = MXE}

    where :math:`\mathbf{A}, \mathbf{M}` are linear operators,
    :math:`\mathbf{E}` is a diagonal matrix containing the eigenvalues, and
    :math:`\mathbf{X}` is a matrix containing the eigenvectors.
    This function can handle derivatives for degenerate cases by setting non-zero
    ``degen_atol`` and ``degen_rtol`` in the backward option using the expressions
    in [1]_.

    Arguments
    ---------
    A: xitorch.LinearOperator
        The linear operator object on which the eigenpairs are constructed.
        It must be a Hermitian linear operator with shape ``(*BA, q, q)``
    neig: int or None
        The number of eigenpairs to be retrieved. If ``None``, all eigenpairs are
        retrieved
    mode: str
        ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``,
        it will take the lowest ``neig`` eigenpairs.
        If ``"uppest"``, it will take the uppermost ``neig``.
    M: xitorch.LinearOperator
        The transformation on the right hand side. If ``None``, then ``M=I``.
        If specified, it must be a Hermitian with shape ``(*BM, q, q)``.
    bck_options: dict
        Method-specific options for :func:`solve` which used in backpropagation
        calculation with some additional arguments for computing the backward
        derivatives:

        * ``degen_atol`` (``float`` or None): Minimum absolute difference between
          two eigenvalues to be treated as degenerate. If None, it is
          ``torch.finfo(dtype).eps**0.6``. If 0.0, no special treatment on
          degeneracy is applied. (default: None)
        * ``degen_rtol`` (``float`` or None): Minimum relative difference between
          two eigenvalues to be treated as degenerate. If None, it is
          ``torch.finfo(dtype).eps**0.4``. If 0.0, no special treatment on
          degeneracy is applied. (default: None)

        Note: the default values of ``degen_atol`` and ``degen_rtol`` are going
        to change in the future. So, for future compatibility, please specify
        the specific values.

    method: str or callable or None
        Method for the eigendecomposition. If ``None``, it will choose
        ``"exacteig"``.
    **fwd_options
        Method-specific options (see method section below).

    Returns
    -------
    tuple of tensors (eigenvalues, eigenvectors)
        It will return eigenvalues and eigenvectors with shapes respectively
        ``(*BAM, neig)`` and ``(*BAM, na, neig)``, where ``*BAM`` is the
        broadcasted shape of ``*BA`` and ``*BM``.

    References
    ----------
    .. [1] Muhammad F. Kasim,
           "Derivatives of partial eigendecomposition of a real symmetric matrix for degenerate cases".
           arXiv:2011.04366 (2020)
           `https://arxiv.org/abs/2011.04366 <https://arxiv.org/abs/2011.04366>`_
    """
    assert_runtime(A.is_hermitian, "The linear operator A must be Hermitian")
    assert_runtime(
        not torch.is_grad_enabled() or A.is_getparamnames_implemented,
        "The _getparamnames(self, prefix) of linear operator A must be "
        "implemented if using symeig with grad enabled")
    if M is not None:
        assert_runtime(M.is_hermitian,
                       "The linear operator M must be Hermitian")
        assert_runtime(
            M.shape[-1] == A.shape[-1],
            "The shape of A & M must match (A: %s, M: %s)" %
            (A.shape, M.shape))
        assert_runtime(
            not torch.is_grad_enabled() or M.is_getparamnames_implemented,
            "The _getparamnames(self, prefix) of linear operator M must be "
            "implemented if using symeig with grad enabled")
    mode = mode.lower()
    if mode == "uppermost":
        mode = "uppest"
    if method is None:
        if isinstance(A, MatrixLinearOperator) and \
           (M is None or isinstance(M, MatrixLinearOperator)):
            method = "exacteig"
        else:
            # TODO: implement robust LOBPCG and put it here
            method = "exacteig"
    if neig is None:
        neig = A.shape[-1]

    # perform expensive check if debug mode is enabled
    if is_debug_enabled():
        A.check()
        if M is not None:
            M.check()

    if method == "exacteig":
        return exacteig(A, neig, mode, M)
    else:
        fwd_options["method"] = method
        # get the unique parameters of A & M
        params = A.getlinopparams()
        mparams = M.getlinopparams() if M is not None else []
        na = len(params)
        return symeig_torchfcn.apply(A, neig, mode, M, fwd_options,
                                     bck_options, na, *params, *mparams)
예제 #20
0
def equilibrium(fcn: Callable[..., torch.Tensor],
                y0: torch.Tensor,
                params: Sequence[Any] = [],
                bck_options: Mapping[str, Any] = {},
                method: Union[str, Callable, None] = None,
                **fwd_options) -> torch.Tensor:
    r"""
    Solving the equilibrium equation of a given function,

    .. math::

        \mathbf{y} = \mathbf{f}(\mathbf{y}, \theta)

    where :math:`\mathbf{f}` is a function that can be non-linear and
    produce output of the same shape of :math:`\mathbf{y}`, and :math:`\theta`
    is other parameters required in the function.
    The output of this block is :math:`\mathbf{y}`
    that produces the same :math:`\mathbf{y}` as the output.

    Arguments
    ---------
    fcn : callable
        The function :math:`\mathbf{f}` with output tensor ``(*ny)``
    y0 : torch.tensor
        Initial guess of the solution with shape ``(*ny)``
    params : list
        Sequence of any other parameters to be put in ``fcn``
    bck_options : dict
        Method-specific options for the backward solve (see :func:`xitorch.linalg.solve`)
    method : str or None
        Rootfinder method. If None, it will choose ``"broyden1"``.
    **fwd_options
        Method-specific options (see method section)

    Returns
    -------
    torch.tensor
        The solution which satisfies
        :math:`\mathbf{y} = \mathbf{f}(\mathbf{y},\theta)`
        with shape ``(*ny)``

    Example
    -------
    .. testsetup:: equil1

        import torch
        from xitorch.optimize import equilibrium

    .. doctest:: equil1

        >>> def func1(y, A):  # example function
        ...     return torch.tanh(A @ y + 0.1) + y / 2.0
        >>> A = torch.tensor([[1.1, 0.4], [0.3, 0.8]]).requires_grad_()
        >>> y0 = torch.zeros((2,1))  # zeros as the initial guess
        >>> yequil = equilibrium(func1, y0, params=(A,))
        >>> print(yequil)
        tensor([[ 0.2313],
                [-0.5957]], grad_fn=<_RootFinderBackward>)

    Note
    ----
    * This is a direct implementation of finding the root of
      :math:`\mathbf{g}(\mathbf{y}, \theta) = \mathbf{y} - \mathbf{f}(\mathbf{y}, \theta)`
    """
    # perform implementation check if debug mode is enabled
    if is_debug_enabled():
        assert_fcn_params(fcn, (y0, *params))

    pfunc = get_pure_function(fcn)

    @make_sibling(pfunc)
    def new_fcn(y, *params):
        return y - pfunc(y, *params)

    fwd_options["method"] = _get_rootfinder_default_method(method)
    return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options,
                             len(params), *params, *pfunc.objparams())
예제 #21
0
파일: symeig.py 프로젝트: mfkasim1/xitorch
def symeig(A: LinearOperator,
           neig: Optional[int] = None,
           mode: str = "lowest",
           M: Optional[LinearOperator] = None,
           bck_options: Mapping[str, Any] = {},
           method: Union[str, Callable, None] = None,
           **fwd_options) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""
    Obtain ``neig`` lowest eigenvalues and eigenvectors of a linear operator,

    .. math::

        \mathbf{AX = MXE}

    where :math:`\mathbf{A}, \mathbf{M}` are linear operators,
    :math:`\mathbf{E}` is a diagonal matrix containing the eigenvalues, and
    :math:`\mathbf{X}` is a matrix containing the eigenvectors.

    Arguments
    ---------
    A: xitorch.LinearOperator
        The linear operator object on which the eigenpairs are constructed.
        It must be a Hermitian linear operator with shape ``(*BA, q, q)``
    neig: int or None
        The number of eigenpairs to be retrieved. If ``None``, all eigenpairs are
        retrieved
    mode: str
        ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``,
        it will take the lowest ``neig`` eigenpairs.
        If ``"uppest"``, it will take the uppermost ``neig``.
    M: xitorch.LinearOperator
        The transformation on the right hand side. If ``None``, then ``M=I``.
        If specified, it must be a Hermitian with shape ``(*BM, q, q)``.
    bck_options: dict
        Method-specific options for :func:`solve` which used in backpropagation
        calculation.
    method: str or callable or None
        Method for the eigendecomposition. If ``None``, it will choose
        ``"exacteig"``.
    **fwd_options
        Method-specific options (see method section below).

    Returns
    -------
    tuple of tensors (eigenvalues, eigenvectors)
        It will return eigenvalues and eigenvectors with shapes respectively
        ``(*BAM, neig)`` and ``(*BAM, na, neig)``, where ``*BAM`` is the
        broadcasted shape of ``*BA`` and ``*BM``.
    """
    assert_runtime(A.is_hermitian, "The linear operator A must be Hermitian")
    assert_runtime(
        not torch.is_grad_enabled() or A.is_getparamnames_implemented,
        "The _getparamnames(self, prefix) of linear operator A must be "
        "implemented if using symeig with grad enabled")
    if M is not None:
        assert_runtime(M.is_hermitian,
                       "The linear operator M must be Hermitian")
        assert_runtime(
            M.shape[-1] == A.shape[-1],
            "The shape of A & M must match (A: %s, M: %s)" %
            (A.shape, M.shape))
        assert_runtime(
            not torch.is_grad_enabled() or M.is_getparamnames_implemented,
            "The _getparamnames(self, prefix) of linear operator M must be "
            "implemented if using symeig with grad enabled")
    mode = mode.lower()
    if mode == "uppermost":
        mode = "uppest"
    if method is None:
        if isinstance(A, MatrixLinearOperator) and \
           (M is None or isinstance(M, MatrixLinearOperator)):
            method = "exacteig"
        else:
            # TODO: implement robust LOBPCG and put it here
            method = "exacteig"
    if neig is None:
        neig = A.shape[-1]

    # perform expensive check if debug mode is enabled
    if is_debug_enabled():
        A.check()
        if M is not None:
            M.check()

    if method == "exacteig":
        return exacteig(A, neig, mode, M)
    else:
        fwd_options["method"] = method
        # get the unique parameters of A & M
        params = A.getlinopparams()
        mparams = M.getlinopparams() if M is not None else []
        na = len(params)
        return symeig_torchfcn.apply(A, neig, mode, M, fwd_options,
                                     bck_options, na, *params, *mparams)
예제 #22
0
def solve_ivp(fcn: Union[Callable[..., torch.Tensor],
                         Callable[..., Sequence[torch.Tensor]]],
              ts: torch.Tensor,
              y0: torch.Tensor,
              params: Sequence[Any] = [],
              bck_options: Mapping[str, Any] = {},
              method: Union[str, Callable, None] = None,
              **fwd_options) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
    r"""
    Solve the initial value problem (IVP) or also commonly known as ordinary
    differential equations (ODE), where given the initial value :math:`\mathbf{y_0}`,
    it then solves

    .. math::

        \mathbf{y}(t) = \mathbf{y_0} + \int_{t_0}^{t} \mathbf{f}(t', \mathbf{y}, \theta)\ \mathrm{d}t'

    Arguments
    ---------
    fcn: callable
        The function that represents dy/dt. The function takes an input of a
        single time ``t`` and tensor ``y`` with shape ``(*ny)`` and
        produce :math:`\mathrm{d}\mathbf{y}/\mathrm{d}t` with shape ``(*ny)``.
        The output of the function must be a tensor with shape ``(*ny)`` or
        a list of tensors.
    ts: torch.tensor
        The time points where the value of `y` will be returned.
        It must be monotonically increasing or decreasing.
        It is a tensor with shape ``(nt,)``.
    y0: torch.tensor
        The initial value of ``y``, i.e. ``y(t[0]) == y0``.
        It is a tensor with shape ``(*ny)`` or a list of tensors.
    params: list
        Sequence of other parameters required in the function.
    bck_options: dict
        Options for the backward solve_ivp method. If not specified, it will
        take the same options as fwd_options.
    method: str or callable or None
        Initial value problem solver. If None, it will choose ``"rk45"``.
    **fwd_options
        Method-specific option (see method section below).

    Returns
    -------
    torch.tensor or a list of tensors
        The values of ``y`` for each time step in ``ts``.
        It is a tensor with shape ``(nt,*ny)`` or a list of tensors
    """
    if is_debug_enabled():
        assert_fcn_params(fcn, (ts[0], y0, *params))
    assert_runtime(len(ts.shape) == 1, "Argument ts must be a 1D tensor")

    if method is None:  # set the default method
        method = "rk45"
    fwd_options["method"] = method

    # run once to see if the outputs is a tuple or a single tensor
    is_y0_list = isinstance(y0, list) or isinstance(y0, tuple)
    dydt = fcn(ts[0], y0, *params)
    is_dydt_list = isinstance(dydt, list) or isinstance(dydt, tuple)
    if is_y0_list != is_dydt_list:
        raise RuntimeError(
            "The y0 and output of fcn must both be tuple or a tensor")

    pfcn = get_pure_function(fcn)
    if is_y0_list:
        nt = len(ts)
        roller = TensorPacker(y0)

        @make_sibling(pfcn)
        def pfcn2(t, ytensor, *params):
            ylist = roller.pack(ytensor)
            res_list = pfcn(t, ylist, *params)
            res = roller.flatten(res_list)
            return res

        y0 = roller.flatten(y0)
        res = _SolveIVP.apply(pfcn2, ts, fwd_options, bck_options, len(params),
                              y0, *params, *pfcn.objparams())
        return roller.pack(res)
    else:
        return _SolveIVP.apply(pfcn, ts, fwd_options, bck_options, len(params),
                               y0, *params, *pfcn.objparams())
예제 #23
0
def solve(A:LinearOperator, B:torch.Tensor, E:Union[torch.Tensor,None]=None,
          M:Union[LinearOperator,None]=None, posdef=False,
          bck_options:Mapping[str,Any]={},
          method:Union[str,None]=None,
          **fwd_options):
    r"""
    Performing iterative method to solve the equation

    .. math::

        \mathbf{AX=B}

    or

    .. math::

        \mathbf{AX-MXE=B}

    where :math:`\mathbf{E}` is a diagonal matrix.
    This function can also solve batched multiple inverse equation at the
    same time by applying :math:`\mathbf{A}` to a tensor :math:`\mathbf{X}`
    with shape ``(...,na,ncols)``.
    The applied :math:`\mathbf{E}` are not necessarily identical for each column.

    Arguments
    ---------
    A: xitorch.LinearOperator
        A linear operator that takes an input ``X`` and produce the vectors in the same
        space as ``B``.
        It should have the shape of ``(*BA, na, na)``
    B: torch.tensor
        The tensor on the right hand side with shape ``(*BB, na, ncols)``
    E: torch.tensor or None
        If a tensor, it will solve :math:`\mathbf{AX-MXE = B}`.
        It will be regarded as the diagonal of the matrix.
        Otherwise, it just solves :math:`\mathbf{AX = B}` and ``M`` is ignored.
        If it is a tensor, it should have shape of ``(*BE, ncols)``.
    M: xitorch.LinearOperator or None
        The transformation on the ``E`` side. If ``E`` is ``None``,
        then this argument is ignored.
        If E is not ``None`` and ``M`` is ``None``, then ``M=I``.
        If LinearOperator, it must be Hermitian with shape ``(*BM, na, na)``.
    bck_options: dict
        Options of the iterative solver in the backward calculation.
    method: str or None
        Indicating the method of solve. If None, it will select ``exactsolve``.
    **fwd_options
        Method-specific options (see method below)
    """
    assert_runtime(A.shape[-1] == A.shape[-2], "The linear operator A must have a square shape")
    assert_runtime(A.shape[-1] == B.shape[-2], "Mismatch shape of A & B (A: %s, B: %s)" % (A.shape, B.shape))
    if M is not None:
        assert_runtime(M.shape[-1] == M.shape[-2], "The linear operator M must have a square shape")
        assert_runtime(M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape))
        assert_runtime(M.is_hermitian, "The linear operator M must be a Hermitian matrix")
    if E is not None:
        assert_runtime(E.shape[-1] == B.shape[-1], "The last dimension of E & B must match (E: %s, B: %s)" % (E.shape, B.shape))
    if E is None and M is not None:
        warnings.warn("M is supplied but will be ignored because E is not supplied")

    # perform expensive check if debug mode is enabled
    if is_debug_enabled():
        A.check()
        if M is not None:
            M.check()

    if method is None:
        method = "exactsolve" # TODO: do a proper method selection based on the size

    if method == "exactsolve":
        return exactsolve(A, B, E, M)
    else:
        fwd_options["method"] = method
        # get the unique parameters of A
        params = A.getlinopparams()
        mparams = M.getlinopparams() if M is not None else []
        na = len(params)
        return solve_torchfcn.apply(
            A, B, E, M, posdef,
            fwd_options, bck_options,
            na, *params, *mparams)