Example #1
0
 def _assert_ops(
     self,
     ode_fn,
     initial_time,
     initial_state,
     solution_times,
     previous_solver_state,
     rtol,
     atol,
     first_step_size,
     safety_factor,
     min_step_size_factor,
     max_step_size_factor,
     max_num_steps,
     solution_times_by_solver
 ):
   """Constructs dynamic assertions that validate input values to `_solve`."""
   assert_ops = []
   if self._validate_args is None:
     return assert_ops
   if solution_times_by_solver:
     final_time = solution_times.final_time
     assert_ops.append(
         util.assert_positive(final_time - initial_time,
                              'final_time - initial_time'))
   else:
     assert_ops += [
         util.assert_increasing(solution_times, 'solution_times'),
         util.assert_nonnegative(solution_times[0] - initial_time,
                                 'solution_times[0] - initial_time'),
     ]
   if previous_solver_state is not None:
     state_diff = initial_state - previous_solver_state.current_state
     assert_states_match = assert_util.assert_near(
         tf.norm(state_diff), 0., message='`previous_solver_state` does not '
         'match the `initial_state`.')
     assert_ops.append(assert_states_match)
   if self._max_num_steps is not None:
     assert_ops.append(util.assert_positive(max_num_steps, 'max_num_steps'))
   assert_ops += [
       util.assert_positive(rtol, 'rtol'),
       util.assert_positive(atol, 'atol'),
       util.assert_positive(first_step_size, 'first_step_size'),
       util.assert_positive(safety_factor, 'safety_factor'),
       util.assert_positive(
           min_step_size_factor, 'min_step_size_factor'),
       util.assert_positive(
           max_step_size_factor, 'max_step_size_factor'),
   ]
   derivative = ode_fn(initial_time, initial_state)
   tf.nest.assert_same_structure(initial_state, derivative)
   return assert_ops
Example #2
0
 def _assert_ops(
     self,
     previous_solver_internal_state,
     initial_state_vec,
     final_time,
     initial_time,
     solution_times,
     max_num_steps,
     max_num_newton_iters,
     atol,
     rtol,
     first_step_size,
     safety_factor,
     min_step_size_factor,
     max_step_size_factor,
     max_order,
     newton_tol_factor,
     newton_step_size_factor,
     solution_times_chosen_by_solver,
 ):
     """Creates a list of assert operations."""
     if not self._validate_args:
         return []
     assert_ops = []
     if previous_solver_internal_state is not None:
         assert_initial_state_matches_previous_solver_internal_state = (
             tf.debugging.assert_near(
                 tf.norm(
                     initial_state_vec -
                     previous_solver_internal_state.backward_differences[0],
                     np.inf),
                 0.,
                 message='`previous_solver_internal_state` does not match '
                 '`initial_state`.'))
         assert_ops.append(
             assert_initial_state_matches_previous_solver_internal_state)
     assert_ops.append(
         util.assert_positive(final_time - initial_time,
                              'final_time - initial_time'))
     if not solution_times_chosen_by_solver:
         assert_ops += [
             util.assert_increasing(solution_times, 'solution_times'),
             util.assert_nonnegative(solution_times[0] - initial_time,
                                     'solution_times[0] - initial_time'),
         ]
     if max_num_steps is not None:
         assert_ops.append(
             util.assert_positive(max_num_steps, 'max_num_steps'))
     if max_num_newton_iters is not None:
         assert_ops.append(
             util.assert_positive(max_num_newton_iters,
                                  'max_num_newton_iters'))
     assert_ops += [
         util.assert_positive(rtol, 'rtol'),
         util.assert_positive(atol, 'atol'),
         util.assert_positive(first_step_size, 'first_step_size'),
         util.assert_positive(safety_factor, 'safety_factor'),
         util.assert_positive(min_step_size_factor, 'min_step_size_factor'),
         util.assert_positive(max_step_size_factor, 'max_step_size_factor'),
         tf.Assert((max_order >= 1) & (max_order <= bdf_util.MAX_ORDER), [
             '`max_order` must be between 1 and {}.'.format(
                 bdf_util.MAX_ORDER)
         ]),
         util.assert_positive(newton_tol_factor, 'newton_tol_factor'),
         util.assert_positive(newton_step_size_factor,
                              'newton_step_size_factor'),
     ]
     return assert_ops