コード例 #1
0
    def test_correct_score_calculation_multi_label_micro(self):
        evaluator = evaluators.RecallEvaluator(model_output_key=None,
                                               batch_target_key='target',
                                               average='micro')

        output = torch.tensor(
            [[0.6, 0.2], [0.7, 0.2], [0.6, 0.6], [0.3, 0.55]],
            dtype=torch.float32)
        batch = {
            'target':
            torch.tensor([[1, 1], [0, 1], [1, 0], [0, 1]], dtype=torch.float32)
        }
        evaluator.step(output, batch)

        output = torch.tensor([[0.6, 0.4]], dtype=torch.float32)
        batch = {'target': torch.tensor([[1, 1]], dtype=torch.float32)}
        evaluator.step(output, batch)

        res = evaluator.calculate()

        correct = metrics.recall_score(y_pred=np.array(
            [0.6, 0.7, 0.6, 0.3, 0.6, 0.2, 0.2, 0.6, 0.55, 0.4]) > 0.5,
                                       y_true=np.array(
                                           [1, 0, 1, 0, 1, 1, 1, 0, 1, 1]))

        self.assertAlmostEqual(res.score, correct)
コード例 #2
0
    def test_correct_score_calculation_binary(self):
        evaluator = evaluators.RecallEvaluator(model_output_key=None,
                                               batch_target_key='target',
                                               average='binary')

        output = torch.tensor([0.9, 0.2, 0.8, 0.3], dtype=torch.float32)
        batch = {'target': torch.tensor([1, 1, 0, 0], dtype=torch.float32)}
        evaluator.step(output, batch)

        output = torch.tensor([0.2, 0.98, 0.76], dtype=torch.float32)
        batch = {'target': torch.tensor([1, 1, 0], dtype=torch.float32)}
        evaluator.step(output, batch)

        res = evaluator.calculate()

        self.assertAlmostEqual(res.score, 0.5)