def trim_t_results( results: OdeResult, t_span: Union[List, Tuple, Array], t_eval: Optional[Union[List, Tuple, Array]] = None, ) -> OdeResult: """Trim ``OdeResult`` object based on value of ``t_span`` and ``t_eval``. Args: results: Result object, assumed to contain solution at time points from the output of ``validate_and_merge_t_span_t_eval(t_span, t_eval)``. t_span: Interval to solve over. t_eval: Time points to include in returned results. Returns: OdeResult: Results with only times/solutions in ``t_eval``. If ``t_eval`` is ``None``, does nothing, returning solver default output. """ if t_eval is None: return results t_span = Array(t_span, backend="numpy") # remove endpoints if not included in t_eval if t_eval[0] != t_span[0]: results.t = results.t[1:] results.y = Array(results.y[1:]) if t_eval[-1] != t_span[1]: results.t = results.t[:-1] results.y = Array(results.y[:-1]) return results
def jax_odeint( rhs: Callable, t_span: Array, y0: Array, t_eval: Optional[Union[Tuple, List, Array]] = None, **kwargs, ): """Routine for calling `jax.experimental.ode.odeint` Args: rhs: Callable of the form :math:`f(t, y)` t_span: Interval to solve over. y0: Initial state. t_eval: Optional list of time points at which to return the solution. kwargs: Optional arguments to be passed to ``odeint``. Returns: OdeResult: Results object. """ t_list = merge_t_args(t_span, t_eval) # determine direction of integration t_direction = np.sign(Array(t_list[-1] - t_list[0], backend="jax")).data results = odeint( lambda y, t: t_direction * rhs(t_direction * t, y), y0=y0, t=t_direction * t_list.data, **kwargs, ) results = OdeResult(t=t_list, y=Array(results, backend="jax")) return trim_t_results(results, t_span, t_eval)
def solve_ivp(fun, t_span, y0, t_eval=None, dt=0.01): t0, tf = float(t_span[0]), float(t_span[-1]) t = t0 y = y0 ts = [t] ys = [y] # start with 1 Euler forward step y = y + dt*fun(t,y) t = t + dt ts.append(t) ys.append(y) while t < tf: y = y + dt/2.0*(3*fun(t,y)-fun(ts[-2], ys[-2])) t = t + dt ts.append(t) ys.append(y) ts = np.hstack(ts) ys = np.vstack(ys).T return OdeResult(t=ts, y=ys)
def fixed_step_solver_template( take_step: Callable, rhs_func: Callable, t_span: Array, y0: Array, max_dt: float, t_eval: Optional[Union[Tuple, List, Array]] = None, ): """Helper function for implementing fixed-step solvers supporting both ``t_span`` and ``max_dt`` arguments. ``take_step`` is assumed to be a function implementing a single step of size h of a fixed-step method. The signature of ``take_step`` is assumed to be: - rhs_func: Either a generator :math:`G(t)` or RHS function :math:`f(t,y)`. - t0: The current time. - y0: The current state. - h: The size of the step to take. It returns: - y: The state of the DE at time t0 + h. ``take_step`` is used to integrate the DE specified by ``rhs_func`` through all points in ``t_eval``, taking steps no larger than ``max_dt``. Each interval in ``t_eval`` is divided into the least number of sub-intervals of equal length so that the sub-intervals are smaller than ``max_dt``. Args: take_step: Callable for fixed step integration. rhs_func: Callable, either a generator or rhs function. t_span: Interval to solve over. y0: Initial state. max_dt: Maximum step size. t_eval: Optional list of time points at which to return the solution. Returns: OdeResult: Results object. """ # ensure the output of rhs_func is a raw array def wrapped_rhs_func(*args): return Array(rhs_func(*args)).data y0 = Array(y0).data t_list, h_list, n_steps_list = get_fixed_step_sizes(t_span, t_eval, max_dt) ys = [y0] for current_t, h, n_steps in zip(t_list, h_list, n_steps_list): y = ys[-1] inner_t = current_t for _ in range(n_steps): y = take_step(wrapped_rhs_func, inner_t, y, h) inner_t = inner_t + h ys.append(y) ys = Array(ys) results = OdeResult(t=t_list, y=ys) return trim_t_results(results, t_span, t_eval)
def solve_ivp(fun, t_span, y0, t_eval=None, dt=0.01): t0, tf = float(t_span[0]), float(t_span[-1]) if t_eval is not None: assert t0 == t_eval[0] assert tf == t_eval[-1] # these variables are only needed if t_eval is not None i = 1 tp = t0 yp = y0 t = t0 y = y0 ts = [t] ys = [y] while t < tf: y = y + dt * fun(t, y) t = t + dt if t_eval is not None: while i < len(t_eval) and t >= t_eval[i]: if t == t_eval[i]: ts.append(t) ys.append(y) i += 1 elif t > t_eval[i]: yint = yp + (t_eval[i] - tp) * (y - yp) / (t - tp) ts.append(t_eval[i]) ys.append(yint) i += 1 tp = t yp = y else: ts.append(t) ys.append(y) ts = np.hstack(ts) ys = np.vstack(ys).T return OdeResult(t=ts, y=ys)
def solve_ivp(fun, t_span, y0, t_eval=None, dt=0.01): t0, tf = float(t_span[0]), float(t_span[-1]) t = t0 y = y0 ts = [t] ys = [y] while t < tf: y = y + dt * fun(t + dt / 2.0, y + dt / 2.0 * fun(t, y)) t = t + dt ts.append(t) ys.append(y) ts = np.hstack(ts) ys = np.vstack(ys).T return OdeResult(t=ts, y=ys)
def scipy_solve_ivp( rhs: Callable, t_span: Array, y0: Array, method: Union[str, OdeSolver], t_eval: Optional[Union[Tuple, List, Array]] = None, **kwargs, ): """Routine for calling `scipy.integrate.solve_ivp`. Args: rhs: Callable of the form :math:`f(t, y)`. t_span: Interval to solve over. y0: Initial state. method: Solver method. t_eval: Points at which to evaluate the solution. kwargs: Optional arguments to be passed to ``solve_ivp``. Returns: OdeResult: results object Raises: QiskitError: If unsupported kwarg present. """ if kwargs.get("dense_output", False) is True: raise QiskitError("dense_output not supported for solve_ivp.") # solve_ivp requires 1d arrays internally internal_state_spec = {"type": "array", "ndim": 1} type_converter = StateTypeConverter.from_outer_instance_inner_type_spec( y0, internal_state_spec) # modify the rhs to work with 1d arrays or real solvers rhs = type_converter.rhs_outer_to_inner(rhs) # convert y0 to the flattened version y0 = type_converter.outer_to_inner(y0) # Check if solver is real only # TODO: Also check if model or y0 are complex # if they are both real we don't need to embed. embed_real = method in REAL_METHODS if embed_real: rhs = real_rhs(rhs) y0 = c2r(y0) results = solve_ivp(rhs, t_span=t_span.data, y0=y0.data, t_eval=t_eval, method=method, **kwargs) if embed_real: results.y = r2c(results.y) # convert to the standardized results format # solve_ivp returns the states as a 2d array with columns being the states results.y = results.y.transpose() results.y = Array([type_converter.inner_to_outer(y) for y in results.y]) return OdeResult(**dict(results))
def fixed_step_solver_template_jax( take_step: Callable, rhs_func: Callable, t_span: Array, y0: Array, max_dt: float, t_eval: Optional[Union[Tuple, List, Array]] = None, ): """This function is the jax control-flow version of :meth:`fixed_step_solver_template`. See the documentation of :meth:`fixed_step_solver_template` for details. Args: take_step: Callable for fixed step integration. rhs_func: Callable, either a generator or rhs function. t_span: Interval to solve over. y0: Initial state. max_dt: Maximum step size. t_eval: Optional list of time points at which to return the solution. Returns: OdeResult: Results object. """ # ensure the output of rhs_func is a raw array def wrapped_rhs_func(*args): return Array(rhs_func(*args), backend="jax").data y0 = Array(y0, backend="jax").data t_list, h_list, n_steps_list = get_fixed_step_sizes(t_span, t_eval, max_dt) # if jax, need bound on number of iterations in each interval max_steps = n_steps_list.max() def identity(y): return y # interval integrator set up for jax.lax.scan def scan_interval_integrate(carry, x): current_t, h, n_steps = x current_y = carry def scan_take_step(carry, step): t, y = carry y = cond(step < n_steps, lambda y: take_step(wrapped_rhs_func, t, y, h), identity, y) t = t + h return (t, y), None next_y = scan(scan_take_step, (current_t, current_y), jnp.arange(max_steps))[0][1] return next_y, next_y ys = scan( scan_interval_integrate, init=y0, xs=(jnp.array(t_list[:-1]), jnp.array(h_list), jnp.array(n_steps_list)), )[1] ys = Array(jnp.append(jnp.expand_dims(y0, axis=0), ys, axis=0), backend="jax") results = OdeResult(t=t_list, y=ys) return trim_t_results(results, t_span, t_eval)
def solve_ivp_switch(sys, t_span, y0, **kwargs): kwargs_copy = kwargs.copy() t_cur = t_span[0] t_end = t_span[1] y_cur = y0 t = np.array([]) y = np.array([[] for _ in range(y0.shape[0])]) n_system_events = len(sys.event_functions) t_sys_events = [[] for _ in range(n_system_events)] y_sys_events = [[] for _ in range(n_system_events)] # the event functions of the original system event_functions = sys.event_functions.copy() # user event function passed as arguments user_event_idx = [] try: user_events = kwargs_copy.pop('events') if not isinstance(user_events, list): n_user_events = 1 user_event_idx.append(len(event_functions)) event_functions.append(user_events) else: n_user_events = len(user_events) user_event_idx = [] for event in user_events: user_event_idx.append(len(event_functions)) event_functions.append(event) t_events = [[] for _ in range(n_user_events)] y_events = [[] for _ in range(n_user_events)] except: n_user_events = 0 t_events = None y_events = None # one last event to stop at the exact time instant event_functions.append(lambda t, y: t - t_end) event_functions[-1].terminal = 0 event_functions[-1].direction = 1 n_events = n_system_events + n_user_events + 1 try: dense_output = kwargs_copy.pop('dense_output') except: dense_output = False if dense_output: ts = np.array([]) interpolants = [] terminate = False nfev = 0 njev = 0 nlu = 0 while np.abs(t_cur - t_end) > 1e-10 and not terminate: sol = solve_ivp(sys, [t_cur, t_end * 1.001], y_cur, events=event_functions, dense_output=True, **kwargs_copy) nfev += sol['nfev'] njev += sol['njev'] nlu += sol['nlu'] if not sol['success']: break t_next = np.inf ev_idx = None for i, t_ev in enumerate(sol['t_events']): if len(t_ev) > 0 and t_ev[-1] != t_cur and np.abs(t_ev[-1] - t_next) > 1e-10: t_next = t_ev[-1] ev_idx = i if ev_idx is None: t_next = sol['t'][-1] y_next = sol['y'][:, -1] elif ev_idx in user_event_idx: y_next = sol['sol'](t_next) t_events[ev_idx - n_system_events].append(t_next) y_events[ev_idx - n_system_events].append(sol['sol'](t_next)) if event_functions[ev_idx].terminal: terminate = True else: y_next = sol['sol'](t_next) if ev_idx < n_system_events: t_sys_events[ev_idx].append(t_next) y_sys_events[ev_idx].append(y_next) S = sys.handle_event(ev_idx, t_next, y_next) if sys.with_variational: N = sys.n_dim phi = S @ np.reshape(y_next[N:], (N, N)) y_next[N:] = phi.flatten() idx, = np.where(sol['t'] < t_next) t = np.append(t, sol['t'][idx]) t = np.append(t, t_next) y = np.append(y, sol['y'][:, idx], axis=1) y = np.append(y, np.array([y_next]).transpose(), axis=1) if dense_output: if len(ts) > 0 and ts[-1] == sol['sol'].ts[0]: ts = np.concatenate((ts, sol['sol'].ts[1:])) else: ts = np.concatenate((ts, sol['sol'].ts)) interpolants += sol['sol'].interpolants t_cur = t_next y_cur = y_next if dense_output: ode_sol = OdeSolution(ts, interpolants) else: ode_sol = None return OdeResult(t=t, y=y, sol=ode_sol, t_events=t_events, y_events=y_events, \ t_sys_events=t_sys_events, y_sys_events=y_sys_events, \ nfev=nfev, njev=njev, nlu=nlu, status=sol['status'], \ message=sol['message'], success=sol['success'])