Exemplo n.º 1
0
 def metrics(self, regularization_losses=None):
     """Creates metrics. See `base_head.Head` for details."""
     keys = metric_keys.MetricKeys
     with ops.name_scope('metrics', values=(regularization_losses, )):
         # Mean metric.
         eval_metrics = {}
         eval_metrics[self._loss_mean_key] = metrics.Mean(
             name=keys.LOSS_MEAN)
         eval_metrics[self._accuracy_key] = (
             metrics.SparseCategoricalAccuracy(name=keys.ACCURACY))
         # TODO(b/118843532): create Keras metrics
         # eval_metrics[self._precision_key] = metrics.Precision(name=keys.AUC)
         # eval_metrics[self._recall_key] = metrics.Precision(name=keys.RECALL)
         eval_metrics[self._prediction_mean_key] = metrics.Mean(
             name=keys.PREDICTION_MEAN)
         eval_metrics[self._label_mean_key] = metrics.Mean(
             name=keys.LABEL_MEAN)
         # TODO(b/118843532): create Keras metrics
         # eval_metrics[self._accuracy_baseline_key] = (
         #     metrics.Mean(name=keys.ACCURACY_BASELINE))
         # eval_metrics[self._auc_key] = metrics.Precision(name=keys.PRECISION)
         # eval_metrics[self._auc_pr_key] = metrics.Precision(name=keys.AUC_PR)
         if regularization_losses is not None:
             eval_metrics[self._loss_regularization_key] = metrics.Mean(
                 name=keys.LOSS_REGULARIZATION)
         for i, threshold in enumerate(self._thresholds):
             eval_metrics[self._accuracy_keys[i]] = metrics.BinaryAccuracy(
                 name=self._accuracy_keys[i], threshold=threshold)
             # TODO(b/118843532): create Keras metrics
             # eval_metrics[self._precision_keys[i]] = (
             #     metrics.PRECISION_AT_THRESHOLD(
             #         name=self._precision_keys[i], threshold=threshold))
             # eval_metrics[self._recall_keys[i]] = metrics.RECALL_AT_THRESHOLD(
             #     name=self._recall_keys[i], threshold=threshold)
     return eval_metrics
Exemplo n.º 2
0
 def metrics(self, regularization_losses=None):
   """Creates metrics. See `base_head.Head` for details."""
   keys = metric_keys.MetricKeys
   with ops.name_scope('metrics', values=(regularization_losses,)):
     # Mean metric.
     eval_metrics = {}
     eval_metrics[self._loss_mean_key] = metrics.Mean(name=keys.LOSS_MEAN)
     if regularization_losses is not None:
       eval_metrics[self._loss_regularization_key] = metrics.Mean(
           name=keys.LOSS_REGULARIZATION)
     # Accuracy metric.
     eval_metrics[self._accuracy_key] = (
         metrics.SparseCategoricalAccuracy(name=keys.ACCURACY))
   return eval_metrics
Exemplo n.º 3
0
    def test_sparse_categorical_accuracy_mismatched_dims_dynamic(self):
        with context.graph_mode(), self.cached_session() as sess:
            acc_obj = metrics.SparseCategoricalAccuracy(name='my acc')
            self.evaluate(variables.variables_initializer(acc_obj.variables))

            t = array_ops.placeholder(dtypes.float32)
            p = array_ops.placeholder(dtypes.float32)
            w = array_ops.placeholder(dtypes.float32)

            result_t = acc_obj(t, p, w)
            result = sess.run(result_t,
                              feed_dict=({
                                  t: [2, 1],
                                  p: [[0.1, 0.1, 0.8], [0.05, 0, 0.95]],
                                  w: [[0.5], [0.2]]
                              }))
            self.assertAlmostEqual(result, 0.71, 2)  # 2.5/2.7
Exemplo n.º 4
0
  def test_sparse_categorical_accuracy_mismatched_dims(self):
    acc_obj = metrics.SparseCategoricalAccuracy(name='my acc')

    # check config
    self.assertEqual(acc_obj.name, 'my acc')
    self.assertTrue(acc_obj.stateful)
    self.assertEqual(len(acc_obj.variables), 2)
    self.assertEqual(acc_obj.dtype, dtypes.float32)
    self.evaluate(variables.variables_initializer(acc_obj.variables))

    # verify that correct value is returned
    update_op = acc_obj.update_state([2, 1], [[0.1, 0.1, 0.8], [0.05, 0.95, 0]])
    self.evaluate(update_op)
    result = self.evaluate(acc_obj.result())
    self.assertEqual(result, 1)  # 2/2

    # check with sample_weight
    result_t = acc_obj([2, 1], [[0.1, 0.1, 0.8], [0.05, 0, 0.95]],
                       [[0.5], [0.2]])
    result = self.evaluate(result_t)
    self.assertAlmostEqual(result, 0.93, 2)  # 2.5/2.7