def one_rng(rng): _, params = dynamics_init(rng, (x_dim, )) fwd_x_T = ode.odeint(lambda x, _: dynamics(params, x), jnp.zeros((x_dim, )), times, mxstep=1e9) bwd_x_0 = vmap( lambda T, x_T: ode.odeint(lambda x, _: -dynamics(params, x), x_T, jnp.array([0.0, T]), mxstep=1e9)[1], in_axes=(0, 0))(times, fwd_x_T) return jnp.sum(bwd_x_0**2, axis=-1)
def eval_from_x0(policy_params, x0, total_time): # Zero is necessary for some reason... t = jp.array([0.0, total_time]) y0 = jp.concatenate((jp.zeros((1, )), x0, jp.zeros((z_dim, )))) # odeint_kwargs = {"rtol": 1e-3, "mxstep": 1e6} odeint_kwargs = {"mxstep": 1e6} y_fwd = ode.odeint(ofunc, y0, t, policy_params, **odeint_kwargs) # This is similar but not exactly the same as the place that the rev-mode # solution since the step sizes can vary when using all the other # parameters. y_bwd = ode.odeint(lambda y, t, *args: -ofunc(y, -t, *args), y_fwd[1], -t[::-1], policy_params, **odeint_kwargs) return y_fwd, y_bwd[::-1]
def update(r): v1_normalized = r * saep * v1 y = np.asarray( (r_axis + ct * v1_normalized[0] + st * v1_normalized[1], z_axis + v1_normalized[2])) sol = odeint(step_partial, y, t_eval) return sol[:, 0], sol[:, 1]
def test_disable_jit_odeint_with_vmap(self): # https://github.com/google/jax/issues/2598 with jax.disable_jit(): t = jnp.array([0.0, 1.0]) x0_eval = jnp.zeros((5, 2)) f = lambda x0: odeint(lambda x, _t: x, x0, t) jax.vmap(f)(x0_eval) # doesn't crash
def manifold_ode_log_prob(params: List, rng: random.PRNGKey, num_samples: int) -> Tuple[jnp.ndarray]: """Forward model of the neural manifold ODE. The base distribution is uniform on the sphere. Computes both the samples from the forward model and the log-probability of the generated samples. Args: params: Parameters of the neural manifold ODE. rng: Pseudo-random number generator key. num_samples: Number of samples use to estimate the KL divergence. Returns: sph: samples generated from the neural manifold ODE under the forward model. log_prob: The log-probability of the generated samples. """ rng, rng_x, rng_b = random.split(rng, 3) x = sample_uniform(rng_x, [num_samples, 4]) b = project_to_sphere(x + 0.01 * random.normal(rng_b, x.shape)) v = log(b, x) fldj = log_det_jac_exp(b, v) vector_field = ambient_to_spherical_vector_field(lambda x, t: net(params, stacked(x, t))) cfunc, divfunc = spherical_to_chart_vector_field(b, vector_field) init = (v, jnp.zeros(len(v))) time = jnp.array([0.0, 1.0]) tang, trace = tuple(_[-1] for _ in odeint(divfunc, init, time)) sph = exp(b, tang) ildj = -log_det_jac_exp(b, tang) log_prob = uniform_log_density(x) + fldj + ildj + trace return sph, log_prob
def manifold_reverse_ode_log_prob(params: List, rng: random.PRNGKey, revx: jnp.ndarray) -> jnp.ndarray: """Given observations, compute their log-likelihood under the neural manifold ODE by integrating the dynamics backwards and applying the change-of-variables formula (computed continuously). Args: params: Parameters of the neural manifold ODE. rng: Pseudo-random number generator key. revx: Observations whose log-likelihood under the neural ODE model should be computed. Returns: rev_log_prob: The log-probability of the observations. """ b = project_to_sphere(revx + 0.01 * random.normal(rng, revx.shape)) vrev = log(b, revx) revfldj = log_det_jac_exp(b, vrev) revinit = (vrev, jnp.zeros(len(vrev))) vector_field = ambient_to_spherical_vector_field(lambda x, t: net(params, stacked(x, t))) cfunc, divfunc = spherical_to_chart_vector_field(b, vector_field) revfunc = lambda x, t: tuple(-_ for _ in divfunc(x, 1.0 - t)) time = jnp.array([0.0, 1.0]) revtang, revtrace = tuple(_[-1] for _ in odeint(revfunc, revinit, time)) revsph = exp(b, revtang) revildj = -log_det_jac_exp(b, revtang) rev_log_prob = uniform_log_density(revsph) - revfldj - revildj - revtrace return rev_log_prob
def _run_static(cls, T, x0, theta, rtol=1e-5, atol=1e-3, mxstep=500): ''' x0 is shape (d,) theta is shape (nargs,) ''' t = np.arange(T, dtype='float32') return odeint(cls.dx_dt, x0, t, *theta)
def integrate(self, x0, n_steps): """ Integrates the equations of motion of the dynamical system. """ t = jnp.asarray([n*self.dt for n in range(n_steps)]) traj = odeint(self.equation_of_motion, x0, t) return traj
def model(N, y=None): """ :param int N: number of measurement times :param numpy.ndarray y: measured populations with shape (N, 2) """ # initial population z_init = numpyro.sample("z_init", dist.LogNormal(jnp.log(10), 1).expand([2])) # measurement times ts = jnp.arange(float(N)) # parameters alpha, beta, gamma, delta of dz_dt theta = numpyro.sample( "theta", dist.TruncatedNormal( low=0.0, loc=jnp.array([1.0, 0.05, 1.0, 0.05]), scale=jnp.array([0.5, 0.05, 0.5, 0.05]), ), ) # integrate dz/dt, the result will have shape N x 2 z = odeint(dz_dt, z_init, ts, theta, rtol=1e-6, atol=1e-5, mxstep=1000) # measurement errors sigma = numpyro.sample("sigma", dist.LogNormal(-1, 1).expand([2])) # measured populations numpyro.sample("y", dist.LogNormal(jnp.log(z), sigma), obs=y)
def solve_ode(init_condition, times): return odeint(f, init_condition, t=times, rtol=1e-10, atol=1e-10, **kwargs)
def eval_from_x0(policy_params, x0, total_time): # Zero is necessary for some reason... ts = jnp.array([0.0, total_time]) y0 = (jnp.zeros(()), jnp.zeros(()), x0) odeint_kwargs = {"mxstep": 1e6} y_fwd = ode.odeint(ofunc, y0, ts, policy_params, **odeint_kwargs) yT = tree_map(itemgetter(-1), y_fwd) # This is similar but not exactly the same as the place that the rev-mode # solution since the step sizes can vary when using all the other # parameters. y_bwd = ode.odeint(lambda y, t, *args: tree_map(jnp.negative, ofunc(y, -t, *args)), yT, -ts[::-1], policy_params, **odeint_kwargs) y0_bwd = tree_map(itemgetter(-1), y_bwd) return y0, yT, y0_bwd
def loss(t1, flat_p, omega, U_T): ''' define the loss function, which is a pure function ''' t_set = jnp.linspace(0., t1, 5) D, _, = jnp.shape(U_T) U_0 = jnp.eye(D, dtype=jnp.complex128) def func(y, t, *args): t1, omega, flat_p, = args return -1.0j * (omega * sz + A(t, flat_p, t1) * sx) @ y # return -1.0j*( omega* sz)@y res = odeint(func, U_0, t_set, t1, omega, flat_p, rtol=1.4e-10, atol=1.4e-10) U_F = res[-1, :, :] return (1 - jnp.abs(jnp.trace(U_T.conj().T @ U_F) / D)**2)
def kernel(a): def integrand(y, t, a): return dndz(1. / t - t) * cosmo.g( a, t) # probably not enough, careful about change ofvar y0 = 0. y = odeint(integrand, y0, np.array([cosmo._amin, a]), a) return y[1]
def g(x): # Two initial values for the ODE y0_arr = jnp.array([[x, 0.1], [x, 0.2]]) # Run ODE twice t = jnp.array([0., 5.]) y = jax.vmap(lambda y0: odeint(dx_dt, y0, t))(y0_arr) return y[:, -1].sum()
def advance(x0, theta): x1 = odeint(cls.dx_dt, x0, t_one_step, *theta, rtol=rtol, atol=atol, mxstep=mxstep)[1] return x1, x1
def solve_model_rk45(inputs): y = odeint(rhs_ode, y0, t_eval, inputs, rtol=self.rtol, atol=self.atol, **self.extra_options) return jnp.transpose(y)
def batch_warmup(self, x0: Array, total_steps: int) -> Array: """ Integrates the model and returns just the last step. This function is used to spin-up the model to statistically stataionary regime. """ t = jnp.asarray([0, total_steps * self.dt]) traj = odeint(self.equation_of_motion, x0, t) return traj[-1, ...]
def evally(policy_params, rng, x0, gamma): # policy_params if first since that's what we want gradients wrt. t = random.exponential(rng, (num_keypoints, )) / -jp.log(gamma) t = jp.concatenate((jp.zeros((1, )), jp.sort(t))) print(f"t = {t}") x_t = ode.odeint(ofunc, x0, t, policy_params, rtol=1e-3, mxstep=1e6) print(f"x_t = {x_t}") costs = vmap(lambda x: cost_fn(x, policy(policy_params, x)))(x_t) print(f"costs = {costs}") return jp.mean(costs)
def _forward_dormandprince(state, ts, params, diffusivity, stimuli, dt, dx): return ode.odeint( step, state, ts, params, diffusivity, stimuli, dx, )
def check_against_scipy(self, fun, y0, tspace, *args, tol=1e-1): y0, tspace = np.array(y0), np.array(tspace) np_fun = partial(fun, np) scipy_result = jnp.asarray(osp_integrate.odeint(np_fun, y0, tspace, args)) y0, tspace = jnp.array(y0), jnp.array(tspace) jax_fun = partial(fun, jnp) jax_result = odeint(jax_fun, y0, tspace, *args) self.assertAllClose(jax_result, scipy_result, check_dtypes=False, atol=tol, rtol=tol)
def make_sim(key): x0 = random.normal(key, (n, total_dim)) # (n, dim * 2 + params) #x0 = index_update(x0, s_[..., -1], np.exp(x0[..., -1])) #all masses x0 = index_update( x0, s_[..., -1], 1) # mass set to 1, [ x, y, x', y', random_value, mass] x_times = odeint(odefunc, x0.reshape(packed_shape), times, mxstep=2000).reshape(-1, *unpacked_shape) return x_times
def ode_forward(rng: random.PRNGKey, net_params: List[jnp.ndarray], num_samples: int, num_dims: int) -> Tuple[jnp.ndarray]: x = random.normal(rng, [num_samples, num_dims]) log_prob_prior = jspst.norm.logpdf(x).sum(axis=-1) vector_field = lambda x, t: net(net_params, stacked(x, t)) divfunc = primal_to_augmented(vector_field) init = (x, jnp.zeros(len(x))) time = jnp.array([0.0, 1.0]) xfwd, trace = tuple(_[-1] for _ in odeint(divfunc, init, time)) log_prob = log_prob_prior + trace xsph = project(xfwd) return xfwd, xsph, log_prob
def ode_reverse(net_params: List[jnp.ndarray], xrev: jnp.ndarray) -> Tuple[jnp.ndarray]: num_dims = xrev.shape[-1] vector_field = lambda x, t: net(net_params, stacked(x, t)) divfunc = primal_to_augmented(vector_field) revfunc = lambda x, t: tuple(-_ for _ in divfunc(x, 1.0 - t)) revinit = (xrev, jnp.zeros(len(xrev))) time = jnp.array([0.0, 1.0]) yrev, revtrace = tuple(_[-1] for _ in odeint(revfunc, revinit, time)) log_prob_prior = jspst.norm.logpdf(yrev).sum(axis=-1) rev_log_prob = log_prob_prior - revtrace return rev_log_prob
def integrate_fieldline(startpoint, objective_dict): closed_l = r(objective_dict['p'], theta) ll = closed_l[:, :-1, :] dl = closed_l[:, :-1, :] - closed_l[:, 1:, :] jitfield = jit( lambda r_eval, t: biot_savart(r_eval, dl, ll, objective_dict['I_arr'])) startpos = r_surf[0, 0, :] times = np.linspace(0, 10, 10000) fieldline = odeint(jitfield, startpos, times) fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.plot(fieldline[:, 0], fieldline[:, 1], fieldline[:, 2])
def test_swoop(self): def swoop(y, t, arg1, arg2): return np.array(y - np.sin(t) - np.cos(t) * arg1 + arg2) ts = np.array([0.1, 0.2]) tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3 y0 = np.linspace(0.1, 0.9, 10) integrate = lambda y0, ts: odeint(swoop, y0, ts, 0.1, 0.2) jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2, rtol=tol, atol=tol) big_y0 = np.linspace(1.1, 10.9, 10) integrate = lambda y0, ts: odeint(swoop, big_y0, ts, 0.1, 0.3) jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2, rtol=tol, atol=tol)
def loss(t1,flat_p,psi_init,psi0): ''' define the loss function, which is a pure function ''' t_set = jnp.linspace(0., t1, 5) def func(y, t, *args): t1, flat_p, = args return -1.0j*Hmat(t, flat_p, t1)@y res = odeint(func, psi_init, t_set, t1, flat_p, rtol=1.4e-10, atol=1.4e-10) psi_final = res[-1, :] return (1 - jnp.abs(jnp.dot(jnp.conjugate(psi_final), psi0))**2)
def state_and_costate_trajectories(initial_costate_and_final_time): """Propagates the ODE that defines the shooting method. Args: initial_costate_and_final_time: An array of shape (4,) containing (p_x(0), p_y(0), p_θ(0), t_f). Returns: A tuple of arrays (times, (states, costates)) where times: An array of shape (N,) containing a sequence of time points spanning [0, t_f]. states: An array of shape (N, 3) containing the states at `times`. controls: An array of shape (N, 3) containing the controls at `times`. """ initial_costate = initial_costate_and_final_time[:-1] final_time = initial_costate_and_final_time[-1] times = jnp.linspace(0, final_time, 20) return times, odeint(shooting_ode, (initial_state, initial_costate), times)
def model(self, t, X): noise = sample('noise', dist.LogNormal(0.0, 1.0), sample_shape=(self.D,)) hyp = sample('hyp', dist.Gamma(1.0, 0.5), sample_shape=(self.D,)) W = sample('W', dist.LogNormal(0.0, 1.0), sample_shape=(self.D,)) J0 = sample('J0', dist.Uniform(1.0, 10.0)) # 2.5 k1 = sample('k1', dist.Uniform(80., 120.0)) # 100. k2 = sample('k2', dist.Uniform(1., 10.0)) # 6. k3 = sample('k3', dist.Uniform(2., 20.0)) # 16. k4 = sample('k4', dist.Uniform(80., 120.0)) # 100. k5 = sample('k5', dist.Uniform(0.1, 2.0)) # 1.28 k6 = sample('k6', dist.Uniform(2., 20.0)) # 12. k = sample('k', dist.Uniform(0.1, 2.0)) # 1.8 ka = sample('ka', dist.Uniform(2., 20.0)) # 13. q = sample('q', dist.Uniform(1., 10.0)) # 4. KI = sample('KI', dist.Uniform(0.1, 2.0)) # 0.52 phi = sample('phi', dist.Uniform(0.05, 1.0)) # 0.1 Np = sample('Np', dist.Uniform(0.1, 2.0)) # 1. A = sample('A', dist.Uniform(1., 10.0)) #4. IC = sample('IC', dist.Uniform(0, 1)) # compute kernel K_11 = W[0]*self.RBF(self.t_t[0], self.t_t[0], hyp[0]) + np.eye(self.N[0])*(noise[0] + self.jitter) K_22 = W[1]*self.RBF(self.t_t[1], self.t_t[1], hyp[1]) + np.eye(self.N[1])*(noise[1] + self.jitter) K_33 = W[2]*self.RBF(self.t_t[2], self.t_t[2], hyp[2]) + np.eye(self.N[2])*(noise[2] + self.jitter) K = np.concatenate([np.concatenate([K_11, np.zeros((self.N[0], self.N[1])), np.zeros((self.N[0], self.N[2]))], axis = 1), np.concatenate([np.zeros((self.N[1], self.N[0])), K_22, np.zeros((self.N[1], self.N[2]))], axis = 1), np.concatenate([np.zeros((self.N[2], self.N[0])), np.zeros((self.N[2], self.N[1])), K_33], axis = 1)], axis = 0) # compute mean x0 = np.array([0.5, 1.9, 0.18, 0.15, IC, 0.1, 0.064]) mut = odeint(self.dxdt, x0, self.t.flatten(), J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A) mu1 = mut[self.i_t[0],ind[0]] / self.max_X[0] mu2 = mut[self.i_t[1],ind[1]] / self.max_X[1] mu3 = mut[self.i_t[2],ind[2]] / self.max_X[2] mu = np.concatenate((mu1,mu2,mu3),axis=0) # Concat data mu = mu.flatten('F') X = np.concatenate((self.X[0],self.X[1],self.X[2]),axis=0) X = X.flatten('F') # sample X according to the standard gaussian process formula sample("X", dist.MultivariateNormal(loc=mu, covariance_matrix=K), obs=X)
def bwd_spline_segment(ta, tb, args, Q, y_old, aug_tb): """Run the backwards RK on just one segment of the spline.""" def adj_dynamics(aug, t, args, Q, y_old): _, y_bar, _ = aug y = unravel(eval_spline(ta, tb, Q, y_old, -t)) _, vjpfun = vjp(fun, -t, y, args) return vjpfun(y_bar) adj_path = ode.odeint(adj_dynamics, aug_tb, jnp.array([-tb, -ta]), args, Q, y_old, rtol=1e-3, atol=1e-3) return tree_map(itemgetter(-1), adj_path)
def make_sim(key): if sim in ['string', 'string_ball']: x0 = random.normal(key, (n, total_dim)) x0 = index_update(x0, s_[..., -1], 1); #const mass x0 = index_update(x0, s_[..., 0], np.arange(n)+x0[...,0]*0.5) x0 = index_update(x0, s_[..., 2:3], 0.0) else: x0 = random.normal(key, (n, total_dim)) x0 = index_update(x0, s_[..., -1], np.exp(x0[..., -1])); #all masses set to positive if sim in ['charge', 'superposition']: x0 = index_update(x0, s_[..., -2], np.sign(x0[..., -2])); #charge is 1 or -1 x_times = odeint( odefunc, x0.reshape(packed_shape), times, mxstep=2000).reshape(-1, *unpacked_shape) return x_times