def maybe_create_label_priors(label_priors, labels, weights, variables_collections): """Creates moving average ops to track label priors, if necessary. Args: label_priors: As required in e.g. precision_recall_auc_loss. labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels]. weights: As required in e.g. precision_recall_auc_loss. variables_collections: Optional list of collections for the variables, if any must be created. Returns: label_priors: A Tensor of shape [num_labels] consisting of the weighted label priors, after updating with moving average ops if created. """ if label_priors is not None: label_priors = util.convert_and_cast( label_priors, name='label_priors', dtype=labels.dtype.base_dtype) return tf.squeeze(label_priors) label_priors = util.build_label_priors( labels, weights, variables_collections=variables_collections) return label_priors
def maybe_create_label_priors(label_priors, labels, weights, variables_collections): if label_priors is not None: label_priors = util.convert_and_cast(label_priors, name='label_priors', dtype=labels.dtype.base_dtype) return tf.squeeze(label_priors) label_priors = util.build_label_priors( labels, weights, variables_collections=variables_collections) return label_priors
def testLabelPriorConsistency(self): # Checks that, with zero pseudocounts, the returned label priors reproduce # label frequencies in the batch. batch_shape = [4, 10] labels = tf.Variable( tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.678))) label_priors_update = util.build_label_priors( labels=labels, positive_pseudocount=0, negative_pseudocount=0) expected_priors = tf.reduce_mean(labels, 0) with self.test_session(): tf.global_variables_initializer().run() self.assertAllClose(label_priors_update.eval(), expected_priors.eval())
def testLabelPriorsUpdateWithWeights(self): # Checks the update of label priors with per-example weights. batch_size = 6 num_labels = 5 batch_shape = [batch_size, num_labels] labels = tf.Variable( tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.6))) weights = tf.Variable(tf.random_uniform(batch_shape) * 6.2) update_op = util.build_label_priors(labels, weights=weights) expected_weighted_label_counts = 1.0 + tf.reduce_sum(weights * labels, 0) expected_weight_sum = 2.0 + tf.reduce_sum(weights, 0) expected_label_posteriors = tf.divide(expected_weighted_label_counts, expected_weight_sum) with self.test_session() as session: tf.global_variables_initializer().run() updated_priors, expected_posteriors = session.run( [update_op, expected_label_posteriors]) self.assertAllClose(updated_priors, expected_posteriors)
def testLabelPriorsUpdate(self): # Checks that the update of label priors behaves as expected. batch_shape = [1, 5] labels = tf.Variable( tf.to_float(tf.greater(tf.random_uniform(batch_shape), 0.4))) label_priors_update = util.build_label_priors(labels) label_sum = np.ones(shape=batch_shape) weight_sum = 2.0 * np.ones(shape=batch_shape) with self.test_session() as session: tf.global_variables_initializer().run() for _ in range(3): label_sum += labels.eval() weight_sum += np.ones(shape=batch_shape) expected_posteriors = label_sum / weight_sum label_priors = label_priors_update.eval().reshape(batch_shape) self.assertAllClose(label_priors, expected_posteriors) # Re-initialize labels to get a new random sample. session.run(labels.initializer)
def maybe_create_label_priors(label_priors, labels, weights, variables_collections): """Creates moving average ops to track label priors, if necessary. Args: label_priors: As required in e.g. precision_recall_auc_loss. labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels]. weights: As required in e.g. precision_recall_auc_loss. variables_collections: Optional list of collections for the variables, if any must be created. Returns: label_priors: A Tensor of shape [num_labels] consisting of the weighted label priors, after updating with moving average ops if created. """ if label_priors is not None: label_priors = util.convert_and_cast(label_priors, name='label_priors', dtype=labels.dtype.base_dtype) return tf.squeeze(label_priors) label_priors = util.build_label_priors( labels, weights, variables_collections=variables_collections) return label_priors