def metrics(self, regularization_losses=None):
    """Creates metrics. See `base_head.Head` for details."""
    keys = metric_keys.MetricKeys
    with ops.name_scope('metrics', values=(regularization_losses,)):
      # Mean metric.
      eval_metrics = {}
      eval_metrics[self._loss_mean_key] = metrics.Mean(name=keys.LOSS_MEAN)
      eval_metrics[self._prediction_mean_key] = metrics.Mean(
          name=keys.PREDICTION_MEAN)
      eval_metrics[self._label_mean_key] = metrics.Mean(name=keys.LABEL_MEAN)
      if regularization_losses is not None:
        eval_metrics[self._loss_regularization_key] = metrics.Mean(
            name=keys.LOSS_REGULARIZATION)
      if self._model_hparams.da_tlen > 0 and not self._model_hparams.event_relation:
        eval_metrics[self._auc_roc_24] = metrics.AUC(name=AUC_ROC % '24')
        eval_metrics[self._auc_roc_48] = metrics.AUC(name=AUC_ROC % '48')
        eval_metrics[self._auc_pr] = metrics.AUC(
            curve='PR', name=AUC_PR % 'avg')
        eval_metrics[self._auc_roc] = metrics.AUC(name=AUC_ROC % 'avg')
        eval_metrics[self._mean_abs_error] = metrics.MeanAbsoluteError(
            name=MEAN_ABS_ERROR % 'avg')

        for i in range(int(self._model_hparams.da_tlen / SLOT_TO_WINDOW) + 1):
          eval_metrics[self._probablity_within_window_list[i]] = metrics.Mean(
              name=PROBABILITY_AT_WINDOW % i)

    return eval_metrics
Ejemplo n.º 2
0
 def metrics(self, regularization_losses=None):
     """Creates metrics. See `base_head.Head` for details."""
     keys = metric_keys.MetricKeys
     with ops.name_scope(None, 'metrics', (regularization_losses, )):
         # Mean metric.
         eval_metrics = {}
         eval_metrics[self._loss_mean_key] = metrics.Mean(
             name=keys.LOSS_MEAN)
         # The default summation_method is "interpolation" in the AUC metric.
         eval_metrics[self._auc_key] = metrics.AUC(name=keys.AUC)
         eval_metrics[self._auc_pr_key] = metrics.AUC(curve='PR',
                                                      name=keys.AUC_PR)
         if regularization_losses is not None:
             eval_metrics[self._loss_regularization_key] = metrics.Mean(
                 name=keys.LOSS_REGULARIZATION)
         for i, threshold in enumerate(self._thresholds):
             eval_metrics[self._accuracy_keys[i]] = metrics.BinaryAccuracy(
                 name=self._accuracy_keys[i], threshold=threshold)
             eval_metrics[self._precision_keys[i]] = (metrics.Precision(
                 name=self._precision_keys[i], thresholds=threshold))
             eval_metrics[self._recall_keys[i]] = metrics.Recall(
                 name=self._recall_keys[i], thresholds=threshold)
         for i in range(len(self._classes_for_class_based_metrics)):
             eval_metrics[self._prob_keys[i]] = metrics.Mean(
                 name=self._prob_keys[i])
             eval_metrics[self._auc_keys[i]] = metrics.AUC(
                 name=self._auc_keys[i])
             eval_metrics[self._auc_pr_keys[i]] = metrics.AUC(
                 curve='PR', name=self._auc_pr_keys[i])
     return eval_metrics
Ejemplo n.º 3
0
 def metrics(self, regularization_losses=None):
   """Creates metrics. See `base_head.Head` for details."""
   keys = metric_keys.MetricKeys
   with ops.name_scope('metrics', values=(regularization_losses,)):
     # Mean metric.
     eval_metrics = {}
     eval_metrics[self._loss_mean_key] = metrics.Mean(name=keys.LOSS_MEAN)
     eval_metrics[self._accuracy_key] = metrics.Accuracy(name=keys.ACCURACY)
     eval_metrics[self._precision_key] = metrics.Precision(name=keys.PRECISION)
     eval_metrics[self._recall_key] = metrics.Recall(name=keys.RECALL)
     eval_metrics[self._prediction_mean_key] = metrics.Mean(
         name=keys.PREDICTION_MEAN)
     eval_metrics[self._label_mean_key] = metrics.Mean(name=keys.LABEL_MEAN)
     eval_metrics[self._accuracy_baseline_key] = (
         metrics.Mean(name=keys.ACCURACY_BASELINE))
     # The default summation_method is "interpolation" in the AUC metric.
     eval_metrics[self._auc_key] = metrics.AUC(name=keys.AUC)
     eval_metrics[self._auc_pr_key] = metrics.AUC(curve='PR', name=keys.AUC_PR)
     if regularization_losses is not None:
       eval_metrics[self._loss_regularization_key] = metrics.Mean(
           name=keys.LOSS_REGULARIZATION)
     for i, threshold in enumerate(self._thresholds):
       eval_metrics[self._accuracy_keys[i]] = metrics.BinaryAccuracy(
           name=self._accuracy_keys[i], threshold=threshold)
       eval_metrics[self._precision_keys[i]] = metrics.Precision(
           name=self._precision_keys[i], thresholds=threshold)
       eval_metrics[self._recall_keys[i]] = metrics.Recall(
           name=self._recall_keys[i], thresholds=threshold)
   return eval_metrics
Ejemplo n.º 4
0
  def test_unweighted(self):
    self.setup()
    auc_obj = metrics.AUC(num_thresholds=self.num_thresholds)
    self.evaluate(variables.variables_initializer(auc_obj.variables))
    result = auc_obj(self.y_true, self.y_pred)

    # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
    # recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0]
    # fp_rate = [2/2, 0, 0] = [1, 0, 0]
    # heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25]
    # widths = [(1 - 0), (0 - 0)] = [1, 0]
    expected_result = (0.75 * 1 + 0.25 * 0)
    self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
Ejemplo n.º 5
0
  def test_weighted_roc_interpolation(self):
    self.setup()
    auc_obj = metrics.AUC(num_thresholds=self.num_thresholds)
    self.evaluate(variables.variables_initializer(auc_obj.variables))
    result = auc_obj(self.y_true, self.y_pred, sample_weight=self.sample_weight)

    # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
    # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]
    # fp_rate = [3/3, 0, 0] = [1, 0, 0]
    # heights = [(1 + 0.571)/2, (0.571 + 0)/2] = [0.7855, 0.2855]
    # widths = [(1 - 0), (0 - 0)] = [1, 0]
    expected_result = (0.7855 * 1 + 0.2855 * 0)
    self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
Ejemplo n.º 6
0
    def test_reset_states_auc(self):
        auc_obj = metrics.AUC(num_thresholds=3)
        model = _get_model([auc_obj])
        x = np.concatenate((np.ones((25, 4)), np.zeros(
            (25, 4)), np.zeros((25, 4)), np.ones((25, 4))))
        y = np.concatenate((np.ones((25, 1)), np.zeros(
            (25, 1)), np.ones((25, 1)), np.zeros((25, 1))))

        for _ in range(2):
            model.evaluate(x, y)
            self.assertEqual(self.evaluate(auc_obj.true_positives[1]), 25.)
            self.assertEqual(self.evaluate(auc_obj.false_positives[1]), 25.)
            self.assertEqual(self.evaluate(auc_obj.false_negatives[1]), 25.)
            self.assertEqual(self.evaluate(auc_obj.true_negatives[1]), 25.)
Ejemplo n.º 7
0
  def test_value_is_idempotent(self):
    self.setup()
    auc_obj = metrics.AUC(num_thresholds=3)
    self.evaluate(variables.variables_initializer(auc_obj.variables))

    # Run several updates.
    update_op = auc_obj.update_state(self.y_true, self.y_pred)
    for _ in range(10):
      self.evaluate(update_op)

    # Then verify idempotency.
    initial_auc = self.evaluate(auc_obj.result())
    for _ in range(10):
      self.assertAllClose(initial_auc, self.evaluate(auc_obj.result()), 1e-3)
Ejemplo n.º 8
0
  def test_weighted_roc_minoring(self):
    self.setup()
    auc_obj = metrics.AUC(
        num_thresholds=self.num_thresholds,
        summation_method=metrics_utils.AUCSummationMethod.MINORING)
    self.evaluate(variables.variables_initializer(auc_obj.variables))
    result = auc_obj(self.y_true, self.y_pred, sample_weight=self.sample_weight)

    # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
    # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]
    # fp_rate = [3/3, 0, 0] = [1, 0, 0]
    # heights = [min(1, 0.571), min(0.571, 0)] = [0.571, 0]
    # widths = [(1 - 0), (0 - 0)] = [1, 0]
    expected_result = (0.571 * 1 + 0 * 0)
    self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
Ejemplo n.º 9
0
  def test_weighted_pr_majoring(self):
    self.setup()
    auc_obj = metrics.AUC(
        num_thresholds=self.num_thresholds,
        curve=metrics_utils.AUCCurve.PR,
        summation_method=metrics_utils.AUCSummationMethod.MAJORING)
    self.evaluate(variables.variables_initializer(auc_obj.variables))
    result = auc_obj(self.y_true, self.y_pred, sample_weight=self.sample_weight)

    # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
    # precision = [7/(7+3), 4/4, 0] = [0.7, 1, 0]
    # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]
    # heights = [max(0.7, 1), max(1, 0)] = [1, 1]
    # widths = [(1 - 0.571), (0.571 - 0)] = [0.429, 0.571]
    expected_result = (1 * 0.429 + 1 * 0.571)
    self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
Ejemplo n.º 10
0
  def test_config(self):
    auc_obj = metrics.AUC(
        num_thresholds=100,
        curve=metrics_utils.AUCCurve.PR,
        summation_method=metrics_utils.AUCSummationMethod.MAJORING,
        name='auc_1')
    self.assertEqual(auc_obj.name, 'auc_1')
    self.assertEqual(len(auc_obj.variables), 4)
    self.assertEqual(auc_obj.num_thresholds, 100)
    self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR)
    self.assertEqual(auc_obj.summation_method,
                     metrics_utils.AUCSummationMethod.MAJORING)

    # Check save and restore config
    auc_obj2 = metrics.AUC.from_config(auc_obj.get_config())
    self.assertEqual(auc_obj2.name, 'auc_1')
    self.assertEqual(len(auc_obj2.variables), 4)
    self.assertEqual(auc_obj2.num_thresholds, 100)
    self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR)
    self.assertEqual(auc_obj2.summation_method,
                     metrics_utils.AUCSummationMethod.MAJORING)
Ejemplo n.º 11
0
  def test_weighted_pr_interpolation(self):
    self.setup()
    auc_obj = metrics.AUC(
        num_thresholds=self.num_thresholds,
        curve=metrics_utils.AUCCurve.PR,
        summation_method=metrics_utils.AUCSummationMethod.INTERPOLATION)
    self.evaluate(variables.variables_initializer(auc_obj.variables))
    result = auc_obj(self.y_true, self.y_pred, sample_weight=self.sample_weight)

    # auc = (slope / Total Pos) * [dTP - intercept * log(Pb/Pa)]

    # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
    # P = tp + fp = [10, 4, 0]
    # dTP = [7-4, 4-0] = [3, 4]
    # dP = [10-4, 4-0] = [6, 4]
    # slope = dTP/dP = [0.5, 1]
    # intercept = (TPa+(slope*Pa) = [(4 - 0.5*4), (0 - 1*0)] = [2, 0]
    # (Pb/Pa) = (Pb/Pa) if Pb > 0 AND Pa > 0 else 1 = [10/4, 4/0] = [2.5, 1]
    # auc * TotalPos = [(0.5 * (3 + 2 * log(2.5))), (1 * (4 + 0))]
    #                = [2.416, 4]
    # auc = [2.416, 4]/(tp[1:]+fn[1:])
    expected_result = (2.416/7 + 4/7)
    self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
Ejemplo n.º 12
0
 def test_unweighted_all_correct(self):
   self.setup()
   auc_obj = metrics.AUC()
   self.evaluate(variables.variables_initializer(auc_obj.variables))
   result = auc_obj(self.y_true, self.y_true)
   self.assertEqual(self.evaluate(result), 1)
Ejemplo n.º 13
0
  def test_invalid_num_thresholds(self):
    with self.assertRaisesRegexp(ValueError, '`num_thresholds` must be > 1.'):
      metrics.AUC(num_thresholds=-1)

    with self.assertRaisesRegexp(ValueError, '`num_thresholds` must be > 1.'):
      metrics.AUC(num_thresholds=1)