def test_weighted_sum_nested_values(self, dtype):
    del dtype  # not used in this test case.
    weights = [0.5, -0.25, -0.25]
    states = [(tf.eye(2), tf.ones((2, 2))) for _ in range(3)]
    weighted_state_sum = rk_util.weighted_sum(weights, states)
    expected_result = (tf.zeros((2, 2)), tf.zeros((2, 2)))
    self.assertAllCloseNested(weighted_state_sum, expected_result)

    weights = [0.5, -0.25, -0.25, 0]
    states = [(tf.eye(2), tf.ones((2, 2))) for _ in range(4)]
    weighted_state_sum = rk_util.weighted_sum(weights, states)
    expected_result = (tf.zeros((2, 2)), tf.zeros((2, 2)))
    self.assertAllCloseNested(weighted_state_sum, expected_result)
Exemplo n.º 2
0
 def reverse_to_result_time(n, augmented_state, _):
     """Integrates the augmented system backwards in time."""
     lower_bound_of_integration = result_time_array.read(n)
     upper_bound_of_integration = result_time_array.read(n -
                                                         1)
     _, adjoint_state, adjoint_variable_state = augmented_state
     initial_state = _read_solution_components(
         result_state_arrays, input_state_structure, n - 1)
     initial_adjoint = _read_solution_components(
         dresult_state_arrays, input_state_structure, n - 1)
     initial_adjoint_state = rk_util.weighted_sum(
         [1.0, 1.0], [adjoint_state, initial_adjoint])
     initial_augmented_state = (initial_state,
                                initial_adjoint_state,
                                adjoint_variable_state)
     # TODO(b/138304303): Allow the user to specify the Hessian of
     # `ode_fn` so that we can get the Jacobian of the adjoint system.
     # TODO(b/143624114): Support higher order derivatives.
     augmented_results = self._solve(
         ode_fn=augmented_ode_fn,
         initial_time=-lower_bound_of_integration,
         initial_state=initial_augmented_state,
         solution_times=[-upper_bound_of_integration],
         batch_ndims=batch_ndims)
     # Results added an extra time dim of size 1, squeeze it.
     select_result = lambda x: tf.squeeze(x, [0])
     result_state = augmented_results.states
     result_state = tf.nest.map_structure(
         select_result, result_state)
     status = augmented_results.diagnostics.status
     return n - 1, result_state, status
  def test_weighted_sum_tensor(self, dtype):
    del dtype  # not used in this test case.
    weights = [0.5, -0.25, -0.25]
    states = [tf.eye(2) for _ in range(3)]
    weighted_tensor_sum = rk_util.weighted_sum(weights, states)
    self.assertAllClose(weighted_tensor_sum, tf.zeros((2, 2)))

    weights = [0.5, -0.25, -0.25, 1.0]
    states = [tf.ones(2) for _ in range(4)]
    weighted_tensor_sum = rk_util.weighted_sum(weights, states)
    self.assertAllClose(weighted_tensor_sum, tf.ones(2))

    weights = [0.5, -0.25, -0.25, 0.0]
    states = [tf.eye(2) for _ in range(4)]
    weighted_tensor_sum = rk_util.weighted_sum(weights, states)
    self.assertAllClose(weighted_tensor_sum, tf.zeros((2, 2)))
Exemplo n.º 4
0
 def reverse_to_result_time(n, augmented_state, _):
     """Integrates the augmented system backwards in time."""
     lower_bound_of_integration = result_time_array.read(n)
     upper_bound_of_integration = result_time_array.read(n -
                                                         1)
     _, adjoint_state, adjoint_variable_state = augmented_state
     initial_state = _read_solution_components(
         result_state_arrays, input_state_structure, n - 1)
     initial_adjoint = _read_solution_components(
         dresult_state_arrays, input_state_structure, n - 1)
     initial_adjoint_state = rk_util.weighted_sum(
         [1.0, 1.0], [adjoint_state, initial_adjoint])
     initial_augmented_state = (initial_state,
                                initial_adjoint_state,
                                adjoint_variable_state)
     augmented_results = self._solve(
         ode_fn=augmented_ode_fn,
         initial_time=-lower_bound_of_integration,
         initial_state=initial_augmented_state,
         solution_times=[-upper_bound_of_integration],
         batch_ndims=batch_ndims)
     # Results added an extra time dim of size 1, squeeze it.
     select_result = lambda x: tf.squeeze(x, [0])
     result_state = augmented_results.states
     result_state = tf.nest.map_structure(
         select_result, result_state)
     status = augmented_results.diagnostics.status
     return n - 1, result_state, status
  def test_weighted_sum_value_errors(self, dtype):
    del dtype  # not used in this test case.
    empty_weights = []
    empty_states = []
    with self.assertRaises(ValueError):
      _ = rk_util.weighted_sum(empty_weights, empty_states)

    wrong_length_weights = [0.5, -0.25, -0.25, 0]
    wrong_length_states = [(tf.eye(2), tf.ones((2, 2))) for _ in range(5)]
    with self.assertRaises(ValueError):
      _ = rk_util.weighted_sum(wrong_length_weights, wrong_length_states)

    weights = [0.5, -0.25, -0.25, 0]
    not_same_structure_states = [(tf.eye(2), tf.ones((2, 2))) for _ in range(3)]
    not_same_structure_states.append(tf.eye(2))
    with self.assertRaises(ValueError):
      _ = rk_util.weighted_sum(weights, not_same_structure_states)
Exemplo n.º 6
0
                    def grad_fn(state, variables, constants):
                        del variables  # We compute these gradients via the GradientTape
                        # capturing them.
                        derivatives = ode_fn(time, state, **constants)
                        adjoint_no_grad = tf.nest.map_structure(
                            tf.stop_gradient, adjoint_state)
                        negative_derivatives = rk_util.weighted_sum(
                            [-1.0], [derivatives])

                        def dot_prod(tensor_a, tensor_b):
                            return tf.reduce_sum(tensor_a * tensor_b)

                        # See docstring for details.
                        adjoint_dot_derivatives = tf.nest.map_structure(
                            dot_prod, adjoint_no_grad, derivatives)
                        adjoint_dot_derivatives = tf.squeeze(
                            tf.add_n(tf.nest.flatten(adjoint_dot_derivatives)))
                        return adjoint_dot_derivatives, negative_derivatives
Exemplo n.º 7
0
 def make_augmented_state(n, prev_augmented_state):
     """Constructs the augmented state for step `n`."""
     (_, adjoint_state, adjoint_variable_state,
      adjoint_constant_state) = prev_augmented_state
     initial_state = _read_solution_components(
         result_state_arrays,
         input_state_structure,
         n - 1,
     )
     initial_adjoint = _read_solution_components(
         dresult_state_arrays,
         input_state_structure,
         n - 1,
     )
     initial_adjoint_state = rk_util.weighted_sum(
         [1.0, 1.0], [adjoint_state, initial_adjoint])
     augmented_state = (
         initial_state,
         initial_adjoint_state,
         adjoint_variable_state,
         adjoint_constant_state,
     )
     return augmented_state
Exemplo n.º 8
0
                    def augmented_ode_fn(backward_time, augmented_state):
                        """Dynamics function for the augmented system.

            Describes a differential equation that evolves the augmented state
            backwards in time to compute gradients using the adjoint method.
            Augmented state consists of 3 components `(state, adjoint_state,
            vars)` all evaluated at time `backward_time`:

            state: represents the solution of user provided `ode_fn`. The
              structure coincides with the `initial_state`.
            adjoint_state: represents the solution of adjoint sensitivity
              differential equation as discussed below. Has the same structure
              and shape as `state`.
            vars: represent the solution of the adjoint equation for variable
              gradients. Represented as a `Tuple(Tensor, ...)` with as many
              tensors as there are `variables`.

            Adjoint sensitivity equation describes the gradient of the solution
            with respect to the value of the solution at previous time t. Its
            dynamics are given by
            d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z)
            Which is computed as:
            d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)]
            d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]
            where in the last line we moved adj(t)_j under derivative by
            removing gradient from it.

            Adjoint equation for the gradient with respect to every
            `tf.Variable` theta follows:
            d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta)
            = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]

            Args:
              backward_time: Floating `Tensor` representing current time.
              augmented_state: `Tuple(state, adjoint_state, variable_grads)`

            Returns:
              negative_derivatives: Structure of `Tensor`s equal to backwards
                time derivative of the `state` componnent.
              adjoint_ode: Structure of `Tensor`s equal to backwards time
                derivative of the `adjoint_state` component.
              adjoint_variables_ode: Structure of `Tensor`s equal to backwards
                time derivative of the `vars` component.
            """
                        # The negative signs disappears after the change of variables.
                        # The ODE solver cannot handle the case initial_time > final_time
                        # and hence a change of variables backward_time = -time is used.
                        time = -backward_time
                        state, adjoint_state, _ = augmented_state

                        with tf.GradientTape() as tape:
                            tape.watch(variables)
                            tape.watch(state)
                            derivatives = ode_fn(time, state)
                            adjoint_no_grad = tf.nest.map_structure(
                                tf.stop_gradient, adjoint_state)
                            negative_derivatives = rk_util.weighted_sum(
                                [-1.0], [derivatives])

                            def dot_prod(tensor_a, tensor_b):
                                return tf.reduce_sum(tensor_a * tensor_b)

                            # See docstring for details.
                            adjoint_dot_derivatives = tf.nest.map_structure(
                                dot_prod, adjoint_no_grad, derivatives)
                            adjoint_dot_derivatives = tf.squeeze(
                                tf.add_n(
                                    tf.nest.flatten(adjoint_dot_derivatives)))

                        adjoint_ode, adjoint_variables_ode = tape.gradient(
                            adjoint_dot_derivatives, (state, tuple(variables)),
                            unconnected_gradients=tf.UnconnectedGradients.ZERO)
                        return negative_derivatives, adjoint_ode, adjoint_variables_ode
 def test_weighted_sum_nested_type(self, dtype):
   del dtype  # not used in this test case.
   weights = [0.5, -0.25, -0.25]
   states = [(tf.eye(2), tf.ones((2, 2))) for _ in range(3)]
   weighted_state_sum = rk_util.weighted_sum(weights, states)
   self.assertIsInstance(weighted_state_sum, tuple)
Exemplo n.º 10
0
    def _step(self, solver_state, diagnostics, max_ode_fn_evals, ode_fn, atol,
              rtol, safety, ifactor, dfactor):
        """Take an adaptive Runge-Kutta step.

    Args:
      solver_state: `_DopriSolverInternalState` - solver internal state.
      diagnostics: `_DopriDiagnostics` - info on the current `_solve` call.
      max_ode_fn_evals: Integer `Tensor` specifying the maximum number of ode_fn
        evaluations.
      ode_fn: Callable(t, y) -> dy_dt.
      atol: Absolute tolerance allowed, see `_solve` method for details.
      rtol: Relative tolerance allowed, see `_solve` method for details.
      safety: Safety factor, see `_solve` method for details.
      ifactor: Maximum factor by which the step size can increase, see `_solve`
        method for details.
      dfactor: Minimum factor by which the step size can decrease, see `_solve`
        method for details.

    Returns:
      solver_state: `_RungeKuttaSolverInternalState` holding new solver state.
        Note that the step might not advance the time if the error tolerance
        criterias were not met. In this case step_size is decreased.
      diagnostics: `_DopriDiagnostics` holding diagnostic values after RK step.
    """
        y0, f0, _, t0, dt, interp_coeff = solver_state
        assertion_ops = []
        # TODO(dkochkov) Profile performance impact of `control_dependencies` here.
        with tf.name_scope('assertions'):
            if self._max_num_steps is not None:
                check_max_num_steps = tf.debugging.assert_less(
                    diagnostics.num_ode_fn_evaluations, max_ode_fn_evals,
                    'max_num_steps exceeded')
                assertion_ops.append(check_max_num_steps)

            if self._validate_args:
                check_underflow = tf.debugging.assert_greater(
                    t0 + dt, t0, 'underflow in dt')
                assert_finite = functools.partial(
                    tf.debugging.assert_all_finite,
                    message='non-finite values in solution')
                check_numerics = tf.nest.map_structure(assert_finite, y0)
                assertion_ops.append(check_underflow)
                assertion_ops.append(check_numerics)

        with tf.control_dependencies(assertion_ops):
            y1, f1, y1_error, k = rk_util.runge_kutta_step(
                ode_fn, y0, f0, t0, dt, _TABLEAU)

        with tf.name_scope('error_ratio'):
            # We use the same criteria for accepting step as in scipy.
            abs_y0 = tf.nest.map_structure(tf.abs, y0)
            abs_y1 = tf.nest.map_structure(tf.abs, y1)
            max_y_vals = tf.nest.map_structure(tf.math.maximum, abs_y0, abs_y1)
            ones_nest = rk_util.nest_constant(abs_y0)
            error_tol = rk_util.weighted_sum([atol, rtol],
                                             [ones_nest, max_y_vals])

            def scale_errors(error_component, tol_scale):
                abs_square_error = rk_util.abs_square(error_component)
                abs_square_tol_scale = rk_util.abs_square(tol_scale)
                return tf.divide(abs_square_error, abs_square_tol_scale)

            scaled_errors = tf.nest.map_structure(scale_errors, y1_error,
                                                  error_tol)
            error_ratio = rk_util.nest_rms_norm(scaled_errors)
            accept_step = error_ratio <= 1

        with tf.name_scope('update/state'):
            y_next = rk_util.nest_where(accept_step, y1, y0)
            f_next = rk_util.nest_where(accept_step, f1, f0)
            t_next = tf.where(accept_step, t0 + dt, t0)

            new_coefficients = rk_util.rk_fourth_order_interpolation_coefficients(
                y0, y1, k, dt, _TABLEAU)
            interp_coeff = rk_util.nest_where(accept_step, new_coefficients,
                                              interp_coeff)

            dt_next = util.next_step_size(dt, self.ORDER, error_ratio, safety,
                                          dfactor, ifactor)
            solver_state = _RungeKuttaSolverInternalState(
                y_next, f_next, t0, t_next, dt_next, interp_coeff)

        diagnostics = diagnostics._replace(
            num_ode_fn_evaluations=diagnostics.num_ode_fn_evaluations +
            self.ODE_FN_EVALS_PER_STEP)

        return solver_state, diagnostics