Пример #1
0
        def _body(current_state):
            """Main optimization loop."""
            search_direction = _get_search_direction(current_state)

            # TODO(b/120134934): Check if the derivative at the start point is not
            # negative, if so then reset position/gradient deltas and recompute
            # search direction.

            next_state = bfgs_utils.line_search_step(
                current_state, value_and_gradients_function, search_direction,
                tolerance, f_relative_tolerance, x_tolerance,
                stopping_condition, max_line_search_iterations)

            # If not failed or converged, update the Hessian estimate.
            should_update = ~(next_state.converged | next_state.failed)
            state_after_inv_hessian_update = bfgs_utils.update_fields(
                next_state,
                position_deltas=_queue_push(
                    current_state.position_deltas, should_update,
                    next_state.position - current_state.position),
                gradient_deltas=_queue_push(
                    current_state.gradient_deltas, should_update,
                    next_state.objective_gradient -
                    current_state.objective_gradient))
            return [state_after_inv_hessian_update]
Пример #2
0
    def _body(state):
      """Main optimization loop."""
      search_direction = _get_search_direction(state.inverse_hessian_estimate,
                                               state.objective_gradient)
      derivative_at_start_pt = tf.reduce_sum(
          input_tensor=state.objective_gradient * search_direction, axis=-1)

      # If the derivative at the start point is not negative, recompute the
      # search direction with the initial inverse Hessian.
      needs_reset = (~state.failed & ~state.converged &
                     (derivative_at_start_pt >= 0))

      search_direction_reset = _get_search_direction(
          initial_inv_hessian, state.objective_gradient)

      actual_serch_direction = tf.compat.v1.where(needs_reset,
                                                  search_direction_reset,
                                                  search_direction)
      actual_inv_hessian = tf.compat.v1.where(needs_reset, initial_inv_hessian,
                                              state.inverse_hessian_estimate)

      # Replace the hessian estimate in the state, in case it had to be reset.
      current_state = bfgs_utils.update_fields(
          state, inverse_hessian_estimate=actual_inv_hessian)

      next_state = bfgs_utils.line_search_step(
          current_state,
          value_and_gradients_function, actual_serch_direction,
          tolerance, f_relative_tolerance, x_tolerance, stopping_condition)

      # Update the inverse Hessian if needed and continue.
      return [_update_inv_hessian(current_state, next_state)]
Пример #3
0
        def _body(state):
            """Main optimization loop."""

            search_direction = _get_search_direction(
                state.inverse_hessian_estimate, state.objective_gradient)
            derivative_at_start_pt = tf.reduce_sum(state.objective_gradient *
                                                   search_direction)
            # If the derivative at the start point is not negative, reset the
            # Hessian estimate and recompute the search direction.
            needs_reset = derivative_at_start_pt >= 0

            def _reset_search_dirn():
                search_direction = _get_search_direction(
                    initial_inv_hessian, state.objective_gradient)
                return search_direction, initial_inv_hessian

            search_direction, inv_hessian_estimate = tf.contrib.framework.smart_cond(
                needs_reset,
                true_fn=_reset_search_dirn,
                false_fn=lambda:
                (search_direction, state.inverse_hessian_estimate))

            # Replace the hessian estimate in the state, in case it had to be reset.
            current_state = bfgs_utils.update_fields(
                state, inverse_hessian_estimate=inv_hessian_estimate)

            next_state = bfgs_utils.line_search_step(
                current_state, value_and_gradients_function, search_direction,
                tolerance, f_relative_tolerance, x_tolerance)

            # If not failed or converged, update the Hessian.
            state_after_inv_hessian_update = tf.contrib.framework.smart_cond(
                next_state.converged | next_state.failed, lambda: next_state,
                lambda: _update_inv_hessian(current_state, next_state))
            return [state_after_inv_hessian_update]
Пример #4
0
def _update_inv_hessian(prev_state, next_state):
    """Update the BGFS state by computing the next inverse hessian estimate."""
    next_inv_hessian = _bfgs_inv_hessian_update(
        next_state.objective_gradient - prev_state.objective_gradient,
        next_state.position - prev_state.position,
        prev_state.inverse_hessian_estimate)
    return bfgs_utils.update_fields(next_state,
                                    inverse_hessian_estimate=next_inv_hessian)
Пример #5
0
 def _do_update_inv_hessian():
   next_inv_hessian = _bfgs_inv_hessian_update(
       gradient_delta, position_delta, normalization_factor,
       prev_state.inverse_hessian_estimate)
   return bfgs_utils.update_fields(
       next_state,
       inverse_hessian_estimate=tf.compat.v1.where(
           should_update, next_inv_hessian,
           prev_state.inverse_hessian_estimate))
Пример #6
0
 def _update_inv_hessian():
     position_delta = next_state.position - current_state.position
     gradient_delta = (next_state.objective_gradient -
                       current_state.objective_gradient)
     return bfgs_utils.update_fields(
         next_state,
         position_deltas=_stack_append(
             current_state.position_deltas, position_delta),
         gradient_deltas=_stack_append(
             current_state.gradient_deltas, gradient_delta))