Ejemplo n.º 1
0
 def testPairedJoint(self):
     n = 7
     np.random.seed(n)
     positives = np.random.randn(n)
     negatives = np.random.randn(n)
     logits = np.array([positives, negatives]).T.ravel()
     labels = np.array([True, False] * n)
     with self.test_session() as sess:
         joint_logits, joint_labels = sess.run(
             model_utils.paired_joint_logits_and_labels(logits, labels))
         self.assertEqual(joint_logits.shape, joint_labels.shape)
         self.assertAllClose(joint_logits, positives - negatives)
         self.assertAllEqual(joint_labels, [True] * n)
Ejemplo n.º 2
0
  def make_metrics():
    """Builds a model and computes metrics."""
    # Generate model
    logits, labels = make_model()
    if joint_safe:
      joint_logits, joint_labels = model_utils.paired_joint_logits_and_labels(
          logits, labels)

    # Measure accuracy
    metric_ops = [metrics.streaming_accuracy(logits > 0, labels)]
    tf.contrib.deprecated.scalar_summary('eval accuracy', metric_ops[-1][0])

    # Measure joint accuracy
    if joint_safe:
      metric_ops.append(metrics.streaming_accuracy(joint_logits > 0,
                                                   joint_labels))
      tf.contrib.deprecated.scalar_summary('eval joint accuracy',
                                           metric_ops[-1][0])
    return metric_ops
Ejemplo n.º 3
0
    def make_loss():
        """Builds a model and computes losses."""
        # Generate model
        logits, labels = make_model()
        if joint_safe:
            joint_logits, joint_labels = model_utils.paired_joint_logits_and_labels(
                logits, labels)

        # Compute loss
        if hparams.loss == 'xent':
            slim.losses.sigmoid_cross_entropy(
                logits[:, None],
                labels[:, None],
                label_smoothing=hparams.label_smoothing)
        elif hparams.loss == 'joint':
            if not joint_safe:
                raise ValueError('joint loss needs joint_safe=True')
            slim.losses.sigmoid_cross_entropy(
                joint_logits[:, None],
                joint_labels[:, None],
                label_smoothing=hparams.label_smoothing)
        else:
            raise ValueError('Unknown loss %r' % hparams.loss)
        total_loss = slim.losses.get_total_loss()
        tf.contrib.deprecated.scalar_summary('total loss', total_loss)

        # Measure discrete accuracy
        predictions = logits > 0
        accuracy = metrics.accuracy(predictions, labels)
        tf.contrib.deprecated.scalar_summary('train accuracy', accuracy)

        # Measure joint accuracy
        if joint_safe:
            joint_predictions = joint_logits > 0
            joint_accuracy = metrics.accuracy(joint_predictions, joint_labels)
            tf.contrib.deprecated.scalar_summary('train joint accuracy',
                                                 joint_accuracy)

        # All done!
        return total_loss