Beispiel #1
0
def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf):
  """Adaptive stepsize (Dormand-Prince) Runge-Kutta odeint implementation.

  Args:
    func: function to evaluate the time derivative of the solution `y` at time
      `t` as `func(y, t, *args)`, producing the same shape/structure as `y0`.
    y0: array or pytree of arrays representing the initial value for the state.
    t: array of float times for evaluation, like `jnp.linspace(0., 10., 101)`,
      in which the values must be strictly increasing.
    *args: tuple of additional arguments for `func`, which must be arrays
      scalars, or (nested) standard Python containers (tuples, lists, dicts,
      namedtuples, i.e. pytrees) of those types.
    rtol: float, relative local error tolerance for solver (optional).
    atol: float, absolute local error tolerance for solver (optional).
    mxstep: int, maximum number of steps to take for each timepoint (optional).

  Returns:
    Values of the solution `y` (i.e. integrated system values) at each time
    point in `t`, represented as an array (or pytree of arrays) with the same
    shape/structure as `y0` except with a new leading axis of length `len(t)`.
  """
  for arg in tree_leaves(args):
    if not isinstance(arg, core.Tracer) and not core.valid_jaxtype(arg):
      raise TypeError(
        f"The contents of odeint *args must be arrays or scalars, but got {arg}.")
  if not jnp.issubdtype(t.dtype, jnp.floating):
    raise TypeError(f"t must be an array of floats, but got {t}.")

  converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args)
  return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *args, *consts)
Beispiel #2
0
 def _check_arg(arg):
     if not isinstance(arg, core.Tracer) and not core.valid_jaxtype(arg):
         msg = (
             "The contents of odeint *args must be arrays or scalars, but got "
             "\n{}."
         )
     raise TypeError(msg.format(arg))
Beispiel #3
0
def bind(self, *args, **kwargs):
    """Like Primitive.bind, but finds the top trace even when no arguments are provided."""
    assert jax.core.skip_checks or all(isinstance(arg, Tracer)
                                       or valid_jaxtype(arg) for arg in args), args

    trace = _top_trace()
    main = find_top_trace(args).main
    dynamic = thread_local_state.trace_state.trace_stack.dynamic
    assert (jax.core.skip_checks or main is dynamic or main is trace.main), args

    tracers = map(trace.full_raise, args)
    out_tracer = trace.process_primitive(self, tracers, kwargs)
    return map(full_lower, out_tracer) if self.multiple_results else full_lower(out_tracer)
Beispiel #4
0
def _stop_gradient_impl(x):
    if not valid_jaxtype(x):
        raise TypeError("stop_gradient only works on valid JAX arrays, but "
                        f"input argument is: {x}")
    return x