示例#1
0
    def _set_metrics(self):
        num_classes = self.num_classes

        # Train
        self.train_acc = torchmetrics.Accuracy()
        self.train_precision = torchmetrics.Precision()
        self.train_recall = torchmetrics.Recall()
        self.train_f1 = torchmetrics.F1(
            num_classes=num_classes) if num_classes else None
        self.train_auc = torchmetrics.AUROC(
            num_classes=num_classes) if num_classes else None

        # Validation
        self.validation_acc = torchmetrics.Accuracy()
        self.validation_precision = torchmetrics.Precision()
        self.validation_recall = torchmetrics.Recall()
        self.validation_f1 = torchmetrics.F1(
            num_classes=num_classes) if num_classes else None
        self.validation_auc = torchmetrics.AUROC(
            num_classes=num_classes) if num_classes else None

        # Test
        self.test_acc = torchmetrics.Accuracy()
        self.test_precision = torchmetrics.Precision()
        self.test_recall = torchmetrics.Recall()
        self.test_f1 = torchmetrics.F1(
            num_classes=num_classes) if num_classes else None
        self.test_auc = torchmetrics.AUROC(
            num_classes=num_classes) if num_classes else None
示例#2
0
文件: model.py 项目: Dehde/mrnet
    def __init__(self):
        super().__init__()
        self.pretrained_model = models.alexnet(pretrained=True)
        self.pooling_layer = nn.AdaptiveAvgPool2d(1)
        self.classifer = nn.Linear(256, 3)
        self.sigmoid = torch.sigmoid
        #self.save_hyperparameters()

        self.train_f1 = torchmetrics.F1(num_classes=3)
        self.valid_f1 = torchmetrics.F1(num_classes=3)
        self.train_auc = torchmetrics.AUROC(num_classes=3,
                                            compute_on_step=False)
        self.valid_auc = torchmetrics.AUROC(num_classes=3,
                                            compute_on_step=False)
示例#3
0
    def instantiate_metrics(self, label, num_classes):
        heads = len(num_classes)
        metrics = [tm.Accuracy, tm.Precision, tm.Recall]
        metrics_per_head = [{} for head in range(heads)]
        for index, head in enumerate(metrics_per_head):
            for metric in metrics:
                try:
                    head.update({
                        f"{metric.__name__}/head-{index}/{label}":
                        metric(num_classes=num_classes[index])
                    })
                except TypeError:
                    head.update(
                        {f"{metric.__name__}/head-{index}/{label}": metric()})

                head[f"FBeta/head-{index}/{label}"] = tm.FBeta(
                    num_classes=num_classes[index])
                if index == 0:
                    auroc = tm.AUROC()
                    auroc.reorder = True
                    head[f"AUROC/head-{index}/{label}"] = auroc
        metrics_as_MetricCollection = [
            tm.MetricCollection(head) for head in metrics_per_head
        ]
        for collection in metrics_as_MetricCollection:
            collection.persistent()
        return metrics_as_MetricCollection
示例#4
0
def get_metric_AUROC(NUM_CLASS):
    
    metrics=MetricCollection(
        {
                "AUROC_macro":torchmetrics.AUROC(num_classes=NUM_CLASS)
                        
        }
    )
    return metrics
示例#5
0
def get_metric(
    metric_name: str,
    num_classes: Optional[int] = None,
    pos_label: Optional[int] = None,
):
    """
    Obtain a torchmerics.Metric from its name.
    Define a customized metric function in case that torchmetrics doesn't support some metric.

    Parameters
    ----------
    metric_name
        Name of metric.
    num_classes
        Number of classes.
    pos_label
        The label (0 or 1) of binary classification's positive class, which is used in some metrics, e.g., AUROC.

    Returns
    -------
    torchmetrics.Metric
        A torchmetrics.Metric object.
    custom_metric_func
        A customized metric function.
    """
    metric_name = metric_name.lower()
    if metric_name in [ACC, ACCURACY]:
        return torchmetrics.Accuracy(), None
    elif metric_name in [RMSE, ROOT_MEAN_SQUARED_ERROR]:
        return torchmetrics.MeanSquaredError(squared=False), None
    elif metric_name == R2:
        return torchmetrics.R2Score(), None
    elif metric_name == QUADRATIC_KAPPA:
        return (
            torchmetrics.CohenKappa(num_classes=num_classes,
                                    weights="quadratic"),
            None,
        )
    elif metric_name == ROC_AUC:
        return torchmetrics.AUROC(pos_label=pos_label), None
    elif metric_name == AVERAGE_PRECISION:
        return torchmetrics.AveragePrecision(pos_label=pos_label), None
    elif metric_name in [LOG_LOSS, CROSS_ENTROPY]:
        return torchmetrics.MeanMetric(), functools.partial(F.cross_entropy,
                                                            reduction="none")
    elif metric_name == COSINE_EMBEDDING_LOSS:
        return torchmetrics.MeanMetric(), functools.partial(
            F.cosine_embedding_loss, reduction="none")
    elif metric_name == PEARSONR:
        return torchmetrics.PearsonCorrCoef(), None
    elif metric_name == SPEARMANR:
        return torchmetrics.SpearmanCorrCoef(), None
    elif metric_name == F1:
        return CustomF1Score(num_classes=num_classes,
                             pos_label=pos_label), None
    else:
        raise ValueError(f"Unknown metric {metric_name}")
示例#6
0
    def __init__(self, hparams: Namespace) -> None:
        super(Classifier, self).__init__()

        self.hparams = hparams
        self.batch_size = hparams.batch_size

        # Build Data module
        self.data = self.DataModule(self)

        # build model
        self.__build_model()

        # Loss criterion initialization.
        self.__build_loss()

        if hparams.nr_frozen_epochs > 0:
            self.freeze_encoder()
        else:
            self._frozen = False
        self.nr_frozen_epochs = hparams.nr_frozen_epochs

        self.test_conf_matrices = []

        # Set up multi label binarizer:
        self.mlb = MultiLabelBinarizer()
        self.mlb.fit([self.hparams.top_codes])

        self.acc = torchmetrics.Accuracy()
        self.f1 = torchmetrics.F1(num_classes=self.hparams.n_labels,
                                  average='micro')
        self.auroc = torchmetrics.AUROC(num_classes=self.hparams.n_labels,
                                        average='weighted')
        # NOTE could try 'global' instead of samplewise for mdmc reduce
        self.prec = torchmetrics.Precision(num_classes=self.hparams.n_labels,
                                           is_multiclass=False)
        self.recall = torchmetrics.Recall(num_classes=self.hparams.n_labels,
                                          is_multiclass=False)
        self.confusion_matrix = torchmetrics.ConfusionMatrix(
            num_classes=self.hparams.n_labels)

        self.test_predictions = None
        self.test_labels = None