def _adaptive_heun_step(self, rk_state): """Take an adaptive Runge-Kutta step to integrate the ODE.""" y0, f0, _, t0, dt, interp_coeff = rk_state ######################################################## # Assertions # ######################################################## assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item()) for y0_ in y0: assert _is_finite(tf.math.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_) y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_ADAPTIVE_HEUN_TABLEAU) ######################################################## # Error Ratio # ######################################################## # print("y error", y1_error[0].numpy()) mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1) # print("mean sq error ratio", mean_sq_error_ratio[0].numpy()) accept_step = tf.reduce_all(tf.convert_to_tensor(mean_sq_error_ratio, dtype=tf.float64) <= 1.) ######################################################## # Update RK State # ######################################################## y_next = y1 if accept_step else y0 f_next = f1 if accept_step else f0 t_next = t0 + dt if accept_step else t0 interp_coeff = _interp_fit_adaptive_heun(y0, y1, k, dt) if accept_step else interp_coeff dt_next = _optimal_step_size( dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=5 ) rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff) return rk_state
def _adaptive_tsit5_step(self, rk_state): """Take an adaptive Runge-Kutta step to integrate the ODE.""" y0, f0, _, t0, dt, _ = rk_state ######################################################## # Assertions # ######################################################## assert t0 + dt > t0, 'underflow in dt {}'.format(dt.numpy()) for y0_ in y0: assert _is_finite( tf.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_) y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_TSITOURAS_TABLEAU) ######################################################## # Error Ratio # ######################################################## error_tol = tuple( self.atol + self.rtol * tf.reduce_max([tf.abs(y0_), tf.abs(y1_)]) for y0_, y1_ in zip(y0, y1)) tensor_error_ratio = tuple( y1_error_ / error_tol_ for y1_error_, error_tol_ in zip(y1_error, error_tol)) sq_error_ratio = tuple( tf.multiply(tensor_error_ratio_, tensor_error_ratio_) for tensor_error_ratio_ in tensor_error_ratio) mean_error_ratio = ( sum( tf.reduce_sum(sq_error_ratio_) for sq_error_ratio_ in sq_error_ratio) / sum(_numel(sq_error_ratio_) for sq_error_ratio_ in sq_error_ratio)) accept_step = mean_error_ratio <= 1. ######################################################## # Update RK State # ######################################################## y_next = y1 if accept_step else y0 f_next = f1 if accept_step else f0 t_next = t0 + dt if accept_step else t0 dt_next = _optimal_step_size(dt, mean_error_ratio, self.safety, self.ifactor, self.dfactor, order=self.order) k_next = k if accept_step else self.rk_state.interp_coeff rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, k_next) return rk_state
def _adaptive_dopri5_step(self, rk_state): """Take an adaptive Runge-Kutta step to integrate the ODE.""" y0, f0, _, t0, dt, interp_coeff = rk_state ######################################################## # Assertions # ######################################################## dt = tf.cast(dt, t0.dtype) assert t0 + dt > t0, 'underflow in dt {}'.format(dt.numpy()) for y0_ in y0: assert _is_finite( tf.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_) y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=self.tableau) ######################################################## # Error Ratio # ######################################################## mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1) accept_step = tf.reduce_all( tf.convert_to_tensor(mean_sq_error_ratio) <= 1) ######################################################## # Update RK State # ######################################################## y_next = y1 if accept_step else y0 f_next = f1 if accept_step else f0 t_next = t0 + dt if accept_step else t0 interp_coeff = _interp_fit_dopri5(y0, y1, k, dt) if accept_step else interp_coeff dt_next = _optimal_step_size(dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=self.order) rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff) return rk_state