def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes,
                      ignore_index, match_str):
    with pytest.raises(ValueError, match=match_str):
        metric(
            average=average,
            mdmc_average=mdmc_average,
            num_classes=num_classes,
            ignore_index=ignore_index,
        )

    with pytest.raises(ValueError, match=match_str):
        fn_metric(
            _input_binary.preds[0],
            _input_binary.target[0],
            average=average,
            mdmc_average=mdmc_average,
            num_classes=num_classes,
            ignore_index=ignore_index,
        )

    with pytest.raises(ValueError, match=match_str):
        precision_recall(
            _input_binary.preds[0],
            _input_binary.target[0],
            average=average,
            mdmc_average=mdmc_average,
            num_classes=num_classes,
            ignore_index=ignore_index,
        )
    def __shared_step_op(self, batch, batch_idx, phase, log=True):
        img, mask, edge_mask = batch
        output = self.forward(img)
        loss_matrix = self.criterion(output, mask)
        loss = (loss_matrix * (self.edge_weight ** edge_mask)).mean()

        output_labels = torch.argmax(output, dim=1).view(-1)
        ground_truths = mask.view(-1)
        f1 = f1_score(output_labels,
                      ground_truths,
                      num_classes=self.n_classes,
                      class_reduction=self.class_reduction)

        precision, recall = precision_recall(output_labels,
                                             ground_truths,
                                             num_classes=self.n_classes,
                                             class_reduction=self.class_reduction)

        if self.n_classes == 2:
            # use the positive class only for binary case
            f1 = f1[-1]
            precision = precision[-1]
            recall = recall[-1]

        if log:
            self.log(f"{phase}/loss", loss, prog_bar=True)
            self.log(f"{phase}/f1_score", f1, prog_bar=True)
            self.log(f"{phase}/precision", precision, prog_bar=False)
            self.log(f"{phase}/recall", recall, prog_bar=False)

        return {f"{phase}_loss": loss, f"{phase}_f1_score": f1}
Exemple #3
0
def test_v1_5_metric_precision_recall():
    AveragePrecision.__init__.warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        AveragePrecision()

    Precision.__init__.warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        Precision()

    Recall.__init__.warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        Recall()

    PrecisionRecallCurve.__init__.warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        PrecisionRecallCurve()

    pred = torch.tensor([0, 1, 2, 3])
    target = torch.tensor([0, 1, 1, 1])
    average_precision.warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert average_precision(pred, target) == torch.tensor(1.)

    precision.warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert precision(pred, target) == torch.tensor(0.5)

    recall.warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert recall(pred, target) == torch.tensor(0.5)

    precision_recall.warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        prec, rc = precision_recall(pred, target)
        assert prec == torch.tensor(0.5)
        assert rc == torch.tensor(0.5)

    precision_recall_curve.warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        prec, rc, thrs = precision_recall_curve(pred, target)
        assert torch.equal(prec, torch.tensor([1., 1., 1., 1.]))
        assert torch.allclose(rc,
                              torch.tensor([1., 0.6667, 0.3333, 0.]),
                              atol=1e-4)
        assert torch.equal(thrs, torch.tensor([1, 2, 3]))
Exemple #4
0
def test_precision_recall_joint(average):
    """A simple test of the joint precision_recall metric.

    No need to test this thorougly, as it is just a combination of precision and recall,
    which are already tested thoroughly.
    """

    precision_result = precision(
        _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
    )
    recall_result = recall(
        _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
    )

    prec_recall_result = precision_recall(
        _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
    )

    assert torch.equal(precision_result, prec_recall_result[0])
    assert torch.equal(recall_result, prec_recall_result[1])
    def get_test_metrics(self, display=True):
        # Get Precision - Recall
        output = precision_recall(self.preds,
                                  self.targets,
                                  num_classes=2,
                                  class_reduction='none')
        precision = output[0].numpy()
        recall = output[1].numpy()
        # Get Precision-Recall Curve
        precision_curve, recall_curve = self.get_precision_recall_curve(
            pos_label=1, display=display)
        # Confusion Matrix
        cm = self.get_confusion_matrix(display=display)
        # F1 Score
        f1_score = self.get_f1_score()
        # F0.5 score
        f05_score = fbeta(self.preds,
                          self.targets,
                          num_classes=2,
                          beta=0.5,
                          threshold=0.5,
                          average='none',
                          multilabel=False)
        # F2 Score
        f2_score = fbeta(self.preds,
                         self.targets,
                         num_classes=2,
                         beta=2,
                         threshold=0.5,
                         average='none',
                         multilabel=False)
        # Stats_score - Class 0
        tp_0, fp_0, tn_0, fn_0, sup_0 = self.get_stats_score(class_index=0)
        # Stats_score - Class 1
        tp_1, fp_1, tn_1, fn_1, sup_1 = self.get_stats_score(class_index=1)
        # ROC Curve
        roc_auc_0 = self.get_ROC_curve(pos_label=0)
        roc_auc_1 = self.get_ROC_curve(pos_label=1)
        # Classification Report
        report = classification_report(
            self.targets.detach().numpy(),
            (self.preds.argmax(dim=1)).detach().numpy(),
            output_dict=True)
        print("Confusion Matrix")
        print(cm)
        print("Classification Report")
        print(report)

        # Variables are saved in a file
        # List of metric, value for class 0, value for class 1
        metric = [
            'Precision', 'Recall', 'F1 Score', 'F0.5 Score', 'F2_Score', 'TP',
            'FP', 'TN', 'FN', 'ROC'
        ]
        value_class0 = [
            precision[0], recall[0], f1_score[0].numpy(), f05_score[0].numpy(),
            f2_score[0].numpy(), tp_0, fp_0, tn_0, fn_0, roc_auc_0
        ]
        value_class1 = [
            precision[1], recall[1], f1_score[1].numpy(), f05_score[1].numpy(),
            f2_score[1].numpy(), tp_1, tp_1, tn_1, fn_1, roc_auc_1
        ]
        # Dictionary of lists
        dict = {
            'Metric': metric,
            'Class 0': value_class0,
            'Class1': value_class1
        }
        df = pd.DataFrame(dict)
        # dictionary of report
        df_report = pd.DataFrame(report)
        # Saving the dataframe
        df.to_csv(self.CSV_PATH, header=True, index=False)
        df_report.to_csv(self.CSV_PATH, mode='a', header=True, index=False)