def upper_bound(expressions):
    """Creates an `Expression` upper bounding the given expressions.

  This function introduces a slack variable, and adds constraints forcing this
  variable to upper bound all elements of the given expression list. It then
  returns the slack variable.

  If you're going to be upper-bounding or minimizing the result of this
  function, then you can think of it as taking the `max` of its arguments. You
  should *never* lower-bound or maximize the result, however, since the
  consequence would be to increase the value of the slack variable, without
  affecting the contents of the expressions list.

  Args:
    expressions: list of `Expression`s, the quantities to upper-bound.

  Returns:
    An `Expression` representing an upper bound on the given expressions.

  Raises:
    ValueError: if the expressions list is empty.
    TypeError: if the expressions list contains a non-`Expression`.
  """
    if not expressions:
        raise ValueError(
            "upper_bound cannot be given an empty expression list")
    if not all(isinstance(ee, expression.Expression) for ee in expressions):
        raise TypeError(
            "upper_bound expects a list of rate Expressions (perhaps you need to "
            "call wrap_rate() to create an Expression from a Tensor?)")

    # Ideally the slack variable would have the same dtype as the predictions, but
    # we might not know their dtype (e.g. in eager mode), so instead we always use
    # float32 with auto_cast=True.
    bound = deferred_tensor.DeferredVariable(0.0,
                                             trainable=True,
                                             name="tfco_upper_bound",
                                             dtype=tf.float32,
                                             auto_cast=True)

    bound_basic_expression = basic_expression.BasicExpression(
        [term.TensorTerm(bound)])
    bound_expression = expression.ExplicitExpression(
        penalty_expression=bound_basic_expression,
        constraint_expression=bound_basic_expression)
    extra_constraints = [ee <= bound_expression for ee in expressions]

    # We wrap the result in a BoundedExpression so that we'll check if the user
    # attempts to maximize of lower-bound the result of this function, and will
    # raise an error if they do.
    return expression.BoundedExpression(
        lower_bound=expression.InvalidExpression(
            "the result of a call to upper_bound() can only be minimized or "
            "upper-bounded; it *cannot* be maximized or lower-bounded"),
        upper_bound=expression.ConstrainedExpression(
            expression.ExplicitExpression(
                penalty_expression=bound_basic_expression,
                constraint_expression=bound_basic_expression),
            extra_constraints=extra_constraints))
Ejemplo n.º 2
0
 def test_error_expression(self):
     """Tests that `InvalidExpression`s raise when used."""
     error_expression = expression.InvalidExpression("an error message")
     # All three of "penalty_expression", "constraint_expression" and
     # "extra_constraints" should raise.
     with self.assertRaises(RuntimeError):
         _ = error_expression.penalty_expression
     with self.assertRaises(RuntimeError):
         _ = error_expression.constraint_expression
     with self.assertRaises(RuntimeError):
         _ = error_expression.extra_constraints