Exemplo n.º 1
0
 def testFalse(self):
   conditions = [(False, raise_exception)]
   y = smart_cond.smart_case(conditions,
                             default=lambda: constant_op.constant(1),
                             exclusive=False)
   z = smart_cond.smart_case(conditions,
                             default=lambda: constant_op.constant(1),
                             exclusive=True)
   with session.Session() as sess:
     self.assertEqual(sess.run(y), 1)
     self.assertEqual(sess.run(z), 1)
Exemplo n.º 2
0
 def testTrue(self):
   x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
   conditions = [(True, lambda: constant_op.constant(1)),
                 (x == 0, raise_exception)]
   y = smart_cond.smart_case(conditions, default=raise_exception,
                             exclusive=False)
   z = smart_cond.smart_case(conditions, default=raise_exception,
                             exclusive=True)
   with session.Session() as sess:
     # No feed_dict necessary
     self.assertEqual(sess.run(y), 1)
     self.assertEqual(sess.run(z), 1)
Exemplo n.º 3
0
 def testTrue(self):
     x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
     conditions = [(True, lambda: constant_op.constant(1)),
                   (x == 0, raise_exception)]
     y = smart_cond.smart_case(conditions,
                               default=raise_exception,
                               exclusive=False)
     z = smart_cond.smart_case(conditions,
                               default=raise_exception,
                               exclusive=True)
     with session.Session() as sess:
         # No feed_dict necessary
         self.assertEqual(sess.run(y), 1)
         self.assertEqual(sess.run(z), 1)
Exemplo n.º 4
0
    def _body(_, evals, val_left, val_right):
        """Loop body to find the bracketing interval."""
        # If the right point has an increasing derivative, then [left, right]
        # encloses a minimum and we are done.
        case1 = (val_right.df >= 0), lambda: (True, evals, val_left, val_right)

        # This case applies if the point has negative derivative (i.e. it is almost
        # suitable as a left endpoint.
        def _case2_fn():
            inner_evals, left, right = _bisection_update(
                value_and_gradients_function, val_0, val_right, f_lim)
            return (True, inner_evals + evals, left, right)

        case2 = (val_right.f > f_lim, _case2_fn)

        def _default_fn():
            next_right = expansion_param * val_right.x
            f_next_right, df_next_right = value_and_gradients_function(
                next_right)
            val_next_right = _FnDFn(x=next_right,
                                    f=f_next_right,
                                    df=df_next_right)
            return False, evals + 1, val_right, val_next_right

        return smart_cond.smart_case([case1, case2],
                                     default=_default_fn,
                                     exclusive=False)
Exemplo n.º 5
0
 def _body(_, evals, val_left, val_right):
   """Loop body to find the bracketing interval."""
   # If the right point has an increasing derivative, then [left, right]
   # encloses a minimum and we are done.
   case1 = (val_right.df >= 0), lambda: (True, evals, val_left, val_right)
   # This case applies if the point has negative derivative (i.e. it is almost
   # suitable as a left endpoint.
   def _case2_fn():
     inner_evals, left, right = _bisection_update(
         value_and_gradients_function,
         val_0,
         val_right,
         f_lim)
     return (True, inner_evals + evals, left, right)
   case2 = (val_right.f > f_lim, _case2_fn)
   def _default_fn():
     next_right = expansion_param * val_right.x
     f_next_right, df_next_right = value_and_gradients_function(next_right)
     val_next_right = _FnDFn(x=next_right, f=f_next_right, df=df_next_right)
     return False, evals + 1, val_right, val_next_right
   return smart_cond.smart_case(
       [
           case1,
           case2
       ],
       default=_default_fn,
       exclusive=False)
Exemplo n.º 6
0
 def _common_update(curr_left, curr_right):
   """Performs secant division to update the interval."""
   # Note that curr_left and curr_right may not satisfy opposite slope so
   # c_bar below may be outside the range [temp_left, temp_right].
   c_bar = _secant(curr_left.x, curr_right.x, curr_left.df, curr_right.df)
   fc_bar, dfc_bar = value_and_gradients_function(c_bar)
   val_c_bar = _FnDFn(x=c_bar, f=fc_bar, df=dfc_bar)
   outside_range = (c_bar < val_temp_left.x) | (val_temp_right.x < c_bar)
   good_c_bar = _satisfies_wolfe(
       val_0, val_c_bar, f_lim,
       sufficient_decrease_param=sufficient_decrease_param,
       curvature_param=curvature_param)
   def _default_fn():
     # Perform yet another update on [temp_left, temp_right] using c_bar.
     inner_evals, val_left_bar, val_right_bar = _update(
         value_and_gradients_function,
         val_temp_left,
         val_temp_right,
         val_c_bar,
         f_lim)
     return (False, evals + inner_evals + 2, val_left_bar, val_right_bar)
   return smart_cond.smart_case(
       [
           (outside_range,
            lambda: (False, evals + 2, val_temp_left, val_temp_right)),
           (good_c_bar,
            lambda: (True, evals + 2, val_c_bar, val_c_bar))
       ],
       default=_default_fn,
       exclusive=False)
Exemplo n.º 7
0
    def loop_body(_, failed, eval_count, val_left, val_right):  # pylint:disable=unused-argument
        """Updates the right end point to satisfy the opposite slope conditions."""
        # Captured by closure: value_and_gradients_function and f_lim
        mid_pt = (val_left.x + val_right.x) / 2
        f_mid, df_mid = value_and_gradients_function(mid_pt)
        # The case conditions.
        val_mid = _FnDFn(x=mid_pt, f=f_mid, df=df_mid)
        failed_case = (~_is_finite(val_mid), lambda:
                       (False, True, eval_count + 1, val_left, val_right))
        valid_right = df_mid >= 0  # The new point can be a valid right end point.
        valid_left = (df_mid < 0) & (f_mid <= f_lim
                                     )  # It is a valid left end pt.

        # The case actions.
        valid_right_fn = lambda: (True, False, eval_count + 1, val_left,
                                  val_mid)
        # Note that we must return found = False in this case because our target
        # is to find a good right end point and improvements to the left end point
        # are coincidental. Hence the loop must continue until we exit via
        # the valid_right case.
        valid_left_fn = lambda: (False, False, eval_count + 1, val_mid,
                                 val_right)
        # To be explicit, this action applies when the new point has a positive
        # slope but the function value at that point is too high. This is the
        # same situation with which we started the loop in the first place. Hence
        # we should just replace the old right end point and continue to loop.
        default_fn = lambda: (False, False, eval_count + 1, val_left, val_mid)
        cases = (failed_case, (valid_right, valid_right_fn), (valid_left,
                                                              valid_left_fn))
        return smart_cond.smart_case(cases,
                                     default=default_fn,
                                     exclusive=False)
Exemplo n.º 8
0
        def _common_update(curr_left, curr_right):
            """Performs secant division to update the interval."""
            # Note that curr_left and curr_right may not satisfy opposite slope so
            # c_bar below may be outside the range [temp_left, temp_right].
            c_bar = _secant(curr_left.x, curr_right.x, curr_left.df,
                            curr_right.df)
            fc_bar, dfc_bar = value_and_gradients_function(c_bar)
            val_c_bar = _FnDFn(x=c_bar, f=fc_bar, df=dfc_bar)
            common_update_failed = ~_is_finite(val_c_bar)
            outside_range = (c_bar < val_temp_left.x) | (val_temp_right.x <
                                                         c_bar)
            good_c_bar = _satisfies_wolfe(
                val_0,
                val_c_bar,
                f_lim,
                sufficient_decrease_param=sufficient_decrease_param,
                curvature_param=curvature_param)

            def _default_fn():
                # Perform yet another update on [temp_left, temp_right] using c_bar.
                failed, inner_evals, val_left_bar, val_right_bar = _update(
                    value_and_gradients_function, val_temp_left,
                    val_temp_right, val_c_bar, f_lim)
                return (False, failed, evals + inner_evals + 2, val_left_bar,
                        val_right_bar)

            return smart_cond.smart_case(
                [(common_update_failed, lambda:
                  (False, True, evals + 2, val_c_bar, val_c_bar)),
                 (outside_range, lambda:
                  (False, False, evals + 2, val_temp_left, val_temp_right)),
                 (good_c_bar, lambda:
                  (True, False, evals + 2, val_c_bar, val_c_bar))],
                default=_default_fn,
                exclusive=False)
Exemplo n.º 9
0
    def _body(_, failed, evals, val_left, val_right):  # pylint:disable=unused-argument
        """Loop body to find the bracketing interval."""
        # Check that the function or gradient are finite and quit if they aren't.
        case0 = (~_is_finite(val_left, val_right), lambda:
                 (False, True, evals, val_left, val_right))
        # If the right point has an increasing derivative, then [left, right]
        # encloses a minimum and we are done.
        case1 = ((val_right.df >= 0), lambda:
                 (True, False, evals, val_left, val_right))

        # This case applies if the point has negative derivative (i.e. it is almost
        # suitable as a left endpoint.
        def _case2_fn():
            failed, inner_evals, left, right = _bisection_update(
                value_and_gradients_function, val_0, val_right, f_lim)
            return (True, failed, inner_evals + evals, left, right)

        case2 = (val_right.f > f_lim, _case2_fn)

        def _default_fn():
            next_right = expansion_param * val_right.x
            f_next_right, df_next_right = value_and_gradients_function(
                next_right)
            val_next_right = _FnDFn(x=next_right,
                                    f=f_next_right,
                                    df=df_next_right)
            failed = ~_is_finite(val_next_right)
            return False, failed, evals + 1, val_right, val_next_right

        return smart_cond.smart_case([case0, case1, case2],
                                     default=_default_fn,
                                     exclusive=False)
Exemplo n.º 10
0
 def testMix(self):
     x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
     y = constant_op.constant(10)
     conditions = [(x > 1, lambda: constant_op.constant(1)),
                   (y < 1, raise_exception), (False, raise_exception),
                   (True, lambda: constant_op.constant(3))]
     z = smart_cond.smart_case(conditions, default=raise_exception)
     with session.Session() as sess:
         self.assertEqual(sess.run(z, feed_dict={x: 2}), 1)
         self.assertEqual(sess.run(z, feed_dict={x: 0}), 3)
Exemplo n.º 11
0
 def testMix(self):
   x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
   y = constant_op.constant(10)
   conditions = [(x > 1, lambda: constant_op.constant(1)),
                 (y < 1, raise_exception),
                 (False, raise_exception),
                 (True, lambda: constant_op.constant(3))]
   z = smart_cond.smart_case(conditions, default=raise_exception)
   with session.Session() as sess:
     self.assertEqual(sess.run(z, feed_dict={x: 2}), 1)
     self.assertEqual(sess.run(z, feed_dict={x: 0}), 3)
Exemplo n.º 12
0
    def testMix(self):
        # Constant expression evaluation only works with the C API enabled.
        if not ops._USE_C_API: return

        x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
        y = constant_op.constant(10)
        conditions = [(x > 1, lambda: constant_op.constant(1)),
                      (y < 1, raise_exception), (False, raise_exception),
                      (True, lambda: constant_op.constant(3))]
        z = smart_cond.smart_case(conditions, default=raise_exception)
        with session.Session() as sess:
            self.assertEqual(sess.run(z, feed_dict={x: 2}), 1)
            self.assertEqual(sess.run(z, feed_dict={x: 0}), 3)
Exemplo n.º 13
0
    def _body(converged, evals, left, right):
      converged, secant2_evals, next_left, next_right = _secant2(
          value_and_gradients_function, val_0, left, right, f_lim,
          sufficient_decrease_param=sufficient_decrease_param,
          curvature_param=curvature_param)

      evals += secant2_evals
      # If converged, then do no further processing.
      return smart_cond.smart_case(
          [(converged,
            lambda: (True, evals, next_left, next_right)),
           (next_right.x - next_left.x > shrinkage_param * (right.x - left.x),
            lambda: _update_with_mid(evals, next_left, next_right))],
          default=lambda: (False, evals, next_left, next_right))
Exemplo n.º 14
0
  def testMix(self):
    # Constant expression evaluation only works with the C API enabled.
    if not ops._USE_C_API: return

    x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
    y = constant_op.constant(10)
    conditions = [(x > 1, lambda: constant_op.constant(1)),
                  (y < 1, raise_exception),
                  (False, raise_exception),
                  (True, lambda: constant_op.constant(3))]
    z = smart_cond.smart_case(conditions, default=raise_exception)
    with session.Session() as sess:
      self.assertEqual(sess.run(z, feed_dict={x: 2}), 1)
      self.assertEqual(sess.run(z, feed_dict={x: 0}), 3)
Exemplo n.º 15
0
    def _body(converged, failed, evals, left, right):  # pylint:disable=unused-argument
      """Line search loop body."""
      converged, failed, secant2_evals, next_left, next_right = _secant2(
          value_and_gradients_function, val_0, left, right, f_lim,
          sufficient_decrease_param=sufficient_decrease_param,
          curvature_param=curvature_param)

      evals += secant2_evals

      return smart_cond.smart_case(
          # If converged or failed, then do no further processing.
          [(converged | failed,
            lambda: (converged, failed, evals, next_left, next_right)),
           (next_right.x - next_left.x > shrinkage_param * (right.x - left.x),
            lambda: _update_with_mid(evals, next_left, next_right))],
          default=lambda: (False, False, evals, next_left, next_right))
Exemplo n.º 16
0
        def _body(converged, evals, left, right):
            converged, secant2_evals, next_left, next_right = _secant2(
                value_and_gradients_function,
                val_0,
                left,
                right,
                f_lim,
                sufficient_decrease_param=sufficient_decrease_param,
                curvature_param=curvature_param)

            evals += secant2_evals
            # If converged, then do no further processing.
            return smart_cond.smart_case(
                [(converged, lambda: (True, evals, next_left, next_right)),
                 (next_right.x - next_left.x > shrinkage_param *
                  (right.x - left.x),
                  lambda: _update_with_mid(evals, next_left, next_right))],
                default=lambda: (False, evals, next_left, next_right))
Exemplo n.º 17
0
 def loop_body(_, eval_count, val_left, val_right):
   """Updates the right end point to satisfy the opposite slope conditions."""
   # Captured by closure: value_and_gradients_function and f_lim
   mid_pt = (val_left.x + val_right.x) / 2
   f_mid, df_mid = value_and_gradients_function(mid_pt)
   # The case conditions.
   valid_right = df_mid >= 0  # The new point can be a valid right end point.
   valid_left = (df_mid < 0) & (f_mid <= f_lim)  # It is a valid left end pt.
   val_mid = _FnDFn(x=mid_pt, f=f_mid, df=df_mid)
   # The case actions.
   valid_right_fn = lambda: (True, eval_count + 1, val_left, val_mid)
   # Note that we must return found = False in this case because our target
   # is to find a good right end point and improvements to the left end point
   # are coincidental. Hence the loop must continue until we exit via
   # the valid_right case.
   valid_left_fn = lambda: (False, eval_count + 1, val_mid, val_right)
   # To be explicit, this action applies when the new point has a positive
   # slope but the function value at that point is too high. This is the
   # same situation with which we started the loop in the first place. Hence
   # we should just replace the old right end point and continue to loop.
   default_fn = lambda: (False, eval_count + 1, val_left, val_mid)
   cases = ((valid_right, valid_right_fn), (valid_left, valid_left_fn))
   return smart_cond.smart_case(cases, default=default_fn, exclusive=False)
Exemplo n.º 18
0
        def _body(
                _,
                failed,  # pylint: disable=unused-argument
                num_iterations,
                total_evals,
                position,
                objective_value,
                objective_gradient,
                inv_hessian_estimate):
            """Main optimization loop."""
            search_direction = _get_search_direction(inv_hessian_estimate,
                                                     objective_gradient)
            line_search_value_grad_func = _restrict_along_direction(
                value_and_gradients_function, position, search_direction)
            derivative_at_start_pt = tf.reduce_sum(objective_gradient *
                                                   search_direction)
            ls_result = linesearch.hager_zhang(
                line_search_value_grad_func,
                initial_step_size=tf.constant(1, dtype=dtype),
                objective_at_zero=objective_value,
                grad_objective_at_zero=derivative_at_start_pt)

            # If the line search failed, then quit at this point.
            failed_retval = BfgsOptimizerResults(
                converged=False,
                failed=True,
                num_iterations=num_iterations + 1,
                num_objective_evaluations=total_evals + ls_result.func_evals,
                position=position,
                objective_value=objective_value,
                objective_gradient=objective_gradient,
                inverse_hessian_estimate=inv_hessian_estimate)
            # Fail if the objective value is not finite or the line search failed.
            ls_failed_case = (
                ~(tf.is_finite(objective_value) & ls_result.converged),
                lambda: failed_retval)

            # If the line search didn't fail, then either we need to continue
            # searching or need to stop because we have converged.
            position_delta = search_direction * ls_result.left_pt
            next_position = position + position_delta
            next_objective, next_objective_gradient = value_and_gradients_function(
                next_position)
            grad_norm = tf.norm(next_objective_gradient, ord=2)
            has_converged = grad_norm <= tolerance
            grad_delta = next_objective_gradient - objective_gradient
            updated_inv_hessian = _bfgs_inv_hessian_update(
                grad_delta, position_delta, inv_hessian_estimate)
            updated_inv_hessian.set_shape(inv_hessian_estimate.shape)
            converged_retval = BfgsOptimizerResults(
                converged=tf.constant(True, name='converged'),
                failed=tf.constant(False, name='failed'),
                num_iterations=tf.convert_to_tensor(num_iterations + 1,
                                                    name='num_iterations'),
                num_objective_evaluations=tf.convert_to_tensor(
                    total_evals + ls_result.func_evals + 1,
                    name='num_objective_evaluations'),
                position=next_position,
                objective_value=next_objective,
                objective_gradient=next_objective_gradient,
                inverse_hessian_estimate=updated_inv_hessian)
            converged_case = (has_converged, lambda: converged_retval)
            default_retval = BfgsOptimizerResults(
                converged=tf.constant(False, name='converged'),
                failed=tf.constant(False, name='failed'),
                num_iterations=tf.convert_to_tensor(num_iterations + 1,
                                                    name='num_iterations'),
                num_objective_evaluations=total_evals + ls_result.func_evals +
                1,
                position=next_position,
                objective_value=next_objective,
                objective_gradient=next_objective_gradient,
                inverse_hessian_estimate=updated_inv_hessian)
            default_fn = lambda: default_retval
            return smart_cond.smart_case([ls_failed_case, converged_case],
                                         default=default_fn,
                                         exclusive=False)
Exemplo n.º 19
0
def nelder_mead_one_step(current_simplex,
                         current_objective_values,
                         objective_function=None,
                         dim=None,
                         func_tolerance=None,
                         position_tolerance=None,
                         batch_evaluate_objective=False,
                         reflection=None,
                         expansion=None,
                         contraction=None,
                         shrinkage=None,
                         name=None):
  """A single iteration of the Nelder Mead algorithm."""
  with tf.name_scope(name, 'nelder_mead_one_step'):
    domain_dtype = current_simplex.dtype.base_dtype
    order = tf.contrib.framework.argsort(current_objective_values,
                                         direction='ASCENDING',
                                         stable=True)
    (
        best_index,
        worst_index,
        second_worst_index
    ) = order[0], order[-1], order[-2]

    worst_vertex = current_simplex[worst_index]

    (
        best_objective_value,
        worst_objective_value,
        second_worst_objective_value
    ) = (
        current_objective_values[best_index],
        current_objective_values[worst_index],
        current_objective_values[second_worst_index]
    )

    # Compute the centroid of the face opposite the worst vertex.
    face_centroid = tf.reduce_sum(current_simplex, axis=0) - worst_vertex
    face_centroid /= tf.cast(dim, domain_dtype)

    # Reflect the worst vertex through the opposite face.
    reflected = face_centroid + reflection * (face_centroid - worst_vertex)
    objective_at_reflected = objective_function(reflected)

    num_evaluations = 1
    has_converged = _check_convergence(current_simplex,
                                       current_simplex[best_index],
                                       best_objective_value,
                                       worst_objective_value,
                                       func_tolerance,
                                       position_tolerance)
    def _converged_fn():
      return (True, current_simplex, current_objective_values, 0)
    case0 = has_converged, _converged_fn
    accept_reflected = (
        (objective_at_reflected < second_worst_objective_value) &
        (objective_at_reflected >= best_objective_value))
    accept_reflected_fn = _accept_reflected_fn(current_simplex,
                                               current_objective_values,
                                               worst_index,
                                               reflected,
                                               objective_at_reflected)
    case1 = accept_reflected, accept_reflected_fn
    do_expansion = objective_at_reflected < best_objective_value
    expansion_fn = _expansion_fn(objective_function,
                                 current_simplex,
                                 current_objective_values,
                                 worst_index,
                                 reflected,
                                 objective_at_reflected,
                                 face_centroid,
                                 expansion)
    case2 = do_expansion, expansion_fn
    do_outside_contraction = (
        (objective_at_reflected < worst_objective_value) &
        (objective_at_reflected >= second_worst_objective_value)
    )
    outside_contraction_fn = _outside_contraction_fn(
        objective_function,
        current_simplex,
        current_objective_values,
        face_centroid,
        best_index,
        worst_index,
        reflected,
        objective_at_reflected,
        contraction,
        shrinkage,
        batch_evaluate_objective)
    case3 = do_outside_contraction, outside_contraction_fn
    default_fn = _inside_contraction_fn(objective_function,
                                        current_simplex,
                                        current_objective_values,
                                        face_centroid,
                                        best_index,
                                        worst_index,
                                        worst_objective_value,
                                        contraction,
                                        shrinkage,
                                        batch_evaluate_objective)
    (
        converged,
        next_simplex,
        next_objective_at_simplex,
        case_evals
    ) = smart_cond.smart_case(
        [
            case0,
            case1,
            case2,
            case3
        ],
        default=default_fn, exclusive=False)
    next_simplex.set_shape(current_simplex.shape)
    next_objective_at_simplex.set_shape(current_objective_values.shape)
    return (
        converged,
        next_simplex,
        next_objective_at_simplex,
        num_evaluations + case_evals
    )
Exemplo n.º 20
0
    def _body(_,
              failed,    # pylint: disable=unused-argument
              total_evals,
              position,
              objective_value,
              objective_gradient,
              inv_hessian_estimate):
      """Main optimization loop."""
      search_direction = _get_search_direction(inv_hessian_estimate,
                                               objective_gradient)
      line_search_value_grad_func = _restrict_along_direction(
          value_and_gradients_function, position, search_direction)
      derivative_at_start_pt = tf.reduce_sum(objective_gradient *
                                             search_direction)
      ls_result = linesearch.hager_zhang(
          line_search_value_grad_func,
          initial_step_size=tf.constant(1, dtype=dtype),
          objective_at_zero=objective_value,
          grad_objective_at_zero=derivative_at_start_pt)

      # If the line search failed, then quit at this point.
      failed_retval = BfgsOptimizerResults(
          converged=False,
          failed=True,
          num_objective_evaluations=total_evals + ls_result.func_evals,
          position=position,
          objective_value=objective_value,
          objective_gradient=objective_gradient,
          inverse_hessian_estimate=inv_hessian_estimate)
      ls_failed_case = (~ls_result.converged, lambda: failed_retval)

      # If the line search didn't fail, then either we need to continue
      # searching or need to stop because we have converged.
      position_delta = search_direction * ls_result.left_pt
      next_position = position + position_delta
      next_objective, next_objective_gradient = value_and_gradients_function(
          next_position)
      grad_norm = tf.norm(next_objective_gradient, ord=2)
      has_converged = grad_norm <= tolerance
      grad_delta = next_objective_gradient - objective_gradient
      updated_inv_hessian = _bfgs_inv_hessian_update(grad_delta,
                                                     position_delta,
                                                     inv_hessian_estimate)
      updated_inv_hessian.set_shape(inv_hessian_estimate.shape)
      converged_retval = BfgsOptimizerResults(
          converged=True,
          failed=False,
          num_objective_evaluations=total_evals + ls_result.func_evals + 1,
          position=next_position,
          objective_value=next_objective,
          objective_gradient=next_objective_gradient,
          inverse_hessian_estimate=updated_inv_hessian)
      converged_case = (has_converged, lambda: converged_retval)
      default_retval = BfgsOptimizerResults(
          converged=False,
          failed=False,
          num_objective_evaluations=total_evals + ls_result.func_evals + 1,
          position=next_position,
          objective_value=next_objective,
          objective_gradient=next_objective_gradient,
          inverse_hessian_estimate=updated_inv_hessian)
      default_fn = lambda: default_retval
      return smart_cond.smart_case([ls_failed_case, converged_case],
                                   default=default_fn,
                                   exclusive=False)
Exemplo n.º 21
0
def hager_zhang(value_and_gradients_function,
                initial_step_size=None,
                objective_at_zero=None,
                grad_objective_at_zero=None,
                objective_at_initial_step_size=None,
                grad_objective_at_initial_step_size=None,
                threshold_use_approximate_wolfe_condition=1e-6,
                shrinkage_param=0.66,
                expansion_param=5.0,
                sufficient_decrease_param=0.1,
                curvature_param=0.9,
                name=None):
    """The Hager Zhang line search algorithm.

  Performs an inexact line search based on the algorithm of
  [Hager and Zhang (2006)][2].
  The univariate objective function `value_and_gradients_function` is typically
  generated by projecting
  a multivariate objective function along a search direction. Suppose the
  multivariate function to be minimized is `g(x1,x2, .. xn)`. Let
  (d1, d2, ..., dn) be the direction along which we wish to perform a line
  search. Then the projected univariate function to be used for line search is

  ```None
    f(a) = g(x1 + d1 * a, x2 + d2 * a, ..., xn + dn * a)
  ```

  The directional derivative along (d1, d2, ..., dn) is needed for this
  procedure. This also corresponds to the derivative of the projected function
  `f(a)` with respect to `a`. Note that this derivative must be negative for
  `a = 0` if the direction is a descent direction.

  The usual stopping criteria for the line search is the satisfaction of the
  (weak) Wolfe conditions. For details of the Wolfe conditions, see
  ref. [3]. On a finite precision machine, the exact Wolfe conditions can
  be difficult to satisfy when one is very close to the minimum and as argued
  by [Hager and Zhang (2005)][1], one can only expect the minimum to be
  determined within square root of machine precision. To improve the situation,
  they propose to replace the Wolfe conditions with an approximate version
  depending on the derivative of the function which is applied only when one
  is very close to the minimum. The following algorithm implements this
  enhanced scheme.

  ### Usage:

  Primary use of line search methods is as an internal component of a class of
  optimization algorithms (called line search based methods as opposed to
  trust region methods). Hence, the end user will typically not want to access
  line search directly. In particular, inexact line search should not be
  confused with a univariate minimization method. The stopping criteria of line
  search is the satisfaction of Wolfe conditions and not the discovery of the
  minimum of the function.

  With this caveat in mind, the following example illustrates the standalone
  usage of the line search.

  ```python
    # Define a quadratic target with minimum at 1.3.
    value_and_gradients_function = lambda x: ((x - 1.3) ** 2, 2 * (x-1.3))
    # Set initial step size.
    step_size = tf.constant(0.1)
    ls_result = tfp.optimizer.linesearch.hager_zhang(
        value_and_gradients_function, initial_step_size=step_size)
    # Evaluate the results.
    with tf.Session() as session:
      results = session.run(ls_result)
      # Ensure convergence.
      assert(results.converged)
      # If the line search converged, the left and the right ends of the
      # bracketing interval are identical.
      assert(results.left_pt == result.right_pt)
      # Print the number of evaluations and the final step size.
      print ("Final Step Size: %f, Evaluation: %d" % (results.left_pt,
                                                      results.func_evals))
  ```

  ### References:
  [1]: William Hager, Hongchao Zhang. A new conjugate gradient method with
    guaranteed descent and an efficient line search. SIAM J. Optim., Vol 16. 1,
    pp. 170-172. 2005.
    https://www.math.lsu.edu/~hozhang/papers/cg_descent.pdf
  [2]: William Hager, Hongchao Zhang. Algorithm 851: CG_DESCENT, a conjugate
    gradient method with guaranteed descent. ACM Transactions on Mathematical
    Software, Vol 32., 1, pp. 113-137. 2006.
    http://users.clas.ufl.edu/hager/papers/CG/cg_compare.pdf
  [3]: Jorge Nocedal, Stephen Wright. Numerical Optimization. Springer Series in
    Operations Research. pp 33-36. 2006

  Args:
    value_and_gradients_function: A Python callable that accepts a real scalar
      tensor and returns a tuple of scalar tensors of real dtype containing
      the value of the function and its derivative at that point.
      In usual optimization application, this function would be generated by
      projecting the multivariate objective function along some specific
      direction. The direction is determined by some other procedure but should
      be a descent direction (i.e. the derivative of the projected univariate
      function must be negative at 0.).
    initial_step_size: (Optional) Scalar positive `Tensor` of real dtype. The
      initial value to try to bracket the minimum. Default is `1.` as a float32.
      Note that this point need not necessarily bracket the minimum for the line
      search to work correctly but the supplied value must be greater than
      0. A good initial value will make the search converge faster.
    objective_at_zero: (Optional) Scalar `Tensor` of real dtype. If supplied,
      the value of the function at `0.`. If not supplied, it will be computed.
    grad_objective_at_zero: (Optional) Scalar `Tensor` of real dtype. If
      supplied, the derivative of the  function at `0.`. If not supplied, it
      will be computed.
    objective_at_initial_step_size: (Optional) Scalar `Tensor` of real dtype.
      If supplied, the value of the function at `initial_step_size`.
      If not supplied, it will be computed.
    grad_objective_at_initial_step_size: (Optional) Scalar `Tensor` of real
      dtype. If supplied, the derivative of the  function at
      `initial_step_size`. If not supplied, it will be computed.
    threshold_use_approximate_wolfe_condition: Scalar positive `Tensor`
      of real dtype. Corresponds to the parameter 'epsilon' in
      [Hager and Zhang (2006)][2]. Used to estimate the
      threshold at which the line search switches to approximate Wolfe
      conditions.
    shrinkage_param: Scalar positive Tensor of real dtype. Must be less than
      `1.`. Corresponds to the parameter `gamma` in
      [Hager and Zhang (2006)][2].
      If the secant**2 step does not shrink the bracketing interval by this
      proportion, a bisection step is performed to reduce the interval width.
    expansion_param: Scalar positive `Tensor` of real dtype. Must be greater
      than `1.`. Used to expand the initial interval in case it does not bracket
      a minimum. Corresponds to `rho` in [Hager and Zhang (2006)][2].
    sufficient_decrease_param: Positive scalar `Tensor` of real dtype.
      Bounded above by the curvature param. Corresponds to `delta` in the
      terminology of [Hager and Zhang (2006)][2].
    curvature_param: Positive scalar `Tensor` of real dtype. Bounded above
      by `1.`. Corresponds to 'sigma' in the terminology of
      [Hager and Zhang (2006)][2].
    name: (Optional) Python str. The name prefixed to the ops created by this
      function. If not supplied, the default name 'hager_zhang' is used.

  Returns:
    results: A namedtuple containing the following attributes.
      converged: Boolean scalar `Tensor`. Whether a point satisfying
        Wolfe/Approx wolfe was found.
      func_evals: Scalar int32 `Tensor`. Number of function evaluations made.
      left_pt: Scalar `Tensor` of same dtype as `initial_step_size`. The
        left end point of the final bracketing interval. If converged is True,
        it is equal to `right_pt`. Otherwise, it corresponds to the last
        interval computed.
      objective_at_left_pt: Scalar `Tensor` of same dtype as
        `objective_at_initial_step_size`. The function value at the left
        end point. If converged is True, it is equal to `objective_at_right_pt`.
        Otherwise, it corresponds to the last interval computed.
      grad_objective_at_left_pt: Scalar `Tensor` of same dtype as
        `grad_objective_at_initial_step_size`. The derivative of the function
        at the left end point. If converged is True,
        it is equal to `grad_objective_at_right_pt`. Otherwise it
        corresponds to the last interval computed.
      right_pt: Scalar `Tensor` of same dtype as `initial_step_size`.
        The right end point of the final bracketing interval.
        If converged is True, it is equal to 'step'. Otherwise,
        it corresponds to the last interval computed.
      objective_at_right_pt: Scalar `Tensor` of same dtype as
        `objective_at_initial_step_size`.
        The function value at the right end point. If converged is True, it
        is equal to fn_step. Otherwise, it corresponds to the last
        interval computed.
      grad_objective_at_right_pt'  Scalar `Tensor` of same dtype as
        `grad_objective_at_initial_step_size`.
        The derivative of the function at the right end point.
        If converged is True, it is equal to the dfn_step.
        Otherwise it corresponds to the last interval computed.
  """
    with tf.name_scope(name, 'hager_zhang', [
            initial_step_size, objective_at_zero, grad_objective_at_zero,
            objective_at_initial_step_size,
            grad_objective_at_initial_step_size,
            threshold_use_approximate_wolfe_condition, shrinkage_param,
            expansion_param, sufficient_decrease_param, curvature_param
    ]):

        val_0, val_c, f_lim, prepare_evals = _prepare_args(
            value_and_gradients_function, initial_step_size,
            objective_at_initial_step_size,
            grad_objective_at_initial_step_size, objective_at_zero,
            grad_objective_at_zero, threshold_use_approximate_wolfe_condition)

        # Checks if the evaluation of the function at the supplied points failed.
        # If it did, then we quit.
        eval_failed = ~_is_finite(val_0, val_c)

        # Check if the initial step size already satisfies the Wolfe conditions.
        # If it does, there is no further work.
        already_converged = _satisfies_wolfe(val_0, val_c, f_lim,
                                             sufficient_decrease_param,
                                             curvature_param)

        def _cond(converged, failed, *ignored_args):  # pylint:disable=unused-argument
            """Loops until convergence is reached."""
            return tf.logical_not(converged | failed)

        def _update_with_mid(current_evals, left, right):
            """Corresponds to step L3 in [Hager and Zhang (2006)][2]."""
            mid_pt = (left.x + right.x) / 2
            f_mid, df_mid = value_and_gradients_function(mid_pt)
            mid = _FnDFn(x=mid_pt, f=f_mid, df=df_mid)
            mid_failed = ~_is_finite(mid)
            updated = _update(value_and_gradients_function, left, right, mid,
                              f_lim)
            failed, update_evals, next_left, next_right = smart_cond.smart_cond(
                mid_failed,
                true_fn=lambda: (True, 0, mid, mid),
                false_fn=lambda: updated)
            return (False, failed, current_evals + update_evals + 1, next_left,
                    next_right)

        def _body(converged, failed, evals, left, right):  # pylint:disable=unused-argument
            """Line search loop body."""
            converged, failed, secant2_evals, next_left, next_right = _secant2(
                value_and_gradients_function,
                val_0,
                left,
                right,
                f_lim,
                sufficient_decrease_param=sufficient_decrease_param,
                curvature_param=curvature_param)

            evals += secant2_evals

            return smart_cond.smart_case(
                # If converged or failed, then do no further processing.
                [(converged | failed, lambda:
                  (converged, failed, evals, next_left, next_right)),
                 (next_right.x - next_left.x > shrinkage_param *
                  (right.x - left.x),
                  lambda: _update_with_mid(evals, next_left, next_right))],
                default=lambda: (False, False, evals, next_left, next_right))

        def do_line_search():
            bracketed, failed, bracket_evals, left, right = _bracket(
                value_and_gradients_function,
                val_0,
                val_c,
                f_lim,
                expansion_param=expansion_param)
            failed = failed & tf.logical_not(bracketed)
            return tf.while_loop(
                _cond,
                _body,
                (False, failed, bracket_evals + prepare_evals, left, right),
                parallel_iterations=1)

        converged, failed, func_evals, left, right = smart_cond.smart_case(
            [(already_converged, lambda:
              (already_converged, False, prepare_evals, val_c, val_c)),
             (eval_failed, lambda:
              (False, True, prepare_evals, val_0, val_c))],
            default=do_line_search)

        return HagerZhangLineSearchResult(
            converged=tf.convert_to_tensor(converged, name='converged'),
            failed=tf.convert_to_tensor(failed, name='failed'),
            func_evals=func_evals,
            left_pt=left.x,
            objective_at_left_pt=left.f,
            grad_objective_at_left_pt=left.df,
            right_pt=right.x,
            objective_at_right_pt=right.f,
            grad_objective_at_right_pt=right.df)
Exemplo n.º 22
0
def _update(value_and_gradients_function,
            val_left,
            val_right,
            val_trial,
            f_lim):
  """Squeezes a bracketing interval containing the minimum.

  Given an interval which brackets a minimum and a point in that interval,
  finds a smaller nested interval which also brackets the minimum. If the
  supplied point does not lie in the bracketing interval, the current interval
  is returned.

  The requirement of the interval bracketing a minimum is expressed through the
  opposite slope conditions. Assume the left end point is 'a', the right
  end point is 'b', the function to be minimized is 'f' and the derivative is
  'df'. The update procedure relies on the following conditions being satisfied:

  '''
    f(a) <= f(0) + epsilon   (1)
    df(a) < 0                (2)
    df(b) > 0                (3)
  '''

  In the first condition, epsilon is a small positive constant. The condition
  demands that the function at the left end point be not much bigger than the
  starting point (i.e. 0). This is an easy to satisfy condition because by
  assumption, we are in a direction where the function value is decreasing.
  The second and third conditions together demand that there is at least one
  zero of the derivative in between a and b.

  In addition to the interval, the update algorithm requires a third point to
  be supplied. Usually, this point would lie within the interval [a, b]. If the
  point is outside this interval, the current interval is returned. If the
  point lies within the interval, the behaviour of the function and derivative
  value at this point is used to squeeze the original interval in a manner that
  preserves the opposite slope conditions.

  For further details of this component, see the procedure U0-U3 on page 123 of
  the [Hager and Zhang (2006)][2] article.

  Note that this function does not explicitly verify whether the opposite slope
  conditions are satisfied for the supplied interval. It is assumed that this
  is so.

  Args:
    value_and_gradients_function: A Python callable that accepts a real scalar
      tensor and returns a tuple containing the value of the function and its
      derivative at that point.
    val_left: Instance of _FnDFn. The value and derivative of the function
      evaluated at the left end point of the bracketing interval (labelled 'a'
      above).
    val_right: Instance of _FnDFn. The value and derivative of the function
      evaluated at the right end point of the bracketing interval (labelled 'b'
      above).
    val_trial: Instance of _FnDFn. The value and derivative of the function
      evaluated at the trial point to be used to shrink the interval (labelled
      'c' above).
    f_lim: Scalar `Tensor` of real dtype. The function value threshold for
      the approximate Wolfe conditions to be checked.

  Returns:
    evals: A scalar int32 `Tensor`. The total number of function evaluations
      made.
    val_left_bar: Instance of _FnDFn. The position and the associated value and
      derivative at the updated left end point of the interval.
    val_right_bar: Instance of _FnDFn. The position and the associated value and
      derivative at the updated right end point of the interval.
  """
  left, right = val_left.x, val_right.x
  trial, f_trial, df_trial = val_trial.x, val_trial.f, val_trial.df

  # If the intermediate point is not in the interval, do nothing.
  inside_case = (
      ((trial < left) | (trial > right)),
      lambda: (tf.constant(0), val_left, val_right))

  # The new point is a valid right end point (has positive derivative).
  can_update_right = (df_trial >= 0,
                      lambda: (tf.constant(0), val_left, val_trial))

  # The new point is a valid left end point because it has negative slope
  # and the value at the point is not too large.
  can_update_left = (((df_trial < 0) & (f_trial <= f_lim)),
                     lambda: (tf.constant(0), val_trial, val_right))

  def _default_fn():
    return _bisection_update(value_and_gradients_function,
                             val_left,
                             val_trial,
                             f_lim)
  return smart_cond.smart_case(
      [
          inside_case,
          can_update_right,
          can_update_left
      ],
      default=_default_fn,
      exclusive=False)
Exemplo n.º 23
0
def _secant2(value_and_gradients_function,
             val_0,
             val_left,
             val_right,
             f_lim,
             sufficient_decrease_param=0.1,
             curvature_param=0.9,
             name=None):
    """Performs the secant square procedure of Hager Zhang.

  Given an interval that brackets a root, this procedure performs an update of
  both end points using two intermediate points generated using the secant
  interpolation. For details see the steps S1-S4 in [Hager and Zhang (2006)][2].

  The interval [a, b] must satisfy the opposite slope conditions described in
  the documentation for '_update'.

  Args:
    value_and_gradients_function: A Python callable that accepts a real
      scalar tensor and returns a tuple
      containing the value of the function and its derivative at that point.
    val_0: Instance of _FnDFn. The function and derivative value at 0.
    val_left: Instance of _FnDFn. The value and derivative of the function
      evaluated at the left end point of the bracketing interval (labelled 'a'
      above).
    val_right: Instance of _FnDFn. The value and derivative of the function
      evaluated at the right end point of the bracketing interval (labelled 'b'
      above).
    f_lim: Scalar `Tensor` of real dtype. The function value threshold for
      the approximate Wolfe conditions to be checked.
    sufficient_decrease_param: Positive scalar `Tensor` of real dtype.
      Bounded above by the curvature param. Corresponds to 'delta' in the
      terminology of [Hager and Zhang (2006)][2].
    curvature_param: Positive scalar `Tensor` of real dtype. Bounded above
      by `1.`. Corresponds to 'sigma' in the terminology of
      [Hager and Zhang (2006)][2].
    name: (Optional) Python str. The name prefixed to the ops created by this
      function. If not supplied, the default name 'secant2' is used.

  Returns:
    converged: A scalar boolean `Tensor`. Indicates whether a point
      satisfying the Wolfe conditions has been found. If this is True, the
      interval will be degenerate (i.e. val_left_bar and val_right_bar below
      will be identical).
    evals: A scalar int32 `Tensor`. The total number of function evaluations
      made.
    val_left_bar: Instance of _FnDFn. The position and the associated value and
      derivative at the updated left end point of the interval.
    val_right_bar: Instance of _FnDFn. The position and the associated value and
      derivative at the updated right end point of the interval.
  """
    with tf.name_scope(name, 'secant2', [
            val_0, val_left, val_right, f_lim, sufficient_decrease_param,
            curvature_param
    ]):
        a, dfa = val_left.x, val_left.df
        b, dfb = val_right.x, val_right.df

        c = _secant(a, b, dfa, dfb)  # This will always be s.t. a <= c <= b
        fc, dfc = value_and_gradients_function(c)
        val_c = _FnDFn(x=c, f=fc, df=dfc)
        secant_failed = ~_is_finite(val_c)
        good_c = _satisfies_wolfe(
            val_0,
            val_c,
            f_lim,
            sufficient_decrease_param=sufficient_decrease_param,
            curvature_param=curvature_param)

        secant_failed_case = (secant_failed, lambda:
                              (False, True, tf.constant(1), val_c, val_c))
        # If we have found a point satisfying the Wolfe conditions,
        # we have converged.
        converged_case = (good_c, lambda:
                          (True, False, tf.constant(1), val_c, val_c))
        # The temp_right and temp_left are the values referred to as A and B in
        # Hager Zhang (2006).
        update_failed, evals, val_temp_left, val_temp_right = _update(
            value_and_gradients_function, val_left, val_right, val_c, f_lim)
        update_failed_case = (
            update_failed, lambda:
            (False, True, evals + 1, val_temp_left, val_temp_right))

        def _common_update(curr_left, curr_right):
            """Performs secant division to update the interval."""
            # Note that curr_left and curr_right may not satisfy opposite slope so
            # c_bar below may be outside the range [temp_left, temp_right].
            c_bar = _secant(curr_left.x, curr_right.x, curr_left.df,
                            curr_right.df)
            fc_bar, dfc_bar = value_and_gradients_function(c_bar)
            val_c_bar = _FnDFn(x=c_bar, f=fc_bar, df=dfc_bar)
            common_update_failed = ~_is_finite(val_c_bar)
            outside_range = (c_bar < val_temp_left.x) | (val_temp_right.x <
                                                         c_bar)
            good_c_bar = _satisfies_wolfe(
                val_0,
                val_c_bar,
                f_lim,
                sufficient_decrease_param=sufficient_decrease_param,
                curvature_param=curvature_param)

            def _default_fn():
                # Perform yet another update on [temp_left, temp_right] using c_bar.
                failed, inner_evals, val_left_bar, val_right_bar = _update(
                    value_and_gradients_function, val_temp_left,
                    val_temp_right, val_c_bar, f_lim)
                return (False, failed, evals + inner_evals + 2, val_left_bar,
                        val_right_bar)

            return smart_cond.smart_case(
                [(common_update_failed, lambda:
                  (False, True, evals + 2, val_c_bar, val_c_bar)),
                 (outside_range, lambda:
                  (False, False, evals + 2, val_temp_left, val_temp_right)),
                 (good_c_bar, lambda:
                  (True, False, evals + 2, val_c_bar, val_c_bar))],
                default=_default_fn,
                exclusive=False)

        # This case checks if the value c has become the right end point
        # (i.e. c==temp_right).
        replace_right_case = (
            tf.equal(c, val_temp_right.x),
            lambda: _common_update(val_right, val_temp_right))
        # This case checks if the value c has become the left end point
        # (i.e. c==temp_left).
        replace_left_case = (tf.equal(c, val_temp_left.x),
                             lambda: _common_update(val_left, val_temp_left))
        default_fn = (lambda:
                      (False, False, evals + 1, val_temp_left, val_temp_right))
        return smart_cond.smart_case([
            secant_failed_case, converged_case, update_failed_case,
            replace_right_case, replace_left_case
        ],
                                     default=default_fn,
                                     exclusive=False)
Exemplo n.º 24
0
def _update(value_and_gradients_function, val_left, val_right, val_trial,
            f_lim):
    """Squeezes a bracketing interval containing the minimum.

  Given an interval which brackets a minimum and a point in that interval,
  finds a smaller nested interval which also brackets the minimum. If the
  supplied point does not lie in the bracketing interval, the current interval
  is returned.

  The requirement of the interval bracketing a minimum is expressed through the
  opposite slope conditions. Assume the left end point is 'a', the right
  end point is 'b', the function to be minimized is 'f' and the derivative is
  'df'. The update procedure relies on the following conditions being satisfied:

  '''
    f(a) <= f(0) + epsilon   (1)
    df(a) < 0                (2)
    df(b) > 0                (3)
  '''

  In the first condition, epsilon is a small positive constant. The condition
  demands that the function at the left end point be not much bigger than the
  starting point (i.e. 0). This is an easy to satisfy condition because by
  assumption, we are in a direction where the function value is decreasing.
  The second and third conditions together demand that there is at least one
  zero of the derivative in between a and b.

  In addition to the interval, the update algorithm requires a third point to
  be supplied. Usually, this point would lie within the interval [a, b]. If the
  point is outside this interval, the current interval is returned. If the
  point lies within the interval, the behaviour of the function and derivative
  value at this point is used to squeeze the original interval in a manner that
  preserves the opposite slope conditions.

  For further details of this component, see the procedure U0-U3 on page 123 of
  the [Hager and Zhang (2006)][2] article.

  Note that this function does not explicitly verify whether the opposite slope
  conditions are satisfied for the supplied interval. It is assumed that this
  is so.

  Args:
    value_and_gradients_function: A Python callable that accepts a real scalar
      tensor and returns a tuple containing the value of the function and its
      derivative at that point.
    val_left: Instance of _FnDFn. The value and derivative of the function
      evaluated at the left end point of the bracketing interval (labelled 'a'
      above).
    val_right: Instance of _FnDFn. The value and derivative of the function
      evaluated at the right end point of the bracketing interval (labelled 'b'
      above).
    val_trial: Instance of _FnDFn. The value and derivative of the function
      evaluated at the trial point to be used to shrink the interval (labelled
      'c' above).
    f_lim: Scalar `Tensor` of real dtype. The function value threshold for
      the approximate Wolfe conditions to be checked.

  Returns:
    update_failed: A boolean scalar `Tensor` indicating whether the objective
      function failed to yield a finite value at the trial points.
    evals: A scalar int32 `Tensor`. The total number of function evaluations
      made.
    val_left_bar: Instance of _FnDFn. The position and the associated value and
      derivative at the updated left end point of the interval.
    val_right_bar: Instance of _FnDFn. The position and the associated value and
      derivative at the updated right end point of the interval.
  """
    left, right = val_left.x, val_right.x
    trial, f_trial, df_trial = val_trial.x, val_trial.f, val_trial.df

    # If the intermediate point is not in the interval, do nothing.
    inside_case = (((trial < left) | (trial > right)), lambda:
                   (False, tf.constant(0), val_left, val_right))

    # The new point is a valid right end point (has positive derivative).
    can_update_right = (df_trial >= 0, lambda:
                        (False, tf.constant(0), val_left, val_trial))

    # The new point is a valid left end point because it has negative slope
    # and the value at the point is not too large.
    can_update_left = (((df_trial < 0) & (f_trial <= f_lim)), lambda:
                       (False, tf.constant(0), val_trial, val_right))

    def _default_fn():
        return _bisection_update(value_and_gradients_function, val_left,
                                 val_trial, f_lim)

    return smart_cond.smart_case(
        [inside_case, can_update_right, can_update_left],
        default=_default_fn,
        exclusive=False)
Exemplo n.º 25
0
def _secant2(value_and_gradients_function,
             val_0,
             val_left,
             val_right,
             f_lim,
             sufficient_decrease_param=0.1,
             curvature_param=0.9,
             name=None):
  """Performs the secant square procedure of Hager Zhang.

  Given an interval that brackets a root, this procedure performs an update of
  both end points using two intermediate points generated using the secant
  interpolation. For details see the steps S1-S4 in [Hager and Zhang (2006)][2].

  The interval [a, b] must satisfy the opposite slope conditions described in
  the documentation for '_update'.

  Args:
    value_and_gradients_function: A Python callable that accepts a real
      scalar tensor and returns a tuple
      containing the value of the function and its derivative at that point.
    val_0: Instance of _FnDFn. The function and derivative value at 0.
    val_left: Instance of _FnDFn. The value and derivative of the function
      evaluated at the left end point of the bracketing interval (labelled 'a'
      above).
    val_right: Instance of _FnDFn. The value and derivative of the function
      evaluated at the right end point of the bracketing interval (labelled 'b'
      above).
    f_lim: Scalar `Tensor` of real dtype. The function value threshold for
      the approximate Wolfe conditions to be checked.
    sufficient_decrease_param: Positive scalar `Tensor` of real dtype.
      Bounded above by the curvature param. Corresponds to 'delta' in the
      terminology of [Hager and Zhang (2006)][2].
    curvature_param: Positive scalar `Tensor` of real dtype. Bounded above
      by `1.`. Corresponds to 'sigma' in the terminology of
      [Hager and Zhang (2006)][2].
    name: (Optional) Python str. The name prefixed to the ops created by this
      function. If not supplied, the default name 'secant2' is used.

  Returns:
    converged: A scalar boolean `Tensor`. Indicates whether a point
      satisfying the Wolfe conditions has been found. If this is True, the
      interval will be degenerate (i.e. val_left_bar and val_right_bar below
      will be identical).
    evals: A scalar int32 `Tensor`. The total number of function evaluations
      made.
    val_left_bar: Instance of _FnDFn. The position and the associated value and
      derivative at the updated left end point of the interval.
    val_right_bar: Instance of _FnDFn. The position and the associated value and
      derivative at the updated right end point of the interval.
  """
  with tf.name_scope(name, 'secant2',
                     [val_0,
                      val_left,
                      val_right,
                      f_lim,
                      sufficient_decrease_param,
                      curvature_param]):
    a, dfa = val_left.x, val_left.df
    b, dfb = val_right.x, val_right.df

    c = _secant(a, b, dfa, dfb)  # This will always be s.t. a <= c <= b
    fc, dfc = value_and_gradients_function(c)
    val_c = _FnDFn(x=c, f=fc, df=dfc)
    good_c = _satisfies_wolfe(
        val_0, val_c, f_lim,
        sufficient_decrease_param=sufficient_decrease_param,
        curvature_param=curvature_param)

    # If we have found a point satisfying the Wolfe conditions,
    # we have converged.
    case1 = good_c, (lambda: (True, tf.constant(1), val_c, val_c))
    # The temp_right and temp_left are the values referred to as A and B in
    # Hager Zhang (2006).
    evals, val_temp_left, val_temp_right = _update(
        value_and_gradients_function, val_left, val_right, val_c, f_lim)

    def _common_update(curr_left, curr_right):
      """Performs secant division to update the interval."""
      # Note that curr_left and curr_right may not satisfy opposite slope so
      # c_bar below may be outside the range [temp_left, temp_right].
      c_bar = _secant(curr_left.x, curr_right.x, curr_left.df, curr_right.df)
      fc_bar, dfc_bar = value_and_gradients_function(c_bar)
      val_c_bar = _FnDFn(x=c_bar, f=fc_bar, df=dfc_bar)
      outside_range = (c_bar < val_temp_left.x) | (val_temp_right.x < c_bar)
      good_c_bar = _satisfies_wolfe(
          val_0, val_c_bar, f_lim,
          sufficient_decrease_param=sufficient_decrease_param,
          curvature_param=curvature_param)
      def _default_fn():
        # Perform yet another update on [temp_left, temp_right] using c_bar.
        inner_evals, val_left_bar, val_right_bar = _update(
            value_and_gradients_function,
            val_temp_left,
            val_temp_right,
            val_c_bar,
            f_lim)
        return (False, evals + inner_evals + 2, val_left_bar, val_right_bar)
      return smart_cond.smart_case(
          [
              (outside_range,
               lambda: (False, evals + 2, val_temp_left, val_temp_right)),
              (good_c_bar,
               lambda: (True, evals + 2, val_c_bar, val_c_bar))
          ],
          default=_default_fn,
          exclusive=False)
    # This case checks if the value c has become the right end point
    # (i.e. c==temp_right).
    case2 = (tf.equal(c, val_temp_right.x),
             lambda: _common_update(val_right, val_temp_right))
    # This case checks if the value c has become the left end point
    # (i.e. c==temp_left).
    case3 = (tf.equal(c, val_temp_left.x),
             lambda: _common_update(val_left, val_temp_left))
    default_fn = lambda: (False, evals + 1, val_temp_left, val_temp_right)
    return smart_cond.smart_case([case1, case2, case3],
                                 default=default_fn, exclusive=False)
Exemplo n.º 26
0
def nelder_mead_one_step(current_simplex,
                         current_objective_values,
                         objective_function=None,
                         dim=None,
                         func_tolerance=None,
                         position_tolerance=None,
                         batch_evaluate_objective=False,
                         reflection=None,
                         expansion=None,
                         contraction=None,
                         shrinkage=None,
                         name=None):
    """A single iteration of the Nelder Mead algorithm."""
    with tf.name_scope(name, 'nelder_mead_one_step'):
        domain_dtype = current_simplex.dtype.base_dtype
        order = tf.contrib.framework.argsort(current_objective_values,
                                             direction='ASCENDING',
                                             stable=True)
        (best_index, worst_index,
         second_worst_index) = order[0], order[-1], order[-2]

        worst_vertex = current_simplex[worst_index]

        (best_objective_value, worst_objective_value,
         second_worst_objective_value) = (
             current_objective_values[best_index],
             current_objective_values[worst_index],
             current_objective_values[second_worst_index])

        # Compute the centroid of the face opposite the worst vertex.
        face_centroid = tf.reduce_sum(current_simplex, axis=0) - worst_vertex
        face_centroid /= tf.cast(dim, domain_dtype)

        # Reflect the worst vertex through the opposite face.
        reflected = face_centroid + reflection * (face_centroid - worst_vertex)
        objective_at_reflected = objective_function(reflected)

        num_evaluations = 1
        has_converged = _check_convergence(current_simplex,
                                           current_simplex[best_index],
                                           best_objective_value,
                                           worst_objective_value,
                                           func_tolerance, position_tolerance)

        def _converged_fn():
            return (True, current_simplex, current_objective_values, 0)

        case0 = has_converged, _converged_fn
        accept_reflected = (
            (objective_at_reflected < second_worst_objective_value) &
            (objective_at_reflected >= best_objective_value))
        accept_reflected_fn = _accept_reflected_fn(current_simplex,
                                                   current_objective_values,
                                                   worst_index, reflected,
                                                   objective_at_reflected)
        case1 = accept_reflected, accept_reflected_fn
        do_expansion = objective_at_reflected < best_objective_value
        expansion_fn = _expansion_fn(objective_function, current_simplex,
                                     current_objective_values, worst_index,
                                     reflected, objective_at_reflected,
                                     face_centroid, expansion)
        case2 = do_expansion, expansion_fn
        do_outside_contraction = (
            (objective_at_reflected < worst_objective_value) &
            (objective_at_reflected >= second_worst_objective_value))
        outside_contraction_fn = _outside_contraction_fn(
            objective_function, current_simplex, current_objective_values,
            face_centroid, best_index, worst_index, reflected,
            objective_at_reflected, contraction, shrinkage,
            batch_evaluate_objective)
        case3 = do_outside_contraction, outside_contraction_fn
        default_fn = _inside_contraction_fn(
            objective_function, current_simplex, current_objective_values,
            face_centroid, best_index, worst_index, worst_objective_value,
            contraction, shrinkage, batch_evaluate_objective)
        (converged, next_simplex, next_objective_at_simplex,
         case_evals) = smart_cond.smart_case([case0, case1, case2, case3],
                                             default=default_fn,
                                             exclusive=False)
        next_simplex.set_shape(current_simplex.shape)
        next_objective_at_simplex.set_shape(current_objective_values.shape)
        return (converged, next_simplex, next_objective_at_simplex,
                num_evaluations + case_evals)