Example #1
0
 def scale_errors(error_component, tol_scale):
     abs_square_error = rk_util.abs_square(error_component)
     abs_square_tol_scale = rk_util.abs_square(tol_scale)
     return tf.divide(abs_square_error, abs_square_tol_scale)
 def test_abs_square(self, dtype):
   test_values = np.array([1 + 2j, 0.3 - 1j, 3.5 - 3.7j])
   input_values = tf.cast(test_values, dtype)
   actual_abs_square = rk_util.abs_square(input_values)
   expected_abs_square = tf.math.square(tf.abs(input_values))
   self.assertAllClose(actual_abs_square, expected_abs_square)
Example #3
0
  def _step(
      self,
      solver_state,
      diagnostics,
      max_ode_fn_evals,
      ode_fn,
      atol,
      rtol,
      safety,
      ifactor,
      dfactor
  ):
    """Take an adaptive Runge-Kutta step.

    Args:
      solver_state: `_DopriSolverInternalState` - solver internal state.
      diagnostics: `_DopriDiagnostics` - info on the current `_solve` call.
      max_ode_fn_evals: Integer `Tensor` specifying the maximum number of ode_fn
        evaluations.
      ode_fn: Callable(t, y) -> dy_dt.
      atol: Absolute tolerance allowed, see `_solve` method for details.
      rtol: Relative tolerance allowed, see `_solve` method for details.
      safety: Safety factor, see `_solve` method for details.
      ifactor: Maximum factor by which the step size can increase, see `_solve`
        method for details.
      dfactor: Minimum factor by which the step size can decrease, see `_solve`
        method for details.

    Returns:
      solver_state: `_RungeKuttaSolverInternalState` holding new solver state.
        Note that the step might not advance the time if the error tolerance
        criterias were not met. In this case step_size is decreased.
      diagnostics: `_DopriDiagnostics` holding diagnostic values after RK step.
    """
    y0, f0, _, t0, dt, interp_coeff = solver_state
    assertion_ops = []
    # TODO(dkochkov) Profile performance impact of `control_dependencies` here.
    with tf.name_scope('assertions'):
      if self._max_num_steps is not None:
        check_max_num_steps = tf.debugging.assert_less(
            diagnostics.num_ode_fn_evaluations, max_ode_fn_evals,
            'max_num_steps exceeded')
        assertion_ops.append(check_max_num_steps)

      if self._validate_args:
        check_underflow = tf.debugging.assert_greater(
            t0 + dt, t0, 'underflow in dt')
        check_numerics = tf.debugging.assert_all_finite(
            y0, 'non-finite values in solution')
        assertion_ops.append(check_underflow)
        assertion_ops.append(check_numerics)

    with tf.control_dependencies(assertion_ops):
      y1, f1, y1_error, k = rk_util.runge_kutta_step(
          ode_fn, y0, f0, t0, dt, _TABLEAU)

    with tf.name_scope('error_ratio'):
      # We use the same criteria for accepting step as in scipy.
      abs_y0 = tf.nest.map_structure(tf.abs, y0)
      abs_y1 = tf.nest.map_structure(tf.abs, y1)
      max_y_vals = tf.nest.map_structure(tf.math.maximum, abs_y0, abs_y1)
      ones_nest = rk_util.nest_constant(abs_y0)
      error_tol = rk_util.weighted_sum([atol, rtol], [ones_nest, max_y_vals])
      scaled_errors = tf.nest.map_structure(tf.divide,
                                            rk_util.abs_square(y1_error),
                                            rk_util.abs_square(error_tol))
      error_ratio = rk_util.nest_rms_norm(scaled_errors)
      accept_step = error_ratio <= 1

    with tf.name_scope('update/state'):
      y_next = tf.where(accept_step, y1, y0)
      f_next = tf.where(accept_step, f1, f0)
      t_next = tf.where(accept_step, t0 + dt, t0)

      new_coefficients = rk_util.rk_fourth_order_interpolation_coefficients(
          y0, y1, k, dt, _TABLEAU)
      new_and_old_coefficients = zip(new_coefficients, interp_coeff)
      interp_coeff = [tf.where(accept_step, new_c, old_c)
                      for new_c, old_c in new_and_old_coefficients]

      dt_next = util.next_step_size(
          dt, self.ORDER, error_ratio, safety, dfactor, ifactor)
      solver_state = _RungeKuttaSolverInternalState(
          y_next, f_next, t0, t_next, dt_next, interp_coeff)

    diagnostics = diagnostics._replace(
        num_ode_fn_evaluations=diagnostics.num_ode_fn_evaluations +
        self.ODE_FN_EVALS_PER_STEP)

    return solver_state, diagnostics