def test_structured_jacobian_mat(self): """Test that structured jacobian gets flattened correctly.""" state = [ tf.constant([1., 2.]), tf.constant([3.]), tf.constant([[4., 5.], [6., 7.]]) ] state_shape = [tf.shape(s) for s in state] mat = tf.convert_to_tensor(np.random.randn(7, 7), dtype=tf.float32) with tf.GradientTape(persistent=True) as tape: tape.watch(state) state_vec = util.get_state_vec(state) tape.watch(state_vec) new_state_vec = tf.matmul(mat, state_vec[..., tf.newaxis])[..., 0] new_state = util.get_state_from_vec(new_state_vec, state_shape) jacobian_mat = self.evaluate(tape.jacobian(new_state_vec, state_vec)) jacobian = [ [self.evaluate(tape.jacobian(y, x)) for x in state] for y in new_state ] jacobian_mat2 = self.evaluate( util.get_jacobian_fn_mat( jacobian_fn=jacobian, ode_fn_vec=None, state_shape=state_shape, use_pfor=False, dtype=tf.float32)(None)) self.assertAllEqual(jacobian_mat, jacobian_mat2)
def test_right_mult_by_jacobian_mat(self, use_automatic_differentiation, use_pfor): vec = np.float32([1., 2., 3.]) jacobian = -np.float32([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]) time = np.float32(0.) state_vec = np.float32([1., 1., 1.]) def ode_fn(_, state): return tf.squeeze(tf.matmul(jacobian, state[:, tf.newaxis])) state_shape = tf.shape(state_vec) ode_fn_vec = util.get_ode_fn_vec(ode_fn, state_shape) jacobian_fn_mat = util.get_jacobian_fn_mat( None if use_automatic_differentiation else jacobian, ode_fn_vec, state_shape, use_pfor) result = util.right_mult_by_jacobian_mat(jacobian_fn_mat, ode_fn_vec, time, state_vec, vec) self.assertAllClose(self.evaluate(result), np.dot(vec, jacobian))
def test_structured_jacobian_mat(self): """Test that structured jacobian gets flattened correctly.""" state = [ tf.constant([1., 2.]), tf.constant([3.]), tf.constant([[4., 5.], [6., 7.]]) ] state_shape = [tf.shape(s) for s in state] mat = tf.convert_to_tensor(np.random.randn(7, 7), dtype=tf.float32) def grad_fn_mat(state_vec): return tf.matmul(mat, state_vec[..., tf.newaxis])[..., 0] def grad_fn(state): state_vec = util.get_state_vec(state) new_state_vec = grad_fn_mat(state_vec) return util.get_state_from_vec(new_state_vec, state_shape) def get_jacobian(f, x): return tfp_gradient.batch_jacobian(lambda x: f(x[0])[tf.newaxis], x[tf.newaxis])[0] def replace_idx(array, i, val): return array[:i] + [val] + array[i + 1:] state_vec = util.get_state_vec(state) jacobian_mat = get_jacobian(grad_fn_mat, state_vec) jacobian = [ [ get_jacobian(lambda x: grad_fn(replace_idx(state, i, x))[j], x) # pylint: disable=cell-var-from-loop for i, x in enumerate(state) ] for j in range(len(state)) ] jacobian_mat2 = util.get_jacobian_fn_mat(jacobian_fn=jacobian, ode_fn_vec=None, state_shape=state_shape, dtype=tf.float32)(None) self.assertAllEqual(jacobian_mat, jacobian_mat2)
def _solve( self, ode_fn, initial_time, initial_state, solution_times, jacobian_fn=None, jacobian_sparsity=None, batch_ndims=None, previous_solver_internal_state=None, ): # This function is comprised of the following sequential stages: # (1) Make static assertions. # (2) Initialize variables. # (3) Make non-static assertions. # (4) Solve up to final time. # (5) Return `Results` object. # # The stages can be found in the code by searching for (n) where n=1..5. # # By static vs. non-static assertions (see stages 1 and 3), we mean # assertions that can be made before the graph is run vs. those that can # only be made at run time. The latter are constructed as a list of # tf.Assert operations by the function `assert_ops` (see below). # # If `solution_times` is specified as a `Tensor`, stage 4 consists of three # nested loops, which can be conceptually understood as follows: # ``` # current_time, current_state = initial_time, initial_state # order, step_size = 1, first_step_size # for solution_time in solution_times: # while current_time < solution_time: # while True: # next_time = current_time + step_size # next_state, error = ( # solve_nonlinear_equation_to_get_approximate_state_at_next_time( # current_time, current_state, next_time, order)) # if error < tolerance: # current_time, current_state = next_time, next_state # order, step_size = ( # maybe_update_order_and_step_size(order, step_size)) # break # else: # step_size = decrease_step_size(step_size) # ``` # The outermost loop advances the solver to the next `solution_time` (see # `advance_to_solution_time`). The middle loop advances the solver by a # small timestep (see `step`). The innermost loop determines the size of # that timestep (see `maybe_step`). # # If `solution_times` is specified as # `tfp.math.ode.ChosenBySolver(final_time)`, the outermost loop is skipped # and `solution_time` in the middle loop is replaced by `final_time`. def advance_to_solution_time(n, diagnostics, iterand, solver_internal_state, state_vec_array, time_array): """Takes multiple steps to advance time to `solution_times[n]`.""" def step_cond(next_time, diagnostics, iterand, *_): return (iterand.time < next_time) & (tf.equal( diagnostics.status, 0)) nth_solution_time = solution_time_array.read(n) [ _, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ] = tf.while_loop(step_cond, step, [ nth_solution_time, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ]) state_vec_array = state_vec_array.write( n, solver_internal_state.backward_differences[0]) time_array = time_array.write(n, nth_solution_time) return (n + 1, diagnostics, iterand, solver_internal_state, state_vec_array, time_array) def step(next_time, diagnostics, iterand, solver_internal_state, state_vec_array, time_array): """Takes a single step.""" distance_to_next_time = next_time - iterand.time overstepped = iterand.new_step_size > distance_to_next_time iterand = iterand._replace(new_step_size=tf1.where( overstepped, distance_to_next_time, iterand.new_step_size), should_update_step_size=overstepped | iterand.should_update_step_size) if not self._evaluate_jacobian_lazily: diagnostics = diagnostics._replace( num_jacobian_evaluations=diagnostics. num_jacobian_evaluations + 1) iterand = iterand._replace(jacobian_mat=jacobian_fn_mat( iterand.time, solver_internal_state.backward_differences[0]), jacobian_is_up_to_date=True) def maybe_step_cond(accepted, diagnostics, *_): return tf.logical_not(accepted) & tf.equal( diagnostics.status, 0) _, diagnostics, iterand, solver_internal_state = tf.while_loop( maybe_step_cond, maybe_step, [False, diagnostics, iterand, solver_internal_state]) if solution_times_chosen_by_solver: state_vec_array = state_vec_array.write( state_vec_array.size(), solver_internal_state.backward_differences[0]) time_array = time_array.write(time_array.size(), iterand.time) return (next_time, diagnostics, iterand, solver_internal_state, state_vec_array, time_array) def maybe_step(accepted, diagnostics, iterand, solver_internal_state): """Takes a single step only if the outcome has a low enough error.""" [ num_jacobian_evaluations, num_matrix_factorizations, num_ode_fn_evaluations, status ] = diagnostics [ jacobian_mat, jacobian_is_up_to_date, new_step_size, num_steps, num_steps_same_size, should_update_jacobian, should_update_step_size, time, unitary, upper ] = iterand [backward_differences, order, step_size] = solver_internal_state if max_num_steps is not None: status = tf1.where(tf.equal(num_steps, max_num_steps), -1, 0) backward_differences = tf1.where( should_update_step_size, bdf_util.interpolate_backward_differences( backward_differences, order, new_step_size / step_size), backward_differences) step_size = tf1.where(should_update_step_size, new_step_size, step_size) should_update_factorization = should_update_step_size num_steps_same_size = tf1.where(should_update_step_size, 0, num_steps_same_size) def update_factorization(): return bdf_util.newton_qr( jacobian_mat, newton_coefficients_array.read(order), step_size) if self._evaluate_jacobian_lazily: def update_jacobian_and_factorization(): new_jacobian_mat = jacobian_fn_mat(time, backward_differences[0]) new_unitary, new_upper = update_factorization() return [ new_jacobian_mat, True, num_jacobian_evaluations + 1, new_unitary, new_upper ] def maybe_update_factorization(): new_unitary, new_upper = tf.cond( should_update_factorization, update_factorization, lambda: [unitary, upper]) return [ jacobian_mat, jacobian_is_up_to_date, num_jacobian_evaluations, new_unitary, new_upper ] [ jacobian_mat, jacobian_is_up_to_date, num_jacobian_evaluations, unitary, upper ] = tf.cond(should_update_jacobian, update_jacobian_and_factorization, maybe_update_factorization) else: unitary, upper = update_factorization() num_matrix_factorizations += 1 tol = p.atol + p.rtol * tf.abs(backward_differences[0]) newton_tol = newton_tol_factor * tf.norm(tol) [ newton_converged, next_backward_difference, next_state_vec, newton_num_iters ] = bdf_util.newton(backward_differences, max_num_newton_iters, newton_coefficients_array.read(order), p.ode_fn_vec, order, step_size, time, newton_tol, unitary, upper) num_steps += 1 num_ode_fn_evaluations += newton_num_iters # If Newton's method failed and the Jacobian was up to date, decrease the # step size. newton_failed = tf.logical_not(newton_converged) should_update_step_size = newton_failed & jacobian_is_up_to_date new_step_size = step_size * tf1.where(should_update_step_size, newton_step_size_factor, 1.) # If Newton's method failed and the Jacobian was NOT up to date, update # the Jacobian. should_update_jacobian = newton_failed & tf.logical_not( jacobian_is_up_to_date) error_ratio = tf1.where( newton_converged, bdf_util.error_ratio(next_backward_difference, error_coefficients_array.read(order), tol), np.nan) accepted = error_ratio < 1. converged_and_rejected = newton_converged & tf.logical_not( accepted) # If Newton's method converged but the solution was NOT accepted, decrease # the step size. new_step_size = tf1.where( converged_and_rejected, util.next_step_size(step_size, order, error_ratio, p.safety_factor, min_step_size_factor, max_step_size_factor), new_step_size) should_update_step_size = should_update_step_size | converged_and_rejected # If Newton's method converged and the solution was accepted, update the # matrix of backward differences. time = tf1.where(accepted, time + step_size, time) backward_differences = tf1.where( accepted, bdf_util.update_backward_differences(backward_differences, next_backward_difference, next_state_vec, order), backward_differences) jacobian_is_up_to_date = jacobian_is_up_to_date & tf.logical_not( accepted) num_steps_same_size = tf1.where(accepted, num_steps_same_size + 1, num_steps_same_size) # Order and step size are only updated if we have taken strictly more than # order + 1 steps of the same size. This is to prevent the order from # being throttled. should_update_order_and_step_size = accepted & (num_steps_same_size > order + 1) backward_differences_array = tf.TensorArray( backward_differences.dtype, size=bdf_util.MAX_ORDER + 3, clear_after_read=False, element_shape=next_backward_difference.shape).unstack( backward_differences) new_order = order new_error_ratio = error_ratio for offset in [-1, +1]: proposed_order = tf.clip_by_value(order + offset, 1, max_order) proposed_error_ratio = bdf_util.error_ratio( backward_differences_array.read(proposed_order + 1), error_coefficients_array.read(proposed_order), tol) proposed_error_ratio_is_lower = proposed_error_ratio < new_error_ratio new_order = tf1.where( should_update_order_and_step_size & proposed_error_ratio_is_lower, proposed_order, new_order) new_error_ratio = tf1.where( should_update_order_and_step_size & proposed_error_ratio_is_lower, proposed_error_ratio, new_error_ratio) order = new_order error_ratio = new_error_ratio new_step_size = tf1.where( should_update_order_and_step_size, util.next_step_size(step_size, order, error_ratio, p.safety_factor, min_step_size_factor, max_step_size_factor), new_step_size) should_update_step_size = (should_update_step_size | should_update_order_and_step_size) diagnostics = _BDFDiagnostics(num_jacobian_evaluations, num_matrix_factorizations, num_ode_fn_evaluations, status) iterand = _BDFIterand(jacobian_mat, jacobian_is_up_to_date, new_step_size, num_steps, num_steps_same_size, should_update_jacobian, should_update_step_size, time, unitary, upper) solver_internal_state = _BDFSolverInternalState( backward_differences, order, step_size) return accepted, diagnostics, iterand, solver_internal_state # (1) Make static assertions. # TODO(b/138304296): Support specifying Jacobian sparsity patterns. if jacobian_sparsity is not None: raise NotImplementedError( 'The BDF solver does not support specifying ' 'Jacobian sparsity patterns.') if batch_ndims is not None and batch_ndims != 0: raise NotImplementedError( 'The BDF solver does not support batching.') solution_times_chosen_by_solver = (isinstance(solution_times, base.ChosenBySolver)) with tf.name_scope(self._name): # (2) Convert to tensors. p = self._prepare_common_params( ode_fn=ode_fn, initial_state=initial_state, initial_time=initial_time, ) if jacobian_fn is None and dtype_util.is_complex( p.common_state_dtype): raise NotImplementedError( 'The BDF solver does not support automatic ' 'Jacobian computations for complex dtypes.') # Convert everything to operate on a single, concatenated vector form. jacobian_fn_mat = util.get_jacobian_fn_mat( jacobian_fn, p.ode_fn_vec, p.state_shape, dtype=p.common_state_dtype, ) num_solution_times = 0 if solution_times_chosen_by_solver: final_time = tf.cast(solution_times.final_time, p.real_dtype) else: solution_times = tf.cast(solution_times, p.real_dtype) final_time = tf.reduce_max(solution_times) num_solution_times = tf.size(solution_times) solution_time_array = tf.TensorArray( solution_times.dtype, size=num_solution_times, element_shape=[]).unstack(solution_times) util.error_if_not_vector(solution_times, 'solution_times') min_step_size_factor = tf.convert_to_tensor( self._min_step_size_factor, dtype=p.real_dtype) max_step_size_factor = tf.convert_to_tensor( self._max_step_size_factor, dtype=p.real_dtype) max_num_steps = self._max_num_steps if max_num_steps is not None: max_num_steps = tf.convert_to_tensor(max_num_steps, dtype=tf.int32) max_order = tf.convert_to_tensor(self._max_order, dtype=tf.int32) max_num_newton_iters = self._max_num_newton_iters if max_num_newton_iters is not None: max_num_newton_iters = tf.convert_to_tensor( max_num_newton_iters, dtype=tf.int32) newton_tol_factor = tf.convert_to_tensor(self._newton_tol_factor, dtype=p.real_dtype) newton_step_size_factor = tf.convert_to_tensor( self._newton_step_size_factor, dtype=p.real_dtype) newton_coefficients, error_coefficients = self._prepare_coefficients( p.common_state_dtype) if self._validate_args: final_time = tf.ensure_shape(final_time, []) min_step_size_factor = tf.ensure_shape(min_step_size_factor, []) max_step_size_factor = tf.ensure_shape(max_step_size_factor, []) if max_num_steps is not None: max_num_steps = tf.ensure_shape(max_num_steps, []) max_order = tf.ensure_shape(max_order, []) if max_num_newton_iters is not None: max_num_newton_iters = tf.ensure_shape( max_num_newton_iters, []) newton_tol_factor = tf.ensure_shape(newton_tol_factor, []) newton_step_size_factor = tf.ensure_shape( newton_step_size_factor, []) newton_coefficients_array = tf.TensorArray( newton_coefficients.dtype, size=bdf_util.MAX_ORDER + 1, clear_after_read=False, element_shape=[]).unstack(newton_coefficients) error_coefficients_array = tf.TensorArray( error_coefficients.dtype, size=bdf_util.MAX_ORDER + 1, clear_after_read=False, element_shape=[]).unstack(error_coefficients) solver_internal_state = previous_solver_internal_state if solver_internal_state is None: solver_internal_state = self._initialize_solver_internal_state( ode_fn=ode_fn, initial_state=initial_state, initial_time=initial_time, ) state_vec_array = tf.TensorArray( p.common_state_dtype, size=num_solution_times, dynamic_size=solution_times_chosen_by_solver, element_shape=p.initial_state_vec.shape) time_array = tf.TensorArray( p.real_dtype, size=num_solution_times, dynamic_size=solution_times_chosen_by_solver, element_shape=tf.TensorShape([])) diagnostics = _BDFDiagnostics(num_jacobian_evaluations=0, num_matrix_factorizations=0, num_ode_fn_evaluations=0, status=0) iterand = _BDFIterand( jacobian_mat=tf.zeros([p.num_odes, p.num_odes], dtype=p.common_state_dtype), jacobian_is_up_to_date=False, new_step_size=solver_internal_state.step_size, num_steps=0, num_steps_same_size=0, should_update_jacobian=True, should_update_step_size=False, time=p.initial_time, unitary=tf.zeros([p.num_odes, p.num_odes], dtype=p.common_state_dtype), upper=tf.zeros([p.num_odes, p.num_odes], dtype=p.common_state_dtype), ) # (3) Make non-static assertions. with tf.control_dependencies( self._assert_ops( previous_solver_internal_state= previous_solver_internal_state, initial_state_vec=p.initial_state_vec, final_time=final_time, initial_time=p.initial_time, solution_times=solution_times, max_num_steps=max_num_steps, max_num_newton_iters=max_num_newton_iters, atol=p.atol, rtol=p.rtol, first_step_size=solver_internal_state.step_size, safety_factor=p.safety_factor, min_step_size_factor=min_step_size_factor, max_step_size_factor=max_step_size_factor, max_order=max_order, newton_tol_factor=newton_tol_factor, newton_step_size_factor=newton_step_size_factor, solution_times_chosen_by_solver= solution_times_chosen_by_solver, )): # (4) Solve up to final time. if solution_times_chosen_by_solver: def step_cond(next_time, diagnostics, iterand, *_): return (iterand.time < next_time) & (tf.equal( diagnostics.status, 0)) [ _, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ] = tf.while_loop(step_cond, step, [ final_time, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ]) else: def advance_to_solution_time_cond(n, diagnostics, *_): return (n < num_solution_times) & (tf.equal( diagnostics.status, 0)) [ _, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ] = tf.while_loop( advance_to_solution_time_cond, advance_to_solution_time, [ 0, diagnostics, iterand, solver_internal_state, state_vec_array, time_array ]) # (6) Return `Results` object. states = util.get_state_from_vec(state_vec_array.stack(), p.state_shape) times = time_array.stack() if not solution_times_chosen_by_solver: tensorshape_util.set_shape(times, solution_times.shape) tf.nest.map_structure( lambda s, ini_s: tensorshape_util.set_shape( # pylint: disable=g-long-lambda s, tensorshape_util.concatenate( solution_times.shape, ini_s.shape)), states, p.initial_state) return base.Results( times=times, states=states, diagnostics=diagnostics, solver_internal_state=solver_internal_state)
def solve(self, ode_fn, initial_time, initial_state, solution_times, jacobian_fn=None, jacobian_sparsity=None, batch_ndims=None, previous_solver_internal_state=None): """See `tfp.math.ode.Solver.solve`.""" # The `solve` function is comprised of the following sequential stages: # (1) Make static assertions. # (2) Initialize variables. # (3) Make non-static assertions. # (4) Solve up to final time. # (5) Return `Results` object. # # The stages can be found in the code by searching for (n) where n=1..5. # # By static vs. non-static assertions (see stages 1 and 3), we mean # assertions that can be made before the graph is run vs. those that can # only be made at run time. The latter are constructed as a list of # tf.Assert operations by the function `assert_ops` (see below). # # If `solution_times` is specified as a `Tensor`, stage 4 consists of three # nested loops, which can be conceptually understood as follows: # ``` # current_time, current_state = initial_time, initial_state # order, step_size = 1, first_step_size # for solution_time in solution_times: # while current_time < solution_time: # while True: # next_time = current_time + step_size # next_state, error = ( # solve_nonlinear_equation_to_get_approximate_state_at_next_time( # current_time, current_state, next_time, order)) # if error < tolerance: # current_time, current_state = next_time, next_state # order, step_size = ( # maybe_update_order_and_step_size(order, step_size)) # break # else: # step_size = decrease_step_size(step_size) # ``` # The outermost loop advances the solver to the next `solution_time` (see # `advance_to_solution_time`). The middle loop advances the solver by a # small timestep (see `step`). The innermost loop determines the size of # that timestep (see `maybe_step`). # # If `solution_times` is specified as # `tfp.math.ode.ChosenBySolver(final_time)`, the outermost loop is skipped # and `solution_time` in the middle loop is replaced by `final_time`. def assert_ops(): """Creates a list of assert operations.""" if not self._validate_args: return [] assert_ops = [] if ((not initial_state_missing) and (previous_solver_internal_state is not None)): assert_initial_state_matches_previous_solver_internal_state = ( tf.assert_near( tf.norm( original_initial_state - previous_solver_internal_state. backward_differences[0], np.inf), 0., message= '`previous_solver_internal_state` does not match ' '`initial_state`.')) assert_ops.append( assert_initial_state_matches_previous_solver_internal_state ) if solution_times_chosen_by_solver: assert_ops.append( util.assert_positive(final_time - initial_time, 'final_time - initial_time')) else: assert_ops += [ util.assert_increasing(solution_times, 'solution_times'), util.assert_nonnegative( solution_times[0] - initial_time, 'solution_times[0] - initial_time'), ] if max_num_steps is not None: assert_ops.append( util.assert_positive(max_num_steps, 'max_num_steps')) if max_num_newton_iters is not None: assert_ops.append( util.assert_positive(max_num_newton_iters, 'max_num_newton_iters')) assert_ops += [ util.assert_positive(rtol, 'rtol'), util.assert_positive(atol, 'atol'), util.assert_positive(first_step_size, 'first_step_size'), util.assert_positive(safety_factor, 'safety_factor'), util.assert_positive(min_step_size_factor, 'min_step_size_factor'), util.assert_positive(max_step_size_factor, 'max_step_size_factor'), tf.Assert((max_order >= 1) & (max_order <= bdf_util.MAX_ORDER), [ '`max_order` must be between 1 and {}.'.format( bdf_util.MAX_ORDER) ]), util.assert_positive(newton_tol_factor, 'newton_tol_factor'), util.assert_positive(newton_step_size_factor, 'newton_step_size_factor'), ] return assert_ops def advance_to_solution_time(n, diagnostics, iterand, solver_internal_state, states_array, times_array): """Takes multiple steps to advance time to `solution_times[n]`.""" def step_cond(next_time, diagnostics, iterand, *_): return (iterand.time < next_time) & (tf.equal( diagnostics.status, 0)) solution_times_n = solution_times_array.read(n) [ _, diagnostics, iterand, solver_internal_state, states_array, times_array ] = tf.while_loop(step_cond, step, [ solution_times_n, diagnostics, iterand, solver_internal_state, states_array, times_array ]) states_array = states_array.write( n, solver_internal_state.backward_differences[0]) times_array = times_array.write(n, solution_times_n) return (n + 1, diagnostics, iterand, solver_internal_state, states_array, times_array) def step(next_time, diagnostics, iterand, solver_internal_state, states_array, times_array): """Takes a single step.""" distance_to_next_time = next_time - iterand.time overstepped = iterand.new_step_size > distance_to_next_time iterand = iterand._replace(new_step_size=tf.where( overstepped, distance_to_next_time, iterand.new_step_size), should_update_step_size=overstepped | iterand.should_update_step_size) if not self._evaluate_jacobian_lazily: diagnostics = diagnostics._replace( num_jacobian_evaluations=diagnostics. num_jacobian_evaluations + 1) iterand = iterand._replace(jacobian=jacobian_fn_mat( iterand.time, solver_internal_state.backward_differences[0]), jacobian_is_up_to_date=True) def maybe_step_cond(accepted, diagnostics, *_): return tf.logical_not(accepted) & tf.equal( diagnostics.status, 0) _, diagnostics, iterand, solver_internal_state = tf.while_loop( maybe_step_cond, maybe_step, [False, diagnostics, iterand, solver_internal_state]) if solution_times_chosen_by_solver: states_array = states_array.write( states_array.size(), solver_internal_state.backward_differences[0]) times_array = times_array.write(times_array.size(), iterand.time) return (next_time, diagnostics, iterand, solver_internal_state, states_array, times_array) def maybe_step(accepted, diagnostics, iterand, solver_internal_state): """Takes a single step only if the outcome has a low enough error.""" [ num_jacobian_evaluations, num_matrix_factorizations, num_ode_fn_evaluations, status ] = diagnostics [ jacobian, jacobian_is_up_to_date, new_step_size, num_steps, num_steps_same_size, should_update_jacobian, should_update_step_size, time, unitary, upper ] = iterand backward_differences, order, state_shape, step_size = solver_internal_state if max_num_steps is not None: status = tf.where(tf.equal(num_steps, max_num_steps), -1, 0) backward_differences = tf.where( should_update_step_size, bdf_util.interpolate_backward_differences( backward_differences, order, new_step_size / step_size), backward_differences) step_size = tf.where(should_update_step_size, new_step_size, step_size) should_update_factorization = should_update_step_size num_steps_same_size = tf.where(should_update_step_size, 0, num_steps_same_size) def update_factorization(): return bdf_util.newton_qr( jacobian, newton_coefficients_array.read(order), step_size) if self._evaluate_jacobian_lazily: def update_jacobian_and_factorization(): new_jacobian = jacobian_fn_mat(time, backward_differences[0]) new_unitary, new_upper = update_factorization() return [ new_jacobian, True, num_jacobian_evaluations + 1, new_unitary, new_upper ] def maybe_update_factorization(): new_unitary, new_upper = tf.cond( should_update_factorization, update_factorization, lambda: [unitary, upper]) return [ jacobian, jacobian_is_up_to_date, num_jacobian_evaluations, new_unitary, new_upper ] [ jacobian, jacobian_is_up_to_date, num_jacobian_evaluations, unitary, upper ] = tf.cond(should_update_jacobian, update_jacobian_and_factorization, maybe_update_factorization) else: unitary, upper = update_factorization() num_matrix_factorizations += 1 tol = atol + rtol * tf.abs(backward_differences[0]) newton_tol = newton_tol_factor * tf.norm(tol) [ newton_converged, next_backward_difference, next_state, newton_num_iters ] = bdf_util.newton(backward_differences, max_num_newton_iters, newton_coefficients_array.read(order), ode_fn_vec, order, step_size, time, newton_tol, unitary, upper) num_steps += 1 num_ode_fn_evaluations += newton_num_iters # If Newton's method failed and the Jacobian was up to date, decrease the # step size. newton_failed = tf.logical_not(newton_converged) should_update_step_size = newton_failed & jacobian_is_up_to_date new_step_size = step_size * tf.where(should_update_step_size, newton_step_size_factor, 1.) # If Newton's method failed and the Jacobian was NOT up to date, update # the Jacobian. should_update_jacobian = newton_failed & tf.logical_not( jacobian_is_up_to_date) error_ratio = tf.where( newton_converged, bdf_util.error_ratio(next_backward_difference, error_coefficients_array.read(order), tol), np.nan) accepted = error_ratio < 1. converged_and_rejected = newton_converged & tf.logical_not( accepted) # If Newton's method converged but the solution was NOT accepted, decrease # the step size. new_step_size = tf.where( converged_and_rejected, util.next_step_size(step_size, order, error_ratio, safety_factor, min_step_size_factor, max_step_size_factor), new_step_size) should_update_step_size = should_update_step_size | converged_and_rejected # If Newton's method converged and the solution was accepted, update the # matrix of backward differences. time = tf.where(accepted, time + step_size, time) backward_differences = tf.where( accepted, bdf_util.update_backward_differences(backward_differences, next_backward_difference, next_state, order), backward_differences) jacobian_is_up_to_date = jacobian_is_up_to_date & tf.logical_not( accepted) num_steps_same_size = tf.where(accepted, num_steps_same_size + 1, num_steps_same_size) # Order and step size are only updated if we have taken strictly more than # order + 1 steps of the same size. This is to prevent the order from # being throttled. should_update_order_and_step_size = accepted & (num_steps_same_size > order + 1) backward_differences_array = tf.TensorArray( backward_differences.dtype, size=bdf_util.MAX_ORDER + 3, clear_after_read=False, element_shape=next_backward_difference.get_shape()).unstack( backward_differences) new_order = order new_error_ratio = error_ratio for offset in [-1, +1]: proposed_order = tf.clip_by_value(order + offset, 1, max_order) proposed_error_ratio = bdf_util.error_ratio( backward_differences_array.read(proposed_order + 1), error_coefficients_array.read(proposed_order), tol) proposed_error_ratio_is_lower = proposed_error_ratio < new_error_ratio new_order = tf.where( should_update_order_and_step_size & proposed_error_ratio_is_lower, proposed_order, new_order) new_error_ratio = tf.where( should_update_order_and_step_size & proposed_error_ratio_is_lower, proposed_error_ratio, new_error_ratio) order = new_order error_ratio = new_error_ratio new_step_size = tf.where( should_update_order_and_step_size, util.next_step_size(step_size, order, error_ratio, safety_factor, min_step_size_factor, max_step_size_factor), new_step_size) should_update_step_size = (should_update_step_size | should_update_order_and_step_size) diagnostics = _BDFDiagnostics(num_jacobian_evaluations, num_matrix_factorizations, num_ode_fn_evaluations, status) iterand = _BDFIterand(jacobian, jacobian_is_up_to_date, new_step_size, num_steps, num_steps_same_size, should_update_jacobian, should_update_step_size, time, unitary, upper) solver_internal_state = _BDFSolverInternalState( backward_differences, order, state_shape, step_size) return accepted, diagnostics, iterand, solver_internal_state # (1) Make static assertions. # TODO(parsiad): Support specifying Jacobian sparsity patterns. if jacobian_sparsity is not None: raise NotImplementedError( 'The BDF solver does not support specifying ' 'Jacobian sparsity patterns.') if batch_ndims is not None and batch_ndims != 0: raise NotImplementedError( 'The BDF solver does not support batching.') solution_times_chosen_by_solver = (isinstance(solution_times, base.ChosenBySolver)) initial_state_missing = initial_state is None if initial_state_missing and previous_solver_internal_state is None: raise ValueError( 'At least one of `initial_state` or `previous_solver_internal_state` ' 'must be specified') with tf.name_scope(self._name): # (2) Initialize variables. original_initial_state = initial_state if previous_solver_internal_state is None: initial_state = tf.convert_to_tensor(initial_state) original_state_shape = tf.shape(initial_state) else: initial_state = previous_solver_internal_state.backward_differences[ 0] original_state_shape = previous_solver_internal_state.state_shape state_dtype = initial_state.dtype util.error_if_not_real_or_complex(initial_state, 'initial_state') # TODO(parsiad): Support complex automatic Jacobians. if jacobian_fn is None and state_dtype.is_complex: raise NotImplementedError( 'The BDF solver does not support automatic ' 'Jacobian computations for complex dtypes.') num_odes = tf.size(initial_state) original_state_tensor_shape = initial_state.get_shape() initial_state = tf.reshape(initial_state, [-1]) ode_fn_vec = util.get_ode_fn_vec(ode_fn, original_state_shape) # `real_dtype` is the floating point `dtype` associated with # `initial_state.dtype` (recall that the latter can be complex). real_dtype = tf.abs(initial_state).dtype initial_time = tf.ensure_shape( tf.convert_to_tensor(initial_time, dtype=real_dtype), []) num_solution_times = 0 if solution_times_chosen_by_solver: final_time = solution_times.final_time final_time = tf.ensure_shape( tf.convert_to_tensor(final_time, dtype=real_dtype), []) else: solution_times = tf.convert_to_tensor(solution_times, dtype=real_dtype) num_solution_times = tf.size(solution_times) solution_times_array = tf.TensorArray( solution_times.dtype, size=num_solution_times, element_shape=[]).unstack(solution_times) util.error_if_not_vector(solution_times, 'solution_times') jacobian_fn_mat = util.get_jacobian_fn_mat( jacobian_fn, ode_fn_vec, original_state_shape, use_pfor=self._use_pfor_to_compute_jacobian) rtol = tf.convert_to_tensor(self._rtol, dtype=real_dtype) atol = tf.convert_to_tensor(self._atol, dtype=real_dtype) safety_factor = tf.ensure_shape( tf.convert_to_tensor(self._safety_factor, dtype=real_dtype), []) min_step_size_factor = tf.ensure_shape( tf.convert_to_tensor(self._min_step_size_factor, dtype=real_dtype), []) max_step_size_factor = tf.ensure_shape( tf.convert_to_tensor(self._max_step_size_factor, dtype=real_dtype), []) max_num_steps = self._max_num_steps if max_num_steps is not None: max_num_steps = tf.convert_to_tensor(max_num_steps, dtype=tf.int32) max_order = tf.convert_to_tensor(self._max_order, dtype=tf.int32) max_num_newton_iters = self._max_num_newton_iters if max_num_newton_iters is not None: max_num_newton_iters = tf.convert_to_tensor( max_num_newton_iters, dtype=tf.int32) newton_tol_factor = tf.ensure_shape( tf.convert_to_tensor(self._newton_tol_factor, dtype=real_dtype), []) newton_step_size_factor = tf.ensure_shape( tf.convert_to_tensor(self._newton_step_size_factor, dtype=real_dtype), []) bdf_coefficients = tf.cast( tf.concat([[0.], tf.convert_to_tensor(self._bdf_coefficients, dtype=real_dtype)], 0), state_dtype) util.error_if_not_vector(bdf_coefficients, 'bdf_coefficients') newton_coefficients = 1. / ( (1. - bdf_coefficients) * bdf_util.RECIPROCAL_SUMS) newton_coefficients_array = tf.TensorArray( newton_coefficients.dtype, size=bdf_util.MAX_ORDER + 1, clear_after_read=False, element_shape=[]).unstack(newton_coefficients) error_coefficients = bdf_coefficients * bdf_util.RECIPROCAL_SUMS + 1. / ( bdf_util.ORDERS + 1) error_coefficients_array = tf.TensorArray( error_coefficients.dtype, size=bdf_util.MAX_ORDER + 1, clear_after_read=False, element_shape=[]).unstack(error_coefficients) first_step_size = self._first_step_size if first_step_size is None: first_step_size = bdf_util.first_step_size( atol, error_coefficients_array.read(1), initial_state, initial_time, ode_fn_vec, rtol, safety_factor) elif previous_solver_internal_state is not None: tf.logging.warn( '`first_step_size` is ignored since' '`previous_solver_internal_state` was specified.') first_step_size = tf.convert_to_tensor(first_step_size, dtype=real_dtype) if self._validate_args: if max_num_steps is not None: max_num_steps = tf.ensure_shape(max_num_steps, []) max_order = tf.ensure_shape(max_order, []) if max_num_newton_iters is not None: max_num_newton_iters = tf.ensure_shape( max_num_newton_iters, []) bdf_coefficients = tf.ensure_shape(bdf_coefficients, [6]) first_step_size = tf.ensure_shape(first_step_size, []) solver_internal_state = previous_solver_internal_state if solver_internal_state is None: first_order_backward_difference = ode_fn_vec( initial_time, initial_state) * tf.cast( first_step_size, state_dtype) backward_differences = tf.concat([ tf.reshape(initial_state, [1, -1]), first_order_backward_difference[tf.newaxis, :], tf.zeros(tf.stack([bdf_util.MAX_ORDER + 1, num_odes]), dtype=state_dtype), ], 0) solver_internal_state = _BDFSolverInternalState( backward_differences=backward_differences, order=1, state_shape=original_state_shape, step_size=first_step_size) states_array = tf.TensorArray( state_dtype, size=num_solution_times, dynamic_size=solution_times_chosen_by_solver, element_shape=initial_state.get_shape()) times_array = tf.TensorArray( real_dtype, size=num_solution_times, dynamic_size=solution_times_chosen_by_solver, element_shape=tf.TensorShape([])) diagnostics = _BDFDiagnostics(num_jacobian_evaluations=0, num_matrix_factorizations=0, num_ode_fn_evaluations=0, status=0) iterand = _BDFIterand( jacobian=tf.zeros([num_odes, num_odes], dtype=state_dtype), jacobian_is_up_to_date=False, new_step_size=solver_internal_state.step_size, num_steps=0, num_steps_same_size=0, should_update_jacobian=True, should_update_step_size=False, time=initial_time, unitary=tf.zeros([num_odes, num_odes], dtype=state_dtype), upper=tf.zeros([num_odes, num_odes], dtype=state_dtype)) # (3) Make non-static assertions. with tf.control_dependencies(assert_ops()): # (4) Solve up to final time. if solution_times_chosen_by_solver: def step_cond(next_time, diagnostics, iterand, *_): return (iterand.time < next_time) & (tf.equal( diagnostics.status, 0)) [ _, diagnostics, iterand, solver_internal_state, states_array, times_array ] = tf.while_loop(step_cond, step, [ final_time, diagnostics, iterand, solver_internal_state, states_array, times_array ]) else: def advance_to_solution_time_cond(n, diagnostics, *_): return (n < num_solution_times) & (tf.equal( diagnostics.status, 0)) [ _, diagnostics, iterand, solver_internal_state, states_array, times_array ] = tf.while_loop( advance_to_solution_time_cond, advance_to_solution_time, [ 0, diagnostics, iterand, solver_internal_state, states_array, times_array ]) # (6) Return `Results` object. states = tf.reshape(states_array.stack(), tf.concat([[-1], original_state_shape], 0)) times = times_array.stack() if not solution_times_chosen_by_solver: times.set_shape(solution_times.get_shape()) states.set_shape(solution_times.get_shape().concatenate( original_state_tensor_shape)) return base.Results( times=times, states=states, diagnostics=diagnostics, solver_internal_state=solver_internal_state)
def grad_fn(*dresults): """Adjoint sensitivity method to compute gradients.""" dresults = tf.nest.pack_sequence_as(results, dresults) dstates = dresults.states # TODO(b/138304303): Support complex types. state_dtype = initial_state.dtype if state_dtype.is_complex: raise NotImplementedError( 'The adjoint sensitivity method does not ' 'support complex dtypes.') with tf.name_scope('{}Gradients'.format(self._name)): state_shape = tf.shape(initial_state) state_vec_tensor_shape = tf.reshape(initial_state, [-1]).get_shape() num_odes = tf.size(initial_state) ode_fn_vec = util.get_ode_fn_vec(ode_fn, state_shape) real_dtype = tf.abs(initial_state).dtype result_times = tf.concat( [[tf.cast(initial_time, real_dtype)], results.times], 0) num_result_times = tf.size(result_times) # The XLA compiler does not compile code which slices/indexes using # integer `Tensor`s. `TensorArray`s are used to get around this. result_time_array = tf.TensorArray( results.times.dtype, clear_after_read=False, size=num_result_times, element_shape=[]).unstack(result_times) jacobian_fn_mat = util.get_jacobian_fn_mat( jacobian_fn, ode_fn_vec, state_shape, use_pfor=self._use_pfor_to_compute_jacobian) result_state_vec_array = tf.TensorArray( state_dtype, size=num_result_times, dynamic_size=False, element_shape=state_vec_tensor_shape).unstack( tf.reshape(results.states, [num_result_times - 1, -1])) dstate_vec_array = tf.TensorArray( state_dtype, size=num_result_times - 1, dynamic_size=False, element_shape=state_vec_tensor_shape).unstack( tf.reshape(dstates, [num_result_times - 1, -1])) terminal_augmented_state_vec = tf.zeros([num_odes * 2], dtype=state_dtype) def augmented_ode_fn_vec(backward_time, augmented_state_vec): """Dynamics function for the augmented system.""" # The ODE solver cannot handle the case initial_time > final_time # and hence a change of variables backward_time = -time is used. time = -backward_time state_vec, adjoint_state_vec = _decompose_augmented( augmented_state_vec) ode_vec = ode_fn_vec(time, state_vec) # The adjoint ODE is # adj'(t) = -dot(adj(t).transpose(), jacobian_fn(t, state(t)). # The negative sign disappears after the change of variables. adjoint_ode_vec = util.right_mult_by_jacobian_mat( jacobian_fn_mat, ode_fn_vec, time, state_vec, adjoint_state_vec) augmented_ode_vec = _compose_augmented( -ode_vec, adjoint_ode_vec) return augmented_ode_vec def reverse_to_result_time(n, augmented_state_vec, _): """Integrates the augmented system backwards in time.""" lower_bound_of_integration = result_time_array.read(n) upper_bound_of_integration = result_time_array.read(n - 1) _, adjoint_state_vec = _decompose_augmented( augmented_state_vec) adjoint_state_vec.set_shape(state_vec_tensor_shape) augmented_state_vec = _compose_augmented( result_state_vec_array.read(n - 1), adjoint_state_vec + dstate_vec_array.read(n - 1)) # TODO(b/138304303): Allow the user to specify the Hessian of # `ode_fn` so that we can get the Jacobian of the adjoint system. augmented_results = self._solve( augmented_ode_fn_vec, -lower_bound_of_integration, augmented_state_vec, [-upper_bound_of_integration], jacobian_fn=None, jacobian_sparsity=None, batch_ndims=batch_ndims, previous_solver_internal_state=None, ) return (n - 1, augmented_results.states[0], augmented_results.diagnostics.status) _, initial_augmented_state_vec, status = tf.while_loop( lambda n, _, status: (n >= 1) & tf.equal(status, 0), reverse_to_result_time, (num_result_times - 1, terminal_augmented_state_vec, 0), ) _, initial_adjoint_state_vec = _decompose_augmented( initial_augmented_state_vec) on_success = tf.reshape(initial_adjoint_state_vec, state_shape) on_failure = np.nan * tf.ones(state_shape, dtype=state_dtype) return tf.where(tf.equal(status, 0), on_success, on_failure)