def check(self, warn: Union[bool, None] = None) -> None: """ Perform checks to make sure the linear operator behaves as a proper linear operator. Arguments --------- * warn: bool or None If True, then raises a warning to the user that the check might slow down the program. This is to remind the user to turn off the check when not in a debugging mode. If None, it will raise a warning if it runs not in a debug mode, but will be silent if it runs in a debug mode. Exceptions ---------- * RuntimeError Raised if an error is raised when performing linear operations of the object (e.g. calling .mv(), .mm(), etc) * AssertionError Raised if the linear operations do not behave as proper linear operations. (e.g. not scaling linearly) """ if warn is None: warn = not is_debug_enabled() if warn: msg = "The linear operator check is performed. This might slow down your program." warnings.warn(msg, stacklevel=2) checklinop(self)
def check(self, warn: Optional[bool] = None) -> None: """ Perform checks to make sure the ``LinearOperator`` behaves as a proper linear operator. Arguments --------- warn: bool or None If ``True``, then raises a warning to the user that the check might slow down the program. This is to remind the user to turn off the check when not in a debugging mode. If ``None``, it will raise a warning if it runs not in a debug mode, but will be silent if it runs in a debug mode. Raises ------ RuntimeError Raised if an error is raised when performing linear operations of the object (e.g. calling ``.mv()``, ``.mm()``, etc) AssertionError Raised if the linear operations do not behave as proper linear operations. (e.g. not scaling linearly) """ if warn is None: warn = not is_debug_enabled() if warn: msg = "The linear operator check is performed. This might slow down your program." warnings.warn(msg, stacklevel=2) checklinop(self) print("Check linear operator done")
def minimize(fcn: Callable[..., torch.Tensor], y0: torch.Tensor, params: Sequence[Any] = [], bck_options: Mapping[str, Any] = {}, method: Union[str, None] = None, **fwd_options) -> torch.Tensor: """ Solve the unbounded minimization problem: .. math:: \mathbf{y^*} = \\arg\min_\mathbf{y} f(\mathbf{y}, \\theta) to find the best :math:`\mathbf{y}` that minimizes the output of the function :math:`f`. Arguments --------- fcn: callable The function to be optimized with output tensor with 1 element. y0: torch.tensor Initial guess of the solution with shape ``(*ny)`` params: list List of any other parameters to be put in ``fcn`` bck_options: dict Method-specific options for the backward solve. method: str or None Minimization method. **fwd_options Method-specific options (see method section) Returns ------- torch.tensor The solution of the minimization with shape ``(*ny)`` """ # perform implementation check if debug mode is enabled if is_debug_enabled(): assert_fcn_params(fcn, (y0, *params)) pfunc = get_pure_function(fcn) fwd_options["method"] = _get_minimizer_default_method(method) # the rootfinder algorithms are designed to move to the opposite direction # of the output of the function, so the output of this function is just # the grad of z w.r.t. y @make_sibling(pfunc) def new_fcn(y, *params): with torch.enable_grad(): y1 = y.clone().requires_grad_() z = pfunc(y1, *params) grady, = torch.autograd.grad(z, (y1, ), retain_graph=True, create_graph=torch.is_grad_enabled()) return grady return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options, len(params), *params, *pfunc.objparams())
def equilibrium(fcn: Callable[..., torch.Tensor], y0: torch.Tensor, params: Sequence[Any] = [], bck_options: Mapping[str, Any] = {}, method: Union[str, None] = None, **fwd_options): """ Solving the equilibrium equation of a given function, .. math:: \mathbf{y} = \mathbf{f}(\mathbf{y}, \\theta) where :math:`\mathbf{f}` is a function that can be non-linear and produce output of the same shape of :math:`\mathbf{y}`, and :math:`\\theta` is other parameters required in the function. The output of this block is :math:`\mathbf{y}` that produces the same :math:`\mathbf{y}` as the output. Arguments --------- fcn : callable The function :math:`\mathbf{f}` with output tensor ``(*ny)`` y0 : torch.tensor Initial guess of the solution with shape ``(*ny)`` params : list List of any other parameters to be put in ``fcn`` bck_options : dict Method-specific options for the backward solve method : str or None Rootfinder method. **fwd_options Method-specific options (see method section) Returns ------- torch.tensor The solution which satisfies :math:`\mathbf{y} = \mathbf{f}(\mathbf{y},\\theta)` with shape ``(*ny)`` """ # perform implementation check if debug mode is enabled if is_debug_enabled(): assert_fcn_params(fcn, (y0, *params)) pfunc = get_pure_function(fcn) @make_sibling(pfunc) def new_fcn(y, *params): return y - pfunc(y, *params) fwd_options["method"] = _get_rootfinder_default_method(method) return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options, len(params), *params, *pfunc.getobjparams())
def minimize(fcn: Callable[..., torch.Tensor], y0: torch.Tensor, params: Sequence[Any] = [], fwd_options: Mapping[str, Any] = {}, bck_options: Mapping[str, Any] = {}) -> torch.Tensor: """ Solve the minimization problem: z = (argmin_y) fcn(y, *params) to find the best `y` that minimizes the output of the function `fcn`. The output of `fcn` must be a single element tensor. Arguments --------- * fcn: callable with output tensor (numel=1) The function * y0: torch.tensor with shape (*ny) Initial guess of the solution * params: list List of any other parameters to be put in fcn * fwd_options: dict Options for the minimizer method * bck_options: dict Options for the backward solve method Returns ------- * y: torch.tensor with shape (*ny) The solution of the minimization """ # perform implementation check if debug mode is enabled if is_debug_enabled(): assert_fcn_params(fcn, (y0, *params)) pfunc = get_pure_function(fcn) # the rootfinder algorithms are designed to move to the opposite direction # of the output of the function, so the output of this function is just # the grad of z w.r.t. y @make_sibling(pfunc) def new_fcn(y, *params): with torch.enable_grad(): y1 = y.clone().requires_grad_() z = pfunc(y1, *params) grady, = torch.autograd.grad(z, (y1, ), retain_graph=True, create_graph=torch.is_grad_enabled()) return grady return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options, len(params), *params, *pfunc.objparams())
def equilibrium(fcn: Callable[..., torch.Tensor], y0: torch.Tensor, params: Sequence[Any] = [], fwd_options: Mapping[str, Any] = {}, bck_options: Mapping[str, Any] = {}): """ Solving the equilibrium equation of a given function, y = fcn(y, *params) where `fcn` is a function that can be non-linear and produce output of shape `y`. The output of this block is `y` that produces the 0 as the output. Arguments --------- * fcn: callable with output tensor (*ny) The function * y0: torch.tensor with shape (*ny) Initial guess of the solution * params: list List of any other parameters to be put in fcn * fwd_options: dict Options for the rootfinder method * bck_options: dict Options for the backward solve method Returns ------- * yout: torch.tensor with shape (*ny) The solution which satisfies yout = fcn(yout) Note ---- * To obtain the correct gradient and higher order gradients, the fcn must be: - a torch.nn.Module with fcn.parameters() list the tensors that determine the output of the fcn. - a method in xt.EditableModule object with no out-of-scope parameters. - a function with no out-of-scope parameters. """ # perform implementation check if debug mode is enabled if is_debug_enabled(): assert_fcn_params(fcn, (y0, *params)) pfunc = get_pure_function(fcn) @make_sibling(pfunc) def new_fcn(y, *params): return y - pfunc(y, *params) return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options, len(params), *params, *pfunc.getobjparams())
def backward(ctx, grad_eival, grad_eivec): in_debug_mode = is_debug_enabled() eival, eivec = ctx.saved_tensors min_threshold = torch.finfo(eival.dtype).eps ** 0.6 eivect = eivec.transpose(-2, -1) # remove the degenerate part # see https://arxiv.org/pdf/2011.04366.pdf if grad_eivec is not None: # take the contribution from the eivec F = eival.unsqueeze(-2) - eival.unsqueeze(-1) idx = torch.abs(F) < min_threshold F[idx] = float("inf") # if in debug mode, check the degeneracy requirements if in_debug_mode: degenerate = torch.any(idx) xtg = eivect @ grad_eivec diff_xtg = (xtg - xtg.transpose(-2, -1))[idx] reqsat = torch.allclose(diff_xtg, torch.zeros_like(diff_xtg)) # if the requirement is not satisfied, mathematically the derivative # should be `nan`, but here we just raise a warning if not reqsat: msg = ("Degeneracy appears but the loss function seem to depend " "strongly on the eigenvector. The gradient might be incorrect.\n") msg += "Eigenvalues:\n%s\n" % str(eival) msg += "Degenerate map:\n%s\n" % str(idx) msg += "Requirements (should be all 0s):\n%s" % str(diff_xtg) warnings.warn(MathWarning(msg)) F = F.pow(-1) F = F * torch.matmul(eivect, grad_eivec) result = torch.matmul(eivec, torch.matmul(F, eivect)) else: result = torch.zeros_like(eivec) # calculate the contribution from the eival if grad_eival is not None: result += torch.matmul(eivec, grad_eival.unsqueeze(-1) * eivect) # symmetrize to reduce numerical instability result = (result + result.transpose(-2, -1)) * 0.5 return result
def _mcquad(ffcn, log_pfcn, x0, xsamples, wsamples, fparams, pparams, fwd_options, bck_options): # this is mcquad with an additional xsamples argument, to prevent xsamples being set by users if is_debug_enabled(): assert_fcn_params(ffcn, (x0, *fparams)) assert_fcn_params(log_pfcn, (x0, *pparams)) # check if ffcn produces a list / tuple out = ffcn(x0, *fparams) is_tuple_out = isinstance(out, list) or isinstance(out, tuple) # get the pure functions pure_ffcn = get_pure_function(ffcn) pure_logpfcn = get_pure_function(log_pfcn) nfparams = len(fparams) npparams = len(pparams) fobjparams = pure_ffcn.objparams() pobjparams = pure_logpfcn.objparams() nf_objparams = len(fobjparams) if is_tuple_out: packer = TensorPacker(out) @make_sibling(pure_ffcn) def pure_ffcn2(x, *fparams): y = pure_ffcn(x, *fparams) return packer.flatten(y) res = _MCQuad.apply(pure_ffcn2, pure_logpfcn, x0, None, None, fwd_options, bck_options, nfparams, nf_objparams, npparams, *fparams, *fobjparams, *pparams, *pobjparams) return packer.pack(res) else: return _MCQuad.apply(pure_ffcn, pure_logpfcn, x0, None, None, fwd_options, bck_options, nfparams, nf_objparams, npparams, *fparams, *fobjparams, *pparams, *pobjparams)
def svd(A: LinearOperator, k: Optional[int] = None, mode: str = "uppest", bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r""" Perform the singular value decomposition (SVD): .. math:: \mathbf{A} = \mathbf{U\Sigma V}^H where :math:`\mathbf{U}` and :math:`\mathbf{V}` are semi-unitary matrix and :math:`\mathbf{\Sigma}` is a diagonal matrix containing real non-negative numbers. Arguments --------- A: xitorch.LinearOperator The linear operator to be decomposed. It has a shape of ``(*BA, m, n)`` where ``(*BA)`` is the batched dimension of ``A``. k: int or None The number of decomposition obtained. If ``None``, it will be ``min(*A.shape[-2:])`` mode: str ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``, it will take the lowest ``k`` decomposition. If ``"uppest"``, it will take the uppermost ``k``. bck_options: dict Method-specific options for :func:`solve` which used in backpropagation calculation. method: str or callable or None Method for the svd (same options for :func:`symeig`). If ``None``, it will choose ``"exacteig"``. **fwd_options Method-specific options (see method section below). Returns ------- tuple of tensors (u, s, vh) It will return ``u, s, vh`` with shapes respectively ``(*BA, m, k)``, ``(*BA, k)``, and ``(*BA, k, n)``. Note ---- It is a naive implementation of symmetric eigendecomposition of ``A.H @ A`` or ``A @ A.H`` (depending which one is cheaper) Warnings -------- * If ``s`` contains very small numbers or degenerate values, the calculation and its gradient might be inaccurate. * The second derivative through U or V might be unstable. Extra care must be taken. """ # A: (*BA, m, n) # adapted from scipy.sparse.linalg.svds if is_debug_enabled(): A.check() BA = A.shape[:-2] m = A.shape[-2] n = A.shape[-1] if m < n: AAsym = A.matmul(A.H, is_hermitian=True) min_nm = m else: AAsym = A.H.matmul(A, is_hermitian=True) min_nm = n eivals, eivecs = symeig(AAsym, k, mode, bck_options=bck_options, method=method, **fwd_options) # (*BA, k) and (*BA, min(mn), k) # clamp the eigenvalues to a small positive values to avoid numerical # instability eivals = torch.clamp(eivals, min=0.0) s = torch.sqrt(eivals) # (*BA, k) sdiv = torch.clamp(s, min=1e-12).unsqueeze(-2) # (*BA, 1, k) if m < n: u = eivecs # (*BA, m, k) v = A.rmm(u) / sdiv # (*BA, n, k) else: v = eivecs # (*BA, n, k) u = A.mm(v) / sdiv # (*BA, m, k) vh = v.transpose(-2, -1) return u, s, vh
def quad(fcn: Union[Callable[..., torch.Tensor], Callable[..., Sequence[torch.Tensor]]], xl: Union[float, int, torch.Tensor], xu: Union[float, int, torch.Tensor], params: Sequence[Any] = [], bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> Union[torch.Tensor, Sequence[torch.Tensor]]: r""" Calculate the quadrature: .. math:: y = \int_{x_l}^{x_u} f(x, \theta)\ \mathrm{d}x Arguments --------- fcn: callable The function to be integrated. Its output must be a tensor with shape ``(*nout)`` or list of tensors. xl: float, int or 1-element torch.Tensor The lower bound of the integration. xu: float, int or 1-element torch.Tensor The upper bound of the integration. params: list Sequence of any other parameters for the function ``fcn``. bck_options: dict Options for the backward quadrature method. method: str or callable or None Quadrature method. If None, it will choose ``"leggauss"``. **fwd_options Method-specific options (see method section). Returns ------- torch.tensor or a list of tensors The quadrature results with shape ``(*nout)`` or list of tensors. """ # perform implementation check if debug mode is enabled if is_debug_enabled(): assert_fcn_params(fcn, (xl, *params)) if isinstance(xl, torch.Tensor): assert_runtime(torch.numel(xl) == 1, "xl must be a 1-element tensors") if isinstance(xu, torch.Tensor): assert_runtime(torch.numel(xu) == 1, "xu must be a 1-element tensors") if method is None: method = "leggauss" fwd_options["method"] = method out = fcn(xl, *params) if isinstance(out, torch.Tensor): dtype = out.dtype device = out.device is_tuple_out = False elif len(out) > 0: dtype = out[0].dtype device = out[0].device is_tuple_out = True else: raise RuntimeError("The output of the fcn must be non-empty") pfunc = get_pure_function(fcn) nparams = len(params) if is_tuple_out: packer = TensorPacker(out) @make_sibling(pfunc) def pfunc2(x, *params): y = fcn(x, *params) return packer.flatten(y) res = _Quadrature.apply(pfunc2, xl, xu, fwd_options, bck_options, nparams, dtype, device, *params, *pfunc.objparams()) return packer.pack(res) else: return _Quadrature.apply(pfunc, xl, xu, fwd_options, bck_options, nparams, dtype, device, *params, *pfunc.objparams())
def solve(A: LinearOperator, B: torch.Tensor, E: Union[torch.Tensor, None] = None, M: Optional[LinearOperator] = None, bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> torch.Tensor: r""" Performing iterative method to solve the equation .. math:: \mathbf{AX=B} or .. math:: \mathbf{AX-MXE=B} where :math:`\mathbf{E}` is a diagonal matrix. This function can also solve batched multiple inverse equation at the same time by applying :math:`\mathbf{A}` to a tensor :math:`\mathbf{X}` with shape ``(...,na,ncols)``. The applied :math:`\mathbf{E}` are not necessarily identical for each column. Arguments --------- A: xitorch.LinearOperator A linear operator that takes an input ``X`` and produce the vectors in the same space as ``B``. It should have the shape of ``(*BA, na, na)`` B: torch.Tensor The tensor on the right hand side with shape ``(*BB, na, ncols)`` E: torch.Tensor or None If a tensor, it will solve :math:`\mathbf{AX-MXE = B}`. It will be regarded as the diagonal of the matrix. Otherwise, it just solves :math:`\mathbf{AX = B}` and ``M`` is ignored. If it is a tensor, it should have shape of ``(*BE, ncols)``. M: xitorch.LinearOperator or None The transformation on the ``E`` side. If ``E`` is ``None``, then this argument is ignored. If E is not ``None`` and ``M`` is ``None``, then ``M=I``. If LinearOperator, it must be Hermitian with shape ``(*BM, na, na)``. bck_options: dict Options of the iterative solver in the backward calculation. method: str or callable or None The method of linear equation solver. If ``None``, it will choose ``"cg"`` or ``"bicgstab"`` based on the matrices symmetry. `Note`: default method will be changed quite frequently, so if you want future compatibility, please specify a method. **fwd_options Method-specific options (see method below) Returns ------- torch.Tensor The tensor :math:`\mathbf{X}` that satisfies :math:`\mathbf{AX-MXE=B}`. """ assert_runtime(A.shape[-1] == A.shape[-2], "The linear operator A must have a square shape") assert_runtime( A.shape[-1] == B.shape[-2], "Mismatch shape of A & B (A: %s, B: %s)" % (A.shape, B.shape)) assert_runtime( not torch.is_grad_enabled() or A.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator A must be " "implemented if using solve with grad enabled") if M is not None: assert_runtime(M.shape[-1] == M.shape[-2], "The linear operator M must have a square shape") assert_runtime( M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape)) assert_runtime(M.is_hermitian, "The linear operator M must be a Hermitian matrix") assert_runtime( not torch.is_grad_enabled() or M.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator M must be " "implemented if using solve with grad enabled") if E is not None: assert_runtime( E.shape[-1] == B.shape[-1], "The last dimension of E & B must match (E: %s, B: %s)" % (E.shape, B.shape)) if E is None and M is not None: warnings.warn( "M is supplied but will be ignored because E is not supplied") # perform expensive check if debug mode is enabled if is_debug_enabled(): A.check() if M is not None: M.check() if method is None: if isinstance(A, MatrixLinearOperator) and \ (M is None or isinstance(M, MatrixLinearOperator)): method = "exactsolve" else: is_hermit = A.is_hermitian and (M is None or M.is_hermitian) method = "cg" if is_hermit else "bicgstab" if method == "exactsolve": return exactsolve(A, B, E, M) else: # get the unique parameters of A params = A.getlinopparams() mparams = M.getlinopparams() if M is not None else [] na = len(params) return solve_torchfcn.apply(A, B, E, M, method, fwd_options, bck_options, na, *params, *mparams)
def solve_ivp(fcn: Callable[..., torch.Tensor], ts: torch.Tensor, y0: torch.Tensor, params: Sequence[Any] = [], fwd_options: Mapping[str, Any] = {}, bck_options: Mapping[str, Any] = {}) -> torch.Tensor: """ Solve the initial value problem (IVP) which given the initial value `y0`, the function is then solve y(t) = y0 + int_t0^t f(t', y, *params) dt' Arguments --------- * fcn: callable with output a tensor with shape (*ny) or a list of tensors The function that represents dy/dt. The function takes an input of a single time `t` and `y` with shape (*ny) and produce dydt with shape (*ny). * ts: torch.tensor with shape (nt,) The time points where the value of `y` is returned. It must be monotonically increasing or decreasing. * y0: torch.tensor with shape (*ny) or a list of tensors The initial value of y, i.e. y(t[0]) == y0 * params: list List of other parameters required in the function. * fwd_options: dict Options for the forward solve_ivp method. * bck_options: dict Options for the backward solve_ivp method. Returns ------- * yt: torch.tensor with shape (nt,*ny) or a list of tensors The values of `y` for each time step in `ts`. """ if is_debug_enabled(): assert_fcn_params(fcn, (ts[0], y0, *params)) assert_runtime(len(ts.shape) == 1, "Argument ts must be a 1D tensor") # run once to see if the outputs is a tuple or a single tensor is_y0_list = isinstance(y0, list) or isinstance(y0, tuple) dydt = fcn(ts[0], y0, *params) is_dydt_list = isinstance(dydt, list) or isinstance(dydt, tuple) if is_y0_list != is_dydt_list: raise RuntimeError( "The y0 and output of fcn must both be tuple or a tensor") pfcn = get_pure_function(fcn) if is_y0_list: nt = len(ts) roller = TensorPacker(y0) @make_sibling(pfcn) def pfcn2(t, ytensor, *params): ylist = roller.pack(ytensor) res_list = pfcn(t, ylist, *params) res = roller.flatten(res_list) return res y0 = roller.flatten(y0) res = _SolveIVP.apply(pfcn2, ts, fwd_options, bck_options, len(params), y0, *params, *pfcn.objparams()) return roller.pack(res) else: return _SolveIVP.apply(pfcn, ts, fwd_options, bck_options, len(params), y0, *params, *pfcn.objparams())
def svd(A: LinearOperator, k: Optional[int] = None, mode: str = "uppest", bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r""" Perform the singular value decomposition (SVD): .. math:: \mathbf{A} = \mathbf{U\Sigma V}^H where :math:`\mathbf{U}` and :math:`\mathbf{V}` are semi-unitary matrix and :math:`\mathbf{\Sigma}` is a diagonal matrix containing real non-negative numbers. This function can handle derivatives for degenerate singular values by setting non-zero ``degen_atol`` and ``degen_rtol`` in the backward option using the expressions in [1]_. Arguments --------- A: xitorch.LinearOperator The linear operator to be decomposed. It has a shape of ``(*BA, m, n)`` where ``(*BA)`` is the batched dimension of ``A``. k: int or None The number of decomposition obtained. If ``None``, it will be ``min(*A.shape[-2:])`` mode: str ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``, it will take the lowest ``k`` decomposition. If ``"uppest"``, it will take the uppermost ``k``. bck_options: dict Method-specific options for :func:`solve` which used in backpropagation calculation with some additional arguments for computing the backward derivatives: * ``degen_atol`` (``float`` or None): Minimum absolute difference between two singular values to be treated as degenerate. If None, it is ``torch.finfo(dtype).eps**0.6``. If 0.0, no special treatment on degeneracy is applied. (default: None) * ``degen_rtol`` (``float`` or None): Minimum relative difference between two singular values to be treated as degenerate. If None, it is ``torch.finfo(dtype).eps**0.4``. If 0.0, no special treatment on degeneracy is applied. (default: None) Note: the default values of ``degen_atol`` and ``degen_rtol`` are going to change in the future. So, for future compatibility, please specify the specific values. method: str or callable or None Method for the svd (same options for :func:`symeig`). If ``None``, it will choose ``"exacteig"``. **fwd_options Method-specific options (see method section below). Returns ------- tuple of tensors (u, s, vh) It will return ``u, s, vh`` with shapes respectively ``(*BA, m, k)``, ``(*BA, k)``, and ``(*BA, k, n)``. Note ---- It is a naive implementation of symmetric eigendecomposition of ``A.H @ A`` or ``A @ A.H`` (depending which one is cheaper) References ---------- .. [1] Muhammad F. Kasim, "Derivatives of partial eigendecomposition of a real symmetric matrix for degenerate cases". arXiv:2011.04366 (2020) `https://arxiv.org/abs/2011.04366 <https://arxiv.org/abs/2011.04366>`_ """ # A: (*BA, m, n) # adapted from scipy.sparse.linalg.svds if is_debug_enabled(): A.check() BA = A.shape[:-2] m = A.shape[-2] n = A.shape[-1] if m < n: AAsym = A.matmul(A.H, is_hermitian=True) min_nm = m else: AAsym = A.H.matmul(A, is_hermitian=True) min_nm = n eivals, eivecs = symeig(AAsym, k, mode, bck_options=bck_options, method=method, **fwd_options) # (*BA, k) and (*BA, min(mn), k) # clamp the eigenvalues to a small positive values to avoid numerical # instability eivals = torch.clamp(eivals, min=0.0) s = torch.sqrt(eivals) # (*BA, k) sdiv = torch.clamp(s, min=1e-12).unsqueeze(-2) # (*BA, 1, k) if m < n: u = eivecs # (*BA, m, k) v = A.rmm(u) / sdiv # (*BA, n, k) else: v = eivecs # (*BA, n, k) u = A.mm(v) / sdiv # (*BA, m, k) vh = v.transpose(-2, -1) return u, s, vh
def solve(A: LinearOperator, B: torch.Tensor, E: Union[torch.Tensor, None] = None, M: Union[LinearOperator, None] = None, posdef=False, fwd_options: Mapping[str, Any] = {}, bck_options: Mapping[str, Any] = {}): """ Performing iterative method to solve the equation AX=B or AX-MXE=B, where E is a diagonal matrix. This function can also solve batched multiple inverse equation at the same time by applying A to a tensor X with shape (...,na,ncols). The applied E are not necessarily identical for each column. Arguments --------- * A: xitorch.LinearOperator instance with shape (*BA, na, na) A function that takes an input X and produce the vectors in the same space as B. * B: torch.tensor (*BB, na, ncols) The tensor on the right hand side. * E: torch.tensor (*BE, ncols) or None If not None, it will solve AX-MXE = B. Otherwise, it just solves AX = B and M is ignored. E would be applied to every column. * M: xitorch.LinearOperator instance (*BM, na, na) or None The transformation on the E side. If E is None, then this argument is ignored. I E is not None and M is None, then M=I. This LinearOperator must be Hermitian. * fwd_options: dict Options of the iterative solver in the forward calculation * bck_options: dict Options of the iterative solver in the backward calculation """ assert_runtime(A.shape[-1] == A.shape[-2], "The linear operator A must have a square shape") assert_runtime( A.shape[-1] == B.shape[-2], "Mismatch shape of A & B (A: %s, B: %s)" % (A.shape, B.shape)) if M is not None: assert_runtime(M.shape[-1] == M.shape[-2], "The linear operator M must have a square shape") assert_runtime( M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape)) assert_runtime(M.is_hermitian, "The linear operator M must be a Hermitian matrix") if E is not None: assert_runtime( E.shape[-1] == B.shape[-1], "The last dimension of E & B must match (E: %s, B: %s)" % (E.shape, B.shape)) if E is None and M is not None: warnings.warn( "M is supplied but will be ignored because E is not supplied") # perform expensive check if debug mode is enabled if is_debug_enabled(): A.check() if M is not None: M.check() if "method" not in fwd_options or fwd_options["method"].lower( ) == "exactsolve": return exactsolve(A, B, E, M) else: # get the unique parameters of A params = A.getlinopparams() mparams = M.getlinopparams() if M is not None else [] na = len(params) return solve_torchfcn.apply(A, B, E, M, posdef, fwd_options, bck_options, na, *params, *mparams)
def symeig(A: LinearOperator, neig: Union[int, None] = None, mode: str = "lowest", M: Union[LinearOperator, None] = None, fwd_options: Mapping[str, Any] = {}, bck_options: Mapping[str, Any] = {}): """ Obtain `neig` lowest eigenvalues and eigenvectors of a linear operator. If M is specified, it solve the eigendecomposition Ax = eMx. Arguments --------- * A: xitorch.LinearOperator hermitian instance with shape (*BA, q, q) The linear module object on which the eigenpairs are constructed. * neig: int or None The number of eigenpairs to be retrieved. If None, all eigenpairs are retrieved * mode: str "lowest" or "uppermost"/"uppest". If "lowest", it will take the lowest `neig` eigenpairs. If "uppest", it will take the uppermost `neig`. * M: xitorch.LinearOperator hermitian instance with shape (*BM, q, q) or None The transformation on the right hand side. If None, then M=I. * fwd_options: dict with str as key Eigendecomposition iterative algorithm options. * bck_options: dict with str as key Conjugate gradient options to calculate the gradient in backpropagation calculation. Returns ------- * eigvals: (*BAM, neig) * eigvecs: (*BAM, na, neig) The lowest eigenvalues and eigenvectors, where *BAM are the broadcasted shape of *BA and *BM. """ assert_runtime(A.is_hermitian, "The linear operator A must be Hermitian") if M is not None: assert_runtime(M.is_hermitian, "The linear operator M must be Hermitian") assert_runtime( M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape)) mode = mode.lower() if mode == "uppermost": mode = "uppest" # perform expensive check if debug mode is enabled if is_debug_enabled(): A.check() if M is not None: M.check() if "method" not in fwd_options or fwd_options["method"].lower( ) == "exacteig": return exacteig(A, neig, mode, M) else: # get the unique parameters of A & M params = A.getlinopparams() mparams = M.getlinopparams() if M is not None else [] na = len(params) return symeig_torchfcn.apply(A, neig, mode, M, fwd_options, bck_options, na, *params, *mparams)
def minimize(fcn: Callable[..., torch.Tensor], y0: torch.Tensor, params: Sequence[Any] = [], bck_options: Mapping[str, Any] = {}, method: Union[str, Callable] = None, **fwd_options) -> torch.Tensor: r""" Solve the unbounded minimization problem: .. math:: \mathbf{y^*} = \arg\min_\mathbf{y} f(\mathbf{y}, \theta) to find the best :math:`\mathbf{y}` that minimizes the output of the function :math:`f`. Arguments --------- fcn: callable The function to be optimized with output tensor with 1 element. y0: torch.tensor Initial guess of the solution with shape ``(*ny)`` params: list Sequence of any other parameters to be put in ``fcn`` bck_options: dict Method-specific options for the backward solve (see :func:`xitorch.linalg.solve`) method: str or callable or None Minimization method. If None, it will choose ``"broyden1"``. **fwd_options Method-specific options (see method section) Returns ------- torch.tensor The solution of the minimization with shape ``(*ny)`` Example ------- .. testsetup:: root1 import torch from xitorch.optimize import minimize .. doctest:: root1 >>> def func1(y, A): # example function ... return torch.sum((A @ y)**2 + y / 2.0) >>> A = torch.tensor([[1.1, 0.4], [0.3, 0.8]]).requires_grad_() >>> y0 = torch.zeros((2,1)) # zeros as the initial guess >>> ymin = minimize(func1, y0, params=(A,)) >>> print(ymin) tensor([[-0.0519], [-0.2684]], grad_fn=<_RootFinderBackward>) """ # perform implementation check if debug mode is enabled if is_debug_enabled(): assert_fcn_params(fcn, (y0, *params)) pfunc = get_pure_function(fcn) fwd_options["method"] = _get_minimizer_default_method(method) # the rootfinder algorithms are designed to move to the opposite direction # of the output of the function, so the output of this function is just # the grad of z w.r.t. y @make_sibling(pfunc) def new_fcn(y, *params): with torch.enable_grad(): y1 = y.clone().requires_grad_() z = pfunc(y1, *params) grady, = torch.autograd.grad(z, (y1, ), retain_graph=True, create_graph=torch.is_grad_enabled()) return grady return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options, len(params), *params, *pfunc.objparams())
def quad(fcn: Union[Callable[..., torch.Tensor], Callable[..., List[torch.Tensor]]], xl: Union[float, int, torch.Tensor], xu: Union[float, int, torch.Tensor], params: Sequence[Any] = [], fwd_options: Mapping[str, Any] = {}, bck_options: Mapping[str, Any] = {}): """ Calculate the quadrature of the function `fcn` from `x0` to `xf`: y = int_xl^xu fcn(x, *params) Arguments --------- * fcn: callable with output tensor with shape (*nout) or list of tensors The function to be integrated. * xl, xu: float, int, or 1-element torch.Tensor The lower and upper bound of the integration. * params: list List of any other parameters for the function `fcn`. * fwd_options: dict Options for the forward quadrature method. * bck_options: dict Options for the backward quadrature method. Returns ------- * y: torch.tensor with shape (*nout) or list of tensors The quadrature results. """ # perform implementation check if debug mode is enabled if is_debug_enabled(): assert_fcn_params(fcn, (xl, *params)) if isinstance(xl, torch.Tensor): assert_runtime(torch.numel(xl) == 1, "xl must be a 1-element tensors") if isinstance(xu, torch.Tensor): assert_runtime(torch.numel(xu) == 1, "xu must be a 1-element tensors") out = fcn(xl, *params) is_tuple_out = not isinstance(out, torch.Tensor) if not is_tuple_out: dtype = out.dtype device = out.device elif len(out) > 0: dtype = out[0].dtype device = out[0].device else: raise RuntimeError("The output of the fcn must be non-empty") pfunc = get_pure_function(fcn) nparams = len(params) if is_tuple_out: packer = TensorPacker(out) @make_sibling(pfunc) def pfunc2(x, *params): y = fcn(x, *params) return packer.flatten(y) res = _Quadrature.apply(pfunc2, xl, xu, fwd_options, bck_options, nparams, dtype, device, *params, *pfunc.objparams()) return packer.pack(res) else: return _Quadrature.apply(pfunc, xl, xu, fwd_options, bck_options, nparams, dtype, device, *params, *pfunc.objparams())
def backward(ctx, grad_evals, grad_evecs): # grad_evals: (*BAM, neig) # grad_evecs: (*BAM, na, neig) # get the variables from ctx evals, evecs = ctx.saved_tensors[:2] na = ctx.na amparams = ctx.saved_tensors[2:] params = amparams[:na] mparams = amparams[na:] M = ctx.M A = ctx.A degen_atol: Optional[float] = ctx.bck_alg_config["degen_atol"] degen_rtol: Optional[float] = ctx.bck_alg_config["degen_rtol"] # set the default values of degen_*tol dtype = evals.dtype if degen_atol is None: degen_atol = torch.finfo(dtype).eps**0.6 if degen_rtol is None: degen_rtol = torch.finfo(dtype).eps**0.4 # check the degeneracy if degen_atol > 0 or degen_rtol > 0: # idx_degen: (*BAM, neig, neig) idx_degen, isdegenerate = _check_degen(evals, degen_atol, degen_rtol) else: isdegenerate = False if not isdegenerate: idx_degen = None # the loss function where the gradient will be retrieved # warnings: if not all params have the connection to the output of A, # it could cause an infinite loop because pytorch will keep looking # for the *params node and propagate further backward via the `evecs` # path. So make sure all the *params are all connected in the graph. with torch.enable_grad(): params = [p.clone().requires_grad_() for p in params] with A.uselinopparams(*params): loss = A.mm(evecs) # (*BAM, na, neig) # if degenerate, check the conditions for finite derivative if is_debug_enabled() and isdegenerate: xtg = torch.matmul(evecs.transpose(-2, -1), grad_evecs) req1 = idx_degen * (xtg - xtg.transpose(-2, -1)) reqtol = xtg.abs().max() * grad_evecs.shape[-2] * torch.finfo( grad_evecs.dtype).eps if not torch.all(torch.abs(req1) <= reqtol): # if the requirements are not satisfied, raises a warning msg = ( "Degeneracy appears but the loss function seem to depend " "strongly on the eigenvector. The gradient might be incorrect.\n" ) msg += "Eigenvalues:\n%s\n" % str(evals) msg += "Degenerate map:\n%s\n" % str(idx_degen) msg += "Requirements (should be all 0s):\n%s" % str(req1) warnings.warn(MathWarning(msg)) # calculate the contributions from the eigenvalues gevalsA = grad_evals.unsqueeze(-2) * evecs # (*BAM, na, neig) # calculate the contributions from the eigenvectors with M.uselinopparams( *mparams) if M is not None else dummy_context_manager(): # orthogonalize the grad_evecs with evecs B = _ortho(grad_evecs, evecs, D=idx_degen, M=M, mright=False) with A.uselinopparams(*params): gevecs = solve(A, -B, evals, M, bck_options=ctx.bck_config, **ctx.bck_config) # (*BAM, na, neig) # orthogonalize gevecs w.r.t. evecs gevecsA = _ortho(gevecs, evecs, D=None, M=M, mright=True) # accummulate the gradient contributions gaccumA = gevalsA + gevecsA grad_params = torch.autograd.grad( outputs=(loss, ), inputs=params, grad_outputs=(gaccumA, ), create_graph=torch.is_grad_enabled(), ) grad_mparams = [] if ctx.M is not None: with torch.enable_grad(): mparams = [p.clone().requires_grad_() for p in mparams] with M.uselinopparams(*mparams): mloss = M.mm(evecs) # (*BAM, na, neig) gevalsM = -gevalsA * evals.unsqueeze(-2) gevecsM = -gevecsA * evals.unsqueeze(-2) # the contribution from the parallel elements gevecsM_par = (-0.5 * torch.einsum("...ae,...ae->...e", grad_evecs, evecs) ).unsqueeze(-2) * evecs # (*BAM, na, neig) gaccumM = gevalsM + gevecsM + gevecsM_par grad_mparams = torch.autograd.grad( outputs=(mloss, ), inputs=mparams, grad_outputs=(gaccumM, ), create_graph=torch.is_grad_enabled(), ) return (None, None, None, None, None, None, None, *grad_params, *grad_mparams)
def symeig(A: LinearOperator, neig: Optional[int] = None, mode: str = "lowest", M: Optional[LinearOperator] = None, bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> Tuple[torch.Tensor, torch.Tensor]: r""" Obtain ``neig`` lowest eigenvalues and eigenvectors of a linear operator, .. math:: \mathbf{AX = MXE} where :math:`\mathbf{A}, \mathbf{M}` are linear operators, :math:`\mathbf{E}` is a diagonal matrix containing the eigenvalues, and :math:`\mathbf{X}` is a matrix containing the eigenvectors. This function can handle derivatives for degenerate cases by setting non-zero ``degen_atol`` and ``degen_rtol`` in the backward option using the expressions in [1]_. Arguments --------- A: xitorch.LinearOperator The linear operator object on which the eigenpairs are constructed. It must be a Hermitian linear operator with shape ``(*BA, q, q)`` neig: int or None The number of eigenpairs to be retrieved. If ``None``, all eigenpairs are retrieved mode: str ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``, it will take the lowest ``neig`` eigenpairs. If ``"uppest"``, it will take the uppermost ``neig``. M: xitorch.LinearOperator The transformation on the right hand side. If ``None``, then ``M=I``. If specified, it must be a Hermitian with shape ``(*BM, q, q)``. bck_options: dict Method-specific options for :func:`solve` which used in backpropagation calculation with some additional arguments for computing the backward derivatives: * ``degen_atol`` (``float`` or None): Minimum absolute difference between two eigenvalues to be treated as degenerate. If None, it is ``torch.finfo(dtype).eps**0.6``. If 0.0, no special treatment on degeneracy is applied. (default: None) * ``degen_rtol`` (``float`` or None): Minimum relative difference between two eigenvalues to be treated as degenerate. If None, it is ``torch.finfo(dtype).eps**0.4``. If 0.0, no special treatment on degeneracy is applied. (default: None) Note: the default values of ``degen_atol`` and ``degen_rtol`` are going to change in the future. So, for future compatibility, please specify the specific values. method: str or callable or None Method for the eigendecomposition. If ``None``, it will choose ``"exacteig"``. **fwd_options Method-specific options (see method section below). Returns ------- tuple of tensors (eigenvalues, eigenvectors) It will return eigenvalues and eigenvectors with shapes respectively ``(*BAM, neig)`` and ``(*BAM, na, neig)``, where ``*BAM`` is the broadcasted shape of ``*BA`` and ``*BM``. References ---------- .. [1] Muhammad F. Kasim, "Derivatives of partial eigendecomposition of a real symmetric matrix for degenerate cases". arXiv:2011.04366 (2020) `https://arxiv.org/abs/2011.04366 <https://arxiv.org/abs/2011.04366>`_ """ assert_runtime(A.is_hermitian, "The linear operator A must be Hermitian") assert_runtime( not torch.is_grad_enabled() or A.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator A must be " "implemented if using symeig with grad enabled") if M is not None: assert_runtime(M.is_hermitian, "The linear operator M must be Hermitian") assert_runtime( M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape)) assert_runtime( not torch.is_grad_enabled() or M.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator M must be " "implemented if using symeig with grad enabled") mode = mode.lower() if mode == "uppermost": mode = "uppest" if method is None: if isinstance(A, MatrixLinearOperator) and \ (M is None or isinstance(M, MatrixLinearOperator)): method = "exacteig" else: # TODO: implement robust LOBPCG and put it here method = "exacteig" if neig is None: neig = A.shape[-1] # perform expensive check if debug mode is enabled if is_debug_enabled(): A.check() if M is not None: M.check() if method == "exacteig": return exacteig(A, neig, mode, M) else: fwd_options["method"] = method # get the unique parameters of A & M params = A.getlinopparams() mparams = M.getlinopparams() if M is not None else [] na = len(params) return symeig_torchfcn.apply(A, neig, mode, M, fwd_options, bck_options, na, *params, *mparams)
def equilibrium(fcn: Callable[..., torch.Tensor], y0: torch.Tensor, params: Sequence[Any] = [], bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> torch.Tensor: r""" Solving the equilibrium equation of a given function, .. math:: \mathbf{y} = \mathbf{f}(\mathbf{y}, \theta) where :math:`\mathbf{f}` is a function that can be non-linear and produce output of the same shape of :math:`\mathbf{y}`, and :math:`\theta` is other parameters required in the function. The output of this block is :math:`\mathbf{y}` that produces the same :math:`\mathbf{y}` as the output. Arguments --------- fcn : callable The function :math:`\mathbf{f}` with output tensor ``(*ny)`` y0 : torch.tensor Initial guess of the solution with shape ``(*ny)`` params : list Sequence of any other parameters to be put in ``fcn`` bck_options : dict Method-specific options for the backward solve (see :func:`xitorch.linalg.solve`) method : str or None Rootfinder method. If None, it will choose ``"broyden1"``. **fwd_options Method-specific options (see method section) Returns ------- torch.tensor The solution which satisfies :math:`\mathbf{y} = \mathbf{f}(\mathbf{y},\theta)` with shape ``(*ny)`` Example ------- .. testsetup:: equil1 import torch from xitorch.optimize import equilibrium .. doctest:: equil1 >>> def func1(y, A): # example function ... return torch.tanh(A @ y + 0.1) + y / 2.0 >>> A = torch.tensor([[1.1, 0.4], [0.3, 0.8]]).requires_grad_() >>> y0 = torch.zeros((2,1)) # zeros as the initial guess >>> yequil = equilibrium(func1, y0, params=(A,)) >>> print(yequil) tensor([[ 0.2313], [-0.5957]], grad_fn=<_RootFinderBackward>) Note ---- * This is a direct implementation of finding the root of :math:`\mathbf{g}(\mathbf{y}, \theta) = \mathbf{y} - \mathbf{f}(\mathbf{y}, \theta)` """ # perform implementation check if debug mode is enabled if is_debug_enabled(): assert_fcn_params(fcn, (y0, *params)) pfunc = get_pure_function(fcn) @make_sibling(pfunc) def new_fcn(y, *params): return y - pfunc(y, *params) fwd_options["method"] = _get_rootfinder_default_method(method) return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options, len(params), *params, *pfunc.objparams())
def symeig(A: LinearOperator, neig: Optional[int] = None, mode: str = "lowest", M: Optional[LinearOperator] = None, bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> Tuple[torch.Tensor, torch.Tensor]: r""" Obtain ``neig`` lowest eigenvalues and eigenvectors of a linear operator, .. math:: \mathbf{AX = MXE} where :math:`\mathbf{A}, \mathbf{M}` are linear operators, :math:`\mathbf{E}` is a diagonal matrix containing the eigenvalues, and :math:`\mathbf{X}` is a matrix containing the eigenvectors. Arguments --------- A: xitorch.LinearOperator The linear operator object on which the eigenpairs are constructed. It must be a Hermitian linear operator with shape ``(*BA, q, q)`` neig: int or None The number of eigenpairs to be retrieved. If ``None``, all eigenpairs are retrieved mode: str ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``, it will take the lowest ``neig`` eigenpairs. If ``"uppest"``, it will take the uppermost ``neig``. M: xitorch.LinearOperator The transformation on the right hand side. If ``None``, then ``M=I``. If specified, it must be a Hermitian with shape ``(*BM, q, q)``. bck_options: dict Method-specific options for :func:`solve` which used in backpropagation calculation. method: str or callable or None Method for the eigendecomposition. If ``None``, it will choose ``"exacteig"``. **fwd_options Method-specific options (see method section below). Returns ------- tuple of tensors (eigenvalues, eigenvectors) It will return eigenvalues and eigenvectors with shapes respectively ``(*BAM, neig)`` and ``(*BAM, na, neig)``, where ``*BAM`` is the broadcasted shape of ``*BA`` and ``*BM``. """ assert_runtime(A.is_hermitian, "The linear operator A must be Hermitian") assert_runtime( not torch.is_grad_enabled() or A.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator A must be " "implemented if using symeig with grad enabled") if M is not None: assert_runtime(M.is_hermitian, "The linear operator M must be Hermitian") assert_runtime( M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape)) assert_runtime( not torch.is_grad_enabled() or M.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator M must be " "implemented if using symeig with grad enabled") mode = mode.lower() if mode == "uppermost": mode = "uppest" if method is None: if isinstance(A, MatrixLinearOperator) and \ (M is None or isinstance(M, MatrixLinearOperator)): method = "exacteig" else: # TODO: implement robust LOBPCG and put it here method = "exacteig" if neig is None: neig = A.shape[-1] # perform expensive check if debug mode is enabled if is_debug_enabled(): A.check() if M is not None: M.check() if method == "exacteig": return exacteig(A, neig, mode, M) else: fwd_options["method"] = method # get the unique parameters of A & M params = A.getlinopparams() mparams = M.getlinopparams() if M is not None else [] na = len(params) return symeig_torchfcn.apply(A, neig, mode, M, fwd_options, bck_options, na, *params, *mparams)
def solve_ivp(fcn: Union[Callable[..., torch.Tensor], Callable[..., Sequence[torch.Tensor]]], ts: torch.Tensor, y0: torch.Tensor, params: Sequence[Any] = [], bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> Union[torch.Tensor, Sequence[torch.Tensor]]: r""" Solve the initial value problem (IVP) or also commonly known as ordinary differential equations (ODE), where given the initial value :math:`\mathbf{y_0}`, it then solves .. math:: \mathbf{y}(t) = \mathbf{y_0} + \int_{t_0}^{t} \mathbf{f}(t', \mathbf{y}, \theta)\ \mathrm{d}t' Arguments --------- fcn: callable The function that represents dy/dt. The function takes an input of a single time ``t`` and tensor ``y`` with shape ``(*ny)`` and produce :math:`\mathrm{d}\mathbf{y}/\mathrm{d}t` with shape ``(*ny)``. The output of the function must be a tensor with shape ``(*ny)`` or a list of tensors. ts: torch.tensor The time points where the value of `y` will be returned. It must be monotonically increasing or decreasing. It is a tensor with shape ``(nt,)``. y0: torch.tensor The initial value of ``y``, i.e. ``y(t[0]) == y0``. It is a tensor with shape ``(*ny)`` or a list of tensors. params: list Sequence of other parameters required in the function. bck_options: dict Options for the backward solve_ivp method. If not specified, it will take the same options as fwd_options. method: str or callable or None Initial value problem solver. If None, it will choose ``"rk45"``. **fwd_options Method-specific option (see method section below). Returns ------- torch.tensor or a list of tensors The values of ``y`` for each time step in ``ts``. It is a tensor with shape ``(nt,*ny)`` or a list of tensors """ if is_debug_enabled(): assert_fcn_params(fcn, (ts[0], y0, *params)) assert_runtime(len(ts.shape) == 1, "Argument ts must be a 1D tensor") if method is None: # set the default method method = "rk45" fwd_options["method"] = method # run once to see if the outputs is a tuple or a single tensor is_y0_list = isinstance(y0, list) or isinstance(y0, tuple) dydt = fcn(ts[0], y0, *params) is_dydt_list = isinstance(dydt, list) or isinstance(dydt, tuple) if is_y0_list != is_dydt_list: raise RuntimeError( "The y0 and output of fcn must both be tuple or a tensor") pfcn = get_pure_function(fcn) if is_y0_list: nt = len(ts) roller = TensorPacker(y0) @make_sibling(pfcn) def pfcn2(t, ytensor, *params): ylist = roller.pack(ytensor) res_list = pfcn(t, ylist, *params) res = roller.flatten(res_list) return res y0 = roller.flatten(y0) res = _SolveIVP.apply(pfcn2, ts, fwd_options, bck_options, len(params), y0, *params, *pfcn.objparams()) return roller.pack(res) else: return _SolveIVP.apply(pfcn, ts, fwd_options, bck_options, len(params), y0, *params, *pfcn.objparams())
def solve(A:LinearOperator, B:torch.Tensor, E:Union[torch.Tensor,None]=None, M:Union[LinearOperator,None]=None, posdef=False, bck_options:Mapping[str,Any]={}, method:Union[str,None]=None, **fwd_options): r""" Performing iterative method to solve the equation .. math:: \mathbf{AX=B} or .. math:: \mathbf{AX-MXE=B} where :math:`\mathbf{E}` is a diagonal matrix. This function can also solve batched multiple inverse equation at the same time by applying :math:`\mathbf{A}` to a tensor :math:`\mathbf{X}` with shape ``(...,na,ncols)``. The applied :math:`\mathbf{E}` are not necessarily identical for each column. Arguments --------- A: xitorch.LinearOperator A linear operator that takes an input ``X`` and produce the vectors in the same space as ``B``. It should have the shape of ``(*BA, na, na)`` B: torch.tensor The tensor on the right hand side with shape ``(*BB, na, ncols)`` E: torch.tensor or None If a tensor, it will solve :math:`\mathbf{AX-MXE = B}`. It will be regarded as the diagonal of the matrix. Otherwise, it just solves :math:`\mathbf{AX = B}` and ``M`` is ignored. If it is a tensor, it should have shape of ``(*BE, ncols)``. M: xitorch.LinearOperator or None The transformation on the ``E`` side. If ``E`` is ``None``, then this argument is ignored. If E is not ``None`` and ``M`` is ``None``, then ``M=I``. If LinearOperator, it must be Hermitian with shape ``(*BM, na, na)``. bck_options: dict Options of the iterative solver in the backward calculation. method: str or None Indicating the method of solve. If None, it will select ``exactsolve``. **fwd_options Method-specific options (see method below) """ assert_runtime(A.shape[-1] == A.shape[-2], "The linear operator A must have a square shape") assert_runtime(A.shape[-1] == B.shape[-2], "Mismatch shape of A & B (A: %s, B: %s)" % (A.shape, B.shape)) if M is not None: assert_runtime(M.shape[-1] == M.shape[-2], "The linear operator M must have a square shape") assert_runtime(M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape)) assert_runtime(M.is_hermitian, "The linear operator M must be a Hermitian matrix") if E is not None: assert_runtime(E.shape[-1] == B.shape[-1], "The last dimension of E & B must match (E: %s, B: %s)" % (E.shape, B.shape)) if E is None and M is not None: warnings.warn("M is supplied but will be ignored because E is not supplied") # perform expensive check if debug mode is enabled if is_debug_enabled(): A.check() if M is not None: M.check() if method is None: method = "exactsolve" # TODO: do a proper method selection based on the size if method == "exactsolve": return exactsolve(A, B, E, M) else: fwd_options["method"] = method # get the unique parameters of A params = A.getlinopparams() mparams = M.getlinopparams() if M is not None else [] na = len(params) return solve_torchfcn.apply( A, B, E, M, posdef, fwd_options, bck_options, na, *params, *mparams)