def test_rank_classification_raise(self):
    with self.assertRaisesWithLiteralMatch(
        ValueError,
        "`targets` should contain 3 elements but has 2."):
      metrics.rank_classification(
          [
              ((0, 0), True),
              ((0, 1), True),
          ],
          [
              0.1, 0.5
          ],
          num_classes=2)

    with self.assertRaisesWithLiteralMatch(
        ValueError,
        "The first element of `targets` ('idx') should be 2-dimensional. Got "
        "0."):
      metrics.rank_classification(
          [
              (0, True, 1.0),
              (0, True, 1.0),
          ],
          [
              0.1, 0.5
          ],
          num_classes=2)
示例#2
0
  def test_rank_classification(self):

    self.assertDictClose(
        metrics.rank_classification(
            [0, 0,
             1, 1,
             0, 0,
             0, 0],
            [0.1, 0.5,
             1.0, 1.1,
             0.3, 0.1,
             0.6, 0.5]),
        {
            "accuracy": 75.,
            "f1": 66.6666667,
            "auc-roc": 66.6666667,
            "auc-pr": 70.8333333
        })

    self.assertDictClose(
        metrics.rank_classification(
            [1, 1, 1,
             0, 0, 0,
             2, 2, 2],
            [0.1, 0.5, 0.0,
             -2, -1, -3,
             3.0, 3.1, 3.2],
            num_classes=3),
        {
            "accuracy": 66.6666667,
            "mean_3class_f1": 55.5555556,
            "auc-roc": 50.0,
            "auc-pr": 61.1111111
        })
  def test_rank_classification_custom_weights(self):
    # num_classes = 2
    self.assertDictClose(
        metrics.rank_classification(
            [(0, True, 0.2), (0, False, 0.2),
             (1, False, 1.0), (1, True, 1.0),
             (0, True, 0.8), (0, False, 0.8),
             (0, True, 0.5), (0, False, 0.5),],
            [0.1, 0.5,
             1.0, 1.1,
             0.3, 0.1,
             0.6, 0.5],
            num_classes=2),
        {
            "accuracy": 92.0,
            "auc-pr": 89.0,
            "auc-roc": 86.6666667,
            "f1": 90.9090909,
        })

    # num_classes = 3
    self.assertDictClose(
        metrics.rank_classification(
            [
                # 1
                (0, False, 0.2),
                (0, True, 0.2),
                (0, False, 0.2),
                # 0
                (1, True, 0.5),
                (1, False, 0.5),
                (1, False, 0.5),
                # 2
                (2, False, 1.0),
                (2, False, 1.0),
                (2, True, 1.0)
            ],
            [0.1, 0.5, 0.0, -2, -1, -3, 3.0, 3.1, 3.2],
            num_classes=3),
        {
            "accuracy": 70.5882353,
            "auc-pr": 52.6610644,
            "auc-roc": 55.5555556,
            "mean_3class_f1": 48.1481481,
        })

    # num_classes = None, multi-answer
    self.assertDictClose(
        metrics.rank_classification(
            [(0, False, 0.2), (0, True, 0.2),  # 1
             (1, True, 0.5), (1, False, 0.5), (1, True, 0.5),  # 0, 3
             (2, True, 1.0)],  # 1
            [0.1, 0.5,
             -2, -1, -3,
             3.0],
            num_classes=None),
        {
            "accuracy": 70.5882353,
        })
 def test_rank_classification_raise(self):
   with self.assertRaisesRegex(
       ValueError,
       "`targets` should contain three elements. Only 2 are provided."):
     metrics.rank_classification([
         (0, True),
         (0, True),
         (0, False),
     ], [0.1, 0.5, 1.0],
                                 num_classes=2)
 def test_rank_classification_normalized(self):
   # num_classes = 2
   self.assertDictClose(
       metrics.rank_classification(
           [
               # 0
               ((0, 0), True, 1.0, 5),
               ((0, 1), False, 1.0, 10),
               # 1
               ((1, 0), False, 1.0, 2),
               ((1, 1), True, 1.0, 3),
               # 0
               ((2, 0), True, 1.0, 5),
               ((2, 1), False, 1.0, 6),
               # 0
               ((3, 0), True, 1.0, 3),
               ((3, 1), False, 1.0, 2),
           ],
           [
               0.5, 5.0,
               2.0, 3.3,
               1.5, 0.6,
               1.8, 1.0
           ],
           num_classes=2,
           normalize_by_target_length=True,),
       {
           "accuracy": 75.,
           "auc-pr": 50.0,
           "auc-roc": 66.6666667,
           "f1": 66.6666667,
       })
  def test_rank_classification_shuffled(self):
    # num_classes = 2
    self.assertDictClose(
        metrics.rank_classification(
            [
                ((3, 0), True, 0.5),
                ((0, 0), True, 0.2),
                ((1, 0), False, 1.0),
                ((1, 1), True, 1.0),
                ((2, 0), True, 0.8),
                ((2, 1), False, 0.8),
                ((3, 1), False, 0.5),
                ((0, 1), False, 0.2),
            ],
            [
                0.6,
                0.1,
                1.0,
                1.1,
                0.3,
                0.1,
                0.5,
                0.5,
            ],
            num_classes=2),
        {
            "accuracy": 92.0,
            "auc-pr": 83.3333333,
            "auc-roc": 86.6666667,
            "f1": 90.9090909,
        })

    # num_classes = 3
    self.assertDictClose(
        metrics.rank_classification(
            [
                ((0, 0), False, 0.2),
                ((2, 1), False, 1.0),
                ((0, 1), True, 0.2),
                ((1, 0), True, 0.5),
                ((1, 1), False, 0.5),
                ((1, 2), False, 0.5),
                ((0, 2), False, 0.2),
                ((2, 0), False, 1.0),
                ((2, 2), True, 1.0)
            ],
            [
                0.1,
                3.1,
                0.5,
                -2,
                -1,
                -3,
                0.0,
                3.0,
                3.2
            ],
            num_classes=3),
        {
            "accuracy": 70.5882353,
            "mean_3class_f1": 48.1481481,
        })

    # num_classes = None, multi-answer
    self.assertDictClose(
        metrics.rank_classification(
            [
                ((0, 0), False, 0.2),
                ((2, 0), True, 1.0),
                ((0, 1), True, 0.2),
                ((1, 2), True, 0.5),
                ((1, 0), True, 0.5),
                ((1, 1), False, 0.5),
            ],
            [
                0.1,
                3.0,
                0.5,
                -3,
                -2,
                -1,
            ],
            num_classes=None),
        {
            "accuracy": 70.5882353,
        })
  def test_rank_classification_default_weights(self):

    # num_classes = 2
    self.assertDictClose(
        metrics.rank_classification(
            [
                # 0
                ((0, 0), True, 1.0),
                ((0, 1), False, 1.0),
                # 1
                ((1, 0), False, 1.0),
                ((1, 1), True, 1.0),
                # 0
                ((2, 0), True, 1.0),
                ((2, 1), False, 1.0),
                # 0
                ((3, 0), True, 1.0),
                ((3, 1), False, 1.0),
            ],
            [
                0.1, 0.5,
                1.0, 1.1,
                0.3, 0.1,
                0.6, 0.5
            ],
            num_classes=2),
        {
            "accuracy": 75.,
            "auc-pr": 50.0,
            "auc-roc": 66.6666667,
            "f1": 66.6666667,
        })

    # num_classes = 3
    self.assertDictClose(
        metrics.rank_classification(
            [
                # 1
                ((0, 0), False, 1.0),
                ((0, 1), True, 1.0),
                ((0, 2), False, 1.0),
                # 0
                ((1, 0), True, 1.0),
                ((1, 1), False, 1.0),
                ((1, 2), False, 1.0),
                # 2
                ((2, 0), False, 1.0),
                ((2, 1), False, 1.0),
                ((2, 2), True, 1.0)
            ],
            [
                0.1, 0.5, 0.0,
                -2, -1, -3,
                3.0, 3.1, 3.2
            ],
            num_classes=3),
        {
            "accuracy": 66.6666667,
            "mean_3class_f1": 55.5555556,
        })

    # num_classes = 3, multi-label
    self.assertDictClose(
        metrics.rank_classification(
            [
                # 1
                ((0, 0), False, 1.0),
                ((0, 1), True, 1.0),
                ((0, 2), False, 1.0),
                # 0, 2
                ((1, 0), True, 1.0),
                ((1, 1), False, 1.0),
                ((1, 2), True, 1.0),
                # 1, 2
                ((2, 0), False, 1.0),
                ((2, 1), True, 1.0),
                ((2, 2), True, 1.0)
            ],
            [
                0.1, 0.5, 0.0,
                -2, -1, -3,
                3.0, 3.1, 3.2
            ],
            num_classes=3),
        {
            "accuracy": 66.6666667,
        })

    # num_classes = None, multi-answer
    self.assertDictClose(
        metrics.rank_classification(
            [
                # 1
                ((0, 0), False, 1.0),
                ((0, 1), True, 1.0),
                # 0, 3
                ((1, 0), True, 1.0),
                ((1, 1), False, 1.0),
                ((1, 2), True, 1.0),
                # 0
                ((2, 0), True, 1.0)
            ],
            [
                0.1, 0.5,
                -2, -1, -3,
                3.0
            ],
            num_classes=None),
        {
            "accuracy": 66.6666667,
        })
    def test_rank_classification(self):

        # num_classes = 2
        self.assertDictClose(
            metrics.rank_classification([
                (0, True),
                (0, False),
                (1, False),
                (1, True),
                (0, True),
                (0, False),
                (0, True),
                (0, False),
            ], [0.1, 0.5, 1.0, 1.1, 0.3, 0.1, 0.6, 0.5],
                                        num_classes=2), {
                                            "accuracy": 75.,
                                            "auc-pr": 70.8333333,
                                            "auc-roc": 66.6666667,
                                            "f1": 66.6666667,
                                        })

        # num_classes = 3
        self.assertDictClose(
            metrics.rank_classification(
                [
                    # 1
                    (0, False),
                    (0, True),
                    (0, False),
                    # 0
                    (1, True),
                    (1, False),
                    (1, False),
                    # 2
                    (2, False),
                    (2, False),
                    (2, True)
                ],
                [0.1, 0.5, 0.0, -2, -1, -3, 3.0, 3.1, 3.2],
                num_classes=3),
            {
                "accuracy": 66.6666667,
                "auc-pr": 61.1111111,
                "auc-roc": 50.0,
                "mean_3class_f1": 55.5555556,
            })

        # num_classes = 3, multi-label
        self.assertDictClose(
            metrics.rank_classification(
                [
                    (0, False),
                    (0, True),
                    (0, False),  # 1
                    (1, True),
                    (1, False),
                    (1, True),  # 0, 2
                    (2, False),
                    (2, True),
                    (2, True)
                ],  # 1, 2
                [0.1, 0.5, 0.0, -2, -1, -3, 3.0, 3.1, 3.2],
                num_classes=3),
            {
                "accuracy": 66.6666667,
            })

        # num_classes = None, multi-answer
        self.assertDictClose(
            metrics.rank_classification(
                [
                    (0, False),
                    (0, True),  # 1
                    (1, True),
                    (1, False),
                    (1, True),  # 0, 3
                    (2, True)
                ],  # 1
                [0.1, 0.5, -2, -1, -3, 3.0],
                num_classes=None),
            {
                "accuracy": 66.6666667,
            })