def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf): """Adaptive stepsize (Dormand-Prince) Runge-Kutta odeint implementation. Args: func: function to evaluate the time derivative of the solution `y` at time `t` as `func(y, t, *args)`, producing the same shape/structure as `y0`. y0: array or pytree of arrays representing the initial value for the state. t: array of float times for evaluation, like `jnp.linspace(0., 10., 101)`, in which the values must be strictly increasing. *args: tuple of additional arguments for `func`, which must be arrays scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types. rtol: float, relative local error tolerance for solver (optional). atol: float, absolute local error tolerance for solver (optional). mxstep: int, maximum number of steps to take for each timepoint (optional). Returns: Values of the solution `y` (i.e. integrated system values) at each time point in `t`, represented as an array (or pytree of arrays) with the same shape/structure as `y0` except with a new leading axis of length `len(t)`. """ for arg in tree_leaves(args): if not isinstance(arg, core.Tracer) and not core.valid_jaxtype(arg): raise TypeError( f"The contents of odeint *args must be arrays or scalars, but got {arg}.") if not jnp.issubdtype(t.dtype, jnp.floating): raise TypeError(f"t must be an array of floats, but got {t}.") converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args) return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *args, *consts)
def _check_arg(arg): if not isinstance(arg, core.Tracer) and not core.valid_jaxtype(arg): msg = ( "The contents of odeint *args must be arrays or scalars, but got " "\n{}." ) raise TypeError(msg.format(arg))
def bind(self, *args, **kwargs): """Like Primitive.bind, but finds the top trace even when no arguments are provided.""" assert jax.core.skip_checks or all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args), args trace = _top_trace() main = find_top_trace(args).main dynamic = thread_local_state.trace_state.trace_stack.dynamic assert (jax.core.skip_checks or main is dynamic or main is trace.main), args tracers = map(trace.full_raise, args) out_tracer = trace.process_primitive(self, tracers, kwargs) return map(full_lower, out_tracer) if self.multiple_results else full_lower(out_tracer)
def _stop_gradient_impl(x): if not valid_jaxtype(x): raise TypeError("stop_gradient only works on valid JAX arrays, but " f"input argument is: {x}") return x