示例#1
0
def trim_t_results(
    results: OdeResult,
    t_span: Union[List, Tuple, Array],
    t_eval: Optional[Union[List, Tuple, Array]] = None,
) -> OdeResult:
    """Trim ``OdeResult`` object based on value of ``t_span`` and ``t_eval``.

    Args:
        results: Result object, assumed to contain solution at time points
                 from the output of ``validate_and_merge_t_span_t_eval(t_span, t_eval)``.
        t_span: Interval to solve over.
        t_eval: Time points to include in returned results.

    Returns:
        OdeResult: Results with only times/solutions in ``t_eval``. If ``t_eval``
                   is ``None``, does nothing, returning solver default output.
    """

    if t_eval is None:
        return results

    t_span = Array(t_span, backend="numpy")

    # remove endpoints if not included in t_eval
    if t_eval[0] != t_span[0]:
        results.t = results.t[1:]
        results.y = Array(results.y[1:])

    if t_eval[-1] != t_span[1]:
        results.t = results.t[:-1]
        results.y = Array(results.y[:-1])

    return results
示例#2
0
def jax_odeint(
    rhs: Callable,
    t_span: Array,
    y0: Array,
    t_eval: Optional[Union[Tuple, List, Array]] = None,
    **kwargs,
):
    """Routine for calling `jax.experimental.ode.odeint`

    Args:
        rhs: Callable of the form :math:`f(t, y)`
        t_span: Interval to solve over.
        y0: Initial state.
        t_eval: Optional list of time points at which to return the solution.
        kwargs: Optional arguments to be passed to ``odeint``.

    Returns:
        OdeResult: Results object.
    """

    t_list = merge_t_args(t_span, t_eval)

    # determine direction of integration
    t_direction = np.sign(Array(t_list[-1] - t_list[0], backend="jax")).data

    results = odeint(
        lambda y, t: t_direction * rhs(t_direction * t, y),
        y0=y0,
        t=t_direction * t_list.data,
        **kwargs,
    )

    results = OdeResult(t=t_list, y=Array(results, backend="jax"))

    return trim_t_results(results, t_span, t_eval)
示例#3
0
def solve_ivp(fun, t_span, y0, t_eval=None, dt=0.01):

    t0, tf = float(t_span[0]), float(t_span[-1])

    t = t0
    y = y0

    ts = [t]
    ys = [y]

    # start with 1 Euler forward step
    y = y + dt*fun(t,y)
    t = t + dt

    ts.append(t)
    ys.append(y)

    while t < tf:
        y = y + dt/2.0*(3*fun(t,y)-fun(ts[-2], ys[-2]))
        t = t + dt

        ts.append(t)
        ys.append(y)

    ts = np.hstack(ts)
    ys = np.vstack(ys).T

    return OdeResult(t=ts, y=ys)
示例#4
0
def fixed_step_solver_template(
    take_step: Callable,
    rhs_func: Callable,
    t_span: Array,
    y0: Array,
    max_dt: float,
    t_eval: Optional[Union[Tuple, List, Array]] = None,
):
    """Helper function for implementing fixed-step solvers supporting both
    ``t_span`` and ``max_dt`` arguments. ``take_step`` is assumed to be a
    function implementing a single step of size h of a fixed-step method.
    The signature of ``take_step`` is assumed to be:
        - rhs_func: Either a generator :math:`G(t)` or RHS function :math:`f(t,y)`.
        - t0: The current time.
        - y0: The current state.
        - h: The size of the step to take.

    It returns:
        - y: The state of the DE at time t0 + h.

    ``take_step`` is used to integrate the DE specified by ``rhs_func``
    through all points in ``t_eval``, taking steps no larger than ``max_dt``.
    Each interval in ``t_eval`` is divided into the least number of sub-intervals
    of equal length so that the sub-intervals are smaller than ``max_dt``.

    Args:
        take_step: Callable for fixed step integration.
        rhs_func: Callable, either a generator or rhs function.
        t_span: Interval to solve over.
        y0: Initial state.
        max_dt: Maximum step size.
        t_eval: Optional list of time points at which to return the solution.

    Returns:
        OdeResult: Results object.
    """

    # ensure the output of rhs_func is a raw array
    def wrapped_rhs_func(*args):
        return Array(rhs_func(*args)).data

    y0 = Array(y0).data

    t_list, h_list, n_steps_list = get_fixed_step_sizes(t_span, t_eval, max_dt)

    ys = [y0]
    for current_t, h, n_steps in zip(t_list, h_list, n_steps_list):
        y = ys[-1]
        inner_t = current_t
        for _ in range(n_steps):
            y = take_step(wrapped_rhs_func, inner_t, y, h)
            inner_t = inner_t + h
        ys.append(y)
    ys = Array(ys)

    results = OdeResult(t=t_list, y=ys)

    return trim_t_results(results, t_span, t_eval)
示例#5
0
def solve_ivp(fun, t_span, y0, t_eval=None, dt=0.01):

    t0, tf = float(t_span[0]), float(t_span[-1])

    if t_eval is not None:
        assert t0 == t_eval[0]
        assert tf == t_eval[-1]

        # these variables are only needed if t_eval is not None
        i = 1
        tp = t0
        yp = y0

    t = t0
    y = y0

    ts = [t]
    ys = [y]

    while t < tf:
        y = y + dt * fun(t, y)
        t = t + dt

        if t_eval is not None:
            while i < len(t_eval) and t >= t_eval[i]:
                if t == t_eval[i]:
                    ts.append(t)
                    ys.append(y)
                    i += 1
                elif t > t_eval[i]:
                    yint = yp + (t_eval[i] - tp) * (y - yp) / (t - tp)
                    ts.append(t_eval[i])
                    ys.append(yint)
                    i += 1
            tp = t
            yp = y
        else:
            ts.append(t)
            ys.append(y)

    ts = np.hstack(ts)
    ys = np.vstack(ys).T

    return OdeResult(t=ts, y=ys)
示例#6
0
def solve_ivp(fun, t_span, y0, t_eval=None, dt=0.01):

    t0, tf = float(t_span[0]), float(t_span[-1])

    t = t0
    y = y0

    ts = [t]
    ys = [y]

    while t < tf:
        y = y + dt * fun(t + dt / 2.0, y + dt / 2.0 * fun(t, y))
        t = t + dt

        ts.append(t)
        ys.append(y)

    ts = np.hstack(ts)
    ys = np.vstack(ys).T

    return OdeResult(t=ts, y=ys)
示例#7
0
def scipy_solve_ivp(
    rhs: Callable,
    t_span: Array,
    y0: Array,
    method: Union[str, OdeSolver],
    t_eval: Optional[Union[Tuple, List, Array]] = None,
    **kwargs,
):
    """Routine for calling `scipy.integrate.solve_ivp`.

    Args:
        rhs: Callable of the form :math:`f(t, y)`.
        t_span: Interval to solve over.
        y0: Initial state.
        method: Solver method.
        t_eval: Points at which to evaluate the solution.
        kwargs: Optional arguments to be passed to ``solve_ivp``.

    Returns:
        OdeResult: results object

    Raises:
        QiskitError: If unsupported kwarg present.
    """

    if kwargs.get("dense_output", False) is True:
        raise QiskitError("dense_output not supported for solve_ivp.")

    # solve_ivp requires 1d arrays internally
    internal_state_spec = {"type": "array", "ndim": 1}
    type_converter = StateTypeConverter.from_outer_instance_inner_type_spec(
        y0, internal_state_spec)

    # modify the rhs to work with 1d arrays or real solvers
    rhs = type_converter.rhs_outer_to_inner(rhs)

    # convert y0 to the flattened version
    y0 = type_converter.outer_to_inner(y0)

    # Check if solver is real only
    # TODO: Also check if model or y0 are complex
    #       if they are both real we don't need to embed.
    embed_real = method in REAL_METHODS
    if embed_real:
        rhs = real_rhs(rhs)
        y0 = c2r(y0)

    results = solve_ivp(rhs,
                        t_span=t_span.data,
                        y0=y0.data,
                        t_eval=t_eval,
                        method=method,
                        **kwargs)
    if embed_real:
        results.y = r2c(results.y)

    # convert to the standardized results format
    # solve_ivp returns the states as a 2d array with columns being the states
    results.y = results.y.transpose()
    results.y = Array([type_converter.inner_to_outer(y) for y in results.y])

    return OdeResult(**dict(results))
示例#8
0
def fixed_step_solver_template_jax(
    take_step: Callable,
    rhs_func: Callable,
    t_span: Array,
    y0: Array,
    max_dt: float,
    t_eval: Optional[Union[Tuple, List, Array]] = None,
):
    """This function is the jax control-flow version of
    :meth:`fixed_step_solver_template`. See the documentation of :meth:`fixed_step_solver_template`
    for details.

    Args:
        take_step: Callable for fixed step integration.
        rhs_func: Callable, either a generator or rhs function.
        t_span: Interval to solve over.
        y0: Initial state.
        max_dt: Maximum step size.
        t_eval: Optional list of time points at which to return the solution.

    Returns:
        OdeResult: Results object.
    """

    # ensure the output of rhs_func is a raw array
    def wrapped_rhs_func(*args):
        return Array(rhs_func(*args), backend="jax").data

    y0 = Array(y0, backend="jax").data

    t_list, h_list, n_steps_list = get_fixed_step_sizes(t_span, t_eval, max_dt)

    # if jax, need bound on number of iterations in each interval
    max_steps = n_steps_list.max()

    def identity(y):
        return y

    # interval integrator set up for jax.lax.scan
    def scan_interval_integrate(carry, x):
        current_t, h, n_steps = x
        current_y = carry

        def scan_take_step(carry, step):
            t, y = carry
            y = cond(step < n_steps,
                     lambda y: take_step(wrapped_rhs_func, t, y, h), identity,
                     y)
            t = t + h
            return (t, y), None

        next_y = scan(scan_take_step, (current_t, current_y),
                      jnp.arange(max_steps))[0][1]

        return next_y, next_y

    ys = scan(
        scan_interval_integrate,
        init=y0,
        xs=(jnp.array(t_list[:-1]), jnp.array(h_list),
            jnp.array(n_steps_list)),
    )[1]

    ys = Array(jnp.append(jnp.expand_dims(y0, axis=0), ys, axis=0),
               backend="jax")

    results = OdeResult(t=t_list, y=ys)

    return trim_t_results(results, t_span, t_eval)
示例#9
0
def solve_ivp_switch(sys, t_span, y0, **kwargs):

    kwargs_copy = kwargs.copy()

    t_cur = t_span[0]
    t_end = t_span[1]
    y_cur = y0

    t = np.array([])
    y = np.array([[] for _ in range(y0.shape[0])])

    n_system_events = len(sys.event_functions)
    t_sys_events = [[] for _ in range(n_system_events)]
    y_sys_events = [[] for _ in range(n_system_events)]

    # the event functions of the original system
    event_functions = sys.event_functions.copy()

    # user event function passed as arguments
    user_event_idx = []
    try:
        user_events = kwargs_copy.pop('events')
        if not isinstance(user_events, list):
            n_user_events = 1
            user_event_idx.append(len(event_functions))
            event_functions.append(user_events)
        else:
            n_user_events = len(user_events)
            user_event_idx = []
            for event in user_events:
                user_event_idx.append(len(event_functions))
                event_functions.append(event)
        t_events = [[] for _ in range(n_user_events)]
        y_events = [[] for _ in range(n_user_events)]
    except:
        n_user_events = 0
        t_events = None
        y_events = None

    # one last event to stop at the exact time instant
    event_functions.append(lambda t, y: t - t_end)
    event_functions[-1].terminal = 0
    event_functions[-1].direction = 1

    n_events = n_system_events + n_user_events + 1

    try:
        dense_output = kwargs_copy.pop('dense_output')
    except:
        dense_output = False

    if dense_output:
        ts = np.array([])
        interpolants = []

    terminate = False
    nfev = 0
    njev = 0
    nlu = 0
    while np.abs(t_cur - t_end) > 1e-10 and not terminate:
        sol = solve_ivp(sys, [t_cur, t_end * 1.001],
                        y_cur,
                        events=event_functions,
                        dense_output=True,
                        **kwargs_copy)
        nfev += sol['nfev']
        njev += sol['njev']
        nlu += sol['nlu']
        if not sol['success']:
            break
        t_next = np.inf
        ev_idx = None
        for i, t_ev in enumerate(sol['t_events']):
            if len(t_ev) > 0 and t_ev[-1] != t_cur and np.abs(t_ev[-1] -
                                                              t_next) > 1e-10:
                t_next = t_ev[-1]
                ev_idx = i
        if ev_idx is None:
            t_next = sol['t'][-1]
            y_next = sol['y'][:, -1]
        elif ev_idx in user_event_idx:
            y_next = sol['sol'](t_next)
            t_events[ev_idx - n_system_events].append(t_next)
            y_events[ev_idx - n_system_events].append(sol['sol'](t_next))
            if event_functions[ev_idx].terminal:
                terminate = True
        else:
            y_next = sol['sol'](t_next)
            if ev_idx < n_system_events:
                t_sys_events[ev_idx].append(t_next)
                y_sys_events[ev_idx].append(y_next)
                S = sys.handle_event(ev_idx, t_next, y_next)
                if sys.with_variational:
                    N = sys.n_dim
                    phi = S @ np.reshape(y_next[N:], (N, N))
                    y_next[N:] = phi.flatten()
        idx, = np.where(sol['t'] < t_next)
        t = np.append(t, sol['t'][idx])
        t = np.append(t, t_next)
        y = np.append(y, sol['y'][:, idx], axis=1)
        y = np.append(y, np.array([y_next]).transpose(), axis=1)
        if dense_output:
            if len(ts) > 0 and ts[-1] == sol['sol'].ts[0]:
                ts = np.concatenate((ts, sol['sol'].ts[1:]))
            else:
                ts = np.concatenate((ts, sol['sol'].ts))
            interpolants += sol['sol'].interpolants
        t_cur = t_next
        y_cur = y_next

    if dense_output:
        ode_sol = OdeSolution(ts, interpolants)
    else:
        ode_sol = None

    return OdeResult(t=t, y=y, sol=ode_sol, t_events=t_events, y_events=y_events, \
                     t_sys_events=t_sys_events, y_sys_events=y_sys_events, \
                     nfev=nfev, njev=njev, nlu=nlu, status=sol['status'], \
                     message=sol['message'], success=sol['success'])