def _body(state): """Main optimization loop.""" search_direction = _get_search_direction(state.inverse_hessian_estimate, state.objective_gradient) derivative_at_start_pt = tf.reduce_sum(state.objective_gradient * search_direction) # If the derivative at the start point is not negative, reset the # Hessian estimate and recompute the search direction. needs_reset = derivative_at_start_pt >= 0 def _reset_search_dirn(): search_direction = _get_search_direction(initial_inv_hessian, state.objective_gradient) return search_direction, initial_inv_hessian search_direction, inv_hessian_estimate = tf.contrib.framework.smart_cond( needs_reset, true_fn=_reset_search_dirn, false_fn=lambda: (search_direction, state.inverse_hessian_estimate)) line_search_value_grad_func = _restrict_along_direction( value_and_gradients_function, state.position, search_direction) derivative_at_start_pt = tf.reduce_sum(state.objective_gradient * search_direction) ls_result = linesearch.hager_zhang( line_search_value_grad_func, initial_step_size=tf.convert_to_tensor(1, dtype=dtype), objective_at_zero=state.objective_value, grad_objective_at_zero=derivative_at_start_pt) state_after_ls = _update_state( state, failed=~ls_result.converged, # Fail if line search failed. num_iterations=state.num_iterations + 1, num_objective_evaluations=( state.num_objective_evaluations + ls_result.func_evals), inverse_hessian_estimate=inv_hessian_estimate) def _do_bfgs_update(): state_updated = _update_position( value_and_gradients_function, state_after_ls, search_direction * ls_result.left_pt, tolerance, f_relative_tolerance, x_tolerance) # If not converged, update the Hessian. return tf.contrib.framework.smart_cond( state_updated.converged, lambda: state_updated, lambda: _update_inv_hessian(state_after_ls, state_updated)) next_state = tf.contrib.framework.smart_cond( state_after_ls.failed, true_fn=lambda: state_after_ls, false_fn=_do_bfgs_update) return [next_state]
def _body(state): """Main optimization loop.""" # We use notation of [HZ2006] for brevity. x_k = state.position d_k = state.direction f_k = state.objective_value g_k = state.objective_gradient a_km1 = state.prev_step # Means a_{k-1}. # Define scalar function, which is objective restricted to direction. def ls_func(alpha): pt = x_k + tf.expand_dims(alpha, axis=-1) * d_k objective_value, gradient = value_and_gradients_function(pt) return ValueAndGradient( x=alpha, f=objective_value, df=_dot(gradient, d_k), full_gradient=gradient) # Generate initial guess for line search. # [HZ2006] suggests to generate first initial guess separately, but # [JuliaLineSearches] generates it as if previous step length was 1, and # we do the same. phi_0 = f_k dphi_0 = _dot(g_k, d_k) ls_val_0 = ValueAndGradient( x=tf.zeros_like(phi_0), f=phi_0, df=dphi_0, full_gradient=g_k) step_guess_result = _init_step(ls_val_0, a_km1, ls_func, psi_1, psi_2, params.quad_step) init_step = step_guess_result.step # Check if initial step size already satisfies Wolfe condition, and in # that case don't perform line search. c = init_step.x phi_lim = phi_0 + eps * tf.abs(phi_0) phi_c = init_step.f dphi_c = init_step.df # Original Wolfe conditions, T1 in [HZ2006]. suff_decrease_1 = delta * dphi_0 >= (phi_c - phi_0) / c curvature = dphi_c >= sigma * dphi_0 wolfe1 = suff_decrease_1 & curvature # Approximate Wolfe conditions, T2 in [HZ2006]. suff_decrease_2 = (2 * delta - 1) * dphi_0 >= dphi_c curvature = dphi_c >= sigma * dphi_0 wolfe2 = suff_decrease_2 & curvature & (phi_c <= phi_lim) wolfe = wolfe1 | wolfe2 skip_line_search = (step_guess_result.may_terminate & wolfe) | state.failed | state.converged # Call Hager-Zhang line search (L0-L3 in [HZ2006]). # Parameter theta from [HZ2006] is not adjustable, it's always 0.5. ls_result = linesearch.hager_zhang( ls_func, value_at_zero=ls_val_0, converged=skip_line_search, initial_step_size=init_step.x, value_at_initial_step=init_step, shrinkage_param=params.shrinkage_param, expansion_param=params.expansion_param, sufficient_decrease_param=delta, curvature_param=sigma, threshold_use_approximate_wolfe_condition=eps) # Moving to the next point, using step length from line search. # If line search was skipped, take step length from initial guess. # To save objective evaluation, use objective value and gradient returned # by line search or initial guess. a_k = tf.where(skip_line_search, init_step.x, ls_result.left.x) x_kp1 = state.position + tf.expand_dims(a_k, -1) * d_k f_kp1 = tf.where(skip_line_search, init_step.f, ls_result.left.f) g_kp1 = tf.where(skip_line_search, init_step.full_gradient, ls_result.left.full_gradient) # Evaluate next direction. # Use formulas (2.7)-(2.11) from [HZ2013] with P_k=I. y_k = g_kp1 - g_k d_dot_y = _dot(d_k, y_k) b_k = (_dot(y_k, g_kp1) - _norm_sq(y_k) * _dot(g_kp1, d_k) / d_dot_y) / d_dot_y eta_k = eta * _dot(d_k, g_k) / _norm_sq(d_k) b_k = tf.maximum(b_k, eta_k) d_kp1 = -g_kp1 + tf.expand_dims(b_k, -1) * d_k # Check convergence criteria. grad_converged = _norm_inf(g_kp1) <= tolerance x_converged = (_norm_inf(x_kp1 - x_k) <= x_tolerance) f_converged = ( tf.math.abs(f_kp1 - f_k) <= f_relative_tolerance * tf.math.abs(f_k)) converged = grad_converged | x_converged | f_converged # Construct new state for next iteration. new_state = _OptimizerState( converged=converged, failed=state.failed, num_iterations=state.num_iterations + 1, num_objective_evaluations=state.num_objective_evaluations + step_guess_result.func_evals + ls_result.func_evals, position=tf.where(state.converged, x_k, x_kp1), objective_value=tf.where(state.converged, f_k, f_kp1), objective_gradient=tf.where(state.converged, g_k, g_kp1), direction=d_kp1, prev_step=a_k) return (new_state,)
def _body( _, failed, # pylint: disable=unused-argument num_iterations, total_evals, position, objective_value, objective_gradient, inv_hessian_estimate): """Main optimization loop.""" search_direction = _get_search_direction(inv_hessian_estimate, objective_gradient) line_search_value_grad_func = _restrict_along_direction( value_and_gradients_function, position, search_direction) derivative_at_start_pt = tf.reduce_sum(objective_gradient * search_direction) ls_result = linesearch.hager_zhang( line_search_value_grad_func, initial_step_size=tf.constant(1, dtype=dtype), objective_at_zero=objective_value, grad_objective_at_zero=derivative_at_start_pt) # If the line search failed, then quit at this point. failed_retval = BfgsOptimizerResults( converged=False, failed=True, num_iterations=num_iterations + 1, num_objective_evaluations=total_evals + ls_result.func_evals, position=position, objective_value=objective_value, objective_gradient=objective_gradient, inverse_hessian_estimate=inv_hessian_estimate) # Fail if the objective value is not finite or the line search failed. ls_failed_case = ( ~(tf.is_finite(objective_value) & ls_result.converged), lambda: failed_retval) # If the line search didn't fail, then either we need to continue # searching or need to stop because we have converged. position_delta = search_direction * ls_result.left_pt next_position = position + position_delta next_objective, next_objective_gradient = value_and_gradients_function( next_position) grad_norm = tf.norm(next_objective_gradient, ord=2) has_converged = grad_norm <= tolerance grad_delta = next_objective_gradient - objective_gradient updated_inv_hessian = _bfgs_inv_hessian_update( grad_delta, position_delta, inv_hessian_estimate) updated_inv_hessian.set_shape(inv_hessian_estimate.shape) converged_retval = BfgsOptimizerResults( converged=tf.constant(True, name='converged'), failed=tf.constant(False, name='failed'), num_iterations=tf.convert_to_tensor(num_iterations + 1, name='num_iterations'), num_objective_evaluations=tf.convert_to_tensor( total_evals + ls_result.func_evals + 1, name='num_objective_evaluations'), position=next_position, objective_value=next_objective, objective_gradient=next_objective_gradient, inverse_hessian_estimate=updated_inv_hessian) converged_case = (has_converged, lambda: converged_retval) default_retval = BfgsOptimizerResults( converged=tf.constant(False, name='converged'), failed=tf.constant(False, name='failed'), num_iterations=tf.convert_to_tensor(num_iterations + 1, name='num_iterations'), num_objective_evaluations=total_evals + ls_result.func_evals + 1, position=next_position, objective_value=next_objective, objective_gradient=next_objective_gradient, inverse_hessian_estimate=updated_inv_hessian) default_fn = lambda: default_retval return smart_cond.smart_case([ls_failed_case, converged_case], default=default_fn, exclusive=False)
def line_search_step(state, value_and_gradients_function, search_direction, grad_tolerance, f_relative_tolerance, x_tolerance): """Performs the line search step of the BFGS search procedure. Uses hager_zhang line search procedure to compute a suitable step size to advance the current `state.position` along the given `search_direction`. Also, if the line search is successful, updates the `state.position` by taking the corresponding step. Args: state: A namedtuple instance holding values for the current state of the search procedure. The state must include the fields: `position`, `objective_value`, `objective_gradient`, `num_iterations`, `num_objective_evaluations`, `converged` and `failed`. value_and_gradients_function: A Python callable that accepts a point as a real `Tensor` and returns a tuple of two tensors of the same dtype: the objective function value, a real scalar `Tensor`, and its derivative, a `Tensor` with the same shape as the input to the function. search_direction: A real `Tensor` of the same shape as the `state.position`. The direction along which to perform line search. grad_tolerance: Scalar `Tensor` of real dtype. Specifies the gradient tolerance for the procedure. f_relative_tolerance: Scalar `Tensor` of real dtype. Specifies the tolerance for the relative change in the objective value. x_tolerance: Scalar `Tensor` of real dtype. Specifies the tolerance for the change in the position. Returns: A copy of the input state with the following fields updated: converged: True if the convergence criteria has been met. failed: True if the line search procedure failed to converge, or if either the updated gradient or objective function are no longer finite. num_iterations: Increased by 1. num_objective_evaluations: Increased by the number of times that the objective function got evaluated. position, objective_value, objective_gradient: If line search succeeded, updated by computing the new position and evaluating the objective function at that position. """ dtype = state.position.dtype.base_dtype line_search_value_grad_func = _restrict_along_direction( value_and_gradients_function, state.position, search_direction) derivative_at_start_pt = tf.reduce_sum(state.objective_gradient * search_direction) ls_result = linesearch.hager_zhang( line_search_value_grad_func, initial_step_size=tf.convert_to_tensor(1, dtype=dtype), objective_at_zero=state.objective_value, grad_objective_at_zero=derivative_at_start_pt) state_after_ls = update_fields( state, failed=~ls_result.converged, # Fail if line search failed to converge. num_iterations=state.num_iterations + 1, num_objective_evaluations=(state.num_objective_evaluations + ls_result.func_evals)) def _do_update_position(): return _update_position(value_and_gradients_function, state_after_ls, search_direction * ls_result.left_pt, grad_tolerance, f_relative_tolerance, x_tolerance) return tf.contrib.framework.smart_cond(state_after_ls.failed, true_fn=lambda: state_after_ls, false_fn=_do_update_position)
def line_search_step(state, value_and_gradients_function, search_direction, grad_tolerance, f_relative_tolerance, x_tolerance, stopping_condition, max_iterations): """Performs the line search step of the BFGS search procedure. Uses hager_zhang line search procedure to compute a suitable step size to advance the current `state.position` along the given `search_direction`. Also, if the line search is successful, updates the `state.position` by taking the corresponding step. Args: state: A namedtuple instance holding values for the current state of the search procedure. The state must include the fields: `position`, `objective_value`, `objective_gradient`, `num_iterations`, `num_objective_evaluations`, `converged` and `failed`. value_and_gradients_function: A Python callable that accepts a point as a real `Tensor` of shape `[..., n]` and returns a tuple of two tensors of the same dtype: the objective function value, a real `Tensor` of shape `[...]`, and its derivative, another real `Tensor` of shape `[..., n]`. search_direction: A real `Tensor` of shape `[..., n]`. The direction along which to perform line search. grad_tolerance: Scalar `Tensor` of real dtype. Specifies the gradient tolerance for the procedure. f_relative_tolerance: Scalar `Tensor` of real dtype. Specifies the tolerance for the relative change in the objective value. x_tolerance: Scalar `Tensor` of real dtype. Specifies the tolerance for the change in the position. stopping_condition: A Python function that takes as input two Boolean tensors of shape `[...]`, and returns a Boolean scalar tensor. The input tensors are `converged` and `failed`, indicating the current status of each respective batch member; the return value states whether the algorithm should stop. max_iterations: A Python integer that is used as the maximum number of iterations of the hager_zhang line search algorithm Returns: A copy of the input state with the following fields updated: converged: a Boolean `Tensor` of shape `[...]` indicating whether the convergence criteria has been met. failed: a Boolean `Tensor` of shape `[...]` indicating whether the line search procedure failed to converge, or if either the updated gradient or objective function are no longer finite. num_iterations: Increased by 1. num_objective_evaluations: Increased by the number of times that the objective function got evaluated. position, objective_value, objective_gradient: If line search succeeded, updated by computing the new position and evaluating the objective function at that position. """ line_search_value_grad_func = _restrict_along_direction( value_and_gradients_function, state.position, search_direction) derivative_at_start_pt = tf.reduce_sum( state.objective_gradient * search_direction, axis=-1) val_0 = ValueAndGradient(x=_broadcast(0, state.position), f=state.objective_value, df=derivative_at_start_pt, full_gradient=state.objective_gradient) inactive = state.failed | state.converged ls_result = linesearch.hager_zhang( line_search_value_grad_func, initial_step_size=_broadcast(1, state.position), value_at_zero=val_0, converged=inactive, max_iterations=max_iterations) # No search needed for these. state_after_ls = update_fields( state, failed=state.failed | ~ls_result.converged, num_iterations=state.num_iterations + 1, num_objective_evaluations=( state.num_objective_evaluations + ls_result.func_evals)) def _do_update_position(): # For inactive batch members `left.x` is zero. However, their # `search_direction` might also be undefined, so we can't rely on # multiplication by zero to produce a `position_delta` of zero. position_delta = tf.where( inactive[..., tf.newaxis], dtype_util.as_numpy_dtype(search_direction.dtype)(0), search_direction * ls_result.left.x[..., tf.newaxis]) return _update_position( state_after_ls, position_delta, ls_result.left.f, ls_result.left.full_gradient, grad_tolerance, f_relative_tolerance, x_tolerance) return prefer_static.cond( stopping_condition(state.converged, state.failed), true_fn=lambda: state_after_ls, false_fn=_do_update_position)
def _body( converged, # pylint: disable=unused-argument stopped, # pylint: disable=unused-argument iteration, total_evals, position, objective_value, objective_gradient, input_inv_hessian_estimate): """Main optimization loop.""" search_direction = _get_search_direction( input_inv_hessian_estimate, objective_gradient) derivative_at_start_pt = tf.reduce_sum(objective_gradient * search_direction) # If the derivative at the start point is not negative, reset the # Hessian estimate and recompute the search direction. needs_reset = derivative_at_start_pt >= 0 def _reset_search_dirn(): search_direction = _get_search_direction( initial_inv_hessian, objective_gradient) return search_direction, initial_inv_hessian search_direction, inv_hessian_estimate = smart_cond.smart_cond( needs_reset, true_fn=_reset_search_dirn, false_fn=lambda: (search_direction, input_inv_hessian_estimate)) line_search_value_grad_func = _restrict_along_direction( value_and_gradients_function, position, search_direction) derivative_at_start_pt = tf.reduce_sum(objective_gradient * search_direction) ls_result = linesearch.hager_zhang( line_search_value_grad_func, initial_step_size=tf.convert_to_tensor(1, dtype=dtype), objective_at_zero=objective_value, grad_objective_at_zero=derivative_at_start_pt) # Fail if the objective value is not finite or the line search failed. ls_failed = ~ls_result.converged # If the line search failed, then quit at this point. def _failed_fn(): """Line search failed action.""" failed_retval = BfgsOptimizerResults( converged=False, failed=True, num_iterations=iteration + 1, num_objective_evaluations=total_evals + ls_result.func_evals, position=position, objective_value=objective_value, objective_gradient=objective_gradient, inverse_hessian_estimate=inv_hessian_estimate) return failed_retval def _success_fn(): return _bfgs_update(value_and_gradients_function, position, objective_value, objective_gradient, search_direction, inv_hessian_estimate, ls_result.left_pt, iteration, total_evals + ls_result.func_evals, tolerance, f_relative_tolerance, x_tolerance) return smart_cond.smart_cond(ls_failed, true_fn=_failed_fn, false_fn=_success_fn)
def _body(_, failed, # pylint: disable=unused-argument total_evals, position, objective_value, objective_gradient, inv_hessian_estimate): """Main optimization loop.""" search_direction = _get_search_direction(inv_hessian_estimate, objective_gradient) line_search_value_grad_func = _restrict_along_direction( value_and_gradients_function, position, search_direction) derivative_at_start_pt = tf.reduce_sum(objective_gradient * search_direction) ls_result = linesearch.hager_zhang( line_search_value_grad_func, initial_step_size=tf.constant(1, dtype=dtype), objective_at_zero=objective_value, grad_objective_at_zero=derivative_at_start_pt) # If the line search failed, then quit at this point. failed_retval = BfgsOptimizerResults( converged=False, failed=True, num_objective_evaluations=total_evals + ls_result.func_evals, position=position, objective_value=objective_value, objective_gradient=objective_gradient, inverse_hessian_estimate=inv_hessian_estimate) ls_failed_case = (~ls_result.converged, lambda: failed_retval) # If the line search didn't fail, then either we need to continue # searching or need to stop because we have converged. position_delta = search_direction * ls_result.left_pt next_position = position + position_delta next_objective, next_objective_gradient = value_and_gradients_function( next_position) grad_norm = tf.norm(next_objective_gradient, ord=2) has_converged = grad_norm <= tolerance grad_delta = next_objective_gradient - objective_gradient updated_inv_hessian = _bfgs_inv_hessian_update(grad_delta, position_delta, inv_hessian_estimate) updated_inv_hessian.set_shape(inv_hessian_estimate.shape) converged_retval = BfgsOptimizerResults( converged=True, failed=False, num_objective_evaluations=total_evals + ls_result.func_evals + 1, position=next_position, objective_value=next_objective, objective_gradient=next_objective_gradient, inverse_hessian_estimate=updated_inv_hessian) converged_case = (has_converged, lambda: converged_retval) default_retval = BfgsOptimizerResults( converged=False, failed=False, num_objective_evaluations=total_evals + ls_result.func_evals + 1, position=next_position, objective_value=next_objective, objective_gradient=next_objective_gradient, inverse_hessian_estimate=updated_inv_hessian) default_fn = lambda: default_retval return smart_cond.smart_case([ls_failed_case, converged_case], default=default_fn, exclusive=False)