Beispiel #1
0
    def forward(self, logits: torch.Tensor, labels: torch.Tensor):
        assert logits.ndim == 2
        assert labels.ndim == 1

        with torch.no_grad():
            if self.average == 'macro':
                return recall(
                    pred=nn.functional.softmax(logits, dim=1),
                    target=labels,
                    num_classes=self.num_classes,
                    reduction='elementwise_mean',
                )
            elif self.average == 'micro':
                raise NotImplementedError
            elif self.average == 'weighted':
                raise NotImplementedError
            else:
                raise ValueError
Beispiel #2
0
def test_precision_recall_joint(average):
    """A simple test of the joint precision_recall metric.

    No need to test this thorougly, as it is just a combination of precision and recall,
    which are already tested thoroughly.
    """

    precision_result = precision(
        _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
    )
    recall_result = recall(
        _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
    )

    prec_recall_result = precision_recall(
        _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
    )

    assert torch.equal(precision_result, prec_recall_result[0])
    assert torch.equal(recall_result, prec_recall_result[1])