def test_bracket_batching(self): """Tests that bracketing works in batching mode.""" wolfe_threshold = 1e-6 # We build an example function with 4 batches, each for one of the # following cases: # - a) Minimum bracketed from the beginning. # - b) Minimum bracketed after one expansion. # - c) Needs bisect from the beginning. # - d) Needs one round of expansion and then bisect. x = tf.constant([0.0, 1.0, 5.0]) y = tf.constant([[1.0, 1.2, 1.1], [1.0, 0.9, 1.2], [1.0, 1.1, 1.2], [1.0, 0.9, 1.1]]) dy = tf.constant([[-0.8, 0.6, -0.8], [-0.8, -0.7, 0.6], [-0.8, -0.7, -0.8], [-0.8, -0.7, -0.8]]) fun = test_function_x_y_dy(x, y, dy, eps=0.1) val_a = hzl._apply(fun, tf.zeros(4)) # Values at zero. val_b = hzl._apply(fun, tf.ones(4)) # Values at initial step. f_lim = val_a.f + (wolfe_threshold * tf.abs(val_a.f)) expected_left = np.array([0.0, 1.0, 0.0, 0.0]) expected_right = np.array([1.0, 5.0, 0.5, 2.5]) result = self.evaluate( hzl.bracket(fun, val_a, val_b, f_lim, max_iterations=5)) self.assertEqual(result.num_evals, 2) # Once bracketing, once bisecting. self.assertTrue(np.all(result.stopped)) self.assertTrue(np.all(~result.failed)) self.assertTrue(np.all(result.left.df < 0)) # Opposite slopes. self.assertTrue(np.all(result.right.df >= 0)) self.assertArrayNear(result.left.x, expected_left, 1e-5) self.assertArrayNear(result.right.x, expected_right, 1e-5)
def test_bracket_simple(self): """Tests that bracketing works on a 1 variable scalar valued function.""" # Example crafted to require one expansion during bracketing, and then # some bisection; same as case (d) in test_bracket_batching below. wolfe_threshold = 1e-6 x = np.array([0.0, 1.0, 2.5, 5.0]) y = np.array([1.0, 0.9, -2.0, 1.1]) dy = np.array([-0.8, -0.7, 1.6, -0.8]) fun = test_function_x_y_dy(x, y, dy) val_a = hzl._apply(fun, 0.0) # Value at zero. val_b = hzl._apply(fun, 1.0) # Value at initial step. f_lim = val_a.f + (wolfe_threshold * tf.abs(val_a.f)) result = self.evaluate( hzl.bracket(fun, val_a, val_b, f_lim, max_iterations=5)) self.assertEqual(result.iteration, 1) # One expansion. self.assertEqual(result.num_evals, 2) # Once bracketing, once bisecting. self.assertEqual(result.left.x, 0.0) self.assertEqual(result.right.x, 2.5) self.assertLess(result.left.df, 0) # Opposite slopes. self.assertGreaterEqual(result.right.df, 0)
def _bracket_and_search(value_and_gradients_function, init_interval, f_lim, max_iterations, shrinkage_param, expansion_param, sufficient_decrease_param, curvature_param): """Brackets the minimum and performs a line search. Args: value_and_gradients_function: A Python callable that accepts a real scalar tensor and returns a namedtuple with the fields 'x', 'f', and 'df' that correspond to scalar tensors of real dtype containing the point at which the function was evaluated, the value of the function, and its derivative at that point. The other namedtuple fields, if present, should be tensors or sequences (possibly nested) of tensors. In usual optimization application, this function would be generated by projecting the multivariate objective function along some specific direction. The direction is determined by some other procedure but should be a descent direction (i.e. the derivative of the projected univariate function must be negative at 0.). Alternatively, the function may represent the batching of `n` such line functions (e.g. projecting a single multivariate objective function along `n` distinct directions at once) accepting n points as input, i.e. a tensor of shape [n], and the fields 'x', 'f' and 'df' in the returned namedtuple should each be a tensor of shape [n], with the corresponding input points, function values, and derivatives at those input points. init_interval: Instance of `HagerZhangLineSearchResults` containing the initial line search interval. The gradient of init_interval.left must be negative (i.e. must be a descent direction), while init_interval.right must be positive and finite. f_lim: Scalar `Tensor` of float dtype. max_iterations: Positive scalar `Tensor` of integral dtype. The maximum number of iterations to perform in the line search. The number of iterations used to bracket the minimum are also counted against this parameter. shrinkage_param: Scalar positive Tensor of real dtype. Must be less than `1.`. Corresponds to the parameter `gamma` in [Hager and Zhang (2006)][2]. expansion_param: Scalar positive `Tensor` of real dtype. Must be greater than `1.`. Used to expand the initial interval in case it does not bracket a minimum. Corresponds to `rho` in [Hager and Zhang (2006)][2]. sufficient_decrease_param: Positive scalar `Tensor` of real dtype. Bounded above by the curvature param. Corresponds to `delta` in the terminology of [Hager and Zhang (2006)][2]. curvature_param: Positive scalar `Tensor` of real dtype. Bounded above by `1.`. Corresponds to 'sigma' in the terminology of [Hager and Zhang (2006)][2]. Returns: A namedtuple containing the following fields. converged: Boolean `Tensor` of shape [n]. Whether a point satisfying Wolfe/Approx wolfe was found. failed: Boolean `Tensor` of shape [n]. Whether line search failed e.g. if either the objective function or the gradient are not finite at an evaluation point. iterations: Scalar int32 `Tensor`. Number of line search iterations made. func_evals: Scalar int32 `Tensor`. Number of function evaluations made. left: A namedtuple, as returned by value_and_gradients_function, of the left end point of the updated bracketing interval. right: A namedtuple, as returned by value_and_gradients_function, of the right end point of the updated bracketing interval. """ bracket_result = hzl.bracket(value_and_gradients_function, init_interval, f_lim, max_iterations, expansion_param) converged = _very_close(bracket_result.left.x, bracket_result.right.x) # We fail if we have not yet converged but already exhausted all iterations. exhausted_iterations = ~converged & tf.greater_equal( bracket_result.iteration, max_iterations) line_search_args = HagerZhangLineSearchResult( converged=converged, failed=bracket_result.failed | exhausted_iterations, iterations=bracket_result.iteration, func_evals=bracket_result.num_evals, left=bracket_result.left, right=bracket_result.right) return _line_search_after_bracketing(value_and_gradients_function, line_search_args, init_interval.left, f_lim, max_iterations, sufficient_decrease_param, curvature_param, shrinkage_param)
def _bracket_and_search(value_and_gradients_function, val_0, val_c, f_lim, max_iterations, shrinkage_param=None, expansion_param=None, sufficient_decrease_param=None, curvature_param=None): """Brackets the minimum and performs a line search. Args: value_and_gradients_function: A Python callable that accepts a real scalar tensor and returns an object that can be converted to a namedtuple. The namedtuple should have fields 'f' and 'df' that correspond to scalar tensors of real dtype containing the value of the function and its derivative at that point. The other namedtuple fields, if present, should be tensors or sequences (possibly nested) of tensors. In usual optimization application, this function would be generated by projecting the multivariate objective function along some specific direction. The direction is determined by some other procedure but should be a descent direction (i.e. the derivative of the projected univariate function must be negative at 0.). val_0: Instance of `_FnDFn` containing the value and gradient of the objective at 0. The gradient must be negative (i.e. must be a descent direction). val_c: Instance of `_FnDFn` containing the initial step size and the value and gradient of the objective at the initial step size. The step size must be positive and finite. f_lim: Scalar `Tensor` of float dtype. max_iterations: Positive scalar `Tensor` of integral dtype. The maximum number of iterations to perform in the line search. The number of iterations used to bracket the minimum are also counted against this parameter. shrinkage_param: Scalar positive Tensor of real dtype. Must be less than `1.`. Corresponds to the parameter `gamma` in [Hager and Zhang (2006)][2]. expansion_param: Scalar positive `Tensor` of real dtype. Must be greater than `1.`. Used to expand the initial interval in case it does not bracket a minimum. Corresponds to `rho` in [Hager and Zhang (2006)][2]. sufficient_decrease_param: Positive scalar `Tensor` of real dtype. Bounded above by the curvature param. Corresponds to `delta` in the terminology of [Hager and Zhang (2006)][2]. curvature_param: Positive scalar `Tensor` of real dtype. Bounded above by `1.`. Corresponds to 'sigma' in the terminology of [Hager and Zhang (2006)][2]. Returns: A namedtuple containing the following fields. iteration: A scalar int32 `Tensor`. The number of iterations consumed. found_wolfe: A scalar boolean `Tensor`. Indicates whether a point satisfying the Wolfe conditions has been found. If this is True, the interval will be degenerate (i.e. left and right below will be identical). failed: A scalar boolean `Tensor`. Indicates if invalid function or gradient values were encountered (i.e. infinity or NaNs). num_evals: A scalar int32 `Tensor`. The total number of function evaluations made. left: Instance of _FnDFn. The position and the associated value and derivative at the updated left end point of the interval. right: Instance of _FnDFn. The position and the associated value and derivative at the updated right end point of the interval. """ bracket_result = hzl.bracket(value_and_gradients_function, val_0, val_c, f_lim, max_iterations, expansion_param=expansion_param) # If the bracketing failed, or we have already exhausted all the allowed # iterations, we return an error. failed = (bracket_result.failed | tf.greater_equal(bracket_result.iteration, max_iterations)) def _bracketing_failed_fn(): return _LineSearchInnerResult(iteration=bracket_result.iteration, found_wolfe=False, failed=True, num_evals=bracket_result.num_evals, left=val_0, right=val_c) def _bracketing_success_fn(): """Performs line search.""" result = _line_search_after_bracketing( value_and_gradients_function, val_0, bracket_result.left, bracket_result.right, f_lim, bracket_result.iteration, max_iterations, sufficient_decrease_param=sufficient_decrease_param, curvature_param=curvature_param, shrinkage_param=shrinkage_param) return _LineSearchInnerResult(iteration=result.iteration, found_wolfe=result.found_wolfe, failed=result.failed, num_evals=bracket_result.num_evals + result.num_evals, left=result.left, right=result.right) return prefer_static.cond(failed, true_fn=_bracketing_failed_fn, false_fn=_bracketing_success_fn)