def test_false(self): conditions = [(False, raise_exception)] y = ps.case(conditions, default=lambda: tf.constant(1), exclusive=False) z = ps.case(conditions, default=lambda: tf.constant(1), exclusive=True) self.assertEqual(self.evaluate(y), 1) self.assertEqual(self.evaluate(z), 1)
def test_true(self): x = tf.constant(0) conditions = [(True, lambda: tf.constant(1)), (tf.equal(x, 1), raise_exception)] y = ps.case(conditions, default=raise_exception, exclusive=False) z = ps.case(conditions, default=raise_exception, exclusive=True) self.assertEqual(self.evaluate(y), 1) self.assertEqual(self.evaluate(z), 1)
def testTrue(self): x = tf.constant(0) conditions = [(True, lambda: tf.constant(1)), (x == 0, raise_exception)] y = prefer_static.case(conditions, default=raise_exception, exclusive=False) z = prefer_static.case(conditions, default=raise_exception, exclusive=True) self.assertEqual(self.evaluate(y), 1) self.assertEqual(self.evaluate(z), 1)
def test_mix(self): x = tf.constant(0) y = tf.constant(10) conditions = [(x > 1, lambda: tf.constant(1)), (y < 1, raise_exception_in_eager_mode(tf.constant(2))), (tf.constant(False), raise_exception), (tf.constant(True), lambda: tf.constant(3))] z = ps.case(conditions, default=lambda: raise_exception) self.assertEqual(self.evaluate(z), 3)
def nelder_mead_one_step(current_simplex, current_objective_values, objective_function=None, dim=None, func_tolerance=None, position_tolerance=None, batch_evaluate_objective=False, reflection=None, expansion=None, contraction=None, shrinkage=None, name=None): """A single iteration of the Nelder Mead algorithm.""" with tf.compat.v1.name_scope(name, 'nelder_mead_one_step'): domain_dtype = current_simplex.dtype.base_dtype order = tf.argsort(current_objective_values, direction='ASCENDING', stable=True) (best_index, worst_index, second_worst_index) = order[0], order[-1], order[-2] worst_vertex = current_simplex[worst_index] (best_objective_value, worst_objective_value, second_worst_objective_value) = ( current_objective_values[best_index], current_objective_values[worst_index], current_objective_values[second_worst_index]) # Compute the centroid of the face opposite the worst vertex. face_centroid = tf.reduce_sum(input_tensor=current_simplex, axis=0) - worst_vertex face_centroid /= tf.cast(dim, domain_dtype) # Reflect the worst vertex through the opposite face. reflected = face_centroid + reflection * (face_centroid - worst_vertex) objective_at_reflected = objective_function(reflected) num_evaluations = 1 has_converged = _check_convergence(current_simplex, current_simplex[best_index], best_objective_value, worst_objective_value, func_tolerance, position_tolerance) def _converged_fn(): return (True, current_simplex, current_objective_values, 0) case0 = has_converged, _converged_fn accept_reflected = ( (objective_at_reflected < second_worst_objective_value) & (objective_at_reflected >= best_objective_value)) accept_reflected_fn = _accept_reflected_fn(current_simplex, current_objective_values, worst_index, reflected, objective_at_reflected) case1 = accept_reflected, accept_reflected_fn do_expansion = objective_at_reflected < best_objective_value expansion_fn = _expansion_fn(objective_function, current_simplex, current_objective_values, worst_index, reflected, objective_at_reflected, face_centroid, expansion) case2 = do_expansion, expansion_fn do_outside_contraction = ( (objective_at_reflected < worst_objective_value) & (objective_at_reflected >= second_worst_objective_value)) outside_contraction_fn = _outside_contraction_fn( objective_function, current_simplex, current_objective_values, face_centroid, best_index, worst_index, reflected, objective_at_reflected, contraction, shrinkage, batch_evaluate_objective) case3 = do_outside_contraction, outside_contraction_fn default_fn = _inside_contraction_fn( objective_function, current_simplex, current_objective_values, face_centroid, best_index, worst_index, worst_objective_value, contraction, shrinkage, batch_evaluate_objective) (converged, next_simplex, next_objective_at_simplex, case_evals) = prefer_static.case([case0, case1, case2, case3], default=default_fn, exclusive=False) next_simplex.set_shape(current_simplex.shape) next_objective_at_simplex.set_shape(current_objective_values.shape) return (converged, next_simplex, next_objective_at_simplex, num_evaluations + case_evals)
def _loop_body(iteration, found_wolfe, failed, evals, val_left, val_right): # pylint:disable=unused-argument """The loop body.""" iteration += 1 secant2_result = hzl.secant2( value_and_gradients_function, val_0, val_left, val_right, f_lim, sufficient_decrease_param=sufficient_decrease_param, curvature_param=curvature_param) evals += secant2_result.num_evals def _failed_fn(): return _LineSearchInnerResult(iteration=iteration, found_wolfe=False, failed=True, num_evals=evals, left=val_left, right=val_right) def _converged_fn(): return _LineSearchInnerResult(iteration=iteration, found_wolfe=True, failed=False, num_evals=evals, left=secant2_result.left, right=secant2_result.right) def _default_fn(): """Default action.""" new_width = secant2_result.right.x - secant2_result.left.x old_width = val_right.x - val_left.x sufficient_shrinkage = new_width < shrinkage_param * old_width def _sufficient_shrinkage_fn(): """Action to perform if secant2 shrank the interval sufficiently.""" func_is_flat = ( (tf.math.nextafter(val_left.f, val_right.f) >= val_right.f) & (tf.math.nextafter(secant2_result.left.f, secant2_result.right.f) >= secant2_result.right.f)) is_flat_retval = _LineSearchInnerResult( iteration=iteration, found_wolfe=True, failed=False, num_evals=evals, left=secant2_result.left, right=secant2_result.left) not_is_flat_retval = _LineSearchInnerResult( iteration=iteration, found_wolfe=False, failed=False, num_evals=evals, left=secant2_result.left, right=secant2_result.right) return prefer_static.cond(func_is_flat, true_fn=lambda: is_flat_retval, false_fn=lambda: not_is_flat_retval) def _insufficient_shrinkage_fn(): """Action to perform if secant2 didn't shrink the interval enough.""" update_result = _line_search_inner_bisection( value_and_gradients_function, secant2_result.left, secant2_result.right, f_lim) return _LineSearchInnerResult(iteration=iteration, found_wolfe=False, failed=update_result.failed, num_evals=evals + update_result.num_evals, left=update_result.left, right=update_result.right) return prefer_static.cond(sufficient_shrinkage, true_fn=_sufficient_shrinkage_fn, false_fn=_insufficient_shrinkage_fn) return prefer_static.case([(secant2_result.failed, _failed_fn), (secant2_result.converged, _converged_fn)], default=_default_fn, exclusive=False)