Esempio n. 1
0
 def test_get_multiclass_predictions_and_correctness(self):
     multiclass_probs = np.array([[0.1, 0.2, 0.7], [0.5, 0.3, 0.2],
                                  [0.7, 0.2, 0.1], [0.3, 0.5, 0.2]])
     labels = np.array([2, 0, 1, 0])
     (argmax_probs,
      is_correct) = metrics_lib.get_multiclass_predictions_and_correctness(
          multiclass_probs, labels)
     self.assertAllEqual(argmax_probs, [0.7, 0.5, 0.7, 0.5])
     self.assertAllEqual(is_correct, [True, True, False, False])
Esempio n. 2
0
 def test_get_multiclass_predictions_and_correctness_error_cases(self):
     multiclass_probs = np.array([[0.1, 0.2, 0.7], [0.5, 0.3, 0.2],
                                  [0.7, 0.2, 0.1], [0.3, 0.5, 0.2]])
     labels = np.array([2, 0, 1, 0])
     with self.assertRaises(ValueError):
         bad_multiclass_probs = multiclass_probs - 0.01
         metrics_lib.get_multiclass_predictions_and_correctness(
             bad_multiclass_probs, labels)
     with self.assertRaises(ValueError):
         metrics_lib.get_multiclass_predictions_and_correctness(
             bad_multiclass_probs[Ellipsis, None], labels)
     with self.assertRaises(ValueError):
         metrics_lib.get_multiclass_predictions_and_correctness(
             bad_multiclass_probs, labels[Ellipsis, None])