def test_bounded_expression(self):
        """Tests that `BoundedExpression`s select their components correctly."""
        structure_memoizer = {
            defaults.DENOMINATOR_LOWER_BOUND_KEY: 0.0,
            defaults.GLOBAL_STEP_KEY: tf.Variable(0, dtype=tf.int32),
            defaults.VARIABLE_FN_KEY: tf.Variable
        }

        term1 = term.TensorTerm(1.0)
        term2 = term.TensorTerm(2.0)
        term3 = term.TensorTerm(4.0)
        term4 = term.TensorTerm(8.0)

        basic_expression1 = basic_expression.BasicExpression([term1])
        basic_expression2 = basic_expression.BasicExpression([term2])
        basic_expression3 = basic_expression.BasicExpression([term3])
        basic_expression4 = basic_expression.BasicExpression([term4])

        expression1 = expression.ExplicitExpression(basic_expression1,
                                                    basic_expression1)
        expression2 = expression.ExplicitExpression(basic_expression2,
                                                    basic_expression2)
        expression3 = expression.ExplicitExpression(basic_expression3,
                                                    basic_expression3)
        expression4 = expression.ExplicitExpression(basic_expression4,
                                                    basic_expression4)

        # Each of our BasicExpressions contains exactly one term, and while we might
        # negate it, by taking the absolute value we can uniquely determine which
        # BasicExpression is which.
        def term_value(expression_object):
            terms = expression_object.penalty_expression._terms
            self.assertEqual(1, len(terms))
            return abs(terms[0].tensor(structure_memoizer))

        bounded_expression1 = expression.BoundedExpression(
            lower_bound=expression1, upper_bound=expression2)
        self.assertEqual(term_value(bounded_expression1), 2.0)
        self.assertEqual(term_value(-bounded_expression1), 1.0)

        bounded_expression2 = expression.BoundedExpression(
            lower_bound=expression3, upper_bound=expression4)
        self.assertEqual(term_value(bounded_expression2), 8.0)
        self.assertEqual(term_value(-bounded_expression2), 4.0)

        bounded_expression3 = -(bounded_expression1 - bounded_expression2)
        self.assertEqual(term_value(bounded_expression3), 8.0 - 1.0)
        self.assertEqual(term_value(-bounded_expression3), 4.0 - 2.0)

        # Checks that nested BoundedExpressions work.
        bounded_expression4 = expression.BoundedExpression(
            lower_bound=bounded_expression1, upper_bound=expression3)
        self.assertEqual(term_value(bounded_expression4), 4.0)
        self.assertEqual(term_value(-bounded_expression4), 1.0)

        # Checks that nested negated BoundedExpressions work.
        bounded_expression5 = expression.BoundedExpression(
            lower_bound=-bounded_expression1, upper_bound=-bounded_expression2)
        self.assertEqual(term_value(bounded_expression5), 4.0)
        self.assertEqual(term_value(-bounded_expression5), 2.0)
def upper_bound(expressions):
    """Creates an `Expression` upper bounding the given expressions.

  This function introduces a slack variable, and adds constraints forcing this
  variable to upper bound all elements of the given expression list. It then
  returns the slack variable.

  If you're going to be upper-bounding or minimizing the result of this
  function, then you can think of it as taking the `max` of its arguments. You
  should *never* lower-bound or maximize the result, however, since the
  consequence would be to increase the value of the slack variable, without
  affecting the contents of the expressions list.

  Args:
    expressions: list of `Expression`s, the quantities to upper-bound.

  Returns:
    An `Expression` representing an upper bound on the given expressions.

  Raises:
    ValueError: if the expressions list is empty.
    TypeError: if the expressions list contains a non-`Expression`.
  """
    if not expressions:
        raise ValueError(
            "upper_bound cannot be given an empty expression list")
    if not all(isinstance(ee, expression.Expression) for ee in expressions):
        raise TypeError(
            "upper_bound expects a list of rate Expressions (perhaps you need to "
            "call wrap_rate() to create an Expression from a Tensor?)")

    # Ideally the slack variable would have the same dtype as the predictions, but
    # we might not know their dtype (e.g. in eager mode), so instead we always use
    # float32 with auto_cast=True.
    bound = deferred_tensor.DeferredVariable(0.0,
                                             trainable=True,
                                             name="tfco_upper_bound",
                                             dtype=tf.float32,
                                             auto_cast=True)

    bound_basic_expression = basic_expression.BasicExpression(
        [term.TensorTerm(bound)])
    bound_expression = expression.ExplicitExpression(
        penalty_expression=bound_basic_expression,
        constraint_expression=bound_basic_expression)
    extra_constraints = [ee <= bound_expression for ee in expressions]

    # We wrap the result in a BoundedExpression so that we'll check if the user
    # attempts to maximize of lower-bound the result of this function, and will
    # raise an error if they do.
    return expression.BoundedExpression(
        lower_bound=expression.InvalidExpression(
            "the result of a call to upper_bound() can only be minimized or "
            "upper-bounded; it *cannot* be maximized or lower-bounded"),
        upper_bound=expression.ConstrainedExpression(
            expression.ExplicitExpression(
                penalty_expression=bound_basic_expression,
                constraint_expression=bound_basic_expression),
            extra_constraints=extra_constraints))