def context(self): # We can't create the context in __init__, since it would then wind up in # the wrong TensorFlow graph. context = subsettable_context.split_rate_context( penalty_predictions=tf.constant(self._penalty_predictions, dtype=tf.float32), constraint_predictions=tf.constant(self._constraint_predictions, dtype=tf.float32), penalty_labels=tf.constant(self._penalty_labels, dtype=tf.float32), constraint_labels=tf.constant(self._constraint_labels, dtype=tf.float32), penalty_weights=tf.constant(self._penalty_weights, dtype=tf.float32), constraint_weights=tf.constant(self._constraint_weights, dtype=tf.float32)) return context.subset(self._penalty_predicate, self._constraint_predicate)
def _split_context(self): """Creates a new split and subsetted context.""" # We can't create the context in __init__, since it would then wind up in # the wrong TensorFlow graph. penalty_predictions = tf.constant( self._penalty_predictions, dtype=tf.float32) constraint_predictions = tf.constant( self._constraint_predictions, dtype=tf.float32) penalty_labels = tf.constant(self._penalty_labels, dtype=tf.float32) constraint_labels = tf.constant(self._constraint_labels, dtype=tf.float32) penalty_weights = tf.constant(self._penalty_weights, dtype=tf.float32) constraint_weights = tf.constant(self._constraint_weights, dtype=tf.float32) context = subsettable_context.split_rate_context( penalty_predictions=lambda: penalty_predictions, constraint_predictions=lambda: constraint_predictions, penalty_labels=lambda: penalty_labels, constraint_labels=lambda: constraint_labels, penalty_weights=lambda: penalty_weights, constraint_weights=lambda: constraint_weights) return context.subset(self._penalty_predicate, self._constraint_predicate)