def _adaptive_heun_step(self, rk_state):
        """Take an adaptive Runge-Kutta step to integrate the ODE."""
        y0, f0, _, t0, dt, interp_coeff = rk_state
        ########################################################
        #                      Assertions                      #
        ########################################################
        assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item())
        for y0_ in y0:
            assert _is_finite(tf.math.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_)
        y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_ADAPTIVE_HEUN_TABLEAU)

        ########################################################
        #                     Error Ratio                      #
        ########################################################
        # print("y error", y1_error[0].numpy())
        mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1)
        # print("mean sq error ratio", mean_sq_error_ratio[0].numpy())
        accept_step = tf.reduce_all(tf.convert_to_tensor(mean_sq_error_ratio, dtype=tf.float64) <= 1.)

        ########################################################
        #                   Update RK State                    #
        ########################################################
        y_next = y1 if accept_step else y0
        f_next = f1 if accept_step else f0
        t_next = t0 + dt if accept_step else t0
        interp_coeff = _interp_fit_adaptive_heun(y0, y1, k, dt) if accept_step else interp_coeff
        dt_next = _optimal_step_size(
            dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=5
        )
        rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff)
        return rk_state
Exemple #2
0
    def _adaptive_dopri5_step(self, rk_state):
        """Take an adaptive Runge-Kutta step to integrate the ODE."""
        y0, f0, _, t0, dt, interp_coeff = rk_state
        ########################################################
        #                      Assertions                      #
        ########################################################
        dt = tf.cast(dt, t0.dtype)
        assert t0 + dt > t0, 'underflow in dt {}'.format(dt.numpy())
        for y0_ in y0:
            assert _is_finite(
                tf.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_)
        y1, f1, y1_error, k = _runge_kutta_step(self.func,
                                                y0,
                                                f0,
                                                t0,
                                                dt,
                                                tableau=self.tableau)

        ########################################################
        #                     Error Ratio                      #
        ########################################################
        mean_sq_error_ratio = _compute_error_ratio(y1_error,
                                                   atol=self.atol,
                                                   rtol=self.rtol,
                                                   y0=y0,
                                                   y1=y1)
        accept_step = tf.reduce_all(
            tf.convert_to_tensor(mean_sq_error_ratio) <= 1)

        ########################################################
        #                   Update RK State                    #
        ########################################################
        y_next = y1 if accept_step else y0
        f_next = f1 if accept_step else f0
        t_next = t0 + dt if accept_step else t0
        interp_coeff = _interp_fit_dopri5(y0, y1, k,
                                          dt) if accept_step else interp_coeff
        dt_next = _optimal_step_size(dt,
                                     mean_sq_error_ratio,
                                     safety=self.safety,
                                     ifactor=self.ifactor,
                                     dfactor=self.dfactor,
                                     order=self.order)
        rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next,
                                    interp_coeff)
        return rk_state
Exemple #3
0
    def _adaptive_adams_step(self, vcabm_state, final_t):
        y0, prev_f, prev_t, next_t, prev_phi, order = vcabm_state
        if next_t > final_t:
            next_t = final_t
        dt = (next_t - prev_t[0])
        dt_cast = move_to_device(dt, y0[0].device)
        dt_cast = tf.cast(dt_cast, y0[0].dtype)

        # Explicit predictor step.
        g, phi = g_and_explicit_phi(prev_t, next_t, prev_phi, order)
        # g = move_to_device(g, y0[0].device)

        g = tf.cast(g, dt_cast.dtype)
        # phi = [tf.cast(phi_, dt_cast.dtype) for phi_ in phi]

        p_next = tuple(
            y0_ +
            _scaled_dot_product(dt_cast, g[:max(1, order -
                                                1)], phi_[:max(1, order - 1)])
            for y0_, phi_ in zip(y0, tuple(zip(*phi))))

        # Update phi to implicit.
        next_t = move_to_device(next_t, p_next[0].device)
        next_f0 = self.func(tf.cast(next_t, self.y0[0].dtype), p_next)
        implicit_phi_p = compute_implicit_phi(phi, next_f0, order + 1)

        # Implicit corrector step.
        y_next = tuple(
            p_next_ + dt_cast * g[order - 1] * tf.cast(iphi_, dt_cast.dtype)
            for p_next_, iphi_ in zip(p_next, implicit_phi_p[order - 1]))

        # Error estimation.
        tolerance = tuple(
            atol_ +
            rtol_ * tf.reduce_max([tf.abs(y0_), tf.abs(y1_)])
            for atol_, rtol_, y0_, y1_ in zip(self.atol, self.rtol, y0,
                                              y_next))
        local_error = tuple(dt_cast * (g[order] - g[order - 1]) *
                            tf.cast(iphi_, dt_cast.dtype)
                            for iphi_ in implicit_phi_p[order])
        error_k = _compute_error_ratio(local_error, tolerance)
        accept_step = tf.reduce_all((tf.convert_to_tensor(error_k) <= 1))

        if not accept_step:
            # Retry with adjusted step size if step is rejected.
            dt_next = _optimal_step_size(dt,
                                         error_k,
                                         self.safety,
                                         self.ifactor,
                                         self.dfactor,
                                         order=order)
            return _VCABMState(y0,
                               prev_f,
                               prev_t,
                               prev_t[0] + dt_next,
                               prev_phi,
                               order=order)

        # We accept the step. Evaluate f and update phi.
        next_t = move_to_device(next_t, p_next[0].device)
        next_f0 = self.func(tf.cast(next_t, self.y0[0].dtype), y_next)
        implicit_phi = compute_implicit_phi(phi, next_f0, order + 2)
        next_order = order

        if len(prev_t) <= 4 or order < 3:
            next_order = min(order + 1, 3, self.max_order)
        else:
            error_km1 = _compute_error_ratio(
                tuple(dt_cast * (g[order - 1] - g[order - 2]) * iphi_
                      for iphi_ in implicit_phi_p[order - 1]), tolerance)
            error_km2 = _compute_error_ratio(
                tuple(dt_cast * (g[order - 2] - g[order - 3]) * iphi_
                      for iphi_ in implicit_phi_p[order - 2]), tolerance)
            if min(error_km1 + error_km2) < max(error_k):
                next_order = order - 1
            elif order < self.max_order:
                error_kp1 = _compute_error_ratio(
                    tuple(dt_cast * gamma_star[order] * iphi_
                          for iphi_ in implicit_phi_p[order]), tolerance)
                if max(error_kp1) < max(error_k):
                    next_order = order + 1

        # Keep step size constant if increasing order. Else use adaptive step size.
        dt_next = dt if next_order > order else _optimal_step_size(
            dt,
            error_k,
            self.safety,
            self.ifactor,
            self.dfactor,
            order=order + 1)

        prev_f.appendleft(next_f0)
        prev_t.appendleft(next_t)
        return _VCABMState(p_next,
                           prev_f,
                           prev_t,
                           next_t + dt_next,
                           implicit_phi,
                           order=next_order)