def test_correct_score_calculation_multi_label_micro(self):
        evaluator = evaluators.PrecisionEvaluator(model_output_key=None,
                                                  batch_target_key='target',
                                                  average='micro')

        output = torch.tensor(
            [[0.6, 0.2], [0.7, 0.2], [0.6, 0.6], [0.3, 0.55]],
            dtype=torch.float32)
        batch = {
            'target':
            torch.tensor([[1, 1], [0, 1], [1, 0], [0, 1]], dtype=torch.float32)
        }
        evaluator.step(output, batch)

        output = torch.tensor([[0.6, 0.4]], dtype=torch.float32)
        batch = {'target': torch.tensor([[1, 1]], dtype=torch.float32)}
        evaluator.step(output, batch)

        res = evaluator.calculate()

        correct = metrics.precision_score(y_pred=np.array(
            [0.6, 0.7, 0.6, 0.3, 0.6, 0.2, 0.2, 0.6, 0.55, 0.4]) > 0.5,
                                          y_true=np.array(
                                              [1, 0, 1, 0, 1, 1, 1, 0, 1, 1]))

        self.assertAlmostEqual(res.score, correct)
    def test_correct_score_calculation_binary(self):
        evaluator = evaluators.PrecisionEvaluator(model_output_key=None,
                                                  batch_target_key='target',
                                                  average='binary')

        output = torch.tensor([0.9, 0.2, 0.8, 0.3], dtype=torch.float32)
        batch = {'target': torch.tensor([1, 1, 0, 0], dtype=torch.float32)}
        evaluator.step(output, batch)

        output = torch.tensor([0.2, 0.98, 0.76], dtype=torch.float32)
        batch = {'target': torch.tensor([1, 1, 0], dtype=torch.float32)}
        evaluator.step(output, batch)

        res = evaluator.calculate()

        self.assertAlmostEqual(res.score, 2. / 4)