Пример #1
0
 def test_false(self):
     conditions = [(False, raise_exception)]
     y = ps.case(conditions,
                 default=lambda: tf.constant(1),
                 exclusive=False)
     z = ps.case(conditions, default=lambda: tf.constant(1), exclusive=True)
     self.assertEqual(self.evaluate(y), 1)
     self.assertEqual(self.evaluate(z), 1)
Пример #2
0
    def test_true(self):
        x = tf.constant(0)
        conditions = [(True, lambda: tf.constant(1)),
                      (tf.equal(x, 1), raise_exception)]
        y = ps.case(conditions, default=raise_exception, exclusive=False)
        z = ps.case(conditions, default=raise_exception, exclusive=True)

        self.assertEqual(self.evaluate(y), 1)
        self.assertEqual(self.evaluate(z), 1)
Пример #3
0
  def testTrue(self):
    x = tf.constant(0)
    conditions = [(True, lambda: tf.constant(1)),
                  (x == 0, raise_exception)]
    y = prefer_static.case(conditions, default=raise_exception,
                           exclusive=False)
    z = prefer_static.case(conditions, default=raise_exception,
                           exclusive=True)

    self.assertEqual(self.evaluate(y), 1)
    self.assertEqual(self.evaluate(z), 1)
Пример #4
0
 def test_mix(self):
     x = tf.constant(0)
     y = tf.constant(10)
     conditions = [(x > 1, lambda: tf.constant(1)),
                   (y < 1, raise_exception_in_eager_mode(tf.constant(2))),
                   (tf.constant(False), raise_exception),
                   (tf.constant(True), lambda: tf.constant(3))]
     z = ps.case(conditions, default=lambda: raise_exception)
     self.assertEqual(self.evaluate(z), 3)
Пример #5
0
def nelder_mead_one_step(current_simplex,
                         current_objective_values,
                         objective_function=None,
                         dim=None,
                         func_tolerance=None,
                         position_tolerance=None,
                         batch_evaluate_objective=False,
                         reflection=None,
                         expansion=None,
                         contraction=None,
                         shrinkage=None,
                         name=None):
    """A single iteration of the Nelder Mead algorithm."""
    with tf.compat.v1.name_scope(name, 'nelder_mead_one_step'):
        domain_dtype = current_simplex.dtype.base_dtype
        order = tf.argsort(current_objective_values,
                           direction='ASCENDING',
                           stable=True)
        (best_index, worst_index,
         second_worst_index) = order[0], order[-1], order[-2]

        worst_vertex = current_simplex[worst_index]

        (best_objective_value, worst_objective_value,
         second_worst_objective_value) = (
             current_objective_values[best_index],
             current_objective_values[worst_index],
             current_objective_values[second_worst_index])

        # Compute the centroid of the face opposite the worst vertex.
        face_centroid = tf.reduce_sum(input_tensor=current_simplex,
                                      axis=0) - worst_vertex
        face_centroid /= tf.cast(dim, domain_dtype)

        # Reflect the worst vertex through the opposite face.
        reflected = face_centroid + reflection * (face_centroid - worst_vertex)
        objective_at_reflected = objective_function(reflected)

        num_evaluations = 1
        has_converged = _check_convergence(current_simplex,
                                           current_simplex[best_index],
                                           best_objective_value,
                                           worst_objective_value,
                                           func_tolerance, position_tolerance)

        def _converged_fn():
            return (True, current_simplex, current_objective_values, 0)

        case0 = has_converged, _converged_fn
        accept_reflected = (
            (objective_at_reflected < second_worst_objective_value) &
            (objective_at_reflected >= best_objective_value))
        accept_reflected_fn = _accept_reflected_fn(current_simplex,
                                                   current_objective_values,
                                                   worst_index, reflected,
                                                   objective_at_reflected)
        case1 = accept_reflected, accept_reflected_fn
        do_expansion = objective_at_reflected < best_objective_value
        expansion_fn = _expansion_fn(objective_function, current_simplex,
                                     current_objective_values, worst_index,
                                     reflected, objective_at_reflected,
                                     face_centroid, expansion)
        case2 = do_expansion, expansion_fn
        do_outside_contraction = (
            (objective_at_reflected < worst_objective_value) &
            (objective_at_reflected >= second_worst_objective_value))
        outside_contraction_fn = _outside_contraction_fn(
            objective_function, current_simplex, current_objective_values,
            face_centroid, best_index, worst_index, reflected,
            objective_at_reflected, contraction, shrinkage,
            batch_evaluate_objective)
        case3 = do_outside_contraction, outside_contraction_fn
        default_fn = _inside_contraction_fn(
            objective_function, current_simplex, current_objective_values,
            face_centroid, best_index, worst_index, worst_objective_value,
            contraction, shrinkage, batch_evaluate_objective)
        (converged, next_simplex, next_objective_at_simplex,
         case_evals) = prefer_static.case([case0, case1, case2, case3],
                                          default=default_fn,
                                          exclusive=False)
        next_simplex.set_shape(current_simplex.shape)
        next_objective_at_simplex.set_shape(current_objective_values.shape)
        return (converged, next_simplex, next_objective_at_simplex,
                num_evaluations + case_evals)
Пример #6
0
    def _loop_body(iteration, found_wolfe, failed, evals, val_left, val_right):  # pylint:disable=unused-argument
        """The loop body."""
        iteration += 1

        secant2_result = hzl.secant2(
            value_and_gradients_function,
            val_0,
            val_left,
            val_right,
            f_lim,
            sufficient_decrease_param=sufficient_decrease_param,
            curvature_param=curvature_param)

        evals += secant2_result.num_evals

        def _failed_fn():
            return _LineSearchInnerResult(iteration=iteration,
                                          found_wolfe=False,
                                          failed=True,
                                          num_evals=evals,
                                          left=val_left,
                                          right=val_right)

        def _converged_fn():
            return _LineSearchInnerResult(iteration=iteration,
                                          found_wolfe=True,
                                          failed=False,
                                          num_evals=evals,
                                          left=secant2_result.left,
                                          right=secant2_result.right)

        def _default_fn():
            """Default action."""
            new_width = secant2_result.right.x - secant2_result.left.x
            old_width = val_right.x - val_left.x
            sufficient_shrinkage = new_width < shrinkage_param * old_width

            def _sufficient_shrinkage_fn():
                """Action to perform if secant2 shrank the interval sufficiently."""
                func_is_flat = (
                    (tf.math.nextafter(val_left.f, val_right.f) >= val_right.f)
                    & (tf.math.nextafter(secant2_result.left.f,
                                         secant2_result.right.f) >=
                       secant2_result.right.f))
                is_flat_retval = _LineSearchInnerResult(
                    iteration=iteration,
                    found_wolfe=True,
                    failed=False,
                    num_evals=evals,
                    left=secant2_result.left,
                    right=secant2_result.left)
                not_is_flat_retval = _LineSearchInnerResult(
                    iteration=iteration,
                    found_wolfe=False,
                    failed=False,
                    num_evals=evals,
                    left=secant2_result.left,
                    right=secant2_result.right)

                return prefer_static.cond(func_is_flat,
                                          true_fn=lambda: is_flat_retval,
                                          false_fn=lambda: not_is_flat_retval)

            def _insufficient_shrinkage_fn():
                """Action to perform if secant2 didn't shrink the interval enough."""
                update_result = _line_search_inner_bisection(
                    value_and_gradients_function, secant2_result.left,
                    secant2_result.right, f_lim)
                return _LineSearchInnerResult(iteration=iteration,
                                              found_wolfe=False,
                                              failed=update_result.failed,
                                              num_evals=evals +
                                              update_result.num_evals,
                                              left=update_result.left,
                                              right=update_result.right)

            return prefer_static.cond(sufficient_shrinkage,
                                      true_fn=_sufficient_shrinkage_fn,
                                      false_fn=_insufficient_shrinkage_fn)

        return prefer_static.case([(secant2_result.failed, _failed_fn),
                                   (secant2_result.converged, _converged_fn)],
                                  default=_default_fn,
                                  exclusive=False)