def test_nest_rms_norm_on_nest(self, dtype):
   del dtype  # not used in this test case.
   a = np.array([1.4, 2.7, 7.3])
   b = 0.3 * np.eye(3, dtype=np.float32) + 0.64 * np.ones((3, 3))
   input_nest = (tf.convert_to_tensor(a), tf.convert_to_tensor(b))
   actual_norm_nest = rk_util.nest_rms_norm(input_nest)
   full_state = np.concatenate([np.expand_dims(a, 0), b])
   expected_norm_nest = np.linalg.norm(full_state) / np.sqrt(full_state.size)
   self.assertAllClose(expected_norm_nest, actual_norm_nest)
Ejemplo n.º 2
0
    def _step(self, solver_state, diagnostics, max_ode_fn_evals, ode_fn, atol,
              rtol, safety, ifactor, dfactor):
        """Take an adaptive Runge-Kutta step.

    Args:
      solver_state: `_DopriSolverInternalState` - solver internal state.
      diagnostics: `_DopriDiagnostics` - info on the current `_solve` call.
      max_ode_fn_evals: Integer `Tensor` specifying the maximum number of ode_fn
        evaluations.
      ode_fn: Callable(t, y) -> dy_dt.
      atol: Absolute tolerance allowed, see `_solve` method for details.
      rtol: Relative tolerance allowed, see `_solve` method for details.
      safety: Safety factor, see `_solve` method for details.
      ifactor: Maximum factor by which the step size can increase, see `_solve`
        method for details.
      dfactor: Minimum factor by which the step size can decrease, see `_solve`
        method for details.

    Returns:
      solver_state: `_RungeKuttaSolverInternalState` holding new solver state.
        Note that the step might not advance the time if the error tolerance
        criterias were not met. In this case step_size is decreased.
      diagnostics: `_DopriDiagnostics` holding diagnostic values after RK step.
    """
        y0, f0, _, t0, dt, interp_coeff = solver_state
        assertion_ops = []
        # TODO(dkochkov) Profile performance impact of `control_dependencies` here.
        with tf.name_scope('assertions'):
            if self._max_num_steps is not None:
                check_max_num_steps = tf.debugging.assert_less(
                    diagnostics.num_ode_fn_evaluations, max_ode_fn_evals,
                    'max_num_steps exceeded')
                assertion_ops.append(check_max_num_steps)

            if self._validate_args:
                check_underflow = tf.debugging.assert_greater(
                    t0 + dt, t0, 'underflow in dt')
                assert_finite = functools.partial(
                    tf.debugging.assert_all_finite,
                    message='non-finite values in solution')
                check_numerics = tf.nest.map_structure(assert_finite, y0)
                assertion_ops.append(check_underflow)
                assertion_ops.append(check_numerics)

        with tf.control_dependencies(assertion_ops):
            y1, f1, y1_error, k = rk_util.runge_kutta_step(
                ode_fn, y0, f0, t0, dt, _TABLEAU)

        with tf.name_scope('error_ratio'):
            # We use the same criteria for accepting step as in scipy.
            abs_y0 = tf.nest.map_structure(tf.abs, y0)
            abs_y1 = tf.nest.map_structure(tf.abs, y1)
            max_y_vals = tf.nest.map_structure(tf.math.maximum, abs_y0, abs_y1)
            ones_nest = rk_util.nest_constant(abs_y0)
            error_tol = rk_util.weighted_sum([atol, rtol],
                                             [ones_nest, max_y_vals])

            def scale_errors(error_component, tol_scale):
                abs_square_error = rk_util.abs_square(error_component)
                abs_square_tol_scale = rk_util.abs_square(tol_scale)
                return tf.divide(abs_square_error, abs_square_tol_scale)

            scaled_errors = tf.nest.map_structure(scale_errors, y1_error,
                                                  error_tol)
            error_ratio = rk_util.nest_rms_norm(scaled_errors)
            accept_step = error_ratio <= 1

        with tf.name_scope('update/state'):
            y_next = rk_util.nest_where(accept_step, y1, y0)
            f_next = rk_util.nest_where(accept_step, f1, f0)
            t_next = tf.where(accept_step, t0 + dt, t0)

            new_coefficients = rk_util.rk_fourth_order_interpolation_coefficients(
                y0, y1, k, dt, _TABLEAU)
            interp_coeff = rk_util.nest_where(accept_step, new_coefficients,
                                              interp_coeff)

            dt_next = util.next_step_size(dt, self.ORDER, error_ratio, safety,
                                          dfactor, ifactor)
            solver_state = _RungeKuttaSolverInternalState(
                y_next, f_next, t0, t_next, dt_next, interp_coeff)

        diagnostics = diagnostics._replace(
            num_ode_fn_evaluations=diagnostics.num_ode_fn_evaluations +
            self.ODE_FN_EVALS_PER_STEP)

        return solver_state, diagnostics