def test_add_any(self): # https://github.com/google/jax/issues/5217 f = lambda x, eps: x * eps + eps + x def g(eps): x = jnp.array(1.) return jax.grad(f)(x, eps) jet(g, (1.,), ([1.],)) # doesn't crash
def sol_recursive(f, z, t): """ Recursively compute higher order derivatives of dynamics of ODE. """ if reg == "none": return f(z, t), jnp.zeros_like(z) z_shape = z.shape z_t = jnp.concatenate((jnp.ravel(z), jnp.array([t]))) def g(z_t): """ Closure to expand z. """ z, t = jnp.reshape(z_t[:-1], z_shape), z_t[-1] dz = jnp.ravel(f(z, t)) dt = jnp.array([1.]) dz_t = jnp.concatenate((dz, dt)) return dz_t reg_ind = REGS.index(reg) (y0, [*yns]) = jet(g, (z_t, ), ((jnp.ones_like(z_t), ), )) for _ in range(reg_ind + 1): (y0, [*yns]) = jet(g, (z_t, ), ((y0, *yns), )) return (jnp.reshape(y0[:-1], z_shape), jnp.reshape(yns[-2][:-1], z_shape))
def test_inst_zero(self): def f(x): return 2. def g(x): return 2. + 0 * x x = jnp.ones(1) order = 3 f_out_primals, f_out_series = jet(f, (x, ), ([jnp.ones_like(x) for _ in range(order)], )) assert f_out_series is not zero_series g_out_primals, g_out_series = jet(g, (x, ), ([jnp.ones_like(x) for _ in range(order)], )) assert g_out_primals == f_out_primals assert g_out_series == f_out_series
def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5, check_dtypes=True): y, terms = jet(fun, primals, series) expected_y, expected_terms = jvp_taylor(fun, primals, series) self.assertAllClose(y, expected_y, atol=atol, rtol=rtol, check_dtypes=check_dtypes) # TODO(duvenaud): Lower zero_series to actual zeros automatically. if terms == zero_series: terms = tree_map(np.zeros_like, expected_terms) self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol, check_dtypes=check_dtypes)
def check_jet_finite(self, fun, primals, series, atol=1e-5, rtol=1e-5, check_dtypes=True): y, terms = jet(fun, primals, series) expected_y, expected_terms = jvp_taylor(fun, primals, series) def _convert(x): return np.where(np.isfinite(x), x, np.nan) y = _convert(y) expected_y = _convert(expected_y) terms = _convert(np.asarray(terms)) expected_terms = _convert(np.asarray(expected_terms)) self.assertAllClose(y, expected_y, atol=atol, rtol=rtol, check_dtypes=check_dtypes) # TODO(duvenaud): Lower zero_series to actual zeros automatically. if terms == zero_series: terms = tree_map(np.zeros_like, expected_terms) self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol, check_dtypes=check_dtypes)
def check_jet_finite(self, fun, primals, series, atol=1e-5, rtol=1e-5, check_dtypes=True): # Convert to jax arrays to ensure dtype canonicalization. primals = jax.tree_map(jnp.asarray, primals) series = jax.tree_map(jnp.asarray, series) y, terms = jet(fun, primals, series) expected_y, expected_terms = jvp_taylor(fun, primals, series) def _convert(x): return jnp.where(jnp.isfinite(x), x, jnp.nan) y = _convert(y) expected_y = _convert(expected_y) terms = _convert(jnp.asarray(terms)) expected_terms = _convert(jnp.asarray(expected_terms)) self.assertAllClose(y, expected_y, atol=atol, rtol=rtol, check_dtypes=check_dtypes) self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol, check_dtypes=check_dtypes)
def check_jet_finite(self, fun, primals, series, atol=1e-5, rtol=1e-5, check_dtypes=True): y, terms = jet(fun, primals, series) expected_y, expected_terms = jvp_taylor(fun, primals, series) def _convert(x): return jnp.where(jnp.isfinite(x), x, jnp.nan) y = _convert(y) expected_y = _convert(expected_y) terms = _convert(jnp.asarray(terms)) expected_terms = _convert(jnp.asarray(expected_terms)) self.assertAllClose(y, expected_y, atol=atol, rtol=rtol, check_dtypes=check_dtypes) self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol, check_dtypes=check_dtypes)
def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5, check_dtypes=True): # Convert to jax arrays to ensure dtype canonicalization. primals = jax.tree_map(jnp.asarray, primals) series = jax.tree_map(jnp.asarray, series) y, terms = jet(fun, primals, series) expected_y, expected_terms = jvp_taylor(fun, primals, series) self.assertAllClose(y, expected_y, atol=atol, rtol=rtol, check_dtypes=check_dtypes) self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol, check_dtypes=check_dtypes)
def _taylor_coefficient_generator(*, f, y0): """Generate Taylor coefficients. Generate Taylor-series-coefficients of the ODE solution `x(t)` via generating Taylor-series-coefficients of `g(t)=f(x(t))` via ``jax.experimental.jet()``. """ # This is the 0th Taylor coefficient of x(t) at t=t0. x_primals = y0 yield (x_primals, ) # This contains the higher-order, unnormalised # Taylor coefficients of x(t) at t=t0. # We know them because of the ODE. x_series = (f(y0), ) while True: yield (x_primals, ) + x_series # jet() computes a Taylor approximation of g(t) := f(x(t)) # The output is the zeroth Taylor approximation g(t_0) ('primals') # as well its higher-order Taylor coefficients ('series') g_primals, g_series = jet(fun=f, primals=(x_primals, ), series=(x_series, )) # For ODEs \dot y(t) = f(y(t)), # The nth Taylor coefficient of y is the # (n-1)th Taylor coefficient of g(t) = f(y(t)). # This way, by augmenting x0 with the Taylor series # approximating g(t) = f(y(t)), we increase the order # of the approximation by 1. x_series = (g_primals, *g_series)
def _taylormode(f, z0, t0, order): """Taylor-mode automatic differentiation for initialisation. Inspired by the implementation in https://github.com/jacobjinkelly/easy-neural-ode/blob/master/latent_ode.py """ try: import jax.numpy as jnp from jax.config import config from jax.experimental.jet import jet config.update("jax_enable_x64", True) except ImportError as err: raise ImportError( "Cannot perform Taylor-mode initialisation without optional " "dependencies jax and jaxlib. Try installing them via `pip install jax jaxlib`." ) from err def total_derivative(z_t): """Total derivative.""" z, t = jnp.reshape(z_t[:-1], z_shape), z_t[-1] dz = jnp.ravel(f(t, z)) dt = jnp.array([1.0]) dz_t = jnp.concatenate((dz, dt)) return dz_t z_shape = z0.shape z_t = jnp.concatenate((jnp.ravel(z0), jnp.array([t0]))) derivs = [] derivs.extend(z0) if order == 0: return jnp.array(derivs) (y0, [*yns]) = jet(total_derivative, (z_t,), ((jnp.ones_like(z_t),),)) derivs.extend(y0[:-1]) if order == 1: return jnp.array(derivs) order = order - 2 for _ in range(order + 1): (y0, [*yns]) = jet(total_derivative, (z_t,), ((y0, *yns),)) derivs.extend(yns[-2][:-1]) return jnp.array(derivs)
def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5, check_dtypes=True): y, terms = jet(fun, primals, series) expected_y, expected_terms = jvp_taylor(fun, primals, series) self.assertAllClose(y, expected_y, atol=atol, rtol=rtol, check_dtypes=check_dtypes) self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol, check_dtypes=check_dtypes)
def sol_recursive(f, z, t): """ Recursively compute higher order derivatives of dynamics of ODE. """ z_shape = z.shape z_t = jnp.concatenate((jnp.ravel(z), jnp.array([t]))) def g(z_t): """ Closure to expand z. """ z, t = jnp.reshape(z_t[:-1], z_shape), z_t[-1] dz = jnp.ravel(f(z, t)) dt = jnp.array([1.]) dz_t = jnp.concatenate((dz, dt)) return dz_t (y0, [y1h]) = jet(g, (z_t, ), ((jnp.ones_like(z_t), ), )) (y0, [y1, y2h]) = jet(g, (z_t, ), ((y0, y1h,), )) return (jnp.reshape(y0[:-1], z_shape), [jnp.reshape(y1[:-1], z_shape)])
def initialize_odefilter_with_taylormode(f, y0, t0, prior, initrv): """Initialize an ODE filter with Taylor-mode automatic differentiation. This requires JAX. For an explanation of what happens ``under the hood``, see [1]_. References ---------- .. [1] Krämer, N. and Hennig, P., Stable implementation of probabilistic ODE solvers, *arXiv:2012.10106*, 2020. The implementation is inspired by the implementation in https://github.com/jacobjinkelly/easy-neural-ode/blob/master/latent_ode.py Parameters ---------- f ODE vector field. y0 Initial value. t0 Initial time point. prior Prior distribution used for the ODE solver. For instance an integrated Brownian motion prior (``IBM``). initrv Initial random variable. Returns ------- Normal Estimated initial random variable. Compatible with the specified prior. Examples -------- >>> import sys, pytest >>> if not sys.platform.startswith('linux'): ... pytest.skip() >>> from dataclasses import astuple >>> from probnum.randvars import Normal >>> from probnum.problems.zoo.diffeq import threebody_jax, vanderpol_jax >>> from probnum.statespace import IBM Compute the initial values of the restricted three-body problem as follows >>> f, t0, tmax, y0, df, *_ = astuple(threebody_jax()) >>> print(y0) [ 0.994 0. 0. -2.00158511] >>> prior = IBM(ordint=3, spatialdim=4) >>> initrv = Normal(mean=np.zeros(prior.dimension), cov=np.eye(prior.dimension)) >>> improved_initrv = initialize_odefilter_with_taylormode(f, y0, t0, prior, initrv) >>> print(prior.proj2coord(0) @ improved_initrv.mean) [ 0.994 0. 0. -2.00158511] >>> print(improved_initrv.mean) [ 9.94000000e-01 0.00000000e+00 -3.15543023e+02 0.00000000e+00 0.00000000e+00 -2.00158511e+00 0.00000000e+00 9.99720945e+04 0.00000000e+00 -3.15543023e+02 0.00000000e+00 6.39028111e+07 -2.00158511e+00 0.00000000e+00 9.99720945e+04 0.00000000e+00] Compute the initial values of the van-der-Pol oscillator as follows >>> f, t0, tmax, y0, df, *_ = astuple(vanderpol_jax()) >>> print(y0) [2. 0.] >>> prior = IBM(ordint=3, spatialdim=2) >>> initrv = Normal(mean=np.zeros(prior.dimension), cov=np.eye(prior.dimension)) >>> improved_initrv = initialize_odefilter_with_taylormode(f, y0, t0, prior, initrv) >>> print(prior.proj2coord(0) @ improved_initrv.mean) [2. 0.] >>> print(improved_initrv.mean) [ 2. 0. -2. 60. 0. -2. 60. -1798.] >>> print(improved_initrv.std) [0. 0. 0. 0. 0. 0. 0. 0.] """ try: import jax.numpy as jnp from jax.config import config from jax.experimental.jet import jet config.update("jax_enable_x64", True) except ImportError as err: raise ImportError( "Cannot perform Taylor-mode initialisation without optional " "dependencies jax and jaxlib. Try installing them via `pip install jax jaxlib`." ) from err order = prior.ordint def total_derivative(z_t): """Total derivative.""" z, t = jnp.reshape(z_t[:-1], z_shape), z_t[-1] dz = jnp.ravel(f(t, z)) dt = jnp.array([1.0]) dz_t = jnp.concatenate((dz, dt)) return dz_t z_shape = y0.shape z_t = jnp.concatenate((jnp.ravel(y0), jnp.array([t0]))) derivs = [] derivs.extend(y0) if order == 0: all_derivs = statespace.Integrator._convert_derivwise_to_coordwise( np.asarray(jnp.array(derivs)), ordint=0, spatialdim=len(y0)) return randvars.Normal( np.asarray(all_derivs), cov=np.asarray(jnp.diag(jnp.zeros(len(derivs)))), cov_cholesky=np.asarray(jnp.diag(jnp.zeros(len(derivs)))), ) (dy0, [*yns]) = jet(total_derivative, (z_t, ), ((jnp.ones_like(z_t), ), )) derivs.extend(dy0[:-1]) if order == 1: all_derivs = statespace.Integrator._convert_derivwise_to_coordwise( np.asarray(jnp.array(derivs)), ordint=1, spatialdim=len(y0)) return randvars.Normal( np.asarray(all_derivs), cov=np.asarray(jnp.diag(jnp.zeros(len(derivs)))), cov_cholesky=np.asarray(jnp.diag(jnp.zeros(len(derivs)))), ) for _ in range(1, order): (dy0, [*yns]) = jet(total_derivative, (z_t, ), ((dy0, *yns), )) derivs.extend(yns[-2][:-1]) all_derivs = statespace.Integrator._convert_derivwise_to_coordwise( jnp.array(derivs), ordint=order, spatialdim=len(y0)) return randvars.Normal( np.asarray(all_derivs), cov=np.asarray(jnp.diag(jnp.zeros(len(derivs)))), cov_cholesky=np.asarray(jnp.diag(jnp.zeros(len(derivs)))), )
def __call__( self, ivp: problems.InitialValueProblem, prior_process: randprocs.markov.MarkovProcess, ) -> randvars.RandomVariable: try: import jax.numpy as jnp from jax.config import config from jax.experimental.jet import jet config.update("jax_enable_x64", True) except ImportError as err: raise ImportError( "Cannot perform Taylor-mode initialisation without optional " "dependencies jax and jaxlib. Try installing them via `pip install jax jaxlib`." ) from err num_derivatives = prior_process.transition.num_derivatives dt = jnp.array([1.0]) def evaluate_ode_for_extended_state(extended_state, ivp=ivp, dt=dt): r"""Evaluate the ODE for an extended state (x(t), t). More precisely, compute the derivative of the stacked state (x(t), t) according to the ODE. This function implements a rewriting of non-autonomous as autonomous ODEs. This means that .. math:: \dot x(t) = f(t, x(t)) becomes .. math:: \dot z(t) = \dot (x(t), t) = (f(x(t), t), 1). Only considering autonomous ODEs makes the jet-implementation (and automatic differentiation in general) easier. """ x, t = jnp.reshape(extended_state[:-1], ivp.y0.shape), extended_state[-1] dx = ivp.f(t, x) dx_ravelled = jnp.ravel(dx) stacked_ode_eval = jnp.concatenate((dx_ravelled, dt)) return stacked_ode_eval def derivs_to_normal_randvar(derivs, num_derivatives_in_prior): """Finalize the output in terms of creating a suitably sized random variable.""" all_derivs = (randprocs.markov.integrator.convert. convert_derivwise_to_coordwise( np.asarray(derivs), num_derivatives=num_derivatives_in_prior, wiener_process_dimension=ivp.y0.shape[0], )) # Wrap all inputs through np.asarray, because 'Normal's # do not like JAX 'DeviceArray's return randvars.Normal( mean=np.asarray(all_derivs), cov=np.asarray(jnp.diag(jnp.zeros(len(derivs)))), cov_cholesky=np.asarray(jnp.diag(jnp.zeros(len(derivs)))), ) extended_state = jnp.concatenate( (jnp.ravel(ivp.y0), jnp.array([ivp.t0]))) derivs = [] # Corner case 1: num_derivatives == 0 derivs.extend(ivp.y0) if num_derivatives == 0: return derivs_to_normal_randvar( derivs=derivs, num_derivatives_in_prior=num_derivatives) # Corner case 2: num_derivatives == 1 initial_series = (jnp.ones_like(extended_state), ) (initial_taylor_coefficient, [*remaining_taylor_coefficents]) = jet( fun=evaluate_ode_for_extended_state, primals=(extended_state, ), series=(initial_series, ), ) derivs.extend(initial_taylor_coefficient[:-1]) if num_derivatives == 1: return derivs_to_normal_randvar( derivs=derivs, num_derivatives_in_prior=num_derivatives) # Order > 1 for _ in range(1, num_derivatives): taylor_coefficients = ( initial_taylor_coefficient, *remaining_taylor_coefficents, ) (_, [*remaining_taylor_coefficents]) = jet( fun=evaluate_ode_for_extended_state, primals=(extended_state, ), series=(taylor_coefficients, ), ) derivs.extend(remaining_taylor_coefficents[-2][:-1]) return derivs_to_normal_randvar( derivs=derivs, num_derivatives_in_prior=num_derivatives)