Пример #1
0
def jac(fcn: Callable[..., torch.Tensor],
        params: Sequence[Any],
        idxs: Union[None, int, Sequence[int]] = None) -> List[LinearOperator]:
    """
    Returns the LinearOperator that acts as the jacobian of the params.
    The shape of LinearOperator is (nout, nin) where `nout` and `nin` are the
    total number of elements in the output and the input, respectively.

    Arguments
    ---------
    * fcn: Callable[...,torch.Tensor]
        Callable with tensor output and arbitrary numbers of input parameters.
    * params: Sequence[Any]
        List of input parameters of the function.
    * idxs: int or list of int or None
        List of the parameters indices to get the jacobian.
        The pointed parameters in `params` must be tensors and requires_grad.
        If it is None, then it will return all jacobian for all parameters that
        are tensor which requires_grad.

    Returns
    -------
    * linops: list of LinearOperator
        List of LinearOperator of the jacobian
    """
    # check idxs
    idxs_list = _setup_idxs(idxs, params)

    # make the function a functional (depends on all parameters in the object)
    pfcn = get_pure_function(fcn)
    res = [_Jac(pfcn, params, idx) for idx in idxs_list]
    if isinstance(idxs, int):
        res = res[0]
    return res
Пример #2
0
 def fcnl2(i, v, *params2):
     fmv = get_pure_function(hs[i].mv)
     params0 = v.view(-1)
     params1 = fmv.objparams()
     params12 = [p1 * p2 for (p1, p2) in zip(params1, params2)]
     with fmv.useobjparams(params12):
         return fmv(params0)
Пример #3
0
def minimize(fcn: Callable[..., torch.Tensor],
             y0: torch.Tensor,
             params: Sequence[Any] = [],
             bck_options: Mapping[str, Any] = {},
             method: Union[str, None] = None,
             **fwd_options) -> torch.Tensor:
    """
    Solve the unbounded minimization problem:

    .. math::

        \mathbf{y^*} = \\arg\min_\mathbf{y} f(\mathbf{y}, \\theta)

    to find the best :math:`\mathbf{y}` that minimizes the output of the
    function :math:`f`.

    Arguments
    ---------
    fcn: callable
        The function to be optimized with output tensor with 1 element.
    y0: torch.tensor
        Initial guess of the solution with shape ``(*ny)``
    params: list
        List of any other parameters to be put in ``fcn``
    bck_options: dict
        Method-specific options for the backward solve.
    method: str or None
        Minimization method.
    **fwd_options
        Method-specific options (see method section)

    Returns
    -------
    torch.tensor
        The solution of the minimization with shape ``(*ny)``
    """
    # perform implementation check if debug mode is enabled
    if is_debug_enabled():
        assert_fcn_params(fcn, (y0, *params))

    pfunc = get_pure_function(fcn)

    fwd_options["method"] = _get_minimizer_default_method(method)
    # the rootfinder algorithms are designed to move to the opposite direction
    # of the output of the function, so the output of this function is just
    # the grad of z w.r.t. y
    @make_sibling(pfunc)
    def new_fcn(y, *params):
        with torch.enable_grad():
            y1 = y.clone().requires_grad_()
            z = pfunc(y1, *params)
        grady, = torch.autograd.grad(z, (y1, ),
                                     retain_graph=True,
                                     create_graph=torch.is_grad_enabled())
        return grady

    return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options,
                             len(params), *params, *pfunc.objparams())
Пример #4
0
def hess(
    fcn: Callable[..., torch.Tensor],
    params: Sequence[Any],
    idxs: Union[None, int,
                Sequence[int]] = None) -> Union[LinearOperator, List]:
    """
    Returns the LinearOperator that acts as the Hessian of the params.
    The shape of LinearOperator is (nin, nin) where `nin` is the
    total number of elements in the input.

    Arguments
    ---------
    * fcn: Callable[...,torch.Tensor]
        Callable with tensor output and arbitrary numbers of input parameters.
        The numel of the output must be 1.
    * params: Sequence[Any]
        List of input parameters of the function.
    * idxs: int or list of int or None
        List of the parameters indices to get the jacobian.
        The pointed parameters in `params` must be tensors and requires_grad.
        If it is None, then it will return all Hessian for all parameters that
        are tensor which requires_grad.

    Returns
    -------
    * linops: list of LinearOperator
        List of LinearOperator of the Hessian
    """
    idxs_list = _setup_idxs(idxs, params)

    # make the function a functional (depends on all parameters in the object)
    pfcn = get_pure_function(fcn)

    res = []

    def gen_pfcn2(idx):
        @make_sibling(pfcn)
        def pfcn2(*params):
            with torch.enable_grad():
                z = pfcn(*params)
            grady, = torch.autograd.grad(z, (params[idx], ),
                                         retain_graph=True,
                                         create_graph=torch.is_grad_enabled())
            return grady

        return pfcn2

    for idx in idxs_list:
        # suppress warnings of double implementation in hermitian matrix
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            hs = _Jac(gen_pfcn2(idx), params, idx, is_hermitian=True)
        res.append(hs)

    if isinstance(idxs, int):
        return res[0]
    return res
Пример #5
0
def test_edit_simple():
    pfcn = get_pure_function(model.method_correct_getsetparams)
    objparams = pfcn.objparams()
    assert len(objparams) == 4
    newb = torch.tensor([1.])
    newobjparams = [torch.tensor(1.0 * i + 2) for i in range(len(objparams))]
    with pfcn.useobjparams(newobjparams):
        f = pfcn(newb)
    assert torch.allclose(f, f * 0 + 16)  # (1+2+3+4+5+1)

    pfcn2 = get_pure_function(model.method_correct_getsetparams2)
    objparams2 = pfcn2.objparams()
    assert len(objparams2) == 4
    newparams2 = [torch.tensor(1.0 * i + 1) for i in range(2)]
    newobjparams2 = [torch.tensor(1.0 * i + 3) for i in range(len(objparams2))]
    with pfcn2.useobjparams(newobjparams2):
        f2 = pfcn2(*newparams2)
    assert torch.allclose(f2, f2 * 0 + 41)  # (1+3+4+5+6+1) + (2+3+4+5+6+1)
Пример #6
0
def test_module_pfunc(clss):
    a = torch.nn.Parameter(torch.tensor(2.0))
    b = torch.nn.Parameter(torch.tensor(1.0))
    x = torch.tensor(1.5)
    module = clss(a, b)
    pfunc = get_pure_function(module.forward)

    expr = lambda x, a, b: x * a + b
    runtest_pfunc(pfunc, (x, ), expr)
Пример #7
0
def test_edit_list():
    pfcn = get_pure_function(model.method_list_correct)
    objparams = pfcn.objparams()
    assert len(objparams) == 6
    assert objparams[-1] is model.listparams[2]
    assert objparams[-2] is model.listparams[0]
    newparams = [torch.tensor(1.0 * i + 1) for i in range(1)]
    newobjparams = [torch.tensor(1.0 * i + 2) for i in range(len(objparams))]
    with pfcn.useobjparams(newobjparams):
        f = pfcn(*newparams)
    assert torch.allclose(f, f * 0 + 29)  # (1+2+3+4+5+1) + 6+7
Пример #8
0
def test_edit_duplicate():
    pfcn = get_pure_function(model.method_duplicate_correct)
    objparams = pfcn.objparams()
    assert len(objparams) == 4
    assert objparams[0] is model.a
    assert objparams[0] is model.aa  # aa is a duplicate of a
    newparams = [torch.tensor(1.0 * i + 1) for i in range(1)]
    newobjparams = [torch.tensor(1.0 * i + 2) for i in range(len(objparams))]
    with pfcn.useobjparams(newobjparams):
        f = pfcn(*newparams)
    assert torch.allclose(f, f * 0 + 18)  # (1+2+3+4+5+1) + 2
Пример #9
0
def test_pure_function(fcn):
    pfunc1 = get_pure_function(fcn)
    assert isinstance(pfunc1, PureFunction)
    assert len(pfunc1.objparams()) == 0
    a = torch.tensor(2.0)
    b = torch.tensor(1.0)
    x = torch.tensor(1.5)
    res = x * a + b

    expr = lambda x, a, b: x * a + b
    runtest_pfunc(pfunc1, (x, a, b), expr)
Пример #10
0
def equilibrium(fcn: Callable[..., torch.Tensor],
                y0: torch.Tensor,
                params: Sequence[Any] = [],
                bck_options: Mapping[str, Any] = {},
                method: Union[str, None] = None,
                **fwd_options):
    """
    Solving the equilibrium equation of a given function,

    .. math::

        \mathbf{y} = \mathbf{f}(\mathbf{y}, \\theta)

    where :math:`\mathbf{f}` is a function that can be non-linear and
    produce output of the same shape of :math:`\mathbf{y}`, and :math:`\\theta`
    is other parameters required in the function.
    The output of this block is :math:`\mathbf{y}`
    that produces the same :math:`\mathbf{y}` as the output.

    Arguments
    ---------
    fcn : callable
        The function :math:`\mathbf{f}` with output tensor ``(*ny)``
    y0 : torch.tensor
        Initial guess of the solution with shape ``(*ny)``
    params : list
        List of any other parameters to be put in ``fcn``
    bck_options : dict
        Method-specific options for the backward solve
    method : str or None
        Rootfinder method.
    **fwd_options
        Method-specific options (see method section)

    Returns
    -------
    torch.tensor
        The solution which satisfies
        :math:`\mathbf{y} = \mathbf{f}(\mathbf{y},\\theta)`
        with shape ``(*ny)``
    """
    # perform implementation check if debug mode is enabled
    if is_debug_enabled():
        assert_fcn_params(fcn, (y0, *params))

    pfunc = get_pure_function(fcn)

    @make_sibling(pfunc)
    def new_fcn(y, *params):
        return y - pfunc(y, *params)

    fwd_options["method"] = _get_rootfinder_default_method(method)
    return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options,
                             len(params), *params, *pfunc.getobjparams())
Пример #11
0
def minimize(fcn: Callable[..., torch.Tensor],
             y0: torch.Tensor,
             params: Sequence[Any] = [],
             fwd_options: Mapping[str, Any] = {},
             bck_options: Mapping[str, Any] = {}) -> torch.Tensor:
    """
    Solve the minimization problem:

        z = (argmin_y) fcn(y, *params)

    to find the best `y` that minimizes the output of the function `fcn`.
    The output of `fcn` must be a single element tensor.

    Arguments
    ---------
    * fcn: callable with output tensor (numel=1)
        The function
    * y0: torch.tensor with shape (*ny)
        Initial guess of the solution
    * params: list
        List of any other parameters to be put in fcn
    * fwd_options: dict
        Options for the minimizer method
    * bck_options: dict
        Options for the backward solve method

    Returns
    -------
    * y: torch.tensor with shape (*ny)
        The solution of the minimization
    """
    # perform implementation check if debug mode is enabled
    if is_debug_enabled():
        assert_fcn_params(fcn, (y0, *params))

    pfunc = get_pure_function(fcn)

    # the rootfinder algorithms are designed to move to the opposite direction
    # of the output of the function, so the output of this function is just
    # the grad of z w.r.t. y
    @make_sibling(pfunc)
    def new_fcn(y, *params):
        with torch.enable_grad():
            y1 = y.clone().requires_grad_()
            z = pfunc(y1, *params)
        grady, = torch.autograd.grad(z, (y1, ),
                                     retain_graph=True,
                                     create_graph=torch.is_grad_enabled())
        return grady

    return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options,
                             len(params), *params, *pfunc.objparams())
Пример #12
0
def _mcquad(ffcn, log_pfcn, x0, xsamples, wsamples, fparams, pparams,
            fwd_options, bck_options):
    # this is mcquad with an additional xsamples argument, to prevent xsamples being set by users

    if is_debug_enabled():
        assert_fcn_params(ffcn, (x0, *fparams))
        assert_fcn_params(log_pfcn, (x0, *pparams))

    # check if ffcn produces a list / tuple
    out = ffcn(x0, *fparams)
    is_tuple_out = isinstance(out, list) or isinstance(out, tuple)

    # get the pure functions
    pure_ffcn = get_pure_function(ffcn)
    pure_logpfcn = get_pure_function(log_pfcn)
    nfparams = len(fparams)
    npparams = len(pparams)
    fobjparams = pure_ffcn.objparams()
    pobjparams = pure_logpfcn.objparams()
    nf_objparams = len(fobjparams)

    if is_tuple_out:
        packer = TensorPacker(out)

        @make_sibling(pure_ffcn)
        def pure_ffcn2(x, *fparams):
            y = pure_ffcn(x, *fparams)
            return packer.flatten(y)

        res = _MCQuad.apply(pure_ffcn2, pure_logpfcn, x0, None, None,
                            fwd_options, bck_options, nfparams, nf_objparams,
                            npparams, *fparams, *fobjparams, *pparams,
                            *pobjparams)
        return packer.pack(res)
    else:
        return _MCQuad.apply(pure_ffcn, pure_logpfcn, x0, None, None,
                             fwd_options, bck_options, nfparams, nf_objparams,
                             npparams, *fparams, *fobjparams, *pparams,
                             *pobjparams)
Пример #13
0
def equilibrium(fcn: Callable[..., torch.Tensor],
                y0: torch.Tensor,
                params: Sequence[Any] = [],
                fwd_options: Mapping[str, Any] = {},
                bck_options: Mapping[str, Any] = {}):
    """
    Solving the equilibrium equation of a given function,

        y = fcn(y, *params)

    where `fcn` is a function that can be non-linear and produce output of shape
    `y`. The output of this block is `y` that produces the 0 as the output.

    Arguments
    ---------
    * fcn: callable with output tensor (*ny)
        The function
    * y0: torch.tensor with shape (*ny)
        Initial guess of the solution
    * params: list
        List of any other parameters to be put in fcn
    * fwd_options: dict
        Options for the rootfinder method
    * bck_options: dict
        Options for the backward solve method

    Returns
    -------
    * yout: torch.tensor with shape (*ny)
        The solution which satisfies yout = fcn(yout)

    Note
    ----
    * To obtain the correct gradient and higher order gradients, the fcn must be:
        - a torch.nn.Module with fcn.parameters() list the tensors that determine
            the output of the fcn.
        - a method in xt.EditableModule object with no out-of-scope parameters.
        - a function with no out-of-scope parameters.
    """
    # perform implementation check if debug mode is enabled
    if is_debug_enabled():
        assert_fcn_params(fcn, (y0, *params))

    pfunc = get_pure_function(fcn)

    @make_sibling(pfunc)
    def new_fcn(y, *params):
        return y - pfunc(y, *params)

    return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options,
                             len(params), *params, *pfunc.getobjparams())
Пример #14
0
def test_wrap_nnmodule():
    a = torch.nn.Parameter(torch.tensor([1.2], dtype=dtype))
    x = torch.tensor([1.4], dtype=dtype, requires_grad=True)

    modules = [NNModule(a), NNModule0()]
    nparams = [2, 1]

    for i in range(len(modules)):
        module = modules[i]
        pfcn = get_pure_function(module)
        objparams = pfcn.objparams()
        assert len(objparams) == nparams[i] - 1  # 1 for the parameter x
        if len(objparams) == 1:
            assert objparams[0] is a
Пример #15
0
def equilibrium(fcn: Callable[..., torch.Tensor],
                y0: torch.Tensor,
                params: Sequence[Any] = [],
                bck_options: Mapping[str, Any] = {},
                method: Union[str, Callable, None] = None,
                **fwd_options) -> torch.Tensor:
    r"""
    Solving the equilibrium equation of a given function,

    .. math::

        \mathbf{y} = \mathbf{f}(\mathbf{y}, \theta)

    where :math:`\mathbf{f}` is a function that can be non-linear and
    produce output of the same shape of :math:`\mathbf{y}`, and :math:`\theta`
    is other parameters required in the function.
    The output of this block is :math:`\mathbf{y}`
    that produces the same :math:`\mathbf{y}` as the output.

    Arguments
    ---------
    fcn : callable
        The function :math:`\mathbf{f}` with output tensor ``(*ny)``
    y0 : torch.tensor
        Initial guess of the solution with shape ``(*ny)``
    params : list
        Sequence of any other parameters to be put in ``fcn``
    bck_options : dict
        Method-specific options for the backward solve (see :func:`xitorch.linalg.solve`)
    method : str or None
        Rootfinder method. If None, it will choose ``"broyden1"``.
    **fwd_options
        Method-specific options (see method section)

    Returns
    -------
    torch.tensor
        The solution which satisfies
        :math:`\mathbf{y} = \mathbf{f}(\mathbf{y},\theta)`
        with shape ``(*ny)``

    Example
    -------
    .. testsetup:: equil1

        import torch
        from xitorch.optimize import equilibrium

    .. doctest:: equil1

        >>> def func1(y, A):  # example function
        ...     return torch.tanh(A @ y + 0.1) + y / 2.0
        >>> A = torch.tensor([[1.1, 0.4], [0.3, 0.8]]).requires_grad_()
        >>> y0 = torch.zeros((2,1))  # zeros as the initial guess
        >>> yequil = equilibrium(func1, y0, params=(A,))
        >>> print(yequil)
        tensor([[ 0.2313],
                [-0.5957]], grad_fn=<_RootFinderBackward>)

    Note
    ----
    * This is a direct implementation of finding the root of
      :math:`\mathbf{g}(\mathbf{y}, \theta) = \mathbf{y} - \mathbf{f}(\mathbf{y}, \theta)`
    """
    # perform implementation check if debug mode is enabled
    if is_debug_enabled():
        assert_fcn_params(fcn, (y0, *params))

    pfunc = get_pure_function(fcn)

    @make_sibling(pfunc)
    def new_fcn(y, *params):
        return y - pfunc(y, *params)

    fwd_options["method"] = _get_rootfinder_default_method(method)
    return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options,
                             len(params), *params, *pfunc.objparams())
Пример #16
0
def minimize(fcn: Callable[..., torch.Tensor],
             y0: torch.Tensor,
             params: Sequence[Any] = [],
             bck_options: Mapping[str, Any] = {},
             method: Union[str, Callable] = None,
             **fwd_options) -> torch.Tensor:
    r"""
    Solve the unbounded minimization problem:

    .. math::

        \mathbf{y^*} = \arg\min_\mathbf{y} f(\mathbf{y}, \theta)

    to find the best :math:`\mathbf{y}` that minimizes the output of the
    function :math:`f`.

    Arguments
    ---------
    fcn: callable
        The function to be optimized with output tensor with 1 element.
    y0: torch.tensor
        Initial guess of the solution with shape ``(*ny)``
    params: list
        Sequence of any other parameters to be put in ``fcn``
    bck_options: dict
        Method-specific options for the backward solve (see :func:`xitorch.linalg.solve`)
    method: str or callable or None
        Minimization method. If None, it will choose ``"broyden1"``.
    **fwd_options
        Method-specific options (see method section)

    Returns
    -------
    torch.tensor
        The solution of the minimization with shape ``(*ny)``

    Example
    -------
    .. testsetup:: root1

        import torch
        from xitorch.optimize import minimize

    .. doctest:: root1

        >>> def func1(y, A):  # example function
        ...     return torch.sum((A @ y)**2 + y / 2.0)
        >>> A = torch.tensor([[1.1, 0.4], [0.3, 0.8]]).requires_grad_()
        >>> y0 = torch.zeros((2,1))  # zeros as the initial guess
        >>> ymin = minimize(func1, y0, params=(A,))
        >>> print(ymin)
        tensor([[-0.0519],
                [-0.2684]], grad_fn=<_RootFinderBackward>)
    """
    # perform implementation check if debug mode is enabled
    if is_debug_enabled():
        assert_fcn_params(fcn, (y0, *params))

    pfunc = get_pure_function(fcn)

    fwd_options["method"] = _get_minimizer_default_method(method)
    # the rootfinder algorithms are designed to move to the opposite direction
    # of the output of the function, so the output of this function is just
    # the grad of z w.r.t. y

    @make_sibling(pfunc)
    def new_fcn(y, *params):
        with torch.enable_grad():
            y1 = y.clone().requires_grad_()
            z = pfunc(y1, *params)
        grady, = torch.autograd.grad(z, (y1, ),
                                     retain_graph=True,
                                     create_graph=torch.is_grad_enabled())
        return grady

    return _RootFinder.apply(new_fcn, y0, fwd_options, bck_options,
                             len(params), *params, *pfunc.objparams())
Пример #17
0
def solve_ivp(fcn: Callable[..., torch.Tensor],
              ts: torch.Tensor,
              y0: torch.Tensor,
              params: Sequence[Any] = [],
              fwd_options: Mapping[str, Any] = {},
              bck_options: Mapping[str, Any] = {}) -> torch.Tensor:
    """
    Solve the initial value problem (IVP) which given the initial value `y0`,
    the function is then solve

        y(t) = y0 + int_t0^t f(t', y, *params) dt'

    Arguments
    ---------
    * fcn: callable with output a tensor with shape (*ny) or a list of tensors
        The function that represents dy/dt. The function takes an input of a
        single time `t` and `y` with shape (*ny) and produce dydt with shape (*ny).
    * ts: torch.tensor with shape (nt,)
        The time points where the value of `y` is returned.
        It must be monotonically increasing or decreasing.
    * y0: torch.tensor with shape (*ny) or a list of tensors
        The initial value of y, i.e. y(t[0]) == y0
    * params: list
        List of other parameters required in the function.
    * fwd_options: dict
        Options for the forward solve_ivp method.
    * bck_options: dict
        Options for the backward solve_ivp method.

    Returns
    -------
    * yt: torch.tensor with shape (nt,*ny) or a list of tensors
        The values of `y` for each time step in `ts`.
    """
    if is_debug_enabled():
        assert_fcn_params(fcn, (ts[0], y0, *params))
    assert_runtime(len(ts.shape) == 1, "Argument ts must be a 1D tensor")

    # run once to see if the outputs is a tuple or a single tensor
    is_y0_list = isinstance(y0, list) or isinstance(y0, tuple)
    dydt = fcn(ts[0], y0, *params)
    is_dydt_list = isinstance(dydt, list) or isinstance(dydt, tuple)
    if is_y0_list != is_dydt_list:
        raise RuntimeError(
            "The y0 and output of fcn must both be tuple or a tensor")

    pfcn = get_pure_function(fcn)
    if is_y0_list:
        nt = len(ts)
        roller = TensorPacker(y0)

        @make_sibling(pfcn)
        def pfcn2(t, ytensor, *params):
            ylist = roller.pack(ytensor)
            res_list = pfcn(t, ylist, *params)
            res = roller.flatten(res_list)
            return res

        y0 = roller.flatten(y0)
        res = _SolveIVP.apply(pfcn2, ts, fwd_options, bck_options, len(params),
                              y0, *params, *pfcn.objparams())
        return roller.pack(res)
    else:
        return _SolveIVP.apply(pfcn, ts, fwd_options, bck_options, len(params),
                               y0, *params, *pfcn.objparams())
Пример #18
0
def quad(fcn: Union[Callable[..., torch.Tensor], Callable[...,
                                                          List[torch.Tensor]]],
         xl: Union[float, int, torch.Tensor],
         xu: Union[float, int, torch.Tensor],
         params: Sequence[Any] = [],
         fwd_options: Mapping[str, Any] = {},
         bck_options: Mapping[str, Any] = {}):
    """
    Calculate the quadrature of the function `fcn` from `x0` to `xf`:

        y = int_xl^xu fcn(x, *params)

    Arguments
    ---------
    * fcn: callable with output tensor with shape (*nout) or list of tensors
        The function to be integrated.
    * xl, xu: float, int, or 1-element torch.Tensor
        The lower and upper bound of the integration.
    * params: list
        List of any other parameters for the function `fcn`.
    * fwd_options: dict
        Options for the forward quadrature method.
    * bck_options: dict
        Options for the backward quadrature method.

    Returns
    -------
    * y: torch.tensor with shape (*nout) or list of tensors
        The quadrature results.
    """
    # perform implementation check if debug mode is enabled
    if is_debug_enabled():
        assert_fcn_params(fcn, (xl, *params))
    if isinstance(xl, torch.Tensor):
        assert_runtime(torch.numel(xl) == 1, "xl must be a 1-element tensors")
    if isinstance(xu, torch.Tensor):
        assert_runtime(torch.numel(xu) == 1, "xu must be a 1-element tensors")

    out = fcn(xl, *params)
    is_tuple_out = not isinstance(out, torch.Tensor)
    if not is_tuple_out:
        dtype = out.dtype
        device = out.device
    elif len(out) > 0:
        dtype = out[0].dtype
        device = out[0].device
    else:
        raise RuntimeError("The output of the fcn must be non-empty")

    pfunc = get_pure_function(fcn)
    nparams = len(params)
    if is_tuple_out:
        packer = TensorPacker(out)

        @make_sibling(pfunc)
        def pfunc2(x, *params):
            y = fcn(x, *params)
            return packer.flatten(y)

        res = _Quadrature.apply(pfunc2, xl, xu, fwd_options, bck_options,
                                nparams, dtype, device, *params,
                                *pfunc.objparams())
        return packer.pack(res)
    else:
        return _Quadrature.apply(pfunc, xl, xu, fwd_options, bck_options,
                                 nparams, dtype, device, *params,
                                 *pfunc.objparams())
Пример #19
0
def solve_ivp(fcn: Union[Callable[..., torch.Tensor],
                         Callable[..., Sequence[torch.Tensor]]],
              ts: torch.Tensor,
              y0: torch.Tensor,
              params: Sequence[Any] = [],
              bck_options: Mapping[str, Any] = {},
              method: Union[str, Callable, None] = None,
              **fwd_options) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
    r"""
    Solve the initial value problem (IVP) or also commonly known as ordinary
    differential equations (ODE), where given the initial value :math:`\mathbf{y_0}`,
    it then solves

    .. math::

        \mathbf{y}(t) = \mathbf{y_0} + \int_{t_0}^{t} \mathbf{f}(t', \mathbf{y}, \theta)\ \mathrm{d}t'

    Arguments
    ---------
    fcn: callable
        The function that represents dy/dt. The function takes an input of a
        single time ``t`` and tensor ``y`` with shape ``(*ny)`` and
        produce :math:`\mathrm{d}\mathbf{y}/\mathrm{d}t` with shape ``(*ny)``.
        The output of the function must be a tensor with shape ``(*ny)`` or
        a list of tensors.
    ts: torch.tensor
        The time points where the value of `y` will be returned.
        It must be monotonically increasing or decreasing.
        It is a tensor with shape ``(nt,)``.
    y0: torch.tensor
        The initial value of ``y``, i.e. ``y(t[0]) == y0``.
        It is a tensor with shape ``(*ny)`` or a list of tensors.
    params: list
        Sequence of other parameters required in the function.
    bck_options: dict
        Options for the backward solve_ivp method. If not specified, it will
        take the same options as fwd_options.
    method: str or callable or None
        Initial value problem solver. If None, it will choose ``"rk45"``.
    **fwd_options
        Method-specific option (see method section below).

    Returns
    -------
    torch.tensor or a list of tensors
        The values of ``y`` for each time step in ``ts``.
        It is a tensor with shape ``(nt,*ny)`` or a list of tensors
    """
    if is_debug_enabled():
        assert_fcn_params(fcn, (ts[0], y0, *params))
    assert_runtime(len(ts.shape) == 1, "Argument ts must be a 1D tensor")

    if method is None:  # set the default method
        method = "rk45"
    fwd_options["method"] = method

    # run once to see if the outputs is a tuple or a single tensor
    is_y0_list = isinstance(y0, list) or isinstance(y0, tuple)
    dydt = fcn(ts[0], y0, *params)
    is_dydt_list = isinstance(dydt, list) or isinstance(dydt, tuple)
    if is_y0_list != is_dydt_list:
        raise RuntimeError(
            "The y0 and output of fcn must both be tuple or a tensor")

    pfcn = get_pure_function(fcn)
    if is_y0_list:
        nt = len(ts)
        roller = TensorPacker(y0)

        @make_sibling(pfcn)
        def pfcn2(t, ytensor, *params):
            ylist = roller.pack(ytensor)
            res_list = pfcn(t, ylist, *params)
            res = roller.flatten(res_list)
            return res

        y0 = roller.flatten(y0)
        res = _SolveIVP.apply(pfcn2, ts, fwd_options, bck_options, len(params),
                              y0, *params, *pfcn.objparams())
        return roller.pack(res)
    else:
        return _SolveIVP.apply(pfcn, ts, fwd_options, bck_options, len(params),
                               y0, *params, *pfcn.objparams())
Пример #20
0
def quad(fcn: Union[Callable[..., torch.Tensor],
                    Callable[..., Sequence[torch.Tensor]]],
         xl: Union[float, int, torch.Tensor],
         xu: Union[float, int, torch.Tensor],
         params: Sequence[Any] = [],
         bck_options: Mapping[str, Any] = {},
         method: Union[str, Callable, None] = None,
         **fwd_options) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
    r"""
    Calculate the quadrature:

    .. math::

        y = \int_{x_l}^{x_u} f(x, \theta)\ \mathrm{d}x

    Arguments
    ---------
    fcn: callable
        The function to be integrated. Its output must be a tensor with
        shape ``(*nout)`` or list of tensors.
    xl: float, int or 1-element torch.Tensor
        The lower bound of the integration.
    xu: float, int or 1-element torch.Tensor
        The upper bound of the integration.
    params: list
        Sequence of any other parameters for the function ``fcn``.
    bck_options: dict
        Options for the backward quadrature method.
    method: str or callable or None
        Quadrature method. If None, it will choose ``"leggauss"``.
    **fwd_options
        Method-specific options (see method section).

    Returns
    -------
    torch.tensor or a list of tensors
        The quadrature results with shape ``(*nout)`` or list of tensors.
    """
    # perform implementation check if debug mode is enabled
    if is_debug_enabled():
        assert_fcn_params(fcn, (xl, *params))
    if isinstance(xl, torch.Tensor):
        assert_runtime(torch.numel(xl) == 1, "xl must be a 1-element tensors")
    if isinstance(xu, torch.Tensor):
        assert_runtime(torch.numel(xu) == 1, "xu must be a 1-element tensors")
    if method is None:
        method = "leggauss"
    fwd_options["method"] = method

    out = fcn(xl, *params)
    if isinstance(out, torch.Tensor):
        dtype = out.dtype
        device = out.device
        is_tuple_out = False
    elif len(out) > 0:
        dtype = out[0].dtype
        device = out[0].device
        is_tuple_out = True
    else:
        raise RuntimeError("The output of the fcn must be non-empty")

    pfunc = get_pure_function(fcn)
    nparams = len(params)
    if is_tuple_out:
        packer = TensorPacker(out)

        @make_sibling(pfunc)
        def pfunc2(x, *params):
            y = fcn(x, *params)
            return packer.flatten(y)

        res = _Quadrature.apply(pfunc2, xl, xu, fwd_options, bck_options,
                                nparams, dtype, device, *params,
                                *pfunc.objparams())
        return packer.pack(res)
    else:
        return _Quadrature.apply(pfunc, xl, xu, fwd_options, bck_options,
                                 nparams, dtype, device, *params,
                                 *pfunc.objparams())