Ejemplo n.º 1
0
    def test_ratio_weights_memoizer(self):
        """Tests memoization."""
        memoizer = {
            defaults.DENOMINATOR_LOWER_BOUND_KEY: 0.0,
            defaults.GLOBAL_STEP_KEY: tf.compat.v2.Variable(0, dtype=tf.int32)
        }

        weights_tensor = deferred_tensor.DeferredTensor(
            tf.constant([0.5, 0.1, 1.0], dtype=tf.float32))
        numerator1_tensor = deferred_tensor.DeferredTensor(
            tf.constant([True, False, True], dtype=tf.bool))
        numerator2_tensor = deferred_tensor.DeferredTensor(
            tf.constant([True, True, False], dtype=tf.bool))
        numerator1_predicate = predicate.Predicate(numerator1_tensor)
        numerator2_predicate = predicate.Predicate(numerator2_tensor)
        denominator_predicate = predicate.Predicate(True)

        ratio_weights1 = term._RatioWeights.ratio(weights_tensor,
                                                  numerator1_predicate,
                                                  denominator_predicate)
        ratio_weights2 = term._RatioWeights.ratio(weights_tensor,
                                                  numerator2_predicate,
                                                  denominator_predicate)
        result1, variables1 = ratio_weights1.evaluate(memoizer)
        result2, variables2 = ratio_weights2.evaluate(memoizer)

        # The numerators differ, so the results should be different, but the
        # weights and denominators match, so the variables should be the same.
        self.assertIsNot(result1, result2)
        self.assertEqual(variables1, variables2)
Ejemplo n.º 2
0
    def test_arithmetic(self):
        """Tests `Expression`'s arithmetic operators."""
        memoizer = {
            defaults.DENOMINATOR_LOWER_BOUND_KEY: 0.0,
            defaults.GLOBAL_STEP_KEY: tf.compat.v2.Variable(0, dtype=tf.int32)
        }

        penalty_values = [-3.6, 1.5, 0.4]
        constraint_values = [-0.2, -0.5, 2.3]

        # Create three expressions containing the constants in "penalty_values" in
        # their penalty_expressions, and "constraint_values" in their
        # constraint_expressions.
        expression_objects = []
        for penalty_value, constraint_value in zip(penalty_values,
                                                   constraint_values):
            expression_object = expression.Expression(
                basic_expression.BasicExpression(
                    [],
                    deferred_tensor.DeferredTensor(
                        tf.constant(penalty_value, dtype=tf.float32))),
                basic_expression.BasicExpression(
                    [],
                    deferred_tensor.DeferredTensor(
                        tf.constant(constraint_value))))
            expression_objects.append(expression_object)

        # This expression exercises all of the operators.
        expression_object = (
            0.3 - (expression_objects[0] / 2.3 + 0.7 * expression_objects[1]) -
            (1.2 + expression_objects[2] - 0.1) * 0.6 + 0.8)

        actual_penalty_value, penalty_variables = (
            expression_object.penalty_expression.evaluate(memoizer))
        actual_constraint_value, constraint_variables = (
            expression_object.constraint_expression.evaluate(memoizer))

        # We need to explicitly create the variables before creating the wrapped
        # session.
        variables = deferred_tensor.DeferredVariableList(penalty_variables +
                                                         constraint_variables)
        for variable in variables:
            variable.create(memoizer)

        # This is the same expression as above, applied directly to the python
        # floats.
        expected_penalty_value = (
            0.3 - (penalty_values[0] / 2.3 + 0.7 * penalty_values[1]) -
            (1.2 + penalty_values[2] - 0.1) * 0.6 + 0.8)
        expected_constraint_value = (
            0.3 - (constraint_values[0] / 2.3 + 0.7 * constraint_values[1]) -
            (1.2 + constraint_values[2] - 0.1) * 0.6 + 0.8)

        with self.wrapped_session() as session:
            self.assertNear(expected_penalty_value,
                            session.run(actual_penalty_value(memoizer)),
                            err=1e-6)
            self.assertNear(expected_constraint_value,
                            session.run(actual_constraint_value(memoizer)),
                            err=1e-6)
    def test_callable(self):
        """Tests that callbacks are not called until needed."""
        # Keeps track of whether the callbacks have been called.
        memoizer = {
            defaults.DENOMINATOR_LOWER_BOUND_KEY: 0.0,
            defaults.GLOBAL_STEP_KEY: tf.compat.v2.Variable(0, dtype=tf.int32)
        }

        callback_list = []

        def callback1():
            callback_list.append("callback1")
            return 3.14

        def callback2():
            callback_list.append("callback2")
            return 4

        tensor1 = deferred_tensor.DeferredTensor(callback1)
        tensor2 = deferred_tensor.DeferredTensor(callback2)
        expression = tensor1 / tensor2

        # When we created the above expression, it should have created a closure,
        # instead of evaluating the arguments and performing the division.
        self.assertEmpty(callback_list)

        # We don't need to use a Session here, since the callbacks return scalars.
        self.assertAllEqual(0.785, expression(memoizer))

        # Now that we've called expression(memoizer), the callbacks should each have
        # been called once.
        self.assertAllEqual(["callback1", "callback2"], sorted(callback_list))
Ejemplo n.º 4
0
 def create_binary_classification_term(predictions_tensor,
                                       positive_weights_tensor,
                                       negative_weights_tensor):
     positive_ratio_weights = term._RatioWeights.ratio(
         deferred_tensor.DeferredTensor(positive_weights_tensor),
         predicate.Predicate(True), predicate.Predicate(True))
     negative_ratio_weights = term._RatioWeights.ratio(
         deferred_tensor.DeferredTensor(negative_weights_tensor),
         predicate.Predicate(True), predicate.Predicate(True))
     return term.BinaryClassificationTerm(
         deferred_tensor.DeferredTensor(predictions_tensor),
         positive_ratio_weights, negative_ratio_weights,
         loss.HingeLoss())
Ejemplo n.º 5
0
  def __init__(  # pylint: disable=invalid-name
      self, tensor, _convert_and_clip=True):
    """Creates a new `Predicate`.

    Args:
      tensor: an object convertible to a rank-1 `Tensor` (e.g. a scalar, list,
        numpy array, or a `Tensor` itself), or a nullary function returning such
        an object, or a DeferredTensor. This object will be converted to a
        float32 `DeferredTensor` and clipped to [0,1].
      _convert_and_clip: private Boolean. If False, "tensor" will not be
        converted to a float32 `DeferredTensor` and clipped. This is for
        internal use *only*.
    """
    if isinstance(tensor, Predicate):
      raise ValueError("cannot create a Predicate from a Predicate")

    def convert_and_clip_fn(arg):
      """Converts the given object to a rank-one float32 `Tensor` in [0,1]."""
      return tf.clip_by_value(
          tf.cast(
              helpers.convert_to_1d_tensor(arg, "predicate"), dtype=tf.float32),
          0.0, 1.0)

    self._tensor = tensor
    if not isinstance(self._tensor, deferred_tensor.DeferredTensor):
      self._tensor = deferred_tensor.DeferredTensor(self._tensor)
    if _convert_and_clip:
      self._tensor = deferred_tensor.DeferredTensor.apply(
          convert_and_clip_fn, self._tensor)
    def test_type_promotion(self):
        """Tests that automatic type promotion works as expected."""
        memoizer = {
            defaults.DENOMINATOR_LOWER_BOUND_KEY: 0.0,
            defaults.GLOBAL_STEP_KEY: tf.compat.v2.Variable(0, dtype=tf.int32)
        }

        tensor1 = deferred_tensor.DeferredTensor(tf.constant(-2,
                                                             dtype=tf.int16),
                                                 auto_cast=True)
        tensor2 = deferred_tensor.DeferredTensor(
            lambda: tf.constant(1.5, dtype=tf.float32), auto_cast=True)
        tensor3 = deferred_tensor.DeferredTensor(tf.constant(2.7,
                                                             dtype=tf.float32),
                                                 auto_cast=True)
        tensor4 = deferred_tensor.DeferredTensor(tf.constant(0.3,
                                                             dtype=tf.float64),
                                                 auto_cast=True)

        expression5 = tensor1 + tensor2
        expression6 = tensor3 / tensor4
        expression7 = expression5 * expression6

        value1 = tensor1(memoizer)
        value2 = tensor2(memoizer)
        value3 = tensor3(memoizer)
        value4 = tensor4(memoizer)
        value5 = expression5(memoizer)
        value6 = expression6(memoizer)
        value7 = expression7(memoizer)

        self.assertEqual(tf.int16, value1.dtype.base_dtype)
        self.assertEqual(tf.float32, value2.dtype.base_dtype)
        self.assertEqual(tf.float32, value3.dtype.base_dtype)
        self.assertEqual(tf.float64, value4.dtype.base_dtype)
        self.assertEqual(tf.float32, value5.dtype.base_dtype)
        self.assertEqual(tf.float64, value6.dtype.base_dtype)
        self.assertEqual(tf.float64, value7.dtype.base_dtype)

        with self.wrapped_session() as session:
            self.assertAllClose(-2, session.run(value1))
            self.assertAllClose(1.5, session.run(value2))
            self.assertAllClose(2.7, session.run(value3))
            self.assertAllClose(0.3, session.run(value4))
            self.assertAllClose(-0.5, session.run(value5))
            self.assertAllClose(9, session.run(value6))
            self.assertAllClose(-4.5, session.run(value7))
Ejemplo n.º 7
0
def wrap_rate(penalty_tensor, constraint_tensor=None):
  """Creates an `Expression` representing the given `Tensor`(s).

  The reason an `Expression` contains two `BasicExpression`s is that the
  "penalty" `BasicExpression` will be differentiable, while the "constraint"
  `BasicExpression` need not be. During optimization, the former will be used
  whenever we need to take gradients, and the latter otherwise.

  Args:
    penalty_tensor: scalar `Tensor`, the quantity to store in the "penalty"
      portion of the result (and also the "constraint" portion, if
      constraint_tensor is not provided).
    constraint_tensor: scalar `Tensor`, the quantity to store in the
      "constraint" portion of the result.

  Returns:
    An `Expression` wrapping the given `Tensor`(s).

  Raises:
    TypeError: if wrap_rate() is called on an `Expression`.
  """
  # Ideally, we'd check that "penalty_tensor" and "constraint_tensor" are scalar
  # Tensors, or are types that can be converted to a scalar Tensor.
  # Unfortunately, this includes a lot of possible types, so the easiest
  # solution would be to actually perform the conversion, and then check that
  # the resulting Tensor has only one element. This, however, would add a dummy
  # element to the Tensorflow graph, and wouldn't work for a Tensor with an
  # unknown size. Hence, we only check that "penalty_tensor" and
  # "constraint_tensor" are not types that we know for certain are disallowed:
  # objects internal to this library.
  if (isinstance(penalty_tensor, helpers.RateObject) or
      isinstance(constraint_tensor, helpers.RateObject)):
    raise TypeError("you cannot wrap an object that has already been wrapped")

  penalty_basic_expression = basic_expression.BasicExpression(
      terms=[], tensor=deferred_tensor.DeferredTensor(penalty_tensor))
  if constraint_tensor is None:
    constraint_basic_expression = penalty_basic_expression
  else:
    constraint_basic_expression = basic_expression.BasicExpression(
        terms=[], tensor=deferred_tensor.DeferredTensor(constraint_tensor))
  return expression.Expression(penalty_basic_expression,
                               constraint_basic_expression)
    def __init__(self, terms, tensor=0.0):
        """Creates a new `BasicExpression`.

    The reason for taking a collection of `Term`s, instead of only a single
    `Term` representing the entire linear combination, is that, unlike
    `Tensor`s, two `Term`s can only be added or subtracted if they're
    "compatible" (which is a notion defined by the `Term` itself).

    Args:
      terms: collection of `Term`s to sum in the `BasicExpression`.
      tensor: optional scalar `DeferredTensor` or `Tensor`-like object to add to
        the sum of `Term`s.

    Raises:
      TypeError: if "tensor" is not a `DeferredTensor` or `Tensor`-like object.
    """
        # This object contains two member variables: "_terms", representing a linear
        # combination of Term objects, and "_tensor", representing an additional
        # Tensor-like object to include in the sum. The "_tensor" variable is
        # capable of representing a linear combination of Tensors, since Tensors
        # support negation, addition, subtraction, scalar multiplication and
        # division.
        #
        # It isn't so simple for Terms. Like Tensors, they support negation, scalar
        # multiplication and division without restriction. Unlike Tensors, however,
        # only "compatible" Terms may be added or subtracted. Two Terms are
        # compatible iff they have the same key (returned by their "key" method).
        # When we add or subtract two BasicExpressions, compatible Terms are added
        # or subtracted within the _terms list, and incompatible Terms are appended
        # to the list.
        #
        # We use a list to store the Terms (instead of e.g. a dict mapping keys to
        # Terms), so that we can preserve the order of the Terms (of course, this
        # isn't the only way to handle it). This is needed to support distributed
        # optimization, since it results in the DeferredVariables upon which the
        # Terms depend having a consistent order from machine-to-machine.
        self._terms = []
        self._add_terms(terms)
        if isinstance(tensor, deferred_tensor.DeferredTensor):
            self._tensor = tensor
        elif not isinstance(tensor, helpers.RateObject):
            self._tensor = deferred_tensor.DeferredTensor(tensor)
        else:
            raise TypeError(
                "tensor argument to BasicExpression's constructor "
                "should be a DeferredTensor or a Tensor-like object")
Ejemplo n.º 9
0
 def create_ratio_weights(weights_tensor):
     return term._RatioWeights.ratio(
         deferred_tensor.DeferredTensor(weights_tensor),
         predicate.Predicate(True), predicate.Predicate(True))
Ejemplo n.º 10
0
    def test_ratio_weights_ratio(self):
        """Tests `_RatioWeights`'s ratio() class method."""
        weights_placeholder = self.wrapped_placeholder(tf.float32,
                                                       shape=(None, ))
        numerator_predicate_placeholder = self.wrapped_placeholder(
            tf.bool, shape=(None, ))
        denominator_predicate_placeholder = self.wrapped_placeholder(
            tf.bool, shape=(None, ))
        memoizer = {
            defaults.DENOMINATOR_LOWER_BOUND_KEY: 0.0,
            defaults.GLOBAL_STEP_KEY: tf.compat.v2.Variable(0, dtype=tf.int32)
        }

        numerator_predicate = predicate.Predicate(
            numerator_predicate_placeholder)
        denominator_predicate = predicate.Predicate(
            denominator_predicate_placeholder)
        ratio_weights = term._RatioWeights.ratio(
            deferred_tensor.DeferredTensor(weights_placeholder),
            numerator_predicate, denominator_predicate)
        actual_weights, variables = ratio_weights.evaluate(memoizer)

        # We need to explicitly create the variables before creating the wrapped
        # session.
        for variable in variables:
            variable.create(memoizer)

        def update_ops_fn():
            update_ops = []
            for variable in variables:
                update_ops += variable.update_ops(memoizer)
            return update_ops

        with self.wrapped_session() as session:
            running_count = 0.0
            running_sum = 0.0
            for ii in xrange(len(self._splits) - 1):
                begin_index = self._splits[ii]
                end_index = self._splits[ii + 1]
                size = end_index - begin_index

                weights_subarray = self._weights[begin_index:end_index, 0]
                numerator_predicate_subarray = self._numerator_predicate[
                    begin_index:end_index]
                denominator_predicate_subarray = self._denominator_predicate[
                    begin_index:end_index]

                running_count += size
                running_sum += np.sum(weights_subarray *
                                      denominator_predicate_subarray)
                average_denominator = running_sum / running_count
                expected_weights = np.zeros(size)
                expected_weights = (weights_subarray *
                                    numerator_predicate_subarray)
                expected_weights /= average_denominator

                # Running the update_ops will update the running denominator count/sum
                # calculated by the _RatioWeights object.
                session.run_ops(update_ops_fn,
                                feed_dict={
                                    weights_placeholder:
                                    weights_subarray,
                                    numerator_predicate_placeholder:
                                    numerator_predicate_subarray,
                                    denominator_predicate_placeholder:
                                    denominator_predicate_subarray
                                })
                # Now we can calculate the weights.
                actual_weights_value = session.run(
                    lambda: actual_weights(memoizer),
                    feed_dict={
                        weights_placeholder:
                        weights_subarray,
                        numerator_predicate_placeholder:
                        numerator_predicate_subarray,
                        denominator_predicate_placeholder:
                        denominator_predicate_subarray
                    })

                self.assertAllClose(expected_weights,
                                    actual_weights_value,
                                    rtol=0,
                                    atol=1e-6)

                session.run_ops(
                    lambda: memoizer[defaults.GLOBAL_STEP_KEY].assign_add(1))
def split_rate_context(penalty_predictions,
                       constraint_predictions,
                       penalty_labels=None,
                       constraint_labels=None,
                       penalty_weights=1.0,
                       constraint_weights=1.0):
    """Creates a new split context.

  A "split context", unlike a normal context, has separate predictions, labels,
  weights and subset for the "penalty" and "constraint" portions of the problem.
  This is an advanced option, and is not needed in most circumstances.

  Args:
    penalty_predictions: rank-1 floating-point `Tensor`, for which the ith
      element is the output of the model on the ith training example, for the
      training dataset associated with the penalties.
    constraint_predictions: rank-1 floating-point `Tensor`, for which the ith
      element is the output of the model on the ith training example, for the
      training dataset associated with the constraints.
    penalty_labels: optional rank-1 `Tensor`, for which the ith element is the
      label of the ith training example, for the training dataset associated
      with the penalties.
    constraint_labels: optional rank-1 `Tensor`, for which the ith element is
      the label of the ith training example, for the training dataset associated
      with the constraints.
    penalty_weights: optional rank-1 floating-point `Tensor`, for which the ith
      element is the weight of the ith training example, for the training
      dataset associated with the penalties. If not specified, the weights
      default to being all-one.
    constraint_weights: optional rank-1 floating-point `Tensor`, for which the
      ith element is the weight of the ith training example, for the training
      dataset associated with the constraints. If not specified, the weights
      default to being all-one.

  Returns:
    `SubsettableContext` representing the given predictions, labels and weights.

  Raises:
    ValueError: if we're in eager mode, but either penalty_predictions or
      constraint_predictions is not callable.
    TypeError: if any arguments are internal rate library objects, instead of
      `Tensor`s or scalars.
  """
    # See comment in rate_context.
    if isinstance(penalty_predictions, helpers.RateObject):
        raise TypeError(
            "penalty_predictions parameter to split_rate_context() "
            "should be a Tensor-like object, or a nullary function returning such"
        )
    if isinstance(constraint_predictions, helpers.RateObject):
        raise TypeError(
            "constraint_predictions parameter to "
            "split_rate_context() should be a Tensor-like object, or a nullary "
            "function returning such")
    if isinstance(penalty_labels, helpers.RateObject):
        raise TypeError(
            "penalty_labels parameter to split_rate_context() should "
            "be a Tensor-like object, or a nullary function returning such")
    if isinstance(constraint_labels, helpers.RateObject):
        raise TypeError(
            "constraint_labels parameter to split_rate_context() "
            "should be a Tensor-like object, or a nullary function returning such"
        )
    if isinstance(penalty_weights, helpers.RateObject):
        raise TypeError(
            "penalty_weights parameter to split_rate_context() "
            "should be a Tensor-like object, or a nullary function returning such"
        )
    if isinstance(constraint_weights, helpers.RateObject):
        raise TypeError(
            "constraint_weights parameter to split_rate_context() "
            "should be a Tensor-like object, or a nullary function returning such"
        )

    if tf.executing_eagerly():
        if not (callable(penalty_predictions)
                and callable(constraint_predictions)):
            raise ValueError(
                "in eager mode, the predictions provided to a context "
                "must be a nullary function returning a Tensor (to fix "
                "this, consider wrapping it in a lambda)")
        # Unlike the predictions, which *must* be callable, we allow non-Tensor
        # constants (e.g. python scalars or numpy arrays) for the labels and
        # weights. However, they cannot be ordinary Tensors.
        if tf.is_tensor(penalty_labels) or tf.is_tensor(constraint_labels):
            raise ValueError(
                "in eager mode, the labels provided to a context must "
                "either be a constant, or a nullary function returning "
                "a Tensor: it cannot be a plain Tensor (to fix this, "
                "consider wrapping it in a lambda)")
        if tf.is_tensor(penalty_weights) or tf.is_tensor(constraint_weights):
            raise ValueError(
                "in eager mode, the weights provided to a context must "
                "either be a constant, or a nullary function returning "
                "a Tensor: it cannot be a plain Tensor (to fix this, "
                "consider wrapping it in a lambda)")

    penalty_predictions = deferred_tensor.DeferredTensor(penalty_predictions)
    if penalty_labels is not None:
        penalty_labels = deferred_tensor.DeferredTensor(penalty_labels)
    penalty_weights = deferred_tensor.DeferredTensor(penalty_weights)

    constraint_predictions = deferred_tensor.DeferredTensor(
        constraint_predictions)
    if constraint_labels is not None:
        constraint_labels = deferred_tensor.DeferredTensor(constraint_labels)
    constraint_weights = deferred_tensor.DeferredTensor(constraint_weights)

    raw_context = _RawContext(penalty_predictions=penalty_predictions,
                              penalty_labels=penalty_labels,
                              penalty_weights=penalty_weights,
                              constraint_predictions=constraint_predictions,
                              constraint_labels=constraint_labels,
                              constraint_weights=constraint_weights)
    true_predicate = predicate.Predicate(True)
    return SubsettableContext(raw_context, true_predicate, true_predicate)
def rate_context(predictions, labels=None, weights=1.0):
    """Creates a new context.

  Args:
    predictions: rank-1 floating-point `Tensor`, for which the ith element is
      the output of the model on the ith training example.
    labels: optional rank-1 `Tensor`, for which the ith element is the label of
      the ith training example.
    weights: optional rank-1 floating-point `Tensor`, for which the ith element
      is the weight of the ith training example. If not specified, the weights
      default to being all-one.

  Returns:
    `SubsettableContext` representing the given predictions, labels and weights.

  Raises:
    ValueError: if we're in eager mode, but predictions is not callable.
    TypeError: if any arguments are internal rate library objects, instead of
      `Tensor`s or scalars.
  """
    # Ideally, we'd check that these objects are Tensors, or are types that can be
    # converted to Tensors. Unfortunately, this includes a lot of possible types,
    # so the easiest solution would be to actually perform the conversion, and
    # then check that the resulting Tensor has only one element. This, however,
    # would add a dummy element to the Tensorflow graph, and wouldn't work for a
    # Tensor with an unknown size. Hence, we only check that they are not types
    # that we know for certain are disallowed: objects internal to this library.
    if isinstance(predictions, helpers.RateObject):
        raise TypeError(
            "predictions parameter to rate_context() should be a "
            "Tensor-like object, or a nullary function returning such")
    if isinstance(labels, helpers.RateObject):
        raise TypeError(
            "labels parameter to rate_context() should be a "
            "Tensor-like object, or a nullary function returning such")
    if isinstance(weights, helpers.RateObject):
        raise TypeError(
            "weights parameter to rate_context() should be a "
            "Tensor-like object, or a nullary function returning such")

    if tf.executing_eagerly():
        if not callable(predictions):
            raise ValueError(
                "in eager mode, the predictions provided to a context "
                "must be a nullary function returning a Tensor (to fix "
                "this, consider wrapping it in a lambda)")
        # Unlike the predictions, which *must* be callable, we allow non-Tensor
        # constants (e.g. python scalars or numpy arrays) for the labels and
        # weights. However, they cannot be ordinary Tensors.
        if tf.is_tensor(labels):
            raise ValueError(
                "in eager mode, the labels provided to a context must "
                "either be a constant, or a nullary function returning "
                "a Tensor: it cannot be a plain Tensor (to fix this, "
                "consider wrapping it in a lambda)")
        if tf.is_tensor(weights):
            raise ValueError(
                "in eager mode, the weights provided to a context must "
                "either be a constant, or a nullary function returning "
                "a Tensor: it cannot be a plain Tensor (to fix this, "
                "consider wrapping it in a lambda)")

    predictions = deferred_tensor.DeferredTensor(predictions)
    if labels is not None:
        labels = deferred_tensor.DeferredTensor(labels)
    weights = deferred_tensor.DeferredTensor(weights)

    raw_context = _RawContext(penalty_predictions=predictions,
                              penalty_labels=labels,
                              penalty_weights=weights,
                              constraint_predictions=predictions,
                              constraint_labels=labels,
                              constraint_weights=weights)
    true_predicate = predicate.Predicate(True)
    return SubsettableContext(raw_context, true_predicate, true_predicate)
    def subset(self, penalty_predicate, constraint_predicate=None):
        """Returns a subset of this context.

    The two predicates should be boolean `Tensor`s of the same size as the
    predictions `Tensor` from which the top-level context was constructed. If an
    element of the predicate `Tensor` is True, and the corresponding example is
    included in this context, then the example will be included in the resulting
    context. Otherwise, it will not.

    A "split context" contains two sets of predictions (and optionally labels
    and weights). When subsetting a split context, two predicates must be
    provided to this method: the first for the penalty portion, and the second
    for the constraint portion. Alternatively, if you want to create a split
    context from a non-split one, then you can do so by providing both predicate
    arguments explicitly.

    This method is here for convenience, but it comes at a cost. You should use
    subsetting *with great caution*. If, for example, you wish to create a rate
    only on the set of "blue" examples, then it will almost always be better
    (but more complicated) to create an entirely separate dataset containing
    only "blue" examples (e.g. using the "filter" method of a
    `tf.data.Dataset`), rather than taking the "blue" subset of a dataset that
    also contains "red" and "green" examples.

    The reason for this is that, if using subsetting, each minibatch will
    contain varying numbers of "blue" examples during training. As a
    consequence, we'll sometimes perform too-small updates, and sometimes
    overcorrect with extremely large updates. This problem is less serious if
    "blue" examples are common, but can be fatal if "blue" examples are
    extremely rare.

    If, instead of subsetting, we were to create an entirely separate "blue"
    dataset, then every minibatch would contain the same number of "blue"
    examples, and optimization would proceed more smoothly.

    Args:
      penalty_predicate: boolean `Tensor` with the size as the underlying
        predictions `Tensor` (or broadcastable to it), each element of which
        indicates whether the corresponding example should be included in the
        subset.
      constraint_predicate: optional boolean `Tensor`, playing the same role as
        "penalty_predicate", but for the constraints portion of the context.

    Returns:
      `SubsettableContext` representing the subset of this context on which
      penalty_predicate (and constraint_predicate, if applicable) are True.

    Raises:
      ValueError: if no constraint_predicate is provided, but this is a split
        context.
    """
        if constraint_predicate is None:
            # It's fine if the labels and/or weights are different.
            if (self._raw_context.penalty_predictions !=
                    self._raw_context.constraint_predictions
                    or self._penalty_predicate != self._constraint_predicate):
                raise ValueError("constraint_predicate must be provided when "
                                 "subsetting a split context")
            constraint_predicate = penalty_predicate

        # In eager mode, we do not permit ordinary constant Tensors to be passed as
        # predicates: only lambdas. Additionally, we allow non-Tensor constants
        # (e.g. python scalars or numpy arrays) and DeferredTensors (for internal
        # use).
        #
        # The reason for this is that, in eager mode, we want to prevent users from
        # passing a constant when they intend to use a variable. For example:
        #   context.subset(some_variable > 0.5)
        # This is probably a bug, since if, later in the program, the value of
        # "some_variable" changes, the value of the earlier evaluation of
        # "some_variable > 0.5" will not. To prevent this, we have checks that force
        # the user to use something like:
        #   context.subset(lambda: some_variable > 0.5)
        # which will work fine even if "some_variable" subsequently changes.
        if tf.executing_eagerly() and (tf.is_tensor(penalty_predicate)
                                       or tf.is_tensor(constraint_predicate)):
            raise ValueError(
                "in eager mode, the predicate provided to a context's "
                "subset() method must either be a constant, or a "
                "nullary function returning a Tensor: it cannot be a"
                "plain Tensor (to fix this, consider wrapping it in a "
                "lambda)")

        # First convert the predicate Tensors into DeferredTensors, so that we can
        # use the __eq__ operator.
        if not isinstance(penalty_predicate, deferred_tensor.DeferredTensor):
            penalty_predicate = deferred_tensor.DeferredTensor(
                penalty_predicate)
        if not isinstance(constraint_predicate,
                          deferred_tensor.DeferredTensor):
            constraint_predicate = deferred_tensor.DeferredTensor(
                constraint_predicate)

        # Convert the boolean predicates to Predicate objects. Make sure that we
        # don't change from a non-split context (both predicates are the same
        # object) to a split context (the predicates are different objects) unless
        # it's necessary.
        if (self._penalty_predicate == self._constraint_predicate
                and penalty_predicate == constraint_predicate):
            penalty_predicate = self._penalty_predicate & predicate.Predicate(
                penalty_predicate)
            constraint_predicate = penalty_predicate
        else:
            penalty_predicate = self._penalty_predicate & predicate.Predicate(
                penalty_predicate)
            constraint_predicate = self._constraint_predicate & predicate.Predicate(
                constraint_predicate)

        return SubsettableContext(raw_context=self._raw_context,
                                  penalty_predicate=penalty_predicate,
                                  constraint_predicate=constraint_predicate)