示例#1
0
def custom_exactsolve(A, params, B, E=None,
                M=None, mparams=[], **options):
    # A: (*BA, na, na)
    # B: (*BB, na, ncols)
    # E: (*BE, ncols)
    # M: (*BM, na, na)
    with A.uselinopparams(*params), M.uselinopparams(*mparams) if M is not None else dummy_context_manager():
        return exactsolve(A, B, E, M)
示例#2
0
def custom_exactsolve(A, B, E=None, M=None, **options):
    # A: (*BA, na, na)
    # B: (*BB, na, ncols)
    # E: (*BE, ncols)
    # M: (*BM, na, na)
    return exactsolve(A, B, E, M)
示例#3
0
def solve(A: LinearOperator,
          B: torch.Tensor,
          E: Union[torch.Tensor, None] = None,
          M: Optional[LinearOperator] = None,
          bck_options: Mapping[str, Any] = {},
          method: Union[str, Callable, None] = None,
          **fwd_options) -> torch.Tensor:
    r"""
    Performing iterative method to solve the equation

    .. math::

        \mathbf{AX=B}

    or

    .. math::

        \mathbf{AX-MXE=B}

    where :math:`\mathbf{E}` is a diagonal matrix.
    This function can also solve batched multiple inverse equation at the
    same time by applying :math:`\mathbf{A}` to a tensor :math:`\mathbf{X}`
    with shape ``(...,na,ncols)``.
    The applied :math:`\mathbf{E}` are not necessarily identical for each column.

    Arguments
    ---------
    A: xitorch.LinearOperator
        A linear operator that takes an input ``X`` and produce the vectors in the same
        space as ``B``.
        It should have the shape of ``(*BA, na, na)``
    B: torch.Tensor
        The tensor on the right hand side with shape ``(*BB, na, ncols)``
    E: torch.Tensor or None
        If a tensor, it will solve :math:`\mathbf{AX-MXE = B}`.
        It will be regarded as the diagonal of the matrix.
        Otherwise, it just solves :math:`\mathbf{AX = B}` and ``M`` is ignored.
        If it is a tensor, it should have shape of ``(*BE, ncols)``.
    M: xitorch.LinearOperator or None
        The transformation on the ``E`` side. If ``E`` is ``None``,
        then this argument is ignored.
        If E is not ``None`` and ``M`` is ``None``, then ``M=I``.
        If LinearOperator, it must be Hermitian with shape ``(*BM, na, na)``.
    bck_options: dict
        Options of the iterative solver in the backward calculation.
    method: str or callable or None
        The method of linear equation solver. If ``None``, it will choose
        ``"cg"`` or ``"bicgstab"`` based on the matrices symmetry.
        `Note`: default method will be changed quite frequently, so if you want
        future compatibility, please specify a method.
    **fwd_options
        Method-specific options (see method below)

    Returns
    -------
    torch.Tensor
        The tensor :math:`\mathbf{X}` that satisfies :math:`\mathbf{AX-MXE=B}`.
    """
    assert_runtime(A.shape[-1] == A.shape[-2],
                   "The linear operator A must have a square shape")
    assert_runtime(
        A.shape[-1] == B.shape[-2],
        "Mismatch shape of A & B (A: %s, B: %s)" % (A.shape, B.shape))
    assert_runtime(
        not torch.is_grad_enabled() or A.is_getparamnames_implemented,
        "The _getparamnames(self, prefix) of linear operator A must be "
        "implemented if using solve with grad enabled")
    if M is not None:
        assert_runtime(M.shape[-1] == M.shape[-2],
                       "The linear operator M must have a square shape")
        assert_runtime(
            M.shape[-1] == A.shape[-1],
            "The shape of A & M must match (A: %s, M: %s)" %
            (A.shape, M.shape))
        assert_runtime(M.is_hermitian,
                       "The linear operator M must be a Hermitian matrix")
        assert_runtime(
            not torch.is_grad_enabled() or M.is_getparamnames_implemented,
            "The _getparamnames(self, prefix) of linear operator M must be "
            "implemented if using solve with grad enabled")
    if E is not None:
        assert_runtime(
            E.shape[-1] == B.shape[-1],
            "The last dimension of E & B must match (E: %s, B: %s)" %
            (E.shape, B.shape))
    if E is None and M is not None:
        warnings.warn(
            "M is supplied but will be ignored because E is not supplied")

    # perform expensive check if debug mode is enabled
    if is_debug_enabled():
        A.check()
        if M is not None:
            M.check()

    if method is None:
        if isinstance(A, MatrixLinearOperator) and \
           (M is None or isinstance(M, MatrixLinearOperator)):
            method = "exactsolve"
        else:
            is_hermit = A.is_hermitian and (M is None or M.is_hermitian)
            method = "cg" if is_hermit else "bicgstab"

    if method == "exactsolve":
        return exactsolve(A, B, E, M)
    else:
        # get the unique parameters of A
        params = A.getlinopparams()
        mparams = M.getlinopparams() if M is not None else []
        na = len(params)
        return solve_torchfcn.apply(A, B, E, M, method, fwd_options,
                                    bck_options, na, *params, *mparams)
示例#4
0
def solve(A:LinearOperator, B:torch.Tensor, E:Union[torch.Tensor,None]=None,
          M:Union[LinearOperator,None]=None, posdef=False,
          bck_options:Mapping[str,Any]={},
          method:Union[str,None]=None,
          **fwd_options):
    r"""
    Performing iterative method to solve the equation

    .. math::

        \mathbf{AX=B}

    or

    .. math::

        \mathbf{AX-MXE=B}

    where :math:`\mathbf{E}` is a diagonal matrix.
    This function can also solve batched multiple inverse equation at the
    same time by applying :math:`\mathbf{A}` to a tensor :math:`\mathbf{X}`
    with shape ``(...,na,ncols)``.
    The applied :math:`\mathbf{E}` are not necessarily identical for each column.

    Arguments
    ---------
    A: xitorch.LinearOperator
        A linear operator that takes an input ``X`` and produce the vectors in the same
        space as ``B``.
        It should have the shape of ``(*BA, na, na)``
    B: torch.tensor
        The tensor on the right hand side with shape ``(*BB, na, ncols)``
    E: torch.tensor or None
        If a tensor, it will solve :math:`\mathbf{AX-MXE = B}`.
        It will be regarded as the diagonal of the matrix.
        Otherwise, it just solves :math:`\mathbf{AX = B}` and ``M`` is ignored.
        If it is a tensor, it should have shape of ``(*BE, ncols)``.
    M: xitorch.LinearOperator or None
        The transformation on the ``E`` side. If ``E`` is ``None``,
        then this argument is ignored.
        If E is not ``None`` and ``M`` is ``None``, then ``M=I``.
        If LinearOperator, it must be Hermitian with shape ``(*BM, na, na)``.
    bck_options: dict
        Options of the iterative solver in the backward calculation.
    method: str or None
        Indicating the method of solve. If None, it will select ``exactsolve``.
    **fwd_options
        Method-specific options (see method below)
    """
    assert_runtime(A.shape[-1] == A.shape[-2], "The linear operator A must have a square shape")
    assert_runtime(A.shape[-1] == B.shape[-2], "Mismatch shape of A & B (A: %s, B: %s)" % (A.shape, B.shape))
    if M is not None:
        assert_runtime(M.shape[-1] == M.shape[-2], "The linear operator M must have a square shape")
        assert_runtime(M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape))
        assert_runtime(M.is_hermitian, "The linear operator M must be a Hermitian matrix")
    if E is not None:
        assert_runtime(E.shape[-1] == B.shape[-1], "The last dimension of E & B must match (E: %s, B: %s)" % (E.shape, B.shape))
    if E is None and M is not None:
        warnings.warn("M is supplied but will be ignored because E is not supplied")

    # perform expensive check if debug mode is enabled
    if is_debug_enabled():
        A.check()
        if M is not None:
            M.check()

    if method is None:
        method = "exactsolve" # TODO: do a proper method selection based on the size

    if method == "exactsolve":
        return exactsolve(A, B, E, M)
    else:
        fwd_options["method"] = method
        # get the unique parameters of A
        params = A.getlinopparams()
        mparams = M.getlinopparams() if M is not None else []
        na = len(params)
        return solve_torchfcn.apply(
            A, B, E, M, posdef,
            fwd_options, bck_options,
            na, *params, *mparams)