コード例 #1
0
 def context(self):
     # We can't create the context in __init__, since it would then wind up in
     # the wrong TensorFlow graph.
     context = subsettable_context.split_rate_context(
         penalty_predictions=tf.constant(self._penalty_predictions,
                                         dtype=tf.float32),
         constraint_predictions=tf.constant(self._constraint_predictions,
                                            dtype=tf.float32),
         penalty_labels=tf.constant(self._penalty_labels, dtype=tf.float32),
         constraint_labels=tf.constant(self._constraint_labels,
                                       dtype=tf.float32),
         penalty_weights=tf.constant(self._penalty_weights,
                                     dtype=tf.float32),
         constraint_weights=tf.constant(self._constraint_weights,
                                        dtype=tf.float32))
     return context.subset(self._penalty_predicate,
                           self._constraint_predicate)
コード例 #2
0
 def _split_context(self):
   """Creates a new split and subsetted context."""
   # We can't create the context in __init__, since it would then wind up in
   # the wrong TensorFlow graph.
   penalty_predictions = tf.constant(
       self._penalty_predictions, dtype=tf.float32)
   constraint_predictions = tf.constant(
       self._constraint_predictions, dtype=tf.float32)
   penalty_labels = tf.constant(self._penalty_labels, dtype=tf.float32)
   constraint_labels = tf.constant(self._constraint_labels, dtype=tf.float32)
   penalty_weights = tf.constant(self._penalty_weights, dtype=tf.float32)
   constraint_weights = tf.constant(self._constraint_weights, dtype=tf.float32)
   context = subsettable_context.split_rate_context(
       penalty_predictions=lambda: penalty_predictions,
       constraint_predictions=lambda: constraint_predictions,
       penalty_labels=lambda: penalty_labels,
       constraint_labels=lambda: constraint_labels,
       penalty_weights=lambda: penalty_weights,
       constraint_weights=lambda: constraint_weights)
   return context.subset(self._penalty_predicate, self._constraint_predicate)