def metrics(self, regularization_losses=None): """Creates metrics. See `base_head.Head` for details.""" keys = metric_keys.MetricKeys with ops.name_scope('metrics', values=(regularization_losses, )): # Mean metric. eval_metrics = {} eval_metrics[self._loss_mean_key] = metrics.Mean( name=keys.LOSS_MEAN) eval_metrics[self._accuracy_key] = ( metrics.SparseCategoricalAccuracy(name=keys.ACCURACY)) # TODO(b/118843532): create Keras metrics # eval_metrics[self._precision_key] = metrics.Precision(name=keys.AUC) # eval_metrics[self._recall_key] = metrics.Precision(name=keys.RECALL) eval_metrics[self._prediction_mean_key] = metrics.Mean( name=keys.PREDICTION_MEAN) eval_metrics[self._label_mean_key] = metrics.Mean( name=keys.LABEL_MEAN) # TODO(b/118843532): create Keras metrics # eval_metrics[self._accuracy_baseline_key] = ( # metrics.Mean(name=keys.ACCURACY_BASELINE)) # eval_metrics[self._auc_key] = metrics.Precision(name=keys.PRECISION) # eval_metrics[self._auc_pr_key] = metrics.Precision(name=keys.AUC_PR) if regularization_losses is not None: eval_metrics[self._loss_regularization_key] = metrics.Mean( name=keys.LOSS_REGULARIZATION) for i, threshold in enumerate(self._thresholds): eval_metrics[self._accuracy_keys[i]] = metrics.BinaryAccuracy( name=self._accuracy_keys[i], threshold=threshold) # TODO(b/118843532): create Keras metrics # eval_metrics[self._precision_keys[i]] = ( # metrics.PRECISION_AT_THRESHOLD( # name=self._precision_keys[i], threshold=threshold)) # eval_metrics[self._recall_keys[i]] = metrics.RECALL_AT_THRESHOLD( # name=self._recall_keys[i], threshold=threshold) return eval_metrics
def metrics(self, regularization_losses=None): """Creates metrics. See `base_head.Head` for details.""" keys = metric_keys.MetricKeys with ops.name_scope('metrics', values=(regularization_losses,)): # Mean metric. eval_metrics = {} eval_metrics[self._loss_mean_key] = metrics.Mean(name=keys.LOSS_MEAN) if regularization_losses is not None: eval_metrics[self._loss_regularization_key] = metrics.Mean( name=keys.LOSS_REGULARIZATION) # Accuracy metric. eval_metrics[self._accuracy_key] = ( metrics.SparseCategoricalAccuracy(name=keys.ACCURACY)) return eval_metrics
def test_sparse_categorical_accuracy_mismatched_dims_dynamic(self): with context.graph_mode(), self.cached_session() as sess: acc_obj = metrics.SparseCategoricalAccuracy(name='my acc') self.evaluate(variables.variables_initializer(acc_obj.variables)) t = array_ops.placeholder(dtypes.float32) p = array_ops.placeholder(dtypes.float32) w = array_ops.placeholder(dtypes.float32) result_t = acc_obj(t, p, w) result = sess.run(result_t, feed_dict=({ t: [2, 1], p: [[0.1, 0.1, 0.8], [0.05, 0, 0.95]], w: [[0.5], [0.2]] })) self.assertAlmostEqual(result, 0.71, 2) # 2.5/2.7
def test_sparse_categorical_accuracy_mismatched_dims(self): acc_obj = metrics.SparseCategoricalAccuracy(name='my acc') # check config self.assertEqual(acc_obj.name, 'my acc') self.assertTrue(acc_obj.stateful) self.assertEqual(len(acc_obj.variables), 2) self.assertEqual(acc_obj.dtype, dtypes.float32) self.evaluate(variables.variables_initializer(acc_obj.variables)) # verify that correct value is returned update_op = acc_obj.update_state([2, 1], [[0.1, 0.1, 0.8], [0.05, 0.95, 0]]) self.evaluate(update_op) result = self.evaluate(acc_obj.result()) self.assertEqual(result, 1) # 2/2 # check with sample_weight result_t = acc_obj([2, 1], [[0.1, 0.1, 0.8], [0.05, 0, 0.95]], [[0.5], [0.2]]) result = self.evaluate(result_t) self.assertAlmostEqual(result, 0.93, 2) # 2.5/2.7