def scale_errors(error_component, tol_scale): abs_square_error = rk_util.abs_square(error_component) abs_square_tol_scale = rk_util.abs_square(tol_scale) return tf.divide(abs_square_error, abs_square_tol_scale)
def test_abs_square(self, dtype): test_values = np.array([1 + 2j, 0.3 - 1j, 3.5 - 3.7j]) input_values = tf.cast(test_values, dtype) actual_abs_square = rk_util.abs_square(input_values) expected_abs_square = tf.math.square(tf.abs(input_values)) self.assertAllClose(actual_abs_square, expected_abs_square)
def _step( self, solver_state, diagnostics, max_ode_fn_evals, ode_fn, atol, rtol, safety, ifactor, dfactor ): """Take an adaptive Runge-Kutta step. Args: solver_state: `_DopriSolverInternalState` - solver internal state. diagnostics: `_DopriDiagnostics` - info on the current `_solve` call. max_ode_fn_evals: Integer `Tensor` specifying the maximum number of ode_fn evaluations. ode_fn: Callable(t, y) -> dy_dt. atol: Absolute tolerance allowed, see `_solve` method for details. rtol: Relative tolerance allowed, see `_solve` method for details. safety: Safety factor, see `_solve` method for details. ifactor: Maximum factor by which the step size can increase, see `_solve` method for details. dfactor: Minimum factor by which the step size can decrease, see `_solve` method for details. Returns: solver_state: `_RungeKuttaSolverInternalState` holding new solver state. Note that the step might not advance the time if the error tolerance criterias were not met. In this case step_size is decreased. diagnostics: `_DopriDiagnostics` holding diagnostic values after RK step. """ y0, f0, _, t0, dt, interp_coeff = solver_state assertion_ops = [] # TODO(dkochkov) Profile performance impact of `control_dependencies` here. with tf.name_scope('assertions'): if self._max_num_steps is not None: check_max_num_steps = tf.debugging.assert_less( diagnostics.num_ode_fn_evaluations, max_ode_fn_evals, 'max_num_steps exceeded') assertion_ops.append(check_max_num_steps) if self._validate_args: check_underflow = tf.debugging.assert_greater( t0 + dt, t0, 'underflow in dt') check_numerics = tf.debugging.assert_all_finite( y0, 'non-finite values in solution') assertion_ops.append(check_underflow) assertion_ops.append(check_numerics) with tf.control_dependencies(assertion_ops): y1, f1, y1_error, k = rk_util.runge_kutta_step( ode_fn, y0, f0, t0, dt, _TABLEAU) with tf.name_scope('error_ratio'): # We use the same criteria for accepting step as in scipy. abs_y0 = tf.nest.map_structure(tf.abs, y0) abs_y1 = tf.nest.map_structure(tf.abs, y1) max_y_vals = tf.nest.map_structure(tf.math.maximum, abs_y0, abs_y1) ones_nest = rk_util.nest_constant(abs_y0) error_tol = rk_util.weighted_sum([atol, rtol], [ones_nest, max_y_vals]) scaled_errors = tf.nest.map_structure(tf.divide, rk_util.abs_square(y1_error), rk_util.abs_square(error_tol)) error_ratio = rk_util.nest_rms_norm(scaled_errors) accept_step = error_ratio <= 1 with tf.name_scope('update/state'): y_next = tf.where(accept_step, y1, y0) f_next = tf.where(accept_step, f1, f0) t_next = tf.where(accept_step, t0 + dt, t0) new_coefficients = rk_util.rk_fourth_order_interpolation_coefficients( y0, y1, k, dt, _TABLEAU) new_and_old_coefficients = zip(new_coefficients, interp_coeff) interp_coeff = [tf.where(accept_step, new_c, old_c) for new_c, old_c in new_and_old_coefficients] dt_next = util.next_step_size( dt, self.ORDER, error_ratio, safety, dfactor, ifactor) solver_state = _RungeKuttaSolverInternalState( y_next, f_next, t0, t_next, dt_next, interp_coeff) diagnostics = diagnostics._replace( num_ode_fn_evaluations=diagnostics.num_ode_fn_evaluations + self.ODE_FN_EVALS_PER_STEP) return solver_state, diagnostics