def test_weighted_sum_nested_values(self, dtype): del dtype # not used in this test case. weights = [0.5, -0.25, -0.25] states = [(tf.eye(2), tf.ones((2, 2))) for _ in range(3)] weighted_state_sum = rk_util.weighted_sum(weights, states) expected_result = (tf.zeros((2, 2)), tf.zeros((2, 2))) self.assertAllCloseNested(weighted_state_sum, expected_result) weights = [0.5, -0.25, -0.25, 0] states = [(tf.eye(2), tf.ones((2, 2))) for _ in range(4)] weighted_state_sum = rk_util.weighted_sum(weights, states) expected_result = (tf.zeros((2, 2)), tf.zeros((2, 2))) self.assertAllCloseNested(weighted_state_sum, expected_result)
def reverse_to_result_time(n, augmented_state, _): """Integrates the augmented system backwards in time.""" lower_bound_of_integration = result_time_array.read(n) upper_bound_of_integration = result_time_array.read(n - 1) _, adjoint_state, adjoint_variable_state = augmented_state initial_state = _read_solution_components( result_state_arrays, input_state_structure, n - 1) initial_adjoint = _read_solution_components( dresult_state_arrays, input_state_structure, n - 1) initial_adjoint_state = rk_util.weighted_sum( [1.0, 1.0], [adjoint_state, initial_adjoint]) initial_augmented_state = (initial_state, initial_adjoint_state, adjoint_variable_state) # TODO(b/138304303): Allow the user to specify the Hessian of # `ode_fn` so that we can get the Jacobian of the adjoint system. # TODO(b/143624114): Support higher order derivatives. augmented_results = self._solve( ode_fn=augmented_ode_fn, initial_time=-lower_bound_of_integration, initial_state=initial_augmented_state, solution_times=[-upper_bound_of_integration], batch_ndims=batch_ndims) # Results added an extra time dim of size 1, squeeze it. select_result = lambda x: tf.squeeze(x, [0]) result_state = augmented_results.states result_state = tf.nest.map_structure( select_result, result_state) status = augmented_results.diagnostics.status return n - 1, result_state, status
def test_weighted_sum_tensor(self, dtype): del dtype # not used in this test case. weights = [0.5, -0.25, -0.25] states = [tf.eye(2) for _ in range(3)] weighted_tensor_sum = rk_util.weighted_sum(weights, states) self.assertAllClose(weighted_tensor_sum, tf.zeros((2, 2))) weights = [0.5, -0.25, -0.25, 1.0] states = [tf.ones(2) for _ in range(4)] weighted_tensor_sum = rk_util.weighted_sum(weights, states) self.assertAllClose(weighted_tensor_sum, tf.ones(2)) weights = [0.5, -0.25, -0.25, 0.0] states = [tf.eye(2) for _ in range(4)] weighted_tensor_sum = rk_util.weighted_sum(weights, states) self.assertAllClose(weighted_tensor_sum, tf.zeros((2, 2)))
def reverse_to_result_time(n, augmented_state, _): """Integrates the augmented system backwards in time.""" lower_bound_of_integration = result_time_array.read(n) upper_bound_of_integration = result_time_array.read(n - 1) _, adjoint_state, adjoint_variable_state = augmented_state initial_state = _read_solution_components( result_state_arrays, input_state_structure, n - 1) initial_adjoint = _read_solution_components( dresult_state_arrays, input_state_structure, n - 1) initial_adjoint_state = rk_util.weighted_sum( [1.0, 1.0], [adjoint_state, initial_adjoint]) initial_augmented_state = (initial_state, initial_adjoint_state, adjoint_variable_state) augmented_results = self._solve( ode_fn=augmented_ode_fn, initial_time=-lower_bound_of_integration, initial_state=initial_augmented_state, solution_times=[-upper_bound_of_integration], batch_ndims=batch_ndims) # Results added an extra time dim of size 1, squeeze it. select_result = lambda x: tf.squeeze(x, [0]) result_state = augmented_results.states result_state = tf.nest.map_structure( select_result, result_state) status = augmented_results.diagnostics.status return n - 1, result_state, status
def test_weighted_sum_value_errors(self, dtype): del dtype # not used in this test case. empty_weights = [] empty_states = [] with self.assertRaises(ValueError): _ = rk_util.weighted_sum(empty_weights, empty_states) wrong_length_weights = [0.5, -0.25, -0.25, 0] wrong_length_states = [(tf.eye(2), tf.ones((2, 2))) for _ in range(5)] with self.assertRaises(ValueError): _ = rk_util.weighted_sum(wrong_length_weights, wrong_length_states) weights = [0.5, -0.25, -0.25, 0] not_same_structure_states = [(tf.eye(2), tf.ones((2, 2))) for _ in range(3)] not_same_structure_states.append(tf.eye(2)) with self.assertRaises(ValueError): _ = rk_util.weighted_sum(weights, not_same_structure_states)
def grad_fn(state, variables, constants): del variables # We compute these gradients via the GradientTape # capturing them. derivatives = ode_fn(time, state, **constants) adjoint_no_grad = tf.nest.map_structure( tf.stop_gradient, adjoint_state) negative_derivatives = rk_util.weighted_sum( [-1.0], [derivatives]) def dot_prod(tensor_a, tensor_b): return tf.reduce_sum(tensor_a * tensor_b) # See docstring for details. adjoint_dot_derivatives = tf.nest.map_structure( dot_prod, adjoint_no_grad, derivatives) adjoint_dot_derivatives = tf.squeeze( tf.add_n(tf.nest.flatten(adjoint_dot_derivatives))) return adjoint_dot_derivatives, negative_derivatives
def make_augmented_state(n, prev_augmented_state): """Constructs the augmented state for step `n`.""" (_, adjoint_state, adjoint_variable_state, adjoint_constant_state) = prev_augmented_state initial_state = _read_solution_components( result_state_arrays, input_state_structure, n - 1, ) initial_adjoint = _read_solution_components( dresult_state_arrays, input_state_structure, n - 1, ) initial_adjoint_state = rk_util.weighted_sum( [1.0, 1.0], [adjoint_state, initial_adjoint]) augmented_state = ( initial_state, initial_adjoint_state, adjoint_variable_state, adjoint_constant_state, ) return augmented_state
def augmented_ode_fn(backward_time, augmented_state): """Dynamics function for the augmented system. Describes a differential equation that evolves the augmented state backwards in time to compute gradients using the adjoint method. Augmented state consists of 3 components `(state, adjoint_state, vars)` all evaluated at time `backward_time`: state: represents the solution of user provided `ode_fn`. The structure coincides with the `initial_state`. adjoint_state: represents the solution of adjoint sensitivity differential equation as discussed below. Has the same structure and shape as `state`. vars: represent the solution of the adjoint equation for variable gradients. Represented as a `Tuple(Tensor, ...)` with as many tensors as there are `variables`. Adjoint sensitivity equation describes the gradient of the solution with respect to the value of the solution at previous time t. Its dynamics are given by d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z) Which is computed as: d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)] d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)] where in the last line we moved adj(t)_j under derivative by removing gradient from it. Adjoint equation for the gradient with respect to every `tf.Variable` theta follows: d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta) = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)] Args: backward_time: Floating `Tensor` representing current time. augmented_state: `Tuple(state, adjoint_state, variable_grads)` Returns: negative_derivatives: Structure of `Tensor`s equal to backwards time derivative of the `state` componnent. adjoint_ode: Structure of `Tensor`s equal to backwards time derivative of the `adjoint_state` component. adjoint_variables_ode: Structure of `Tensor`s equal to backwards time derivative of the `vars` component. """ # The negative signs disappears after the change of variables. # The ODE solver cannot handle the case initial_time > final_time # and hence a change of variables backward_time = -time is used. time = -backward_time state, adjoint_state, _ = augmented_state with tf.GradientTape() as tape: tape.watch(variables) tape.watch(state) derivatives = ode_fn(time, state) adjoint_no_grad = tf.nest.map_structure( tf.stop_gradient, adjoint_state) negative_derivatives = rk_util.weighted_sum( [-1.0], [derivatives]) def dot_prod(tensor_a, tensor_b): return tf.reduce_sum(tensor_a * tensor_b) # See docstring for details. adjoint_dot_derivatives = tf.nest.map_structure( dot_prod, adjoint_no_grad, derivatives) adjoint_dot_derivatives = tf.squeeze( tf.add_n( tf.nest.flatten(adjoint_dot_derivatives))) adjoint_ode, adjoint_variables_ode = tape.gradient( adjoint_dot_derivatives, (state, tuple(variables)), unconnected_gradients=tf.UnconnectedGradients.ZERO) return negative_derivatives, adjoint_ode, adjoint_variables_ode
def test_weighted_sum_nested_type(self, dtype): del dtype # not used in this test case. weights = [0.5, -0.25, -0.25] states = [(tf.eye(2), tf.ones((2, 2))) for _ in range(3)] weighted_state_sum = rk_util.weighted_sum(weights, states) self.assertIsInstance(weighted_state_sum, tuple)
def _step(self, solver_state, diagnostics, max_ode_fn_evals, ode_fn, atol, rtol, safety, ifactor, dfactor): """Take an adaptive Runge-Kutta step. Args: solver_state: `_DopriSolverInternalState` - solver internal state. diagnostics: `_DopriDiagnostics` - info on the current `_solve` call. max_ode_fn_evals: Integer `Tensor` specifying the maximum number of ode_fn evaluations. ode_fn: Callable(t, y) -> dy_dt. atol: Absolute tolerance allowed, see `_solve` method for details. rtol: Relative tolerance allowed, see `_solve` method for details. safety: Safety factor, see `_solve` method for details. ifactor: Maximum factor by which the step size can increase, see `_solve` method for details. dfactor: Minimum factor by which the step size can decrease, see `_solve` method for details. Returns: solver_state: `_RungeKuttaSolverInternalState` holding new solver state. Note that the step might not advance the time if the error tolerance criterias were not met. In this case step_size is decreased. diagnostics: `_DopriDiagnostics` holding diagnostic values after RK step. """ y0, f0, _, t0, dt, interp_coeff = solver_state assertion_ops = [] # TODO(dkochkov) Profile performance impact of `control_dependencies` here. with tf.name_scope('assertions'): if self._max_num_steps is not None: check_max_num_steps = tf.debugging.assert_less( diagnostics.num_ode_fn_evaluations, max_ode_fn_evals, 'max_num_steps exceeded') assertion_ops.append(check_max_num_steps) if self._validate_args: check_underflow = tf.debugging.assert_greater( t0 + dt, t0, 'underflow in dt') assert_finite = functools.partial( tf.debugging.assert_all_finite, message='non-finite values in solution') check_numerics = tf.nest.map_structure(assert_finite, y0) assertion_ops.append(check_underflow) assertion_ops.append(check_numerics) with tf.control_dependencies(assertion_ops): y1, f1, y1_error, k = rk_util.runge_kutta_step( ode_fn, y0, f0, t0, dt, _TABLEAU) with tf.name_scope('error_ratio'): # We use the same criteria for accepting step as in scipy. abs_y0 = tf.nest.map_structure(tf.abs, y0) abs_y1 = tf.nest.map_structure(tf.abs, y1) max_y_vals = tf.nest.map_structure(tf.math.maximum, abs_y0, abs_y1) ones_nest = rk_util.nest_constant(abs_y0) error_tol = rk_util.weighted_sum([atol, rtol], [ones_nest, max_y_vals]) def scale_errors(error_component, tol_scale): abs_square_error = rk_util.abs_square(error_component) abs_square_tol_scale = rk_util.abs_square(tol_scale) return tf.divide(abs_square_error, abs_square_tol_scale) scaled_errors = tf.nest.map_structure(scale_errors, y1_error, error_tol) error_ratio = rk_util.nest_rms_norm(scaled_errors) accept_step = error_ratio <= 1 with tf.name_scope('update/state'): y_next = rk_util.nest_where(accept_step, y1, y0) f_next = rk_util.nest_where(accept_step, f1, f0) t_next = tf.where(accept_step, t0 + dt, t0) new_coefficients = rk_util.rk_fourth_order_interpolation_coefficients( y0, y1, k, dt, _TABLEAU) interp_coeff = rk_util.nest_where(accept_step, new_coefficients, interp_coeff) dt_next = util.next_step_size(dt, self.ORDER, error_ratio, safety, dfactor, ifactor) solver_state = _RungeKuttaSolverInternalState( y_next, f_next, t0, t_next, dt_next, interp_coeff) diagnostics = diagnostics._replace( num_ode_fn_evaluations=diagnostics.num_ode_fn_evaluations + self.ODE_FN_EVALS_PER_STEP) return solver_state, diagnostics