def custom_exactsolve(A, params, B, E=None, M=None, mparams=[], **options): # A: (*BA, na, na) # B: (*BB, na, ncols) # E: (*BE, ncols) # M: (*BM, na, na) with A.uselinopparams(*params), M.uselinopparams(*mparams) if M is not None else dummy_context_manager(): return exactsolve(A, B, E, M)
def custom_exactsolve(A, B, E=None, M=None, **options): # A: (*BA, na, na) # B: (*BB, na, ncols) # E: (*BE, ncols) # M: (*BM, na, na) return exactsolve(A, B, E, M)
def solve(A: LinearOperator, B: torch.Tensor, E: Union[torch.Tensor, None] = None, M: Optional[LinearOperator] = None, bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> torch.Tensor: r""" Performing iterative method to solve the equation .. math:: \mathbf{AX=B} or .. math:: \mathbf{AX-MXE=B} where :math:`\mathbf{E}` is a diagonal matrix. This function can also solve batched multiple inverse equation at the same time by applying :math:`\mathbf{A}` to a tensor :math:`\mathbf{X}` with shape ``(...,na,ncols)``. The applied :math:`\mathbf{E}` are not necessarily identical for each column. Arguments --------- A: xitorch.LinearOperator A linear operator that takes an input ``X`` and produce the vectors in the same space as ``B``. It should have the shape of ``(*BA, na, na)`` B: torch.Tensor The tensor on the right hand side with shape ``(*BB, na, ncols)`` E: torch.Tensor or None If a tensor, it will solve :math:`\mathbf{AX-MXE = B}`. It will be regarded as the diagonal of the matrix. Otherwise, it just solves :math:`\mathbf{AX = B}` and ``M`` is ignored. If it is a tensor, it should have shape of ``(*BE, ncols)``. M: xitorch.LinearOperator or None The transformation on the ``E`` side. If ``E`` is ``None``, then this argument is ignored. If E is not ``None`` and ``M`` is ``None``, then ``M=I``. If LinearOperator, it must be Hermitian with shape ``(*BM, na, na)``. bck_options: dict Options of the iterative solver in the backward calculation. method: str or callable or None The method of linear equation solver. If ``None``, it will choose ``"cg"`` or ``"bicgstab"`` based on the matrices symmetry. `Note`: default method will be changed quite frequently, so if you want future compatibility, please specify a method. **fwd_options Method-specific options (see method below) Returns ------- torch.Tensor The tensor :math:`\mathbf{X}` that satisfies :math:`\mathbf{AX-MXE=B}`. """ assert_runtime(A.shape[-1] == A.shape[-2], "The linear operator A must have a square shape") assert_runtime( A.shape[-1] == B.shape[-2], "Mismatch shape of A & B (A: %s, B: %s)" % (A.shape, B.shape)) assert_runtime( not torch.is_grad_enabled() or A.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator A must be " "implemented if using solve with grad enabled") if M is not None: assert_runtime(M.shape[-1] == M.shape[-2], "The linear operator M must have a square shape") assert_runtime( M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape)) assert_runtime(M.is_hermitian, "The linear operator M must be a Hermitian matrix") assert_runtime( not torch.is_grad_enabled() or M.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator M must be " "implemented if using solve with grad enabled") if E is not None: assert_runtime( E.shape[-1] == B.shape[-1], "The last dimension of E & B must match (E: %s, B: %s)" % (E.shape, B.shape)) if E is None and M is not None: warnings.warn( "M is supplied but will be ignored because E is not supplied") # perform expensive check if debug mode is enabled if is_debug_enabled(): A.check() if M is not None: M.check() if method is None: if isinstance(A, MatrixLinearOperator) and \ (M is None or isinstance(M, MatrixLinearOperator)): method = "exactsolve" else: is_hermit = A.is_hermitian and (M is None or M.is_hermitian) method = "cg" if is_hermit else "bicgstab" if method == "exactsolve": return exactsolve(A, B, E, M) else: # get the unique parameters of A params = A.getlinopparams() mparams = M.getlinopparams() if M is not None else [] na = len(params) return solve_torchfcn.apply(A, B, E, M, method, fwd_options, bck_options, na, *params, *mparams)
def solve(A:LinearOperator, B:torch.Tensor, E:Union[torch.Tensor,None]=None, M:Union[LinearOperator,None]=None, posdef=False, bck_options:Mapping[str,Any]={}, method:Union[str,None]=None, **fwd_options): r""" Performing iterative method to solve the equation .. math:: \mathbf{AX=B} or .. math:: \mathbf{AX-MXE=B} where :math:`\mathbf{E}` is a diagonal matrix. This function can also solve batched multiple inverse equation at the same time by applying :math:`\mathbf{A}` to a tensor :math:`\mathbf{X}` with shape ``(...,na,ncols)``. The applied :math:`\mathbf{E}` are not necessarily identical for each column. Arguments --------- A: xitorch.LinearOperator A linear operator that takes an input ``X`` and produce the vectors in the same space as ``B``. It should have the shape of ``(*BA, na, na)`` B: torch.tensor The tensor on the right hand side with shape ``(*BB, na, ncols)`` E: torch.tensor or None If a tensor, it will solve :math:`\mathbf{AX-MXE = B}`. It will be regarded as the diagonal of the matrix. Otherwise, it just solves :math:`\mathbf{AX = B}` and ``M`` is ignored. If it is a tensor, it should have shape of ``(*BE, ncols)``. M: xitorch.LinearOperator or None The transformation on the ``E`` side. If ``E`` is ``None``, then this argument is ignored. If E is not ``None`` and ``M`` is ``None``, then ``M=I``. If LinearOperator, it must be Hermitian with shape ``(*BM, na, na)``. bck_options: dict Options of the iterative solver in the backward calculation. method: str or None Indicating the method of solve. If None, it will select ``exactsolve``. **fwd_options Method-specific options (see method below) """ assert_runtime(A.shape[-1] == A.shape[-2], "The linear operator A must have a square shape") assert_runtime(A.shape[-1] == B.shape[-2], "Mismatch shape of A & B (A: %s, B: %s)" % (A.shape, B.shape)) if M is not None: assert_runtime(M.shape[-1] == M.shape[-2], "The linear operator M must have a square shape") assert_runtime(M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape)) assert_runtime(M.is_hermitian, "The linear operator M must be a Hermitian matrix") if E is not None: assert_runtime(E.shape[-1] == B.shape[-1], "The last dimension of E & B must match (E: %s, B: %s)" % (E.shape, B.shape)) if E is None and M is not None: warnings.warn("M is supplied but will be ignored because E is not supplied") # perform expensive check if debug mode is enabled if is_debug_enabled(): A.check() if M is not None: M.check() if method is None: method = "exactsolve" # TODO: do a proper method selection based on the size if method == "exactsolve": return exactsolve(A, B, E, M) else: fwd_options["method"] = method # get the unique parameters of A params = A.getlinopparams() mparams = M.getlinopparams() if M is not None else [] na = len(params) return solve_torchfcn.apply( A, B, E, M, posdef, fwd_options, bck_options, na, *params, *mparams)