def testPairedJoint(self): n = 7 np.random.seed(n) positives = np.random.randn(n) negatives = np.random.randn(n) logits = np.array([positives, negatives]).T.ravel() labels = np.array([True, False] * n) with self.test_session() as sess: joint_logits, joint_labels = sess.run( model_utils.paired_joint_logits_and_labels(logits, labels)) self.assertEqual(joint_logits.shape, joint_labels.shape) self.assertAllClose(joint_logits, positives - negatives) self.assertAllEqual(joint_labels, [True] * n)
def make_metrics(): """Builds a model and computes metrics.""" # Generate model logits, labels = make_model() if joint_safe: joint_logits, joint_labels = model_utils.paired_joint_logits_and_labels( logits, labels) # Measure accuracy metric_ops = [metrics.streaming_accuracy(logits > 0, labels)] tf.contrib.deprecated.scalar_summary('eval accuracy', metric_ops[-1][0]) # Measure joint accuracy if joint_safe: metric_ops.append(metrics.streaming_accuracy(joint_logits > 0, joint_labels)) tf.contrib.deprecated.scalar_summary('eval joint accuracy', metric_ops[-1][0]) return metric_ops
def make_loss(): """Builds a model and computes losses.""" # Generate model logits, labels = make_model() if joint_safe: joint_logits, joint_labels = model_utils.paired_joint_logits_and_labels( logits, labels) # Compute loss if hparams.loss == 'xent': slim.losses.sigmoid_cross_entropy( logits[:, None], labels[:, None], label_smoothing=hparams.label_smoothing) elif hparams.loss == 'joint': if not joint_safe: raise ValueError('joint loss needs joint_safe=True') slim.losses.sigmoid_cross_entropy( joint_logits[:, None], joint_labels[:, None], label_smoothing=hparams.label_smoothing) else: raise ValueError('Unknown loss %r' % hparams.loss) total_loss = slim.losses.get_total_loss() tf.contrib.deprecated.scalar_summary('total loss', total_loss) # Measure discrete accuracy predictions = logits > 0 accuracy = metrics.accuracy(predictions, labels) tf.contrib.deprecated.scalar_summary('train accuracy', accuracy) # Measure joint accuracy if joint_safe: joint_predictions = joint_logits > 0 joint_accuracy = metrics.accuracy(joint_predictions, joint_labels) tf.contrib.deprecated.scalar_summary('train joint accuracy', joint_accuracy) # All done! return total_loss