def solve_forward(self, t0, tvals, y0, y_out): CVodeReInit = lib.CVodeReInit CVodeAdjReInit = lib.CVodeAdjReInit CVodeF = lib.CVodeF ode = self._ode TOO_MUCH_WORK = lib.CV_TOO_MUCH_WORK state_data = self._state_buffer.data state_c_ptr = self._state_buffer.c_ptr state_data[:] = y0 time_p = ffi.new('double*') time_p[0] = t0 n_check = ffi.new('int*') n_check[0] = 0 check(CVodeReInit(ode, t0, state_c_ptr)) check(CVodeAdjReInit(ode)) for i, t in enumerate(tvals): if t == t0: y_out[0, :] = y0 continue retval = TOO_MUCH_WORK while retval == TOO_MUCH_WORK: retval = CVodeF(ode, t, state_c_ptr, time_p, lib.CV_NORMAL, n_check) if retval != TOO_MUCH_WORK and retval != 0: raise SolverError( "Bad sundials return code while solving ode: %s (%s)" % (ERRORS[retval], retval)) y_out[i, :] = state_data
def current_stats(self) -> Dict[str, Any]: order = ffi.new('int*', 1) check_code(lib.CVodeGetCurrentOrder(self.c_ptr, order)) return { "order": order[0], }
def solve(self, t0: float, tvals: np.ndarray, y0: np.ndarray, y_out: np.ndarray, forward_sens: Optional[np.ndarray] = None, checkpointing: bool = False) -> None: CVodeReInit = lib.CVodeReInit CVodeAdjReInit = lib.CVodeAdjReInit CVodeF = lib.CVodeF ode = self.c_ptr TOO_MUCH_WORK = lib.CV_TOO_MUCH_WORK state_data = self._state_buffer.data state_c_ptr = self._state_buffer.c_ptr if y0.dtype == self._problem.state_dtype: y0 = y0[None].view(np.float64) state_data[:] = y0 time_p = ffi.new('double*') time_p[0] = t0 n_check = ffi.new('int*') n_check[0] = 0 check(CVodeReInit(ode, t0, state_c_ptr)) check(CVodeAdjReInit(ode)) for i, t in enumerate(tvals): if t == t0: y_out[0, :] = y0 continue retval = TOO_MUCH_WORK while retval == TOO_MUCH_WORK: retval = CVodeF(ode, t, state_c_ptr, time_p, lib.CV_NORMAL, n_check) if retval != TOO_MUCH_WORK and retval != 0: raise SolverError( "Bad sundials return code while solving ode: %s (%s)" % (ERRORS[retval], retval)) y_out[i, :] = state_data self.mark_changed(False)
def solve_forward(self, t0, tvals, y0, y_out, *, max_retries=5): CVodeReInit = lib.CVodeReInit CVodeAdjReInit = lib.CVodeAdjReInit CVodeF = lib.CVodeF ode = self._ode TOO_MUCH_WORK = lib.CV_TOO_MUCH_WORK state_data = self._state_buffer.data state_c_ptr = self._state_buffer.c_ptr if y0.dtype == self._problem.state_dtype: y0 = y0[None].view(np.float64) state_data[:] = y0 time_p = ffi.new('double*') time_p[0] = t0 n_check = ffi.new('int*') n_check[0] = 0 check(CVodeReInit(ode, t0, state_c_ptr)) check(CVodeAdjReInit(ode)) for i, t in enumerate(tvals): if t == t0: y_out[0, :] = y0 continue for retry in range(max_retries): retval = CVodeF(ode, t, state_c_ptr, time_p, lib.CV_NORMAL, n_check) if retval == 0: assert time_p[0] == t break if retval != TOO_MUCH_WORK: error = ERRORS[retval] raise SolverError( f"Solving ode failed before time={t}: {error} ({retval})" ) else: raise SolverError(f"Too many solver retries before time={t}.") y_out[i, :] = state_data
def _init_backward(self, checkpoint_n, interpolation): check(lib.CVodeAdjInit(self._ode, checkpoint_n, interpolation)) # Initialized by CVodeCreateB backward_ode = ffi.new('int*') check( lib.CVodeCreateB(self._ode, self._adjoint_solver_type, backward_ode)) self._odeB = backward_ode[0] self._state_bufferB = sunode.empty_vector(self._problem.n_states) check( lib.CVodeInitB(self._ode, self._odeB, self._adj_rhs.cffi, 0., self._state_bufferB.c_ptr)) # TODO check(lib.CVodeSStolerancesB(self._ode, self._odeB, 1e-10, 1e-10)) linsolver = check( lib.SUNLinSol_Dense(self._state_bufferB.c_ptr, self._jacB)) check( lib.CVodeSetLinearSolverB(self._ode, self._odeB, linsolver, self._jacB)) self._jac_funcB = self._problem.make_sundials_adjoint_jac_dense() check(lib.CVodeSetJacFnB(self._ode, self._odeB, self._jac_funcB.cffi)) user_data_p = ffi.cast( 'void *', ffi.addressof(ffi.from_buffer(self._user_data.data))) check(lib.CVodeSetUserDataB(self._ode, self._odeB, user_data_p)) self._quad_buffer = sunode.empty_vector(self._problem.n_params) self._quad_buffer_out = sunode.empty_vector(self._problem.n_params) check( lib.CVodeQuadInitB(self._ode, self._odeB, self._quad_rhs.cffi, self._quad_buffer.c_ptr)) check(lib.CVodeQuadSStolerancesB(self._ode, self._odeB, 1e-10, 1e-10)) check(lib.CVodeSetQuadErrConB(self._ode, self._odeB, 1))
def solve_backward(self, t0, tend, tvals, grads, grad_out, lamda_out, lamda_all_out=None, quad_all_out=None, max_retries=50): CVodeReInitB = lib.CVodeReInitB CVodeQuadReInitB = lib.CVodeQuadReInitB CVodeGetQuadB = lib.CVodeGetQuadB CVodeB = lib.CVodeB CVodeGetB = lib.CVodeGetB ode = self._ode odeB = self._odeB TOO_MUCH_WORK = lib.CV_TOO_MUCH_WORK state_data = self._state_bufferB.data state_c_ptr = self._state_bufferB.c_ptr quad_data = self._quad_buffer.data quad_c_ptr = self._quad_buffer.c_ptr quad_out_data = self._quad_buffer_out.data quad_out_c_ptr = self._quad_buffer_out.c_ptr state_data[:] = 0 quad_data[:] = 0 quad_out_data[:] = 0 time_p = ffi.new('double*') time_p[0] = t0 ts = [t0] + list(tvals[::-1]) + [tend] t_intervals = zip(ts[1:], ts[:-1]) grads = [None] + list(grads) for i, ((t_lower, t_upper), grad) in enumerate(zip(t_intervals, reversed(grads))): if t_lower < t_upper: check(CVodeReInitB(ode, odeB, t_upper, state_c_ptr)) check(CVodeQuadReInitB(ode, odeB, quad_c_ptr)) for retry in range(max_retries): retval = CVodeB(ode, t_lower, lib.CV_NORMAL) if retval == 0: break if retval != TOO_MUCH_WORK: error = ERRORS[retval] raise SolverError( f"Solving ode failed between time {t_upper} and " f"{t_lower}: {error} ({retval})") else: raise SolverError( f"Too many solver retries between time {t_upper} and {t_lower}." ) check(CVodeGetB(ode, odeB, time_p, state_c_ptr)) check(CVodeGetQuadB(ode, odeB, time_p, quad_out_c_ptr)) quad_data[:] = quad_out_data[:] assert time_p[0] == t_lower, (time_p[0], t_lower) if grad is not None: state_data[:] -= grad if lamda_all_out is not None: lamda_all_out[-i, :] = state_data if quad_all_out is not None: quad_all_out[-i, :] = quad_data grad_out[:] = quad_out_data lamda_out[:] = state_data
def solve(self, t0, tvals, y0, y_out, *, sens0=None, sens_out=None, max_retries=5): if self._compute_sens and (sens0 is None or sens_out is None): raise ValueError( '"sens_out" and "sens0" are required when computin sensitivities.' ) CVodeReInit = lib.CVodeReInit CVodeSensReInit = lib.CVodeSensReInit CVode = lib.CVode CVodeGetSens = lib.CVodeGetSens ode = self._ode TOO_MUCH_WORK = lib.CV_TOO_MUCH_WORK n_params = self._problem.n_params state_data = self._state_buffer.data state_c_ptr = self._state_buffer.c_ptr if self._compute_sens: sens_buffer_array = self._sens_buffer_array sens_data = tuple(buffer.data for buffer in self._sens_buffers) for i in range(n_params): sens_data[i][:] = sens0[i, :] if y0.dtype == self._problem.state_dtype: y0 = y0[None].view(np.float64) state_data[:] = y0 time_p = ffi.new('double*') time_p[0] = t0 check(CVodeReInit(ode, t0, state_c_ptr)) if self._compute_sens: check( CVodeSensReInit(ode, self._sens_mode, self._sens_buffer_array)) for i, t in enumerate(tvals): if t == t0: y_out[0, :] = y0 if self._compute_sens: sens_out[0, :, :] = sens0 continue for retry in range(max_retries): retval = CVode(ode, t, state_c_ptr, time_p, lib.CV_NORMAL) if retval == 0: assert time_p[0] == t break if retval != TOO_MUCH_WORK: error = ERRORS[retval] raise SolverError( f"Solving ode failed before time={t}: {error} ({retval})" ) else: raise SolverError(f"Too many solver retries before time={t}.") y_out[i, :] = state_data if self._compute_sens: retval = CVodeGetSens(ode, time_p, sens_buffer_array) if retval == 0: for j in range(n_params): sens_out[i, j, :] = sens_data[j]