def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index, match_str): with pytest.raises(ValueError, match=match_str): metric( average=average, mdmc_average=mdmc_average, num_classes=num_classes, ignore_index=ignore_index, ) with pytest.raises(ValueError, match=match_str): fn_metric( _input_binary.preds[0], _input_binary.target[0], average=average, mdmc_average=mdmc_average, num_classes=num_classes, ignore_index=ignore_index, ) with pytest.raises(ValueError, match=match_str): precision_recall( _input_binary.preds[0], _input_binary.target[0], average=average, mdmc_average=mdmc_average, num_classes=num_classes, ignore_index=ignore_index, )
def __shared_step_op(self, batch, batch_idx, phase, log=True): img, mask, edge_mask = batch output = self.forward(img) loss_matrix = self.criterion(output, mask) loss = (loss_matrix * (self.edge_weight ** edge_mask)).mean() output_labels = torch.argmax(output, dim=1).view(-1) ground_truths = mask.view(-1) f1 = f1_score(output_labels, ground_truths, num_classes=self.n_classes, class_reduction=self.class_reduction) precision, recall = precision_recall(output_labels, ground_truths, num_classes=self.n_classes, class_reduction=self.class_reduction) if self.n_classes == 2: # use the positive class only for binary case f1 = f1[-1] precision = precision[-1] recall = recall[-1] if log: self.log(f"{phase}/loss", loss, prog_bar=True) self.log(f"{phase}/f1_score", f1, prog_bar=True) self.log(f"{phase}/precision", precision, prog_bar=False) self.log(f"{phase}/recall", recall, prog_bar=False) return {f"{phase}_loss": loss, f"{phase}_f1_score": f1}
def test_v1_5_metric_precision_recall(): AveragePrecision.__init__.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): AveragePrecision() Precision.__init__.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): Precision() Recall.__init__.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): Recall() PrecisionRecallCurve.__init__.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): PrecisionRecallCurve() pred = torch.tensor([0, 1, 2, 3]) target = torch.tensor([0, 1, 1, 1]) average_precision.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert average_precision(pred, target) == torch.tensor(1.) precision.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert precision(pred, target) == torch.tensor(0.5) recall.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert recall(pred, target) == torch.tensor(0.5) precision_recall.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): prec, rc = precision_recall(pred, target) assert prec == torch.tensor(0.5) assert rc == torch.tensor(0.5) precision_recall_curve.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): prec, rc, thrs = precision_recall_curve(pred, target) assert torch.equal(prec, torch.tensor([1., 1., 1., 1.])) assert torch.allclose(rc, torch.tensor([1., 0.6667, 0.3333, 0.]), atol=1e-4) assert torch.equal(thrs, torch.tensor([1, 2, 3]))
def test_precision_recall_joint(average): """A simple test of the joint precision_recall metric. No need to test this thorougly, as it is just a combination of precision and recall, which are already tested thoroughly. """ precision_result = precision( _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES ) recall_result = recall( _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES ) prec_recall_result = precision_recall( _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES ) assert torch.equal(precision_result, prec_recall_result[0]) assert torch.equal(recall_result, prec_recall_result[1])
def get_test_metrics(self, display=True): # Get Precision - Recall output = precision_recall(self.preds, self.targets, num_classes=2, class_reduction='none') precision = output[0].numpy() recall = output[1].numpy() # Get Precision-Recall Curve precision_curve, recall_curve = self.get_precision_recall_curve( pos_label=1, display=display) # Confusion Matrix cm = self.get_confusion_matrix(display=display) # F1 Score f1_score = self.get_f1_score() # F0.5 score f05_score = fbeta(self.preds, self.targets, num_classes=2, beta=0.5, threshold=0.5, average='none', multilabel=False) # F2 Score f2_score = fbeta(self.preds, self.targets, num_classes=2, beta=2, threshold=0.5, average='none', multilabel=False) # Stats_score - Class 0 tp_0, fp_0, tn_0, fn_0, sup_0 = self.get_stats_score(class_index=0) # Stats_score - Class 1 tp_1, fp_1, tn_1, fn_1, sup_1 = self.get_stats_score(class_index=1) # ROC Curve roc_auc_0 = self.get_ROC_curve(pos_label=0) roc_auc_1 = self.get_ROC_curve(pos_label=1) # Classification Report report = classification_report( self.targets.detach().numpy(), (self.preds.argmax(dim=1)).detach().numpy(), output_dict=True) print("Confusion Matrix") print(cm) print("Classification Report") print(report) # Variables are saved in a file # List of metric, value for class 0, value for class 1 metric = [ 'Precision', 'Recall', 'F1 Score', 'F0.5 Score', 'F2_Score', 'TP', 'FP', 'TN', 'FN', 'ROC' ] value_class0 = [ precision[0], recall[0], f1_score[0].numpy(), f05_score[0].numpy(), f2_score[0].numpy(), tp_0, fp_0, tn_0, fn_0, roc_auc_0 ] value_class1 = [ precision[1], recall[1], f1_score[1].numpy(), f05_score[1].numpy(), f2_score[1].numpy(), tp_1, tp_1, tn_1, fn_1, roc_auc_1 ] # Dictionary of lists dict = { 'Metric': metric, 'Class 0': value_class0, 'Class1': value_class1 } df = pd.DataFrame(dict) # dictionary of report df_report = pd.DataFrame(report) # Saving the dataframe df.to_csv(self.CSV_PATH, header=True, index=False) df_report.to_csv(self.CSV_PATH, mode='a', header=True, index=False)