def assert_ops(): """Creates a list of assert operations.""" if not self._validate_args: return [] assert_ops = [] if ((not initial_state_missing) and (previous_solver_internal_state is not None)): assert_initial_state_matches_previous_solver_internal_state = ( tf.assert_near( tf.norm( original_initial_state - previous_solver_internal_state. backward_differences[0], np.inf), 0., message= '`previous_solver_internal_state` does not match ' '`initial_state`.')) assert_ops.append( assert_initial_state_matches_previous_solver_internal_state ) if solution_times_chosen_by_solver: assert_ops.append( util.assert_positive(final_time - initial_time, 'final_time - initial_time')) else: assert_ops += [ util.assert_increasing(solution_times, 'solution_times'), util.assert_nonnegative( solution_times[0] - initial_time, 'solution_times[0] - initial_time'), ] if max_num_steps is not None: assert_ops.append( util.assert_positive(max_num_steps, 'max_num_steps')) if max_num_newton_iters is not None: assert_ops.append( util.assert_positive(max_num_newton_iters, 'max_num_newton_iters')) assert_ops += [ util.assert_positive(rtol, 'rtol'), util.assert_positive(atol, 'atol'), util.assert_positive(first_step_size, 'first_step_size'), util.assert_positive(safety_factor, 'safety_factor'), util.assert_positive(min_step_size_factor, 'min_step_size_factor'), util.assert_positive(max_step_size_factor, 'max_step_size_factor'), tf.Assert((max_order >= 1) & (max_order <= bdf_util.MAX_ORDER), [ '`max_order` must be between 1 and {}.'.format( bdf_util.MAX_ORDER) ]), util.assert_positive(newton_tol_factor, 'newton_tol_factor'), util.assert_positive(newton_step_size_factor, 'newton_step_size_factor'), ] return assert_ops
def _assert_ops( self, ode_fn, initial_time, initial_state, solution_times, previous_solver_state, rtol, atol, first_step_size, safety_factor, min_step_size_factor, max_step_size_factor, max_num_steps, solution_times_by_solver ): """Constructs dynamic assertions that validate input values to `_solve`.""" assert_ops = [] if self._validate_args is None: return assert_ops if solution_times_by_solver: final_time = solution_times.final_time assert_ops.append( util.assert_positive(final_time - initial_time, 'final_time - initial_time')) else: assert_ops += [ util.assert_increasing(solution_times, 'solution_times'), util.assert_nonnegative(solution_times[0] - initial_time, 'solution_times[0] - initial_time'), ] if previous_solver_state is not None: state_diff = initial_state - previous_solver_state.current_state assert_states_match = assert_util.assert_near( tf.norm(state_diff), 0., message='`previous_solver_state` does not ' 'match the `initial_state`.') assert_ops.append(assert_states_match) if self._max_num_steps is not None: assert_ops.append(util.assert_positive(max_num_steps, 'max_num_steps')) assert_ops += [ util.assert_positive(rtol, 'rtol'), util.assert_positive(atol, 'atol'), util.assert_positive(first_step_size, 'first_step_size'), util.assert_positive(safety_factor, 'safety_factor'), util.assert_positive( min_step_size_factor, 'min_step_size_factor'), util.assert_positive( max_step_size_factor, 'max_step_size_factor'), ] derivative = ode_fn(initial_time, initial_state) tf.nest.assert_same_structure(initial_state, derivative) return assert_ops