def _body(current_state): """Main optimization loop.""" search_direction = _get_search_direction(current_state) # TODO(b/120134934): Check if the derivative at the start point is not # negative, if so then reset position/gradient deltas and recompute # search direction. next_state = bfgs_utils.line_search_step( current_state, value_and_gradients_function, search_direction, tolerance, f_relative_tolerance, x_tolerance, stopping_condition, max_line_search_iterations) # If not failed or converged, update the Hessian estimate. should_update = ~(next_state.converged | next_state.failed) state_after_inv_hessian_update = bfgs_utils.update_fields( next_state, position_deltas=_queue_push( current_state.position_deltas, should_update, next_state.position - current_state.position), gradient_deltas=_queue_push( current_state.gradient_deltas, should_update, next_state.objective_gradient - current_state.objective_gradient)) return [state_after_inv_hessian_update]
def _body(state): """Main optimization loop.""" search_direction = _get_search_direction( state.inverse_hessian_estimate, state.objective_gradient) derivative_at_start_pt = tf.reduce_sum(state.objective_gradient * search_direction) # If the derivative at the start point is not negative, reset the # Hessian estimate and recompute the search direction. needs_reset = derivative_at_start_pt >= 0 def _reset_search_dirn(): search_direction = _get_search_direction( initial_inv_hessian, state.objective_gradient) return search_direction, initial_inv_hessian search_direction, inv_hessian_estimate = tf.contrib.framework.smart_cond( needs_reset, true_fn=_reset_search_dirn, false_fn=lambda: (search_direction, state.inverse_hessian_estimate)) # Replace the hessian estimate in the state, in case it had to be reset. current_state = bfgs_utils.update_fields( state, inverse_hessian_estimate=inv_hessian_estimate) next_state = bfgs_utils.line_search_step( current_state, value_and_gradients_function, search_direction, tolerance, f_relative_tolerance, x_tolerance) # If not failed or converged, update the Hessian. state_after_inv_hessian_update = tf.contrib.framework.smart_cond( next_state.converged | next_state.failed, lambda: next_state, lambda: _update_inv_hessian(current_state, next_state)) return [state_after_inv_hessian_update]
def _body(state): """Main optimization loop.""" search_direction = _get_search_direction(state.inverse_hessian_estimate, state.objective_gradient) derivative_at_start_pt = tf.reduce_sum( input_tensor=state.objective_gradient * search_direction, axis=-1) # If the derivative at the start point is not negative, recompute the # search direction with the initial inverse Hessian. needs_reset = (~state.failed & ~state.converged & (derivative_at_start_pt >= 0)) search_direction_reset = _get_search_direction( initial_inv_hessian, state.objective_gradient) actual_serch_direction = tf.compat.v1.where(needs_reset, search_direction_reset, search_direction) actual_inv_hessian = tf.compat.v1.where(needs_reset, initial_inv_hessian, state.inverse_hessian_estimate) # Replace the hessian estimate in the state, in case it had to be reset. current_state = bfgs_utils.update_fields( state, inverse_hessian_estimate=actual_inv_hessian) next_state = bfgs_utils.line_search_step( current_state, value_and_gradients_function, actual_serch_direction, tolerance, f_relative_tolerance, x_tolerance, stopping_condition) # Update the inverse Hessian if needed and continue. return [_update_inv_hessian(current_state, next_state)]
def _body(current_state): """Main optimization loop.""" search_direction = _get_search_direction(current_state) # TODO(b/120134934): Check if the derivative at the start point is not # negative, if so then reset position/gradient deltas and recompute # search direction. next_state = bfgs_utils.line_search_step( current_state, value_and_gradients_function, search_direction, tolerance, f_relative_tolerance, x_tolerance, bfgs_utils.converged_all) def _update_inv_hessian(): position_delta = next_state.position - current_state.position gradient_delta = (next_state.objective_gradient - current_state.objective_gradient) return bfgs_utils.update_fields( next_state, position_deltas=_stack_append( current_state.position_deltas, position_delta), gradient_deltas=_stack_append( current_state.gradient_deltas, gradient_delta)) # If not failed or converged, update the Hessian estimate. state_after_inv_hessian_update = prefer_static.cond( next_state.converged | next_state.failed, lambda: next_state, _update_inv_hessian) return [state_after_inv_hessian_update]