예제 #1
0
    def test_single_meter_update_and_reset(self):
        """
        This test verifies that the meter works as expected on a single
        update + reset + same single update.
        """
        meter = RecallAtKMeter(topk=[1, 2])

        # Batchsize = 3, num classes = 3, score is probability of class
        model_output = torch.tensor(
            [
                [0.2, 0.4, 0.4],  # top-1: 1/2, top-2: 1/2
                [0.2, 0.65, 0.15],  # top-1: 1, top-2: 1/0
                [0.33, 0.33, 0.34],  # top-1: 2, top-2: 2/0?1
            ]
        )

        # One-hot encoding, 1 = positive for class
        # sample-1: 1, sample-2: 0, sample-3: 0,1,2
        target = torch.tensor([[0, 1, 0], [1, 0, 0], [1, 1, 1]])

        # Note for ties, we select randomly, so we should not use ambiguous ties
        expected_value = {"top_1": 2 / 5.0, "top_2": 4 / 5.0}

        self.meter_update_and_reset_test(meter, model_output, target, expected_value)