def testFalse(self): conditions = [(False, raise_exception)] y = smart_cond.smart_case(conditions, default=lambda: constant_op.constant(1), exclusive=False) z = smart_cond.smart_case(conditions, default=lambda: constant_op.constant(1), exclusive=True) with session.Session() as sess: self.assertEqual(sess.run(y), 1) self.assertEqual(sess.run(z), 1)
def testTrue(self): x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) conditions = [(True, lambda: constant_op.constant(1)), (x == 0, raise_exception)] y = smart_cond.smart_case(conditions, default=raise_exception, exclusive=False) z = smart_cond.smart_case(conditions, default=raise_exception, exclusive=True) with session.Session() as sess: # No feed_dict necessary self.assertEqual(sess.run(y), 1) self.assertEqual(sess.run(z), 1)
def testTrue(self): x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) conditions = [(True, lambda: constant_op.constant(1)), (x == 0, raise_exception)] y = smart_cond.smart_case(conditions, default=raise_exception, exclusive=False) z = smart_cond.smart_case(conditions, default=raise_exception, exclusive=True) with session.Session() as sess: # No feed_dict necessary self.assertEqual(sess.run(y), 1) self.assertEqual(sess.run(z), 1)
def _body(_, evals, val_left, val_right): """Loop body to find the bracketing interval.""" # If the right point has an increasing derivative, then [left, right] # encloses a minimum and we are done. case1 = (val_right.df >= 0), lambda: (True, evals, val_left, val_right) # This case applies if the point has negative derivative (i.e. it is almost # suitable as a left endpoint. def _case2_fn(): inner_evals, left, right = _bisection_update( value_and_gradients_function, val_0, val_right, f_lim) return (True, inner_evals + evals, left, right) case2 = (val_right.f > f_lim, _case2_fn) def _default_fn(): next_right = expansion_param * val_right.x f_next_right, df_next_right = value_and_gradients_function( next_right) val_next_right = _FnDFn(x=next_right, f=f_next_right, df=df_next_right) return False, evals + 1, val_right, val_next_right return smart_cond.smart_case([case1, case2], default=_default_fn, exclusive=False)
def _body(_, evals, val_left, val_right): """Loop body to find the bracketing interval.""" # If the right point has an increasing derivative, then [left, right] # encloses a minimum and we are done. case1 = (val_right.df >= 0), lambda: (True, evals, val_left, val_right) # This case applies if the point has negative derivative (i.e. it is almost # suitable as a left endpoint. def _case2_fn(): inner_evals, left, right = _bisection_update( value_and_gradients_function, val_0, val_right, f_lim) return (True, inner_evals + evals, left, right) case2 = (val_right.f > f_lim, _case2_fn) def _default_fn(): next_right = expansion_param * val_right.x f_next_right, df_next_right = value_and_gradients_function(next_right) val_next_right = _FnDFn(x=next_right, f=f_next_right, df=df_next_right) return False, evals + 1, val_right, val_next_right return smart_cond.smart_case( [ case1, case2 ], default=_default_fn, exclusive=False)
def _common_update(curr_left, curr_right): """Performs secant division to update the interval.""" # Note that curr_left and curr_right may not satisfy opposite slope so # c_bar below may be outside the range [temp_left, temp_right]. c_bar = _secant(curr_left.x, curr_right.x, curr_left.df, curr_right.df) fc_bar, dfc_bar = value_and_gradients_function(c_bar) val_c_bar = _FnDFn(x=c_bar, f=fc_bar, df=dfc_bar) outside_range = (c_bar < val_temp_left.x) | (val_temp_right.x < c_bar) good_c_bar = _satisfies_wolfe( val_0, val_c_bar, f_lim, sufficient_decrease_param=sufficient_decrease_param, curvature_param=curvature_param) def _default_fn(): # Perform yet another update on [temp_left, temp_right] using c_bar. inner_evals, val_left_bar, val_right_bar = _update( value_and_gradients_function, val_temp_left, val_temp_right, val_c_bar, f_lim) return (False, evals + inner_evals + 2, val_left_bar, val_right_bar) return smart_cond.smart_case( [ (outside_range, lambda: (False, evals + 2, val_temp_left, val_temp_right)), (good_c_bar, lambda: (True, evals + 2, val_c_bar, val_c_bar)) ], default=_default_fn, exclusive=False)
def loop_body(_, failed, eval_count, val_left, val_right): # pylint:disable=unused-argument """Updates the right end point to satisfy the opposite slope conditions.""" # Captured by closure: value_and_gradients_function and f_lim mid_pt = (val_left.x + val_right.x) / 2 f_mid, df_mid = value_and_gradients_function(mid_pt) # The case conditions. val_mid = _FnDFn(x=mid_pt, f=f_mid, df=df_mid) failed_case = (~_is_finite(val_mid), lambda: (False, True, eval_count + 1, val_left, val_right)) valid_right = df_mid >= 0 # The new point can be a valid right end point. valid_left = (df_mid < 0) & (f_mid <= f_lim ) # It is a valid left end pt. # The case actions. valid_right_fn = lambda: (True, False, eval_count + 1, val_left, val_mid) # Note that we must return found = False in this case because our target # is to find a good right end point and improvements to the left end point # are coincidental. Hence the loop must continue until we exit via # the valid_right case. valid_left_fn = lambda: (False, False, eval_count + 1, val_mid, val_right) # To be explicit, this action applies when the new point has a positive # slope but the function value at that point is too high. This is the # same situation with which we started the loop in the first place. Hence # we should just replace the old right end point and continue to loop. default_fn = lambda: (False, False, eval_count + 1, val_left, val_mid) cases = (failed_case, (valid_right, valid_right_fn), (valid_left, valid_left_fn)) return smart_cond.smart_case(cases, default=default_fn, exclusive=False)
def _common_update(curr_left, curr_right): """Performs secant division to update the interval.""" # Note that curr_left and curr_right may not satisfy opposite slope so # c_bar below may be outside the range [temp_left, temp_right]. c_bar = _secant(curr_left.x, curr_right.x, curr_left.df, curr_right.df) fc_bar, dfc_bar = value_and_gradients_function(c_bar) val_c_bar = _FnDFn(x=c_bar, f=fc_bar, df=dfc_bar) common_update_failed = ~_is_finite(val_c_bar) outside_range = (c_bar < val_temp_left.x) | (val_temp_right.x < c_bar) good_c_bar = _satisfies_wolfe( val_0, val_c_bar, f_lim, sufficient_decrease_param=sufficient_decrease_param, curvature_param=curvature_param) def _default_fn(): # Perform yet another update on [temp_left, temp_right] using c_bar. failed, inner_evals, val_left_bar, val_right_bar = _update( value_and_gradients_function, val_temp_left, val_temp_right, val_c_bar, f_lim) return (False, failed, evals + inner_evals + 2, val_left_bar, val_right_bar) return smart_cond.smart_case( [(common_update_failed, lambda: (False, True, evals + 2, val_c_bar, val_c_bar)), (outside_range, lambda: (False, False, evals + 2, val_temp_left, val_temp_right)), (good_c_bar, lambda: (True, False, evals + 2, val_c_bar, val_c_bar))], default=_default_fn, exclusive=False)
def _body(_, failed, evals, val_left, val_right): # pylint:disable=unused-argument """Loop body to find the bracketing interval.""" # Check that the function or gradient are finite and quit if they aren't. case0 = (~_is_finite(val_left, val_right), lambda: (False, True, evals, val_left, val_right)) # If the right point has an increasing derivative, then [left, right] # encloses a minimum and we are done. case1 = ((val_right.df >= 0), lambda: (True, False, evals, val_left, val_right)) # This case applies if the point has negative derivative (i.e. it is almost # suitable as a left endpoint. def _case2_fn(): failed, inner_evals, left, right = _bisection_update( value_and_gradients_function, val_0, val_right, f_lim) return (True, failed, inner_evals + evals, left, right) case2 = (val_right.f > f_lim, _case2_fn) def _default_fn(): next_right = expansion_param * val_right.x f_next_right, df_next_right = value_and_gradients_function( next_right) val_next_right = _FnDFn(x=next_right, f=f_next_right, df=df_next_right) failed = ~_is_finite(val_next_right) return False, failed, evals + 1, val_right, val_next_right return smart_cond.smart_case([case0, case1, case2], default=_default_fn, exclusive=False)
def testMix(self): x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) y = constant_op.constant(10) conditions = [(x > 1, lambda: constant_op.constant(1)), (y < 1, raise_exception), (False, raise_exception), (True, lambda: constant_op.constant(3))] z = smart_cond.smart_case(conditions, default=raise_exception) with session.Session() as sess: self.assertEqual(sess.run(z, feed_dict={x: 2}), 1) self.assertEqual(sess.run(z, feed_dict={x: 0}), 3)
def testMix(self): x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) y = constant_op.constant(10) conditions = [(x > 1, lambda: constant_op.constant(1)), (y < 1, raise_exception), (False, raise_exception), (True, lambda: constant_op.constant(3))] z = smart_cond.smart_case(conditions, default=raise_exception) with session.Session() as sess: self.assertEqual(sess.run(z, feed_dict={x: 2}), 1) self.assertEqual(sess.run(z, feed_dict={x: 0}), 3)
def testMix(self): # Constant expression evaluation only works with the C API enabled. if not ops._USE_C_API: return x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) y = constant_op.constant(10) conditions = [(x > 1, lambda: constant_op.constant(1)), (y < 1, raise_exception), (False, raise_exception), (True, lambda: constant_op.constant(3))] z = smart_cond.smart_case(conditions, default=raise_exception) with session.Session() as sess: self.assertEqual(sess.run(z, feed_dict={x: 2}), 1) self.assertEqual(sess.run(z, feed_dict={x: 0}), 3)
def _body(converged, evals, left, right): converged, secant2_evals, next_left, next_right = _secant2( value_and_gradients_function, val_0, left, right, f_lim, sufficient_decrease_param=sufficient_decrease_param, curvature_param=curvature_param) evals += secant2_evals # If converged, then do no further processing. return smart_cond.smart_case( [(converged, lambda: (True, evals, next_left, next_right)), (next_right.x - next_left.x > shrinkage_param * (right.x - left.x), lambda: _update_with_mid(evals, next_left, next_right))], default=lambda: (False, evals, next_left, next_right))
def testMix(self): # Constant expression evaluation only works with the C API enabled. if not ops._USE_C_API: return x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) y = constant_op.constant(10) conditions = [(x > 1, lambda: constant_op.constant(1)), (y < 1, raise_exception), (False, raise_exception), (True, lambda: constant_op.constant(3))] z = smart_cond.smart_case(conditions, default=raise_exception) with session.Session() as sess: self.assertEqual(sess.run(z, feed_dict={x: 2}), 1) self.assertEqual(sess.run(z, feed_dict={x: 0}), 3)
def _body(converged, failed, evals, left, right): # pylint:disable=unused-argument """Line search loop body.""" converged, failed, secant2_evals, next_left, next_right = _secant2( value_and_gradients_function, val_0, left, right, f_lim, sufficient_decrease_param=sufficient_decrease_param, curvature_param=curvature_param) evals += secant2_evals return smart_cond.smart_case( # If converged or failed, then do no further processing. [(converged | failed, lambda: (converged, failed, evals, next_left, next_right)), (next_right.x - next_left.x > shrinkage_param * (right.x - left.x), lambda: _update_with_mid(evals, next_left, next_right))], default=lambda: (False, False, evals, next_left, next_right))
def _body(converged, evals, left, right): converged, secant2_evals, next_left, next_right = _secant2( value_and_gradients_function, val_0, left, right, f_lim, sufficient_decrease_param=sufficient_decrease_param, curvature_param=curvature_param) evals += secant2_evals # If converged, then do no further processing. return smart_cond.smart_case( [(converged, lambda: (True, evals, next_left, next_right)), (next_right.x - next_left.x > shrinkage_param * (right.x - left.x), lambda: _update_with_mid(evals, next_left, next_right))], default=lambda: (False, evals, next_left, next_right))
def loop_body(_, eval_count, val_left, val_right): """Updates the right end point to satisfy the opposite slope conditions.""" # Captured by closure: value_and_gradients_function and f_lim mid_pt = (val_left.x + val_right.x) / 2 f_mid, df_mid = value_and_gradients_function(mid_pt) # The case conditions. valid_right = df_mid >= 0 # The new point can be a valid right end point. valid_left = (df_mid < 0) & (f_mid <= f_lim) # It is a valid left end pt. val_mid = _FnDFn(x=mid_pt, f=f_mid, df=df_mid) # The case actions. valid_right_fn = lambda: (True, eval_count + 1, val_left, val_mid) # Note that we must return found = False in this case because our target # is to find a good right end point and improvements to the left end point # are coincidental. Hence the loop must continue until we exit via # the valid_right case. valid_left_fn = lambda: (False, eval_count + 1, val_mid, val_right) # To be explicit, this action applies when the new point has a positive # slope but the function value at that point is too high. This is the # same situation with which we started the loop in the first place. Hence # we should just replace the old right end point and continue to loop. default_fn = lambda: (False, eval_count + 1, val_left, val_mid) cases = ((valid_right, valid_right_fn), (valid_left, valid_left_fn)) return smart_cond.smart_case(cases, default=default_fn, exclusive=False)
def _body( _, failed, # pylint: disable=unused-argument num_iterations, total_evals, position, objective_value, objective_gradient, inv_hessian_estimate): """Main optimization loop.""" search_direction = _get_search_direction(inv_hessian_estimate, objective_gradient) line_search_value_grad_func = _restrict_along_direction( value_and_gradients_function, position, search_direction) derivative_at_start_pt = tf.reduce_sum(objective_gradient * search_direction) ls_result = linesearch.hager_zhang( line_search_value_grad_func, initial_step_size=tf.constant(1, dtype=dtype), objective_at_zero=objective_value, grad_objective_at_zero=derivative_at_start_pt) # If the line search failed, then quit at this point. failed_retval = BfgsOptimizerResults( converged=False, failed=True, num_iterations=num_iterations + 1, num_objective_evaluations=total_evals + ls_result.func_evals, position=position, objective_value=objective_value, objective_gradient=objective_gradient, inverse_hessian_estimate=inv_hessian_estimate) # Fail if the objective value is not finite or the line search failed. ls_failed_case = ( ~(tf.is_finite(objective_value) & ls_result.converged), lambda: failed_retval) # If the line search didn't fail, then either we need to continue # searching or need to stop because we have converged. position_delta = search_direction * ls_result.left_pt next_position = position + position_delta next_objective, next_objective_gradient = value_and_gradients_function( next_position) grad_norm = tf.norm(next_objective_gradient, ord=2) has_converged = grad_norm <= tolerance grad_delta = next_objective_gradient - objective_gradient updated_inv_hessian = _bfgs_inv_hessian_update( grad_delta, position_delta, inv_hessian_estimate) updated_inv_hessian.set_shape(inv_hessian_estimate.shape) converged_retval = BfgsOptimizerResults( converged=tf.constant(True, name='converged'), failed=tf.constant(False, name='failed'), num_iterations=tf.convert_to_tensor(num_iterations + 1, name='num_iterations'), num_objective_evaluations=tf.convert_to_tensor( total_evals + ls_result.func_evals + 1, name='num_objective_evaluations'), position=next_position, objective_value=next_objective, objective_gradient=next_objective_gradient, inverse_hessian_estimate=updated_inv_hessian) converged_case = (has_converged, lambda: converged_retval) default_retval = BfgsOptimizerResults( converged=tf.constant(False, name='converged'), failed=tf.constant(False, name='failed'), num_iterations=tf.convert_to_tensor(num_iterations + 1, name='num_iterations'), num_objective_evaluations=total_evals + ls_result.func_evals + 1, position=next_position, objective_value=next_objective, objective_gradient=next_objective_gradient, inverse_hessian_estimate=updated_inv_hessian) default_fn = lambda: default_retval return smart_cond.smart_case([ls_failed_case, converged_case], default=default_fn, exclusive=False)
def nelder_mead_one_step(current_simplex, current_objective_values, objective_function=None, dim=None, func_tolerance=None, position_tolerance=None, batch_evaluate_objective=False, reflection=None, expansion=None, contraction=None, shrinkage=None, name=None): """A single iteration of the Nelder Mead algorithm.""" with tf.name_scope(name, 'nelder_mead_one_step'): domain_dtype = current_simplex.dtype.base_dtype order = tf.contrib.framework.argsort(current_objective_values, direction='ASCENDING', stable=True) ( best_index, worst_index, second_worst_index ) = order[0], order[-1], order[-2] worst_vertex = current_simplex[worst_index] ( best_objective_value, worst_objective_value, second_worst_objective_value ) = ( current_objective_values[best_index], current_objective_values[worst_index], current_objective_values[second_worst_index] ) # Compute the centroid of the face opposite the worst vertex. face_centroid = tf.reduce_sum(current_simplex, axis=0) - worst_vertex face_centroid /= tf.cast(dim, domain_dtype) # Reflect the worst vertex through the opposite face. reflected = face_centroid + reflection * (face_centroid - worst_vertex) objective_at_reflected = objective_function(reflected) num_evaluations = 1 has_converged = _check_convergence(current_simplex, current_simplex[best_index], best_objective_value, worst_objective_value, func_tolerance, position_tolerance) def _converged_fn(): return (True, current_simplex, current_objective_values, 0) case0 = has_converged, _converged_fn accept_reflected = ( (objective_at_reflected < second_worst_objective_value) & (objective_at_reflected >= best_objective_value)) accept_reflected_fn = _accept_reflected_fn(current_simplex, current_objective_values, worst_index, reflected, objective_at_reflected) case1 = accept_reflected, accept_reflected_fn do_expansion = objective_at_reflected < best_objective_value expansion_fn = _expansion_fn(objective_function, current_simplex, current_objective_values, worst_index, reflected, objective_at_reflected, face_centroid, expansion) case2 = do_expansion, expansion_fn do_outside_contraction = ( (objective_at_reflected < worst_objective_value) & (objective_at_reflected >= second_worst_objective_value) ) outside_contraction_fn = _outside_contraction_fn( objective_function, current_simplex, current_objective_values, face_centroid, best_index, worst_index, reflected, objective_at_reflected, contraction, shrinkage, batch_evaluate_objective) case3 = do_outside_contraction, outside_contraction_fn default_fn = _inside_contraction_fn(objective_function, current_simplex, current_objective_values, face_centroid, best_index, worst_index, worst_objective_value, contraction, shrinkage, batch_evaluate_objective) ( converged, next_simplex, next_objective_at_simplex, case_evals ) = smart_cond.smart_case( [ case0, case1, case2, case3 ], default=default_fn, exclusive=False) next_simplex.set_shape(current_simplex.shape) next_objective_at_simplex.set_shape(current_objective_values.shape) return ( converged, next_simplex, next_objective_at_simplex, num_evaluations + case_evals )
def _body(_, failed, # pylint: disable=unused-argument total_evals, position, objective_value, objective_gradient, inv_hessian_estimate): """Main optimization loop.""" search_direction = _get_search_direction(inv_hessian_estimate, objective_gradient) line_search_value_grad_func = _restrict_along_direction( value_and_gradients_function, position, search_direction) derivative_at_start_pt = tf.reduce_sum(objective_gradient * search_direction) ls_result = linesearch.hager_zhang( line_search_value_grad_func, initial_step_size=tf.constant(1, dtype=dtype), objective_at_zero=objective_value, grad_objective_at_zero=derivative_at_start_pt) # If the line search failed, then quit at this point. failed_retval = BfgsOptimizerResults( converged=False, failed=True, num_objective_evaluations=total_evals + ls_result.func_evals, position=position, objective_value=objective_value, objective_gradient=objective_gradient, inverse_hessian_estimate=inv_hessian_estimate) ls_failed_case = (~ls_result.converged, lambda: failed_retval) # If the line search didn't fail, then either we need to continue # searching or need to stop because we have converged. position_delta = search_direction * ls_result.left_pt next_position = position + position_delta next_objective, next_objective_gradient = value_and_gradients_function( next_position) grad_norm = tf.norm(next_objective_gradient, ord=2) has_converged = grad_norm <= tolerance grad_delta = next_objective_gradient - objective_gradient updated_inv_hessian = _bfgs_inv_hessian_update(grad_delta, position_delta, inv_hessian_estimate) updated_inv_hessian.set_shape(inv_hessian_estimate.shape) converged_retval = BfgsOptimizerResults( converged=True, failed=False, num_objective_evaluations=total_evals + ls_result.func_evals + 1, position=next_position, objective_value=next_objective, objective_gradient=next_objective_gradient, inverse_hessian_estimate=updated_inv_hessian) converged_case = (has_converged, lambda: converged_retval) default_retval = BfgsOptimizerResults( converged=False, failed=False, num_objective_evaluations=total_evals + ls_result.func_evals + 1, position=next_position, objective_value=next_objective, objective_gradient=next_objective_gradient, inverse_hessian_estimate=updated_inv_hessian) default_fn = lambda: default_retval return smart_cond.smart_case([ls_failed_case, converged_case], default=default_fn, exclusive=False)
def hager_zhang(value_and_gradients_function, initial_step_size=None, objective_at_zero=None, grad_objective_at_zero=None, objective_at_initial_step_size=None, grad_objective_at_initial_step_size=None, threshold_use_approximate_wolfe_condition=1e-6, shrinkage_param=0.66, expansion_param=5.0, sufficient_decrease_param=0.1, curvature_param=0.9, name=None): """The Hager Zhang line search algorithm. Performs an inexact line search based on the algorithm of [Hager and Zhang (2006)][2]. The univariate objective function `value_and_gradients_function` is typically generated by projecting a multivariate objective function along a search direction. Suppose the multivariate function to be minimized is `g(x1,x2, .. xn)`. Let (d1, d2, ..., dn) be the direction along which we wish to perform a line search. Then the projected univariate function to be used for line search is ```None f(a) = g(x1 + d1 * a, x2 + d2 * a, ..., xn + dn * a) ``` The directional derivative along (d1, d2, ..., dn) is needed for this procedure. This also corresponds to the derivative of the projected function `f(a)` with respect to `a`. Note that this derivative must be negative for `a = 0` if the direction is a descent direction. The usual stopping criteria for the line search is the satisfaction of the (weak) Wolfe conditions. For details of the Wolfe conditions, see ref. [3]. On a finite precision machine, the exact Wolfe conditions can be difficult to satisfy when one is very close to the minimum and as argued by [Hager and Zhang (2005)][1], one can only expect the minimum to be determined within square root of machine precision. To improve the situation, they propose to replace the Wolfe conditions with an approximate version depending on the derivative of the function which is applied only when one is very close to the minimum. The following algorithm implements this enhanced scheme. ### Usage: Primary use of line search methods is as an internal component of a class of optimization algorithms (called line search based methods as opposed to trust region methods). Hence, the end user will typically not want to access line search directly. In particular, inexact line search should not be confused with a univariate minimization method. The stopping criteria of line search is the satisfaction of Wolfe conditions and not the discovery of the minimum of the function. With this caveat in mind, the following example illustrates the standalone usage of the line search. ```python # Define a quadratic target with minimum at 1.3. value_and_gradients_function = lambda x: ((x - 1.3) ** 2, 2 * (x-1.3)) # Set initial step size. step_size = tf.constant(0.1) ls_result = tfp.optimizer.linesearch.hager_zhang( value_and_gradients_function, initial_step_size=step_size) # Evaluate the results. with tf.Session() as session: results = session.run(ls_result) # Ensure convergence. assert(results.converged) # If the line search converged, the left and the right ends of the # bracketing interval are identical. assert(results.left_pt == result.right_pt) # Print the number of evaluations and the final step size. print ("Final Step Size: %f, Evaluation: %d" % (results.left_pt, results.func_evals)) ``` ### References: [1]: William Hager, Hongchao Zhang. A new conjugate gradient method with guaranteed descent and an efficient line search. SIAM J. Optim., Vol 16. 1, pp. 170-172. 2005. https://www.math.lsu.edu/~hozhang/papers/cg_descent.pdf [2]: William Hager, Hongchao Zhang. Algorithm 851: CG_DESCENT, a conjugate gradient method with guaranteed descent. ACM Transactions on Mathematical Software, Vol 32., 1, pp. 113-137. 2006. http://users.clas.ufl.edu/hager/papers/CG/cg_compare.pdf [3]: Jorge Nocedal, Stephen Wright. Numerical Optimization. Springer Series in Operations Research. pp 33-36. 2006 Args: value_and_gradients_function: A Python callable that accepts a real scalar tensor and returns a tuple of scalar tensors of real dtype containing the value of the function and its derivative at that point. In usual optimization application, this function would be generated by projecting the multivariate objective function along some specific direction. The direction is determined by some other procedure but should be a descent direction (i.e. the derivative of the projected univariate function must be negative at 0.). initial_step_size: (Optional) Scalar positive `Tensor` of real dtype. The initial value to try to bracket the minimum. Default is `1.` as a float32. Note that this point need not necessarily bracket the minimum for the line search to work correctly but the supplied value must be greater than 0. A good initial value will make the search converge faster. objective_at_zero: (Optional) Scalar `Tensor` of real dtype. If supplied, the value of the function at `0.`. If not supplied, it will be computed. grad_objective_at_zero: (Optional) Scalar `Tensor` of real dtype. If supplied, the derivative of the function at `0.`. If not supplied, it will be computed. objective_at_initial_step_size: (Optional) Scalar `Tensor` of real dtype. If supplied, the value of the function at `initial_step_size`. If not supplied, it will be computed. grad_objective_at_initial_step_size: (Optional) Scalar `Tensor` of real dtype. If supplied, the derivative of the function at `initial_step_size`. If not supplied, it will be computed. threshold_use_approximate_wolfe_condition: Scalar positive `Tensor` of real dtype. Corresponds to the parameter 'epsilon' in [Hager and Zhang (2006)][2]. Used to estimate the threshold at which the line search switches to approximate Wolfe conditions. shrinkage_param: Scalar positive Tensor of real dtype. Must be less than `1.`. Corresponds to the parameter `gamma` in [Hager and Zhang (2006)][2]. If the secant**2 step does not shrink the bracketing interval by this proportion, a bisection step is performed to reduce the interval width. expansion_param: Scalar positive `Tensor` of real dtype. Must be greater than `1.`. Used to expand the initial interval in case it does not bracket a minimum. Corresponds to `rho` in [Hager and Zhang (2006)][2]. sufficient_decrease_param: Positive scalar `Tensor` of real dtype. Bounded above by the curvature param. Corresponds to `delta` in the terminology of [Hager and Zhang (2006)][2]. curvature_param: Positive scalar `Tensor` of real dtype. Bounded above by `1.`. Corresponds to 'sigma' in the terminology of [Hager and Zhang (2006)][2]. name: (Optional) Python str. The name prefixed to the ops created by this function. If not supplied, the default name 'hager_zhang' is used. Returns: results: A namedtuple containing the following attributes. converged: Boolean scalar `Tensor`. Whether a point satisfying Wolfe/Approx wolfe was found. func_evals: Scalar int32 `Tensor`. Number of function evaluations made. left_pt: Scalar `Tensor` of same dtype as `initial_step_size`. The left end point of the final bracketing interval. If converged is True, it is equal to `right_pt`. Otherwise, it corresponds to the last interval computed. objective_at_left_pt: Scalar `Tensor` of same dtype as `objective_at_initial_step_size`. The function value at the left end point. If converged is True, it is equal to `objective_at_right_pt`. Otherwise, it corresponds to the last interval computed. grad_objective_at_left_pt: Scalar `Tensor` of same dtype as `grad_objective_at_initial_step_size`. The derivative of the function at the left end point. If converged is True, it is equal to `grad_objective_at_right_pt`. Otherwise it corresponds to the last interval computed. right_pt: Scalar `Tensor` of same dtype as `initial_step_size`. The right end point of the final bracketing interval. If converged is True, it is equal to 'step'. Otherwise, it corresponds to the last interval computed. objective_at_right_pt: Scalar `Tensor` of same dtype as `objective_at_initial_step_size`. The function value at the right end point. If converged is True, it is equal to fn_step. Otherwise, it corresponds to the last interval computed. grad_objective_at_right_pt' Scalar `Tensor` of same dtype as `grad_objective_at_initial_step_size`. The derivative of the function at the right end point. If converged is True, it is equal to the dfn_step. Otherwise it corresponds to the last interval computed. """ with tf.name_scope(name, 'hager_zhang', [ initial_step_size, objective_at_zero, grad_objective_at_zero, objective_at_initial_step_size, grad_objective_at_initial_step_size, threshold_use_approximate_wolfe_condition, shrinkage_param, expansion_param, sufficient_decrease_param, curvature_param ]): val_0, val_c, f_lim, prepare_evals = _prepare_args( value_and_gradients_function, initial_step_size, objective_at_initial_step_size, grad_objective_at_initial_step_size, objective_at_zero, grad_objective_at_zero, threshold_use_approximate_wolfe_condition) # Checks if the evaluation of the function at the supplied points failed. # If it did, then we quit. eval_failed = ~_is_finite(val_0, val_c) # Check if the initial step size already satisfies the Wolfe conditions. # If it does, there is no further work. already_converged = _satisfies_wolfe(val_0, val_c, f_lim, sufficient_decrease_param, curvature_param) def _cond(converged, failed, *ignored_args): # pylint:disable=unused-argument """Loops until convergence is reached.""" return tf.logical_not(converged | failed) def _update_with_mid(current_evals, left, right): """Corresponds to step L3 in [Hager and Zhang (2006)][2].""" mid_pt = (left.x + right.x) / 2 f_mid, df_mid = value_and_gradients_function(mid_pt) mid = _FnDFn(x=mid_pt, f=f_mid, df=df_mid) mid_failed = ~_is_finite(mid) updated = _update(value_and_gradients_function, left, right, mid, f_lim) failed, update_evals, next_left, next_right = smart_cond.smart_cond( mid_failed, true_fn=lambda: (True, 0, mid, mid), false_fn=lambda: updated) return (False, failed, current_evals + update_evals + 1, next_left, next_right) def _body(converged, failed, evals, left, right): # pylint:disable=unused-argument """Line search loop body.""" converged, failed, secant2_evals, next_left, next_right = _secant2( value_and_gradients_function, val_0, left, right, f_lim, sufficient_decrease_param=sufficient_decrease_param, curvature_param=curvature_param) evals += secant2_evals return smart_cond.smart_case( # If converged or failed, then do no further processing. [(converged | failed, lambda: (converged, failed, evals, next_left, next_right)), (next_right.x - next_left.x > shrinkage_param * (right.x - left.x), lambda: _update_with_mid(evals, next_left, next_right))], default=lambda: (False, False, evals, next_left, next_right)) def do_line_search(): bracketed, failed, bracket_evals, left, right = _bracket( value_and_gradients_function, val_0, val_c, f_lim, expansion_param=expansion_param) failed = failed & tf.logical_not(bracketed) return tf.while_loop( _cond, _body, (False, failed, bracket_evals + prepare_evals, left, right), parallel_iterations=1) converged, failed, func_evals, left, right = smart_cond.smart_case( [(already_converged, lambda: (already_converged, False, prepare_evals, val_c, val_c)), (eval_failed, lambda: (False, True, prepare_evals, val_0, val_c))], default=do_line_search) return HagerZhangLineSearchResult( converged=tf.convert_to_tensor(converged, name='converged'), failed=tf.convert_to_tensor(failed, name='failed'), func_evals=func_evals, left_pt=left.x, objective_at_left_pt=left.f, grad_objective_at_left_pt=left.df, right_pt=right.x, objective_at_right_pt=right.f, grad_objective_at_right_pt=right.df)
def _update(value_and_gradients_function, val_left, val_right, val_trial, f_lim): """Squeezes a bracketing interval containing the minimum. Given an interval which brackets a minimum and a point in that interval, finds a smaller nested interval which also brackets the minimum. If the supplied point does not lie in the bracketing interval, the current interval is returned. The requirement of the interval bracketing a minimum is expressed through the opposite slope conditions. Assume the left end point is 'a', the right end point is 'b', the function to be minimized is 'f' and the derivative is 'df'. The update procedure relies on the following conditions being satisfied: ''' f(a) <= f(0) + epsilon (1) df(a) < 0 (2) df(b) > 0 (3) ''' In the first condition, epsilon is a small positive constant. The condition demands that the function at the left end point be not much bigger than the starting point (i.e. 0). This is an easy to satisfy condition because by assumption, we are in a direction where the function value is decreasing. The second and third conditions together demand that there is at least one zero of the derivative in between a and b. In addition to the interval, the update algorithm requires a third point to be supplied. Usually, this point would lie within the interval [a, b]. If the point is outside this interval, the current interval is returned. If the point lies within the interval, the behaviour of the function and derivative value at this point is used to squeeze the original interval in a manner that preserves the opposite slope conditions. For further details of this component, see the procedure U0-U3 on page 123 of the [Hager and Zhang (2006)][2] article. Note that this function does not explicitly verify whether the opposite slope conditions are satisfied for the supplied interval. It is assumed that this is so. Args: value_and_gradients_function: A Python callable that accepts a real scalar tensor and returns a tuple containing the value of the function and its derivative at that point. val_left: Instance of _FnDFn. The value and derivative of the function evaluated at the left end point of the bracketing interval (labelled 'a' above). val_right: Instance of _FnDFn. The value and derivative of the function evaluated at the right end point of the bracketing interval (labelled 'b' above). val_trial: Instance of _FnDFn. The value and derivative of the function evaluated at the trial point to be used to shrink the interval (labelled 'c' above). f_lim: Scalar `Tensor` of real dtype. The function value threshold for the approximate Wolfe conditions to be checked. Returns: evals: A scalar int32 `Tensor`. The total number of function evaluations made. val_left_bar: Instance of _FnDFn. The position and the associated value and derivative at the updated left end point of the interval. val_right_bar: Instance of _FnDFn. The position and the associated value and derivative at the updated right end point of the interval. """ left, right = val_left.x, val_right.x trial, f_trial, df_trial = val_trial.x, val_trial.f, val_trial.df # If the intermediate point is not in the interval, do nothing. inside_case = ( ((trial < left) | (trial > right)), lambda: (tf.constant(0), val_left, val_right)) # The new point is a valid right end point (has positive derivative). can_update_right = (df_trial >= 0, lambda: (tf.constant(0), val_left, val_trial)) # The new point is a valid left end point because it has negative slope # and the value at the point is not too large. can_update_left = (((df_trial < 0) & (f_trial <= f_lim)), lambda: (tf.constant(0), val_trial, val_right)) def _default_fn(): return _bisection_update(value_and_gradients_function, val_left, val_trial, f_lim) return smart_cond.smart_case( [ inside_case, can_update_right, can_update_left ], default=_default_fn, exclusive=False)
def _secant2(value_and_gradients_function, val_0, val_left, val_right, f_lim, sufficient_decrease_param=0.1, curvature_param=0.9, name=None): """Performs the secant square procedure of Hager Zhang. Given an interval that brackets a root, this procedure performs an update of both end points using two intermediate points generated using the secant interpolation. For details see the steps S1-S4 in [Hager and Zhang (2006)][2]. The interval [a, b] must satisfy the opposite slope conditions described in the documentation for '_update'. Args: value_and_gradients_function: A Python callable that accepts a real scalar tensor and returns a tuple containing the value of the function and its derivative at that point. val_0: Instance of _FnDFn. The function and derivative value at 0. val_left: Instance of _FnDFn. The value and derivative of the function evaluated at the left end point of the bracketing interval (labelled 'a' above). val_right: Instance of _FnDFn. The value and derivative of the function evaluated at the right end point of the bracketing interval (labelled 'b' above). f_lim: Scalar `Tensor` of real dtype. The function value threshold for the approximate Wolfe conditions to be checked. sufficient_decrease_param: Positive scalar `Tensor` of real dtype. Bounded above by the curvature param. Corresponds to 'delta' in the terminology of [Hager and Zhang (2006)][2]. curvature_param: Positive scalar `Tensor` of real dtype. Bounded above by `1.`. Corresponds to 'sigma' in the terminology of [Hager and Zhang (2006)][2]. name: (Optional) Python str. The name prefixed to the ops created by this function. If not supplied, the default name 'secant2' is used. Returns: converged: A scalar boolean `Tensor`. Indicates whether a point satisfying the Wolfe conditions has been found. If this is True, the interval will be degenerate (i.e. val_left_bar and val_right_bar below will be identical). evals: A scalar int32 `Tensor`. The total number of function evaluations made. val_left_bar: Instance of _FnDFn. The position and the associated value and derivative at the updated left end point of the interval. val_right_bar: Instance of _FnDFn. The position and the associated value and derivative at the updated right end point of the interval. """ with tf.name_scope(name, 'secant2', [ val_0, val_left, val_right, f_lim, sufficient_decrease_param, curvature_param ]): a, dfa = val_left.x, val_left.df b, dfb = val_right.x, val_right.df c = _secant(a, b, dfa, dfb) # This will always be s.t. a <= c <= b fc, dfc = value_and_gradients_function(c) val_c = _FnDFn(x=c, f=fc, df=dfc) secant_failed = ~_is_finite(val_c) good_c = _satisfies_wolfe( val_0, val_c, f_lim, sufficient_decrease_param=sufficient_decrease_param, curvature_param=curvature_param) secant_failed_case = (secant_failed, lambda: (False, True, tf.constant(1), val_c, val_c)) # If we have found a point satisfying the Wolfe conditions, # we have converged. converged_case = (good_c, lambda: (True, False, tf.constant(1), val_c, val_c)) # The temp_right and temp_left are the values referred to as A and B in # Hager Zhang (2006). update_failed, evals, val_temp_left, val_temp_right = _update( value_and_gradients_function, val_left, val_right, val_c, f_lim) update_failed_case = ( update_failed, lambda: (False, True, evals + 1, val_temp_left, val_temp_right)) def _common_update(curr_left, curr_right): """Performs secant division to update the interval.""" # Note that curr_left and curr_right may not satisfy opposite slope so # c_bar below may be outside the range [temp_left, temp_right]. c_bar = _secant(curr_left.x, curr_right.x, curr_left.df, curr_right.df) fc_bar, dfc_bar = value_and_gradients_function(c_bar) val_c_bar = _FnDFn(x=c_bar, f=fc_bar, df=dfc_bar) common_update_failed = ~_is_finite(val_c_bar) outside_range = (c_bar < val_temp_left.x) | (val_temp_right.x < c_bar) good_c_bar = _satisfies_wolfe( val_0, val_c_bar, f_lim, sufficient_decrease_param=sufficient_decrease_param, curvature_param=curvature_param) def _default_fn(): # Perform yet another update on [temp_left, temp_right] using c_bar. failed, inner_evals, val_left_bar, val_right_bar = _update( value_and_gradients_function, val_temp_left, val_temp_right, val_c_bar, f_lim) return (False, failed, evals + inner_evals + 2, val_left_bar, val_right_bar) return smart_cond.smart_case( [(common_update_failed, lambda: (False, True, evals + 2, val_c_bar, val_c_bar)), (outside_range, lambda: (False, False, evals + 2, val_temp_left, val_temp_right)), (good_c_bar, lambda: (True, False, evals + 2, val_c_bar, val_c_bar))], default=_default_fn, exclusive=False) # This case checks if the value c has become the right end point # (i.e. c==temp_right). replace_right_case = ( tf.equal(c, val_temp_right.x), lambda: _common_update(val_right, val_temp_right)) # This case checks if the value c has become the left end point # (i.e. c==temp_left). replace_left_case = (tf.equal(c, val_temp_left.x), lambda: _common_update(val_left, val_temp_left)) default_fn = (lambda: (False, False, evals + 1, val_temp_left, val_temp_right)) return smart_cond.smart_case([ secant_failed_case, converged_case, update_failed_case, replace_right_case, replace_left_case ], default=default_fn, exclusive=False)
def _update(value_and_gradients_function, val_left, val_right, val_trial, f_lim): """Squeezes a bracketing interval containing the minimum. Given an interval which brackets a minimum and a point in that interval, finds a smaller nested interval which also brackets the minimum. If the supplied point does not lie in the bracketing interval, the current interval is returned. The requirement of the interval bracketing a minimum is expressed through the opposite slope conditions. Assume the left end point is 'a', the right end point is 'b', the function to be minimized is 'f' and the derivative is 'df'. The update procedure relies on the following conditions being satisfied: ''' f(a) <= f(0) + epsilon (1) df(a) < 0 (2) df(b) > 0 (3) ''' In the first condition, epsilon is a small positive constant. The condition demands that the function at the left end point be not much bigger than the starting point (i.e. 0). This is an easy to satisfy condition because by assumption, we are in a direction where the function value is decreasing. The second and third conditions together demand that there is at least one zero of the derivative in between a and b. In addition to the interval, the update algorithm requires a third point to be supplied. Usually, this point would lie within the interval [a, b]. If the point is outside this interval, the current interval is returned. If the point lies within the interval, the behaviour of the function and derivative value at this point is used to squeeze the original interval in a manner that preserves the opposite slope conditions. For further details of this component, see the procedure U0-U3 on page 123 of the [Hager and Zhang (2006)][2] article. Note that this function does not explicitly verify whether the opposite slope conditions are satisfied for the supplied interval. It is assumed that this is so. Args: value_and_gradients_function: A Python callable that accepts a real scalar tensor and returns a tuple containing the value of the function and its derivative at that point. val_left: Instance of _FnDFn. The value and derivative of the function evaluated at the left end point of the bracketing interval (labelled 'a' above). val_right: Instance of _FnDFn. The value and derivative of the function evaluated at the right end point of the bracketing interval (labelled 'b' above). val_trial: Instance of _FnDFn. The value and derivative of the function evaluated at the trial point to be used to shrink the interval (labelled 'c' above). f_lim: Scalar `Tensor` of real dtype. The function value threshold for the approximate Wolfe conditions to be checked. Returns: update_failed: A boolean scalar `Tensor` indicating whether the objective function failed to yield a finite value at the trial points. evals: A scalar int32 `Tensor`. The total number of function evaluations made. val_left_bar: Instance of _FnDFn. The position and the associated value and derivative at the updated left end point of the interval. val_right_bar: Instance of _FnDFn. The position and the associated value and derivative at the updated right end point of the interval. """ left, right = val_left.x, val_right.x trial, f_trial, df_trial = val_trial.x, val_trial.f, val_trial.df # If the intermediate point is not in the interval, do nothing. inside_case = (((trial < left) | (trial > right)), lambda: (False, tf.constant(0), val_left, val_right)) # The new point is a valid right end point (has positive derivative). can_update_right = (df_trial >= 0, lambda: (False, tf.constant(0), val_left, val_trial)) # The new point is a valid left end point because it has negative slope # and the value at the point is not too large. can_update_left = (((df_trial < 0) & (f_trial <= f_lim)), lambda: (False, tf.constant(0), val_trial, val_right)) def _default_fn(): return _bisection_update(value_and_gradients_function, val_left, val_trial, f_lim) return smart_cond.smart_case( [inside_case, can_update_right, can_update_left], default=_default_fn, exclusive=False)
def _secant2(value_and_gradients_function, val_0, val_left, val_right, f_lim, sufficient_decrease_param=0.1, curvature_param=0.9, name=None): """Performs the secant square procedure of Hager Zhang. Given an interval that brackets a root, this procedure performs an update of both end points using two intermediate points generated using the secant interpolation. For details see the steps S1-S4 in [Hager and Zhang (2006)][2]. The interval [a, b] must satisfy the opposite slope conditions described in the documentation for '_update'. Args: value_and_gradients_function: A Python callable that accepts a real scalar tensor and returns a tuple containing the value of the function and its derivative at that point. val_0: Instance of _FnDFn. The function and derivative value at 0. val_left: Instance of _FnDFn. The value and derivative of the function evaluated at the left end point of the bracketing interval (labelled 'a' above). val_right: Instance of _FnDFn. The value and derivative of the function evaluated at the right end point of the bracketing interval (labelled 'b' above). f_lim: Scalar `Tensor` of real dtype. The function value threshold for the approximate Wolfe conditions to be checked. sufficient_decrease_param: Positive scalar `Tensor` of real dtype. Bounded above by the curvature param. Corresponds to 'delta' in the terminology of [Hager and Zhang (2006)][2]. curvature_param: Positive scalar `Tensor` of real dtype. Bounded above by `1.`. Corresponds to 'sigma' in the terminology of [Hager and Zhang (2006)][2]. name: (Optional) Python str. The name prefixed to the ops created by this function. If not supplied, the default name 'secant2' is used. Returns: converged: A scalar boolean `Tensor`. Indicates whether a point satisfying the Wolfe conditions has been found. If this is True, the interval will be degenerate (i.e. val_left_bar and val_right_bar below will be identical). evals: A scalar int32 `Tensor`. The total number of function evaluations made. val_left_bar: Instance of _FnDFn. The position and the associated value and derivative at the updated left end point of the interval. val_right_bar: Instance of _FnDFn. The position and the associated value and derivative at the updated right end point of the interval. """ with tf.name_scope(name, 'secant2', [val_0, val_left, val_right, f_lim, sufficient_decrease_param, curvature_param]): a, dfa = val_left.x, val_left.df b, dfb = val_right.x, val_right.df c = _secant(a, b, dfa, dfb) # This will always be s.t. a <= c <= b fc, dfc = value_and_gradients_function(c) val_c = _FnDFn(x=c, f=fc, df=dfc) good_c = _satisfies_wolfe( val_0, val_c, f_lim, sufficient_decrease_param=sufficient_decrease_param, curvature_param=curvature_param) # If we have found a point satisfying the Wolfe conditions, # we have converged. case1 = good_c, (lambda: (True, tf.constant(1), val_c, val_c)) # The temp_right and temp_left are the values referred to as A and B in # Hager Zhang (2006). evals, val_temp_left, val_temp_right = _update( value_and_gradients_function, val_left, val_right, val_c, f_lim) def _common_update(curr_left, curr_right): """Performs secant division to update the interval.""" # Note that curr_left and curr_right may not satisfy opposite slope so # c_bar below may be outside the range [temp_left, temp_right]. c_bar = _secant(curr_left.x, curr_right.x, curr_left.df, curr_right.df) fc_bar, dfc_bar = value_and_gradients_function(c_bar) val_c_bar = _FnDFn(x=c_bar, f=fc_bar, df=dfc_bar) outside_range = (c_bar < val_temp_left.x) | (val_temp_right.x < c_bar) good_c_bar = _satisfies_wolfe( val_0, val_c_bar, f_lim, sufficient_decrease_param=sufficient_decrease_param, curvature_param=curvature_param) def _default_fn(): # Perform yet another update on [temp_left, temp_right] using c_bar. inner_evals, val_left_bar, val_right_bar = _update( value_and_gradients_function, val_temp_left, val_temp_right, val_c_bar, f_lim) return (False, evals + inner_evals + 2, val_left_bar, val_right_bar) return smart_cond.smart_case( [ (outside_range, lambda: (False, evals + 2, val_temp_left, val_temp_right)), (good_c_bar, lambda: (True, evals + 2, val_c_bar, val_c_bar)) ], default=_default_fn, exclusive=False) # This case checks if the value c has become the right end point # (i.e. c==temp_right). case2 = (tf.equal(c, val_temp_right.x), lambda: _common_update(val_right, val_temp_right)) # This case checks if the value c has become the left end point # (i.e. c==temp_left). case3 = (tf.equal(c, val_temp_left.x), lambda: _common_update(val_left, val_temp_left)) default_fn = lambda: (False, evals + 1, val_temp_left, val_temp_right) return smart_cond.smart_case([case1, case2, case3], default=default_fn, exclusive=False)
def nelder_mead_one_step(current_simplex, current_objective_values, objective_function=None, dim=None, func_tolerance=None, position_tolerance=None, batch_evaluate_objective=False, reflection=None, expansion=None, contraction=None, shrinkage=None, name=None): """A single iteration of the Nelder Mead algorithm.""" with tf.name_scope(name, 'nelder_mead_one_step'): domain_dtype = current_simplex.dtype.base_dtype order = tf.contrib.framework.argsort(current_objective_values, direction='ASCENDING', stable=True) (best_index, worst_index, second_worst_index) = order[0], order[-1], order[-2] worst_vertex = current_simplex[worst_index] (best_objective_value, worst_objective_value, second_worst_objective_value) = ( current_objective_values[best_index], current_objective_values[worst_index], current_objective_values[second_worst_index]) # Compute the centroid of the face opposite the worst vertex. face_centroid = tf.reduce_sum(current_simplex, axis=0) - worst_vertex face_centroid /= tf.cast(dim, domain_dtype) # Reflect the worst vertex through the opposite face. reflected = face_centroid + reflection * (face_centroid - worst_vertex) objective_at_reflected = objective_function(reflected) num_evaluations = 1 has_converged = _check_convergence(current_simplex, current_simplex[best_index], best_objective_value, worst_objective_value, func_tolerance, position_tolerance) def _converged_fn(): return (True, current_simplex, current_objective_values, 0) case0 = has_converged, _converged_fn accept_reflected = ( (objective_at_reflected < second_worst_objective_value) & (objective_at_reflected >= best_objective_value)) accept_reflected_fn = _accept_reflected_fn(current_simplex, current_objective_values, worst_index, reflected, objective_at_reflected) case1 = accept_reflected, accept_reflected_fn do_expansion = objective_at_reflected < best_objective_value expansion_fn = _expansion_fn(objective_function, current_simplex, current_objective_values, worst_index, reflected, objective_at_reflected, face_centroid, expansion) case2 = do_expansion, expansion_fn do_outside_contraction = ( (objective_at_reflected < worst_objective_value) & (objective_at_reflected >= second_worst_objective_value)) outside_contraction_fn = _outside_contraction_fn( objective_function, current_simplex, current_objective_values, face_centroid, best_index, worst_index, reflected, objective_at_reflected, contraction, shrinkage, batch_evaluate_objective) case3 = do_outside_contraction, outside_contraction_fn default_fn = _inside_contraction_fn( objective_function, current_simplex, current_objective_values, face_centroid, best_index, worst_index, worst_objective_value, contraction, shrinkage, batch_evaluate_objective) (converged, next_simplex, next_objective_at_simplex, case_evals) = smart_cond.smart_case([case0, case1, case2, case3], default=default_fn, exclusive=False) next_simplex.set_shape(current_simplex.shape) next_objective_at_simplex.set_shape(current_objective_values.shape) return (converged, next_simplex, next_objective_at_simplex, num_evaluations + case_evals)